Use a single socket for relaying

This commit is contained in:
AudriusButkevicius 2015-09-02 21:35:52 +01:00
parent 7fe1fdd8c7
commit c0554c9fbf
4 changed files with 79 additions and 111 deletions

View File

@ -11,6 +11,7 @@ import (
"time" "time"
syncthingprotocol "github.com/syncthing/protocol" syncthingprotocol "github.com/syncthing/protocol"
"github.com/syncthing/syncthing/lib/tlsutil"
"github.com/syncthing/relaysrv/protocol" "github.com/syncthing/relaysrv/protocol"
) )
@ -21,14 +22,16 @@ var (
numConnections int64 numConnections int64
) )
func protocolListener(addr string, config *tls.Config) { func listener(addr string, config *tls.Config) {
listener, err := net.Listen("tcp", addr) tcpListener, err := net.Listen("tcp", addr)
if err != nil { if err != nil {
log.Fatalln(err) log.Fatalln(err)
} }
listener := tlsutil.DowngradingListener{tcpListener, nil}
for { for {
conn, err := listener.Accept() conn, isTLS, err := listener.AcceptNoWrap()
if err != nil { if err != nil {
if debug { if debug {
log.Println(err) log.Println(err)
@ -39,10 +42,15 @@ func protocolListener(addr string, config *tls.Config) {
setTCPOptions(conn) setTCPOptions(conn)
if debug { if debug {
log.Println("Protocol listener accepted connection from", conn.RemoteAddr()) log.Println("Listener accepted connection from", conn.RemoteAddr(), "tls", isTLS)
} }
if isTLS {
go protocolConnectionHandler(conn, config) go protocolConnectionHandler(conn, config)
} else {
go sessionConnectionHandler(conn)
}
} }
} }
@ -220,6 +228,65 @@ func protocolConnectionHandler(tcpConn net.Conn, config *tls.Config) {
} }
} }
func sessionConnectionHandler(conn net.Conn) {
if err := conn.SetDeadline(time.Now().Add(messageTimeout)); err != nil {
if debug {
log.Println("Weird error setting deadline:", err, "on", conn.RemoteAddr())
}
return
}
message, err := protocol.ReadMessage(conn)
if err != nil {
return
}
switch msg := message.(type) {
case protocol.JoinSessionRequest:
ses := findSession(string(msg.Key))
if debug {
log.Println(conn.RemoteAddr(), "session lookup", ses)
}
if ses == nil {
protocol.WriteMessage(conn, protocol.ResponseNotFound)
conn.Close()
return
}
if !ses.AddConnection(conn) {
if debug {
log.Println("Failed to add", conn.RemoteAddr(), "to session", ses)
}
protocol.WriteMessage(conn, protocol.ResponseAlreadyConnected)
conn.Close()
return
}
if err := protocol.WriteMessage(conn, protocol.ResponseSuccess); err != nil {
if debug {
log.Println("Failed to send session join response to ", conn.RemoteAddr(), "for", ses)
}
return
}
if err := conn.SetDeadline(time.Time{}); err != nil {
if debug {
log.Println("Weird error setting deadline:", err, "on", conn.RemoteAddr())
}
conn.Close()
return
}
default:
if debug {
log.Println("Unexpected message from", conn.RemoteAddr(), message)
}
protocol.WriteMessage(conn, protocol.ResponseUnexpectedMessage)
conn.Close()
}
}
func messageReader(conn net.Conn, messages chan<- interface{}, errors chan<- error) { func messageReader(conn net.Conn, messages chan<- interface{}, errors chan<- error) {
atomic.AddInt64(&numConnections, 1) atomic.AddInt64(&numConnections, 1)
defer atomic.AddInt64(&numConnections, -1) defer atomic.AddInt64(&numConnections, -1)

13
main.go
View File

@ -17,8 +17,7 @@ import (
) )
var ( var (
listenProtocol string listen string
listenSession string
debug bool debug bool
sessionAddress []byte sessionAddress []byte
@ -39,9 +38,7 @@ var (
func main() { func main() {
var dir, extAddress string var dir, extAddress string
flag.StringVar(&listenProtocol, "protocol-listen", ":22067", "Protocol listen address") flag.StringVar(&listen, "listen", ":22067", "Protocol listen address")
flag.StringVar(&listenSession, "session-listen", ":22068", "Session listen address")
flag.StringVar(&extAddress, "external-address", "", "External address to advertise, defaults no IP and session-listen port, causing clients to use the remote IP from the protocol connection")
flag.StringVar(&dir, "keys", ".", "Directory where cert.pem and key.pem is stored") flag.StringVar(&dir, "keys", ".", "Directory where cert.pem and key.pem is stored")
flag.DurationVar(&networkTimeout, "network-timeout", 2*time.Minute, "Timeout for network operations") flag.DurationVar(&networkTimeout, "network-timeout", 2*time.Minute, "Timeout for network operations")
flag.DurationVar(&pingInterval, "ping-interval", time.Minute, "How often pings are sent") flag.DurationVar(&pingInterval, "ping-interval", time.Minute, "How often pings are sent")
@ -54,7 +51,7 @@ func main() {
flag.Parse() flag.Parse()
if extAddress == "" { if extAddress == "" {
extAddress = listenSession extAddress = listen
} }
addr, err := net.ResolveTCPAddr("tcp", extAddress) addr, err := net.ResolveTCPAddr("tcp", extAddress)
@ -100,11 +97,9 @@ func main() {
globalLimiter = ratelimit.NewBucketWithRate(float64(globalLimitBps), int64(2*globalLimitBps)) globalLimiter = ratelimit.NewBucketWithRate(float64(globalLimitBps), int64(2*globalLimitBps))
} }
go sessionListener(listenSession)
if statusAddr != "" { if statusAddr != "" {
go statusService(statusAddr) go statusService(statusAddr)
} }
protocolListener(listenProtocol, tlsCfg) listener(listen, tlsCfg)
} }

View File

@ -7,8 +7,9 @@ package protocol
import ( import (
"fmt" "fmt"
syncthingprotocol "github.com/syncthing/protocol"
"net" "net"
syncthingprotocol "github.com/syncthing/protocol"
) )
const ( const (

View File

@ -1,95 +0,0 @@
// Copyright (C) 2015 Audrius Butkevicius and Contributors (see the CONTRIBUTORS file).
package main
import (
"log"
"net"
"time"
"github.com/syncthing/relaysrv/protocol"
)
func sessionListener(addr string) {
listener, err := net.Listen("tcp", addr)
if err != nil {
log.Fatalln(err)
}
for {
conn, err := listener.Accept()
if err != nil {
if debug {
log.Println(err)
}
continue
}
setTCPOptions(conn)
if debug {
log.Println("Session listener accepted connection from", conn.RemoteAddr())
}
go sessionConnectionHandler(conn)
}
}
func sessionConnectionHandler(conn net.Conn) {
if err := conn.SetDeadline(time.Now().Add(messageTimeout)); err != nil {
if debug {
log.Println("Weird error setting deadline:", err, "on", conn.RemoteAddr())
}
return
}
message, err := protocol.ReadMessage(conn)
if err != nil {
return
}
switch msg := message.(type) {
case protocol.JoinSessionRequest:
ses := findSession(string(msg.Key))
if debug {
log.Println(conn.RemoteAddr(), "session lookup", ses)
}
if ses == nil {
protocol.WriteMessage(conn, protocol.ResponseNotFound)
conn.Close()
return
}
if !ses.AddConnection(conn) {
if debug {
log.Println("Failed to add", conn.RemoteAddr(), "to session", ses)
}
protocol.WriteMessage(conn, protocol.ResponseAlreadyConnected)
conn.Close()
return
}
if err := protocol.WriteMessage(conn, protocol.ResponseSuccess); err != nil {
if debug {
log.Println("Failed to send session join response to ", conn.RemoteAddr(), "for", ses)
}
return
}
if err := conn.SetDeadline(time.Time{}); err != nil {
if debug {
log.Println("Weird error setting deadline:", err, "on", conn.RemoteAddr())
}
conn.Close()
return
}
default:
if debug {
log.Println("Unexpected message from", conn.RemoteAddr(), message)
}
protocol.WriteMessage(conn, protocol.ResponseUnexpectedMessage)
conn.Close()
}
}