all: Use context in lib/dialer (#6177)

* all: Use context in lib/dialer

* a bit slimmer

* https://github.com/syncthing/syncthing/pull/5753

* bot

* missed adding debug.go

* errors.Cause

* simultaneous dialing

* anti-leak
This commit is contained in:
Simon Frei
2019-11-26 08:39:51 +01:00
committed by Audrius Butkevicius
parent 4e151d380c
commit 1bae4b7f50
24 changed files with 175 additions and 204 deletions

View File

@@ -42,7 +42,7 @@ type quicDialer struct {
commonDialer
}
func (d *quicDialer) Dial(_ protocol.DeviceID, uri *url.URL) (internalConn, error) {
func (d *quicDialer) Dial(ctx context.Context, _ protocol.DeviceID, uri *url.URL) (internalConn, error) {
uri = fixupPort(uri, config.DefaultQUICPort)
addr, err := net.ResolveUDPAddr("udp", uri.Host)
@@ -66,7 +66,7 @@ func (d *quicDialer) Dial(_ protocol.DeviceID, uri *url.URL) (internalConn, erro
}
}
ctx, cancel := context.WithTimeout(context.Background(), quicOperationTimeout)
ctx, cancel := context.WithTimeout(ctx, quicOperationTimeout)
defer cancel()
session, err := quic.DialContext(ctx, conn, addr, uri.Host, d.tlsCfg, quicConfig)

View File

@@ -7,6 +7,7 @@
package connections
import (
"context"
"crypto/tls"
"net/url"
"time"
@@ -27,13 +28,13 @@ type relayDialer struct {
commonDialer
}
func (d *relayDialer) Dial(id protocol.DeviceID, uri *url.URL) (internalConn, error) {
inv, err := client.GetInvitationFromRelay(uri, id, d.tlsCfg.Certificates, 10*time.Second)
func (d *relayDialer) Dial(ctx context.Context, id protocol.DeviceID, uri *url.URL) (internalConn, error) {
inv, err := client.GetInvitationFromRelay(ctx, uri, id, d.tlsCfg.Certificates, 10*time.Second)
if err != nil {
return internalConn{}, err
}
conn, err := client.JoinSession(inv)
conn, err := client.JoinSession(ctx, inv)
if err != nil {
return internalConn{}, err
}

View File

@@ -13,6 +13,8 @@ import (
"sync"
"time"
"github.com/pkg/errors"
"github.com/syncthing/syncthing/lib/config"
"github.com/syncthing/syncthing/lib/dialer"
"github.com/syncthing/syncthing/lib/nat"
@@ -70,9 +72,11 @@ func (t *relayListener) serve(ctx context.Context) error {
return err
}
conn, err := client.JoinSession(inv)
conn, err := client.JoinSession(ctx, inv)
if err != nil {
l.Infoln("Listen (BEP/relay): joining session:", err)
if errors.Cause(err) != context.Canceled {
l.Infoln("Listen (BEP/relay): joining session:", err)
}
continue
}

View File

@@ -9,7 +9,6 @@ package connections
import (
"context"
"crypto/tls"
"errors"
"fmt"
"net"
"net/url"
@@ -31,6 +30,7 @@ import (
_ "github.com/syncthing/syncthing/lib/pmp"
_ "github.com/syncthing/syncthing/lib/upnp"
"github.com/pkg/errors"
"github.com/thejerf/suture"
"golang.org/x/time/rate"
)
@@ -463,7 +463,7 @@ func (s *service) connect(ctx context.Context) {
})
}
conn, ok := s.dialParallel(deviceCfg.DeviceID, dialTargets)
conn, ok := s.dialParallel(ctx, deviceCfg.DeviceID, dialTargets)
if ok {
s.conns <- conn
}
@@ -701,6 +701,10 @@ func (s *service) ConnectionStatus() map[string]ConnectionStatusEntry {
}
func (s *service) setConnectionStatus(address string, err error) {
if errors.Cause(err) != context.Canceled {
return
}
status := ConnectionStatusEntry{When: time.Now().UTC().Truncate(time.Second)}
if err != nil {
errStr := err.Error()
@@ -828,7 +832,7 @@ func IsAllowedNetwork(host string, allowed []string) bool {
return false
}
func (s *service) dialParallel(deviceID protocol.DeviceID, dialTargets []dialTarget) (internalConn, bool) {
func (s *service) dialParallel(ctx context.Context, deviceID protocol.DeviceID, dialTargets []dialTarget) (internalConn, bool) {
// Group targets into buckets by priority
dialTargetBuckets := make(map[int][]dialTarget, len(dialTargets))
for _, tgt := range dialTargets {
@@ -851,7 +855,7 @@ func (s *service) dialParallel(deviceID protocol.DeviceID, dialTargets []dialTar
for _, tgt := range tgts {
wg.Add(1)
go func(tgt dialTarget) {
conn, err := tgt.Dial()
conn, err := tgt.Dial(ctx)
if err == nil {
// Closes the connection on error
err = s.validateIdentity(conn, deviceID)

View File

@@ -7,6 +7,7 @@
package connections
import (
"context"
"crypto/tls"
"fmt"
"io"
@@ -164,7 +165,7 @@ func (d *commonDialer) RedialFrequency() time.Duration {
}
type genericDialer interface {
Dial(protocol.DeviceID, *url.URL) (internalConn, error)
Dial(context.Context, protocol.DeviceID, *url.URL) (internalConn, error)
RedialFrequency() time.Duration
}
@@ -223,7 +224,7 @@ type dialTarget struct {
deviceID protocol.DeviceID
}
func (t dialTarget) Dial() (internalConn, error) {
func (t dialTarget) Dial(ctx context.Context) (internalConn, error) {
l.Debugln("dialing", t.deviceID, t.uri, "prio", t.priority)
return t.dialer.Dial(t.deviceID, t.uri)
return t.dialer.Dial(ctx, t.deviceID, t.uri)
}

View File

@@ -7,6 +7,7 @@
package connections
import (
"context"
"crypto/tls"
"net/url"
"time"
@@ -29,10 +30,12 @@ type tcpDialer struct {
commonDialer
}
func (d *tcpDialer) Dial(_ protocol.DeviceID, uri *url.URL) (internalConn, error) {
func (d *tcpDialer) Dial(ctx context.Context, _ protocol.DeviceID, uri *url.URL) (internalConn, error) {
uri = fixupPort(uri, config.DefaultTCPPort)
conn, err := dialer.DialTimeout(uri.Scheme, uri.Host, 10*time.Second)
timeoutCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
defer cancel()
conn, err := dialer.DialContext(timeoutCtx, uri.Scheme, uri.Host)
if err != nil {
return internalConn{}, err
}