165行で実装するProtocol Buffersデコーダ(ミニマム版)

この記事は Go2 Advent Calendar 2018の11日目の記事です。

今年の後半くらいに Protocol Buffers の仕様を読み始めたら、とてもシンプルかつコンパクトな仕様なのにcompatibilityへの考慮が凄まじくて、2018年後半に書いた記事の大半がProtocol Buffersに関するものでした。 仕様とバイナリを睨めっこしていたら、自分でもバイナリをデコードしたくなったので、実装してみました。

本内容は、あくまでProtocol Buffersの勉強を目的としたもので、仕様には完璧に添っていません。 というか、わかりやすさ(と実装のしやすさ)を優先して、コンパクトな仕様のさらにミニマム版な内容となっています。 当然ですが、実運用する際にはofficialの実装を利用してください。

どこまで実装するか

上述の通り、ミニマム版として、以下を実装範囲とします。

  1. バイナリからGoの構造体へUnmarshalする
  2. Unmarshalする構造体は既知とする
  3. wire typeは、値によって長さが変化する wire type 0(Varint)とwire type 2(Length-delimited)を対象とする

Length-delimitedVarintは後ほど説明するので、今は一旦飛ばしてOKです。

3.について、後述のLexerは、他のwire typeにも適用できるように設計してるので、長さが固定なことに注意すれば実装可能です。Future Workです。

ミニマム版ですが、Protocol Buffersでも最初にハマるVarintも含めているので、Protocol Buffersよくわからないという方にも、お伝えできることがあるんじゃないかと思います。 Varintの仕様がわかっていると値は127以下にするのが望ましい(バイナリが短くなる)とかも理解できるのうれしいです。

バイナリの生成

まずはofficalのエンコーダでバイナリを生成していきましょう。 正しく生成されたバイナリをデコードすることで、もとの構造体が復元できることを目標にします。

まず、すべての大元になる.protoファイルを以下のように定義します。
Person型にNameAgeのフィールドをもたせています。 wire type 2(Length-delimited)を表現したかったので、stringint32のままでよさそうなNameAgeをwrapしています。 (NameFirstNameLastNameとかにしたくなるかもしれないし......)

syntax = "proto3";

package person;

message Person {
  Name name = 1;
  Age age = 2;
}

message Name {
  string value = 1;
}

message Age {
  int32 value = 1;
}

protocして、上記.protoファイルに対応したライブラリ.pb.goをgenerateします。

❯ protoc -I=./ --go_out=./ person.proto

以下表の値を設定したバイナリを生成します。

フィールド
Name Alice
Age 20
package main

import (
    "io/ioutil"
    "log"

    pb "github.com/cipepser/protobufDecoder/Person"
    "github.com/golang/protobuf/proto"
)

func main() {
    p := &pb.Person{
        Name: &pb.Name{
            Value: "Alice",
        },
        Age: &pb.Age{
            Value: 20,
        },
    }

    if err := write("./person/alice.bin", p); err != nil {
        log.Fatal(err)
    }
}

func write(file string, p *pb.Person) error {
    out, err := proto.Marshal(p)
    if err != nil {
        return err
    }

    if err := ioutil.WriteFile(file, out, 0644); err != nil {
        return err
    }

    return nil
}

生成したperson/alice.binを適当なバイナリエディタで見てみると以下のようになります。 なお、vimなら:%!xxdで見れます。 hexdump person/alice.binをターミナル上で実行するのでもよいと思います。

0a07 0a05 416c 6963 6512 0208 14         ....Alice....

この時点で、バイナリではあるものの暗号化されているわけではない(Aliceが見えている)ので、ちょっと安心感を覚えます。

Protocol Buffersのバイナリの読み方

さて、実装を始める前にProtocol Buffersのバイナリがどのようなフォーマットなのか説明します。
Protocol Buffers のエンコーディング仕様の解説でも述べられているように以下が基本です。

key = タグナンバー * 8 + タイプ値

タグナンバーは.protoで定義した値です。 例えば、今回のAge age = 2;であれば、2がタグナンバーです。

タイプ値は、メッセージタイプを表す値で、 公式ドキュメントwire typesとして、以下のように定義されています。 冒頭からLength-delimitedVarintと言っていたやつです。

Type Meaning Used For
0 Varint int32, int64, uint32, uint64, sint32, sint64, bool, enum
1 64-bit fixed64, sfixed64, double
2 Length-delimited string, bytes, embedded messages, packed repeated fields
3 Start group groups (deprecated)
4 End group groups (deprecated)
5 32-bit fixed32, sfixed32, float

これだけだとよくわからないと思うので、例として、上で生成したバイナリをデコードしていきましょう。
改めて生成したバイナリを記載します。

