This commit is contained in:
Audrius Butkevicius 2015-06-28 01:52:01 +01:00
parent 8e191c8e6b
commit b72d31f87f
13 changed files with 1114 additions and 319 deletions

6
README.md Normal file
View File

@ -0,0 +1,6 @@
relaysrv
========
This is the relay server for the `syncthing` project.
`go get github.com/syncthing/relaysrv`

249
client/client.go Normal file
View File

@ -0,0 +1,249 @@
// Copyright (C) 2015 Audrius Butkevicius and Contributors (see the CONTRIBUTORS file).
package client
import (
"crypto/tls"
"fmt"
"log"
"net"
"net/url"
"time"
syncthingprotocol "github.com/syncthing/protocol"
"github.com/syncthing/relaysrv/protocol"
)
func NewProtocolClient(uri *url.URL, certs []tls.Certificate, invitations chan protocol.SessionInvitation) ProtocolClient {
closeInvitationsOnFinish := false
if invitations == nil {
closeInvitationsOnFinish = true
invitations = make(chan protocol.SessionInvitation)
}
return ProtocolClient{
URI: uri,
Invitations: invitations,
closeInvitationsOnFinish: closeInvitationsOnFinish,
config: configForCerts(certs),
timeout: time.Minute * 2,
stop: make(chan struct{}),
stopped: make(chan struct{}),
}
}
type ProtocolClient struct {
URI *url.URL
Invitations chan protocol.SessionInvitation
closeInvitationsOnFinish bool
config *tls.Config
timeout time.Duration
stop chan struct{}
stopped chan struct{}
conn *tls.Conn
}
func (c *ProtocolClient) connect() error {
conn, err := tls.Dial("tcp", c.URI.Host, c.config)
if err != nil {
return err
}
conn.SetDeadline(time.Now().Add(10 * time.Second))
if err := performHandshakeAndValidation(conn, c.URI); err != nil {
return err
}
c.conn = conn
return nil
}
func (c *ProtocolClient) Serve() {
if err := c.connect(); err != nil {
panic(err)
}
if debug {
l.Debugln(c, "connected", c.conn.RemoteAddr())
}
if err := c.join(); err != nil {
c.conn.Close()
panic(err)
}
c.conn.SetDeadline(time.Time{})
if debug {
l.Debugln(c, "joined", c.conn.RemoteAddr(), "via", c.conn.LocalAddr())
}
c.stop = make(chan struct{})
c.stopped = make(chan struct{})
defer c.cleanup()
messages := make(chan interface{})
errors := make(chan error, 1)
go func(conn net.Conn, message chan<- interface{}, errors chan<- error) {
for {
msg, err := protocol.ReadMessage(conn)
if err != nil {
errors <- err
return
}
messages <- msg
}
}(c.conn, messages, errors)
timeout := time.NewTimer(c.timeout)
for {
select {
case message := <-messages:
timeout.Reset(c.timeout)
if debug {
log.Printf("%s received message %T", c, message)
}
switch msg := message.(type) {
case protocol.Ping:
if err := protocol.WriteMessage(c.conn, protocol.Pong{}); err != nil {
panic(err)
}
if debug {
l.Debugln(c, "sent pong")
}
case protocol.SessionInvitation:
ip := net.IP(msg.Address)
if len(ip) == 0 || ip.IsUnspecified() {
msg.Address = c.conn.RemoteAddr().(*net.TCPAddr).IP[:]
}
c.Invitations <- msg
default:
panic(fmt.Errorf("protocol error: unexpected message %v", msg))
}
case <-c.stop:
if debug {
l.Debugln(c, "stopping")
}
break
case err := <-errors:
panic(err)
case <-timeout.C:
if debug {
l.Debugln(c, "timed out")
}
return
}
}
c.stopped <- struct{}{}
}
func (c *ProtocolClient) Stop() {
if c.stop == nil {
return
}
c.stop <- struct{}{}
<-c.stopped
}
func (c *ProtocolClient) String() string {
return fmt.Sprintf("ProtocolClient@%p", c)
}
func (c *ProtocolClient) cleanup() {
if c.closeInvitationsOnFinish {
close(c.Invitations)
c.Invitations = make(chan protocol.SessionInvitation)
}
if debug {
l.Debugln(c, "cleaning up")
}
if c.stop != nil {
close(c.stop)
c.stop = nil
}
if c.stopped != nil {
close(c.stopped)
c.stopped = nil
}
if c.conn != nil {
c.conn.Close()
}
}
func (c *ProtocolClient) join() error {
err := protocol.WriteMessage(c.conn, protocol.JoinRelayRequest{})
if err != nil {
return err
}
message, err := protocol.ReadMessage(c.conn)
if err != nil {
return err
}
switch msg := message.(type) {
case protocol.Response:
if msg.Code != 0 {
return fmt.Errorf("Incorrect response code %d: %s", msg.Code, msg.Message)
}
default:
return fmt.Errorf("protocol error: expecting response got %v", msg)
}
return nil
}
func performHandshakeAndValidation(conn *tls.Conn, uri *url.URL) error {
err := conn.Handshake()
if err != nil {
conn.Close()
return err
}
cs := conn.ConnectionState()
if !cs.NegotiatedProtocolIsMutual || cs.NegotiatedProtocol != protocol.ProtocolName {
conn.Close()
return fmt.Errorf("protocol negotiation error")
}
q := uri.Query()
relayIDs := q.Get("id")
if relayIDs != "" {
relayID, err := syncthingprotocol.DeviceIDFromString(relayIDs)
if err != nil {
conn.Close()
return fmt.Errorf("relay address contains invalid verification id: %s", err)
}
certs := cs.PeerCertificates
if cl := len(certs); cl != 1 {
conn.Close()
return fmt.Errorf("unexpected certificate count: %d", cl)
}
remoteID := syncthingprotocol.NewDeviceID(certs[0].Raw)
if remoteID != relayID {
conn.Close()
return fmt.Errorf("relay id does not match. Expected %v got %v", relayID, remoteID)
}
}
return nil
}

15
client/debug.go Normal file
View File

@ -0,0 +1,15 @@
// Copyright (C) 2015 Audrius Butkevicius and Contributors (see the CONTRIBUTORS file).
package client
import (
"os"
"strings"
"github.com/calmh/logger"
)
var (
debug = strings.Contains(os.Getenv("STTRACE"), "relay") || os.Getenv("STTRACE") == "all"
l = logger.DefaultLogger
)

113
client/methods.go Normal file
View File

