diff --git a/coord/coordinator.go b/coord/coordinator.go index 3eef338..a68a2da 100644 --- a/coord/coordinator.go +++ b/coord/coordinator.go @@ -43,7 +43,7 @@ func (q *eventQueue[E]) Enqueue(ctx context.Context, e E) { } // Dequeue reads an event from the queue. It returns the event and a true value -// if an event was read or the zero value if the event type and false if no event +// if an event was read or the zero value of the event type and false if no event // was read. This method is non-blocking. func (q *eventQueue[E]) Dequeue(ctx context.Context) (E, bool) { select { @@ -55,6 +55,10 @@ func (q *eventQueue[E]) Dequeue(ctx context.Context) (E, bool) { } } +// FindNodeRequestFunc is a function that creates a request to find the supplied node id +// TODO: consider this being a first class method of the Endpoint +type FindNodeRequestFunc[K kad.Key[K], A kad.Address[A]] func(kad.NodeID[K]) (address.ProtocolID, kad.Request[K, A]) + // A Coordinator coordinates the state machines that comprise a Kademlia DHT // Currently this is only queries and bootstrapping but will expand to include other state machines such as // routing table refresh, and reproviding. @@ -77,12 +81,22 @@ type Coordinator[K kad.Key[K], A kad.Address[A]] struct { // bootstrapEvents is a fifo queue of events that are to be processed by the bootstrap state machine bootstrapEvents *eventQueue[routing.BootstrapEvent] + // include is the include state machine, responsible for including candidate nodes into the routing table + include StateMachine[routing.IncludeState, routing.IncludeEvent] + + // includeEvents is a fifo queue of events that are to be processed by the include state machine + includeEvents *eventQueue[routing.IncludeEvent] + // rt is the routing table used to look up nodes by distance rt kad.RoutingTable[K, kad.NodeID[K]] // ep is the message endpoint used to send requests ep endpoint.Endpoint[K, A] + // findNodeFn is a function that creates a find node request that may be understod by the endpoint + // TODO: thiis should be a function of the endpoint + findNodeFn FindNodeRequestFunc[K, A] + // queue not used queue event.EventQueue @@ -155,7 +169,7 @@ func DefaultConfig() *Config { } } -func NewCoordinator[K kad.Key[K], A kad.Address[A]](self kad.NodeID[K], ep endpoint.Endpoint[K, A], rt kad.RoutingTable[K, kad.NodeID[K]], cfg *Config) (*Coordinator[K, A], error) { +func NewCoordinator[K kad.Key[K], A kad.Address[A]](self kad.NodeID[K], ep endpoint.Endpoint[K, A], fn FindNodeRequestFunc[K, A], rt kad.RoutingTable[K, kad.NodeID[K]], cfg *Config) (*Coordinator[K, A], error) { if cfg == nil { cfg = DefaultConfig() } else if err := cfg.Validate(); err != nil { @@ -182,17 +196,34 @@ func NewCoordinator[K kad.Key[K], A kad.Address[A]](self kad.NodeID[K], ep endpo bootstrap, err := routing.NewBootstrap(self, bootstrapCfg) if err != nil { - return nil, fmt.Errorf("query pool: %w", err) + return nil, fmt.Errorf("bootstrap: %w", err) + } + + includeCfg := routing.DefaultIncludeConfig() + includeCfg.Clock = cfg.Clock + includeCfg.Timeout = cfg.QueryTimeout + + // TODO: expose config + // includeCfg.QueueCapacity = cfg.IncludeQueueCapacity + // includeCfg.Concurrency = cfg.IncludeConcurrency + // includeCfg.Timeout = cfg.IncludeTimeout + + include, err := routing.NewInclude[K, A](rt, includeCfg) + if err != nil { + return nil, fmt.Errorf("include: %w", err) } return &Coordinator[K, A]{ self: self, cfg: *cfg, ep: ep, + findNodeFn: fn, rt: rt, pool: qp, poolEvents: newEventQueue[query.PoolEvent](20), // 20 is abitrary, move to config bootstrap: bootstrap, bootstrapEvents: newEventQueue[routing.BootstrapEvent](20), // 20 is abitrary, move to config + include: include, + includeEvents: newEventQueue[routing.IncludeEvent](20), // 20 is abitrary, move to config outboundEvents: make(chan KademliaEvent, 20), queue: event.NewChanQueue(DefaultChanqueueCapacity), planner: event.NewSimplePlanner(cfg.Clock), @@ -207,8 +238,28 @@ func (c *Coordinator[K, A]) RunOne(ctx context.Context) bool { ctx, span := util.StartSpan(ctx, "Coordinator.RunOne") defer span.End() + // Process state machines in priority order + // Give the bootstrap state machine priority - // No queries can be run while a bootstrap is in progress + if c.advanceBootstrap(ctx) { + return true + } + + // Attempt to advance the include state machine so candidate nodes + // are added to the routing table + if c.advanceInclude(ctx) { + return true + } + + // Attempt to advance an outbound query + if c.advancePool(ctx) { + return true + } + + return false +} + +func (c *Coordinator[K, A]) advanceBootstrap(ctx context.Context) bool { bev, ok := c.bootstrapEvents.Dequeue(ctx) if !ok { bev = &routing.EventBootstrapPoll{} @@ -217,11 +268,11 @@ func (c *Coordinator[K, A]) RunOne(ctx context.Context) bool { bstate := c.bootstrap.Advance(ctx, bev) switch st := bstate.(type) { case *routing.StateBootstrapMessage[K, A]: - c.sendBootstrapMessage(ctx, st.ProtocolID, st.NodeID, st.Message, st.QueryID, st.Stats) + c.sendBootstrapFindNode(ctx, st.NodeID, st.QueryID, st.Stats) return true case *routing.StateBootstrapWaiting: - // bootstrap waiting for a message response, don't proceed with other state machines + // bootstrap waiting for a message response, proceed with other state machines return false case *routing.StateBootstrapFinished: @@ -232,12 +283,49 @@ func (c *Coordinator[K, A]) RunOne(ctx context.Context) bool { case *routing.StateBootstrapIdle: // bootstrap not running, can proceed to other state machines - break + return false default: panic(fmt.Sprintf("unexpected bootstrap state: %T", st)) } +} - // Attempt to advance an outbound query +func (c *Coordinator[K, A]) advanceInclude(ctx context.Context) bool { + // Attempt to advance the include state machine so candidate nodes + // are added to the routing table + iev, ok := c.includeEvents.Dequeue(ctx) + if !ok { + iev = &routing.EventIncludePoll{} + } + istate := c.include.Advance(ctx, iev) + switch st := istate.(type) { + case *routing.StateIncludeFindNodeMessage[K, A]: + // include wants to send a find node message to a node + c.sendIncludeFindNode(ctx, st.NodeInfo) + return true + case *routing.StateIncludeRoutingUpdated[K, A]: + // a node has been included in the routing table + c.outboundEvents <- &KademliaRoutingUpdatedEvent[K, A]{ + NodeInfo: st.NodeInfo, + } + return true + case *routing.StateIncludeWaitingAtCapacity: + // nothing to do except wait for message response or timeout + return false + case *routing.StateIncludeWaitingWithCapacity: + // nothing to do except wait for message response or timeout + return false + case *routing.StateIncludeWaitingFull: + // nothing to do except wait for message response or timeout + return false + case *routing.StateIncludeIdle: + // nothing to do except wait for message response or timeout + return false + default: + panic(fmt.Sprintf("unexpected include state: %T", st)) + } +} + +func (c *Coordinator[K, A]) advancePool(ctx context.Context) bool { pev, ok := c.poolEvents.Dequeue(ctx) if !ok { pev = &query.EventPoolPoll{} @@ -249,26 +337,26 @@ func (c *Coordinator[K, A]) RunOne(ctx context.Context) bool { c.sendQueryMessage(ctx, st.ProtocolID, st.NodeID, st.Message, st.QueryID, st.Stats) return true case *query.StatePoolWaitingAtCapacity: - // TODO + // nothing to do except wait for message response or timeout + return false case *query.StatePoolWaitingWithCapacity: - // TODO + // nothing to do except wait for message response or timeout + return false case *query.StatePoolQueryFinished: c.outboundEvents <- &KademliaOutboundQueryFinishedEvent{ QueryID: st.QueryID, Stats: st.Stats, } return true - - // TODO case *query.StatePoolQueryTimeout: // TODO + return false case *query.StatePoolIdle: - // TODO + // nothing to do + return false default: panic(fmt.Sprintf("unexpected pool state: %T", st)) } - - return false } func (c *Coordinator[K, A]) sendQueryMessage(ctx context.Context, protoID address.ProtocolID, to kad.NodeID[K], msg kad.Request[K, A], queryID query.QueryID, stats query.QueryStats) { @@ -326,8 +414,8 @@ func (c *Coordinator[K, A]) sendQueryMessage(ctx context.Context, protoID addres } } -func (c *Coordinator[K, A]) sendBootstrapMessage(ctx context.Context, protoID address.ProtocolID, to kad.NodeID[K], msg kad.Request[K, A], queryID query.QueryID, stats query.QueryStats) { - ctx, span := util.StartSpan(ctx, "Coordinator.sendBootstrapMessage") +func (c *Coordinator[K, A]) sendBootstrapFindNode(ctx context.Context, to kad.NodeID[K], queryID query.QueryID, stats query.QueryStats) { + ctx, span := util.StartSpan(ctx, "Coordinator.sendBootstrapFindNode") defer span.End() onSendError := func(ctx context.Context, err error) { @@ -373,13 +461,65 @@ func (c *Coordinator[K, A]) sendBootstrapMessage(ctx context.Context, protoID ad c.bootstrapEvents.Enqueue(ctx, bev) } + protoID, msg := c.findNodeFn(c.self) err := c.ep.SendRequestHandleResponse(ctx, protoID, to, msg, msg.EmptyResponse(), 0, onMessageResponse) if err != nil { onSendError(ctx, err) } } +func (c *Coordinator[K, A]) sendIncludeFindNode(ctx context.Context, to kad.NodeInfo[K, A]) { + ctx, span := util.StartSpan(ctx, "Coordinator.sendIncludeFindNode") + defer span.End() + + onSendError := func(ctx context.Context, err error) { + if errors.Is(err, endpoint.ErrCannotConnect) { + // here we can notify that the peer is unroutable, which would feed into peerstore and routing table + // TODO: remove from routing table + return + } + + iev := &routing.EventIncludeMessageFailure[K, A]{ + NodeInfo: to, + Error: err, + } + c.includeEvents.Enqueue(ctx, iev) + } + + onMessageResponse := func(ctx context.Context, resp kad.Response[K, A], err error) { + if err != nil { + onSendError(ctx, err) + return + } + + iev := &routing.EventIncludeMessageResponse[K, A]{ + NodeInfo: to, + Response: resp, + } + c.includeEvents.Enqueue(ctx, iev) + + if resp != nil { + candidates := resp.CloserNodes() + if len(candidates) > 0 { + // ignore error here + c.AddNodes(ctx, candidates) + } + } + } + + // this might be new node addressing info + c.ep.MaybeAddToPeerstore(ctx, to, c.cfg.PeerstoreTTL) + + protoID, msg := c.findNodeFn(c.self) + err := c.ep.SendRequestHandleResponse(ctx, protoID, to.ID(), msg, msg.EmptyResponse(), 0, onMessageResponse) + if err != nil { + onSendError(ctx, err) + } +} + func (c *Coordinator[K, A]) StartQuery(ctx context.Context, queryID query.QueryID, protocolID address.ProtocolID, msg kad.Request[K, A]) error { + ctx, span := util.StartSpan(ctx, "Coordinator.StartQuery") + defer span.End() knownClosestPeers := c.rt.NearestNodes(msg.Target(), 20) qev := &query.EventPoolAddQuery[K, A]{ @@ -394,6 +534,8 @@ func (c *Coordinator[K, A]) StartQuery(ctx context.Context, queryID query.QueryI } func (c *Coordinator[K, A]) StopQuery(ctx context.Context, queryID query.QueryID) error { + ctx, span := util.StartSpan(ctx, "Coordinator.StopQuery") + defer span.End() qev := &query.EventPoolStopQuery{ QueryID: queryID, } @@ -402,33 +544,29 @@ func (c *Coordinator[K, A]) StopQuery(ctx context.Context, queryID query.QueryID } // AddNodes suggests new DHT nodes and their associated addresses to be added to the routing table. -// If the routing table is been updated as a result of this operation a KademliaRoutingUpdatedEvent event is emitted. +// If the routing table is updated as a result of this operation a KademliaRoutingUpdatedEvent event is emitted. func (c *Coordinator[K, A]) AddNodes(ctx context.Context, infos []kad.NodeInfo[K, A]) error { + ctx, span := util.StartSpan(ctx, "Coordinator.AddNodes") + defer span.End() for _, info := range infos { if key.Equal(info.ID().Key(), c.self.Key()) { + // skip self continue } - isNew := c.rt.AddNode(info.ID()) - c.ep.MaybeAddToPeerstore(ctx, info, c.cfg.PeerstoreTTL) - - if isNew { - c.outboundEvents <- &KademliaRoutingUpdatedEvent[K, A]{ - NodeInfo: info, - } + // inject a new node into the coordinator's includeEvents queue + iev := &routing.EventIncludeAddCandidate[K, A]{ + NodeInfo: info, } + c.includeEvents.Enqueue(ctx, iev) } return nil } -// FindNodeRequestFunc is a function that creates a request to find the supplied node id -// TODO: consider this being a first class method of the Endpoint -type FindNodeRequestFunc[K kad.Key[K], A kad.Address[A]] func(kad.NodeID[K]) (address.ProtocolID, kad.Request[K, A]) - // Bootstrap instructs the coordinator to begin bootstrapping the routing table. // While bootstrap is in progress, no other queries will make progress. -func (c *Coordinator[K, A]) Bootstrap(ctx context.Context, seeds []kad.NodeID[K], fn FindNodeRequestFunc[K, A]) error { - protoID, msg := fn(c.self) +func (c *Coordinator[K, A]) Bootstrap(ctx context.Context, seeds []kad.NodeID[K]) error { + protoID, msg := c.findNodeFn(c.self) bev := &routing.EventBootstrapStart[K, A]{ ProtocolID: protoID, diff --git a/coord/coordinator_test.go b/coord/coordinator_test.go index 85c94f4..79dbc39 100644 --- a/coord/coordinator_test.go +++ b/coord/coordinator_test.go @@ -22,7 +22,7 @@ import ( "github.com/plprobelab/go-kademlia/sim" ) -func setupSimulation(t *testing.T, ctx context.Context) ([]kad.NodeInfo[key.Key8, kadtest.StrAddr], []*sim.Endpoint[key.Key8, kadtest.StrAddr], []kad.RoutingTable[key.Key8, kad.NodeID[key.Key8]], *sim.LiteSimulator) { +func setupSimulation(t *testing.T, ctx context.Context) ([]kad.NodeInfo[key.Key8, kadtest.StrAddr], []*sim.Endpoint[key.Key8, kadtest.StrAddr], []*simplert.SimpleRT[key.Key8, kad.NodeID[key.Key8]], *sim.LiteSimulator) { // create node identifiers nodeCount := 4 ids := make([]*kadtest.ID[key.Key8], nodeCount) @@ -48,7 +48,7 @@ func setupSimulation(t *testing.T, ctx context.Context) ([]kad.NodeInfo[key.Key8 // create a fake router to virtually connect nodes router := sim.NewRouter[key.Key8, kadtest.StrAddr]() - rts := make([]kad.RoutingTable[key.Key8, kad.NodeID[key.Key8]], len(addrs)) + rts := make([]*simplert.SimpleRT[key.Key8, kad.NodeID[key.Key8]], len(addrs)) eps := make([]*sim.Endpoint[key.Key8, kadtest.StrAddr], len(addrs)) schedulers := make([]event.AwareScheduler, len(addrs)) servers := make([]*sim.Server[key.Key8, kadtest.StrAddr], len(addrs)) @@ -107,6 +107,10 @@ const peerstoreTTL = 10 * time.Minute var protoID = address.ProtocolID("/statemachine/1.0.0") // protocol ID for the test +var findNodeFn = func(n kad.NodeID[key.Key8]) (address.ProtocolID, kad.Request[key.Key8, kadtest.StrAddr]) { + return protoID, sim.NewRequest[key.Key8, kadtest.StrAddr](n.Key()) +} + // expectEventType selects on the event channel until an event of the expected type is sent. func expectEventType(t *testing.T, ctx context.Context, events <-chan KademliaEvent, expected KademliaEvent) (KademliaEvent, error) { t.Helper() @@ -195,7 +199,7 @@ func TestExhaustiveQuery(t *testing.T) { // A will first ask B, B will reply with C's address (and A's address) // A will then ask C, C will reply with D's address (and B's address) self := nodes[0].ID() - c, err := NewCoordinator[key.Key8, kadtest.StrAddr](self, eps[0], rts[0], ccfg) + c, err := NewCoordinator[key.Key8, kadtest.StrAddr](self, eps[0], findNodeFn, rts[0], ccfg) if err != nil { log.Fatalf("unexpected error creating coordinator: %v", err) } @@ -272,7 +276,7 @@ func TestRoutingUpdatedEventEmittedForCloserNodes(t *testing.T) { // A will first ask B, B will reply with C's address (and A's address) // A will then ask C, C will reply with D's address (and B's address) self := nodes[0].ID() - c, err := NewCoordinator[key.Key8, kadtest.StrAddr](self, eps[0], rts[0], ccfg) + c, err := NewCoordinator[key.Key8, kadtest.StrAddr](self, eps[0], findNodeFn, rts[0], ccfg) if err != nil { log.Fatalf("unexpected error creating coordinator: %v", err) } @@ -309,10 +313,6 @@ func TestRoutingUpdatedEventEmittedForCloserNodes(t *testing.T) { require.NoError(t, err) } -var findNodeFn = func(n kad.NodeID[key.Key8]) (address.ProtocolID, kad.Request[key.Key8, kadtest.StrAddr]) { - return protoID, sim.NewRequest[key.Key8, kadtest.StrAddr](n.Key()) -} - func TestBootstrap(t *testing.T) { ctx, cancel := kadtest.Ctx(t) defer cancel() @@ -337,7 +337,7 @@ func TestBootstrap(t *testing.T) { }(ctx) self := nodes[0].ID() - c, err := NewCoordinator[key.Key8, kadtest.StrAddr](self, eps[0], rts[0], ccfg) + c, err := NewCoordinator[key.Key8, kadtest.StrAddr](self, eps[0], findNodeFn, rts[0], ccfg) if err != nil { log.Fatalf("unexpected error creating coordinator: %v", err) } @@ -349,10 +349,8 @@ func TestBootstrap(t *testing.T) { seeds := []kad.NodeID[key.Key8]{ nodes[1].ID(), } - err = c.Bootstrap(ctx, seeds, findNodeFn) - if err != nil { - t.Fatalf("failed to initiate bootstrap: %v", err) - } + err = c.Bootstrap(ctx, seeds) + require.NoError(t, err) // the query run by the coordinator should have received a response from nodes[1] ev, err := expectEventType(t, ctx, events, &KademliaOutboundQueryProgressedEvent[key.Key8, kadtest.StrAddr]{}) @@ -388,3 +386,58 @@ func TestBootstrap(t *testing.T) { require.Equal(t, 3, tevf.Stats.Success) require.Equal(t, 0, tevf.Stats.Failure) } + +func TestIncludeNode(t *testing.T) { + ctx, cancel := kadtest.Ctx(t) + defer cancel() + + nodes, eps, rts, siml := setupSimulation(t, ctx) + + clk := siml.Clock() + + ccfg := DefaultConfig() + ccfg.Clock = clk + ccfg.PeerstoreTTL = peerstoreTTL + + go func(ctx context.Context) { + for { + select { + case <-time.After(10 * time.Millisecond): + siml.Run(ctx) + case <-ctx.Done(): + return + } + } + }(ctx) + + self := nodes[0].ID() + c, err := NewCoordinator[key.Key8, kadtest.StrAddr](self, eps[0], findNodeFn, rts[0], ccfg) + if err != nil { + log.Fatalf("unexpected error creating coordinator: %v", err) + } + siml.Add(c) + events := c.Events() + + candidate := nodes[3] // not in nodes[0] routing table + + // the routing table should not contain the node yet + foundNode, err := rts[0].Find(ctx, candidate.ID().Key()) + require.NoError(t, err) + require.Nil(t, foundNode) + + // inject a new node into the coordinator's includeEvents queue + err = c.AddNodes(ctx, []kad.NodeInfo[key.Key8, kadtest.StrAddr]{candidate}) + require.NoError(t, err) + + // the include state machine runs in the background and eventually should add the node to routing table + ev, err := expectEventType(t, ctx, events, &KademliaRoutingUpdatedEvent[key.Key8, kadtest.StrAddr]{}) + require.NoError(t, err) + + tev := ev.(*KademliaRoutingUpdatedEvent[key.Key8, kadtest.StrAddr]) + require.Equal(t, candidate.ID(), tev.NodeInfo.ID()) + + // the routing table should contain the node + foundNode, err = rts[0].Find(ctx, candidate.ID().Key()) + require.NoError(t, err) + require.NotNil(t, foundNode) +} diff --git a/examples/statemachine/main.go b/examples/statemachine/main.go index 27628b5..bba44d8 100644 --- a/examples/statemachine/main.go +++ b/examples/statemachine/main.go @@ -20,6 +20,7 @@ import ( "github.com/plprobelab/go-kademlia/internal/kadtest" "github.com/plprobelab/go-kademlia/kad" "github.com/plprobelab/go-kademlia/key" + "github.com/plprobelab/go-kademlia/network/address" "github.com/plprobelab/go-kademlia/network/endpoint" "github.com/plprobelab/go-kademlia/routing/simplert" "github.com/plprobelab/go-kademlia/sim" @@ -60,7 +61,7 @@ func main() { ccfg.Clock = siml.Clock() ccfg.PeerstoreTTL = peerstoreTTL - kad, err := coord.NewCoordinator[key.Key256, net.IP](nodes[0].ID(), eps[0], rts[0], ccfg) + kad, err := coord.NewCoordinator[key.Key256, net.IP](nodes[0].ID(), eps[0], findNodeFn, rts[0], ccfg) if err != nil { log.Fatal(err) } @@ -204,6 +205,10 @@ func debug(f string, args ...any) { fmt.Println(fmt.Sprintf(f, args...)) } +var findNodeFn = func(n kad.NodeID[key.Key256]) (address.ProtocolID, kad.Request[key.Key256, net.IP]) { + return protoID, sim.NewRequest[key.Key256, net.IP](n.Key()) +} + type RoutingUpdate any type Event any diff --git a/routing/include.go b/routing/include.go new file mode 100644 index 0000000..d789912 --- /dev/null +++ b/routing/include.go @@ -0,0 +1,281 @@ +package routing + +import ( + "context" + "fmt" + "time" + + "github.com/benbjohnson/clock" + + "github.com/plprobelab/go-kademlia/kad" + "github.com/plprobelab/go-kademlia/kaderr" + "github.com/plprobelab/go-kademlia/key" + "github.com/plprobelab/go-kademlia/util" +) + +type check[K kad.Key[K], A kad.Address[A]] struct { + NodeInfo kad.NodeInfo[K, A] + Started time.Time +} + +type Include[K kad.Key[K], A kad.Address[A]] struct { + rt kad.RoutingTable[K, kad.NodeID[K]] + + // checks is an index of checks in progress + checks map[string]check[K, A] + + candidates *nodeQueue[K, A] + + // cfg is a copy of the optional configuration supplied to the Include + cfg IncludeConfig +} + +// IncludeConfig specifies optional configuration for an Include +type IncludeConfig struct { + QueueCapacity int // the maximum number of nodes that can be in the candidate queue + Concurrency int // the maximum number of include checks that may be in progress at any one time + Timeout time.Duration // the time to wait before terminating a check that is not making progress + Clock clock.Clock // a clock that may replaced by a mock when testing +} + +// Validate checks the configuration options and returns an error if any have invalid values. +func (cfg *IncludeConfig) Validate() error { + if cfg.Clock == nil { + return &kaderr.ConfigurationError{ + Component: "IncludeConfig", + Err: fmt.Errorf("clock must not be nil"), + } + } + + if cfg.Concurrency < 1 { + return &kaderr.ConfigurationError{ + Component: "IncludeConfig", + Err: fmt.Errorf("concurrency must be greater than zero"), + } + } + + if cfg.Timeout < 1 { + return &kaderr.ConfigurationError{ + Component: "IncludeConfig", + Err: fmt.Errorf("timeout must be greater than zero"), + } + } + + if cfg.QueueCapacity < 1 { + return &kaderr.ConfigurationError{ + Component: "IncludeConfig", + Err: fmt.Errorf("queue size must be greater than zero"), + } + } + + return nil +} + +// DefaultIncludeConfig returns the default configuration options for an Include. +// Options may be overridden before passing to NewInclude +func DefaultIncludeConfig() *IncludeConfig { + return &IncludeConfig{ + Clock: clock.New(), // use standard time + Concurrency: 3, + Timeout: time.Minute, + QueueCapacity: 128, + } +} + +func NewInclude[K kad.Key[K], A kad.Address[A]](rt kad.RoutingTable[K, kad.NodeID[K]], cfg *IncludeConfig) (*Include[K, A], error) { + if cfg == nil { + cfg = DefaultIncludeConfig() + } else if err := cfg.Validate(); err != nil { + return nil, err + } + + return &Include[K, A]{ + candidates: newNodeQueue[K, A](cfg.QueueCapacity), + cfg: *cfg, + rt: rt, + checks: make(map[string]check[K, A], cfg.Concurrency), + }, nil +} + +// Advance advances the state of the include state machine by attempting to advance its query if running. +func (b *Include[K, A]) Advance(ctx context.Context, ev IncludeEvent) IncludeState { + ctx, span := util.StartSpan(ctx, "Include.Advance") + defer span.End() + + switch tev := ev.(type) { + + case *EventIncludeAddCandidate[K, A]: + // TODO: potentially time out a check and make room in the queue + if !b.candidates.HasCapacity() { + return &StateIncludeWaitingFull{} + } + b.candidates.Enqueue(ctx, tev.NodeInfo) + + case *EventIncludeMessageResponse[K, A]: + ch, ok := b.checks[key.HexString(tev.NodeInfo.ID().Key())] + if ok { + delete(b.checks, key.HexString(tev.NodeInfo.ID().Key())) + // require that the node responded with at least one closer node + if tev.Response != nil && len(tev.Response.CloserNodes()) > 0 { + if b.rt.AddNode(tev.NodeInfo.ID()) { + return &StateIncludeRoutingUpdated[K, A]{ + NodeInfo: ch.NodeInfo, + } + } + } + } + case *EventIncludeMessageFailure[K, A]: + delete(b.checks, key.HexString(tev.NodeInfo.ID().Key())) + + case *EventIncludePoll: + // ignore, nothing to do + default: + panic(fmt.Sprintf("unexpected event: %T", tev)) + } + + if len(b.checks) == b.cfg.Concurrency { + if !b.candidates.HasCapacity() { + return &StateIncludeWaitingFull{} + } + return &StateIncludeWaitingAtCapacity{} + } + + candidate, ok := b.candidates.Dequeue(ctx) + if !ok { + // No candidate in queue + if len(b.checks) > 0 { + return &StateIncludeWaitingWithCapacity{} + } + return &StateIncludeIdle{} + } + + b.checks[key.HexString(candidate.ID().Key())] = check[K, A]{ + NodeInfo: candidate, + Started: b.cfg.Clock.Now(), + } + + // Ask the node to find itself + return &StateIncludeFindNodeMessage[K, A]{ + NodeInfo: candidate, + } +} + +// nodeQueue is a bounded queue of unique NodeIDs +type nodeQueue[K kad.Key[K], A kad.Address[A]] struct { + capacity int + nodes []kad.NodeInfo[K, A] + keys map[string]struct{} +} + +func newNodeQueue[K kad.Key[K], A kad.Address[A]](capacity int) *nodeQueue[K, A] { + return &nodeQueue[K, A]{ + capacity: capacity, + nodes: make([]kad.NodeInfo[K, A], 0, capacity), + keys: make(map[string]struct{}, capacity), + } +} + +// Enqueue adds a node to the queue. It returns true if the node was +// added and false otherwise. +func (q *nodeQueue[K, A]) Enqueue(ctx context.Context, n kad.NodeInfo[K, A]) bool { + if len(q.nodes) == q.capacity { + return false + } + + if _, exists := q.keys[key.HexString(n.ID().Key())]; exists { + return false + } + + q.nodes = append(q.nodes, n) + q.keys[key.HexString(n.ID().Key())] = struct{}{} + return true +} + +// Dequeue reads an node from the queue. It returns the node and a true value +// if a node was read or nil and false if no node was read. +func (q *nodeQueue[K, A]) Dequeue(ctx context.Context) (kad.NodeInfo[K, A], bool) { + if len(q.nodes) == 0 { + var v kad.NodeInfo[K, A] + return v, false + } + + var n kad.NodeInfo[K, A] + n, q.nodes = q.nodes[0], q.nodes[1:] + delete(q.keys, key.HexString(n.ID().Key())) + + return n, true +} + +func (q *nodeQueue[K, A]) HasCapacity() bool { + return len(q.nodes) < q.capacity +} + +// IncludeState is the state of a include. +type IncludeState interface { + includeState() +} + +// StateIncludeFindNodeMessage indicates that the include subsystem is waiting to send a find node message a node. +// A find node message should be sent to the node, with the target being the node's key. +type StateIncludeFindNodeMessage[K kad.Key[K], A kad.Address[A]] struct { + NodeInfo kad.NodeInfo[K, A] // the node to send the mssage to +} + +// StateIncludeIdle indicates that the include is not running its query. +type StateIncludeIdle struct{} + +// StateIncludeWaitingAtCapacity indicates that the include subsystem is waiting for responses for checks and +// that the maximum number of concurrent checks has been reached. +type StateIncludeWaitingAtCapacity struct{} + +// StateIncludeWaitingWithCapacity indicates that the include subsystem is waiting for responses for checks +// but has capacity to perform more. +type StateIncludeWaitingWithCapacity struct{} + +// StateIncludeWaitingFull indicates that the include subsystem is waiting for responses for checks and +// that the maximum number of queued candidates has been reached. +type StateIncludeWaitingFull struct{} + +// StateIncludeRoutingUpdated indicates the routing table has been updated with a new node. +type StateIncludeRoutingUpdated[K kad.Key[K], A kad.Address[A]] struct { + NodeInfo kad.NodeInfo[K, A] +} + +// includeState() ensures that only Include states can be assigned to an IncludeState. +func (*StateIncludeFindNodeMessage[K, A]) includeState() {} +func (*StateIncludeIdle) includeState() {} +func (*StateIncludeWaitingAtCapacity) includeState() {} +func (*StateIncludeWaitingWithCapacity) includeState() {} +func (*StateIncludeWaitingFull) includeState() {} +func (*StateIncludeRoutingUpdated[K, A]) includeState() {} + +// IncludeEvent is an event intended to advance the state of a include. +type IncludeEvent interface { + includeEvent() +} + +// EventIncludePoll is an event that signals the include that it can perform housekeeping work such as time out queries. +type EventIncludePoll struct{} + +// EventIncludeAddCandidate notifies that a node should be added to the candidate list +type EventIncludeAddCandidate[K kad.Key[K], A kad.Address[A]] struct { + NodeInfo kad.NodeInfo[K, A] // the candidate node +} + +// EventIncludeMessageResponse notifies a include that a sent message has received a successful response. +type EventIncludeMessageResponse[K kad.Key[K], A kad.Address[A]] struct { + NodeInfo kad.NodeInfo[K, A] // the node the message was sent to + Response kad.Response[K, A] // the message response sent by the node +} + +// EventIncludeMessageFailure notifiesa include that an attempt to send a message has failed. +type EventIncludeMessageFailure[K kad.Key[K], A kad.Address[A]] struct { + NodeInfo kad.NodeInfo[K, A] // the node the message was sent to + Error error // the error that caused the failure, if any +} + +// includeEvent() ensures that only Include events can be assigned to the IncludeEvent interface. +func (*EventIncludePoll) includeEvent() {} +func (*EventIncludeAddCandidate[K, A]) includeEvent() {} +func (*EventIncludeMessageResponse[K, A]) includeEvent() {} +func (*EventIncludeMessageFailure[K, A]) includeEvent() {} diff --git a/routing/include_test.go b/routing/include_test.go new file mode 100644 index 0000000..4b00311 --- /dev/null +++ b/routing/include_test.go @@ -0,0 +1,350 @@ +package routing + +import ( + "context" + "testing" + + "github.com/benbjohnson/clock" + "github.com/stretchr/testify/require" + + "github.com/plprobelab/go-kademlia/internal/kadtest" + "github.com/plprobelab/go-kademlia/kad" + "github.com/plprobelab/go-kademlia/key" + "github.com/plprobelab/go-kademlia/routing/simplert" +) + +func TestIncludeConfigValidate(t *testing.T) { + t.Run("default is valid", func(t *testing.T) { + cfg := DefaultIncludeConfig() + require.NoError(t, cfg.Validate()) + }) + + t.Run("clock is not nil", func(t *testing.T) { + cfg := DefaultIncludeConfig() + cfg.Clock = nil + require.Error(t, cfg.Validate()) + }) + + t.Run("timeout positive", func(t *testing.T) { + cfg := DefaultIncludeConfig() + cfg.Timeout = 0 + require.Error(t, cfg.Validate()) + cfg.Timeout = -1 + require.Error(t, cfg.Validate()) + }) + + t.Run("request concurrency positive", func(t *testing.T) { + cfg := DefaultIncludeConfig() + cfg.Concurrency = 0 + require.Error(t, cfg.Validate()) + cfg.Concurrency = -1 + require.Error(t, cfg.Validate()) + }) + + t.Run("queue size positive", func(t *testing.T) { + cfg := DefaultIncludeConfig() + cfg.QueueCapacity = 0 + require.Error(t, cfg.Validate()) + cfg.QueueCapacity = -1 + require.Error(t, cfg.Validate()) + }) +} + +func TestIncludeStartsIdle(t *testing.T) { + ctx := context.Background() + clk := clock.NewMock() + cfg := DefaultIncludeConfig() + cfg.Clock = clk + + rt := simplert.New[key.Key8, kad.NodeID[key.Key8]](kadtest.NewID(key.Key8(128)), 5) + + bs, err := NewInclude[key.Key8, kadtest.StrAddr](rt, cfg) + require.NoError(t, err) + + state := bs.Advance(ctx, &EventIncludePoll{}) + require.IsType(t, &StateIncludeIdle{}, state) +} + +func TestIncludeAddCandidateStartsCheckIfCapacity(t *testing.T) { + ctx := context.Background() + clk := clock.NewMock() + cfg := DefaultIncludeConfig() + cfg.Clock = clk + cfg.Concurrency = 1 + + rt := simplert.New[key.Key8, kad.NodeID[key.Key8]](kadtest.NewID(key.Key8(128)), 5) + + p, err := NewInclude[key.Key8, kadtest.StrAddr](rt, cfg) + require.NoError(t, err) + + candidate := kadtest.NewInfo( + kadtest.NewID(key.Key8(0b00000100)), + []kadtest.StrAddr{kadtest.StrAddr("4")}, + ) + + // add a candidate + state := p.Advance(ctx, &EventIncludeAddCandidate[key.Key8, kadtest.StrAddr]{ + NodeInfo: candidate, + }) + // the state machine should attempt to send a message + require.IsType(t, &StateIncludeFindNodeMessage[key.Key8, kadtest.StrAddr]{}, state) + + st := state.(*StateIncludeFindNodeMessage[key.Key8, kadtest.StrAddr]) + + // the message should be sent to the candidate node + require.Equal(t, candidate, st.NodeInfo) + + // the message should be looking for the candidate node + require.Equal(t, candidate.ID(), st.NodeInfo.ID()) + + // now the include reports that it is waiting since concurrency is 1 + state = p.Advance(ctx, &EventIncludePoll{}) + require.IsType(t, &StateIncludeWaitingAtCapacity{}, state) +} + +func TestIncludeAddCandidateReportsCapacity(t *testing.T) { + ctx := context.Background() + clk := clock.NewMock() + cfg := DefaultIncludeConfig() + cfg.Clock = clk + cfg.Concurrency = 2 + + rt := simplert.New[key.Key8, kad.NodeID[key.Key8]](kadtest.NewID(key.Key8(128)), 5) + p, err := NewInclude[key.Key8, kadtest.StrAddr](rt, cfg) + require.NoError(t, err) + + candidate := kadtest.NewInfo( + kadtest.NewID(key.Key8(0b00000100)), + []kadtest.StrAddr{kadtest.StrAddr("4")}, + ) + + // add a candidate + state := p.Advance(ctx, &EventIncludeAddCandidate[key.Key8, kadtest.StrAddr]{ + NodeInfo: candidate, + }) + require.IsType(t, &StateIncludeFindNodeMessage[key.Key8, kadtest.StrAddr]{}, state) + + // now the state machine reports that it is waiting with capacity since concurrency + // is greater than the number of checks in flight + state = p.Advance(ctx, &EventIncludePoll{}) + require.IsType(t, &StateIncludeWaitingWithCapacity{}, state) +} + +func TestIncludeAddCandidateOverQueueLength(t *testing.T) { + ctx := context.Background() + clk := clock.NewMock() + cfg := DefaultIncludeConfig() + cfg.Clock = clk + cfg.QueueCapacity = 2 // only allow two candidates in the queue + cfg.Concurrency = 3 + + rt := simplert.New[key.Key8, kad.NodeID[key.Key8]](kadtest.NewID(key.Key8(128)), 5) + + p, err := NewInclude[key.Key8, kadtest.StrAddr](rt, cfg) + require.NoError(t, err) + + // add a candidate + state := p.Advance(ctx, &EventIncludeAddCandidate[key.Key8, kadtest.StrAddr]{ + NodeInfo: kadtest.NewInfo( + kadtest.NewID(key.Key8(0b00000100)), + []kadtest.StrAddr{kadtest.StrAddr("4")}, + ), + }) + require.IsType(t, &StateIncludeFindNodeMessage[key.Key8, kadtest.StrAddr]{}, state) + + // include reports that it is waiting and has capacity for more + state = p.Advance(ctx, &EventIncludePoll{}) + require.IsType(t, &StateIncludeWaitingWithCapacity{}, state) + + // add second candidate + state = p.Advance(ctx, &EventIncludeAddCandidate[key.Key8, kadtest.StrAddr]{ + NodeInfo: kadtest.NewInfo( + kadtest.NewID(key.Key8(0b00000010)), + []kadtest.StrAddr{kadtest.StrAddr("2")}, + ), + }) + // sends a message to the candidate + require.IsType(t, &StateIncludeFindNodeMessage[key.Key8, kadtest.StrAddr]{}, state) + + // include reports that it is waiting and has capacity for more + state = p.Advance(ctx, &EventIncludePoll{}) + // sends a message to the candidate + require.IsType(t, &StateIncludeWaitingWithCapacity{}, state) + + // add third candidate + state = p.Advance(ctx, &EventIncludeAddCandidate[key.Key8, kadtest.StrAddr]{ + NodeInfo: kadtest.NewInfo( + kadtest.NewID(key.Key8(0b00000011)), + []kadtest.StrAddr{kadtest.StrAddr("3")}, + ), + }) + // sends a message to the candidate + require.IsType(t, &StateIncludeFindNodeMessage[key.Key8, kadtest.StrAddr]{}, state) + + // include reports that it is waiting at capacity since 3 messages are in flight + state = p.Advance(ctx, &EventIncludePoll{}) + require.IsType(t, &StateIncludeWaitingAtCapacity{}, state) + + // add fourth candidate + state = p.Advance(ctx, &EventIncludeAddCandidate[key.Key8, kadtest.StrAddr]{ + NodeInfo: kadtest.NewInfo( + kadtest.NewID(key.Key8(0b00000101)), + []kadtest.StrAddr{kadtest.StrAddr("5")}, + ), + }) + + // include reports that it is waiting at capacity since 3 messages are already in flight + require.IsType(t, &StateIncludeWaitingAtCapacity{}, state) + + // add fifth candidate + state = p.Advance(ctx, &EventIncludeAddCandidate[key.Key8, kadtest.StrAddr]{ + NodeInfo: kadtest.NewInfo( + kadtest.NewID(key.Key8(0b00000110)), + []kadtest.StrAddr{kadtest.StrAddr("6")}, + ), + }) + + // include reports that it is waiting and the candidate queue is full since it + // is configured to have 3 concurrent checks and 2 queued + require.IsType(t, &StateIncludeWaitingFull{}, state) + + // add sixth candidate + state = p.Advance(ctx, &EventIncludeAddCandidate[key.Key8, kadtest.StrAddr]{ + NodeInfo: kadtest.NewInfo( + kadtest.NewID(key.Key8(0b00000111)), + []kadtest.StrAddr{kadtest.StrAddr("7")}, + ), + }) + + // include reports that it is still waiting and the candidate queue is full since it + // is configured to have 3 concurrent checks and 2 queued + require.IsType(t, &StateIncludeWaitingFull{}, state) +} + +func TestIncludeMessageResponse(t *testing.T) { + ctx := context.Background() + clk := clock.NewMock() + cfg := DefaultIncludeConfig() + cfg.Clock = clk + cfg.Concurrency = 2 + + rt := simplert.New[key.Key8, kad.NodeID[key.Key8]](kadtest.NewID(key.Key8(128)), 5) + + p, err := NewInclude[key.Key8, kadtest.StrAddr](rt, cfg) + require.NoError(t, err) + + // add a candidate + state := p.Advance(ctx, &EventIncludeAddCandidate[key.Key8, kadtest.StrAddr]{ + NodeInfo: kadtest.NewInfo( + kadtest.NewID(key.Key8(0b00000100)), + []kadtest.StrAddr{kadtest.StrAddr("4")}, + ), + }) + require.IsType(t, &StateIncludeFindNodeMessage[key.Key8, kadtest.StrAddr]{}, state) + + // notify that node was contacted successfully, with no closer nodes + state = p.Advance(ctx, &EventIncludeMessageResponse[key.Key8, kadtest.StrAddr]{ + NodeInfo: kadtest.NewInfo( + kadtest.NewID(key.Key8(0b00000100)), + []kadtest.StrAddr{kadtest.StrAddr("4")}, + ), + Response: kadtest.NewResponse("resp", []kad.NodeInfo[key.Key8, kadtest.StrAddr]{ + kadtest.NewInfo(kadtest.NewID(key.Key8(4)), []kadtest.StrAddr{"addr_4"}), + kadtest.NewInfo(kadtest.NewID(key.Key8(6)), []kadtest.StrAddr{"addr_6"}), + }), + }) + + // should respond that the routing table was updated + require.IsType(t, &StateIncludeRoutingUpdated[key.Key8, kadtest.StrAddr]{}, state) + + st := state.(*StateIncludeRoutingUpdated[key.Key8, kadtest.StrAddr]) + + // the update is for the correct node + require.Equal(t, kadtest.NewID(key.Key8(4)), st.NodeInfo.ID()) + + // the routing table should contain the node + foundNode, err := rt.Find(ctx, key.Key8(4)) + require.NoError(t, err) + require.NotNil(t, foundNode) + + require.True(t, key.Equal(foundNode.Key(), key.Key8(4))) + + // advancing again should reports that it is idle + state = p.Advance(ctx, &EventIncludePoll{}) + require.IsType(t, &StateIncludeIdle{}, state) +} + +func TestIncludeMessageResponseInvalid(t *testing.T) { + ctx := context.Background() + clk := clock.NewMock() + cfg := DefaultIncludeConfig() + cfg.Clock = clk + cfg.Concurrency = 2 + + rt := simplert.New[key.Key8, kad.NodeID[key.Key8]](kadtest.NewID(key.Key8(128)), 5) + + p, err := NewInclude[key.Key8, kadtest.StrAddr](rt, cfg) + require.NoError(t, err) + + // add a candidate + state := p.Advance(ctx, &EventIncludeAddCandidate[key.Key8, kadtest.StrAddr]{ + NodeInfo: kadtest.NewInfo( + kadtest.NewID(key.Key8(0b00000100)), + []kadtest.StrAddr{kadtest.StrAddr("4")}, + ), + }) + require.IsType(t, &StateIncludeFindNodeMessage[key.Key8, kadtest.StrAddr]{}, state) + + // notify that node was contacted successfully, but no closer nodes + state = p.Advance(ctx, &EventIncludeMessageResponse[key.Key8, kadtest.StrAddr]{ + NodeInfo: kadtest.NewInfo( + kadtest.NewID(key.Key8(0b00000100)), + []kadtest.StrAddr{kadtest.StrAddr("4")}, + ), + }) + // should respond that state machine is idle + require.IsType(t, &StateIncludeIdle{}, state) + + // the routing table should not contain the node + foundNode, err := rt.Find(ctx, key.Key8(4)) + require.NoError(t, err) + require.Nil(t, foundNode) +} + +func TestIncludeMessageFailure(t *testing.T) { + ctx := context.Background() + clk := clock.NewMock() + cfg := DefaultIncludeConfig() + cfg.Clock = clk + cfg.Concurrency = 2 + + rt := simplert.New[key.Key8, kad.NodeID[key.Key8]](kadtest.NewID(key.Key8(128)), 5) + + p, err := NewInclude[key.Key8, kadtest.StrAddr](rt, cfg) + require.NoError(t, err) + + // add a candidate + state := p.Advance(ctx, &EventIncludeAddCandidate[key.Key8, kadtest.StrAddr]{ + NodeInfo: kadtest.NewInfo( + kadtest.NewID(key.Key8(0b00000100)), + []kadtest.StrAddr{kadtest.StrAddr("4")}, + ), + }) + require.IsType(t, &StateIncludeFindNodeMessage[key.Key8, kadtest.StrAddr]{}, state) + + // notify that node was not contacted successfully + state = p.Advance(ctx, &EventIncludeMessageFailure[key.Key8, kadtest.StrAddr]{ + NodeInfo: kadtest.NewInfo( + kadtest.NewID(key.Key8(0b00000100)), + []kadtest.StrAddr{kadtest.StrAddr("4")}, + ), + }) + + // should respond that state machine is idle + require.IsType(t, &StateIncludeIdle{}, state) + + // the routing table should not contain the node + foundNode, err := rt.Find(ctx, key.Key8(4)) + require.NoError(t, err) + require.Nil(t, foundNode) +}