diff --git a/lib/connections/relay_dial.go b/lib/connections/relay_dial.go index d18625ae..81d29da8 100644 --- a/lib/connections/relay_dial.go +++ b/lib/connections/relay_dial.go @@ -8,7 +8,6 @@ package connections import ( "crypto/tls" - "net" "net/url" "time" @@ -40,7 +39,7 @@ func (d *relayDialer) Dial(id protocol.DeviceID, uri *url.URL) (IntermediateConn return IntermediateConnection{}, err } - err = dialer.SetTCPOptions(conn.(*net.TCPConn)) + err = dialer.SetTCPOptions(conn) if err != nil { conn.Close() return IntermediateConnection{}, err diff --git a/lib/connections/relay_listen.go b/lib/connections/relay_listen.go index 68f8b62e..a679acc5 100644 --- a/lib/connections/relay_listen.go +++ b/lib/connections/relay_listen.go @@ -8,7 +8,6 @@ package connections import ( "crypto/tls" - "net" "net/url" "sync" "time" @@ -74,7 +73,7 @@ func (t *relayListener) Serve() { continue } - err = dialer.SetTCPOptions(conn.(*net.TCPConn)) + err = dialer.SetTCPOptions(conn) if err != nil { l.Infoln(err) } diff --git a/lib/connections/tcp_listen.go b/lib/connections/tcp_listen.go index 58febccc..35e95c95 100644 --- a/lib/connections/tcp_listen.go +++ b/lib/connections/tcp_listen.go @@ -102,7 +102,7 @@ func (t *tcpListener) Serve() { l.Debugln("connect from", conn.RemoteAddr()) - err = dialer.SetTCPOptions(conn.(*net.TCPConn)) + err = dialer.SetTCPOptions(conn) if err != nil { l.Infoln(err) } diff --git a/lib/dialer/internal.go b/lib/dialer/internal.go index 9cf071fb..9a326354 100644 --- a/lib/dialer/internal.go +++ b/lib/dialer/internal.go @@ -57,9 +57,7 @@ func dialWithFallback(proxyDialFunc dialFunc, fallbackDialFunc dialFunc, network conn, err := proxyDialFunc(network, addr) if err == nil { l.Debugf("Dialing %s address %s via proxy - success, %s -> %s", network, addr, conn.LocalAddr(), conn.RemoteAddr()) - if tcpconn, ok := conn.(*net.TCPConn); ok { - SetTCPOptions(tcpconn) - } + SetTCPOptions(conn) return dialerConn{ conn, newDialerAddr(network, addr), }, nil @@ -73,9 +71,7 @@ func dialWithFallback(proxyDialFunc dialFunc, fallbackDialFunc dialFunc, network conn, err = fallbackDialFunc(network, addr) if err == nil { l.Debugf("Dialing %s address %s via fallback - success, %s -> %s", network, addr, conn.LocalAddr(), conn.RemoteAddr()) - if tcpconn, ok := conn.(*net.TCPConn); ok { - SetTCPOptions(tcpconn) - } + SetTCPOptions(conn) } else { l.Debugf("Dialing %s address %s via fallback - error %s", network, addr, err) } diff --git a/lib/dialer/public.go b/lib/dialer/public.go index f64ce196..4a19f831 100644 --- a/lib/dialer/public.go +++ b/lib/dialer/public.go @@ -7,6 +7,7 @@ package dialer import ( + "fmt" "net" "time" ) @@ -47,20 +48,30 @@ func DialTimeout(network, addr string, timeout time.Duration) (net.Conn, error) return net.DialTimeout(network, addr, timeout) } -// SetTCPOptions sets syncthings default TCP options on a TCP connection -func SetTCPOptions(conn *net.TCPConn) error { - var err error - if err = conn.SetLinger(0); err != nil { - return err +// SetTCPOptions sets our default TCP options on a TCP connection, possibly +// digging through dialerConn to extract the *net.TCPConn +func SetTCPOptions(conn net.Conn) error { + switch conn := conn.(type) { + case *net.TCPConn: + var err error + if err = conn.SetLinger(0); err != nil { + return err + } + if err = conn.SetNoDelay(false); err != nil { + return err + } + if err = conn.SetKeepAlivePeriod(60 * time.Second); err != nil { + return err + } + if err = conn.SetKeepAlive(true); err != nil { + return err + } + return nil + + case dialerConn: + return SetTCPOptions(conn.Conn) + + default: + return fmt.Errorf("unknown connection type %T", conn) } - if err = conn.SetNoDelay(false); err != nil { - return err - } - if err = conn.SetKeepAlivePeriod(60 * time.Second); err != nil { - return err - } - if err = conn.SetKeepAlive(true); err != nil { - return err - } - return nil } diff --git a/lib/protocol/benchmark_test.go b/lib/protocol/benchmark_test.go index ca2483fa..ee79f93c 100644 --- a/lib/protocol/benchmark_test.go +++ b/lib/protocol/benchmark_test.go @@ -131,8 +131,8 @@ func getTCPConnectionPair() (net.Conn, net.Conn, error) { } // Set the buffer sizes etc as usual - dialer.SetTCPOptions(conn0.(*net.TCPConn)) - dialer.SetTCPOptions(conn1.(*net.TCPConn)) + dialer.SetTCPOptions(conn0) + dialer.SetTCPOptions(conn1) return conn0, conn1, nil } diff --git a/lib/relay/client/static.go b/lib/relay/client/static.go index d9487bc2..cebc5f68 100644 --- a/lib/relay/client/static.go +++ b/lib/relay/client/static.go @@ -122,7 +122,8 @@ func (c *staticClient) Serve() { case protocol.SessionInvitation: ip := net.IP(msg.Address) if len(ip) == 0 || ip.IsUnspecified() { - msg.Address = c.conn.RemoteAddr().(*net.TCPAddr).IP[:] + ip := net.ParseIP(c.conn.RemoteAddr().String()) + msg.Address = ip[:] } c.invitations <- msg