From 56cae129e2db4ff37ae09049d2004abd2cde3326 Mon Sep 17 00:00:00 2001 From: Maelkum Date: Mon, 5 Feb 2024 23:32:25 +0100 Subject: [PATCH 1/5] Use reuseport for websocket Update tests Invert default state - reuseport is disabled Add a test for port reuse on dial --- p2p/transport/websocket/addrs_test.go | 3 +- p2p/transport/websocket/listener.go | 22 +++++--- p2p/transport/websocket/reuseport.go | 9 ++++ p2p/transport/websocket/websocket.go | 65 +++++++++++++++++++++-- p2p/transport/websocket/websocket_test.go | 48 +++++++++++++++++ 5 files changed, 135 insertions(+), 12 deletions(-) create mode 100644 p2p/transport/websocket/reuseport.go diff --git a/p2p/transport/websocket/addrs_test.go b/p2p/transport/websocket/addrs_test.go index 3c5ba502a9..d262eedbad 100644 --- a/p2p/transport/websocket/addrs_test.go +++ b/p2p/transport/websocket/addrs_test.go @@ -69,7 +69,8 @@ func TestConvertWebsocketMultiaddrToNetAddr(t *testing.T) { } func TestListeningOnDNSAddr(t *testing.T) { - ln, err := newListener(ma.StringCast("/dns/localhost/tcp/0/ws"), nil) + wt := &WebsocketTransport{} + ln, err := wt.newListener(ma.StringCast("/dns/localhost/tcp/0/ws"), nil) require.NoError(t, err) addr := ln.Multiaddr() first, rest := ma.SplitFirst(addr) diff --git a/p2p/transport/websocket/listener.go b/p2p/transport/websocket/listener.go index d7a1b885b8..d9e27a2e18 100644 --- a/p2p/transport/websocket/listener.go +++ b/p2p/transport/websocket/listener.go @@ -14,7 +14,7 @@ import ( ) type listener struct { - nl net.Listener + nl manet.Listener server http.Server // The Go standard library sets the http.Server.TLSConfig no matter if this is a WS or WSS, // so we can't rely on checking if server.TLSConfig is set. @@ -40,7 +40,7 @@ func (pwma *parsedWebsocketMultiaddr) toMultiaddr() ma.Multiaddr { // newListener creates a new listener from a raw net.Listener. // tlsConf may be nil (for unencrypted websockets). -func newListener(a ma.Multiaddr, tlsConf *tls.Config) (*listener, error) { +func (t *WebsocketTransport) newListener(a ma.Multiaddr, tlsConf *tls.Config) (*listener, error) { parsed, err := parseWebsocketMultiaddr(a) if err != nil { return nil, err @@ -50,11 +50,16 @@ func newListener(a ma.Multiaddr, tlsConf *tls.Config) (*listener, error) { return nil, fmt.Errorf("cannot listen on wss address %s without a tls.Config", a) } - lnet, lnaddr, err := manet.DialArgs(parsed.restMultiaddr) - if err != nil { - return nil, err + var nl manet.Listener + if !t.UseReuseport() { + nl, err = manet.Listen(a) + } else { + nl, err = t.reuse.Listen(a) + // Fallback to regular listener in case of an error. + if err != nil { + nl, err = manet.Listen(a) + } } - nl, err := net.Listen(lnet, lnaddr) if err != nil { return nil, err } @@ -88,10 +93,11 @@ func newListener(a ma.Multiaddr, tlsConf *tls.Config) (*listener, error) { func (l *listener) serve() { defer close(l.closed) + list := manet.NetListener(l.nl) if !l.isWss { - l.server.Serve(l.nl) + l.server.Serve(list) } else { - l.server.ServeTLS(l.nl, "", "") + l.server.ServeTLS(list, "", "") } } diff --git a/p2p/transport/websocket/reuseport.go b/p2p/transport/websocket/reuseport.go new file mode 100644 index 0000000000..ea8bee7af8 --- /dev/null +++ b/p2p/transport/websocket/reuseport.go @@ -0,0 +1,9 @@ +package websocket + +import ( + "github.com/libp2p/go-reuseport" +) + +func reuseportIsAvailable() bool { + return reuseport.Available() +} diff --git a/p2p/transport/websocket/websocket.go b/p2p/transport/websocket/websocket.go index 5142ca97a1..f3f4059916 100644 --- a/p2p/transport/websocket/websocket.go +++ b/p2p/transport/websocket/websocket.go @@ -11,6 +11,7 @@ import ( "github.com/libp2p/go-libp2p/core/network" "github.com/libp2p/go-libp2p/core/peer" "github.com/libp2p/go-libp2p/core/transport" + "github.com/libp2p/go-libp2p/p2p/net/reuseport" ma "github.com/multiformats/go-multiaddr" mafmt "github.com/multiformats/go-multiaddr-fmt" @@ -80,6 +81,13 @@ func WithTLSConfig(conf *tls.Config) Option { } } +func EnableReuseport() Option { + return func(t *WebsocketTransport) error { + t.enableReuseport = true + return nil + } +} + // WebsocketTransport is the actual go-libp2p transport type WebsocketTransport struct { upgrader transport.Upgrader @@ -87,6 +95,9 @@ type WebsocketTransport struct { tlsClientConf *tls.Config tlsConf *tls.Config + + enableReuseport bool // Explicitly enable reuseport. + reuse reuseport.Transport } var _ transport.Transport = (*WebsocketTransport)(nil) @@ -188,6 +199,32 @@ func (t *WebsocketTransport) maDial(ctx context.Context, raddr ma.Multiaddr) (ma } isWss := wsurl.Scheme == "wss" dialer := ws.Dialer{HandshakeTimeout: 30 * time.Second} + dialer.NetDialContext = func(ctx context.Context, network string, address string) (net.Conn, error) { + + tcpAddr, err := net.ResolveTCPAddr(network, address) + if err != nil { + return nil, err + } + + maddr, err := manet.FromNetAddr(tcpAddr) + if err != nil { + return nil, err + } + + var conn manet.Conn + if t.UseReuseport() { + conn, err = t.reuse.DialContext(ctx, maddr) + } else { + var d manet.Dialer + conn, err = d.DialContext(ctx, maddr) + } + if err != nil { + return nil, err + } + + return conn, nil + } + if isWss { sni := "" sni, err = raddr.ValueForProtocol(ma.P_SNI) @@ -202,12 +239,29 @@ func (t *WebsocketTransport) maDial(ctx context.Context, raddr ma.Multiaddr) (ma ipAddr := wsurl.Host // Setting the NetDial because we already have the resolved IP address, so we don't want to do another resolution. // We set the `.Host` to the sni field so that the host header gets properly set. - dialer.NetDial = func(network, address string) (net.Conn, error) { + dialer.NetDialContext = func(ctx context.Context, network, address string) (net.Conn, error) { tcpAddr, err := net.ResolveTCPAddr(network, ipAddr) if err != nil { return nil, err } - return net.DialTCP("tcp", nil, tcpAddr) + + maddr, err := manet.FromNetAddr(tcpAddr) + if err != nil { + return nil, err + } + + var conn manet.Conn + if t.UseReuseport() { + conn, err = t.reuse.DialContext(ctx, maddr) + } else { + var d manet.Dialer + conn, err = d.DialContext(ctx, maddr) + } + if err != nil { + return nil, err + } + + return conn, nil } wsurl.Host = sni + ":" + wsurl.Port() } else { @@ -229,7 +283,7 @@ func (t *WebsocketTransport) maDial(ctx context.Context, raddr ma.Multiaddr) (ma } func (t *WebsocketTransport) maListen(a ma.Multiaddr) (manet.Listener, error) { - l, err := newListener(a, t.tlsConf) + l, err := t.newListener(a, t.tlsConf) if err != nil { return nil, err } @@ -244,3 +298,8 @@ func (t *WebsocketTransport) Listen(a ma.Multiaddr) (transport.Listener, error) } return &transportListener{Listener: t.upgrader.UpgradeListener(t, malist)}, nil } + +// UseReuseport returns true if reuseport is enabled and available. +func (t *WebsocketTransport) UseReuseport() bool { + return t.enableReuseport && reuseportIsAvailable() +} diff --git a/p2p/transport/websocket/websocket_test.go b/p2p/transport/websocket/websocket_test.go index 2023ee3528..ae45b2a40f 100644 --- a/p2p/transport/websocket/websocket_test.go +++ b/p2p/transport/websocket/websocket_test.go @@ -549,3 +549,51 @@ func TestResolveMultiaddr(t *testing.T) { }) } } + +func TestReusePortOnDial(t *testing.T) { + + // Create an endpoint that will accept connections. + // We'll use this to verify that the party initiating the connection reused port. + clientID, cu := newUpgrader(t) + client, err := New(cu, &network.NullResourceManager{}) + require.NoError(t, err) + + cliListen, err := client.Listen(ma.StringCast("/ip4/127.0.0.1/tcp/0/ws")) + require.NoError(t, err) + defer cliListen.Close() + + // Create an endpoint that will initiate connection. + _, u := newUpgrader(t) + tpt, err := New(u, &network.NullResourceManager{}, EnableReuseport()) + require.NoError(t, err) + + // Start listening. + l, err := tpt.Listen(ma.StringCast("/ip4/127.0.0.1/tcp/0/ws")) + require.NoError(t, err) + defer l.Close() + + // Take a note of the port on which we listen. This should be the address from which we dial too. + expectedAddr := l.Multiaddr() + + done := make(chan struct{}) + go func() { + defer close(done) + + conn, err := cliListen.Accept() + require.NoError(t, err) + defer conn.Close() + + remote := conn.RemoteMultiaddr() + require.Equal(t, expectedAddr, remote) + }() + + conn, err := tpt.Dial(context.Background(), cliListen.Multiaddr(), clientID) + require.NoError(t, err) + defer conn.Close() + + stream, err := conn.OpenStream(context.Background()) + require.NoError(t, err) + defer stream.Close() + + <-done +} From 794049ece621eb956dfdf0ad19fdb9f775d218ed Mon Sep 17 00:00:00 2001 From: Maelkum Date: Thu, 22 Feb 2024 20:34:43 +0100 Subject: [PATCH 2/5] Remove stream opening - not important for the test --- p2p/transport/websocket/websocket_test.go | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/p2p/transport/websocket/websocket_test.go b/p2p/transport/websocket/websocket_test.go index ae45b2a40f..c99e5ae3cc 100644 --- a/p2p/transport/websocket/websocket_test.go +++ b/p2p/transport/websocket/websocket_test.go @@ -572,7 +572,7 @@ func TestReusePortOnDial(t *testing.T) { require.NoError(t, err) defer l.Close() - // Take a note of the port on which we listen. This should be the address from which we dial too. + // Take a note of the multiaddress on which we listen. This should be the address from which we dial too. expectedAddr := l.Multiaddr() done := make(chan struct{}) @@ -583,6 +583,7 @@ func TestReusePortOnDial(t *testing.T) { require.NoError(t, err) defer conn.Close() + // The meat of this test - verify that the connection was received from the same port as the listen port recorded above. remote := conn.RemoteMultiaddr() require.Equal(t, expectedAddr, remote) }() @@ -591,9 +592,5 @@ func TestReusePortOnDial(t *testing.T) { require.NoError(t, err) defer conn.Close() - stream, err := conn.OpenStream(context.Background()) - require.NoError(t, err) - defer stream.Close() - <-done } From b27549a77c6c67900fead4ab2d8e82ef11e52d38 Mon Sep 17 00:00:00 2001 From: Maelkum Date: Fri, 23 Feb 2024 17:44:21 +0100 Subject: [PATCH 3/5] Add test for port reuse on listen --- p2p/transport/websocket/websocket_test.go | 107 ++++++++++++++++++++-- 1 file changed, 98 insertions(+), 9 deletions(-) diff --git a/p2p/transport/websocket/websocket_test.go b/p2p/transport/websocket/websocket_test.go index c99e5ae3cc..43d93fa909 100644 --- a/p2p/transport/websocket/websocket_test.go +++ b/p2p/transport/websocket/websocket_test.go @@ -16,6 +16,7 @@ import ( "net" "net/http" "strings" + "sync" "testing" "time" @@ -32,6 +33,7 @@ import ( ttransport "github.com/libp2p/go-libp2p/p2p/transport/testsuite" ma "github.com/multiformats/go-multiaddr" + manet "github.com/multiformats/go-multiaddr/net" "github.com/stretchr/testify/require" ) @@ -554,13 +556,13 @@ func TestReusePortOnDial(t *testing.T) { // Create an endpoint that will accept connections. // We'll use this to verify that the party initiating the connection reused port. - clientID, cu := newUpgrader(t) - client, err := New(cu, &network.NullResourceManager{}) + serverID, cu := newUpgrader(t) + transport, err := New(cu, &network.NullResourceManager{}) require.NoError(t, err) - cliListen, err := client.Listen(ma.StringCast("/ip4/127.0.0.1/tcp/0/ws")) + server, err := transport.Listen(ma.StringCast("/ip4/127.0.0.1/tcp/0/ws")) require.NoError(t, err) - defer cliListen.Close() + defer server.Close() // Create an endpoint that will initiate connection. _, u := newUpgrader(t) @@ -568,18 +570,18 @@ func TestReusePortOnDial(t *testing.T) { require.NoError(t, err) // Start listening. - l, err := tpt.Listen(ma.StringCast("/ip4/127.0.0.1/tcp/0/ws")) + listener, err := tpt.Listen(ma.StringCast("/ip4/127.0.0.1/tcp/0/ws")) require.NoError(t, err) - defer l.Close() + defer listener.Close() // Take a note of the multiaddress on which we listen. This should be the address from which we dial too. - expectedAddr := l.Multiaddr() + expectedAddr := listener.Multiaddr() done := make(chan struct{}) go func() { defer close(done) - conn, err := cliListen.Accept() + conn, err := server.Accept() require.NoError(t, err) defer conn.Close() @@ -588,9 +590,96 @@ func TestReusePortOnDial(t *testing.T) { require.Equal(t, expectedAddr, remote) }() - conn, err := tpt.Dial(context.Background(), cliListen.Multiaddr(), clientID) + conn, err := tpt.Dial(context.Background(), server.Multiaddr(), serverID) require.NoError(t, err) defer conn.Close() <-done } + +func TestReusePortOnListen(t *testing.T) { + + const ( + // how many connections we try to establish. + connectionCount = 20 + ) + + // Create an endpoint that will accept connections. + // We'll use this to verify that the party initiating the connection reused port. + _, cu := newUpgrader(t) + tpt, err := New(cu, &network.NullResourceManager{}, EnableReuseport()) + require.NoError(t, err) + + listener1, err := tpt.maListen(ma.StringCast("/ip4/127.0.0.1/tcp/0/ws")) + require.NoError(t, err) + + // Get the port on which we should start the second listener + addr, ok := listener1.Addr().(*net.TCPAddr) + require.True(t, ok) + + port := addr.Port + listener2, err := tpt.maListen(ma.StringCast(fmt.Sprintf("/ip4/127.0.0.1/tcp/%v/ws", port))) + require.NoError(t, err) + + listeners := []manet.Listener{listener1, listener2} + + // Record which listener accepted how many connections. + requestCount := make(map[int]int) + var lock sync.Mutex + + var connsHandled sync.WaitGroup + connsHandled.Add(connectionCount) + for i, l := range listeners { + go func(index int, listener manet.Listener) { + for { + + conn, err := listener.Accept() + if err != nil { + // Stop condition - this happens when the listener is closed. + require.ErrorIs(t, err, transport.ErrListenerClosed) + return + } + defer conn.Close() + + connsHandled.Done() + + // Record which listener accepted the connection. + lock.Lock() + requestCount[index]++ + lock.Unlock() + } + }(i, l) + } + + _, u := newUpgrader(t) + tpt2, err := New(u, &network.NullResourceManager{}) + require.NoError(t, err) + + var dialers sync.WaitGroup + dialers.Add(connectionCount) + + for i := 0; i < connectionCount; i++ { + go func() { + defer dialers.Done() + conn, err := tpt2.maDial(context.Background(), listener1.Multiaddr()) + require.NoError(t, err) + defer conn.Close() + }() + } + + // Wait for all dialers to complete. + dialers.Wait() + + // Wait for listeners to complete their part. + connsHandled.Wait() + + // Cancel listeners to unblock any further pending accepts. + listener1.Close() + listener2.Close() + + require.NotZero(t, requestCount[0], "first listener accepted no connections") + require.NotZero(t, requestCount[1], "second listener accepted no connections") + + total := requestCount[0] + requestCount[1] + require.Equal(t, connectionCount, total, "not all requests were handled") +} From 7f6ce49e2f5b14191acbacc80e6a2e0020605c0d Mon Sep 17 00:00:00 2001 From: Maelkum Date: Sat, 24 Feb 2024 14:20:58 +0100 Subject: [PATCH 4/5] Fix flaky test --- p2p/transport/websocket/websocket_test.go | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/p2p/transport/websocket/websocket_test.go b/p2p/transport/websocket/websocket_test.go index 43d93fa909..208711c1ae 100644 --- a/p2p/transport/websocket/websocket_test.go +++ b/p2p/transport/websocket/websocket_test.go @@ -629,9 +629,10 @@ func TestReusePortOnListen(t *testing.T) { var connsHandled sync.WaitGroup connsHandled.Add(connectionCount) - for i, l := range listeners { - go func(index int, listener manet.Listener) { - for { + // For both listeners spin up goroutines to accept incoming connections. + for i, listener := range listeners { + for j := 0; j < connectionCount; j++ { + go func(index int, listener manet.Listener) { conn, err := listener.Accept() if err != nil { @@ -640,19 +641,18 @@ func TestReusePortOnListen(t *testing.T) { return } defer conn.Close() - connsHandled.Done() // Record which listener accepted the connection. lock.Lock() + defer lock.Unlock() requestCount[index]++ - lock.Unlock() - } - }(i, l) + }(i, listener) + } } - _, u := newUpgrader(t) - tpt2, err := New(u, &network.NullResourceManager{}) + // Create a different transport as you cannot self-dial using reuseport. + tpt2, err := New(cu, &network.NullResourceManager{}) require.NoError(t, err) var dialers sync.WaitGroup From 0a6f05102c6f526535b594e22954be7a1def4dcd Mon Sep 17 00:00:00 2001 From: Maelkum Date: Sat, 24 Feb 2024 16:53:40 +0100 Subject: [PATCH 5/5] Add comments documenting the behavior on Windows and Linux --- p2p/transport/websocket/websocket_test.go | 21 +++++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/p2p/transport/websocket/websocket_test.go b/p2p/transport/websocket/websocket_test.go index 208711c1ae..c26c8f83e5 100644 --- a/p2p/transport/websocket/websocket_test.go +++ b/p2p/transport/websocket/websocket_test.go @@ -15,6 +15,7 @@ import ( "math/big" "net" "net/http" + "runtime" "strings" "sync" "testing" @@ -677,8 +678,24 @@ func TestReusePortOnListen(t *testing.T) { listener1.Close() listener2.Close() - require.NotZero(t, requestCount[0], "first listener accepted no connections") - require.NotZero(t, requestCount[1], "second listener accepted no connections") + // For Windows we can't make any assumptions with regards to connection distribution: + // "Once the second socket has successfully bound, the behavior for all sockets bound to that port is indeterminate. + // For example, if all of the sockets on the same port provide TCP service, any incoming TCP connection requests over + // the port cannot be guaranteed to be handled by the correct socket — the behavior is non-deterministic." + // => https://learn.microsoft.com/en-us/windows/win32/winsock/using-so-reuseaddr-and-so-exclusiveaddruse + + // For MacOS (FreeBSD) it's the last socket to bind that receives the connections. Anegdotal evidence but: + // "Ironically it's the BSD semantics which support seamless server restarts. In my tests OS X's behavior (which I presume + // is identical to FreeBSD and other BSDs) is that the last socket to bind is the only one to receive new connections." + // => https://lwn.net/Articles/542629/ + // On FreeBSD it's the SO_REUSEPORT_LB variant that provides load balancing. + + // For Linux only - verify that both listeners handled some connections. + if runtime.GOOS == "linux" { + // We're not trying to verify an even distribution as it's not a perfect world. + require.NotZero(t, requestCount[0], "first listener accepted no connections") + require.NotZero(t, requestCount[1], "second listener accepted no connections") + } total := requestCount[0] + requestCount[1] require.Equal(t, connectionCount, total, "not all requests were handled")