all: Use context in lib/dialer (#6177)
* all: Use context in lib/dialer * a bit slimmer * https://github.com/syncthing/syncthing/pull/5753 * bot * missed adding debug.go * errors.Cause * simultaneous dialing * anti-leak
This commit is contained in:
committed by
Audrius Butkevicius
parent
4e151d380c
commit
1bae4b7f50
@@ -42,7 +42,7 @@ type quicDialer struct {
|
||||
commonDialer
|
||||
}
|
||||
|
||||
func (d *quicDialer) Dial(_ protocol.DeviceID, uri *url.URL) (internalConn, error) {
|
||||
func (d *quicDialer) Dial(ctx context.Context, _ protocol.DeviceID, uri *url.URL) (internalConn, error) {
|
||||
uri = fixupPort(uri, config.DefaultQUICPort)
|
||||
|
||||
addr, err := net.ResolveUDPAddr("udp", uri.Host)
|
||||
@@ -66,7 +66,7 @@ func (d *quicDialer) Dial(_ protocol.DeviceID, uri *url.URL) (internalConn, erro
|
||||
}
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), quicOperationTimeout)
|
||||
ctx, cancel := context.WithTimeout(ctx, quicOperationTimeout)
|
||||
defer cancel()
|
||||
|
||||
session, err := quic.DialContext(ctx, conn, addr, uri.Host, d.tlsCfg, quicConfig)
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
package connections
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"net/url"
|
||||
"time"
|
||||
@@ -27,13 +28,13 @@ type relayDialer struct {
|
||||
commonDialer
|
||||
}
|
||||
|
||||
func (d *relayDialer) Dial(id protocol.DeviceID, uri *url.URL) (internalConn, error) {
|
||||
inv, err := client.GetInvitationFromRelay(uri, id, d.tlsCfg.Certificates, 10*time.Second)
|
||||
func (d *relayDialer) Dial(ctx context.Context, id protocol.DeviceID, uri *url.URL) (internalConn, error) {
|
||||
inv, err := client.GetInvitationFromRelay(ctx, uri, id, d.tlsCfg.Certificates, 10*time.Second)
|
||||
if err != nil {
|
||||
return internalConn{}, err
|
||||
}
|
||||
|
||||
conn, err := client.JoinSession(inv)
|
||||
conn, err := client.JoinSession(ctx, inv)
|
||||
if err != nil {
|
||||
return internalConn{}, err
|
||||
}
|
||||
|
||||
@@ -13,6 +13,8 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/syncthing/syncthing/lib/config"
|
||||
"github.com/syncthing/syncthing/lib/dialer"
|
||||
"github.com/syncthing/syncthing/lib/nat"
|
||||
@@ -70,9 +72,11 @@ func (t *relayListener) serve(ctx context.Context) error {
|
||||
return err
|
||||
}
|
||||
|
||||
conn, err := client.JoinSession(inv)
|
||||
conn, err := client.JoinSession(ctx, inv)
|
||||
if err != nil {
|
||||
l.Infoln("Listen (BEP/relay): joining session:", err)
|
||||
if errors.Cause(err) != context.Canceled {
|
||||
l.Infoln("Listen (BEP/relay): joining session:", err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
|
||||
@@ -9,7 +9,6 @@ package connections
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/url"
|
||||
@@ -31,6 +30,7 @@ import (
|
||||
_ "github.com/syncthing/syncthing/lib/pmp"
|
||||
_ "github.com/syncthing/syncthing/lib/upnp"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/thejerf/suture"
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
@@ -463,7 +463,7 @@ func (s *service) connect(ctx context.Context) {
|
||||
})
|
||||
}
|
||||
|
||||
conn, ok := s.dialParallel(deviceCfg.DeviceID, dialTargets)
|
||||
conn, ok := s.dialParallel(ctx, deviceCfg.DeviceID, dialTargets)
|
||||
if ok {
|
||||
s.conns <- conn
|
||||
}
|
||||
@@ -701,6 +701,10 @@ func (s *service) ConnectionStatus() map[string]ConnectionStatusEntry {
|
||||
}
|
||||
|
||||
func (s *service) setConnectionStatus(address string, err error) {
|
||||
if errors.Cause(err) != context.Canceled {
|
||||
return
|
||||
}
|
||||
|
||||
status := ConnectionStatusEntry{When: time.Now().UTC().Truncate(time.Second)}
|
||||
if err != nil {
|
||||
errStr := err.Error()
|
||||
@@ -828,7 +832,7 @@ func IsAllowedNetwork(host string, allowed []string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (s *service) dialParallel(deviceID protocol.DeviceID, dialTargets []dialTarget) (internalConn, bool) {
|
||||
func (s *service) dialParallel(ctx context.Context, deviceID protocol.DeviceID, dialTargets []dialTarget) (internalConn, bool) {
|
||||
// Group targets into buckets by priority
|
||||
dialTargetBuckets := make(map[int][]dialTarget, len(dialTargets))
|
||||
for _, tgt := range dialTargets {
|
||||
@@ -851,7 +855,7 @@ func (s *service) dialParallel(deviceID protocol.DeviceID, dialTargets []dialTar
|
||||
for _, tgt := range tgts {
|
||||
wg.Add(1)
|
||||
go func(tgt dialTarget) {
|
||||
conn, err := tgt.Dial()
|
||||
conn, err := tgt.Dial(ctx)
|
||||
if err == nil {
|
||||
// Closes the connection on error
|
||||
err = s.validateIdentity(conn, deviceID)
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
package connections
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -164,7 +165,7 @@ func (d *commonDialer) RedialFrequency() time.Duration {
|
||||
}
|
||||
|
||||
type genericDialer interface {
|
||||
Dial(protocol.DeviceID, *url.URL) (internalConn, error)
|
||||
Dial(context.Context, protocol.DeviceID, *url.URL) (internalConn, error)
|
||||
RedialFrequency() time.Duration
|
||||
}
|
||||
|
||||
@@ -223,7 +224,7 @@ type dialTarget struct {
|
||||
deviceID protocol.DeviceID
|
||||
}
|
||||
|
||||
func (t dialTarget) Dial() (internalConn, error) {
|
||||
func (t dialTarget) Dial(ctx context.Context) (internalConn, error) {
|
||||
l.Debugln("dialing", t.deviceID, t.uri, "prio", t.priority)
|
||||
return t.dialer.Dial(t.deviceID, t.uri)
|
||||
return t.dialer.Dial(ctx, t.deviceID, t.uri)
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
package connections
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"net/url"
|
||||
"time"
|
||||
@@ -29,10 +30,12 @@ type tcpDialer struct {
|
||||
commonDialer
|
||||
}
|
||||
|
||||
func (d *tcpDialer) Dial(_ protocol.DeviceID, uri *url.URL) (internalConn, error) {
|
||||
func (d *tcpDialer) Dial(ctx context.Context, _ protocol.DeviceID, uri *url.URL) (internalConn, error) {
|
||||
uri = fixupPort(uri, config.DefaultTCPPort)
|
||||
|
||||
conn, err := dialer.DialTimeout(uri.Scheme, uri.Host, 10*time.Second)
|
||||
timeoutCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
||||
defer cancel()
|
||||
conn, err := dialer.DialContext(timeoutCtx, uri.Scheme, uri.Host)
|
||||
if err != nil {
|
||||
return internalConn{}, err
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user