Rework XDR encoding

This commit is contained in:
Jakob Borg
2014-02-20 17:40:15 +01:00
parent 87d473dc8f
commit 5837277f8d
27 changed files with 1843 additions and 1029 deletions

393
xdr/cmd/coder/main.go Normal file
View File

@@ -0,0 +1,393 @@
package main
import (
"bytes"
"flag"
"fmt"
"go/ast"
"go/format"
"go/parser"
"go/token"
"os"
"regexp"
"strconv"
"strings"
"text/template"
)
var output string
type field struct {
Name string
IsBasic bool
IsSlice bool
IsMap bool
FieldType string
KeyType string
Encoder string
Convert string
Max int
}
var headerTpl = template.Must(template.New("header").Parse(`package {{.Package}}
import (
"bytes"
"io"
"github.com/calmh/syncthing/xdr"
)
`))
var encodeTpl = template.Must(template.New("encoder").Parse(`
func (o {{.TypeName}}) EncodeXDR(w io.Writer) (int, error) {
var xw = xdr.NewWriter(w)
return o.encodeXDR(xw)
}//+n
func (o {{.TypeName}}) MarshalXDR() []byte {
var buf bytes.Buffer
var xw = xdr.NewWriter(&buf)
o.encodeXDR(xw)
return buf.Bytes()
}//+n
func (o {{.TypeName}}) encodeXDR(xw *xdr.Writer) (int, error) {
{{range $field := .Fields}}
{{if not $field.IsSlice}}
{{if ne $field.Convert ""}}
xw.Write{{$field.Encoder}}({{$field.Convert}}(o.{{$field.Name}}))
{{else if $field.IsBasic}}
{{if ge $field.Max 1}}
if len(o.{{$field.Name}}) > {{$field.Max}} {
return xw.Tot(), xdr.ErrElementSizeExceeded
}
{{end}}
xw.Write{{$field.Encoder}}(o.{{$field.Name}})
{{else}}
o.{{$field.Name}}.encodeXDR(xw)
{{end}}
{{else}}
{{if ge $field.Max 1}}
if len(o.{{$field.Name}}) > {{$field.Max}} {
return xw.Tot(), xdr.ErrElementSizeExceeded
}
{{end}}
xw.WriteUint32(uint32(len(o.{{$field.Name}})))
for i := range o.{{$field.Name}} {
{{if ne $field.Convert ""}}
xw.Write{{$field.Encoder}}({{$field.Convert}}(o.{{$field.Name}}[i]))
{{else if $field.IsBasic}}
xw.Write{{$field.Encoder}}(o.{{$field.Name}}[i])
{{else}}
o.{{$field.Name}}[i].encodeXDR(xw)
{{end}}
}
{{end}}
{{end}}
return xw.Tot(), xw.Error()
}//+n
func (o *{{.TypeName}}) DecodeXDR(r io.Reader) error {
xr := xdr.NewReader(r)
return o.decodeXDR(xr)
}//+n
func (o *{{.TypeName}}) UnmarshalXDR(bs []byte) error {
var buf = bytes.NewBuffer(bs)
var xr = xdr.NewReader(buf)
return o.decodeXDR(xr)
}//+n
func (o *{{.TypeName}}) decodeXDR(xr *xdr.Reader) error {
{{range $field := .Fields}}
{{if not $field.IsSlice}}
{{if ne $field.Convert ""}}
o.{{$field.Name}} = {{$field.FieldType}}(xr.Read{{$field.Encoder}}())
{{else if $field.IsBasic}}
{{if ge $field.Max 1}}
o.{{$field.Name}} = xr.Read{{$field.Encoder}}Max({{$field.Max}})
{{else}}
o.{{$field.Name}} = xr.Read{{$field.Encoder}}()
{{end}}
{{else}}
(&o.{{$field.Name}}).decodeXDR(xr)
{{end}}
{{else}}
_{{$field.Name}}Size := int(xr.ReadUint32())
{{if ge $field.Max 1}}
if _{{$field.Name}}Size > {{$field.Max}} {
return xdr.ErrElementSizeExceeded
}
{{end}}
o.{{$field.Name}} = make([]{{$field.FieldType}}, _{{$field.Name}}Size)
for i := range o.{{$field.Name}} {
{{if ne $field.Convert ""}}
o.{{$field.Name}}[i] = {{$field.FieldType}}(xr.Read{{$field.Encoder}}())
{{else if $field.IsBasic}}
o.{{$field.Name}}[i] = xr.Read{{$field.Encoder}}()
{{else}}
(&o.{{$field.Name}}[i]).decodeXDR(xr)
{{end}}
}
{{end}}
{{end}}
return xr.Error()
}`))
var maxRe = regexp.MustCompile(`\Wmax:(\d+)`)
type typeSet struct {
Type string
Encoder string
}
var xdrEncoders = map[string]typeSet{
"int16": typeSet{"uint16", "Uint16"},
"uint16": typeSet{"", "Uint16"},
"int32": typeSet{"uint32", "Uint32"},
"uint32": typeSet{"", "Uint32"},
"int64": typeSet{"uint64", "Uint64"},
"uint64": typeSet{"", "Uint64"},
"int": typeSet{"uint64", "Uint64"},
"string": typeSet{"", "String"},
"[]byte": typeSet{"", "Bytes"},
"bool": typeSet{"", "Bool"},
}
func handleStruct(name string, t *ast.StructType) {
var fs []field
for _, sf := range t.Fields.List {
if len(sf.Names) == 0 {
// We don't handle anonymous fields
continue
}
fn := sf.Names[0].Name
var max = 0
if sf.Comment != nil {
c := sf.Comment.List[0].Text
if m := maxRe.FindStringSubmatch(c); m != nil {
max, _ = strconv.Atoi(m[1])
}
}
var f field
switch ft := sf.Type.(type) {
case *ast.Ident:
tn := ft.Name
if enc, ok := xdrEncoders[tn]; ok {
f = field{
Name: fn,
IsBasic: true,
FieldType: tn,
Encoder: enc.Encoder,
Convert: enc.Type,
Max: max,
}
} else {
f = field{
Name: fn,
IsBasic: false,
FieldType: tn,
Max: max,
}
}
case *ast.ArrayType:
if ft.Len != nil {
// We don't handle arrays
continue
}
tn := ft.Elt.(*ast.Ident).Name
if enc, ok := xdrEncoders["[]"+tn]; ok {
f = field{
Name: fn,
IsBasic: true,
FieldType: tn,
Encoder: enc.Encoder,
Convert: enc.Type,
Max: max,
}
} else if enc, ok := xdrEncoders[tn]; ok {
f = field{
Name: fn,
IsBasic: true,
IsSlice: true,
FieldType: tn,
Encoder: enc.Encoder,
Convert: enc.Type,
Max: max,
}
} else {
f = field{
Name: fn,
IsBasic: false,
IsSlice: true,
FieldType: tn,
Max: max,
}
}
}
fs = append(fs, f)
}
switch output {
case "code":
generateCode(name, fs)
case "diagram":
generateDiagram(name, fs)
case "xdr":
generateXdr(name, fs)
}
}
func generateCode(name string, fs []field) {
var buf bytes.Buffer
err := encodeTpl.Execute(&buf, map[string]interface{}{"TypeName": name, "Fields": fs})
if err != nil {
panic(err)
}
bs := regexp.MustCompile(`(\s*\n)+`).ReplaceAll(buf.Bytes(), []byte("\n"))
bs = bytes.Replace(bs, []byte("//+n"), []byte("\n"), -1)
bs, err = format.Source(bs)
if err != nil {
panic(err)
}
fmt.Println(string(bs))
}
func generateDiagram(sn string, fs []field) {
fmt.Println(sn + " Structure:")
fmt.Println()
fmt.Println(" 0 1 2 3")
fmt.Println(" 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1")
line := "+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+"
fmt.Println(line)
for _, f := range fs {
tn := f.FieldType
sl := f.IsSlice
if sl {
fmt.Printf("| %s |\n", center("Number of "+f.Name, 61))
fmt.Println(line)
}
switch tn {
case "uint16":
fmt.Printf("| %s | %s |\n", center(f.Name, 29), center("0x0000", 29))
fmt.Println(line)
case "uint32":
fmt.Printf("| %s |\n", center(f.Name, 61))
fmt.Println(line)
case "int64", "uint64":
fmt.Printf("| %-61s |\n", "")
fmt.Printf("+ %s +\n", center(f.Name+" (64 bits)", 61))
fmt.Printf("| %-61s |\n", "")
fmt.Println(line)
case "string", "byte": // XXX We assume slice of byte!
fmt.Printf("| %s |\n", center("Length of "+f.Name, 61))
fmt.Println(line)
fmt.Printf("/ %61s /\n", "")
fmt.Printf("\\ %s \\\n", center(f.Name+" (variable length)", 61))
fmt.Printf("/ %61s /\n", "")
fmt.Println(line)
default:
if sl {
tn = "Zero or more " + tn + " Structures"
fmt.Printf("/ %s /\n", center("", 61))
fmt.Printf("\\ %s \\\n", center(tn, 61))
fmt.Printf("/ %s /\n", center("", 61))
} else {
fmt.Printf("| %s |\n", center(tn, 61))
}
fmt.Println(line)
}
}
fmt.Println()
fmt.Println()
}
func generateXdr(sn string, fs []field) {
fmt.Printf("struct %s {\n", sn)
for _, f := range fs {
tn := f.FieldType
fn := f.Name
suf := ""
if f.IsSlice {
suf = "<>"
}
switch tn {
case "uint16":
fmt.Printf("\tunsigned short %s%s;\n", fn, suf)
case "uint32":
fmt.Printf("\tunsigned int %s%s;\n", fn, suf)
case "int64":
fmt.Printf("\thyper %s%s;\n", fn, suf)
case "uint64":
fmt.Printf("\tunsigned hyper %s%s;\n", fn, suf)
case "string":
fmt.Printf("\tstring %s<>;\n", fn)
case "byte":
fmt.Printf("\topaque %s<>;\n", fn)
default:
fmt.Printf("\t%s %s%s;\n", tn, fn, suf)
}
}
fmt.Println("}")
fmt.Println()
}
func center(s string, w int) string {
w -= len(s)
l := w / 2
r := l
if l+r < w {
r++
}
return strings.Repeat(" ", l) + s + strings.Repeat(" ", r)
}
func inspector(fset *token.FileSet) func(ast.Node) bool {
return func(n ast.Node) bool {
switch n := n.(type) {
case *ast.TypeSpec:
switch t := n.Type.(type) {
case *ast.StructType:
name := n.Name.Name
handleStruct(name, t)
}
return false
default:
return true
}
}
}
func main() {
flag.StringVar(&output, "output", "code", "code,xdr,diagram")
flag.Parse()
fname := flag.Arg(0)
// Create the AST by parsing src.
fset := token.NewFileSet() // positions are relative to fset
f, err := parser.ParseFile(fset, fname, nil, parser.ParseComments)
if err != nil {
panic(err)
}
//ast.Print(fset, f)
if output == "code" {
headerTpl.Execute(os.Stdout, map[string]string{"Package": f.Name.Name})
}
i := inspector(fset)
ast.Inspect(f, i)
}

