diff --git a/lib/discover/cache.go b/lib/discover/cache.go index ad1faf9b..60685880 100644 --- a/lib/discover/cache.go +++ b/lib/discover/cache.go @@ -26,7 +26,7 @@ type CachingMux struct { *suture.Supervisor finders []cachedFinder caches []*cache - mut sync.Mutex + mut sync.RWMutex } // A cachedFinder is a Finder with associated cache timeouts. @@ -54,7 +54,7 @@ type cachedError interface { func NewCachingMux() *CachingMux { return &CachingMux{ Supervisor: suture.NewSimple("discover.cachingMux"), - mut: sync.NewMutex(), + mut: sync.NewRWMutex(), } } @@ -75,7 +75,7 @@ func (m *CachingMux) Add(finder Finder, cacheTime, negCacheTime time.Duration, p func (m *CachingMux) Lookup(deviceID protocol.DeviceID) (direct []string, relays []Relay, err error) { var pdirect []prioritizedAddress - m.mut.Lock() + m.mut.RLock() for i, finder := range m.finders { if cacheEntry, ok := m.caches[i].Get(deviceID); ok { // We have a cache entry. Lets see what it says. @@ -129,7 +129,7 @@ func (m *CachingMux) Lookup(deviceID protocol.DeviceID) (direct []string, relays m.caches[i].Set(deviceID, entry) } } - m.mut.Unlock() + m.mut.RUnlock() direct = uniqueSortedAddrs(pdirect) relays = uniqueSortedRelays(relays) @@ -149,12 +149,12 @@ func (m *CachingMux) Error() error { } func (m *CachingMux) ChildErrors() map[string]error { - m.mut.Lock() children := make(map[string]error, len(m.finders)) + m.mut.RLock() for _, f := range m.finders { children[f.String()] = f.Error() } - m.mut.Unlock() + m.mut.RUnlock() return children } @@ -163,7 +163,7 @@ func (m *CachingMux) Cache() map[protocol.DeviceID]CacheEntry { // children's caches. res := make(map[protocol.DeviceID]CacheEntry) - m.mut.Lock() + m.mut.RLock() for i := range m.finders { // Each finder[i] has a corresponding cache at cache[i]. Go through it // and populate the total, if it's newer than what's already in there. @@ -183,7 +183,7 @@ func (m *CachingMux) Cache() map[protocol.DeviceID]CacheEntry { } } } - m.mut.Unlock() + m.mut.RUnlock() return res } diff --git a/lib/discover/cache_test.go b/lib/discover/cache_test.go index 0205e219..c7900c2f 100644 --- a/lib/discover/cache_test.go +++ b/lib/discover/cache_test.go @@ -91,3 +91,55 @@ func (f *fakeDiscovery) String() string { func (f *fakeDiscovery) Cache() map[protocol.DeviceID]CacheEntry { return nil } + +func TestCacheSlowLookup(t *testing.T) { + c := NewCachingMux() + c.ServeBackground() + defer c.Stop() + + // Add a slow discovery service. + + started := make(chan struct{}) + f1 := &slowDiscovery{time.Second, started} + c.Add(f1, time.Minute, 0, 0) + + // Start a lookup, which will take at least a second + + t0 := time.Now() + go c.Lookup(protocol.LocalDeviceID) + <-started // The slow lookup method has been called so we're inside the lock + + // It should be possible to get ChildErrors while it's running + + c.ChildErrors() + + // Only a small amount of time should have passed, not the full second + + diff := time.Since(t0) + if diff > 500*time.Millisecond { + t.Error("ChildErrors was blocked for", diff) + } +} + +type slowDiscovery struct { + delay time.Duration + started chan struct{} +} + +func (f *slowDiscovery) Lookup(deviceID protocol.DeviceID) (direct []string, relays []Relay, err error) { + close(f.started) + time.Sleep(f.delay) + return nil, nil, nil +} + +func (f *slowDiscovery) Error() error { + return nil +} + +func (f *slowDiscovery) String() string { + return "fake" +} + +func (f *slowDiscovery) Cache() map[protocol.DeviceID]CacheEntry { + return nil +}