diff --git a/p2p/transport/quicreuse/connmgr.go b/p2p/transport/quicreuse/connmgr.go index c3aa0fa046..def7c5da56 100644 --- a/p2p/transport/quicreuse/connmgr.go +++ b/p2p/transport/quicreuse/connmgr.go @@ -4,6 +4,7 @@ import ( "context" "crypto/tls" "errors" + "io" "net" "sync" @@ -15,6 +16,22 @@ import ( quicmetrics "github.com/quic-go/quic-go/metrics" ) +type QUICListener interface { + Accept(ctx context.Context) (quic.Connection, error) + Close() error + Addr() net.Addr +} + +var _ QUICListener = &quic.Listener{} + +type QUICTransport interface { + Listen(tlsConf *tls.Config, conf *quic.Config) (QUICListener, error) + Dial(ctx context.Context, addr net.Addr, tlsConf *tls.Config, conf *quic.Config) (quic.Connection, error) + WriteTo(b []byte, addr net.Addr) (int, error) + ReadNonQUICPacket(ctx context.Context, b []byte) (int, net.Addr, error) + io.Closer +} + type ConnManager struct { reuseUDP4 *reuse reuseUDP6 *reuse @@ -101,6 +118,32 @@ func (c *ConnManager) getReuse(network string) (*reuse, error) { } } +// LendTransport is an advanced method used to lend an existing QUICTransport +// to the ConnManager. The ConnManager will close the returned channel when it +// is done with the transport, so that the owner may safely close the transport. +func (c *ConnManager) LendTransport(network string, tr QUICTransport, conn net.PacketConn) (<-chan struct{}, error) { + c.quicListenersMu.Lock() + defer c.quicListenersMu.Unlock() + + localAddr, ok := conn.LocalAddr().(*net.UDPAddr) + if !ok { + return nil, errors.New("expected a conn.LocalAddr() to return a *net.UDPAddr") + } + + refCountedTr := &refcountedTransport{ + QUICTransport: tr, + packetConn: conn, + borrowDoneSignal: make(chan struct{}), + } + + var reuse *reuse + reuse, err := c.getReuse(network) + if err != nil { + return nil, err + } + return refCountedTr.borrowDoneSignal, reuse.AddTransport(refCountedTr, localAddr) +} + func (c *ConnManager) ListenQUIC(addr ma.Multiaddr, tlsConf *tls.Config, allowWindowIncrease func(conn quic.Connection, delta uint64) bool) (Listener, error) { return c.ListenQUICAndAssociate(nil, addr, tlsConf, allowWindowIncrease) } @@ -175,7 +218,7 @@ func (c *ConnManager) SharedNonQUICPacketConn(network string, laddr *net.UDPAddr ctx: ctx, ctxCancel: cancel, owningTransport: t, - tr: &t.Transport, + tr: t.QUICTransport, }, nil } return nil, errors.New("expected to be able to share with a QUIC listener, but the QUIC listener is not using a refcountedTransport. `DisableReuseport` should not be set") @@ -201,10 +244,12 @@ func (c *ConnManager) transportForListen(association any, network string, laddr } return &singleOwnerTransport{ packetConn: conn, - Transport: quic.Transport{ - Conn: conn, - StatelessResetKey: &c.srk, - TokenGeneratorKey: &c.tokenKey, + Transport: &wrappedQUICTransport{ + &quic.Transport{ + Conn: conn, + StatelessResetKey: &c.srk, + TokenGeneratorKey: &c.tokenKey, + }, }, }, nil } @@ -279,7 +324,7 @@ func (c *ConnManager) TransportWithAssociationForDial(association any, network s if err != nil { return nil, err } - return &singleOwnerTransport{Transport: quic.Transport{Conn: conn, StatelessResetKey: &c.srk}, packetConn: conn}, nil + return &singleOwnerTransport{Transport: &wrappedQUICTransport{&quic.Transport{Conn: conn, StatelessResetKey: &c.srk}}, packetConn: conn}, nil } func (c *ConnManager) Protocols() []int { @@ -299,3 +344,11 @@ func (c *ConnManager) Close() error { func (c *ConnManager) ClientConfig() *quic.Config { return c.clientConfig } + +type wrappedQUICTransport struct { + *quic.Transport +} + +func (t *wrappedQUICTransport) Listen(tlsConf *tls.Config, conf *quic.Config) (QUICListener, error) { + return t.Transport.Listen(tlsConf, conf) +} diff --git a/p2p/transport/quicreuse/connmgr_test.go b/p2p/transport/quicreuse/connmgr_test.go index 8e7da2cd7f..51646bac98 100644 --- a/p2p/transport/quicreuse/connmgr_test.go +++ b/p2p/transport/quicreuse/connmgr_test.go @@ -97,7 +97,7 @@ func TestConnectionPassedToQUICForListening(t *testing.T) { quicTr, err := cm.transportForListen(nil, netw, naddr) require.NoError(t, err) defer quicTr.Close() - if _, ok := quicTr.(*singleOwnerTransport).Transport.Conn.(quic.OOBCapablePacketConn); !ok { + if _, ok := quicTr.(*singleOwnerTransport).packetConn.(quic.OOBCapablePacketConn); !ok { t.Fatal("connection passed to quic-go cannot be type asserted to a *net.UDPConn") } } @@ -156,7 +156,7 @@ func TestConnectionPassedToQUICForDialing(t *testing.T) { require.NoError(t, err, "dial error") defer quicTr.Close() - if _, ok := quicTr.(*singleOwnerTransport).Transport.Conn.(quic.OOBCapablePacketConn); !ok { + if _, ok := quicTr.(*singleOwnerTransport).packetConn.(quic.OOBCapablePacketConn); !ok { t.Fatal("connection passed to quic-go cannot be type asserted to a *net.UDPConn") } } @@ -257,3 +257,61 @@ func testListener(t *testing.T, enableReuseport bool) { checkClosed(t, cm) } + +func TestExternalTransport(t *testing.T) { + conn, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4zero}) + require.NoError(t, err) + defer conn.Close() + port := conn.LocalAddr().(*net.UDPAddr).Port + tr := &quic.Transport{Conn: conn} + defer tr.Close() + + cm, err := NewConnManager(quic.StatelessResetKey{}, quic.TokenGeneratorKey{}) + require.NoError(t, err) + doneWithTr, err := cm.LendTransport("udp4", &wrappedQUICTransport{tr}, conn) + require.NoError(t, err) + + // make sure this transport is used when listening on the same port + ln, err := cm.ListenQUICAndAssociate( + "quic", + ma.StringCast(fmt.Sprintf("/ip4/0.0.0.0/udp/%d", port)), + &tls.Config{NextProtos: []string{"libp2p"}}, + func(quic.Connection, uint64) bool { return false }, + ) + require.NoError(t, err) + defer ln.Close() + require.Equal(t, port, ln.Addr().(*net.UDPAddr).Port) + + // make sure this transport is used when dialing out + udpLn, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1)}) + require.NoError(t, err) + defer udpLn.Close() + addrChan := make(chan net.Addr, 1) + go func() { + _, addr, _ := udpLn.ReadFrom(make([]byte, 2000)) + addrChan <- addr + }() + ctx, cancel := context.WithTimeout(context.Background(), 20*time.Millisecond) + defer cancel() + _, err = cm.DialQUIC( + ctx, + ma.StringCast(fmt.Sprintf("/ip4/127.0.0.1/udp/%d/quic-v1", udpLn.LocalAddr().(*net.UDPAddr).Port)), + &tls.Config{NextProtos: []string{"libp2p"}}, + func(quic.Connection, uint64) bool { return false }, + ) + require.ErrorIs(t, err, context.DeadlineExceeded) + + select { + case addr := <-addrChan: + require.Equal(t, port, addr.(*net.UDPAddr).Port) + case <-time.After(time.Second): + t.Fatal("timeout") + } + + cm.Close() + select { + case <-doneWithTr: + default: + t.Fatal("doneWithTr not closed") + } +} diff --git a/p2p/transport/quicreuse/listener.go b/p2p/transport/quicreuse/listener.go index 4ee20042d3..42f1d00cef 100644 --- a/p2p/transport/quicreuse/listener.go +++ b/p2p/transport/quicreuse/listener.go @@ -29,7 +29,7 @@ type protoConf struct { } type quicListener struct { - l *quic.Listener + l QUICListener transport refCountedQuicTransport running chan struct{} addrs []ma.Multiaddr diff --git a/p2p/transport/quicreuse/nonquic_packetconn.go b/p2p/transport/quicreuse/nonquic_packetconn.go index 2f950e76a1..833bd5804a 100644 --- a/p2p/transport/quicreuse/nonquic_packetconn.go +++ b/p2p/transport/quicreuse/nonquic_packetconn.go @@ -4,8 +4,6 @@ import ( "context" "net" "time" - - "github.com/quic-go/quic-go" ) // nonQUICPacketConn is a net.PacketConn that can be used to read and write @@ -13,7 +11,7 @@ import ( // other transports like WebRTC. type nonQUICPacketConn struct { owningTransport refCountedQuicTransport - tr *quic.Transport + tr QUICTransport ctx context.Context ctxCancel context.CancelFunc readCtx context.Context @@ -32,7 +30,7 @@ func (n *nonQUICPacketConn) Close() error { // LocalAddr implements net.PacketConn. func (n *nonQUICPacketConn) LocalAddr() net.Addr { - return n.tr.Conn.LocalAddr() + return n.owningTransport.LocalAddr() } // ReadFrom implements net.PacketConn. diff --git a/p2p/transport/quicreuse/reuse.go b/p2p/transport/quicreuse/reuse.go index c6fc611331..e329ea49dd 100644 --- a/p2p/transport/quicreuse/reuse.go +++ b/p2p/transport/quicreuse/reuse.go @@ -3,6 +3,8 @@ package quicreuse import ( "context" "crypto/tls" + "errors" + "fmt" "net" "sync" "time" @@ -25,23 +27,30 @@ type refCountedQuicTransport interface { IncreaseCount() Dial(ctx context.Context, addr net.Addr, tlsConf *tls.Config, conf *quic.Config) (quic.Connection, error) - Listen(tlsConf *tls.Config, conf *quic.Config) (*quic.Listener, error) + Listen(tlsConf *tls.Config, conf *quic.Config) (QUICListener, error) } type singleOwnerTransport struct { - quic.Transport + Transport QUICTransport // Used to write packets directly around QUIC. packetConn net.PacketConn } +var _ QUICTransport = &singleOwnerTransport{} + func (c *singleOwnerTransport) IncreaseCount() {} -func (c *singleOwnerTransport) DecreaseCount() { - c.Transport.Close() +func (c *singleOwnerTransport) DecreaseCount() { c.Transport.Close() } +func (c *singleOwnerTransport) LocalAddr() net.Addr { + return c.packetConn.LocalAddr() } -func (c *singleOwnerTransport) LocalAddr() net.Addr { - return c.Transport.Conn.LocalAddr() +func (c *singleOwnerTransport) Dial(ctx context.Context, addr net.Addr, tlsConf *tls.Config, conf *quic.Config) (quic.Connection, error) { + return c.Transport.Dial(ctx, addr, tlsConf, conf) +} + +func (c *singleOwnerTransport) ReadNonQUICPacket(ctx context.Context, b []byte) (int, net.Addr, error) { + return c.Transport.ReadNonQUICPacket(ctx, b) } func (c *singleOwnerTransport) Close() error { @@ -54,6 +63,10 @@ func (c *singleOwnerTransport) WriteTo(b []byte, addr net.Addr) (int, error) { return c.Transport.WriteTo(b, addr) } +func (c *singleOwnerTransport) Listen(tlsConf *tls.Config, conf *quic.Config) (QUICListener, error) { + return c.Transport.Listen(tlsConf, conf) +} + // Constant. Defined as variables to simplify testing. var ( garbageCollectInterval = 30 * time.Second @@ -61,7 +74,7 @@ var ( ) type refcountedTransport struct { - quic.Transport + QUICTransport // Used to write packets directly around QUIC. packetConn net.PacketConn @@ -70,6 +83,11 @@ type refcountedTransport struct { refCount int unusedSince time.Time + // Only set for transports we are borrowing. + // If set, we will _never_ close the underlying transport. We only close this + // channel to signal to the owner that we are done with it. + borrowDoneSignal chan struct{} + assocations map[any]struct{} } @@ -109,17 +127,24 @@ func (c *refcountedTransport) IncreaseCount() { } func (c *refcountedTransport) Close() error { - // TODO(when we drop support for go 1.19) use errors.Join - c.Transport.Close() - return c.packetConn.Close() + if c.borrowDoneSignal != nil { + close(c.borrowDoneSignal) + return nil + } + + return errors.Join(c.QUICTransport.Close(), c.packetConn.Close()) } func (c *refcountedTransport) WriteTo(b []byte, addr net.Addr) (int, error) { - return c.Transport.WriteTo(b, addr) + return c.QUICTransport.WriteTo(b, addr) } func (c *refcountedTransport) LocalAddr() net.Addr { - return c.Transport.Conn.LocalAddr() + return c.packetConn.LocalAddr() +} + +func (c *refcountedTransport) Listen(tlsConf *tls.Config, conf *quic.Config) (QUICListener, error) { + return c.QUICTransport.Listen(tlsConf, conf) } func (c *refcountedTransport) DecreaseCount() { @@ -302,15 +327,34 @@ func (r *reuse) transportForDialLocked(association any, network string, source * if err != nil { return nil, err } - tr := &refcountedTransport{Transport: quic.Transport{ - Conn: conn, - StatelessResetKey: r.statelessResetKey, - TokenGeneratorKey: r.tokenGeneratorKey, - }, packetConn: conn} + tr := &refcountedTransport{ + QUICTransport: &wrappedQUICTransport{ + Transport: &quic.Transport{ + Conn: conn, + StatelessResetKey: r.statelessResetKey, + TokenGeneratorKey: r.tokenGeneratorKey, + }, + }, + packetConn: conn, + } r.globalDialers[conn.LocalAddr().(*net.UDPAddr).Port] = tr return tr, nil } +func (r *reuse) AddTransport(tr *refcountedTransport, laddr *net.UDPAddr) error { + r.mutex.Lock() + defer r.mutex.Unlock() + + if !laddr.IP.IsUnspecified() { + return errors.New("adding transport for specific IP not supported") + } + if _, ok := r.globalDialers[laddr.Port]; ok { + return fmt.Errorf("already have global dialer for port %d", laddr.Port) + } + r.globalDialers[laddr.Port] = tr + return nil +} + func (r *reuse) TransportForListen(network string, laddr *net.UDPAddr) (*refcountedTransport, error) { r.mutex.Lock() defer r.mutex.Unlock() @@ -351,9 +395,11 @@ func (r *reuse) TransportForListen(network string, laddr *net.UDPAddr) (*refcoun } localAddr := conn.LocalAddr().(*net.UDPAddr) tr := &refcountedTransport{ - Transport: quic.Transport{ - Conn: conn, - StatelessResetKey: r.statelessResetKey, + QUICTransport: &wrappedQUICTransport{ + Transport: &quic.Transport{ + Conn: conn, + StatelessResetKey: r.statelessResetKey, + }, }, packetConn: conn, }