From e1ac740ac4827609c5c534686d690d3aca0cc473 Mon Sep 17 00:00:00 2001 From: Jakob Borg Date: Tue, 2 Feb 2016 12:48:09 +0100 Subject: [PATCH] Use v2 of XDR package (actual changes) --- lib/db/leveldb_dbinstance.go | 12 +++++-- lib/protocol/header.go | 15 ++++---- lib/protocol/protocol.go | 59 +++++++++++++++++++------------ lib/protocol/protocol_test.go | 63 ++++++++++++++++++---------------- lib/relay/protocol/protocol.go | 29 +++++++++++----- 5 files changed, 109 insertions(+), 69 deletions(-) diff --git a/lib/db/leveldb_dbinstance.go b/lib/db/leveldb_dbinstance.go index f432d99e..a0a8ebbf 100644 --- a/lib/db/leveldb_dbinstance.go +++ b/lib/db/leveldb_dbinstance.go @@ -267,7 +267,11 @@ func (db *Instance) withHave(folder, device []byte, truncate bool, fn Iterator) defer dbi.Release() for dbi.Next() { - f, err := unmarshalTrunc(dbi.Value(), truncate) + // The iterator function may keep a reference to the unmarshalled + // struct, which in turn references the buffer it was unmarshalled + // from. dbi.Value() just returns an internal slice that it reuses, so + // we need to copy it. + f, err := unmarshalTrunc(append([]byte{}, dbi.Value()...), truncate) if err != nil { panic(err) } @@ -287,7 +291,11 @@ func (db *Instance) withAllFolderTruncated(folder []byte, fn func(device []byte, for dbi.Next() { device := db.deviceKeyDevice(dbi.Key()) var f FileInfoTruncated - err := f.UnmarshalXDR(dbi.Value()) + // The iterator function may keep a reference to the unmarshalled + // struct, which in turn references the buffer it was unmarshalled + // from. dbi.Value() just returns an internal slice that it reuses, so + // we need to copy it. + err := f.UnmarshalXDR(append([]byte{}, dbi.Value()...)) if err != nil { panic(err) } diff --git a/lib/protocol/header.go b/lib/protocol/header.go index 846ee48c..184d165d 100644 --- a/lib/protocol/header.go +++ b/lib/protocol/header.go @@ -11,15 +11,16 @@ type header struct { compression bool } -func (h header) encodeXDR(xw *xdr.Writer) (int, error) { - u := encodeHeader(h) - return xw.WriteUint32(u) +func (h header) MarshalXDRInto(m *xdr.Marshaller) error { + v := encodeHeader(h) + m.MarshalUint32(v) + return m.Error } -func (h *header) decodeXDR(xr *xdr.Reader) error { - u := xr.ReadUint32() - *h = decodeHeader(u) - return xr.Error() +func (h *header) UnmarshalXDRFrom(u *xdr.Unmarshaller) error { + v := u.UnmarshalUint32() + *h = decodeHeader(v) + return u.Error } func encodeHeader(h header) uint32 { diff --git a/lib/protocol/protocol.go b/lib/protocol/protocol.go index 334ba28a..c2412230 100644 --- a/lib/protocol/protocol.go +++ b/lib/protocol/protocol.go @@ -12,6 +12,7 @@ import ( "time" lz4 "github.com/bkaradzic/go-lz4" + "github.com/calmh/xdr" ) const ( @@ -130,8 +131,7 @@ type rawConnection struct { pool sync.Pool compression Compression - rdbuf0 []byte // used & reused by readMessage - rdbuf1 []byte // used & reused by readMessage + readerBuf []byte // used & reused by readMessage } type asyncResult struct { @@ -146,7 +146,8 @@ type hdrMsg struct { } type encodable interface { - AppendXDR([]byte) ([]byte, error) + MarshalXDRInto(m *xdr.Marshaller) error + XDRSize() int } type isEofer interface { @@ -374,18 +375,14 @@ func (c *rawConnection) readerLoop() (err error) { } func (c *rawConnection) readMessage() (hdr header, msg encodable, err error) { - if cap(c.rdbuf0) < 8 { - c.rdbuf0 = make([]byte, 8) - } else { - c.rdbuf0 = c.rdbuf0[:8] - } - _, err = io.ReadFull(c.cr, c.rdbuf0) + hdrBuf := make([]byte, 8) + _, err = io.ReadFull(c.cr, hdrBuf) if err != nil { return } - hdr = decodeHeader(binary.BigEndian.Uint32(c.rdbuf0[0:4])) - msglen := int(binary.BigEndian.Uint32(c.rdbuf0[4:8])) + hdr = decodeHeader(binary.BigEndian.Uint32(hdrBuf[:4])) + msglen := int(binary.BigEndian.Uint32(hdrBuf[4:])) l.Debugf("read header %v (msglen=%d)", hdr, msglen) @@ -399,27 +396,40 @@ func (c *rawConnection) readMessage() (hdr header, msg encodable, err error) { return } - if cap(c.rdbuf0) < msglen { - c.rdbuf0 = make([]byte, msglen) + // c.readerBuf contains a buffer we can reuse. But once we've unmarshalled + // a message from the buffer we can't reuse it again as the unmarshalled + // message refers to the contents of the buffer. The only case we a buffer + // ends up in readerBuf for reuse is when the message is compressed, as we + // then decompress into a new buffer instead. + + var msgBuf []byte + if cap(c.readerBuf) >= msglen { + // If we have a buffer ready in rdbuf we just use that. + msgBuf = c.readerBuf[:msglen] } else { - c.rdbuf0 = c.rdbuf0[:msglen] + // Otherwise we allocate a new buffer. + msgBuf = make([]byte, msglen) } - _, err = io.ReadFull(c.cr, c.rdbuf0) + + _, err = io.ReadFull(c.cr, msgBuf) if err != nil { return } - l.Debugf("read %d bytes", len(c.rdbuf0)) + l.Debugf("read %d bytes", len(msgBuf)) - msgBuf := c.rdbuf0 if hdr.compression && msglen > 0 { - c.rdbuf1 = c.rdbuf1[:cap(c.rdbuf1)] - c.rdbuf1, err = lz4.Decode(c.rdbuf1, c.rdbuf0) + // We're going to decompress msgBuf into a different newly allocated + // buffer, so keep msgBuf around for reuse on the next message. + c.readerBuf = msgBuf + + msgBuf, err = lz4.Decode(nil, msgBuf) if err != nil { return } - msgBuf = c.rdbuf1 l.Debugf("decompressed to %d bytes", len(msgBuf)) + } else { + c.readerBuf = nil } if shouldDebug() { @@ -601,7 +611,14 @@ func (c *rawConnection) writerLoop() { case hm := <-c.outbox: if hm.msg != nil { // Uncompressed message in uncBuf - uncBuf, err = hm.msg.AppendXDR(uncBuf[:0]) + msgLen := hm.msg.XDRSize() + if cap(uncBuf) >= msgLen { + uncBuf = uncBuf[:msgLen] + } else { + uncBuf = make([]byte, msgLen) + } + m := &xdr.Marshaller{Data: uncBuf} + err = hm.msg.MarshalXDRInto(m) if hm.done != nil { close(hm.done) } diff --git a/lib/protocol/protocol_test.go b/lib/protocol/protocol_test.go index 96510768..4aa8c670 100644 --- a/lib/protocol/protocol_test.go +++ b/lib/protocol/protocol_test.go @@ -3,7 +3,6 @@ package protocol import ( - "bytes" "encoding/binary" "encoding/hex" "encoding/json" @@ -55,14 +54,13 @@ func TestHeaderMarshalUnmarshal(t *testing.T) { ver = int(uint(ver) % 16) id = int(uint(id) % 4096) typ = int(uint(typ) % 256) - buf := new(bytes.Buffer) - xw := xdr.NewWriter(buf) - h0 := header{version: ver, msgID: id, msgType: typ} - h0.encodeXDR(xw) + buf := make([]byte, 4) + + h0 := header{version: ver, msgID: id, msgType: typ} + h0.MarshalXDRInto(&xdr.Marshaller{Data: buf}) - xr := xdr.NewReader(buf) var h1 header - h1.decodeXDR(xr) + h1.UnmarshalXDRFrom(&xdr.Unmarshaller{Data: buf}) return h0 == h1 } if err := quick.Check(f, nil); err != nil { @@ -128,8 +126,7 @@ func TestVersionErr(t *testing.T) { c0.ClusterConfig(ClusterConfigMessage{}) c1.ClusterConfig(ClusterConfigMessage{}) - w := xdr.NewWriter(c0.cw) - timeoutWriteHeader(w, header{ + timeoutWriteHeader(c0.cw, header{ version: 2, // higher than supported msgID: 0, msgType: messageTypeIndex, @@ -154,8 +151,7 @@ func TestTypeErr(t *testing.T) { c0.ClusterConfig(ClusterConfigMessage{}) c1.ClusterConfig(ClusterConfigMessage{}) - w := xdr.NewWriter(c0.cw) - timeoutWriteHeader(w, header{ + timeoutWriteHeader(c0.cw, header{ version: 0, msgID: 0, msgType: 42, // unknown type @@ -205,7 +201,7 @@ func TestElementSizeExceededNested(t *testing.T) { m := ClusterConfigMessage{ ClientName: "longstringlongstringlongstringinglongstringlongstringlonlongstringlongstringlon", } - _, err := m.EncodeXDR(ioutil.Discard) + _, err := m.MarshalXDR() if err == nil { t.Errorf("ID length %d > max 64, but no error", len(m.Folders[0].ID)) } @@ -213,12 +209,19 @@ func TestElementSizeExceededNested(t *testing.T) { func TestMarshalIndexMessage(t *testing.T) { f := func(m1 IndexMessage) bool { + if len(m1.Options) == 0 { + m1.Options = nil + } for i, f := range m1.Files { m1.Files[i].CachedSize = 0 - for j := range f.Blocks { - f.Blocks[j].Offset = 0 - if len(f.Blocks[j].Hash) == 0 { - f.Blocks[j].Hash = nil + if len(f.Blocks) == 0 { + m1.Files[i].Blocks = nil + } else { + for j := range f.Blocks { + f.Blocks[j].Offset = 0 + if len(f.Blocks[j].Hash) == 0 { + f.Blocks[j].Hash = nil + } } } } @@ -233,6 +236,9 @@ func TestMarshalIndexMessage(t *testing.T) { func TestMarshalRequestMessage(t *testing.T) { f := func(m1 RequestMessage) bool { + if len(m1.Options) == 0 { + m1.Options = nil + } return testMarshal(t, "request", &m1, &RequestMessage{}) } @@ -256,6 +262,9 @@ func TestMarshalResponseMessage(t *testing.T) { func TestMarshalClusterConfigMessage(t *testing.T) { f := func(m1 ClusterConfigMessage) bool { + if len(m1.Options) == 0 { + m1.Options = nil + } return testMarshal(t, "clusterconfig", &m1, &ClusterConfigMessage{}) } @@ -275,13 +284,11 @@ func TestMarshalCloseMessage(t *testing.T) { } type message interface { - EncodeXDR(io.Writer) (int, error) - DecodeXDR(io.Reader) error + MarshalXDR() ([]byte, error) + UnmarshalXDR([]byte) error } func testMarshal(t *testing.T, prefix string, m1, m2 message) bool { - var buf bytes.Buffer - failed := func(bc []byte) { bs, _ := json.MarshalIndent(m1, "", " ") ioutil.WriteFile(prefix+"-1.txt", bs, 0644) @@ -294,7 +301,7 @@ func testMarshal(t *testing.T, prefix string, m1, m2 message) bool { } } - _, err := m1.EncodeXDR(&buf) + buf, err := m1.MarshalXDR() if err != nil && strings.Contains(err.Error(), "exceeds size") { return true } @@ -303,23 +310,20 @@ func testMarshal(t *testing.T, prefix string, m1, m2 message) bool { t.Fatal(err) } - bc := make([]byte, len(buf.Bytes())) - copy(bc, buf.Bytes()) - - err = m2.DecodeXDR(&buf) + err = m2.UnmarshalXDR(buf) if err != nil { - failed(bc) + failed(buf) t.Fatal(err) } ok := reflect.DeepEqual(m1, m2) if !ok { - failed(bc) + failed(buf) } return ok } -func timeoutWriteHeader(w *xdr.Writer, hdr header) { +func timeoutWriteHeader(w io.Writer, hdr header) { // This tries to write a message header to w, but times out after a while. // This is useful because in testing, with a PipeWriter, it will block // forever if the other side isn't reading any more. On the other hand we @@ -332,8 +336,7 @@ func timeoutWriteHeader(w *xdr.Writer, hdr header) { done := make(chan struct{}) go func() { - w.WriteRaw(buf[:]) - l.Infoln("write completed") + w.Write(buf[:]) close(done) }() select { diff --git a/lib/relay/protocol/protocol.go b/lib/relay/protocol/protocol.go index dad76d94..2a7339e1 100644 --- a/lib/relay/protocol/protocol.go +++ b/lib/relay/protocol/protocol.go @@ -74,7 +74,13 @@ func WriteMessage(w io.Writer, message interface{}) error { func ReadMessage(r io.Reader) (interface{}, error) { var header header - if err := header.DecodeXDR(r); err != nil { + + buf := make([]byte, header.XDRSize()) + if _, err := io.ReadFull(r, buf); err != nil { + return nil, err + } + + if err := header.UnmarshalXDR(buf); err != nil { return nil, err } @@ -82,38 +88,43 @@ func ReadMessage(r io.Reader) (interface{}, error) { return nil, fmt.Errorf("magic mismatch") } + buf = make([]byte, int(header.messageLength)) + if _, err := io.ReadFull(r, buf); err != nil { + return nil, err + } + switch header.messageType { case messageTypePing: var msg Ping - err := msg.DecodeXDR(r) + err := msg.UnmarshalXDR(buf) return msg, err case messageTypePong: var msg Pong - err := msg.DecodeXDR(r) + err := msg.UnmarshalXDR(buf) return msg, err case messageTypeJoinRelayRequest: var msg JoinRelayRequest - err := msg.DecodeXDR(r) + err := msg.UnmarshalXDR(buf) return msg, err case messageTypeJoinSessionRequest: var msg JoinSessionRequest - err := msg.DecodeXDR(r) + err := msg.UnmarshalXDR(buf) return msg, err case messageTypeResponse: var msg Response - err := msg.DecodeXDR(r) + err := msg.UnmarshalXDR(buf) return msg, err case messageTypeConnectRequest: var msg ConnectRequest - err := msg.DecodeXDR(r) + err := msg.UnmarshalXDR(buf) return msg, err case messageTypeSessionInvitation: var msg SessionInvitation - err := msg.DecodeXDR(r) + err := msg.UnmarshalXDR(buf) return msg, err case messageTypeRelayFull: var msg RelayFull - err := msg.DecodeXDR(r) + err := msg.UnmarshalXDR(buf) return msg, err }