Skip to content

Commit

Permalink
internal/zstd: configure window size for single segment frames
Browse files Browse the repository at this point in the history
For #62513
  • Loading branch information
AlexanderYastrebov committed Sep 25, 2023
1 parent 5e9afab commit af57e1c
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 9 deletions.
Binary file not shown.
22 changes: 13 additions & 9 deletions src/internal/zstd/zstd.go
Original file line number Diff line number Diff line change
Expand Up @@ -235,10 +235,7 @@ retry:
// Figure out the maximum amount of data we need to retain
// for backreferences.
var windowSize int
if singleSegment {
// No window required, as all the data is in a single buffer.
windowSize = 0
} else {
if !singleSegment {
// Window descriptor. RFC 3.1.1.1.2.
windowDescriptor := r.scratch[0]
exponent := uint64(windowDescriptor >> 3)
Expand All @@ -252,11 +249,6 @@ retry:
if fuzzing && (windowLog > 31 || windowSize > 1<<27) {
return r.makeError(relativeOffset, "windowSize too large")
}

// RFC 8878 permits us to set an 8M max on window size.
if windowSize > 8<<20 {
windowSize = 8 << 20
}
}

// Frame_Content_Size. RFC 3.1.1.4.
Expand All @@ -278,6 +270,18 @@ retry:
panic("unreachable")
}

// RFC 3.1.1.1.2.
// When Single_Segment_Flag is set, Window_Descriptor is not present.
// In this case, Window_Size is Frame_Content_Size, which can be any value from 0 to 2^64 - 1 bytes (16 ExaBytes).
if singleSegment {
windowSize = int(r.remainingFrameSize)
}

// RFC 8878 permits us to set an 8M max on window size.
if windowSize > 8<<20 {
windowSize = 8 << 20
}

relativeOffset += headerSize

r.sawFrameHeader = true
Expand Down
33 changes: 33 additions & 0 deletions src/internal/zstd/zstd_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package zstd

import (
"bytes"
"crypto/sha256"
"fmt"
"internal/race"
"internal/testenv"
Expand Down Expand Up @@ -232,6 +233,38 @@ func TestAlloc(t *testing.T) {
}
}

func TestFileSamples(t *testing.T) {
samples, err := os.ReadDir("testdata/")
if err != nil {
t.Fatal(err)
}

for _, sample := range samples {
name := sample.Name()
if !strings.HasSuffix(name, ".zst") {
continue
}
t.Run(name, func(t *testing.T) {
f, err := os.Open("testdata/" + name)
if err != nil {
t.Fatal(err)
}

r := NewReader(f)
h := sha256.New()
if _, err := io.Copy(h, r); err != nil {
t.Fatal(err)
}
got := fmt.Sprintf("%x", h.Sum(nil))[:8]

want, _, _ := strings.Cut(name, ".")
if got != want {
t.Errorf("Wrong uncompressed content hash: want: %s, got: %s", want, got)
}
})
}
}

func BenchmarkLarge(b *testing.B) {
b.StopTimer()
b.ReportAllocs()
Expand Down

0 comments on commit af57e1c

Please sign in to comment.