-
Notifications
You must be signed in to change notification settings - Fork 67
Commit
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
// (c) Cartesi and individual authors (see AUTHORS) | ||
// SPDX-License-Identifier: Apache-2.0 (see LICENSE) | ||
|
||
package advancer | ||
|
||
import ( | ||
"context" | ||
"errors" | ||
"fmt" | ||
"log/slog" | ||
"sync" | ||
"time" | ||
|
||
"github.com/cartesi/rollups-node/internal/node/machine/nodemachine" | ||
. "github.com/cartesi/rollups-node/internal/node/model" | ||
) | ||
|
||
type Repository interface { | ||
// Only needs Id and RawData fields from model.Input. | ||
GetInputs(_ context.Context, appAddresses []Address) (map[Address][]*Input, error) | ||
|
||
StoreResults(context.Context, *Input, *nodemachine.AdvanceResult) error | ||
} | ||
|
||
type Machine interface { | ||
Advance(_ context.Context, input []byte) (*nodemachine.AdvanceResult, error) | ||
} | ||
|
||
type MachineAdvancer struct { | ||
machines *sync.Map // map[Address]Machine | ||
repository Repository | ||
ticker *time.Ticker | ||
} | ||
|
||
var ( | ||
ErrInvalidMachines = errors.New("machines must not be nil") | ||
ErrInvalidRepository = errors.New("repository must not be nil") | ||
ErrInvalidPollingInterval = errors.New("polling interval must be greater than zero") | ||
|
||
ErrInvalidAddress = errors.New("no machine for address") | ||
) | ||
|
||
func New( | ||
machines *sync.Map, | ||
repository Repository, | ||
pollingInterval time.Duration, | ||
) (*MachineAdvancer, error) { | ||
if machines == nil { | ||
return nil, ErrInvalidMachines | ||
} | ||
if repository == nil { | ||
return nil, ErrInvalidRepository | ||
} | ||
if pollingInterval <= 0 { | ||
return nil, ErrInvalidPollingInterval | ||
} | ||
return &MachineAdvancer{ | ||
machines: machines, | ||
repository: repository, | ||
ticker: time.NewTicker(pollingInterval), | ||
}, nil | ||
} | ||
|
||
func (advancer *MachineAdvancer) Start(ctx context.Context) error { | ||
for { | ||
appAddresses := keysToSlice(advancer.machines) | ||
slog.Info("advancer: getting unprocessed inputs from the database.", | ||
"appAddresses", appAddresses) | ||
|
||
// Gets the unprocessed inputs (of all apps) from the repository. | ||
inputs, err := advancer.repository.GetInputs(ctx, appAddresses) | ||
if err != nil { | ||
return err | ||
} | ||
|
||
for appAddress, inputs := range inputs { | ||
slog.Info("advancer: processing inputs.", "appAddress", appAddress) | ||
value, ok := advancer.machines.Load(appAddress) | ||
if !ok { | ||
return fmt.Errorf("%w %s", ErrInvalidAddress, appAddress.String()) | ||
} | ||
machine := value.(Machine) | ||
|
||
// Processes all inputs sequentially. | ||
for _, input := range inputs { | ||
slog.Info("advancer: processing input", "id", input.Id) | ||
slog.Debug("--->", "input", input) | ||
res, err := machine.Advance(ctx, input.RawData) | ||
if err != nil { | ||
return err | ||
} | ||
|
||
slog.Info("advancer: storing result") | ||
slog.Debug("--->", "result", res) | ||
err = advancer.repository.StoreResults(ctx, input, res) | ||
if err != nil { | ||
return err | ||
} | ||
} | ||
} | ||
|
||
// Waits for the current polling interval to elapse. | ||
slog.Info("advancer: waiting.") | ||
<-advancer.ticker.C | ||
} | ||
} | ||
|
||
// ------------------------------------------------------------------------------------------------ | ||
|
||
// keysToSlice returns a slice with the keys of a sync.Map. | ||
func keysToSlice(m *sync.Map) []Address { | ||
keys := []Address{} | ||
m.Range(func(key, _ any) bool { | ||
keys = append(keys, key.(Address)) | ||
return true | ||
}) | ||
return keys | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,242 @@ | ||
// (c) Cartesi and individual authors (see AUTHORS) | ||
// SPDX-License-Identifier: Apache-2.0 (see LICENSE) | ||
|
||
package advancer | ||
|
||
import ( | ||
"context" | ||
crand "crypto/rand" | ||
"errors" | ||
mrand "math/rand" | ||
"testing" | ||
"time" | ||
|
||
"github.com/cartesi/rollups-node/internal/node/machine/nodemachine" | ||
"github.com/cartesi/rollups-node/internal/node/model" | ||
|
||
"github.com/stretchr/testify/suite" | ||
) | ||
|
||
func TestMachineAdvancer(t *testing.T) { | ||
suite.Run(t, new(MachineAdvancerSuite)) | ||
} | ||
|
||
type MachineAdvancerSuite struct{ suite.Suite } | ||
|
||
func (s *MachineAdvancerSuite) TestNew() { | ||
s.Run("Ok", func() { | ||
require := s.Require() | ||
machines := map[model.Address]Machine{randomAddress(): newMockMachine()} | ||
repository := newMockRepository() | ||
machineAdvancer, err := New(machines, repository, time.Nanosecond) | ||
Check failure on line 31 in internal/node/machine/advancer/advancer_test.go GitHub Actions / test-go
Check failure on line 31 in internal/node/machine/advancer/advancer_test.go GitHub Actions / test-go
|
||
require.NotNil(machineAdvancer) | ||
require.Nil(err) | ||
}) | ||
|
||
s.Run("InvalidMachines", func() { | ||
require := s.Require() | ||
repository := newMockRepository() | ||
machineAdvancer, err := New(nil, repository, time.Nanosecond) | ||
Check failure on line 39 in internal/node/machine/advancer/advancer_test.go GitHub Actions / test-go
|
||
require.Nil(machineAdvancer) | ||
require.Equal(ErrInvalidMachines, err) | ||
}) | ||
|
||
s.Run("InvalidRepository", func() { | ||
require := s.Require() | ||
machines := map[model.Address]Machine{randomAddress(): newMockMachine()} | ||
machineAdvancer, err := New(machines, nil, time.Nanosecond) | ||
require.Nil(machineAdvancer) | ||
require.Equal(ErrInvalidRepository, err) | ||
}) | ||
|
||
s.Run("InvalidPollingInterval", func() { | ||
require := s.Require() | ||
machines := map[model.Address]Machine{randomAddress(): newMockMachine()} | ||
repository := newMockRepository() | ||
machineAdvancer, err := New(machines, repository, time.Duration(0)) | ||
Check failure on line 56 in internal/node/machine/advancer/advancer_test.go GitHub Actions / test-go
Check failure on line 56 in internal/node/machine/advancer/advancer_test.go GitHub Actions / test-go
|
||
require.Nil(machineAdvancer) | ||
require.Equal(ErrInvalidPollingInterval, err) | ||
}) | ||
} | ||
|
||
func (s *MachineAdvancerSuite) TestStart() { | ||
suite.Run(s.T(), new(StartSuite)) | ||
} | ||
|
||
// ------------------------------------------------------------------------------------------------ | ||
|
||
type StartSuite struct { | ||
suite.Suite | ||
machines map[model.Address]Machine | ||
repository *MockRepository | ||
} | ||
|
||
func (s *StartSuite) SetupTest() { | ||
s.machines = map[model.Address]Machine{} | ||
s.repository = newMockRepository() | ||
} | ||
|
||
// NOTE: This test is very basic! We need more tests! | ||
func (s *StartSuite) TestBasic() { | ||
require := s.Require() | ||
|
||
appAddress := randomAddress() | ||
|
||
machine := newMockMachine() | ||
advanceResponse := randomAdvanceResponse() | ||
machine.add(advanceResponse, nil) | ||
s.machines[appAddress] = machine | ||
|
||
s.repository.add(map[model.Address][]model.Input{appAddress: randomInputs(1)}, nil, nil) | ||
|
||
machineAdvancer, err := New(s.machines, s.repository, time.Nanosecond) | ||
Check failure on line 92 in internal/node/machine/advancer/advancer_test.go GitHub Actions / test-go
Check failure on line 92 in internal/node/machine/advancer/advancer_test.go GitHub Actions / test-go
|
||
require.NotNil(machineAdvancer) | ||
require.Nil(err) | ||
|
||
err = machineAdvancer.Start() | ||
require.Equal(testFinished, err) | ||
|
||
require.Len(s.repository.stored, 1) | ||
require.Equal(advanceResponse, s.repository.stored[0]) | ||
} | ||
|
||
// ------------------------------------------------------------------------------------------------ | ||
|
||
type MockMachine struct { | ||
index uint8 | ||
results []*nodemachine.AdvanceResult | ||
errors []error | ||
} | ||
|
||
func newMockMachine() *MockMachine { | ||
return &MockMachine{ | ||
index: 0, | ||
results: []*nodemachine.AdvanceResult{}, | ||
errors: []error{}, | ||
} | ||
} | ||
|
||
func (m *MockMachine) add(result *nodemachine.AdvanceResult, err error) { | ||
m.results = append(m.results, result) | ||
m.errors = append(m.errors, err) | ||
} | ||
|
||
func (m *MockMachine) Advance( | ||
_ context.Context, | ||
input []byte, | ||
) (*nodemachine.AdvanceResult, error) { | ||
result, err := m.results[m.index], m.errors[m.index] | ||
m.index += 1 | ||
return result, err | ||
} | ||
|
||
// ------------------------------------------------------------------------------------------------ | ||
|
||
type MockRepository struct { | ||
getInputsIndex uint8 | ||
getInputsResults []map[model.Address][]model.Input | ||
getInputsErrors []error | ||
|
||
storeIndex uint8 | ||
storeErrors []error | ||
stored []*nodemachine.AdvanceResult | ||
} | ||
|
||
func newMockRepository() *MockRepository { | ||
return &MockRepository{ | ||
getInputsIndex: 0, | ||
getInputsResults: []map[model.Address][]model.Input{}, | ||
getInputsErrors: []error{}, | ||
storeIndex: 0, | ||
storeErrors: []error{}, | ||
stored: []*nodemachine.AdvanceResult{}, | ||
} | ||
} | ||
|
||
func (r *MockRepository) add( | ||
getInputsResult map[model.Address][]model.Input, | ||
getInputsError error, | ||
storeError error, | ||
) { | ||
r.getInputsResults = append(r.getInputsResults, getInputsResult) | ||
r.getInputsErrors = append(r.getInputsErrors, getInputsError) | ||
r.storeErrors = append(r.storeErrors, storeError) | ||
} | ||
|
||
var testFinished = errors.New("test finished") | ||
|
||
func (r *MockRepository) GetUnprocessedInputs( | ||
appAddresses []model.Address, | ||
) (map[model.Address][]model.Input, error) { | ||
if int(r.getInputsIndex) == len(r.getInputsResults) { | ||
return nil, testFinished | ||
} | ||
result, err := r.getInputsResults[r.getInputsIndex], r.getInputsErrors[r.getInputsIndex] | ||
r.getInputsIndex += 1 | ||
return result, err | ||
} | ||
|
||
func (r *MockRepository) Store(input model.Input, res *nodemachine.AdvanceResult) error { | ||
err := r.storeErrors[r.storeIndex] | ||
r.storeIndex += 1 | ||
r.stored = append(r.stored, res) | ||
return err | ||
} | ||
|
||
// ------------------------------------------------------------------------------------------------ | ||
|
||
func randomAddress() model.Address { | ||
address := make([]byte, 20) | ||
_, err := crand.Read(address) | ||
if err != nil { | ||
panic(err) | ||
} | ||
return model.Address(address) | ||
} | ||
|
||
func randomHash() model.Hash { | ||
hash := make([]byte, 32) | ||
_, err := crand.Read(hash) | ||
if err != nil { | ||
panic(err) | ||
} | ||
return model.Hash(hash) | ||
} | ||
|
||
func randomBytes() []byte { | ||
size := mrand.Intn(100) + 1 | ||
bytes := make([]byte, size) | ||
_, err := crand.Read(bytes) | ||
if err != nil { | ||
panic(err) | ||
} | ||
return bytes | ||
} | ||
|
||
func randomSliceOfBytes() [][]byte { | ||
size := mrand.Intn(10) + 1 | ||
slice := make([][]byte, size) | ||
for i := 0; i < size; i++ { | ||
slice[i] = randomBytes() | ||
} | ||
return slice | ||
} | ||
|
||
func randomInputs(size int) []model.Input { | ||
slice := make([]model.Input, size) | ||
for i := 0; i < size; i++ { | ||
slice[i] = model.Input{Id: uint64(i), RawData: randomBytes()} | ||
} | ||
return slice | ||
|
||
} | ||
|
||
func randomAdvanceResponse() *nodemachine.AdvanceResult { | ||
return &nodemachine.AdvanceResult{ | ||
Status: model.InputStatusAccepted, | ||
Outputs: randomSliceOfBytes(), | ||
Reports: randomSliceOfBytes(), | ||
OutputsHash: randomHash(), | ||
MachineHash: randomHash(), | ||
} | ||
} |