Skip to content

Commit

Permalink
feat(advancer): add the advancer's repository
Browse files Browse the repository at this point in the history
  • Loading branch information
renan061 committed Aug 27, 2024
1 parent 82fc64b commit 0f1a63a
Show file tree
Hide file tree
Showing 8 changed files with 782 additions and 68 deletions.
16 changes: 10 additions & 6 deletions internal/node/advancer/advancer.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,14 @@ func (advancer *Advancer) Step(ctx context.Context) error {
}
}

// Updates the status of the epochs.
for _, app := range apps {
err := advancer.repository.UpdateEpochs(ctx, app)
if err != nil {
return err
}
}

return nil
}

Expand Down Expand Up @@ -99,11 +107,7 @@ func (advancer *Advancer) process(ctx context.Context, app Address, inputs []*In
}
}

// Updates the status of the epochs based on the last processed input.
lastInput := inputs[len(inputs)-1]
err := advancer.repository.UpdateEpochs(ctx, app, lastInput)

return err
return nil
}

// ------------------------------------------------------------------------------------------------
Expand All @@ -114,7 +118,7 @@ type Repository interface {

StoreAdvanceResult(context.Context, *Input, *nodemachine.AdvanceResult) error

UpdateEpochs(_ context.Context, app Address, lastInput *Input) error
UpdateEpochs(_ context.Context, app Address) error
}

