diff --git a/p2p/transport/webrtc/stream.go b/p2p/transport/webrtc/stream.go index 135940e4ce..b4f06ef9f5 100644 --- a/p2p/transport/webrtc/stream.go +++ b/p2p/transport/webrtc/stream.go @@ -95,8 +95,8 @@ type stream struct { // SetReadDeadline // See: https://github.com/pion/sctp/pull/290 controlMessageReaderEndTime time.Time - controlMessageReaderDone sync.WaitGroup + onDoneOnce sync.Once onDone func() id uint16 // for logging purposes dataChannel *datachannel.DataChannel @@ -118,8 +118,6 @@ func newStream( dataChannel: rwc.(*datachannel.DataChannel), onDone: onDone, } - // released when the controlMessageReader goroutine exits - s.controlMessageReaderDone.Add(1) s.dataChannel.SetBufferedAmountLowThreshold(bufferedAmountLowThreshold) s.dataChannel.OnBufferedAmountLow(func() { s.notifyWriteStateChanged() @@ -135,7 +133,7 @@ func (s *stream) Close() error { if isClosed { return nil } - + defer s.cleanup() closeWriteErr := s.CloseWrite() closeReadErr := s.CloseRead() if closeWriteErr != nil || closeReadErr != nil { @@ -147,10 +145,6 @@ func (s *stream) Close() error { if s.controlMessageReaderEndTime.IsZero() { s.controlMessageReaderEndTime = time.Now().Add(maxFINACKWait) s.setDataChannelReadDeadline(time.Now().Add(-1 * time.Hour)) - go func() { - s.controlMessageReaderDone.Wait() - s.cleanup() - }() } s.mx.Unlock() return nil @@ -227,17 +221,10 @@ func (s *stream) spawnControlMessageReader() { s.controlMessageReaderOnce.Do(func() { // Spawn a goroutine to ensure that we're not holding any locks go func() { - defer s.controlMessageReaderDone.Done() // cleanup the sctp deadline timer goroutine defer s.setDataChannelReadDeadline(time.Time{}) - setDeadline := func() bool { - if s.controlMessageReaderEndTime.IsZero() || time.Now().Before(s.controlMessageReaderEndTime) { - s.setDataChannelReadDeadline(s.controlMessageReaderEndTime) - return true - } - return false - } + defer s.dataChannel.Close() // Unblock any Read call waiting on reader.ReadMsg s.setDataChannelReadDeadline(time.Now().Add(-1 * time.Hour)) @@ -256,12 +243,22 @@ func (s *stream) spawnControlMessageReader() { s.processIncomingFlag(s.nextMessage.Flag) s.nextMessage = nil } - for s.closeForShutdownErr == nil && - s.sendState != sendStateDataReceived && s.sendState != sendStateReset { - var msg pb.Message - if !setDeadline() { + var msg pb.Message + for { + // Connection closed. No need to cleanup the data channel. + if s.closeForShutdownErr != nil { + return + } + // Write half of the stream completed. + if s.sendState == sendStateDataReceived || s.sendState == sendStateReset { + return + } + // FIN_ACK wait deadling exceeded. + if !s.controlMessageReaderEndTime.IsZero() && time.Now().After(s.controlMessageReaderEndTime) { return } + + s.setDataChannelReadDeadline(s.controlMessageReaderEndTime) s.mx.Unlock() err := s.reader.ReadMsg(&msg) s.mx.Lock() @@ -281,12 +278,9 @@ func (s *stream) spawnControlMessageReader() { } func (s *stream) cleanup() { - // Even if we close the datachannel pion keeps a reference to the datachannel around. - // Remove the onBufferedAmountLow callback to ensure that we at least garbage collect - // memory we allocated for this stream. - s.dataChannel.OnBufferedAmountLow(nil) - s.dataChannel.Close() - if s.onDone != nil { - s.onDone() - } + s.onDoneOnce.Do(func() { + if s.onDone != nil { + s.onDone() + } + }) } diff --git a/p2p/transport/webrtc/stream_test.go b/p2p/transport/webrtc/stream_test.go index 52b464c0e4..514f91b5f1 100644 --- a/p2p/transport/webrtc/stream_test.go +++ b/p2p/transport/webrtc/stream_test.go @@ -11,10 +11,12 @@ import ( "github.com/libp2p/go-libp2p/p2p/transport/webrtc/pb" "github.com/libp2p/go-msgio/pbio" + "google.golang.org/protobuf/proto" "github.com/libp2p/go-libp2p/core/network" "github.com/pion/datachannel" + "github.com/pion/sctp" "github.com/pion/webrtc/v3" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -98,6 +100,44 @@ func getDetachedDataChannels(t *testing.T) (detachedChan, detachedChan) { return <-answerChan, <-offerRWCChan } +func checkDataChannelOpen(t *testing.T, dc *datachannel.DataChannel) { + buf := make([]byte, 0) + dc.SetReadDeadline(time.Now().Add(100 * time.Millisecond)) + for { + _, err := dc.Read(buf) + if errors.Is(err, os.ErrDeadlineExceeded) { + break + } + if !errors.Is(err, io.ErrShortBuffer) { + t.Fatal("unexpected err", err) + } + } +} + +// checkDataChannelClosed checks if the datachannel has been closed. +// It sends empty messages on the data channel to check if the channel has been closed. +// The control message reader goroutine depends on exclusive access to datachannel.Read +// so we have to depend on Write to determine whether the channel has been closed. +func checkDataChannelClosed(t *testing.T, dc *datachannel.DataChannel) { + t.Helper() + emptyMsg := &pb.Message{} + msg, err := proto.Marshal(emptyMsg) + if err != nil { + t.Fatal("unexpected mashalling error", err) + } + for i := 0; i < 5; i++ { + _, err := dc.Write(msg) + if err != nil { + if errors.Is(err, sctp.ErrStreamClosed) { + return + } else { + t.Fatal("unexpected write err: ", err) + } + } + time.Sleep(50 * time.Millisecond) + } +} + func TestStreamSimpleReadWriteClose(t *testing.T) { client, server := getDetachedDataChannels(t) @@ -357,27 +397,22 @@ func TestStreamCloseAfterFINACK(t *testing.T) { serverStr := newStream(server.dc, server.rwc, func() {}) go func() { - done <- true err := clientStr.Close() assert.NoError(t, err) }() - <-done select { case <-done: - t.Fatalf("Close should not have completed without processing FIN_ACK") case <-time.After(200 * time.Millisecond): + t.Fatalf("Close should signal OnDone immediately") } + // Reading FIN_ACK on server should trigger data channel close on the client b := make([]byte, 1) _, err := serverStr.Read(b) require.Error(t, err) require.ErrorIs(t, err, io.EOF) - select { - case <-done: - case <-time.After(3 * time.Second): - t.Errorf("Close should have completed") - } + checkDataChannelClosed(t, client.rwc.(*datachannel.DataChannel)) } // TestStreamFinAckAfterStopSending tests that FIN_ACK is sent even after the write half @@ -400,8 +435,8 @@ func TestStreamFinAckAfterStopSending(t *testing.T) { select { case <-done: - t.Errorf("Close should not have completed without processing FIN_ACK") case <-time.After(500 * time.Millisecond): + t.Errorf("Close should signal onDone immediately") } // serverStr has write half closed and read half open @@ -410,11 +445,8 @@ func TestStreamFinAckAfterStopSending(t *testing.T) { _, err := serverStr.Read(b) require.NoError(t, err) serverStr.Close() // Sends stop_sending, fin - select { - case <-done: - case <-time.After(5 * time.Second): - t.Fatalf("Close should have completed") - } + checkDataChannelClosed(t, server.rwc.(*datachannel.DataChannel)) + checkDataChannelClosed(t, client.rwc.(*datachannel.DataChannel)) } func TestStreamConcurrentClose(t *testing.T) { @@ -446,10 +478,15 @@ func TestStreamConcurrentClose(t *testing.T) { case <-time.After(2 * time.Second): t.Fatalf("concurrent close should succeed quickly") } + + // Wait for FIN_ACK AND datachannel close + checkDataChannelClosed(t, client.rwc.(*datachannel.DataChannel)) + checkDataChannelClosed(t, server.rwc.(*datachannel.DataChannel)) + } func TestStreamResetAfterClose(t *testing.T) { - client, _ := getDetachedDataChannels(t) + client, server := getDetachedDataChannels(t) done := make(chan bool, 2) clientStr := newStream(client.dc, client.rwc, func() { done <- true }) @@ -457,15 +494,19 @@ func TestStreamResetAfterClose(t *testing.T) { select { case <-done: - t.Fatalf("Close shouldn't run cleanup immediately") case <-time.After(500 * time.Millisecond): + t.Fatalf("Close should run cleanup immediately") } - + // The server data channel should still be open + checkDataChannelOpen(t, server.rwc.(*datachannel.DataChannel)) clientStr.Reset() + // Reset closes the datachannels + checkDataChannelClosed(t, server.rwc.(*datachannel.DataChannel)) + checkDataChannelClosed(t, client.rwc.(*datachannel.DataChannel)) select { case <-done: - case <-time.After(2 * time.Second): - t.Fatalf("Reset should run callback immediately") + t.Fatalf("onDone should not be called twice") + case <-time.After(50 * time.Millisecond): } } @@ -478,30 +519,18 @@ func TestStreamDataChannelCloseOnFINACK(t *testing.T) { clientStr.Close() select { - case <-done: - t.Fatalf("Close shouldn't run cleanup immediately") case <-time.After(500 * time.Millisecond): + t.Fatalf("Close should run cleanup immediately") + case <-done: } + // sending FIN_ACK closes the datachannel serverWriter := pbio.NewDelimitedWriter(server.rwc) err := serverWriter.WriteMsg(&pb.Message{Flag: pb.Message_FIN_ACK.Enum()}) require.NoError(t, err) - select { - case <-done: - case <-time.After(2 * time.Second): - t.Fatalf("Callback should be run on reading FIN_ACK") - } - b := make([]byte, 100) - N := 0 - for { - n, err := server.rwc.Read(b) - N += n - if err != nil { - require.ErrorIs(t, err, io.EOF) - break - } - } - require.Less(t, N, 10) + + checkDataChannelClosed(t, server.rwc.(*datachannel.DataChannel)) + checkDataChannelClosed(t, client.rwc.(*datachannel.DataChannel)) } func TestStreamChunking(t *testing.T) { diff --git a/p2p/transport/webrtc/stream_write.go b/p2p/transport/webrtc/stream_write.go index 82d4ac287d..cacd605bb0 100644 --- a/p2p/transport/webrtc/stream_write.go +++ b/p2p/transport/webrtc/stream_write.go @@ -129,11 +129,10 @@ func (s *stream) cancelWrite() error { return nil } s.sendState = sendStateReset + // Remove reference to this stream from data channel + s.dataChannel.OnBufferedAmountLow(nil) s.notifyWriteStateChanged() - if err := s.writer.WriteMsg(&pb.Message{Flag: pb.Message_RESET.Enum()}); err != nil { - return err - } - return nil + return s.writer.WriteMsg(&pb.Message{Flag: pb.Message_RESET.Enum()}) } func (s *stream) CloseWrite() error { @@ -144,11 +143,10 @@ func (s *stream) CloseWrite() error { return nil } s.sendState = sendStateDataSent + // Remove reference to this stream from data channel + s.dataChannel.OnBufferedAmountLow(nil) s.notifyWriteStateChanged() - if err := s.writer.WriteMsg(&pb.Message{Flag: pb.Message_FIN.Enum()}); err != nil { - return err - } - return nil + return s.writer.WriteMsg(&pb.Message{Flag: pb.Message_FIN.Enum()}) } func (s *stream) notifyWriteStateChanged() {