diff --git a/internal/block/header_test.go b/internal/block/header_test.go index bb7bc67..b708d50 100644 --- a/internal/block/header_test.go +++ b/internal/block/header_test.go @@ -45,7 +45,7 @@ func Test_HeaderEncodeDecode(t *testing.T) { VRFSignature: randomSignature(t), BlockSealSignature: randomSignature(t), } - serializer := serialization.NewSerializer(&codec.SCALECodec{}) + serializer := serialization.NewSerializer[uint64](&codec.SCALECodec[uint64]{}) bb, err := serializer.Encode(h) if err != nil { t.Fatal(err) diff --git a/pkg/serialization/codec/codec.go b/pkg/serialization/codec/codec.go index 11c424a..abef721 100644 --- a/pkg/serialization/codec/codec.go +++ b/pkg/serialization/codec/codec.go @@ -1,8 +1,14 @@ package codec -type Codec interface { +type Uint interface { + uint8 | uint16 | uint32 | uint64 +} + +type Codec[T Uint] interface { Marshal(v interface{}) ([]byte, error) MarshalGeneral(x uint64) ([]byte, error) + MarshalTrivialUint(x T, l uint8) ([]byte, error) Unmarshal(data []byte, v interface{}) error UnmarshalGeneral(data []byte, v *uint64) error + UnmarshalTrivialUint(data []byte, v *T) error } diff --git a/pkg/serialization/codec/jam/general_natural.go b/pkg/serialization/codec/jam/general_natural.go index 8860601..587ac53 100644 --- a/pkg/serialization/codec/jam/general_natural.go +++ b/pkg/serialization/codec/jam/general_natural.go @@ -6,7 +6,7 @@ import ( "math/bits" ) -// GeneralNatural implements the formula (275: able to encode naturals of up to 2^64) +// GeneralNatural implements the formula (able to encode naturals of up to 2^64) type GeneralNatural struct{} func (j *GeneralNatural) SerializeUint64(x uint64) []byte { diff --git a/pkg/serialization/codec/jam/general_natural_test.go b/pkg/serialization/codec/jam/general_natural_test.go index 1ab6df9..b325049 100644 --- a/pkg/serialization/codec/jam/general_natural_test.go +++ b/pkg/serialization/codec/jam/general_natural_test.go @@ -53,19 +53,19 @@ func TestEncodeDecodeUint64(t *testing.T) { for _, tc := range testCases { t.Run(fmt.Sprintf("uint64(%d)", tc.input), func(t *testing.T) { - // Marshal the input value + // Marshal the x value serialized := gn.SerializeUint64(tc.input) // Check if the serialized output matches the expected output - assert.Equal(t, tc.expected, serialized, "serialized output mismatch for input %d", tc.input) + assert.Equal(t, tc.expected, serialized, "serialized output mismatch for x %d", tc.input) // Unmarshal the serialized data back into a uint64 var deserialized uint64 err := gn.DeserializeUint64(serialized, &deserialized) require.NoError(t, err, "unmarshal(%v) returned an unexpected error", serialized) - // Check if the deserialized value matches the original input - assert.Equal(t, tc.input, deserialized, "deserialized value mismatch for input %d", tc.input) + // Check if the deserialized value matches the original x + assert.Equal(t, tc.input, deserialized, "deserialized value mismatch for x %d", tc.input) }) } } diff --git a/pkg/serialization/codec/jam/trivial_natural.go b/pkg/serialization/codec/jam/trivial_natural.go new file mode 100644 index 0000000..aa98250 --- /dev/null +++ b/pkg/serialization/codec/jam/trivial_natural.go @@ -0,0 +1,22 @@ +package jam + +import ( + "math" +) + +func SerializeTrivialNatural[T ~uint8 | ~uint16 | ~uint32 | ~uint64](x T, l uint8) []byte { + bytes := make([]byte, l) + for i := uint8(0); i < l; i++ { + bytes[i] = byte((x >> (8 * i)) & T(math.MaxUint8)) + } + return bytes +} + +func DeserializeTrivialNatural[T ~uint8 | ~uint16 | ~uint32 | ~uint64](serialized []byte, u *T) { + *u = 0 + + // Iterate over each byte in the serialized array + for i := 0; i < len(serialized); i++ { + *u |= T(serialized[i]) << (8 * i) + } +} diff --git a/pkg/serialization/codec/jam/trivial_natural_test.go b/pkg/serialization/codec/jam/trivial_natural_test.go new file mode 100644 index 0000000..10a77e9 --- /dev/null +++ b/pkg/serialization/codec/jam/trivial_natural_test.go @@ -0,0 +1,83 @@ +package jam + +import ( + "fmt" + "math" + "reflect" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestSerializationTrivialNatural(t *testing.T) { + testCases := []struct { + x any + l uint8 + expected []byte + }{ + {uint8(0), 0, []byte{}}, + {uint8(0), 1, []byte{0}}, + {uint8(1), 3, []byte{1, 0, 0}}, + {uint8(math.MaxInt8), 4, []byte{127, 0, 0, 0}}, + {uint8(math.MaxUint8), 2, []byte{255, 0}}, + {uint16(0), 0, []byte{}}, + {uint16(0), 1, []byte{0}}, + {uint16(math.MaxUint16), 2, []byte{255, 255}}, + {uint32(0), 0, []byte{}}, + {uint32(0), 1, []byte{0}}, + {uint32(1), 3, []byte{1, 0, 0}}, + {uint32(math.MaxInt8), 4, []byte{127, 0, 0, 0}}, + {uint32(128), 1, []byte{128}}, + {uint32(math.MaxUint8), 3, []byte{255, 0, 0}}, + {uint32(256), 2, []byte{0, 1}}, + {uint32(1023), 3, []byte{255, 3, 0}}, + {uint32(1024), 2, []byte{0, 4}}, + {uint32(16383), 4, []byte{255, 63, 0, 0}}, + {uint32(math.MaxUint16), 3, []byte{255, 255, 0}}, + {uint32(math.MaxUint32), 4, []byte{255, 255, 255, 255}}, + {uint64(0), 0, []byte{}}, + {uint64(0), 4, []byte{0, 0, 0, 0}}, + {uint64(1), 3, []byte{1, 0, 0}}, + {uint64(math.MaxUint16), 3, []byte{255, 255, 0}}, + {uint64(math.MaxUint32), 6, []byte{255, 255, 255, 255, 0, 0}}, + {uint64(math.MaxUint64), 8, []byte{255, 255, 255, 255, 255, 255, 255, 255}}, + } + + for _, tc := range testCases { + testName := fmt.Sprintf("%s_%v", reflect.TypeOf(tc.x).Name(), tc.x) + t.Run(testName, func(t *testing.T) { + var serialized []byte + switch v := tc.x.(type) { + case uint8: + serialized = SerializeTrivialNatural(v, tc.l) + case uint16: + serialized = SerializeTrivialNatural(v, tc.l) + case uint32: + serialized = SerializeTrivialNatural(v, tc.l) + case uint64: + serialized = SerializeTrivialNatural(v, tc.l) + } + + assert.Equal(t, tc.expected, serialized, "serialized output mismatch") + + switch v := tc.x.(type) { + case uint8: + var deserialized uint8 + DeserializeTrivialNatural(serialized, &deserialized) + assert.Equal(t, v, deserialized, "deserialized value mismatch") + case uint16: + var deserialized uint16 + DeserializeTrivialNatural(serialized, &deserialized) + assert.Equal(t, v, deserialized, "deserialized value mismatch") + case uint32: + var deserialized uint32 + DeserializeTrivialNatural(serialized, &deserialized) + assert.Equal(t, v, deserialized, "deserialized value mismatch") + case uint64: + var deserialized uint64 + DeserializeTrivialNatural(serialized, &deserialized) + assert.Equal(t, v, deserialized, "deserialized value mismatch") + } + }) + } +} diff --git a/pkg/serialization/codec/jam_codec.go b/pkg/serialization/codec/jam_codec.go index b44f017..28d1408 100644 --- a/pkg/serialization/codec/jam_codec.go +++ b/pkg/serialization/codec/jam_codec.go @@ -6,29 +6,40 @@ import ( ) // JAMCodec implements the Codec interface for JSON encoding and decoding. -type JAMCodec struct { +type JAMCodec[T Uint] struct { gn jam.GeneralNatural } // NewJamCodec initializes an instance of Jam codec -func NewJamCodec() *JAMCodec { - return &JAMCodec{gn: jam.GeneralNatural{}} +func NewJamCodec[T Uint]() *JAMCodec[T] { + return &JAMCodec[T]{ + gn: jam.GeneralNatural{}, + } } -func (j *JAMCodec) Marshal(v interface{}) ([]byte, error) { +func (j *JAMCodec[T]) Marshal(v interface{}) ([]byte, error) { // TODO return nil, errors.New("not implemented") } -func (j *JAMCodec) MarshalGeneral(v uint64) ([]byte, error) { +func (j *JAMCodec[T]) MarshalGeneral(v uint64) ([]byte, error) { return j.gn.SerializeUint64(v), nil } -func (j *JAMCodec) Unmarshal(data []byte, v interface{}) error { +func (j *JAMCodec[T]) MarshalTrivialUint(x T, l uint8) ([]byte, error) { + return jam.SerializeTrivialNatural(x, l), nil +} + +func (j *JAMCodec[T]) Unmarshal(data []byte, v interface{}) error { // TODO return errors.New("not implemented") } -func (j *JAMCodec) UnmarshalGeneral(data []byte, v *uint64) error { +func (j *JAMCodec[T]) UnmarshalGeneral(data []byte, v *uint64) error { return j.gn.DeserializeUint64(data, v) } + +func (j *JAMCodec[T]) UnmarshalTrivialUint(data []byte, x *T) error { + jam.DeserializeTrivialNatural(data, x) + return nil +} diff --git a/pkg/serialization/codec/json_codec.go b/pkg/serialization/codec/json_codec.go index e2f9bdd..7445b1a 100644 --- a/pkg/serialization/codec/json_codec.go +++ b/pkg/serialization/codec/json_codec.go @@ -5,20 +5,28 @@ import ( ) // JSONCodec implements the Codec interface for JSON encoding and decoding. -type JSONCodec struct{} +type JSONCodec[T Uint] struct{} -func (j *JSONCodec) Marshal(v interface{}) ([]byte, error) { +func (j *JSONCodec[T]) Marshal(v interface{}) ([]byte, error) { return json.Marshal(v) } -func (j *JSONCodec) MarshalGeneral(v uint64) ([]byte, error) { +func (j *JSONCodec[T]) MarshalGeneral(v uint64) ([]byte, error) { return json.Marshal(v) } -func (j *JSONCodec) Unmarshal(data []byte, v interface{}) error { +func (j *JSONCodec[T]) MarshalTrivialUint(x T, l uint8) ([]byte, error) { + return json.Marshal(x) +} + +func (j *JSONCodec[T]) Unmarshal(data []byte, v interface{}) error { return json.Unmarshal(data, v) } -func (j *JSONCodec) UnmarshalGeneral(data []byte, v *uint64) error { +func (j *JSONCodec[T]) UnmarshalGeneral(data []byte, v *uint64) error { return json.Unmarshal(data, v) } + +func (j *JSONCodec[T]) UnmarshalTrivialUint(data []byte, x *T) error { + return json.Unmarshal(data, x) +} diff --git a/pkg/serialization/codec/scale_codec.go b/pkg/serialization/codec/scale_codec.go index 0a3b2a3..193c914 100644 --- a/pkg/serialization/codec/scale_codec.go +++ b/pkg/serialization/codec/scale_codec.go @@ -3,20 +3,28 @@ package codec import "github.com/ChainSafe/gossamer/pkg/scale" // SCALECodec implements the Codec interface for SCALE encoding and decoding. -type SCALECodec struct{} +type SCALECodec[T Uint] struct{} -func (s *SCALECodec) Marshal(v interface{}) ([]byte, error) { +func (s *SCALECodec[T]) Marshal(v interface{}) ([]byte, error) { return scale.Marshal(v) } -func (j *SCALECodec) MarshalGeneral(v uint64) ([]byte, error) { +func (s *SCALECodec[T]) MarshalGeneral(v uint64) ([]byte, error) { return scale.Marshal(v) } -func (s *SCALECodec) Unmarshal(data []byte, v interface{}) error { +func (s *SCALECodec[T]) MarshalTrivialUint(x T, l uint8) ([]byte, error) { + return scale.Marshal(x) +} + +func (s *SCALECodec[T]) Unmarshal(data []byte, v interface{}) error { return scale.Unmarshal(data, v) } -func (s *SCALECodec) UnmarshalGeneral(data []byte, v *uint64) error { +func (s *SCALECodec[T]) UnmarshalGeneral(data []byte, v *uint64) error { return scale.Unmarshal(data, v) } + +func (s *SCALECodec[T]) UnmarshalTrivialUint(data []byte, x *T) error { + return scale.Unmarshal(data, x) +} diff --git a/pkg/serialization/serializer.go b/pkg/serialization/serializer.go index e0aa7bf..9459ded 100644 --- a/pkg/serialization/serializer.go +++ b/pkg/serialization/serializer.go @@ -3,31 +3,41 @@ package serialization import "github.com/eigerco/strawberry/pkg/serialization/codec" // Serializer provides methods to encode and decode using a specified codec. -type Serializer struct { - codec codec.Codec +type Serializer[T codec.Uint] struct { + codec codec.Codec[T] } // NewSerializer initializes a new Serializer with the given codec. -func NewSerializer(c codec.Codec) *Serializer { - return &Serializer{codec: c} +func NewSerializer[T codec.Uint](c codec.Codec[T]) *Serializer[T] { + return &Serializer[T]{codec: c} } // Encode serializes the given value using the codec. -func (s *Serializer) Encode(v interface{}) ([]byte, error) { +func (s *Serializer[T]) Encode(v interface{}) ([]byte, error) { return s.codec.Marshal(v) } // EncodeGeneral is specific encoding for natural numbers up to 2^64 -func (s *Serializer) EncodeGeneral(v uint64) ([]byte, error) { +func (s *Serializer[T]) EncodeGeneral(v uint64) ([]byte, error) { return s.codec.MarshalGeneral(v) } +// EncodeTrivialUint is the trivial encoding for natural numbers +func (s *Serializer[T]) EncodeTrivialUint(x T, l uint8) ([]byte, error) { + return s.codec.MarshalTrivialUint(x, l) +} + // Decode deserializes the given data into the specified value using the codec. -func (s *Serializer) Decode(data []byte, v interface{}) error { +func (s *Serializer[T]) Decode(data []byte, v interface{}) error { return s.codec.Unmarshal(data, v) } // DecodeGeneral is specific decoding for natural numbers up to 2^64 -func (s *Serializer) DecodeGeneral(data []byte, v *uint64) error { +func (s *Serializer[T]) DecodeGeneral(data []byte, v *uint64) error { return s.codec.UnmarshalGeneral(data, v) } + +// DecodeTrivialUint is the trivial decoding for natural numbers +func (s *Serializer[T]) DecodeTrivialUint(data []byte, v *T) error { + return s.codec.UnmarshalTrivialUint(data, v) +} diff --git a/pkg/serialization/serializer_test.go b/pkg/serialization/serializer_test.go index b066682..69bb9e6 100644 --- a/pkg/serialization/serializer_test.go +++ b/pkg/serialization/serializer_test.go @@ -16,8 +16,8 @@ type PayloadExample struct { } func TestJSONSerializer(t *testing.T) { - jsonCodec := &codec.JSONCodec{} - serializer := serialization.NewSerializer(jsonCodec) + jsonCodec := &codec.JSONCodec[uint16]{} + serializer := serialization.NewSerializer[uint16](jsonCodec) example := PayloadExample{ID: 1, Data: []byte{1, 2, 3}} @@ -34,8 +34,8 @@ func TestJSONSerializer(t *testing.T) { } func TestSCALESerializer(t *testing.T) { - scaleCodec := &codec.SCALECodec{} - serializer := serialization.NewSerializer(scaleCodec) + scaleCodec := &codec.SCALECodec[uint64]{} + serializer := serialization.NewSerializer[uint64](scaleCodec) example := PayloadExample{ID: 2, Data: []byte{1, 2, 3}} @@ -52,8 +52,8 @@ func TestSCALESerializer(t *testing.T) { } func TestGeneralSerializer(t *testing.T) { - jamCodec := codec.NewJamCodec() - serializer := serialization.NewSerializer(jamCodec) + jamCodec := codec.NewJamCodec[uint64]() + serializer := serialization.NewSerializer[uint64](jamCodec) // Test Encoding v := uint64(127) @@ -67,3 +67,38 @@ func TestGeneralSerializer(t *testing.T) { require.NoError(t, err) assert.Equal(t, v, decoded) } + +func TestTrivialSerializer(t *testing.T) { + jamCodec := codec.NewJamCodec[uint32]() + serializer := serialization.NewSerializer[uint32](jamCodec) + + // Test Encoding + v := 127 + encoded, err := serializer.EncodeTrivialUint(uint32(v), 3) + require.NoError(t, err) + require.Equal(t, []byte{127, 0, 0}, encoded) + + // Test Decoding + var d64 uint64 + serializer64 := serialization.NewSerializer[uint64](codec.NewJamCodec[uint64]()) + err = serializer64.DecodeTrivialUint(encoded, &d64) + require.NoError(t, err) + assert.Equal(t, uint64(v), d64) + + var d32 uint32 + err = serializer.DecodeTrivialUint(encoded, &d32) + require.NoError(t, err) + assert.Equal(t, uint32(v), d32) + + var d16 uint16 + serializer16 := serialization.NewSerializer[uint16](codec.NewJamCodec[uint16]()) + err = serializer16.DecodeTrivialUint(encoded, &d16) + require.NoError(t, err) + assert.Equal(t, uint16(v), d16) + + var d8 uint8 + serializer8 := serialization.NewSerializer[uint8](codec.NewJamCodec[uint8]()) + err = serializer8.DecodeTrivialUint(encoded, &d8) + require.NoError(t, err) + assert.Equal(t, uint8(v), d8) +}