// A map of application addresses to machines.
Expand Down
43 changes: 10 additions & 33 deletions internal/node/advancer/advancer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ func (s *AdvancerSuite) TestRun() {
res3 := randomAdvanceResult()

repository := &MockRepository{
GetInputsReturn: map[Address][]*Input{
GetUnprocessedInputsReturn: map[Address][]*Input{
app1: {
{Id: 1, RawData: marshal(res1)},
{Id: 2, RawData: marshal(res2)},
Expand All @@ -94,7 +94,9 @@ func (s *AdvancerSuite) TestRun() {
require.Len(repository.StoredResults, 3)
})

// NOTE: missing more test cases
s.Run("Error/UpdateEpochs", func() {
s.T().Skip("TODO")
})
}

func (s *AdvancerSuite) TestProcess() {
Expand Down Expand Up @@ -124,7 +126,6 @@ func (s *AdvancerSuite) TestProcess() {
err := advancer.process(context.Background(), app, inputs)
require.Nil(err)
require.Len(repository.StoredResults, 7)
require.Equal(*inputs[6], repository.LastInput)
})

s.Run("Panic", func() {
Expand Down Expand Up @@ -183,25 +184,7 @@ func (s *AdvancerSuite) TestProcess() {
require.Errorf(err, "store-advance error")
require.Len(repository.StoredResults, 1)
})

s.Run("UpdateEpochs", func() {
require := s.Require()

_, repository, advancer, app := setup()
inputs := []*Input{
{Id: 1, RawData: marshal(randomAdvanceResult())},
{Id: 2, RawData: marshal(randomAdvanceResult())},
{Id: 3, RawData: marshal(randomAdvanceResult())},
{Id: 4, RawData: marshal(randomAdvanceResult())},
}
repository.UpdateEpochsError = errors.New("update-epochs error")

err := advancer.process(context.Background(), app, inputs)
require.Errorf(err, "update-epochs error")
require.Len(repository.StoredResults, 4)
})
})

}

func (s *AdvancerSuite) TestKeysFrom() {
Expand All @@ -228,20 +211,19 @@ func (mock *MockMachine) Advance(
// ------------------------------------------------------------------------------------------------

type MockRepository struct {
GetInputsReturn map[Address][]*Input
GetInputsError error
StoreAdvanceError error
UpdateEpochsError error
GetUnprocessedInputsReturn map[Address][]*Input
GetUnprocessedInputsError error
StoreAdvanceError error
UpdateEpochsError error

StoredResults []*nodemachine.AdvanceResult
LastInput Input
}

func (mock *MockRepository) GetUnprocessedInputs(
_ context.Context,
appAddresses []Address,
) (map[Address][]*Input, error) {
return mock.GetInputsReturn, mock.GetInputsError
return mock.GetUnprocessedInputsReturn, mock.GetUnprocessedInputsError
}

func (mock *MockRepository) StoreAdvanceResult(
Expand All @@ -253,12 +235,7 @@ func (mock *MockRepository) StoreAdvanceResult(
return mock.StoreAdvanceError
}

func (mock *MockRepository) UpdateEpochs(
_ context.Context,
_ Address,
lastInput *Input,
) error {
mock.LastInput = *lastInput
func (mock *MockRepository) UpdateEpochs(_ context.Context, _ Address) error {
return mock.UpdateEpochsError
}

Expand Down
221 changes: 221 additions & 0 deletions internal/repository/advancer.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,221 @@
// (c) Cartesi and individual authors (see AUTHORS)
// SPDX-License-Identifier: Apache-2.0 (see LICENSE)

package repository

import (
"context"
"errors"
"fmt"
"strings"

. "github.com/cartesi/rollups-node/internal/node/model"
"github.com/cartesi/rollups-node/internal/nodemachine"
"github.com/jackc/pgx/v5"
)

var ErrAdvancerRepository = errors.New("advancer repository error")

type AdvancerRepository struct{ *Database }

func (repo *AdvancerRepository) GetUnprocessedInputs(
ctx context.Context,
apps []Address,
) (map[Address][]*Input, error) {
result := map[Address][]*Input{}
if len(apps) == 0 {
return result, nil
}

query := fmt.Sprintf(`
SELECT id, application_address, raw_data
FROM input
WHERE status = 'NONE'
AND application_address IN %s
ORDER BY index ASC, application_address
`, toSqlIn(apps)) // NOTE: not sanitized
rows, err := repo.db.Query(ctx, query)
if err != nil {
return nil, fmt.Errorf("%w (failed querying inputs): %w", ErrAdvancerRepository, err)
}

var input Input
scans := []any{&input.Id, &input.AppAddress, &input.RawData}
_, err = pgx.ForEachRow(rows, scans, func() error {
input := input
if _, ok := result[input.AppAddress]; ok { //nolint:gosimple
result[input.AppAddress] = append(result[input.AppAddress], &input)
} else {
result[input.AppAddress] = []*Input{&input}
}
return nil
})
if err != nil {
return nil, fmt.Errorf("%w (failed reading input rows): %w", ErrAdvancerRepository, err)
}

return result, nil
}

func (repo *AdvancerRepository) StoreAdvanceResult(
ctx context.Context,
input *Input,
res *nodemachine.AdvanceResult,
) error {
tx, err := repo.db.Begin(ctx)
if err != nil {
return errors.Join(ErrBeginTx, err)
}

// Inserts the outputs.
nextOutputIndex, err := repo.getNextIndex(ctx, tx, "output", input.AppAddress)
if err != nil {
return err
}
err = repo.insert(ctx, tx, "output", res.Outputs, input.Id, nextOutputIndex)
if err != nil {
return err
}

// Inserts the reports.
nextReportIndex, err := repo.getNextIndex(ctx, tx, "report", input.AppAddress)
if err != nil {
return err
}
err = repo.insert(ctx, tx, "report", res.Reports, input.Id, nextReportIndex)
if err != nil {
return err
}

// Updates the input's status.
err = repo.updateInput(ctx, tx, input.Id, res.Status, res.OutputsHash, res.MachineHash)
if err != nil {
return err
}

err = tx.Commit(ctx)
if err != nil {
return errors.Join(ErrCommitTx, err, tx.Rollback(ctx))
}

return nil
}

func (repo *AdvancerRepository) UpdateEpochs(ctx context.Context, app Address) error {
query := `
UPDATE epoch
SET status = 'PROCESSED_ALL_INPUTS'
WHERE id IN ((
SELECT DISTINCT epoch.id
FROM epoch INNER JOIN input ON (epoch.id = input.epoch_id)
WHERE epoch.application_address = @applicationAddress
AND epoch.status = 'CLOSED'
AND input.status != 'NONE'
) EXCEPT (
SELECT DISTINCT epoch.id
FROM epoch INNER JOIN input ON (epoch.id = input.epoch_id)
WHERE epoch.application_address = @applicationAddress
AND epoch.status = 'CLOSED'
AND input.status = 'NONE'))
`
args := pgx.NamedArgs{"applicationAddress": app}
_, err := repo.db.Exec(ctx, query, args)
if err != nil {
return errors.Join(ErrUpdateRow, err)
}
return nil
}

// ------------------------------------------------------------------------------------------------

func (_ *AdvancerRepository) getNextIndex(
ctx context.Context,
tx pgx.Tx,
tableName string,
appAddress Address,
) (uint64, error) {
var nextIndex uint64
query := fmt.Sprintf(`
SELECT COALESCE(MAX(%s.index) + 1, 0)
FROM input INNER JOIN %s ON input.id = %s.input_id
WHERE input.status = 'ACCEPTED'
AND input.application_address = $1
`, tableName, tableName, tableName)
err := tx.QueryRow(ctx, query, appAddress).Scan(&nextIndex)
if err != nil {
err = fmt.Errorf("failed to get the next %s index: %w", tableName, err)
return 0, errors.Join(err, tx.Rollback(ctx))
}
return nextIndex, nil
}

func (_ *AdvancerRepository) insert(
ctx context.Context,
tx pgx.Tx,
tableName string,
dataArray [][]byte,
inputId uint64,
nextIndex uint64,
) error {
lenOutputs := int64(len(dataArray))
if lenOutputs < 1 {
return nil
}

rows := [][]any{}
for i, data := range dataArray {
rows = append(rows, []any{inputId, nextIndex + uint64(i), data})
}

count, err := tx.CopyFrom(
ctx,
pgx.Identifier{tableName},
[]string{"input_id", "index", "raw_data"},
pgx.CopyFromRows(rows),
)
if err != nil {
return errors.Join(ErrCopyFrom, err, tx.Rollback(ctx))
}
if lenOutputs != count {
err := fmt.Errorf("not all %ss were inserted (%d != %d)", tableName, lenOutputs, count)
return errors.Join(err, tx.Rollback(ctx))
}

return nil
}

func (_ *AdvancerRepository) updateInput(
ctx context.Context,
tx pgx.Tx,
inputId uint64,
status InputCompletionStatus,
outputsHash Hash,
machineHash *Hash,
) error {
query := `
UPDATE input
SET (status, outputs_hash, machine_hash) = (@status, @outputsHash, @machineHash)
WHERE id = @id
`
args := pgx.NamedArgs{
"status": status,
"outputsHash": outputsHash,
"machineHash": machineHash,
"id": inputId,
}
_, err := tx.Exec(ctx, query, args)
if err != nil {
return errors.Join(ErrUpdateRow, err, tx.Rollback(ctx))
}
return nil
}

// ------------------------------------------------------------------------------------------------

func toSqlIn[T fmt.Stringer](a []T) string {
s := []string{}
for _, x := range a {
s = append(s, fmt.Sprintf("'\\x%s'", x.String()[2:]))
}
return fmt.Sprintf("(%s)", strings.Join(s, ", "))
}
Loading

0 comments on commit 0f1a63a

Please sign in to comment.