View File

@@ -1,10 +1,15 @@
package xdr
import "io"
import (
"errors"
"io"
)
var ErrElementSizeExceeded = errors.New("Element size exceeded")
type Reader struct {
r io.Reader
tot uint64
tot int
err error
b [8]byte
}
@@ -16,10 +21,26 @@ func NewReader(r io.Reader) *Reader {
}
func (r *Reader) ReadString() string {
return string(r.ReadBytes(nil))
return string(r.ReadBytes())
}
func (r *Reader) ReadBytes(dst []byte) []byte {
func (r *Reader) ReadStringMax(max int) string {
return string(r.ReadBytesMax(max))
}
func (r *Reader) ReadBytes() []byte {
return r.ReadBytesInto(nil)
}
func (r *Reader) ReadBytesMax(max int) []byte {
return r.ReadBytesMaxInto(max, nil)
}
func (r *Reader) ReadBytesInto(dst []byte) []byte {
return r.ReadBytesMaxInto(0, dst)
}
func (r *Reader) ReadBytesMaxInto(max int, dst []byte) []byte {
if r.err != nil {
return nil
}
@@ -27,22 +48,35 @@ func (r *Reader) ReadBytes(dst []byte) []byte {
if r.err != nil {
return nil
}
if max > 0 && l > max {
r.err = ErrElementSizeExceeded
return nil
}
if l+pad(l) > len(dst) {
dst = make([]byte, l+pad(l))
} else {
dst = dst[:l+pad(l)]
}
_, r.err = io.ReadFull(r.r, dst)
r.tot += uint64(l + pad(l))
r.tot += l + pad(l)
return dst[:l]
}
func (r *Reader) ReadUint16() uint16 {
if r.err != nil {
return 0
}
_, r.err = io.ReadFull(r.r, r.b[:4])
r.tot += 4
return uint16(r.b[1]) | uint16(r.b[0])<<8
}
func (r *Reader) ReadUint32() uint32 {
if r.err != nil {
return 0
}
_, r.err = io.ReadFull(r.r, r.b[:4])
r.tot += 8
r.tot += 4
return uint32(r.b[3]) | uint32(r.b[2])<<8 | uint32(r.b[1])<<16 | uint32(r.b[0])<<24
}
@@ -56,10 +90,10 @@ func (r *Reader) ReadUint64() uint64 {
uint64(r.b[3])<<32 | uint64(r.b[2])<<40 | uint64(r.b[1])<<48 | uint64(r.b[0])<<56
}
func (r *Reader) Tot() uint64 {
func (r *Reader) Tot() int {
return r.tot
}
func (r *Reader) Err() error {
func (r *Reader) Error() error {
return r.err
}

View File

@@ -14,7 +14,7 @@ var padBytes = []byte{0, 0, 0}
type Writer struct {
w io.Writer
tot uint64
tot int
err error
b [8]byte
}
@@ -48,7 +48,22 @@ func (w *Writer) WriteBytes(bs []byte) (int, error) {
l += n
}
w.tot += uint64(l)
w.tot += l
return l, w.err
}
func (w *Writer) WriteUint16(v uint16) (int, error) {
if w.err != nil {
return 0, w.err
}
w.b[0] = byte(v >> 8)
w.b[1] = byte(v)
w.b[2] = 0
w.b[3] = 0
var l int
l, w.err = w.w.Write(w.b[:4])
w.tot += l
return l, w.err
}
@@ -63,7 +78,7 @@ func (w *Writer) WriteUint32(v uint32) (int, error) {
var l int
l, w.err = w.w.Write(w.b[:4])
w.tot += uint64(l)
w.tot += l
return l, w.err
}
@@ -82,14 +97,14 @@ func (w *Writer) WriteUint64(v uint64) (int, error) {
var l int
l, w.err = w.w.Write(w.b[:8])
w.tot += uint64(l)
w.tot += l
return l, w.err
}
func (w *Writer) Tot() uint64 {
func (w *Writer) Tot() int {
return w.tot
}
func (w *Writer) Err() error {
func (w *Writer) Error() error {
return w.err
}

View File

@@ -30,8 +30,8 @@ func TestBytesNil(t *testing.T) {
var r = NewReader(b)
w.WriteBytes(bs)
w.WriteBytes(bs)
r.ReadBytes(nil)
res := r.ReadBytes(nil)
r.ReadBytes()
res := r.ReadBytes()
return bytes.Compare(bs, res) == 0
}
if err := quick.Check(fn, nil); err != nil {
@@ -47,8 +47,8 @@ func TestBytesGiven(t *testing.T) {
w.WriteBytes(bs)
w.WriteBytes(bs)
res := make([]byte, 12)
res = r.ReadBytes(res)
res = r.ReadBytes(res)
res = r.ReadBytesInto(res)
res = r.ReadBytesInto(res)
return bytes.Compare(bs, res) == 0
}
if err := quick.Check(fn, nil); err != nil {