From a339e37cca94adf4fec5665dc7f3172f9ea5263b Mon Sep 17 00:00:00 2001 From: Alan Donovan Date: Fri, 3 Jan 2025 12:22:23 -0500 Subject: [PATCH] gopls/internal/util/persistent: {Map,Set}: use iter.Seq2 This CL updates the tree data structures to use go1.23-style iterators. package persistent func (*Map[K, V]) Keys() iter.Seq[K] func (*Map[K, V]) All() iter.Seq2[K, V] func (*Set[K]) All() iter.Seq[K] Change-Id: I4b8917fa35c38e055e42e10cefea7997fe7b35f3 Reviewed-on: https://go-review.googlesource.com/c/tools/+/640035 Reviewed-by: Robert Findley LUCI-TryBot-Result: Go LUCI Auto-Submit: Alan Donovan --- gopls/internal/cache/filemap.go | 17 +++++---- gopls/internal/cache/filemap_test.go | 8 ++-- gopls/internal/cache/load.go | 4 +- gopls/internal/cache/snapshot.go | 30 +++++++-------- gopls/internal/cache/view.go | 2 +- gopls/internal/util/persistent/map.go | 43 +++++++++++----------- gopls/internal/util/persistent/map_test.go | 4 +- gopls/internal/util/persistent/set.go | 20 ++++++---- gopls/internal/util/persistent/set_test.go | 4 +- 9 files changed, 69 insertions(+), 63 deletions(-) diff --git a/gopls/internal/cache/filemap.go b/gopls/internal/cache/filemap.go index c826141ed98..1f1fd947d71 100644 --- a/gopls/internal/cache/filemap.go +++ b/gopls/internal/cache/filemap.go @@ -5,6 +5,7 @@ package cache import ( + "iter" "path/filepath" "golang.org/x/tools/gopls/internal/file" @@ -77,9 +78,9 @@ func (m *fileMap) get(key protocol.DocumentURI) (file.Handle, bool) { return m.files.Get(key) } -// foreach calls f for each (uri, fh) in the map. -func (m *fileMap) foreach(f func(uri protocol.DocumentURI, fh file.Handle)) { - m.files.Range(f) +// all returns the sequence of (uri, fh) entries in the map. +func (m *fileMap) all() iter.Seq2[protocol.DocumentURI, file.Handle] { + return m.files.All() } // set stores the given file handle for key, updating overlays and directories @@ -130,9 +131,9 @@ func (m *fileMap) delete(key protocol.DocumentURI) { // getOverlays returns a new unordered array of overlay files. func (m *fileMap) getOverlays() []*overlay { var overlays []*overlay - m.overlays.Range(func(_ protocol.DocumentURI, o *overlay) { + for _, o := range m.overlays.All() { overlays = append(overlays, o) - }) + } return overlays } @@ -143,9 +144,9 @@ func (m *fileMap) getOverlays() []*overlay { func (m *fileMap) getDirs() *persistent.Set[string] { if m.dirs == nil { m.dirs = new(persistent.Set[string]) - m.files.Range(func(u protocol.DocumentURI, _ file.Handle) { - m.addDirs(u) - }) + for uri := range m.files.All() { + m.addDirs(uri) + } } return m.dirs } diff --git a/gopls/internal/cache/filemap_test.go b/gopls/internal/cache/filemap_test.go index 13f2c1a9ccd..24b3a19d108 100644 --- a/gopls/internal/cache/filemap_test.go +++ b/gopls/internal/cache/filemap_test.go @@ -83,9 +83,9 @@ func TestFileMap(t *testing.T) { } var gotFiles []string - m.foreach(func(uri protocol.DocumentURI, _ file.Handle) { + for uri := range m.all() { gotFiles = append(gotFiles, normalize(uri.Path())) - }) + } sort.Strings(gotFiles) if diff := cmp.Diff(test.wantFiles, gotFiles); diff != "" { t.Errorf("Files mismatch (-want +got):\n%s", diff) @@ -100,9 +100,9 @@ func TestFileMap(t *testing.T) { } var gotDirs []string - m.getDirs().Range(func(dir string) { + for dir := range m.getDirs().All() { gotDirs = append(gotDirs, normalize(dir)) - }) + } sort.Strings(gotDirs) if diff := cmp.Diff(test.wantDirs, gotDirs); diff != "" { t.Errorf("Dirs mismatch (-want +got):\n%s", diff) diff --git a/gopls/internal/cache/load.go b/gopls/internal/cache/load.go index 873cef56a2b..140cbc45490 100644 --- a/gopls/internal/cache/load.go +++ b/gopls/internal/cache/load.go @@ -262,11 +262,11 @@ func (s *Snapshot) load(ctx context.Context, allowNetwork AllowNetwork, scopes . s.mu.Lock() // Assert the invariant s.packages.Get(id).m == s.meta.metadata[id]. - s.packages.Range(func(id PackageID, ph *packageHandle) { + for id, ph := range s.packages.All() { if s.meta.Packages[id] != ph.mp { panic("inconsistent metadata") } - }) + } // Compute the minimal metadata updates (for Clone) // required to preserve the above invariant. diff --git a/gopls/internal/cache/snapshot.go b/gopls/internal/cache/snapshot.go index de4a52ff6cb..ffca1dc00ec 100644 --- a/gopls/internal/cache/snapshot.go +++ b/gopls/internal/cache/snapshot.go @@ -344,11 +344,11 @@ func (s *Snapshot) Templates() map[protocol.DocumentURI]file.Handle { defer s.mu.Unlock() tmpls := map[protocol.DocumentURI]file.Handle{} - s.files.foreach(func(k protocol.DocumentURI, fh file.Handle) { + for k, fh := range s.files.all() { if s.FileKind(fh) == file.Tmpl { tmpls[k] = fh } - }) + } return tmpls } @@ -864,13 +864,13 @@ func (s *Snapshot) addKnownSubdirs(patterns map[protocol.RelativePattern]unit, w s.mu.Lock() defer s.mu.Unlock() - s.files.getDirs().Range(func(dir string) { + for dir := range s.files.getDirs().All() { for _, wsDir := range wsDirs { if pathutil.InDir(wsDir, dir) { patterns[protocol.RelativePattern{Pattern: filepath.ToSlash(dir)}] = unit{} } } - }) + } } // watchSubdirs reports whether gopls should request separate file watchers for @@ -912,11 +912,11 @@ func (s *Snapshot) filesInDir(uri protocol.DocumentURI) []protocol.DocumentURI { return nil } var files []protocol.DocumentURI - s.files.foreach(func(uri protocol.DocumentURI, _ file.Handle) { + for uri := range s.files.all() { if pathutil.InDir(dir, uri.Path()) { files = append(files, uri) } - }) + } return files } @@ -1029,13 +1029,11 @@ func (s *Snapshot) clearShouldLoad(scopes ...loadScope) { case packageLoadScope: scopePath := PackagePath(scope) var toDelete []PackageID - s.shouldLoad.Range(func(id PackageID, pkgPaths []PackagePath) { - for _, pkgPath := range pkgPaths { - if pkgPath == scopePath { - toDelete = append(toDelete, id) - } + for id, pkgPaths := range s.shouldLoad.All() { + if slices.Contains(pkgPaths, scopePath) { + toDelete = append(toDelete, id) } - }) + } for _, id := range toDelete { s.shouldLoad.Delete(id) } @@ -1183,7 +1181,7 @@ func (s *Snapshot) reloadWorkspace(ctx context.Context) { var scopes []loadScope var seen map[PackagePath]bool s.mu.Lock() - s.shouldLoad.Range(func(_ PackageID, pkgPaths []PackagePath) { + for _, pkgPaths := range s.shouldLoad.All() { for _, pkgPath := range pkgPaths { if seen == nil { seen = make(map[PackagePath]bool) @@ -1194,7 +1192,7 @@ func (s *Snapshot) reloadWorkspace(ctx context.Context) { seen[pkgPath] = true scopes = append(scopes, packageLoadScope(pkgPath)) } - }) + } s.mu.Unlock() if len(scopes) == 0 { @@ -1886,13 +1884,13 @@ func deleteMostRelevantModFile(m *persistent.Map[protocol.DocumentURI, *memoize. var mostRelevant protocol.DocumentURI changedFile := changed.Path() - m.Range(func(modURI protocol.DocumentURI, _ *memoize.Promise) { + for modURI := range m.All() { if len(modURI) > len(mostRelevant) { if pathutil.InDir(modURI.DirPath(), changedFile) { mostRelevant = modURI } } - }) + } if mostRelevant != "" { m.Delete(mostRelevant) } diff --git a/gopls/internal/cache/view.go b/gopls/internal/cache/view.go index 5fb03cb1152..33c77760e7f 100644 --- a/gopls/internal/cache/view.go +++ b/gopls/internal/cache/view.go @@ -1171,7 +1171,7 @@ func (s *Snapshot) Vulnerabilities(modfiles ...protocol.DocumentURI) map[protoco defer s.mu.Unlock() if len(modfiles) == 0 { // empty means all modfiles - modfiles = s.vulns.Keys() + modfiles = slices.Collect(s.vulns.Keys()) } for _, modfile := range modfiles { vuln, _ := s.vulns.Get(modfile) diff --git a/gopls/internal/util/persistent/map.go b/gopls/internal/util/persistent/map.go index 5cb556a482b..193f98791d8 100644 --- a/gopls/internal/util/persistent/map.go +++ b/gopls/internal/util/persistent/map.go @@ -9,6 +9,7 @@ package persistent import ( "fmt" + "iter" "math/rand" "strings" "sync/atomic" @@ -57,10 +58,10 @@ func (m *Map[K, V]) String() string { var buf strings.Builder buf.WriteByte('{') var sep string - m.Range(func(k K, v V) { + for k, v := range m.All() { fmt.Fprintf(&buf, "%s%v: %v", sep, k, v) sep = ", " - }) + } buf.WriteByte('}') return buf.String() } @@ -149,29 +150,29 @@ func (pm *Map[K, V]) Clear() { pm.root = nil } -// Keys returns all keys present in the map. -func (pm *Map[K, V]) Keys() []K { - var keys []K - pm.root.forEach(func(k, _ any) { - keys = append(keys, k.(K)) - }) - return keys +// Keys returns the ascending sequence of keys present in the map. +func (pm *Map[K, V]) Keys() iter.Seq[K] { + return func(yield func(K) bool) { + pm.root.forEach(func(k, _ any) bool { + return yield(k.(K)) + }) + } } -// Range calls f sequentially in ascending key order for all entries in the map. -func (pm *Map[K, V]) Range(f func(key K, value V)) { - pm.root.forEach(func(k, v any) { - f(k.(K), v.(V)) - }) +// All returns the sequence of map entries in ascending key order. +func (pm *Map[K, V]) All() iter.Seq2[K, V] { + return func(yield func(K, V) bool) { + pm.root.forEach(func(k, v any) bool { + return yield(k.(K), v.(V)) + }) + } } -func (node *mapNode) forEach(f func(key, value any)) { - if node == nil { - return - } - node.left.forEach(f) - f(node.key, node.value.value) - node.right.forEach(f) +func (node *mapNode) forEach(yield func(key, value any) bool) bool { + return node == nil || + node.left.forEach(yield) && + yield(node.key, node.value.value) && + node.right.forEach(yield) } // Get returns the map value associated with the specified key. diff --git a/gopls/internal/util/persistent/map_test.go b/gopls/internal/util/persistent/map_test.go index effa1c1da85..88dced2a85f 100644 --- a/gopls/internal/util/persistent/map_test.go +++ b/gopls/internal/util/persistent/map_test.go @@ -240,12 +240,12 @@ func (vm *validatedMap) validate(t *testing.T) { } actualMap := make(map[int]int, len(vm.expected)) - vm.impl.Range(func(key, value int) { + for key, value := range vm.impl.All() { if other, ok := actualMap[key]; ok { t.Fatalf("key is present twice, key: %d, first value: %d, second value: %d", key, value, other) } actualMap[key] = value - }) + } assertSameMap(t, actualMap, vm.expected) } diff --git a/gopls/internal/util/persistent/set.go b/gopls/internal/util/persistent/set.go index 2d5f4edac96..e47d046fb48 100644 --- a/gopls/internal/util/persistent/set.go +++ b/gopls/internal/util/persistent/set.go @@ -4,7 +4,11 @@ package persistent -import "golang.org/x/tools/gopls/internal/util/constraints" +import ( + "iter" + + "golang.org/x/tools/gopls/internal/util/constraints" +) // Set is a collection of elements of type K. // @@ -43,12 +47,14 @@ func (s *Set[K]) Contains(key K) bool { return ok } -// Range calls f sequentially in ascending key order for all entries in the set. -func (s *Set[K]) Range(f func(key K)) { - if s.impl != nil { - s.impl.Range(func(key K, _ struct{}) { - f(key) - }) +// All returns the sequence of set elements in ascending order. +func (s *Set[K]) All() iter.Seq[K] { + return func(yield func(K) bool) { + if s.impl != nil { + s.impl.root.forEach(func(k, _ any) bool { + return yield(k.(K)) + }) + } } } diff --git a/gopls/internal/util/persistent/set_test.go b/gopls/internal/util/persistent/set_test.go index 31911b451b3..192b1c74121 100644 --- a/gopls/internal/util/persistent/set_test.go +++ b/gopls/internal/util/persistent/set_test.go @@ -111,11 +111,11 @@ func diff[K constraints.Ordered](got *persistent.Set[K], want []K) string { wantSet[w] = struct{}{} } var diff []string - got.Range(func(key K) { + for key := range got.All() { if _, ok := wantSet[key]; !ok { diff = append(diff, fmt.Sprintf("+%v", key)) } - }) + } for key := range wantSet { if !got.Contains(key) { diff = append(diff, fmt.Sprintf("-%v", key))