From 19d742b9e4b1f4f338c70c94547ed163406e66ae Mon Sep 17 00:00:00 2001 From: Audrius Butkevicius Date: Wed, 24 Jun 2015 00:34:16 +0100 Subject: [PATCH 01/30] Initial commit --- .gitignore | 24 ++++++++++++++++++++++++ LICENSE | 22 ++++++++++++++++++++++ 2 files changed, 46 insertions(+) create mode 100644 .gitignore create mode 100644 LICENSE diff --git a/.gitignore b/.gitignore new file mode 100644 index 00000000..daf913b1 --- /dev/null +++ b/.gitignore @@ -0,0 +1,24 @@ +# Compiled Object files, Static and Dynamic libs (Shared Objects) +*.o +*.a +*.so + +# Folders +_obj +_test + +# Architecture specific extensions/prefixes +*.[568vq] +[568vq].out + +*.cgo1.go +*.cgo2.c +_cgo_defun.c +_cgo_gotypes.go +_cgo_export.* + +_testmain.go + +*.exe +*.test +*.prof diff --git a/LICENSE b/LICENSE new file mode 100644 index 00000000..581a1705 --- /dev/null +++ b/LICENSE @@ -0,0 +1,22 @@ +The MIT License (MIT) + +Copyright (c) 2015 The Syncthing Project + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + From 8e191c8e6bd16847cd5599c2625d685a4c1bef8e Mon Sep 17 00:00:00 2001 From: Audrius Butkevicius Date: Wed, 24 Jun 2015 12:39:46 +0100 Subject: [PATCH 02/30] Add initial code --- CONTRIBUTORS | 0 main.go | 88 +++++++++ protocol/packets.go | 45 +++++ protocol/packets_xdr.go | 415 ++++++++++++++++++++++++++++++++++++++++ protocol_listener.go | 230 ++++++++++++++++++++++ session.go | 173 +++++++++++++++++ session_listener.go | 59 ++++++ utils.go | 53 +++++ 8 files changed, 1063 insertions(+) create mode 100644 CONTRIBUTORS create mode 100644 main.go create mode 100644 protocol/packets.go create mode 100644 protocol/packets_xdr.go create mode 100644 protocol_listener.go create mode 100644 session.go create mode 100644 session_listener.go create mode 100644 utils.go diff --git a/CONTRIBUTORS b/CONTRIBUTORS new file mode 100644 index 00000000..e69de29b diff --git a/main.go b/main.go new file mode 100644 index 00000000..3c4d533e --- /dev/null +++ b/main.go @@ -0,0 +1,88 @@ +// Copyright (C) 2015 Audrius Butkevicius and Contributors (see the CONTRIBUTORS file). + +package main + +import ( + "crypto/tls" + "flag" + "log" + "os" + "path/filepath" + "sync" + "time" + + syncthingprotocol "github.com/syncthing/protocol" + "github.com/syncthing/relaysrv/protocol" +) + +var ( + listenProtocol string + listenSession string + debug bool + + sessionAddress []byte + sessionPort uint16 + + networkTimeout time.Duration + pingInterval time.Duration + messageTimeout time.Duration + + pingMessage message + + mut = sync.RWMutex{} + outbox = make(map[syncthingprotocol.DeviceID]chan message) +) + +func main() { + var dir, extAddress string + + pingPayload := protocol.Ping{}.MustMarshalXDR() + pingMessage = message{ + header: protocol.Header{ + Magic: protocol.Magic, + MessageType: protocol.MessageTypePing, + MessageLength: int32(len(pingPayload)), + }, + payload: pingPayload, + } + + flag.StringVar(&listenProtocol, "protocol-listen", ":22067", "Protocol listen address") + flag.StringVar(&listenSession, "session-listen", ":22068", "Session listen address") + flag.StringVar(&extAddress, "external-address", "", "External address to advertise, defaults no IP and session-listen port, causing clients to use the remote IP from the protocol connection") + flag.StringVar(&dir, "keys", ".", "Directory where cert.pem and key.pem is stored") + flag.DurationVar(&networkTimeout, "network-timeout", 2*time.Minute, "Timeout for network operations") + flag.DurationVar(&pingInterval, "ping-interval", time.Minute, "How often pings are sent") + flag.DurationVar(&messageTimeout, "message-timeout", time.Minute, "Maximum amount of time we wait for relevant messages to arrive") + + flag.BoolVar(&debug, "debug", false, "Enable debug output") + flag.Parse() + + certFile, keyFile := filepath.Join(dir, "cert.pem"), filepath.Join(dir, "key.pem") + cert, err := tls.LoadX509KeyPair(certFile, keyFile) + if err != nil { + log.Fatalln("Failed to load X509 key pair:", err) + } + + tlsCfg := &tls.Config{ + Certificates: []tls.Certificate{cert}, + NextProtos: []string{protocol.ProtocolName}, + ClientAuth: tls.RequestClientCert, + SessionTicketsDisabled: true, + InsecureSkipVerify: true, + MinVersion: tls.VersionTLS12, + CipherSuites: []uint16{ + tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, + tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA, + tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, + tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA, + tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, + }, + } + + log.SetOutput(os.Stdout) + + go sessionListener(listenSession) + + protocolListener(listenProtocol, tlsCfg) +} diff --git a/protocol/packets.go b/protocol/packets.go new file mode 100644 index 00000000..4675d1cf --- /dev/null +++ b/protocol/packets.go @@ -0,0 +1,45 @@ +// Copyright (C) 2015 Audrius Butkevicius and Contributors (see the CONTRIBUTORS file). + +//go:generate -command genxdr go run ../../syncthing/Godeps/_workspace/src/github.com/calmh/xdr/cmd/genxdr/main.go +//go:generate genxdr -o packets_xdr.go packets.go + +package protocol + +import ( + "unsafe" +) + +const ( + Magic = 0x9E79BC40 + HeaderSize = unsafe.Sizeof(&Header{}) + ProtocolName = "bep-relay" +) + +const ( + MessageTypePing int32 = iota + MessageTypePong + MessageTypeJoinRequest + MessageTypeConnectRequest + MessageTypeSessionInvitation +) + +type Header struct { + Magic uint32 + MessageType int32 + MessageLength int32 +} + +type Ping struct{} +type Pong struct{} +type JoinRequest struct{} + +type ConnectRequest struct { + ID []byte // max:32 +} + +type SessionInvitation struct { + Key []byte // max:32 + Address []byte // max:32 + Port uint16 + ServerSocket bool +} diff --git a/protocol/packets_xdr.go b/protocol/packets_xdr.go new file mode 100644 index 00000000..ca547e00 --- /dev/null +++ b/protocol/packets_xdr.go @@ -0,0 +1,415 @@ +// ************************************************************ +// This file is automatically generated by genxdr. Do not edit. +// ************************************************************ + +package protocol + +import ( + "bytes" + "io" + + "github.com/calmh/xdr" +) + +/* + +Header Structure: + + 0 1 2 3 + 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| Magic | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| Message Type | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| Message Length | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + + +struct Header { + unsigned int Magic; + int MessageType; + int MessageLength; +} + +*/ + +func (o Header) EncodeXDR(w io.Writer) (int, error) { + var xw = xdr.NewWriter(w) + return o.EncodeXDRInto(xw) +} + +func (o Header) MarshalXDR() ([]byte, error) { + return o.AppendXDR(make([]byte, 0, 128)) +} + +func (o Header) MustMarshalXDR() []byte { + bs, err := o.MarshalXDR() + if err != nil { + panic(err) + } + return bs +} + +func (o Header) AppendXDR(bs []byte) ([]byte, error) { + var aw = xdr.AppendWriter(bs) + var xw = xdr.NewWriter(&aw) + _, err := o.EncodeXDRInto(xw) + return []byte(aw), err +} + +func (o Header) EncodeXDRInto(xw *xdr.Writer) (int, error) { + xw.WriteUint32(o.Magic) + xw.WriteUint32(uint32(o.MessageType)) + xw.WriteUint32(uint32(o.MessageLength)) + return xw.Tot(), xw.Error() +} + +func (o *Header) DecodeXDR(r io.Reader) error { + xr := xdr.NewReader(r) + return o.DecodeXDRFrom(xr) +} + +func (o *Header) UnmarshalXDR(bs []byte) error { + var br = bytes.NewReader(bs) + var xr = xdr.NewReader(br) + return o.DecodeXDRFrom(xr) +} + +func (o *Header) DecodeXDRFrom(xr *xdr.Reader) error { + o.Magic = xr.ReadUint32() + o.MessageType = int32(xr.ReadUint32()) + o.MessageLength = int32(xr.ReadUint32()) + return xr.Error() +} + +/* + +Ping Structure: + + 0 1 2 3 + 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + + +struct Ping { +} + +*/ + +func (o Ping) EncodeXDR(w io.Writer) (int, error) { + var xw = xdr.NewWriter(w) + return o.EncodeXDRInto(xw) +} + +func (o Ping) MarshalXDR() ([]byte, error) { + return o.AppendXDR(make([]byte, 0, 128)) +} + +func (o Ping) MustMarshalXDR() []byte { + bs, err := o.MarshalXDR() + if err != nil { + panic(err) + } + return bs +} + +func (o Ping) AppendXDR(bs []byte) ([]byte, error) { + var aw = xdr.AppendWriter(bs) + var xw = xdr.NewWriter(&aw) + _, err := o.EncodeXDRInto(xw) + return []byte(aw), err +} + +func (o Ping) EncodeXDRInto(xw *xdr.Writer) (int, error) { + return xw.Tot(), xw.Error() +} + +func (o *Ping) DecodeXDR(r io.Reader) error { + xr := xdr.NewReader(r) + return o.DecodeXDRFrom(xr) +} + +func (o *Ping) UnmarshalXDR(bs []byte) error { + var br = bytes.NewReader(bs) + var xr = xdr.NewReader(br) + return o.DecodeXDRFrom(xr) +} + +func (o *Ping) DecodeXDRFrom(xr *xdr.Reader) error { + return xr.Error() +} + +/* + +Pong Structure: + + 0 1 2 3 + 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + + +struct Pong { +} + +*/ + +func (o Pong) EncodeXDR(w io.Writer) (int, error) { + var xw = xdr.NewWriter(w) + return o.EncodeXDRInto(xw) +} + +func (o Pong) MarshalXDR() ([]byte, error) { + return o.AppendXDR(make([]byte, 0, 128)) +} + +func (o Pong) MustMarshalXDR() []byte { + bs, err := o.MarshalXDR() + if err != nil { + panic(err) + } + return bs +} + +func (o Pong) AppendXDR(bs []byte) ([]byte, error) { + var aw = xdr.AppendWriter(bs) + var xw = xdr.NewWriter(&aw) + _, err := o.EncodeXDRInto(xw) + return []byte(aw), err +} + +func (o Pong) EncodeXDRInto(xw *xdr.Writer) (int, error) { + return xw.Tot(), xw.Error() +} + +func (o *Pong) DecodeXDR(r io.Reader) error { + xr := xdr.NewReader(r) + return o.DecodeXDRFrom(xr) +} + +func (o *Pong) UnmarshalXDR(bs []byte) error { + var br = bytes.NewReader(bs) + var xr = xdr.NewReader(br) + return o.DecodeXDRFrom(xr) +} + +func (o *Pong) DecodeXDRFrom(xr *xdr.Reader) error { + return xr.Error() +} + +/* + +JoinRequest Structure: + + 0 1 2 3 + 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + + +struct JoinRequest { +} + +*/ + +func (o JoinRequest) EncodeXDR(w io.Writer) (int, error) { + var xw = xdr.NewWriter(w) + return o.EncodeXDRInto(xw) +} + +func (o JoinRequest) MarshalXDR() ([]byte, error) { + return o.AppendXDR(make([]byte, 0, 128)) +} + +func (o JoinRequest) MustMarshalXDR() []byte { + bs, err := o.MarshalXDR() + if err != nil { + panic(err) + } + return bs +} + +func (o JoinRequest) AppendXDR(bs []byte) ([]byte, error) { + var aw = xdr.AppendWriter(bs) + var xw = xdr.NewWriter(&aw) + _, err := o.EncodeXDRInto(xw) + return []byte(aw), err +} + +func (o JoinRequest) EncodeXDRInto(xw *xdr.Writer) (int, error) { + return xw.Tot(), xw.Error() +} + +func (o *JoinRequest) DecodeXDR(r io.Reader) error { + xr := xdr.NewReader(r) + return o.DecodeXDRFrom(xr) +} + +func (o *JoinRequest) UnmarshalXDR(bs []byte) error { + var br = bytes.NewReader(bs) + var xr = xdr.NewReader(br) + return o.DecodeXDRFrom(xr) +} + +func (o *JoinRequest) DecodeXDRFrom(xr *xdr.Reader) error { + return xr.Error() +} + +/* + +ConnectRequest Structure: + + 0 1 2 3 + 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| Length of ID | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +/ / +\ ID (variable length) \ +/ / ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + + +struct ConnectRequest { + opaque ID<32>; +} + +*/ + +func (o ConnectRequest) EncodeXDR(w io.Writer) (int, error) { + var xw = xdr.NewWriter(w) + return o.EncodeXDRInto(xw) +} + +func (o ConnectRequest) MarshalXDR() ([]byte, error) { + return o.AppendXDR(make([]byte, 0, 128)) +} + +func (o ConnectRequest) MustMarshalXDR() []byte { + bs, err := o.MarshalXDR() + if err != nil { + panic(err) + } + return bs +} + +func (o ConnectRequest) AppendXDR(bs []byte) ([]byte, error) { + var aw = xdr.AppendWriter(bs) + var xw = xdr.NewWriter(&aw) + _, err := o.EncodeXDRInto(xw) + return []byte(aw), err +} + +func (o ConnectRequest) EncodeXDRInto(xw *xdr.Writer) (int, error) { + if l := len(o.ID); l > 32 { + return xw.Tot(), xdr.ElementSizeExceeded("ID", l, 32) + } + xw.WriteBytes(o.ID) + return xw.Tot(), xw.Error() +} + +func (o *ConnectRequest) DecodeXDR(r io.Reader) error { + xr := xdr.NewReader(r) + return o.DecodeXDRFrom(xr) +} + +func (o *ConnectRequest) UnmarshalXDR(bs []byte) error { + var br = bytes.NewReader(bs) + var xr = xdr.NewReader(br) + return o.DecodeXDRFrom(xr) +} + +func (o *ConnectRequest) DecodeXDRFrom(xr *xdr.Reader) error { + o.ID = xr.ReadBytesMax(32) + return xr.Error() +} + +/* + +SessionInvitation Structure: + + 0 1 2 3 + 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| Length of Key | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +/ / +\ Key (variable length) \ +/ / ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| Length of Address | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +/ / +\ Address (variable length) \ +/ / ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| 0x0000 | Port | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| Server Socket (V=0 or 1) |V| ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + + +struct SessionInvitation { + opaque Key<32>; + opaque Address<32>; + unsigned int Port; + bool ServerSocket; +} + +*/ + +func (o SessionInvitation) EncodeXDR(w io.Writer) (int, error) { + var xw = xdr.NewWriter(w) + return o.EncodeXDRInto(xw) +} + +func (o SessionInvitation) MarshalXDR() ([]byte, error) { + return o.AppendXDR(make([]byte, 0, 128)) +} + +func (o SessionInvitation) MustMarshalXDR() []byte { + bs, err := o.MarshalXDR() + if err != nil { + panic(err) + } + return bs +} + +func (o SessionInvitation) AppendXDR(bs []byte) ([]byte, error) { + var aw = xdr.AppendWriter(bs) + var xw = xdr.NewWriter(&aw) + _, err := o.EncodeXDRInto(xw) + return []byte(aw), err +} + +func (o SessionInvitation) EncodeXDRInto(xw *xdr.Writer) (int, error) { + if l := len(o.Key); l > 32 { + return xw.Tot(), xdr.ElementSizeExceeded("Key", l, 32) + } + xw.WriteBytes(o.Key) + if l := len(o.Address); l > 32 { + return xw.Tot(), xdr.ElementSizeExceeded("Address", l, 32) + } + xw.WriteBytes(o.Address) + xw.WriteUint16(o.Port) + xw.WriteBool(o.ServerSocket) + return xw.Tot(), xw.Error() +} + +func (o *SessionInvitation) DecodeXDR(r io.Reader) error { + xr := xdr.NewReader(r) + return o.DecodeXDRFrom(xr) +} + +func (o *SessionInvitation) UnmarshalXDR(bs []byte) error { + var br = bytes.NewReader(bs) + var xr = xdr.NewReader(br) + return o.DecodeXDRFrom(xr) +} + +func (o *SessionInvitation) DecodeXDRFrom(xr *xdr.Reader) error { + o.Key = xr.ReadBytesMax(32) + o.Address = xr.ReadBytesMax(32) + o.Port = xr.ReadUint16() + o.ServerSocket = xr.ReadBool() + return xr.Error() +} diff --git a/protocol_listener.go b/protocol_listener.go new file mode 100644 index 00000000..b6d89b22 --- /dev/null +++ b/protocol_listener.go @@ -0,0 +1,230 @@ +// Copyright (C) 2015 Audrius Butkevicius and Contributors (see the CONTRIBUTORS file). + +package main + +import ( + "crypto/tls" + "io" + "log" + "net" + "time" + + syncthingprotocol "github.com/syncthing/protocol" + + "github.com/syncthing/relaysrv/protocol" +) + +type message struct { + header protocol.Header + payload []byte +} + +func protocolListener(addr string, config *tls.Config) { + listener, err := net.Listen("tcp", addr) + if err != nil { + log.Fatalln(err) + } + + for { + conn, err := listener.Accept() + if err != nil { + if debug { + log.Println(err) + } + continue + } + + if debug { + log.Println("Protocol listener accepted connection from", conn.RemoteAddr()) + } + + go protocolConnectionHandler(conn, config) + } +} + +func protocolConnectionHandler(tcpConn net.Conn, config *tls.Config) { + err := setTCPOptions(tcpConn) + if err != nil && debug { + log.Println("Failed to set TCP options on protocol connection", tcpConn.RemoteAddr(), err) + } + + conn := tls.Server(tcpConn, config) + err = conn.Handshake() + if err != nil { + log.Println("Protocol connection TLS handshake:", conn.RemoteAddr(), err) + conn.Close() + return + } + + state := conn.ConnectionState() + if (!state.NegotiatedProtocolIsMutual || state.NegotiatedProtocol != protocol.ProtocolName) && debug { + log.Println("Protocol negotiation error") + } + + certs := state.PeerCertificates + if len(certs) != 1 { + log.Println("Certificate list error") + conn.Close() + return + } + + deviceId := syncthingprotocol.NewDeviceID(certs[0].Raw) + + mut.RLock() + _, ok := outbox[deviceId] + mut.RUnlock() + if ok { + log.Println("Already have a peer with the same ID", deviceId, conn.RemoteAddr()) + conn.Close() + return + } + + errorChannel := make(chan error) + messageChannel := make(chan message) + outboxChannel := make(chan message) + + go readerLoop(conn, messageChannel, errorChannel) + + pingTicker := time.NewTicker(pingInterval) + timeoutTicker := time.NewTimer(messageTimeout * 2) + joined := false + + for { + select { + case msg := <-messageChannel: + switch msg.header.MessageType { + case protocol.MessageTypeJoinRequest: + mut.Lock() + outbox[deviceId] = outboxChannel + mut.Unlock() + joined = true + case protocol.MessageTypeConnectRequest: + // We will disconnect after this message, no matter what, + // because, we've either sent out an invitation, or we don't + // have the peer available. + var fmsg protocol.ConnectRequest + err := fmsg.UnmarshalXDR(msg.payload) + if err != nil { + log.Println(err) + conn.Close() + continue + } + + requestedPeer := syncthingprotocol.DeviceIDFromBytes(fmsg.ID) + mut.RLock() + peerOutbox, ok := outbox[requestedPeer] + mut.RUnlock() + if !ok { + if debug { + log.Println("Do not have", requestedPeer) + } + conn.Close() + continue + } + + ses := newSession() + + smsg, err := ses.GetServerInvitationMessage() + if err != nil { + log.Println("Error getting server invitation", requestedPeer) + conn.Close() + continue + } + cmsg, err := ses.GetClientInvitationMessage() + if err != nil { + log.Println("Error getting client invitation", requestedPeer) + conn.Close() + continue + } + + go ses.Serve() + + if err := sendMessage(cmsg, conn); err != nil { + log.Println("Failed to send invitation message", err) + } else { + peerOutbox <- smsg + if debug { + log.Println("Sent invitation from", deviceId, "to", requestedPeer) + } + } + conn.Close() + case protocol.MessageTypePong: + timeoutTicker.Reset(messageTimeout) + } + case err := <-errorChannel: + log.Println("Closing connection:", err) + return + case <-pingTicker.C: + if !joined { + log.Println(deviceId, "didn't join within", messageTimeout) + conn.Close() + continue + } + + if err := sendMessage(pingMessage, conn); err != nil { + log.Println(err) + conn.Close() + continue + } + case <-timeoutTicker.C: + // We should receive a error, which will cause us to quit the + // loop. + conn.Close() + case msg := <-outboxChannel: + if debug { + log.Println("Sending message to", deviceId, msg) + } + if err := sendMessage(msg, conn); err == nil { + log.Println(err) + conn.Close() + continue + } + } + } +} + +func readerLoop(conn *tls.Conn, messages chan<- message, errors chan<- error) { + header := make([]byte, protocol.HeaderSize) + data := make([]byte, 0, 0) + for { + _, err := io.ReadFull(conn, header) + if err != nil { + errors <- err + conn.Close() + return + } + + var hdr protocol.Header + err = hdr.UnmarshalXDR(header) + if err != nil { + conn.Close() + return + } + + if hdr.Magic != protocol.Magic { + conn.Close() + return + } + + if hdr.MessageLength > int32(cap(data)) { + data = make([]byte, 0, hdr.MessageLength) + } else { + data = data[:hdr.MessageLength] + } + + _, err = io.ReadFull(conn, data) + if err != nil { + errors <- err + conn.Close() + return + } + + msg := message{ + header: hdr, + payload: make([]byte, hdr.MessageLength), + } + copy(msg.payload, data[:hdr.MessageLength]) + + messages <- msg + } +} diff --git a/session.go b/session.go new file mode 100644 index 00000000..3466bd53 --- /dev/null +++ b/session.go @@ -0,0 +1,173 @@ +// Copyright (C) 2015 Audrius Butkevicius and Contributors (see the CONTRIBUTORS file). + +package main + +import ( + "crypto/rand" + "net" + "sync" + "time" + + "github.com/syncthing/relaysrv/protocol" +) + +var ( + sessionmut = sync.Mutex{} + sessions = make(map[string]*session, 0) +) + +type session struct { + serverkey string + clientkey string + + mut sync.RWMutex + conns chan net.Conn +} + +func newSession() *session { + serverkey := make([]byte, 32) + _, err := rand.Read(serverkey) + if err != nil { + return nil + } + + clientkey := make([]byte, 32) + _, err = rand.Read(clientkey) + if err != nil { + return nil + } + + return &session{ + serverkey: string(serverkey), + clientkey: string(clientkey), + conns: make(chan net.Conn), + } +} + +func findSession(key string) *session { + sessionmut.Lock() + defer sessionmut.Unlock() + lob, ok := sessions[key] + if !ok { + return nil + + } + delete(sessions, key) + return lob +} + +func (l *session) AddConnection(conn net.Conn) { + select { + case l.conns <- conn: + default: + } +} + +func (l *session) Serve() { + + timedout := time.After(messageTimeout) + + sessionmut.Lock() + sessions[l.serverkey] = l + sessions[l.clientkey] = l + sessionmut.Unlock() + + conns := make([]net.Conn, 0, 2) + for { + select { + case conn := <-l.conns: + conns = append(conns, conn) + if len(conns) < 2 { + continue + } + + close(l.conns) + + wg := sync.WaitGroup{} + + wg.Add(2) + + go proxy(conns[0], conns[1], wg) + go proxy(conns[1], conns[0], wg) + + wg.Wait() + + break + case <-timedout: + sessionmut.Lock() + delete(sessions, l.serverkey) + delete(sessions, l.clientkey) + sessionmut.Unlock() + + for _, conn := range conns { + conn.Close() + } + + break + } + } +} + +func (l *session) GetClientInvitationMessage() (message, error) { + invitation := protocol.SessionInvitation{ + Key: []byte(l.clientkey), + Address: nil, + Port: 123, + ServerSocket: false, + } + data, err := invitation.MarshalXDR() + if err != nil { + return message{}, err + } + + return message{ + header: protocol.Header{ + Magic: protocol.Magic, + MessageType: protocol.MessageTypeSessionInvitation, + MessageLength: int32(len(data)), + }, + payload: data, + }, nil +} + +func (l *session) GetServerInvitationMessage() (message, error) { + invitation := protocol.SessionInvitation{ + Key: []byte(l.serverkey), + Address: nil, + Port: 123, + ServerSocket: true, + } + data, err := invitation.MarshalXDR() + if err != nil { + return message{}, err + } + + return message{ + header: protocol.Header{ + Magic: protocol.Magic, + MessageType: protocol.MessageTypeSessionInvitation, + MessageLength: int32(len(data)), + }, + payload: data, + }, nil +} + +func proxy(c1, c2 net.Conn, wg sync.WaitGroup) { + for { + buf := make([]byte, 1024) + c1.SetReadDeadline(time.Now().Add(networkTimeout)) + n, err := c1.Read(buf) + if err != nil { + break + } + + c2.SetWriteDeadline(time.Now().Add(networkTimeout)) + _, err = c2.Write(buf[:n]) + if err != nil { + break + } + } + c1.Close() + c2.Close() + wg.Done() +} diff --git a/session_listener.go b/session_listener.go new file mode 100644 index 00000000..b78c4f4b --- /dev/null +++ b/session_listener.go @@ -0,0 +1,59 @@ +// Copyright (C) 2015 Audrius Butkevicius and Contributors (see the CONTRIBUTORS file). + +package main + +import ( + "io" + "log" + "net" + "time" +) + +func sessionListener(addr string) { + listener, err := net.Listen("tcp", addr) + if err != nil { + log.Fatalln(err) + } + + for { + conn, err := listener.Accept() + if err != nil { + if debug { + log.Println(err) + } + continue + } + + if debug { + log.Println("Session listener accepted connection from", conn.RemoteAddr()) + } + + go sessionConnectionHandler(conn) + } +} + +func sessionConnectionHandler(conn net.Conn) { + conn.SetReadDeadline(time.Now().Add(messageTimeout)) + key := make([]byte, 32) + + _, err := io.ReadFull(conn, key) + if err != nil { + if debug { + log.Println("Failed to read key", err, conn.RemoteAddr()) + } + conn.Close() + return + } + + ses := findSession(string(key)) + if debug { + log.Println("Key", key, "by", conn.RemoteAddr(), "session", ses) + } + + if ses != nil { + ses.AddConnection(conn) + } else { + conn.Close() + return + } +} diff --git a/utils.go b/utils.go new file mode 100644 index 00000000..5388ba32 --- /dev/null +++ b/utils.go @@ -0,0 +1,53 @@ +// Copyright (C) 2015 Audrius Butkevicius and Contributors (see the CONTRIBUTORS file). + +package main + +import ( + "errors" + "net" + "time" +) + +func setTCPOptions(conn net.Conn) error { + tcpConn, ok := conn.(*net.TCPConn) + if !ok { + return errors.New("Not a TCP connection") + } + if err := tcpConn.SetLinger(0); err != nil { + return err + } + if err := tcpConn.SetNoDelay(true); err != nil { + return err + } + if err := tcpConn.SetKeepAlivePeriod(60 * time.Second); err != nil { + return err + } + if err := tcpConn.SetKeepAlive(true); err != nil { + return err + } + return nil +} + +func sendMessage(msg message, conn net.Conn) error { + header, err := msg.header.MarshalXDR() + if err != nil { + return err + } + + err = conn.SetWriteDeadline(time.Now().Add(networkTimeout)) + if err != nil { + return err + } + + _, err = conn.Write(header) + if err != nil { + return err + } + + _, err = conn.Write(msg.payload) + if err != nil { + return err + } + + return nil +} From b72d31f87fa34fa77a0a5a73712ec019bdf2bf61 Mon Sep 17 00:00:00 2001 From: Audrius Butkevicius Date: Sun, 28 Jun 2015 01:52:01 +0100 Subject: [PATCH 03/30] Progress --- README.md | 6 + client/client.go | 249 ++++++++++++++++++++++++++++++++++++++++ client/debug.go | 15 +++ client/methods.go | 113 ++++++++++++++++++ main.go | 39 ++++--- protocol/packets.go | 42 +++---- protocol/packets_xdr.go | 216 ++++++++++++++++++++++++++++------ protocol/protocol.go | 114 ++++++++++++++++++ protocol_listener.go | 233 +++++++++++++++++-------------------- session.go | 179 ++++++++++++++++------------- session_listener.go | 58 +++++++--- testutil/main.go | 142 +++++++++++++++++++++++ utils.go | 27 +---- 13 files changed, 1114 insertions(+), 319 deletions(-) create mode 100644 README.md create mode 100644 client/client.go create mode 100644 client/debug.go create mode 100644 client/methods.go create mode 100644 protocol/protocol.go create mode 100644 testutil/main.go diff --git a/README.md b/README.md new file mode 100644 index 00000000..e8892928 --- /dev/null +++ b/README.md @@ -0,0 +1,6 @@ +relaysrv +======== + +This is the relay server for the `syncthing` project. + +`go get github.com/syncthing/relaysrv` diff --git a/client/client.go b/client/client.go new file mode 100644 index 00000000..b48320fd --- /dev/null +++ b/client/client.go @@ -0,0 +1,249 @@ +// Copyright (C) 2015 Audrius Butkevicius and Contributors (see the CONTRIBUTORS file). + +package client + +import ( + "crypto/tls" + "fmt" + "log" + "net" + "net/url" + "time" + + syncthingprotocol "github.com/syncthing/protocol" + "github.com/syncthing/relaysrv/protocol" +) + +func NewProtocolClient(uri *url.URL, certs []tls.Certificate, invitations chan protocol.SessionInvitation) ProtocolClient { + closeInvitationsOnFinish := false + if invitations == nil { + closeInvitationsOnFinish = true + invitations = make(chan protocol.SessionInvitation) + } + return ProtocolClient{ + URI: uri, + Invitations: invitations, + + closeInvitationsOnFinish: closeInvitationsOnFinish, + + config: configForCerts(certs), + + timeout: time.Minute * 2, + + stop: make(chan struct{}), + stopped: make(chan struct{}), + } +} + +type ProtocolClient struct { + URI *url.URL + Invitations chan protocol.SessionInvitation + + closeInvitationsOnFinish bool + + config *tls.Config + + timeout time.Duration + + stop chan struct{} + stopped chan struct{} + + conn *tls.Conn +} + +func (c *ProtocolClient) connect() error { + conn, err := tls.Dial("tcp", c.URI.Host, c.config) + if err != nil { + return err + } + + conn.SetDeadline(time.Now().Add(10 * time.Second)) + + if err := performHandshakeAndValidation(conn, c.URI); err != nil { + return err + } + + c.conn = conn + return nil +} + +func (c *ProtocolClient) Serve() { + if err := c.connect(); err != nil { + panic(err) + } + + if debug { + l.Debugln(c, "connected", c.conn.RemoteAddr()) + } + + if err := c.join(); err != nil { + c.conn.Close() + panic(err) + } + + c.conn.SetDeadline(time.Time{}) + + if debug { + l.Debugln(c, "joined", c.conn.RemoteAddr(), "via", c.conn.LocalAddr()) + } + + c.stop = make(chan struct{}) + c.stopped = make(chan struct{}) + + defer c.cleanup() + + messages := make(chan interface{}) + errors := make(chan error, 1) + + go func(conn net.Conn, message chan<- interface{}, errors chan<- error) { + for { + msg, err := protocol.ReadMessage(conn) + if err != nil { + errors <- err + return + } + messages <- msg + } + }(c.conn, messages, errors) + + timeout := time.NewTimer(c.timeout) + for { + select { + case message := <-messages: + timeout.Reset(c.timeout) + if debug { + log.Printf("%s received message %T", c, message) + } + switch msg := message.(type) { + case protocol.Ping: + if err := protocol.WriteMessage(c.conn, protocol.Pong{}); err != nil { + panic(err) + } + if debug { + l.Debugln(c, "sent pong") + } + case protocol.SessionInvitation: + ip := net.IP(msg.Address) + if len(ip) == 0 || ip.IsUnspecified() { + msg.Address = c.conn.RemoteAddr().(*net.TCPAddr).IP[:] + } + c.Invitations <- msg + default: + panic(fmt.Errorf("protocol error: unexpected message %v", msg)) + } + case <-c.stop: + if debug { + l.Debugln(c, "stopping") + } + break + case err := <-errors: + panic(err) + case <-timeout.C: + if debug { + l.Debugln(c, "timed out") + } + return + } + } + + c.stopped <- struct{}{} +} + +func (c *ProtocolClient) Stop() { + if c.stop == nil { + return + } + + c.stop <- struct{}{} + <-c.stopped +} + +func (c *ProtocolClient) String() string { + return fmt.Sprintf("ProtocolClient@%p", c) +} + +func (c *ProtocolClient) cleanup() { + if c.closeInvitationsOnFinish { + close(c.Invitations) + c.Invitations = make(chan protocol.SessionInvitation) + } + + if debug { + l.Debugln(c, "cleaning up") + } + + if c.stop != nil { + close(c.stop) + c.stop = nil + } + + if c.stopped != nil { + close(c.stopped) + c.stopped = nil + } + + if c.conn != nil { + c.conn.Close() + } +} + +func (c *ProtocolClient) join() error { + err := protocol.WriteMessage(c.conn, protocol.JoinRelayRequest{}) + if err != nil { + return err + } + + message, err := protocol.ReadMessage(c.conn) + if err != nil { + return err + } + + switch msg := message.(type) { + case protocol.Response: + if msg.Code != 0 { + return fmt.Errorf("Incorrect response code %d: %s", msg.Code, msg.Message) + } + default: + return fmt.Errorf("protocol error: expecting response got %v", msg) + } + + return nil +} + +func performHandshakeAndValidation(conn *tls.Conn, uri *url.URL) error { + err := conn.Handshake() + if err != nil { + conn.Close() + return err + } + + cs := conn.ConnectionState() + if !cs.NegotiatedProtocolIsMutual || cs.NegotiatedProtocol != protocol.ProtocolName { + conn.Close() + return fmt.Errorf("protocol negotiation error") + } + + q := uri.Query() + relayIDs := q.Get("id") + if relayIDs != "" { + relayID, err := syncthingprotocol.DeviceIDFromString(relayIDs) + if err != nil { + conn.Close() + return fmt.Errorf("relay address contains invalid verification id: %s", err) + } + + certs := cs.PeerCertificates + if cl := len(certs); cl != 1 { + conn.Close() + return fmt.Errorf("unexpected certificate count: %d", cl) + } + + remoteID := syncthingprotocol.NewDeviceID(certs[0].Raw) + if remoteID != relayID { + conn.Close() + return fmt.Errorf("relay id does not match. Expected %v got %v", relayID, remoteID) + } + } + + return nil +} diff --git a/client/debug.go b/client/debug.go new file mode 100644 index 00000000..4a3608de --- /dev/null +++ b/client/debug.go @@ -0,0 +1,15 @@ +// Copyright (C) 2015 Audrius Butkevicius and Contributors (see the CONTRIBUTORS file). + +package client + +import ( + "os" + "strings" + + "github.com/calmh/logger" +) + +var ( + debug = strings.Contains(os.Getenv("STTRACE"), "relay") || os.Getenv("STTRACE") == "all" + l = logger.DefaultLogger +) diff --git a/client/methods.go b/client/methods.go new file mode 100644 index 00000000..1d457e29 --- /dev/null +++ b/client/methods.go @@ -0,0 +1,113 @@ +// Copyright (C) 2015 Audrius Butkevicius and Contributors (see the CONTRIBUTORS file). + +package client + +import ( + "crypto/tls" + "fmt" + "net" + "net/url" + "strconv" + "time" + + syncthingprotocol "github.com/syncthing/protocol" + "github.com/syncthing/relaysrv/protocol" +) + +func GetInvitationFromRelay(uri *url.URL, id syncthingprotocol.DeviceID, certs []tls.Certificate) (protocol.SessionInvitation, error) { + conn, err := tls.Dial("tcp", uri.Host, configForCerts(certs)) + conn.SetDeadline(time.Now().Add(10 * time.Second)) + if err != nil { + return protocol.SessionInvitation{}, err + } + + if err := performHandshakeAndValidation(conn, uri); err != nil { + return protocol.SessionInvitation{}, err + } + + defer conn.Close() + + request := protocol.ConnectRequest{ + ID: id[:], + } + + if err := protocol.WriteMessage(conn, request); err != nil { + return protocol.SessionInvitation{}, err + } + + message, err := protocol.ReadMessage(conn) + if err != nil { + return protocol.SessionInvitation{}, err + } + + switch msg := message.(type) { + case protocol.Response: + return protocol.SessionInvitation{}, fmt.Errorf("Incorrect response code %d: %s", msg.Code, msg.Message) + case protocol.SessionInvitation: + if debug { + l.Debugln("Received invitation via", conn.LocalAddr()) + } + ip := net.IP(msg.Address) + if len(ip) == 0 || ip.IsUnspecified() { + msg.Address = conn.RemoteAddr().(*net.TCPAddr).IP[:] + } + return msg, nil + default: + return protocol.SessionInvitation{}, fmt.Errorf("protocol error: unexpected message %v", msg) + } +} + +func JoinSession(invitation protocol.SessionInvitation) (net.Conn, error) { + addr := net.JoinHostPort(net.IP(invitation.Address).String(), strconv.Itoa(int(invitation.Port))) + + conn, err := net.Dial("tcp", addr) + if err != nil { + return nil, err + } + + request := protocol.JoinSessionRequest{ + Key: invitation.Key, + } + + conn.SetDeadline(time.Now().Add(10 * time.Second)) + err = protocol.WriteMessage(conn, request) + if err != nil { + return nil, err + } + + message, err := protocol.ReadMessage(conn) + if err != nil { + return nil, err + } + + conn.SetDeadline(time.Time{}) + + switch msg := message.(type) { + case protocol.Response: + if msg.Code != 0 { + return nil, fmt.Errorf("Incorrect response code %d: %s", msg.Code, msg.Message) + } + return conn, nil + default: + return nil, fmt.Errorf("protocol error: expecting response got %v", msg) + } +} + +func configForCerts(certs []tls.Certificate) *tls.Config { + return &tls.Config{ + Certificates: certs, + NextProtos: []string{protocol.ProtocolName}, + ClientAuth: tls.RequestClientCert, + SessionTicketsDisabled: true, + InsecureSkipVerify: true, + MinVersion: tls.VersionTLS12, + CipherSuites: []uint16{ + tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, + tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA, + tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, + tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA, + tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, + }, + } +} diff --git a/main.go b/main.go index 3c4d533e..5ca06068 100644 --- a/main.go +++ b/main.go @@ -6,13 +6,13 @@ import ( "crypto/tls" "flag" "log" - "os" + "net" "path/filepath" - "sync" "time" - syncthingprotocol "github.com/syncthing/protocol" "github.com/syncthing/relaysrv/protocol" + + syncthingprotocol "github.com/syncthing/protocol" ) var ( @@ -26,26 +26,11 @@ var ( networkTimeout time.Duration pingInterval time.Duration messageTimeout time.Duration - - pingMessage message - - mut = sync.RWMutex{} - outbox = make(map[syncthingprotocol.DeviceID]chan message) ) func main() { var dir, extAddress string - pingPayload := protocol.Ping{}.MustMarshalXDR() - pingMessage = message{ - header: protocol.Header{ - Magic: protocol.Magic, - MessageType: protocol.MessageTypePing, - MessageLength: int32(len(pingPayload)), - }, - payload: pingPayload, - } - flag.StringVar(&listenProtocol, "protocol-listen", ":22067", "Protocol listen address") flag.StringVar(&listenSession, "session-listen", ":22068", "Session listen address") flag.StringVar(&extAddress, "external-address", "", "External address to advertise, defaults no IP and session-listen port, causing clients to use the remote IP from the protocol connection") @@ -54,7 +39,20 @@ func main() { flag.DurationVar(&pingInterval, "ping-interval", time.Minute, "How often pings are sent") flag.DurationVar(&messageTimeout, "message-timeout", time.Minute, "Maximum amount of time we wait for relevant messages to arrive") + if extAddress == "" { + extAddress = listenSession + } + + addr, err := net.ResolveTCPAddr("tcp", extAddress) + if err != nil { + log.Fatal(err) + } + + sessionAddress = addr.IP[:] + sessionPort = uint16(addr.Port) + flag.BoolVar(&debug, "debug", false, "Enable debug output") + flag.Parse() certFile, keyFile := filepath.Join(dir, "cert.pem"), filepath.Join(dir, "key.pem") @@ -80,7 +78,10 @@ func main() { }, } - log.SetOutput(os.Stdout) + id := syncthingprotocol.NewDeviceID(cert.Certificate[0]) + if debug { + log.Println("ID:", id) + } go sessionListener(listenSession) diff --git a/protocol/packets.go b/protocol/packets.go index 4675d1cf..658bc536 100644 --- a/protocol/packets.go +++ b/protocol/packets.go @@ -5,39 +5,41 @@ package protocol -import ( - "unsafe" -) - const ( - Magic = 0x9E79BC40 - HeaderSize = unsafe.Sizeof(&Header{}) - ProtocolName = "bep-relay" + messageTypePing int32 = iota + messageTypePong + messageTypeJoinRelayRequest + messageTypeJoinSessionRequest + messageTypeResponse + messageTypeConnectRequest + messageTypeSessionInvitation ) -const ( - MessageTypePing int32 = iota - MessageTypePong - MessageTypeJoinRequest - MessageTypeConnectRequest - MessageTypeSessionInvitation -) - -type Header struct { - Magic uint32 - MessageType int32 - MessageLength int32 +type header struct { + magic uint32 + messageType int32 + messageLength int32 } type Ping struct{} type Pong struct{} -type JoinRequest struct{} +type JoinRelayRequest struct{} + +type JoinSessionRequest struct { + Key []byte // max:32 +} + +type Response struct { + Code int32 + Message string +} type ConnectRequest struct { ID []byte // max:32 } type SessionInvitation struct { + From []byte // max:32 Key []byte // max:32 Address []byte // max:32 Port uint16 diff --git a/protocol/packets_xdr.go b/protocol/packets_xdr.go index ca547e00..f18e18c1 100644 --- a/protocol/packets_xdr.go +++ b/protocol/packets_xdr.go @@ -13,37 +13,37 @@ import ( /* -Header Structure: +header Structure: 0 1 2 3 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -| Magic | +| magic | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -| Message Type | +| message Type | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -| Message Length | +| message Length | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -struct Header { - unsigned int Magic; - int MessageType; - int MessageLength; +struct header { + unsigned int magic; + int messageType; + int messageLength; } */ -func (o Header) EncodeXDR(w io.Writer) (int, error) { +func (o header) EncodeXDR(w io.Writer) (int, error) { var xw = xdr.NewWriter(w) return o.EncodeXDRInto(xw) } -func (o Header) MarshalXDR() ([]byte, error) { +func (o header) MarshalXDR() ([]byte, error) { return o.AppendXDR(make([]byte, 0, 128)) } -func (o Header) MustMarshalXDR() []byte { +func (o header) MustMarshalXDR() []byte { bs, err := o.MarshalXDR() if err != nil { panic(err) @@ -51,35 +51,35 @@ func (o Header) MustMarshalXDR() []byte { return bs } -func (o Header) AppendXDR(bs []byte) ([]byte, error) { +func (o header) AppendXDR(bs []byte) ([]byte, error) { var aw = xdr.AppendWriter(bs) var xw = xdr.NewWriter(&aw) _, err := o.EncodeXDRInto(xw) return []byte(aw), err } -func (o Header) EncodeXDRInto(xw *xdr.Writer) (int, error) { - xw.WriteUint32(o.Magic) - xw.WriteUint32(uint32(o.MessageType)) - xw.WriteUint32(uint32(o.MessageLength)) +func (o header) EncodeXDRInto(xw *xdr.Writer) (int, error) { + xw.WriteUint32(o.magic) + xw.WriteUint32(uint32(o.messageType)) + xw.WriteUint32(uint32(o.messageLength)) return xw.Tot(), xw.Error() } -func (o *Header) DecodeXDR(r io.Reader) error { +func (o *header) DecodeXDR(r io.Reader) error { xr := xdr.NewReader(r) return o.DecodeXDRFrom(xr) } -func (o *Header) UnmarshalXDR(bs []byte) error { +func (o *header) UnmarshalXDR(bs []byte) error { var br = bytes.NewReader(bs) var xr = xdr.NewReader(br) return o.DecodeXDRFrom(xr) } -func (o *Header) DecodeXDRFrom(xr *xdr.Reader) error { - o.Magic = xr.ReadUint32() - o.MessageType = int32(xr.ReadUint32()) - o.MessageLength = int32(xr.ReadUint32()) +func (o *header) DecodeXDRFrom(xr *xdr.Reader) error { + o.magic = xr.ReadUint32() + o.messageType = int32(xr.ReadUint32()) + o.messageLength = int32(xr.ReadUint32()) return xr.Error() } @@ -199,28 +199,28 @@ func (o *Pong) DecodeXDRFrom(xr *xdr.Reader) error { /* -JoinRequest Structure: +JoinRelayRequest Structure: 0 1 2 3 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -struct JoinRequest { +struct JoinRelayRequest { } */ -func (o JoinRequest) EncodeXDR(w io.Writer) (int, error) { +func (o JoinRelayRequest) EncodeXDR(w io.Writer) (int, error) { var xw = xdr.NewWriter(w) return o.EncodeXDRInto(xw) } -func (o JoinRequest) MarshalXDR() ([]byte, error) { +func (o JoinRelayRequest) MarshalXDR() ([]byte, error) { return o.AppendXDR(make([]byte, 0, 128)) } -func (o JoinRequest) MustMarshalXDR() []byte { +func (o JoinRelayRequest) MustMarshalXDR() []byte { bs, err := o.MarshalXDR() if err != nil { panic(err) @@ -228,29 +228,169 @@ func (o JoinRequest) MustMarshalXDR() []byte { return bs } -func (o JoinRequest) AppendXDR(bs []byte) ([]byte, error) { +func (o JoinRelayRequest) AppendXDR(bs []byte) ([]byte, error) { var aw = xdr.AppendWriter(bs) var xw = xdr.NewWriter(&aw) _, err := o.EncodeXDRInto(xw) return []byte(aw), err } -func (o JoinRequest) EncodeXDRInto(xw *xdr.Writer) (int, error) { +func (o JoinRelayRequest) EncodeXDRInto(xw *xdr.Writer) (int, error) { return xw.Tot(), xw.Error() } -func (o *JoinRequest) DecodeXDR(r io.Reader) error { +func (o *JoinRelayRequest) DecodeXDR(r io.Reader) error { xr := xdr.NewReader(r) return o.DecodeXDRFrom(xr) } -func (o *JoinRequest) UnmarshalXDR(bs []byte) error { +func (o *JoinRelayRequest) UnmarshalXDR(bs []byte) error { var br = bytes.NewReader(bs) var xr = xdr.NewReader(br) return o.DecodeXDRFrom(xr) } -func (o *JoinRequest) DecodeXDRFrom(xr *xdr.Reader) error { +func (o *JoinRelayRequest) DecodeXDRFrom(xr *xdr.Reader) error { + return xr.Error() +} + +/* + +JoinSessionRequest Structure: + + 0 1 2 3 + 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| Length of Key | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +/ / +\ Key (variable length) \ +/ / ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + + +struct JoinSessionRequest { + opaque Key<32>; +} + +*/ + +func (o JoinSessionRequest) EncodeXDR(w io.Writer) (int, error) { + var xw = xdr.NewWriter(w) + return o.EncodeXDRInto(xw) +} + +func (o JoinSessionRequest) MarshalXDR() ([]byte, error) { + return o.AppendXDR(make([]byte, 0, 128)) +} + +func (o JoinSessionRequest) MustMarshalXDR() []byte { + bs, err := o.MarshalXDR() + if err != nil { + panic(err) + } + return bs +} + +func (o JoinSessionRequest) AppendXDR(bs []byte) ([]byte, error) { + var aw = xdr.AppendWriter(bs) + var xw = xdr.NewWriter(&aw) + _, err := o.EncodeXDRInto(xw) + return []byte(aw), err +} + +func (o JoinSessionRequest) EncodeXDRInto(xw *xdr.Writer) (int, error) { + if l := len(o.Key); l > 32 { + return xw.Tot(), xdr.ElementSizeExceeded("Key", l, 32) + } + xw.WriteBytes(o.Key) + return xw.Tot(), xw.Error() +} + +func (o *JoinSessionRequest) DecodeXDR(r io.Reader) error { + xr := xdr.NewReader(r) + return o.DecodeXDRFrom(xr) +} + +func (o *JoinSessionRequest) UnmarshalXDR(bs []byte) error { + var br = bytes.NewReader(bs) + var xr = xdr.NewReader(br) + return o.DecodeXDRFrom(xr) +} + +func (o *JoinSessionRequest) DecodeXDRFrom(xr *xdr.Reader) error { + o.Key = xr.ReadBytesMax(32) + return xr.Error() +} + +/* + +Response Structure: + + 0 1 2 3 + 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| Code | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| Length of Message | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +/ / +\ Message (variable length) \ +/ / ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + + +struct Response { + int Code; + string Message<>; +} + +*/ + +func (o Response) EncodeXDR(w io.Writer) (int, error) { + var xw = xdr.NewWriter(w) + return o.EncodeXDRInto(xw) +} + +func (o Response) MarshalXDR() ([]byte, error) { + return o.AppendXDR(make([]byte, 0, 128)) +} + +func (o Response) MustMarshalXDR() []byte { + bs, err := o.MarshalXDR() + if err != nil { + panic(err) + } + return bs +} + +func (o Response) AppendXDR(bs []byte) ([]byte, error) { + var aw = xdr.AppendWriter(bs) + var xw = xdr.NewWriter(&aw) + _, err := o.EncodeXDRInto(xw) + return []byte(aw), err +} + +func (o Response) EncodeXDRInto(xw *xdr.Writer) (int, error) { + xw.WriteUint32(uint32(o.Code)) + xw.WriteString(o.Message) + return xw.Tot(), xw.Error() +} + +func (o *Response) DecodeXDR(r io.Reader) error { + xr := xdr.NewReader(r) + return o.DecodeXDRFrom(xr) +} + +func (o *Response) UnmarshalXDR(bs []byte) error { + var br = bytes.NewReader(bs) + var xr = xdr.NewReader(br) + return o.DecodeXDRFrom(xr) +} + +func (o *Response) DecodeXDRFrom(xr *xdr.Reader) error { + o.Code = int32(xr.ReadUint32()) + o.Message = xr.ReadString() return xr.Error() } @@ -330,6 +470,12 @@ SessionInvitation Structure: 0 1 2 3 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| Length of From | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +/ / +\ From (variable length) \ +/ / ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ | Length of Key | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ / / @@ -349,6 +495,7 @@ SessionInvitation Structure: struct SessionInvitation { + opaque From<32>; opaque Key<32>; opaque Address<32>; unsigned int Port; @@ -382,6 +529,10 @@ func (o SessionInvitation) AppendXDR(bs []byte) ([]byte, error) { } func (o SessionInvitation) EncodeXDRInto(xw *xdr.Writer) (int, error) { + if l := len(o.From); l > 32 { + return xw.Tot(), xdr.ElementSizeExceeded("From", l, 32) + } + xw.WriteBytes(o.From) if l := len(o.Key); l > 32 { return xw.Tot(), xdr.ElementSizeExceeded("Key", l, 32) } @@ -407,6 +558,7 @@ func (o *SessionInvitation) UnmarshalXDR(bs []byte) error { } func (o *SessionInvitation) DecodeXDRFrom(xr *xdr.Reader) error { + o.From = xr.ReadBytesMax(32) o.Key = xr.ReadBytesMax(32) o.Address = xr.ReadBytesMax(32) o.Port = xr.ReadUint16() diff --git a/protocol/protocol.go b/protocol/protocol.go new file mode 100644 index 00000000..57a967ac --- /dev/null +++ b/protocol/protocol.go @@ -0,0 +1,114 @@ +// Copyright (C) 2015 Audrius Butkevicius and Contributors (see the CONTRIBUTORS file). + +package protocol + +import ( + "fmt" + "io" +) + +const ( + magic = 0x9E79BC40 + ProtocolName = "bep-relay" +) + +var ( + ResponseSuccess = Response{0, "success"} + ResponseNotFound = Response{1, "not found"} + ResponseAlreadyConnected = Response{2, "already connected"} + ResponseInternalError = Response{99, "internal error"} + ResponseUnexpectedMessage = Response{100, "unexpected message"} +) + +func WriteMessage(w io.Writer, message interface{}) error { + header := header{ + magic: magic, + } + + var payload []byte + var err error + + switch msg := message.(type) { + case Ping: + payload, err = msg.MarshalXDR() + header.messageType = messageTypePing + case Pong: + payload, err = msg.MarshalXDR() + header.messageType = messageTypePong + case JoinRelayRequest: + payload, err = msg.MarshalXDR() + header.messageType = messageTypeJoinRelayRequest + case JoinSessionRequest: + payload, err = msg.MarshalXDR() + header.messageType = messageTypeJoinSessionRequest + case Response: + payload, err = msg.MarshalXDR() + header.messageType = messageTypeResponse + case ConnectRequest: + payload, err = msg.MarshalXDR() + header.messageType = messageTypeConnectRequest + case SessionInvitation: + payload, err = msg.MarshalXDR() + header.messageType = messageTypeSessionInvitation + default: + err = fmt.Errorf("Unknown message type") + } + + if err != nil { + return err + } + + header.messageLength = int32(len(payload)) + + headerpayload, err := header.MarshalXDR() + if err != nil { + return err + } + + _, err = w.Write(append(headerpayload, payload...)) + return err +} + +func ReadMessage(r io.Reader) (interface{}, error) { + var header header + if err := header.DecodeXDR(r); err != nil { + return nil, err + } + + if header.magic != magic { + return nil, fmt.Errorf("magic mismatch") + } + + switch header.messageType { + case messageTypePing: + var msg Ping + err := msg.DecodeXDR(r) + return msg, err + case messageTypePong: + var msg Pong + err := msg.DecodeXDR(r) + return msg, err + case messageTypeJoinRelayRequest: + var msg JoinRelayRequest + err := msg.DecodeXDR(r) + return msg, err + case messageTypeJoinSessionRequest: + var msg JoinSessionRequest + err := msg.DecodeXDR(r) + return msg, err + case messageTypeResponse: + var msg Response + err := msg.DecodeXDR(r) + return msg, err + case messageTypeConnectRequest: + var msg ConnectRequest + err := msg.DecodeXDR(r) + return msg, err + case messageTypeSessionInvitation: + var msg SessionInvitation + err := msg.DecodeXDR(r) + return msg, err + } + + return nil, fmt.Errorf("Unknown message type") +} diff --git a/protocol_listener.go b/protocol_listener.go index b6d89b22..1e18b156 100644 --- a/protocol_listener.go +++ b/protocol_listener.go @@ -4,9 +4,9 @@ package main import ( "crypto/tls" - "io" "log" "net" + "sync" "time" syncthingprotocol "github.com/syncthing/protocol" @@ -14,10 +14,10 @@ import ( "github.com/syncthing/relaysrv/protocol" ) -type message struct { - header protocol.Header - payload []byte -} +var ( + outboxesMut = sync.RWMutex{} + outboxes = make(map[syncthingprotocol.DeviceID]chan interface{}) +) func protocolListener(addr string, config *tls.Config) { listener, err := net.Listen("tcp", addr) @@ -27,6 +27,7 @@ func protocolListener(addr string, config *tls.Config) { for { conn, err := listener.Accept() + setTCPOptions(conn) if err != nil { if debug { log.Println(err) @@ -43,15 +44,12 @@ func protocolListener(addr string, config *tls.Config) { } func protocolConnectionHandler(tcpConn net.Conn, config *tls.Config) { - err := setTCPOptions(tcpConn) - if err != nil && debug { - log.Println("Failed to set TCP options on protocol connection", tcpConn.RemoteAddr(), err) - } - conn := tls.Server(tcpConn, config) - err = conn.Handshake() + err := conn.Handshake() if err != nil { - log.Println("Protocol connection TLS handshake:", conn.RemoteAddr(), err) + if debug { + log.Println("Protocol connection TLS handshake:", conn.RemoteAddr(), err) + } conn.Close() return } @@ -63,168 +61,147 @@ func protocolConnectionHandler(tcpConn net.Conn, config *tls.Config) { certs := state.PeerCertificates if len(certs) != 1 { - log.Println("Certificate list error") + if debug { + log.Println("Certificate list error") + } conn.Close() return } - deviceId := syncthingprotocol.NewDeviceID(certs[0].Raw) + id := syncthingprotocol.NewDeviceID(certs[0].Raw) - mut.RLock() - _, ok := outbox[deviceId] - mut.RUnlock() - if ok { - log.Println("Already have a peer with the same ID", deviceId, conn.RemoteAddr()) - conn.Close() - return - } + messages := make(chan interface{}) + errors := make(chan error, 1) + outbox := make(chan interface{}) - errorChannel := make(chan error) - messageChannel := make(chan message) - outboxChannel := make(chan message) - - go readerLoop(conn, messageChannel, errorChannel) + go func(conn net.Conn, message chan<- interface{}, errors chan<- error) { + for { + msg, err := protocol.ReadMessage(conn) + if err != nil { + errors <- err + return + } + messages <- msg + } + }(conn, messages, errors) pingTicker := time.NewTicker(pingInterval) - timeoutTicker := time.NewTimer(messageTimeout * 2) + timeoutTicker := time.NewTimer(networkTimeout) joined := false for { select { - case msg := <-messageChannel: - switch msg.header.MessageType { - case protocol.MessageTypeJoinRequest: - mut.Lock() - outbox[deviceId] = outboxChannel - mut.Unlock() - joined = true - case protocol.MessageTypeConnectRequest: - // We will disconnect after this message, no matter what, - // because, we've either sent out an invitation, or we don't - // have the peer available. - var fmsg protocol.ConnectRequest - err := fmsg.UnmarshalXDR(msg.payload) - if err != nil { - log.Println(err) + case message := <-messages: + timeoutTicker.Reset(networkTimeout) + if debug { + log.Printf("Message %T from %s", message, id) + } + switch msg := message.(type) { + case protocol.JoinRelayRequest: + outboxesMut.RLock() + _, ok := outboxes[id] + outboxesMut.RUnlock() + if ok { + protocol.WriteMessage(conn, protocol.ResponseAlreadyConnected) + if debug { + log.Println("Already have a peer with the same ID", id, conn.RemoteAddr()) + } conn.Close() continue } - requestedPeer := syncthingprotocol.DeviceIDFromBytes(fmsg.ID) - mut.RLock() - peerOutbox, ok := outbox[requestedPeer] - mut.RUnlock() + outboxesMut.Lock() + outboxes[id] = outbox + outboxesMut.Unlock() + joined = true + + protocol.WriteMessage(conn, protocol.ResponseSuccess) + case protocol.ConnectRequest: + requestedPeer := syncthingprotocol.DeviceIDFromBytes(msg.ID) + outboxesMut.RLock() + peerOutbox, ok := outboxes[requestedPeer] + outboxesMut.RUnlock() if !ok { if debug { - log.Println("Do not have", requestedPeer) + log.Println(id, "is looking", requestedPeer, "which does not exist") } + protocol.WriteMessage(conn, protocol.ResponseNotFound) conn.Close() continue } ses := newSession() - smsg, err := ses.GetServerInvitationMessage() - if err != nil { - log.Println("Error getting server invitation", requestedPeer) - conn.Close() - continue - } - cmsg, err := ses.GetClientInvitationMessage() - if err != nil { - log.Println("Error getting client invitation", requestedPeer) - conn.Close() - continue - } - go ses.Serve() - if err := sendMessage(cmsg, conn); err != nil { - log.Println("Failed to send invitation message", err) - } else { - peerOutbox <- smsg + clientInvitation := ses.GetClientInvitationMessage(requestedPeer) + serverInvitation := ses.GetServerInvitationMessage(id) + + if err := protocol.WriteMessage(conn, clientInvitation); err != nil { if debug { - log.Println("Sent invitation from", deviceId, "to", requestedPeer) + log.Printf("Error sending invitation from %s to client: %s", id, err) } + conn.Close() + continue + } + + peerOutbox <- serverInvitation + + if debug { + log.Println("Sent invitation from", id, "to", requestedPeer) } conn.Close() - case protocol.MessageTypePong: - timeoutTicker.Reset(messageTimeout) + case protocol.Pong: + default: + if debug { + log.Printf("Unknown message %s: %T", id, message) + } + protocol.WriteMessage(conn, protocol.ResponseUnexpectedMessage) + conn.Close() } - case err := <-errorChannel: - log.Println("Closing connection:", err) + case err := <-errors: + if debug { + log.Printf("Closing connection %s: %s", id, err) + } + // Potentially closing a second time. + close(outbox) + conn.Close() + outboxesMut.Lock() + delete(outboxes, id) + outboxesMut.Unlock() return case <-pingTicker.C: if !joined { - log.Println(deviceId, "didn't join within", messageTimeout) + if debug { + log.Println(id, "didn't join within", pingInterval) + } conn.Close() continue } - if err := sendMessage(pingMessage, conn); err != nil { - log.Println(err) + if err := protocol.WriteMessage(conn, protocol.Ping{}); err != nil { + if debug { + log.Println(id, err) + } conn.Close() - continue } case <-timeoutTicker.C: - // We should receive a error, which will cause us to quit the - // loop. - conn.Close() - case msg := <-outboxChannel: + // We should receive a error from the reader loop, which will cause + // us to quit this loop. if debug { - log.Println("Sending message to", deviceId, msg) + log.Printf("%s timed out", id) } - if err := sendMessage(msg, conn); err == nil { - log.Println(err) + conn.Close() + case msg := <-outbox: + if debug { + log.Printf("Sending message %T to %s", msg, id) + } + if err := protocol.WriteMessage(conn, msg); err != nil { + if debug { + log.Println(id, err) + } conn.Close() - continue } } } } - -func readerLoop(conn *tls.Conn, messages chan<- message, errors chan<- error) { - header := make([]byte, protocol.HeaderSize) - data := make([]byte, 0, 0) - for { - _, err := io.ReadFull(conn, header) - if err != nil { - errors <- err - conn.Close() - return - } - - var hdr protocol.Header - err = hdr.UnmarshalXDR(header) - if err != nil { - conn.Close() - return - } - - if hdr.Magic != protocol.Magic { - conn.Close() - return - } - - if hdr.MessageLength > int32(cap(data)) { - data = make([]byte, 0, hdr.MessageLength) - } else { - data = data[:hdr.MessageLength] - } - - _, err = io.ReadFull(conn, data) - if err != nil { - errors <- err - conn.Close() - return - } - - msg := message{ - header: hdr, - payload: make([]byte, hdr.MessageLength), - } - copy(msg.payload, data[:hdr.MessageLength]) - - messages <- msg - } -} diff --git a/session.go b/session.go index 3466bd53..c5a09195 100644 --- a/session.go +++ b/session.go @@ -4,23 +4,27 @@ package main import ( "crypto/rand" + "encoding/hex" + "fmt" + "log" "net" "sync" "time" "github.com/syncthing/relaysrv/protocol" + + syncthingprotocol "github.com/syncthing/protocol" ) var ( - sessionmut = sync.Mutex{} + sessionMut = sync.Mutex{} sessions = make(map[string]*session, 0) ) type session struct { - serverkey string - clientkey string + serverkey []byte + clientkey []byte - mut sync.RWMutex conns chan net.Conn } @@ -37,16 +41,27 @@ func newSession() *session { return nil } - return &session{ - serverkey: string(serverkey), - clientkey: string(clientkey), + ses := &session{ + serverkey: serverkey, + clientkey: clientkey, conns: make(chan net.Conn), } + + if debug { + log.Println("New session", ses) + } + + sessionMut.Lock() + sessions[string(ses.serverkey)] = ses + sessions[string(ses.clientkey)] = ses + sessionMut.Unlock() + + return ses } func findSession(key string) *session { - sessionmut.Lock() - defer sessionmut.Unlock() + sessionMut.Lock() + defer sessionMut.Unlock() lob, ok := sessions[key] if !ok { return nil @@ -56,118 +71,128 @@ func findSession(key string) *session { return lob } -func (l *session) AddConnection(conn net.Conn) { +func (s *session) AddConnection(conn net.Conn) bool { + if debug { + log.Println("New connection for", s, "from", conn.RemoteAddr()) + } + select { - case l.conns <- conn: + case s.conns <- conn: + return true default: } + return false } -func (l *session) Serve() { - +func (s *session) Serve() { timedout := time.After(messageTimeout) - sessionmut.Lock() - sessions[l.serverkey] = l - sessions[l.clientkey] = l - sessionmut.Unlock() + if debug { + log.Println("Session", s, "serving") + } conns := make([]net.Conn, 0, 2) for { select { - case conn := <-l.conns: + case conn := <-s.conns: conns = append(conns, conn) if len(conns) < 2 { continue } - close(l.conns) + close(s.conns) + + if debug { + log.Println("Session", s, "starting between", conns[0].RemoteAddr(), conns[1].RemoteAddr()) + } wg := sync.WaitGroup{} - wg.Add(2) - go proxy(conns[0], conns[1], wg) - go proxy(conns[1], conns[0], wg) + errors := make(chan error, 2) + + go func() { + errors <- proxy(conns[0], conns[1]) + wg.Done() + }() + + go func() { + errors <- proxy(conns[1], conns[0]) + wg.Done() + }() wg.Wait() - break - case <-timedout: - sessionmut.Lock() - delete(sessions, l.serverkey) - delete(sessions, l.clientkey) - sessionmut.Unlock() - - for _, conn := range conns { - conn.Close() + if debug { + log.Println("Session", s, "ended, outcomes:", <-errors, <-errors) } - - break + goto done + case <-timedout: + if debug { + log.Println("Session", s, "timed out") + } + goto done } } +done: + sessionMut.Lock() + delete(sessions, string(s.serverkey)) + delete(sessions, string(s.clientkey)) + sessionMut.Unlock() + + for _, conn := range conns { + conn.Close() + } + + if debug { + log.Println("Session", s, "stopping") + } } -func (l *session) GetClientInvitationMessage() (message, error) { - invitation := protocol.SessionInvitation{ - Key: []byte(l.clientkey), - Address: nil, - Port: 123, +func (s *session) GetClientInvitationMessage(from syncthingprotocol.DeviceID) protocol.SessionInvitation { + return protocol.SessionInvitation{ + From: from[:], + Key: []byte(s.clientkey), + Address: sessionAddress, + Port: sessionPort, ServerSocket: false, } - data, err := invitation.MarshalXDR() - if err != nil { - return message{}, err - } - - return message{ - header: protocol.Header{ - Magic: protocol.Magic, - MessageType: protocol.MessageTypeSessionInvitation, - MessageLength: int32(len(data)), - }, - payload: data, - }, nil } -func (l *session) GetServerInvitationMessage() (message, error) { - invitation := protocol.SessionInvitation{ - Key: []byte(l.serverkey), - Address: nil, - Port: 123, +func (s *session) GetServerInvitationMessage(from syncthingprotocol.DeviceID) protocol.SessionInvitation { + return protocol.SessionInvitation{ + From: from[:], + Key: []byte(s.serverkey), + Address: sessionAddress, + Port: sessionPort, ServerSocket: true, } - data, err := invitation.MarshalXDR() - if err != nil { - return message{}, err - } - - return message{ - header: protocol.Header{ - Magic: protocol.Magic, - MessageType: protocol.MessageTypeSessionInvitation, - MessageLength: int32(len(data)), - }, - payload: data, - }, nil } -func proxy(c1, c2 net.Conn, wg sync.WaitGroup) { +func proxy(c1, c2 net.Conn) error { + if debug { + log.Println("Proxy", c1.RemoteAddr(), "->", c2.RemoteAddr()) + } + buf := make([]byte, 1024) for { - buf := make([]byte, 1024) c1.SetReadDeadline(time.Now().Add(networkTimeout)) - n, err := c1.Read(buf) + n, err := c1.Read(buf[0:]) if err != nil { - break + return err + } + + if debug { + log.Printf("%d bytes from %s to %s", n, c1.RemoteAddr(), c2.RemoteAddr()) } c2.SetWriteDeadline(time.Now().Add(networkTimeout)) _, err = c2.Write(buf[:n]) if err != nil { - break + return err } } - c1.Close() - c2.Close() - wg.Done() +} + +func (s *session) String() string { + return fmt.Sprintf("<%s/%s>", hex.EncodeToString(s.clientkey)[:5], hex.EncodeToString(s.serverkey)[:5]) } diff --git a/session_listener.go b/session_listener.go index b78c4f4b..6159ceef 100644 --- a/session_listener.go +++ b/session_listener.go @@ -3,10 +3,11 @@ package main import ( - "io" "log" "net" "time" + + "github.com/syncthing/relaysrv/protocol" ) func sessionListener(addr string) { @@ -17,6 +18,7 @@ func sessionListener(addr string) { for { conn, err := listener.Accept() + setTCPOptions(conn) if err != nil { if debug { log.Println(err) @@ -33,27 +35,49 @@ func sessionListener(addr string) { } func sessionConnectionHandler(conn net.Conn) { - conn.SetReadDeadline(time.Now().Add(messageTimeout)) - key := make([]byte, 32) - - _, err := io.ReadFull(conn, key) + conn.SetDeadline(time.Now().Add(messageTimeout)) + message, err := protocol.ReadMessage(conn) if err != nil { + conn.Close() + return + } + + switch msg := message.(type) { + case protocol.JoinSessionRequest: + ses := findSession(string(msg.Key)) if debug { - log.Println("Failed to read key", err, conn.RemoteAddr()) + log.Println(conn.RemoteAddr(), "session lookup", ses) } - conn.Close() - return - } - ses := findSession(string(key)) - if debug { - log.Println("Key", key, "by", conn.RemoteAddr(), "session", ses) - } + if ses == nil { + protocol.WriteMessage(conn, protocol.ResponseNotFound) + conn.Close() + return + } - if ses != nil { - ses.AddConnection(conn) - } else { + if !ses.AddConnection(conn) { + if debug { + log.Println("Failed to add", conn.RemoteAddr(), "to session", ses) + } + protocol.WriteMessage(conn, protocol.ResponseAlreadyConnected) + conn.Close() + return + } + + err := protocol.WriteMessage(conn, protocol.ResponseSuccess) + if err != nil { + if debug { + log.Println("Failed to send session join response to ", conn.RemoteAddr(), "for", ses) + } + conn.Close() + return + } + conn.SetDeadline(time.Time{}) + default: + if debug { + log.Println("Unexpected message from", conn.RemoteAddr(), message) + } + protocol.WriteMessage(conn, protocol.ResponseUnexpectedMessage) conn.Close() - return } } diff --git a/testutil/main.go b/testutil/main.go new file mode 100644 index 00000000..10c22245 --- /dev/null +++ b/testutil/main.go @@ -0,0 +1,142 @@ +// Copyright (C) 2015 Audrius Butkevicius and Contributors (see the CONTRIBUTORS file). + +package main + +import ( + "bufio" + "crypto/tls" + "flag" + "log" + "net" + "net/url" + "os" + "path/filepath" + "time" + + syncthingprotocol "github.com/syncthing/protocol" + "github.com/syncthing/relaysrv/client" + "github.com/syncthing/relaysrv/protocol" +) + +func main() { + log.SetOutput(os.Stdout) + log.SetFlags(log.LstdFlags | log.Lshortfile) + + var connect, relay, dir string + var join bool + + flag.StringVar(&connect, "connect", "", "Device ID to which to connect to") + flag.BoolVar(&join, "join", false, "Join relay") + flag.StringVar(&relay, "relay", "relay://127.0.0.1:22067", "Relay address") + flag.StringVar(&dir, "keys", ".", "Directory where cert.pem and key.pem is stored") + + flag.Parse() + + certFile, keyFile := filepath.Join(dir, "cert.pem"), filepath.Join(dir, "key.pem") + cert, err := tls.LoadX509KeyPair(certFile, keyFile) + if err != nil { + log.Fatalln("Failed to load X509 key pair:", err) + } + + id := syncthingprotocol.NewDeviceID(cert.Certificate[0]) + log.Println("ID:", id) + + uri, err := url.Parse(relay) + if err != nil { + log.Fatal(err) + } + + stdin := make(chan string) + + go stdinReader(stdin) + + if join { + log.Printf("Creating client") + relay := client.NewProtocolClient(uri, []tls.Certificate{cert}, nil) + log.Printf("Created client") + + go relay.Serve() + + recv := make(chan protocol.SessionInvitation) + + go func() { + log.Println("Starting invitation receiver") + for invite := range relay.Invitations { + select { + case recv <- invite: + log.Printf("Received invitation from %s on %s:%d", syncthingprotocol.DeviceIDFromBytes(invite.From), net.IP(invite.Address), invite.Port) + default: + log.Printf("Discarding invitation", invite) + } + } + }() + + for { + conn, err := client.JoinSession(<-recv) + if err != nil { + log.Fatalln("Failed to join", err) + } + log.Println("Joined", conn.RemoteAddr(), conn.LocalAddr()) + connectToStdio(stdin, conn) + log.Println("Finished", conn.RemoteAddr(), conn.LocalAddr()) + } + } else if connect != "" { + id, err := syncthingprotocol.DeviceIDFromString(connect) + if err != nil { + log.Fatal(err) + } + + invite, err := client.GetInvitationFromRelay(uri, id, []tls.Certificate{cert}) + if err != nil { + log.Fatal(err) + } + + log.Printf("Received invitation from %s on %s:%d", syncthingprotocol.DeviceIDFromBytes(invite.From), net.IP(invite.Address), invite.Port) + conn, err := client.JoinSession(invite) + if err != nil { + log.Fatalln("Failed to join", err) + } + log.Println("Joined", conn.RemoteAddr(), conn.LocalAddr()) + connectToStdio(stdin, conn) + log.Println("Finished", conn.RemoteAddr(), conn.LocalAddr()) + } else { + log.Fatal("Requires either join or connect") + } +} + +func stdinReader(c chan<- string) { + scanner := bufio.NewScanner(os.Stdin) + for scanner.Scan() { + c <- scanner.Text() + c <- "\n" + } +} + +func connectToStdio(stdin <-chan string, conn net.Conn) { + go func() { + + }() + + buf := make([]byte, 1024) + for { + conn.SetReadDeadline(time.Now().Add(time.Millisecond)) + n, err := conn.Read(buf[0:]) + if err != nil { + nerr, ok := err.(net.Error) + if !ok || !nerr.Timeout() { + log.Println(err) + return + } + } + os.Stdout.Write(buf[:n]) + + select { + case msg := <-stdin: + _, err := conn.Write([]byte(msg)) + if err != nil { + return + } + default: + } + } +} diff --git a/utils.go b/utils.go index 5388ba32..7d1f6bfa 100644 --- a/utils.go +++ b/utils.go @@ -5,7 +5,6 @@ package main import ( "errors" "net" - "time" ) func setTCPOptions(conn net.Conn) error { @@ -19,7 +18,7 @@ func setTCPOptions(conn net.Conn) error { if err := tcpConn.SetNoDelay(true); err != nil { return err } - if err := tcpConn.SetKeepAlivePeriod(60 * time.Second); err != nil { + if err := tcpConn.SetKeepAlivePeriod(networkTimeout); err != nil { return err } if err := tcpConn.SetKeepAlive(true); err != nil { @@ -27,27 +26,3 @@ func setTCPOptions(conn net.Conn) error { } return nil } - -func sendMessage(msg message, conn net.Conn) error { - header, err := msg.header.MarshalXDR() - if err != nil { - return err - } - - err = conn.SetWriteDeadline(time.Now().Add(networkTimeout)) - if err != nil { - return err - } - - _, err = conn.Write(header) - if err != nil { - return err - } - - _, err = conn.Write(msg.payload) - if err != nil { - return err - } - - return nil -} From c68c78d4120ecde3eb24f4431dcf2bab820dcb5d Mon Sep 17 00:00:00 2001 From: Audrius Butkevicius Date: Sun, 28 Jun 2015 20:34:28 +0100 Subject: [PATCH 04/30] Do scheme validation in the client --- client/client.go | 4 ++++ client/methods.go | 4 ++++ 2 files changed, 8 insertions(+) diff --git a/client/client.go b/client/client.go index b48320fd..d05944ac 100644 --- a/client/client.go +++ b/client/client.go @@ -52,6 +52,10 @@ type ProtocolClient struct { } func (c *ProtocolClient) connect() error { + if c.URI.Scheme != "relay" { + return fmt.Errorf("Unsupported relay schema:", c.URI.Scheme) + } + conn, err := tls.Dial("tcp", c.URI.Host, c.config) if err != nil { return err diff --git a/client/methods.go b/client/methods.go index 1d457e29..c9b7a265 100644 --- a/client/methods.go +++ b/client/methods.go @@ -15,6 +15,10 @@ import ( ) func GetInvitationFromRelay(uri *url.URL, id syncthingprotocol.DeviceID, certs []tls.Certificate) (protocol.SessionInvitation, error) { + if uri.Scheme != "relay" { + return protocol.SessionInvitation{}, fmt.Errorf("Unsupported relay schema:", uri.Scheme) + } + conn, err := tls.Dial("tcp", uri.Host, configForCerts(certs)) conn.SetDeadline(time.Now().Add(10 * time.Second)) if err != nil { From e1959afb6beca5c15575cf1573e9c0809b93eb49 Mon Sep 17 00:00:00 2001 From: Audrius Butkevicius Date: Sun, 28 Jun 2015 21:18:38 +0100 Subject: [PATCH 05/30] Change EOL --- client/debug.go | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/client/debug.go b/client/debug.go index 4a3608de..935e9fe6 100644 --- a/client/debug.go +++ b/client/debug.go @@ -1,15 +1,15 @@ -// Copyright (C) 2015 Audrius Butkevicius and Contributors (see the CONTRIBUTORS file). - -package client - -import ( - "os" - "strings" - - "github.com/calmh/logger" -) - -var ( - debug = strings.Contains(os.Getenv("STTRACE"), "relay") || os.Getenv("STTRACE") == "all" - l = logger.DefaultLogger -) +// Copyright (C) 2015 Audrius Butkevicius and Contributors (see the CONTRIBUTORS file). + +package client + +import ( + "os" + "strings" + + "github.com/calmh/logger" +) + +var ( + debug = strings.Contains(os.Getenv("STTRACE"), "relay") || os.Getenv("STTRACE") == "all" + l = logger.DefaultLogger +) From 2505f82ce5f923480dfe9a63d16a5a87a9fa9753 Mon Sep 17 00:00:00 2001 From: Audrius Butkevicius Date: Fri, 17 Jul 2015 20:17:49 +0100 Subject: [PATCH 06/30] General cleanup --- client/methods.go | 4 ++-- protocol/packets.go | 14 ++++++++++++++ testutil/main.go | 10 +++++----- 3 files changed, 21 insertions(+), 7 deletions(-) diff --git a/client/methods.go b/client/methods.go index c9b7a265..ef6145e9 100644 --- a/client/methods.go +++ b/client/methods.go @@ -16,7 +16,7 @@ import ( func GetInvitationFromRelay(uri *url.URL, id syncthingprotocol.DeviceID, certs []tls.Certificate) (protocol.SessionInvitation, error) { if uri.Scheme != "relay" { - return protocol.SessionInvitation{}, fmt.Errorf("Unsupported relay schema:", uri.Scheme) + return protocol.SessionInvitation{}, fmt.Errorf("Unsupported relay scheme:", uri.Scheme) } conn, err := tls.Dial("tcp", uri.Host, configForCerts(certs)) @@ -49,7 +49,7 @@ func GetInvitationFromRelay(uri *url.URL, id syncthingprotocol.DeviceID, certs [ return protocol.SessionInvitation{}, fmt.Errorf("Incorrect response code %d: %s", msg.Code, msg.Message) case protocol.SessionInvitation: if debug { - l.Debugln("Received invitation via", conn.LocalAddr()) + l.Debugln("Received invitation", msg, "via", conn.LocalAddr()) } ip := net.IP(msg.Address) if len(ip) == 0 || ip.IsUnspecified() { diff --git a/protocol/packets.go b/protocol/packets.go index 658bc536..84316da9 100644 --- a/protocol/packets.go +++ b/protocol/packets.go @@ -5,6 +5,12 @@ package protocol +import ( + "fmt" + syncthingprotocol "github.com/syncthing/protocol" + "net" +) + const ( messageTypePing int32 = iota messageTypePong @@ -45,3 +51,11 @@ type SessionInvitation struct { Port uint16 ServerSocket bool } + +func (i *SessionInvitation) String() string { + return fmt.Sprintf("%s@%s", syncthingprotocol.DeviceIDFromBytes(i.From), i.AddressString()) +} + +func (i *SessionInvitation) AddressString() string { + return fmt.Sprintf("%s:%d", net.IP(i.Address), i.Port) +} diff --git a/testutil/main.go b/testutil/main.go index 10c22245..69dbb00a 100644 --- a/testutil/main.go +++ b/testutil/main.go @@ -51,9 +51,9 @@ func main() { go stdinReader(stdin) if join { - log.Printf("Creating client") + log.Println("Creating client") relay := client.NewProtocolClient(uri, []tls.Certificate{cert}, nil) - log.Printf("Created client") + log.Println("Created client") go relay.Serve() @@ -64,9 +64,9 @@ func main() { for invite := range relay.Invitations { select { case recv <- invite: - log.Printf("Received invitation from %s on %s:%d", syncthingprotocol.DeviceIDFromBytes(invite.From), net.IP(invite.Address), invite.Port) + log.Println("Received invitation", invite) default: - log.Printf("Discarding invitation", invite) + log.Println("Discarding invitation", invite) } } }() @@ -91,7 +91,7 @@ func main() { log.Fatal(err) } - log.Printf("Received invitation from %s on %s:%d", syncthingprotocol.DeviceIDFromBytes(invite.From), net.IP(invite.Address), invite.Port) + log.Println("Received invitation", invite) conn, err := client.JoinSession(invite) if err != nil { log.Fatalln("Failed to join", err) From e97f75cad50e728b6938b1142781465aa9d1a6c4 Mon Sep 17 00:00:00 2001 From: Audrius Butkevicius Date: Fri, 17 Jul 2015 21:49:45 +0100 Subject: [PATCH 07/30] Change receiver type, add GoStringer --- protocol/packets.go | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/protocol/packets.go b/protocol/packets.go index 84316da9..7ff02011 100644 --- a/protocol/packets.go +++ b/protocol/packets.go @@ -52,10 +52,14 @@ type SessionInvitation struct { ServerSocket bool } -func (i *SessionInvitation) String() string { +func (i SessionInvitation) String() string { return fmt.Sprintf("%s@%s", syncthingprotocol.DeviceIDFromBytes(i.From), i.AddressString()) } -func (i *SessionInvitation) AddressString() string { +func (i SessionInvitation) GoString() string { + return i.String() +} + +func (i SessionInvitation) AddressString() string { return fmt.Sprintf("%s:%d", net.IP(i.Address), i.Port) } From f86946c6dfe5acec0a73e0ad4741b313de496ac4 Mon Sep 17 00:00:00 2001 From: Audrius Butkevicius Date: Fri, 17 Jul 2015 22:04:02 +0100 Subject: [PATCH 08/30] Fix bugs --- protocol_listener.go | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/protocol_listener.go b/protocol_listener.go index 1e18b156..c3321aa5 100644 --- a/protocol_listener.go +++ b/protocol_listener.go @@ -166,9 +166,13 @@ func protocolConnectionHandler(tcpConn net.Conn, config *tls.Config) { // Potentially closing a second time. close(outbox) conn.Close() - outboxesMut.Lock() - delete(outboxes, id) - outboxesMut.Unlock() + // Only delete the outbox if the client join, as it migth be a + // lookup request coming from the same client. + if joined { + outboxesMut.Lock() + delete(outboxes, id) + outboxesMut.Unlock() + } return case <-pingTicker.C: if !joined { From dab1c4cfc9cb55c00cf5c83e505931c9a1f2c9a6 Mon Sep 17 00:00:00 2001 From: Jakob Borg Date: Mon, 20 Jul 2015 12:11:06 +0200 Subject: [PATCH 09/30] Build script from discosrv --- .gitignore | 2 ++ build.sh | 37 +++++++++++++++++++++++++++++++++++++ 2 files changed, 39 insertions(+) create mode 100755 build.sh diff --git a/.gitignore b/.gitignore index daf913b1..b7006615 100644 --- a/.gitignore +++ b/.gitignore @@ -22,3 +22,5 @@ _testmain.go *.exe *.test *.prof +*.tar.gz +*.zip diff --git a/build.sh b/build.sh new file mode 100755 index 00000000..5f605e1b --- /dev/null +++ b/build.sh @@ -0,0 +1,37 @@ +#!/bin/bash +set -euo pipefail +set nullglob + +echo Get dependencies +go get -d + +rm -rf relaysrv-*-* + +build() { + export GOOS="$1" + export GOARCH="$2" + target="relaysrv-$GOOS-$GOARCH" + go build -v + mkdir "$target" + if [ -f relaysrv ] ; then + mv relaysrv "$target" + tar zcvf "$target.tar.gz" "$target" + fi + if [ -f relaysrv.exe ] ; then + mv relaysrv.exe "$target" + zip -r "$target.zip" "$target" + fi +} + +for goos in linux darwin windows freebsd openbsd netbsd solaris ; do + build "$goos" amd64 +done +for goos in linux windows freebsd openbsd netbsd ; do + build "$goos" 386 +done +build linux arm + +# Hack used because we run as root under Docker +if [[ ${CHOWN_USER:-} != "" ]] ; then + chown -R $CHOWN_USER . +fi From 35d20a19bc7d389d0367b9eae4769d323e5ecb13 Mon Sep 17 00:00:00 2001 From: Jakob Borg Date: Mon, 20 Jul 2015 13:25:08 +0200 Subject: [PATCH 10/30] Implement global and per session rate limiting --- main.go | 22 +++++++++++++--- protocol_listener.go | 2 +- session.go | 63 +++++++++++++++++++++++++++++++++++++++----- 3 files changed, 76 insertions(+), 11 deletions(-) diff --git a/main.go b/main.go index 5ca06068..b429d94a 100644 --- a/main.go +++ b/main.go @@ -10,6 +10,7 @@ import ( "path/filepath" "time" + "github.com/juju/ratelimit" "github.com/syncthing/relaysrv/protocol" syncthingprotocol "github.com/syncthing/protocol" @@ -26,6 +27,11 @@ var ( networkTimeout time.Duration pingInterval time.Duration messageTimeout time.Duration + + sessionLimitBps int + globalLimitBps int + sessionLimiter *ratelimit.Bucket + globalLimiter *ratelimit.Bucket ) func main() { @@ -38,6 +44,11 @@ func main() { flag.DurationVar(&networkTimeout, "network-timeout", 2*time.Minute, "Timeout for network operations") flag.DurationVar(&pingInterval, "ping-interval", time.Minute, "How often pings are sent") flag.DurationVar(&messageTimeout, "message-timeout", time.Minute, "Maximum amount of time we wait for relevant messages to arrive") + flag.IntVar(&sessionLimitBps, "per-session-rate", sessionLimitBps, "Per session rate limit, in bytes/s") + flag.IntVar(&globalLimitBps, "global-rate", globalLimitBps, "Global rate limit, in bytes/s") + flag.BoolVar(&debug, "debug", false, "Enable debug output") + + flag.Parse() if extAddress == "" { extAddress = listenSession @@ -51,10 +62,6 @@ func main() { sessionAddress = addr.IP[:] sessionPort = uint16(addr.Port) - flag.BoolVar(&debug, "debug", false, "Enable debug output") - - flag.Parse() - certFile, keyFile := filepath.Join(dir, "cert.pem"), filepath.Join(dir, "key.pem") cert, err := tls.LoadX509KeyPair(certFile, keyFile) if err != nil { @@ -83,6 +90,13 @@ func main() { log.Println("ID:", id) } + if sessionLimitBps > 0 { + sessionLimiter = ratelimit.NewBucketWithRate(float64(sessionLimitBps), int64(2*sessionLimitBps)) + } + if globalLimitBps > 0 { + globalLimiter = ratelimit.NewBucketWithRate(float64(globalLimitBps), int64(2*globalLimitBps)) + } + go sessionListener(listenSession) protocolListener(listenProtocol, tlsCfg) diff --git a/protocol_listener.go b/protocol_listener.go index c3321aa5..8825af82 100644 --- a/protocol_listener.go +++ b/protocol_listener.go @@ -130,7 +130,7 @@ func protocolConnectionHandler(tcpConn net.Conn, config *tls.Config) { continue } - ses := newSession() + ses := newSession(sessionLimiter, globalLimiter) go ses.Serve() diff --git a/session.go b/session.go index c5a09195..c526ed5d 100644 --- a/session.go +++ b/session.go @@ -11,6 +11,7 @@ import ( "sync" "time" + "github.com/juju/ratelimit" "github.com/syncthing/relaysrv/protocol" syncthingprotocol "github.com/syncthing/protocol" @@ -25,10 +26,12 @@ type session struct { serverkey []byte clientkey []byte + rateLimit func(bytes int64) + conns chan net.Conn } -func newSession() *session { +func newSession(sessionRateLimit, globalRateLimit *ratelimit.Bucket) *session { serverkey := make([]byte, 32) _, err := rand.Read(serverkey) if err != nil { @@ -44,6 +47,7 @@ func newSession() *session { ses := &session{ serverkey: serverkey, clientkey: clientkey, + rateLimit: makeRateLimitFunc(sessionRateLimit, globalRateLimit), conns: make(chan net.Conn), } @@ -112,12 +116,12 @@ func (s *session) Serve() { errors := make(chan error, 2) go func() { - errors <- proxy(conns[0], conns[1]) + errors <- s.proxy(conns[0], conns[1]) wg.Done() }() go func() { - errors <- proxy(conns[1], conns[0]) + errors <- s.proxy(conns[1], conns[0]) wg.Done() }() @@ -169,14 +173,15 @@ func (s *session) GetServerInvitationMessage(from syncthingprotocol.DeviceID) pr } } -func proxy(c1, c2 net.Conn) error { +func (s *session) proxy(c1, c2 net.Conn) error { if debug { log.Println("Proxy", c1.RemoteAddr(), "->", c2.RemoteAddr()) } - buf := make([]byte, 1024) + + buf := make([]byte, 65536) for { c1.SetReadDeadline(time.Now().Add(networkTimeout)) - n, err := c1.Read(buf[0:]) + n, err := c1.Read(buf) if err != nil { return err } @@ -185,6 +190,10 @@ func proxy(c1, c2 net.Conn) error { log.Printf("%d bytes from %s to %s", n, c1.RemoteAddr(), c2.RemoteAddr()) } + if s.rateLimit != nil { + s.rateLimit(int64(n)) + } + c2.SetWriteDeadline(time.Now().Add(networkTimeout)) _, err = c2.Write(buf[:n]) if err != nil { @@ -196,3 +205,45 @@ func proxy(c1, c2 net.Conn) error { func (s *session) String() string { return fmt.Sprintf("<%s/%s>", hex.EncodeToString(s.clientkey)[:5], hex.EncodeToString(s.serverkey)[:5]) } + +func makeRateLimitFunc(sessionRateLimit, globalRateLimit *ratelimit.Bucket) func(int64) { + // This may be a case of super duper premature optimization... We build an + // optimized function to do the rate limiting here based on what we need + // to do and then use it in the loop. + + if sessionRateLimit == nil && globalRateLimit == nil { + // No limiting needed. We could equally well return a func(int64){} and + // not do a nil check were we use it, but I think the nil check there + // makes it clear that there will be no limiting if none is + // configured... + return nil + } + + if sessionRateLimit == nil { + // We only have a global limiter + return func(bytes int64) { + globalRateLimit.Wait(bytes) + } + } + + if globalRateLimit == nil { + // We only have a session limiter + return func(bytes int64) { + sessionRateLimit.Wait(bytes) + } + } + + // We have both. Queue the bytes on both the global and session specific + // rate limiters. Wait for both in parallell, so that the actual send + // happens when both conditions are satisfied. In practice this just means + // wait the longer of the two times. + return func(bytes int64) { + t0 := sessionRateLimit.Take(bytes) + t1 := globalRateLimit.Take(bytes) + if t0 > t1 { + time.Sleep(t0) + } else { + time.Sleep(t1) + } + } +} From f9bd59f031caa6f246a91c5a919e2e7313962f60 Mon Sep 17 00:00:00 2001 From: Jakob Borg Date: Mon, 20 Jul 2015 11:38:00 +0200 Subject: [PATCH 11/30] Style and minor fixes, main package --- protocol_listener.go | 47 +++++++++++++++++++++++++++++++------------- session_listener.go | 29 +++++++++++++++++---------- 2 files changed, 52 insertions(+), 24 deletions(-) diff --git a/protocol_listener.go b/protocol_listener.go index 8825af82..a7243ff6 100644 --- a/protocol_listener.go +++ b/protocol_listener.go @@ -27,7 +27,6 @@ func protocolListener(addr string, config *tls.Config) { for { conn, err := listener.Accept() - setTCPOptions(conn) if err != nil { if debug { log.Println(err) @@ -35,6 +34,8 @@ func protocolListener(addr string, config *tls.Config) { continue } + setTCPOptions(conn) + if debug { log.Println("Protocol listener accepted connection from", conn.RemoteAddr()) } @@ -74,16 +75,12 @@ func protocolConnectionHandler(tcpConn net.Conn, config *tls.Config) { errors := make(chan error, 1) outbox := make(chan interface{}) - go func(conn net.Conn, message chan<- interface{}, errors chan<- error) { - for { - msg, err := protocol.ReadMessage(conn) - if err != nil { - errors <- err - return - } - messages <- msg - } - }(conn, messages, errors) + // Read messages from the connection and send them on the messages + // channel. When there is an error, send it on the error channel and + // return. Applies also when the connection gets closed, so the pattern + // below is to close the connection on error, then wait for the error + // signal from messageReader to exit. + go messageReader(conn, messages, errors) pingTicker := time.NewTicker(pingInterval) timeoutTicker := time.NewTimer(networkTimeout) @@ -96,6 +93,7 @@ func protocolConnectionHandler(tcpConn net.Conn, config *tls.Config) { if debug { log.Printf("Message %T from %s", message, id) } + switch msg := message.(type) { case protocol.JoinRelayRequest: outboxesMut.RLock() @@ -116,6 +114,7 @@ func protocolConnectionHandler(tcpConn net.Conn, config *tls.Config) { joined = true protocol.WriteMessage(conn, protocol.ResponseSuccess) + case protocol.ConnectRequest: requestedPeer := syncthingprotocol.DeviceIDFromBytes(msg.ID) outboxesMut.RLock() @@ -151,7 +150,10 @@ func protocolConnectionHandler(tcpConn net.Conn, config *tls.Config) { log.Println("Sent invitation from", id, "to", requestedPeer) } conn.Close() + case protocol.Pong: + // Nothing + default: if debug { log.Printf("Unknown message %s: %T", id, message) @@ -159,21 +161,25 @@ func protocolConnectionHandler(tcpConn net.Conn, config *tls.Config) { protocol.WriteMessage(conn, protocol.ResponseUnexpectedMessage) conn.Close() } + case err := <-errors: if debug { log.Printf("Closing connection %s: %s", id, err) } - // Potentially closing a second time. close(outbox) + + // Potentially closing a second time. conn.Close() - // Only delete the outbox if the client join, as it migth be a - // lookup request coming from the same client. + + // Only delete the outbox if the client is joined, as it might be + // a lookup request coming from the same client. if joined { outboxesMut.Lock() delete(outboxes, id) outboxesMut.Unlock() } return + case <-pingTicker.C: if !joined { if debug { @@ -189,6 +195,7 @@ func protocolConnectionHandler(tcpConn net.Conn, config *tls.Config) { } conn.Close() } + case <-timeoutTicker.C: // We should receive a error from the reader loop, which will cause // us to quit this loop. @@ -196,6 +203,7 @@ func protocolConnectionHandler(tcpConn net.Conn, config *tls.Config) { log.Printf("%s timed out", id) } conn.Close() + case msg := <-outbox: if debug { log.Printf("Sending message %T to %s", msg, id) @@ -209,3 +217,14 @@ func protocolConnectionHandler(tcpConn net.Conn, config *tls.Config) { } } } + +func messageReader(conn net.Conn, messages chan<- interface{}, errors chan<- error) { + for { + msg, err := protocol.ReadMessage(conn) + if err != nil { + errors <- err + return + } + messages <- msg + } +} diff --git a/session_listener.go b/session_listener.go index 6159ceef..2f6bae9a 100644 --- a/session_listener.go +++ b/session_listener.go @@ -18,7 +18,6 @@ func sessionListener(addr string) { for { conn, err := listener.Accept() - setTCPOptions(conn) if err != nil { if debug { log.Println(err) @@ -26,6 +25,8 @@ func sessionListener(addr string) { continue } + setTCPOptions(conn) + if debug { log.Println("Session listener accepted connection from", conn.RemoteAddr()) } @@ -35,10 +36,17 @@ func sessionListener(addr string) { } func sessionConnectionHandler(conn net.Conn) { - conn.SetDeadline(time.Now().Add(messageTimeout)) + defer conn.Close() + + if err := conn.SetDeadline(time.Now().Add(messageTimeout)); err != nil { + if debug { + log.Println("Weird error setting deadline:", err, "on", conn.RemoteAddr()) + } + return + } + message, err := protocol.ReadMessage(conn) if err != nil { - conn.Close() return } @@ -51,7 +59,6 @@ func sessionConnectionHandler(conn net.Conn) { if ses == nil { protocol.WriteMessage(conn, protocol.ResponseNotFound) - conn.Close() return } @@ -60,24 +67,26 @@ func sessionConnectionHandler(conn net.Conn) { log.Println("Failed to add", conn.RemoteAddr(), "to session", ses) } protocol.WriteMessage(conn, protocol.ResponseAlreadyConnected) - conn.Close() return } - err := protocol.WriteMessage(conn, protocol.ResponseSuccess) - if err != nil { + if err := protocol.WriteMessage(conn, protocol.ResponseSuccess); err != nil { if debug { log.Println("Failed to send session join response to ", conn.RemoteAddr(), "for", ses) } - conn.Close() return } - conn.SetDeadline(time.Time{}) + + if err := conn.SetDeadline(time.Time{}); err != nil { + if debug { + log.Println("Weird error setting deadline:", err, "on", conn.RemoteAddr()) + } + return + } default: if debug { log.Println("Unexpected message from", conn.RemoteAddr(), message) } protocol.WriteMessage(conn, protocol.ResponseUnexpectedMessage) - conn.Close() } } From 049d92b52581015a759593e64314fcbfa2462ad2 Mon Sep 17 00:00:00 2001 From: Jakob Borg Date: Mon, 20 Jul 2015 11:56:10 +0200 Subject: [PATCH 12/30] Style and minor fixes, client package --- client/client.go | 282 ++++++++++++++++++++++++----------------------- 1 file changed, 143 insertions(+), 139 deletions(-) diff --git a/client/client.go b/client/client.go index d05944ac..7169e6a8 100644 --- a/client/client.go +++ b/client/client.go @@ -14,27 +14,6 @@ import ( "github.com/syncthing/relaysrv/protocol" ) -func NewProtocolClient(uri *url.URL, certs []tls.Certificate, invitations chan protocol.SessionInvitation) ProtocolClient { - closeInvitationsOnFinish := false - if invitations == nil { - closeInvitationsOnFinish = true - invitations = make(chan protocol.SessionInvitation) - } - return ProtocolClient{ - URI: uri, - Invitations: invitations, - - closeInvitationsOnFinish: closeInvitationsOnFinish, - - config: configForCerts(certs), - - timeout: time.Minute * 2, - - stop: make(chan struct{}), - stopped: make(chan struct{}), - } -} - type ProtocolClient struct { URI *url.URL Invitations chan protocol.SessionInvitation @@ -51,6 +30,129 @@ type ProtocolClient struct { conn *tls.Conn } +func NewProtocolClient(uri *url.URL, certs []tls.Certificate, invitations chan protocol.SessionInvitation) *ProtocolClient { + closeInvitationsOnFinish := false + if invitations == nil { + closeInvitationsOnFinish = true + invitations = make(chan protocol.SessionInvitation) + } + + return &ProtocolClient{ + URI: uri, + Invitations: invitations, + + closeInvitationsOnFinish: closeInvitationsOnFinish, + + config: configForCerts(certs), + + timeout: time.Minute * 2, + + stop: make(chan struct{}), + stopped: make(chan struct{}), + } +} + +func (c *ProtocolClient) Serve() { + c.stop = make(chan struct{}) + c.stopped = make(chan struct{}) + defer close(c.stopped) + + if err := c.connect(); err != nil { + l.Infoln("Relay connect:", err) + return + } + + if debug { + l.Debugln(c, "connected", c.conn.RemoteAddr()) + } + + if err := c.join(); err != nil { + c.conn.Close() + l.Infoln("Relay join:", err) + return + } + + if err := c.conn.SetDeadline(time.Time{}); err != nil { + l.Infoln("Relay set deadline:", err) + return + } + + if debug { + l.Debugln(c, "joined", c.conn.RemoteAddr(), "via", c.conn.LocalAddr()) + } + + defer c.cleanup() + + messages := make(chan interface{}) + errors := make(chan error, 1) + + go messageReader(c.conn, messages, errors) + + timeout := time.NewTimer(c.timeout) + + for { + select { + case message := <-messages: + timeout.Reset(c.timeout) + if debug { + log.Printf("%s received message %T", c, message) + } + + switch msg := message.(type) { + case protocol.Ping: + if err := protocol.WriteMessage(c.conn, protocol.Pong{}); err != nil { + l.Infoln("Relay write:", err) + return + + } + if debug { + l.Debugln(c, "sent pong") + } + + case protocol.SessionInvitation: + ip := net.IP(msg.Address) + if len(ip) == 0 || ip.IsUnspecified() { + msg.Address = c.conn.RemoteAddr().(*net.TCPAddr).IP[:] + } + c.Invitations <- msg + + default: + l.Infoln("Relay: protocol error: unexpected message %v", msg) + return + } + + case <-c.stop: + if debug { + l.Debugln(c, "stopping") + } + return + + case err := <-errors: + l.Infoln("Relay received:", err) + return + + case <-timeout.C: + if debug { + l.Debugln(c, "timed out") + } + return + } + } +} + +func (c *ProtocolClient) Stop() { + if c.stop == nil { + return + } + + close(c.stop) + <-c.stopped +} + +func (c *ProtocolClient) String() string { + return fmt.Sprintf("ProtocolClient@%p", c) +} + func (c *ProtocolClient) connect() error { if c.URI.Scheme != "relay" { return fmt.Errorf("Unsupported relay schema:", c.URI.Scheme) @@ -61,9 +163,13 @@ func (c *ProtocolClient) connect() error { return err } - conn.SetDeadline(time.Now().Add(10 * time.Second)) + if err := conn.SetDeadline(time.Now().Add(10 * time.Second)); err != nil { + conn.Close() + return err + } if err := performHandshakeAndValidation(conn, c.URI); err != nil { + conn.Close() return err } @@ -71,101 +177,6 @@ func (c *ProtocolClient) connect() error { return nil } -func (c *ProtocolClient) Serve() { - if err := c.connect(); err != nil { - panic(err) - } - - if debug { - l.Debugln(c, "connected", c.conn.RemoteAddr()) - } - - if err := c.join(); err != nil { - c.conn.Close() - panic(err) - } - - c.conn.SetDeadline(time.Time{}) - - if debug { - l.Debugln(c, "joined", c.conn.RemoteAddr(), "via", c.conn.LocalAddr()) - } - - c.stop = make(chan struct{}) - c.stopped = make(chan struct{}) - - defer c.cleanup() - - messages := make(chan interface{}) - errors := make(chan error, 1) - - go func(conn net.Conn, message chan<- interface{}, errors chan<- error) { - for { - msg, err := protocol.ReadMessage(conn) - if err != nil { - errors <- err - return - } - messages <- msg - } - }(c.conn, messages, errors) - - timeout := time.NewTimer(c.timeout) - for { - select { - case message := <-messages: - timeout.Reset(c.timeout) - if debug { - log.Printf("%s received message %T", c, message) - } - switch msg := message.(type) { - case protocol.Ping: - if err := protocol.WriteMessage(c.conn, protocol.Pong{}); err != nil { - panic(err) - } - if debug { - l.Debugln(c, "sent pong") - } - case protocol.SessionInvitation: - ip := net.IP(msg.Address) - if len(ip) == 0 || ip.IsUnspecified() { - msg.Address = c.conn.RemoteAddr().(*net.TCPAddr).IP[:] - } - c.Invitations <- msg - default: - panic(fmt.Errorf("protocol error: unexpected message %v", msg)) - } - case <-c.stop: - if debug { - l.Debugln(c, "stopping") - } - break - case err := <-errors: - panic(err) - case <-timeout.C: - if debug { - l.Debugln(c, "timed out") - } - return - } - } - - c.stopped <- struct{}{} -} - -func (c *ProtocolClient) Stop() { - if c.stop == nil { - return - } - - c.stop <- struct{}{} - <-c.stopped -} - -func (c *ProtocolClient) String() string { - return fmt.Sprintf("ProtocolClient@%p", c) -} - func (c *ProtocolClient) cleanup() { if c.closeInvitationsOnFinish { close(c.Invitations) @@ -176,24 +187,11 @@ func (c *ProtocolClient) cleanup() { l.Debugln(c, "cleaning up") } - if c.stop != nil { - close(c.stop) - c.stop = nil - } - - if c.stopped != nil { - close(c.stopped) - c.stopped = nil - } - - if c.conn != nil { - c.conn.Close() - } + c.conn.Close() } func (c *ProtocolClient) join() error { - err := protocol.WriteMessage(c.conn, protocol.JoinRelayRequest{}) - if err != nil { + if err := protocol.WriteMessage(c.conn, protocol.JoinRelayRequest{}); err != nil { return err } @@ -207,6 +205,7 @@ func (c *ProtocolClient) join() error { if msg.Code != 0 { return fmt.Errorf("Incorrect response code %d: %s", msg.Code, msg.Message) } + default: return fmt.Errorf("protocol error: expecting response got %v", msg) } @@ -215,15 +214,12 @@ func (c *ProtocolClient) join() error { } func performHandshakeAndValidation(conn *tls.Conn, uri *url.URL) error { - err := conn.Handshake() - if err != nil { - conn.Close() + if err := conn.Handshake(); err != nil { return err } cs := conn.ConnectionState() if !cs.NegotiatedProtocolIsMutual || cs.NegotiatedProtocol != protocol.ProtocolName { - conn.Close() return fmt.Errorf("protocol negotiation error") } @@ -232,22 +228,30 @@ func performHandshakeAndValidation(conn *tls.Conn, uri *url.URL) error { if relayIDs != "" { relayID, err := syncthingprotocol.DeviceIDFromString(relayIDs) if err != nil { - conn.Close() return fmt.Errorf("relay address contains invalid verification id: %s", err) } certs := cs.PeerCertificates if cl := len(certs); cl != 1 { - conn.Close() return fmt.Errorf("unexpected certificate count: %d", cl) } remoteID := syncthingprotocol.NewDeviceID(certs[0].Raw) if remoteID != relayID { - conn.Close() return fmt.Errorf("relay id does not match. Expected %v got %v", relayID, remoteID) } } return nil } + +func messageReader(conn net.Conn, messages chan<- interface{}, errors chan<- error) { + for { + msg, err := protocol.ReadMessage(conn) + if err != nil { + errors <- err + return + } + messages <- msg + } +} From 78ef42daa1943b4bd012e6ec7906b49755574967 Mon Sep 17 00:00:00 2001 From: Audrius Butkevicius Date: Wed, 22 Jul 2015 22:34:05 +0100 Subject: [PATCH 13/30] Add ability to lookup relay status --- client/client.go | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/client/client.go b/client/client.go index 7169e6a8..48e97b40 100644 --- a/client/client.go +++ b/client/client.go @@ -12,6 +12,7 @@ import ( syncthingprotocol "github.com/syncthing/protocol" "github.com/syncthing/relaysrv/protocol" + "github.com/syncthing/syncthing/internal/sync" ) type ProtocolClient struct { @@ -28,6 +29,9 @@ type ProtocolClient struct { stopped chan struct{} conn *tls.Conn + + mut sync.RWMutex + connected bool } func NewProtocolClient(uri *url.URL, certs []tls.Certificate, invitations chan protocol.SessionInvitation) *ProtocolClient { @@ -49,6 +53,9 @@ func NewProtocolClient(uri *url.URL, certs []tls.Certificate, invitations chan p stop: make(chan struct{}), stopped: make(chan struct{}), + + mut: sync.NewRWMutex(), + connected: false, } } @@ -82,6 +89,9 @@ func (c *ProtocolClient) Serve() { } defer c.cleanup() + c.mut.Lock() + c.connected = true + c.mut.Unlock() messages := make(chan interface{}) errors := make(chan error, 1) @@ -149,6 +159,13 @@ func (c *ProtocolClient) Stop() { <-c.stopped } +func (c *ProtocolClient) StatusOK() bool { + c.mut.RLock() + con := c.connected + c.mut.RUnlock() + return con +} + func (c *ProtocolClient) String() string { return fmt.Sprintf("ProtocolClient@%p", c) } @@ -187,6 +204,10 @@ func (c *ProtocolClient) cleanup() { l.Debugln(c, "cleaning up") } + c.mut.Lock() + c.connected = false + c.mut.Unlock() + c.conn.Close() } From eb29989dff2810d7317229f7742e8e58d116dc73 Mon Sep 17 00:00:00 2001 From: Audrius Butkevicius Date: Thu, 23 Jul 2015 20:53:16 +0100 Subject: [PATCH 14/30] Connection errors are debug errors --- client/client.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/client/client.go b/client/client.go index 48e97b40..a1924e3e 100644 --- a/client/client.go +++ b/client/client.go @@ -65,7 +65,9 @@ func (c *ProtocolClient) Serve() { defer close(c.stopped) if err := c.connect(); err != nil { - l.Infoln("Relay connect:", err) + if debug { + l.Debugln("Relay connect:", err) + } return } From 7c6a31017968e7c1a69148db1ca3dea71eba8236 Mon Sep 17 00:00:00 2001 From: Audrius Butkevicius Date: Wed, 19 Aug 2015 20:49:34 +0100 Subject: [PATCH 15/30] Fix after package move --- client/client.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/client/client.go b/client/client.go index a1924e3e..94e4eedd 100644 --- a/client/client.go +++ b/client/client.go @@ -12,7 +12,7 @@ import ( syncthingprotocol "github.com/syncthing/protocol" "github.com/syncthing/relaysrv/protocol" - "github.com/syncthing/syncthing/internal/sync" + "github.com/syncthing/syncthing/lib/sync" ) type ProtocolClient struct { From f0c0c5483fe50b5ec060aac32db7da848368f160 Mon Sep 17 00:00:00 2001 From: Jakob Borg Date: Thu, 20 Aug 2015 12:33:11 +0200 Subject: [PATCH 16/30] Cleaner build --- build.sh | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/build.sh b/build.sh index 5f605e1b..42c2681d 100755 --- a/build.sh +++ b/build.sh @@ -11,15 +11,17 @@ build() { export GOOS="$1" export GOARCH="$2" target="relaysrv-$GOOS-$GOARCH" - go build -v + go build -i -v -ldflags -w mkdir "$target" if [ -f relaysrv ] ; then mv relaysrv "$target" - tar zcvf "$target.tar.gz" "$target" + tar zcvf "$target.tar.gz" "$target" + rm -r "$target" fi if [ -f relaysrv.exe ] ; then - mv relaysrv.exe "$target" + mv relaysrv.exe "$target" zip -r "$target.zip" "$target" + rm -r "$target" fi } From d7949aa58ee92df583793e43ce751ffcc0c4f283 Mon Sep 17 00:00:00 2001 From: Jakob Borg Date: Thu, 20 Aug 2015 12:33:52 +0200 Subject: [PATCH 17/30] I contribute stuff --- CONTRIBUTORS | 1 + 1 file changed, 1 insertion(+) diff --git a/CONTRIBUTORS b/CONTRIBUTORS index e69de29b..4b00258c 100644 --- a/CONTRIBUTORS +++ b/CONTRIBUTORS @@ -0,0 +1 @@ +Jakob Borg From f76a66fc555150cf752040193fa2d154cbc61b61 Mon Sep 17 00:00:00 2001 From: Jakob Borg Date: Thu, 20 Aug 2015 12:59:44 +0200 Subject: [PATCH 18/30] Very basic status service --- main.go | 7 +++++++ session.go | 12 ++++++++++-- status.go | 39 +++++++++++++++++++++++++++++++++++++++ 3 files changed, 56 insertions(+), 2 deletions(-) create mode 100644 status.go diff --git a/main.go b/main.go index b429d94a..2defed87 100644 --- a/main.go +++ b/main.go @@ -32,6 +32,8 @@ var ( globalLimitBps int sessionLimiter *ratelimit.Bucket globalLimiter *ratelimit.Bucket + + statusAddr string ) func main() { @@ -47,6 +49,7 @@ func main() { flag.IntVar(&sessionLimitBps, "per-session-rate", sessionLimitBps, "Per session rate limit, in bytes/s") flag.IntVar(&globalLimitBps, "global-rate", globalLimitBps, "Global rate limit, in bytes/s") flag.BoolVar(&debug, "debug", false, "Enable debug output") + flag.StringVar(&statusAddr, "status-srv", ":22070", "Listen address for status service (blank to disable)") flag.Parse() @@ -99,5 +102,9 @@ func main() { go sessionListener(listenSession) + if statusAddr != "" { + go statusService(statusAddr) + } + protocolListener(listenProtocol, tlsCfg) } diff --git a/session.go b/session.go index c526ed5d..a189ab40 100644 --- a/session.go +++ b/session.go @@ -9,6 +9,7 @@ import ( "log" "net" "sync" + "sync/atomic" "time" "github.com/juju/ratelimit" @@ -18,8 +19,10 @@ import ( ) var ( - sessionMut = sync.Mutex{} - sessions = make(map[string]*session, 0) + sessionMut = sync.Mutex{} + sessions = make(map[string]*session, 0) + numProxies int64 + bytesProxied int64 ) type session struct { @@ -178,6 +181,9 @@ func (s *session) proxy(c1, c2 net.Conn) error { log.Println("Proxy", c1.RemoteAddr(), "->", c2.RemoteAddr()) } + atomic.AddInt64(&numProxies, 1) + defer atomic.AddInt64(&numProxies, -1) + buf := make([]byte, 65536) for { c1.SetReadDeadline(time.Now().Add(networkTimeout)) @@ -186,6 +192,8 @@ func (s *session) proxy(c1, c2 net.Conn) error { return err } + atomic.AddInt64(&bytesProxied, int64(n)) + if debug { log.Printf("%d bytes from %s to %s", n, c1.RemoteAddr(), c2.RemoteAddr()) } diff --git a/status.go b/status.go new file mode 100644 index 00000000..d0391e17 --- /dev/null +++ b/status.go @@ -0,0 +1,39 @@ +package main + +import ( + "encoding/json" + "log" + "net/http" + "runtime" + "sync/atomic" +) + +func statusService(addr string) { + http.HandleFunc("/status", getStatus) + if err := http.ListenAndServe(addr, nil); err != nil { + log.Fatal(err) + } +} + +func getStatus(w http.ResponseWriter, r *http.Request) { + status := make(map[string]interface{}) + + sessionMut.Lock() + status["numSessions"] = len(sessions) + sessionMut.Unlock() + status["numProxies"] = atomic.LoadInt64(&numProxies) + status["bytesProxied"] = atomic.LoadInt64(&bytesProxied) + status["goVersion"] = runtime.Version() + status["goOS"] = runtime.GOOS + status["goAarch"] = runtime.GOARCH + status["goMaxProcs"] = runtime.GOMAXPROCS(-1) + + bs, err := json.MarshalIndent(status, "", " ") + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + w.Header().Set("Content-Type", "application/json") + w.Write(bs) +} From 37cbe68204c4bf16c6d7db374c5908580028e279 Mon Sep 17 00:00:00 2001 From: Jakob Borg Date: Thu, 20 Aug 2015 13:58:07 +0200 Subject: [PATCH 19/30] Fix broken connection close --- session_listener.go | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/session_listener.go b/session_listener.go index 2f6bae9a..82d2ec73 100644 --- a/session_listener.go +++ b/session_listener.go @@ -36,8 +36,6 @@ func sessionListener(addr string) { } func sessionConnectionHandler(conn net.Conn) { - defer conn.Close() - if err := conn.SetDeadline(time.Now().Add(messageTimeout)); err != nil { if debug { log.Println("Weird error setting deadline:", err, "on", conn.RemoteAddr()) @@ -59,6 +57,7 @@ func sessionConnectionHandler(conn net.Conn) { if ses == nil { protocol.WriteMessage(conn, protocol.ResponseNotFound) + conn.Close() return } @@ -67,6 +66,7 @@ func sessionConnectionHandler(conn net.Conn) { log.Println("Failed to add", conn.RemoteAddr(), "to session", ses) } protocol.WriteMessage(conn, protocol.ResponseAlreadyConnected) + conn.Close() return } @@ -81,12 +81,15 @@ func sessionConnectionHandler(conn net.Conn) { if debug { log.Println("Weird error setting deadline:", err, "on", conn.RemoteAddr()) } + conn.Close() return } + default: if debug { log.Println("Unexpected message from", conn.RemoteAddr(), message) } protocol.WriteMessage(conn, protocol.ResponseUnexpectedMessage) + conn.Close() } } From 7fe1fdd8c751df165ea825bc8d3e895f118bb236 Mon Sep 17 00:00:00 2001 From: Jakob Borg Date: Thu, 20 Aug 2015 14:02:52 +0200 Subject: [PATCH 20/30] Improve status reporter --- .gitignore | 1 + protocol_listener.go | 11 ++++++--- session.go | 13 ++++++----- status.go | 53 ++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 69 insertions(+), 9 deletions(-) diff --git a/.gitignore b/.gitignore index b7006615..775c43cc 100644 --- a/.gitignore +++ b/.gitignore @@ -24,3 +24,4 @@ _testmain.go *.prof *.tar.gz *.zip +relaysrv diff --git a/protocol_listener.go b/protocol_listener.go index a7243ff6..0ff0154d 100644 --- a/protocol_listener.go +++ b/protocol_listener.go @@ -7,6 +7,7 @@ import ( "log" "net" "sync" + "sync/atomic" "time" syncthingprotocol "github.com/syncthing/protocol" @@ -15,8 +16,9 @@ import ( ) var ( - outboxesMut = sync.RWMutex{} - outboxes = make(map[syncthingprotocol.DeviceID]chan interface{}) + outboxesMut = sync.RWMutex{} + outboxes = make(map[syncthingprotocol.DeviceID]chan interface{}) + numConnections int64 ) func protocolListener(addr string, config *tls.Config) { @@ -122,7 +124,7 @@ func protocolConnectionHandler(tcpConn net.Conn, config *tls.Config) { outboxesMut.RUnlock() if !ok { if debug { - log.Println(id, "is looking", requestedPeer, "which does not exist") + log.Println(id, "is looking for", requestedPeer, "which does not exist") } protocol.WriteMessage(conn, protocol.ResponseNotFound) conn.Close() @@ -219,6 +221,9 @@ func protocolConnectionHandler(tcpConn net.Conn, config *tls.Config) { } func messageReader(conn net.Conn, messages chan<- interface{}, errors chan<- error) { + atomic.AddInt64(&numConnections, 1) + defer atomic.AddInt64(&numConnections, -1) + for { msg, err := protocol.ReadMessage(conn) if err != nil { diff --git a/session.go b/session.go index a189ab40..c94cdc2a 100644 --- a/session.go +++ b/session.go @@ -110,30 +110,31 @@ func (s *session) Serve() { close(s.conns) if debug { - log.Println("Session", s, "starting between", conns[0].RemoteAddr(), conns[1].RemoteAddr()) + log.Println("Session", s, "starting between", conns[0].RemoteAddr(), "and", conns[1].RemoteAddr()) } wg := sync.WaitGroup{} wg.Add(2) - errors := make(chan error, 2) - + var err0 error go func() { - errors <- s.proxy(conns[0], conns[1]) + err0 = s.proxy(conns[0], conns[1]) wg.Done() }() + var err1 error go func() { - errors <- s.proxy(conns[1], conns[0]) + err1 = s.proxy(conns[1], conns[0]) wg.Done() }() wg.Wait() if debug { - log.Println("Session", s, "ended, outcomes:", <-errors, <-errors) + log.Println("Session", s, "ended, outcomes:", err0, "and", err1) } goto done + case <-timedout: if debug { log.Println("Session", s, "timed out") diff --git a/status.go b/status.go index d0391e17..53cea572 100644 --- a/status.go +++ b/status.go @@ -6,9 +6,14 @@ import ( "net/http" "runtime" "sync/atomic" + "time" ) +var rc *rateCalculator + func statusService(addr string) { + rc = newRateCalculator(360, 10*time.Second, &bytesProxied) + http.HandleFunc("/status", getStatus) if err := http.ListenAndServe(addr, nil); err != nil { log.Fatal(err) @@ -21,12 +26,21 @@ func getStatus(w http.ResponseWriter, r *http.Request) { sessionMut.Lock() status["numSessions"] = len(sessions) sessionMut.Unlock() + status["numConnections"] = atomic.LoadInt64(&numConnections) status["numProxies"] = atomic.LoadInt64(&numProxies) status["bytesProxied"] = atomic.LoadInt64(&bytesProxied) status["goVersion"] = runtime.Version() status["goOS"] = runtime.GOOS status["goAarch"] = runtime.GOARCH status["goMaxProcs"] = runtime.GOMAXPROCS(-1) + status["kbps10s1m5m15m30m60m"] = []int64{ + rc.rate(10/10) * 8 / 1000, + rc.rate(60/10) * 8 / 1000, + rc.rate(5*60/10) * 8 / 1000, + rc.rate(15*60/10) * 8 / 1000, + rc.rate(30*60/10) * 8 / 1000, + rc.rate(60*60/10) * 8 / 1000, + } bs, err := json.MarshalIndent(status, "", " ") if err != nil { @@ -37,3 +51,42 @@ func getStatus(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") w.Write(bs) } + +type rateCalculator struct { + rates []int64 + prev int64 + counter *int64 +} + +func newRateCalculator(keepIntervals int, interval time.Duration, counter *int64) *rateCalculator { + r := &rateCalculator{ + rates: make([]int64, keepIntervals), + counter: counter, + } + + go r.updateRates(interval) + + return r +} + +func (r *rateCalculator) updateRates(interval time.Duration) { + for { + now := time.Now() + next := now.Truncate(interval).Add(interval) + time.Sleep(next.Sub(now)) + + cur := atomic.LoadInt64(r.counter) + rate := int64(float64(cur-r.prev) / interval.Seconds()) + copy(r.rates[1:], r.rates) + r.rates[0] = rate + r.prev = cur + } +} + +func (r *rateCalculator) rate(periods int) int64 { + var tot int64 + for i := 0; i < periods; i++ { + tot += r.rates[i] + } + return tot / int64(periods) +} From c0554c9fbfcca55f701082a8245382d5d74440b5 Mon Sep 17 00:00:00 2001 From: AudriusButkevicius Date: Wed, 2 Sep 2015 21:35:52 +0100 Subject: [PATCH 21/30] Use a single socket for relaying --- protocol_listener.go => listener.go | 77 +++++++++++++++++++++-- main.go | 15 ++--- protocol/packets.go | 3 +- session_listener.go | 95 ----------------------------- 4 files changed, 79 insertions(+), 111 deletions(-) rename protocol_listener.go => listener.go (73%) delete mode 100644 session_listener.go diff --git a/protocol_listener.go b/listener.go similarity index 73% rename from protocol_listener.go rename to listener.go index 0ff0154d..a63bc4c8 100644 --- a/protocol_listener.go +++ b/listener.go @@ -11,6 +11,7 @@ import ( "time" syncthingprotocol "github.com/syncthing/protocol" + "github.com/syncthing/syncthing/lib/tlsutil" "github.com/syncthing/relaysrv/protocol" ) @@ -21,14 +22,16 @@ var ( numConnections int64 ) -func protocolListener(addr string, config *tls.Config) { - listener, err := net.Listen("tcp", addr) +func listener(addr string, config *tls.Config) { + tcpListener, err := net.Listen("tcp", addr) if err != nil { log.Fatalln(err) } + listener := tlsutil.DowngradingListener{tcpListener, nil} + for { - conn, err := listener.Accept() + conn, isTLS, err := listener.AcceptNoWrap() if err != nil { if debug { log.Println(err) @@ -39,10 +42,15 @@ func protocolListener(addr string, config *tls.Config) { setTCPOptions(conn) if debug { - log.Println("Protocol listener accepted connection from", conn.RemoteAddr()) + log.Println("Listener accepted connection from", conn.RemoteAddr(), "tls", isTLS) + } + + if isTLS { + go protocolConnectionHandler(conn, config) + } else { + go sessionConnectionHandler(conn) } - go protocolConnectionHandler(conn, config) } } @@ -220,6 +228,65 @@ func protocolConnectionHandler(tcpConn net.Conn, config *tls.Config) { } } +func sessionConnectionHandler(conn net.Conn) { + if err := conn.SetDeadline(time.Now().Add(messageTimeout)); err != nil { + if debug { + log.Println("Weird error setting deadline:", err, "on", conn.RemoteAddr()) + } + return + } + + message, err := protocol.ReadMessage(conn) + if err != nil { + return + } + + switch msg := message.(type) { + case protocol.JoinSessionRequest: + ses := findSession(string(msg.Key)) + if debug { + log.Println(conn.RemoteAddr(), "session lookup", ses) + } + + if ses == nil { + protocol.WriteMessage(conn, protocol.ResponseNotFound) + conn.Close() + return + } + + if !ses.AddConnection(conn) { + if debug { + log.Println("Failed to add", conn.RemoteAddr(), "to session", ses) + } + protocol.WriteMessage(conn, protocol.ResponseAlreadyConnected) + conn.Close() + return + } + + if err := protocol.WriteMessage(conn, protocol.ResponseSuccess); err != nil { + if debug { + log.Println("Failed to send session join response to ", conn.RemoteAddr(), "for", ses) + } + return + } + + if err := conn.SetDeadline(time.Time{}); err != nil { + if debug { + log.Println("Weird error setting deadline:", err, "on", conn.RemoteAddr()) + } + conn.Close() + return + } + + default: + if debug { + log.Println("Unexpected message from", conn.RemoteAddr(), message) + } + protocol.WriteMessage(conn, protocol.ResponseUnexpectedMessage) + conn.Close() + } +} + func messageReader(conn net.Conn, messages chan<- interface{}, errors chan<- error) { atomic.AddInt64(&numConnections, 1) defer atomic.AddInt64(&numConnections, -1) diff --git a/main.go b/main.go index 2defed87..c050c355 100644 --- a/main.go +++ b/main.go @@ -17,9 +17,8 @@ import ( ) var ( - listenProtocol string - listenSession string - debug bool + listen string + debug bool sessionAddress []byte sessionPort uint16 @@ -39,9 +38,7 @@ var ( func main() { var dir, extAddress string - flag.StringVar(&listenProtocol, "protocol-listen", ":22067", "Protocol listen address") - flag.StringVar(&listenSession, "session-listen", ":22068", "Session listen address") - flag.StringVar(&extAddress, "external-address", "", "External address to advertise, defaults no IP and session-listen port, causing clients to use the remote IP from the protocol connection") + flag.StringVar(&listen, "listen", ":22067", "Protocol listen address") flag.StringVar(&dir, "keys", ".", "Directory where cert.pem and key.pem is stored") flag.DurationVar(&networkTimeout, "network-timeout", 2*time.Minute, "Timeout for network operations") flag.DurationVar(&pingInterval, "ping-interval", time.Minute, "How often pings are sent") @@ -54,7 +51,7 @@ func main() { flag.Parse() if extAddress == "" { - extAddress = listenSession + extAddress = listen } addr, err := net.ResolveTCPAddr("tcp", extAddress) @@ -100,11 +97,9 @@ func main() { globalLimiter = ratelimit.NewBucketWithRate(float64(globalLimitBps), int64(2*globalLimitBps)) } - go sessionListener(listenSession) - if statusAddr != "" { go statusService(statusAddr) } - protocolListener(listenProtocol, tlsCfg) + listener(listen, tlsCfg) } diff --git a/protocol/packets.go b/protocol/packets.go index 7ff02011..1b21eba2 100644 --- a/protocol/packets.go +++ b/protocol/packets.go @@ -7,8 +7,9 @@ package protocol import ( "fmt" - syncthingprotocol "github.com/syncthing/protocol" "net" + + syncthingprotocol "github.com/syncthing/protocol" ) const ( diff --git a/session_listener.go b/session_listener.go deleted file mode 100644 index 82d2ec73..00000000 --- a/session_listener.go +++ /dev/null @@ -1,95 +0,0 @@ -// Copyright (C) 2015 Audrius Butkevicius and Contributors (see the CONTRIBUTORS file). - -package main - -import ( - "log" - "net" - "time" - - "github.com/syncthing/relaysrv/protocol" -) - -func sessionListener(addr string) { - listener, err := net.Listen("tcp", addr) - if err != nil { - log.Fatalln(err) - } - - for { - conn, err := listener.Accept() - if err != nil { - if debug { - log.Println(err) - } - continue - } - - setTCPOptions(conn) - - if debug { - log.Println("Session listener accepted connection from", conn.RemoteAddr()) - } - - go sessionConnectionHandler(conn) - } -} - -func sessionConnectionHandler(conn net.Conn) { - if err := conn.SetDeadline(time.Now().Add(messageTimeout)); err != nil { - if debug { - log.Println("Weird error setting deadline:", err, "on", conn.RemoteAddr()) - } - return - } - - message, err := protocol.ReadMessage(conn) - if err != nil { - return - } - - switch msg := message.(type) { - case protocol.JoinSessionRequest: - ses := findSession(string(msg.Key)) - if debug { - log.Println(conn.RemoteAddr(), "session lookup", ses) - } - - if ses == nil { - protocol.WriteMessage(conn, protocol.ResponseNotFound) - conn.Close() - return - } - - if !ses.AddConnection(conn) { - if debug { - log.Println("Failed to add", conn.RemoteAddr(), "to session", ses) - } - protocol.WriteMessage(conn, protocol.ResponseAlreadyConnected) - conn.Close() - return - } - - if err := protocol.WriteMessage(conn, protocol.ResponseSuccess); err != nil { - if debug { - log.Println("Failed to send session join response to ", conn.RemoteAddr(), "for", ses) - } - return - } - - if err := conn.SetDeadline(time.Time{}); err != nil { - if debug { - log.Println("Weird error setting deadline:", err, "on", conn.RemoteAddr()) - } - conn.Close() - return - } - - default: - if debug { - log.Println("Unexpected message from", conn.RemoteAddr(), message) - } - protocol.WriteMessage(conn, protocol.ResponseUnexpectedMessage) - conn.Close() - } -} From 541d05df1b732c26438d1e9b8758487c4a51db96 Mon Sep 17 00:00:00 2001 From: AudriusButkevicius Date: Wed, 2 Sep 2015 22:02:17 +0100 Subject: [PATCH 22/30] Use new method name --- listener.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/listener.go b/listener.go index a63bc4c8..2091b1dd 100644 --- a/listener.go +++ b/listener.go @@ -31,7 +31,7 @@ func listener(addr string, config *tls.Config) { listener := tlsutil.DowngradingListener{tcpListener, nil} for { - conn, isTLS, err := listener.AcceptNoWrap() + conn, isTLS, err := listener.AcceptNoWrapTLS() if err != nil { if debug { log.Println(err) From 11b2815b8858d3fadc36c9dcf5ae9b64b7d3d3ce Mon Sep 17 00:00:00 2001 From: AudriusButkevicius Date: Sun, 6 Sep 2015 18:35:38 +0100 Subject: [PATCH 23/30] Add a test method, fix nil pointer panic --- client/methods.go | 22 +++++++++++++++++++++- testutil/main.go | 9 ++++++++- 2 files changed, 29 insertions(+), 2 deletions(-) diff --git a/client/methods.go b/client/methods.go index ef6145e9..cfc810f5 100644 --- a/client/methods.go +++ b/client/methods.go @@ -8,6 +8,7 @@ import ( "net" "net/url" "strconv" + "strings" "time" syncthingprotocol "github.com/syncthing/protocol" @@ -20,10 +21,10 @@ func GetInvitationFromRelay(uri *url.URL, id syncthingprotocol.DeviceID, certs [ } conn, err := tls.Dial("tcp", uri.Host, configForCerts(certs)) - conn.SetDeadline(time.Now().Add(10 * time.Second)) if err != nil { return protocol.SessionInvitation{}, err } + conn.SetDeadline(time.Now().Add(10 * time.Second)) if err := performHandshakeAndValidation(conn, uri); err != nil { return protocol.SessionInvitation{}, err @@ -97,6 +98,25 @@ func JoinSession(invitation protocol.SessionInvitation) (net.Conn, error) { } } +func TestRelay(uri *url.URL, certs []tls.Certificate) bool { + id := syncthingprotocol.NewDeviceID(certs[0].Certificate[0]) + c := NewProtocolClient(uri, certs, nil) + go c.Serve() + defer c.Stop() + + for i := 0; i < 5; i++ { + _, err := GetInvitationFromRelay(uri, id, certs) + if err == nil { + return true + } + if !strings.Contains(err.Error(), "Incorrect response code") { + return false + } + time.Sleep(time.Second) + } + return false +} + func configForCerts(certs []tls.Certificate) *tls.Config { return &tls.Config{ Certificates: certs, diff --git a/testutil/main.go b/testutil/main.go index 69dbb00a..ffeb9942 100644 --- a/testutil/main.go +++ b/testutil/main.go @@ -23,10 +23,11 @@ func main() { log.SetFlags(log.LstdFlags | log.Lshortfile) var connect, relay, dir string - var join bool + var join, test bool flag.StringVar(&connect, "connect", "", "Device ID to which to connect to") flag.BoolVar(&join, "join", false, "Join relay") + flag.BoolVar(&test, "test", false, "Generic relay test") flag.StringVar(&relay, "relay", "relay://127.0.0.1:22067", "Relay address") flag.StringVar(&dir, "keys", ".", "Directory where cert.pem and key.pem is stored") @@ -99,6 +100,12 @@ func main() { log.Println("Joined", conn.RemoteAddr(), conn.LocalAddr()) connectToStdio(stdin, conn) log.Println("Finished", conn.RemoteAddr(), conn.LocalAddr()) + } else if test { + if client.TestRelay(uri, []tls.Certificate{cert}) { + log.Println("OK") + } else { + log.Println("FAIL") + } } else { log.Fatal("Requires either join or connect") } From e3ca797dadcbec00b8e5bd7d98cb2f2b2bb43a7a Mon Sep 17 00:00:00 2001 From: AudriusButkevicius Date: Sun, 6 Sep 2015 20:25:53 +0100 Subject: [PATCH 24/30] Receive the invite, otherwise stop blocks, add extra arguments --- client/methods.go | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/client/methods.go b/client/methods.go index cfc810f5..67a9a71c 100644 --- a/client/methods.go +++ b/client/methods.go @@ -98,13 +98,17 @@ func JoinSession(invitation protocol.SessionInvitation) (net.Conn, error) { } } -func TestRelay(uri *url.URL, certs []tls.Certificate) bool { +func TestRelay(uri *url.URL, certs []tls.Certificate, sleep time.Duration, times int) bool { id := syncthingprotocol.NewDeviceID(certs[0].Certificate[0]) - c := NewProtocolClient(uri, certs, nil) + invs := make(chan protocol.SessionInvitation, 1) + c := NewProtocolClient(uri, certs, invs) go c.Serve() - defer c.Stop() + defer func() { + close(invs) + c.Stop() + }() - for i := 0; i < 5; i++ { + for i := 0; i < times; i++ { _, err := GetInvitationFromRelay(uri, id, certs) if err == nil { return true @@ -112,7 +116,7 @@ func TestRelay(uri *url.URL, certs []tls.Certificate) bool { if !strings.Contains(err.Error(), "Incorrect response code") { return false } - time.Sleep(time.Second) + time.Sleep(sleep) } return false } From eab5fd5bdd35f27e4134bec87f6518f823565fe8 Mon Sep 17 00:00:00 2001 From: AudriusButkevicius Date: Mon, 7 Sep 2015 09:21:23 +0100 Subject: [PATCH 25/30] Join relay pool by default --- main.go | 28 ++++++++++++++++++++++++- pool.go | 63 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 90 insertions(+), 1 deletion(-) create mode 100644 pool.go diff --git a/main.go b/main.go index c050c355..4e7e9a13 100644 --- a/main.go +++ b/main.go @@ -5,9 +5,12 @@ package main import ( "crypto/tls" "flag" + "fmt" "log" "net" + "net/url" "path/filepath" + "strings" "time" "github.com/juju/ratelimit" @@ -32,7 +35,9 @@ var ( sessionLimiter *ratelimit.Bucket globalLimiter *ratelimit.Bucket - statusAddr string + statusAddr string + poolAddrs string + defaultPoolAddrs string = "https://relays.syncthing.net" ) func main() { @@ -47,6 +52,7 @@ func main() { flag.IntVar(&globalLimitBps, "global-rate", globalLimitBps, "Global rate limit, in bytes/s") flag.BoolVar(&debug, "debug", false, "Enable debug output") flag.StringVar(&statusAddr, "status-srv", ":22070", "Listen address for status service (blank to disable)") + flag.StringVar(&poolAddrs, "pools", defaultPoolAddrs, "Comma separated list of relau pool addresses to join") flag.Parse() @@ -101,5 +107,25 @@ func main() { go statusService(statusAddr) } + uri, err := url.Parse(fmt.Sprintf("relay://%s/?id=%s", extAddress, id)) + if err != nil { + log.Fatalln("Failed to construct URI", err) + } + + if poolAddrs == defaultPoolAddrs { + log.Println("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!") + log.Println("!! Joining default relay pools, this relay will be available for public use. !!") + log.Println(`!! Use the -pools="" command line option to make the relay private. !!`) + log.Println("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!") + } + + pools := strings.Split(poolAddrs, ",") + for _, pool := range pools { + pool = strings.TrimSpace(pool) + if len(pool) > 0 { + go poolHandler(pool, uri) + } + } + listener(listen, tlsCfg) } diff --git a/pool.go b/pool.go new file mode 100644 index 00000000..398b0c0a --- /dev/null +++ b/pool.go @@ -0,0 +1,63 @@ +// Copyright (C) 2015 Audrius Butkevicius and Contributors (see the CONTRIBUTORS file). + +package main + +import ( + "bytes" + "encoding/json" + "io/ioutil" + "log" + "net/http" + "net/url" + "time" +) + +func poolHandler(pool string, uri *url.URL) { + for { + var b bytes.Buffer + json.NewEncoder(&b).Encode(struct { + URL string `json:"url"` + }{ + uri.String(), + }) + + resp, err := http.Post(pool, "application/json", &b) + if err != nil { + if debug { + log.Println("Error joining pool", pool, err) + } + } else if resp.StatusCode == 500 { + if debug { + bs, err := ioutil.ReadAll(resp.Body) + if err != nil { + log.Println("Failed to read response body for", pool, err) + } else { + log.Println("Response for", pool, string(bs)) + } + resp.Body.Close() + } + } else if resp.StatusCode == 429 { + if debug { + log.Println(pool, "under load, will retry in a minute") + } + time.Sleep(time.Minute) + continue + } else if resp.StatusCode == 200 { + var x struct { + EvictionIn time.Duration `json:"evictionIn"` + } + err := json.NewDecoder(resp.Body).Decode(&x) + if err == nil { + rejoin := x.EvictionIn - (x.EvictionIn / 5) + if debug { + log.Println("Joined", pool, "rejoining in", rejoin) + } + time.Sleep(rejoin) + continue + } else if debug { + log.Println("Failed to deserialize respnse", err) + } + } + time.Sleep(time.Hour) + } +} From d180bc794b56b660a98b412ff193a7fa15a570ef Mon Sep 17 00:00:00 2001 From: AudriusButkevicius Date: Mon, 7 Sep 2015 18:12:18 +0100 Subject: [PATCH 26/30] Handle 403 --- pool.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/pool.go b/pool.go index 398b0c0a..0f327f69 100644 --- a/pool.go +++ b/pool.go @@ -42,6 +42,11 @@ func poolHandler(pool string, uri *url.URL) { } time.Sleep(time.Minute) continue + } else if resp.StatusCode == 403 { + if debug { + log.Println(pool, "failed to join due to IP address not matching external address") + } + return } else if resp.StatusCode == 200 { var x struct { EvictionIn time.Duration `json:"evictionIn"` From 7e0106da0cf3eddba2a5ed338fd67906753d3194 Mon Sep 17 00:00:00 2001 From: Audrius Butkevicius Date: Fri, 11 Sep 2015 20:01:33 +0100 Subject: [PATCH 27/30] Tweaks 1. Advertise relay server paramters so that clients could make a decision wether or not to connect 2. Generate certificate if it's not there. --- main.go | 31 ++++++++++++++++++++----------- 1 file changed, 20 insertions(+), 11 deletions(-) diff --git a/main.go b/main.go index 4e7e9a13..1da3c455 100644 --- a/main.go +++ b/main.go @@ -15,20 +15,21 @@ import ( "github.com/juju/ratelimit" "github.com/syncthing/relaysrv/protocol" + "github.com/syncthing/syncthing/lib/tlsutil" syncthingprotocol "github.com/syncthing/protocol" ) var ( listen string - debug bool + debug bool = false sessionAddress []byte sessionPort uint16 - networkTimeout time.Duration - pingInterval time.Duration - messageTimeout time.Duration + networkTimeout time.Duration = 2 * time.Minute + pingInterval time.Duration = time.Minute + messageTimeout time.Duration = time.Minute sessionLimitBps int globalLimitBps int @@ -45,14 +46,14 @@ func main() { flag.StringVar(&listen, "listen", ":22067", "Protocol listen address") flag.StringVar(&dir, "keys", ".", "Directory where cert.pem and key.pem is stored") - flag.DurationVar(&networkTimeout, "network-timeout", 2*time.Minute, "Timeout for network operations") - flag.DurationVar(&pingInterval, "ping-interval", time.Minute, "How often pings are sent") - flag.DurationVar(&messageTimeout, "message-timeout", time.Minute, "Maximum amount of time we wait for relevant messages to arrive") + flag.DurationVar(&networkTimeout, "network-timeout", networkTimeout, "Timeout for network operations between the client and the relay.\n\tIf no data is received between the client and the relay in this period of time, the connection is terminated.\n\tFurthermore, if no data is sent between either clients being relayed within this period of time, the session is also terminated.") + flag.DurationVar(&pingInterval, "ping-interval", pingInterval, "How often pings are sent") + flag.DurationVar(&messageTimeout, "message-timeout", messageTimeout, "Maximum amount of time we wait for relevant messages to arrive") flag.IntVar(&sessionLimitBps, "per-session-rate", sessionLimitBps, "Per session rate limit, in bytes/s") flag.IntVar(&globalLimitBps, "global-rate", globalLimitBps, "Global rate limit, in bytes/s") - flag.BoolVar(&debug, "debug", false, "Enable debug output") + flag.BoolVar(&debug, "debug", debug, "Enable debug output") flag.StringVar(&statusAddr, "status-srv", ":22070", "Listen address for status service (blank to disable)") - flag.StringVar(&poolAddrs, "pools", defaultPoolAddrs, "Comma separated list of relau pool addresses to join") + flag.StringVar(&poolAddrs, "pools", defaultPoolAddrs, "Comma separated list of relay pool addresses to join") flag.Parse() @@ -71,7 +72,11 @@ func main() { certFile, keyFile := filepath.Join(dir, "cert.pem"), filepath.Join(dir, "key.pem") cert, err := tls.LoadX509KeyPair(certFile, keyFile) if err != nil { - log.Fatalln("Failed to load X509 key pair:", err) + log.Println("Failed to load keypair. Generating one, this might take a while...") + cert, err = tlsutil.NewCertificate(certFile, keyFile, "relaysrv", 3072) + if err != nil { + log.Fatalln("Failed to generate X509 key pair:", err) + } } tlsCfg := &tls.Config{ @@ -107,11 +112,15 @@ func main() { go statusService(statusAddr) } - uri, err := url.Parse(fmt.Sprintf("relay://%s/?id=%s", extAddress, id)) + uri, err := url.Parse(fmt.Sprintf("relay://%s/?id=%s&pingInterval=%s&networkTimeout=%s&sessionLimitBps=%d&globalLimitBps=%d&statusAddr=%s", extAddress, id, pingInterval, networkTimeout, sessionLimitBps, globalLimitBps, statusAddr)) if err != nil { log.Fatalln("Failed to construct URI", err) } + if debug { + log.Println("URI:", uri.String()) + } + if poolAddrs == defaultPoolAddrs { log.Println("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!") log.Println("!! Joining default relay pools, this relay will be available for public use. !!") From f9f12131ae2fa70edc866413fee5ca4951d278b1 Mon Sep 17 00:00:00 2001 From: Audrius Butkevicius Date: Fri, 11 Sep 2015 22:29:50 +0100 Subject: [PATCH 28/30] Drop all sessions when we realize a node has gone away --- listener.go | 24 ++++++---- main.go | 2 + session.go | 131 +++++++++++++++++++++++++++++++++++++--------------- status.go | 4 +- 4 files changed, 114 insertions(+), 47 deletions(-) diff --git a/listener.go b/listener.go index 2091b1dd..9dde7627 100644 --- a/listener.go +++ b/listener.go @@ -4,6 +4,7 @@ package main import ( "crypto/tls" + "encoding/hex" "log" "net" "sync" @@ -34,7 +35,7 @@ func listener(addr string, config *tls.Config) { conn, isTLS, err := listener.AcceptNoWrapTLS() if err != nil { if debug { - log.Println(err) + log.Println("Listener failed to accept connection from", conn.RemoteAddr(), ". Possibly a TCP Ping.") } continue } @@ -138,13 +139,13 @@ func protocolConnectionHandler(tcpConn net.Conn, config *tls.Config) { conn.Close() continue } - - ses := newSession(sessionLimiter, globalLimiter) + // requestedPeer is the server, id is the client + ses := newSession(requestedPeer, id, sessionLimiter, globalLimiter) go ses.Serve() - clientInvitation := ses.GetClientInvitationMessage(requestedPeer) - serverInvitation := ses.GetServerInvitationMessage(id) + clientInvitation := ses.GetClientInvitationMessage() + serverInvitation := ses.GetServerInvitationMessage() if err := protocol.WriteMessage(conn, clientInvitation); err != nil { if debug { @@ -181,12 +182,19 @@ func protocolConnectionHandler(tcpConn net.Conn, config *tls.Config) { // Potentially closing a second time. conn.Close() - // Only delete the outbox if the client is joined, as it might be - // a lookup request coming from the same client. if joined { + // Only delete the outbox if the client is joined, as it might be + // a lookup request coming from the same client. outboxesMut.Lock() delete(outboxes, id) outboxesMut.Unlock() + // Also, kill all sessions related to this node, as it probably + // went offline. This is for the other end to realize the client + // is no longer there faster. This also helps resolve + // 'already connected' errors when one of the sides is + // restarting, and connecting to the other peer before the other + // peer even realised that the node has gone away. + dropSessions(id) } return @@ -245,7 +253,7 @@ func sessionConnectionHandler(conn net.Conn) { case protocol.JoinSessionRequest: ses := findSession(string(msg.Key)) if debug { - log.Println(conn.RemoteAddr(), "session lookup", ses) + log.Println(conn.RemoteAddr(), "session lookup", ses, hex.EncodeToString(msg.Key)[:5]) } if ses == nil { diff --git a/main.go b/main.go index 1da3c455..614f82c7 100644 --- a/main.go +++ b/main.go @@ -42,6 +42,8 @@ var ( ) func main() { + log.SetFlags(log.Lshortfile | log.LstdFlags) + var dir, extAddress string flag.StringVar(&listen, "listen", ":22067", "Protocol listen address") diff --git a/session.go b/session.go index c94cdc2a..bbd29d1f 100644 --- a/session.go +++ b/session.go @@ -19,22 +19,14 @@ import ( ) var ( - sessionMut = sync.Mutex{} - sessions = make(map[string]*session, 0) - numProxies int64 - bytesProxied int64 + sessionMut = sync.RWMutex{} + activeSessions = make([]*session, 0) + pendingSessions = make(map[string]*session, 0) + numProxies int64 + bytesProxied int64 ) -type session struct { - serverkey []byte - clientkey []byte - - rateLimit func(bytes int64) - - conns chan net.Conn -} - -func newSession(sessionRateLimit, globalRateLimit *ratelimit.Bucket) *session { +func newSession(serverid, clientid syncthingprotocol.DeviceID, sessionRateLimit, globalRateLimit *ratelimit.Bucket) *session { serverkey := make([]byte, 32) _, err := rand.Read(serverkey) if err != nil { @@ -49,9 +41,12 @@ func newSession(sessionRateLimit, globalRateLimit *ratelimit.Bucket) *session { ses := &session{ serverkey: serverkey, + serverid: serverid, clientkey: clientkey, + clientid: clientid, rateLimit: makeRateLimitFunc(sessionRateLimit, globalRateLimit), - conns: make(chan net.Conn), + connsChan: make(chan net.Conn), + conns: make([]net.Conn, 0, 2), } if debug { @@ -59,8 +54,8 @@ func newSession(sessionRateLimit, globalRateLimit *ratelimit.Bucket) *session { } sessionMut.Lock() - sessions[string(ses.serverkey)] = ses - sessions[string(ses.clientkey)] = ses + pendingSessions[string(ses.serverkey)] = ses + pendingSessions[string(ses.clientkey)] = ses sessionMut.Unlock() return ses @@ -69,13 +64,41 @@ func newSession(sessionRateLimit, globalRateLimit *ratelimit.Bucket) *session { func findSession(key string) *session { sessionMut.Lock() defer sessionMut.Unlock() - lob, ok := sessions[key] + ses, ok := pendingSessions[key] if !ok { return nil } - delete(sessions, key) - return lob + delete(pendingSessions, key) + return ses +} + +func dropSessions(id syncthingprotocol.DeviceID) { + sessionMut.RLock() + for _, session := range activeSessions { + if session.HasParticipant(id) { + if debug { + log.Println("Dropping session", session, "involving", id) + } + session.CloseConns() + } + } + sessionMut.RUnlock() +} + +type session struct { + mut sync.Mutex + + serverkey []byte + serverid syncthingprotocol.DeviceID + + clientkey []byte + clientid syncthingprotocol.DeviceID + + rateLimit func(bytes int64) + + connsChan chan net.Conn + conns []net.Conn } func (s *session) AddConnection(conn net.Conn) bool { @@ -84,7 +107,7 @@ func (s *session) AddConnection(conn net.Conn) bool { } select { - case s.conns <- conn: + case s.connsChan <- conn: return true default: } @@ -98,19 +121,21 @@ func (s *session) Serve() { log.Println("Session", s, "serving") } - conns := make([]net.Conn, 0, 2) for { select { - case conn := <-s.conns: - conns = append(conns, conn) - if len(conns) < 2 { + case conn := <-s.connsChan: + s.mut.Lock() + s.conns = append(s.conns, conn) + s.mut.Unlock() + // We're the only ones mutating% s.conns, hence we are free to read it. + if len(s.conns) < 2 { continue } - close(s.conns) + close(s.connsChan) if debug { - log.Println("Session", s, "starting between", conns[0].RemoteAddr(), "and", conns[1].RemoteAddr()) + log.Println("Session", s, "starting between", s.conns[0].RemoteAddr(), "and", s.conns[1].RemoteAddr()) } wg := sync.WaitGroup{} @@ -118,16 +143,20 @@ func (s *session) Serve() { var err0 error go func() { - err0 = s.proxy(conns[0], conns[1]) + err0 = s.proxy(s.conns[0], s.conns[1]) wg.Done() }() var err1 error go func() { - err1 = s.proxy(conns[1], conns[0]) + err1 = s.proxy(s.conns[1], s.conns[0]) wg.Done() }() + sessionMut.Lock() + activeSessions = append(activeSessions, s) + sessionMut.Unlock() + wg.Wait() if debug { @@ -143,23 +172,37 @@ func (s *session) Serve() { } } done: + // We can end up here in 3 cases: + // 1. Timeout joining, in which case there are potentially entries in pendingSessions + // 2. General session end/timeout, in which case there are entries in activeSessions + // 3. Protocol handler calls dropSession as one of it's clients disconnects. + sessionMut.Lock() - delete(sessions, string(s.serverkey)) - delete(sessions, string(s.clientkey)) + delete(pendingSessions, string(s.serverkey)) + delete(pendingSessions, string(s.clientkey)) + + for i, session := range activeSessions { + if session == s { + l := len(activeSessions) - 1 + activeSessions[i] = activeSessions[l] + activeSessions[l] = nil + activeSessions = activeSessions[:l] + } + } sessionMut.Unlock() - for _, conn := range conns { - conn.Close() - } + // If we are here because of case 2 or 3, we are potentially closing some or + // all connections a second time. + s.CloseConns() if debug { log.Println("Session", s, "stopping") } } -func (s *session) GetClientInvitationMessage(from syncthingprotocol.DeviceID) protocol.SessionInvitation { +func (s *session) GetClientInvitationMessage() protocol.SessionInvitation { return protocol.SessionInvitation{ - From: from[:], + From: s.serverid[:], Key: []byte(s.clientkey), Address: sessionAddress, Port: sessionPort, @@ -167,9 +210,9 @@ func (s *session) GetClientInvitationMessage(from syncthingprotocol.DeviceID) pr } } -func (s *session) GetServerInvitationMessage(from syncthingprotocol.DeviceID) protocol.SessionInvitation { +func (s *session) GetServerInvitationMessage() protocol.SessionInvitation { return protocol.SessionInvitation{ - From: from[:], + From: s.clientid[:], Key: []byte(s.serverkey), Address: sessionAddress, Port: sessionPort, @@ -177,6 +220,18 @@ func (s *session) GetServerInvitationMessage(from syncthingprotocol.DeviceID) pr } } +func (s *session) HasParticipant(id syncthingprotocol.DeviceID) bool { + return s.clientid == id || s.serverid == id +} + +func (s *session) CloseConns() { + s.mut.Lock() + for _, conn := range s.conns { + conn.Close() + } + s.mut.Unlock() +} + func (s *session) proxy(c1, c2 net.Conn) error { if debug { log.Println("Proxy", c1.RemoteAddr(), "->", c2.RemoteAddr()) diff --git a/status.go b/status.go index 53cea572..b18cf3ea 100644 --- a/status.go +++ b/status.go @@ -24,7 +24,9 @@ func getStatus(w http.ResponseWriter, r *http.Request) { status := make(map[string]interface{}) sessionMut.Lock() - status["numSessions"] = len(sessions) + // This can potentially be double the number of pending sessions, as each session has two keys, one for each side. + status["numPendingSessionKeys"] = len(pendingSessions) + status["numActiveSessions"] = len(activeSessions) sessionMut.Unlock() status["numConnections"] = atomic.LoadInt64(&numConnections) status["numProxies"] = atomic.LoadInt64(&numProxies) From fccb9c0bf4e1c4b7f9b3ab98ab23d31f0acbc4d7 Mon Sep 17 00:00:00 2001 From: Jakob Borg Date: Mon, 14 Sep 2015 13:44:47 +0200 Subject: [PATCH 29/30] Server should respond to ping --- listener.go | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/listener.go b/listener.go index 2091b1dd..88fe8fd0 100644 --- a/listener.go +++ b/listener.go @@ -161,6 +161,15 @@ func protocolConnectionHandler(tcpConn net.Conn, config *tls.Config) { } conn.Close() + case protocol.Ping: + if err := protocol.WriteMessage(conn, protocol.Pong{}); err != nil { + if debug { + log.Println("Error writing pong:", err) + } + conn.Close() + continue + } + case protocol.Pong: // Nothing From 22783d8f6c07c34e686e824819b0945528510291 Mon Sep 17 00:00:00 2001 From: Jakob Borg Date: Mon, 14 Sep 2015 13:55:00 +0200 Subject: [PATCH 30/30] Connected clients should know their own latency --- client/client.go | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/client/client.go b/client/client.go index 94e4eedd..89b16e00 100644 --- a/client/client.go +++ b/client/client.go @@ -32,6 +32,7 @@ type ProtocolClient struct { mut sync.RWMutex connected bool + latency time.Duration } func NewProtocolClient(uri *url.URL, certs []tls.Certificate, invitations chan protocol.SessionInvitation) *ProtocolClient { @@ -168,6 +169,13 @@ func (c *ProtocolClient) StatusOK() bool { return con } +func (c *ProtocolClient) Latency() time.Duration { + c.mut.RLock() + lat := c.latency + c.mut.RUnlock() + return lat +} + func (c *ProtocolClient) String() string { return fmt.Sprintf("ProtocolClient@%p", c) } @@ -177,11 +185,21 @@ func (c *ProtocolClient) connect() error { return fmt.Errorf("Unsupported relay schema:", c.URI.Scheme) } - conn, err := tls.Dial("tcp", c.URI.Host, c.config) + t0 := time.Now() + tcpConn, err := net.Dial("tcp", c.URI.Host) if err != nil { return err } + c.mut.Lock() + c.latency = time.Since(t0) + c.mut.Unlock() + + conn := tls.Client(tcpConn, c.config) + if err = conn.Handshake(); err != nil { + return err + } + if err := conn.SetDeadline(time.Now().Add(10 * time.Second)); err != nil { conn.Close() return err