diff --git a/crawler/crawler.go b/crawler/crawler.go index 5f4ec7732..dae145382 100644 --- a/crawler/crawler.go +++ b/crawler/crawler.go @@ -119,6 +119,8 @@ type HandleQueryResult func(p peer.ID, rtPeers []*peer.AddrInfo) // HandleQueryFail is a callback on failed peer query type HandleQueryFail func(p peer.ID, err error) +const dialAddressExtendDur time.Duration = time.Minute * 30 + // Run crawls dht peers from an initial seed of `startingPeers` func (c *Crawler) Run(ctx context.Context, startingPeers []*peer.AddrInfo, handleSuccess HandleQueryResult, handleFail HandleQueryFail) { jobs := make(chan peer.ID, 1) @@ -140,15 +142,27 @@ func (c *Crawler) Run(ctx context.Context, startingPeers []*peer.AddrInfo, handl defer wg.Wait() defer close(jobs) - toDial := make([]*peer.AddrInfo, 0, len(startingPeers)) + var toDial []*peer.AddrInfo peersSeen := make(map[peer.ID]struct{}) + numSkipped := 0 for _, ai := range startingPeers { + extendAddrs := c.host.Peerstore().Addrs(ai.ID) + if len(ai.Addrs) > 0 { + extendAddrs = append(extendAddrs, ai.Addrs...) + c.host.Peerstore().AddAddrs(ai.ID, extendAddrs, dialAddressExtendDur) + } + if len(extendAddrs) == 0 { + numSkipped++ + continue + } + toDial = append(toDial, ai) peersSeen[ai.ID] = struct{}{} - extendAddrs := c.host.Peerstore().Addrs(ai.ID) - extendAddrs = append(extendAddrs, ai.Addrs...) - c.host.Peerstore().AddAddrs(ai.ID, extendAddrs, time.Hour) + } + + if numSkipped > 0 { + logger.Infof("%d starting peers were skipped due to lack of addresses. Starting crawl with %d peers", numSkipped, len(toDial)) } numQueried := 0 @@ -168,7 +182,7 @@ func (c *Crawler) Run(ctx context.Context, startingPeers []*peer.AddrInfo, handl logger.Debugf("peer %v had %d peers", res.peer, len(res.data)) rtPeers := make([]*peer.AddrInfo, 0, len(res.data)) for p, ai := range res.data { - c.host.Peerstore().AddAddrs(p, ai.Addrs, time.Hour) + c.host.Peerstore().AddAddrs(p, ai.Addrs, dialAddressExtendDur) if _, ok := peersSeen[p]; !ok { peersSeen[p] = struct{}{} toDial = append(toDial, ai) @@ -208,7 +222,7 @@ func (c *Crawler) queryPeer(ctx context.Context, nextPeer peer.ID) *queryResult defer cancel() err = c.host.Connect(connCtx, peer.AddrInfo{ID: nextPeer}) if err != nil { - logger.Infof("could not connect to peer %v: %v", nextPeer, err) + logger.Debugf("could not connect to peer %v: %v", nextPeer, err) return &queryResult{nextPeer, nil, err} } diff --git a/ctx_mutex.go b/ctx_mutex.go deleted file mode 100644 index c28d89875..000000000 --- a/ctx_mutex.go +++ /dev/null @@ -1,28 +0,0 @@ -package dht - -import ( - "context" -) - -type ctxMutex chan struct{} - -func newCtxMutex() ctxMutex { - return make(ctxMutex, 1) -} - -func (m ctxMutex) Lock(ctx context.Context) error { - select { - case m <- struct{}{}: - return nil - case <-ctx.Done(): - return ctx.Err() - } -} - -func (m ctxMutex) Unlock() { - select { - case <-m: - default: - panic("not locked") - } -} diff --git a/dht.go b/dht.go index 7c89ab56a..ae82c1396 100644 --- a/dht.go +++ b/dht.go @@ -16,6 +16,8 @@ import ( "github.com/libp2p/go-libp2p-core/routing" "github.com/libp2p/go-libp2p-kad-dht/internal" + dhtcfg "github.com/libp2p/go-libp2p-kad-dht/internal/config" + "github.com/libp2p/go-libp2p-kad-dht/internal/net" "github.com/libp2p/go-libp2p-kad-dht/metrics" pb "github.com/libp2p/go-libp2p-kad-dht/pb" "github.com/libp2p/go-libp2p-kad-dht/providers" @@ -96,7 +98,7 @@ type IpfsDHT struct { proc goprocess.Process protoMessenger *pb.ProtocolMessenger - msgSender *messageSenderImpl + msgSender pb.MessageSender plk sync.Mutex @@ -163,15 +165,15 @@ var ( // If the Routing Table has more than "minRTRefreshThreshold" peers, we consider a peer as a Routing Table candidate ONLY when // we successfully get a query response from it OR if it send us a query. func New(ctx context.Context, h host.Host, options ...Option) (*IpfsDHT, error) { - var cfg config - if err := cfg.apply(append([]Option{defaults}, options...)...); err != nil { + var cfg dhtcfg.Config + if err := cfg.Apply(append([]Option{dhtcfg.Defaults}, options...)...); err != nil { return nil, err } - if err := cfg.applyFallbacks(h); err != nil { + if err := cfg.ApplyFallbacks(h); err != nil { return nil, err } - if err := cfg.validate(); err != nil { + if err := cfg.Validate(); err != nil { return nil, err } @@ -180,34 +182,30 @@ func New(ctx context.Context, h host.Host, options ...Option) (*IpfsDHT, error) return nil, fmt.Errorf("failed to create DHT, err=%s", err) } - dht.autoRefresh = cfg.routingTable.autoRefresh + dht.autoRefresh = cfg.RoutingTable.AutoRefresh - dht.maxRecordAge = cfg.maxRecordAge - dht.enableProviders = cfg.enableProviders - dht.enableValues = cfg.enableValues - dht.disableFixLowPeers = cfg.disableFixLowPeers + dht.maxRecordAge = cfg.MaxRecordAge + dht.enableProviders = cfg.EnableProviders + dht.enableValues = cfg.EnableValues + dht.disableFixLowPeers = cfg.DisableFixLowPeers - dht.Validator = cfg.validator - dht.msgSender = &messageSenderImpl{ - host: h, - strmap: make(map[peer.ID]*peerMessageSender), - protocols: dht.protocols, - } + dht.Validator = cfg.Validator + dht.msgSender = net.NewMessageSenderImpl(h, dht.protocols) dht.protoMessenger, err = pb.NewProtocolMessenger(dht.msgSender, pb.WithValidator(dht.Validator)) if err != nil { return nil, err } - dht.testAddressUpdateProcessing = cfg.testAddressUpdateProcessing + dht.testAddressUpdateProcessing = cfg.TestAddressUpdateProcessing - dht.auto = cfg.mode - switch cfg.mode { + dht.auto = cfg.Mode + switch cfg.Mode { case ModeAuto, ModeClient: dht.mode = modeClient case ModeAutoServer, ModeServer: dht.mode = modeServer default: - return nil, fmt.Errorf("invalid dht mode %d", cfg.mode) + return nil, fmt.Errorf("invalid dht mode %d", cfg.Mode) } if dht.mode == modeServer { @@ -265,20 +263,20 @@ func NewDHTClient(ctx context.Context, h host.Host, dstore ds.Batching) *IpfsDHT return dht } -func makeDHT(ctx context.Context, h host.Host, cfg config) (*IpfsDHT, error) { +func makeDHT(ctx context.Context, h host.Host, cfg dhtcfg.Config) (*IpfsDHT, error) { var protocols, serverProtocols []protocol.ID - v1proto := cfg.protocolPrefix + kad1 + v1proto := cfg.ProtocolPrefix + kad1 - if cfg.v1ProtocolOverride != "" { - v1proto = cfg.v1ProtocolOverride + if cfg.V1ProtocolOverride != "" { + v1proto = cfg.V1ProtocolOverride } protocols = []protocol.ID{v1proto} serverProtocols = []protocol.ID{v1proto} dht := &IpfsDHT{ - datastore: cfg.datastore, + datastore: cfg.Datastore, self: h.ID(), selfKey: kb.ConvertPeerID(h.ID()), peerstore: h.Peerstore(), @@ -287,12 +285,12 @@ func makeDHT(ctx context.Context, h host.Host, cfg config) (*IpfsDHT, error) { protocols: protocols, protocolsStrs: protocol.ConvertToStrings(protocols), serverProtocols: serverProtocols, - bucketSize: cfg.bucketSize, - alpha: cfg.concurrency, - beta: cfg.resiliency, - queryPeerFilter: cfg.queryPeerFilter, - routingTablePeerFilter: cfg.routingTable.peerFilter, - rtPeerDiversityFilter: cfg.routingTable.diversityFilter, + bucketSize: cfg.BucketSize, + alpha: cfg.Concurrency, + beta: cfg.Resiliency, + queryPeerFilter: cfg.QueryPeerFilter, + routingTablePeerFilter: cfg.RoutingTable.PeerFilter, + rtPeerDiversityFilter: cfg.RoutingTable.DiversityFilter, fixLowPeersChan: make(chan struct{}, 1), @@ -306,12 +304,12 @@ func makeDHT(ctx context.Context, h host.Host, cfg config) (*IpfsDHT, error) { // query a peer as part of our refresh cycle. // To grok the Math Wizardy that produced these exact equations, please be patient as a document explaining it will // be published soon. - if cfg.concurrency < cfg.bucketSize { // (alpha < K) - l1 := math.Log(float64(1) / float64(cfg.bucketSize)) //(Log(1/K)) - l2 := math.Log(float64(1) - (float64(cfg.concurrency) / float64(cfg.bucketSize))) // Log(1 - (alpha / K)) - maxLastSuccessfulOutboundThreshold = time.Duration(l1 / l2 * float64(cfg.routingTable.refreshInterval)) + if cfg.Concurrency < cfg.BucketSize { // (alpha < K) + l1 := math.Log(float64(1) / float64(cfg.BucketSize)) //(Log(1/K)) + l2 := math.Log(float64(1) - (float64(cfg.Concurrency) / float64(cfg.BucketSize))) // Log(1 - (alpha / K)) + maxLastSuccessfulOutboundThreshold = time.Duration(l1 / l2 * float64(cfg.RoutingTable.RefreshInterval)) } else { - maxLastSuccessfulOutboundThreshold = cfg.routingTable.refreshInterval + maxLastSuccessfulOutboundThreshold = cfg.RoutingTable.RefreshInterval } // construct routing table @@ -321,7 +319,7 @@ func makeDHT(ctx context.Context, h host.Host, cfg config) (*IpfsDHT, error) { return nil, fmt.Errorf("failed to construct routing table,err=%s", err) } dht.routingTable = rt - dht.bootstrapPeers = cfg.bootstrapPeers + dht.bootstrapPeers = cfg.BootstrapPeers // rt refresh manager rtRefresh, err := makeRtRefreshManager(dht, cfg, maxLastSuccessfulOutboundThreshold) @@ -340,7 +338,7 @@ func makeDHT(ctx context.Context, h host.Host, cfg config) (*IpfsDHT, error) { // the DHT context should be done when the process is closed dht.ctx = goprocessctx.WithProcessClosing(ctxTags, dht.proc) - pm, err := providers.NewProviderManager(dht.ctx, h.ID(), cfg.datastore, cfg.providersOptions...) + pm, err := providers.NewProviderManager(dht.ctx, h.ID(), cfg.Datastore, cfg.ProvidersOptions...) if err != nil { return nil, err } @@ -351,7 +349,7 @@ func makeDHT(ctx context.Context, h host.Host, cfg config) (*IpfsDHT, error) { return dht, nil } -func makeRtRefreshManager(dht *IpfsDHT, cfg config, maxLastSuccessfulOutboundThreshold time.Duration) (*rtrefresh.RtRefreshManager, error) { +func makeRtRefreshManager(dht *IpfsDHT, cfg dhtcfg.Config, maxLastSuccessfulOutboundThreshold time.Duration) (*rtrefresh.RtRefreshManager, error) { keyGenFnc := func(cpl uint) (string, error) { p, err := dht.routingTable.GenRandPeerID(cpl) return string(p), err @@ -363,18 +361,18 @@ func makeRtRefreshManager(dht *IpfsDHT, cfg config, maxLastSuccessfulOutboundThr } r, err := rtrefresh.NewRtRefreshManager( - dht.host, dht.routingTable, cfg.routingTable.autoRefresh, + dht.host, dht.routingTable, cfg.RoutingTable.AutoRefresh, keyGenFnc, queryFnc, - cfg.routingTable.refreshQueryTimeout, - cfg.routingTable.refreshInterval, + cfg.RoutingTable.RefreshQueryTimeout, + cfg.RoutingTable.RefreshInterval, maxLastSuccessfulOutboundThreshold, dht.refreshFinishedCh) return r, err } -func makeRoutingTable(dht *IpfsDHT, cfg config, maxLastSuccessfulOutboundThreshold time.Duration) (*kb.RoutingTable, error) { +func makeRoutingTable(dht *IpfsDHT, cfg dhtcfg.Config, maxLastSuccessfulOutboundThreshold time.Duration) (*kb.RoutingTable, error) { // make a Routing Table Diversity Filter var filter *peerdiversity.Filter if dht.rtPeerDiversityFilter != nil { @@ -389,7 +387,7 @@ func makeRoutingTable(dht *IpfsDHT, cfg config, maxLastSuccessfulOutboundThresho filter = df } - rt, err := kb.NewRoutingTable(cfg.bucketSize, dht.selfKey, time.Minute, dht.host.Peerstore(), maxLastSuccessfulOutboundThreshold, filter) + rt, err := kb.NewRoutingTable(cfg.BucketSize, dht.selfKey, time.Minute, dht.host.Peerstore(), maxLastSuccessfulOutboundThreshold, filter) if err != nil { return nil, err } diff --git a/dht_filters.go b/dht_filters.go index df921963a..b6c041ea1 100644 --- a/dht_filters.go +++ b/dht_filters.go @@ -6,6 +6,7 @@ import ( "sync" "time" + "github.com/libp2p/go-libp2p-core/host" "github.com/libp2p/go-libp2p-core/network" "github.com/libp2p/go-libp2p-core/peer" @@ -14,14 +15,16 @@ import ( ma "github.com/multiformats/go-multiaddr" manet "github.com/multiformats/go-multiaddr/net" + + dhtcfg "github.com/libp2p/go-libp2p-kad-dht/internal/config" ) // QueryFilterFunc is a filter applied when considering peers to dial when querying -type QueryFilterFunc func(dht *IpfsDHT, ai peer.AddrInfo) bool +type QueryFilterFunc = dhtcfg.QueryFilterFunc // RouteTableFilterFunc is a filter applied when considering connections to keep in // the local route table. -type RouteTableFilterFunc func(dht *IpfsDHT, conns []network.Conn) bool +type RouteTableFilterFunc = dhtcfg.RouteTableFilterFunc var publicCIDR6 = "2000::/3" var public6 *net.IPNet @@ -59,7 +62,7 @@ func isPrivateAddr(a ma.Multiaddr) bool { } // PublicQueryFilter returns true if the peer is suspected of being publicly accessible -func PublicQueryFilter(_ *IpfsDHT, ai peer.AddrInfo) bool { +func PublicQueryFilter(_ interface{}, ai peer.AddrInfo) bool { if len(ai.Addrs) == 0 { return false } @@ -73,18 +76,25 @@ func PublicQueryFilter(_ *IpfsDHT, ai peer.AddrInfo) bool { return hasPublicAddr } +type hasHost interface { + Host() host.Host +} + var _ QueryFilterFunc = PublicQueryFilter // PublicRoutingTableFilter allows a peer to be added to the routing table if the connections to that peer indicate // that it is on a public network -func PublicRoutingTableFilter(dht *IpfsDHT, conns []network.Conn) bool { +func PublicRoutingTableFilter(dht interface{}, p peer.ID) bool { + d := dht.(hasHost) + + conns := d.Host().Network().ConnsToPeer(p) if len(conns) == 0 { return false } // Do we have a public address for this peer? id := conns[0].RemotePeer() - known := dht.peerstore.PeerInfo(id) + known := d.Host().Peerstore().PeerInfo(id) for _, a := range known.Addrs { if !isRelayAddr(a) && isPublicAddr(a) { return true @@ -97,7 +107,7 @@ func PublicRoutingTableFilter(dht *IpfsDHT, conns []network.Conn) bool { var _ RouteTableFilterFunc = PublicRoutingTableFilter // PrivateQueryFilter doens't currently restrict which peers we are willing to query from the local DHT. -func PrivateQueryFilter(dht *IpfsDHT, ai peer.AddrInfo) bool { +func PrivateQueryFilter(_ interface{}, ai peer.AddrInfo) bool { return len(ai.Addrs) > 0 } @@ -137,10 +147,19 @@ func getCachedRouter() routing.Router { // PrivateRoutingTableFilter allows a peer to be added to the routing table if the connections to that peer indicate // that it is on a private network -func PrivateRoutingTableFilter(dht *IpfsDHT, conns []network.Conn) bool { +func PrivateRoutingTableFilter(dht interface{}, p peer.ID) bool { + d := dht.(hasHost) + conns := d.Host().Network().ConnsToPeer(p) + return privRTFilter(d, conns) +} + +func privRTFilter(dht interface{}, conns []network.Conn) bool { + d := dht.(hasHost) + h := d.Host() + router := getCachedRouter() myAdvertisedIPs := make([]net.IP, 0) - for _, a := range dht.Host().Addrs() { + for _, a := range h.Addrs() { if isPublicAddr(a) && !isRelayAddr(a) { ip, err := manet.ToIP(a) if err != nil { diff --git a/dht_filters_test.go b/dht_filters_test.go index e4b098afd..000d4572f 100644 --- a/dht_filters_test.go +++ b/dht_filters_test.go @@ -53,7 +53,7 @@ func TestFilterCaching(t *testing.T) { d := setupDHT(ctx, t, true) remote, _ := manet.FromIP(net.IPv4(8, 8, 8, 8)) - if PrivateRoutingTableFilter(d, []network.Conn{&mockConn{ + if privRTFilter(d, []network.Conn{&mockConn{ local: d.Host().Peerstore().PeerInfo(d.Host().ID()), remote: peer.AddrInfo{ID: "", Addrs: []ma.Multiaddr{remote}}, }}) { diff --git a/dht_net.go b/dht_net.go index 278216625..18f4bd733 100644 --- a/dht_net.go +++ b/dht_net.go @@ -1,15 +1,12 @@ package dht import ( - "bufio" - "fmt" "io" - "sync" "time" "github.com/libp2p/go-libp2p-core/network" - "github.com/libp2p/go-msgio/protoio" + "github.com/libp2p/go-libp2p-kad-dht/internal/net" "github.com/libp2p/go-libp2p-kad-dht/metrics" pb "github.com/libp2p/go-libp2p-kad-dht/pb" @@ -19,45 +16,10 @@ import ( "go.uber.org/zap" ) -var dhtReadMessageTimeout = 10 * time.Second var dhtStreamIdleTimeout = 1 * time.Minute // ErrReadTimeout is an error that occurs when no message is read within the timeout period. -var ErrReadTimeout = fmt.Errorf("timed out reading response") - -// The Protobuf writer performs multiple small writes when writing a message. -// We need to buffer those writes, to make sure that we're not sending a new -// packet for every single write. -type bufferedDelimitedWriter struct { - *bufio.Writer - protoio.WriteCloser -} - -var writerPool = sync.Pool{ - New: func() interface{} { - w := bufio.NewWriter(nil) - return &bufferedDelimitedWriter{ - Writer: w, - WriteCloser: protoio.NewDelimitedWriter(w), - } - }, -} - -func writeMsg(w io.Writer, mes *pb.Message) error { - bw := writerPool.Get().(*bufferedDelimitedWriter) - bw.Reset(w) - err := bw.WriteMsg(mes) - if err == nil { - err = bw.Flush() - } - bw.Reset(nil) - writerPool.Put(bw) - return err -} - -func (w *bufferedDelimitedWriter) Flush() error { - return w.Writer.Flush() -} +var ErrReadTimeout = net.ErrReadTimeout // handleNewStream implements the network.StreamHandler func (dht *IpfsDHT) handleNewStream(s network.Stream) { @@ -180,7 +142,7 @@ func (dht *IpfsDHT) handleNewMessage(s network.Stream) bool { } // send out response msg - err = writeMsg(s, resp) + err = net.WriteMsg(s, resp) if err != nil { stats.Record(ctx, metrics.ReceivedMessageErrors.M(1)) if c := baseLogger.Check(zap.DebugLevel, "error writing response"); c != nil { diff --git a/dht_options.go b/dht_options.go index 820145451..1f4e47afe 100644 --- a/dht_options.go +++ b/dht_options.go @@ -5,22 +5,19 @@ import ( "testing" "time" - "github.com/libp2p/go-libp2p-core/host" - "github.com/libp2p/go-libp2p-core/network" "github.com/libp2p/go-libp2p-core/peer" "github.com/libp2p/go-libp2p-core/protocol" + dhtcfg "github.com/libp2p/go-libp2p-kad-dht/internal/config" "github.com/libp2p/go-libp2p-kad-dht/providers" "github.com/libp2p/go-libp2p-kbucket/peerdiversity" record "github.com/libp2p/go-libp2p-record" ds "github.com/ipfs/go-datastore" - dssync "github.com/ipfs/go-datastore/sync" - "github.com/ipfs/go-ipns" ) // ModeOpt describes what mode the dht should operate in -type ModeOpt int +type ModeOpt = dhtcfg.ModeOpt const ( // ModeAuto utilizes EvtLocalReachabilityChanged events sent over the event bus to dynamically switch the DHT @@ -37,143 +34,13 @@ const ( // DefaultPrefix is the application specific prefix attached to all DHT protocols by default. const DefaultPrefix protocol.ID = "/ipfs" -// Options is a structure containing all the options that can be used when constructing a DHT. -type config struct { - datastore ds.Batching - validator record.Validator - validatorChanged bool // if true implies that the validator has been changed and that defaults should not be used - mode ModeOpt - protocolPrefix protocol.ID - v1ProtocolOverride protocol.ID - bucketSize int - concurrency int - resiliency int - maxRecordAge time.Duration - enableProviders bool - enableValues bool - providersOptions []providers.Option - queryPeerFilter QueryFilterFunc - - routingTable struct { - refreshQueryTimeout time.Duration - refreshInterval time.Duration - autoRefresh bool - latencyTolerance time.Duration - checkInterval time.Duration - peerFilter RouteTableFilterFunc - diversityFilter peerdiversity.PeerIPGroupFilter - } - - bootstrapPeers []peer.AddrInfo - - // test specific config options - disableFixLowPeers bool - testAddressUpdateProcessing bool -} - -func emptyQueryFilter(_ *IpfsDHT, ai peer.AddrInfo) bool { return true } -func emptyRTFilter(_ *IpfsDHT, conns []network.Conn) bool { return true } - -// apply applies the given options to this Option -func (c *config) apply(opts ...Option) error { - for i, opt := range opts { - if err := opt(c); err != nil { - return fmt.Errorf("dht option %d failed: %s", i, err) - } - } - return nil -} - -// applyFallbacks sets default values that could not be applied during config creation since they are dependent -// on other configuration parameters (e.g. optA is by default 2x optB) and/or on the Host -func (c *config) applyFallbacks(h host.Host) error { - if !c.validatorChanged { - nsval, ok := c.validator.(record.NamespacedValidator) - if ok { - if _, pkFound := nsval["pk"]; !pkFound { - nsval["pk"] = record.PublicKeyValidator{} - } - if _, ipnsFound := nsval["ipns"]; !ipnsFound { - nsval["ipns"] = ipns.Validator{KeyBook: h.Peerstore()} - } - } else { - return fmt.Errorf("the default validator was changed without being marked as changed") - } - } - return nil -} - -// Option DHT option type. -type Option func(*config) error - -const defaultBucketSize = 20 - -// defaults are the default DHT options. This option will be automatically -// prepended to any options you pass to the DHT constructor. -var defaults = func(o *config) error { - o.validator = record.NamespacedValidator{} - o.datastore = dssync.MutexWrap(ds.NewMapDatastore()) - o.protocolPrefix = DefaultPrefix - o.enableProviders = true - o.enableValues = true - o.queryPeerFilter = emptyQueryFilter - - o.routingTable.latencyTolerance = time.Minute - o.routingTable.refreshQueryTimeout = 1 * time.Minute - o.routingTable.refreshInterval = 10 * time.Minute - o.routingTable.autoRefresh = true - o.routingTable.peerFilter = emptyRTFilter - o.maxRecordAge = time.Hour * 36 - - o.bucketSize = defaultBucketSize - o.concurrency = 10 - o.resiliency = 3 - - return nil -} - -func (c *config) validate() error { - if c.protocolPrefix != DefaultPrefix { - return nil - } - if c.bucketSize != defaultBucketSize { - return fmt.Errorf("protocol prefix %s must use bucket size %d", DefaultPrefix, defaultBucketSize) - } - if !c.enableProviders { - return fmt.Errorf("protocol prefix %s must have providers enabled", DefaultPrefix) - } - if !c.enableValues { - return fmt.Errorf("protocol prefix %s must have values enabled", DefaultPrefix) - } - - nsval, isNSVal := c.validator.(record.NamespacedValidator) - if !isNSVal { - return fmt.Errorf("protocol prefix %s must use a namespaced validator", DefaultPrefix) - } - - if len(nsval) != 2 { - return fmt.Errorf("protocol prefix %s must have exactly two namespaced validators - /pk and /ipns", DefaultPrefix) - } - - if pkVal, pkValFound := nsval["pk"]; !pkValFound { - return fmt.Errorf("protocol prefix %s must support the /pk namespaced validator", DefaultPrefix) - } else if _, ok := pkVal.(record.PublicKeyValidator); !ok { - return fmt.Errorf("protocol prefix %s must use the record.PublicKeyValidator for the /pk namespace", DefaultPrefix) - } - - if ipnsVal, ipnsValFound := nsval["ipns"]; !ipnsValFound { - return fmt.Errorf("protocol prefix %s must support the /ipns namespaced validator", DefaultPrefix) - } else if _, ok := ipnsVal.(ipns.Validator); !ok { - return fmt.Errorf("protocol prefix %s must use ipns.Validator for the /ipns namespace", DefaultPrefix) - } - return nil -} +type Option = dhtcfg.Option // RoutingTableLatencyTolerance sets the maximum acceptable latency for peers // in the routing table's cluster. func RoutingTableLatencyTolerance(latency time.Duration) Option { - return func(c *config) error { - c.routingTable.latencyTolerance = latency + return func(c *dhtcfg.Config) error { + c.RoutingTable.LatencyTolerance = latency return nil } } @@ -181,8 +48,8 @@ func RoutingTableLatencyTolerance(latency time.Duration) Option { // RoutingTableRefreshQueryTimeout sets the timeout for routing table refresh // queries. func RoutingTableRefreshQueryTimeout(timeout time.Duration) Option { - return func(c *config) error { - c.routingTable.refreshQueryTimeout = timeout + return func(c *dhtcfg.Config) error { + c.RoutingTable.RefreshQueryTimeout = timeout return nil } } @@ -194,8 +61,8 @@ func RoutingTableRefreshQueryTimeout(timeout time.Duration) Option { // 1. Then searching for a random key in each bucket that hasn't been queried in // the last refresh period. func RoutingTableRefreshPeriod(period time.Duration) Option { - return func(c *config) error { - c.routingTable.refreshInterval = period + return func(c *dhtcfg.Config) error { + c.RoutingTable.RefreshInterval = period return nil } } @@ -204,8 +71,8 @@ func RoutingTableRefreshPeriod(period time.Duration) Option { // // Defaults to an in-memory (temporary) map. func Datastore(ds ds.Batching) Option { - return func(c *config) error { - c.datastore = ds + return func(c *dhtcfg.Config) error { + c.Datastore = ds return nil } } @@ -214,8 +81,8 @@ func Datastore(ds ds.Batching) Option { // // Defaults to ModeAuto. func Mode(m ModeOpt) Option { - return func(c *config) error { - c.mode = m + return func(c *dhtcfg.Config) error { + c.Mode = m return nil } } @@ -227,9 +94,9 @@ func Mode(m ModeOpt) Option { // implies that the user wants to control the validators and therefore the default // public key and IPNS validators will not be added. func Validator(v record.Validator) Option { - return func(c *config) error { - c.validator = v - c.validatorChanged = true + return func(c *dhtcfg.Config) error { + c.Validator = v + c.ValidatorChanged = true return nil } } @@ -246,8 +113,8 @@ func Validator(v record.Validator) Option { // myValidator)`, all records with keys starting with `/ipns/` will be validated // with `myValidator`. func NamespacedValidator(ns string, v record.Validator) Option { - return func(c *config) error { - nsval, ok := c.validator.(record.NamespacedValidator) + return func(c *dhtcfg.Config) error { + nsval, ok := c.Validator.(record.NamespacedValidator) if !ok { return fmt.Errorf("can only add namespaced validators to a NamespacedValidator") } @@ -261,8 +128,8 @@ func NamespacedValidator(ns string, v record.Validator) Option { // // Defaults to dht.DefaultPrefix func ProtocolPrefix(prefix protocol.ID) Option { - return func(c *config) error { - c.protocolPrefix = prefix + return func(c *dhtcfg.Config) error { + c.ProtocolPrefix = prefix return nil } } @@ -270,8 +137,8 @@ func ProtocolPrefix(prefix protocol.ID) Option { // ProtocolExtension adds an application specific protocol to the DHT protocol. For example, // /ipfs/lan/kad/1.0.0 instead of /ipfs/kad/1.0.0. extension should be of the form /lan. func ProtocolExtension(ext protocol.ID) Option { - return func(c *config) error { - c.protocolPrefix += ext + return func(c *dhtcfg.Config) error { + c.ProtocolPrefix += ext return nil } } @@ -282,8 +149,8 @@ func ProtocolExtension(ext protocol.ID) Option { // // This option will override and ignore the ProtocolPrefix and ProtocolExtension options func V1ProtocolOverride(proto protocol.ID) Option { - return func(c *config) error { - c.v1ProtocolOverride = proto + return func(c *dhtcfg.Config) error { + c.V1ProtocolOverride = proto return nil } } @@ -292,8 +159,8 @@ func V1ProtocolOverride(proto protocol.ID) Option { // // The default value is 20. func BucketSize(bucketSize int) Option { - return func(c *config) error { - c.bucketSize = bucketSize + return func(c *dhtcfg.Config) error { + c.BucketSize = bucketSize return nil } } @@ -302,8 +169,8 @@ func BucketSize(bucketSize int) Option { // // The default value is 10. func Concurrency(alpha int) Option { - return func(c *config) error { - c.concurrency = alpha + return func(c *dhtcfg.Config) error { + c.Concurrency = alpha return nil } } @@ -313,8 +180,8 @@ func Concurrency(alpha int) Option { // // The default value is 3. func Resiliency(beta int) Option { - return func(c *config) error { - c.resiliency = beta + return func(c *dhtcfg.Config) error { + c.Resiliency = beta return nil } } @@ -326,8 +193,8 @@ func Resiliency(beta int) Option { // until the year 2020 (a great time in the future). For that record to stick around // it must be rebroadcasted more frequently than once every 'MaxRecordAge' func MaxRecordAge(maxAge time.Duration) Option { - return func(c *config) error { - c.maxRecordAge = maxAge + return func(c *dhtcfg.Config) error { + c.MaxRecordAge = maxAge return nil } } @@ -336,8 +203,8 @@ func MaxRecordAge(maxAge time.Duration) Option { // table. This means that we will neither refresh the routing table periodically // nor when the routing table size goes below the minimum threshold. func DisableAutoRefresh() Option { - return func(c *config) error { - c.routingTable.autoRefresh = false + return func(c *dhtcfg.Config) error { + c.RoutingTable.AutoRefresh = false return nil } } @@ -349,8 +216,8 @@ func DisableAutoRefresh() Option { // WARNING: do not change this unless you're using a forked DHT (i.e., a private // network and/or distinct DHT protocols with the `Protocols` option). func DisableProviders() Option { - return func(c *config) error { - c.enableProviders = false + return func(c *dhtcfg.Config) error { + c.EnableProviders = false return nil } } @@ -363,8 +230,8 @@ func DisableProviders() Option { // WARNING: do not change this unless you're using a forked DHT (i.e., a private // network and/or distinct DHT protocols with the `Protocols` option). func DisableValues() Option { - return func(c *config) error { - c.enableValues = false + return func(c *dhtcfg.Config) error { + c.EnableValues = false return nil } } @@ -375,16 +242,16 @@ func DisableValues() Option { // them in between. These options are passed to the provider manager allowing // customisation of things like the GC interval and cache implementation. func ProvidersOptions(opts []providers.Option) Option { - return func(c *config) error { - c.providersOptions = opts + return func(c *dhtcfg.Config) error { + c.ProvidersOptions = opts return nil } } // QueryFilter sets a function that approves which peers may be dialed in a query func QueryFilter(filter QueryFilterFunc) Option { - return func(c *config) error { - c.queryPeerFilter = filter + return func(c *dhtcfg.Config) error { + c.QueryPeerFilter = filter return nil } } @@ -392,8 +259,8 @@ func QueryFilter(filter QueryFilterFunc) Option { // RoutingTableFilter sets a function that approves which peers may be added to the routing table. The host should // already have at least one connection to the peer under consideration. func RoutingTableFilter(filter RouteTableFilterFunc) Option { - return func(c *config) error { - c.routingTable.peerFilter = filter + return func(c *dhtcfg.Config) error { + c.RoutingTable.PeerFilter = filter return nil } } @@ -401,8 +268,8 @@ func RoutingTableFilter(filter RouteTableFilterFunc) Option { // BootstrapPeers configures the bootstrapping nodes that we will connect to to seed // and refresh our Routing Table if it becomes empty. func BootstrapPeers(bootstrappers ...peer.AddrInfo) Option { - return func(c *config) error { - c.bootstrapPeers = bootstrappers + return func(c *dhtcfg.Config) error { + c.BootstrapPeers = bootstrappers return nil } } @@ -411,8 +278,8 @@ func BootstrapPeers(bootstrappers ...peer.AddrInfo) Option { // to construct the diversity filter for the Routing Table. // Please see the docs for `peerdiversity.PeerIPGroupFilter` AND `peerdiversity.Filter` for more details. func RoutingTablePeerDiversityFilter(pg peerdiversity.PeerIPGroupFilter) Option { - return func(c *config) error { - c.routingTable.diversityFilter = pg + return func(c *dhtcfg.Config) error { + c.RoutingTable.DiversityFilter = pg return nil } } @@ -420,8 +287,8 @@ func RoutingTablePeerDiversityFilter(pg peerdiversity.PeerIPGroupFilter) Option // disableFixLowPeersRoutine disables the "fixLowPeers" routine in the DHT. // This is ONLY for tests. func disableFixLowPeersRoutine(t *testing.T) Option { - return func(c *config) error { - c.disableFixLowPeers = true + return func(c *dhtcfg.Config) error { + c.DisableFixLowPeers = true return nil } } @@ -430,8 +297,8 @@ func disableFixLowPeersRoutine(t *testing.T) Option { // This occurs even when AutoRefresh has been disabled. // This is ONLY for tests. func forceAddressUpdateProcessing(t *testing.T) Option { - return func(c *config) error { - c.testAddressUpdateProcessing = true + return func(c *dhtcfg.Config) error { + c.TestAddressUpdateProcessing = true return nil } } diff --git a/dht_test.go b/dht_test.go index ea26b1bd3..3bc45426b 100644 --- a/dht_test.go +++ b/dht_test.go @@ -562,28 +562,6 @@ func TestValueGetInvalid(t *testing.T) { testSetGet("valid", "newer", nil) } -func TestInvalidMessageSenderTracking(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - dht := setupDHT(ctx, t, false) - defer dht.Close() - - foo := peer.ID("asdasd") - _, err := dht.msgSender.messageSenderForPeer(ctx, foo) - if err == nil { - t.Fatal("that shouldnt have succeeded") - } - - dht.msgSender.smlk.Lock() - mscnt := len(dht.msgSender.strmap) - dht.msgSender.smlk.Unlock() - - if mscnt > 0 { - t.Fatal("should have no message senders in map") - } -} - func TestProvides(t *testing.T) { // t.Skip("skipping test to debug another") ctx, cancel := context.WithCancel(context.Background()) @@ -1187,7 +1165,7 @@ func TestFindPeerWithQueryFilter(t *testing.T) { defer cancel() filteredPeer := bhost.New(swarmt.GenSwarm(t, ctx, swarmt.OptDisableReuseport)) - dhts := setupDHTS(t, ctx, 4, QueryFilter(func(_ *IpfsDHT, ai peer.AddrInfo) bool { + dhts := setupDHTS(t, ctx, 4, QueryFilter(func(_ interface{}, ai peer.AddrInfo) bool { return ai.ID != filteredPeer.ID() })) defer func() { @@ -1501,14 +1479,9 @@ func testFindPeerQuery(t *testing.T, val := "foobar" rtval := kb.ConvertKey(val) - out, err := guy.GetClosestPeers(ctx, val) + outpeers, err := guy.GetClosestPeers(ctx, val) require.NoError(t, err) - var outpeers []peer.ID - for p := range out { - outpeers = append(outpeers, p) - } - sort.Sort(peer.IDSlice(outpeers)) exp := kb.SortClosestPeers(peers, rtval)[:minInt(guy.bucketSize, len(peers))] @@ -1542,13 +1515,8 @@ func TestFindClosestPeers(t *testing.T) { t.Fatal(err) } - var out []peer.ID - for p := range peers { - out = append(out, p) - } - - if len(out) < querier.beta { - t.Fatalf("got wrong number of peers (got %d, expected at least %d)", len(out), querier.beta) + if len(peers) < querier.beta { + t.Fatalf("got wrong number of peers (got %d, expected at least %d)", len(peers), querier.beta) } } @@ -2134,18 +2102,16 @@ func TestPreconnectedNodes(t *testing.T) { } // See if it works - peerCh, err := d2.GetClosestPeers(ctx, "testkey") + peers, err := d2.GetClosestPeers(ctx, "testkey") if err != nil { t.Fatal(err) } - select { - case p := <-peerCh: - if p == h1.ID() { - break - } + if len(peers) != 1 { + t.Fatal("why is there more than one peer?") + } + + if peers[0] != h1.ID() { t.Fatal("could not find peer") - case <-ctx.Done(): - t.Fatal("test hung") } } diff --git a/dual/dual_test.go b/dual/dual_test.go index 41c8a112b..bab6d726d 100644 --- a/dual/dual_test.go +++ b/dual/dual_test.go @@ -7,7 +7,7 @@ import ( "github.com/ipfs/go-cid" u "github.com/ipfs/go-ipfs-util" - "github.com/libp2p/go-libp2p-core/network" + "github.com/libp2p/go-libp2p-core/host" "github.com/libp2p/go-libp2p-core/peer" peerstore "github.com/libp2p/go-libp2p-core/peerstore" dht "github.com/libp2p/go-libp2p-kad-dht" @@ -34,9 +34,17 @@ type customRtHelper struct { allow peer.ID } -func MkFilterForPeer() (func(d *dht.IpfsDHT, conns []network.Conn) bool, *customRtHelper) { +func MkFilterForPeer() (func(_ interface{}, p peer.ID) bool, *customRtHelper) { helper := customRtHelper{} - f := func(_ *dht.IpfsDHT, conns []network.Conn) bool { + + type hasHost interface { + Host() host.Host + } + + f := func(dht interface{}, p peer.ID) bool { + d := dht.(hasHost) + conns := d.Host().Network().ConnsToPeer(p) + for _, c := range conns { if c.RemotePeer() == helper.allow { return true diff --git a/fullrt/dht.go b/fullrt/dht.go new file mode 100644 index 000000000..f01662737 --- /dev/null +++ b/fullrt/dht.go @@ -0,0 +1,1326 @@ +package fullrt + +import ( + "bytes" + "context" + "fmt" + "sync" + "sync/atomic" + "time" + + "github.com/multiformats/go-base32" + "github.com/multiformats/go-multiaddr" + "github.com/multiformats/go-multihash" + + "github.com/libp2p/go-libp2p-core/host" + "github.com/libp2p/go-libp2p-core/network" + "github.com/libp2p/go-libp2p-core/peer" + "github.com/libp2p/go-libp2p-core/peerstore" + "github.com/libp2p/go-libp2p-core/protocol" + "github.com/libp2p/go-libp2p-core/routing" + + "github.com/gogo/protobuf/proto" + "github.com/ipfs/go-cid" + ds "github.com/ipfs/go-datastore" + dssync "github.com/ipfs/go-datastore/sync" + u "github.com/ipfs/go-ipfs-util" + logging "github.com/ipfs/go-log" + + kaddht "github.com/libp2p/go-libp2p-kad-dht" + "github.com/libp2p/go-libp2p-kad-dht/crawler" + "github.com/libp2p/go-libp2p-kad-dht/internal" + internalConfig "github.com/libp2p/go-libp2p-kad-dht/internal/config" + "github.com/libp2p/go-libp2p-kad-dht/internal/net" + dht_pb "github.com/libp2p/go-libp2p-kad-dht/pb" + "github.com/libp2p/go-libp2p-kad-dht/providers" + kb "github.com/libp2p/go-libp2p-kbucket" + + record "github.com/libp2p/go-libp2p-record" + recpb "github.com/libp2p/go-libp2p-record/pb" + + "github.com/libp2p/go-libp2p-xor/kademlia" + kadkey "github.com/libp2p/go-libp2p-xor/key" + "github.com/libp2p/go-libp2p-xor/trie" +) + +var logger = logging.Logger("fullrtdht") + +// FullRT is an experimental DHT client that is under development. Expect breaking changes to occur in this client +// until it stabilizes. +type FullRT struct { + ctx context.Context + cancel context.CancelFunc + wg sync.WaitGroup + + enableValues, enableProviders bool + Validator record.Validator + ProviderManager *providers.ProviderManager + datastore ds.Datastore + h host.Host + + crawlerInterval time.Duration + lastCrawlTime time.Time + + crawler *crawler.Crawler + protoMessenger *dht_pb.ProtocolMessenger + messageSender dht_pb.MessageSender + + filterFromTable kaddht.QueryFilterFunc + rtLk sync.RWMutex + rt *trie.Trie + + kMapLk sync.RWMutex + keyToPeerMap map[string]peer.ID + + peerAddrsLk sync.RWMutex + peerAddrs map[peer.ID][]multiaddr.Multiaddr + + bootstrapPeers []*peer.AddrInfo + + bucketSize int + + triggerRefresh chan struct{} + + waitFrac float64 + timeoutPerOp time.Duration + + bulkSendParallelism int +} + +// NewFullRT creates a DHT client that tracks the full network. It takes a protocol prefix for the given network, +// For example, the protocol /ipfs/kad/1.0.0 has the prefix /ipfs. +// +// FullRT is an experimental DHT client that is under development. Expect breaking changes to occur in this client +// until it stabilizes. +// +// Not all of the standard DHT options are supported in this DHT. +func NewFullRT(h host.Host, protocolPrefix protocol.ID, options ...Option) (*FullRT, error) { + var fullrtcfg config + if err := fullrtcfg.apply(options...); err != nil { + return nil, err + } + + dhtcfg := &internalConfig.Config{ + Datastore: dssync.MutexWrap(ds.NewMapDatastore()), + Validator: record.NamespacedValidator{}, + ValidatorChanged: false, + EnableProviders: true, + EnableValues: true, + ProtocolPrefix: protocolPrefix, + } + + if err := dhtcfg.Apply(fullrtcfg.dhtOpts...); err != nil { + return nil, err + } + if err := dhtcfg.ApplyFallbacks(h); err != nil { + return nil, err + } + + if err := dhtcfg.Validate(); err != nil { + return nil, err + } + + ms := net.NewMessageSenderImpl(h, []protocol.ID{dhtcfg.ProtocolPrefix + "/kad/1.0.0"}) + protoMessenger, err := dht_pb.NewProtocolMessenger(ms, dht_pb.WithValidator(dhtcfg.Validator)) + if err != nil { + return nil, err + } + + c, err := crawler.New(h, crawler.WithParallelism(200)) + if err != nil { + return nil, err + } + + ctx, cancel := context.WithCancel(context.Background()) + + pm, err := providers.NewProviderManager(ctx, h.ID(), dhtcfg.Datastore) + if err != nil { + cancel() + return nil, err + } + + var bsPeers []*peer.AddrInfo + + for _, ai := range dhtcfg.BootstrapPeers { + tmpai := ai + bsPeers = append(bsPeers, &tmpai) + } + + rt := &FullRT{ + ctx: ctx, + cancel: cancel, + + enableValues: dhtcfg.EnableValues, + enableProviders: dhtcfg.EnableProviders, + Validator: dhtcfg.Validator, + ProviderManager: pm, + datastore: dhtcfg.Datastore, + h: h, + crawler: c, + messageSender: ms, + protoMessenger: protoMessenger, + filterFromTable: kaddht.PublicQueryFilter, + rt: trie.New(), + keyToPeerMap: make(map[string]peer.ID), + bucketSize: dhtcfg.BucketSize, + + peerAddrs: make(map[peer.ID][]multiaddr.Multiaddr), + bootstrapPeers: bsPeers, + + triggerRefresh: make(chan struct{}), + + waitFrac: 0.3, + timeoutPerOp: 5 * time.Second, + + crawlerInterval: time.Minute * 60, + + bulkSendParallelism: 10, + } + + rt.wg.Add(1) + go rt.runCrawler(ctx) + + return rt, nil +} + +type crawlVal struct { + addrs []multiaddr.Multiaddr + key kadkey.Key +} + +func (dht *FullRT) TriggerRefresh(ctx context.Context) error { + select { + case <-ctx.Done(): + return ctx.Err() + case dht.triggerRefresh <- struct{}{}: + return nil + case <-dht.ctx.Done(): + return fmt.Errorf("dht is closed") + } +} + +func (dht *FullRT) Stat() map[string]peer.ID { + newMap := make(map[string]peer.ID) + + dht.kMapLk.RLock() + for k, v := range dht.keyToPeerMap { + newMap[k] = v + } + dht.kMapLk.RUnlock() + return newMap +} + +func (dht *FullRT) Ready() bool { + dht.rtLk.RLock() + lastCrawlTime := dht.lastCrawlTime + dht.rtLk.RUnlock() + + if time.Since(lastCrawlTime) > dht.crawlerInterval { + return false + } + + // TODO: This function needs to be better defined. Perhaps based on going through the peer map and seeing when the + // last time we were connected to any of them was. + dht.peerAddrsLk.RLock() + rtSize := len(dht.keyToPeerMap) + dht.peerAddrsLk.RUnlock() + + return rtSize > len(dht.bootstrapPeers)+1 +} + +func (dht *FullRT) Host() host.Host { + return dht.h +} + +func (dht *FullRT) runCrawler(ctx context.Context) { + defer dht.wg.Done() + t := time.NewTicker(dht.crawlerInterval) + + m := make(map[peer.ID]*crawlVal) + mxLk := sync.Mutex{} + + initialTrigger := make(chan struct{}, 1) + initialTrigger <- struct{}{} + + for { + select { + case <-t.C: + case <-initialTrigger: + case <-dht.triggerRefresh: + case <-ctx.Done(): + return + } + + var addrs []*peer.AddrInfo + dht.peerAddrsLk.Lock() + for k := range m { + addrs = append(addrs, &peer.AddrInfo{ID: k}) // Addrs: v.addrs + } + + addrs = append(addrs, dht.bootstrapPeers...) + dht.peerAddrsLk.Unlock() + + for k := range m { + delete(m, k) + } + + start := time.Now() + dht.crawler.Run(ctx, addrs, + func(p peer.ID, rtPeers []*peer.AddrInfo) { + conns := dht.h.Network().ConnsToPeer(p) + var addrs []multiaddr.Multiaddr + for _, conn := range conns { + addr := conn.RemoteMultiaddr() + addrs = append(addrs, addr) + } + + if len(addrs) == 0 { + logger.Debugf("no connections to %v after successful query. keeping addresses from the peerstore", p) + addrs = dht.h.Peerstore().Addrs(p) + } + + keep := kaddht.PublicRoutingTableFilter(dht, p) + if !keep { + return + } + + mxLk.Lock() + defer mxLk.Unlock() + m[p] = &crawlVal{ + addrs: addrs, + } + }, + func(p peer.ID, err error) {}) + dur := time.Since(start) + logger.Infof("crawl took %v", dur) + + peerAddrs := make(map[peer.ID][]multiaddr.Multiaddr) + kPeerMap := make(map[string]peer.ID) + newRt := trie.New() + for k, v := range m { + v.key = kadkey.KbucketIDToKey(kb.ConvertPeerID(k)) + peerAddrs[k] = v.addrs + kPeerMap[string(v.key)] = k + newRt.Add(v.key) + } + + dht.peerAddrsLk.Lock() + dht.peerAddrs = peerAddrs + dht.peerAddrsLk.Unlock() + + dht.kMapLk.Lock() + dht.keyToPeerMap = kPeerMap + dht.kMapLk.Unlock() + + dht.rtLk.Lock() + dht.rt = newRt + dht.lastCrawlTime = time.Now() + dht.rtLk.Unlock() + } +} + +func (dht *FullRT) Close() error { + dht.cancel() + err := dht.ProviderManager.Process().Close() + dht.wg.Wait() + return err +} + +func (dht *FullRT) Bootstrap(ctx context.Context) error { + return nil +} + +// CheckPeers return (success, total) +func (dht *FullRT) CheckPeers(ctx context.Context, peers ...peer.ID) (int, int) { + var peerAddrs chan interface{} + var total int + if len(peers) == 0 { + dht.peerAddrsLk.RLock() + total = len(dht.peerAddrs) + peerAddrs = make(chan interface{}, total) + for k, v := range dht.peerAddrs { + peerAddrs <- peer.AddrInfo{ + ID: k, + Addrs: v, + } + } + close(peerAddrs) + dht.peerAddrsLk.RUnlock() + } else { + total = len(peers) + peerAddrs = make(chan interface{}, total) + dht.peerAddrsLk.RLock() + for _, p := range peers { + peerAddrs <- peer.AddrInfo{ + ID: p, + Addrs: dht.peerAddrs[p], + } + } + close(peerAddrs) + dht.peerAddrsLk.RUnlock() + } + + var success uint64 + + workers(100, func(i interface{}) { + a := i.(peer.AddrInfo) + dialctx, dialcancel := context.WithTimeout(ctx, time.Second*3) + if err := dht.h.Connect(dialctx, a); err == nil { + atomic.AddUint64(&success, 1) + } + dialcancel() + }, peerAddrs) + return int(success), total +} + +func workers(numWorkers int, fn func(interface{}), inputs <-chan interface{}) { + jobs := make(chan interface{}) + defer close(jobs) + for i := 0; i < numWorkers; i++ { + go func() { + for j := range jobs { + fn(j) + } + }() + } + for i := range inputs { + jobs <- i + } +} + +func (dht *FullRT) GetClosestPeers(ctx context.Context, key string) ([]peer.ID, error) { + kbID := kb.ConvertKey(key) + kadKey := kadkey.KbucketIDToKey(kbID) + dht.rtLk.RLock() + closestKeys := kademlia.ClosestN(kadKey, dht.rt, dht.bucketSize) + dht.rtLk.RUnlock() + + peers := make([]peer.ID, 0, len(closestKeys)) + for _, k := range closestKeys { + dht.kMapLk.RLock() + p, ok := dht.keyToPeerMap[string(k)] + if !ok { + logger.Errorf("key not found in map") + } + dht.kMapLk.RUnlock() + dht.peerAddrsLk.RLock() + peerAddrs := dht.peerAddrs[p] + dht.peerAddrsLk.RUnlock() + + dht.h.Peerstore().AddAddrs(p, peerAddrs, peerstore.TempAddrTTL) + peers = append(peers, p) + } + return peers, nil +} + +// PutValue adds value corresponding to given Key. +// This is the top level "Store" operation of the DHT +func (dht *FullRT) PutValue(ctx context.Context, key string, value []byte, opts ...routing.Option) (err error) { + if !dht.enableValues { + return routing.ErrNotSupported + } + + logger.Debugw("putting value", "key", internal.LoggableRecordKeyString(key)) + + // don't even allow local users to put bad values. + if err := dht.Validator.Validate(key, value); err != nil { + return err + } + + old, err := dht.getLocal(key) + if err != nil { + // Means something is wrong with the datastore. + return err + } + + // Check if we have an old value that's not the same as the new one. + if old != nil && !bytes.Equal(old.GetValue(), value) { + // Check to see if the new one is better. + i, err := dht.Validator.Select(key, [][]byte{value, old.GetValue()}) + if err != nil { + return err + } + if i != 0 { + return fmt.Errorf("can't replace a newer value with an older value") + } + } + + rec := record.MakePutRecord(key, value) + rec.TimeReceived = u.FormatRFC3339(time.Now()) + err = dht.putLocal(key, rec) + if err != nil { + return err + } + + peers, err := dht.GetClosestPeers(ctx, key) + if err != nil { + return err + } + + successes := dht.execOnMany(ctx, func(ctx context.Context, p peer.ID) error { + routing.PublishQueryEvent(ctx, &routing.QueryEvent{ + Type: routing.Value, + ID: p, + }) + err := dht.protoMessenger.PutValue(ctx, p, rec) + return err + }, peers) + + if successes == 0 { + return fmt.Errorf("failed to complete put") + } + + return nil +} + +// RecvdVal stores a value and the peer from which we got the value. +type RecvdVal struct { + Val []byte + From peer.ID +} + +// GetValue searches for the value corresponding to given Key. +func (dht *FullRT) GetValue(ctx context.Context, key string, opts ...routing.Option) (_ []byte, err error) { + if !dht.enableValues { + return nil, routing.ErrNotSupported + } + + // apply defaultQuorum if relevant + var cfg routing.Options + if err := cfg.Apply(opts...); err != nil { + return nil, err + } + opts = append(opts, kaddht.Quorum(internalConfig.GetQuorum(&cfg))) + + responses, err := dht.SearchValue(ctx, key, opts...) + if err != nil { + return nil, err + } + var best []byte + + for r := range responses { + best = r + } + + if ctx.Err() != nil { + return best, ctx.Err() + } + + if best == nil { + return nil, routing.ErrNotFound + } + logger.Debugf("GetValue %v %x", internal.LoggableRecordKeyString(key), best) + return best, nil +} + +// SearchValue searches for the value corresponding to given Key and streams the results. +func (dht *FullRT) SearchValue(ctx context.Context, key string, opts ...routing.Option) (<-chan []byte, error) { + if !dht.enableValues { + return nil, routing.ErrNotSupported + } + + var cfg routing.Options + if err := cfg.Apply(opts...); err != nil { + return nil, err + } + + responsesNeeded := 0 + if !cfg.Offline { + responsesNeeded = internalConfig.GetQuorum(&cfg) + } + + stopCh := make(chan struct{}) + valCh, lookupRes := dht.getValues(ctx, key, stopCh) + + out := make(chan []byte) + go func() { + defer close(out) + best, peersWithBest, aborted := dht.searchValueQuorum(ctx, key, valCh, stopCh, out, responsesNeeded) + if best == nil || aborted { + return + } + + updatePeers := make([]peer.ID, 0, dht.bucketSize) + select { + case l := <-lookupRes: + if l == nil { + return + } + + for _, p := range l.peers { + if _, ok := peersWithBest[p]; !ok { + updatePeers = append(updatePeers, p) + } + } + case <-ctx.Done(): + return + } + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + dht.updatePeerValues(ctx, key, best, updatePeers) + cancel() + }() + + return out, nil +} + +func (dht *FullRT) searchValueQuorum(ctx context.Context, key string, valCh <-chan RecvdVal, stopCh chan struct{}, + out chan<- []byte, nvals int) ([]byte, map[peer.ID]struct{}, bool) { + numResponses := 0 + return dht.processValues(ctx, key, valCh, + func(ctx context.Context, v RecvdVal, better bool) bool { + numResponses++ + if better { + select { + case out <- v.Val: + case <-ctx.Done(): + return false + } + } + + if nvals > 0 && numResponses > nvals { + close(stopCh) + return true + } + return false + }) +} + +// GetValues gets nvals values corresponding to the given key. +func (dht *FullRT) GetValues(ctx context.Context, key string, nvals int) (_ []RecvdVal, err error) { + if !dht.enableValues { + return nil, routing.ErrNotSupported + } + + queryCtx, cancel := context.WithCancel(ctx) + defer cancel() + valCh, _ := dht.getValues(queryCtx, key, nil) + + out := make([]RecvdVal, 0, nvals) + for val := range valCh { + out = append(out, val) + if len(out) == nvals { + cancel() + } + } + + return out, ctx.Err() +} + +func (dht *FullRT) processValues(ctx context.Context, key string, vals <-chan RecvdVal, + newVal func(ctx context.Context, v RecvdVal, better bool) bool) (best []byte, peersWithBest map[peer.ID]struct{}, aborted bool) { +loop: + for { + if aborted { + return + } + + select { + case v, ok := <-vals: + if !ok { + break loop + } + + // Select best value + if best != nil { + if bytes.Equal(best, v.Val) { + peersWithBest[v.From] = struct{}{} + aborted = newVal(ctx, v, false) + continue + } + sel, err := dht.Validator.Select(key, [][]byte{best, v.Val}) + if err != nil { + logger.Warnw("failed to select best value", "key", internal.LoggableRecordKeyString(key), "error", err) + continue + } + if sel != 1 { + aborted = newVal(ctx, v, false) + continue + } + } + peersWithBest = make(map[peer.ID]struct{}) + peersWithBest[v.From] = struct{}{} + best = v.Val + aborted = newVal(ctx, v, true) + case <-ctx.Done(): + return + } + } + + return +} + +func (dht *FullRT) updatePeerValues(ctx context.Context, key string, val []byte, peers []peer.ID) { + fixupRec := record.MakePutRecord(key, val) + for _, p := range peers { + go func(p peer.ID) { + //TODO: Is this possible? + if p == dht.h.ID() { + err := dht.putLocal(key, fixupRec) + if err != nil { + logger.Error("Error correcting local dht entry:", err) + } + return + } + ctx, cancel := context.WithTimeout(ctx, time.Second*5) + defer cancel() + err := dht.protoMessenger.PutValue(ctx, p, fixupRec) + if err != nil { + logger.Debug("Error correcting DHT entry: ", err) + } + }(p) + } +} + +type lookupWithFollowupResult struct { + peers []peer.ID // the top K not unreachable peers at the end of the query +} + +func (dht *FullRT) getValues(ctx context.Context, key string, stopQuery chan struct{}) (<-chan RecvdVal, <-chan *lookupWithFollowupResult) { + valCh := make(chan RecvdVal, 1) + lookupResCh := make(chan *lookupWithFollowupResult, 1) + + logger.Debugw("finding value", "key", internal.LoggableRecordKeyString(key)) + + if rec, err := dht.getLocal(key); rec != nil && err == nil { + select { + case valCh <- RecvdVal{ + Val: rec.GetValue(), + From: dht.h.ID(), + }: + case <-ctx.Done(): + } + } + peers, err := dht.GetClosestPeers(ctx, key) + if err != nil { + lookupResCh <- &lookupWithFollowupResult{} + close(valCh) + close(lookupResCh) + return valCh, lookupResCh + } + + go func() { + defer close(valCh) + defer close(lookupResCh) + queryFn := func(ctx context.Context, p peer.ID) error { + // For DHT query command + routing.PublishQueryEvent(ctx, &routing.QueryEvent{ + Type: routing.SendingQuery, + ID: p, + }) + + rec, peers, err := dht.protoMessenger.GetValue(ctx, p, key) + switch err { + case routing.ErrNotFound: + // in this case, they responded with nothing, + // still send a notification so listeners can know the + // request has completed 'successfully' + routing.PublishQueryEvent(ctx, &routing.QueryEvent{ + Type: routing.PeerResponse, + ID: p, + }) + return nil + default: + return err + case nil, internal.ErrInvalidRecord: + // in either of these cases, we want to keep going + } + + // TODO: What should happen if the record is invalid? + // Pre-existing code counted it towards the quorum, but should it? + if rec != nil && rec.GetValue() != nil { + rv := RecvdVal{ + Val: rec.GetValue(), + From: p, + } + + select { + case valCh <- rv: + case <-ctx.Done(): + return ctx.Err() + } + } + + // For DHT query command + routing.PublishQueryEvent(ctx, &routing.QueryEvent{ + Type: routing.PeerResponse, + ID: p, + Responses: peers, + }) + + return nil + } + + dht.execOnMany(ctx, queryFn, peers) + lookupResCh <- &lookupWithFollowupResult{peers: peers} + }() + return valCh, lookupResCh +} + +// Provider abstraction for indirect stores. +// Some DHTs store values directly, while an indirect store stores pointers to +// locations of the value, similarly to Coral and Mainline DHT. + +// Provide makes this node announce that it can provide a value for the given key +func (dht *FullRT) Provide(ctx context.Context, key cid.Cid, brdcst bool) (err error) { + if !dht.enableProviders { + return routing.ErrNotSupported + } else if !key.Defined() { + return fmt.Errorf("invalid cid: undefined") + } + keyMH := key.Hash() + logger.Debugw("providing", "cid", key, "mh", internal.LoggableProviderRecordBytes(keyMH)) + + // add self locally + dht.ProviderManager.AddProvider(ctx, keyMH, dht.h.ID()) + if !brdcst { + return nil + } + + closerCtx := ctx + if deadline, ok := ctx.Deadline(); ok { + now := time.Now() + timeout := deadline.Sub(now) + + if timeout < 0 { + // timed out + return context.DeadlineExceeded + } else if timeout < 10*time.Second { + // Reserve 10% for the final put. + deadline = deadline.Add(-timeout / 10) + } else { + // Otherwise, reserve a second (we'll already be + // connected so this should be fast). + deadline = deadline.Add(-time.Second) + } + var cancel context.CancelFunc + closerCtx, cancel = context.WithDeadline(ctx, deadline) + defer cancel() + } + + var exceededDeadline bool + peers, err := dht.GetClosestPeers(closerCtx, string(keyMH)) + switch err { + case context.DeadlineExceeded: + // If the _inner_ deadline has been exceeded but the _outer_ + // context is still fine, provide the value to the closest peers + // we managed to find, even if they're not the _actual_ closest peers. + if ctx.Err() != nil { + return ctx.Err() + } + exceededDeadline = true + case nil: + default: + return err + } + + successes := dht.execOnMany(ctx, func(ctx context.Context, p peer.ID) error { + err := dht.protoMessenger.PutProvider(ctx, p, keyMH, dht.h) + return err + }, peers) + + if exceededDeadline { + return context.DeadlineExceeded + } + + if successes == 0 { + return fmt.Errorf("failed to complete provide") + } + + return ctx.Err() +} + +func (dht *FullRT) execOnMany(ctx context.Context, fn func(context.Context, peer.ID) error, peers []peer.ID) int { + putctx, cancel := context.WithCancel(ctx) + defer cancel() + + waitAllCh := make(chan struct{}, len(peers)) + numSuccessfulToWaitFor := int(float64(len(peers)) * dht.waitFrac) + waitSuccessCh := make(chan struct{}, numSuccessfulToWaitFor) + for _, p := range peers { + go func(p peer.ID) { + fnCtx, fnCancel := context.WithTimeout(putctx, dht.timeoutPerOp) + defer fnCancel() + err := fn(fnCtx, p) + if err != nil { + logger.Debug(err) + } else { + waitSuccessCh <- struct{}{} + } + waitAllCh <- struct{}{} + }(p) + } + + numSuccess, numDone := 0, 0 + t := time.NewTimer(time.Hour) + for numDone != len(peers) { + select { + case <-waitAllCh: + numDone++ + case <-waitSuccessCh: + if numSuccess >= numSuccessfulToWaitFor { + t.Reset(time.Millisecond * 500) + } + numSuccess++ + numDone++ + case <-t.C: + cancel() + } + } + return numSuccess +} + +func (dht *FullRT) ProvideMany(ctx context.Context, keys []multihash.Multihash) error { + if !dht.enableProviders { + return routing.ErrNotSupported + } + + // Compute addresses once for all provides + pi := peer.AddrInfo{ + ID: dht.h.ID(), + Addrs: dht.h.Addrs(), + } + pbPeers := dht_pb.RawPeerInfosToPBPeers([]peer.AddrInfo{pi}) + + // TODO: We may want to limit the type of addresses in our provider records + // For example, in a WAN-only DHT prohibit sharing non-WAN addresses (e.g. 192.168.0.100) + if len(pi.Addrs) < 1 { + return fmt.Errorf("no known addresses for self, cannot put provider") + } + + fn := func(ctx context.Context, k peer.ID) error { + peers, err := dht.GetClosestPeers(ctx, string(k)) + if err != nil { + return err + } + successes := dht.execOnMany(ctx, func(ctx context.Context, p peer.ID) error { + pmes := dht_pb.NewMessage(dht_pb.Message_ADD_PROVIDER, multihash.Multihash(k), 0) + pmes.ProviderPeers = pbPeers + + return dht.messageSender.SendMessage(ctx, p, pmes) + }, peers) + if successes == 0 { + return fmt.Errorf("no successful provides") + } + return nil + } + + keysAsPeerIDs := make([]peer.ID, 0, len(keys)) + for _, k := range keys { + keysAsPeerIDs = append(keysAsPeerIDs, peer.ID(k)) + } + + return dht.bulkMessageSend(ctx, keysAsPeerIDs, fn, true) +} + +func (dht *FullRT) PutMany(ctx context.Context, keys []string, values [][]byte) error { + if !dht.enableValues { + return routing.ErrNotSupported + } + + if len(keys) != len(values) { + return fmt.Errorf("number of keys does not match the number of values") + } + + keysAsPeerIDs := make([]peer.ID, 0, len(keys)) + keyRecMap := make(map[string][]byte) + for i, k := range keys { + keysAsPeerIDs = append(keysAsPeerIDs, peer.ID(k)) + keyRecMap[k] = values[i] + } + + if len(keys) != len(keyRecMap) { + return fmt.Errorf("does not support duplicate keys") + } + + fn := func(ctx context.Context, k peer.ID) error { + peers, err := dht.GetClosestPeers(ctx, string(k)) + if err != nil { + return err + } + successes := dht.execOnMany(ctx, func(ctx context.Context, p peer.ID) error { + keyStr := string(k) + return dht.protoMessenger.PutValue(ctx, p, record.MakePutRecord(keyStr, keyRecMap[keyStr])) + }, peers) + if successes == 0 { + return fmt.Errorf("no successful puts") + } + return nil + } + + return dht.bulkMessageSend(ctx, keysAsPeerIDs, fn, false) +} + +func (dht *FullRT) bulkMessageSend(ctx context.Context, keys []peer.ID, fn func(ctx context.Context, k peer.ID) error, isProvRec bool) error { + if len(keys) == 0 { + return nil + } + + sortedKeys := kb.SortClosestPeers(keys, kb.ID(make([]byte, 32))) + + var numSends uint64 = 0 + var numSendsSuccessful uint64 = 0 + + wg := sync.WaitGroup{} + wg.Add(dht.bulkSendParallelism) + chunkSize := len(sortedKeys) / dht.bulkSendParallelism + onePctKeys := uint64(len(sortedKeys)) / 100 + for i := 0; i < dht.bulkSendParallelism; i++ { + var chunk []peer.ID + end := (i + 1) * chunkSize + if end > len(sortedKeys) { + chunk = sortedKeys[i*chunkSize:] + } else { + chunk = sortedKeys[i*chunkSize : end] + } + + go func() { + defer wg.Done() + for _, key := range chunk { + sendsSoFar := atomic.AddUint64(&numSends, 1) + if sendsSoFar%onePctKeys == 0 { + logger.Infof("bulk sending goroutine: %.1f%% done - %d/%d done", 100*float64(sendsSoFar)/float64(len(sortedKeys)), sendsSoFar, len(sortedKeys)) + } + if err := fn(ctx, key); err != nil { + var l interface{} + if isProvRec { + l = internal.LoggableProviderRecordBytes(key) + } else { + l = internal.LoggableRecordKeyString(key) + } + logger.Infof("failed to complete bulk sending of key :%v. %v", l, err) + } else { + atomic.AddUint64(&numSendsSuccessful, 1) + } + } + }() + } + wg.Wait() + + if numSendsSuccessful == 0 { + return fmt.Errorf("failed to complete bulk sending") + } + + logger.Infof("bulk send complete: %d of %d successful", numSendsSuccessful, len(keys)) + + return nil +} + +// FindProviders searches until the context expires. +func (dht *FullRT) FindProviders(ctx context.Context, c cid.Cid) ([]peer.AddrInfo, error) { + if !dht.enableProviders { + return nil, routing.ErrNotSupported + } else if !c.Defined() { + return nil, fmt.Errorf("invalid cid: undefined") + } + + var providers []peer.AddrInfo + for p := range dht.FindProvidersAsync(ctx, c, dht.bucketSize) { + providers = append(providers, p) + } + return providers, nil +} + +// FindProvidersAsync is the same thing as FindProviders, but returns a channel. +// Peers will be returned on the channel as soon as they are found, even before +// the search query completes. If count is zero then the query will run until it +// completes. Note: not reading from the returned channel may block the query +// from progressing. +func (dht *FullRT) FindProvidersAsync(ctx context.Context, key cid.Cid, count int) <-chan peer.AddrInfo { + if !dht.enableProviders || !key.Defined() { + peerOut := make(chan peer.AddrInfo) + close(peerOut) + return peerOut + } + + chSize := count + if count == 0 { + chSize = 1 + } + peerOut := make(chan peer.AddrInfo, chSize) + + keyMH := key.Hash() + + logger.Debugw("finding providers", "cid", key, "mh", internal.LoggableProviderRecordBytes(keyMH)) + go dht.findProvidersAsyncRoutine(ctx, keyMH, count, peerOut) + return peerOut +} + +func (dht *FullRT) findProvidersAsyncRoutine(ctx context.Context, key multihash.Multihash, count int, peerOut chan peer.AddrInfo) { + defer close(peerOut) + + findAll := count == 0 + var ps *peer.Set + if findAll { + ps = peer.NewSet() + } else { + ps = peer.NewLimitedSet(count) + } + + provs := dht.ProviderManager.GetProviders(ctx, key) + for _, p := range provs { + // NOTE: Assuming that this list of peers is unique + if ps.TryAdd(p) { + pi := dht.h.Peerstore().PeerInfo(p) + select { + case peerOut <- pi: + case <-ctx.Done(): + return + } + } + + // If we have enough peers locally, don't bother with remote RPC + // TODO: is this a DOS vector? + if !findAll && ps.Size() >= count { + return + } + } + + peers, err := dht.GetClosestPeers(ctx, string(key)) + if err != nil { + return + } + + queryctx, cancelquery := context.WithCancel(ctx) + defer cancelquery() + + fn := func(ctx context.Context, p peer.ID) error { + // For DHT query command + routing.PublishQueryEvent(ctx, &routing.QueryEvent{ + Type: routing.SendingQuery, + ID: p, + }) + + provs, closest, err := dht.protoMessenger.GetProviders(ctx, p, key) + if err != nil { + return err + } + + logger.Debugf("%d provider entries", len(provs)) + + // Add unique providers from request, up to 'count' + for _, prov := range provs { + dht.maybeAddAddrs(prov.ID, prov.Addrs, peerstore.TempAddrTTL) + logger.Debugf("got provider: %s", prov) + if ps.TryAdd(prov.ID) { + logger.Debugf("using provider: %s", prov) + select { + case peerOut <- *prov: + case <-ctx.Done(): + logger.Debug("context timed out sending more providers") + return ctx.Err() + } + } + if !findAll && ps.Size() >= count { + logger.Debugf("got enough providers (%d/%d)", ps.Size(), count) + cancelquery() + return nil + } + } + + // Give closer peers back to the query to be queried + logger.Debugf("got closer peers: %d %s", len(closest), closest) + + routing.PublishQueryEvent(ctx, &routing.QueryEvent{ + Type: routing.PeerResponse, + ID: p, + Responses: closest, + }) + return nil + } + + dht.execOnMany(queryctx, fn, peers) +} + +// FindPeer searches for a peer with given ID. +func (dht *FullRT) FindPeer(ctx context.Context, id peer.ID) (_ peer.AddrInfo, err error) { + if err := id.Validate(); err != nil { + return peer.AddrInfo{}, err + } + + logger.Debugw("finding peer", "peer", id) + + // Check if were already connected to them + if pi := dht.FindLocal(id); pi.ID != "" { + return pi, nil + } + + peers, err := dht.GetClosestPeers(ctx, string(id)) + if err != nil { + return peer.AddrInfo{}, err + } + + queryctx, cancelquery := context.WithCancel(ctx) + defer cancelquery() + + addrsCh := make(chan *peer.AddrInfo, 1) + newAddrs := make([]multiaddr.Multiaddr, 0) + + wg := sync.WaitGroup{} + wg.Add(1) + go func() { + defer wg.Done() + addrsSoFar := make(map[multiaddr.Multiaddr]struct{}) + for { + select { + case ai, ok := <-addrsCh: + if !ok { + return + } + + for _, a := range ai.Addrs { + _, found := addrsSoFar[a] + if !found { + newAddrs = append(newAddrs, a) + addrsSoFar[a] = struct{}{} + } + } + case <-ctx.Done(): + return + } + } + }() + + fn := func(ctx context.Context, p peer.ID) error { + // For DHT query command + routing.PublishQueryEvent(ctx, &routing.QueryEvent{ + Type: routing.SendingQuery, + ID: p, + }) + + peers, err := dht.protoMessenger.GetClosestPeers(ctx, p, id) + if err != nil { + logger.Debugf("error getting closer peers: %s", err) + return err + } + + // For DHT query command + routing.PublishQueryEvent(ctx, &routing.QueryEvent{ + Type: routing.PeerResponse, + ID: p, + Responses: peers, + }) + + for _, a := range peers { + if a.ID == id { + select { + case addrsCh <- a: + case <-ctx.Done(): + return ctx.Err() + } + return nil + } + } + return nil + } + + dht.execOnMany(queryctx, fn, peers) + + close(addrsCh) + wg.Wait() + + if len(newAddrs) > 0 { + connctx, cancelconn := context.WithTimeout(ctx, time.Second*5) + defer cancelconn() + _ = dht.h.Connect(connctx, peer.AddrInfo{ + ID: id, + Addrs: newAddrs, + }) + } + + // Return peer information if we tried to dial the peer during the query or we are (or recently were) connected + // to the peer. + connectedness := dht.h.Network().Connectedness(id) + if connectedness == network.Connected || connectedness == network.CanConnect { + return dht.h.Peerstore().PeerInfo(id), nil + } + + return peer.AddrInfo{}, routing.ErrNotFound +} + +var _ routing.Routing = (*FullRT)(nil) + +// getLocal attempts to retrieve the value from the datastore. +// +// returns nil, nil when either nothing is found or the value found doesn't properly validate. +// returns nil, some_error when there's a *datastore* error (i.e., something goes very wrong) +func (dht *FullRT) getLocal(key string) (*recpb.Record, error) { + logger.Debugw("finding value in datastore", "key", internal.LoggableRecordKeyString(key)) + + rec, err := dht.getRecordFromDatastore(mkDsKey(key)) + if err != nil { + logger.Warnw("get local failed", "key", internal.LoggableRecordKeyString(key), "error", err) + return nil, err + } + + // Double check the key. Can't hurt. + if rec != nil && string(rec.GetKey()) != key { + logger.Errorw("BUG: found a DHT record that didn't match it's key", "expected", internal.LoggableRecordKeyString(key), "got", rec.GetKey()) + return nil, nil + + } + return rec, nil +} + +// putLocal stores the key value pair in the datastore +func (dht *FullRT) putLocal(key string, rec *recpb.Record) error { + data, err := proto.Marshal(rec) + if err != nil { + logger.Warnw("failed to put marshal record for local put", "error", err, "key", internal.LoggableRecordKeyString(key)) + return err + } + + return dht.datastore.Put(mkDsKey(key), data) +} + +func mkDsKey(s string) ds.Key { + return ds.NewKey(base32.RawStdEncoding.EncodeToString([]byte(s))) +} + +// returns nil, nil when either nothing is found or the value found doesn't properly validate. +// returns nil, some_error when there's a *datastore* error (i.e., something goes very wrong) +func (dht *FullRT) getRecordFromDatastore(dskey ds.Key) (*recpb.Record, error) { + buf, err := dht.datastore.Get(dskey) + if err == ds.ErrNotFound { + return nil, nil + } + if err != nil { + logger.Errorw("error retrieving record from datastore", "key", dskey, "error", err) + return nil, err + } + rec := new(recpb.Record) + err = proto.Unmarshal(buf, rec) + if err != nil { + // Bad data in datastore, log it but don't return an error, we'll just overwrite it + logger.Errorw("failed to unmarshal record from datastore", "key", dskey, "error", err) + return nil, nil + } + + err = dht.Validator.Validate(string(rec.GetKey()), rec.GetValue()) + if err != nil { + // Invalid record in datastore, probably expired but don't return an error, + // we'll just overwrite it + logger.Debugw("local record verify failed", "key", rec.GetKey(), "error", err) + return nil, nil + } + + return rec, nil +} + +// FindLocal looks for a peer with a given ID connected to this dht and returns the peer and the table it was found in. +func (dht *FullRT) FindLocal(id peer.ID) peer.AddrInfo { + switch dht.h.Network().Connectedness(id) { + case network.Connected, network.CanConnect: + return dht.h.Peerstore().PeerInfo(id) + default: + return peer.AddrInfo{} + } +} + +func (dht *FullRT) maybeAddAddrs(p peer.ID, addrs []multiaddr.Multiaddr, ttl time.Duration) { + // Don't add addresses for self or our connected peers. We have better ones. + if p == dht.h.ID() || dht.h.Network().Connectedness(p) == network.Connected { + return + } + dht.h.Peerstore().AddAddrs(p, addrs, ttl) +} diff --git a/fullrt/options.go b/fullrt/options.go new file mode 100644 index 000000000..cd0f9ba59 --- /dev/null +++ b/fullrt/options.go @@ -0,0 +1,28 @@ +package fullrt + +import ( + "fmt" + kaddht "github.com/libp2p/go-libp2p-kad-dht" +) + +type config struct { + dhtOpts []kaddht.Option +} + +func (cfg *config) apply(opts ...Option) error { + for i, o := range opts { + if err := o(cfg); err != nil { + return fmt.Errorf("fullrt dht option %d failed: %w", i, err) + } + } + return nil +} + +type Option func(opt *config) error + +func DHTOption(opts ...kaddht.Option) Option { + return func(c *config) error { + c.dhtOpts = append(c.dhtOpts, opts...) + return nil + } +} diff --git a/go.mod b/go.mod index 2ba606e08..bf25c800b 100644 --- a/go.mod +++ b/go.mod @@ -24,6 +24,7 @@ require ( github.com/libp2p/go-libp2p-routing-helpers v0.2.3 github.com/libp2p/go-libp2p-swarm v0.4.0 github.com/libp2p/go-libp2p-testing v0.4.0 + github.com/libp2p/go-libp2p-xor v0.0.0-20200501025846-71e284145d58 github.com/libp2p/go-msgio v0.0.6 github.com/libp2p/go-netroute v0.1.6 github.com/multiformats/go-base32 v0.0.3 diff --git a/go.sum b/go.sum index 298e40e2a..b54217d82 100644 --- a/go.sum +++ b/go.sum @@ -250,6 +250,7 @@ github.com/libp2p/go-libp2p-discovery v0.2.0/go.mod h1:s4VGaxYMbw4+4+tsoQTqh7wfx github.com/libp2p/go-libp2p-discovery v0.3.0/go.mod h1:o03drFnz9BVAZdzC/QUQ+NeQOu38Fu7LJGEOK2gQltw= github.com/libp2p/go-libp2p-discovery v0.5.0 h1:Qfl+e5+lfDgwdrXdu4YNCWyEo3fWuP+WgN9mN0iWviQ= github.com/libp2p/go-libp2p-discovery v0.5.0/go.mod h1:+srtPIU9gDaBNu//UHvcdliKBIcr4SfDcm0/PfPJLug= +github.com/libp2p/go-libp2p-kbucket v0.3.1/go.mod h1:oyjT5O7tS9CQurok++ERgc46YLwEpuGoFq9ubvoUOio= github.com/libp2p/go-libp2p-kbucket v0.4.7 h1:spZAcgxifvFZHBD8tErvppbnNiKA5uokDu3CV7axu70= github.com/libp2p/go-libp2p-kbucket v0.4.7/go.mod h1:XyVo99AfQH0foSf176k4jY1xUJ2+jUJIZCSDm7r2YKk= github.com/libp2p/go-libp2p-loggables v0.1.0 h1:h3w8QFfCt2UJl/0/NW4K829HX/0S4KD31PQ7m8UXXO8= @@ -312,6 +313,8 @@ github.com/libp2p/go-libp2p-transport-upgrader v0.2.0/go.mod h1:mQcrHj4asu6ArfSo github.com/libp2p/go-libp2p-transport-upgrader v0.3.0/go.mod h1:i+SKzbRnvXdVbU3D1dwydnTmKRPXiAR/fyvi1dXuL4o= github.com/libp2p/go-libp2p-transport-upgrader v0.4.0 h1:xwj4h3hJdBrxqMOyMUjwscjoVst0AASTsKtZiTChoHI= github.com/libp2p/go-libp2p-transport-upgrader v0.4.0/go.mod h1:J4ko0ObtZSmgn5BX5AmegP+dK3CSnU2lMCKsSq/EY0s= +github.com/libp2p/go-libp2p-xor v0.0.0-20200501025846-71e284145d58 h1:GcTNu27BMpOTtMnQqun03+kbtHA1qTxJ/J8cZRRYu2k= +github.com/libp2p/go-libp2p-xor v0.0.0-20200501025846-71e284145d58/go.mod h1:AYjOiqJIdcmI4SXE2ouKQuFrUbE5myv8txWaB2pl4TI= github.com/libp2p/go-libp2p-yamux v0.2.0/go.mod h1:Db2gU+XfLpm6E4rG5uGCFX6uXA8MEXOxFcRoXUODaK8= github.com/libp2p/go-libp2p-yamux v0.2.2/go.mod h1:lIohaR0pT6mOt0AZ0L2dFze9hds9Req3OfS+B+dv4qw= github.com/libp2p/go-libp2p-yamux v0.2.5/go.mod h1:Zpgj6arbyQrmZ3wxSZxfBmbdnWtbZ48OpsfmQVTErwA= @@ -510,6 +513,7 @@ github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/syndtr/goleveldb v1.0.0/go.mod h1:ZVVdQEZoIme9iO1Ch2Jdy24qqXrMMOU6lpPAyBWyWuQ= github.com/ugorji/go/codec v0.0.0-20181204163529-d75b2dcb6bc8/go.mod h1:VFNgLljTbGfSG7qAOspJ7OScBnGdDN/yBr0sguwnwf0= +github.com/wangjia184/sortedset v0.0.0-20160527075905-f5d03557ba30/go.mod h1:YkocrP2K2tcw938x9gCOmT5G5eCD6jsTz0SZuyAqwIE= github.com/whyrusleeping/go-keyspace v0.0.0-20160322163242-5b898ac5add1 h1:EKhdznlJHPMoKr0XTrX+IlJs1LH3lyx2nfr1dOlZ79k= github.com/whyrusleeping/go-keyspace v0.0.0-20160322163242-5b898ac5add1/go.mod h1:8UvriyWtv5Q5EOgjHaSseUEdkQfvwFv1I/In/O2M9gc= github.com/whyrusleeping/go-logging v0.0.0-20170515211332-0457bb6b88fc/go.mod h1:bopw91TMyo8J3tvftk8xmU2kPmlrt4nScJQZU2hE5EM= diff --git a/internal/config/config.go b/internal/config/config.go new file mode 100644 index 000000000..8e805688c --- /dev/null +++ b/internal/config/config.go @@ -0,0 +1,161 @@ +package config + +import ( + "fmt" + "time" + + ds "github.com/ipfs/go-datastore" + dssync "github.com/ipfs/go-datastore/sync" + "github.com/ipfs/go-ipns" + "github.com/libp2p/go-libp2p-core/host" + "github.com/libp2p/go-libp2p-core/peer" + "github.com/libp2p/go-libp2p-core/protocol" + "github.com/libp2p/go-libp2p-kad-dht/providers" + "github.com/libp2p/go-libp2p-kbucket/peerdiversity" + record "github.com/libp2p/go-libp2p-record" +) + +// DefaultPrefix is the application specific prefix attached to all DHT protocols by default. +const DefaultPrefix protocol.ID = "/ipfs" + +const defaultBucketSize = 20 + +// ModeOpt describes what mode the dht should operate in +type ModeOpt int + +// QueryFilterFunc is a filter applied when considering peers to dial when querying +type QueryFilterFunc func(dht interface{}, ai peer.AddrInfo) bool + +// RouteTableFilterFunc is a filter applied when considering connections to keep in +// the local route table. +type RouteTableFilterFunc func(dht interface{}, p peer.ID) bool + +// Config is a structure containing all the options that can be used when constructing a DHT. +type Config struct { + Datastore ds.Batching + Validator record.Validator + ValidatorChanged bool // if true implies that the validator has been changed and that Defaults should not be used + Mode ModeOpt + ProtocolPrefix protocol.ID + V1ProtocolOverride protocol.ID + BucketSize int + Concurrency int + Resiliency int + MaxRecordAge time.Duration + EnableProviders bool + EnableValues bool + ProvidersOptions []providers.Option + QueryPeerFilter QueryFilterFunc + + RoutingTable struct { + RefreshQueryTimeout time.Duration + RefreshInterval time.Duration + AutoRefresh bool + LatencyTolerance time.Duration + CheckInterval time.Duration + PeerFilter RouteTableFilterFunc + DiversityFilter peerdiversity.PeerIPGroupFilter + } + + BootstrapPeers []peer.AddrInfo + + // test specific Config options + DisableFixLowPeers bool + TestAddressUpdateProcessing bool +} + +func EmptyQueryFilter(_ interface{}, ai peer.AddrInfo) bool { return true } +func EmptyRTFilter(_ interface{}, p peer.ID) bool { return true } + +// Apply applies the given options to this Option +func (c *Config) Apply(opts ...Option) error { + for i, opt := range opts { + if err := opt(c); err != nil { + return fmt.Errorf("dht option %d failed: %s", i, err) + } + } + return nil +} + +// ApplyFallbacks sets default values that could not be applied during config creation since they are dependent +// on other configuration parameters (e.g. optA is by default 2x optB) and/or on the Host +func (c *Config) ApplyFallbacks(h host.Host) error { + if !c.ValidatorChanged { + nsval, ok := c.Validator.(record.NamespacedValidator) + if ok { + if _, pkFound := nsval["pk"]; !pkFound { + nsval["pk"] = record.PublicKeyValidator{} + } + if _, ipnsFound := nsval["ipns"]; !ipnsFound { + nsval["ipns"] = ipns.Validator{KeyBook: h.Peerstore()} + } + } else { + return fmt.Errorf("the default Validator was changed without being marked as changed") + } + } + return nil +} + +// Option DHT option type. +type Option func(*Config) error + +// Defaults are the default DHT options. This option will be automatically +// prepended to any options you pass to the DHT constructor. +var Defaults = func(o *Config) error { + o.Validator = record.NamespacedValidator{} + o.Datastore = dssync.MutexWrap(ds.NewMapDatastore()) + o.ProtocolPrefix = DefaultPrefix + o.EnableProviders = true + o.EnableValues = true + o.QueryPeerFilter = EmptyQueryFilter + + o.RoutingTable.LatencyTolerance = time.Minute + o.RoutingTable.RefreshQueryTimeout = 1 * time.Minute + o.RoutingTable.RefreshInterval = 10 * time.Minute + o.RoutingTable.AutoRefresh = true + o.RoutingTable.PeerFilter = EmptyRTFilter + o.MaxRecordAge = time.Hour * 36 + + o.BucketSize = defaultBucketSize + o.Concurrency = 10 + o.Resiliency = 3 + + return nil +} + +func (c *Config) Validate() error { + if c.ProtocolPrefix != DefaultPrefix { + return nil + } + if c.BucketSize != defaultBucketSize { + return fmt.Errorf("protocol prefix %s must use bucket size %d", DefaultPrefix, defaultBucketSize) + } + if !c.EnableProviders { + return fmt.Errorf("protocol prefix %s must have providers enabled", DefaultPrefix) + } + if !c.EnableValues { + return fmt.Errorf("protocol prefix %s must have values enabled", DefaultPrefix) + } + + nsval, isNSVal := c.Validator.(record.NamespacedValidator) + if !isNSVal { + return fmt.Errorf("protocol prefix %s must use a namespaced Validator", DefaultPrefix) + } + + if len(nsval) != 2 { + return fmt.Errorf("protocol prefix %s must have exactly two namespaced validators - /pk and /ipns", DefaultPrefix) + } + + if pkVal, pkValFound := nsval["pk"]; !pkValFound { + return fmt.Errorf("protocol prefix %s must support the /pk namespaced Validator", DefaultPrefix) + } else if _, ok := pkVal.(record.PublicKeyValidator); !ok { + return fmt.Errorf("protocol prefix %s must use the record.PublicKeyValidator for the /pk namespace", DefaultPrefix) + } + + if ipnsVal, ipnsValFound := nsval["ipns"]; !ipnsValFound { + return fmt.Errorf("protocol prefix %s must support the /ipns namespaced Validator", DefaultPrefix) + } else if _, ok := ipnsVal.(ipns.Validator); !ok { + return fmt.Errorf("protocol prefix %s must use ipns.Validator for the /ipns namespace", DefaultPrefix) + } + return nil +} diff --git a/internal/config/quorum.go b/internal/config/quorum.go new file mode 100644 index 000000000..ce5fba2a8 --- /dev/null +++ b/internal/config/quorum.go @@ -0,0 +1,16 @@ +package config + +import "github.com/libp2p/go-libp2p-core/routing" + +type QuorumOptionKey struct{} + +const defaultQuorum = 0 + +// GetQuorum defaults to 0 if no option is found +func GetQuorum(opts *routing.Options) int { + responsesNeeded, ok := opts.Other[QuorumOptionKey{}].(int) + if !ok { + responsesNeeded = defaultQuorum + } + return responsesNeeded +} diff --git a/internal/ctx_mutex.go b/internal/ctx_mutex.go new file mode 100644 index 000000000..4e923f6e0 --- /dev/null +++ b/internal/ctx_mutex.go @@ -0,0 +1,28 @@ +package internal + +import ( + "context" +) + +type CtxMutex chan struct{} + +func NewCtxMutex() CtxMutex { + return make(CtxMutex, 1) +} + +func (m CtxMutex) Lock(ctx context.Context) error { + select { + case m <- struct{}{}: + return nil + case <-ctx.Done(): + return ctx.Err() + } +} + +func (m CtxMutex) Unlock() { + select { + case <-m: + default: + panic("not locked") + } +} diff --git a/message_manager.go b/internal/net/message_manager.go similarity index 82% rename from message_manager.go rename to internal/net/message_manager.go index 8cc3e22e3..627c47a5b 100644 --- a/message_manager.go +++ b/internal/net/message_manager.go @@ -1,8 +1,10 @@ -package dht +package net import ( + "bufio" "context" "fmt" + "io" "sync" "time" @@ -11,14 +13,25 @@ import ( "github.com/libp2p/go-libp2p-core/peer" "github.com/libp2p/go-libp2p-core/protocol" - "github.com/libp2p/go-libp2p-kad-dht/metrics" - pb "github.com/libp2p/go-libp2p-kad-dht/pb" - + logging "github.com/ipfs/go-log" "github.com/libp2p/go-msgio" + "github.com/libp2p/go-msgio/protoio" + "go.opencensus.io/stats" "go.opencensus.io/tag" + + "github.com/libp2p/go-libp2p-kad-dht/internal" + "github.com/libp2p/go-libp2p-kad-dht/metrics" + pb "github.com/libp2p/go-libp2p-kad-dht/pb" ) +var dhtReadMessageTimeout = 10 * time.Second + +// ErrReadTimeout is an error that occurs when no message is read within the timeout period. +var ErrReadTimeout = fmt.Errorf("timed out reading response") + +var logger = logging.Logger("dht") + // messageSenderImpl is responsible for sending requests and messages to peers efficiently, including reuse of streams. // It also tracks metrics for sent requests and messages. type messageSenderImpl struct { @@ -28,7 +41,15 @@ type messageSenderImpl struct { protocols []protocol.ID } -func (m *messageSenderImpl) streamDisconnect(ctx context.Context, p peer.ID) { +func NewMessageSenderImpl(h host.Host, protos []protocol.ID) pb.MessageSender { + return &messageSenderImpl{ + host: h, + strmap: make(map[peer.ID]*peerMessageSender), + protocols: protos, + } +} + +func (m *messageSenderImpl) OnDisconnect(ctx context.Context, p peer.ID) { m.smlk.Lock() defer m.smlk.Unlock() ms, ok := m.strmap[p] @@ -120,7 +141,7 @@ func (m *messageSenderImpl) messageSenderForPeer(ctx context.Context, p peer.ID) m.smlk.Unlock() return ms, nil } - ms = &peerMessageSender{p: p, m: m, lk: newCtxMutex()} + ms = &peerMessageSender{p: p, m: m, lk: internal.NewCtxMutex()} m.strmap[p] = ms m.smlk.Unlock() @@ -149,7 +170,7 @@ func (m *messageSenderImpl) messageSenderForPeer(ctx context.Context, p peer.ID) type peerMessageSender struct { s network.Stream r msgio.ReadCloser - lk ctxMutex + lk internal.CtxMutex p peer.ID m *messageSenderImpl @@ -297,7 +318,7 @@ func (ms *peerMessageSender) SendRequest(ctx context.Context, pmes *pb.Message) } func (ms *peerMessageSender) writeMsg(pmes *pb.Message) error { - return writeMsg(ms.s, pmes) + return WriteMsg(ms.s, pmes) } func (ms *peerMessageSender) ctxReadMsg(ctx context.Context, mes *pb.Message) error { @@ -325,3 +346,37 @@ func (ms *peerMessageSender) ctxReadMsg(ctx context.Context, mes *pb.Message) er return ErrReadTimeout } } + +// The Protobuf writer performs multiple small writes when writing a message. +// We need to buffer those writes, to make sure that we're not sending a new +// packet for every single write. +type bufferedDelimitedWriter struct { + *bufio.Writer + protoio.WriteCloser +} + +var writerPool = sync.Pool{ + New: func() interface{} { + w := bufio.NewWriter(nil) + return &bufferedDelimitedWriter{ + Writer: w, + WriteCloser: protoio.NewDelimitedWriter(w), + } + }, +} + +func WriteMsg(w io.Writer, mes *pb.Message) error { + bw := writerPool.Get().(*bufferedDelimitedWriter) + bw.Reset(w) + err := bw.WriteMsg(mes) + if err == nil { + err = bw.Flush() + } + bw.Reset(nil) + writerPool.Put(bw) + return err +} + +func (w *bufferedDelimitedWriter) Flush() error { + return w.Writer.Flush() +} diff --git a/internal/net/message_manager_test.go b/internal/net/message_manager_test.go new file mode 100644 index 000000000..3bd6d2aea --- /dev/null +++ b/internal/net/message_manager_test.go @@ -0,0 +1,36 @@ +package net + +import ( + "context" + "testing" + + "github.com/libp2p/go-libp2p-core/peer" + "github.com/libp2p/go-libp2p-core/protocol" + + swarmt "github.com/libp2p/go-libp2p-swarm/testing" + bhost "github.com/libp2p/go-libp2p/p2p/host/basic" +) + +func TestInvalidMessageSenderTracking(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + foo := peer.ID("asdasd") + + h := bhost.New(swarmt.GenSwarm(t, ctx, swarmt.OptDisableReuseport)) + + msgSender := NewMessageSenderImpl(h, []protocol.ID{"/test/kad/1.0.0"}).(*messageSenderImpl) + + _, err := msgSender.messageSenderForPeer(ctx, foo) + if err == nil { + t.Fatal("that shouldnt have succeeded") + } + + msgSender.smlk.Lock() + mscnt := len(msgSender.strmap) + msgSender.smlk.Unlock() + + if mscnt > 0 { + t.Fatal("should have no message senders in map") + } +} diff --git a/lookup.go b/lookup.go index dff8bb244..88695dc4a 100644 --- a/lookup.go +++ b/lookup.go @@ -16,7 +16,7 @@ import ( // // If the context is canceled, this function will return the context error along // with the closest K peers it has found so far. -func (dht *IpfsDHT) GetClosestPeers(ctx context.Context, key string) (<-chan peer.ID, error) { +func (dht *IpfsDHT) GetClosestPeers(ctx context.Context, key string) ([]peer.ID, error) { if key == "" { return nil, fmt.Errorf("can't lookup empty key") } @@ -51,17 +51,10 @@ func (dht *IpfsDHT) GetClosestPeers(ctx context.Context, key string) (<-chan pee return nil, err } - out := make(chan peer.ID, dht.bucketSize) - defer close(out) - - for _, p := range lookupRes.peers { - out <- p - } - if ctx.Err() == nil && lookupRes.completed { // refresh the cpl for this key as the query was successful dht.routingTable.ResetCplRefreshedAtForID(kb.ConvertKey(key), time.Now()) } - return out, ctx.Err() + return lookupRes.peers, ctx.Err() } diff --git a/routing.go b/routing.go index 6d31e0c5c..7793bebb4 100644 --- a/routing.go +++ b/routing.go @@ -15,6 +15,7 @@ import ( "github.com/ipfs/go-cid" u "github.com/ipfs/go-ipfs-util" "github.com/libp2p/go-libp2p-kad-dht/internal" + internalConfig "github.com/libp2p/go-libp2p-kad-dht/internal/config" "github.com/libp2p/go-libp2p-kad-dht/qpeerset" kb "github.com/libp2p/go-libp2p-kbucket" record "github.com/libp2p/go-libp2p-record" @@ -64,13 +65,13 @@ func (dht *IpfsDHT) PutValue(ctx context.Context, key string, value []byte, opts return err } - pchan, err := dht.GetClosestPeers(ctx, key) + peers, err := dht.GetClosestPeers(ctx, key) if err != nil { return err } wg := sync.WaitGroup{} - for p := range pchan { + for _, p := range peers { wg.Add(1) go func(p peer.ID) { ctx, cancel := context.WithCancel(ctx) @@ -109,7 +110,7 @@ func (dht *IpfsDHT) GetValue(ctx context.Context, key string, opts ...routing.Op if err := cfg.Apply(opts...); err != nil { return nil, err } - opts = append(opts, Quorum(getQuorum(&cfg, defaultQuorum))) + opts = append(opts, Quorum(internalConfig.GetQuorum(&cfg))) responses, err := dht.SearchValue(ctx, key, opts...) if err != nil { @@ -145,7 +146,7 @@ func (dht *IpfsDHT) SearchValue(ctx context.Context, key string, opts ...routing responsesNeeded := 0 if !cfg.Offline { - responsesNeeded = getQuorum(&cfg, defaultQuorum) + responsesNeeded = internalConfig.GetQuorum(&cfg) } stopCh := make(chan struct{}) @@ -445,7 +446,7 @@ func (dht *IpfsDHT) Provide(ctx context.Context, key cid.Cid, brdcst bool) (err } wg := sync.WaitGroup{} - for p := range peers { + for _, p := range peers { wg.Add(1) go func(p peer.ID) { defer wg.Done() diff --git a/routing_options.go b/routing_options.go index a1e5935b9..7352c098b 100644 --- a/routing_options.go +++ b/routing_options.go @@ -1,10 +1,9 @@ package dht -import "github.com/libp2p/go-libp2p-core/routing" - -type quorumOptionKey struct{} - -const defaultQuorum = 0 +import ( + "github.com/libp2p/go-libp2p-core/routing" + internalConfig "github.com/libp2p/go-libp2p-kad-dht/internal/config" +) // Quorum is a DHT option that tells the DHT how many peers it needs to get // values from before returning the best one. Zero means the DHT query @@ -16,15 +15,7 @@ func Quorum(n int) routing.Option { if opts.Other == nil { opts.Other = make(map[interface{}]interface{}, 1) } - opts.Other[quorumOptionKey{}] = n + opts.Other[internalConfig.QuorumOptionKey{}] = n return nil } } - -func getQuorum(opts *routing.Options, ndefault int) int { - responsesNeeded, ok := opts.Other[quorumOptionKey{}].(int) - if !ok { - responsesNeeded = ndefault - } - return responsesNeeded -} diff --git a/subscriber_notifee.go b/subscriber_notifee.go index 7cc9018f7..00ff4ba03 100644 --- a/subscriber_notifee.go +++ b/subscriber_notifee.go @@ -1,6 +1,7 @@ package dht import ( + "context" "fmt" "github.com/libp2p/go-libp2p-core/event" @@ -151,11 +152,21 @@ func (dht *IpfsDHT) validRTPeer(p peer.ID) (bool, error) { return false, err } - return dht.routingTablePeerFilter == nil || dht.routingTablePeerFilter(dht, dht.Host().Network().ConnsToPeer(p)), nil + return dht.routingTablePeerFilter == nil || dht.routingTablePeerFilter(dht, p), nil +} + +type disconnector interface { + OnDisconnect(ctx context.Context, p peer.ID) } func (nn *subscriberNotifee) Disconnected(n network.Network, v network.Conn) { dht := nn.dht + + ms, ok := dht.msgSender.(disconnector) + if !ok { + return + } + select { case <-dht.Process().Closing(): return @@ -173,7 +184,7 @@ func (nn *subscriberNotifee) Disconnected(n network.Network, v network.Conn) { return } - dht.msgSender.streamDisconnect(dht.Context(), p) + ms.OnDisconnect(dht.Context(), p) } func (nn *subscriberNotifee) Connected(network.Network, network.Conn) {}