Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for trivial natural encoding #44

Merged
merged 9 commits into from
Aug 21, 2024
2 changes: 1 addition & 1 deletion internal/block/header_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ func Test_HeaderEncodeDecode(t *testing.T) {
VRFSignature: randomSignature(t),
BlockSealSignature: randomSignature(t),
}
serializer := serialization.NewSerializer(&codec.SCALECodec{})
serializer := serialization.NewSerializer[uint64](&codec.SCALECodec[uint64]{})
bb, err := serializer.Encode(h)
if err != nil {
t.Fatal(err)
Expand Down
8 changes: 7 additions & 1 deletion pkg/serialization/codec/codec.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
package codec

type Codec interface {
type Uint interface {
uint8 | uint16 | uint32 | uint64
}

type Codec[T Uint] interface {
Marshal(v interface{}) ([]byte, error)
MarshalGeneral(x uint64) ([]byte, error)
MarshalTrivialUint(x T, l uint8) ([]byte, error)
Unmarshal(data []byte, v interface{}) error
UnmarshalGeneral(data []byte, v *uint64) error
UnmarshalTrivialUint(data []byte, v *T) error
}
2 changes: 1 addition & 1 deletion pkg/serialization/codec/jam/general_natural.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import (
"math/bits"
)

// GeneralNatural implements the formula (275: able to encode naturals of up to 2^64)
// GeneralNatural implements the formula (able to encode naturals of up to 2^64)
type GeneralNatural struct{}

func (j *GeneralNatural) SerializeUint64(x uint64) []byte {
Expand Down
8 changes: 4 additions & 4 deletions pkg/serialization/codec/jam/general_natural_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,19 +53,19 @@ func TestEncodeDecodeUint64(t *testing.T) {

for _, tc := range testCases {
t.Run(fmt.Sprintf("uint64(%d)", tc.input), func(t *testing.T) {
// Marshal the input value
// Marshal the x value
serialized := gn.SerializeUint64(tc.input)

// Check if the serialized output matches the expected output
assert.Equal(t, tc.expected, serialized, "serialized output mismatch for input %d", tc.input)
assert.Equal(t, tc.expected, serialized, "serialized output mismatch for x %d", tc.input)

// Unmarshal the serialized data back into a uint64
var deserialized uint64
err := gn.DeserializeUint64(serialized, &deserialized)
require.NoError(t, err, "unmarshal(%v) returned an unexpected error", serialized)

// Check if the deserialized value matches the original input
assert.Equal(t, tc.input, deserialized, "deserialized value mismatch for input %d", tc.input)
// Check if the deserialized value matches the original x
assert.Equal(t, tc.input, deserialized, "deserialized value mismatch for x %d", tc.input)
})
}
}
23 changes: 23 additions & 0 deletions pkg/serialization/codec/jam/trivial_natural.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package jam

import (
"math"
)

func SerializeTrivialNatural[T uint8 | uint16 | uint32 | uint64](x T, l uint8) []byte {
pantrif marked this conversation as resolved.
Show resolved Hide resolved
bytes := make([]byte, 0, l) // Preallocate with length `l`
pantrif marked this conversation as resolved.
Show resolved Hide resolved
for i := uint8(0); i < l; i++ {
byteVal := byte((x >> (8 * i)) & T(math.MaxUint8))
pantrif marked this conversation as resolved.
Show resolved Hide resolved
bytes = append(bytes, byteVal)
pantrif marked this conversation as resolved.
Show resolved Hide resolved
}
return bytes
}

func DeserializeTrivialNatural[T uint8 | uint16 | uint32 | uint64](serialized []byte, u *T) {
pantrif marked this conversation as resolved.
Show resolved Hide resolved
*u = 0

// Iterate over each byte in the serialized array
for i := 0; i < len(serialized); i++ {
*u |= T(serialized[i]) << (8 * i)
}
}
79 changes: 79 additions & 0 deletions pkg/serialization/codec/jam/trivial_natural_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
package jam

import (
"fmt"
"math"
"reflect"
"testing"

"github.com/stretchr/testify/assert"
)

func TestSerializationTrivialNatural(t *testing.T) {
pantrif marked this conversation as resolved.
Show resolved Hide resolved
testCases := []struct {
x any
l uint8
expected []byte
}{
{uint8(0), 1, []byte{0}},
{uint8(1), 3, []byte{1, 0, 0}},
{uint8(math.MaxInt8), 4, []byte{127, 0, 0, 0}},
{uint8(math.MaxUint8), 2, []byte{255, 0}},
{uint16(0), 1, []byte{0}},
{uint16(math.MaxUint16), 2, []byte{255, 255}},
{uint32(0), 1, []byte{0}},
{uint32(1), 3, []byte{1, 0, 0}},
{uint32(math.MaxInt8), 4, []byte{127, 0, 0, 0}},
{uint32(128), 1, []byte{128}},
{uint32(math.MaxUint8), 3, []byte{255, 0, 0}},
{uint32(256), 2, []byte{0, 1}},
{uint32(1023), 3, []byte{255, 3, 0}},
{uint32(1024), 2, []byte{0, 4}},
{uint32(16383), 4, []byte{255, 63, 0, 0}},
{uint32(math.MaxUint16), 3, []byte{255, 255, 0}},
{uint32(math.MaxUint32), 4, []byte{255, 255, 255, 255}},
{uint64(0), 4, []byte{0, 0, 0, 0}},
{uint64(1), 3, []byte{1, 0, 0}},
{uint64(math.MaxUint16), 3, []byte{255, 255, 0}},
{uint64(math.MaxUint32), 6, []byte{255, 255, 255, 255, 0, 0}},
{uint64(math.MaxUint64), 8, []byte{255, 255, 255, 255, 255, 255, 255, 255}},
}

for _, tc := range testCases {
testName := fmt.Sprintf("%s_%v", reflect.TypeOf(tc.x).Name(), tc.x)
t.Run(testName, func(t *testing.T) {
var serialized []byte
switch v := tc.x.(type) {
case uint8:
serialized = SerializeTrivialNatural(v, tc.l)
case uint16:
serialized = SerializeTrivialNatural(v, tc.l)
case uint32:
serialized = SerializeTrivialNatural(v, tc.l)
case uint64:
serialized = SerializeTrivialNatural(v, tc.l)
}

assert.Equal(t, tc.expected, serialized, "serialized output mismatch")

switch v := tc.x.(type) {
case uint8:
var deserialized uint8
DeserializeTrivialNatural(serialized, &deserialized)
assert.Equal(t, v, deserialized, "deserialized value mismatch")
case uint16:
var deserialized uint16
DeserializeTrivialNatural(serialized, &deserialized)
assert.Equal(t, v, deserialized, "deserialized value mismatch")
case uint32:
var deserialized uint32
DeserializeTrivialNatural(serialized, &deserialized)
assert.Equal(t, v, deserialized, "deserialized value mismatch")
case uint64:
var deserialized uint64
DeserializeTrivialNatural(serialized, &deserialized)
assert.Equal(t, v, deserialized, "deserialized value mismatch")
}
})
}
}
25 changes: 18 additions & 7 deletions pkg/serialization/codec/jam_codec.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,29 +6,40 @@ import (
)

// JAMCodec implements the Codec interface for JSON encoding and decoding.
type JAMCodec struct {
type JAMCodec[T Uint] struct {
gn jam.GeneralNatural
}

// NewJamCodec initializes an instance of Jam codec
func NewJamCodec() *JAMCodec {
return &JAMCodec{gn: jam.GeneralNatural{}}
func NewJamCodec[T Uint]() *JAMCodec[T] {
return &JAMCodec[T]{
gn: jam.GeneralNatural{},
}
}

func (j *JAMCodec) Marshal(v interface{}) ([]byte, error) {
func (j *JAMCodec[T]) Marshal(v interface{}) ([]byte, error) {
// TODO
return nil, errors.New("not implemented")
}

func (j *JAMCodec) MarshalGeneral(v uint64) ([]byte, error) {
func (j *JAMCodec[T]) MarshalGeneral(v uint64) ([]byte, error) {
return j.gn.SerializeUint64(v), nil
}

func (j *JAMCodec) Unmarshal(data []byte, v interface{}) error {
func (j *JAMCodec[T]) MarshalTrivialUint(x T, l uint8) ([]byte, error) {
return jam.SerializeTrivialNatural(x, l), nil
}

func (j *JAMCodec[T]) Unmarshal(data []byte, v interface{}) error {
// TODO
return errors.New("not implemented")
}

func (j *JAMCodec) UnmarshalGeneral(data []byte, v *uint64) error {
func (j *JAMCodec[T]) UnmarshalGeneral(data []byte, v *uint64) error {
return j.gn.DeserializeUint64(data, v)
}

func (j *JAMCodec[T]) UnmarshalTrivialUint(data []byte, x *T) error {
jam.DeserializeTrivialNatural(data, x)
return nil
}
18 changes: 13 additions & 5 deletions pkg/serialization/codec/json_codec.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,28 @@ import (
)

// JSONCodec implements the Codec interface for JSON encoding and decoding.
type JSONCodec struct{}
type JSONCodec[T Uint] struct{}

func (j *JSONCodec) Marshal(v interface{}) ([]byte, error) {
func (j *JSONCodec[T]) Marshal(v interface{}) ([]byte, error) {
return json.Marshal(v)
}

func (j *JSONCodec) MarshalGeneral(v uint64) ([]byte, error) {
func (j *JSONCodec[T]) MarshalGeneral(v uint64) ([]byte, error) {
return json.Marshal(v)
}

func (j *JSONCodec) Unmarshal(data []byte, v interface{}) error {
func (j *JSONCodec[T]) MarshalTrivialUint(x T, l uint8) ([]byte, error) {
return json.Marshal(x)
}

func (j *JSONCodec[T]) Unmarshal(data []byte, v interface{}) error {
return json.Unmarshal(data, v)
}

func (j *JSONCodec) UnmarshalGeneral(data []byte, v *uint64) error {
func (j *JSONCodec[T]) UnmarshalGeneral(data []byte, v *uint64) error {
return json.Unmarshal(data, v)
}

func (j *JSONCodec[T]) UnmarshalTrivialUint(data []byte, x *T) error {
return json.Unmarshal(data, x)
}
18 changes: 13 additions & 5 deletions pkg/serialization/codec/scale_codec.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,28 @@ package codec
import "github.com/ChainSafe/gossamer/pkg/scale"

// SCALECodec implements the Codec interface for SCALE encoding and decoding.
type SCALECodec struct{}
type SCALECodec[T Uint] struct{}

func (s *SCALECodec) Marshal(v interface{}) ([]byte, error) {
func (s *SCALECodec[T]) Marshal(v interface{}) ([]byte, error) {
return scale.Marshal(v)
}

func (j *SCALECodec) MarshalGeneral(v uint64) ([]byte, error) {
func (s *SCALECodec[T]) MarshalGeneral(v uint64) ([]byte, error) {
return scale.Marshal(v)
}

func (s *SCALECodec) Unmarshal(data []byte, v interface{}) error {
func (s *SCALECodec[T]) MarshalTrivialUint(x T, l uint8) ([]byte, error) {
return scale.Marshal(x)
}

func (s *SCALECodec[T]) Unmarshal(data []byte, v interface{}) error {
return scale.Unmarshal(data, v)
}

func (s *SCALECodec) UnmarshalGeneral(data []byte, v *uint64) error {
func (s *SCALECodec[T]) UnmarshalGeneral(data []byte, v *uint64) error {
return scale.Unmarshal(data, v)
}

func (s *SCALECodec[T]) UnmarshalTrivialUint(data []byte, x *T) error {
return scale.Unmarshal(data, x)
}
26 changes: 18 additions & 8 deletions pkg/serialization/serializer.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,31 +3,41 @@ package serialization
import "github.com/eigerco/strawberry/pkg/serialization/codec"

// Serializer provides methods to encode and decode using a specified codec.
type Serializer struct {
codec codec.Codec
type Serializer[T codec.Uint] struct {
codec codec.Codec[T]
}

// NewSerializer initializes a new Serializer with the given codec.
func NewSerializer(c codec.Codec) *Serializer {
return &Serializer{codec: c}
func NewSerializer[T codec.Uint](c codec.Codec[T]) *Serializer[T] {
return &Serializer[T]{codec: c}
}

// Encode serializes the given value using the codec.
func (s *Serializer) Encode(v interface{}) ([]byte, error) {
func (s *Serializer[T]) Encode(v interface{}) ([]byte, error) {
return s.codec.Marshal(v)
}

// EncodeGeneral is specific encoding for natural numbers up to 2^64
func (s *Serializer) EncodeGeneral(v uint64) ([]byte, error) {
func (s *Serializer[T]) EncodeGeneral(v uint64) ([]byte, error) {
return s.codec.MarshalGeneral(v)
}

// EncodeTrivialUint is the trivial encoding for natural numbers
func (s *Serializer[T]) EncodeTrivialUint(x T, l uint8) ([]byte, error) {
return s.codec.MarshalTrivialUint(x, l)
}

// Decode deserializes the given data into the specified value using the codec.
func (s *Serializer) Decode(data []byte, v interface{}) error {
func (s *Serializer[T]) Decode(data []byte, v interface{}) error {
return s.codec.Unmarshal(data, v)
}

// DecodeGeneral is specific decoding for natural numbers up to 2^64
func (s *Serializer) DecodeGeneral(data []byte, v *uint64) error {
func (s *Serializer[T]) DecodeGeneral(data []byte, v *uint64) error {
return s.codec.UnmarshalGeneral(data, v)
}

// DecodeTrivialUint is the trivial decoding for natural numbers
func (s *Serializer[T]) DecodeTrivialUint(data []byte, v *T) error {
return s.codec.UnmarshalTrivialUint(data, v)
}
47 changes: 41 additions & 6 deletions pkg/serialization/serializer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ type PayloadExample struct {
}

func TestJSONSerializer(t *testing.T) {
jsonCodec := &codec.JSONCodec{}
serializer := serialization.NewSerializer(jsonCodec)
jsonCodec := &codec.JSONCodec[uint16]{}
serializer := serialization.NewSerializer[uint16](jsonCodec)

example := PayloadExample{ID: 1, Data: []byte{1, 2, 3}}

Expand All @@ -34,8 +34,8 @@ func TestJSONSerializer(t *testing.T) {
}

func TestSCALESerializer(t *testing.T) {
scaleCodec := &codec.SCALECodec{}
serializer := serialization.NewSerializer(scaleCodec)
scaleCodec := &codec.SCALECodec[uint64]{}
serializer := serialization.NewSerializer[uint64](scaleCodec)

example := PayloadExample{ID: 2, Data: []byte{1, 2, 3}}

Expand All @@ -52,8 +52,8 @@ func TestSCALESerializer(t *testing.T) {
}

func TestGeneralSerializer(t *testing.T) {
jamCodec := codec.NewJamCodec()
serializer := serialization.NewSerializer(jamCodec)
jamCodec := codec.NewJamCodec[uint64]()
serializer := serialization.NewSerializer[uint64](jamCodec)

// Test Encoding
v := uint64(127)
Expand All @@ -67,3 +67,38 @@ func TestGeneralSerializer(t *testing.T) {
require.NoError(t, err)
assert.Equal(t, v, decoded)
}

func TestTrivialSerializer(t *testing.T) {
jamCodec := codec.NewJamCodec[uint32]()
serializer := serialization.NewSerializer[uint32](jamCodec)

// Test Encoding
v := 127
encoded, err := serializer.EncodeTrivialUint(uint32(v), 3)
require.NoError(t, err)
require.Equal(t, []byte{127, 0, 0}, encoded)

// Test Decoding
var d64 uint64
serializer64 := serialization.NewSerializer[uint64](codec.NewJamCodec[uint64]())
err = serializer64.DecodeTrivialUint(encoded, &d64)
require.NoError(t, err)
assert.Equal(t, uint64(v), d64)

var d32 uint32
err = serializer.DecodeTrivialUint(encoded, &d32)
require.NoError(t, err)
assert.Equal(t, uint32(v), d32)

var d16 uint16
serializer16 := serialization.NewSerializer[uint16](codec.NewJamCodec[uint16]())
err = serializer16.DecodeTrivialUint(encoded, &d16)
require.NoError(t, err)
assert.Equal(t, uint16(v), d16)

var d8 uint8
serializer8 := serialization.NewSerializer[uint8](codec.NewJamCodec[uint8]())
err = serializer8.DecodeTrivialUint(encoded, &d8)
require.NoError(t, err)
assert.Equal(t, uint8(v), d8)
}