Skip to content

Commit

Permalink
Fix signal handling (#3)
Browse files Browse the repository at this point in the history
  • Loading branch information
roysc committed Oct 3, 2023
1 parent e1ab6a1 commit cdff077
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 27 deletions.
40 changes: 22 additions & 18 deletions cmd/validateTrie.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,43 +70,47 @@ func validateTrie() {
stateRootStr := viper.GetString("validator.stateRoot")
storageRootStr := viper.GetString("validator.storageRoot")
contractAddrStr := viper.GetString("validator.address")

if stateRootStr == "" {
logWithCommand.Fatal("must provide a state root for state trie validation")
}
stateRoot := common.HexToHash(stateRootStr)

traversal := strings.ToLower(viper.GetString("validator.type"))
switch traversal {
case "f", "full":
if stateRootStr == "" {
logWithCommand.Fatal("must provide a state root for full state validation")
}
stateRoot := common.HexToHash(stateRootStr)
logWithCommand.
WithField("root", stateRoot).
Debug("Validating full state")
if err = v.ValidateTrie(stateRoot); err != nil {
logWithCommand.Fatalf("State for root %s is not complete\r\nerr: %v", stateRoot.String(), err)
logWithCommand.Fatalf("Validation failed: %v", err)
}
logWithCommand.Infof("State for root %s is complete", stateRoot.String())
logWithCommand.Infof("State for root %s is complete", stateRoot)
case "state":
if stateRootStr == "" {
logWithCommand.Fatal("must provide a state root for state trie validation")
}
stateRoot := common.HexToHash(stateRootStr)
logWithCommand.
WithField("root", stateRoot).
Debug("Validating state trie")
if err = v.ValidateStateTrie(stateRoot); err != nil {
logWithCommand.Fatalf("State trie for root %s is not complete\r\nerr: %v", stateRoot.String(), err)
logWithCommand.Fatalf("Validation failed: %s", err)
}
logWithCommand.Infof("State trie for root %s is complete", stateRoot.String())
logWithCommand.Infof("State trie for root %s is complete", stateRoot)
case "storage":
if storageRootStr == "" {
logWithCommand.Fatal("must provide a storage root for storage trie validation")
}
if contractAddrStr == "" {
logWithCommand.Fatal("must provide a contract address for storage trie validation")
}
if stateRootStr == "" {
logWithCommand.Fatal("must provide a state root for state trie validation")
}
storageRoot := common.HexToHash(storageRootStr)
addr := common.HexToAddress(contractAddrStr)
stateRoot := common.HexToHash(stateRootStr)
logWithCommand.
WithField("contract", addr).
WithField("storage root", storageRoot).
Debug("Validating storage trie")
if err = v.ValidateStorageTrie(stateRoot, addr, storageRoot); err != nil {
logWithCommand.Fatalf("Storage trie for contract %s and root %s not complete\r\nerr: %v", addr.String(), storageRoot.String(), err)
logWithCommand.Fatalf("Validation failed", err)
}
logWithCommand.Infof("Storage trie for contract %s and root %s is complete", addr.String(), storageRoot.String())
logWithCommand.Infof("Storage trie for contract %s and root %s is complete", addr, storageRoot)
default:
logWithCommand.Fatalf("Invalid traversal level: '%s'", traversal)
}
Expand Down
44 changes: 36 additions & 8 deletions pkg/validator.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ import (
"bytes"
"context"
"fmt"
"os"
"os/signal"
"syscall"
"time"

"github.com/spf13/viper"
Expand Down Expand Up @@ -143,7 +146,7 @@ func (v *Validator) ValidateTrie(stateRoot common.Hash) error {
if err != nil {
return err
}
iterate := func(it trie.NodeIterator) error { return v.iterate(it, true) }
iterate := func(ctx context.Context, it trie.NodeIterator) error { return v.iterate(ctx, it, true) }
return iterateTracked(t, fmt.Sprintf(v.params.RecoveryFormat, fullTraversal), v.params.Workers, iterate)
}

Expand All @@ -155,7 +158,7 @@ func (v *Validator) ValidateStateTrie(stateRoot common.Hash) error {
if err != nil {
return err
}
iterate := func(it trie.NodeIterator) error { return v.iterate(it, false) }
iterate := func(ctx context.Context, it trie.NodeIterator) error { return v.iterate(ctx, it, false) }
return iterateTracked(t, fmt.Sprintf(v.params.RecoveryFormat, stateTraversal), v.params.Workers, iterate)
}

Expand All @@ -167,7 +170,7 @@ func (v *Validator) ValidateStorageTrie(stateRoot common.Hash, address common.Ad
if err != nil {
return err
}
iterate := func(it trie.NodeIterator) error { return v.iterate(it, false) }
iterate := func(ctx context.Context, it trie.NodeIterator) error { return v.iterate(ctx, it, false) }
return iterateTracked(t, fmt.Sprintf(v.params.RecoveryFormat, storageTraversal), v.params.Workers, iterate)
}

Expand All @@ -181,12 +184,18 @@ func (v *Validator) Close() error {

// Traverses one iterator fully
// If storage = true, also traverse storage tries for each leaf.
func (v *Validator) iterate(it trie.NodeIterator, storage bool) error {
func (v *Validator) iterate(ctx context.Context, it trie.NodeIterator, storage bool) error {
// Iterate through entire state trie. it.Next() will return false when we have
// either completed iteration of the entire trie or run into an error (e.g. a
// missing node). If we are able to iterate through the entire trie without error
// then the trie is complete.
for it.Next(true) {
select {
case <-ctx.Done():
return ctx.Err()
default:
}

// This block adapted from geth - core/state/iterator.go
// If storage is not requested, or the state trie node is an internal entry, skip
if !storage || !it.Leaf() {
Expand Down Expand Up @@ -219,10 +228,15 @@ func (v *Validator) iterate(it trie.NodeIterator, storage bool) error {

// Traverses each iterator in a separate goroutine.
// Dumps to a recovery file on failure or interrupt.
func iterateTracked(tree state.Trie, recoveryFile string, iterCount uint, fn func(trie.NodeIterator) error) error {
ctx, _ := context.WithCancel(context.Background())
func iterateTracked(
tree state.Trie,
recoveryFile string,
iterCount uint,
fn func(context.Context, trie.NodeIterator) error,
) error {
tracker := tracker.New(recoveryFile, iterCount)
halt := func() {
log.Errorf("writing recovery file: %s", recoveryFile)
if err := tracker.CloseAndSave(); err != nil {
log.Errorf("failed to write recovery file: %v", err)
}
Expand All @@ -242,14 +256,28 @@ func iterateTracked(tree state.Trie, recoveryFile string, iterCount uint, fn fun
for i, it := range iters {
iters[i] = tracker.Tracked(it)
}
} else {
log.Debugf("restored %d iterators from: %s", len(iters), recoveryFile)
}

ctx, cancel := context.WithCancel(context.Background())
g, ctx := errgroup.WithContext(ctx)
defer halt()

sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
go func() {
sig := <-sigChan
log.Errorf("Signal received (%v), stopping", sig)
cancel()
}()

defer halt()
for _, it := range iters {
func(it trie.NodeIterator) {
g.Go(func() error { return fn(it) })
g.Go(func() error {
return fn(ctx, it)
})

}(it)
}
return g.Wait()
Expand Down
2 changes: 1 addition & 1 deletion test/compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ services:
restart: on-failure
depends_on:
- ipld-eth-db
image: git.vdb.to/cerc-io/ipld-eth-db/ipld-eth-db:v5.0.2-alpha
image: git.vdb.to/cerc-io/ipld-eth-db/ipld-eth-db:v5.0.5-alpha
environment:
DATABASE_USER: "vdbm"
DATABASE_NAME: "cerc_testing"
Expand Down

0 comments on commit cdff077

Please sign in to comment.