0a 07 0a 05 41 6c 69 63 65 12 02 08 14

わかりやすいように色分けしました。

f:id:cipepser:20181208001319p:plain

一つずつ読んでいきます。

まず初めの0x0aは、
0d10 = タグname(1) * 8 + Length-delimited(type 2)
であることからタグナンバーとタイプ値がわかります。
Nameは自身で定義したmessageなので、表中のembedded messageが該当し、Length-delimitedとなります。
Length-delimited(type 2)だったので、lengthを得るために続く0x07を読みます。これがタグname(1)のlengthとなるので、後続の7バイト0a 05 41 6c 69 63 65Nameとしてデコードします。

Name(赤色の7バイト)の初め0x0aは、
0d10 = タグvalue(1) * 8 + Length-delimited(type 2)
です。
Length-delimited(type 2)だったので、lengthを得るために続く0x05を読みます。これがタグvalue(1)のlengthとなるので、後続の5バイト41 6c 69 63 65stringとして読んでいきます。 この5バイトをASCIIとして読むと41 6c 69 63 65Aliceとなります。

いい感じです。この調子で残りの12 02 08 14も読んでいきましょう。

0x12は、
0d18 = タグage(2) * 8 + Length-delimited(type 2)
です。
またまた、Length-delimited(type 2)だったので、lengthを得るために続く0x02を読みます。これがタグage(2)のlengthとなるので、後続の2バイト08 14Ageとしてデコードします。

0x08は、
0d08 = タグvalue(1) * 8 + Varint(type 0)
です。
やっとVarint(type 0)が登場しましたね。 Varint(type 0)はちょっとトリッキーなので、このあとすぐ説明します。 値が128未満であれば、そのままデコードしてあげることができるので、 今回の例では0x14int32として読んで0x14 = 0d20が得られます。

以上から、もともと定義したAlice20をデコードできました。

フィールド
Name Alice
Age 20

Varintの仕様について

Varint(type 0)について、もう少し詳しく見ていきます。 値が128未満であれば、そのままデコードできると述べたので、128以上の値として131としてみましょう。(128だと1が立つbitが一つしかないのでもう少しわかりやすく131にしました)

上記例で最後にVarintとして読み込んだ08 14(緑色の箇所)を思い出しましょう。

0x08 = 0d08 = タグvalue(1) * 8 + Varint(type 0)
からVarint(type 0)であることがわかり、 0x14int32としてデコードし、0d20を得ました。

.protoAge: 20,Age: 131,に変更してバイナリを生成し直してみます。 コードは上述と同じなので省略しますが、実行して得られるバイナリは、08 83 01となります。

先頭1バイトは変わらず0x08なので、value(tag 1),Varint(type 0)となるので83をVarintとして読み込みます。

ここで0x83を2進数で表記すると0b1000 0011です。
Varintでは、先頭1bitが1のとき、次の1バイトに数値が続いていることを表します。
次の1バイトを読み込むと0x01 = 0b0000 0001となり、先頭1bitが0となったため、読み込みはここで終了です。

あとは0x830x01を組み合わせてint32にデコードしてあげればいいのですが、仕様に以下のように書かれているので、リトルエンディアンで読んでいく必要があります。

varints store numbers with the least significant group first

また、先頭1bitは無視することにも注意です。

Varintの値
= 0x01(0b0000 0001)から先頭1bit落としたもの ++ 0x83(0b1000 0011)から先頭1bit落としたもの // リトルエンディアンなので0x01が先
= 000 0001 ++ 000 0011
= 0b1000 0011 // 先頭の6bitは0なので省略
= 0d131

以上より、131にデコードできます。

++演算子はバイナリを結合する操作を表します。

実装

バイナリの読み方がわかったところで実装に入っていきます。

Lexer

バイナリを読む部分を実装します。 ちょうど最近、Go言語でつくるインタプリタを読んでいるので、こちらの実装を大きく参考にさせて頂いています。

まず、Lexerを以下のように定義します。バイナリのバイト列bを1バイトずつ読んでいくため、 現在の位置positionと次のバイトの位置readPositionを保持します。

type Lexer struct {
    b            []byte
    position     int
    readPosition int
}

コンストラクタはバイト列inputを受け取って*Lexerを返します。

func New(input []byte) *Lexer {
    l := &Lexer{b: input}
    l.readPosition = l.position + 1
    return l
}

以下は補助的な役割を果たすメソッドですが、実装しておくと便利です。 バイト列がまだ読み込めるかをhasNext()で判断します。 また、next()で1バイト先に進められるようにします。

func (l *Lexer) hasNext() bool {
    return l.readPosition < len(l.b)
}

