Skip to content

Commit

Permalink
webrtc: run onDone callback immediately on close
Browse files Browse the repository at this point in the history
  • Loading branch information
sukunrt committed Mar 9, 2024
1 parent 91e1025 commit c192570
Show file tree
Hide file tree
Showing 4 changed files with 101 additions and 74 deletions.
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ require (
github.com/pion/datachannel v1.5.5
github.com/pion/ice/v2 v2.3.11
github.com/pion/logging v0.2.2
github.com/pion/sctp v1.8.9
github.com/pion/stun v0.6.1
github.com/pion/webrtc/v3 v3.2.23
github.com/prometheus/client_golang v1.18.0
Expand Down Expand Up @@ -105,7 +106,6 @@ require (
github.com/pion/randutil v0.1.0 // indirect
github.com/pion/rtcp v1.2.13 // indirect
github.com/pion/rtp v1.8.3 // indirect
github.com/pion/sctp v1.8.9 // indirect
github.com/pion/sdp/v3 v3.0.6 // indirect
github.com/pion/srtp/v2 v2.0.18 // indirect
github.com/pion/transport/v2 v2.2.4 // indirect
Expand Down
50 changes: 22 additions & 28 deletions p2p/transport/webrtc/stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,8 @@ type stream struct {
// SetReadDeadline
// See: https://github.com/pion/sctp/pull/290
controlMessageReaderEndTime time.Time
controlMessageReaderDone sync.WaitGroup

onDoneOnce sync.Once
onDone func()
id uint16 // for logging purposes
dataChannel *datachannel.DataChannel
Expand All @@ -118,8 +118,6 @@ func newStream(
dataChannel: rwc.(*datachannel.DataChannel),
onDone: onDone,
}
// released when the controlMessageReader goroutine exits
s.controlMessageReaderDone.Add(1)
s.dataChannel.SetBufferedAmountLowThreshold(bufferedAmountLowThreshold)
s.dataChannel.OnBufferedAmountLow(func() {
s.notifyWriteStateChanged()
Expand All @@ -135,7 +133,7 @@ func (s *stream) Close() error {
if isClosed {
return nil
}

defer s.cleanup()
closeWriteErr := s.CloseWrite()
closeReadErr := s.CloseRead()
if closeWriteErr != nil || closeReadErr != nil {
Expand All @@ -147,10 +145,6 @@ func (s *stream) Close() error {
if s.controlMessageReaderEndTime.IsZero() {
s.controlMessageReaderEndTime = time.Now().Add(maxFINACKWait)
s.setDataChannelReadDeadline(time.Now().Add(-1 * time.Hour))
go func() {
s.controlMessageReaderDone.Wait()
s.cleanup()
}()
}
s.mx.Unlock()
return nil
Expand Down Expand Up @@ -227,17 +221,10 @@ func (s *stream) spawnControlMessageReader() {
s.controlMessageReaderOnce.Do(func() {
// Spawn a goroutine to ensure that we're not holding any locks
go func() {
defer s.controlMessageReaderDone.Done()
// cleanup the sctp deadline timer goroutine
defer s.setDataChannelReadDeadline(time.Time{})

setDeadline := func() bool {
if s.controlMessageReaderEndTime.IsZero() || time.Now().Before(s.controlMessageReaderEndTime) {
s.setDataChannelReadDeadline(s.controlMessageReaderEndTime)
return true
}
return false
}
defer s.dataChannel.Close()

// Unblock any Read call waiting on reader.ReadMsg
s.setDataChannelReadDeadline(time.Now().Add(-1 * time.Hour))
Expand All @@ -256,12 +243,22 @@ func (s *stream) spawnControlMessageReader() {
s.processIncomingFlag(s.nextMessage.Flag)
s.nextMessage = nil
}
for s.closeForShutdownErr == nil &&
s.sendState != sendStateDataReceived && s.sendState != sendStateReset {
var msg pb.Message
if !setDeadline() {
var msg pb.Message
for {
// Connection closed. No need to cleanup the data channel.
if s.closeForShutdownErr != nil {
return
}
// Write half of the stream completed.
if s.sendState == sendStateDataReceived || s.sendState == sendStateReset {
return
}
// FIN_ACK wait deadling exceeded.
if !s.controlMessageReaderEndTime.IsZero() && time.Now().After(s.controlMessageReaderEndTime) {
return
}

s.setDataChannelReadDeadline(s.controlMessageReaderEndTime)
s.mx.Unlock()
err := s.reader.ReadMsg(&msg)
s.mx.Lock()
Expand All @@ -281,12 +278,9 @@ func (s *stream) spawnControlMessageReader() {
}

func (s *stream) cleanup() {
// Even if we close the datachannel pion keeps a reference to the datachannel around.
// Remove the onBufferedAmountLow callback to ensure that we at least garbage collect
// memory we allocated for this stream.
s.dataChannel.OnBufferedAmountLow(nil)
s.dataChannel.Close()
if s.onDone != nil {
s.onDone()
}
s.onDoneOnce.Do(func() {
if s.onDone != nil {
s.onDone()
}
})
}
109 changes: 72 additions & 37 deletions p2p/transport/webrtc/stream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,12 @@ import (

"github.com/libp2p/go-libp2p/p2p/transport/webrtc/pb"
"github.com/libp2p/go-msgio/pbio"
"google.golang.org/protobuf/proto"

"github.com/libp2p/go-libp2p/core/network"

"github.com/pion/datachannel"
"github.com/pion/sctp"
"github.com/pion/webrtc/v3"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -98,6 +100,50 @@ func getDetachedDataChannels(t *testing.T) (detachedChan, detachedChan) {
return <-answerChan, <-offerRWCChan
}

// checkDataChannelClosed checks if the datachannel has been closed.
// It sends empty messages on the data channel to check if the channel is still open.
// The control message reader goroutine depends on exclusive access to datachannel.Read
// so we have to depend on Write to determine whether the channel has been closed.
func checkDataChannelOpen(t *testing.T, dc *datachannel.DataChannel) {
t.Helper()
emptyMsg := &pb.Message{}
msg, err := proto.Marshal(emptyMsg)
if err != nil {
t.Fatal("unexpected mashalling error", err)
}
for i := 0; i < 3; i++ {
_, err := dc.Write(msg)
if err != nil {
t.Fatal("unexpected write err: ", err)
}
time.Sleep(50 * time.Millisecond)
}
}

// checkDataChannelClosed checks if the datachannel has been closed.
// It sends empty messages on the data channel to check if the channel has been closed.
// The control message reader goroutine depends on exclusive access to datachannel.Read
// so we have to depend on Write to determine whether the channel has been closed.
func checkDataChannelClosed(t *testing.T, dc *datachannel.DataChannel) {
t.Helper()
emptyMsg := &pb.Message{}
msg, err := proto.Marshal(emptyMsg)
if err != nil {
t.Fatal("unexpected mashalling error", err)
}
for i := 0; i < 5; i++ {
_, err := dc.Write(msg)
if err != nil {
if errors.Is(err, sctp.ErrStreamClosed) {
return
} else {
t.Fatal("unexpected write err: ", err)
}
}
time.Sleep(50 * time.Millisecond)
}
}

func TestStreamSimpleReadWriteClose(t *testing.T) {
client, server := getDetachedDataChannels(t)

Expand Down Expand Up @@ -357,27 +403,22 @@ func TestStreamCloseAfterFINACK(t *testing.T) {
serverStr := newStream(server.dc, server.rwc, func() {})

go func() {
done <- true
err := clientStr.Close()
assert.NoError(t, err)
}()
<-done

select {
case <-done:
t.Fatalf("Close should not have completed without processing FIN_ACK")
case <-time.After(200 * time.Millisecond):
t.Fatalf("Close should signal OnDone immediately")
}

// Reading FIN_ACK on server should trigger data channel close on the client
b := make([]byte, 1)
_, err := serverStr.Read(b)
require.Error(t, err)
require.ErrorIs(t, err, io.EOF)
select {
case <-done:
case <-time.After(3 * time.Second):
t.Errorf("Close should have completed")
}
checkDataChannelClosed(t, client.rwc.(*datachannel.DataChannel))
}

// TestStreamFinAckAfterStopSending tests that FIN_ACK is sent even after the write half
Expand All @@ -400,8 +441,8 @@ func TestStreamFinAckAfterStopSending(t *testing.T) {

select {
case <-done:
t.Errorf("Close should not have completed without processing FIN_ACK")
case <-time.After(500 * time.Millisecond):
t.Errorf("Close should signal onDone immediately")
}

// serverStr has write half closed and read half open
Expand All @@ -410,11 +451,8 @@ func TestStreamFinAckAfterStopSending(t *testing.T) {
_, err := serverStr.Read(b)
require.NoError(t, err)
serverStr.Close() // Sends stop_sending, fin
select {
case <-done:
case <-time.After(5 * time.Second):
t.Fatalf("Close should have completed")
}
checkDataChannelClosed(t, server.rwc.(*datachannel.DataChannel))
checkDataChannelClosed(t, client.rwc.(*datachannel.DataChannel))
}

func TestStreamConcurrentClose(t *testing.T) {
Expand Down Expand Up @@ -446,26 +484,35 @@ func TestStreamConcurrentClose(t *testing.T) {
case <-time.After(2 * time.Second):
t.Fatalf("concurrent close should succeed quickly")
}

// Wait for FIN_ACK AND datachannel close
checkDataChannelClosed(t, client.rwc.(*datachannel.DataChannel))
checkDataChannelClosed(t, server.rwc.(*datachannel.DataChannel))

}

func TestStreamResetAfterClose(t *testing.T) {
client, _ := getDetachedDataChannels(t)
client, server := getDetachedDataChannels(t)

done := make(chan bool, 2)
clientStr := newStream(client.dc, client.rwc, func() { done <- true })
clientStr.Close()

select {
case <-done:
t.Fatalf("Close shouldn't run cleanup immediately")
case <-time.After(500 * time.Millisecond):
t.Fatalf("Close should run cleanup immediately")
}

// The server data channel should still be open
checkDataChannelOpen(t, server.rwc.(*datachannel.DataChannel))
clientStr.Reset()
// Reset closes the datachannels
checkDataChannelClosed(t, server.rwc.(*datachannel.DataChannel))
checkDataChannelClosed(t, client.rwc.(*datachannel.DataChannel))
select {
case <-done:
case <-time.After(2 * time.Second):
t.Fatalf("Reset should run callback immediately")
t.Fatalf("onDone should not be called twice")
case <-time.After(50 * time.Millisecond):
}
}

Expand All @@ -478,30 +525,18 @@ func TestStreamDataChannelCloseOnFINACK(t *testing.T) {
clientStr.Close()

select {
case <-done:
t.Fatalf("Close shouldn't run cleanup immediately")
case <-time.After(500 * time.Millisecond):
t.Fatalf("Close should run cleanup immediately")
case <-done:
}

// sending FIN_ACK closes the datachannel
serverWriter := pbio.NewDelimitedWriter(server.rwc)
err := serverWriter.WriteMsg(&pb.Message{Flag: pb.Message_FIN_ACK.Enum()})
require.NoError(t, err)
select {
case <-done:
case <-time.After(2 * time.Second):
t.Fatalf("Callback should be run on reading FIN_ACK")
}
b := make([]byte, 100)
N := 0
for {
n, err := server.rwc.Read(b)
N += n
if err != nil {
require.ErrorIs(t, err, io.EOF)
break
}
}
require.Less(t, N, 10)

checkDataChannelClosed(t, server.rwc.(*datachannel.DataChannel))
checkDataChannelClosed(t, client.rwc.(*datachannel.DataChannel))
}

func TestStreamChunking(t *testing.T) {
Expand Down
14 changes: 6 additions & 8 deletions p2p/transport/webrtc/stream_write.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,11 +129,10 @@ func (s *stream) cancelWrite() error {
return nil
}
s.sendState = sendStateReset
// Remove reference to this stream from data channel
s.dataChannel.OnBufferedAmountLow(nil)
s.notifyWriteStateChanged()
if err := s.writer.WriteMsg(&pb.Message{Flag: pb.Message_RESET.Enum()}); err != nil {
return err
}
return nil
return s.writer.WriteMsg(&pb.Message{Flag: pb.Message_RESET.Enum()})
}

func (s *stream) CloseWrite() error {
Expand All @@ -144,11 +143,10 @@ func (s *stream) CloseWrite() error {
return nil
}
s.sendState = sendStateDataSent
// Remove reference to this stream from data channel
s.dataChannel.OnBufferedAmountLow(nil)
s.notifyWriteStateChanged()
if err := s.writer.WriteMsg(&pb.Message{Flag: pb.Message_FIN.Enum()}); err != nil {
return err
}
return nil
return s.writer.WriteMsg(&pb.Message{Flag: pb.Message_FIN.Enum()})
}

func (s *stream) notifyWriteStateChanged() {
Expand Down

0 comments on commit c192570

Please sign in to comment.