lib: Ensure timely service termination (fixes #5860) (#5863)

This commit is contained in:
Simon Frei 2019-07-19 19:40:40 +02:00 committed by Jakob Borg
parent 1cb55904bc
commit 4d3432af3e
8 changed files with 174 additions and 150 deletions

View File

@ -8,7 +8,6 @@ package beacon
import ( import (
"net" "net"
stdsync "sync"
"github.com/thejerf/suture" "github.com/thejerf/suture"
) )
@ -24,21 +23,3 @@ type Interface interface {
Recv() ([]byte, net.Addr) Recv() ([]byte, net.Addr)
Error() error Error() error
} }
type errorHolder struct {
err error
mut stdsync.Mutex // uses stdlib sync as I want this to be trivially embeddable, and there is no risk of blocking
}
func (e *errorHolder) setError(err error) {
e.mut.Lock()
e.err = err
e.mut.Unlock()
}
func (e *errorHolder) Error() error {
e.mut.Lock()
err := e.err
e.mut.Unlock()
return err
}

View File

@ -11,8 +11,9 @@ import (
"net" "net"
"time" "time"
"github.com/syncthing/syncthing/lib/sync"
"github.com/thejerf/suture" "github.com/thejerf/suture"
"github.com/syncthing/syncthing/lib/util"
) )
type Broadcast struct { type Broadcast struct {
@ -46,14 +47,14 @@ func NewBroadcast(port int) *Broadcast {
b.br = &broadcastReader{ b.br = &broadcastReader{
port: port, port: port,
outbox: b.outbox, outbox: b.outbox,
connMut: sync.NewMutex(),
} }
b.br.ServiceWithError = util.AsServiceWithError(b.br.serve)
b.Add(b.br) b.Add(b.br)
b.bw = &broadcastWriter{ b.bw = &broadcastWriter{
port: port, port: port,
inbox: b.inbox, inbox: b.inbox,
connMut: sync.NewMutex(),
} }
b.bw.ServiceWithError = util.AsServiceWithError(b.bw.serve)
b.Add(b.bw) b.Add(b.bw)
return b return b
@ -76,34 +77,42 @@ func (b *Broadcast) Error() error {
} }
type broadcastWriter struct { type broadcastWriter struct {
util.ServiceWithError
port int port int
inbox chan []byte inbox chan []byte
conn *net.UDPConn
connMut sync.Mutex
errorHolder
} }
func (w *broadcastWriter) Serve() { func (w *broadcastWriter) serve(stop chan struct{}) error {
l.Debugln(w, "starting") l.Debugln(w, "starting")
defer l.Debugln(w, "stopping") defer l.Debugln(w, "stopping")
conn, err := net.ListenUDP("udp4", nil) conn, err := net.ListenUDP("udp4", nil)
if err != nil { if err != nil {
l.Debugln(err) l.Debugln(err)
w.setError(err) return err
return
} }
defer conn.Close() done := make(chan struct{})
defer close(done)
go func() {
select {
case <-stop:
case <-done:
}
conn.Close()
}()
w.connMut.Lock() for {
w.conn = conn var bs []byte
w.connMut.Unlock() select {
case bs = <-w.inbox:
case <-stop:
return nil
}
for bs := range w.inbox {
addrs, err := net.InterfaceAddrs() addrs, err := net.InterfaceAddrs()
if err != nil { if err != nil {
l.Debugln(err) l.Debugln(err)
w.setError(err) w.SetError(err)
continue continue
} }
@ -134,14 +143,13 @@ func (w *broadcastWriter) Serve() {
// Write timeouts should not happen. We treat it as a fatal // Write timeouts should not happen. We treat it as a fatal
// error on the socket. // error on the socket.
l.Debugln(err) l.Debugln(err)
w.setError(err) return err
return
} }
if err != nil { if err != nil {
// Some other error that we don't expect. Debug and continue. // Some other error that we don't expect. Debug and continue.
l.Debugln(err) l.Debugln(err)
w.setError(err) w.SetError(err)
continue continue
} }
@ -150,57 +158,49 @@ func (w *broadcastWriter) Serve() {
} }
if success > 0 { if success > 0 {
w.setError(nil) w.SetError(nil)
} }
} }
} }
func (w *broadcastWriter) Stop() {
w.connMut.Lock()
if w.conn != nil {
w.conn.Close()
}
w.connMut.Unlock()
}
func (w *broadcastWriter) String() string { func (w *broadcastWriter) String() string {
return fmt.Sprintf("broadcastWriter@%p", w) return fmt.Sprintf("broadcastWriter@%p", w)
} }
type broadcastReader struct { type broadcastReader struct {
util.ServiceWithError
port int port int
outbox chan recv outbox chan recv
conn *net.UDPConn
connMut sync.Mutex
errorHolder
} }
func (r *broadcastReader) Serve() { func (r *broadcastReader) serve(stop chan struct{}) error {
l.Debugln(r, "starting") l.Debugln(r, "starting")
defer l.Debugln(r, "stopping") defer l.Debugln(r, "stopping")
conn, err := net.ListenUDP("udp4", &net.UDPAddr{Port: r.port}) conn, err := net.ListenUDP("udp4", &net.UDPAddr{Port: r.port})
if err != nil { if err != nil {
l.Debugln(err) l.Debugln(err)
r.setError(err) return err
return
} }
defer conn.Close() done := make(chan struct{})
defer close(done)
r.connMut.Lock() go func() {
r.conn = conn select {
r.connMut.Unlock() case <-stop:
case <-done:
}
conn.Close()
}()
bs := make([]byte, 65536) bs := make([]byte, 65536)
for { for {
n, addr, err := conn.ReadFrom(bs) n, addr, err := conn.ReadFrom(bs)
if err != nil { if err != nil {
l.Debugln(err) l.Debugln(err)
r.setError(err) return err
return
} }
r.setError(nil) r.SetError(nil)
l.Debugf("recv %d bytes from %s", n, addr) l.Debugf("recv %d bytes from %s", n, addr)
@ -208,19 +208,12 @@ func (r *broadcastReader) Serve() {
copy(c, bs) copy(c, bs)
select { select {
case r.outbox <- recv{c, addr}: case r.outbox <- recv{c, addr}:
case <-stop:
return nil
default: default:
l.Debugln("dropping message") l.Debugln("dropping message")
} }
} }
}
func (r *broadcastReader) Stop() {
r.connMut.Lock()
if r.conn != nil {
r.conn.Close()
}
r.connMut.Unlock()
} }
func (r *broadcastReader) String() string { func (r *broadcastReader) String() string {

View File

@ -48,14 +48,14 @@ func NewMulticast(addr string) *Multicast {
addr: addr, addr: addr,
outbox: m.outbox, outbox: m.outbox,
} }
m.mr.Service = util.AsService(m.mr.serve) m.mr.ServiceWithError = util.AsServiceWithError(m.mr.serve)
m.Add(m.mr) m.Add(m.mr)
m.mw = &multicastWriter{ m.mw = &multicastWriter{
addr: addr, addr: addr,
inbox: m.inbox, inbox: m.inbox,
} }
m.mw.Service = util.AsService(m.mw.serve) m.mw.ServiceWithError = util.AsServiceWithError(m.mw.serve)
m.Add(m.mw) m.Add(m.mw)
return m return m
@ -78,29 +78,35 @@ func (m *Multicast) Error() error {
} }
type multicastWriter struct { type multicastWriter struct {
suture.Service util.ServiceWithError
addr string addr string
inbox <-chan []byte inbox <-chan []byte
errorHolder
} }
func (w *multicastWriter) serve(stop chan struct{}) { func (w *multicastWriter) serve(stop chan struct{}) error {
l.Debugln(w, "starting") l.Debugln(w, "starting")
defer l.Debugln(w, "stopping") defer l.Debugln(w, "stopping")
gaddr, err := net.ResolveUDPAddr("udp6", w.addr) gaddr, err := net.ResolveUDPAddr("udp6", w.addr)
if err != nil { if err != nil {
l.Debugln(err) l.Debugln(err)
w.setError(err) return err
return
} }
conn, err := net.ListenPacket("udp6", ":0") conn, err := net.ListenPacket("udp6", ":0")
if err != nil { if err != nil {
l.Debugln(err) l.Debugln(err)
w.setError(err) return err
return
} }
done := make(chan struct{})
defer close(done)
go func() {
select {
case <-stop:
case <-done:
}
conn.Close()
}()
pconn := ipv6.NewPacketConn(conn) pconn := ipv6.NewPacketConn(conn)
@ -113,14 +119,13 @@ func (w *multicastWriter) serve(stop chan struct{}) {
select { select {
case bs = <-w.inbox: case bs = <-w.inbox:
case <-stop: case <-stop:
return return nil
} }
intfs, err := net.Interfaces() intfs, err := net.Interfaces()
if err != nil { if err != nil {
l.Debugln(err) l.Debugln(err)
w.setError(err) return err
return
} }
success := 0 success := 0
@ -132,7 +137,7 @@ func (w *multicastWriter) serve(stop chan struct{}) {
if err != nil { if err != nil {
l.Debugln(err, "on write to", gaddr, intf.Name) l.Debugln(err, "on write to", gaddr, intf.Name)
w.setError(err) w.SetError(err)
continue continue
} }
@ -142,16 +147,13 @@ func (w *multicastWriter) serve(stop chan struct{}) {
select { select {
case <-stop: case <-stop:
return return nil
default: default:
} }
} }
if success > 0 { if success > 0 {
w.setError(nil) w.SetError(nil)
} else {
l.Debugln(err)
w.setError(err)
} }
} }
} }
@ -161,35 +163,40 @@ func (w *multicastWriter) String() string {
} }
type multicastReader struct { type multicastReader struct {
suture.Service util.ServiceWithError
addr string addr string
outbox chan<- recv outbox chan<- recv
errorHolder
} }
func (r *multicastReader) serve(stop chan struct{}) { func (r *multicastReader) serve(stop chan struct{}) error {
l.Debugln(r, "starting") l.Debugln(r, "starting")
defer l.Debugln(r, "stopping") defer l.Debugln(r, "stopping")
gaddr, err := net.ResolveUDPAddr("udp6", r.addr) gaddr, err := net.ResolveUDPAddr("udp6", r.addr)
if err != nil { if err != nil {
l.Debugln(err) l.Debugln(err)
r.setError(err) return err
return
} }
conn, err := net.ListenPacket("udp6", r.addr) conn, err := net.ListenPacket("udp6", r.addr)
if err != nil { if err != nil {
l.Debugln(err) l.Debugln(err)
r.setError(err) return err
return
} }
done := make(chan struct{})
defer close(done)
go func() {
select {
case <-stop:
case <-done:
}
conn.Close()
}()
intfs, err := net.Interfaces() intfs, err := net.Interfaces()
if err != nil { if err != nil {
l.Debugln(err) l.Debugln(err)
r.setError(err) return err
return
} }
pconn := ipv6.NewPacketConn(conn) pconn := ipv6.NewPacketConn(conn)
@ -206,16 +213,20 @@ func (r *multicastReader) serve(stop chan struct{}) {
if joined == 0 { if joined == 0 {
l.Debugln("no multicast interfaces available") l.Debugln("no multicast interfaces available")
r.setError(errors.New("no multicast interfaces available")) return errors.New("no multicast interfaces available")
return
} }
bs := make([]byte, 65536) bs := make([]byte, 65536)
for { for {
select {
case <-stop:
return nil
default:
}
n, _, addr, err := pconn.ReadFrom(bs) n, _, addr, err := pconn.ReadFrom(bs)
if err != nil { if err != nil {
l.Debugln(err) l.Debugln(err)
r.setError(err) r.SetError(err)
continue continue
} }
l.Debugf("recv %d bytes from %s", n, addr) l.Debugf("recv %d bytes from %s", n, addr)
@ -224,8 +235,6 @@ func (r *multicastReader) serve(stop chan struct{}) {
copy(c, bs) copy(c, bs)
select { select {
case r.outbox <- recv{c, addr}: case r.outbox <- recv{c, addr}:
case <-stop:
return
default: default:
l.Debugln("dropping message") l.Debugln("dropping message")
} }

View File

@ -19,7 +19,7 @@ func Register(provider DiscoverFunc) {
providers = append(providers, provider) providers = append(providers, provider)
} }
func discoverAll(renewal, timeout time.Duration) map[string]Device { func discoverAll(renewal, timeout time.Duration, stop chan struct{}) map[string]Device {
wg := &sync.WaitGroup{} wg := &sync.WaitGroup{}
wg.Add(len(providers)) wg.Add(len(providers))
@ -28,20 +28,32 @@ func discoverAll(renewal, timeout time.Duration) map[string]Device {
for _, discoverFunc := range providers { for _, discoverFunc := range providers {
go func(f DiscoverFunc) { go func(f DiscoverFunc) {
defer wg.Done()
for _, dev := range f(renewal, timeout) { for _, dev := range f(renewal, timeout) {
c <- dev select {
case c <- dev:
case <-stop:
return
}
} }
wg.Done()
}(discoverFunc) }(discoverFunc)
} }
nats := make(map[string]Device) nats := make(map[string]Device)
go func() { go func() {
for dev := range c { defer close(done)
nats[dev.ID()] = dev for {
select {
case dev, ok := <-c:
if !ok {
return
}
nats[dev.ID()] = dev
case <-stop:
return
}
} }
close(done)
}() }()
wg.Wait() wg.Wait()

View File

@ -14,17 +14,21 @@ import (
stdsync "sync" stdsync "sync"
"time" "time"
"github.com/thejerf/suture"
"github.com/syncthing/syncthing/lib/config" "github.com/syncthing/syncthing/lib/config"
"github.com/syncthing/syncthing/lib/protocol" "github.com/syncthing/syncthing/lib/protocol"
"github.com/syncthing/syncthing/lib/sync" "github.com/syncthing/syncthing/lib/sync"
"github.com/syncthing/syncthing/lib/util"
) )
// Service runs a loop for discovery of IGDs (Internet Gateway Devices) and // Service runs a loop for discovery of IGDs (Internet Gateway Devices) and
// setup/renewal of a port mapping. // setup/renewal of a port mapping.
type Service struct { type Service struct {
suture.Service
id protocol.DeviceID id protocol.DeviceID
cfg config.Wrapper cfg config.Wrapper
stop chan struct{}
mappings []*Mapping mappings []*Mapping
timer *time.Timer timer *time.Timer
@ -32,27 +36,28 @@ type Service struct {
} }
func NewService(id protocol.DeviceID, cfg config.Wrapper) *Service { func NewService(id protocol.DeviceID, cfg config.Wrapper) *Service {
return &Service{ s := &Service{
id: id, id: id,
cfg: cfg, cfg: cfg,
timer: time.NewTimer(0), timer: time.NewTimer(0),
mut: sync.NewRWMutex(), mut: sync.NewRWMutex(),
} }
s.Service = util.AsService(s.serve)
return s
} }
func (s *Service) Serve() { func (s *Service) serve(stop chan struct{}) {
announce := stdsync.Once{} announce := stdsync.Once{}
s.mut.Lock() s.mut.Lock()
s.timer.Reset(0) s.timer.Reset(0)
s.stop = make(chan struct{})
s.mut.Unlock() s.mut.Unlock()
for { for {
select { select {
case <-s.timer.C: case <-s.timer.C:
if found := s.process(); found != -1 { if found := s.process(stop); found != -1 {
announce.Do(func() { announce.Do(func() {
suffix := "s" suffix := "s"
if found == 1 { if found == 1 {
@ -61,7 +66,7 @@ func (s *Service) Serve() {
l.Infoln("Detected", found, "NAT service"+suffix) l.Infoln("Detected", found, "NAT service"+suffix)
}) })
} }
case <-s.stop: case <-stop:
s.timer.Stop() s.timer.Stop()
s.mut.RLock() s.mut.RLock()
for _, mapping := range s.mappings { for _, mapping := range s.mappings {
@ -73,7 +78,7 @@ func (s *Service) Serve() {
} }
} }
func (s *Service) process() int { func (s *Service) process(stop chan struct{}) int {
// toRenew are mappings which are due for renewal // toRenew are mappings which are due for renewal
// toUpdate are the remaining mappings, which will only be updated if one of // toUpdate are the remaining mappings, which will only be updated if one of
// the old IGDs has gone away, or a new IGD has appeared, but only if we // the old IGDs has gone away, or a new IGD has appeared, but only if we
@ -115,25 +120,19 @@ func (s *Service) process() int {
return -1 return -1
} }
nats := discoverAll(time.Duration(s.cfg.Options().NATRenewalM)*time.Minute, time.Duration(s.cfg.Options().NATTimeoutS)*time.Second) nats := discoverAll(time.Duration(s.cfg.Options().NATRenewalM)*time.Minute, time.Duration(s.cfg.Options().NATTimeoutS)*time.Second, stop)
for _, mapping := range toRenew { for _, mapping := range toRenew {
s.updateMapping(mapping, nats, true) s.updateMapping(mapping, nats, true, stop)
} }
for _, mapping := range toUpdate { for _, mapping := range toUpdate {
s.updateMapping(mapping, nats, false) s.updateMapping(mapping, nats, false, stop)
} }
return len(nats) return len(nats)
} }
func (s *Service) Stop() {
s.mut.RLock()
close(s.stop)
s.mut.RUnlock()
}
func (s *Service) NewMapping(protocol Protocol, ip net.IP, port int) *Mapping { func (s *Service) NewMapping(protocol Protocol, ip net.IP, port int) *Mapping {
mapping := &Mapping{ mapping := &Mapping{
protocol: protocol, protocol: protocol,
@ -178,17 +177,17 @@ func (s *Service) RemoveMapping(mapping *Mapping) {
// acquire mappings for natds which the mapping was unaware of before. // acquire mappings for natds which the mapping was unaware of before.
// Optionally takes renew flag which indicates whether or not we should renew // Optionally takes renew flag which indicates whether or not we should renew
// mappings with existing natds // mappings with existing natds
func (s *Service) updateMapping(mapping *Mapping, nats map[string]Device, renew bool) { func (s *Service) updateMapping(mapping *Mapping, nats map[string]Device, renew bool, stop chan struct{}) {
var added, removed []Address var added, removed []Address
renewalTime := time.Duration(s.cfg.Options().NATRenewalM) * time.Minute renewalTime := time.Duration(s.cfg.Options().NATRenewalM) * time.Minute
mapping.expires = time.Now().Add(renewalTime) mapping.expires = time.Now().Add(renewalTime)
newAdded, newRemoved := s.verifyExistingMappings(mapping, nats, renew) newAdded, newRemoved := s.verifyExistingMappings(mapping, nats, renew, stop)
added = append(added, newAdded...) added = append(added, newAdded...)
removed = append(removed, newRemoved...) removed = append(removed, newRemoved...)
newAdded, newRemoved = s.acquireNewMappings(mapping, nats) newAdded, newRemoved = s.acquireNewMappings(mapping, nats, stop)
added = append(added, newAdded...) added = append(added, newAdded...)
removed = append(removed, newRemoved...) removed = append(removed, newRemoved...)
@ -197,12 +196,18 @@ func (s *Service) updateMapping(mapping *Mapping, nats map[string]Device, renew
} }
} }
func (s *Service) verifyExistingMappings(mapping *Mapping, nats map[string]Device, renew bool) ([]Address, []Address) { func (s *Service) verifyExistingMappings(mapping *Mapping, nats map[string]Device, renew bool, stop chan struct{}) ([]Address, []Address) {
var added, removed []Address var added, removed []Address
leaseTime := time.Duration(s.cfg.Options().NATLeaseM) * time.Minute leaseTime := time.Duration(s.cfg.Options().NATLeaseM) * time.Minute
for id, address := range mapping.addressMap() { for id, address := range mapping.addressMap() {
select {
case <-stop:
return nil, nil
default:
}
// Delete addresses for NATDevice's that do not exist anymore // Delete addresses for NATDevice's that do not exist anymore
nat, ok := nats[id] nat, ok := nats[id]
if !ok { if !ok {
@ -242,13 +247,19 @@ func (s *Service) verifyExistingMappings(mapping *Mapping, nats map[string]Devic
return added, removed return added, removed
} }
func (s *Service) acquireNewMappings(mapping *Mapping, nats map[string]Device) ([]Address, []Address) { func (s *Service) acquireNewMappings(mapping *Mapping, nats map[string]Device, stop chan struct{}) ([]Address, []Address) {
var added, removed []Address var added, removed []Address
leaseTime := time.Duration(s.cfg.Options().NATLeaseM) * time.Minute leaseTime := time.Duration(s.cfg.Options().NATLeaseM) * time.Minute
addrMap := mapping.addressMap() addrMap := mapping.addressMap()
for id, nat := range nats { for id, nat := range nats {
select {
case <-stop:
return nil, nil
default:
}
if _, ok := addrMap[id]; ok { if _, ok := addrMap[id]; ok {
continue continue
} }

View File

@ -69,15 +69,7 @@ func (c *dynamicClient) serve(stop chan struct{}) error {
addrs = append(addrs, ruri.String()) addrs = append(addrs, ruri.String())
} }
defer func() { for _, addr := range relayAddressesOrder(addrs, stop) {
c.mut.RLock()
if c.client != nil {
c.client.Stop()
}
c.mut.RUnlock()
}()
for _, addr := range relayAddressesOrder(addrs) {
select { select {
case <-stop: case <-stop:
l.Debugln(c, "stopping") l.Debugln(c, "stopping")
@ -104,6 +96,15 @@ func (c *dynamicClient) serve(stop chan struct{}) error {
return fmt.Errorf("could not find a connectable relay") return fmt.Errorf("could not find a connectable relay")
} }
func (c *dynamicClient) Stop() {
c.mut.RLock()
if c.client != nil {
c.client.Stop()
}
c.mut.RUnlock()
c.commonClient.Stop()
}
func (c *dynamicClient) Error() error { func (c *dynamicClient) Error() error {
c.mut.RLock() c.mut.RLock()
defer c.mut.RUnlock() defer c.mut.RUnlock()
@ -147,7 +148,7 @@ type dynamicAnnouncement struct {
// the closest 50ms, and puts them in buckets of 50ms latency ranges. Then // the closest 50ms, and puts them in buckets of 50ms latency ranges. Then
// shuffles each bucket, and returns all addresses starting with the ones from // shuffles each bucket, and returns all addresses starting with the ones from
// the lowest latency bucket, ending with the highest latency buceket. // the lowest latency bucket, ending with the highest latency buceket.
func relayAddressesOrder(input []string) []string { func relayAddressesOrder(input []string, stop chan struct{}) []string {
buckets := make(map[int][]string) buckets := make(map[int][]string)
for _, relay := range input { for _, relay := range input {
@ -159,6 +160,12 @@ func relayAddressesOrder(input []string) []string {
id := int(latency/time.Millisecond) / 50 id := int(latency/time.Millisecond) / 50
buckets[id] = append(buckets[id], relay) buckets[id] = append(buckets[id], relay)
select {
case <-stop:
return nil
default:
}
} }
var ids []int var ids []int

View File

@ -109,8 +109,8 @@ func New(cfg config.Wrapper, subscriber Subscriber, conn net.PacketConn) (*Servi
} }
func (s *Service) Stop() { func (s *Service) Stop() {
s.Service.Stop()
_ = s.stunConn.Close() _ = s.stunConn.Close()
s.Service.Stop()
} }
func (s *Service) serve(stop chan struct{}) { func (s *Service) serve(stop chan struct{}) {
@ -163,7 +163,11 @@ func (s *Service) serve(stop chan struct{}) {
// We failed to contact all provided stun servers or the nat is not punchable. // We failed to contact all provided stun servers or the nat is not punchable.
// Chillout for a while. // Chillout for a while.
time.Sleep(stunRetryInterval) select {
case <-time.After(stunRetryInterval):
case <-stop:
return
}
} }
} }

View File

@ -187,6 +187,7 @@ func AsService(fn func(stop chan struct{})) suture.Service {
type ServiceWithError interface { type ServiceWithError interface {
suture.Service suture.Service
Error() error Error() error
SetError(error)
} }
// AsServiceWithError does the same as AsService, except that it keeps track // AsServiceWithError does the same as AsService, except that it keeps track
@ -244,3 +245,9 @@ func (s *service) Error() error {
defer s.mut.Unlock() defer s.mut.Unlock()
return s.err return s.err
} }
func (s *service) SetError(err error) {
s.mut.Lock()
s.err = err
s.mut.Unlock()
}