Skip to content

Commit

Permalink
Add support for trivial natural encoding (#44)
Browse files Browse the repository at this point in the history
  • Loading branch information
pantrif authored Aug 21, 2024
1 parent 9f97fd4 commit fd8314d
Show file tree
Hide file tree
Showing 11 changed files with 221 additions and 38 deletions.
2 changes: 1 addition & 1 deletion internal/block/header_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ func Test_HeaderEncodeDecode(t *testing.T) {
VRFSignature: randomSignature(t),
BlockSealSignature: randomSignature(t),
}
serializer := serialization.NewSerializer(&codec.SCALECodec{})
serializer := serialization.NewSerializer[uint64](&codec.SCALECodec[uint64]{})
bb, err := serializer.Encode(h)
if err != nil {
t.Fatal(err)
Expand Down
8 changes: 7 additions & 1 deletion pkg/serialization/codec/codec.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
package codec

type Codec interface {
type Uint interface {
uint8 | uint16 | uint32 | uint64
}

type Codec[T Uint] interface {
Marshal(v interface{}) ([]byte, error)
MarshalGeneral(x uint64) ([]byte, error)
MarshalTrivialUint(x T, l uint8) ([]byte, error)
Unmarshal(data []byte, v interface{}) error
UnmarshalGeneral(data []byte, v *uint64) error
UnmarshalTrivialUint(data []byte, v *T) error
}
2 changes: 1 addition & 1 deletion pkg/serialization/codec/jam/general_natural.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import (
"math/bits"
)

// GeneralNatural implements the formula (275: able to encode naturals of up to 2^64)
// GeneralNatural implements the formula (able to encode naturals of up to 2^64)
type GeneralNatural struct{}

func (j *GeneralNatural) SerializeUint64(x uint64) []byte {
Expand Down
8 changes: 4 additions & 4 deletions pkg/serialization/codec/jam/general_natural_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,19 +53,19 @@ func TestEncodeDecodeUint64(t *testing.T) {

for _, tc := range testCases {
t.Run(fmt.Sprintf("uint64(%d)", tc.input), func(t *testing.T) {
// Marshal the input value
// Marshal the x value
serialized := gn.SerializeUint64(tc.input)

// Check if the serialized output matches the expected output
assert.Equal(t, tc.expected, serialized, "serialized output mismatch for input %d", tc.input)
assert.Equal(t, tc.expected, serialized, "serialized output mismatch for x %d", tc.input)

// Unmarshal the serialized data back into a uint64
var deserialized uint64
err := gn.DeserializeUint64(serialized, &deserialized)
require.NoError(t, err, "unmarshal(%v) returned an unexpected error", serialized)

// Check if the deserialized value matches the original input
assert.Equal(t, tc.input, deserialized, "deserialized value mismatch for input %d", tc.input)
// Check if the deserialized value matches the original x
assert.Equal(t, tc.input, deserialized, "deserialized value mismatch for x %d", tc.input)
})
}
}
22 changes: 22 additions & 0 deletions pkg/serialization/codec/jam/trivial_natural.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package jam

import (
"math"
)

func SerializeTrivialNatural[T ~uint8 | ~uint16 | ~uint32 | ~uint64](x T, l uint8) []byte {
bytes := make([]byte, l)
for i := uint8(0); i < l; i++ {
bytes[i] = byte((x >> (8 * i)) & T(math.MaxUint8))
}
return bytes
}

func DeserializeTrivialNatural[T ~uint8 | ~uint16 | ~uint32 | ~uint64](serialized []byte, u *T) {
*u = 0

// Iterate over each byte in the serialized array
for i := 0; i < len(serialized); i++ {
*u |= T(serialized[i]) << (8 * i)
}
}
83 changes: 83 additions & 0 deletions pkg/serialization/codec/jam/trivial_natural_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
package jam

import (
"fmt"
"math"
"reflect"
"testing"

"github.com/stretchr/testify/assert"
)

func TestSerializationTrivialNatural(t *testing.T) {
testCases := []struct {
x any
l uint8
expected []byte
}{
{uint8(0), 0, []byte{}},
{uint8(0), 1, []byte{0}},
{uint8(1), 3, []byte{1, 0, 0}},
{uint8(math.MaxInt8), 4, []byte{127, 0, 0, 0}},
{uint8(math.MaxUint8), 2, []byte{255, 0}},
{uint16(0), 0, []byte{}},
{uint16(0), 1, []byte{0}},
{uint16(math.MaxUint16), 2, []byte{255, 255}},
{uint32(0), 0, []byte{}},
{uint32(0), 1, []byte{0}},
{uint32(1), 3, []byte{1, 0, 0}},
{uint32(math.MaxInt8), 4, []byte{127, 0, 0, 0}},
{uint32(128), 1, []byte{128}},
{uint32(math.MaxUint8), 3, []byte{255, 0, 0}},
{uint32(256), 2, []byte{0, 1}},
{uint32(1023), 3, []byte{255, 3, 0}},
{uint32(1024), 2, []byte{0, 4}},
{uint32(16383), 4, []byte{255, 63, 0, 0}},
{uint32(math.MaxUint16), 3, []byte{255, 255, 0}},
{uint32(math.MaxUint32), 4, []byte{255, 255, 255, 255}},
{uint64(0), 0, []byte{}},
{uint64(0), 4, []byte{0, 0, 0, 0}},
{uint64(1), 3, []byte{1, 0, 0}},
{uint64(math.MaxUint16), 3, []byte{255, 255, 0}},
{uint64(math.MaxUint32), 6, []byte{255, 255, 255, 255, 0, 0}},
{uint64(math.MaxUint64), 8, []byte{255, 255, 255, 255, 255, 255, 255, 255}},
}

for _, tc := range testCases {
testName := fmt.Sprintf("%s_%v", reflect.TypeOf(tc.x).Name(), tc.x)
t.Run(testName, func(t *testing.T) {
var serialized []byte
switch v := tc.x.(type) {
case uint8:
serialized = SerializeTrivialNatural(v, tc.l)
case uint16:
serialized = SerializeTrivialNatural(v, tc.l)
case uint32:
serialized = SerializeTrivialNatural(v, tc.l)
case uint64:
serialized = SerializeTrivialNatural(v, tc.l)
}

assert.Equal(t, tc.expected, serialized, "serialized output mismatch")

switch v := tc.x.(type) {
case uint8:
var deserialized uint8
DeserializeTrivialNatural(serialized, &deserialized)
assert.Equal(t, v, deserialized, "deserialized value mismatch")
case uint16:
var deserialized uint16
DeserializeTrivialNatural(serialized, &deserialized)
assert.Equal(t, v, deserialized, "deserialized value mismatch")
case uint32:
var deserialized uint32
DeserializeTrivialNatural(serialized, &deserialized)
assert.Equal(t, v, deserialized, "deserialized value mismatch")
case uint64:
var deserialized uint64
DeserializeTrivialNatural(serialized, &deserialized)
assert.Equal(t, v, deserialized, "deserialized value mismatch")
}
})
}
}
25 changes: 18 additions & 7 deletions pkg/serialization/codec/jam_codec.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,29 +6,40 @@ import (
)

// JAMCodec implements the Codec interface for JSON encoding and decoding.
type JAMCodec struct {
type JAMCodec[T Uint] struct {
gn jam.GeneralNatural
}

// NewJamCodec initializes an instance of Jam codec
func NewJamCodec() *JAMCodec {
return &JAMCodec{gn: jam.GeneralNatural{}}
func NewJamCodec[T Uint]() *JAMCodec[T] {
return &JAMCodec[T]{
gn: jam.GeneralNatural{},
}
}

func (j *JAMCodec) Marshal(v interface{}) ([]byte, error) {
func (j *JAMCodec[T]) Marshal(v interface{}) ([]byte, error) {
// TODO
return nil, errors.New("not implemented")
}

func (j *JAMCodec) MarshalGeneral(v uint64) ([]byte, error) {
func (j *JAMCodec[T]) MarshalGeneral(v uint64) ([]byte, error) {
return j.gn.SerializeUint64(v), nil
}

func (j *JAMCodec) Unmarshal(data []byte, v interface{}) error {
func (j *JAMCodec[T]) MarshalTrivialUint(x T, l uint8) ([]byte, error) {
return jam.SerializeTrivialNatural(x, l), nil
}

func (j *JAMCodec[T]) Unmarshal(data []byte, v interface{}) error {
// TODO
return errors.New("not implemented")
}

func (j *JAMCodec) UnmarshalGeneral(data []byte, v *uint64) error {
func (j *JAMCodec[T]) UnmarshalGeneral(data []byte, v *uint64) error {
return j.gn.DeserializeUint64(data, v)
}

func (j *JAMCodec[T]) UnmarshalTrivialUint(data []byte, x *T) error {
jam.DeserializeTrivialNatural(data, x)
return nil
}
18 changes: 13 additions & 5 deletions pkg/serialization/codec/json_codec.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,28 @@ import (
)

// JSONCodec implements the Codec interface for JSON encoding and decoding.
type JSONCodec struct{}
type JSONCodec[T Uint] struct{}

func (j *JSONCodec) Marshal(v interface{}) ([]byte, error) {
func (j *JSONCodec[T]) Marshal(v interface{}) ([]byte, error) {
return json.Marshal(v)
}

func (j *JSONCodec) MarshalGeneral(v uint64) ([]byte, error) {
func (j *JSONCodec[T]) MarshalGeneral(v uint64) ([]byte, error) {
return json.Marshal(v)
}

func (j *JSONCodec) Unmarshal(data []byte, v interface{}) error {
func (j *JSONCodec[T]) MarshalTrivialUint(x T, l uint8) ([]byte, error) {
return json.Marshal(x)
}

func (j *JSONCodec[T]) Unmarshal(data []byte, v interface{}) error {
return json.Unmarshal(data, v)
}

func (j *JSONCodec) UnmarshalGeneral(data []byte, v *uint64) error {
func (j *JSONCodec[T]) UnmarshalGeneral(data []byte, v *uint64) error {
return json.Unmarshal(data, v)
}

func (j *JSONCodec[T]) UnmarshalTrivialUint(data []byte, x *T) error {
return json.Unmarshal(data, x)
}
18 changes: 13 additions & 5 deletions pkg/serialization/codec/scale_codec.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,28 @@ package codec
import "github.com/ChainSafe/gossamer/pkg/scale"

// SCALECodec implements the Codec interface for SCALE encoding and decoding.
type SCALECodec struct{}
type SCALECodec[T Uint] struct{}

func (s *SCALECodec) Marshal(v interface{}) ([]byte, error) {
func (s *SCALECodec[T]) Marshal(v interface{}) ([]byte, error) {
return scale.Marshal(v)
}

func (j *SCALECodec) MarshalGeneral(v uint64) ([]byte, error) {
func (s *SCALECodec[T]) MarshalGeneral(v uint64) ([]byte, error) {
return scale.Marshal(v)
}

func (s *SCALECodec) Unmarshal(data []byte, v interface{}) error {
func (s *SCALECodec[T]) MarshalTrivialUint(x T, l uint8) ([]byte, error) {
return scale.Marshal(x)
}

func (s *SCALECodec[T]) Unmarshal(data []byte, v interface{}) error {
return scale.Unmarshal(data, v)
}

func (s *SCALECodec) UnmarshalGeneral(data []byte, v *uint64) error {
func (s *SCALECodec[T]) UnmarshalGeneral(data []byte, v *uint64) error {
return scale.Unmarshal(data, v)
}

func (s *SCALECodec[T]) UnmarshalTrivialUint(data []byte, x *T) error {
return scale.Unmarshal(data, x)
}
26 changes: 18 additions & 8 deletions pkg/serialization/serializer.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,31 +3,41 @@ package serialization
import "github.com/eigerco/strawberry/pkg/serialization/codec"

// Serializer provides methods to encode and decode using a specified codec.
type Serializer struct {
codec codec.Codec
type Serializer[T codec.Uint] struct {
codec codec.Codec[T]
}

// NewSerializer initializes a new Serializer with the given codec.
func NewSerializer(c codec.Codec) *Serializer {
return &Serializer{codec: c}
func NewSerializer[T codec.Uint](c codec.Codec[T]) *Serializer[T] {
return &Serializer[T]{codec: c}
}

// Encode serializes the given value using the codec.
func (s *Serializer) Encode(v interface{}) ([]byte, error) {
func (s *Serializer[T]) Encode(v interface{}) ([]byte, error) {
return s.codec.Marshal(v)
}

// EncodeGeneral is specific encoding for natural numbers up to 2^64
func (s *Serializer) EncodeGeneral(v uint64) ([]byte, error) {
func (s *Serializer[T]) EncodeGeneral(v uint64) ([]byte, error) {
return s.codec.MarshalGeneral(v)
}

// EncodeTrivialUint is the trivial encoding for natural numbers
func (s *Serializer[T]) EncodeTrivialUint(x T, l uint8) ([]byte, error) {
return s.codec.MarshalTrivialUint(x, l)
}

// Decode deserializes the given data into the specified value using the codec.
func (s *Serializer) Decode(data []byte, v interface{}) error {
func (s *Serializer[T]) Decode(data []byte, v interface{}) error {
return s.codec.Unmarshal(data, v)
}

// DecodeGeneral is specific decoding for natural numbers up to 2^64
func (s *Serializer) DecodeGeneral(data []byte, v *uint64) error {
func (s *Serializer[T]) DecodeGeneral(data []byte, v *uint64) error {
return s.codec.UnmarshalGeneral(data, v)
}

// DecodeTrivialUint is the trivial decoding for natural numbers
func (s *Serializer[T]) DecodeTrivialUint(data []byte, v *T) error {
return s.codec.UnmarshalTrivialUint(data, v)
}
Loading

0 comments on commit fd8314d

Please sign in to comment.