diff --git a/protocol_listener.go b/listener.go similarity index 73% rename from protocol_listener.go rename to listener.go index 0ff0154d..a63bc4c8 100644 --- a/protocol_listener.go +++ b/listener.go @@ -11,6 +11,7 @@ import ( "time" syncthingprotocol "github.com/syncthing/protocol" + "github.com/syncthing/syncthing/lib/tlsutil" "github.com/syncthing/relaysrv/protocol" ) @@ -21,14 +22,16 @@ var ( numConnections int64 ) -func protocolListener(addr string, config *tls.Config) { - listener, err := net.Listen("tcp", addr) +func listener(addr string, config *tls.Config) { + tcpListener, err := net.Listen("tcp", addr) if err != nil { log.Fatalln(err) } + listener := tlsutil.DowngradingListener{tcpListener, nil} + for { - conn, err := listener.Accept() + conn, isTLS, err := listener.AcceptNoWrap() if err != nil { if debug { log.Println(err) @@ -39,10 +42,15 @@ func protocolListener(addr string, config *tls.Config) { setTCPOptions(conn) if debug { - log.Println("Protocol listener accepted connection from", conn.RemoteAddr()) + log.Println("Listener accepted connection from", conn.RemoteAddr(), "tls", isTLS) + } + + if isTLS { + go protocolConnectionHandler(conn, config) + } else { + go sessionConnectionHandler(conn) } - go protocolConnectionHandler(conn, config) } } @@ -220,6 +228,65 @@ func protocolConnectionHandler(tcpConn net.Conn, config *tls.Config) { } } +func sessionConnectionHandler(conn net.Conn) { + if err := conn.SetDeadline(time.Now().Add(messageTimeout)); err != nil { + if debug { + log.Println("Weird error setting deadline:", err, "on", conn.RemoteAddr()) + } + return + } + + message, err := protocol.ReadMessage(conn) + if err != nil { + return + } + + switch msg := message.(type) { + case protocol.JoinSessionRequest: + ses := findSession(string(msg.Key)) + if debug { + log.Println(conn.RemoteAddr(), "session lookup", ses) + } + + if ses == nil { + protocol.WriteMessage(conn, protocol.ResponseNotFound) + conn.Close() + return + } + + if !ses.AddConnection(conn) { + if debug { + log.Println("Failed to add", conn.RemoteAddr(), "to session", ses) + } + protocol.WriteMessage(conn, protocol.ResponseAlreadyConnected) + conn.Close() + return + } + + if err := protocol.WriteMessage(conn, protocol.ResponseSuccess); err != nil { + if debug { + log.Println("Failed to send session join response to ", conn.RemoteAddr(), "for", ses) + } + return + } + + if err := conn.SetDeadline(time.Time{}); err != nil { + if debug { + log.Println("Weird error setting deadline:", err, "on", conn.RemoteAddr()) + } + conn.Close() + return + } + + default: + if debug { + log.Println("Unexpected message from", conn.RemoteAddr(), message) + } + protocol.WriteMessage(conn, protocol.ResponseUnexpectedMessage) + conn.Close() + } +} + func messageReader(conn net.Conn, messages chan<- interface{}, errors chan<- error) { atomic.AddInt64(&numConnections, 1) defer atomic.AddInt64(&numConnections, -1) diff --git a/main.go b/main.go index 2defed87..c050c355 100644 --- a/main.go +++ b/main.go @@ -17,9 +17,8 @@ import ( ) var ( - listenProtocol string - listenSession string - debug bool + listen string + debug bool sessionAddress []byte sessionPort uint16 @@ -39,9 +38,7 @@ var ( func main() { var dir, extAddress string - flag.StringVar(&listenProtocol, "protocol-listen", ":22067", "Protocol listen address") - flag.StringVar(&listenSession, "session-listen", ":22068", "Session listen address") - flag.StringVar(&extAddress, "external-address", "", "External address to advertise, defaults no IP and session-listen port, causing clients to use the remote IP from the protocol connection") + flag.StringVar(&listen, "listen", ":22067", "Protocol listen address") flag.StringVar(&dir, "keys", ".", "Directory where cert.pem and key.pem is stored") flag.DurationVar(&networkTimeout, "network-timeout", 2*time.Minute, "Timeout for network operations") flag.DurationVar(&pingInterval, "ping-interval", time.Minute, "How often pings are sent") @@ -54,7 +51,7 @@ func main() { flag.Parse() if extAddress == "" { - extAddress = listenSession + extAddress = listen } addr, err := net.ResolveTCPAddr("tcp", extAddress) @@ -100,11 +97,9 @@ func main() { globalLimiter = ratelimit.NewBucketWithRate(float64(globalLimitBps), int64(2*globalLimitBps)) } - go sessionListener(listenSession) - if statusAddr != "" { go statusService(statusAddr) } - protocolListener(listenProtocol, tlsCfg) + listener(listen, tlsCfg) } diff --git a/protocol/packets.go b/protocol/packets.go index 7ff02011..1b21eba2 100644 --- a/protocol/packets.go +++ b/protocol/packets.go @@ -7,8 +7,9 @@ package protocol import ( "fmt" - syncthingprotocol "github.com/syncthing/protocol" "net" + + syncthingprotocol "github.com/syncthing/protocol" ) const ( diff --git a/session_listener.go b/session_listener.go deleted file mode 100644 index 82d2ec73..00000000 --- a/session_listener.go +++ /dev/null @@ -1,95 +0,0 @@ -// Copyright (C) 2015 Audrius Butkevicius and Contributors (see the CONTRIBUTORS file). - -package main - -import ( - "log" - "net" - "time" - - "github.com/syncthing/relaysrv/protocol" -) - -func sessionListener(addr string) { - listener, err := net.Listen("tcp", addr) - if err != nil { - log.Fatalln(err) - } - - for { - conn, err := listener.Accept() - if err != nil { - if debug { - log.Println(err) - } - continue - } - - setTCPOptions(conn) - - if debug { - log.Println("Session listener accepted connection from", conn.RemoteAddr()) - } - - go sessionConnectionHandler(conn) - } -} - -func sessionConnectionHandler(conn net.Conn) { - if err := conn.SetDeadline(time.Now().Add(messageTimeout)); err != nil { - if debug { - log.Println("Weird error setting deadline:", err, "on", conn.RemoteAddr()) - } - return - } - - message, err := protocol.ReadMessage(conn) - if err != nil { - return - } - - switch msg := message.(type) { - case protocol.JoinSessionRequest: - ses := findSession(string(msg.Key)) - if debug { - log.Println(conn.RemoteAddr(), "session lookup", ses) - } - - if ses == nil { - protocol.WriteMessage(conn, protocol.ResponseNotFound) - conn.Close() - return - } - - if !ses.AddConnection(conn) { - if debug { - log.Println("Failed to add", conn.RemoteAddr(), "to session", ses) - } - protocol.WriteMessage(conn, protocol.ResponseAlreadyConnected) - conn.Close() - return - } - - if err := protocol.WriteMessage(conn, protocol.ResponseSuccess); err != nil { - if debug { - log.Println("Failed to send session join response to ", conn.RemoteAddr(), "for", ses) - } - return - } - - if err := conn.SetDeadline(time.Time{}); err != nil { - if debug { - log.Println("Weird error setting deadline:", err, "on", conn.RemoteAddr()) - } - conn.Close() - return - } - - default: - if debug { - log.Println("Unexpected message from", conn.RemoteAddr(), message) - } - protocol.WriteMessage(conn, protocol.ResponseUnexpectedMessage) - conn.Close() - } -}