diff --git a/lib/model/model_test.go b/lib/model/model_test.go index 76b41102..caa91352 100644 --- a/lib/model/model_test.go +++ b/lib/model/model_test.go @@ -3266,6 +3266,12 @@ func TestSanitizePath(t *testing.T) { // on a protocol connection that has a blocking reader (blocking writer can't // be done as the test requires clusterconfigs to go through). func TestConnCloseOnRestart(t *testing.T) { + oldCloseTimeout := protocol.CloseTimeout + protocol.CloseTimeout = 100 * time.Millisecond + defer func() { + protocol.CloseTimeout = oldCloseTimeout + }() + w, fcfg := tmpDefaultWrapper() m := setupModel(w) defer cleanupModelAndRemoveDir(m, fcfg.Filesystem().URI()) diff --git a/lib/protocol/protocol.go b/lib/protocol/protocol.go index 116807f5..a081ec74 100644 --- a/lib/protocol/protocol.go +++ b/lib/protocol/protocol.go @@ -184,6 +184,7 @@ type rawConnection struct { inbox chan message outbox chan asyncMessage + closeBox chan asyncMessage clusterConfigBox chan *ClusterConfig dispatcherLoopStopped chan struct{} closed chan struct{} @@ -218,6 +219,11 @@ const ( ReceiveTimeout = 300 * time.Second ) +// CloseTimeout is the longest we'll wait when trying to send the close +// message before just closing the connection. +// Should not be modified in production code, just for testing. +var CloseTimeout = 10 * time.Second + func NewConnection(deviceID DeviceID, reader io.Reader, writer io.Writer, receiver Model, name string, compress Compression) Connection { cr := &countingReader{Reader: reader} cw := &countingWriter{Writer: writer} @@ -231,6 +237,7 @@ func NewConnection(deviceID DeviceID, reader io.Reader, writer io.Writer, receiv awaiting: make(map[int32]chan asyncResult), inbox: make(chan message), outbox: make(chan asyncMessage), + closeBox: make(chan asyncMessage), clusterConfigBox: make(chan *ClusterConfig), dispatcherLoopStopped: make(chan struct{}), closed: make(chan struct{}), @@ -671,6 +678,10 @@ func (c *rawConnection) writerLoop() { c.internalClose(err) return } + case hm := <-c.closeBox: + _ = c.writeMessage(hm.msg) + close(hm.done) + return case <-c.closed: return } @@ -686,6 +697,11 @@ func (c *rawConnection) writerLoop() { return } + case hm := <-c.closeBox: + _ = c.writeMessage(hm.msg) + close(hm.done) + return + case <-c.closed: return } @@ -853,17 +869,20 @@ func (c *rawConnection) shouldCompressMessage(msg message) bool { func (c *rawConnection) Close(err error) { c.sendCloseOnce.Do(func() { done := make(chan struct{}) - c.send(&Close{err.Error()}, done) + timeout := time.NewTimer(CloseTimeout) select { - case <-done: + case c.closeBox <- asyncMessage{&Close{err.Error()}, done}: + select { + case <-done: + case <-timeout.C: + case <-c.closed: + } + case <-timeout.C: case <-c.closed: } }) - // No more sends are necessary, therefore further steps to close the - // connection outside of this package can proceed immediately. - // And this prevents a potential deadlock due to calling c.receiver.Closed - go c.internalClose(err) + c.internalClose(err) } // internalClose is called if there is an unexpected error during normal operation. diff --git a/lib/protocol/protocol_test.go b/lib/protocol/protocol_test.go index 38660097..cba8dd3c 100644 --- a/lib/protocol/protocol_test.go +++ b/lib/protocol/protocol_test.go @@ -86,6 +86,12 @@ func TestClose(t *testing.T) { // Close is called while the underlying connection is broken (send blocks). // https://github.com/syncthing/syncthing/pull/5442 func TestCloseOnBlockingSend(t *testing.T) { + oldCloseTimeout := CloseTimeout + CloseTimeout = 100 * time.Millisecond + defer func() { + CloseTimeout = oldCloseTimeout + }() + m := newTestModel() c := NewConnection(c0ID, &testutils.BlockingRW{}, &testutils.BlockingRW{}, m, "name", CompressAlways).(wireFormatConnection).Connection.(*rawConnection) @@ -214,6 +220,33 @@ func TestClusterConfigFirst(t *testing.T) { } } +// TestCloseTimeout checks that calling Close times out and proceeds, if sending +// the close message does not succeed. +func TestCloseTimeout(t *testing.T) { + oldCloseTimeout := CloseTimeout + CloseTimeout = 100 * time.Millisecond + defer func() { + CloseTimeout = oldCloseTimeout + }() + + m := newTestModel() + + c := NewConnection(c0ID, &testutils.BlockingRW{}, &testutils.BlockingRW{}, m, "name", CompressAlways).(wireFormatConnection).Connection.(*rawConnection) + c.Start() + + done := make(chan struct{}) + go func() { + c.Close(errManual) + close(done) + }() + + select { + case <-done: + case <-time.After(5 * CloseTimeout): + t.Fatal("timed out before Close returned") + } +} + func TestMarshalIndexMessage(t *testing.T) { if testing.Short() { quickCfg.MaxCount = 10