Merge pull request #1699 from calmh/connsvc

Break out connection handling into a service
This commit is contained in:
Audrius Butkevicius 2015-04-25 15:37:08 +01:00
commit ecc8591c95
2 changed files with 92 additions and 30 deletions

View File

@ -15,23 +15,84 @@ import (
"time" "time"
"github.com/syncthing/protocol" "github.com/syncthing/protocol"
"github.com/syncthing/syncthing/internal/config"
"github.com/syncthing/syncthing/internal/events" "github.com/syncthing/syncthing/internal/events"
"github.com/syncthing/syncthing/internal/model" "github.com/syncthing/syncthing/internal/model"
"github.com/thejerf/suture"
) )
func listenConnect(myID protocol.DeviceID, m *model.Model, tlsCfg *tls.Config) { // The connection service listens on TLS and dials configured unconnected
var conns = make(chan *tls.Conn) // devices. Successfull connections are handed to the model.
type connectionSvc struct {
*suture.Supervisor
cfg *config.Wrapper
myID protocol.DeviceID
model *model.Model
tlsCfg *tls.Config
conns chan *tls.Conn
}
// Listen func newConnectionSvc(cfg *config.Wrapper, myID protocol.DeviceID, model *model.Model, tlsCfg *tls.Config) *connectionSvc {
for _, addr := range cfg.Options().ListenAddress { svc := &connectionSvc{
go listenTLS(conns, addr, tlsCfg) Supervisor: suture.NewSimple("connectionSvc"),
cfg: cfg,
myID: myID,
model: model,
tlsCfg: tlsCfg,
conns: make(chan *tls.Conn),
} }
// Connect // There are several moving parts here; one routine per listening address
go dialTLS(m, conns, tlsCfg) // to handle incoming connections, one routine to periodically attempt
// outgoing connections, and lastly one routine to the the common handling
// regardless of whether the connection was incoming or outgoing. It ends
// up as in the diagram below. We embed a Supervisor to manage the
// routines (i.e. log and restart if they crash or exit, etc).
//
// +-----------------+
// Incoming | +---------------+-+ +-----------------+
// Connections | | | | | Outgoing
// -------------->| | svc.listen | | | Connections
// | | (1 per listen | | svc.connect |-------------->
// | | address) | | |
// +-+ | | |
// +-----------------+ +-----------------+
// v v
// | |
// | |
// +------------+-----------+
// |
// | svc.conns
// v
// +-----------------+
// | |
// | |
// | svc.handle |------> model.AddConnection()
// | |
// | |
// +-----------------+
//
// TODO: Clean shutdown, and/or handling config changes on the fly. We
// partly do this now - new devices and addresses will be picked up, but
// not new listen addresses and we don't support disconnecting devices
// that are removed and so on...
svc.Add(serviceFunc(svc.connect))
for _, addr := range svc.cfg.Options().ListenAddress {
addr := addr
listener := serviceFunc(func() {
svc.listen(addr)
})
svc.Add(listener)
}
svc.Add(serviceFunc(svc.handle))
return svc
}
func (s *connectionSvc) handle() {
next: next:
for conn := range conns { for conn := range s.conns {
cs := conn.ConnectionState() cs := conn.ConnectionState()
// We should have negotiated the next level protocol "bep/1.0" as part // We should have negotiated the next level protocol "bep/1.0" as part
@ -69,13 +130,13 @@ next:
// this one. But in case we are two devices connecting to each other // this one. But in case we are two devices connecting to each other
// in parallell we don't want to do that or we end up with no // in parallell we don't want to do that or we end up with no
// connections still established... // connections still established...
if m.ConnectedTo(remoteID) { if s.model.ConnectedTo(remoteID) {
l.Infof("Connected to already connected device (%s)", remoteID) l.Infof("Connected to already connected device (%s)", remoteID)
conn.Close() conn.Close()
continue continue
} }
for deviceID, deviceCfg := range cfg.Devices() { for deviceID, deviceCfg := range s.cfg.Devices() {
if deviceID == remoteID { if deviceID == remoteID {
// Verify the name on the certificate. By default we set it to // Verify the name on the certificate. By default we set it to
// "syncthing" when generating, but the user may have replaced // "syncthing" when generating, but the user may have replaced
@ -97,7 +158,7 @@ next:
// If rate limiting is set, and based on the address we should // If rate limiting is set, and based on the address we should
// limit the connection, then we wrap it in a limiter. // limit the connection, then we wrap it in a limiter.
limit := shouldLimit(conn.RemoteAddr()) limit := s.shouldLimit(conn.RemoteAddr())
wr := io.Writer(conn) wr := io.Writer(conn)
if limit && writeRateLimit != nil { if limit && writeRateLimit != nil {
@ -110,7 +171,7 @@ next:
} }
name := fmt.Sprintf("%s-%s", conn.LocalAddr(), conn.RemoteAddr()) name := fmt.Sprintf("%s-%s", conn.LocalAddr(), conn.RemoteAddr())
protoConn := protocol.NewConnection(remoteID, rd, wr, m, name, deviceCfg.Compression) protoConn := protocol.NewConnection(remoteID, rd, wr, s.model, name, deviceCfg.Compression)
l.Infof("Established secure connection to %s at %s", remoteID, name) l.Infof("Established secure connection to %s at %s", remoteID, name)
if debugNet { if debugNet {
@ -121,12 +182,12 @@ next:
"addr": conn.RemoteAddr().String(), "addr": conn.RemoteAddr().String(),
}) })
m.AddConnection(conn, protoConn) s.model.AddConnection(conn, protoConn)
continue next continue next
} }
} }
if !cfg.IgnoredDevice(remoteID) { if !s.cfg.IgnoredDevice(remoteID) {
events.Default.Log(events.DeviceRejected, map[string]string{ events.Default.Log(events.DeviceRejected, map[string]string{
"device": remoteID.String(), "device": remoteID.String(),
"address": conn.RemoteAddr().String(), "address": conn.RemoteAddr().String(),
@ -140,7 +201,7 @@ next:
} }
} }
func listenTLS(conns chan *tls.Conn, addr string, tlsCfg *tls.Config) { func (s *connectionSvc) listen(addr string) {
if debugNet { if debugNet {
l.Debugln("listening on", addr) l.Debugln("listening on", addr)
} }
@ -166,9 +227,9 @@ func listenTLS(conns chan *tls.Conn, addr string, tlsCfg *tls.Config) {
} }
tcpConn := conn.(*net.TCPConn) tcpConn := conn.(*net.TCPConn)
setTCPOptions(tcpConn) s.setTCPOptions(tcpConn)
tc := tls.Server(conn, tlsCfg) tc := tls.Server(conn, s.tlsCfg)
err = tc.Handshake() err = tc.Handshake()
if err != nil { if err != nil {
l.Infoln("TLS handshake:", err) l.Infoln("TLS handshake:", err)
@ -176,21 +237,20 @@ func listenTLS(conns chan *tls.Conn, addr string, tlsCfg *tls.Config) {
continue continue
} }
conns <- tc s.conns <- tc
} }
} }
func dialTLS(m *model.Model, conns chan *tls.Conn, tlsCfg *tls.Config) { func (s *connectionSvc) connect() {
delay := time.Second delay := time.Second
for { for {
nextDevice: nextDevice:
for deviceID, deviceCfg := range cfg.Devices() { for deviceID, deviceCfg := range s.cfg.Devices() {
if deviceID == myID { if deviceID == myID {
continue continue
} }
if m.ConnectedTo(deviceID) { if s.model.ConnectedTo(deviceID) {
continue continue
} }
@ -238,9 +298,9 @@ func dialTLS(m *model.Model, conns chan *tls.Conn, tlsCfg *tls.Config) {
continue continue
} }
setTCPOptions(conn) s.setTCPOptions(conn)
tc := tls.Client(conn, tlsCfg) tc := tls.Client(conn, s.tlsCfg)
err = tc.Handshake() err = tc.Handshake()
if err != nil { if err != nil {
l.Infoln("TLS handshake:", err) l.Infoln("TLS handshake:", err)
@ -248,20 +308,20 @@ func dialTLS(m *model.Model, conns chan *tls.Conn, tlsCfg *tls.Config) {
continue continue
} }
conns <- tc s.conns <- tc
continue nextDevice continue nextDevice
} }
} }
time.Sleep(delay) time.Sleep(delay)
delay *= 2 delay *= 2
if maxD := time.Duration(cfg.Options().ReconnectIntervalS) * time.Second; delay > maxD { if maxD := time.Duration(s.cfg.Options().ReconnectIntervalS) * time.Second; delay > maxD {
delay = maxD delay = maxD
} }
} }
} }
func setTCPOptions(conn *net.TCPConn) { func (*connectionSvc) setTCPOptions(conn *net.TCPConn) {
var err error var err error
if err = conn.SetLinger(0); err != nil { if err = conn.SetLinger(0); err != nil {
l.Infoln(err) l.Infoln(err)
@ -277,8 +337,8 @@ func setTCPOptions(conn *net.TCPConn) {
} }
} }
func shouldLimit(addr net.Addr) bool { func (s *connectionSvc) shouldLimit(addr net.Addr) bool {
if cfg.Options().LimitBandwidthInLan { if s.cfg.Options().LimitBandwidthInLan {
return true return true
} }

View File

@ -584,7 +584,9 @@ func syncthingMain() {
// Routine to connect out to configured devices // Routine to connect out to configured devices
discoverer = discovery(externalPort) discoverer = discovery(externalPort)
go listenConnect(myID, m, tlsCfg)
connectionSvc := newConnectionSvc(cfg, myID, m, tlsCfg)
mainSvc.Add(connectionSvc)
for _, folder := range cfg.Folders() { for _, folder := range cfg.Folders() {
// Routine to pull blocks from other devices to synchronize the local // Routine to pull blocks from other devices to synchronize the local