Skip to content

Commit

Permalink
gopls/internal/util/persistent: {Map,Set}: use iter.Seq2
Browse files Browse the repository at this point in the history
This CL updates the tree data structures to use go1.23-style
iterators.

package persistent
func (*Map[K, V]) Keys() iter.Seq[K]
func (*Map[K, V]) All() iter.Seq2[K, V]
func (*Set[K]) All() iter.Seq[K]

Change-Id: I4b8917fa35c38e055e42e10cefea7997fe7b35f3
Reviewed-on: https://go-review.googlesource.com/c/tools/+/640035
Reviewed-by: Robert Findley <rfindley@google.com>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Auto-Submit: Alan Donovan <adonovan@google.com>
  • Loading branch information
adonovan authored and gopherbot committed Jan 7, 2025
1 parent a2408f8 commit a339e37
Show file tree
Hide file tree
Showing 9 changed files with 69 additions and 63 deletions.
17 changes: 9 additions & 8 deletions gopls/internal/cache/filemap.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
package cache

import (
"iter"
"path/filepath"

"golang.org/x/tools/gopls/internal/file"
Expand Down Expand Up @@ -77,9 +78,9 @@ func (m *fileMap) get(key protocol.DocumentURI) (file.Handle, bool) {
return m.files.Get(key)
}

// foreach calls f for each (uri, fh) in the map.
func (m *fileMap) foreach(f func(uri protocol.DocumentURI, fh file.Handle)) {
m.files.Range(f)
// all returns the sequence of (uri, fh) entries in the map.
func (m *fileMap) all() iter.Seq2[protocol.DocumentURI, file.Handle] {
return m.files.All()
}

// set stores the given file handle for key, updating overlays and directories
Expand Down Expand Up @@ -130,9 +131,9 @@ func (m *fileMap) delete(key protocol.DocumentURI) {
// getOverlays returns a new unordered array of overlay files.
func (m *fileMap) getOverlays() []*overlay {
var overlays []*overlay
m.overlays.Range(func(_ protocol.DocumentURI, o *overlay) {
for _, o := range m.overlays.All() {
overlays = append(overlays, o)
})
}
return overlays
}

Expand All @@ -143,9 +144,9 @@ func (m *fileMap) getOverlays() []*overlay {
func (m *fileMap) getDirs() *persistent.Set[string] {
if m.dirs == nil {
m.dirs = new(persistent.Set[string])
m.files.Range(func(u protocol.DocumentURI, _ file.Handle) {
m.addDirs(u)
})
for uri := range m.files.All() {
m.addDirs(uri)
}
}
return m.dirs
}
8 changes: 4 additions & 4 deletions gopls/internal/cache/filemap_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,9 @@ func TestFileMap(t *testing.T) {
}

var gotFiles []string
m.foreach(func(uri protocol.DocumentURI, _ file.Handle) {
for uri := range m.all() {
gotFiles = append(gotFiles, normalize(uri.Path()))
})
}
sort.Strings(gotFiles)
if diff := cmp.Diff(test.wantFiles, gotFiles); diff != "" {
t.Errorf("Files mismatch (-want +got):\n%s", diff)
Expand All @@ -100,9 +100,9 @@ func TestFileMap(t *testing.T) {
}

var gotDirs []string
m.getDirs().Range(func(dir string) {
for dir := range m.getDirs().All() {
gotDirs = append(gotDirs, normalize(dir))
})
}
sort.Strings(gotDirs)
if diff := cmp.Diff(test.wantDirs, gotDirs); diff != "" {
t.Errorf("Dirs mismatch (-want +got):\n%s", diff)
Expand Down
4 changes: 2 additions & 2 deletions gopls/internal/cache/load.go
Original file line number Diff line number Diff line change
Expand Up @@ -262,11 +262,11 @@ func (s *Snapshot) load(ctx context.Context, allowNetwork AllowNetwork, scopes .
s.mu.Lock()

// Assert the invariant s.packages.Get(id).m == s.meta.metadata[id].
s.packages.Range(func(id PackageID, ph *packageHandle) {
for id, ph := range s.packages.All() {
if s.meta.Packages[id] != ph.mp {
panic("inconsistent metadata")
}
})
}

// Compute the minimal metadata updates (for Clone)
// required to preserve the above invariant.
Expand Down
30 changes: 14 additions & 16 deletions gopls/internal/cache/snapshot.go
Original file line number Diff line number Diff line change
Expand Up @@ -344,11 +344,11 @@ func (s *Snapshot) Templates() map[protocol.DocumentURI]file.Handle {
defer s.mu.Unlock()

tmpls := map[protocol.DocumentURI]file.Handle{}
s.files.foreach(func(k protocol.DocumentURI, fh file.Handle) {
for k, fh := range s.files.all() {
if s.FileKind(fh) == file.Tmpl {
tmpls[k] = fh
}
})
}
return tmpls
}

Expand Down Expand Up @@ -864,13 +864,13 @@ func (s *Snapshot) addKnownSubdirs(patterns map[protocol.RelativePattern]unit, w
s.mu.Lock()
defer s.mu.Unlock()

s.files.getDirs().Range(func(dir string) {
for dir := range s.files.getDirs().All() {
for _, wsDir := range wsDirs {
if pathutil.InDir(wsDir, dir) {
patterns[protocol.RelativePattern{Pattern: filepath.ToSlash(dir)}] = unit{}
}
}
})
}
}

// watchSubdirs reports whether gopls should request separate file watchers for
Expand Down Expand Up @@ -912,11 +912,11 @@ func (s *Snapshot) filesInDir(uri protocol.DocumentURI) []protocol.DocumentURI {
return nil
}
var files []protocol.DocumentURI
s.files.foreach(func(uri protocol.DocumentURI, _ file.Handle) {
for uri := range s.files.all() {
if pathutil.InDir(dir, uri.Path()) {
files = append(files, uri)
}
})
}
return files
}

Expand Down Expand Up @@ -1029,13 +1029,11 @@ func (s *Snapshot) clearShouldLoad(scopes ...loadScope) {
case packageLoadScope:
scopePath := PackagePath(scope)
var toDelete []PackageID
s.shouldLoad.Range(func(id PackageID, pkgPaths []PackagePath) {
for _, pkgPath := range pkgPaths {
if pkgPath == scopePath {
toDelete = append(toDelete, id)
}
for id, pkgPaths := range s.shouldLoad.All() {
if slices.Contains(pkgPaths, scopePath) {
toDelete = append(toDelete, id)
}
})
}
for _, id := range toDelete {
s.shouldLoad.Delete(id)
}
Expand Down Expand Up @@ -1183,7 +1181,7 @@ func (s *Snapshot) reloadWorkspace(ctx context.Context) {
var scopes []loadScope
var seen map[PackagePath]bool
s.mu.Lock()
s.shouldLoad.Range(func(_ PackageID, pkgPaths []PackagePath) {
for _, pkgPaths := range s.shouldLoad.All() {
for _, pkgPath := range pkgPaths {
if seen == nil {
seen = make(map[PackagePath]bool)
Expand All @@ -1194,7 +1192,7 @@ func (s *Snapshot) reloadWorkspace(ctx context.Context) {
seen[pkgPath] = true
scopes = append(scopes, packageLoadScope(pkgPath))
}
})
}
s.mu.Unlock()

if len(scopes) == 0 {
Expand Down Expand Up @@ -1886,13 +1884,13 @@ func deleteMostRelevantModFile(m *persistent.Map[protocol.DocumentURI, *memoize.
var mostRelevant protocol.DocumentURI
changedFile := changed.Path()

m.Range(func(modURI protocol.DocumentURI, _ *memoize.Promise) {
for modURI := range m.All() {
if len(modURI) > len(mostRelevant) {
if pathutil.InDir(modURI.DirPath(), changedFile) {
mostRelevant = modURI
}
}
})
}
if mostRelevant != "" {
m.Delete(mostRelevant)
}
Expand Down
2 changes: 1 addition & 1 deletion gopls/internal/cache/view.go
Original file line number Diff line number Diff line change
Expand Up @@ -1171,7 +1171,7 @@ func (s *Snapshot) Vulnerabilities(modfiles ...protocol.DocumentURI) map[protoco
defer s.mu.Unlock()

if len(modfiles) == 0 { // empty means all modfiles
modfiles = s.vulns.Keys()
modfiles = slices.Collect(s.vulns.Keys())
}
for _, modfile := range modfiles {
vuln, _ := s.vulns.Get(modfile)
Expand Down
43 changes: 22 additions & 21 deletions gopls/internal/util/persistent/map.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ package persistent

import (
"fmt"
"iter"
"math/rand"
"strings"
"sync/atomic"
Expand Down Expand Up @@ -57,10 +58,10 @@ func (m *Map[K, V]) String() string {
var buf strings.Builder
buf.WriteByte('{')
var sep string
m.Range(func(k K, v V) {
for k, v := range m.All() {
fmt.Fprintf(&buf, "%s%v: %v", sep, k, v)
sep = ", "
})
}
buf.WriteByte('}')
return buf.String()
}
Expand Down Expand Up @@ -149,29 +150,29 @@ func (pm *Map[K, V]) Clear() {
pm.root = nil
}

// Keys returns all keys present in the map.
func (pm *Map[K, V]) Keys() []K {
var keys []K
pm.root.forEach(func(k, _ any) {
keys = append(keys, k.(K))
})
return keys
// Keys returns the ascending sequence of keys present in the map.
func (pm *Map[K, V]) Keys() iter.Seq[K] {
return func(yield func(K) bool) {
pm.root.forEach(func(k, _ any) bool {
return yield(k.(K))
})
}
}

// Range calls f sequentially in ascending key order for all entries in the map.
func (pm *Map[K, V]) Range(f func(key K, value V)) {
pm.root.forEach(func(k, v any) {
f(k.(K), v.(V))
})
// All returns the sequence of map entries in ascending key order.
func (pm *Map[K, V]) All() iter.Seq2[K, V] {
return func(yield func(K, V) bool) {
pm.root.forEach(func(k, v any) bool {
return yield(k.(K), v.(V))
})
}
}

func (node *mapNode) forEach(f func(key, value any)) {
if node == nil {
return
}
node.left.forEach(f)
f(node.key, node.value.value)
node.right.forEach(f)
func (node *mapNode) forEach(yield func(key, value any) bool) bool {
return node == nil ||
node.left.forEach(yield) &&
yield(node.key, node.value.value) &&
node.right.forEach(yield)
}

// Get returns the map value associated with the specified key.
Expand Down
4 changes: 2 additions & 2 deletions gopls/internal/util/persistent/map_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -240,12 +240,12 @@ func (vm *validatedMap) validate(t *testing.T) {
}

actualMap := make(map[int]int, len(vm.expected))
vm.impl.Range(func(key, value int) {
for key, value := range vm.impl.All() {
if other, ok := actualMap[key]; ok {
t.Fatalf("key is present twice, key: %d, first value: %d, second value: %d", key, value, other)
}
actualMap[key] = value
})
}

assertSameMap(t, actualMap, vm.expected)
}
Expand Down
20 changes: 13 additions & 7 deletions gopls/internal/util/persistent/set.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@

package persistent

import "golang.org/x/tools/gopls/internal/util/constraints"
import (
"iter"

"golang.org/x/tools/gopls/internal/util/constraints"
)

// Set is a collection of elements of type K.
//
Expand Down Expand Up @@ -43,12 +47,14 @@ func (s *Set[K]) Contains(key K) bool {
return ok
}

// Range calls f sequentially in ascending key order for all entries in the set.
func (s *Set[K]) Range(f func(key K)) {
if s.impl != nil {
s.impl.Range(func(key K, _ struct{}) {
f(key)
})
// All returns the sequence of set elements in ascending order.
func (s *Set[K]) All() iter.Seq[K] {
return func(yield func(K) bool) {
if s.impl != nil {
s.impl.root.forEach(func(k, _ any) bool {
return yield(k.(K))
})
}
}
}

Expand Down
4 changes: 2 additions & 2 deletions gopls/internal/util/persistent/set_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,11 +111,11 @@ func diff[K constraints.Ordered](got *persistent.Set[K], want []K) string {
wantSet[w] = struct{}{}
}
var diff []string
got.Range(func(key K) {
for key := range got.All() {
if _, ok := wantSet[key]; !ok {
diff = append(diff, fmt.Sprintf("+%v", key))
}
})
}
for key := range wantSet {
if !got.Contains(key) {
diff = append(diff, fmt.Sprintf("-%v", key))
Expand Down

0 comments on commit a339e37

Please sign in to comment.