func (l *Lexer) next() {
    l.position++
    l.readPosition = l.position + 1
}

readCurByte()は現在の位置の1バイトを読み込み、位置を1つ進めます。 Varintを読み込む際はこちらを使います。

func (l *Lexer) readCurByte() byte {
    b := l.b[l.position]
    l.next()
    return b
}

readBytes(n int)readCurByte()のn文字版です。 Length-delimitedを読み込む際に利用します。
今回は省略しましたが、hasNext()EOFの判定を入れたほうがいいですね。

func (l *Lexer) readBytes(n int) []byte {
    bs := l.b[l.position : l.position+n]
    for i := 0; i < n; i++ {
        l.next()
    }
    return bs
}

Varintのデコード

Varintデコーダを実装します。
readCurByte()で1バイト読んできて、先頭1bitが1である限り(0になるまで)、1バイトずつ読み込みます。 先頭1bitは数値としてデコードする際には不要なので、0x7fとANDで論理積を取ります。 また、リトルエンディアンとしてデコードする必要があるので、スタックに積んでおきます。 スタックといいつつ、Goのcontainer/listはあんまり使われている印象がない(おそらくsliceのほうが速い?)ので、 bs[]byteappendすることにしました。
あとは先頭1bitを取り除いて、++で結合してあげればよいので、7bitほど左シフトさせて足しこんでいきます。

func (l *Lexer) decodeVarint() (uint64, error) {
    if len(l.b) == l.position {
        return 0, errors.New("unexpected EOF")
    }

    var bs []byte
    b := l.readCurByte()
    for bits.LeadingZeros8(b) == 0 { // 最上位bitが1のとき
        bs = append(bs, b&0x7f)
        b = l.readCurByte()
    }

    // 最上位bitが0のとき = 最後の1byte
    x := uint64(b)
    for i := 0; i < len(bs); i++ {
        x = x<<7 + uint64(bs[len(bs)-1-i])
    }

    return x, nil
}

Unmarshalの実装

PersonNameAgeそれぞれについて、Unmarshalの実装します。 どの型も流れは共通で、以下の流れになります。

  1. 1バイト読み込み、タグナンバーtagとタイプ値wireを計算する
  2. wireごとに後続の何バイト読み込むかがわかるので、その数だけ読み込む
  3. 読み込んだ値をtagに応じて評価する

1.のtagwireの計算は、

key = タグナンバー * 8 + タイプ値

であることを思い出すと、以下のように実装できます。

tag := key >> 3       // 下位4bit目以上を抜き出す
wire := int(key) & 7  // 下位3bitのみ抜き出す

tag0にしたとき

コラム的な話になりますが、tag0になるようなバイナリを与えると以下のようにpanicします。

panic: proto: person.Person: illegal tag 0 (wire type 2)

逆に0でないtagではpanic(どころかエラーにもならない)になりません。 今回の例でいうとPerson型はtag12しか定義していませんが、tagが3になるようなバイナリを読み込ませてもエラーにはなりません。 これはcompatibilityを考慮してのことで、tag3となるフィールドが増えた際に、 tag2までしかない古い.protoしか知らないクライアントでもエラーが起きないようにするためだと思われます。
今回の例(tag12しかない状態)で、tag3になるようにバイナリを作り、デコードすると、そのフィールドはnilになります。

panicさせる動作については、table_unmarshal.goに以下のように書かれています。

Explicitly disallow tag 0. This will ensure we flag an error when decoding a buffer of all zeros. Without this code, we would decode and skip an all-zero buffer of even length. [0 0] is [tag=0/wiretype=varint varint-encoded-0].

PersonのUnmarshal

今回の.protoで定義したPersonは以下のような構造体となります。

type Person struct {
    Name *Name // tag: 1
    Age  *Age  // tag: 2
}

このPerson型に対してUnmarshalを実装します。 今回Person型には、Length-delimited(type 2)になるフィールドしかないため、wireが意味を持つのはcase 2のときのみです。 tagName(1)Age(2)があるので、場合分けします。

func (p *Person) Unmarshal(b []byte) error {
    l := New(b)
    for l.hasNext() {
        key := uint64(l.readCurByte())
        tag := key >> 3
        wire := int(key) & 7

        switch wire {
        case 2:
            length := int(l.readCurByte())
            v := l.readBytes(length)

            switch tag {
            case 0:
                return errors.New("illegal tag 0")
            case 1:
                p.Name = &Name{}
                if err := p.Name.Unmarshal(v); err != nil {
                    return err
                }
            case 2:
                p.Age = &Age{}
                if err := p.Age.Unmarshal(v); err != nil {
                    return err
                }
            }
        default: // Person型はwire type 2以外は存在しない
            return fmt.Errorf("unexpected wire type: %d", wire)
        }
    }

    return nil
}

