Skip to content

Commit

Permalink
Fix matrix buffer sizes (#241)
Browse files Browse the repository at this point in the history
Fixes #240 - retract v1.11.6

Add re-entrant tests.
  • Loading branch information
klauspost authored Feb 15, 2023
1 parent a6d2e3d commit 1bbcf49
Show file tree
Hide file tree
Showing 4 changed files with 153 additions and 19 deletions.
4 changes: 3 additions & 1 deletion galois.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@

package reedsolomon

import "encoding/binary"
import (
"encoding/binary"
)

const (
// The number of elements in the field.
Expand Down
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,5 @@ require golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e // indirect
retract (
v1.11.2 // https://github.com/klauspost/reedsolomon/pull/229
[v1.11.3, v1.11.5] // https://github.com/klauspost/reedsolomon/pull/238
v1.11.6 // https://github.com/klauspost/reedsolomon/issues/240
)
40 changes: 28 additions & 12 deletions reedsolomon.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ package reedsolomon
import (
"bytes"
"errors"
"fmt"
"io"
"runtime"
"sync"
Expand Down Expand Up @@ -171,7 +172,8 @@ type reedSolomon struct {
tree *inversionTree
parity [][]byte
o options
mPool sync.Pool
mPoolSz int
mPool sync.Pool // Pool for temp matrices, etc
}

var _ = Extensions(&reedSolomon{})
Expand Down Expand Up @@ -571,12 +573,28 @@ func New(dataShards, parityShards int, opts ...Option) (Encoder, error) {
if avx2CodeGen && r.o.useAVX2 {
sz := r.dataShards * r.parityShards * 2 * 32
r.mPool.New = func() interface{} {
return make([]byte, sz)
return AllocAligned(1, sz)[0]
}
r.mPoolSz = sz
}
return &r, err
}

func (r *reedSolomon) getTmpSlice() []byte {
return r.mPool.Get().([]byte)
}

func (r *reedSolomon) putTmpSlice(b []byte) {
if b != nil && cap(b) >= r.mPoolSz {
r.mPool.Put(b[:r.mPoolSz])
return
}
if false {
// Sanity check
panic(fmt.Sprintf("got short tmp returned, want %d, got %d", r.mPoolSz, cap(b)))
}
}

// ErrTooFewShards is returned if too few shards where given to
// Encode/Verify/Reconstruct/Update. It will also be returned from Reconstruct
// if there were too few shards to reconstruct the missing data.
Expand Down Expand Up @@ -806,16 +824,16 @@ func (r *reedSolomon) codeSomeShards(matrixRows, inputs, outputs [][]byte, byteC
start += galMulSlicesGFNI(m, inputs, outputs, 0, byteCount)
end = len(inputs[0])
} else if r.canAVX2C(byteCount, len(inputs), len(outputs)) {
m := genAvx2Matrix(matrixRows, len(inputs), 0, len(outputs), r.mPool.Get().([]byte))
m := genAvx2Matrix(matrixRows, len(inputs), 0, len(outputs), r.getTmpSlice())
start += galMulSlicesAvx2(m, inputs, outputs, 0, byteCount)
r.mPool.Put(m)
r.putTmpSlice(m)
end = len(inputs[0])
} else if len(inputs)+len(outputs) > avx2CodeGenMinShards && r.canAVX2C(byteCount, maxAvx2Inputs, maxAvx2Outputs) {
var gfni [maxAvx2Inputs * maxAvx2Outputs]uint64
end = len(inputs[0])
inIdx := 0
m := r.mPool.Get().([]byte)
defer r.mPool.Put(m)
m := r.getTmpSlice()
defer r.putTmpSlice(m)
ins := inputs
for len(ins) > 0 {
inPer := ins
Expand Down Expand Up @@ -888,8 +906,8 @@ func (r *reedSolomon) codeSomeShardsP(matrixRows, inputs, outputs [][]byte, byte
var tmp [maxAvx2Inputs * maxAvx2Outputs]uint64
gfniMatrix = genGFNIMatrix(matrixRows, len(inputs), 0, len(outputs), tmp[:])
} else if useAvx2 {
avx2Matrix = genAvx2Matrix(matrixRows, len(inputs), 0, len(outputs), r.mPool.Get().([]byte))
defer r.mPool.Put(avx2Matrix)
avx2Matrix = genAvx2Matrix(matrixRows, len(inputs), 0, len(outputs), r.getTmpSlice())
defer r.putTmpSlice(avx2Matrix)
} else if r.o.useGFNI && byteCount < 10<<20 && len(inputs)+len(outputs) > avx2CodeGenMinShards &&
r.canAVX2C(byteCount/4, maxAvx2Inputs, maxAvx2Outputs) {
// It appears there is a switchover point at around 10MB where
Expand Down Expand Up @@ -977,10 +995,8 @@ func (r *reedSolomon) codeSomeShardsAVXP(matrixRows, inputs, outputs [][]byte, b
// Make a plan...
plan := make([]state, 0, ((len(inputs)+maxAvx2Inputs-1)/maxAvx2Inputs)*((len(outputs)+maxAvx2Outputs-1)/maxAvx2Outputs))

tmp := r.mPool.Get().([]byte)
defer func(b []byte) {
r.mPool.Put(b)
}(tmp)
tmp := r.getTmpSlice()
defer r.putTmpSlice(tmp)

// Flips between input first to output first.
// We put the smallest data load in the inner loop.
Expand Down
127 changes: 121 additions & 6 deletions reedsolomon_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -275,16 +275,26 @@ func TestEncoding(t *testing.T) {

// matrix sizes to test.
// note that par1 matrix will fail on some combinations.
var testSizes = [][2]int{
{1, 0}, {3, 0}, {5, 0}, {8, 0}, {10, 0}, {12, 0}, {14, 0}, {41, 0}, {49, 0},
{1, 1}, {1, 2}, {3, 3}, {3, 1}, {5, 3}, {8, 4}, {10, 30}, {12, 10}, {14, 7}, {41, 17}, {49, 1}, {5, 20},
{256, 20}, {500, 300}, {2945, 129},
func testSizes() [][2]int {
if testing.Short() {
return [][2]int{
{3, 0},
{1, 1}, {1, 2}, {8, 4}, {10, 30}, {41, 17},
{256, 20}, {500, 300},
}
}
return [][2]int{
{1, 0}, {10, 0}, {12, 0}, {49, 0},
{1, 1}, {1, 2}, {3, 3}, {3, 1}, {5, 3}, {8, 4}, {10, 30}, {12, 10}, {14, 7}, {41, 17}, {49, 1}, {5, 20},
{256, 20}, {500, 300}, {2945, 129},
}
}

var testDataSizes = []int{10, 100, 1000, 10001, 100003, 1000055}
var testDataSizesShort = []int{10, 10001, 100003}

func testEncoding(t *testing.T, o ...Option) {
for _, size := range testSizes {
for _, size := range testSizes() {
data, parity := size[0], size[1]
rng := rand.New(rand.NewSource(0xabadc0cac01a))
t.Run(fmt.Sprintf("%dx%d", data, parity), func(t *testing.T) {
Expand Down Expand Up @@ -398,7 +408,7 @@ func testEncoding(t *testing.T, o ...Option) {
}

func testEncodingIdx(t *testing.T, o ...Option) {
for _, size := range testSizes {
for _, size := range testSizes() {
data, parity := size[0], size[1]
rng := rand.New(rand.NewSource(0xabadc0cac01a))
t.Run(fmt.Sprintf("%dx%d", data, parity), func(t *testing.T) {
Expand Down Expand Up @@ -2100,3 +2110,108 @@ func BenchmarkParallel_8x8x32M(b *testing.B) { benchmarkParallel(b, 8, 8, 32<<
func BenchmarkParallel_8x3x1M(b *testing.B) { benchmarkParallel(b, 8, 3, 1<<20) }
func BenchmarkParallel_8x4x1M(b *testing.B) { benchmarkParallel(b, 8, 4, 1<<20) }
func BenchmarkParallel_8x5x1M(b *testing.B) { benchmarkParallel(b, 8, 5, 1<<20) }

func TestReentrant(t *testing.T) {
for optN, o := range testOpts() {
for _, size := range testSizes() {
data, parity := size[0], size[1]
rng := rand.New(rand.NewSource(0xabadc0cac01a))
t.Run(fmt.Sprintf("opt-%d-%dx%d", optN, data, parity), func(t *testing.T) {
perShard := 16384 + 1
if testing.Short() {
perShard = 1024 + 1
}
r, err := New(data, parity, testOptions(o...)...)
if err != nil {
t.Fatal(err)
}
x := r.(Extensions)
if want, got := data, x.DataShards(); want != got {
t.Errorf("DataShards returned %d, want %d", got, want)
}
if want, got := parity, x.ParityShards(); want != got {
t.Errorf("ParityShards returned %d, want %d", got, want)
}
if want, got := parity+data, x.TotalShards(); want != got {
t.Errorf("TotalShards returned %d, want %d", got, want)
}
mul := x.ShardSizeMultiple()
if mul <= 0 {
t.Fatalf("Got unexpected ShardSizeMultiple: %d", mul)
}
perShard = ((perShard + mul - 1) / mul) * mul
runs := 10
if testing.Short() {
runs = 2
}
for i := 0; i < runs; i++ {
shards := AllocAligned(data+parity, perShard)

err = r.Encode(shards)
if err != nil {
t.Fatal(err)
}
ok, err := r.Verify(shards)
if err != nil {
t.Fatal(err)
}
if !ok {
t.Fatal("Verification failed")
}

if parity == 0 {
// Check that Reconstruct and ReconstructData do nothing
err = r.ReconstructData(shards)
if err != nil {
t.Fatal(err)
}
err = r.Reconstruct(shards)
if err != nil {
t.Fatal(err)
}

// Skip integrity checks
continue
}

// Delete one in data
idx := rng.Intn(data)
want := shards[idx]
shards[idx] = nil

err = r.ReconstructData(shards)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(shards[idx], want) {
t.Fatal("did not ReconstructData correctly")
}

// Delete one randomly
idx = rng.Intn(data + parity)
want = shards[idx]
shards[idx] = nil
err = r.Reconstruct(shards)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(shards[idx], want) {
t.Fatal("did not Reconstruct correctly")
}

err = r.Encode(make([][]byte, 1))
if err != ErrTooFewShards {
t.Errorf("expected %v, got %v", ErrTooFewShards, err)
}

// Make one too short.
shards[idx] = shards[idx][:perShard-1]
err = r.Encode(shards)
if err != ErrShardSize {
t.Errorf("expected %v, got %v", ErrShardSize, err)
}
}
})
}
}
}

0 comments on commit 1bbcf49

Please sign in to comment.