From ec7c88ca55bcc6827abab339defe774a1cf6cea1 Mon Sep 17 00:00:00 2001 From: Simon Frei Date: Thu, 2 May 2019 10:21:07 +0200 Subject: [PATCH] lib/protocol: Fix yet another deadlock (fixes #5678) (#5679) * lib/protocol: Fix yet another deadlock (fixes #5678) * more consistency * read deadlock * naming * more naming --- lib/protocol/protocol.go | 84 +++++++++++++++++++++++++--------------- 1 file changed, 52 insertions(+), 32 deletions(-) diff --git a/lib/protocol/protocol.go b/lib/protocol/protocol.go index 96b41bd6..c460739a 100644 --- a/lib/protocol/protocol.go +++ b/lib/protocol/protocol.go @@ -240,14 +240,21 @@ func NewConnection(deviceID DeviceID, reader io.Reader, writer io.Writer, receiv // Start creates the goroutines for sending and receiving of messages. It must // be called exactly once after creating a connection. func (c *rawConnection) Start() { - c.wg.Add(4) + c.startGoroutine(c.readerLoop) + c.startGoroutine(c.writerLoop) + c.startGoroutine(c.pingSender) + c.startGoroutine(c.pingReceiver) +} + +func (c *rawConnection) startGoroutine(loop func() error) { + c.wg.Add(1) go func() { - err := c.readerLoop() - c.internalClose(err) + err := loop() + c.wg.Done() + if err != nil && err != ErrClosed { + c.internalClose(err) + } }() - go c.writerLoop() - go c.pingSender() - go c.pingReceiver() } func (c *rawConnection) ID() DeviceID { @@ -363,25 +370,44 @@ func (c *rawConnection) ping() bool { return c.send(&Ping{}, nil) } -func (c *rawConnection) readerLoop() (err error) { - defer c.wg.Done() +type messageWithError struct { + msg message + err error +} + +func (c *rawConnection) readerLoop() error { fourByteBuf := make([]byte, 4) + inbox := make(chan messageWithError) + + // Reading from the wire may block until the underlying connection is closed. + go func() { + for { + msg, err := c.readMessage(fourByteBuf) + select { + case inbox <- messageWithError{msg: msg, err: err}: + case <-c.closed: + return + } + } + }() + state := stateInitial + var msgWithErr messageWithError for { - if c.Closed() { + select { + case msgWithErr = <-inbox: + case <-c.closed: return ErrClosed } - - msg, err := c.readMessage(fourByteBuf) - if err == errUnknownMessage { - // Unknown message types are skipped, for future extensibility. - continue - } - if err != nil { - return err + if msgWithErr.err != nil { + if msgWithErr.err == errUnknownMessage { + // Unknown message types are skipped, for future extensibility. + continue + } + return msgWithErr.err } - switch msg := msg.(type) { + switch msg := msgWithErr.msg.(type) { case *ClusterConfig: l.Debugln("read ClusterConfig message") if state != stateInitial { @@ -660,8 +686,7 @@ func (c *rawConnection) send(msg message, done chan struct{}) (sent bool) { } } -func (c *rawConnection) writerLoop() { - defer c.wg.Done() +func (c *rawConnection) writerLoop() error { for { select { case hm := <-c.outbox: @@ -670,12 +695,11 @@ func (c *rawConnection) writerLoop() { close(hm.done) } if err != nil { - c.internalClose(err) - return + return err } case <-c.closed: - return + return ErrClosed } } } @@ -882,9 +906,7 @@ func (c *rawConnection) internalClose(err error) { // PingSendInterval/2, we do nothing. Otherwise we send a ping message. This // results in an effecting ping interval of somewhere between // PingSendInterval/2 and PingSendInterval. -func (c *rawConnection) pingSender() { - defer c.wg.Done() - +func (c *rawConnection) pingSender() error { ticker := time.NewTicker(PingSendInterval / 2) defer ticker.Stop() @@ -901,7 +923,7 @@ func (c *rawConnection) pingSender() { c.ping() case <-c.closed: - return + return ErrClosed } } } @@ -909,9 +931,7 @@ func (c *rawConnection) pingSender() { // The pingReceiver checks that we've received a message (any message will do, // but we expect pings in the absence of other messages) within the last // ReceiveTimeout. If not, we close the connection with an ErrTimeout. -func (c *rawConnection) pingReceiver() { - defer c.wg.Done() - +func (c *rawConnection) pingReceiver() error { ticker := time.NewTicker(ReceiveTimeout / 2) defer ticker.Stop() @@ -921,13 +941,13 @@ func (c *rawConnection) pingReceiver() { d := time.Since(c.cr.Last()) if d > ReceiveTimeout { l.Debugln(c.id, "ping timeout", d) - c.internalClose(ErrTimeout) + return ErrTimeout } l.Debugln(c.id, "last read within", d) case <-c.closed: - return + return ErrClosed } } }