-
Notifications
You must be signed in to change notification settings - Fork 67
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
5 changed files
with
625 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
// (c) Cartesi and individual authors (see AUTHORS) | ||
// SPDX-License-Identifier: Apache-2.0 (see LICENSE) | ||
|
||
package machineadvancer | ||
|
||
import ( | ||
"context" | ||
"errors" | ||
"fmt" | ||
"time" | ||
|
||
"github.com/cartesi/rollups-node/internal/node/model" | ||
"github.com/cartesi/rollups-node/internal/node/nodemachine" | ||
) | ||
|
||
type Machine interface { | ||
Advance(_ context.Context, input []byte) (*nodemachine.AdvanceResponse, error) | ||
} | ||
|
||
type MachineAdvancer struct { | ||
machines map[model.Address]Machine | ||
repository Repository | ||
ticker *time.Ticker | ||
} | ||
|
||
var ( | ||
ErrInvalidMachines = errors.New("must have at least one machine") | ||
ErrInvalidRepository = errors.New("repository must not be nil") | ||
ErrInvalidPollingInterval = errors.New("polling interval must be greater than zero") | ||
|
||
ErrInvalidAddress = errors.New("invalid address from repository") | ||
) | ||
|
||
// Duration must be greater than 0. | ||
func New( | ||
machines map[model.Address]Machine, | ||
repository Repository, | ||
pollingInterval time.Duration, | ||
) (*MachineAdvancer, error) { | ||
if len(machines) <= 0 { | ||
return nil, ErrInvalidMachines | ||
} | ||
if repository == nil { | ||
return nil, ErrInvalidRepository | ||
} | ||
if pollingInterval <= 0 { | ||
return nil, ErrInvalidPollingInterval | ||
} | ||
return &MachineAdvancer{ | ||
machines: machines, | ||
repository: repository, | ||
ticker: time.NewTicker(pollingInterval), | ||
}, nil | ||
} | ||
|
||
func (advancer *MachineAdvancer) Start() error { | ||
addresses := keysToSlice(advancer.machines) | ||
for { | ||
// Gets the unprocessed inputs (of all apps) from the repository. | ||
inputs, err := advancer.repository.GetInputs(addresses) | ||
if err != nil { | ||
return err | ||
} | ||
|
||
for appAddress, inputs := range inputs { | ||
machine, ok := advancer.machines[appAddress] | ||
if !ok { | ||
return fmt.Errorf("%w: %s", ErrInvalidAddress, appAddress.String()) | ||
} | ||
|
||
// Processes all inputs sequentially. | ||
for _, input := range inputs { | ||
res, err := machine.Advance(context.Background(), input.RawData) | ||
if err != nil { | ||
return err | ||
} | ||
|
||
err = advancer.repository.Store(input, res) | ||
if err != nil { | ||
return err | ||
} | ||
} | ||
} | ||
|
||
// Waits for the current polling interval to elapse. | ||
<-advancer.ticker.C | ||
} | ||
} | ||
|
||
// keysToSlice returns a slice with the keys of a map. | ||
func keysToSlice[T comparable, U any](m map[T]U) []T { | ||
keys := make([]T, len(m)) | ||
i := 0 | ||
for key := range m { | ||
keys[i] = key | ||
i++ | ||
} | ||
return keys | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,242 @@ | ||
// (c) Cartesi and individual authors (see AUTHORS) | ||
// SPDX-License-Identifier: Apache-2.0 (see LICENSE) | ||
|
||
package machineadvancer | ||
|
||
import ( | ||
"context" | ||
crand "crypto/rand" | ||
"errors" | ||
mrand "math/rand" | ||
"testing" | ||
"time" | ||
|
||
"github.com/cartesi/rollups-node/internal/node/model" | ||
"github.com/cartesi/rollups-node/internal/node/nodemachine" | ||
|
||
"github.com/stretchr/testify/suite" | ||
) | ||
|
||
func TestMachineAdvancer(t *testing.T) { | ||
suite.Run(t, new(MachineAdvancerSuite)) | ||
} | ||
|
||
type MachineAdvancerSuite struct{ suite.Suite } | ||
|
||
func (s *MachineAdvancerSuite) TestNew() { | ||
s.Run("Ok", func() { | ||
require := s.Require() | ||
machines := map[model.Address]Machine{randomAddress(): newMockMachine()} | ||
repository := newMockRepository() | ||
machineAdvancer, err := New(machines, repository, time.Nanosecond) | ||
require.NotNil(machineAdvancer) | ||
require.Nil(err) | ||
}) | ||
|
||
s.Run("InvalidMachines", func() { | ||
require := s.Require() | ||
repository := newMockRepository() | ||
machineAdvancer, err := New(nil, repository, time.Nanosecond) | ||
require.Nil(machineAdvancer) | ||
require.Equal(ErrInvalidMachines, err) | ||
}) | ||
|
||
s.Run("InvalidRepository", func() { | ||
require := s.Require() | ||
machines := map[model.Address]Machine{randomAddress(): newMockMachine()} | ||
machineAdvancer, err := New(machines, nil, time.Nanosecond) | ||
require.Nil(machineAdvancer) | ||
require.Equal(ErrInvalidRepository, err) | ||
}) | ||
|
||
s.Run("InvalidPollingInterval", func() { | ||
require := s.Require() | ||
machines := map[model.Address]Machine{randomAddress(): newMockMachine()} | ||
repository := newMockRepository() | ||
machineAdvancer, err := New(machines, repository, time.Duration(0)) | ||
require.Nil(machineAdvancer) | ||
require.Equal(ErrInvalidPollingInterval, err) | ||
}) | ||
} | ||
|
||
func (s *MachineAdvancerSuite) TestStart() { | ||
suite.Run(s.T(), new(StartSuite)) | ||
} | ||
|
||
// ------------------------------------------------------------------------------------------------ | ||
|
||
type StartSuite struct { | ||
suite.Suite | ||
machines map[model.Address]Machine | ||
repository *MockRepository | ||
} | ||
|
||
func (s *StartSuite) SetupTest() { | ||
s.machines = map[model.Address]Machine{} | ||
s.repository = newMockRepository() | ||
} | ||
|
||
// NOTE: This test is very basic! We need more tests! | ||
func (s *StartSuite) TestBasic() { | ||
require := s.Require() | ||
|
||
appAddress := randomAddress() | ||
|
||
machine := newMockMachine() | ||
advanceResponse := randomAdvanceResponse() | ||
machine.add(advanceResponse, nil) | ||
s.machines[appAddress] = machine | ||
|
||
s.repository.add(map[model.Address][]model.Input{appAddress: randomInputs(1)}, nil, nil) | ||
|
||
machineAdvancer, err := New(s.machines, s.repository, time.Nanosecond) | ||
require.NotNil(machineAdvancer) | ||
require.Nil(err) | ||
|
||
err = machineAdvancer.Start() | ||
require.Equal(testFinished, err) | ||
|
||
require.Len(s.repository.stored, 1) | ||
require.Equal(advanceResponse, s.repository.stored[0]) | ||
} | ||
|
||
// ------------------------------------------------------------------------------------------------ | ||
|
||
type MockMachine struct { | ||
index uint8 | ||
results []*nodemachine.AdvanceResponse | ||
errors []error | ||
} | ||
|
||
func newMockMachine() *MockMachine { | ||
return &MockMachine{ | ||
index: 0, | ||
results: []*nodemachine.AdvanceResponse{}, | ||
errors: []error{}, | ||
} | ||
} | ||
|
||
func (m *MockMachine) add(result *nodemachine.AdvanceResponse, err error) { | ||
m.results = append(m.results, result) | ||
m.errors = append(m.errors, err) | ||
} | ||
|
||
func (m *MockMachine) Advance( | ||
_ context.Context, | ||
input []byte, | ||
) (*nodemachine.AdvanceResponse, error) { | ||
result, err := m.results[m.index], m.errors[m.index] | ||
m.index += 1 | ||
return result, err | ||
} | ||
|
||
// ------------------------------------------------------------------------------------------------ | ||
|
||
type MockRepository struct { | ||
getInputsIndex uint8 | ||
getInputsResults []map[model.Address][]model.Input | ||
getInputsErrors []error | ||
|
||
storeIndex uint8 | ||
storeErrors []error | ||
stored []*nodemachine.AdvanceResponse | ||
} | ||
|
||
func newMockRepository() *MockRepository { | ||
return &MockRepository{ | ||
getInputsIndex: 0, | ||
getInputsResults: []map[model.Address][]model.Input{}, | ||
getInputsErrors: []error{}, | ||
storeIndex: 0, | ||
storeErrors: []error{}, | ||
stored: []*nodemachine.AdvanceResponse{}, | ||
} | ||
} | ||
|
||
func (r *MockRepository) add( | ||
getInputsResult map[model.Address][]model.Input, | ||
getInputsError error, | ||
storeError error, | ||
) { | ||
r.getInputsResults = append(r.getInputsResults, getInputsResult) | ||
r.getInputsErrors = append(r.getInputsErrors, getInputsError) | ||
r.storeErrors = append(r.storeErrors, storeError) | ||
} | ||
|
||
var testFinished = errors.New("test finished") | ||
|
||
func (r *MockRepository) GetInputs( | ||
appAddresses []model.Address, | ||
) (map[model.Address][]model.Input, error) { | ||
if int(r.getInputsIndex) == len(r.getInputsResults) { | ||
return nil, testFinished | ||
} | ||
result, err := r.getInputsResults[r.getInputsIndex], r.getInputsErrors[r.getInputsIndex] | ||
r.getInputsIndex += 1 | ||
return result, err | ||
} | ||
|
||
func (r *MockRepository) Store(input model.Input, res *nodemachine.AdvanceResponse) error { | ||
err := r.storeErrors[r.storeIndex] | ||
r.storeIndex += 1 | ||
r.stored = append(r.stored, res) | ||
return err | ||
} | ||
|
||
// ------------------------------------------------------------------------------------------------ | ||
|
||
func randomAddress() model.Address { | ||
address := make([]byte, 20) | ||
_, err := crand.Read(address) | ||
if err != nil { | ||
panic(err) | ||
} | ||
return model.Address(address) | ||
} | ||
|
||
func randomHash() model.Hash { | ||
hash := make([]byte, 32) | ||
_, err := crand.Read(hash) | ||
if err != nil { | ||
panic(err) | ||
} | ||
return model.Hash(hash) | ||
} | ||
|
||
func randomBytes() []byte { | ||
size := mrand.Intn(100) + 1 | ||
bytes := make([]byte, size) | ||
_, err := crand.Read(bytes) | ||
if err != nil { | ||
panic(err) | ||
} | ||
return bytes | ||
} | ||
|
||
func randomSliceOfBytes() [][]byte { | ||
size := mrand.Intn(10) + 1 | ||
slice := make([][]byte, size) | ||
for i := 0; i < size; i++ { | ||
slice[i] = randomBytes() | ||
} | ||
return slice | ||
} | ||
|
||
func randomInputs(size int) []model.Input { | ||
slice := make([]model.Input, size) | ||
for i := 0; i < size; i++ { | ||
slice[i] = model.Input{Id: uint64(i), RawData: randomBytes()} | ||
} | ||
return slice | ||
|
||
} | ||
|
||
func randomAdvanceResponse() *nodemachine.AdvanceResponse { | ||
return &nodemachine.AdvanceResponse{ | ||
Status: model.InputStatusAccepted, | ||
Outputs: randomSliceOfBytes(), | ||
Reports: randomSliceOfBytes(), | ||
OutputsHash: randomHash(), | ||
MachineHash: randomHash(), | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
// (c) Cartesi and individual authors (see AUTHORS) | ||
// SPDX-License-Identifier: Apache-2.0 (see LICENSE) | ||
|
||
package machineadvancer | ||
|
||
import ( | ||
"github.com/cartesi/rollups-node/internal/node/model" | ||
"github.com/cartesi/rollups-node/internal/node/nodemachine" | ||
) | ||
|
||
type Repository interface { | ||
// Only needs Id and RawData fields from model.Input. | ||
GetInputs(appAddresses []model.Address) (map[model.Address][]model.Input, error) | ||
|
||
Store(model.Input, *nodemachine.AdvanceResponse) error | ||
} |
Oops, something went wrong.