diff --git a/internal/block/block_test.go b/internal/block/block_test.go new file mode 100644 index 0000000..b25fa45 --- /dev/null +++ b/internal/block/block_test.go @@ -0,0 +1,190 @@ +package block + +import ( + "crypto/rand" + "github.com/eigerco/strawberry/internal/crypto" + "github.com/eigerco/strawberry/internal/jamtime" + "github.com/stretchr/testify/assert" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/eigerco/strawberry/pkg/serialization" + "github.com/eigerco/strawberry/pkg/serialization/codec" +) + +func Test_BlockEncodeDecode(t *testing.T) { + h := Header{ + ParentHash: randomHash(t), + PriorStateRoot: randomHash(t), + ExtrinsicHash: randomHash(t), + TimeSlotIndex: 123, + EpochMarker: &EpochMarker{ + Keys: [NumberOfValidators]crypto.BandersnatchPublicKey{ + randomPublicKey(t), + randomPublicKey(t), + }, + Entropy: randomHash(t), + }, + WinningTicketsMarker: [jamtime.TimeslotsPerEpoch]*Ticket{{ + Identifier: randomHash(t), + EntryIndex: 112, + }, + { + Identifier: randomHash(t), + EntryIndex: 222, + }}, + Verdicts: []crypto.Hash{ + randomHash(t), + randomHash(t), + }, + OffendersMarkers: []crypto.Ed25519PublicKey{ + randomED25519PublicKey(t), + }, + BlockAuthorIndex: 1, + VRFSignature: randomSignature(t), + BlockSealSignature: randomSignature(t), + } + + ticketProofs := []TicketProof{ + { + EntryIndex: uint8(0), + Proof: randomTicketProof(t), + }, + } + ticketExtrinsic := &TicketExtrinsic{ + TicketProofs: ticketProofs, + } + + preimageExtrinsic := &PreimageExtrinsic{ + { + ServiceIndex: uint32(1), + Data: []byte("preimage data"), + }, + } + + verdicts := []Verdict{ + { + ReportHash: randomHash(t), + EpochIndex: uint32(1), + Judgments: []Judgment{ + { + IsValid: true, + ValidatorIndex: uint16(2), + Signature: randomEd25519Signature(t), + }, + }, + }, + } + disputeExtrinsic := &DisputeExtrinsic{ + Verdicts: verdicts, + Culprits: []Culprit{ + { + ReportHash: randomHash(t), + ValidatorEd25519PublicKey: randomED25519PublicKey(t), + Signature: randomEd25519Signature(t), + }, + }, + Faults: []Fault{ + { + ReportHash: randomHash(t), + IsValid: true, + ValidatorEd25519PublicKey: randomED25519PublicKey(t), + Signature: randomEd25519Signature(t), + }, + }, + } + + assurancesExtrinsic := &AssurancesExtrinsic{ + { + Anchor: randomHash(t), + Flag: true, + ValidatorIndex: uint16(1), + Signature: randomEd25519Signature(t), + }, + } + + guaranteesExtrinsic := &GuaranteesExtrinsic{ + Guarantees: []Guarantee{ + { + WorkReport: WorkReport{ + Specification: WorkPackageSpecification{ + Hash: randomHash(t), + Length: uint32(100), + ErasureRoot: randomHash(t), + SegmentRoot: randomHash(t), + }, + Context: RefinementContext{ + AnchorHeaderHash: randomHash(t), + AnchorPosteriorStateRoot: randomHash(t), + AnchorPosteriorBeefyRoot: randomHash(t), + LookupAnchorHeaderHash: randomHash(t), + LookupAnchorTimeslot: 125, + PrerequisiteHash: nil, + }, + CoreIndex: uint16(1), + AuthorizerHash: randomHash(t), + Output: []byte("output data"), + Results: []WorkResult{ + { + ServiceIndex: uint32(1), + CodeHash: randomHash(t), + PayloadHash: randomHash(t), + GasRatio: uint64(10), + Output: WorkResultOutput{ + Data: []byte("work result data"), + Error: NoError, + }, + }, + }, + }, + Credentials: []CredentialSignature{ + { + ValidatorIndex: uint32(1), + Signature: randomEd25519Signature(t), + }, + }, + Timeslot: 200, + }, + }, + } + + e := Extrinsic{ + ET: ticketExtrinsic, + EP: preimageExtrinsic, + ED: disputeExtrinsic, + EA: assurancesExtrinsic, + EG: guaranteesExtrinsic, + } + + originalBlock := Block{ + Header: &h, + Extrinsic: &e, + } + + serializer := serialization.NewSerializer(&codec.JAMCodec{}) + serialized, err := serializer.Encode(originalBlock) + require.NoError(t, err) + + var deserializedBlock Block + err = serializer.Decode(serialized, &deserializedBlock) + require.NoError(t, err) + + assert.Equal(t, originalBlock, deserializedBlock) +} + +func randomTicketProof(t *testing.T) [ticketProofSize]byte { + var hash [ticketProofSize]byte + _, err := rand.Read(hash[:]) + require.NoError(t, err) + + return hash +} + +func randomEd25519Signature(t *testing.T) [crypto.Ed25519SignatureSize]byte { + var hash [crypto.Ed25519SignatureSize]byte + _, err := rand.Read(hash[:]) + require.NoError(t, err) + + return hash +} diff --git a/internal/block/header_test.go b/internal/block/header_test.go index b708d50..4dfcd05 100644 --- a/internal/block/header_test.go +++ b/internal/block/header_test.go @@ -3,13 +3,13 @@ package block import ( "crypto/ed25519" "crypto/rand" + "github.com/eigerco/strawberry/internal/jamtime" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/eigerco/strawberry/internal/crypto" - "github.com/eigerco/strawberry/internal/jamtime" "github.com/eigerco/strawberry/pkg/serialization" "github.com/eigerco/strawberry/pkg/serialization/codec" ) @@ -45,17 +45,13 @@ func Test_HeaderEncodeDecode(t *testing.T) { VRFSignature: randomSignature(t), BlockSealSignature: randomSignature(t), } - serializer := serialization.NewSerializer[uint64](&codec.SCALECodec[uint64]{}) + serializer := serialization.NewSerializer(&codec.JAMCodec{}) bb, err := serializer.Encode(h) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) h2 := Header{} err = serializer.Decode(bb, &h2) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) assert.Equal(t, h, h2) } @@ -66,6 +62,7 @@ func randomHash(t *testing.T) crypto.Hash { require.NoError(t, err) return crypto.Hash(hash) } + func randomED25519PublicKey(t *testing.T) crypto.Ed25519PublicKey { hash := make([]byte, ed25519.PublicKeySize) _, err := rand.Read(hash) @@ -78,6 +75,7 @@ func randomPublicKey(t *testing.T) crypto.BandersnatchPublicKey { require.NoError(t, err) return crypto.BandersnatchPublicKey(hash) } + func randomSignature(t *testing.T) crypto.BandersnatchSignature { hash := make([]byte, 96) _, err := rand.Read(hash) diff --git a/pkg/serialization/codec/codec.go b/pkg/serialization/codec/codec.go index abef721..4e22668 100644 --- a/pkg/serialization/codec/codec.go +++ b/pkg/serialization/codec/codec.go @@ -4,11 +4,7 @@ type Uint interface { uint8 | uint16 | uint32 | uint64 } -type Codec[T Uint] interface { +type Codec interface { Marshal(v interface{}) ([]byte, error) - MarshalGeneral(x uint64) ([]byte, error) - MarshalTrivialUint(x T, l uint8) ([]byte, error) Unmarshal(data []byte, v interface{}) error - UnmarshalGeneral(data []byte, v *uint64) error - UnmarshalTrivialUint(data []byte, v *T) error } diff --git a/pkg/serialization/codec/jam/decode.go b/pkg/serialization/codec/jam/decode.go new file mode 100644 index 0000000..5f18290 --- /dev/null +++ b/pkg/serialization/codec/jam/decode.go @@ -0,0 +1,362 @@ +package jam + +import ( + "bytes" + "fmt" + "io" + "math" + "math/bits" + "reflect" +) + +func Unmarshal(data []byte, dst interface{}) error { + dstv := reflect.ValueOf(dst) + if dstv.Kind() != reflect.Ptr || dstv.IsNil() { + return fmt.Errorf(ErrUnsupportedType, dst) + } + + ds := byteReader{} + ds.Reader = bytes.NewBuffer(data) + + return ds.unmarshal(indirect(dstv)) +} + +type byteReader struct { + io.Reader +} + +func (br *byteReader) unmarshal(value reflect.Value) error { + in := value.Interface() + switch in.(type) { + + case int, uint: + return br.decodeUint(value) + case int8, uint8, int16, uint16, int32, uint32, int64, uint64: + return br.decodeFixedWidthInt(value) + case []byte: + return br.decodeBytes(value) + case bool: + return br.decodeBool(value) + default: + return br.handleReflectTypes(value) + } +} + +func (br *byteReader) handleReflectTypes(value reflect.Value) error { + switch value.Kind() { + case reflect.Bool, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return br.decodeCustomPrimitive(value) + case reflect.Ptr: + return br.decodePointer(value) + case reflect.Struct: + return br.decodeStruct(value) + case reflect.Array: + return br.decodeArray(value) + case reflect.Slice: + return br.decodeSlice(value) + default: + return fmt.Errorf(ErrUnsupportedType, value.Interface()) + } +} + +func (br *byteReader) decodeCustomPrimitive(value reflect.Value) error { + in := value.Interface() + inType := reflect.TypeOf(in) + var temp reflect.Value + + switch inType.Kind() { + case reflect.Bool: + temp = reflect.New(reflect.TypeOf(false)) + case reflect.Int: + temp = reflect.New(reflect.TypeOf(0)) + case reflect.Int8: + temp = reflect.New(reflect.TypeOf(int8(0))) + case reflect.Int16: + temp = reflect.New(reflect.TypeOf(int16(0))) + case reflect.Int32: + temp = reflect.New(reflect.TypeOf(int32(0))) + case reflect.Int64: + temp = reflect.New(reflect.TypeOf(int64(0))) + case reflect.Uint: + temp = reflect.New(reflect.TypeOf(uint(0))) + case reflect.Uint8: + temp = reflect.New(reflect.TypeOf(uint8(0))) + case reflect.Uint16: + temp = reflect.New(reflect.TypeOf(uint16(0))) + case reflect.Uint32: + temp = reflect.New(reflect.TypeOf(uint32(0))) + case reflect.Uint64: + temp = reflect.New(reflect.TypeOf(uint64(0))) + default: + return fmt.Errorf(ErrUnsupportedType, in) + } + + if err := br.unmarshal(temp.Elem()); err != nil { + return err + } + + value.Set(temp.Elem().Convert(inType)) + + return nil +} + +func (br *byteReader) ReadOctet() (byte, error) { + var b [1]byte + _, err := br.Reader.Read(b[:]) + if err != nil { + return 0, err + } + return b[0], nil +} + +func (br *byteReader) decodePointer(value reflect.Value) error { + rb, err := br.ReadOctet() + if err != nil { + return err + } + + switch rb { + case 0x00: + // Handle the nil pointer case by setting the destination to nil if necessary + if !value.IsNil() { + value.Set(reflect.Zero(value.Type())) + } + case 0x01: + // Check if the destination is a non-nil pointer + if !value.IsZero() { + // If it's a pointer to another pointer, we need to handle it recursively + if value.Elem().Kind() == reflect.Ptr { + return br.unmarshal(value.Elem().Elem()) + } + return br.unmarshal(value.Elem()) + } + + // If value is nil or zero, we need to create a new instance + elemType := value.Type().Elem() + tempElem := reflect.New(elemType) + if err := br.unmarshal(tempElem.Elem()); err != nil { + return err + } + value.Set(tempElem) + default: + return ErrInvalidPointer + } + return nil +} + +func (br *byteReader) decodeSlice(value reflect.Value) error { + l, err := br.decodeLength() + if err != nil { + return err + } + in := value.Interface() + temp := reflect.New(reflect.ValueOf(in).Type()) + for i := uint(0); i < l; i++ { + tempElemType := reflect.TypeOf(in).Elem() + tempElem := reflect.New(tempElemType).Elem() + + err = br.unmarshal(tempElem) + if err != nil { + return err + } + temp.Elem().Set(reflect.Append(temp.Elem(), tempElem)) + } + value.Set(temp.Elem()) + + return nil +} + +func (br *byteReader) decodeArray(value reflect.Value) error { + in := value.Interface() + temp := reflect.New(reflect.ValueOf(in).Type()) + for i := 0; i < temp.Elem().Len(); i++ { + elem := temp.Elem().Index(i) + err := br.unmarshal(elem) + if err != nil { + return err + } + } + value.Set(temp.Elem()) + + return nil +} + +func (br *byteReader) decodeStruct(value reflect.Value) error { + t := value.Type() + + // Iterate over each field in the struct + for i := 0; i < value.NumField(); i++ { + field := value.Field(i) + fieldType := t.Field(i) + + // Skip unexported fields + if !field.CanSet() { + continue + } + + // Decode the field value + err := br.unmarshal(field) + if err != nil { + return fmt.Errorf(ErrDecodingStructField, fieldType.Name, err) + } + } + + return nil +} + +func (br *byteReader) decodeBool(value reflect.Value) error { + rb, err := br.ReadOctet() + if err != nil { + return err + } + + switch rb { + case 0x00: + value.SetBool(false) + case 0x01: + value.SetBool(true) + default: + return ErrDecodingBool + } + + return nil +} + +func (br *byteReader) decodeUint(value reflect.Value) error { + // Read the first byte to determine how many bytes are used in the encoding + prefix, err := br.ReadOctet() + if err != nil { + return fmt.Errorf(ErrReadingByte, err) + } + + var serialized []byte + + // Determine the number of additional bytes using the prefix + l := uint8(bits.LeadingZeros8(^prefix)) + + serialized = make([]byte, l+1) + serialized[0] = prefix + _, err = br.Read(serialized[1:]) + if err != nil { + return fmt.Errorf(ErrReadingBytes, err) + } + + var v uint64 + err = DeserializeUint64WithLength(serialized, l, &v) + if err != nil { + return fmt.Errorf(ErrDecodingUint, err) + } + + // Set the decoded value into the destination + value.Set(reflect.ValueOf(v).Convert(value.Type())) + + return nil +} + +// decodeLength is helper method which calls decodeUint and casts to int +func (br *byteReader) decodeLength() (uint, error) { + var l uint + dstv := reflect.New(reflect.TypeOf(l)) + err := br.decodeUint(dstv.Elem()) + if err != nil { + return 0, fmt.Errorf(ErrDecodingUint, err) + } + l = dstv.Elem().Interface().(uint) + return l, nil +} + +// decodeBytes is used to decode with a destination of []byte +func (br *byteReader) decodeBytes(dstv reflect.Value) error { + length, err := br.decodeLength() + if err != nil { + return err + } + + if length > math.MaxUint32 { + return ErrExceedingByteArrayLimit + } + + b := make([]byte, length) + + if length > 0 { + _, err = br.Read(b) + if err != nil { + return err + } + } + + in := dstv.Interface() + inType := reflect.TypeOf(in) + dstv.Set(reflect.ValueOf(b).Convert(inType)) + return nil +} + +func (br *byteReader) decodeFixedWidthInt(dstv reflect.Value) error { + in := dstv.Interface() + var buf []byte + var length int + + switch in.(type) { + case uint8: + length = 1 + case uint16: + length = 2 + case uint32: + length = 4 + case uint64: + length = 8 + default: + return fmt.Errorf(ErrUnsupportedType, in) + } + + // Read the appropriate number of bytes + buf = make([]byte, length) + _, err := br.Read(buf) + if err != nil { + return fmt.Errorf(ErrReadingByte, err) + } + + // Deserialize the value + switch in.(type) { + case uint8: + var temp uint8 + DeserializeTrivialNatural(buf, &temp) + dstv.Set(reflect.ValueOf(temp)) + case uint16: + var temp uint16 + DeserializeTrivialNatural(buf, &temp) + dstv.Set(reflect.ValueOf(temp)) + case uint32: + var temp uint32 + DeserializeTrivialNatural(buf, &temp) + dstv.Set(reflect.ValueOf(temp)) + case uint64: + var temp uint64 + DeserializeTrivialNatural(buf, &temp) + dstv.Set(reflect.ValueOf(temp)) + } + + return nil +} + +// indirect recursively dereferences pointers and interfaces, +// allocating new pointers as needed, until it reaches a non-pointer value. +func indirect(v reflect.Value) reflect.Value { + for { + switch v.Kind() { + case reflect.Ptr: + if v.IsNil() { + v.Set(reflect.New(v.Type().Elem())) + } + v = v.Elem() + case reflect.Interface: + if v.IsNil() { + return v + } + v = v.Elem() + default: + return v + } + } +} diff --git a/pkg/serialization/codec/jam/encode.go b/pkg/serialization/codec/jam/encode.go new file mode 100644 index 0000000..1ebdc1a --- /dev/null +++ b/pkg/serialization/codec/jam/encode.go @@ -0,0 +1,209 @@ +package jam + +import ( + "bytes" + "fmt" + "io" + "reflect" +) + +func Marshal(v interface{}) ([]byte, error) { + buffer := bytes.NewBuffer(nil) + es := byteWriter{ + Writer: buffer, + } + err := es.marshal(v) + if err != nil { + return nil, err + } + + b := buffer.Bytes() + + return b, nil +} + +type byteWriter struct { + io.Writer +} + +func (bw *byteWriter) marshal(in interface{}) error { + switch in := in.(type) { + case int: + return bw.encodeUint(uint(in)) + case uint: + return bw.encodeUint(in) + case uint8, uint16, uint32, uint64: + return bw.encodeFixedWidthUint(in) + case []byte: + return bw.encodeBytes(in) + case bool: + return bw.encodeBool(in) + default: + return bw.handleReflectTypes(in) + } +} + +func (bw *byteWriter) handleReflectTypes(in interface{}) error { + val := reflect.ValueOf(in) + switch val.Kind() { + case reflect.Bool, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return bw.encodeCustomPrimitive(in) + case reflect.Ptr: + elem := reflect.ValueOf(in).Elem() + switch elem.IsValid() { + case false: + _, err := bw.Write([]byte{0}) + return err + default: + _, err := bw.Write([]byte{1}) + if err != nil { + return err + } + return bw.marshal(elem.Interface()) + } + case reflect.Struct: + return bw.encodeStruct(in) + case reflect.Array: + return bw.encodeArray(in) + case reflect.Slice: + return bw.encodeSlice(in) + default: + return fmt.Errorf(ErrUnsupportedType, in) + } +} + +func (bw *byteWriter) encodeCustomPrimitive(in interface{}) error { + switch reflect.TypeOf(in).Kind() { + case reflect.Bool: + in = reflect.ValueOf(in).Convert(reflect.TypeOf(false)).Interface() + case reflect.Int: + in = reflect.ValueOf(in).Convert(reflect.TypeOf(int(0))).Interface() + case reflect.Int8: + in = reflect.ValueOf(in).Convert(reflect.TypeOf(int8(0))).Interface() + case reflect.Int16: + in = reflect.ValueOf(in).Convert(reflect.TypeOf(int16(0))).Interface() + case reflect.Int32: + in = reflect.ValueOf(in).Convert(reflect.TypeOf(int32(0))).Interface() + case reflect.Int64: + in = reflect.ValueOf(in).Convert(reflect.TypeOf(int64(0))).Interface() + case reflect.Uint: + in = reflect.ValueOf(in).Convert(reflect.TypeOf(uint(0))).Interface() + case reflect.Uint8: + in = reflect.ValueOf(in).Convert(reflect.TypeOf(uint8(0))).Interface() + case reflect.Uint16: + in = reflect.ValueOf(in).Convert(reflect.TypeOf(uint16(0))).Interface() + case reflect.Uint32: + in = reflect.ValueOf(in).Convert(reflect.TypeOf(uint32(0))).Interface() + case reflect.Uint64: + in = reflect.ValueOf(in).Convert(reflect.TypeOf(uint64(0))).Interface() + default: + return fmt.Errorf(ErrUnsupportedType, in) + } + + return bw.marshal(in) +} + +func (bw *byteWriter) encodeSlice(in interface{}) error { + v := reflect.ValueOf(in) + err := bw.encodeLength(v.Len()) + if err != nil { + return err + } + for i := 0; i < v.Len(); i++ { + err = bw.marshal(v.Index(i).Interface()) + if err != nil { + return err + } + } + return nil +} + +func (bw *byteWriter) encodeArray(in interface{}) error { + v := reflect.ValueOf(in) + for i := 0; i < v.Len(); i++ { + err := bw.marshal(v.Index(i).Interface()) + if err != nil { + return err + } + } + return nil +} + +func (bw *byteWriter) encodeBool(l bool) error { + var err error + switch l { + case true: + _, err = bw.Write([]byte{0x01}) + case false: + _, err = bw.Write([]byte{0x00}) + } + + return err +} + +func (bw *byteWriter) encodeBytes(b []byte) error { + err := bw.encodeLength(len(b)) + if err != nil { + return err + } + + _, err = bw.Write(b) + return err +} + +func (bw *byteWriter) encodeFixedWidthUint(i interface{}) error { + var data []byte + + switch v := i.(type) { + case uint8: + data = SerializeTrivialNatural(v, 1) + case uint16: + data = SerializeTrivialNatural(v, 2) + case uint32: + data = SerializeTrivialNatural(v, 4) + case uint64: + data = SerializeTrivialNatural(v, 8) + default: + return fmt.Errorf(ErrUnsupportedType, i) + } + + _, err := bw.Write(data) + return err +} + +func (bw *byteWriter) encodeStruct(in interface{}) error { + v := reflect.ValueOf(in) + t := reflect.TypeOf(in) + + // Iterate over each field in the struct + for i := 0; i < t.NumField(); i++ { + field := v.Field(i) + fieldType := t.Field(i) + + // Skip unexported fields + if !field.CanInterface() { + continue + } + + // Marshal and encode the field value + err := bw.marshal(field.Interface()) + if err != nil { + return fmt.Errorf(ErrEncodingStructField, fieldType.Name, err) + } + } + + return nil +} + +func (bw *byteWriter) encodeLength(l int) error { + return bw.encodeUint(uint(l)) +} + +func (bw *byteWriter) encodeUint(i uint) error { + encodedBytes := SerializeUint64(uint64(i)) + + _, err := bw.Write(encodedBytes) + + return err +} diff --git a/pkg/serialization/codec/jam/encode_decode_jam_test.go b/pkg/serialization/codec/jam/encode_decode_jam_test.go new file mode 100644 index 0000000..830e4a2 --- /dev/null +++ b/pkg/serialization/codec/jam/encode_decode_jam_test.go @@ -0,0 +1,78 @@ +package jam_test + +import ( + "github.com/eigerco/strawberry/pkg/serialization/codec/jam" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "math" + "testing" +) + +type InnerStruct struct { + Uint64 uint64 + Uint32 uint32 + Uint16 uint16 + Uint8 uint8 +} +type TestStruct struct { + IntField int + BoolField bool + LargeUint uint + InnerSlice []InnerStruct +} + +func TestMarshalUnmarshal(t *testing.T) { + original := TestStruct{ + BoolField: true, + LargeUint: math.MaxUint, + InnerSlice: []InnerStruct{ + {1, 2, 3, 4}, + {2, 3, 4, 5}, + {3, 4, 5, 6}, + }, + } + + marshaledData, err := jam.Marshal(original) + require.NoError(t, err) + + var unmarshaled TestStruct + err = jam.Unmarshal(marshaledData, &unmarshaled) + require.NoError(t, err) + + assert.Equal(t, original, unmarshaled) +} + +func TestEmptyStruct(t *testing.T) { + original := TestStruct{} + + marshaledData, err := jam.Marshal(original) + require.NoError(t, err) + + var unmarshaled TestStruct + err = jam.Unmarshal(marshaledData, &unmarshaled) + require.NoError(t, err) + + assert.Equal(t, original, unmarshaled) +} + +func TestMarshalUnmarshalWithPointer(t *testing.T) { + type StructWithPointer struct { + IntField *uint + } + intVal := uint(42) + original := StructWithPointer{ + IntField: &intVal, + } + + marshaledData, err := jam.Marshal(original) + require.NoError(t, err) + + // Prepare a variable to hold the unmarshaled struct + var unmarshaled StructWithPointer + + // Unmarshal the data back into the struct + err = jam.Unmarshal(marshaledData, &unmarshaled) + require.NoError(t, err) + + assert.Equal(t, original, unmarshaled) +} diff --git a/pkg/serialization/codec/jam/errors.go b/pkg/serialization/codec/jam/errors.go index 57c7cca..52cde3a 100644 --- a/pkg/serialization/codec/jam/errors.go +++ b/pkg/serialization/codec/jam/errors.go @@ -7,13 +7,14 @@ import ( var ( // errFirstByteNineByteSerialization is returned when the first byte has wrong value in 9-byte serialization errFirstByteNineByteSerialization = errors.New("expected first byte to be 255 for 9-byte serialization") + ErrInvalidPointer = errors.New("invalid pointer") + ErrDecodingBool = errors.New("error decoding boolean") + ErrExceedingByteArrayLimit = errors.New("byte array length exceeds max value of uint32") - ErrEmptyData = errors.New("empty data") - ErrNonPointerOrNil = errors.New("value must be a not-nil pointer") - - ErrInvalidBooleanEncoding = errors.New("invalid boolean encoding") - - ErrUnsupportedType = "unsupported type: %T" - ErrArrayLengthMismatch = "array length mismatch: expected %d, got %d" - ErrDataLengthMismatch = "data length mismatch: expected %d, got %d" + ErrUnsupportedType = "unsupported type: %v" + ErrReadingBytes = "error reading bytes: %w" + ErrReadingByte = "error reading byte: %w" + ErrDecodingUint = "error decoding uint: : %w" + ErrEncodingStructField = "encoding struct field '%s': %w" + ErrDecodingStructField = "decoding struct field '%s': %w" ) diff --git a/pkg/serialization/codec/jam/general_natural.go b/pkg/serialization/codec/jam/general_natural.go index 587ac53..cff50ac 100644 --- a/pkg/serialization/codec/jam/general_natural.go +++ b/pkg/serialization/codec/jam/general_natural.go @@ -3,13 +3,10 @@ package jam import ( "encoding/binary" "math" - "math/bits" ) -// GeneralNatural implements the formula (able to encode naturals of up to 2^64) -type GeneralNatural struct{} - -func (j *GeneralNatural) SerializeUint64(x uint64) []byte { +// SerializeUint64 implements the general formula (able to encode naturals of up to 2^64) +func SerializeUint64(x uint64) []byte { var l uint8 // Determine the length needed to represent the value for l = 0; l < 8; l++ { @@ -33,8 +30,8 @@ func (j *GeneralNatural) SerializeUint64(x uint64) []byte { return bytes } -// DeserializeUint64 deserializes a byte slice into a uint64 value. -func (j *GeneralNatural) DeserializeUint64(serialized []byte, u *uint64) error { +// DeserializeUint64WithLength deserializes a byte slice into a uint64 value, with length `l`. +func DeserializeUint64WithLength(serialized []byte, l uint8, u *uint64) error { *u = 0 n := len(serialized) @@ -50,16 +47,13 @@ func (j *GeneralNatural) DeserializeUint64(serialized []byte, u *uint64) error { return nil } - prefix := serialized[0] - l := uint8(bits.LeadingZeros8(^prefix)) - // Deserialize the first `l` bytes for i := uint8(0); i < l; i++ { *u |= uint64(serialized[i+1]) << (8 * i) } // Combine the remaining part of the prefix - *u |= uint64(prefix&(math.MaxUint8>>l)) << (8 * l) + *u |= uint64(serialized[0]&(math.MaxUint8>>l)) << (8 * l) return nil } diff --git a/pkg/serialization/codec/jam/general_natural_test.go b/pkg/serialization/codec/jam/general_natural_test.go index b325049..958824b 100644 --- a/pkg/serialization/codec/jam/general_natural_test.go +++ b/pkg/serialization/codec/jam/general_natural_test.go @@ -5,6 +5,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "math" + "math/bits" "testing" ) @@ -49,19 +50,21 @@ func TestEncodeDecodeUint64(t *testing.T) { {1 << 63, []byte{255, 0, 0, 0, 0, 0, 0, 0, 128}}, // 9223372036854775808 } - gn := GeneralNatural{} - for _, tc := range testCases { t.Run(fmt.Sprintf("uint64(%d)", tc.input), func(t *testing.T) { // Marshal the x value - serialized := gn.SerializeUint64(tc.input) + serialized := SerializeUint64(tc.input) // Check if the serialized output matches the expected output assert.Equal(t, tc.expected, serialized, "serialized output mismatch for x %d", tc.input) + var l uint8 + if len(serialized) > 0 { + l = uint8(bits.LeadingZeros8(^serialized[0])) + } // Unmarshal the serialized data back into a uint64 var deserialized uint64 - err := gn.DeserializeUint64(serialized, &deserialized) + err := DeserializeUint64WithLength(serialized, l, &deserialized) require.NoError(t, err, "unmarshal(%v) returned an unexpected error", serialized) // Check if the deserialized value matches the original x diff --git a/pkg/serialization/codec/jam_codec.go b/pkg/serialization/codec/jam_codec.go index 65515a6..cad1371 100644 --- a/pkg/serialization/codec/jam_codec.go +++ b/pkg/serialization/codec/jam_codec.go @@ -1,150 +1,24 @@ package codec import ( - "fmt" - "reflect" - "github.com/eigerco/strawberry/pkg/serialization/codec/jam" ) -// JAMCodec implements the Codec interface for JSON encoding and decoding. -type JAMCodec[T Uint] struct { - gn jam.GeneralNatural +// JAMCodec implements the Codec interface for JAM encoding and decoding. +type JAMCodec struct { } // NewJamCodec initializes an instance of Jam codec -func NewJamCodec[T Uint]() *JAMCodec[T] { - return &JAMCodec[T]{ - gn: jam.GeneralNatural{}, - } +func NewJamCodec() *JAMCodec { + return &JAMCodec{} } -// Marshal encodes the given value into a byte slice. -func (j *JAMCodec[T]) Marshal(v interface{}) ([]byte, error) { - val := reflect.ValueOf(v) - - switch val.Kind() { - case reflect.Bool: - return j.encodeBool(val.Bool()) - case reflect.Array, reflect.Slice: - if val.Type().Elem().Kind() == reflect.Uint8 { - return j.encodeByteSlice(val) - } - } - - return nil, fmt.Errorf(jam.ErrUnsupportedType, v) +// Marshal encodes the given value +func (j *JAMCodec) Marshal(v interface{}) ([]byte, error) { + return jam.Marshal(v) } // Unmarshal decodes the given byte slice into the provided value. -func (j *JAMCodec[T]) Unmarshal(data []byte, v interface{}) error { - if len(data) == 0 { - return jam.ErrEmptyData - } - - val := reflect.ValueOf(v) - if val.Kind() != reflect.Ptr || val.IsNil() { - return jam.ErrNonPointerOrNil - } - - elem := val.Elem() - - switch elem.Kind() { - case reflect.Bool: - return j.decodeBool(data, elem) - case reflect.Slice, reflect.Array: - if elem.Type().Elem().Kind() == reflect.Uint8 { - return j.decodeByteSlice(data, elem) - } - } - - return fmt.Errorf(jam.ErrUnsupportedType, v) -} - -func (j *JAMCodec[T]) MarshalGeneral(v uint64) ([]byte, error) { - return j.gn.SerializeUint64(v), nil -} - -func (j *JAMCodec[T]) MarshalTrivialUint(x T, l uint8) ([]byte, error) { - return jam.SerializeTrivialNatural(x, l), nil -} - -func (j *JAMCodec[T]) UnmarshalGeneral(data []byte, v *uint64) error { - return j.gn.DeserializeUint64(data, v) -} - -func (j *JAMCodec[T]) UnmarshalTrivialUint(data []byte, x *T) error { - jam.DeserializeTrivialNatural(data, x) - return nil -} - -// encodeBool encodes a boolean value into a byte slice. -func (j *JAMCodec[T]) encodeBool(b bool) ([]byte, error) { - if b { - return []byte{0x01}, nil // true -> 0x01 - } - return []byte{0x00}, nil // false -> 0x00 -} - -// decodeBool decodes a boolean value from a byte slice. -func (j *JAMCodec[T]) decodeBool(data []byte, elem reflect.Value) error { - if len(data) == 0 { - return jam.ErrEmptyData - } - - switch data[0] { - case 0x01: - elem.SetBool(true) - case 0x00: - elem.SetBool(false) - default: - return jam.ErrInvalidBooleanEncoding - } - return nil -} - -// encodeByteSlice encodes a byte slice or array into a byte slice with a prefixed length. -func (j *JAMCodec[T]) encodeByteSlice(val reflect.Value) ([]byte, error) { - byteSlice, err := j.toByteSlice(val) - if err != nil { - return nil, err - } - // Prepend the length to the byte slice - result := append([]byte{byte(len(byteSlice))}, byteSlice...) - return result, nil -} - -// decodeByteSlice decodes a byte slice or array from the given byte slice. -func (j *JAMCodec[T]) decodeByteSlice(data []byte, elem reflect.Value) error { - length := int(data[0]) - - if len(data)-1 < length { - return fmt.Errorf(jam.ErrDataLengthMismatch, length, len(data)-1) - } - - extractedData := data[len(data)-length:] - - if elem.Kind() == reflect.Slice { - elem.SetBytes(extractedData) - } else if elem.Kind() == reflect.Array { - if elem.Len() != len(extractedData) { - return fmt.Errorf(jam.ErrArrayLengthMismatch, elem.Len(), len(extractedData)) - } - reflect.Copy(elem, reflect.ValueOf(extractedData)) - } - - return nil -} - -// toByteSlice converts an array or slice of bytes to a byte slice. -func (j *JAMCodec[T]) toByteSlice(val reflect.Value) ([]byte, error) { - switch val.Kind() { - case reflect.Array: - b := make([]byte, val.Len()) - reflect.Copy(reflect.ValueOf(b), val) - return b, nil - case reflect.Slice: - return val.Interface().([]byte), nil - default: - return nil, fmt.Errorf(jam.ErrUnsupportedType, val.Kind()) - } +func (j *JAMCodec) Unmarshal(data []byte, v interface{}) error { + return jam.Unmarshal(data, v) } diff --git a/pkg/serialization/codec/jam_codec_test.go b/pkg/serialization/codec/jam_codec_test.go index da8dab9..31f0bd8 100644 --- a/pkg/serialization/codec/jam_codec_test.go +++ b/pkg/serialization/codec/jam_codec_test.go @@ -3,35 +3,12 @@ package codec import ( "testing" - "github.com/eigerco/strawberry/pkg/serialization/codec/jam" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) -func TestDecodeErrors(t *testing.T) { - codec := &JAMCodec[uint64]{} - - // Test with a non-pointer value - var nonPointer int - err := codec.Unmarshal([]byte{1}, nonPointer) - require.Error(t, err) - assert.Equal(t, jam.ErrNonPointerOrNil, err) - - // Test with a nil pointer - var nilPointer *int - err = codec.Unmarshal([]byte{1}, nilPointer) - require.Error(t, err) - assert.Equal(t, jam.ErrNonPointerOrNil, err) - - // Empty data - var dst *int - err = codec.Unmarshal([]byte{}, dst) - require.Error(t, err) - assert.Equal(t, jam.ErrEmptyData, err) -} - func TestEncodeDecodeSlice(t *testing.T) { - j := JAMCodec[uint64]{} + j := JAMCodec{} input := []byte{1, 2, 3, 4} // Marshal the input value @@ -51,7 +28,7 @@ func TestEncodeDecodeSlice(t *testing.T) { } func TestEncodeDecodeArray(t *testing.T) { - j := JAMCodec[uint32]{} + j := JAMCodec{} input := [4]byte{1, 2, 3, 4} // Marshal the input value @@ -59,7 +36,7 @@ func TestEncodeDecodeArray(t *testing.T) { require.NoError(t, err) // Check if the serialized output matches the expected output - assert.Equal(t, []byte{4, 1, 2, 3, 4}, serialized, "serialized output mismatch for input %d", input) + assert.Equal(t, []byte{1, 2, 3, 4}, serialized, "serialized output mismatch for input %d", input) // Unmarshal the serialized data back into byte var deserialized [4]byte @@ -71,7 +48,7 @@ func TestEncodeDecodeArray(t *testing.T) { } func TestEncodeDecodeBool(t *testing.T) { - j := JAMCodec[uint32]{} + j := JAMCodec{} input := true // Marshal the boolean value diff --git a/pkg/serialization/codec/json_codec.go b/pkg/serialization/codec/json_codec.go index 7445b1a..fb5e78f 100644 --- a/pkg/serialization/codec/json_codec.go +++ b/pkg/serialization/codec/json_codec.go @@ -5,28 +5,12 @@ import ( ) // JSONCodec implements the Codec interface for JSON encoding and decoding. -type JSONCodec[T Uint] struct{} +type JSONCodec struct{} -func (j *JSONCodec[T]) Marshal(v interface{}) ([]byte, error) { +func (j *JSONCodec) Marshal(v interface{}) ([]byte, error) { return json.Marshal(v) } -func (j *JSONCodec[T]) MarshalGeneral(v uint64) ([]byte, error) { - return json.Marshal(v) -} - -func (j *JSONCodec[T]) MarshalTrivialUint(x T, l uint8) ([]byte, error) { - return json.Marshal(x) -} - -func (j *JSONCodec[T]) Unmarshal(data []byte, v interface{}) error { +func (j *JSONCodec) Unmarshal(data []byte, v interface{}) error { return json.Unmarshal(data, v) } - -func (j *JSONCodec[T]) UnmarshalGeneral(data []byte, v *uint64) error { - return json.Unmarshal(data, v) -} - -func (j *JSONCodec[T]) UnmarshalTrivialUint(data []byte, x *T) error { - return json.Unmarshal(data, x) -} diff --git a/pkg/serialization/codec/scale_codec.go b/pkg/serialization/codec/scale_codec.go index 193c914..e9e101a 100644 --- a/pkg/serialization/codec/scale_codec.go +++ b/pkg/serialization/codec/scale_codec.go @@ -3,28 +3,12 @@ package codec import "github.com/ChainSafe/gossamer/pkg/scale" // SCALECodec implements the Codec interface for SCALE encoding and decoding. -type SCALECodec[T Uint] struct{} +type SCALECodec struct{} -func (s *SCALECodec[T]) Marshal(v interface{}) ([]byte, error) { +func (s *SCALECodec) Marshal(v interface{}) ([]byte, error) { return scale.Marshal(v) } -func (s *SCALECodec[T]) MarshalGeneral(v uint64) ([]byte, error) { - return scale.Marshal(v) -} - -func (s *SCALECodec[T]) MarshalTrivialUint(x T, l uint8) ([]byte, error) { - return scale.Marshal(x) -} - -func (s *SCALECodec[T]) Unmarshal(data []byte, v interface{}) error { +func (s *SCALECodec) Unmarshal(data []byte, v interface{}) error { return scale.Unmarshal(data, v) } - -func (s *SCALECodec[T]) UnmarshalGeneral(data []byte, v *uint64) error { - return scale.Unmarshal(data, v) -} - -func (s *SCALECodec[T]) UnmarshalTrivialUint(data []byte, x *T) error { - return scale.Unmarshal(data, x) -} diff --git a/pkg/serialization/serializer.go b/pkg/serialization/serializer.go index 9459ded..41f2982 100644 --- a/pkg/serialization/serializer.go +++ b/pkg/serialization/serializer.go @@ -3,41 +3,21 @@ package serialization import "github.com/eigerco/strawberry/pkg/serialization/codec" // Serializer provides methods to encode and decode using a specified codec. -type Serializer[T codec.Uint] struct { - codec codec.Codec[T] +type Serializer struct { + codec codec.Codec } // NewSerializer initializes a new Serializer with the given codec. -func NewSerializer[T codec.Uint](c codec.Codec[T]) *Serializer[T] { - return &Serializer[T]{codec: c} +func NewSerializer(c codec.Codec) *Serializer { + return &Serializer{codec: c} } // Encode serializes the given value using the codec. -func (s *Serializer[T]) Encode(v interface{}) ([]byte, error) { +func (s *Serializer) Encode(v interface{}) ([]byte, error) { return s.codec.Marshal(v) } -// EncodeGeneral is specific encoding for natural numbers up to 2^64 -func (s *Serializer[T]) EncodeGeneral(v uint64) ([]byte, error) { - return s.codec.MarshalGeneral(v) -} - -// EncodeTrivialUint is the trivial encoding for natural numbers -func (s *Serializer[T]) EncodeTrivialUint(x T, l uint8) ([]byte, error) { - return s.codec.MarshalTrivialUint(x, l) -} - // Decode deserializes the given data into the specified value using the codec. -func (s *Serializer[T]) Decode(data []byte, v interface{}) error { +func (s *Serializer) Decode(data []byte, v interface{}) error { return s.codec.Unmarshal(data, v) } - -// DecodeGeneral is specific decoding for natural numbers up to 2^64 -func (s *Serializer[T]) DecodeGeneral(data []byte, v *uint64) error { - return s.codec.UnmarshalGeneral(data, v) -} - -// DecodeTrivialUint is the trivial decoding for natural numbers -func (s *Serializer[T]) DecodeTrivialUint(data []byte, v *T) error { - return s.codec.UnmarshalTrivialUint(data, v) -} diff --git a/pkg/serialization/serializer_test.go b/pkg/serialization/serializer_test.go index 69bb9e6..49f076f 100644 --- a/pkg/serialization/serializer_test.go +++ b/pkg/serialization/serializer_test.go @@ -16,8 +16,8 @@ type PayloadExample struct { } func TestJSONSerializer(t *testing.T) { - jsonCodec := &codec.JSONCodec[uint16]{} - serializer := serialization.NewSerializer[uint16](jsonCodec) + jsonCodec := &codec.JSONCodec{} + serializer := serialization.NewSerializer(jsonCodec) example := PayloadExample{ID: 1, Data: []byte{1, 2, 3}} @@ -34,8 +34,8 @@ func TestJSONSerializer(t *testing.T) { } func TestSCALESerializer(t *testing.T) { - scaleCodec := &codec.SCALECodec[uint64]{} - serializer := serialization.NewSerializer[uint64](scaleCodec) + scaleCodec := &codec.SCALECodec{} + serializer := serialization.NewSerializer(scaleCodec) example := PayloadExample{ID: 2, Data: []byte{1, 2, 3}} @@ -51,54 +51,51 @@ func TestSCALESerializer(t *testing.T) { assert.Equal(t, example, decoded) } -func TestGeneralSerializer(t *testing.T) { - jamCodec := codec.NewJamCodec[uint64]() - serializer := serialization.NewSerializer[uint64](jamCodec) +func TestGeneralSerialization(t *testing.T) { + jamCodec := codec.NewJamCodec() + serializer := serialization.NewSerializer(jamCodec) // Test Encoding - v := uint64(127) - encoded, err := serializer.EncodeGeneral(v) + v := uint(127) + encoded, err := serializer.Encode(v) require.NoError(t, err) require.Equal(t, []byte{127}, encoded) // Test Decoding - var decoded uint64 - err = serializer.DecodeGeneral(encoded, &decoded) + var decoded uint + err = serializer.Decode(encoded, &decoded) require.NoError(t, err) assert.Equal(t, v, decoded) } -func TestTrivialSerializer(t *testing.T) { - jamCodec := codec.NewJamCodec[uint32]() - serializer := serialization.NewSerializer[uint32](jamCodec) +func TestTrivialSerialization(t *testing.T) { + jamCodec := codec.NewJamCodec() + serializer := serialization.NewSerializer(jamCodec) // Test Encoding v := 127 - encoded, err := serializer.EncodeTrivialUint(uint32(v), 3) + encoded, err := serializer.Encode(uint32(v)) require.NoError(t, err) - require.Equal(t, []byte{127, 0, 0}, encoded) + require.Equal(t, []byte{127, 0, 0, 0}, encoded) // Test Decoding var d64 uint64 - serializer64 := serialization.NewSerializer[uint64](codec.NewJamCodec[uint64]()) - err = serializer64.DecodeTrivialUint(encoded, &d64) + err = serializer.Decode(encoded, &d64) require.NoError(t, err) assert.Equal(t, uint64(v), d64) var d32 uint32 - err = serializer.DecodeTrivialUint(encoded, &d32) + err = serializer.Decode(encoded, &d32) require.NoError(t, err) assert.Equal(t, uint32(v), d32) var d16 uint16 - serializer16 := serialization.NewSerializer[uint16](codec.NewJamCodec[uint16]()) - err = serializer16.DecodeTrivialUint(encoded, &d16) + err = serializer.Decode(encoded, &d16) require.NoError(t, err) assert.Equal(t, uint16(v), d16) var d8 uint8 - serializer8 := serialization.NewSerializer[uint8](codec.NewJamCodec[uint8]()) - err = serializer8.DecodeTrivialUint(encoded, &d8) + err = serializer.Decode(encoded, &d8) require.NoError(t, err) assert.Equal(t, uint8(v), d8) }