From 38f95955b1ad1c5bd146202c7a9fff05d9a2467e Mon Sep 17 00:00:00 2001 From: Renan Santos Date: Wed, 21 Aug 2024 16:41:11 -0300 Subject: [PATCH] feat(advancer): add advancer service --- internal/node/advancer/advancer.go | 138 ++++++++++ internal/node/advancer/advancer_test.go | 331 ++++++++++++++++++++++++ internal/node/advancer/poller/poller.go | 63 +++++ 3 files changed, 532 insertions(+) create mode 100644 internal/node/advancer/advancer.go create mode 100644 internal/node/advancer/advancer_test.go create mode 100644 internal/node/advancer/poller/poller.go diff --git a/internal/node/advancer/advancer.go b/internal/node/advancer/advancer.go new file mode 100644 index 000000000..32ef675ab --- /dev/null +++ b/internal/node/advancer/advancer.go @@ -0,0 +1,138 @@ +// (c) Cartesi and individual authors (see AUTHORS) +// SPDX-License-Identifier: Apache-2.0 (see LICENSE) + +package advancer + +import ( + "context" + "errors" + "fmt" + "log/slog" + "time" + + "github.com/cartesi/rollups-node/internal/node/advancer/poller" + . "github.com/cartesi/rollups-node/internal/node/model" + "github.com/cartesi/rollups-node/internal/nodemachine" +) + +var ( + ErrInvalidMachines = errors.New("machines must not be nil") + ErrInvalidRepository = errors.New("repository must not be nil") + + ErrNoApp = errors.New("no machine for application") + ErrNoInputs = errors.New("no inputs") +) + +type Advancer struct { + machines Machines + repository Repository +} + +// New instantiates a new Advancer. +func New(machines Machines, repository Repository) (*Advancer, error) { + if machines == nil { + return nil, ErrInvalidMachines + } + if repository == nil { + return nil, ErrInvalidRepository + } + return &Advancer{machines: machines, repository: repository}, nil +} + +// Poller instantiates a new poller.Poller using the Advancer. +func (advancer *Advancer) Poller(pollingInterval time.Duration) (*poller.Poller, error) { + return poller.New("advancer", advancer, pollingInterval) +} + +// Step steps the Advancer for one processing cycle. +// It gets unprocessed inputs from the repository, +// runs them through the cartesi machine, +// and updates the repository with the ouputs. +func (advancer *Advancer) Step(ctx context.Context) error { + apps := keysFrom(advancer.machines) + + // Gets the unprocessed inputs (of all apps) from the repository. + slog.Info("advancer: getting unprocessed inputs") + inputs, err := advancer.repository.GetUnprocessedInputs(ctx, apps) + if err != nil { + return err + } + + // Processes each set of inputs. + for app, inputs := range inputs { + slog.Info(fmt.Sprintf("advancer: processing %d input(s) from %v", len(inputs), app)) + err := advancer.process(ctx, app, inputs) + if err != nil { + return err + } + } + + return nil +} + +// process sequentially processes inputs from the the application. +func (advancer *Advancer) process(ctx context.Context, app Address, inputs []*Input) error { + // Asserts that the app has an associated machine. + machine, ok := advancer.machines[app] + if !ok { + panic(fmt.Errorf("%w %s", ErrNoApp, app.String())) + } + + // Asserts that there are inputs to process. + if len(inputs) <= 0 { + panic(ErrNoInputs) + } + + for _, input := range inputs { + slog.Info("advancer: processing input", "id", input.Id, "index", input.Index) + + // Sends the input to the cartesi machine. + res, err := machine.Advance(ctx, input.RawData, input.Index) + if err != nil { + return err + } + + // Stores the result in the database. + err = advancer.repository.StoreAdvanceResult(ctx, input, res) + if err != nil { + return err + } + } + + // Updates the status of the epochs based on the last processed input. + lastInput := inputs[len(inputs)-1] + err := advancer.repository.UpdateEpochs(ctx, app, lastInput) + + return err +} + +// ------------------------------------------------------------------------------------------------ + +type Repository interface { + // Only needs Id, Index, and RawData fields from the retrieved Inputs. + GetUnprocessedInputs(_ context.Context, apps []Address) (map[Address][]*Input, error) + + StoreAdvanceResult(context.Context, *Input, *nodemachine.AdvanceResult) error + + UpdateEpochs(_ context.Context, app Address, lastInput *Input) error +} + +// A map of application addresses to machines. +type Machines = map[Address]Machine + +type Machine interface { + Advance(_ context.Context, input []byte, index uint64) (*nodemachine.AdvanceResult, error) +} + +// ------------------------------------------------------------------------------------------------ + +// keysFrom returns a slice with the keysFrom of a map. +func keysFrom[K comparable, V any](m map[K]V) []K { + keys := make([]K, len(m)) + i := 0 + for k := range m { + keys[i] = k + i++ + } + return keys +} diff --git a/internal/node/advancer/advancer_test.go b/internal/node/advancer/advancer_test.go new file mode 100644 index 000000000..07b6bb6c0 --- /dev/null +++ b/internal/node/advancer/advancer_test.go @@ -0,0 +1,331 @@ +// (c) Cartesi and individual authors (see AUTHORS) +// SPDX-License-Identifier: Apache-2.0 (see LICENSE) + +package advancer + +import ( + "context" + crand "crypto/rand" + "encoding/json" + "errors" + "fmt" + mrand "math/rand" + "testing" + + . "github.com/cartesi/rollups-node/internal/node/model" + "github.com/cartesi/rollups-node/internal/nodemachine" + + "github.com/stretchr/testify/suite" +) + +func TestAdvancer(t *testing.T) { + suite.Run(t, new(AdvancerSuite)) +} + +type AdvancerSuite struct{ suite.Suite } + +func (s *AdvancerSuite) TestNew() { + s.Run("Ok", func() { + require := s.Require() + var machines map[Address]Machine = Machines{randomAddress(): &MockMachine{}} + var repository Repository = &MockRepository{} + advancer, err := New(machines, repository) + require.NotNil(advancer) + require.Nil(err) + }) + + s.Run("InvalidMachines", func() { + require := s.Require() + var machines map[Address]Machine = nil + var repository Repository = &MockRepository{} + advancer, err := New(machines, repository) + require.Nil(advancer) + require.Error(err) + require.Equal(ErrInvalidMachines, err) + }) + + s.Run("InvalidRepository", func() { + require := s.Require() + var machines map[Address]Machine = Machines{randomAddress(): &MockMachine{}} + var repository Repository = nil + advancer, err := New(machines, repository) + require.Nil(advancer) + require.Error(err) + require.Equal(ErrInvalidRepository, err) + }) +} + +func (s *AdvancerSuite) TestPoller() { + s.T().Skip("TODO") +} + +func (s *AdvancerSuite) TestRun() { + s.Run("Ok", func() { + require := s.Require() + + machines := Machines{} + app1 := randomAddress() + machines[app1] = &MockMachine{} + app2 := randomAddress() + machines[app2] = &MockMachine{} + res1 := randomAdvanceResult() + res2 := randomAdvanceResult() + res3 := randomAdvanceResult() + + repository := &MockRepository{ + GetInputsReturn: map[Address][]*Input{ + app1: { + {Id: 1, RawData: marshal(res1)}, + {Id: 2, RawData: marshal(res2)}, + }, + app2: { + {Id: 5, RawData: marshal(res3)}, + }, + }, + } + + advancer, err := New(machines, repository) + require.NotNil(advancer) + require.Nil(err) + + err = advancer.Step(context.Background()) + require.Nil(err) + + require.Len(repository.StoredResults, 3) + }) + + // NOTE: missing more test cases +} + +func (s *AdvancerSuite) TestProcess() { + setup := func() (Machines, *MockRepository, *Advancer, Address) { + app := randomAddress() + machines := Machines{} + machines[app] = &MockMachine{} + repository := &MockRepository{} + advancer := &Advancer{machines, repository} + return machines, repository, advancer, app + } + + s.Run("Ok", func() { + require := s.Require() + + _, repository, advancer, app := setup() + inputs := []*Input{ + {Id: 1, RawData: marshal(randomAdvanceResult())}, + {Id: 2, RawData: marshal(randomAdvanceResult())}, + {Id: 3, RawData: marshal(randomAdvanceResult())}, + {Id: 4, RawData: marshal(randomAdvanceResult())}, + {Id: 5, RawData: marshal(randomAdvanceResult())}, + {Id: 6, RawData: marshal(randomAdvanceResult())}, + {Id: 7, RawData: marshal(randomAdvanceResult())}, + } + + err := advancer.process(context.Background(), app, inputs) + require.Nil(err) + require.Len(repository.StoredResults, 7) + require.Equal(*inputs[6], repository.LastInput) + }) + + s.Run("Panic", func() { + s.Run("ErrApp", func() { + require := s.Require() + + invalidApp := randomAddress() + _, _, advancer, _ := setup() + inputs := randomInputs(3) + + expected := fmt.Sprintf("%v %v", ErrNoApp, invalidApp) + require.PanicsWithError(expected, func() { + _ = advancer.process(context.Background(), invalidApp, inputs) + }) + }) + + s.Run("ErrInputs", func() { + require := s.Require() + + _, _, advancer, app := setup() + inputs := []*Input{} + + require.PanicsWithValue(ErrNoInputs, func() { + _ = advancer.process(context.Background(), app, inputs) + }) + }) + }) + + s.Run("Error", func() { + s.Run("Advance", func() { + require := s.Require() + + _, repository, advancer, app := setup() + inputs := []*Input{ + {Id: 1, RawData: marshal(randomAdvanceResult())}, + {Id: 2, RawData: []byte("advance error")}, + {Id: 3, RawData: []byte("unreachable")}, + } + + err := advancer.process(context.Background(), app, inputs) + require.Errorf(err, "advance error") + require.Len(repository.StoredResults, 1) + }) + + s.Run("StoreAdvance", func() { + require := s.Require() + + _, repository, advancer, app := setup() + inputs := []*Input{ + {Id: 1, RawData: marshal(randomAdvanceResult())}, + {Id: 2, RawData: []byte("unreachable")}, + } + repository.StoreAdvanceError = errors.New("store-advance error") + + err := advancer.process(context.Background(), app, inputs) + require.Errorf(err, "store-advance error") + require.Len(repository.StoredResults, 1) + }) + + s.Run("UpdateEpochs", func() { + require := s.Require() + + _, repository, advancer, app := setup() + inputs := []*Input{ + {Id: 1, RawData: marshal(randomAdvanceResult())}, + {Id: 2, RawData: marshal(randomAdvanceResult())}, + {Id: 3, RawData: marshal(randomAdvanceResult())}, + {Id: 4, RawData: marshal(randomAdvanceResult())}, + } + repository.UpdateEpochsError = errors.New("update-epochs error") + + err := advancer.process(context.Background(), app, inputs) + require.Errorf(err, "update-epochs error") + require.Len(repository.StoredResults, 4) + }) + }) + +} + +func (s *AdvancerSuite) TestKeysFrom() { + s.T().Skip("TODO") +} + +// ------------------------------------------------------------------------------------------------ + +type MockMachine struct{} + +func (mock *MockMachine) Advance( + _ context.Context, + input []byte, + _ uint64, +) (*nodemachine.AdvanceResult, error) { + var res nodemachine.AdvanceResult + err := json.Unmarshal(input, &res) + if err != nil { + return nil, errors.New(string(input)) + } + return &res, nil +} + +// ------------------------------------------------------------------------------------------------ + +type MockRepository struct { + GetInputsReturn map[Address][]*Input + GetInputsError error + StoreAdvanceError error + UpdateEpochsError error + + StoredResults []*nodemachine.AdvanceResult + LastInput Input +} + +func (mock *MockRepository) GetUnprocessedInputs( + _ context.Context, + appAddresses []Address, +) (map[Address][]*Input, error) { + return mock.GetInputsReturn, mock.GetInputsError +} + +func (mock *MockRepository) StoreAdvanceResult( + _ context.Context, + input *Input, + res *nodemachine.AdvanceResult, +) error { + mock.StoredResults = append(mock.StoredResults, res) + return mock.StoreAdvanceError +} + +func (mock *MockRepository) UpdateEpochs( + _ context.Context, + _ Address, + lastInput *Input, +) error { + mock.LastInput = *lastInput + return mock.UpdateEpochsError +} + +// ------------------------------------------------------------------------------------------------ + +func randomAddress() Address { + address := make([]byte, 20) + _, err := crand.Read(address) + if err != nil { + panic(err) + } + return Address(address) +} + +func randomHash() Hash { + hash := make([]byte, 32) + _, err := crand.Read(hash) + if err != nil { + panic(err) + } + return Hash(hash) +} + +func randomBytes() []byte { + size := mrand.Intn(100) + 1 + bytes := make([]byte, size) + _, err := crand.Read(bytes) + if err != nil { + panic(err) + } + return bytes +} + +func randomSliceOfBytes() [][]byte { + size := mrand.Intn(10) + 1 + slice := make([][]byte, size) + for i := 0; i < size; i++ { + slice[i] = randomBytes() + } + return slice +} + +func randomInputs(size int) []*Input { + slice := make([]*Input, size) + for i := 0; i < size; i++ { + slice[i] = &Input{Id: uint64(i), RawData: randomBytes()} + } + return slice + +} + +func randomAdvanceResult() *nodemachine.AdvanceResult { + res := &nodemachine.AdvanceResult{ + Status: InputStatusAccepted, + Outputs: randomSliceOfBytes(), + Reports: randomSliceOfBytes(), + OutputsHash: randomHash(), + MachineHash: new(Hash), + } + *res.MachineHash = randomHash() + return res +} + +func marshal(res *nodemachine.AdvanceResult) []byte { + data, err := json.Marshal(*res) + if err != nil { + panic(err) + } + return data +} diff --git a/internal/node/advancer/poller/poller.go b/internal/node/advancer/poller/poller.go new file mode 100644 index 000000000..f8f469e98 --- /dev/null +++ b/internal/node/advancer/poller/poller.go @@ -0,0 +1,63 @@ +// (c) Cartesi and individual authors (see AUTHORS) +// SPDX-License-Identifier: Apache-2.0 (see LICENSE) + +package poller + +import ( + "context" + "errors" + "fmt" + "log/slog" + "sync/atomic" + "time" +) + +type Service interface { + Step(context.Context) error +} + +type Poller struct { + name string + service Service + shouldStop atomic.Bool + ticker *time.Ticker +} + +var ErrInvalidPollingInterval = errors.New("polling interval must be greater than zero") + +func New(name string, service Service, pollingInterval time.Duration) (*Poller, error) { + if pollingInterval <= 0 { + return nil, ErrInvalidPollingInterval + } + ticker := time.NewTicker(pollingInterval) + return &Poller{name: name, service: service, ticker: ticker}, nil +} + +func (poller *Poller) Start(ctx context.Context, ready chan<- struct{}) error { + ready <- struct{}{} + + slog.Debug(fmt.Sprintf("%s poller started", poller.name)) + + for { + // Runs the service's inner routine. + err := poller.service.Step(ctx) + if err != nil { + return err + } + + // Checks if the service was ordered to stop. + if poller.shouldStop.Load() { + poller.shouldStop.Store(false) + slog.Debug(fmt.Sprintf("%s poller stopped", poller.name)) + return nil + } + + // Waits for the polling interval to elapse. + <-poller.ticker.C + } +} + +// Stop orders the service to stop, which will happen before the next poll. +func (poller *Poller) Stop() { + poller.shouldStop.Store(true) +}