NameのUnmarshal

Personと同じように.protoで定義したNameは以下のような構造体となります。

type Name struct {
    Value string // tag: 1
}

このName型に対してUnmarshalを実装します。 今回Name型には、Length-delimited(type 2)になるフィールドしかないため、wireが意味を持つのはPersonと同じくcase 2のときのみです。 tagValue(1)のみなので、読み込んだバイト列をstringとしてデコードします。

func (n *Name) Unmarshal(b []byte) error {
    l := New(b)
    for l.hasNext() {
        key := uint64(l.readCurByte())
        tag := key >> 3
        wire := int(key) & 7

        switch wire {
        case 2:
            length := int(l.readCurByte())
            v := l.readBytes(length)

            switch tag {
            case 0:
                return errors.New("illegal tag 0")
            case 1:
                n.Value = string(v)
            }
        default: // Name型はwire type 2以外は存在しない
            return fmt.Errorf("unexpected wire type: %d", wire)
        }

    }
    return nil
}

AgeのUnmarshal

PersonNameと同じように.protoで定義したAgeは以下のような構造体となります。

type Age struct {
    Value int32 // tag: 1
}

このAge型に対してUnmarshalを実装します。 今回Age型には、Varint(type 0)になるフィールドしかないため、wireが意味を持つのはcase 0のときのみです。 tagValue(1)のみなので、読み込んだバイト列をdecodeVarint()int32としてデコードします。

func (a *Age) Unmarshal(b []byte) error {
    l := New(b)

    for l.hasNext() {
        key := uint64(l.readCurByte())
        tag := key >> 3
        wire := int(key) & 7

        switch wire {
        case 0:
            switch tag {
            case 0:
                return errors.New("illegal tag 0")
            case 1:
                i, err := l.decodeVarint()
                if err != nil {
                    return err
                }
                a.Value = int32(i)
            }
        default: // Age型はwire type 1以外は存在しない
            return fmt.Errorf("unexpected wire type: %d", wire)
        }

    }
    return nil
}

以上でデコーダ本体の実装は完了です。お疲れ様でした。

テスト

最後にテストを書いていきます。 今回の例で使ったバイナリ、ゼロ値(nil)になるパターン、マルチバイトになるVarintをテストパターンとします。

package decoder

import (
    "encoding/hex"
    "testing"

    "github.com/google/go-cmp/cmp"
)

func atob(s string) []byte {
    b, _ := hex.DecodeString(s)
    return b
}

func TestUnmarshalPerson(t *testing.T) {
    tests := []struct {
        b      []byte
        expect Person
    }{
        {
            // 今回の例
            b: atob("0a070a05416c69636512020814"),
            expect: Person{
                Name: &Name{Value: "Alice"},
                Age:  &Age{Value: 20},
            },
        },
        {
            // ゼロ値
            b:      atob(""),
            expect: Person{},
        },
        {
            // Ageのみゼロ値
            b: atob("0a070a05416c696365"),
            expect: Person{
                Name: &Name{Value: "Alice"},
            },
        },
        {
            // Nameのみゼロ値
            b: atob("12020814"),
            expect: Person{
                Age: &Age{Value: 20},
            },
        },
        {
            // Varintが2バイトになる場合
            b: atob("1203088301"),
            expect: Person{
                Age: &Age{Value: 131},
            },
        },
        {
            // Varintが3バイトになる場合
            b: atob("120408928002"),
            expect: Person{
                Age: &Age{Value: 32786},
            },
        },
    }

    for i, tt := range tests {
        p := Person{}
        if err := p.Unmarshal(tt.b); err != nil {
            t.Fatalf("test[%d - failed to Unmarshal. got err:%q", i, err)
        }
        if diff := cmp.Diff(p, tt.expect); diff != "" {
            t.Fatalf("test[%d - failed to Unmarshal. expected=%q, got=%q", i, tt.expect, p)
        }
    }
}

ちゃんとテストが通ります。

❯ go test ./decoder
ok      github.com/cipepser/protobufDecoder/decoder 0.006s

最後に

というわけで、Protocol Buffersのバイナリをデコードしてみました。 全体を表示しても以下のように165行程度なのでだいぶコンパクトに実装できたと思います。

f:id:cipepser:20181209140919p:plain

今回はPersonNameAgeとそれぞれUnmarshalを実装しましたが、同じような処理になるのでコード生成したいですね。 あと.protoからバイナリを生成するMarshalのほうとかも実装していきたいです。

ご指摘等あればtwitterまでご連絡ください。

References

Go言語でつくるインタプリタ

Go言語でつくるインタプリタ