Skip to content

Commit

Permalink
fix(sampledconn): Correctly handle slow bytes and closed conns
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcoPolo committed Dec 4, 2024
1 parent 9024f8e commit 07f78ef
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 15 deletions.
101 changes: 101 additions & 0 deletions p2p/transport/tcpreuse/internal/sampledconn/sampledconn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
manet "github.com/multiformats/go-multiaddr/net"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestSampledConn(t *testing.T) {
Expand Down Expand Up @@ -76,3 +77,103 @@ func TestSampledConn(t *testing.T) {
})
}
}

func spawnServerAndClientConn(t *testing.T) (serverConn manet.Conn, clientConn manet.Conn) {
serverConnChan := make(chan manet.Conn, 1)

listener, err := manet.Listen(ma.StringCast("/ip4/127.0.0.1/tcp/0"))
assert.NoError(t, err)
defer listener.Close()

serverAddr := listener.Multiaddr()

// Server goroutine
go func() {
conn, err := listener.Accept()
assert.NoError(t, err)
serverConnChan <- conn
}()

// Give the server a moment to start
time.Sleep(100 * time.Millisecond)

// Create a TCP client
clientConn, err = manet.Dial(serverAddr)
assert.NoError(t, err)

return <-serverConnChan, clientConn
}

func TestHandleNoBytes(t *testing.T) {
serverConn, clientConn := spawnServerAndClientConn(t)
defer clientConn.Close()

// Server goroutine
go func() {
serverConn.Close()
}()
_, _, err := PeekBytes(clientConn.(interface {
manet.Conn
syscall.Conn
}))
assert.ErrorIs(t, err, io.EOF)
}

func TestHandle1ByteAndClose(t *testing.T) {
serverConn, clientConn := spawnServerAndClientConn(t)
defer clientConn.Close()

// Server goroutine
go func() {
defer serverConn.Close()
_, err := serverConn.Write([]byte("h"))
assert.NoError(t, err)
}()
_, _, err := PeekBytes(clientConn.(interface {
manet.Conn
syscall.Conn
}))
assert.ErrorIs(t, err, io.EOF)
}

func TestSlowBytes(t *testing.T) {
serverConn, clientConn := spawnServerAndClientConn(t)

interval := 100 * time.Millisecond

// Server goroutine
go func() {
defer serverConn.Close()

// Write < 3 bytes
time.Sleep(interval)
_, err := serverConn.Write([]byte("h"))
assert.NoError(t, err)
time.Sleep(interval)
_, err = serverConn.Write([]byte("e"))
assert.NoError(t, err)
time.Sleep(interval)
_, err = serverConn.Write([]byte("l"))
assert.NoError(t, err)
time.Sleep(interval)
_, err = serverConn.Write([]byte("lo"))
assert.NoError(t, err)
}()

defer clientConn.Close()

err := clientConn.SetReadDeadline(time.Now().Add(interval * 10))
require.NoError(t, err)

peeked, clientConn, err := PeekBytes(clientConn.(interface {
manet.Conn
syscall.Conn
}))
assert.NoError(t, err)
assert.Equal(t, "hel", string(peeked[:]))

buf := make([]byte, 5)
_, err = io.ReadFull(clientConn, buf)
assert.NoError(t, err)
assert.Equal(t, "hello", string(buf))
}
24 changes: 9 additions & 15 deletions p2p/transport/tcpreuse/internal/sampledconn/sampledconn_unix.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ package sampledconn

import (
"errors"
"io"
"syscall"
)

Expand All @@ -15,27 +16,20 @@ func OSPeekConn(conn syscall.Conn) (PeekedBytes, error) {
return s, err
}

readBytes := 0
var readErr error
var n int
err = rawConn.Read(func(fd uintptr) bool {
for readBytes < peekSize {
var n int
n, _, readErr = syscall.Recvfrom(int(fd), s[readBytes:], syscall.MSG_PEEK)
if errors.Is(readErr, syscall.EAGAIN) {
return false
}
if readErr != nil {
return true
}
readBytes += n
}
return true
n, _, readErr = syscall.Recvfrom(int(fd), s[:], syscall.MSG_PEEK|syscall.MSG_WAITALL)
return !errors.Is(readErr, syscall.EAGAIN)
})
if err != nil {
return s, err
}
if readErr != nil {
return s, readErr
}
if err != nil {
return s, err
if n < peekSize {
return s, io.EOF
}

return s, nil
Expand Down

0 comments on commit 07f78ef

Please sign in to comment.