lib: Replace done channel with contexts in and add names to util services (#6166)

This commit is contained in:
Simon Frei
2019-11-21 08:41:15 +01:00
committed by GitHub
parent 552ea68672
commit 90d85fd0a2
34 changed files with 240 additions and 218 deletions

View File

@@ -3,6 +3,7 @@
package client
import (
"context"
"crypto/tls"
"fmt"
"net/url"
@@ -51,16 +52,16 @@ type commonClient struct {
mut sync.RWMutex
}
func newCommonClient(invitations chan protocol.SessionInvitation, serve func(chan struct{}) error) commonClient {
func newCommonClient(invitations chan protocol.SessionInvitation, serve func(context.Context) error, creator string) commonClient {
c := commonClient{
invitations: invitations,
mut: sync.NewRWMutex(),
}
newServe := func(stop chan struct{}) error {
newServe := func(ctx context.Context) error {
defer c.cleanup()
return serve(stop)
return serve(ctx)
}
c.ServiceWithError = util.AsServiceWithError(newServe)
c.ServiceWithError = util.AsServiceWithError(newServe, creator)
if c.invitations == nil {
c.closeInvitationsOnFinish = true
c.invitations = make(chan protocol.SessionInvitation)

View File

@@ -3,6 +3,7 @@
package client
import (
"context"
"crypto/tls"
"encoding/json"
"fmt"
@@ -32,11 +33,11 @@ func newDynamicClient(uri *url.URL, certs []tls.Certificate, invitations chan pr
certs: certs,
timeout: timeout,
}
c.commonClient = newCommonClient(invitations, c.serve)
c.commonClient = newCommonClient(invitations, c.serve, c.String())
return c
}
func (c *dynamicClient) serve(stop chan struct{}) error {
func (c *dynamicClient) serve(ctx context.Context) error {
uri := *c.pooladdr
// Trim off the `dynamic+` prefix
@@ -69,9 +70,9 @@ func (c *dynamicClient) serve(stop chan struct{}) error {
addrs = append(addrs, ruri.String())
}
for _, addr := range relayAddressesOrder(addrs, stop) {
for _, addr := range relayAddressesOrder(ctx, addrs) {
select {
case <-stop:
case <-ctx.Done():
l.Debugln(c, "stopping")
return nil
default:
@@ -148,7 +149,7 @@ type dynamicAnnouncement struct {
// the closest 50ms, and puts them in buckets of 50ms latency ranges. Then
// shuffles each bucket, and returns all addresses starting with the ones from
// the lowest latency bucket, ending with the highest latency buceket.
func relayAddressesOrder(input []string, stop chan struct{}) []string {
func relayAddressesOrder(ctx context.Context, input []string) []string {
buckets := make(map[int][]string)
for _, relay := range input {
@@ -162,7 +163,7 @@ func relayAddressesOrder(input []string, stop chan struct{}) []string {
buckets[id] = append(buckets[id], relay)
select {
case <-stop:
case <-ctx.Done():
return nil
default:
}

View File

@@ -3,6 +3,7 @@
package client
import (
"context"
"crypto/tls"
"fmt"
"net"
@@ -39,11 +40,11 @@ func newStaticClient(uri *url.URL, certs []tls.Certificate, invitations chan pro
messageTimeout: time.Minute * 2,
connectTimeout: timeout,
}
c.commonClient = newCommonClient(invitations, c.serve)
c.commonClient = newCommonClient(invitations, c.serve, c.String())
return c
}
func (c *staticClient) serve(stop chan struct{}) error {
func (c *staticClient) serve(ctx context.Context) error {
if err := c.connect(); err != nil {
l.Infof("Could not connect to relay %s: %s", c.uri, err)
return err
@@ -72,7 +73,7 @@ func (c *staticClient) serve(stop chan struct{}) error {
messages := make(chan interface{})
errors := make(chan error, 1)
go messageReader(c.conn, messages, errors, stop)
go messageReader(ctx, c.conn, messages, errors)
timeout := time.NewTimer(c.messageTimeout)
@@ -106,7 +107,7 @@ func (c *staticClient) serve(stop chan struct{}) error {
return fmt.Errorf("protocol error: unexpected message %v", msg)
}
case <-stop:
case <-ctx.Done():
l.Debugln(c, "stopping")
return nil
@@ -241,7 +242,7 @@ func performHandshakeAndValidation(conn *tls.Conn, uri *url.URL) error {
return nil
}
func messageReader(conn net.Conn, messages chan<- interface{}, errors chan<- error, stop chan struct{}) {
func messageReader(ctx context.Context, conn net.Conn, messages chan<- interface{}, errors chan<- error) {
for {
msg, err := protocol.ReadMessage(conn)
if err != nil {
@@ -250,7 +251,7 @@ func messageReader(conn net.Conn, messages chan<- interface{}, errors chan<- err
}
select {
case messages <- msg:
case <-stop:
case <-ctx.Done():
return
}
}