From 07f78ef4ef88c4a187b9b7c769dcb6ae6c51bf10 Mon Sep 17 00:00:00 2001 From: Marco Munizaga Date: Wed, 4 Dec 2024 10:42:41 -0800 Subject: [PATCH] fix(sampledconn): Correctly handle slow bytes and closed conns --- .../internal/sampledconn/sampledconn_test.go | 101 ++++++++++++++++++ .../internal/sampledconn/sampledconn_unix.go | 24 ++--- 2 files changed, 110 insertions(+), 15 deletions(-) diff --git a/p2p/transport/tcpreuse/internal/sampledconn/sampledconn_test.go b/p2p/transport/tcpreuse/internal/sampledconn/sampledconn_test.go index d5b31009e2..e3b8c83a33 100644 --- a/p2p/transport/tcpreuse/internal/sampledconn/sampledconn_test.go +++ b/p2p/transport/tcpreuse/internal/sampledconn/sampledconn_test.go @@ -10,6 +10,7 @@ import ( manet "github.com/multiformats/go-multiaddr/net" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestSampledConn(t *testing.T) { @@ -76,3 +77,103 @@ func TestSampledConn(t *testing.T) { }) } } + +func spawnServerAndClientConn(t *testing.T) (serverConn manet.Conn, clientConn manet.Conn) { + serverConnChan := make(chan manet.Conn, 1) + + listener, err := manet.Listen(ma.StringCast("/ip4/127.0.0.1/tcp/0")) + assert.NoError(t, err) + defer listener.Close() + + serverAddr := listener.Multiaddr() + + // Server goroutine + go func() { + conn, err := listener.Accept() + assert.NoError(t, err) + serverConnChan <- conn + }() + + // Give the server a moment to start + time.Sleep(100 * time.Millisecond) + + // Create a TCP client + clientConn, err = manet.Dial(serverAddr) + assert.NoError(t, err) + + return <-serverConnChan, clientConn +} + +func TestHandleNoBytes(t *testing.T) { + serverConn, clientConn := spawnServerAndClientConn(t) + defer clientConn.Close() + + // Server goroutine + go func() { + serverConn.Close() + }() + _, _, err := PeekBytes(clientConn.(interface { + manet.Conn + syscall.Conn + })) + assert.ErrorIs(t, err, io.EOF) +} + +func TestHandle1ByteAndClose(t *testing.T) { + serverConn, clientConn := spawnServerAndClientConn(t) + defer clientConn.Close() + + // Server goroutine + go func() { + defer serverConn.Close() + _, err := serverConn.Write([]byte("h")) + assert.NoError(t, err) + }() + _, _, err := PeekBytes(clientConn.(interface { + manet.Conn + syscall.Conn + })) + assert.ErrorIs(t, err, io.EOF) +} + +func TestSlowBytes(t *testing.T) { + serverConn, clientConn := spawnServerAndClientConn(t) + + interval := 100 * time.Millisecond + + // Server goroutine + go func() { + defer serverConn.Close() + + // Write < 3 bytes + time.Sleep(interval) + _, err := serverConn.Write([]byte("h")) + assert.NoError(t, err) + time.Sleep(interval) + _, err = serverConn.Write([]byte("e")) + assert.NoError(t, err) + time.Sleep(interval) + _, err = serverConn.Write([]byte("l")) + assert.NoError(t, err) + time.Sleep(interval) + _, err = serverConn.Write([]byte("lo")) + assert.NoError(t, err) + }() + + defer clientConn.Close() + + err := clientConn.SetReadDeadline(time.Now().Add(interval * 10)) + require.NoError(t, err) + + peeked, clientConn, err := PeekBytes(clientConn.(interface { + manet.Conn + syscall.Conn + })) + assert.NoError(t, err) + assert.Equal(t, "hel", string(peeked[:])) + + buf := make([]byte, 5) + _, err = io.ReadFull(clientConn, buf) + assert.NoError(t, err) + assert.Equal(t, "hello", string(buf)) +} diff --git a/p2p/transport/tcpreuse/internal/sampledconn/sampledconn_unix.go b/p2p/transport/tcpreuse/internal/sampledconn/sampledconn_unix.go index 9847e8d4be..7c7eaeb485 100644 --- a/p2p/transport/tcpreuse/internal/sampledconn/sampledconn_unix.go +++ b/p2p/transport/tcpreuse/internal/sampledconn/sampledconn_unix.go @@ -4,6 +4,7 @@ package sampledconn import ( "errors" + "io" "syscall" ) @@ -15,27 +16,20 @@ func OSPeekConn(conn syscall.Conn) (PeekedBytes, error) { return s, err } - readBytes := 0 var readErr error + var n int err = rawConn.Read(func(fd uintptr) bool { - for readBytes < peekSize { - var n int - n, _, readErr = syscall.Recvfrom(int(fd), s[readBytes:], syscall.MSG_PEEK) - if errors.Is(readErr, syscall.EAGAIN) { - return false - } - if readErr != nil { - return true - } - readBytes += n - } - return true + n, _, readErr = syscall.Recvfrom(int(fd), s[:], syscall.MSG_PEEK|syscall.MSG_WAITALL) + return !errors.Is(readErr, syscall.EAGAIN) }) + if err != nil { + return s, err + } if readErr != nil { return s, readErr } - if err != nil { - return s, err + if n < peekSize { + return s, io.EOF } return s, nil