diff --git a/protocol.go b/protocol.go index 8e73afea..4c1364ea 100644 --- a/protocol.go +++ b/protocol.go @@ -28,7 +28,6 @@ const ( messageTypeRequest = 2 messageTypeResponse = 3 messageTypePing = 4 - messageTypePong = 5 messageTypeIndexUpdate = 6 messageTypeClose = 7 ) @@ -71,13 +70,12 @@ const ( ) var ( - ErrClusterHash = fmt.Errorf("configuration error: mismatched cluster hash") - ErrClosed = errors.New("connection closed") + ErrClosed = errors.New("connection closed") + ErrTimeout = errors.New("read timeout") ) // Specific variants of empty messages... type pingMessage struct{ EmptyMessage } -type pongMessage struct{ EmptyMessage } type Model interface { // An index was received from the peer device @@ -146,9 +144,11 @@ type isEofer interface { IsEOF() bool } -var ( - PingTimeout = 30 * time.Second - PingIdleTime = 60 * time.Second +const ( + // We make sure to send a message at least this often, by triggering pings. + PingSendInterval = 90 * time.Second + // If we haven't received a message from the other side for this long, close the connection. + ReceiveTimeout = 300 * time.Second ) func NewConnection(deviceID DeviceID, reader io.Reader, writer io.Writer, receiver Model, name string, compress Compression) Connection { @@ -180,7 +180,8 @@ func NewConnection(deviceID DeviceID, reader io.Reader, writer io.Writer, receiv func (c *rawConnection) Start() { go c.readerLoop() go c.writerLoop() - go c.pingerLoop() + go c.pingSender() + go c.pingReceiver() go c.idGenerator() } @@ -278,18 +279,7 @@ func (c *rawConnection) ping() bool { return false } - rc := make(chan asyncResult, 1) - c.awaitingMut.Lock() - c.awaiting[id] = rc - c.awaitingMut.Unlock() - - ok := c.send(id, messageTypePing, nil, nil) - if !ok { - return false - } - - res, ok := <-rc - return ok && res.err == nil + return c.send(id, messageTypePing, nil, nil) } func (c *rawConnection) readerLoop() (err error) { @@ -352,13 +342,7 @@ func (c *rawConnection) readerLoop() (err error) { if state != stateReady { return fmt.Errorf("protocol error: ping message in state %d", state) } - c.send(hdr.msgID, messageTypePong, pongMessage{}, nil) - - case pongMessage: - if state != stateReady { - return fmt.Errorf("protocol error: pong message in state %d", state) - } - c.handlePong(hdr.msgID) + // Nothing case CloseMessage: return errors.New(msg.Reason) @@ -467,9 +451,6 @@ func (c *rawConnection) readMessage() (hdr header, msg encodable, err error) { case messageTypePing: msg = pingMessage{} - case messageTypePong: - msg = pongMessage{} - case messageTypeClusterConfig: var cc ClusterConfigMessage err = cc.UnmarshalXDR(msgBuf) @@ -729,42 +710,55 @@ func (c *rawConnection) idGenerator() { } } -func (c *rawConnection) pingerLoop() { - var rc = make(chan bool, 1) - ticker := time.Tick(PingIdleTime / 2) +// The pingSender makes sure that we've sent a message within the last +// PingSendInterval. If we already have something sent in the last +// PingSendInterval/2, we do nothing. Otherwise we send a ping message. This +// results in an effecting ping interval of somewhere between +// PingSendInterval/2 and PingSendInterval. +func (c *rawConnection) pingSender() { + ticker := time.Tick(PingSendInterval / 2) + for { select { case <-ticker: - if d := time.Since(c.cr.Last()); d < PingIdleTime { - if debug { - l.Debugln(c.id, "ping skipped after rd", d) - } - continue - } - if d := time.Since(c.cw.Last()); d < PingIdleTime { + d := time.Since(c.cw.Last()) + if d < PingSendInterval/2 { if debug { l.Debugln(c.id, "ping skipped after wr", d) } continue } - go func() { + + if debug { + l.Debugln(c.id, "ping -> after", d) + } + c.ping() + + case <-c.closed: + return + } + } +} + +// The pingReciever checks that we've received a message (any message will do, +// but we expect pings in the absence of other messages) within the last +// ReceiveTimeout. If not, we close the connection with an ErrTimeout. +func (c *rawConnection) pingReceiver() { + ticker := time.Tick(ReceiveTimeout / 2) + + for { + select { + case <-ticker: + d := time.Since(c.cr.Last()) + if d > ReceiveTimeout { if debug { - l.Debugln(c.id, "ping ->") + l.Debugln(c.id, "ping timeout", d) } - rc <- c.ping() - }() - select { - case ok := <-rc: - if debug { - l.Debugln(c.id, "<- pong") - } - if !ok { - c.close(fmt.Errorf("ping failure")) - } - case <-time.After(PingTimeout): - c.close(fmt.Errorf("ping timeout")) - case <-c.closed: - return + c.close(ErrTimeout) + } + + if debug { + l.Debugln(c.id, "last read within", d) } case <-c.closed: diff --git a/protocol_test.go b/protocol_test.go index 2c64a9b3..8a470884 100644 --- a/protocol_test.go +++ b/protocol_test.go @@ -6,7 +6,6 @@ import ( "bytes" "encoding/hex" "encoding/json" - "errors" "fmt" "io" "io/ioutil" @@ -82,94 +81,6 @@ func TestPing(t *testing.T) { } } -func TestPingErr(t *testing.T) { - e := errors.New("something broke") - - for i := 0; i < 32; i++ { - for j := 0; j < 32; j++ { - m0 := newTestModel() - m1 := newTestModel() - - ar, aw := io.Pipe() - br, bw := io.Pipe() - eaw := &ErrPipe{PipeWriter: *aw, max: i, err: e} - ebw := &ErrPipe{PipeWriter: *bw, max: j, err: e} - - c0 := NewConnection(c0ID, ar, ebw, m0, "name", CompressAlways).(wireFormatConnection).next.(*rawConnection) - c0.Start() - c1 := NewConnection(c1ID, br, eaw, m1, "name", CompressAlways) - c1.Start() - c0.ClusterConfig(ClusterConfigMessage{}) - c1.ClusterConfig(ClusterConfigMessage{}) - - res := c0.ping() - if (i < 8 || j < 8) && res { - // This should have resulted in failure, as there is no way an empty ClusterConfig plus a Ping message fits in eight bytes. - t.Errorf("Unexpected ping success; i=%d, j=%d", i, j) - } else if (i >= 28 && j >= 28) && !res { - // This should have worked though, as 28 bytes is plenty for both. - t.Errorf("Unexpected ping fail; i=%d, j=%d", i, j) - } - } - } -} - -// func TestRequestResponseErr(t *testing.T) { -// e := errors.New("something broke") - -// var pass bool -// for i := 0; i < 48; i++ { -// for j := 0; j < 38; j++ { -// m0 := newTestModel() -// m0.data = []byte("response data") -// m1 := newTestModel() - -// ar, aw := io.Pipe() -// br, bw := io.Pipe() -// eaw := &ErrPipe{PipeWriter: *aw, max: i, err: e} -// ebw := &ErrPipe{PipeWriter: *bw, max: j, err: e} - -// NewConnection(c0ID, ar, ebw, m0, nil) -// c1 := NewConnection(c1ID, br, eaw, m1, nil).(wireFormatConnection).next.(*rawConnection) - -// d, err := c1.Request("default", "tn", 1234, 5678) -// if err == e || err == ErrClosed { -// t.Logf("Error at %d+%d bytes", i, j) -// if !m1.isClosed() { -// t.Fatal("c1 not closed") -// } -// if !m0.isClosed() { -// t.Fatal("c0 not closed") -// } -// continue -// } -// if err != nil { -// t.Fatal(err) -// } -// if string(d) != "response data" { -// t.Fatalf("Incorrect response data %q", string(d)) -// } -// if m0.folder != "default" { -// t.Fatalf("Incorrect folder %q", m0.folder) -// } -// if m0.name != "tn" { -// t.Fatalf("Incorrect name %q", m0.name) -// } -// if m0.offset != 1234 { -// t.Fatalf("Incorrect offset %d", m0.offset) -// } -// if m0.size != 5678 { -// t.Fatalf("Incorrect size %d", m0.size) -// } -// t.Logf("Pass at %d+%d bytes", i, j) -// pass = true -// } -// } -// if !pass { -// t.Fatal("Never passed") -// } -// } - func TestVersionErr(t *testing.T) { m0 := newTestModel() m1 := newTestModel()