diff --git a/src/internal/zstd/testdata/1890a371.gettysburg.txt-100x.zst b/src/internal/zstd/testdata/1890a371.gettysburg.txt-100x.zst new file mode 100644 index 00000000000000..afb4a2769b6389 Binary files /dev/null and b/src/internal/zstd/testdata/1890a371.gettysburg.txt-100x.zst differ diff --git a/src/internal/zstd/zstd.go b/src/internal/zstd/zstd.go index 60551a4371767f..1a7a0a381b02c5 100644 --- a/src/internal/zstd/zstd.go +++ b/src/internal/zstd/zstd.go @@ -235,10 +235,7 @@ retry: // Figure out the maximum amount of data we need to retain // for backreferences. var windowSize int - if singleSegment { - // No window required, as all the data is in a single buffer. - windowSize = 0 - } else { + if !singleSegment { // Window descriptor. RFC 3.1.1.1.2. windowDescriptor := r.scratch[0] exponent := uint64(windowDescriptor >> 3) @@ -252,11 +249,6 @@ retry: if fuzzing && (windowLog > 31 || windowSize > 1<<27) { return r.makeError(relativeOffset, "windowSize too large") } - - // RFC 8878 permits us to set an 8M max on window size. - if windowSize > 8<<20 { - windowSize = 8 << 20 - } } // Frame_Content_Size. RFC 3.1.1.4. @@ -278,6 +270,18 @@ retry: panic("unreachable") } + // RFC 3.1.1.1.2. + // When Single_Segment_Flag is set, Window_Descriptor is not present. + // In this case, Window_Size is Frame_Content_Size. + if singleSegment { + windowSize = int(r.remainingFrameSize) + } + + // RFC 8878 3.1.1.1.1.2. permits us to set an 8M max on window size. + if windowSize > 8<<20 { + windowSize = 8 << 20 + } + relativeOffset += headerSize r.sawFrameHeader = true diff --git a/src/internal/zstd/zstd_test.go b/src/internal/zstd/zstd_test.go index 22af814acfce57..b9e16c6de9b01d 100644 --- a/src/internal/zstd/zstd_test.go +++ b/src/internal/zstd/zstd_test.go @@ -6,6 +6,7 @@ package zstd import ( "bytes" + "crypto/sha256" "fmt" "internal/race" "internal/testenv" @@ -232,6 +233,38 @@ func TestAlloc(t *testing.T) { } } +func TestFileSamples(t *testing.T) { + samples, err := os.ReadDir("testdata/") + if err != nil { + t.Fatal(err) + } + + for _, sample := range samples { + name := sample.Name() + if !strings.HasSuffix(name, ".zst") { + continue + } + t.Run(name, func(t *testing.T) { + f, err := os.Open("testdata/" + name) + if err != nil { + t.Fatal(err) + } + + r := NewReader(f) + h := sha256.New() + if _, err := io.Copy(h, r); err != nil { + t.Fatal(err) + } + got := fmt.Sprintf("%x", h.Sum(nil))[:8] + + want, _, _ := strings.Cut(name, ".") + if got != want { + t.Errorf("Wrong uncompressed content hash: want: %s, got: %s", want, got) + } + }) + } +} + func BenchmarkLarge(b *testing.B) { b.StopTimer() b.ReportAllocs()