@ -0,0 +1,113 @@
// Copyright (C) 2015 Audrius Butkevicius and Contributors (see the CONTRIBUTORS file).
package client
import (
"crypto/tls"
"fmt"
"net"
"net/url"
"strconv"
"time"
syncthingprotocol "github.com/syncthing/protocol"
"github.com/syncthing/relaysrv/protocol"
)
func GetInvitationFromRelay(uri *url.URL, id syncthingprotocol.DeviceID, certs []tls.Certificate) (protocol.SessionInvitation, error) {
conn, err := tls.Dial("tcp", uri.Host, configForCerts(certs))
conn.SetDeadline(time.Now().Add(10 * time.Second))
if err != nil {
return protocol.SessionInvitation{}, err
}
if err := performHandshakeAndValidation(conn, uri); err != nil {
return protocol.SessionInvitation{}, err
}
defer conn.Close()
request := protocol.ConnectRequest{
ID: id[:],
}
if err := protocol.WriteMessage(conn, request); err != nil {
return protocol.SessionInvitation{}, err
}
message, err := protocol.ReadMessage(conn)
if err != nil {
return protocol.SessionInvitation{}, err
}
switch msg := message.(type) {
case protocol.Response:
return protocol.SessionInvitation{}, fmt.Errorf("Incorrect response code %d: %s", msg.Code, msg.Message)
case protocol.SessionInvitation:
if debug {
l.Debugln("Received invitation via", conn.LocalAddr())
}
ip := net.IP(msg.Address)
if len(ip) == 0 || ip.IsUnspecified() {
msg.Address = conn.RemoteAddr().(*net.TCPAddr).IP[:]
}
return msg, nil
default:
return protocol.SessionInvitation{}, fmt.Errorf("protocol error: unexpected message %v", msg)
}
}
func JoinSession(invitation protocol.SessionInvitation) (net.Conn, error) {
addr := net.JoinHostPort(net.IP(invitation.Address).String(), strconv.Itoa(int(invitation.Port)))
conn, err := net.Dial("tcp", addr)
if err != nil {
return nil, err
}
request := protocol.JoinSessionRequest{
Key: invitation.Key,
}
conn.SetDeadline(time.Now().Add(10 * time.Second))
err = protocol.WriteMessage(conn, request)
if err != nil {
return nil, err
}
message, err := protocol.ReadMessage(conn)
if err != nil {
return nil, err
}
conn.SetDeadline(time.Time{})
switch msg := message.(type) {
case protocol.Response:
if msg.Code != 0 {
return nil, fmt.Errorf("Incorrect response code %d: %s", msg.Code, msg.Message)
}
return conn, nil
default:
return nil, fmt.Errorf("protocol error: expecting response got %v", msg)
}
}
func configForCerts(certs []tls.Certificate) *tls.Config {
return &tls.Config{
Certificates: certs,
NextProtos: []string{protocol.ProtocolName},
ClientAuth: tls.RequestClientCert,
SessionTicketsDisabled: true,
InsecureSkipVerify: true,
MinVersion: tls.VersionTLS12,
CipherSuites: []uint16{
tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA,
tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA,
tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA,
tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA,
},
}
}

39
main.go
View File

