Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for trivial natural encoding #44

Merged
merged 9 commits into from
Aug 21, 2024
2 changes: 2 additions & 0 deletions pkg/serialization/codec/codec.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package codec
type Codec interface {
Marshal(v interface{}) ([]byte, error)
MarshalGeneral(x uint64) ([]byte, error)
MarshalTrivialUint(x interface{}, l uint8) ([]byte, error)
pantrif marked this conversation as resolved.
Show resolved Hide resolved
Unmarshal(data []byte, v interface{}) error
UnmarshalGeneral(data []byte, v *uint64) error
UnmarshalTrivialUint(data []byte, v interface{}) error
}
2 changes: 1 addition & 1 deletion pkg/serialization/codec/jam/general_natural.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import (
"math/bits"
)

// GeneralNatural implements the formula (275: able to encode naturals of up to 2^64)
// GeneralNatural implements the formula (able to encode naturals of up to 2^64)
type GeneralNatural struct{}

func (j *GeneralNatural) SerializeUint64(x uint64) []byte {
Expand Down
8 changes: 4 additions & 4 deletions pkg/serialization/codec/jam/general_natural_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,19 +53,19 @@ func TestEncodeDecodeUint64(t *testing.T) {

for _, tc := range testCases {
t.Run(fmt.Sprintf("uint64(%d)", tc.input), func(t *testing.T) {
// Marshal the input value
// Marshal the x value
serialized := gn.SerializeUint64(tc.input)

// Check if the serialized output matches the expected output
assert.Equal(t, tc.expected, serialized, "serialized output mismatch for input %d", tc.input)
assert.Equal(t, tc.expected, serialized, "serialized output mismatch for x %d", tc.input)

// Unmarshal the serialized data back into a uint64
var deserialized uint64
err := gn.DeserializeUint64(serialized, &deserialized)
require.NoError(t, err, "unmarshal(%v) returned an unexpected error", serialized)

// Check if the deserialized value matches the original input
assert.Equal(t, tc.input, deserialized, "deserialized value mismatch for input %d", tc.input)
// Check if the deserialized value matches the original x
assert.Equal(t, tc.input, deserialized, "deserialized value mismatch for x %d", tc.input)
})
}
}
29 changes: 29 additions & 0 deletions pkg/serialization/codec/jam/trivial_natural.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
package jam

import (
"math"
)

// TrivialNatural implements the trivial integer formula
// (This is utilized for almost all integer encoding across the protocol)
type TrivialNatural[T uint8 | uint16 | uint32 | uint64] struct{}

// Serialize serializes any unsigned integer type into a byte slice.
func (j *TrivialNatural[T]) Serialize(x T, l uint8) []byte {
pantrif marked this conversation as resolved.
Show resolved Hide resolved
bytes := make([]byte, 0, l) // Preallocate with length `l`
pantrif marked this conversation as resolved.
Show resolved Hide resolved
for i := uint8(0); i < l; i++ {
byteVal := byte((x >> (8 * i)) & T(math.MaxUint8))
pantrif marked this conversation as resolved.
Show resolved Hide resolved
bytes = append(bytes, byteVal)
pantrif marked this conversation as resolved.
Show resolved Hide resolved
}
return bytes
}

// Deserialize deserializes a byte slice into the provided unsigned integer type.
func (j *TrivialNatural[T]) Deserialize(serialized []byte, u *T) {
*u = 0

// Iterate over each byte in the serialized array
for i := 0; i < len(serialized); i++ {
*u |= T(serialized[i]) << (8 * i)
}
}
105 changes: 105 additions & 0 deletions pkg/serialization/codec/jam/trivial_natural_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
package jam

import (
"fmt"
"math"
"reflect"
"testing"

"github.com/stretchr/testify/assert"
)

func TestEncodeDecodeTrivialUint8(t *testing.T) {
tn := TrivialNatural[uint8]{}
testCases := []struct {
x uint8
l uint8
expected []byte
}{
{0, 1, []byte{0}},
{1, 3, []byte{1, 0, 0}},
{math.MaxInt8, 4, []byte{127, 0, 0, 0}},
{math.MaxUint8, 2, []byte{255, 0}},
}

testEncodeDecodeTrivialUint(t, tn, testCases)
}

func TestEncodeDecodeTrivialUint16(t *testing.T) {
tn := TrivialNatural[uint16]{}
testCases := []struct {
x uint16
l uint8
expected []byte
}{
{0, 1, []byte{0}},
{math.MaxUint16, 2, []byte{255, 255}},
}

testEncodeDecodeTrivialUint(t, tn, testCases)
}

func TestEncodeDecodeTrivialUint32(t *testing.T) {
tn := TrivialNatural[uint32]{}
testCases := []struct {
x uint32
l uint8
expected []byte
}{
{0, 1, []byte{0}},
{1, 3, []byte{1, 0, 0}},
{math.MaxInt8, 4, []byte{127, 0, 0, 0}},
{128, 1, []byte{128}},
{math.MaxUint8, 3, []byte{255, 0, 0}},
{256, 2, []byte{0, 1}},
{1023, 3, []byte{255, 3, 0}},
{1024, 2, []byte{0, 4}},
{16383, 4, []byte{255, 63, 0, 0}},
{math.MaxUint16, 3, []byte{255, 255, 0}},
{math.MaxUint32, 4, []byte{255, 255, 255, 255}},
}

testEncodeDecodeTrivialUint(t, tn, testCases)
}

func TestEncodeDecodeTrivialUint64(t *testing.T) {
tn := TrivialNatural[uint64]{}
testCases := []struct {
x uint64
l uint8
expected []byte
}{
{0, 4, []byte{0, 0, 0, 0}},
{1, 3, []byte{1, 0, 0}},
{math.MaxUint16, 3, []byte{255, 255, 0}},
{math.MaxUint32, 6, []byte{255, 255, 255, 255, 0, 0}},
{math.MaxUint64, 8, []byte{255, 255, 255, 255, 255, 255, 255, 255}},
}

testEncodeDecodeTrivialUint(t, tn, testCases)
}

func testEncodeDecodeTrivialUint[T uint8 | uint16 | uint32 | uint64](t *testing.T, tn TrivialNatural[T], testCases []struct {
pantrif marked this conversation as resolved.
Show resolved Hide resolved
x T
l uint8
expected []byte
}) {
typeName := reflect.TypeOf(*new(T)).Name()

for _, tc := range testCases {
t.Run(fmt.Sprintf("%s(%v)", typeName, tc.x), func(t *testing.T) {
// Marshal the x value
serialized := tn.Serialize(tc.x, tc.l)

// Check if the serialized output matches the expected output
assert.Equal(t, tc.expected, serialized, "serialized output mismatch for x %v", tc.x)

// Unmarshal the serialized data back into the type T
var deserialized T
tn.Deserialize(serialized, &deserialized)

// Check if the deserialized value matches the original x
assert.Equal(t, tc.x, deserialized, "deserialized value mismatch for x %v", tc.x)
})
}
}
51 changes: 49 additions & 2 deletions pkg/serialization/codec/jam_codec.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,30 @@ package codec

import (
"errors"
"fmt"
"github.com/eigerco/strawberry/pkg/serialization/codec/jam"
)

var unsupportedType = "unsupported type: %T"

// JAMCodec implements the Codec interface for JSON encoding and decoding.
type JAMCodec struct {
gn jam.GeneralNatural
gn jam.GeneralNatural
tn8 jam.TrivialNatural[uint8]
tn16 jam.TrivialNatural[uint16]
tn32 jam.TrivialNatural[uint32]
tn64 jam.TrivialNatural[uint64]
}

// NewJamCodec initializes an instance of Jam codec
func NewJamCodec() *JAMCodec {
return &JAMCodec{gn: jam.GeneralNatural{}}
return &JAMCodec{
gn: jam.GeneralNatural{},
tn8: jam.TrivialNatural[uint8]{},
tn16: jam.TrivialNatural[uint16]{},
tn32: jam.TrivialNatural[uint32]{},
tn64: jam.TrivialNatural[uint64]{},
}
}

func (j *JAMCodec) Marshal(v interface{}) ([]byte, error) {
Expand All @@ -24,6 +37,21 @@ func (j *JAMCodec) MarshalGeneral(v uint64) ([]byte, error) {
return j.gn.SerializeUint64(v), nil
}

func (j *JAMCodec) MarshalTrivialUint(x interface{}, l uint8) ([]byte, error) {
switch v := x.(type) {
case uint8:
pantrif marked this conversation as resolved.
Show resolved Hide resolved
return j.tn8.Serialize(v, l), nil
case uint16:
return j.tn16.Serialize(v, l), nil
case uint32:
return j.tn32.Serialize(v, l), nil
case uint64:
return j.tn64.Serialize(v, l), nil
default:
return nil, fmt.Errorf(unsupportedType, v)
}
}

func (j *JAMCodec) Unmarshal(data []byte, v interface{}) error {
// TODO
return errors.New("not implemented")
Expand All @@ -32,3 +60,22 @@ func (j *JAMCodec) Unmarshal(data []byte, v interface{}) error {
func (j *JAMCodec) UnmarshalGeneral(data []byte, v *uint64) error {
return j.gn.DeserializeUint64(data, v)
}

func (j *JAMCodec) UnmarshalTrivialUint(data []byte, x interface{}) error {
switch v := x.(type) {
case *uint8:
j.tn8.Deserialize(data, v)
return nil
case *uint16:
j.tn16.Deserialize(data, v)
return nil
case *uint32:
j.tn32.Deserialize(data, v)
return nil
case *uint64:
j.tn64.Deserialize(data, v)
return nil
default:
return fmt.Errorf(unsupportedType, v)
}
}
8 changes: 8 additions & 0 deletions pkg/serialization/codec/json_codec.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,18 @@ func (j *JSONCodec) MarshalGeneral(v uint64) ([]byte, error) {
return json.Marshal(v)
}

func (j *JSONCodec) MarshalTrivialUint(x interface{}, l uint8) ([]byte, error) {
return json.Marshal(x)
}

func (j *JSONCodec) Unmarshal(data []byte, v interface{}) error {
return json.Unmarshal(data, v)
}

func (j *JSONCodec) UnmarshalGeneral(data []byte, v *uint64) error {
return json.Unmarshal(data, v)
}

func (j *JSONCodec) UnmarshalTrivialUint(data []byte, x interface{}) error {
return json.Unmarshal(data, x)
}
8 changes: 8 additions & 0 deletions pkg/serialization/codec/scale_codec.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,18 @@ func (j *SCALECodec) MarshalGeneral(v uint64) ([]byte, error) {
return scale.Marshal(v)
}

func (j *SCALECodec) MarshalTrivialUint(x interface{}, l uint8) ([]byte, error) {
return scale.Marshal(x)
}

func (s *SCALECodec) Unmarshal(data []byte, v interface{}) error {
return scale.Unmarshal(data, v)
}

func (s *SCALECodec) UnmarshalGeneral(data []byte, v *uint64) error {
return scale.Unmarshal(data, v)
}

func (j *SCALECodec) UnmarshalTrivialUint(data []byte, x interface{}) error {
return scale.Unmarshal(data, x)
}
10 changes: 10 additions & 0 deletions pkg/serialization/serializer.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@ func (s *Serializer) EncodeGeneral(v uint64) ([]byte, error) {
return s.codec.MarshalGeneral(v)
}

// EncodeTrivialUint is the trivial encoding for natural numbers
func (s *Serializer) EncodeTrivialUint(x interface{}, l uint8) ([]byte, error) {
return s.codec.MarshalTrivialUint(x, l)
}

// Decode deserializes the given data into the specified value using the codec.
func (s *Serializer) Decode(data []byte, v interface{}) error {
return s.codec.Unmarshal(data, v)
Expand All @@ -31,3 +36,8 @@ func (s *Serializer) Decode(data []byte, v interface{}) error {
func (s *Serializer) DecodeGeneral(data []byte, v *uint64) error {
return s.codec.UnmarshalGeneral(data, v)
}

// DecodeTrivialUint is the trivial decoding for natural numbers
func (s *Serializer) DecodeTrivialUint(data []byte, v interface{}) error {
return s.codec.UnmarshalTrivialUint(data, v)
}
32 changes: 32 additions & 0 deletions pkg/serialization/serializer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,3 +67,35 @@ func TestGeneralSerializer(t *testing.T) {
require.NoError(t, err)
assert.Equal(t, v, decoded)
}

func TestTrivialSerializer(t *testing.T) {
jamCodec := codec.NewJamCodec()
serializer := serialization.NewSerializer(jamCodec)

// Test Encoding
v := 127
encoded, err := serializer.EncodeTrivialUint(uint64(v), 3)
require.NoError(t, err)
require.Equal(t, []byte{127, 0, 0}, encoded)

// Test Decoding
var d64 uint64
err = serializer.DecodeTrivialUint(encoded, &d64)
require.NoError(t, err)
assert.Equal(t, uint64(v), d64)

var d32 uint32
err = serializer.DecodeTrivialUint(encoded, &d32)
require.NoError(t, err)
assert.Equal(t, uint32(v), d32)

var d16 uint16
err = serializer.DecodeTrivialUint(encoded, &d16)
require.NoError(t, err)
assert.Equal(t, uint16(v), d16)

var d8 uint8
err = serializer.DecodeTrivialUint(encoded, &d8)
require.NoError(t, err)
assert.Equal(t, uint8(v), d8)
}