diff --git a/internal/files/blockmap.go b/internal/files/blockmap.go index ed8f6ef8..1dd246b0 100644 --- a/internal/files/blockmap.go +++ b/internal/files/blockmap.go @@ -90,6 +90,17 @@ func (m *BlockMap) Update(files []protocol.FileInfo) error { return m.db.Write(batch, nil) } +// Discard block map state, removing the given files +func (m *BlockMap) Discard(files []protocol.FileInfo) error { + batch := new(leveldb.Batch) + for _, file := range files { + for _, block := range file.Blocks { + batch.Delete(m.blockKey(block.Hash, file.Name)) + } + } + return m.db.Write(batch, nil) +} + // Drop block map, removing all entries related to this block map from the db. func (m *BlockMap) Drop() error { batch := new(leveldb.Batch) diff --git a/internal/files/set.go b/internal/files/set.go index 112c7386..1b7e8296 100644 --- a/internal/files/set.go +++ b/internal/files/set.go @@ -111,12 +111,22 @@ func (s *Set) Update(device protocol.DeviceID, fs []protocol.FileInfo) { normalizeFilenames(fs) s.mutex.Lock() defer s.mutex.Unlock() + if device == protocol.LocalDeviceID { + discards := make([]protocol.FileInfo, 0, len(fs)) + updates := make([]protocol.FileInfo, 0, len(fs)) + for _, newFile := range fs { + existingFile := ldbGet(s.db, []byte(s.folder), device[:], []byte(newFile.Name)) + if existingFile.Version <= newFile.Version { + discards = append(discards, existingFile) + updates = append(updates, newFile) + } + } + s.blockmap.Discard(discards) + s.blockmap.Update(updates) + } if lv := ldbUpdate(s.db, []byte(s.folder), device[:], fs); lv > s.localVersion[device] { s.localVersion[device] = lv } - if device == protocol.LocalDeviceID { - s.blockmap.Update(fs) - } } func (s *Set) WithNeed(device protocol.DeviceID, fn fileIterator) { diff --git a/internal/model/puller_test.go b/internal/model/puller_test.go index 8c76c30d..451b694f 100644 --- a/internal/model/puller_test.go +++ b/internal/model/puller_test.go @@ -250,3 +250,58 @@ func TestCopierFinder(t *testing.T) { os.Remove(tempFile) } + +// Test that updating a file removes it's old blocks from the blockmap +func TestCopierCleanup(t *testing.T) { + iterFn := func(folder, file string, index uint32) bool { + return true + } + + fcfg := config.FolderConfiguration{ID: "default", Path: "testdata"} + cfg := config.Configuration{Folders: []config.FolderConfiguration{fcfg}} + + db, _ := leveldb.Open(storage.NewMemStorage(), nil) + m := NewModel(config.Wrap("/tmp/test", cfg), "device", "syncthing", "dev", db) + m.AddFolder(fcfg) + + // Create a file + file := protocol.FileInfo{ + Name: "test", + Flags: 0, + Modified: 0, + Blocks: []protocol.BlockInfo{blocks[0]}, + } + + // Add file to index + m.updateLocal("default", file) + + if !m.finder.Iterate(blocks[0].Hash, iterFn) { + t.Error("Expected block not found") + } + + file.Blocks = []protocol.BlockInfo{blocks[1]} + file.Version++ + // Update index (removing old blocks) + m.updateLocal("default", file) + + if m.finder.Iterate(blocks[0].Hash, iterFn) { + t.Error("Unexpected block found") + } + + if !m.finder.Iterate(blocks[1].Hash, iterFn) { + t.Error("Expected block not found") + } + + file.Blocks = []protocol.BlockInfo{blocks[0]} + file.Version++ + // Update index (removing old blocks) + m.updateLocal("default", file) + + if !m.finder.Iterate(blocks[0].Hash, iterFn) { + t.Error("Unexpected block found") + } + + if m.finder.Iterate(blocks[1].Hash, iterFn) { + t.Error("Expected block not found") + } +}