From c49d864f14ffd477bb2574e1c56fc36aa8dbfab1 Mon Sep 17 00:00:00 2001 From: Jakob Borg Date: Mon, 26 Mar 2018 06:56:50 -0400 Subject: [PATCH] lib/connections: Slightly refactor limiter juggling Two small behavior changes: don't "charge" the data to the global rate limit until it's been accepted by the device specific limiter, and fix the send/recv direction in the log print on per device rate limits. --- lib/config/config.go | 18 +++--- lib/connections/limiter.go | 110 ++++++++++++++++++++++--------------- 2 files changed, 74 insertions(+), 54 deletions(-) diff --git a/lib/config/config.go b/lib/config/config.go index 9d7820b8..1cda1630 100644 --- a/lib/config/config.go +++ b/lib/config/config.go @@ -377,6 +377,15 @@ func (cfg *Configuration) clean() error { return nil } +// DeviceMap returns a map of device ID to device configuration for the given configuration. +func (cfg *Configuration) DeviceMap() map[protocol.DeviceID]DeviceConfiguration { + m := make(map[protocol.DeviceID]DeviceConfiguration, len(cfg.Devices)) + for _, dev := range cfg.Devices { + m[dev.DeviceID] = dev + } + return m +} + func convertV27V28(cfg *Configuration) { // Show a notification about enabling filesystem watching cfg.Options.UnackedNotificationIDs = append(cfg.Options.UnackedNotificationIDs, "fsWatcherNotification") @@ -797,12 +806,3 @@ func filterURLSchemePrefix(addrs []string, prefix string) []string { } return addrs } - -// mapDeviceConfigs returns a map of device ID to device configuration for the given configuration. -func (cfg *Configuration) DeviceMap() map[protocol.DeviceID]DeviceConfiguration { - m := make(map[protocol.DeviceID]DeviceConfiguration, len(cfg.Devices)) - for _, dev := range cfg.Devices { - m[dev.DeviceID] = dev - } - return m -} diff --git a/lib/connections/limiter.go b/lib/connections/limiter.go index 119aa39d..50929447 100644 --- a/lib/connections/limiter.go +++ b/lib/connections/limiter.go @@ -21,12 +21,17 @@ import ( // limiter manages a read and write rate limit, reacting to config changes // as appropriate. type limiter struct { + mu sync.Mutex write *rate.Limiter read *rate.Limiter limitsLAN atomicBool deviceReadLimiters map[protocol.DeviceID]*rate.Limiter deviceWriteLimiters map[protocol.DeviceID]*rate.Limiter - mu sync.Mutex +} + +type waiter interface { + // This is the rate limiting operation + WaitN(ctx context.Context, n int) error } const limiterBurstSize = 4 * 128 << 10 @@ -96,7 +101,7 @@ func (lim *limiter) processDevicesConfigurationLocked(from, to config.Configurat writeLimitStr = fmt.Sprintf("limit is %d KiB/s", dev.MaxSendKbps) } - l.Infof("Device %s send rate %s, receive rate %s", dev.DeviceID, readLimitStr, writeLimitStr) + l.Infof("Device %s send rate %s, receive rate %s", dev.DeviceID, writeLimitStr, readLimitStr) } } @@ -169,49 +174,76 @@ func (lim *limiter) String() string { return "connections.limiter" } -func (lim *limiter) getLimiters(remoteID protocol.DeviceID, c internalConn, isLAN bool) (io.Reader, io.Writer) { +func (lim *limiter) getLimiters(remoteID protocol.DeviceID, rw io.ReadWriter, isLAN bool) (io.Reader, io.Writer) { lim.mu.Lock() - wr := lim.newLimitedWriterLocked(remoteID, c, isLAN) - rd := lim.newLimitedReaderLocked(remoteID, c, isLAN) + wr := lim.newLimitedWriterLocked(remoteID, rw, isLAN) + rd := lim.newLimitedReaderLocked(remoteID, rw, isLAN) lim.mu.Unlock() return rd, wr } func (lim *limiter) newLimitedReaderLocked(remoteID protocol.DeviceID, r io.Reader, isLAN bool) io.Reader { - return &limitedReader{reader: r, limiter: lim, deviceLimiter: lim.getReadLimiterLocked(remoteID), isLAN: isLAN} + return &limitedReader{ + reader: r, + limitsLAN: &lim.limitsLAN, + waiter: totalWaiter{lim.getReadLimiterLocked(remoteID), lim.read}, + isLAN: isLAN, + } } func (lim *limiter) newLimitedWriterLocked(remoteID protocol.DeviceID, w io.Writer, isLAN bool) io.Writer { - return &limitedWriter{writer: w, limiter: lim, deviceLimiter: lim.getWriteLimiterLocked(remoteID), isLAN: isLAN} + return &limitedWriter{ + writer: w, + limitsLAN: &lim.limitsLAN, + waiter: totalWaiter{lim.getWriteLimiterLocked(remoteID), lim.write}, + isLAN: isLAN, + } +} + +func (lim *limiter) getReadLimiterLocked(deviceID protocol.DeviceID) *rate.Limiter { + return getRateLimiter(lim.deviceReadLimiters, deviceID) +} + +func (lim *limiter) getWriteLimiterLocked(deviceID protocol.DeviceID) *rate.Limiter { + return getRateLimiter(lim.deviceWriteLimiters, deviceID) +} + +func getRateLimiter(m map[protocol.DeviceID]*rate.Limiter, deviceID protocol.DeviceID) *rate.Limiter { + limiter, ok := m[deviceID] + if !ok { + limiter = rate.NewLimiter(rate.Inf, limiterBurstSize) + m[deviceID] = limiter + } + return limiter } // limitedReader is a rate limited io.Reader type limitedReader struct { - reader io.Reader - limiter *limiter - deviceLimiter *rate.Limiter - isLAN bool + reader io.Reader + limitsLAN *atomicBool + waiter waiter + isLAN bool } func (r *limitedReader) Read(buf []byte) (int, error) { n, err := r.reader.Read(buf) - if !r.isLAN || r.limiter.limitsLAN.get() { - take(r.limiter.read, r.deviceLimiter, n) + if !r.isLAN || r.limitsLAN.get() { + take(r.waiter, n) } return n, err } // limitedWriter is a rate limited io.Writer type limitedWriter struct { - writer io.Writer - limiter *limiter - deviceLimiter *rate.Limiter - isLAN bool + writer io.Writer + limitsLAN *atomicBool + waiter waiter + isLAN bool } func (w *limitedWriter) Write(buf []byte) (int, error) { - if !w.isLAN || w.limiter.limitsLAN.get() { - take(w.limiter.write, w.deviceLimiter, len(buf)) + if !w.isLAN || w.limitsLAN.get() { + take(w.waiter, len(buf)) } return w.writer.Write(buf) } @@ -219,24 +251,21 @@ func (w *limitedWriter) Write(buf []byte) (int, error) { // take is a utility function to consume tokens from a overall rate.Limiter and deviceLimiter. // No call to WaitN can be larger than the limiter burst size so we split it up into // several calls when necessary. -func take(overallLimiter, deviceLimiter *rate.Limiter, tokens int) { +func take(waiter waiter, tokens int) { if tokens < limiterBurstSize { // This is the by far more common case so we get it out of the way // early. - deviceLimiter.WaitN(context.TODO(), tokens) - overallLimiter.WaitN(context.TODO(), tokens) + waiter.WaitN(context.TODO(), tokens) return } for tokens > 0 { // Consume limiterBurstSize tokens at a time until we're done. if tokens > limiterBurstSize { - deviceLimiter.WaitN(context.TODO(), limiterBurstSize) - overallLimiter.WaitN(context.TODO(), limiterBurstSize) + waiter.WaitN(context.TODO(), limiterBurstSize) tokens -= limiterBurstSize } else { - deviceLimiter.WaitN(context.TODO(), tokens) - overallLimiter.WaitN(context.TODO(), tokens) + waiter.WaitN(context.TODO(), tokens) tokens = 0 } } @@ -256,25 +285,16 @@ func (b *atomicBool) get() bool { return atomic.LoadInt32((*int32)(b)) != 0 } -// Utility functions for atomic operations on device limiters map -func (lim *limiter) getWriteLimiterLocked(deviceID protocol.DeviceID) *rate.Limiter { - limiter, ok := lim.deviceWriteLimiters[deviceID] +// totalWaiter waits for all of the waiters +type totalWaiter []waiter - if !ok { - limiter = rate.NewLimiter(rate.Inf, limiterBurstSize) - lim.deviceWriteLimiters[deviceID] = limiter +func (tw totalWaiter) WaitN(ctx context.Context, n int) error { + for _, w := range tw { + if err := w.WaitN(ctx, n); err != nil { + // error here is context cancellation, most likely, so we abort + // early + return err + } } - - return limiter -} - -func (lim *limiter) getReadLimiterLocked(deviceID protocol.DeviceID) *rate.Limiter { - limiter, ok := lim.deviceReadLimiters[deviceID] - - if !ok { - limiter = rate.NewLimiter(rate.Inf, limiterBurstSize) - lim.deviceReadLimiters[deviceID] = limiter - } - - return limiter + return nil }