diff --git a/go.mod b/go.mod index db01dcc..cca9618 100644 --- a/go.mod +++ b/go.mod @@ -8,8 +8,8 @@ require ( github.com/onsi/ginkgo/v2 v2.14.0 github.com/onsi/gomega v1.30.0 github.com/spf13/pflag v1.0.5 + github.com/zeebo/blake3 v0.2.4 go.uber.org/zap v1.27.0 - golang.org/x/crypto v0.21.0 sigs.k8s.io/controller-runtime v0.17.3 ) @@ -21,6 +21,7 @@ require ( github.com/google/gofuzz v1.2.0 // indirect github.com/google/pprof v0.0.0-20210720184732-4bb14d4b1be1 // indirect github.com/json-iterator/go v1.1.12 // indirect + github.com/klauspost/cpuid/v2 v2.0.12 // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.2 // indirect go.uber.org/multierr v1.11.0 // indirect diff --git a/go.sum b/go.sum index bfdcc66..11cfdee 100644 --- a/go.sum +++ b/go.sum @@ -30,6 +30,8 @@ github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnr github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= +github.com/klauspost/cpuid/v2 v2.0.12 h1:p9dKCg8i4gmOxtv35DvrYoWqYzQrvEVdjQ762Y0OqZE= +github.com/klauspost/cpuid/v2 v2.0.12/go.mod h1:g2LTdtYhdyuGPqyWyv7qRAmj1WBqxuObKfj5c0PQa7c= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= @@ -50,6 +52,12 @@ github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcU github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= +github.com/zeebo/assert v1.1.0 h1:hU1L1vLTHsnO8x8c9KAR5GmM5QscxHg5RNU5z5qbUWY= +github.com/zeebo/assert v1.1.0/go.mod h1:Pq9JiuJQpG8JLJdtkwrJESF0Foym2/D9XMU5ciN/wJ0= +github.com/zeebo/blake3 v0.2.4 h1:KYQPkhpRtcqh0ssGYcKLG1JYvddkEA8QwCM/yBqhaZI= +github.com/zeebo/blake3 v0.2.4/go.mod h1:7eeQ6d2iXWRGF6npfaxl2CU+xy2Fjo2gxeyZGCRUjcE= +github.com/zeebo/pcg v1.0.1 h1:lyqfGeWiv4ahac6ttHs+I5hwtH/+1mrhlCtVNQM2kHo= +github.com/zeebo/pcg v1.0.1/go.mod h1:09F0S9iiKrwn9rlI5yjLkmrug154/YRW6KnnXVDM/l4= go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0= @@ -59,8 +67,6 @@ go.uber.org/zap v1.27.0/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= -golang.org/x/crypto v0.21.0 h1:X31++rzVUdKhX5sWmSOFZxx8UW/ldWx55cbf08iNAMA= -golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs= golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= diff --git a/pkg/blockrsync/client.go b/pkg/blockrsync/client.go index dde6ab0..90999b2 100644 --- a/pkg/blockrsync/client.go +++ b/pkg/blockrsync/client.go @@ -1,6 +1,7 @@ package blockrsync import ( + "bytes" "encoding/binary" "fmt" "io" @@ -42,19 +43,26 @@ func (b *BlockrsyncClient) ConnectToTarget() error { } b.log.Info("Opened file", "file", b.sourceFile) defer f.Close() - + b.log.V(3).Info("Connecting to target", "address", b.connectionProvider.TargetAddress()) + conn, err := b.connectionProvider.Connect() + if err != nil { + return err + } + defer conn.Close() + b.log.Info("Connected to target, reading file to hash") size, err := b.hasher.HashFile(b.sourceFile) if err != nil { return err } b.sourceSize = size b.log.V(5).Info("Hashed file", "filename", b.sourceFile, "size", size) - conn, err := b.connectionProvider.Connect() - if err != nil { + reader := snappy.NewReader(conn) + if match, err := b.hasher.CompareHashHash(conn); err != nil { return err + } else if match { + b.log.Info("No differences found, exiting") + return nil } - defer conn.Close() - reader := snappy.NewReader(conn) var diff []int64 if blockSize, sourceHashes, err := b.hasher.DeserializeHashes(reader); err != nil { return err @@ -141,12 +149,7 @@ func (b *BlockrsyncClient) writeBlocksToServer(writer io.Writer, offsets []int64 } func isEmptyBlock(buf []byte) bool { - for _, b := range buf { - if b != 0 { - return false - } - } - return true + return bytes.Equal(buf, emptyBlock) } func int64SortFunc(i, j int64) int { @@ -160,6 +163,7 @@ func int64SortFunc(i, j int64) int { type ConnectionProvider interface { Connect() (io.ReadWriteCloser, error) + TargetAddress() string } type NetworkConnectionProvider struct { @@ -177,9 +181,13 @@ func (n *NetworkConnectionProvider) Connect() (io.ReadWriteCloser, error) { if retryCount > 30 { return nil, fmt.Errorf("unable to connect to target after %d retries", retryCount) } - time.Sleep(time.Second) + time.Sleep(time.Second * 10) retryCount++ } } return conn, nil } + +func (n *NetworkConnectionProvider) TargetAddress() string { + return n.targetAddress +} diff --git a/pkg/blockrsync/hasher.go b/pkg/blockrsync/hasher.go index 41eef44..a97e519 100644 --- a/pkg/blockrsync/hasher.go +++ b/pkg/blockrsync/hasher.go @@ -15,12 +15,18 @@ import ( "time" "github.com/go-logr/logr" - "golang.org/x/crypto/blake2b" + "github.com/zeebo/blake3" ) const ( DefaultBlockSize = int64(64 * 1024) defaultConcurrency = 25 + blake3HashLength = 32 +) + +var ( + emptyBlock []byte + zeroHash []byte ) type Hasher interface { @@ -30,6 +36,7 @@ type Hasher interface { SerializeHashes(io.Writer) error DeserializeHashes(io.Reader) (int64, map[int64][]byte, error) BlockSize() int64 + CompareHashHash(io.ReadWriter) (bool, error) } type OffsetHash struct { @@ -44,9 +51,13 @@ type FileHasher struct { blockSize int64 fileSize int64 log logr.Logger + // Hash of hashes + hashHash []byte } func NewFileHasher(blockSize int64, log logr.Logger) Hasher { + emptyBlock = make([]byte, blockSize) + zeroHash = computeZeroHash() return &FileHasher{ blockSize: blockSize, queue: make(chan int64, defaultConcurrency), @@ -56,10 +67,17 @@ func NewFileHasher(blockSize int64, log logr.Logger) Hasher { } } +func computeZeroHash() []byte { + h := blake3.New() + h.Write(emptyBlock) + return h.Sum(nil) +} + func (f *FileHasher) HashFile(fileName string) (int64, error) { f.log.V(3).Info("Hashing file", "file", fileName) t := time.Now() defer func() { + f.hashHash = f.calculateHashHash() f.log.V(3).Info("Hashing took", "milliseconds", time.Since(t).Milliseconds()) }() done := make(chan struct{}) @@ -79,10 +97,7 @@ func (f *FileHasher) HashFile(fileName string) (int64, error) { for i := 0; i < count; i++ { wg.Add(1) - h, err := blake2b.New512(nil) - if err != nil { - return 0, err - } + h := blake3.New() go func(h hash.Hash) { defer wg.Done() osFile, err := os.Open(fileName) @@ -110,6 +125,14 @@ func (f *FileHasher) HashFile(fileName string) (int64, error) { } } +func (f *FileHasher) calculateHashHash() []byte { + h := blake3.New() + for _, v := range f.hashes { + h.Write(v) + } + return h.Sum(nil) +} + func (f *FileHasher) getFileSize(fileName string) (int64, error) { file, err := os.Open(fileName) if err != nil { @@ -153,17 +176,23 @@ func (f *FileHasher) calculateHash(offset int64, rs io.ReadSeeker, h hash.Hash) f.log.V(5).Info("Failed to read") return err } - n, err = h.Write(buf[:n]) - if err != nil { - f.log.V(5).Info("Failed to write to hash") - return err - } - if n != len(buf) { - f.log.V(5).Info("Finished reading file") + var hash []byte + if bytes.Equal(buf, emptyBlock) { + hash = zeroHash + } else { + n, err = h.Write(buf[:n]) + if err != nil { + f.log.V(5).Info("Failed to write to hash") + return err + } + if n != len(buf) { + f.log.V(5).Info("Finished reading file") + } + hash = h.Sum(nil) } offsetHash := OffsetHash{ Offset: offset, - Hash: h.Sum(nil), + Hash: hash, } f.res <- offsetHash return nil @@ -210,6 +239,7 @@ func (f *FileHasher) SerializeHashes(w io.Writer) error { if err := binary.Write(w, binary.LittleEndian, int64(f.blockSize)); err != nil { return err } + length := len(f.hashes) f.log.V(5).Info("Number of blocks", "size", length) if err := binary.Write(w, binary.LittleEndian, int64(length)); err != nil { @@ -225,8 +255,8 @@ func (f *FileHasher) SerializeHashes(w io.Writer) error { if err := binary.Write(w, binary.LittleEndian, k); err != nil { return err } - if len(f.hashes[k]) != 64 { - return errors.New("invalid hash length") + if len(f.hashes[k]) != blake3HashLength { + return fmt.Errorf("invalid hash length %d", len(f.hashes[k])) } if n, err := w.Write(f.hashes[k]); err != nil { return err @@ -248,6 +278,7 @@ func (f *FileHasher) DeserializeHashes(r io.Reader) (int64, map[int64][]byte, er if err := binary.Read(r, binary.LittleEndian, &blockSize); err != nil { return 0, nil, err } + f.log.V(5).Info("Block size", "size", blockSize) var length int64 if err := binary.Read(r, binary.LittleEndian, &length); err != nil { return 0, nil, err @@ -263,7 +294,7 @@ func (f *FileHasher) DeserializeHashes(r io.Reader) (int64, map[int64][]byte, er if offset < 0 || offset > length*blockSize { return 0, nil, fmt.Errorf("invalid offset %d", offset) } - hash := make([]byte, 64) + hash := make([]byte, blake3HashLength) if n, err := io.ReadFull(r, hash); err != nil { return 0, nil, err } else { @@ -275,6 +306,22 @@ func (f *FileHasher) DeserializeHashes(r io.Reader) (int64, map[int64][]byte, er return blockSize, hashes, nil } +func (f *FileHasher) CompareHashHash(rw io.ReadWriter) (bool, error) { + f.log.V(5).Info("Comparing hash of hashes", "hash", base64.StdEncoding.EncodeToString(f.hashHash)) + if n, err := rw.Write(f.hashHash); err != nil { + return false, err + } else { + f.log.V(5).Info("Wrote hash of hashes", "bytes", n) + } + hashHash := make([]byte, blake3HashLength) + if n, err := io.ReadFull(rw, hashHash); err != nil { + return false, err + } else { + f.log.V(5).Info("Read hash of hashes", "bytes", n, "hash", base64.StdEncoding.EncodeToString(hashHash)) + } + return bytes.Equal(hashHash, f.hashHash), nil +} + func (f *FileHasher) BlockSize() int64 { return f.blockSize } diff --git a/pkg/blockrsync/hasher_test.go b/pkg/blockrsync/hasher_test.go index e6d85cc..501eee1 100644 --- a/pkg/blockrsync/hasher_test.go +++ b/pkg/blockrsync/hasher_test.go @@ -63,8 +63,8 @@ var _ = Describe("hasher tests", func() { err = hasher.SerializeHashes(w) Expect(err).ToNot(HaveOccurred()) hashes := hasher.GetHashes() - // 16 for the blocksize and length, 72 for each hash - Expect(b.Len()).To(Equal(72*len(hashes) + 16)) + // 16 for the blocksize and length, 40 for each hash (32 bytes for the hash, 8 for the offset) + Expect(b.Len()).To(Equal(40*len(hashes) + 16)) r := io.Reader(&b) blockSize, h, err := hasher.DeserializeHashes(r) Expect(err).ToNot(HaveOccurred()) diff --git a/pkg/blockrsync/server.go b/pkg/blockrsync/server.go index 9ae3e96..7edfc92 100644 --- a/pkg/blockrsync/server.go +++ b/pkg/blockrsync/server.go @@ -67,6 +67,12 @@ func (b *BlockrsyncServer) StartServer() error { defer conn.Close() writer := snappy.NewBufferedWriter(conn) <-readyChan + if match, err := b.hasher.CompareHashHash(conn); err != nil { + return err + } else if match { + b.log.Info("No differences found, exiting") + return nil + } if err := b.writeHashes(writer); err != nil { return err @@ -99,6 +105,7 @@ func (b *BlockrsyncServer) writeBlocksToFile(f *os.File, reader io.Reader) error _, err = handleReadError(err, nocallback) return err } + b.targetFileSize = max(b.targetFileSize, sourceSize) if err := b.truncateFileIfNeeded(f, sourceSize, b.targetFileSize); err != nil { _, err = handleReadError(err, nocallback) return err @@ -131,16 +138,15 @@ func (b *BlockrsyncServer) truncateFileIfNeeded(f *os.File, sourceSize, targetSi if err != nil { return err } - if targetSize > sourceSize { - b.log.V(5).Info("Source size", "size", sourceSize) - if info.Mode()&(os.ModeDevice|os.ModeCharDevice) == 0 { - // Not a block device, truncate the file if it is larger than the source file - // Truncate the target file if it is larger than the source file - b.log.V(5).Info("Source is smaller than target, truncating file") - if err := f.Truncate(sourceSize); err != nil { - return err - } - } else { + b.log.V(5).Info("Source size", "size", sourceSize) + if info.Mode()&(os.ModeDevice|os.ModeCharDevice) == 0 { + // Not a block device, set the file size to the received size + b.log.V(3).Info("Setting target file size", "size", targetSize) + if err := f.Truncate(sourceSize); err != nil { + return err + } + } else { + if targetSize > sourceSize { // empty out existing blocks PunchHole(f, sourceSize, targetSize-sourceSize) }