@ -6,13 +6,13 @@ import (
"crypto/tls" "crypto/tls"
"flag" "flag"
"log" "log"
"os" "net"
"path/filepath" "path/filepath"
"sync"
"time" "time"
syncthingprotocol "github.com/syncthing/protocol"
"github.com/syncthing/relaysrv/protocol" "github.com/syncthing/relaysrv/protocol"
syncthingprotocol "github.com/syncthing/protocol"
) )
var ( var (
@ -26,26 +26,11 @@ var (
networkTimeout time.Duration networkTimeout time.Duration
pingInterval time.Duration pingInterval time.Duration
messageTimeout time.Duration messageTimeout time.Duration
pingMessage message
mut = sync.RWMutex{}
outbox = make(map[syncthingprotocol.DeviceID]chan message)
) )
func main() { func main() {
var dir, extAddress string var dir, extAddress string
pingPayload := protocol.Ping{}.MustMarshalXDR()
pingMessage = message{
header: protocol.Header{
Magic: protocol.Magic,
MessageType: protocol.MessageTypePing,
MessageLength: int32(len(pingPayload)),
},
payload: pingPayload,
}
flag.StringVar(&listenProtocol, "protocol-listen", ":22067", "Protocol listen address") flag.StringVar(&listenProtocol, "protocol-listen", ":22067", "Protocol listen address")
flag.StringVar(&listenSession, "session-listen", ":22068", "Session listen address") flag.StringVar(&listenSession, "session-listen", ":22068", "Session listen address")
flag.StringVar(&extAddress, "external-address", "", "External address to advertise, defaults no IP and session-listen port, causing clients to use the remote IP from the protocol connection") flag.StringVar(&extAddress, "external-address", "", "External address to advertise, defaults no IP and session-listen port, causing clients to use the remote IP from the protocol connection")
@ -54,7 +39,20 @@ func main() {
flag.DurationVar(&pingInterval, "ping-interval", time.Minute, "How often pings are sent") flag.DurationVar(&pingInterval, "ping-interval", time.Minute, "How often pings are sent")
flag.DurationVar(&messageTimeout, "message-timeout", time.Minute, "Maximum amount of time we wait for relevant messages to arrive") flag.DurationVar(&messageTimeout, "message-timeout", time.Minute, "Maximum amount of time we wait for relevant messages to arrive")
if extAddress == "" {
extAddress = listenSession
}
addr, err := net.ResolveTCPAddr("tcp", extAddress)
if err != nil {
log.Fatal(err)
}
sessionAddress = addr.IP[:]
sessionPort = uint16(addr.Port)
flag.BoolVar(&debug, "debug", false, "Enable debug output") flag.BoolVar(&debug, "debug", false, "Enable debug output")
flag.Parse() flag.Parse()
certFile, keyFile := filepath.Join(dir, "cert.pem"), filepath.Join(dir, "key.pem") certFile, keyFile := filepath.Join(dir, "cert.pem"), filepath.Join(dir, "key.pem")
@ -80,7 +78,10 @@ func main() {
}, },
} }
log.SetOutput(os.Stdout) id := syncthingprotocol.NewDeviceID(cert.Certificate[0])
if debug {
log.Println("ID:", id)
}
go sessionListener(listenSession) go sessionListener(listenSession)

View File

@ -5,39 +5,41 @@
package protocol package protocol
import (
"unsafe"
)
const ( const (
Magic = 0x9E79BC40 messageTypePing int32 = iota
HeaderSize = unsafe.Sizeof(&Header{}) messageTypePong
ProtocolName = "bep-relay" messageTypeJoinRelayRequest
messageTypeJoinSessionRequest
messageTypeResponse
messageTypeConnectRequest
messageTypeSessionInvitation
) )
const ( type header struct {
MessageTypePing int32 = iota magic uint32
MessageTypePong messageType int32
MessageTypeJoinRequest messageLength int32
MessageTypeConnectRequest
MessageTypeSessionInvitation
)
type Header struct {
Magic uint32
MessageType int32
MessageLength int32
} }
type Ping struct{} type Ping struct{}
type Pong struct{} type Pong struct{}
type JoinRequest struct{} type JoinRelayRequest struct{}
type JoinSessionRequest struct {
Key []byte // max:32
}
type Response struct {
Code int32
Message string
}
type ConnectRequest struct { type ConnectRequest struct {
ID []byte // max:32 ID []byte // max:32
} }
type SessionInvitation struct { type SessionInvitation struct {
From []byte // max:32
Key []byte // max:32 Key []byte // max:32
Address []byte // max:32 Address []byte // max:32
Port uint16 Port uint16

View File

@ -13,37 +13,37 @@ import (
/* /*
Header Structure: header Structure:
0 1 2 3 0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Magic | | magic |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Message Type | | message Type |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Message Length | | message Length |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
struct Header { struct header {
unsigned int Magic; unsigned int magic;
int MessageType; int messageType;
int MessageLength; int messageLength;
} }
*/ */
func (o Header) EncodeXDR(w io.Writer) (int, error) { func (o header) EncodeXDR(w io.Writer) (int, error) {
var xw = xdr.NewWriter(w) var xw = xdr.NewWriter(w)
return o.EncodeXDRInto(xw) return o.EncodeXDRInto(xw)
} }
func (o Header) MarshalXDR() ([]byte, error) { func (o header) MarshalXDR() ([]byte, error) {
return o.AppendXDR(make([]byte, 0, 128)) return o.AppendXDR(make([]byte, 0, 128))
} }
func (o Header) MustMarshalXDR() []byte { func (o header) MustMarshalXDR() []byte {
bs, err := o.MarshalXDR() bs, err := o.MarshalXDR()
if err != nil { if err != nil {
panic(err) panic(err)
@ -51,35 +51,35 @@ func (o Header) MustMarshalXDR() []byte {
return bs return bs
} }
func (o Header) AppendXDR(bs []byte) ([]byte, error) { func (o header) AppendXDR(bs []byte) ([]byte, error) {
var aw = xdr.AppendWriter(bs) var aw = xdr.AppendWriter(bs)
var xw = xdr.NewWriter(&aw) var xw = xdr.NewWriter(&aw)
_, err := o.EncodeXDRInto(xw) _, err := o.EncodeXDRInto(xw)
return []byte(aw), err return []byte(aw), err
} }
func (o Header) EncodeXDRInto(xw *xdr.Writer) (int, error) { func (o header) EncodeXDRInto(xw *xdr.Writer) (int, error) {
xw.WriteUint32(o.Magic) xw.WriteUint32(o.magic)
xw.WriteUint32(uint32(o.MessageType)) xw.WriteUint32(uint32(o.messageType))
xw.WriteUint32(uint32(o.MessageLength)) xw.WriteUint32(uint32(o.messageLength))
return xw.Tot(), xw.Error() return xw.Tot(), xw.Error()
} }
func (o *Header) DecodeXDR(r io.Reader) error { func (o *header) DecodeXDR(r io.Reader) error {
xr := xdr.NewReader(r) xr := xdr.NewReader(r)
return o.DecodeXDRFrom(xr) return o.DecodeXDRFrom(xr)
} }
func (o *Header) UnmarshalXDR(bs []byte) error { func (o *header) UnmarshalXDR(bs []byte) error {
var br = bytes.NewReader(bs) var br = bytes.NewReader(bs)
var xr = xdr.NewReader(br) var xr = xdr.NewReader(br)
return o.DecodeXDRFrom(xr) return o.DecodeXDRFrom(xr)
} }
func (o *Header) DecodeXDRFrom(xr *xdr.Reader) error { func (o *header) DecodeXDRFrom(xr *xdr.Reader) error {
o.Magic = xr.ReadUint32() o.magic = xr.ReadUint32()
o.MessageType = int32(xr.ReadUint32()) o.messageType = int32(xr.ReadUint32())
o.MessageLength = int32(xr.ReadUint32()) o.messageLength = int32(xr.ReadUint32())
return xr.Error() return xr.Error()
} }
@ -199,28 +199,28 @@ func (o *Pong) DecodeXDRFrom(xr *xdr.Reader) error {
/* /*
JoinRequest Structure: JoinRelayRequest Structure:
0 1 2 3 0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
struct JoinRequest { struct JoinRelayRequest {
} }
*/ */
func (o JoinRequest) EncodeXDR(w io.Writer) (int, error) { func (o JoinRelayRequest) EncodeXDR(w io.Writer) (int, error) {
var xw = xdr.NewWriter(w) var xw = xdr.NewWriter(w)
return o.EncodeXDRInto(xw) return o.EncodeXDRInto(xw)
} }
func (o JoinRequest) MarshalXDR() ([]byte, error) { func (o JoinRelayRequest) MarshalXDR() ([]byte, error) {
return o.AppendXDR(make([]byte, 0, 128)) return o.AppendXDR(make([]byte, 0, 128))
} }
func (o JoinRequest) MustMarshalXDR() []byte { func (o JoinRelayRequest) MustMarshalXDR() []byte {
bs, err := o.MarshalXDR() bs, err := o.MarshalXDR()
if err != nil { if err != nil {
panic(err) panic(err)
@ -228,29 +228,169 @@ func (o JoinRequest) MustMarshalXDR() []byte {
return bs return bs
} }
func (o JoinRequest) AppendXDR(bs []byte) ([]byte, error) { func (o JoinRelayRequest) AppendXDR(bs []byte) ([]byte, error) {
var aw = xdr.AppendWriter(bs) var aw = xdr.AppendWriter(bs)
var xw = xdr.NewWriter(&aw) var xw = xdr.NewWriter(&aw)
_, err := o.EncodeXDRInto(xw) _, err := o.EncodeXDRInto(xw)
return []byte(aw), err return []byte(aw), err
} }
func (o JoinRequest) EncodeXDRInto(xw *xdr.Writer) (int, error) { func (o JoinRelayRequest) EncodeXDRInto(xw *xdr.Writer) (int, error) {
return xw.Tot(), xw.Error() return xw.Tot(), xw.Error()
} }
func (o *JoinRequest) DecodeXDR(r io.Reader) error { func (o *JoinRelayRequest) DecodeXDR(r io.Reader) error {
xr := xdr.NewReader(r) xr := xdr.NewReader(r)
return o.DecodeXDRFrom(xr) return o.DecodeXDRFrom(xr)
} }
func (o *JoinRequest) UnmarshalXDR(bs []byte) error { func (o *JoinRelayRequest) UnmarshalXDR(bs []byte) error {
var br = bytes.NewReader(bs) var br = bytes.NewReader(bs)
var xr = xdr.NewReader(br) var xr = xdr.NewReader(br)
return o.DecodeXDRFrom(xr) return o.DecodeXDRFrom(xr)
} }
func (o *JoinRequest) DecodeXDRFrom(xr *xdr.Reader) error { func (o *JoinRelayRequest) DecodeXDRFrom(xr *xdr.Reader) error {
return xr.Error()
}
/*
JoinSessionRequest Structure:
0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Length of Key |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
/ /
\ Key (variable length) \
/ /
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
struct JoinSessionRequest {
opaque Key<32>;
}
*/
func (o JoinSessionRequest) EncodeXDR(w io.Writer) (int, error) {
var xw = xdr.NewWriter(w)
return o.EncodeXDRInto(xw)
}
func (o JoinSessionRequest) MarshalXDR() ([]byte, error) {
return o.AppendXDR(make([]byte, 0, 128))
}
func (o JoinSessionRequest) MustMarshalXDR() []byte {
bs, err := o.MarshalXDR()
if err != nil {
panic(err)
}
return bs
}
func (o JoinSessionRequest) AppendXDR(bs []byte) ([]byte, error) {
var aw = xdr.AppendWriter(bs)
var xw = xdr.NewWriter(&aw)
_, err := o.EncodeXDRInto(xw)
return []byte(aw), err
}
func (o JoinSessionRequest) EncodeXDRInto(xw *xdr.Writer) (int, error) {
if l := len(o.Key); l > 32 {
return xw.Tot(), xdr.ElementSizeExceeded("Key", l, 32)
}
xw.WriteBytes(o.Key)
return xw.Tot(), xw.Error()
}
func (o *JoinSessionRequest) DecodeXDR(r io.Reader) error {
xr := xdr.NewReader(r)
return o.DecodeXDRFrom(xr)
}
func (o *JoinSessionRequest) UnmarshalXDR(bs []byte) error {
var br = bytes.NewReader(bs)
var xr = xdr.NewReader(br)
return o.DecodeXDRFrom(xr)
}
func (o *JoinSessionRequest) DecodeXDRFrom(xr *xdr.Reader) error {
o.Key = xr.ReadBytesMax(32)
return xr.Error()
}
/*
Response Structure:
0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Code |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Length of Message |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
/ /
\ Message (variable length) \
/ /
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
struct Response {
int Code;
string Message<>;
}
*/
func (o Response) EncodeXDR(w io.Writer) (int, error) {
var xw = xdr.NewWriter(w)
return o.EncodeXDRInto(xw)
}
func (o Response) MarshalXDR() ([]byte, error) {
return o.AppendXDR(make([]byte, 0, 128))
}
func (o Response) MustMarshalXDR() []byte {
bs, err := o.MarshalXDR()
if err != nil {
panic(err)
}
return bs
}
func (o Response) AppendXDR(bs []byte) ([]byte, error) {
var aw = xdr.AppendWriter(bs)
var xw = xdr.NewWriter(&aw)
_, err := o.EncodeXDRInto(xw)
return []byte(aw), err
}
func (o Response) EncodeXDRInto(xw *xdr.Writer) (int, error) {
xw.WriteUint32(uint32(o.Code))
xw.WriteString(o.Message)
return xw.Tot(), xw.Error()
}
func (o *Response) DecodeXDR(r io.Reader) error {
xr := xdr.NewReader(r)
return o.DecodeXDRFrom(xr)
}
func (o *Response) UnmarshalXDR(bs []byte) error {
var br = bytes.NewReader(bs)
var xr = xdr.NewReader(br)
return o.DecodeXDRFrom(xr)
}
func (o *Response) DecodeXDRFrom(xr *xdr.Reader) error {
o.Code = int32(xr.ReadUint32())
o.Message = xr.ReadString()
return xr.Error() return xr.Error()
} }
@ -330,6 +470,12 @@ SessionInvitation Structure:
0 1 2 3 0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Length of From |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
/ /
\ From (variable length) \
/ /
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Length of Key | | Length of Key |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
/ / / /
@ -349,6 +495,7 @@ SessionInvitation Structure:
struct SessionInvitation { struct SessionInvitation {
opaque From<32>;
opaque Key<32>; opaque Key<32>;
opaque Address<32>; opaque Address<32>;
unsigned int Port; unsigned int Port;
@ -382,6 +529,10 @@ func (o SessionInvitation) AppendXDR(bs []byte) ([]byte, error) {
} }
func (o SessionInvitation) EncodeXDRInto(xw *xdr.Writer) (int, error) { func (o SessionInvitation) EncodeXDRInto(xw *xdr.Writer) (int, error) {
if l := len(o.From); l > 32 {
return xw.Tot(), xdr.ElementSizeExceeded("From", l, 32)
}
xw.WriteBytes(o.From)
if l := len(o.Key); l > 32 { if l := len(o.Key); l > 32 {
return xw.Tot(), xdr.ElementSizeExceeded("Key", l, 32) return xw.Tot(), xdr.ElementSizeExceeded("Key", l, 32)
} }
@ -407,6 +558,7 @@ func (o *SessionInvitation) UnmarshalXDR(bs []byte) error {
} }
func (o *SessionInvitation) DecodeXDRFrom(xr *xdr.Reader) error { func (o *SessionInvitation) DecodeXDRFrom(xr *xdr.Reader) error {
o.From = xr.ReadBytesMax(32)
o.Key = xr.ReadBytesMax(32) o.Key = xr.ReadBytesMax(32)
o.Address = xr.ReadBytesMax(32) o.Address = xr.ReadBytesMax(32)
o.Port = xr.ReadUint16() o.Port = xr.ReadUint16()

114
protocol/protocol.go Normal file
View File

@ -0,0 +1,114 @@
// Copyright (C) 2015 Audrius Butkevicius and Contributors (see the CONTRIBUTORS file).
package protocol
import (
"fmt"
"io"
)
const (
magic = 0x9E79BC40
ProtocolName = "bep-relay"
)
var (
ResponseSuccess = Response{0, "success"}
ResponseNotFound = Response{1, "not found"}
ResponseAlreadyConnected = Response{2, "already connected"}
ResponseInternalError = Response{99, "internal error"}
ResponseUnexpectedMessage = Response{100, "unexpected message"}
)
func WriteMessage(w io.Writer, message interface{}) error {
header := header{
magic: magic,
}
var payload []byte
var err error
switch msg := message.(type) {
case Ping:
payload, err = msg.MarshalXDR()
header.messageType = messageTypePing
case Pong:
payload, err = msg.MarshalXDR()
header.messageType = messageTypePong
case JoinRelayRequest:
payload, err = msg.MarshalXDR()
header.messageType = messageTypeJoinRelayRequest
case JoinSessionRequest:
payload, err = msg.MarshalXDR()
header.messageType = messageTypeJoinSessionRequest
case Response:
payload, err = msg.MarshalXDR()
header.messageType = messageTypeResponse
case ConnectRequest:
payload, err = msg.MarshalXDR()
header.messageType = messageTypeConnectRequest
case SessionInvitation:
payload, err = msg.MarshalXDR()
header.messageType = messageTypeSessionInvitation
default:
err = fmt.Errorf("Unknown message type")
}
if err != nil {
return err
}
header.messageLength = int32(len(payload))
headerpayload, err := header.MarshalXDR()
if err != nil {
return err
}
_, err = w.Write(append(headerpayload, payload...))
return err
}
func ReadMessage(r io.Reader) (interface{}, error) {
var header header
if err := header.DecodeXDR(r); err != nil {
return nil, err
}
if header.magic != magic {
return nil, fmt.Errorf("magic mismatch")
}
switch header.messageType {
case messageTypePing:
var msg Ping
err := msg.DecodeXDR(r)
return msg, err
case messageTypePong:
var msg Pong
err := msg.DecodeXDR(r)
return msg, err
case messageTypeJoinRelayRequest:
var msg JoinRelayRequest
err := msg.DecodeXDR(r)
return msg, err
case messageTypeJoinSessionRequest:
var msg JoinSessionRequest
err := msg.DecodeXDR(r)
return msg, err
case messageTypeResponse:
var msg Response
err := msg.DecodeXDR(r)
return msg, err
case messageTypeConnectRequest:
var msg ConnectRequest
err := msg.DecodeXDR(r)
return msg, err
case messageTypeSessionInvitation:
var msg SessionInvitation
err := msg.DecodeXDR(r)
return msg, err
}
return nil, fmt.Errorf("Unknown message type")
}

View File

@ -4,9 +4,9 @@ package main
import ( import (
"crypto/tls" "crypto/tls"
"io"
"log" "log"
"net" "net"
"sync"
"time" "time"
syncthingprotocol "github.com/syncthing/protocol" syncthingprotocol "github.com/syncthing/protocol"
@ -14,10 +14,10 @@ import (
"github.com/syncthing/relaysrv/protocol" "github.com/syncthing/relaysrv/protocol"
) )
type message struct { var (
header protocol.Header outboxesMut = sync.RWMutex{}
payload []byte outboxes = make(map[syncthingprotocol.DeviceID]chan interface{})
} )
func protocolListener(addr string, config *tls.Config) { func protocolListener(addr string, config *tls.Config) {
listener, err := net.Listen("tcp", addr) listener, err := net.Listen("tcp", addr)
@ -27,6 +27,7 @@ func protocolListener(addr string, config *tls.Config) {
for { for {
conn, err := listener.Accept() conn, err := listener.Accept()
setTCPOptions(conn)
if err != nil { if err != nil {
if debug { if debug {
log.Println(err) log.Println(err)
@ -43,15 +44,12 @@ func protocolListener(addr string, config *tls.Config) {
} }
func protocolConnectionHandler(tcpConn net.Conn, config *tls.Config) { func protocolConnectionHandler(tcpConn net.Conn, config *tls.Config) {
err := setTCPOptions(tcpConn)
if err != nil && debug {
log.Println("Failed to set TCP options on protocol connection", tcpConn.RemoteAddr(), err)
}
conn := tls.Server(tcpConn, config) conn := tls.Server(tcpConn, config)
err = conn.Handshake() err := conn.Handshake()
if err != nil { if err != nil {
if debug {
log.Println("Protocol connection TLS handshake:", conn.RemoteAddr(), err) log.Println("Protocol connection TLS handshake:", conn.RemoteAddr(), err)
}
conn.Close() conn.Close()
return return
} }
@ -63,168 +61,147 @@ func protocolConnectionHandler(tcpConn net.Conn, config *tls.Config) {
certs := state.PeerCertificates certs := state.PeerCertificates
if len(certs) != 1 { if len(certs) != 1 {
if debug {
log.Println("Certificate list error") log.Println("Certificate list error")
}
conn.Close() conn.Close()
return return
} }
deviceId := syncthingprotocol.NewDeviceID(certs[0].Raw) id := syncthingprotocol.NewDeviceID(certs[0].Raw)
mut.RLock() messages := make(chan interface{})
_, ok := outbox[deviceId] errors := make(chan error, 1)
mut.RUnlock() outbox := make(chan interface{})
if ok {
log.Println("Already have a peer with the same ID", deviceId, conn.RemoteAddr()) go func(conn net.Conn, message chan<- interface{}, errors chan<- error) {
conn.Close() for {
msg, err := protocol.ReadMessage(conn)
if err != nil {
errors <- err
return return
} }
messages <- msg
errorChannel := make(chan error) }
messageChannel := make(chan message) }(conn, messages, errors)
outboxChannel := make(chan message)
go readerLoop(conn, messageChannel, errorChannel)
pingTicker := time.NewTicker(pingInterval) pingTicker := time.NewTicker(pingInterval)
timeoutTicker := time.NewTimer(messageTimeout * 2) timeoutTicker := time.NewTimer(networkTimeout)
joined := false joined := false
for { for {
select { select {
case msg := <-messageChannel: case message := <-messages:
switch msg.header.MessageType { timeoutTicker.Reset(networkTimeout)
case protocol.MessageTypeJoinRequest: if debug {
mut.Lock() log.Printf("Message %T from %s", message, id)
outbox[deviceId] = outboxChannel }
mut.Unlock() switch msg := message.(type) {
joined = true case protocol.JoinRelayRequest:
case protocol.MessageTypeConnectRequest: outboxesMut.RLock()
// We will disconnect after this message, no matter what, _, ok := outboxes[id]
// because, we've either sent out an invitation, or we don't outboxesMut.RUnlock()
// have the peer available. if ok {
var fmsg protocol.ConnectRequest protocol.WriteMessage(conn, protocol.ResponseAlreadyConnected)
err := fmsg.UnmarshalXDR(msg.payload) if debug {
if err != nil { log.Println("Already have a peer with the same ID", id, conn.RemoteAddr())
log.Println(err) }
conn.Close() conn.Close()
continue continue
} }
requestedPeer := syncthingprotocol.DeviceIDFromBytes(fmsg.ID) outboxesMut.Lock()
mut.RLock() outboxes[id] = outbox
peerOutbox, ok := outbox[requestedPeer] outboxesMut.Unlock()
mut.RUnlock() joined = true
protocol.WriteMessage(conn, protocol.ResponseSuccess)
case protocol.ConnectRequest:
requestedPeer := syncthingprotocol.DeviceIDFromBytes(msg.ID)
outboxesMut.RLock()
peerOutbox, ok := outboxes[requestedPeer]
outboxesMut.RUnlock()
if !ok { if !ok {
if debug { if debug {
log.Println("Do not have", requestedPeer) log.Println(id, "is looking", requestedPeer, "which does not exist")
} }
protocol.WriteMessage(conn, protocol.ResponseNotFound)
conn.Close() conn.Close()
continue continue
} }
ses := newSession() ses := newSession()
smsg, err := ses.GetServerInvitationMessage()
if err != nil {
log.Println("Error getting server invitation", requestedPeer)
conn.Close()
continue
}
cmsg, err := ses.GetClientInvitationMessage()
if err != nil {
log.Println("Error getting client invitation", requestedPeer)
conn.Close()
continue
}
go ses.Serve() go ses.Serve()
if err := sendMessage(cmsg, conn); err != nil { clientInvitation := ses.GetClientInvitationMessage(requestedPeer)
log.Println("Failed to send invitation message", err) serverInvitation := ses.GetServerInvitationMessage(id)
} else {
peerOutbox <- smsg if err := protocol.WriteMessage(conn, clientInvitation); err != nil {
if debug { if debug {
log.Println("Sent invitation from", deviceId, "to", requestedPeer) log.Printf("Error sending invitation from %s to client: %s", id, err)
}
} }
conn.Close() conn.Close()
case protocol.MessageTypePong: continue
timeoutTicker.Reset(messageTimeout)
} }
case err := <-errorChannel:
log.Println("Closing connection:", err) peerOutbox <- serverInvitation
if debug {
log.Println("Sent invitation from", id, "to", requestedPeer)
}
conn.Close()
case protocol.Pong:
default:
if debug {
log.Printf("Unknown message %s: %T", id, message)
}
protocol.WriteMessage(conn, protocol.ResponseUnexpectedMessage)
conn.Close()
}
case err := <-errors:
if debug {
log.Printf("Closing connection %s: %s", id, err)
}
// Potentially closing a second time.
close(outbox)
conn.Close()
outboxesMut.Lock()
delete(outboxes, id)
outboxesMut.Unlock()
return return
case <-pingTicker.C: case <-pingTicker.C:
if !joined { if !joined {
log.Println(deviceId, "didn't join within", messageTimeout) if debug {
log.Println(id, "didn't join within", pingInterval)
}
conn.Close() conn.Close()
continue continue
} }
if err := sendMessage(pingMessage, conn); err != nil { if err := protocol.WriteMessage(conn, protocol.Ping{}); err != nil {
log.Println(err) if debug {
log.Println(id, err)
}
conn.Close() conn.Close()
continue
} }
case <-timeoutTicker.C: case <-timeoutTicker.C:
// We should receive a error, which will cause us to quit the // We should receive a error from the reader loop, which will cause
// loop. // us to quit this loop.
conn.Close()
case msg := <-outboxChannel:
if debug { if debug {
log.Println("Sending message to", deviceId, msg) log.Printf("%s timed out", id)
} }
if err := sendMessage(msg, conn); err == nil {
log.Println(err)
conn.Close() conn.Close()
continue case msg := <-outbox:
if debug {
log.Printf("Sending message %T to %s", msg, id)
} }
if err := protocol.WriteMessage(conn, msg); err != nil {
if debug {
log.Println(id, err)
} }
}
}
func readerLoop(conn *tls.Conn, messages chan<- message, errors chan<- error) {
header := make([]byte, protocol.HeaderSize)
data := make([]byte, 0, 0)
for {
_, err := io.ReadFull(conn, header)
if err != nil {
errors <- err
conn.Close() conn.Close()
return }
} }
var hdr protocol.Header
err = hdr.UnmarshalXDR(header)
if err != nil {
conn.Close()
return
}
if hdr.Magic != protocol.Magic {
conn.Close()
return
}
if hdr.MessageLength > int32(cap(data)) {
data = make([]byte, 0, hdr.MessageLength)
} else {
data = data[:hdr.MessageLength]
}
_, err = io.ReadFull(conn, data)
if err != nil {
errors <- err
conn.Close()
return
}
msg := message{
header: hdr,
payload: make([]byte, hdr.MessageLength),
}
copy(msg.payload, data[:hdr.MessageLength])
messages <- msg
} }
} }

View File

@ -4,23 +4,27 @@ package main
import ( import (
"crypto/rand" "crypto/rand"
"encoding/hex"
"fmt"
"log"
"net" "net"
"sync" "sync"
"time" "time"
"github.com/syncthing/relaysrv/protocol" "github.com/syncthing/relaysrv/protocol"
syncthingprotocol "github.com/syncthing/protocol"
) )
var ( var (
sessionmut = sync.Mutex{} sessionMut = sync.Mutex{}
sessions = make(map[string]*session, 0) sessions = make(map[string]*session, 0)
) )
type session struct { type session struct {
serverkey string serverkey []byte
clientkey string clientkey []byte
mut sync.RWMutex
conns chan net.Conn conns chan net.Conn
} }
@ -37,16 +41,27 @@ func newSession() *session {
return nil return nil
} }
return &session{ ses := &session{
serverkey: string(serverkey), serverkey: serverkey,
clientkey: string(clientkey), clientkey: clientkey,
conns: make(chan net.Conn), conns: make(chan net.Conn),
} }
if debug {
log.Println("New session", ses)
}
sessionMut.Lock()
sessions[string(ses.serverkey)] = ses
sessions[string(ses.clientkey)] = ses
sessionMut.Unlock()
return ses
} }
func findSession(key string) *session { func findSession(key string) *session {
sessionmut.Lock() sessionMut.Lock()
defer sessionmut.Unlock() defer sessionMut.Unlock()
lob, ok := sessions[key] lob, ok := sessions[key]
if !ok { if !ok {
return nil return nil
@ -56,118 +71,128 @@ func findSession(key string) *session {
return lob return lob
} }
func (l *session) AddConnection(conn net.Conn) { func (s *session) AddConnection(conn net.Conn) bool {
if debug {
log.Println("New connection for", s, "from", conn.RemoteAddr())
}
select { select {
case l.conns <- conn: case s.conns <- conn:
return true
default: default:
} }
return false
} }
func (l *session) Serve() { func (s *session) Serve() {
timedout := time.After(messageTimeout) timedout := time.After(messageTimeout)
sessionmut.Lock() if debug {
sessions[l.serverkey] = l log.Println("Session", s, "serving")
sessions[l.clientkey] = l }
sessionmut.Unlock()
conns := make([]net.Conn, 0, 2) conns := make([]net.Conn, 0, 2)
for { for {
select { select {
case conn := <-l.conns: case conn := <-s.conns:
conns = append(conns, conn) conns = append(conns, conn)
if len(conns) < 2 { if len(conns) < 2 {
continue continue
} }
close(l.conns) close(s.conns)
if debug {
log.Println("Session", s, "starting between", conns[0].RemoteAddr(), conns[1].RemoteAddr())
}
wg := sync.WaitGroup{} wg := sync.WaitGroup{}
wg.Add(2) wg.Add(2)
go proxy(conns[0], conns[1], wg) errors := make(chan error, 2)
go proxy(conns[1], conns[0], wg)
go func() {
errors <- proxy(conns[0], conns[1])
wg.Done()
}()
go func() {
errors <- proxy(conns[1], conns[0])
wg.Done()
}()
wg.Wait() wg.Wait()
break if debug {
log.Println("Session", s, "ended, outcomes:", <-errors, <-errors)
}
goto done
case <-timedout: case <-timedout:
sessionmut.Lock() if debug {
delete(sessions, l.serverkey) log.Println("Session", s, "timed out")
delete(sessions, l.clientkey) }
sessionmut.Unlock() goto done
}
}
done:
sessionMut.Lock()
delete(sessions, string(s.serverkey))
delete(sessions, string(s.clientkey))
sessionMut.Unlock()
for _, conn := range conns { for _, conn := range conns {
conn.Close() conn.Close()
} }
break if debug {
} log.Println("Session", s, "stopping")
} }
} }
func (l *session) GetClientInvitationMessage() (message, error) { func (s *session) GetClientInvitationMessage(from syncthingprotocol.DeviceID) protocol.SessionInvitation {
invitation := protocol.SessionInvitation{ return protocol.SessionInvitation{
Key: []byte(l.clientkey), From: from[:],
Address: nil, Key: []byte(s.clientkey),
Port: 123, Address: sessionAddress,
Port: sessionPort,
ServerSocket: false, ServerSocket: false,
} }
data, err := invitation.MarshalXDR()
if err != nil {
return message{}, err
} }
return message{ func (s *session) GetServerInvitationMessage(from syncthingprotocol.DeviceID) protocol.SessionInvitation {
header: protocol.Header{ return protocol.SessionInvitation{
Magic: protocol.Magic, From: from[:],
MessageType: protocol.MessageTypeSessionInvitation, Key: []byte(s.serverkey),
MessageLength: int32(len(data)), Address: sessionAddress,
}, Port: sessionPort,
payload: data,
}, nil
}
func (l *session) GetServerInvitationMessage() (message, error) {
invitation := protocol.SessionInvitation{
Key: []byte(l.serverkey),
Address: nil,
Port: 123,
ServerSocket: true, ServerSocket: true,
} }
data, err := invitation.MarshalXDR()
if err != nil {
return message{}, err
} }
return message{ func proxy(c1, c2 net.Conn) error {
header: protocol.Header{ if debug {
Magic: protocol.Magic, log.Println("Proxy", c1.RemoteAddr(), "->", c2.RemoteAddr())
MessageType: protocol.MessageTypeSessionInvitation,
MessageLength: int32(len(data)),
},
payload: data,
}, nil
} }
func proxy(c1, c2 net.Conn, wg sync.WaitGroup) {
for {
buf := make([]byte, 1024) buf := make([]byte, 1024)
for {
c1.SetReadDeadline(time.Now().Add(networkTimeout)) c1.SetReadDeadline(time.Now().Add(networkTimeout))
n, err := c1.Read(buf) n, err := c1.Read(buf[0:])
if err != nil { if err != nil {
break return err
}
if debug {
log.Printf("%d bytes from %s to %s", n, c1.RemoteAddr(), c2.RemoteAddr())
} }
c2.SetWriteDeadline(time.Now().Add(networkTimeout)) c2.SetWriteDeadline(time.Now().Add(networkTimeout))
_, err = c2.Write(buf[:n]) _, err = c2.Write(buf[:n])
if err != nil { if err != nil {
break return err
} }
} }
c1.Close() }
c2.Close()
wg.Done() func (s *session) String() string {
return fmt.Sprintf("<%s/%s>", hex.EncodeToString(s.clientkey)[:5], hex.EncodeToString(s.serverkey)[:5])
} }

View File

@ -3,10 +3,11 @@
package main package main
import ( import (
"io"
"log" "log"
"net" "net"
"time" "time"
"github.com/syncthing/relaysrv/protocol"
) )
func sessionListener(addr string) { func sessionListener(addr string) {
@ -17,6 +18,7 @@ func sessionListener(addr string) {
for { for {
conn, err := listener.Accept() conn, err := listener.Accept()
setTCPOptions(conn)
if err != nil { if err != nil {
if debug { if debug {
log.Println(err) log.Println(err)
@ -33,27 +35,49 @@ func sessionListener(addr string) {
} }
func sessionConnectionHandler(conn net.Conn) { func sessionConnectionHandler(conn net.Conn) {
conn.SetReadDeadline(time.Now().Add(messageTimeout)) conn.SetDeadline(time.Now().Add(messageTimeout))
key := make([]byte, 32) message, err := protocol.ReadMessage(conn)
if err != nil {
conn.Close()
return
}
_, err := io.ReadFull(conn, key) switch msg := message.(type) {
case protocol.JoinSessionRequest:
ses := findSession(string(msg.Key))
if debug {
log.Println(conn.RemoteAddr(), "session lookup", ses)
}
if ses == nil {
protocol.WriteMessage(conn, protocol.ResponseNotFound)
conn.Close()
return
}
if !ses.AddConnection(conn) {
if debug {
log.Println("Failed to add", conn.RemoteAddr(), "to session", ses)
}
protocol.WriteMessage(conn, protocol.ResponseAlreadyConnected)
conn.Close()
return
}
err := protocol.WriteMessage(conn, protocol.ResponseSuccess)
if err != nil { if err != nil {
if debug { if debug {
log.Println("Failed to read key", err, conn.RemoteAddr()) log.Println("Failed to send session join response to ", conn.RemoteAddr(), "for", ses)
} }
conn.Close() conn.Close()
return return
} }
conn.SetDeadline(time.Time{})
ses := findSession(string(key)) default:
if debug { if debug {
log.Println("Key", key, "by", conn.RemoteAddr(), "session", ses) log.Println("Unexpected message from", conn.RemoteAddr(), message)
} }
protocol.WriteMessage(conn, protocol.ResponseUnexpectedMessage)
if ses != nil {
ses.AddConnection(conn)
} else {
conn.Close() conn.Close()
return
} }
} }

142
testutil/main.go Normal file
View File

@ -0,0 +1,142 @@
// Copyright (C) 2015 Audrius Butkevicius and Contributors (see the CONTRIBUTORS file).
package main
import (
"bufio"
"crypto/tls"
"flag"
"log"
"net"
"net/url"
"os"
"path/filepath"
"time"
syncthingprotocol "github.com/syncthing/protocol"
"github.com/syncthing/relaysrv/client"
"github.com/syncthing/relaysrv/protocol"
)
func main() {
log.SetOutput(os.Stdout)
log.SetFlags(log.LstdFlags | log.Lshortfile)
var connect, relay, dir string
var join bool
flag.StringVar(&connect, "connect", "", "Device ID to which to connect to")
flag.BoolVar(&join, "join", false, "Join relay")
flag.StringVar(&relay, "relay", "relay://127.0.0.1:22067", "Relay address")
flag.StringVar(&dir, "keys", ".", "Directory where cert.pem and key.pem is stored")
flag.Parse()
certFile, keyFile := filepath.Join(dir, "cert.pem"), filepath.Join(dir, "key.pem")
cert, err := tls.LoadX509KeyPair(certFile, keyFile)
if err != nil {
log.Fatalln("Failed to load X509 key pair:", err)
}
id := syncthingprotocol.NewDeviceID(cert.Certificate[0])
log.Println("ID:", id)
uri, err := url.Parse(relay)
if err != nil {
log.Fatal(err)
}
stdin := make(chan string)
go stdinReader(stdin)
if join {
log.Printf("Creating client")
relay := client.NewProtocolClient(uri, []tls.Certificate{cert}, nil)
log.Printf("Created client")
go relay.Serve()
recv := make(chan protocol.SessionInvitation)
go func() {
log.Println("Starting invitation receiver")
for invite := range relay.Invitations {
select {
case recv <- invite:
log.Printf("Received invitation from %s on %s:%d", syncthingprotocol.DeviceIDFromBytes(invite.From), net.IP(invite.Address), invite.Port)
default:
log.Printf("Discarding invitation", invite)
}
}
}()
for {
conn, err := client.JoinSession(<-recv)
if err != nil {
log.Fatalln("Failed to join", err)
}
log.Println("Joined", conn.RemoteAddr(), conn.LocalAddr())
connectToStdio(stdin, conn)
log.Println("Finished", conn.RemoteAddr(), conn.LocalAddr())
}
} else if connect != "" {
id, err := syncthingprotocol.DeviceIDFromString(connect)
if err != nil {
log.Fatal(err)
}
invite, err := client.GetInvitationFromRelay(uri, id, []tls.Certificate{cert})
if err != nil {
log.Fatal(err)
}
log.Printf("Received invitation from %s on %s:%d", syncthingprotocol.DeviceIDFromBytes(invite.From), net.IP(invite.Address), invite.Port)
conn, err := client.JoinSession(invite)
if err != nil {
log.Fatalln("Failed to join", err)
}
log.Println("Joined", conn.RemoteAddr(), conn.LocalAddr())
connectToStdio(stdin, conn)
log.Println("Finished", conn.RemoteAddr(), conn.LocalAddr())
} else {
log.Fatal("Requires either join or connect")
}
}
func stdinReader(c chan<- string) {
scanner := bufio.NewScanner(os.Stdin)
for scanner.Scan() {
c <- scanner.Text()
c <- "\n"
}
}
func connectToStdio(stdin <-chan string, conn net.Conn) {
go func() {
}()
buf := make([]byte, 1024)
for {
conn.SetReadDeadline(time.Now().Add(time.Millisecond))
n, err := conn.Read(buf[0:])
if err != nil {
nerr, ok := err.(net.Error)
if !ok || !nerr.Timeout() {
log.Println(err)
return
}
}
os.Stdout.Write(buf[:n])
select {
case msg := <-stdin:
_, err := conn.Write([]byte(msg))
if err != nil {
return
}
default:
}
}
}

View File

@ -5,7 +5,6 @@ package main
import ( import (
"errors" "errors"
"net" "net"
"time"
) )
func setTCPOptions(conn net.Conn) error { func setTCPOptions(conn net.Conn) error {
@ -19,7 +18,7 @@ func setTCPOptions(conn net.Conn) error {
if err := tcpConn.SetNoDelay(true); err != nil { if err := tcpConn.SetNoDelay(true); err != nil {
return err return err
} }
if err := tcpConn.SetKeepAlivePeriod(60 * time.Second); err != nil { if err := tcpConn.SetKeepAlivePeriod(networkTimeout); err != nil {
return err return err
} }
if err := tcpConn.SetKeepAlive(true); err != nil { if err := tcpConn.SetKeepAlive(true); err != nil {
@ -27,27 +26,3 @@ func setTCPOptions(conn net.Conn) error {
} }
return nil return nil
} }
func sendMessage(msg message, conn net.Conn) error {
header, err := msg.header.MarshalXDR()
if err != nil {
return err
}
err = conn.SetWriteDeadline(time.Now().Add(networkTimeout))
if err != nil {
return err
}
_, err = conn.Write(header)
if err != nil {
return err
}
_, err = conn.Write(msg.payload)
if err != nil {
return err
}
return nil
}