From 87c435512c8c98d4faee76c64a2042313db2a02e Mon Sep 17 00:00:00 2001 From: sukun Date: Fri, 21 Jun 2024 14:55:22 +0530 Subject: [PATCH] autonatv2: implement autonatv2 spec (#2469) --- config/config.go | 99 ++- core/network/network.go | 3 + defaults.go | 25 + libp2p_test.go | 6 + options.go | 26 + p2p/host/basic/basic_host.go | 21 + p2p/net/mock/mock_peernet.go | 4 + p2p/net/swarm/black_hole_detector.go | 200 +++--- p2p/net/swarm/black_hole_detector_test.go | 98 ++- p2p/net/swarm/swarm.go | 41 +- p2p/net/swarm/swarm_dial.go | 5 + p2p/net/swarm/swarm_dial_test.go | 2 +- p2p/net/swarm/swarm_metrics.go | 22 +- p2p/net/swarm/swarm_metrics_test.go | 4 +- p2p/net/swarm/testing/testing.go | 24 +- p2p/protocol/autonatv2/autonat.go | 236 +++++++ p2p/protocol/autonatv2/autonat_test.go | 661 +++++++++++++++++ p2p/protocol/autonatv2/client.go | 342 +++++++++ p2p/protocol/autonatv2/msg_reader.go | 38 + p2p/protocol/autonatv2/options.go | 56 ++ p2p/protocol/autonatv2/pb/autonatv2.pb.go | 818 ++++++++++++++++++++++ p2p/protocol/autonatv2/pb/autonatv2.proto | 64 ++ p2p/protocol/autonatv2/server.go | 449 ++++++++++++ p2p/protocol/autonatv2/server_test.go | 484 +++++++++++++ 24 files changed, 3555 insertions(+), 173 deletions(-) create mode 100644 p2p/protocol/autonatv2/autonat.go create mode 100644 p2p/protocol/autonatv2/autonat_test.go create mode 100644 p2p/protocol/autonatv2/client.go create mode 100644 p2p/protocol/autonatv2/msg_reader.go create mode 100644 p2p/protocol/autonatv2/options.go create mode 100644 p2p/protocol/autonatv2/pb/autonatv2.pb.go create mode 100644 p2p/protocol/autonatv2/pb/autonatv2.proto create mode 100644 p2p/protocol/autonatv2/server.go create mode 100644 p2p/protocol/autonatv2/server_test.go diff --git a/config/config.go b/config/config.go index cc689baf7b..ba326dd2d1 100644 --- a/config/config.go +++ b/config/config.go @@ -24,6 +24,7 @@ import ( "github.com/libp2p/go-libp2p/p2p/host/autonat" "github.com/libp2p/go-libp2p/p2p/host/autorelay" bhost "github.com/libp2p/go-libp2p/p2p/host/basic" + blankhost "github.com/libp2p/go-libp2p/p2p/host/blank" "github.com/libp2p/go-libp2p/p2p/host/eventbus" "github.com/libp2p/go-libp2p/p2p/host/peerstore/pstoremem" rcmgr "github.com/libp2p/go-libp2p/p2p/host/resource-manager" @@ -131,6 +132,13 @@ type Config struct { SwarmOpts []swarm.Option DisableIdentifyAddressDiscovery bool + + EnableAutoNATv2 bool + + UDPBlackHoleSuccessCounter *swarm.BlackHoleSuccessCounter + CustomUDPBlackHoleSuccessCounter bool + IPv6BlackHoleSuccessCounter *swarm.BlackHoleSuccessCounter + CustomIPv6BlackHoleSuccessCounter bool } func (cfg *Config) makeSwarm(eventBus event.Bus, enableMetrics bool) (*swarm.Swarm, error) { @@ -165,7 +173,10 @@ func (cfg *Config) makeSwarm(eventBus event.Bus, enableMetrics bool) (*swarm.Swa return nil, err } - opts := cfg.SwarmOpts + opts := append(cfg.SwarmOpts, + swarm.WithUDPBlackHoleSuccessCounter(cfg.UDPBlackHoleSuccessCounter), + swarm.WithIPv6BlackHoleSuccessCounter(cfg.IPv6BlackHoleSuccessCounter), + ) if cfg.Reporter != nil { opts = append(opts, swarm.WithMetrics(cfg.Reporter)) } @@ -193,6 +204,77 @@ func (cfg *Config) makeSwarm(eventBus event.Bus, enableMetrics bool) (*swarm.Swa return swarm.NewSwarm(pid, cfg.Peerstore, eventBus, opts...) } +func (cfg *Config) makeAutoNATV2Host() (host.Host, error) { + autonatPrivKey, _, err := crypto.GenerateEd25519Key(rand.Reader) + if err != nil { + return nil, err + } + ps, err := pstoremem.NewPeerstore() + if err != nil { + return nil, err + } + + autoNatCfg := Config{ + Transports: cfg.Transports, + Muxers: cfg.Muxers, + SecurityTransports: cfg.SecurityTransports, + Insecure: cfg.Insecure, + PSK: cfg.PSK, + ConnectionGater: cfg.ConnectionGater, + Reporter: cfg.Reporter, + PeerKey: autonatPrivKey, + Peerstore: ps, + DialRanker: swarm.NoDelayDialRanker, + UDPBlackHoleSuccessCounter: cfg.UDPBlackHoleSuccessCounter, + IPv6BlackHoleSuccessCounter: cfg.IPv6BlackHoleSuccessCounter, + ResourceManager: cfg.ResourceManager, + SwarmOpts: []swarm.Option{ + // Don't update black hole state for failed autonat dials + swarm.WithReadOnlyBlackHoleDetector(), + }, + } + fxopts, err := autoNatCfg.addTransports() + if err != nil { + return nil, err + } + var dialerHost host.Host + fxopts = append(fxopts, + fx.Provide(eventbus.NewBus), + fx.Provide(func(lifecycle fx.Lifecycle, b event.Bus) (*swarm.Swarm, error) { + lifecycle.Append(fx.Hook{ + OnStop: func(context.Context) error { + return ps.Close() + }}) + sw, err := autoNatCfg.makeSwarm(b, false) + return sw, err + }), + fx.Provide(func(sw *swarm.Swarm) *blankhost.BlankHost { + return blankhost.NewBlankHost(sw) + }), + fx.Provide(func(bh *blankhost.BlankHost) host.Host { + return bh + }), + fx.Provide(func() crypto.PrivKey { return autonatPrivKey }), + fx.Provide(func(bh host.Host) peer.ID { return bh.ID() }), + fx.Invoke(func(bh *blankhost.BlankHost) { + dialerHost = bh + }), + ) + app := fx.New(fxopts...) + if err := app.Err(); err != nil { + return nil, err + } + err = app.Start(context.Background()) + if err != nil { + return nil, err + } + go func() { + <-dialerHost.Network().(*swarm.Swarm).Done() + app.Stop(context.Background()) + }() + return dialerHost, nil +} + func (cfg *Config) addTransports() ([]fx.Option, error) { fxopts := []fx.Option{ fx.WithLogger(func() fxevent.Logger { return getFXLogger() }), @@ -291,6 +373,14 @@ func (cfg *Config) addTransports() ([]fx.Option, error) { } func (cfg *Config) newBasicHost(swrm *swarm.Swarm, eventBus event.Bus) (*bhost.BasicHost, error) { + var autonatv2Dialer host.Host + if cfg.EnableAutoNATv2 { + ah, err := cfg.makeAutoNATV2Host() + if err != nil { + return nil, err + } + autonatv2Dialer = ah + } h, err := bhost.NewHost(swrm, &bhost.HostOpts{ EventBus: eventBus, ConnManager: cfg.ConnManager, @@ -306,6 +396,8 @@ func (cfg *Config) newBasicHost(swrm *swarm.Swarm, eventBus event.Bus) (*bhost.B EnableMetrics: !cfg.DisableMetrics, PrometheusRegisterer: cfg.PrometheusRegisterer, DisableIdentifyAddressDiscovery: cfg.DisableIdentifyAddressDiscovery, + EnableAutoNATv2: cfg.EnableAutoNATv2, + AutoNATv2Dialer: autonatv2Dialer, }) if err != nil { return nil, err @@ -488,9 +580,8 @@ func (cfg *Config) addAutoNAT(h *bhost.BasicHost) error { Peerstore: ps, DialRanker: swarm.NoDelayDialRanker, SwarmOpts: []swarm.Option{ - // It is better to disable black hole detection and just attempt a dial for autonat - swarm.WithUDPBlackHoleConfig(false, 0, 0), - swarm.WithIPv6BlackHoleConfig(false, 0, 0), + swarm.WithUDPBlackHoleSuccessCounter(nil), + swarm.WithIPv6BlackHoleSuccessCounter(nil), }, } diff --git a/core/network/network.go b/core/network/network.go index 22efbf235d..d2e2bc818d 100644 --- a/core/network/network.go +++ b/core/network/network.go @@ -194,6 +194,9 @@ type Dialer interface { // Notify/StopNotify register and unregister a notifiee for signals Notify(Notifiee) StopNotify(Notifiee) + + // CanDial returns whether the dialer can dial peer p at addr + CanDial(p peer.ID, addr ma.Multiaddr) bool } // AddrDelay provides an address along with the delay after which the address diff --git a/defaults.go b/defaults.go index 8f7dc94671..1aba3bd7e2 100644 --- a/defaults.go +++ b/defaults.go @@ -10,6 +10,7 @@ import ( rcmgr "github.com/libp2p/go-libp2p/p2p/host/resource-manager" "github.com/libp2p/go-libp2p/p2p/muxer/yamux" "github.com/libp2p/go-libp2p/p2p/net/connmgr" + "github.com/libp2p/go-libp2p/p2p/net/swarm" "github.com/libp2p/go-libp2p/p2p/security/noise" tls "github.com/libp2p/go-libp2p/p2p/security/tls" quic "github.com/libp2p/go-libp2p/p2p/transport/quic" @@ -133,6 +134,18 @@ var DefaultPrometheusRegisterer = func(cfg *Config) error { return cfg.Apply(PrometheusRegisterer(prometheus.DefaultRegisterer)) } +var defaultUDPBlackHoleDetector = func(cfg *Config) error { + // A black hole is a binary property. On a network if UDP dials are blocked, all dials will + // fail. So a low success rate of 5 out 100 dials is good enough. + return cfg.Apply(UDPBlackHoleSuccessCounter(&swarm.BlackHoleSuccessCounter{N: 100, MinSuccesses: 5, Name: "UDP"})) +} + +var defaultIPv6BlackHoleDetector = func(cfg *Config) error { + // A black hole is a binary property. On a network if there is no IPv6 connectivity, all + // dials will fail. So a low success rate of 5 out 100 dials is good enough. + return cfg.Apply(IPv6BlackHoleSuccessCounter(&swarm.BlackHoleSuccessCounter{N: 100, MinSuccesses: 5, Name: "IPv6"})) +} + // Complete list of default options and when to fallback on them. // // Please *DON'T* specify default options any other way. Putting this all here @@ -189,6 +202,18 @@ var defaults = []struct { fallback: func(cfg *Config) bool { return !cfg.DisableMetrics && cfg.PrometheusRegisterer == nil }, opt: DefaultPrometheusRegisterer, }, + { + fallback: func(cfg *Config) bool { + return !cfg.CustomUDPBlackHoleSuccessCounter && cfg.UDPBlackHoleSuccessCounter == nil + }, + opt: defaultUDPBlackHoleDetector, + }, + { + fallback: func(cfg *Config) bool { + return !cfg.CustomIPv6BlackHoleSuccessCounter && cfg.IPv6BlackHoleSuccessCounter == nil + }, + opt: defaultIPv6BlackHoleDetector, + }, } // Defaults configures libp2p to use the default options. Can be combined with diff --git a/libp2p_test.go b/libp2p_test.go index 8a9a8edcb5..ea52e56470 100644 --- a/libp2p_test.go +++ b/libp2p_test.go @@ -397,6 +397,12 @@ func TestInsecureConstructor(t *testing.T) { h.Close() } +func TestAutoNATv2Service(t *testing.T) { + h, err := New(EnableAutoNATv2()) + require.NoError(t, err) + h.Close() +} + func TestDisableIdentifyAddressDiscovery(t *testing.T) { h, err := New(DisableIdentifyAddressDiscovery()) require.NoError(t, err) diff --git a/options.go b/options.go index de95251ad3..1a8bc5dd55 100644 --- a/options.go +++ b/options.go @@ -609,3 +609,29 @@ func DisableIdentifyAddressDiscovery() Option { return nil } } + +// EnableAutoNATv2 enables autonat v2 +func EnableAutoNATv2() Option { + return func(cfg *Config) error { + cfg.EnableAutoNATv2 = true + return nil + } +} + +// UDPBlackHoleSuccessCounter configures libp2p to use f as the black hole filter for UDP addrs +func UDPBlackHoleSuccessCounter(f *swarm.BlackHoleSuccessCounter) Option { + return func(cfg *Config) error { + cfg.UDPBlackHoleSuccessCounter = f + cfg.CustomUDPBlackHoleSuccessCounter = true + return nil + } +} + +// IPv6BlackHoleSuccessCounter configures libp2p to use f as the black hole filter for IPv6 addrs +func IPv6BlackHoleSuccessCounter(f *swarm.BlackHoleSuccessCounter) Option { + return func(cfg *Config) error { + cfg.IPv6BlackHoleSuccessCounter = f + cfg.CustomIPv6BlackHoleSuccessCounter = true + return nil + } +} diff --git a/p2p/host/basic/basic_host.go b/p2p/host/basic/basic_host.go index 8fc808e6b6..766b1b13bc 100644 --- a/p2p/host/basic/basic_host.go +++ b/p2p/host/basic/basic_host.go @@ -24,6 +24,7 @@ import ( "github.com/libp2p/go-libp2p/p2p/host/eventbus" "github.com/libp2p/go-libp2p/p2p/host/pstoremanager" "github.com/libp2p/go-libp2p/p2p/host/relaysvc" + "github.com/libp2p/go-libp2p/p2p/protocol/autonatv2" relayv2 "github.com/libp2p/go-libp2p/p2p/protocol/circuitv2/relay" "github.com/libp2p/go-libp2p/p2p/protocol/holepunch" "github.com/libp2p/go-libp2p/p2p/protocol/identify" @@ -105,6 +106,8 @@ type BasicHost struct { caBook peerstore.CertifiedAddrBook autoNat autonat.AutoNAT + + autonatv2 *autonatv2.AutoNAT } var _ host.Host = (*BasicHost)(nil) @@ -167,6 +170,8 @@ type HostOpts struct { // DisableIdentifyAddressDiscovery disables address discovery using peer provided observed addresses in identify DisableIdentifyAddressDiscovery bool + EnableAutoNATv2 bool + AutoNATv2Dialer host.Host } // NewHost constructs a new *BasicHost and activates it by attaching its stream and connection handlers to the given inet.Network. @@ -310,6 +315,13 @@ func NewHost(n network.Network, opts *HostOpts) (*BasicHost, error) { h.pings = ping.NewPingService(h) } + if opts.EnableAutoNATv2 { + h.autonatv2, err = autonatv2.New(h, opts.AutoNATv2Dialer) + if err != nil { + return nil, fmt.Errorf("failed to create autonatv2: %w", err) + } + } + n.SetStreamHandler(h.newStreamHandler) // register to be notified when the network's listen addrs change, @@ -398,6 +410,12 @@ func (h *BasicHost) Start() { h.psManager.Start() h.refCount.Add(1) h.ids.Start() + if h.autonatv2 != nil { + err := h.autonatv2.Start() + if err != nil { + log.Errorf("autonat v2 failed to start: %s", err) + } + } go h.background() } @@ -1100,6 +1118,9 @@ func (h *BasicHost) Close() error { if h.hps != nil { h.hps.Close() } + if h.autonatv2 != nil { + h.autonatv2.Close() + } _ = h.emitters.evtLocalProtocolsUpdated.Close() _ = h.emitters.evtLocalAddrsUpdated.Close() diff --git a/p2p/net/mock/mock_peernet.go b/p2p/net/mock/mock_peernet.go index 2e56b7f2bb..0b525d3e64 100644 --- a/p2p/net/mock/mock_peernet.go +++ b/p2p/net/mock/mock_peernet.go @@ -434,3 +434,7 @@ func (pn *peernet) notifyAll(notification func(f network.Notifiee)) { func (pn *peernet) ResourceManager() network.ResourceManager { return &network.NullResourceManager{} } + +func (pn *peernet) CanDial(p peer.ID, addr ma.Multiaddr) bool { + return true +} diff --git a/p2p/net/swarm/black_hole_detector.go b/p2p/net/swarm/black_hole_detector.go index dd7849eea6..54782c1c01 100644 --- a/p2p/net/swarm/black_hole_detector.go +++ b/p2p/net/swarm/black_hole_detector.go @@ -29,35 +29,26 @@ func (st blackHoleState) String() string { } } -type blackHoleResult int - -const ( - blackHoleResultAllowed blackHoleResult = iota - blackHoleResultProbing - blackHoleResultBlocked -) - -// blackHoleFilter provides black hole filtering for dials. This filter should be used in -// concert with a UDP of IPv6 address filter to detect UDP or IPv6 black hole. In a black -// holed environments dial requests are blocked and only periodic probes to check the -// state of the black hole are allowed. -// -// Requests are blocked if the number of successes in the last n dials is less than -// minSuccesses. If a request succeeds in Blocked state, the filter state is reset and n -// subsequent requests are allowed before reevaluating black hole state. Dials cancelled -// when some other concurrent dial succeeded are counted as failures. A sufficiently large -// n prevents false negatives in such cases. -type blackHoleFilter struct { - // n serves the dual purpose of being the minimum number of requests after which we - // probe the state of the black hole in blocked state and the minimum number of - // completed dials required before evaluating black hole state. - n int - // minSuccesses is the minimum number of Success required in the last n dials +// BlackHoleSuccessCounter provides black hole filtering for dials. This filter should be used in concert +// with a UDP or IPv6 address filter to detect UDP or IPv6 black hole. In a black holed environment, +// dial requests are refused Requests are blocked if the number of successes in the last N dials is +// less than MinSuccesses. +// If a request succeeds in Blocked state, the filter state is reset and N subsequent requests are +// allowed before reevaluating black hole state. Dials cancelled when some other concurrent dial +// succeeded are counted as failures. A sufficiently large N prevents false negatives in such cases. +type BlackHoleSuccessCounter struct { + // N is + // 1. The minimum number of completed dials required before evaluating black hole state + // 2. the minimum number of requests after which we probe the state of the black hole in + // blocked state + N int + // MinSuccesses is the minimum number of Success required in the last n dials // to consider we are not blocked. - minSuccesses int - // name for the detector. - name string + MinSuccesses int + // Name for the detector. + Name string + mu sync.Mutex // requests counts number of dial requests to peers. We handle request at a peer // level and record results at individual address dial level. requests int @@ -67,22 +58,19 @@ type blackHoleFilter struct { successes int // state is the current state of the detector state blackHoleState - - mu sync.Mutex - metricsTracer MetricsTracer } -// RecordResult records the outcome of a dial. A successful dial will change the state -// of the filter to Allowed. A failed dial only blocks subsequent requests if the success +// RecordResult records the outcome of a dial. A successful dial in Blocked state will change the +// state of the filter to Probing. A failed dial only blocks subsequent requests if the success // fraction over the last n outcomes is less than the minSuccessFraction of the filter. -func (b *blackHoleFilter) RecordResult(success bool) { +func (b *BlackHoleSuccessCounter) RecordResult(success bool) { b.mu.Lock() defer b.mu.Unlock() if b.state == blackHoleStateBlocked && success { // If the call succeeds in a blocked state we reset to allowed. // This is better than slowly accumulating values till we cross the minSuccessFraction - // threshold since a blackhole is a binary property. + // threshold since a black hole is a binary property. b.reset() return } @@ -92,7 +80,7 @@ func (b *blackHoleFilter) RecordResult(success bool) { } b.dialResults = append(b.dialResults, success) - if len(b.dialResults) > b.n { + if len(b.dialResults) > b.N { if b.dialResults[0] { b.successes-- } @@ -100,58 +88,68 @@ func (b *blackHoleFilter) RecordResult(success bool) { } b.updateState() - b.trackMetrics() } // HandleRequest returns the result of applying the black hole filter for the request. -func (b *blackHoleFilter) HandleRequest() blackHoleResult { +func (b *BlackHoleSuccessCounter) HandleRequest() blackHoleState { b.mu.Lock() defer b.mu.Unlock() b.requests++ - b.trackMetrics() - if b.state == blackHoleStateAllowed { - return blackHoleResultAllowed - } else if b.state == blackHoleStateProbing || b.requests%b.n == 0 { - return blackHoleResultProbing + return blackHoleStateAllowed + } else if b.state == blackHoleStateProbing || b.requests%b.N == 0 { + return blackHoleStateProbing } else { - return blackHoleResultBlocked + return blackHoleStateBlocked } } -func (b *blackHoleFilter) reset() { +func (b *BlackHoleSuccessCounter) reset() { b.successes = 0 b.dialResults = b.dialResults[:0] b.requests = 0 b.updateState() } -func (b *blackHoleFilter) updateState() { +func (b *BlackHoleSuccessCounter) updateState() { st := b.state - if len(b.dialResults) < b.n { + if len(b.dialResults) < b.N { b.state = blackHoleStateProbing - } else if b.successes >= b.minSuccesses { + } else if b.successes >= b.MinSuccesses { b.state = blackHoleStateAllowed } else { b.state = blackHoleStateBlocked } if st != b.state { - log.Debugf("%s blackHoleDetector state changed from %s to %s", b.name, st, b.state) + log.Debugf("%s blackHoleDetector state changed from %s to %s", b.Name, st, b.state) } } -func (b *blackHoleFilter) trackMetrics() { - if b.metricsTracer == nil { - return - } +func (b *BlackHoleSuccessCounter) State() blackHoleState { + b.mu.Lock() + defer b.mu.Unlock() + + return b.state +} - nextRequestAllowedAfter := 0 +type blackHoleInfo struct { + name string + state blackHoleState + nextProbeAfter int + successFraction float64 +} + +func (b *BlackHoleSuccessCounter) info() blackHoleInfo { + b.mu.Lock() + defer b.mu.Unlock() + + nextProbeAfter := 0 if b.state == blackHoleStateBlocked { - nextRequestAllowedAfter = b.n - (b.requests % b.n) + nextProbeAfter = b.N - (b.requests % b.N) } successFraction := 0.0 @@ -159,22 +157,27 @@ func (b *blackHoleFilter) trackMetrics() { successFraction = float64(b.successes) / float64(len(b.dialResults)) } - b.metricsTracer.UpdatedBlackHoleFilterState( - b.name, - b.state, - nextRequestAllowedAfter, - successFraction, - ) + return blackHoleInfo{ + name: b.Name, + state: b.state, + nextProbeAfter: nextProbeAfter, + successFraction: successFraction, + } } -// blackHoleDetector provides UDP and IPv6 black hole detection using a `blackHoleFilter` -// for each. For details of the black hole detection logic see `blackHoleFilter`. +// blackHoleDetector provides UDP and IPv6 black hole detection using a `BlackHoleSuccessCounter` for each. +// For details of the black hole detection logic see `BlackHoleSuccessCounter`. +// In Read Only mode, detector doesn't update the state of underlying filters and refuses requests +// when black hole state is unknown. This is useful for Swarms made specifically for services like +// AutoNAT where we care about accurately reporting the reachability of a peer. // -// black hole filtering is done at a peer dial level to ensure that periodic probes to -// detect change of the black hole state are actually dialed and are not skipped -// because of dial prioritisation logic. +// Black hole filtering is done at a peer dial level to ensure that periodic probes to detect change +// of the black hole state are actually dialed and are not skipped because of dial prioritisation +// logic. type blackHoleDetector struct { - udp, ipv6 *blackHoleFilter + udp, ipv6 *BlackHoleSuccessCounter + mt MetricsTracer + readOnly bool } // FilterAddrs filters the peer's addresses removing black holed addresses @@ -192,14 +195,16 @@ func (d *blackHoleDetector) FilterAddrs(addrs []ma.Multiaddr) (valid []ma.Multia } } - udpRes := blackHoleResultAllowed + udpRes := blackHoleStateAllowed if d.udp != nil && hasUDP { - udpRes = d.udp.HandleRequest() + udpRes = d.getFilterState(d.udp) + d.trackMetrics(d.udp) } - ipv6Res := blackHoleResultAllowed + ipv6Res := blackHoleStateAllowed if d.ipv6 != nil && hasIPv6 { - ipv6Res = d.ipv6.HandleRequest() + ipv6Res = d.getFilterState(d.ipv6) + d.trackMetrics(d.ipv6) } blackHoled = make([]ma.Multiaddr, 0, len(addrs)) @@ -210,19 +215,19 @@ func (d *blackHoleDetector) FilterAddrs(addrs []ma.Multiaddr) (valid []ma.Multia return true } // allow all UDP addresses while probing irrespective of IPv6 black hole state - if udpRes == blackHoleResultProbing && isProtocolAddr(a, ma.P_UDP) { + if udpRes == blackHoleStateProbing && isProtocolAddr(a, ma.P_UDP) { return true } // allow all IPv6 addresses while probing irrespective of UDP black hole state - if ipv6Res == blackHoleResultProbing && isProtocolAddr(a, ma.P_IP6) { + if ipv6Res == blackHoleStateProbing && isProtocolAddr(a, ma.P_IP6) { return true } - if udpRes == blackHoleResultBlocked && isProtocolAddr(a, ma.P_UDP) { + if udpRes == blackHoleStateBlocked && isProtocolAddr(a, ma.P_UDP) { blackHoled = append(blackHoled, a) return false } - if ipv6Res == blackHoleResultBlocked && isProtocolAddr(a, ma.P_IP6) { + if ipv6Res == blackHoleStateBlocked && isProtocolAddr(a, ma.P_IP6) { blackHoled = append(blackHoled, a) return false } @@ -231,49 +236,36 @@ func (d *blackHoleDetector) FilterAddrs(addrs []ma.Multiaddr) (valid []ma.Multia ), blackHoled } -// RecordResult updates the state of the relevant `blackHoleFilter`s for addr +// RecordResult updates the state of the relevant BlackHoleSuccessCounters for addr func (d *blackHoleDetector) RecordResult(addr ma.Multiaddr, success bool) { - if !manet.IsPublicAddr(addr) { + if d.readOnly || !manet.IsPublicAddr(addr) { return } if d.udp != nil && isProtocolAddr(addr, ma.P_UDP) { d.udp.RecordResult(success) + d.trackMetrics(d.udp) } if d.ipv6 != nil && isProtocolAddr(addr, ma.P_IP6) { d.ipv6.RecordResult(success) + d.trackMetrics(d.ipv6) } } -// blackHoleConfig is the config used for black hole detection -type blackHoleConfig struct { - // Enabled enables black hole detection - Enabled bool - // N is the size of the sliding window used to evaluate black hole state - N int - // MinSuccesses is the minimum number of successes out of N required to not - // block requests - MinSuccesses int -} - -func newBlackHoleDetector(udpConfig, ipv6Config blackHoleConfig, mt MetricsTracer) *blackHoleDetector { - d := &blackHoleDetector{} - - if udpConfig.Enabled { - d.udp = &blackHoleFilter{ - n: udpConfig.N, - minSuccesses: udpConfig.MinSuccesses, - name: "UDP", - metricsTracer: mt, +func (d *blackHoleDetector) getFilterState(f *BlackHoleSuccessCounter) blackHoleState { + if d.readOnly { + if f.State() != blackHoleStateAllowed { + return blackHoleStateBlocked } + return blackHoleStateAllowed } + return f.HandleRequest() +} - if ipv6Config.Enabled { - d.ipv6 = &blackHoleFilter{ - n: ipv6Config.N, - minSuccesses: ipv6Config.MinSuccesses, - name: "IPv6", - metricsTracer: mt, - } +func (d *blackHoleDetector) trackMetrics(f *BlackHoleSuccessCounter) { + if d.readOnly || d.mt == nil { + return } - return d + // Track metrics only in non readOnly state + info := f.info() + d.mt.UpdatedBlackHoleSuccessCounter(info.name, info.state, info.nextProbeAfter, info.successFraction) } diff --git a/p2p/net/swarm/black_hole_detector_test.go b/p2p/net/swarm/black_hole_detector_test.go index 1ab2cbe587..a38b43f4ce 100644 --- a/p2p/net/swarm/black_hole_detector_test.go +++ b/p2p/net/swarm/black_hole_detector_test.go @@ -8,58 +8,70 @@ import ( "github.com/stretchr/testify/require" ) -func TestBlackHoleFilterReset(t *testing.T) { +func TestBlackHoleSuccessCounterReset(t *testing.T) { n := 10 - bhf := &blackHoleFilter{n: n, minSuccesses: 2, name: "test"} + bhf := &BlackHoleSuccessCounter{N: n, MinSuccesses: 2, Name: "test"} var i = 0 // calls up to n should be probing for i = 1; i <= n; i++ { - if bhf.HandleRequest() != blackHoleResultProbing { + if bhf.HandleRequest() != blackHoleStateProbing { t.Fatalf("expected calls up to n to be probes") } + if bhf.State() != blackHoleStateProbing { + t.Fatalf("expected state to be probing got %s", bhf.State()) + } bhf.RecordResult(false) } // after threshold calls every nth call should be a probe for i = n + 1; i < 42; i++ { result := bhf.HandleRequest() - if (i%n == 0 && result != blackHoleResultProbing) || (i%n != 0 && result != blackHoleResultBlocked) { + if (i%n == 0 && result != blackHoleStateProbing) || (i%n != 0 && result != blackHoleStateBlocked) { t.Fatalf("expected every nth dial to be a probe") } + if bhf.State() != blackHoleStateBlocked { + t.Fatalf("expected state to be blocked, got %s", bhf.State()) + } } bhf.RecordResult(true) // check if calls up to n are probes again for i = 0; i < n; i++ { - if bhf.HandleRequest() != blackHoleResultProbing { + if bhf.HandleRequest() != blackHoleStateProbing { t.Fatalf("expected black hole detector state to reset after success") } + if bhf.State() != blackHoleStateProbing { + t.Fatalf("expected state to be probing got %s", bhf.State()) + } bhf.RecordResult(false) } // next call should be blocked - if bhf.HandleRequest() != blackHoleResultBlocked { + if bhf.HandleRequest() != blackHoleStateBlocked { t.Fatalf("expected dial to be blocked") + if bhf.State() != blackHoleStateBlocked { + t.Fatalf("expected state to be blocked, got %s", bhf.State()) + } } } -func TestBlackHoleFilterSuccessFraction(t *testing.T) { +func TestBlackHoleSuccessCounterSuccessFraction(t *testing.T) { n := 10 tests := []struct { minSuccesses, successes int - result blackHoleResult + result blackHoleState }{ - {minSuccesses: 5, successes: 5, result: blackHoleResultAllowed}, - {minSuccesses: 3, successes: 3, result: blackHoleResultAllowed}, - {minSuccesses: 5, successes: 4, result: blackHoleResultBlocked}, - {minSuccesses: 5, successes: 7, result: blackHoleResultAllowed}, - {minSuccesses: 3, successes: 1, result: blackHoleResultBlocked}, - {minSuccesses: 0, successes: 0, result: blackHoleResultAllowed}, - {minSuccesses: 10, successes: 10, result: blackHoleResultAllowed}, + {minSuccesses: 5, successes: 5, result: blackHoleStateAllowed}, + {minSuccesses: 3, successes: 3, result: blackHoleStateAllowed}, + {minSuccesses: 5, successes: 4, result: blackHoleStateBlocked}, + {minSuccesses: 5, successes: 7, result: blackHoleStateAllowed}, + {minSuccesses: 3, successes: 1, result: blackHoleStateBlocked}, + {minSuccesses: 0, successes: 0, result: blackHoleStateAllowed}, + {minSuccesses: 10, successes: 10, result: blackHoleStateAllowed}, } for i, tc := range tests { t.Run(fmt.Sprintf("case-%d", i), func(t *testing.T) { - bhf := blackHoleFilter{n: n, minSuccesses: tc.minSuccesses} + bhf := BlackHoleSuccessCounter{N: n, MinSuccesses: tc.minSuccesses} for i := 0; i < tc.successes; i++ { bhf.RecordResult(true) } @@ -75,9 +87,9 @@ func TestBlackHoleFilterSuccessFraction(t *testing.T) { } func TestBlackHoleDetectorInApplicableAddress(t *testing.T) { - udpConfig := blackHoleConfig{Enabled: true, N: 10, MinSuccesses: 5} - ipv6Config := blackHoleConfig{Enabled: true, N: 10, MinSuccesses: 5} - bhd := newBlackHoleDetector(udpConfig, ipv6Config, nil) + udpF := &BlackHoleSuccessCounter{N: 10, MinSuccesses: 5} + ipv6F := &BlackHoleSuccessCounter{N: 10, MinSuccesses: 5} + bhd := &blackHoleDetector{udp: udpF, ipv6: ipv6F} addrs := []ma.Multiaddr{ ma.StringCast("/ip4/1.2.3.4/tcp/1234"), ma.StringCast("/ip4/1.2.3.4/tcp/1233"), @@ -94,8 +106,8 @@ func TestBlackHoleDetectorInApplicableAddress(t *testing.T) { } func TestBlackHoleDetectorUDPDisabled(t *testing.T) { - ipv6Config := blackHoleConfig{Enabled: true, N: 10, MinSuccesses: 5} - bhd := newBlackHoleDetector(blackHoleConfig{Enabled: false}, ipv6Config, nil) + ipv6F := &BlackHoleSuccessCounter{N: 10, MinSuccesses: 5} + bhd := &blackHoleDetector{ipv6: ipv6F} publicAddr := ma.StringCast("/ip4/1.2.3.4/udp/1234/quic-v1") privAddr := ma.StringCast("/ip4/192.168.1.5/udp/1234/quic-v1") for i := 0; i < 100; i++ { @@ -110,8 +122,8 @@ func TestBlackHoleDetectorUDPDisabled(t *testing.T) { } func TestBlackHoleDetectorIPv6Disabled(t *testing.T) { - udpConfig := blackHoleConfig{Enabled: true, N: 10, MinSuccesses: 5} - bhd := newBlackHoleDetector(udpConfig, blackHoleConfig{Enabled: false}, nil) + udpF := &BlackHoleSuccessCounter{N: 10, MinSuccesses: 5} + bhd := &blackHoleDetector{udp: udpF} publicAddr := ma.StringCast("/ip6/2001::1/tcp/1234") privAddr := ma.StringCast("/ip6/::1/tcp/1234") for i := 0; i < 100; i++ { @@ -128,8 +140,8 @@ func TestBlackHoleDetectorIPv6Disabled(t *testing.T) { func TestBlackHoleDetectorProbes(t *testing.T) { bhd := &blackHoleDetector{ - udp: &blackHoleFilter{n: 2, minSuccesses: 1, name: "udp"}, - ipv6: &blackHoleFilter{n: 3, minSuccesses: 1, name: "ipv6"}, + udp: &BlackHoleSuccessCounter{N: 2, MinSuccesses: 1, Name: "udp"}, + ipv6: &BlackHoleSuccessCounter{N: 3, MinSuccesses: 1, Name: "ipv6"}, } udp6Addr := ma.StringCast("/ip6/2001::1/udp/1234/quic-v1") addrs := []ma.Multiaddr{udp6Addr} @@ -163,8 +175,8 @@ func TestBlackHoleDetectorAddrFiltering(t *testing.T) { makeBHD := func(udpBlocked, ipv6Blocked bool) *blackHoleDetector { bhd := &blackHoleDetector{ - udp: &blackHoleFilter{n: 100, minSuccesses: 10, name: "udp"}, - ipv6: &blackHoleFilter{n: 100, minSuccesses: 10, name: "ipv6"}, + udp: &BlackHoleSuccessCounter{N: 100, MinSuccesses: 10, Name: "udp"}, + ipv6: &BlackHoleSuccessCounter{N: 100, MinSuccesses: 10, Name: "ipv6"}, } for i := 0; i < 100; i++ { bhd.RecordResult(udp4Pub, !udpBlocked) @@ -199,3 +211,35 @@ func TestBlackHoleDetectorAddrFiltering(t *testing.T) { require.ElementsMatch(t, bothBlockedOutput, gotAddrs) require.ElementsMatch(t, bothPublicAddrs, gotRemovedAddrs) } + +func TestBlackHoleDetectorReadOnlyMode(t *testing.T) { + udpF := &BlackHoleSuccessCounter{N: 10, MinSuccesses: 5} + ipv6F := &BlackHoleSuccessCounter{N: 10, MinSuccesses: 5} + bhd := &blackHoleDetector{udp: udpF, ipv6: ipv6F, readOnly: true} + publicAddr := ma.StringCast("/ip4/1.2.3.4/udp/1234/quic-v1") + privAddr := ma.StringCast("/ip6/::1/tcp/1234") + for i := 0; i < 100; i++ { + bhd.RecordResult(publicAddr, true) + } + allAddr := []ma.Multiaddr{privAddr, publicAddr} + // public addr filtered because state is probing + wantAddrs := []ma.Multiaddr{privAddr} + wantRemovedAddrs := []ma.Multiaddr{publicAddr} + + gotAddrs, gotRemovedAddrs := bhd.FilterAddrs(allAddr) + require.ElementsMatch(t, wantAddrs, gotAddrs) + require.ElementsMatch(t, wantRemovedAddrs, gotRemovedAddrs) + + // a non readonly shared state black hole detector + nbhd := &blackHoleDetector{udp: bhd.udp, ipv6: bhd.ipv6, readOnly: false} + for i := 0; i < 100; i++ { + nbhd.RecordResult(publicAddr, true) + } + // no addresses filtered because state is allowed + wantAddrs = []ma.Multiaddr{privAddr, publicAddr} + wantRemovedAddrs = []ma.Multiaddr{} + + gotAddrs, gotRemovedAddrs = bhd.FilterAddrs(allAddr) + require.ElementsMatch(t, wantAddrs, gotAddrs) + require.ElementsMatch(t, wantRemovedAddrs, gotRemovedAddrs) +} diff --git a/p2p/net/swarm/swarm.go b/p2p/net/swarm/swarm.go index 7897277cc7..02cff1e881 100644 --- a/p2p/net/swarm/swarm.go +++ b/p2p/net/swarm/swarm.go @@ -112,22 +112,33 @@ func WithDialRanker(d network.DialRanker) Option { } } -// WithUDPBlackHoleConfig configures swarm to use c as the config for UDP black hole detection +// WithUDPBlackHoleSuccessCounter configures swarm to use the provided config for UDP black hole detection // n is the size of the sliding window used to evaluate black hole state // min is the minimum number of successes out of n required to not block requests -func WithUDPBlackHoleConfig(enabled bool, n, min int) Option { +func WithUDPBlackHoleSuccessCounter(f *BlackHoleSuccessCounter) Option { return func(s *Swarm) error { - s.udpBlackHoleConfig = blackHoleConfig{Enabled: enabled, N: n, MinSuccesses: min} + s.udpBHF = f return nil } } -// WithIPv6BlackHoleConfig configures swarm to use c as the config for IPv6 black hole detection +// WithIPv6BlackHoleSuccessCounter configures swarm to use the provided config for IPv6 black hole detection // n is the size of the sliding window used to evaluate black hole state // min is the minimum number of successes out of n required to not block requests -func WithIPv6BlackHoleConfig(enabled bool, n, min int) Option { +func WithIPv6BlackHoleSuccessCounter(f *BlackHoleSuccessCounter) Option { return func(s *Swarm) error { - s.ipv6BlackHoleConfig = blackHoleConfig{Enabled: enabled, N: n, MinSuccesses: min} + s.ipv6BHF = f + return nil + } +} + +// WithReadOnlyBlackHoleDetector configures the swarm to use the black hole detector in +// read only mode. In Read Only mode dial requests are refused in unknown state and +// no updates to the detector state are made. This is useful for services like AutoNAT that +// care about accurately providing reachability info. +func WithReadOnlyBlackHoleDetector() Option { + return func(s *Swarm) error { + s.readOnlyBHD = true return nil } } @@ -203,10 +214,11 @@ type Swarm struct { dialRanker network.DialRanker - udpBlackHoleConfig blackHoleConfig - ipv6BlackHoleConfig blackHoleConfig - bhd *blackHoleDetector connectednessEventEmitter *connectednessEventEmitter + udpBHF *BlackHoleSuccessCounter + ipv6BHF *BlackHoleSuccessCounter + bhd *blackHoleDetector + readOnlyBHD bool } // NewSwarm constructs a Swarm. @@ -230,8 +242,8 @@ func NewSwarm(local peer.ID, peers peerstore.Peerstore, eventBus event.Bus, opts // A black hole is a binary property. On a network if UDP dials are blocked or there is // no IPv6 connectivity, all dials will fail. So a low success rate of 5 out 100 dials // is good enough. - udpBlackHoleConfig: blackHoleConfig{Enabled: true, N: 100, MinSuccesses: 5}, - ipv6BlackHoleConfig: blackHoleConfig{Enabled: true, N: 100, MinSuccesses: 5}, + udpBHF: &BlackHoleSuccessCounter{N: 100, MinSuccesses: 5, Name: "UDP"}, + ipv6BHF: &BlackHoleSuccessCounter{N: 100, MinSuccesses: 5, Name: "IPv6"}, } s.conns.m = make(map[peer.ID][]*Conn) @@ -255,7 +267,12 @@ func NewSwarm(local peer.ID, peers peerstore.Peerstore, eventBus event.Bus, opts s.limiter = newDialLimiter(s.dialAddr) s.backf.init(s.ctx) - s.bhd = newBlackHoleDetector(s.udpBlackHoleConfig, s.ipv6BlackHoleConfig, s.metricsTracer) + s.bhd = &blackHoleDetector{ + udp: s.udpBHF, + ipv6: s.ipv6BHF, + mt: s.metricsTracer, + readOnly: s.readOnlyBHD, + } return s, nil } diff --git a/p2p/net/swarm/swarm_dial.go b/p2p/net/swarm/swarm_dial.go index f639ce16a2..446ece4504 100644 --- a/p2p/net/swarm/swarm_dial.go +++ b/p2p/net/swarm/swarm_dial.go @@ -416,6 +416,11 @@ func (s *Swarm) dialNextAddr(ctx context.Context, p peer.ID, addr ma.Multiaddr, return nil } +func (s *Swarm) CanDial(p peer.ID, addr ma.Multiaddr) bool { + dialable, _ := s.filterKnownUndialables(p, []ma.Multiaddr{addr}) + return len(dialable) > 0 +} + func (s *Swarm) nonProxyAddr(addr ma.Multiaddr) bool { t := s.TransportForDialing(addr) return !t.Proxy() diff --git a/p2p/net/swarm/swarm_dial_test.go b/p2p/net/swarm/swarm_dial_test.go index 83f94b9d91..f4c33170a9 100644 --- a/p2p/net/swarm/swarm_dial_test.go +++ b/p2p/net/swarm/swarm_dial_test.go @@ -364,7 +364,7 @@ func TestBlackHoledAddrBlocked(t *testing.T) { defer s.Close() n := 3 - s.bhd.ipv6 = &blackHoleFilter{n: n, minSuccesses: 1, name: "IPv6"} + s.bhd.ipv6 = &BlackHoleSuccessCounter{N: n, MinSuccesses: 1, Name: "IPv6"} // All dials to this addr will fail. // manet.IsPublic is aggressive for IPv6 addresses. Use a NAT64 address. diff --git a/p2p/net/swarm/swarm_metrics.go b/p2p/net/swarm/swarm_metrics.go index b5c0f2e499..929f3f4946 100644 --- a/p2p/net/swarm/swarm_metrics.go +++ b/p2p/net/swarm/swarm_metrics.go @@ -85,7 +85,7 @@ var ( Buckets: []float64{0.001, 0.01, 0.05, 0.1, 0.2, 0.3, 0.4, 0.5, 0.75, 1, 2}, }, ) - blackHoleFilterState = prometheus.NewGaugeVec( + blackHoleSuccessCounterState = prometheus.NewGaugeVec( prometheus.GaugeOpts{ Namespace: metricNamespace, Name: "black_hole_filter_state", @@ -93,7 +93,7 @@ var ( }, []string{"name"}, ) - blackHoleFilterSuccessFraction = prometheus.NewGaugeVec( + blackHoleSuccessCounterSuccessFraction = prometheus.NewGaugeVec( prometheus.GaugeOpts{ Namespace: metricNamespace, Name: "black_hole_filter_success_fraction", @@ -101,7 +101,7 @@ var ( }, []string{"name"}, ) - blackHoleFilterNextRequestAllowedAfter = prometheus.NewGaugeVec( + blackHoleSuccessCounterNextRequestAllowedAfter = prometheus.NewGaugeVec( prometheus.GaugeOpts{ Namespace: metricNamespace, Name: "black_hole_filter_next_request_allowed_after", @@ -118,9 +118,9 @@ var ( connHandshakeLatency, dialsPerPeer, dialRankingDelay, - blackHoleFilterSuccessFraction, - blackHoleFilterState, - blackHoleFilterNextRequestAllowedAfter, + blackHoleSuccessCounterSuccessFraction, + blackHoleSuccessCounterState, + blackHoleSuccessCounterNextRequestAllowedAfter, } ) @@ -131,7 +131,7 @@ type MetricsTracer interface { FailedDialing(ma.Multiaddr, error, error) DialCompleted(success bool, totalDials int) DialRankingDelay(d time.Duration) - UpdatedBlackHoleFilterState(name string, state blackHoleState, nextProbeAfter int, successFraction float64) + UpdatedBlackHoleSuccessCounter(name string, state blackHoleState, nextProbeAfter int, successFraction float64) } type metricsTracer struct{} @@ -274,14 +274,14 @@ func (m *metricsTracer) DialRankingDelay(d time.Duration) { dialRankingDelay.Observe(d.Seconds()) } -func (m *metricsTracer) UpdatedBlackHoleFilterState(name string, state blackHoleState, +func (m *metricsTracer) UpdatedBlackHoleSuccessCounter(name string, state blackHoleState, nextProbeAfter int, successFraction float64) { tags := metricshelper.GetStringSlice() defer metricshelper.PutStringSlice(tags) *tags = append(*tags, name) - blackHoleFilterState.WithLabelValues(*tags...).Set(float64(state)) - blackHoleFilterSuccessFraction.WithLabelValues(*tags...).Set(successFraction) - blackHoleFilterNextRequestAllowedAfter.WithLabelValues(*tags...).Set(float64(nextProbeAfter)) + blackHoleSuccessCounterState.WithLabelValues(*tags...).Set(float64(state)) + blackHoleSuccessCounterSuccessFraction.WithLabelValues(*tags...).Set(successFraction) + blackHoleSuccessCounterNextRequestAllowedAfter.WithLabelValues(*tags...).Set(float64(nextProbeAfter)) } diff --git a/p2p/net/swarm/swarm_metrics_test.go b/p2p/net/swarm/swarm_metrics_test.go index e415c55fb8..45ce0c2e47 100644 --- a/p2p/net/swarm/swarm_metrics_test.go +++ b/p2p/net/swarm/swarm_metrics_test.go @@ -94,8 +94,8 @@ func TestMetricsNoAllocNoCover(t *testing.T) { "FailedDialing": func() { mt.FailedDialing(randItem(addrs), randItem(errors), randItem(errors)) }, "DialCompleted": func() { mt.DialCompleted(mrand.Intn(2) == 1, mrand.Intn(10)) }, "DialRankingDelay": func() { mt.DialRankingDelay(time.Duration(mrand.Intn(1e10))) }, - "UpdatedBlackHoleFilterState": func() { - mt.UpdatedBlackHoleFilterState( + "UpdatedBlackHoleSuccessCounter": func() { + mt.UpdatedBlackHoleSuccessCounter( randItem(bhfNames), randItem(bhfState), mrand.Intn(100), diff --git a/p2p/net/swarm/testing/testing.go b/p2p/net/swarm/testing/testing.go index 627114f645..2bbe8b27a5 100644 --- a/p2p/net/swarm/testing/testing.go +++ b/p2p/net/swarm/testing/testing.go @@ -53,63 +53,63 @@ func (rc realclock) Now() time.Time { } // Option is an option that can be passed when constructing a test swarm. -type Option func(*testing.T, *config) +type Option func(testing.TB, *config) // WithClock sets the clock to use for this swarm func WithClock(clock clock) Option { - return func(_ *testing.T, c *config) { + return func(_ testing.TB, c *config) { c.clock = clock } } func WithSwarmOpts(swarmOpts ...swarm.Option) Option { - return func(_ *testing.T, c *config) { + return func(_ testing.TB, c *config) { c.swarmOpts = swarmOpts } } // OptDisableReuseport disables reuseport in this test swarm. -var OptDisableReuseport Option = func(_ *testing.T, c *config) { +var OptDisableReuseport Option = func(_ testing.TB, c *config) { c.disableReuseport = true } // OptDialOnly prevents the test swarm from listening. -var OptDialOnly Option = func(_ *testing.T, c *config) { +var OptDialOnly Option = func(_ testing.TB, c *config) { c.dialOnly = true } // OptDisableTCP disables TCP. -var OptDisableTCP Option = func(_ *testing.T, c *config) { +var OptDisableTCP Option = func(_ testing.TB, c *config) { c.disableTCP = true } // OptDisableQUIC disables QUIC. -var OptDisableQUIC Option = func(_ *testing.T, c *config) { +var OptDisableQUIC Option = func(_ testing.TB, c *config) { c.disableQUIC = true } // OptConnGater configures the given connection gater on the test func OptConnGater(cg connmgr.ConnectionGater) Option { - return func(_ *testing.T, c *config) { + return func(_ testing.TB, c *config) { c.connectionGater = cg } } // OptPeerPrivateKey configures the peer private key which is then used to derive the public key and peer ID. func OptPeerPrivateKey(sk crypto.PrivKey) Option { - return func(_ *testing.T, c *config) { + return func(_ testing.TB, c *config) { c.sk = sk } } func EventBus(b event.Bus) Option { - return func(_ *testing.T, c *config) { + return func(_ testing.TB, c *config) { c.eventBus = b } } // GenUpgrader creates a new connection upgrader for use with this swarm. -func GenUpgrader(t *testing.T, n *swarm.Swarm, connGater connmgr.ConnectionGater, opts ...tptu.Option) transport.Upgrader { +func GenUpgrader(t testing.TB, n *swarm.Swarm, connGater connmgr.ConnectionGater, opts ...tptu.Option) transport.Upgrader { id := n.LocalPeer() pk := n.Peerstore().PrivKey(id) st := insecure.NewWithIdentity(insecure.ID, id, pk) @@ -120,7 +120,7 @@ func GenUpgrader(t *testing.T, n *swarm.Swarm, connGater connmgr.ConnectionGater } // GenSwarm generates a new test swarm. -func GenSwarm(t *testing.T, opts ...Option) *swarm.Swarm { +func GenSwarm(t testing.TB, opts ...Option) *swarm.Swarm { var cfg config cfg.clock = realclock{} for _, o := range opts { diff --git a/p2p/protocol/autonatv2/autonat.go b/p2p/protocol/autonatv2/autonat.go new file mode 100644 index 0000000000..1ae11edbfa --- /dev/null +++ b/p2p/protocol/autonatv2/autonat.go @@ -0,0 +1,236 @@ +package autonatv2 + +import ( + "context" + "errors" + "fmt" + "sync" + "time" + + logging "github.com/ipfs/go-log/v2" + "github.com/libp2p/go-libp2p/core/event" + "github.com/libp2p/go-libp2p/core/host" + "github.com/libp2p/go-libp2p/core/network" + "github.com/libp2p/go-libp2p/core/peer" + "github.com/libp2p/go-libp2p/p2p/protocol/autonatv2/pb" + ma "github.com/multiformats/go-multiaddr" + manet "github.com/multiformats/go-multiaddr/net" + "golang.org/x/exp/rand" + "golang.org/x/exp/slices" +) + +//go:generate protoc --go_out=. --go_opt=Mpb/autonatv2.proto=./pb pb/autonatv2.proto + +const ( + ServiceName = "libp2p.autonatv2" + DialBackProtocol = "/libp2p/autonat/2/dial-back" + DialProtocol = "/libp2p/autonat/2/dial-request" + + maxMsgSize = 8192 + streamTimeout = time.Minute + dialBackStreamTimeout = 5 * time.Second + dialBackDialTimeout = 30 * time.Second + dialBackMaxMsgSize = 1024 + minHandshakeSizeBytes = 30_000 // for amplification attack prevention + maxHandshakeSizeBytes = 100_000 + // maxPeerAddresses is the number of addresses in a dial request the server + // will inspect, rest are ignored. + maxPeerAddresses = 50 +) + +var ( + ErrNoValidPeers = errors.New("no valid peers for autonat v2") + ErrDialRefused = errors.New("dial refused") + + log = logging.Logger("autonatv2") +) + +// Request is the request to verify reachability of a single address +type Request struct { + // Addr is the multiaddr to verify + Addr ma.Multiaddr + // SendDialData indicates whether to send dial data if the server requests it for Addr + SendDialData bool +} + +// Result is the result of the CheckReachability call +type Result struct { + // Addr is the dialed address + Addr ma.Multiaddr + // Reachability of the dialed address + Reachability network.Reachability + // Status is the outcome of the dialback + Status pb.DialStatus +} + +// AutoNAT implements the AutoNAT v2 client and server. +// Users can check reachability for their addresses using the CheckReachability method. +// The server provides amplification attack prevention and rate limiting. +type AutoNAT struct { + host host.Host + + // for cleanly closing + ctx context.Context + cancel context.CancelFunc + wg sync.WaitGroup + + srv *server + cli *client + + mx sync.Mutex + peers *peersMap + // allowPrivateAddrs enables using private and localhost addresses for reachability checks. + // This is only useful for testing. + allowPrivateAddrs bool +} + +// New returns a new AutoNAT instance. +// host and dialerHost should have the same dialing capabilities. In case the host doesn't support +// a transport, dial back requests for address for that transport will be ignored. +func New(host host.Host, dialerHost host.Host, opts ...AutoNATOption) (*AutoNAT, error) { + s := defaultSettings() + for _, o := range opts { + if err := o(s); err != nil { + return nil, fmt.Errorf("failed to apply option: %w", err) + } + } + + ctx, cancel := context.WithCancel(context.Background()) + an := &AutoNAT{ + host: host, + ctx: ctx, + cancel: cancel, + srv: newServer(host, dialerHost, s), + cli: newClient(host), + allowPrivateAddrs: s.allowPrivateAddrs, + peers: newPeersMap(), + } + return an, nil +} + +func (an *AutoNAT) background(sub event.Subscription) { + for { + select { + case <-an.ctx.Done(): + sub.Close() + an.wg.Done() + return + case e := <-sub.Out(): + switch evt := e.(type) { + case event.EvtPeerProtocolsUpdated: + an.updatePeer(evt.Peer) + case event.EvtPeerConnectednessChanged: + an.updatePeer(evt.Peer) + case event.EvtPeerIdentificationCompleted: + an.updatePeer(evt.Peer) + } + } + } +} + +func (an *AutoNAT) Start() error { + // Listen on event.EvtPeerProtocolsUpdated, event.EvtPeerConnectednessChanged + // event.EvtPeerIdentificationCompleted to maintain our set of autonat supporting peers. + sub, err := an.host.EventBus().Subscribe([]interface{}{ + new(event.EvtPeerProtocolsUpdated), + new(event.EvtPeerConnectednessChanged), + new(event.EvtPeerIdentificationCompleted), + }) + if err != nil { + return fmt.Errorf("event subscription failed: %w", err) + } + an.cli.Start() + an.srv.Start() + + an.wg.Add(1) + go an.background(sub) + return nil +} + +func (an *AutoNAT) Close() { + an.cancel() + an.wg.Wait() + an.srv.Close() + an.cli.Close() + an.peers = nil +} + +// GetReachability makes a single dial request for checking reachability for requested addresses +func (an *AutoNAT) GetReachability(ctx context.Context, reqs []Request) (Result, error) { + if !an.allowPrivateAddrs { + for _, r := range reqs { + if !manet.IsPublicAddr(r.Addr) { + return Result{}, fmt.Errorf("private address cannot be verified by autonatv2: %s", r.Addr) + } + } + } + an.mx.Lock() + p := an.peers.GetRand() + an.mx.Unlock() + if p == "" { + return Result{}, ErrNoValidPeers + } + + res, err := an.cli.GetReachability(ctx, p, reqs) + if err != nil { + log.Debugf("reachability check with %s failed, err: %s", p, err) + return Result{}, fmt.Errorf("reachability check with %s failed: %w", p, err) + } + log.Debugf("reachability check with %s successful", p) + return res, nil +} + +func (an *AutoNAT) updatePeer(p peer.ID) { + an.mx.Lock() + defer an.mx.Unlock() + + // There are no ordering gurantees between identify and swarm events. Check peerstore + // and swarm for the current state + protos, err := an.host.Peerstore().SupportsProtocols(p, DialProtocol) + connectedness := an.host.Network().Connectedness(p) + if err == nil && slices.Contains(protos, DialProtocol) && connectedness == network.Connected { + an.peers.Put(p) + } else { + an.peers.Delete(p) + } +} + +// peersMap provides random access to a set of peers. This is useful when the map iteration order is +// not sufficiently random. +type peersMap struct { + peerIdx map[peer.ID]int + peers []peer.ID +} + +func newPeersMap() *peersMap { + return &peersMap{ + peerIdx: make(map[peer.ID]int), + peers: make([]peer.ID, 0), + } +} + +func (p *peersMap) GetRand() peer.ID { + if len(p.peers) == 0 { + return "" + } + return p.peers[rand.Intn(len(p.peers))] +} + +func (p *peersMap) Put(pid peer.ID) { + if _, ok := p.peerIdx[pid]; ok { + return + } + p.peers = append(p.peers, pid) + p.peerIdx[pid] = len(p.peers) - 1 +} + +func (p *peersMap) Delete(pid peer.ID) { + idx, ok := p.peerIdx[pid] + if !ok { + return + } + p.peers[idx] = p.peers[len(p.peers)-1] + p.peerIdx[p.peers[idx]] = idx + p.peers = p.peers[:len(p.peers)-1] + delete(p.peerIdx, pid) +} diff --git a/p2p/protocol/autonatv2/autonat_test.go b/p2p/protocol/autonatv2/autonat_test.go new file mode 100644 index 0000000000..ee97ccc695 --- /dev/null +++ b/p2p/protocol/autonatv2/autonat_test.go @@ -0,0 +1,661 @@ +package autonatv2 + +import ( + "context" + "errors" + "fmt" + "sync/atomic" + "testing" + "time" + + "github.com/libp2p/go-libp2p/core/host" + "github.com/libp2p/go-libp2p/core/network" + "github.com/libp2p/go-libp2p/core/peer" + "github.com/libp2p/go-libp2p/core/peerstore" + bhost "github.com/libp2p/go-libp2p/p2p/host/blank" + "github.com/libp2p/go-libp2p/p2p/host/eventbus" + "github.com/libp2p/go-libp2p/p2p/net/swarm" + swarmt "github.com/libp2p/go-libp2p/p2p/net/swarm/testing" + "github.com/libp2p/go-libp2p/p2p/protocol/autonatv2/pb" + + "github.com/libp2p/go-msgio/pbio" + ma "github.com/multiformats/go-multiaddr" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func newAutoNAT(t testing.TB, dialer host.Host, opts ...AutoNATOption) *AutoNAT { + t.Helper() + b := eventbus.NewBus() + h := bhost.NewBlankHost( + swarmt.GenSwarm(t, swarmt.EventBus(b)), bhost.WithEventBus(b)) + if dialer == nil { + dialer = bhost.NewBlankHost( + swarmt.GenSwarm(t, + swarmt.WithSwarmOpts( + swarm.WithUDPBlackHoleSuccessCounter(nil), + swarm.WithIPv6BlackHoleSuccessCounter(nil)))) + } + an, err := New(h, dialer, opts...) + if err != nil { + t.Error(err) + } + an.Start() + t.Cleanup(an.Close) + return an +} + +func parseAddrs(t *testing.T, msg *pb.Message) []ma.Multiaddr { + t.Helper() + req := msg.GetDialRequest() + addrs := make([]ma.Multiaddr, 0) + for _, ab := range req.Addrs { + a, err := ma.NewMultiaddrBytes(ab) + if err != nil { + t.Error("invalid addr bytes", ab) + } + addrs = append(addrs, a) + } + return addrs +} + +// idAndConnect identifies b to a and connects them +func idAndConnect(t testing.TB, a, b host.Host) { + a.Peerstore().AddAddrs(b.ID(), b.Addrs(), peerstore.PermanentAddrTTL) + a.Peerstore().AddProtocols(b.ID(), DialProtocol) + + err := a.Connect(context.Background(), peer.AddrInfo{ID: b.ID()}) + require.NoError(t, err) +} + +// waitForPeer waits for a to have 1 peer in the peerMap +func waitForPeer(t testing.TB, a *AutoNAT) { + t.Helper() + require.Eventually(t, func() bool { + a.mx.Lock() + defer a.mx.Unlock() + return a.peers.GetRand() != "" + }, 5*time.Second, 100*time.Millisecond) +} + +// idAndWait provides server address and protocol to client +func idAndWait(t testing.TB, cli *AutoNAT, srv *AutoNAT) { + idAndConnect(t, cli.host, srv.host) + waitForPeer(t, cli) +} + +func TestAutoNATPrivateAddr(t *testing.T) { + an := newAutoNAT(t, nil) + res, err := an.GetReachability(context.Background(), []Request{{Addr: ma.StringCast("/ip4/192.168.0.1/udp/10/quic-v1")}}) + require.Equal(t, res, Result{}) + require.Contains(t, err.Error(), "private address cannot be verified by autonatv2") +} + +func TestClientRequest(t *testing.T) { + an := newAutoNAT(t, nil, allowPrivateAddrs) + defer an.Close() + defer an.host.Close() + + b := bhost.NewBlankHost(swarmt.GenSwarm(t)) + defer b.Close() + idAndConnect(t, an.host, b) + waitForPeer(t, an) + + addrs := an.host.Addrs() + addrbs := make([][]byte, len(addrs)) + for i := 0; i < len(addrs); i++ { + addrbs[i] = addrs[i].Bytes() + } + + var receivedRequest atomic.Bool + b.SetStreamHandler(DialProtocol, func(s network.Stream) { + receivedRequest.Store(true) + r := pbio.NewDelimitedReader(s, maxMsgSize) + var msg pb.Message + assert.NoError(t, r.ReadMsg(&msg)) + assert.NotNil(t, msg.GetDialRequest()) + assert.Equal(t, addrbs, msg.GetDialRequest().Addrs) + s.Reset() + }) + + res, err := an.GetReachability(context.Background(), []Request{ + {Addr: addrs[0], SendDialData: true}, {Addr: addrs[1]}, + }) + require.Equal(t, res, Result{}) + require.NotNil(t, err) + require.True(t, receivedRequest.Load()) +} + +func TestClientServerError(t *testing.T) { + an := newAutoNAT(t, nil, allowPrivateAddrs) + defer an.Close() + defer an.host.Close() + + b := bhost.NewBlankHost(swarmt.GenSwarm(t)) + defer b.Close() + idAndConnect(t, an.host, b) + waitForPeer(t, an) + + tests := []struct { + handler func(network.Stream) + errorStr string + }{ + { + handler: func(s network.Stream) { + s.Reset() + }, + errorStr: "stream reset", + }, + { + handler: func(s network.Stream) { + w := pbio.NewDelimitedWriter(s) + assert.NoError(t, w.WriteMsg( + &pb.Message{Msg: &pb.Message_DialRequest{DialRequest: &pb.DialRequest{}}})) + }, + errorStr: "invalid msg type", + }, + { + handler: func(s network.Stream) { + w := pbio.NewDelimitedWriter(s) + assert.NoError(t, w.WriteMsg( + &pb.Message{Msg: &pb.Message_DialResponse{ + DialResponse: &pb.DialResponse{ + Status: pb.DialResponse_E_DIAL_REFUSED, + }, + }}, + )) + }, + errorStr: ErrDialRefused.Error(), + }, + } + + for i, tc := range tests { + t.Run(fmt.Sprintf("test-%d", i), func(t *testing.T) { + b.SetStreamHandler(DialProtocol, tc.handler) + addrs := an.host.Addrs() + res, err := an.GetReachability( + context.Background(), + newTestRequests(addrs, false)) + require.Equal(t, res, Result{}) + require.NotNil(t, err) + require.Contains(t, err.Error(), tc.errorStr) + }) + } +} + +func TestClientDataRequest(t *testing.T) { + an := newAutoNAT(t, nil, allowPrivateAddrs) + defer an.Close() + defer an.host.Close() + + b := bhost.NewBlankHost(swarmt.GenSwarm(t)) + defer b.Close() + idAndConnect(t, an.host, b) + waitForPeer(t, an) + + tests := []struct { + handler func(network.Stream) + name string + }{ + { + name: "provides dial data", + handler: func(s network.Stream) { + r := pbio.NewDelimitedReader(s, maxMsgSize) + var msg pb.Message + assert.NoError(t, r.ReadMsg(&msg)) + w := pbio.NewDelimitedWriter(s) + if err := w.WriteMsg(&pb.Message{ + Msg: &pb.Message_DialDataRequest{ + DialDataRequest: &pb.DialDataRequest{ + AddrIdx: 0, + NumBytes: 10000, + }, + }}, + ); err != nil { + t.Error(err) + s.Reset() + return + } + var dialData []byte + for len(dialData) < 10000 { + if err := r.ReadMsg(&msg); err != nil { + t.Error(err) + s.Reset() + return + } + if msg.GetDialDataResponse() == nil { + t.Errorf("expected to receive msg of type DialDataResponse") + s.Reset() + return + } + dialData = append(dialData, msg.GetDialDataResponse().Data...) + } + s.Reset() + }, + }, + { + name: "low priority addr", + handler: func(s network.Stream) { + r := pbio.NewDelimitedReader(s, maxMsgSize) + var msg pb.Message + assert.NoError(t, r.ReadMsg(&msg)) + w := pbio.NewDelimitedWriter(s) + if err := w.WriteMsg(&pb.Message{ + Msg: &pb.Message_DialDataRequest{ + DialDataRequest: &pb.DialDataRequest{ + AddrIdx: 1, + NumBytes: 10000, + }, + }}, + ); err != nil { + t.Error(err) + s.Reset() + return + } + assert.Error(t, r.ReadMsg(&msg)) + s.Reset() + }, + }, + { + name: "too high dial data request", + handler: func(s network.Stream) { + r := pbio.NewDelimitedReader(s, maxMsgSize) + var msg pb.Message + assert.NoError(t, r.ReadMsg(&msg)) + w := pbio.NewDelimitedWriter(s) + if err := w.WriteMsg(&pb.Message{ + Msg: &pb.Message_DialDataRequest{ + DialDataRequest: &pb.DialDataRequest{ + AddrIdx: 0, + NumBytes: 1 << 32, + }, + }}, + ); err != nil { + t.Error(err) + s.Reset() + return + } + assert.Error(t, r.ReadMsg(&msg)) + s.Reset() + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + b.SetStreamHandler(DialProtocol, tc.handler) + addrs := an.host.Addrs() + + res, err := an.GetReachability( + context.Background(), + []Request{ + {Addr: addrs[0], SendDialData: true}, + {Addr: addrs[1]}, + }) + require.Equal(t, res, Result{}) + require.NotNil(t, err) + }) + } +} + +func TestClientDialBacks(t *testing.T) { + an := newAutoNAT(t, nil, allowPrivateAddrs) + defer an.Close() + defer an.host.Close() + + b := bhost.NewBlankHost(swarmt.GenSwarm(t)) + defer b.Close() + idAndConnect(t, an.host, b) + waitForPeer(t, an) + + dialerHost := bhost.NewBlankHost(swarmt.GenSwarm(t)) + defer dialerHost.Close() + + readReq := func(r pbio.Reader) ([]ma.Multiaddr, uint64, error) { + var msg pb.Message + if err := r.ReadMsg(&msg); err != nil { + return nil, 0, err + } + if msg.GetDialRequest() == nil { + return nil, 0, errors.New("no dial request in msg") + } + addrs := parseAddrs(t, &msg) + return addrs, msg.GetDialRequest().GetNonce(), nil + } + + writeNonce := func(addr ma.Multiaddr, nonce uint64) error { + pid := an.host.ID() + dialerHost.Peerstore().AddAddr(pid, addr, peerstore.PermanentAddrTTL) + defer func() { + dialerHost.Network().ClosePeer(pid) + dialerHost.Peerstore().RemovePeer(pid) + dialerHost.Peerstore().ClearAddrs(pid) + }() + as, err := dialerHost.NewStream(context.Background(), pid, DialBackProtocol) + if err != nil { + return err + } + w := pbio.NewDelimitedWriter(as) + if err := w.WriteMsg(&pb.DialBack{Nonce: nonce}); err != nil { + return err + } + as.CloseWrite() + data := make([]byte, 1) + as.Read(data) + as.Close() + return nil + } + + tests := []struct { + name string + handler func(network.Stream) + success bool + }{ + { + name: "correct dial attempt", + handler: func(s network.Stream) { + r := pbio.NewDelimitedReader(s, maxMsgSize) + w := pbio.NewDelimitedWriter(s) + + addrs, nonce, err := readReq(r) + if err != nil { + s.Reset() + t.Error(err) + return + } + if err := writeNonce(addrs[1], nonce); err != nil { + s.Reset() + t.Error(err) + return + } + w.WriteMsg(&pb.Message{ + Msg: &pb.Message_DialResponse{ + DialResponse: &pb.DialResponse{ + Status: pb.DialResponse_OK, + DialStatus: pb.DialStatus_OK, + AddrIdx: 1, + }, + }, + }) + s.Close() + }, + success: true, + }, + { + name: "no dial attempt", + handler: func(s network.Stream) { + r := pbio.NewDelimitedReader(s, maxMsgSize) + if _, _, err := readReq(r); err != nil { + s.Reset() + t.Error(err) + return + } + resp := &pb.DialResponse{ + Status: pb.DialResponse_OK, + DialStatus: pb.DialStatus_OK, + AddrIdx: 0, + } + w := pbio.NewDelimitedWriter(s) + w.WriteMsg(&pb.Message{ + Msg: &pb.Message_DialResponse{ + DialResponse: resp, + }, + }) + s.Close() + }, + success: false, + }, + { + name: "invalid reported address", + handler: func(s network.Stream) { + r := pbio.NewDelimitedReader(s, maxMsgSize) + addrs, nonce, err := readReq(r) + if err != nil { + s.Reset() + t.Error(err) + return + } + + if err := writeNonce(addrs[1], nonce); err != nil { + s.Reset() + t.Error(err) + return + } + + w := pbio.NewDelimitedWriter(s) + w.WriteMsg(&pb.Message{ + Msg: &pb.Message_DialResponse{ + DialResponse: &pb.DialResponse{ + Status: pb.DialResponse_OK, + DialStatus: pb.DialStatus_OK, + AddrIdx: 0, + }, + }, + }) + s.Close() + }, + success: false, + }, + { + name: "invalid nonce", + handler: func(s network.Stream) { + r := pbio.NewDelimitedReader(s, maxMsgSize) + addrs, nonce, err := readReq(r) + if err != nil { + s.Reset() + t.Error(err) + return + } + if err := writeNonce(addrs[0], nonce-1); err != nil { + s.Reset() + t.Error(err) + return + } + w := pbio.NewDelimitedWriter(s) + w.WriteMsg(&pb.Message{ + Msg: &pb.Message_DialResponse{ + DialResponse: &pb.DialResponse{ + Status: pb.DialResponse_OK, + DialStatus: pb.DialStatus_OK, + AddrIdx: 0, + }, + }, + }) + s.Close() + }, + success: false, + }, + { + name: "invalid addr index", + handler: func(s network.Stream) { + r := pbio.NewDelimitedReader(s, maxMsgSize) + _, _, err := readReq(r) + if err != nil { + s.Reset() + t.Error(err) + return + } + w := pbio.NewDelimitedWriter(s) + w.WriteMsg(&pb.Message{ + Msg: &pb.Message_DialResponse{ + DialResponse: &pb.DialResponse{ + Status: pb.DialResponse_OK, + DialStatus: pb.DialStatus_OK, + AddrIdx: 10, + }, + }, + }) + s.Close() + }, + success: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + addrs := an.host.Addrs() + b.SetStreamHandler(DialProtocol, tc.handler) + res, err := an.GetReachability( + context.Background(), + []Request{ + {Addr: addrs[0], SendDialData: true}, + {Addr: addrs[1]}, + }) + if !tc.success { + require.Error(t, err) + require.Equal(t, Result{}, res) + } else { + require.NoError(t, err) + require.Equal(t, res.Reachability, network.ReachabilityPublic) + require.Equal(t, res.Status, pb.DialStatus_OK) + } + }) + } +} + +func TestEventSubscription(t *testing.T) { + an := newAutoNAT(t, nil) + defer an.host.Close() + + b := bhost.NewBlankHost(swarmt.GenSwarm(t)) + defer b.Close() + c := bhost.NewBlankHost(swarmt.GenSwarm(t)) + defer c.Close() + + idAndConnect(t, an.host, b) + require.Eventually(t, func() bool { + an.mx.Lock() + defer an.mx.Unlock() + return len(an.peers.peers) == 1 + }, 5*time.Second, 100*time.Millisecond) + + idAndConnect(t, an.host, c) + require.Eventually(t, func() bool { + an.mx.Lock() + defer an.mx.Unlock() + return len(an.peers.peers) == 2 + }, 5*time.Second, 100*time.Millisecond) + + an.host.Network().ClosePeer(b.ID()) + require.Eventually(t, func() bool { + an.mx.Lock() + defer an.mx.Unlock() + return len(an.peers.peers) == 1 + }, 5*time.Second, 100*time.Millisecond) + + an.host.Network().ClosePeer(c.ID()) + require.Eventually(t, func() bool { + an.mx.Lock() + defer an.mx.Unlock() + return len(an.peers.peers) == 0 + }, 5*time.Second, 100*time.Millisecond) +} + +func TestPeersMap(t *testing.T) { + emptyPeerID := peer.ID("") + + t.Run("single_item", func(t *testing.T) { + p := newPeersMap() + p.Put("peer1") + p.Delete("peer1") + p.Put("peer1") + require.Equal(t, peer.ID("peer1"), p.GetRand()) + p.Delete("peer1") + require.Equal(t, emptyPeerID, p.GetRand()) + }) + + t.Run("multiple_items", func(t *testing.T) { + p := newPeersMap() + require.Equal(t, emptyPeerID, p.GetRand()) + + allPeers := make(map[peer.ID]bool) + for i := 0; i < 20; i++ { + pid := peer.ID(fmt.Sprintf("peer-%d", i)) + allPeers[pid] = true + p.Put(pid) + } + foundPeers := make(map[peer.ID]bool) + for i := 0; i < 1000; i++ { + pid := p.GetRand() + require.NotEqual(t, emptyPeerID, p) + require.True(t, allPeers[pid]) + foundPeers[pid] = true + if len(foundPeers) == len(allPeers) { + break + } + } + for pid := range allPeers { + p.Delete(pid) + } + require.Equal(t, emptyPeerID, p.GetRand()) + }) +} + +func TestAreAddrsConsistency(t *testing.T) { + c := &client{ + normalizeMultiaddr: func(a ma.Multiaddr) ma.Multiaddr { + for { + rest, l := ma.SplitLast(a) + if _, err := l.ValueForProtocol(ma.P_CERTHASH); err != nil { + return a + } + a = rest + } + }, + } + tests := []struct { + name string + localAddr ma.Multiaddr + dialAddr ma.Multiaddr + success bool + }{ + { + name: "simple match", + localAddr: ma.StringCast("/ip4/192.168.0.1/tcp/12345"), + dialAddr: ma.StringCast("/ip4/1.2.3.4/tcp/23232"), + success: true, + }, + { + name: "nat64", + localAddr: ma.StringCast("/ip6/1::1/tcp/12345"), + dialAddr: ma.StringCast("/ip4/1.2.3.4/tcp/23232"), + success: false, + }, + { + name: "simple mismatch", + localAddr: ma.StringCast("/ip4/192.168.0.1/tcp/12345"), + dialAddr: ma.StringCast("/ip4/1.2.3.4/udp/23232/quic-v1"), + success: false, + }, + { + name: "quic-vs-webtransport", + localAddr: ma.StringCast("/ip4/192.168.0.1/udp/12345/quic-v1"), + dialAddr: ma.StringCast("/ip4/1.2.3.4/udp/123/quic-v1/webtransport"), + success: false, + }, + { + name: "webtransport-certhash", + localAddr: ma.StringCast("/ip4/192.168.0.1/udp/12345/quic-v1/webtransport"), + dialAddr: ma.StringCast("/ip4/1.2.3.4/udp/123/quic-v1/webtransport/certhash/uEgNmb28"), + success: true, + }, + { + name: "dns", + localAddr: ma.StringCast("/dns/lib.p2p/udp/12345/quic-v1"), + dialAddr: ma.StringCast("/ip6/1::1/udp/123/quic-v1/"), + success: false, + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + if c.areAddrsConsistent(tc.localAddr, tc.dialAddr) != tc.success { + wantStr := "match" + if !tc.success { + wantStr = "mismatch" + } + t.Errorf("expected %s between\nlocal addr: %s\ndial addr: %s", wantStr, tc.localAddr, tc.dialAddr) + } + }) + } + +} diff --git a/p2p/protocol/autonatv2/client.go b/p2p/protocol/autonatv2/client.go new file mode 100644 index 0000000000..f93cc31377 --- /dev/null +++ b/p2p/protocol/autonatv2/client.go @@ -0,0 +1,342 @@ +package autonatv2 + +import ( + "context" + "fmt" + "sync" + "time" + + "github.com/libp2p/go-libp2p/core/host" + "github.com/libp2p/go-libp2p/core/network" + "github.com/libp2p/go-libp2p/core/peer" + "github.com/libp2p/go-libp2p/p2p/protocol/autonatv2/pb" + "github.com/libp2p/go-msgio/pbio" + ma "github.com/multiformats/go-multiaddr" + "golang.org/x/exp/rand" +) + +//go:generate protoc --go_out=. --go_opt=Mpb/autonatv2.proto=./pb pb/autonatv2.proto + +// client implements the client for making dial requests for AutoNAT v2. It verifies successful +// dials and provides an option to send data for dial requests. +type client struct { + host host.Host + dialData []byte + normalizeMultiaddr func(ma.Multiaddr) ma.Multiaddr + + mu sync.Mutex + // dialBackQueues maps nonce to the channel for providing the local multiaddr of the connection + // the nonce was received on + dialBackQueues map[uint64]chan ma.Multiaddr +} + +type normalizeMultiaddrer interface { + NormalizeMultiaddr(ma.Multiaddr) ma.Multiaddr +} + +func newClient(h host.Host) *client { + normalizeMultiaddr := func(a ma.Multiaddr) ma.Multiaddr { return a } + if hn, ok := h.(normalizeMultiaddrer); ok { + normalizeMultiaddr = hn.NormalizeMultiaddr + } + return &client{ + host: h, + dialData: make([]byte, 4000), + normalizeMultiaddr: normalizeMultiaddr, + dialBackQueues: make(map[uint64]chan ma.Multiaddr), + } +} + +func (ac *client) Start() { + ac.host.SetStreamHandler(DialBackProtocol, ac.handleDialBack) +} + +func (ac *client) Close() { + ac.host.RemoveStreamHandler(DialBackProtocol) +} + +// GetReachability verifies address reachability with a AutoNAT v2 server p. +func (ac *client) GetReachability(ctx context.Context, p peer.ID, reqs []Request) (Result, error) { + ctx, cancel := context.WithTimeout(ctx, streamTimeout) + defer cancel() + + s, err := ac.host.NewStream(ctx, p, DialProtocol) + if err != nil { + return Result{}, fmt.Errorf("open %s stream failed: %w", DialProtocol, err) + } + + if err := s.Scope().SetService(ServiceName); err != nil { + s.Reset() + return Result{}, fmt.Errorf("attach stream %s to service %s failed: %w", DialProtocol, ServiceName, err) + } + + if err := s.Scope().ReserveMemory(maxMsgSize, network.ReservationPriorityAlways); err != nil { + s.Reset() + return Result{}, fmt.Errorf("failed to reserve memory for stream %s: %w", DialProtocol, err) + } + defer s.Scope().ReleaseMemory(maxMsgSize) + + s.SetDeadline(time.Now().Add(streamTimeout)) + defer s.Close() + + nonce := rand.Uint64() + ch := make(chan ma.Multiaddr, 1) + ac.mu.Lock() + ac.dialBackQueues[nonce] = ch + ac.mu.Unlock() + defer func() { + ac.mu.Lock() + delete(ac.dialBackQueues, nonce) + ac.mu.Unlock() + }() + + msg := newDialRequest(reqs, nonce) + w := pbio.NewDelimitedWriter(s) + if err := w.WriteMsg(&msg); err != nil { + s.Reset() + return Result{}, fmt.Errorf("dial request write failed: %w", err) + } + + r := pbio.NewDelimitedReader(s, maxMsgSize) + if err := r.ReadMsg(&msg); err != nil { + s.Reset() + return Result{}, fmt.Errorf("dial msg read failed: %w", err) + } + + switch { + case msg.GetDialResponse() != nil: + break + // provide dial data if appropriate + case msg.GetDialDataRequest() != nil: + if err := ac.validateDialDataRequest(reqs, &msg); err != nil { + s.Reset() + return Result{}, fmt.Errorf("invalid dial data request: %w", err) + } + // dial data request is valid and we want to send data + if err := sendDialData(ac.dialData, int(msg.GetDialDataRequest().GetNumBytes()), w, &msg); err != nil { + s.Reset() + return Result{}, fmt.Errorf("dial data send failed: %w", err) + } + if err := r.ReadMsg(&msg); err != nil { + s.Reset() + return Result{}, fmt.Errorf("dial response read failed: %w", err) + } + if msg.GetDialResponse() == nil { + s.Reset() + return Result{}, fmt.Errorf("invalid response type: %T", msg.Msg) + } + default: + s.Reset() + return Result{}, fmt.Errorf("invalid msg type: %T", msg.Msg) + } + + resp := msg.GetDialResponse() + if resp.GetStatus() != pb.DialResponse_OK { + // E_DIAL_REFUSED has implication for deciding future address verificiation priorities + // wrap a distinct error for convenient errors.Is usage + if resp.GetStatus() == pb.DialResponse_E_DIAL_REFUSED { + return Result{}, fmt.Errorf("dial request failed: %w", ErrDialRefused) + } + return Result{}, fmt.Errorf("dial request failed: response status %d %s", resp.GetStatus(), + pb.DialResponse_ResponseStatus_name[int32(resp.GetStatus())]) + } + if resp.GetDialStatus() == pb.DialStatus_UNUSED { + return Result{}, fmt.Errorf("invalid response: invalid dial status UNUSED") + } + if int(resp.AddrIdx) >= len(reqs) { + return Result{}, fmt.Errorf("invalid response: addr index out of range: %d [0-%d)", resp.AddrIdx, len(reqs)) + } + + // wait for nonce from the server + var dialBackAddr ma.Multiaddr + if resp.GetDialStatus() == pb.DialStatus_OK { + timer := time.NewTimer(dialBackStreamTimeout) + select { + case at := <-ch: + dialBackAddr = at + case <-ctx.Done(): + case <-timer.C: + } + timer.Stop() + } + return ac.newResult(resp, reqs, dialBackAddr) +} + +func (ac *client) validateDialDataRequest(reqs []Request, msg *pb.Message) error { + idx := int(msg.GetDialDataRequest().AddrIdx) + if idx >= len(reqs) { // invalid address index + return fmt.Errorf("addr index out of range: %d [0-%d)", idx, len(reqs)) + } + if msg.GetDialDataRequest().NumBytes > maxHandshakeSizeBytes { // data request is too high + return fmt.Errorf("requested data too high: %d", msg.GetDialDataRequest().NumBytes) + } + if !reqs[idx].SendDialData { // low priority addr + return fmt.Errorf("low priority addr: %s index %d", reqs[idx].Addr, idx) + } + return nil +} + +func (ac *client) newResult(resp *pb.DialResponse, reqs []Request, dialBackAddr ma.Multiaddr) (Result, error) { + idx := int(resp.AddrIdx) + addr := reqs[idx].Addr + + var rch network.Reachability + switch resp.DialStatus { + case pb.DialStatus_OK: + if !ac.areAddrsConsistent(dialBackAddr, addr) { + // the server is misinforming us about the address it successfully dialed + // either we received no dialback or the address on the dialback is inconsistent with + // what the server is telling us + return Result{}, fmt.Errorf("invalid response: dialBackAddr: %s, respAddr: %s", dialBackAddr, addr) + } + rch = network.ReachabilityPublic + case pb.DialStatus_E_DIAL_ERROR: + rch = network.ReachabilityPrivate + case pb.DialStatus_E_DIAL_BACK_ERROR: + if ac.areAddrsConsistent(dialBackAddr, addr) { + // We received the dial back but the server claims the dial back errored. + // As long as we received the correct nonce in dial back it is safe to assume + // that we are public. + rch = network.ReachabilityPublic + } else { + rch = network.ReachabilityUnknown + } + default: + // Unexpected response code. Discard the response and fail. + log.Warnf("invalid status code received in response for addr %s: %d", addr, resp.DialStatus) + return Result{}, fmt.Errorf("invalid response: invalid status code for addr %s: %d", addr, resp.DialStatus) + } + + return Result{ + Addr: addr, + Reachability: rch, + Status: resp.DialStatus, + }, nil +} + +func sendDialData(dialData []byte, numBytes int, w pbio.Writer, msg *pb.Message) (err error) { + ddResp := &pb.DialDataResponse{Data: dialData} + *msg = pb.Message{ + Msg: &pb.Message_DialDataResponse{ + DialDataResponse: ddResp, + }, + } + for remain := numBytes; remain > 0; { + if remain < len(ddResp.Data) { + ddResp.Data = ddResp.Data[:remain] + } + if err := w.WriteMsg(msg); err != nil { + return fmt.Errorf("write failed: %w", err) + } + remain -= len(dialData) + } + return nil +} + +func newDialRequest(reqs []Request, nonce uint64) pb.Message { + addrbs := make([][]byte, len(reqs)) + for i, r := range reqs { + addrbs[i] = r.Addr.Bytes() + } + return pb.Message{ + Msg: &pb.Message_DialRequest{ + DialRequest: &pb.DialRequest{ + Addrs: addrbs, + Nonce: nonce, + }, + }, + } +} + +// handleDialBack receives the nonce on the dial-back stream +func (ac *client) handleDialBack(s network.Stream) { + if err := s.Scope().SetService(ServiceName); err != nil { + log.Debugf("failed to attach stream to service %s: %w", ServiceName, err) + s.Reset() + return + } + + if err := s.Scope().ReserveMemory(dialBackMaxMsgSize, network.ReservationPriorityAlways); err != nil { + log.Debugf("failed to reserve memory for stream %s: %w", DialBackProtocol, err) + s.Reset() + return + } + defer s.Scope().ReleaseMemory(dialBackMaxMsgSize) + + s.SetDeadline(time.Now().Add(dialBackStreamTimeout)) + defer s.Close() + + r := pbio.NewDelimitedReader(s, dialBackMaxMsgSize) + var msg pb.DialBack + if err := r.ReadMsg(&msg); err != nil { + log.Debugf("failed to read dialback msg from %s: %s", s.Conn().RemotePeer(), err) + s.Reset() + return + } + nonce := msg.GetNonce() + + ac.mu.Lock() + ch := ac.dialBackQueues[nonce] + ac.mu.Unlock() + if ch == nil { + log.Debugf("dialback received with invalid nonce: localAdds: %s peer: %s nonce: %d", s.Conn().LocalMultiaddr(), s.Conn().RemotePeer(), nonce) + s.Reset() + return + } + select { + case ch <- s.Conn().LocalMultiaddr(): + default: + log.Debugf("multiple dialbacks received: localAddr: %s peer: %s", s.Conn().LocalMultiaddr(), s.Conn().RemotePeer()) + s.Reset() + return + } + w := pbio.NewDelimitedWriter(s) + res := pb.DialBackResponse{} + if err := w.WriteMsg(&res); err != nil { + log.Debugf("failed to write dialback response: %s", err) + s.Reset() + } +} + +func (ac *client) areAddrsConsistent(connLocalAddr, dialedAddr ma.Multiaddr) bool { + if connLocalAddr == nil || dialedAddr == nil { + return false + } + connLocalAddr = ac.normalizeMultiaddr(connLocalAddr) + dialedAddr = ac.normalizeMultiaddr(dialedAddr) + + localProtos := connLocalAddr.Protocols() + externalProtos := dialedAddr.Protocols() + if len(localProtos) != len(externalProtos) { + return false + } + for i := 0; i < len(localProtos); i++ { + if i == 0 { + switch externalProtos[i].Code { + case ma.P_DNS, ma.P_DNSADDR: + if localProtos[i].Code == ma.P_IP4 || localProtos[i].Code == ma.P_IP6 { + continue + } + return false + case ma.P_DNS4: + if localProtos[i].Code == ma.P_IP4 { + continue + } + return false + case ma.P_DNS6: + if localProtos[i].Code == ma.P_IP6 { + continue + } + return false + } + if localProtos[i].Code != externalProtos[i].Code { + return false + } + } else { + if localProtos[i].Code != externalProtos[i].Code { + return false + } + } + } + return true +} diff --git a/p2p/protocol/autonatv2/msg_reader.go b/p2p/protocol/autonatv2/msg_reader.go new file mode 100644 index 0000000000..87849a55ac --- /dev/null +++ b/p2p/protocol/autonatv2/msg_reader.go @@ -0,0 +1,38 @@ +package autonatv2 + +import ( + "io" + + "github.com/multiformats/go-varint" +) + +// msgReader reads a varint prefixed message from R without any buffering +type msgReader struct { + R io.Reader + Buf []byte +} + +func (m *msgReader) ReadByte() (byte, error) { + buf := m.Buf[:1] + _, err := m.R.Read(buf) + return buf[0], err +} + +func (m *msgReader) ReadMsg() ([]byte, error) { + sz, err := varint.ReadUvarint(m) + if err != nil { + return nil, err + } + if sz > uint64(len(m.Buf)) { + return nil, io.ErrShortBuffer + } + n := 0 + for n < int(sz) { + nr, err := m.R.Read(m.Buf[n:sz]) + if err != nil { + return nil, err + } + n += nr + } + return m.Buf[:sz], nil +} diff --git a/p2p/protocol/autonatv2/options.go b/p2p/protocol/autonatv2/options.go new file mode 100644 index 0000000000..5cdb0fe928 --- /dev/null +++ b/p2p/protocol/autonatv2/options.go @@ -0,0 +1,56 @@ +package autonatv2 + +import "time" + +// autoNATSettings is used to configure AutoNAT +type autoNATSettings struct { + allowPrivateAddrs bool + serverRPM int + serverPerPeerRPM int + serverDialDataRPM int + dataRequestPolicy dataRequestPolicyFunc + now func() time.Time + amplificatonAttackPreventionDialWait time.Duration +} + +func defaultSettings() *autoNATSettings { + return &autoNATSettings{ + allowPrivateAddrs: false, + serverRPM: 60, // 1 every second + serverPerPeerRPM: 12, // 1 every 5 seconds + serverDialDataRPM: 12, // 1 every 5 seconds + dataRequestPolicy: amplificationAttackPrevention, + amplificatonAttackPreventionDialWait: 3 * time.Second, + now: time.Now, + } +} + +type AutoNATOption func(s *autoNATSettings) error + +func WithServerRateLimit(rpm, perPeerRPM, dialDataRPM int) AutoNATOption { + return func(s *autoNATSettings) error { + s.serverRPM = rpm + s.serverPerPeerRPM = perPeerRPM + s.serverDialDataRPM = dialDataRPM + return nil + } +} + +func withDataRequestPolicy(drp dataRequestPolicyFunc) AutoNATOption { + return func(s *autoNATSettings) error { + s.dataRequestPolicy = drp + return nil + } +} + +func allowPrivateAddrs(s *autoNATSettings) error { + s.allowPrivateAddrs = true + return nil +} + +func withAmplificationAttackPreventionDialWait(d time.Duration) AutoNATOption { + return func(s *autoNATSettings) error { + s.amplificatonAttackPreventionDialWait = d + return nil + } +} diff --git a/p2p/protocol/autonatv2/pb/autonatv2.pb.go b/p2p/protocol/autonatv2/pb/autonatv2.pb.go new file mode 100644 index 0000000000..5c3ea8089f --- /dev/null +++ b/p2p/protocol/autonatv2/pb/autonatv2.pb.go @@ -0,0 +1,818 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.30.0 +// protoc v4.25.3 +// source: pb/autonatv2.proto + +package pb + +import ( + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + reflect "reflect" + sync "sync" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +type DialStatus int32 + +const ( + DialStatus_UNUSED DialStatus = 0 + DialStatus_E_DIAL_ERROR DialStatus = 100 + DialStatus_E_DIAL_BACK_ERROR DialStatus = 101 + DialStatus_OK DialStatus = 200 +) + +// Enum value maps for DialStatus. +var ( + DialStatus_name = map[int32]string{ + 0: "UNUSED", + 100: "E_DIAL_ERROR", + 101: "E_DIAL_BACK_ERROR", + 200: "OK", + } + DialStatus_value = map[string]int32{ + "UNUSED": 0, + "E_DIAL_ERROR": 100, + "E_DIAL_BACK_ERROR": 101, + "OK": 200, + } +) + +func (x DialStatus) Enum() *DialStatus { + p := new(DialStatus) + *p = x + return p +} + +func (x DialStatus) String() string { + return protoimpl.X.EnumStringOf(x.Descriptor(), protoreflect.EnumNumber(x)) +} + +func (DialStatus) Descriptor() protoreflect.EnumDescriptor { + return file_pb_autonatv2_proto_enumTypes[0].Descriptor() +} + +func (DialStatus) Type() protoreflect.EnumType { + return &file_pb_autonatv2_proto_enumTypes[0] +} + +func (x DialStatus) Number() protoreflect.EnumNumber { + return protoreflect.EnumNumber(x) +} + +// Deprecated: Use DialStatus.Descriptor instead. +func (DialStatus) EnumDescriptor() ([]byte, []int) { + return file_pb_autonatv2_proto_rawDescGZIP(), []int{0} +} + +type DialResponse_ResponseStatus int32 + +const ( + DialResponse_E_INTERNAL_ERROR DialResponse_ResponseStatus = 0 + DialResponse_E_REQUEST_REJECTED DialResponse_ResponseStatus = 100 + DialResponse_E_DIAL_REFUSED DialResponse_ResponseStatus = 101 + DialResponse_OK DialResponse_ResponseStatus = 200 +) + +// Enum value maps for DialResponse_ResponseStatus. +var ( + DialResponse_ResponseStatus_name = map[int32]string{ + 0: "E_INTERNAL_ERROR", + 100: "E_REQUEST_REJECTED", + 101: "E_DIAL_REFUSED", + 200: "OK", + } + DialResponse_ResponseStatus_value = map[string]int32{ + "E_INTERNAL_ERROR": 0, + "E_REQUEST_REJECTED": 100, + "E_DIAL_REFUSED": 101, + "OK": 200, + } +) + +func (x DialResponse_ResponseStatus) Enum() *DialResponse_ResponseStatus { + p := new(DialResponse_ResponseStatus) + *p = x + return p +} + +func (x DialResponse_ResponseStatus) String() string { + return protoimpl.X.EnumStringOf(x.Descriptor(), protoreflect.EnumNumber(x)) +} + +func (DialResponse_ResponseStatus) Descriptor() protoreflect.EnumDescriptor { + return file_pb_autonatv2_proto_enumTypes[1].Descriptor() +} + +func (DialResponse_ResponseStatus) Type() protoreflect.EnumType { + return &file_pb_autonatv2_proto_enumTypes[1] +} + +func (x DialResponse_ResponseStatus) Number() protoreflect.EnumNumber { + return protoreflect.EnumNumber(x) +} + +// Deprecated: Use DialResponse_ResponseStatus.Descriptor instead. +func (DialResponse_ResponseStatus) EnumDescriptor() ([]byte, []int) { + return file_pb_autonatv2_proto_rawDescGZIP(), []int{3, 0} +} + +type DialBackResponse_DialBackStatus int32 + +const ( + DialBackResponse_OK DialBackResponse_DialBackStatus = 0 +) + +// Enum value maps for DialBackResponse_DialBackStatus. +var ( + DialBackResponse_DialBackStatus_name = map[int32]string{ + 0: "OK", + } + DialBackResponse_DialBackStatus_value = map[string]int32{ + "OK": 0, + } +) + +func (x DialBackResponse_DialBackStatus) Enum() *DialBackResponse_DialBackStatus { + p := new(DialBackResponse_DialBackStatus) + *p = x + return p +} + +func (x DialBackResponse_DialBackStatus) String() string { + return protoimpl.X.EnumStringOf(x.Descriptor(), protoreflect.EnumNumber(x)) +} + +func (DialBackResponse_DialBackStatus) Descriptor() protoreflect.EnumDescriptor { + return file_pb_autonatv2_proto_enumTypes[2].Descriptor() +} + +func (DialBackResponse_DialBackStatus) Type() protoreflect.EnumType { + return &file_pb_autonatv2_proto_enumTypes[2] +} + +func (x DialBackResponse_DialBackStatus) Number() protoreflect.EnumNumber { + return protoreflect.EnumNumber(x) +} + +// Deprecated: Use DialBackResponse_DialBackStatus.Descriptor instead. +func (DialBackResponse_DialBackStatus) EnumDescriptor() ([]byte, []int) { + return file_pb_autonatv2_proto_rawDescGZIP(), []int{6, 0} +} + +type Message struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + // Types that are assignable to Msg: + // + // *Message_DialRequest + // *Message_DialResponse + // *Message_DialDataRequest + // *Message_DialDataResponse + Msg isMessage_Msg `protobuf_oneof:"msg"` +} + +func (x *Message) Reset() { + *x = Message{} + if protoimpl.UnsafeEnabled { + mi := &file_pb_autonatv2_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *Message) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*Message) ProtoMessage() {} + +func (x *Message) ProtoReflect() protoreflect.Message { + mi := &file_pb_autonatv2_proto_msgTypes[0] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use Message.ProtoReflect.Descriptor instead. +func (*Message) Descriptor() ([]byte, []int) { + return file_pb_autonatv2_proto_rawDescGZIP(), []int{0} +} + +func (m *Message) GetMsg() isMessage_Msg { + if m != nil { + return m.Msg + } + return nil +} + +func (x *Message) GetDialRequest() *DialRequest { + if x, ok := x.GetMsg().(*Message_DialRequest); ok { + return x.DialRequest + } + return nil +} + +func (x *Message) GetDialResponse() *DialResponse { + if x, ok := x.GetMsg().(*Message_DialResponse); ok { + return x.DialResponse + } + return nil +} + +func (x *Message) GetDialDataRequest() *DialDataRequest { + if x, ok := x.GetMsg().(*Message_DialDataRequest); ok { + return x.DialDataRequest + } + return nil +} + +func (x *Message) GetDialDataResponse() *DialDataResponse { + if x, ok := x.GetMsg().(*Message_DialDataResponse); ok { + return x.DialDataResponse + } + return nil +} + +type isMessage_Msg interface { + isMessage_Msg() +} + +type Message_DialRequest struct { + DialRequest *DialRequest `protobuf:"bytes,1,opt,name=dialRequest,proto3,oneof"` +} + +type Message_DialResponse struct { + DialResponse *DialResponse `protobuf:"bytes,2,opt,name=dialResponse,proto3,oneof"` +} + +type Message_DialDataRequest struct { + DialDataRequest *DialDataRequest `protobuf:"bytes,3,opt,name=dialDataRequest,proto3,oneof"` +} + +type Message_DialDataResponse struct { + DialDataResponse *DialDataResponse `protobuf:"bytes,4,opt,name=dialDataResponse,proto3,oneof"` +} + +func (*Message_DialRequest) isMessage_Msg() {} + +func (*Message_DialResponse) isMessage_Msg() {} + +func (*Message_DialDataRequest) isMessage_Msg() {} + +func (*Message_DialDataResponse) isMessage_Msg() {} + +type DialRequest struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Addrs [][]byte `protobuf:"bytes,1,rep,name=addrs,proto3" json:"addrs,omitempty"` + Nonce uint64 `protobuf:"fixed64,2,opt,name=nonce,proto3" json:"nonce,omitempty"` +} + +func (x *DialRequest) Reset() { + *x = DialRequest{} + if protoimpl.UnsafeEnabled { + mi := &file_pb_autonatv2_proto_msgTypes[1] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *DialRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*DialRequest) ProtoMessage() {} + +func (x *DialRequest) ProtoReflect() protoreflect.Message { + mi := &file_pb_autonatv2_proto_msgTypes[1] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use DialRequest.ProtoReflect.Descriptor instead. +func (*DialRequest) Descriptor() ([]byte, []int) { + return file_pb_autonatv2_proto_rawDescGZIP(), []int{1} +} + +func (x *DialRequest) GetAddrs() [][]byte { + if x != nil { + return x.Addrs + } + return nil +} + +func (x *DialRequest) GetNonce() uint64 { + if x != nil { + return x.Nonce + } + return 0 +} + +type DialDataRequest struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + AddrIdx uint32 `protobuf:"varint,1,opt,name=addrIdx,proto3" json:"addrIdx,omitempty"` + NumBytes uint64 `protobuf:"varint,2,opt,name=numBytes,proto3" json:"numBytes,omitempty"` +} + +func (x *DialDataRequest) Reset() { + *x = DialDataRequest{} + if protoimpl.UnsafeEnabled { + mi := &file_pb_autonatv2_proto_msgTypes[2] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *DialDataRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*DialDataRequest) ProtoMessage() {} + +func (x *DialDataRequest) ProtoReflect() protoreflect.Message { + mi := &file_pb_autonatv2_proto_msgTypes[2] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use DialDataRequest.ProtoReflect.Descriptor instead. +func (*DialDataRequest) Descriptor() ([]byte, []int) { + return file_pb_autonatv2_proto_rawDescGZIP(), []int{2} +} + +func (x *DialDataRequest) GetAddrIdx() uint32 { + if x != nil { + return x.AddrIdx + } + return 0 +} + +func (x *DialDataRequest) GetNumBytes() uint64 { + if x != nil { + return x.NumBytes + } + return 0 +} + +type DialResponse struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Status DialResponse_ResponseStatus `protobuf:"varint,1,opt,name=status,proto3,enum=autonatv2.pb.DialResponse_ResponseStatus" json:"status,omitempty"` + AddrIdx uint32 `protobuf:"varint,2,opt,name=addrIdx,proto3" json:"addrIdx,omitempty"` + DialStatus DialStatus `protobuf:"varint,3,opt,name=dialStatus,proto3,enum=autonatv2.pb.DialStatus" json:"dialStatus,omitempty"` +} + +func (x *DialResponse) Reset() { + *x = DialResponse{} + if protoimpl.UnsafeEnabled { + mi := &file_pb_autonatv2_proto_msgTypes[3] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *DialResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*DialResponse) ProtoMessage() {} + +func (x *DialResponse) ProtoReflect() protoreflect.Message { + mi := &file_pb_autonatv2_proto_msgTypes[3] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use DialResponse.ProtoReflect.Descriptor instead. +func (*DialResponse) Descriptor() ([]byte, []int) { + return file_pb_autonatv2_proto_rawDescGZIP(), []int{3} +} + +func (x *DialResponse) GetStatus() DialResponse_ResponseStatus { + if x != nil { + return x.Status + } + return DialResponse_E_INTERNAL_ERROR +} + +func (x *DialResponse) GetAddrIdx() uint32 { + if x != nil { + return x.AddrIdx + } + return 0 +} + +func (x *DialResponse) GetDialStatus() DialStatus { + if x != nil { + return x.DialStatus + } + return DialStatus_UNUSED +} + +type DialDataResponse struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Data []byte `protobuf:"bytes,1,opt,name=data,proto3" json:"data,omitempty"` +} + +func (x *DialDataResponse) Reset() { + *x = DialDataResponse{} + if protoimpl.UnsafeEnabled { + mi := &file_pb_autonatv2_proto_msgTypes[4] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *DialDataResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*DialDataResponse) ProtoMessage() {} + +func (x *DialDataResponse) ProtoReflect() protoreflect.Message { + mi := &file_pb_autonatv2_proto_msgTypes[4] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use DialDataResponse.ProtoReflect.Descriptor instead. +func (*DialDataResponse) Descriptor() ([]byte, []int) { + return file_pb_autonatv2_proto_rawDescGZIP(), []int{4} +} + +func (x *DialDataResponse) GetData() []byte { + if x != nil { + return x.Data + } + return nil +} + +type DialBack struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Nonce uint64 `protobuf:"fixed64,1,opt,name=nonce,proto3" json:"nonce,omitempty"` +} + +func (x *DialBack) Reset() { + *x = DialBack{} + if protoimpl.UnsafeEnabled { + mi := &file_pb_autonatv2_proto_msgTypes[5] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *DialBack) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*DialBack) ProtoMessage() {} + +func (x *DialBack) ProtoReflect() protoreflect.Message { + mi := &file_pb_autonatv2_proto_msgTypes[5] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use DialBack.ProtoReflect.Descriptor instead. +func (*DialBack) Descriptor() ([]byte, []int) { + return file_pb_autonatv2_proto_rawDescGZIP(), []int{5} +} + +func (x *DialBack) GetNonce() uint64 { + if x != nil { + return x.Nonce + } + return 0 +} + +type DialBackResponse struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Status DialBackResponse_DialBackStatus `protobuf:"varint,1,opt,name=status,proto3,enum=autonatv2.pb.DialBackResponse_DialBackStatus" json:"status,omitempty"` +} + +func (x *DialBackResponse) Reset() { + *x = DialBackResponse{} + if protoimpl.UnsafeEnabled { + mi := &file_pb_autonatv2_proto_msgTypes[6] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *DialBackResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*DialBackResponse) ProtoMessage() {} + +func (x *DialBackResponse) ProtoReflect() protoreflect.Message { + mi := &file_pb_autonatv2_proto_msgTypes[6] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use DialBackResponse.ProtoReflect.Descriptor instead. +func (*DialBackResponse) Descriptor() ([]byte, []int) { + return file_pb_autonatv2_proto_rawDescGZIP(), []int{6} +} + +func (x *DialBackResponse) GetStatus() DialBackResponse_DialBackStatus { + if x != nil { + return x.Status + } + return DialBackResponse_OK +} + +var File_pb_autonatv2_proto protoreflect.FileDescriptor + +var file_pb_autonatv2_proto_rawDesc = []byte{ + 0x0a, 0x12, 0x70, 0x62, 0x2f, 0x61, 0x75, 0x74, 0x6f, 0x6e, 0x61, 0x74, 0x76, 0x32, 0x2e, 0x70, + 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x0c, 0x61, 0x75, 0x74, 0x6f, 0x6e, 0x61, 0x74, 0x76, 0x32, 0x2e, + 0x70, 0x62, 0x22, 0xaa, 0x02, 0x0a, 0x07, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x12, 0x3d, + 0x0a, 0x0b, 0x64, 0x69, 0x61, 0x6c, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x18, 0x01, 0x20, + 0x01, 0x28, 0x0b, 0x32, 0x19, 0x2e, 0x61, 0x75, 0x74, 0x6f, 0x6e, 0x61, 0x74, 0x76, 0x32, 0x2e, + 0x70, 0x62, 0x2e, 0x44, 0x69, 0x61, 0x6c, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x48, 0x00, + 0x52, 0x0b, 0x64, 0x69, 0x61, 0x6c, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x40, 0x0a, + 0x0c, 0x64, 0x69, 0x61, 0x6c, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x18, 0x02, 0x20, + 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x61, 0x75, 0x74, 0x6f, 0x6e, 0x61, 0x74, 0x76, 0x32, 0x2e, + 0x70, 0x62, 0x2e, 0x44, 0x69, 0x61, 0x6c, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x48, + 0x00, 0x52, 0x0c, 0x64, 0x69, 0x61, 0x6c, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, + 0x49, 0x0a, 0x0f, 0x64, 0x69, 0x61, 0x6c, 0x44, 0x61, 0x74, 0x61, 0x52, 0x65, 0x71, 0x75, 0x65, + 0x73, 0x74, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1d, 0x2e, 0x61, 0x75, 0x74, 0x6f, 0x6e, + 0x61, 0x74, 0x76, 0x32, 0x2e, 0x70, 0x62, 0x2e, 0x44, 0x69, 0x61, 0x6c, 0x44, 0x61, 0x74, 0x61, + 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x48, 0x00, 0x52, 0x0f, 0x64, 0x69, 0x61, 0x6c, 0x44, + 0x61, 0x74, 0x61, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x4c, 0x0a, 0x10, 0x64, 0x69, + 0x61, 0x6c, 0x44, 0x61, 0x74, 0x61, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x18, 0x04, + 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1e, 0x2e, 0x61, 0x75, 0x74, 0x6f, 0x6e, 0x61, 0x74, 0x76, 0x32, + 0x2e, 0x70, 0x62, 0x2e, 0x44, 0x69, 0x61, 0x6c, 0x44, 0x61, 0x74, 0x61, 0x52, 0x65, 0x73, 0x70, + 0x6f, 0x6e, 0x73, 0x65, 0x48, 0x00, 0x52, 0x10, 0x64, 0x69, 0x61, 0x6c, 0x44, 0x61, 0x74, 0x61, + 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x42, 0x05, 0x0a, 0x03, 0x6d, 0x73, 0x67, 0x22, + 0x39, 0x0a, 0x0b, 0x44, 0x69, 0x61, 0x6c, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x14, + 0x0a, 0x05, 0x61, 0x64, 0x64, 0x72, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0c, 0x52, 0x05, 0x61, + 0x64, 0x64, 0x72, 0x73, 0x12, 0x14, 0x0a, 0x05, 0x6e, 0x6f, 0x6e, 0x63, 0x65, 0x18, 0x02, 0x20, + 0x01, 0x28, 0x06, 0x52, 0x05, 0x6e, 0x6f, 0x6e, 0x63, 0x65, 0x22, 0x47, 0x0a, 0x0f, 0x44, 0x69, + 0x61, 0x6c, 0x44, 0x61, 0x74, 0x61, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x18, 0x0a, + 0x07, 0x61, 0x64, 0x64, 0x72, 0x49, 0x64, 0x78, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x07, + 0x61, 0x64, 0x64, 0x72, 0x49, 0x64, 0x78, 0x12, 0x1a, 0x0a, 0x08, 0x6e, 0x75, 0x6d, 0x42, 0x79, + 0x74, 0x65, 0x73, 0x18, 0x02, 0x20, 0x01, 0x28, 0x04, 0x52, 0x08, 0x6e, 0x75, 0x6d, 0x42, 0x79, + 0x74, 0x65, 0x73, 0x22, 0x82, 0x02, 0x0a, 0x0c, 0x44, 0x69, 0x61, 0x6c, 0x52, 0x65, 0x73, 0x70, + 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x41, 0x0a, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, 0x18, 0x01, + 0x20, 0x01, 0x28, 0x0e, 0x32, 0x29, 0x2e, 0x61, 0x75, 0x74, 0x6f, 0x6e, 0x61, 0x74, 0x76, 0x32, + 0x2e, 0x70, 0x62, 0x2e, 0x44, 0x69, 0x61, 0x6c, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, + 0x2e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, + 0x06, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x18, 0x0a, 0x07, 0x61, 0x64, 0x64, 0x72, 0x49, + 0x64, 0x78, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x07, 0x61, 0x64, 0x64, 0x72, 0x49, 0x64, + 0x78, 0x12, 0x38, 0x0a, 0x0a, 0x64, 0x69, 0x61, 0x6c, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x18, + 0x03, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x18, 0x2e, 0x61, 0x75, 0x74, 0x6f, 0x6e, 0x61, 0x74, 0x76, + 0x32, 0x2e, 0x70, 0x62, 0x2e, 0x44, 0x69, 0x61, 0x6c, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, + 0x0a, 0x64, 0x69, 0x61, 0x6c, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x22, 0x5b, 0x0a, 0x0e, 0x52, + 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x14, 0x0a, + 0x10, 0x45, 0x5f, 0x49, 0x4e, 0x54, 0x45, 0x52, 0x4e, 0x41, 0x4c, 0x5f, 0x45, 0x52, 0x52, 0x4f, + 0x52, 0x10, 0x00, 0x12, 0x16, 0x0a, 0x12, 0x45, 0x5f, 0x52, 0x45, 0x51, 0x55, 0x45, 0x53, 0x54, + 0x5f, 0x52, 0x45, 0x4a, 0x45, 0x43, 0x54, 0x45, 0x44, 0x10, 0x64, 0x12, 0x12, 0x0a, 0x0e, 0x45, + 0x5f, 0x44, 0x49, 0x41, 0x4c, 0x5f, 0x52, 0x45, 0x46, 0x55, 0x53, 0x45, 0x44, 0x10, 0x65, 0x12, + 0x07, 0x0a, 0x02, 0x4f, 0x4b, 0x10, 0xc8, 0x01, 0x22, 0x26, 0x0a, 0x10, 0x44, 0x69, 0x61, 0x6c, + 0x44, 0x61, 0x74, 0x61, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x12, 0x0a, 0x04, + 0x64, 0x61, 0x74, 0x61, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x04, 0x64, 0x61, 0x74, 0x61, + 0x22, 0x20, 0x0a, 0x08, 0x44, 0x69, 0x61, 0x6c, 0x42, 0x61, 0x63, 0x6b, 0x12, 0x14, 0x0a, 0x05, + 0x6e, 0x6f, 0x6e, 0x63, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x06, 0x52, 0x05, 0x6e, 0x6f, 0x6e, + 0x63, 0x65, 0x22, 0x73, 0x0a, 0x10, 0x44, 0x69, 0x61, 0x6c, 0x42, 0x61, 0x63, 0x6b, 0x52, 0x65, + 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x45, 0x0a, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, + 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x2d, 0x2e, 0x61, 0x75, 0x74, 0x6f, 0x6e, 0x61, 0x74, + 0x76, 0x32, 0x2e, 0x70, 0x62, 0x2e, 0x44, 0x69, 0x61, 0x6c, 0x42, 0x61, 0x63, 0x6b, 0x52, 0x65, + 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x2e, 0x44, 0x69, 0x61, 0x6c, 0x42, 0x61, 0x63, 0x6b, 0x53, + 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, 0x22, 0x18, 0x0a, + 0x0e, 0x44, 0x69, 0x61, 0x6c, 0x42, 0x61, 0x63, 0x6b, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, + 0x06, 0x0a, 0x02, 0x4f, 0x4b, 0x10, 0x00, 0x2a, 0x4a, 0x0a, 0x0a, 0x44, 0x69, 0x61, 0x6c, 0x53, + 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x0a, 0x0a, 0x06, 0x55, 0x4e, 0x55, 0x53, 0x45, 0x44, 0x10, + 0x00, 0x12, 0x10, 0x0a, 0x0c, 0x45, 0x5f, 0x44, 0x49, 0x41, 0x4c, 0x5f, 0x45, 0x52, 0x52, 0x4f, + 0x52, 0x10, 0x64, 0x12, 0x15, 0x0a, 0x11, 0x45, 0x5f, 0x44, 0x49, 0x41, 0x4c, 0x5f, 0x42, 0x41, + 0x43, 0x4b, 0x5f, 0x45, 0x52, 0x52, 0x4f, 0x52, 0x10, 0x65, 0x12, 0x07, 0x0a, 0x02, 0x4f, 0x4b, + 0x10, 0xc8, 0x01, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, +} + +var ( + file_pb_autonatv2_proto_rawDescOnce sync.Once + file_pb_autonatv2_proto_rawDescData = file_pb_autonatv2_proto_rawDesc +) + +func file_pb_autonatv2_proto_rawDescGZIP() []byte { + file_pb_autonatv2_proto_rawDescOnce.Do(func() { + file_pb_autonatv2_proto_rawDescData = protoimpl.X.CompressGZIP(file_pb_autonatv2_proto_rawDescData) + }) + return file_pb_autonatv2_proto_rawDescData +} + +var file_pb_autonatv2_proto_enumTypes = make([]protoimpl.EnumInfo, 3) +var file_pb_autonatv2_proto_msgTypes = make([]protoimpl.MessageInfo, 7) +var file_pb_autonatv2_proto_goTypes = []interface{}{ + (DialStatus)(0), // 0: autonatv2.pb.DialStatus + (DialResponse_ResponseStatus)(0), // 1: autonatv2.pb.DialResponse.ResponseStatus + (DialBackResponse_DialBackStatus)(0), // 2: autonatv2.pb.DialBackResponse.DialBackStatus + (*Message)(nil), // 3: autonatv2.pb.Message + (*DialRequest)(nil), // 4: autonatv2.pb.DialRequest + (*DialDataRequest)(nil), // 5: autonatv2.pb.DialDataRequest + (*DialResponse)(nil), // 6: autonatv2.pb.DialResponse + (*DialDataResponse)(nil), // 7: autonatv2.pb.DialDataResponse + (*DialBack)(nil), // 8: autonatv2.pb.DialBack + (*DialBackResponse)(nil), // 9: autonatv2.pb.DialBackResponse +} +var file_pb_autonatv2_proto_depIdxs = []int32{ + 4, // 0: autonatv2.pb.Message.dialRequest:type_name -> autonatv2.pb.DialRequest + 6, // 1: autonatv2.pb.Message.dialResponse:type_name -> autonatv2.pb.DialResponse + 5, // 2: autonatv2.pb.Message.dialDataRequest:type_name -> autonatv2.pb.DialDataRequest + 7, // 3: autonatv2.pb.Message.dialDataResponse:type_name -> autonatv2.pb.DialDataResponse + 1, // 4: autonatv2.pb.DialResponse.status:type_name -> autonatv2.pb.DialResponse.ResponseStatus + 0, // 5: autonatv2.pb.DialResponse.dialStatus:type_name -> autonatv2.pb.DialStatus + 2, // 6: autonatv2.pb.DialBackResponse.status:type_name -> autonatv2.pb.DialBackResponse.DialBackStatus + 7, // [7:7] is the sub-list for method output_type + 7, // [7:7] is the sub-list for method input_type + 7, // [7:7] is the sub-list for extension type_name + 7, // [7:7] is the sub-list for extension extendee + 0, // [0:7] is the sub-list for field type_name +} + +func init() { file_pb_autonatv2_proto_init() } +func file_pb_autonatv2_proto_init() { + if File_pb_autonatv2_proto != nil { + return + } + if !protoimpl.UnsafeEnabled { + file_pb_autonatv2_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*Message); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_pb_autonatv2_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*DialRequest); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_pb_autonatv2_proto_msgTypes[2].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*DialDataRequest); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_pb_autonatv2_proto_msgTypes[3].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*DialResponse); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_pb_autonatv2_proto_msgTypes[4].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*DialDataResponse); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_pb_autonatv2_proto_msgTypes[5].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*DialBack); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_pb_autonatv2_proto_msgTypes[6].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*DialBackResponse); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + } + file_pb_autonatv2_proto_msgTypes[0].OneofWrappers = []interface{}{ + (*Message_DialRequest)(nil), + (*Message_DialResponse)(nil), + (*Message_DialDataRequest)(nil), + (*Message_DialDataResponse)(nil), + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: file_pb_autonatv2_proto_rawDesc, + NumEnums: 3, + NumMessages: 7, + NumExtensions: 0, + NumServices: 0, + }, + GoTypes: file_pb_autonatv2_proto_goTypes, + DependencyIndexes: file_pb_autonatv2_proto_depIdxs, + EnumInfos: file_pb_autonatv2_proto_enumTypes, + MessageInfos: file_pb_autonatv2_proto_msgTypes, + }.Build() + File_pb_autonatv2_proto = out.File + file_pb_autonatv2_proto_rawDesc = nil + file_pb_autonatv2_proto_goTypes = nil + file_pb_autonatv2_proto_depIdxs = nil +} diff --git a/p2p/protocol/autonatv2/pb/autonatv2.proto b/p2p/protocol/autonatv2/pb/autonatv2.proto new file mode 100644 index 0000000000..64dca1138f --- /dev/null +++ b/p2p/protocol/autonatv2/pb/autonatv2.proto @@ -0,0 +1,64 @@ +syntax = "proto3"; + +package autonatv2.pb; + +message Message { + oneof msg { + DialRequest dialRequest = 1; + DialResponse dialResponse = 2; + DialDataRequest dialDataRequest = 3; + DialDataResponse dialDataResponse = 4; + } +} + +message DialRequest { + repeated bytes addrs = 1; + fixed64 nonce = 2; +} + + +message DialDataRequest { + uint32 addrIdx = 1; + uint64 numBytes = 2; +} + + +enum DialStatus { + UNUSED = 0; + E_DIAL_ERROR = 100; + E_DIAL_BACK_ERROR = 101; + OK = 200; +} + + +message DialResponse { + enum ResponseStatus { + E_INTERNAL_ERROR = 0; + E_REQUEST_REJECTED = 100; + E_DIAL_REFUSED = 101; + OK = 200; + } + + ResponseStatus status = 1; + uint32 addrIdx = 2; + DialStatus dialStatus = 3; +} + + +message DialDataResponse { + bytes data = 1; +} + + +message DialBack { + fixed64 nonce = 1; +} + + +message DialBackResponse { + enum DialBackStatus { + OK = 0; + } + + DialBackStatus status = 1; +} \ No newline at end of file diff --git a/p2p/protocol/autonatv2/server.go b/p2p/protocol/autonatv2/server.go new file mode 100644 index 0000000000..12a6fa7860 --- /dev/null +++ b/p2p/protocol/autonatv2/server.go @@ -0,0 +1,449 @@ +package autonatv2 + +import ( + "context" + "fmt" + "io" + "sync" + "time" + + pool "github.com/libp2p/go-buffer-pool" + "github.com/libp2p/go-libp2p/core/host" + "github.com/libp2p/go-libp2p/core/network" + "github.com/libp2p/go-libp2p/core/peer" + "github.com/libp2p/go-libp2p/core/peerstore" + "github.com/libp2p/go-libp2p/p2p/protocol/autonatv2/pb" + "github.com/libp2p/go-msgio/pbio" + + "math/rand" + + ma "github.com/multiformats/go-multiaddr" + manet "github.com/multiformats/go-multiaddr/net" +) + +type dataRequestPolicyFunc = func(s network.Stream, dialAddr ma.Multiaddr) bool + +// server implements the AutoNATv2 server. +// It can ask client to provide dial data before attempting the requested dial. +// It rate limits requests on a global level, per peer level and on whether the request requires dial data. +type server struct { + host host.Host + dialerHost host.Host + limiter *rateLimiter + + // dialDataRequestPolicy is used to determine whether dialing the address requires receiving + // dial data. It is set to amplification attack prevention by default. + dialDataRequestPolicy dataRequestPolicyFunc + amplificatonAttackPreventionDialWait time.Duration + + // for tests + now func() time.Time + allowPrivateAddrs bool +} + +func newServer(host, dialer host.Host, s *autoNATSettings) *server { + return &server{ + dialerHost: dialer, + host: host, + dialDataRequestPolicy: s.dataRequestPolicy, + amplificatonAttackPreventionDialWait: s.amplificatonAttackPreventionDialWait, + allowPrivateAddrs: s.allowPrivateAddrs, + limiter: &rateLimiter{ + RPM: s.serverRPM, + PerPeerRPM: s.serverPerPeerRPM, + DialDataRPM: s.serverDialDataRPM, + now: s.now, + }, + now: s.now, + } +} + +// Enable attaches the stream handler to the host. +func (as *server) Start() { + as.host.SetStreamHandler(DialProtocol, as.handleDialRequest) +} + +func (as *server) Close() { + as.host.RemoveStreamHandler(DialProtocol) + as.dialerHost.Close() + as.limiter.Close() +} + +// handleDialRequest is the dial-request protocol stream handler +func (as *server) handleDialRequest(s network.Stream) { + if err := s.Scope().SetService(ServiceName); err != nil { + s.Reset() + log.Debugf("failed to attach stream to service %s: %w", ServiceName, err) + return + } + + if err := s.Scope().ReserveMemory(maxMsgSize, network.ReservationPriorityAlways); err != nil { + s.Reset() + log.Debugf("failed to reserve memory for stream %s: %w", DialProtocol, err) + return + } + defer s.Scope().ReleaseMemory(maxMsgSize) + + deadline := as.now().Add(streamTimeout) + ctx, cancel := context.WithDeadline(context.Background(), deadline) + defer cancel() + s.SetDeadline(as.now().Add(streamTimeout)) + defer s.Close() + + p := s.Conn().RemotePeer() + + var msg pb.Message + w := pbio.NewDelimitedWriter(s) + // Check for rate limit before parsing the request + if !as.limiter.Accept(p) { + msg = pb.Message{ + Msg: &pb.Message_DialResponse{ + DialResponse: &pb.DialResponse{ + Status: pb.DialResponse_E_REQUEST_REJECTED, + }, + }, + } + if err := w.WriteMsg(&msg); err != nil { + s.Reset() + log.Debugf("failed to write request rejected response to %s: %s", p, err) + return + } + log.Debugf("rejected request from %s: rate limit exceeded", p) + return + } + defer as.limiter.CompleteRequest(p) + + r := pbio.NewDelimitedReader(s, maxMsgSize) + if err := r.ReadMsg(&msg); err != nil { + s.Reset() + log.Debugf("failed to read request from %s: %s", p, err) + return + } + if msg.GetDialRequest() == nil { + s.Reset() + log.Debugf("invalid message type from %s: %T expected: DialRequest", p, msg.Msg) + return + } + + // parse peer's addresses + var dialAddr ma.Multiaddr + var addrIdx int + for i, ab := range msg.GetDialRequest().GetAddrs() { + if i >= maxPeerAddresses { + break + } + a, err := ma.NewMultiaddrBytes(ab) + if err != nil { + continue + } + if !as.allowPrivateAddrs && !manet.IsPublicAddr(a) { + continue + } + if !as.dialerHost.Network().CanDial(p, a) { + continue + } + dialAddr = a + addrIdx = i + break + } + // No dialable address + if dialAddr == nil { + msg = pb.Message{ + Msg: &pb.Message_DialResponse{ + DialResponse: &pb.DialResponse{ + Status: pb.DialResponse_E_DIAL_REFUSED, + }, + }, + } + if err := w.WriteMsg(&msg); err != nil { + s.Reset() + log.Debugf("failed to write dial refused response to %s: %s", p, err) + return + } + return + } + + nonce := msg.GetDialRequest().Nonce + + isDialDataRequired := as.dialDataRequestPolicy(s, dialAddr) + if isDialDataRequired && !as.limiter.AcceptDialDataRequest(p) { + msg = pb.Message{ + Msg: &pb.Message_DialResponse{ + DialResponse: &pb.DialResponse{ + Status: pb.DialResponse_E_REQUEST_REJECTED, + }, + }, + } + if err := w.WriteMsg(&msg); err != nil { + s.Reset() + log.Debugf("failed to write request rejected response to %s: %s", p, err) + return + } + log.Debugf("rejected request from %s: rate limit exceeded", p) + return + } + + if isDialDataRequired { + if err := getDialData(w, s, &msg, addrIdx); err != nil { + s.Reset() + log.Debugf("%s refused dial data request: %s", p, err) + return + } + // wait for a bit to prevent thundering herd style attacks on a victim + waitTime := time.Duration(rand.Intn(int(as.amplificatonAttackPreventionDialWait) + 1)) // the range is [0, n) + t := time.NewTimer(waitTime) + defer t.Stop() + select { + case <-ctx.Done(): + s.Reset() + log.Debugf("rejecting request without dialing: %s %p ", p, ctx.Err()) + return + case <-t.C: + } + } + + dialStatus := as.dialBack(ctx, s.Conn().RemotePeer(), dialAddr, nonce) + msg = pb.Message{ + Msg: &pb.Message_DialResponse{ + DialResponse: &pb.DialResponse{ + Status: pb.DialResponse_OK, + DialStatus: dialStatus, + AddrIdx: uint32(addrIdx), + }, + }, + } + if err := w.WriteMsg(&msg); err != nil { + s.Reset() + log.Debugf("failed to write response to %s: %s", p, err) + return + } +} + +// getDialData gets data from the client for dialing the address +func getDialData(w pbio.Writer, s network.Stream, msg *pb.Message, addrIdx int) error { + numBytes := minHandshakeSizeBytes + rand.Intn(maxHandshakeSizeBytes-minHandshakeSizeBytes) + *msg = pb.Message{ + Msg: &pb.Message_DialDataRequest{ + DialDataRequest: &pb.DialDataRequest{ + AddrIdx: uint32(addrIdx), + NumBytes: uint64(numBytes), + }, + }, + } + if err := w.WriteMsg(msg); err != nil { + return fmt.Errorf("dial data write: %w", err) + } + // pbio.Reader that we used so far on this stream is buffered. But at this point + // there is nothing unread on the stream. So it is safe to use the raw stream to + // read, reducing allocations. + return readDialData(numBytes, s) +} + +func readDialData(numBytes int, r io.Reader) error { + mr := &msgReader{R: r, Buf: pool.Get(maxMsgSize)} + defer pool.Put(mr.Buf) + for remain := numBytes; remain > 0; { + msg, err := mr.ReadMsg() + if err != nil { + return fmt.Errorf("dial data read: %w", err) + } + // protobuf format is: + // (oneof dialDataResponse:)(dial data:) + bytesLen := len(msg) + bytesLen -= 2 // fieldTag + varint first byte + if bytesLen > 127 { + bytesLen -= 1 // varint second byte + } + bytesLen -= 2 // second fieldTag + varint first byte + if bytesLen > 127 { + bytesLen -= 1 // varint second byte + } + if bytesLen > 0 { + remain -= bytesLen + } + // Check if the peer is not sending too little data forcing us to just do a lot of compute + if bytesLen < 100 && remain > 0 { + return fmt.Errorf("dial data msg too small: %d", bytesLen) + } + } + return nil +} + +func (as *server) dialBack(ctx context.Context, p peer.ID, addr ma.Multiaddr, nonce uint64) pb.DialStatus { + ctx, cancel := context.WithTimeout(ctx, dialBackDialTimeout) + ctx = network.WithForceDirectDial(ctx, "autonatv2") + as.dialerHost.Peerstore().AddAddr(p, addr, peerstore.TempAddrTTL) + defer func() { + cancel() + as.dialerHost.Network().ClosePeer(p) + as.dialerHost.Peerstore().ClearAddrs(p) + as.dialerHost.Peerstore().RemovePeer(p) + }() + + err := as.dialerHost.Connect(ctx, peer.AddrInfo{ID: p}) + if err != nil { + return pb.DialStatus_E_DIAL_ERROR + } + + s, err := as.dialerHost.NewStream(ctx, p, DialBackProtocol) + if err != nil { + return pb.DialStatus_E_DIAL_BACK_ERROR + } + + defer s.Close() + s.SetDeadline(as.now().Add(dialBackStreamTimeout)) + + w := pbio.NewDelimitedWriter(s) + if err := w.WriteMsg(&pb.DialBack{Nonce: nonce}); err != nil { + s.Reset() + return pb.DialStatus_E_DIAL_BACK_ERROR + } + + // Since the underlying connection is on a separate dialer, it'll be closed after this + // function returns. Connection close will drop all the queued writes. To ensure message + // delivery, do a CloseWrite and read a byte from the stream. The peer actually sends a + // response of type DialBackResponse but we only care about the fact that the DialBack + // message has reached the peer. So we ignore that message on the read side. + s.CloseWrite() + s.SetDeadline(as.now().Add(5 * time.Second)) // 5 is a magic number + b := make([]byte, 1) // Read 1 byte here because 0 len reads are free to return (0, nil) immediately + s.Read(b) + + return pb.DialStatus_OK +} + +// rateLimiter implements a sliding window rate limit of requests per minute. It allows 1 concurrent request +// per peer. It rate limits requests globally, at a peer level and depending on whether it requires dial data. +type rateLimiter struct { + // PerPeerRPM is the rate limit per peer + PerPeerRPM int + // RPM is the global rate limit + RPM int + // DialDataRPM is the rate limit for requests that require dial data + DialDataRPM int + + mu sync.Mutex + closed bool + reqs []entry + peerReqs map[peer.ID][]time.Time + dialDataReqs []time.Time + // ongoingReqs tracks in progress requests. This is used to disallow multiple concurrent requests by the + // same peer + // TODO: Should we allow a few concurrent requests per peer? + ongoingReqs map[peer.ID]struct{} + + now func() time.Time // for tests +} + +type entry struct { + PeerID peer.ID + Time time.Time +} + +func (r *rateLimiter) Accept(p peer.ID) bool { + r.mu.Lock() + defer r.mu.Unlock() + if r.closed { + return false + } + if r.peerReqs == nil { + r.peerReqs = make(map[peer.ID][]time.Time) + r.ongoingReqs = make(map[peer.ID]struct{}) + } + + nw := r.now() + r.cleanup(nw) + + if _, ok := r.ongoingReqs[p]; ok { + return false + } + if len(r.reqs) >= r.RPM || len(r.peerReqs[p]) >= r.PerPeerRPM { + return false + } + + r.ongoingReqs[p] = struct{}{} + r.reqs = append(r.reqs, entry{PeerID: p, Time: nw}) + r.peerReqs[p] = append(r.peerReqs[p], nw) + return true +} + +func (r *rateLimiter) AcceptDialDataRequest(p peer.ID) bool { + r.mu.Lock() + defer r.mu.Unlock() + if r.closed { + return false + } + if r.peerReqs == nil { + r.peerReqs = make(map[peer.ID][]time.Time) + r.ongoingReqs = make(map[peer.ID]struct{}) + } + nw := r.now() + r.cleanup(nw) + if len(r.dialDataReqs) >= r.DialDataRPM { + return false + } + r.dialDataReqs = append(r.dialDataReqs, nw) + return true +} + +// cleanup removes stale requests. +// +// This is fast enough in rate limited cases and the state is small enough to +// clean up quickly when blocking requests. +func (r *rateLimiter) cleanup(now time.Time) { + idx := len(r.reqs) + for i, e := range r.reqs { + if now.Sub(e.Time) >= time.Minute { + pi := len(r.peerReqs[e.PeerID]) + for j, t := range r.peerReqs[e.PeerID] { + if now.Sub(t) < time.Minute { + pi = j + break + } + } + r.peerReqs[e.PeerID] = r.peerReqs[e.PeerID][pi:] + if len(r.peerReqs[e.PeerID]) == 0 { + delete(r.peerReqs, e.PeerID) + } + } else { + idx = i + break + } + } + r.reqs = r.reqs[idx:] + + idx = len(r.dialDataReqs) + for i, t := range r.dialDataReqs { + if now.Sub(t) < time.Minute { + idx = i + break + } + } + r.dialDataReqs = r.dialDataReqs[idx:] +} + +func (r *rateLimiter) CompleteRequest(p peer.ID) { + r.mu.Lock() + defer r.mu.Unlock() + delete(r.ongoingReqs, p) +} + +func (r *rateLimiter) Close() { + r.mu.Lock() + defer r.mu.Unlock() + r.closed = true + r.peerReqs = nil + r.ongoingReqs = nil + r.dialDataReqs = nil +} + +// amplificationAttackPrevention is a dialDataRequestPolicy which requests data when the peer's observed +// IP address is different from the dial back IP address +func amplificationAttackPrevention(s network.Stream, dialAddr ma.Multiaddr) bool { + connIP, err := manet.ToIP(s.Conn().RemoteMultiaddr()) + if err != nil { + return true + } + dialIP, _ := manet.ToIP(s.Conn().LocalMultiaddr()) // must be an IP multiaddr + return !connIP.Equal(dialIP) +} diff --git a/p2p/protocol/autonatv2/server_test.go b/p2p/protocol/autonatv2/server_test.go new file mode 100644 index 0000000000..0b40f27535 --- /dev/null +++ b/p2p/protocol/autonatv2/server_test.go @@ -0,0 +1,484 @@ +package autonatv2 + +import ( + "bytes" + "context" + "fmt" + "io" + "math" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/libp2p/go-libp2p/core/network" + "github.com/libp2p/go-libp2p/core/peer" + "github.com/libp2p/go-libp2p/core/test" + bhost "github.com/libp2p/go-libp2p/p2p/host/blank" + "github.com/libp2p/go-libp2p/p2p/net/swarm" + swarmt "github.com/libp2p/go-libp2p/p2p/net/swarm/testing" + "github.com/libp2p/go-libp2p/p2p/protocol/autonatv2/pb" + "github.com/libp2p/go-msgio/pbio" + ma "github.com/multiformats/go-multiaddr" + "github.com/multiformats/go-varint" + "github.com/stretchr/testify/require" +) + +func newTestRequests(addrs []ma.Multiaddr, sendDialData bool) (reqs []Request) { + reqs = make([]Request, len(addrs)) + for i := 0; i < len(addrs); i++ { + reqs[i] = Request{Addr: addrs[i], SendDialData: sendDialData} + } + return +} + +func TestServerInvalidAddrsRejected(t *testing.T) { + c := newAutoNAT(t, nil, allowPrivateAddrs, withAmplificationAttackPreventionDialWait(0)) + defer c.Close() + defer c.host.Close() + + t.Run("no transport", func(t *testing.T) { + dialer := bhost.NewBlankHost(swarmt.GenSwarm(t, swarmt.OptDisableQUIC, swarmt.OptDisableTCP)) + an := newAutoNAT(t, dialer, allowPrivateAddrs) + defer an.Close() + defer an.host.Close() + + idAndWait(t, c, an) + + res, err := c.GetReachability(context.Background(), newTestRequests(c.host.Addrs(), true)) + require.ErrorIs(t, err, ErrDialRefused) + require.Equal(t, Result{}, res) + }) + + t.Run("black holed addr", func(t *testing.T) { + dialer := bhost.NewBlankHost(swarmt.GenSwarm( + t, swarmt.WithSwarmOpts(swarm.WithReadOnlyBlackHoleDetector()))) + an := newAutoNAT(t, dialer) + defer an.Close() + defer an.host.Close() + + idAndWait(t, c, an) + + res, err := c.GetReachability(context.Background(), + []Request{{ + Addr: ma.StringCast("/ip4/1.2.3.4/udp/1234/quic-v1"), + SendDialData: true, + }}) + require.ErrorIs(t, err, ErrDialRefused) + require.Equal(t, Result{}, res) + }) + + t.Run("private addrs", func(t *testing.T) { + an := newAutoNAT(t, nil) + defer an.Close() + defer an.host.Close() + + idAndWait(t, c, an) + + res, err := c.GetReachability(context.Background(), newTestRequests(c.host.Addrs(), true)) + require.ErrorIs(t, err, ErrDialRefused) + require.Equal(t, Result{}, res) + }) + + t.Run("relay addrs", func(t *testing.T) { + an := newAutoNAT(t, nil) + defer an.Close() + defer an.host.Close() + + idAndWait(t, c, an) + + res, err := c.GetReachability(context.Background(), newTestRequests( + []ma.Multiaddr{ma.StringCast(fmt.Sprintf("/ip4/1.2.3.4/tcp/1/p2p/%s/p2p-circuit/p2p/%s", c.host.ID(), c.srv.dialerHost.ID()))}, true)) + require.ErrorIs(t, err, ErrDialRefused) + require.Equal(t, Result{}, res) + }) + + t.Run("no addr", func(t *testing.T) { + _, err := c.GetReachability(context.Background(), nil) + require.Error(t, err) + }) + + t.Run("too many address", func(t *testing.T) { + dialer := bhost.NewBlankHost(swarmt.GenSwarm(t, swarmt.OptDisableTCP)) + an := newAutoNAT(t, dialer, allowPrivateAddrs) + defer an.Close() + defer an.host.Close() + + var addrs []ma.Multiaddr + for i := 0; i < 100; i++ { + addrs = append(addrs, ma.StringCast(fmt.Sprintf("/ip4/127.0.0.1/tcp/%d", 2000+i))) + } + addrs = append(addrs, c.host.Addrs()...) + // The dial should still fail because we have too many addresses that the server cannot dial + idAndWait(t, c, an) + + res, err := c.GetReachability(context.Background(), newTestRequests(addrs, true)) + require.ErrorIs(t, err, ErrDialRefused) + require.Equal(t, Result{}, res) + }) + + t.Run("msg too large", func(t *testing.T) { + dialer := bhost.NewBlankHost(swarmt.GenSwarm(t, swarmt.OptDisableTCP)) + an := newAutoNAT(t, dialer, allowPrivateAddrs) + defer an.Close() + defer an.host.Close() + + var addrs []ma.Multiaddr + for i := 0; i < 10000; i++ { + addrs = append(addrs, ma.StringCast(fmt.Sprintf("/ip4/127.0.0.1/tcp/%d", 2000+i))) + } + addrs = append(addrs, c.host.Addrs()...) + // The dial should still fail because we have too many addresses that the server cannot dial + idAndWait(t, c, an) + + res, err := c.GetReachability(context.Background(), newTestRequests(addrs, true)) + require.ErrorIs(t, err, network.ErrReset) + require.Equal(t, Result{}, res) + }) + +} + +func TestServerDataRequest(t *testing.T) { + // server will skip all tcp addresses + dialer := bhost.NewBlankHost(swarmt.GenSwarm(t, swarmt.OptDisableTCP)) + // ask for dial data for quic address + an := newAutoNAT(t, dialer, allowPrivateAddrs, withDataRequestPolicy( + func(s network.Stream, dialAddr ma.Multiaddr) bool { + if _, err := dialAddr.ValueForProtocol(ma.P_QUIC_V1); err == nil { + return true + } + return false + }), + WithServerRateLimit(10, 10, 10), + withAmplificationAttackPreventionDialWait(0), + ) + defer an.Close() + defer an.host.Close() + + c := newAutoNAT(t, nil, allowPrivateAddrs) + defer c.Close() + defer c.host.Close() + + idAndWait(t, c, an) + + var quicAddr, tcpAddr ma.Multiaddr + for _, a := range c.host.Addrs() { + if _, err := a.ValueForProtocol(ma.P_QUIC_V1); err == nil { + quicAddr = a + } else if _, err := a.ValueForProtocol(ma.P_TCP); err == nil { + tcpAddr = a + } + } + + _, err := c.GetReachability(context.Background(), []Request{{Addr: tcpAddr, SendDialData: true}, {Addr: quicAddr}}) + require.Error(t, err) + + res, err := c.GetReachability(context.Background(), []Request{{Addr: quicAddr, SendDialData: true}, {Addr: tcpAddr}}) + require.NoError(t, err) + + require.Equal(t, Result{ + Addr: quicAddr, + Reachability: network.ReachabilityPublic, + Status: pb.DialStatus_OK, + }, res) + + // Small messages should be rejected for dial data + c.cli.dialData = c.cli.dialData[:10] + _, err = c.GetReachability(context.Background(), []Request{{Addr: quicAddr, SendDialData: true}, {Addr: tcpAddr}}) + require.Error(t, err) +} +func TestServerDataRequestJitter(t *testing.T) { + // server will skip all tcp addresses + dialer := bhost.NewBlankHost(swarmt.GenSwarm(t, swarmt.OptDisableTCP)) + // ask for dial data for quic address + an := newAutoNAT(t, dialer, allowPrivateAddrs, withDataRequestPolicy( + func(s network.Stream, dialAddr ma.Multiaddr) bool { + if _, err := dialAddr.ValueForProtocol(ma.P_QUIC_V1); err == nil { + return true + } + return false + }), + WithServerRateLimit(10, 10, 10), + withAmplificationAttackPreventionDialWait(5*time.Second), + ) + defer an.Close() + defer an.host.Close() + + c := newAutoNAT(t, nil, allowPrivateAddrs) + defer c.Close() + defer c.host.Close() + + idAndWait(t, c, an) + + var quicAddr, tcpAddr ma.Multiaddr + for _, a := range c.host.Addrs() { + if _, err := a.ValueForProtocol(ma.P_QUIC_V1); err == nil { + quicAddr = a + } else if _, err := a.ValueForProtocol(ma.P_TCP); err == nil { + tcpAddr = a + } + } + + for i := 0; i < 10; i++ { + st := time.Now() + res, err := c.GetReachability(context.Background(), []Request{{Addr: quicAddr, SendDialData: true}, {Addr: tcpAddr}}) + took := time.Since(st) + require.NoError(t, err) + + require.Equal(t, Result{ + Addr: quicAddr, + Reachability: network.ReachabilityPublic, + Status: pb.DialStatus_OK, + }, res) + if took > 500*time.Millisecond { + return + } + } + t.Fatalf("expected server to delay at least 1 dial") +} + +func TestServerDial(t *testing.T) { + an := newAutoNAT(t, nil, WithServerRateLimit(10, 10, 10), allowPrivateAddrs) + defer an.Close() + defer an.host.Close() + + c := newAutoNAT(t, nil, allowPrivateAddrs) + defer c.Close() + defer c.host.Close() + + idAndWait(t, c, an) + + unreachableAddr := ma.StringCast("/ip4/1.2.3.4/tcp/2") + hostAddrs := c.host.Addrs() + + t.Run("unreachable addr", func(t *testing.T) { + res, err := c.GetReachability(context.Background(), + append([]Request{{Addr: unreachableAddr, SendDialData: true}}, newTestRequests(hostAddrs, false)...)) + require.NoError(t, err) + require.Equal(t, Result{ + Addr: unreachableAddr, + Reachability: network.ReachabilityPrivate, + Status: pb.DialStatus_E_DIAL_ERROR, + }, res) + }) + + t.Run("reachable addr", func(t *testing.T) { + res, err := c.GetReachability(context.Background(), newTestRequests(c.host.Addrs(), false)) + require.NoError(t, err) + require.Equal(t, Result{ + Addr: hostAddrs[0], + Reachability: network.ReachabilityPublic, + Status: pb.DialStatus_OK, + }, res) + for _, addr := range c.host.Addrs() { + res, err := c.GetReachability(context.Background(), newTestRequests([]ma.Multiaddr{addr}, false)) + require.NoError(t, err) + require.Equal(t, Result{ + Addr: addr, + Reachability: network.ReachabilityPublic, + Status: pb.DialStatus_OK, + }, res) + } + }) + + t.Run("dialback error", func(t *testing.T) { + c.host.RemoveStreamHandler(DialBackProtocol) + res, err := c.GetReachability(context.Background(), newTestRequests(c.host.Addrs(), false)) + require.NoError(t, err) + require.Equal(t, Result{ + Addr: hostAddrs[0], + Reachability: network.ReachabilityUnknown, + Status: pb.DialStatus_E_DIAL_BACK_ERROR, + }, res) + }) +} + +func TestRateLimiter(t *testing.T) { + cl := test.NewMockClock() + r := rateLimiter{RPM: 3, PerPeerRPM: 2, DialDataRPM: 1, now: cl.Now} + + require.True(t, r.Accept("peer1")) + + cl.AdvanceBy(10 * time.Second) + require.False(t, r.Accept("peer1")) // first request is still active + r.CompleteRequest("peer1") + + require.True(t, r.Accept("peer1")) + r.CompleteRequest("peer1") + + cl.AdvanceBy(10 * time.Second) + require.False(t, r.Accept("peer1")) + + cl.AdvanceBy(10 * time.Second) + require.True(t, r.Accept("peer2")) + r.CompleteRequest("peer2") + + cl.AdvanceBy(10 * time.Second) + require.False(t, r.Accept("peer3")) + + cl.AdvanceBy(21 * time.Second) // first request expired + require.True(t, r.Accept("peer1")) + r.CompleteRequest("peer1") + + cl.AdvanceBy(10 * time.Second) + require.True(t, r.Accept("peer3")) + r.CompleteRequest("peer3") + + cl.AdvanceBy(50 * time.Second) + require.True(t, r.Accept("peer3")) + r.CompleteRequest("peer3") + + cl.AdvanceBy(1 * time.Second) + require.False(t, r.Accept("peer3")) + + cl.AdvanceBy(10 * time.Second) + require.True(t, r.Accept("peer3")) +} + +func TestRateLimiterStress(t *testing.T) { + cl := test.NewMockClock() + for i := 0; i < 10; i++ { + r := rateLimiter{RPM: 20 + i, PerPeerRPM: 10 + i, DialDataRPM: i, now: cl.Now} + + peers := make([]peer.ID, 10+i) + for i := 0; i < len(peers); i++ { + peers[i] = peer.ID(fmt.Sprintf("peer-%d", i)) + } + peerSuccesses := make([]atomic.Int64, len(peers)) + var success, dialDataSuccesses atomic.Int64 + var wg sync.WaitGroup + for k := 0; k < 5; k++ { + wg.Add(1) + go func() { + defer wg.Done() + for i := 0; i < 2*60; i++ { + for j, p := range peers { + if r.Accept(p) { + success.Add(1) + peerSuccesses[j].Add(1) + } + if r.AcceptDialDataRequest(p) { + dialDataSuccesses.Add(1) + } + r.CompleteRequest(p) + } + cl.AdvanceBy(time.Second) + } + }() + } + wg.Wait() + if int(success.Load()) > 10*r.RPM || int(success.Load()) < 9*r.RPM { + t.Fatalf("invalid successes, %d, expected %d-%d", success.Load(), 9*r.RPM, 10*r.RPM) + } + if int(dialDataSuccesses.Load()) > 10*r.DialDataRPM || int(dialDataSuccesses.Load()) < 9*r.DialDataRPM { + t.Fatalf("invalid dial data successes, %d expected %d-%d", dialDataSuccesses.Load(), 9*r.DialDataRPM, 10*r.DialDataRPM) + } + for i := range peerSuccesses { + // We cannot check the lower bound because some peers would be hitting the global rpm limit + if int(peerSuccesses[i].Load()) > 10*r.PerPeerRPM { + t.Fatalf("too many per peer successes, PerPeerRPM=%d", r.PerPeerRPM) + } + } + cl.AdvanceBy(1 * time.Minute) + require.True(t, r.Accept(peers[0])) + // Assert lengths to check that we are cleaning up correctly + require.Equal(t, len(r.reqs), 1) + require.Equal(t, len(r.peerReqs), 1) + require.Equal(t, len(r.peerReqs[peers[0]]), 1) + require.Equal(t, len(r.dialDataReqs), 0) + require.Equal(t, len(r.ongoingReqs), 1) + } +} + +func TestReadDialData(t *testing.T) { + for N := 30_000; N < 30_010; N++ { + for msgSize := 100; msgSize < 256; msgSize++ { + r, w := io.Pipe() + msg := &pb.Message{} + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + mw := pbio.NewDelimitedWriter(w) + err := sendDialData(make([]byte, msgSize), N, mw, msg) + if err != nil { + t.Error(err) + } + mw.Close() + }() + err := readDialData(N, r) + require.NoError(t, err) + wg.Wait() + } + + for msgSize := 1000; msgSize < 1256; msgSize++ { + r, w := io.Pipe() + msg := &pb.Message{} + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + mw := pbio.NewDelimitedWriter(w) + err := sendDialData(make([]byte, msgSize), N, mw, msg) + if err != nil { + t.Error(err) + } + mw.Close() + }() + err := readDialData(N, r) + require.NoError(t, err) + wg.Wait() + } + } +} + +func FuzzServerDialRequest(f *testing.F) { + a := newAutoNAT(f, nil, allowPrivateAddrs, WithServerRateLimit(math.MaxInt32, math.MaxInt32, math.MaxInt32)) + c := newAutoNAT(f, nil) + idAndWait(f, c, a) + // reduce the streamTimeout before running this. TODO: fix this + f.Fuzz(func(t *testing.T, data []byte) { + s, err := c.host.NewStream(context.Background(), a.host.ID(), DialProtocol) + if err != nil { + t.Fatal(err) + } + s.SetDeadline(time.Now().Add(10 * time.Second)) + s.Write(data) + buf := make([]byte, 64) + s.Read(buf) // We only care that server didn't panic + s, err = c.host.NewStream(context.Background(), a.host.ID(), DialProtocol) + if err != nil { + t.Fatal(err) + } + + n := varint.PutUvarint(buf, uint64(len(data))) + s.SetDeadline(time.Now().Add(10 * time.Second)) + s.Write(buf[:n]) + s.Write(data) + s.Read(buf) // We only care that server didn't panic + s.Reset() + }) +} + +func FuzzReadDialData(f *testing.F) { + f.Fuzz(func(t *testing.T, numBytes int, data []byte) { + readDialData(numBytes, bytes.NewReader(data)) + }) +} + +func BenchmarkDialData(b *testing.B) { + b.ReportAllocs() + const N = 100_000 + streamBuffer := make([]byte, 2*N) + buf := bytes.NewBuffer(streamBuffer[:0]) + dialData := make([]byte, 4000) + msg := &pb.Message{} + w := pbio.NewDelimitedWriter(buf) + err := sendDialData(dialData, N, w, msg) + require.NoError(b, err) + dialDataBuf := buf.Bytes() + for i := 0; i < b.N; i++ { + err = readDialData(N, bytes.NewReader(dialDataBuf)) + require.NoError(b, err) + } +}