From 5d7f68ca78c2b50eb9c0e9d73e89e263621156c2 Mon Sep 17 00:00:00 2001 From: Ian Davis <18375+iand@users.noreply.github.com> Date: Wed, 16 Aug 2023 16:55:04 +0100 Subject: [PATCH] Add include state machine (#95) Based on sm-bootstrap branch (#90) This adds a state machine for running the include process described in https://github.com/plprobelab/go-kademlia/issues/45. The state machine manages a queue of candidates nodes and processes them by checking whether they respond to a find node request. Candidates that respond with one or more closer nodes are considered live and added to the routing table. Nodes that do not respond or do not provide any suggested closer nodes are dropped from the queue. The number of concurrent checks in flight is configurable. Not done yet: - [ ] check timeouts - [ ] removing nodes failing checks from routing table - [ ] notifying of unroutable nodes --- coord/coordinator.go | 200 ++++++++++++++++--- coord/coordinator_test.go | 79 ++++++-- examples/statemachine/main.go | 7 +- routing/include.go | 281 +++++++++++++++++++++++++++ routing/include_test.go | 350 ++++++++++++++++++++++++++++++++++ 5 files changed, 872 insertions(+), 45 deletions(-) create mode 100644 routing/include.go create mode 100644 routing/include_test.go diff --git a/coord/coordinator.go b/coord/coordinator.go index 3eef338..a68a2da 100644 --- a/coord/coordinator.go +++ b/coord/coordinator.go @@ -43,7 +43,7 @@ func (q *eventQueue[E]) Enqueue(ctx context.Context, e E) { } // Dequeue reads an event from the queue. It returns the event and a true value -// if an event was read or the zero value if the event type and false if no event +// if an event was read or the zero value of the event type and false if no event // was read. This method is non-blocking. func (q *eventQueue[E]) Dequeue(ctx context.Context) (E, bool) { select { @@ -55,6 +55,10 @@ func (q *eventQueue[E]) Dequeue(ctx context.Context) (E, bool) { } } +// FindNodeRequestFunc is a function that creates a request to find the supplied node id +// TODO: consider this being a first class method of the Endpoint +type FindNodeRequestFunc[K kad.Key[K], A kad.Address[A]] func(kad.NodeID[K]) (address.ProtocolID, kad.Request[K, A]) + // A Coordinator coordinates the state machines that comprise a Kademlia DHT // Currently this is only queries and bootstrapping but will expand to include other state machines such as // routing table refresh, and reproviding. @@ -77,12 +81,22 @@ type Coordinator[K kad.Key[K], A kad.Address[A]] struct { // bootstrapEvents is a fifo queue of events that are to be processed by the bootstrap state machine bootstrapEvents *eventQueue[routing.BootstrapEvent] + // include is the include state machine, responsible for including candidate nodes into the routing table + include StateMachine[routing.IncludeState, routing.IncludeEvent] + + // includeEvents is a fifo queue of events that are to be processed by the include state machine + includeEvents *eventQueue[routing.IncludeEvent] + // rt is the routing table used to look up nodes by distance rt kad.RoutingTable[K, kad.NodeID[K]] // ep is the message endpoint used to send requests ep endpoint.Endpoint[K, A] + // findNodeFn is a function that creates a find node request that may be understod by the endpoint + // TODO: thiis should be a function of the endpoint + findNodeFn FindNodeRequestFunc[K, A] + // queue not used queue event.EventQueue @@ -155,7 +169,7 @@ func DefaultConfig() *Config { } } -func NewCoordinator[K kad.Key[K], A kad.Address[A]](self kad.NodeID[K], ep endpoint.Endpoint[K, A], rt kad.RoutingTable[K, kad.NodeID[K]], cfg *Config) (*Coordinator[K, A], error) { +func NewCoordinator[K kad.Key[K], A kad.Address[A]](self kad.NodeID[K], ep endpoint.Endpoint[K, A], fn FindNodeRequestFunc[K, A], rt kad.RoutingTable[K, kad.NodeID[K]], cfg *Config) (*Coordinator[K, A], error) { if cfg == nil { cfg = DefaultConfig() } else if err := cfg.Validate(); err != nil { @@ -182,17 +196,34 @@ func NewCoordinator[K kad.Key[K], A kad.Address[A]](self kad.NodeID[K], ep endpo bootstrap, err := routing.NewBootstrap(self, bootstrapCfg) if err != nil { - return nil, fmt.Errorf("query pool: %w", err) + return nil, fmt.Errorf("bootstrap: %w", err) + } + + includeCfg := routing.DefaultIncludeConfig() + includeCfg.Clock = cfg.Clock + includeCfg.Timeout = cfg.QueryTimeout + + // TODO: expose config + // includeCfg.QueueCapacity = cfg.IncludeQueueCapacity + // includeCfg.Concurrency = cfg.IncludeConcurrency + // includeCfg.Timeout = cfg.IncludeTimeout + + include, err := routing.NewInclude[K, A](rt, includeCfg) + if err != nil { + return nil, fmt.Errorf("include: %w", err) } return &Coordinator[K, A]{ self: self, cfg: *cfg, ep: ep, + findNodeFn: fn, rt: rt, pool: qp, poolEvents: newEventQueue[query.PoolEvent](20), // 20 is abitrary, move to config bootstrap: bootstrap, bootstrapEvents: newEventQueue[routing.BootstrapEvent](20), // 20 is abitrary, move to config + include: include, + includeEvents: newEventQueue[routing.IncludeEvent](20), // 20 is abitrary, move to config outboundEvents: make(chan KademliaEvent, 20), queue: event.NewChanQueue(DefaultChanqueueCapacity), planner: event.NewSimplePlanner(cfg.Clock), @@ -207,8 +238,28 @@ func (c *Coordinator[K, A]) RunOne(ctx context.Context) bool { ctx, span := util.StartSpan(ctx, "Coordinator.RunOne") defer span.End() + // Process state machines in priority order + // Give the bootstrap state machine priority - // No queries can be run while a bootstrap is in progress + if c.advanceBootstrap(ctx) { + return true + } + + // Attempt to advance the include state machine so candidate nodes + // are added to the routing table + if c.advanceInclude(ctx) { + return true + } + + // Attempt to advance an outbound query + if c.advancePool(ctx) { + return true + } + + return false +} + +func (c *Coordinator[K, A]) advanceBootstrap(ctx context.Context) bool { bev, ok := c.bootstrapEvents.Dequeue(ctx) if !ok { bev = &routing.EventBootstrapPoll{} @@ -217,11 +268,11 @@ func (c *Coordinator[K, A]) RunOne(ctx context.Context) bool { bstate := c.bootstrap.Advance(ctx, bev) switch st := bstate.(type) { case *routing.StateBootstrapMessage[K, A]: - c.sendBootstrapMessage(ctx, st.ProtocolID, st.NodeID, st.Message, st.QueryID, st.Stats) + c.sendBootstrapFindNode(ctx, st.NodeID, st.QueryID, st.Stats) return true case *routing.StateBootstrapWaiting: - // bootstrap waiting for a message response, don't proceed with other state machines + // bootstrap waiting for a message response, proceed with other state machines return false case *routing.StateBootstrapFinished: @@ -232,12 +283,49 @@ func (c *Coordinator[K, A]) RunOne(ctx context.Context) bool { case *routing.StateBootstrapIdle: // bootstrap not running, can proceed to other state machines - break + return false default: panic(fmt.Sprintf("unexpected bootstrap state: %T", st)) } +} - // Attempt to advance an outbound query +func (c *Coordinator[K, A]) advanceInclude(ctx context.Context) bool { + // Attempt to advance the include state machine so candidate nodes + // are added to the routing table + iev, ok := c.includeEvents.Dequeue(ctx) + if !ok { + iev = &routing.EventIncludePoll{} + } + istate := c.include.Advance(ctx, iev) + switch st := istate.(type) { + case *routing.StateIncludeFindNodeMessage[K, A]: + // include wants to send a find node message to a node + c.sendIncludeFindNode(ctx, st.NodeInfo) + return true + case *routing.StateIncludeRoutingUpdated[K, A]: + // a node has been included in the routing table + c.outboundEvents <- &KademliaRoutingUpdatedEvent[K, A]{ + NodeInfo: st.NodeInfo, + } + return true + case *routing.StateIncludeWaitingAtCapacity: + // nothing to do except wait for message response or timeout + return false + case *routing.StateIncludeWaitingWithCapacity: + // nothing to do except wait for message response or timeout + return false + case *routing.StateIncludeWaitingFull: + // nothing to do except wait for message response or timeout + return false + case *routing.StateIncludeIdle: + // nothing to do except wait for message response or timeout + return false + default: + panic(fmt.Sprintf("unexpected include state: %T", st)) + } +} + +func (c *Coordinator[K, A]) advancePool(ctx context.Context) bool { pev, ok := c.poolEvents.Dequeue(ctx) if !ok { pev = &query.EventPoolPoll{} @@ -249,26 +337,26 @@ func (c *Coordinator[K, A]) RunOne(ctx context.Context) bool { c.sendQueryMessage(ctx, st.ProtocolID, st.NodeID, st.Message, st.QueryID, st.Stats) return true case *query.StatePoolWaitingAtCapacity: - // TODO + // nothing to do except wait for message response or timeout + return false case *query.StatePoolWaitingWithCapacity: - // TODO + // nothing to do except wait for message response or timeout + return false case *query.StatePoolQueryFinished: c.outboundEvents <- &KademliaOutboundQueryFinishedEvent{ QueryID: st.QueryID, Stats: st.Stats, } return true - - // TODO case *query.StatePoolQueryTimeout: // TODO + return false case *query.StatePoolIdle: - // TODO + // nothing to do + return false default: panic(fmt.Sprintf("unexpected pool state: %T", st)) } - - return false } func (c *Coordinator[K, A]) sendQueryMessage(ctx context.Context, protoID address.ProtocolID, to kad.NodeID[K], msg kad.Request[K, A], queryID query.QueryID, stats query.QueryStats) { @@ -326,8 +414,8 @@ func (c *Coordinator[K, A]) sendQueryMessage(ctx context.Context, protoID addres } } -func (c *Coordinator[K, A]) sendBootstrapMessage(ctx context.Context, protoID address.ProtocolID, to kad.NodeID[K], msg kad.Request[K, A], queryID query.QueryID, stats query.QueryStats) { - ctx, span := util.StartSpan(ctx, "Coordinator.sendBootstrapMessage") +func (c *Coordinator[K, A]) sendBootstrapFindNode(ctx context.Context, to kad.NodeID[K], queryID query.QueryID, stats query.QueryStats) { + ctx, span := util.StartSpan(ctx, "Coordinator.sendBootstrapFindNode") defer span.End() onSendError := func(ctx context.Context, err error) { @@ -373,13 +461,65 @@ func (c *Coordinator[K, A]) sendBootstrapMessage(ctx context.Context, protoID ad c.bootstrapEvents.Enqueue(ctx, bev) } + protoID, msg := c.findNodeFn(c.self) err := c.ep.SendRequestHandleResponse(ctx, protoID, to, msg, msg.EmptyResponse(), 0, onMessageResponse) if err != nil { onSendError(ctx, err) } } +func (c *Coordinator[K, A]) sendIncludeFindNode(ctx context.Context, to kad.NodeInfo[K, A]) { + ctx, span := util.StartSpan(ctx, "Coordinator.sendIncludeFindNode") + defer span.End() + + onSendError := func(ctx context.Context, err error) { + if errors.Is(err, endpoint.ErrCannotConnect) { + // here we can notify that the peer is unroutable, which would feed into peerstore and routing table + // TODO: remove from routing table + return + } + + iev := &routing.EventIncludeMessageFailure[K, A]{ + NodeInfo: to, + Error: err, + } + c.includeEvents.Enqueue(ctx, iev) + } + + onMessageResponse := func(ctx context.Context, resp kad.Response[K, A], err error) { + if err != nil { + onSendError(ctx, err) + return + } + + iev := &routing.EventIncludeMessageResponse[K, A]{ + NodeInfo: to, + Response: resp, + } + c.includeEvents.Enqueue(ctx, iev) + + if resp != nil { + candidates := resp.CloserNodes() + if len(candidates) > 0 { + // ignore error here + c.AddNodes(ctx, candidates) + } + } + } + + // this might be new node addressing info + c.ep.MaybeAddToPeerstore(ctx, to, c.cfg.PeerstoreTTL) + + protoID, msg := c.findNodeFn(c.self) + err := c.ep.SendRequestHandleResponse(ctx, protoID, to.ID(), msg, msg.EmptyResponse(), 0, onMessageResponse) + if err != nil { + onSendError(ctx, err) + } +} + func (c *Coordinator[K, A]) StartQuery(ctx context.Context, queryID query.QueryID, protocolID address.ProtocolID, msg kad.Request[K, A]) error { + ctx, span := util.StartSpan(ctx, "Coordinator.StartQuery") + defer span.End() knownClosestPeers := c.rt.NearestNodes(msg.Target(), 20) qev := &query.EventPoolAddQuery[K, A]{ @@ -394,6 +534,8 @@ func (c *Coordinator[K, A]) StartQuery(ctx context.Context, queryID query.QueryI } func (c *Coordinator[K, A]) StopQuery(ctx context.Context, queryID query.QueryID) error { + ctx, span := util.StartSpan(ctx, "Coordinator.StopQuery") + defer span.End() qev := &query.EventPoolStopQuery{ QueryID: queryID, } @@ -402,33 +544,29 @@ func (c *Coordinator[K, A]) StopQuery(ctx context.Context, queryID query.QueryID } // AddNodes suggests new DHT nodes and their associated addresses to be added to the routing table. -// If the routing table is been updated as a result of this operation a KademliaRoutingUpdatedEvent event is emitted. +// If the routing table is updated as a result of this operation a KademliaRoutingUpdatedEvent event is emitted. func (c *Coordinator[K, A]) AddNodes(ctx context.Context, infos []kad.NodeInfo[K, A]) error { + ctx, span := util.StartSpan(ctx, "Coordinator.AddNodes") + defer span.End() for _, info := range infos { if key.Equal(info.ID().Key(), c.self.Key()) { + // skip self continue } - isNew := c.rt.AddNode(info.ID()) - c.ep.MaybeAddToPeerstore(ctx, info, c.cfg.PeerstoreTTL) - - if isNew { - c.outboundEvents <- &KademliaRoutingUpdatedEvent[K, A]{ - NodeInfo: info, - } + // inject a new node into the coordinator's includeEvents queue + iev := &routing.EventIncludeAddCandidate[K, A]{ + NodeInfo: info, } + c.includeEvents.Enqueue(ctx, iev) } return nil } -// FindNodeRequestFunc is a function that creates a request to find the supplied node id -// TODO: consider this being a first class method of the Endpoint -type FindNodeRequestFunc[K kad.Key[K], A kad.Address[A]] func(kad.NodeID[K]) (address.ProtocolID, kad.Request[K, A]) - // Bootstrap instructs the coordinator to begin bootstrapping the routing table. // While bootstrap is in progress, no other queries will make progress. -func (c *Coordinator[K, A]) Bootstrap(ctx context.Context, seeds []kad.NodeID[K], fn FindNodeRequestFunc[K, A]) error { - protoID, msg := fn(c.self) +func (c *Coordinator[K, A]) Bootstrap(ctx context.Context, seeds []kad.NodeID[K]) error { + protoID, msg := c.findNodeFn(c.self) bev := &routing.EventBootstrapStart[K, A]{ ProtocolID: protoID, diff --git a/coord/coordinator_test.go b/coord/coordinator_test.go index 85c94f4..79dbc39 100644 --- a/coord/coordinator_test.go +++ b/coord/coordinator_test.go @@ -22,7 +22,7 @@ import ( "github.com/plprobelab/go-kademlia/sim" ) -func setupSimulation(t *testing.T, ctx context.Context) ([]kad.NodeInfo[key.Key8, kadtest.StrAddr], []*sim.Endpoint[key.Key8, kadtest.StrAddr], []kad.RoutingTable[key.Key8, kad.NodeID[key.Key8]], *sim.LiteSimulator) { +func setupSimulation(t *testing.T, ctx context.Context) ([]kad.NodeInfo[key.Key8, kadtest.StrAddr], []*sim.Endpoint[key.Key8, kadtest.StrAddr], []*simplert.SimpleRT[key.Key8, kad.NodeID[key.Key8]], *sim.LiteSimulator) { // create node identifiers nodeCount := 4 ids := make([]*kadtest.ID[key.Key8], nodeCount) @@ -48,7 +48,7 @@ func setupSimulation(t *testing.T, ctx context.Context) ([]kad.NodeInfo[key.Key8 // create a fake router to virtually connect nodes router := sim.NewRouter[key.Key8, kadtest.StrAddr]() - rts := make([]kad.RoutingTable[key.Key8, kad.NodeID[key.Key8]], len(addrs)) + rts := make([]*simplert.SimpleRT[key.Key8, kad.NodeID[key.Key8]], len(addrs)) eps := make([]*sim.Endpoint[key.Key8, kadtest.StrAddr], len(addrs)) schedulers := make([]event.AwareScheduler, len(addrs)) servers := make([]*sim.Server[key.Key8, kadtest.StrAddr], len(addrs)) @@ -107,6 +107,10 @@ const peerstoreTTL = 10 * time.Minute var protoID = address.ProtocolID("/statemachine/1.0.0") // protocol ID for the test +var findNodeFn = func(n kad.NodeID[key.Key8]) (address.ProtocolID, kad.Request[key.Key8, kadtest.StrAddr]) { + return protoID, sim.NewRequest[key.Key8, kadtest.StrAddr](n.Key()) +} + // expectEventType selects on the event channel until an event of the expected type is sent. func expectEventType(t *testing.T, ctx context.Context, events <-chan KademliaEvent, expected KademliaEvent) (KademliaEvent, error) { t.Helper() @@ -195,7 +199,7 @@ func TestExhaustiveQuery(t *testing.T) { // A will first ask B, B will reply with C's address (and A's address) // A will then ask C, C will reply with D's address (and B's address) self := nodes[0].ID() - c, err := NewCoordinator[key.Key8, kadtest.StrAddr](self, eps[0], rts[0], ccfg) + c, err := NewCoordinator[key.Key8, kadtest.StrAddr](self, eps[0], findNodeFn, rts[0], ccfg) if err != nil { log.Fatalf("unexpected error creating coordinator: %v", err) } @@ -272,7 +276,7 @@ func TestRoutingUpdatedEventEmittedForCloserNodes(t *testing.T) { // A will first ask B, B will reply with C's address (and A's address) // A will then ask C, C will reply with D's address (and B's address) self := nodes[0].ID() - c, err := NewCoordinator[key.Key8, kadtest.StrAddr](self, eps[0], rts[0], ccfg) + c, err := NewCoordinator[key.Key8, kadtest.StrAddr](self, eps[0], findNodeFn, rts[0], ccfg) if err != nil { log.Fatalf("unexpected error creating coordinator: %v", err) } @@ -309,10 +313,6 @@ func TestRoutingUpdatedEventEmittedForCloserNodes(t *testing.T) { require.NoError(t, err) } -var findNodeFn = func(n kad.NodeID[key.Key8]) (address.ProtocolID, kad.Request[key.Key8, kadtest.StrAddr]) { - return protoID, sim.NewRequest[key.Key8, kadtest.StrAddr](n.Key()) -} - func TestBootstrap(t *testing.T) { ctx, cancel := kadtest.Ctx(t) defer cancel() @@ -337,7 +337,7 @@ func TestBootstrap(t *testing.T) { }(ctx) self := nodes[0].ID() - c, err := NewCoordinator[key.Key8, kadtest.StrAddr](self, eps[0], rts[0], ccfg) + c, err := NewCoordinator[key.Key8, kadtest.StrAddr](self, eps[0], findNodeFn, rts[0], ccfg) if err != nil { log.Fatalf("unexpected error creating coordinator: %v", err) } @@ -349,10 +349,8 @@ func TestBootstrap(t *testing.T) { seeds := []kad.NodeID[key.Key8]{ nodes[1].ID(), } - err = c.Bootstrap(ctx, seeds, findNodeFn) - if err != nil { - t.Fatalf("failed to initiate bootstrap: %v", err) - } + err = c.Bootstrap(ctx, seeds) + require.NoError(t, err) // the query run by the coordinator should have received a response from nodes[1] ev, err := expectEventType(t, ctx, events, &KademliaOutboundQueryProgressedEvent[key.Key8, kadtest.StrAddr]{}) @@ -388,3 +386,58 @@ func TestBootstrap(t *testing.T) { require.Equal(t, 3, tevf.Stats.Success) require.Equal(t, 0, tevf.Stats.Failure) } + +func TestIncludeNode(t *testing.T) { + ctx, cancel := kadtest.Ctx(t) + defer cancel() + + nodes, eps, rts, siml := setupSimulation(t, ctx) + + clk := siml.Clock() + + ccfg := DefaultConfig() + ccfg.Clock = clk + ccfg.PeerstoreTTL = peerstoreTTL + + go func(ctx context.Context) { + for { + select { + case <-time.After(10 * time.Millisecond): + siml.Run(ctx) + case <-ctx.Done(): + return + } + } + }(ctx) + + self := nodes[0].ID() + c, err := NewCoordinator[key.Key8, kadtest.StrAddr](self, eps[0], findNodeFn, rts[0], ccfg) + if err != nil { + log.Fatalf("unexpected error creating coordinator: %v", err) + } + siml.Add(c) + events := c.Events() + + candidate := nodes[3] // not in nodes[0] routing table + + // the routing table should not contain the node yet + foundNode, err := rts[0].Find(ctx, candidate.ID().Key()) + require.NoError(t, err) + require.Nil(t, foundNode) + + // inject a new node into the coordinator's includeEvents queue + err = c.AddNodes(ctx, []kad.NodeInfo[key.Key8, kadtest.StrAddr]{candidate}) + require.NoError(t, err) + + // the include state machine runs in the background and eventually should add the node to routing table + ev, err := expectEventType(t, ctx, events, &KademliaRoutingUpdatedEvent[key.Key8, kadtest.StrAddr]{}) + require.NoError(t, err) + + tev := ev.(*KademliaRoutingUpdatedEvent[key.Key8, kadtest.StrAddr]) + require.Equal(t, candidate.ID(), tev.NodeInfo.ID()) + + // the routing table should contain the node + foundNode, err = rts[0].Find(ctx, candidate.ID().Key()) + require.NoError(t, err) + require.NotNil(t, foundNode) +} diff --git a/examples/statemachine/main.go b/examples/statemachine/main.go index 27628b5..bba44d8 100644 --- a/examples/statemachine/main.go +++ b/examples/statemachine/main.go @@ -20,6 +20,7 @@ import ( "github.com/plprobelab/go-kademlia/internal/kadtest" "github.com/plprobelab/go-kademlia/kad" "github.com/plprobelab/go-kademlia/key" + "github.com/plprobelab/go-kademlia/network/address" "github.com/plprobelab/go-kademlia/network/endpoint" "github.com/plprobelab/go-kademlia/routing/simplert" "github.com/plprobelab/go-kademlia/sim" @@ -60,7 +61,7 @@ func main() { ccfg.Clock = siml.Clock() ccfg.PeerstoreTTL = peerstoreTTL - kad, err := coord.NewCoordinator[key.Key256, net.IP](nodes[0].ID(), eps[0], rts[0], ccfg) + kad, err := coord.NewCoordinator[key.Key256, net.IP](nodes[0].ID(), eps[0], findNodeFn, rts[0], ccfg) if err != nil { log.Fatal(err) } @@ -204,6 +205,10 @@ func debug(f string, args ...any) { fmt.Println(fmt.Sprintf(f, args...)) } +var findNodeFn = func(n kad.NodeID[key.Key256]) (address.ProtocolID, kad.Request[key.Key256, net.IP]) { + return protoID, sim.NewRequest[key.Key256, net.IP](n.Key()) +} + type RoutingUpdate any type Event any diff --git a/routing/include.go b/routing/include.go new file mode 100644 index 0000000..d789912 --- /dev/null +++ b/routing/include.go @@ -0,0 +1,281 @@ +package routing + +import ( + "context" + "fmt" + "time" + + "github.com/benbjohnson/clock" + + "github.com/plprobelab/go-kademlia/kad" + "github.com/plprobelab/go-kademlia/kaderr" + "github.com/plprobelab/go-kademlia/key" + "github.com/plprobelab/go-kademlia/util" +) + +type check[K kad.Key[K], A kad.Address[A]] struct { + NodeInfo kad.NodeInfo[K, A] + Started time.Time +} + +type Include[K kad.Key[K], A kad.Address[A]] struct { + rt kad.RoutingTable[K, kad.NodeID[K]] + + // checks is an index of checks in progress + checks map[string]check[K, A] + + candidates *nodeQueue[K, A] + + // cfg is a copy of the optional configuration supplied to the Include + cfg IncludeConfig +} + +// IncludeConfig specifies optional configuration for an Include +type IncludeConfig struct { + QueueCapacity int // the maximum number of nodes that can be in the candidate queue + Concurrency int // the maximum number of include checks that may be in progress at any one time + Timeout time.Duration // the time to wait before terminating a check that is not making progress + Clock clock.Clock // a clock that may replaced by a mock when testing +} + +// Validate checks the configuration options and returns an error if any have invalid values. +func (cfg *IncludeConfig) Validate() error { + if cfg.Clock == nil { + return &kaderr.ConfigurationError{ + Component: "IncludeConfig", + Err: fmt.Errorf("clock must not be nil"), + } + } + + if cfg.Concurrency < 1 { + return &kaderr.ConfigurationError{ + Component: "IncludeConfig", + Err: fmt.Errorf("concurrency must be greater than zero"), + } + } + + if cfg.Timeout < 1 { + return &kaderr.ConfigurationError{ + Component: "IncludeConfig", + Err: fmt.Errorf("timeout must be greater than zero"), + } + } + + if cfg.QueueCapacity < 1 { + return &kaderr.ConfigurationError{ + Component: "IncludeConfig", + Err: fmt.Errorf("queue size must be greater than zero"), + } + } + + return nil +} + +// DefaultIncludeConfig returns the default configuration options for an Include. +// Options may be overridden before passing to NewInclude +func DefaultIncludeConfig() *IncludeConfig { + return &IncludeConfig{ + Clock: clock.New(), // use standard time + Concurrency: 3, + Timeout: time.Minute, + QueueCapacity: 128, + } +} + +func NewInclude[K kad.Key[K], A kad.Address[A]](rt kad.RoutingTable[K, kad.NodeID[K]], cfg *IncludeConfig) (*Include[K, A], error) { + if cfg == nil { + cfg = DefaultIncludeConfig() + } else if err := cfg.Validate(); err != nil { + return nil, err + } + + return &Include[K, A]{ + candidates: newNodeQueue[K, A](cfg.QueueCapacity), + cfg: *cfg, + rt: rt, + checks: make(map[string]check[K, A], cfg.Concurrency), + }, nil +} + +// Advance advances the state of the include state machine by attempting to advance its query if running. +func (b *Include[K, A]) Advance(ctx context.Context, ev IncludeEvent) IncludeState { + ctx, span := util.StartSpan(ctx, "Include.Advance") + defer span.End() + + switch tev := ev.(type) { + + case *EventIncludeAddCandidate[K, A]: + // TODO: potentially time out a check and make room in the queue + if !b.candidates.HasCapacity() { + return &StateIncludeWaitingFull{} + } + b.candidates.Enqueue(ctx, tev.NodeInfo) + + case *EventIncludeMessageResponse[K, A]: + ch, ok := b.checks[key.HexString(tev.NodeInfo.ID().Key())] + if ok { + delete(b.checks, key.HexString(tev.NodeInfo.ID().Key())) + // require that the node responded with at least one closer node + if tev.Response != nil && len(tev.Response.CloserNodes()) > 0 { + if b.rt.AddNode(tev.NodeInfo.ID()) { + return &StateIncludeRoutingUpdated[K, A]{ + NodeInfo: ch.NodeInfo, + } + } + } + } + case *EventIncludeMessageFailure[K, A]: + delete(b.checks, key.HexString(tev.NodeInfo.ID().Key())) + + case *EventIncludePoll: + // ignore, nothing to do + default: + panic(fmt.Sprintf("unexpected event: %T", tev)) + } + + if len(b.checks) == b.cfg.Concurrency { + if !b.candidates.HasCapacity() { + return &StateIncludeWaitingFull{} + } + return &StateIncludeWaitingAtCapacity{} + } + + candidate, ok := b.candidates.Dequeue(ctx) + if !ok { + // No candidate in queue + if len(b.checks) > 0 { + return &StateIncludeWaitingWithCapacity{} + } + return &StateIncludeIdle{} + } + + b.checks[key.HexString(candidate.ID().Key())] = check[K, A]{ + NodeInfo: candidate, + Started: b.cfg.Clock.Now(), + } + + // Ask the node to find itself + return &StateIncludeFindNodeMessage[K, A]{ + NodeInfo: candidate, + } +} + +// nodeQueue is a bounded queue of unique NodeIDs +type nodeQueue[K kad.Key[K], A kad.Address[A]] struct { + capacity int + nodes []kad.NodeInfo[K, A] + keys map[string]struct{} +} + +func newNodeQueue[K kad.Key[K], A kad.Address[A]](capacity int) *nodeQueue[K, A] { + return &nodeQueue[K, A]{ + capacity: capacity, + nodes: make([]kad.NodeInfo[K, A], 0, capacity), + keys: make(map[string]struct{}, capacity), + } +} + +// Enqueue adds a node to the queue. It returns true if the node was +// added and false otherwise. +func (q *nodeQueue[K, A]) Enqueue(ctx context.Context, n kad.NodeInfo[K, A]) bool { + if len(q.nodes) == q.capacity { + return false + } + + if _, exists := q.keys[key.HexString(n.ID().Key())]; exists { + return false + } + + q.nodes = append(q.nodes, n) + q.keys[key.HexString(n.ID().Key())] = struct{}{} + return true +} + +// Dequeue reads an node from the queue. It returns the node and a true value +// if a node was read or nil and false if no node was read. +func (q *nodeQueue[K, A]) Dequeue(ctx context.Context) (kad.NodeInfo[K, A], bool) { + if len(q.nodes) == 0 { + var v kad.NodeInfo[K, A] + return v, false + } + + var n kad.NodeInfo[K, A] + n, q.nodes = q.nodes[0], q.nodes[1:] + delete(q.keys, key.HexString(n.ID().Key())) + + return n, true +} + +func (q *nodeQueue[K, A]) HasCapacity() bool { + return len(q.nodes) < q.capacity +} + +// IncludeState is the state of a include. +type IncludeState interface { + includeState() +} + +// StateIncludeFindNodeMessage indicates that the include subsystem is waiting to send a find node message a node. +// A find node message should be sent to the node, with the target being the node's key. +type StateIncludeFindNodeMessage[K kad.Key[K], A kad.Address[A]] struct { + NodeInfo kad.NodeInfo[K, A] // the node to send the mssage to +} + +// StateIncludeIdle indicates that the include is not running its query. +type StateIncludeIdle struct{} + +// StateIncludeWaitingAtCapacity indicates that the include subsystem is waiting for responses for checks and +// that the maximum number of concurrent checks has been reached. +type StateIncludeWaitingAtCapacity struct{} + +// StateIncludeWaitingWithCapacity indicates that the include subsystem is waiting for responses for checks +// but has capacity to perform more. +type StateIncludeWaitingWithCapacity struct{} + +// StateIncludeWaitingFull indicates that the include subsystem is waiting for responses for checks and +// that the maximum number of queued candidates has been reached. +type StateIncludeWaitingFull struct{} + +// StateIncludeRoutingUpdated indicates the routing table has been updated with a new node. +type StateIncludeRoutingUpdated[K kad.Key[K], A kad.Address[A]] struct { + NodeInfo kad.NodeInfo[K, A] +} + +// includeState() ensures that only Include states can be assigned to an IncludeState. +func (*StateIncludeFindNodeMessage[K, A]) includeState() {} +func (*StateIncludeIdle) includeState() {} +func (*StateIncludeWaitingAtCapacity) includeState() {} +func (*StateIncludeWaitingWithCapacity) includeState() {} +func (*StateIncludeWaitingFull) includeState() {} +func (*StateIncludeRoutingUpdated[K, A]) includeState() {} + +// IncludeEvent is an event intended to advance the state of a include. +type IncludeEvent interface { + includeEvent() +} + +// EventIncludePoll is an event that signals the include that it can perform housekeeping work such as time out queries. +type EventIncludePoll struct{} + +// EventIncludeAddCandidate notifies that a node should be added to the candidate list +type EventIncludeAddCandidate[K kad.Key[K], A kad.Address[A]] struct { + NodeInfo kad.NodeInfo[K, A] // the candidate node +} + +// EventIncludeMessageResponse notifies a include that a sent message has received a successful response. +type EventIncludeMessageResponse[K kad.Key[K], A kad.Address[A]] struct { + NodeInfo kad.NodeInfo[K, A] // the node the message was sent to + Response kad.Response[K, A] // the message response sent by the node +} + +// EventIncludeMessageFailure notifiesa include that an attempt to send a message has failed. +type EventIncludeMessageFailure[K kad.Key[K], A kad.Address[A]] struct { + NodeInfo kad.NodeInfo[K, A] // the node the message was sent to + Error error // the error that caused the failure, if any +} + +// includeEvent() ensures that only Include events can be assigned to the IncludeEvent interface. +func (*EventIncludePoll) includeEvent() {} +func (*EventIncludeAddCandidate[K, A]) includeEvent() {} +func (*EventIncludeMessageResponse[K, A]) includeEvent() {} +func (*EventIncludeMessageFailure[K, A]) includeEvent() {} diff --git a/routing/include_test.go b/routing/include_test.go new file mode 100644 index 0000000..4b00311 --- /dev/null +++ b/routing/include_test.go @@ -0,0 +1,350 @@ +package routing + +import ( + "context" + "testing" + + "github.com/benbjohnson/clock" + "github.com/stretchr/testify/require" + + "github.com/plprobelab/go-kademlia/internal/kadtest" + "github.com/plprobelab/go-kademlia/kad" + "github.com/plprobelab/go-kademlia/key" + "github.com/plprobelab/go-kademlia/routing/simplert" +) + +func TestIncludeConfigValidate(t *testing.T) { + t.Run("default is valid", func(t *testing.T) { + cfg := DefaultIncludeConfig() + require.NoError(t, cfg.Validate()) + }) + + t.Run("clock is not nil", func(t *testing.T) { + cfg := DefaultIncludeConfig() + cfg.Clock = nil + require.Error(t, cfg.Validate()) + }) + + t.Run("timeout positive", func(t *testing.T) { + cfg := DefaultIncludeConfig() + cfg.Timeout = 0 + require.Error(t, cfg.Validate()) + cfg.Timeout = -1 + require.Error(t, cfg.Validate()) + }) + + t.Run("request concurrency positive", func(t *testing.T) { + cfg := DefaultIncludeConfig() + cfg.Concurrency = 0 + require.Error(t, cfg.Validate()) + cfg.Concurrency = -1 + require.Error(t, cfg.Validate()) + }) + + t.Run("queue size positive", func(t *testing.T) { + cfg := DefaultIncludeConfig() + cfg.QueueCapacity = 0 + require.Error(t, cfg.Validate()) + cfg.QueueCapacity = -1 + require.Error(t, cfg.Validate()) + }) +} + +func TestIncludeStartsIdle(t *testing.T) { + ctx := context.Background() + clk := clock.NewMock() + cfg := DefaultIncludeConfig() + cfg.Clock = clk + + rt := simplert.New[key.Key8, kad.NodeID[key.Key8]](kadtest.NewID(key.Key8(128)), 5) + + bs, err := NewInclude[key.Key8, kadtest.StrAddr](rt, cfg) + require.NoError(t, err) + + state := bs.Advance(ctx, &EventIncludePoll{}) + require.IsType(t, &StateIncludeIdle{}, state) +} + +func TestIncludeAddCandidateStartsCheckIfCapacity(t *testing.T) { + ctx := context.Background() + clk := clock.NewMock() + cfg := DefaultIncludeConfig() + cfg.Clock = clk + cfg.Concurrency = 1 + + rt := simplert.New[key.Key8, kad.NodeID[key.Key8]](kadtest.NewID(key.Key8(128)), 5) + + p, err := NewInclude[key.Key8, kadtest.StrAddr](rt, cfg) + require.NoError(t, err) + + candidate := kadtest.NewInfo( + kadtest.NewID(key.Key8(0b00000100)), + []kadtest.StrAddr{kadtest.StrAddr("4")}, + ) + + // add a candidate + state := p.Advance(ctx, &EventIncludeAddCandidate[key.Key8, kadtest.StrAddr]{ + NodeInfo: candidate, + }) + // the state machine should attempt to send a message + require.IsType(t, &StateIncludeFindNodeMessage[key.Key8, kadtest.StrAddr]{}, state) + + st := state.(*StateIncludeFindNodeMessage[key.Key8, kadtest.StrAddr]) + + // the message should be sent to the candidate node + require.Equal(t, candidate, st.NodeInfo) + + // the message should be looking for the candidate node + require.Equal(t, candidate.ID(), st.NodeInfo.ID()) + + // now the include reports that it is waiting since concurrency is 1 + state = p.Advance(ctx, &EventIncludePoll{}) + require.IsType(t, &StateIncludeWaitingAtCapacity{}, state) +} + +func TestIncludeAddCandidateReportsCapacity(t *testing.T) { + ctx := context.Background() + clk := clock.NewMock() + cfg := DefaultIncludeConfig() + cfg.Clock = clk + cfg.Concurrency = 2 + + rt := simplert.New[key.Key8, kad.NodeID[key.Key8]](kadtest.NewID(key.Key8(128)), 5) + p, err := NewInclude[key.Key8, kadtest.StrAddr](rt, cfg) + require.NoError(t, err) + + candidate := kadtest.NewInfo( + kadtest.NewID(key.Key8(0b00000100)), + []kadtest.StrAddr{kadtest.StrAddr("4")}, + ) + + // add a candidate + state := p.Advance(ctx, &EventIncludeAddCandidate[key.Key8, kadtest.StrAddr]{ + NodeInfo: candidate, + }) + require.IsType(t, &StateIncludeFindNodeMessage[key.Key8, kadtest.StrAddr]{}, state) + + // now the state machine reports that it is waiting with capacity since concurrency + // is greater than the number of checks in flight + state = p.Advance(ctx, &EventIncludePoll{}) + require.IsType(t, &StateIncludeWaitingWithCapacity{}, state) +} + +func TestIncludeAddCandidateOverQueueLength(t *testing.T) { + ctx := context.Background() + clk := clock.NewMock() + cfg := DefaultIncludeConfig() + cfg.Clock = clk + cfg.QueueCapacity = 2 // only allow two candidates in the queue + cfg.Concurrency = 3 + + rt := simplert.New[key.Key8, kad.NodeID[key.Key8]](kadtest.NewID(key.Key8(128)), 5) + + p, err := NewInclude[key.Key8, kadtest.StrAddr](rt, cfg) + require.NoError(t, err) + + // add a candidate + state := p.Advance(ctx, &EventIncludeAddCandidate[key.Key8, kadtest.StrAddr]{ + NodeInfo: kadtest.NewInfo( + kadtest.NewID(key.Key8(0b00000100)), + []kadtest.StrAddr{kadtest.StrAddr("4")}, + ), + }) + require.IsType(t, &StateIncludeFindNodeMessage[key.Key8, kadtest.StrAddr]{}, state) + + // include reports that it is waiting and has capacity for more + state = p.Advance(ctx, &EventIncludePoll{}) + require.IsType(t, &StateIncludeWaitingWithCapacity{}, state) + + // add second candidate + state = p.Advance(ctx, &EventIncludeAddCandidate[key.Key8, kadtest.StrAddr]{ + NodeInfo: kadtest.NewInfo( + kadtest.NewID(key.Key8(0b00000010)), + []kadtest.StrAddr{kadtest.StrAddr("2")}, + ), + }) + // sends a message to the candidate + require.IsType(t, &StateIncludeFindNodeMessage[key.Key8, kadtest.StrAddr]{}, state) + + // include reports that it is waiting and has capacity for more + state = p.Advance(ctx, &EventIncludePoll{}) + // sends a message to the candidate + require.IsType(t, &StateIncludeWaitingWithCapacity{}, state) + + // add third candidate + state = p.Advance(ctx, &EventIncludeAddCandidate[key.Key8, kadtest.StrAddr]{ + NodeInfo: kadtest.NewInfo( + kadtest.NewID(key.Key8(0b00000011)), + []kadtest.StrAddr{kadtest.StrAddr("3")}, + ), + }) + // sends a message to the candidate + require.IsType(t, &StateIncludeFindNodeMessage[key.Key8, kadtest.StrAddr]{}, state) + + // include reports that it is waiting at capacity since 3 messages are in flight + state = p.Advance(ctx, &EventIncludePoll{}) + require.IsType(t, &StateIncludeWaitingAtCapacity{}, state) + + // add fourth candidate + state = p.Advance(ctx, &EventIncludeAddCandidate[key.Key8, kadtest.StrAddr]{ + NodeInfo: kadtest.NewInfo( + kadtest.NewID(key.Key8(0b00000101)), + []kadtest.StrAddr{kadtest.StrAddr("5")}, + ), + }) + + // include reports that it is waiting at capacity since 3 messages are already in flight + require.IsType(t, &StateIncludeWaitingAtCapacity{}, state) + + // add fifth candidate + state = p.Advance(ctx, &EventIncludeAddCandidate[key.Key8, kadtest.StrAddr]{ + NodeInfo: kadtest.NewInfo( + kadtest.NewID(key.Key8(0b00000110)), + []kadtest.StrAddr{kadtest.StrAddr("6")}, + ), + }) + + // include reports that it is waiting and the candidate queue is full since it + // is configured to have 3 concurrent checks and 2 queued + require.IsType(t, &StateIncludeWaitingFull{}, state) + + // add sixth candidate + state = p.Advance(ctx, &EventIncludeAddCandidate[key.Key8, kadtest.StrAddr]{ + NodeInfo: kadtest.NewInfo( + kadtest.NewID(key.Key8(0b00000111)), + []kadtest.StrAddr{kadtest.StrAddr("7")}, + ), + }) + + // include reports that it is still waiting and the candidate queue is full since it + // is configured to have 3 concurrent checks and 2 queued + require.IsType(t, &StateIncludeWaitingFull{}, state) +} + +func TestIncludeMessageResponse(t *testing.T) { + ctx := context.Background() + clk := clock.NewMock() + cfg := DefaultIncludeConfig() + cfg.Clock = clk + cfg.Concurrency = 2 + + rt := simplert.New[key.Key8, kad.NodeID[key.Key8]](kadtest.NewID(key.Key8(128)), 5) + + p, err := NewInclude[key.Key8, kadtest.StrAddr](rt, cfg) + require.NoError(t, err) + + // add a candidate + state := p.Advance(ctx, &EventIncludeAddCandidate[key.Key8, kadtest.StrAddr]{ + NodeInfo: kadtest.NewInfo( + kadtest.NewID(key.Key8(0b00000100)), + []kadtest.StrAddr{kadtest.StrAddr("4")}, + ), + }) + require.IsType(t, &StateIncludeFindNodeMessage[key.Key8, kadtest.StrAddr]{}, state) + + // notify that node was contacted successfully, with no closer nodes + state = p.Advance(ctx, &EventIncludeMessageResponse[key.Key8, kadtest.StrAddr]{ + NodeInfo: kadtest.NewInfo( + kadtest.NewID(key.Key8(0b00000100)), + []kadtest.StrAddr{kadtest.StrAddr("4")}, + ), + Response: kadtest.NewResponse("resp", []kad.NodeInfo[key.Key8, kadtest.StrAddr]{ + kadtest.NewInfo(kadtest.NewID(key.Key8(4)), []kadtest.StrAddr{"addr_4"}), + kadtest.NewInfo(kadtest.NewID(key.Key8(6)), []kadtest.StrAddr{"addr_6"}), + }), + }) + + // should respond that the routing table was updated + require.IsType(t, &StateIncludeRoutingUpdated[key.Key8, kadtest.StrAddr]{}, state) + + st := state.(*StateIncludeRoutingUpdated[key.Key8, kadtest.StrAddr]) + + // the update is for the correct node + require.Equal(t, kadtest.NewID(key.Key8(4)), st.NodeInfo.ID()) + + // the routing table should contain the node + foundNode, err := rt.Find(ctx, key.Key8(4)) + require.NoError(t, err) + require.NotNil(t, foundNode) + + require.True(t, key.Equal(foundNode.Key(), key.Key8(4))) + + // advancing again should reports that it is idle + state = p.Advance(ctx, &EventIncludePoll{}) + require.IsType(t, &StateIncludeIdle{}, state) +} + +func TestIncludeMessageResponseInvalid(t *testing.T) { + ctx := context.Background() + clk := clock.NewMock() + cfg := DefaultIncludeConfig() + cfg.Clock = clk + cfg.Concurrency = 2 + + rt := simplert.New[key.Key8, kad.NodeID[key.Key8]](kadtest.NewID(key.Key8(128)), 5) + + p, err := NewInclude[key.Key8, kadtest.StrAddr](rt, cfg) + require.NoError(t, err) + + // add a candidate + state := p.Advance(ctx, &EventIncludeAddCandidate[key.Key8, kadtest.StrAddr]{ + NodeInfo: kadtest.NewInfo( + kadtest.NewID(key.Key8(0b00000100)), + []kadtest.StrAddr{kadtest.StrAddr("4")}, + ), + }) + require.IsType(t, &StateIncludeFindNodeMessage[key.Key8, kadtest.StrAddr]{}, state) + + // notify that node was contacted successfully, but no closer nodes + state = p.Advance(ctx, &EventIncludeMessageResponse[key.Key8, kadtest.StrAddr]{ + NodeInfo: kadtest.NewInfo( + kadtest.NewID(key.Key8(0b00000100)), + []kadtest.StrAddr{kadtest.StrAddr("4")}, + ), + }) + // should respond that state machine is idle + require.IsType(t, &StateIncludeIdle{}, state) + + // the routing table should not contain the node + foundNode, err := rt.Find(ctx, key.Key8(4)) + require.NoError(t, err) + require.Nil(t, foundNode) +} + +func TestIncludeMessageFailure(t *testing.T) { + ctx := context.Background() + clk := clock.NewMock() + cfg := DefaultIncludeConfig() + cfg.Clock = clk + cfg.Concurrency = 2 + + rt := simplert.New[key.Key8, kad.NodeID[key.Key8]](kadtest.NewID(key.Key8(128)), 5) + + p, err := NewInclude[key.Key8, kadtest.StrAddr](rt, cfg) + require.NoError(t, err) + + // add a candidate + state := p.Advance(ctx, &EventIncludeAddCandidate[key.Key8, kadtest.StrAddr]{ + NodeInfo: kadtest.NewInfo( + kadtest.NewID(key.Key8(0b00000100)), + []kadtest.StrAddr{kadtest.StrAddr("4")}, + ), + }) + require.IsType(t, &StateIncludeFindNodeMessage[key.Key8, kadtest.StrAddr]{}, state) + + // notify that node was not contacted successfully + state = p.Advance(ctx, &EventIncludeMessageFailure[key.Key8, kadtest.StrAddr]{ + NodeInfo: kadtest.NewInfo( + kadtest.NewID(key.Key8(0b00000100)), + []kadtest.StrAddr{kadtest.StrAddr("4")}, + ), + }) + + // should respond that state machine is idle + require.IsType(t, &StateIncludeIdle{}, state) + + // the routing table should not contain the node + foundNode, err := rt.Find(ctx, key.Key8(4)) + require.NoError(t, err) + require.Nil(t, foundNode) +}