From c97dea0443d4ec9de0311d05fcb1f2287d2b372e Mon Sep 17 00:00:00 2001 From: Alex Potsides Date: Fri, 6 Oct 2023 13:00:43 +0100 Subject: [PATCH] fix: close webrtc streams without data loss (#2073) - Gracefully close streams on muxer shutdown - Refactor initiator/recipient flows for clarity - Wait for `bufferedAmount` to be `0` before closing a datachannel - Close datachannels on both initiator and recipient - Implements FIN_ACK for closing datachannels without data loss Supersedes #2048 --------- Co-authored-by: Chad Nehemiah --- packages/interface/package.json | 4 + packages/interface/src/stream-muxer/stream.ts | 68 ++++-- packages/interface/test/fixtures/logger.ts | 16 ++ .../test/stream-muxer/stream.spec.ts | 196 +++++++++++++++ packages/libp2p/src/connection/index.ts | 10 +- packages/libp2p/src/dcutr/dcutr.ts | 5 +- packages/libp2p/src/upgrader.ts | 7 +- packages/transport-webrtc/.aegir.js | 7 +- packages/transport-webrtc/package.json | 7 +- packages/transport-webrtc/src/index.ts | 34 +++ packages/transport-webrtc/src/maconn.ts | 9 +- packages/transport-webrtc/src/muxer.ts | 102 ++++---- .../transport-webrtc/src/pb/message.proto | 7 +- packages/transport-webrtc/src/pb/message.ts | 6 +- .../src/private-to-private/handler.ts | 177 -------------- .../private-to-private/initiate-connection.ts | 191 +++++++++++++++ .../src/private-to-private/listener.ts | 16 +- .../signaling-stream-handler.ts | 129 ++++++++++ .../src/private-to-private/transport.ts | 146 ++++++----- .../src/private-to-private/util.ts | 113 +++++++-- .../src/private-to-public/transport.ts | 8 +- packages/transport-webrtc/src/stream.ts | 222 ++++++++++++----- packages/transport-webrtc/src/util.ts | 60 +++++ packages/transport-webrtc/test/basics.spec.ts | 230 +++++++++++++++++- .../transport-webrtc/test/listener.spec.ts | 2 + .../test/peer.browser.spec.ts | 157 +++++++++--- .../test/stream.browser.spec.ts | 21 +- packages/transport-webrtc/test/stream.spec.ts | 65 +++-- packages/transport-webrtc/test/util.ts | 27 ++ packages/transport-websockets/test/node.ts | 4 + 30 files changed, 1548 insertions(+), 498 deletions(-) create mode 100644 packages/interface/test/fixtures/logger.ts create mode 100644 packages/interface/test/stream-muxer/stream.spec.ts delete mode 100644 packages/transport-webrtc/src/private-to-private/handler.ts create mode 100644 packages/transport-webrtc/src/private-to-private/initiate-connection.ts create mode 100644 packages/transport-webrtc/src/private-to-private/signaling-stream-handler.ts diff --git a/packages/interface/package.json b/packages/interface/package.json index af42d6097e..e143c35371 100644 --- a/packages/interface/package.json +++ b/packages/interface/package.json @@ -164,11 +164,15 @@ "it-stream-types": "^2.0.1", "multiformats": "^12.0.1", "p-defer": "^4.0.0", + "race-signal": "^1.0.0", "uint8arraylist": "^2.4.3" }, "devDependencies": { "@types/sinon": "^10.0.15", "aegir": "^40.0.8", + "delay": "^6.0.0", + "it-all": "^3.0.3", + "it-drain": "^3.0.3", "sinon": "^16.0.0", "sinon-ts": "^1.0.0" } diff --git a/packages/interface/src/stream-muxer/stream.ts b/packages/interface/src/stream-muxer/stream.ts index bac8c7f729..93d6b5adb9 100644 --- a/packages/interface/src/stream-muxer/stream.ts +++ b/packages/interface/src/stream-muxer/stream.ts @@ -1,12 +1,14 @@ import { abortableSource } from 'abortable-iterator' import { type Pushable, pushable } from 'it-pushable' import defer, { type DeferredPromise } from 'p-defer' +import { raceSignal } from 'race-signal' import { Uint8ArrayList } from 'uint8arraylist' import { CodeError } from '../errors.js' import type { Direction, ReadStatus, Stream, StreamStatus, StreamTimeline, WriteStatus } from '../connection/index.js' import type { AbortOptions } from '../index.js' import type { Source } from 'it-stream-types' +// copied from @libp2p/logger to break a circular dependency interface Logger { (formatter: any, ...args: any[]): void error: (formatter: any, ...args: any[]) => void @@ -16,6 +18,7 @@ interface Logger { const ERR_STREAM_RESET = 'ERR_STREAM_RESET' const ERR_SINK_INVALID_STATE = 'ERR_SINK_INVALID_STATE' +const DEFAULT_SEND_CLOSE_WRITE_TIMEOUT = 5000 export interface AbstractStreamInit { /** @@ -68,6 +71,12 @@ export interface AbstractStreamInit { * connection when closing the writable end of the stream. (default: 500) */ closeTimeout?: number + + /** + * After the stream sink has closed, a limit on how long it takes to send + * a close-write message to the remote peer. + */ + sendCloseWriteTimeout?: number } function isPromise (res?: any): res is Promise { @@ -94,6 +103,7 @@ export abstract class AbstractStream implements Stream { private readonly onCloseWrite?: () => void private readonly onReset?: () => void private readonly onAbort?: (err: Error) => void + private readonly sendCloseWriteTimeout: number protected readonly log: Logger @@ -113,6 +123,7 @@ export abstract class AbstractStream implements Stream { this.timeline = { open: Date.now() } + this.sendCloseWriteTimeout = init.sendCloseWriteTimeout ?? DEFAULT_SEND_CLOSE_WRITE_TIMEOUT this.onEnd = init.onEnd this.onCloseRead = init?.onCloseRead @@ -128,7 +139,6 @@ export abstract class AbstractStream implements Stream { this.log.trace('source ended') } - this.readStatus = 'closed' this.onSourceEnd(err) } }) @@ -173,11 +183,19 @@ export abstract class AbstractStream implements Stream { } } - this.log.trace('sink finished reading from source') - this.writeStatus = 'done' + this.log.trace('sink finished reading from source, write status is "%s"', this.writeStatus) + + if (this.writeStatus === 'writing') { + this.writeStatus = 'closing' + + this.log.trace('send close write to remote') + await this.sendCloseWrite({ + signal: AbortSignal.timeout(this.sendCloseWriteTimeout) + }) + + this.writeStatus = 'closed' + } - this.log.trace('sink calling closeWrite') - await this.closeWrite(options) this.onSinkEnd() } catch (err: any) { this.log.trace('sink ended with error, calling abort with error', err) @@ -196,6 +214,7 @@ export abstract class AbstractStream implements Stream { } this.timeline.closeRead = Date.now() + this.readStatus = 'closed' if (err != null && this.endErr == null) { this.endErr = err @@ -207,6 +226,10 @@ export abstract class AbstractStream implements Stream { this.log.trace('source and sink ended') this.timeline.close = Date.now() + if (this.status !== 'aborted' && this.status !== 'reset') { + this.status = 'closed' + } + if (this.onEnd != null) { this.onEnd(this.endErr) } @@ -221,6 +244,7 @@ export abstract class AbstractStream implements Stream { } this.timeline.closeWrite = Date.now() + this.writeStatus = 'closed' if (err != null && this.endErr == null) { this.endErr = err @@ -232,6 +256,10 @@ export abstract class AbstractStream implements Stream { this.log.trace('sink and source ended') this.timeline.close = Date.now() + if (this.status !== 'aborted' && this.status !== 'reset') { + this.status = 'closed' + } + if (this.onEnd != null) { this.onEnd(this.endErr) } @@ -266,16 +294,16 @@ export abstract class AbstractStream implements Stream { const readStatus = this.readStatus this.readStatus = 'closing' - if (readStatus === 'ready') { - this.log.trace('ending internal source queue') - this.streamSource.end() - } - if (this.status !== 'reset' && this.status !== 'aborted' && this.timeline.closeRead == null) { this.log.trace('send close read to remote') await this.sendCloseRead(options) } + if (readStatus === 'ready') { + this.log.trace('ending internal source queue') + this.streamSource.end() + } + this.log.trace('closed readable end of stream') } @@ -286,16 +314,13 @@ export abstract class AbstractStream implements Stream { this.log.trace('closing writable end of stream with starting write status "%s"', this.writeStatus) - const writeStatus = this.writeStatus - if (this.writeStatus === 'ready') { this.log.trace('sink was never sunk, sink an empty array') - await this.sink([]) - } - this.writeStatus = 'closing' + await raceSignal(this.sink([]), options.signal) + } - if (writeStatus === 'writing') { + if (this.writeStatus === 'writing') { // stop reading from the source passed to `.sink` in the microtask queue // - this lets any data queued by the user in the current tick get read // before we exit @@ -303,16 +328,12 @@ export abstract class AbstractStream implements Stream { queueMicrotask(() => { this.log.trace('aborting source passed to .sink') this.sinkController.abort() - this.sinkEnd.promise.then(resolve, reject) + raceSignal(this.sinkEnd.promise, options.signal) + .then(resolve, reject) }) }) } - if (this.status !== 'reset' && this.status !== 'aborted' && this.timeline.closeWrite == null) { - this.log.trace('send close write to remote') - await this.sendCloseWrite(options) - } - this.writeStatus = 'closed' this.log.trace('closed writable end of stream') @@ -357,6 +378,7 @@ export abstract class AbstractStream implements Stream { const err = new CodeError('stream reset', ERR_STREAM_RESET) this.status = 'reset' + this.timeline.reset = Date.now() this._closeSinkAndSource(err) this.onReset?.() } @@ -423,7 +445,7 @@ export abstract class AbstractStream implements Stream { return } - this.log.trace('muxer destroyed') + this.log.trace('stream destroyed') this._closeSinkAndSource() } diff --git a/packages/interface/test/fixtures/logger.ts b/packages/interface/test/fixtures/logger.ts new file mode 100644 index 0000000000..bd614bd944 --- /dev/null +++ b/packages/interface/test/fixtures/logger.ts @@ -0,0 +1,16 @@ +// copied from @libp2p/logger to break a circular dependency +interface Logger { + (): void + error: () => void + trace: () => void + enabled: boolean +} + +export function logger (): Logger { + const output = (): void => {} + output.trace = (): void => {} + output.error = (): void => {} + output.enabled = false + + return output +} diff --git a/packages/interface/test/stream-muxer/stream.spec.ts b/packages/interface/test/stream-muxer/stream.spec.ts new file mode 100644 index 0000000000..eaeb33f04a --- /dev/null +++ b/packages/interface/test/stream-muxer/stream.spec.ts @@ -0,0 +1,196 @@ +import { expect } from 'aegir/chai' +import delay from 'delay' +import all from 'it-all' +import drain from 'it-drain' +import Sinon from 'sinon' +import { Uint8ArrayList } from 'uint8arraylist' +import { AbstractStream } from '../../src/stream-muxer/stream.js' +import { logger } from '../fixtures/logger.js' +import type { AbortOptions } from '../../src/index.js' + +class TestStream extends AbstractStream { + async sendNewStream (options?: AbortOptions): Promise { + + } + + async sendData (buf: Uint8ArrayList, options?: AbortOptions): Promise { + + } + + async sendReset (options?: AbortOptions): Promise { + + } + + async sendCloseWrite (options?: AbortOptions): Promise { + + } + + async sendCloseRead (options?: AbortOptions): Promise { + + } +} + +describe('abstract stream', () => { + let stream: TestStream + + beforeEach(() => { + stream = new TestStream({ + id: 'test', + direction: 'outbound', + log: logger() + }) + }) + + it('sends data', async () => { + const sendSpy = Sinon.spy(stream, 'sendData') + const data = [ + Uint8Array.from([0, 1, 2, 3, 4]) + ] + + await stream.sink(data) + + const call = sendSpy.getCall(0) + expect(call.args[0].subarray()).to.equalBytes(data[0]) + }) + + it('receives data', async () => { + const data = new Uint8ArrayList( + Uint8Array.from([0, 1, 2, 3, 4]) + ) + + stream.sourcePush(data) + stream.remoteCloseWrite() + + const output = await all(stream.source) + expect(output[0].subarray()).to.equalBytes(data.subarray()) + }) + + it('closes', async () => { + const sendCloseReadSpy = Sinon.spy(stream, 'sendCloseRead') + const sendCloseWriteSpy = Sinon.spy(stream, 'sendCloseWrite') + + await stream.close() + + expect(sendCloseReadSpy.calledOnce).to.be.true() + expect(sendCloseWriteSpy.calledOnce).to.be.true() + + expect(stream).to.have.property('status', 'closed') + expect(stream).to.have.property('writeStatus', 'closed') + expect(stream).to.have.property('readStatus', 'closed') + expect(stream).to.have.nested.property('timeline.close').that.is.a('number') + expect(stream).to.have.nested.property('timeline.closeRead').that.is.a('number') + expect(stream).to.have.nested.property('timeline.closeWrite').that.is.a('number') + expect(stream).to.not.have.nested.property('timeline.reset') + expect(stream).to.not.have.nested.property('timeline.abort') + }) + + it('closes for reading', async () => { + const sendCloseReadSpy = Sinon.spy(stream, 'sendCloseRead') + const sendCloseWriteSpy = Sinon.spy(stream, 'sendCloseWrite') + + await stream.closeRead() + + expect(sendCloseReadSpy.calledOnce).to.be.true() + expect(sendCloseWriteSpy.called).to.be.false() + + expect(stream).to.have.property('status', 'open') + expect(stream).to.have.property('writeStatus', 'ready') + expect(stream).to.have.property('readStatus', 'closed') + expect(stream).to.not.have.nested.property('timeline.close') + expect(stream).to.have.nested.property('timeline.closeRead').that.is.a('number') + expect(stream).to.not.have.nested.property('timeline.closeWrite') + expect(stream).to.not.have.nested.property('timeline.reset') + expect(stream).to.not.have.nested.property('timeline.abort') + }) + + it('closes for writing', async () => { + const sendCloseReadSpy = Sinon.spy(stream, 'sendCloseRead') + const sendCloseWriteSpy = Sinon.spy(stream, 'sendCloseWrite') + + await stream.closeWrite() + + expect(sendCloseReadSpy.called).to.be.false() + expect(sendCloseWriteSpy.calledOnce).to.be.true() + + expect(stream).to.have.property('status', 'open') + expect(stream).to.have.property('writeStatus', 'closed') + expect(stream).to.have.property('readStatus', 'ready') + expect(stream).to.not.have.nested.property('timeline.close') + expect(stream).to.not.have.nested.property('timeline.closeRead') + expect(stream).to.have.nested.property('timeline.closeWrite').that.is.a('number') + expect(stream).to.not.have.nested.property('timeline.reset') + expect(stream).to.not.have.nested.property('timeline.abort') + }) + + it('aborts', async () => { + const sendResetSpy = Sinon.spy(stream, 'sendReset') + + stream.abort(new Error('Urk!')) + + expect(sendResetSpy.calledOnce).to.be.true() + + expect(stream).to.have.property('status', 'aborted') + expect(stream).to.have.property('writeStatus', 'closed') + expect(stream).to.have.property('readStatus', 'closed') + expect(stream).to.have.nested.property('timeline.close').that.is.a('number') + expect(stream).to.have.nested.property('timeline.closeRead').that.is.a('number') + expect(stream).to.have.nested.property('timeline.closeWrite').that.is.a('number') + expect(stream).to.not.have.nested.property('timeline.reset') + expect(stream).to.have.nested.property('timeline.abort').that.is.a('number') + + await expect(stream.sink([])).to.eventually.be.rejected + .with.property('code', 'ERR_SINK_INVALID_STATE') + await expect(drain(stream.source)).to.eventually.be.rejected + .with('Urk!') + }) + + it('gets reset remotely', async () => { + stream.reset() + + expect(stream).to.have.property('status', 'reset') + expect(stream).to.have.property('writeStatus', 'closed') + expect(stream).to.have.property('readStatus', 'closed') + expect(stream).to.have.nested.property('timeline.close').that.is.a('number') + expect(stream).to.have.nested.property('timeline.closeRead').that.is.a('number') + expect(stream).to.have.nested.property('timeline.closeWrite').that.is.a('number') + expect(stream).to.have.nested.property('timeline.reset').that.is.a('number') + expect(stream).to.not.have.nested.property('timeline.abort') + + await expect(stream.sink([])).to.eventually.be.rejected + .with.property('code', 'ERR_SINK_INVALID_STATE') + await expect(drain(stream.source)).to.eventually.be.rejected + .with.property('code', 'ERR_STREAM_RESET') + }) + + it('does not send close read when remote closes write', async () => { + const sendCloseReadSpy = Sinon.spy(stream, 'sendCloseRead') + + stream.remoteCloseWrite() + + await delay(100) + + expect(sendCloseReadSpy.called).to.be.false() + }) + + it('does not send close write when remote closes read', async () => { + const sendCloseWriteSpy = Sinon.spy(stream, 'sendCloseWrite') + + stream.remoteCloseRead() + + await delay(100) + + expect(sendCloseWriteSpy.called).to.be.false() + }) + + it('does not send close read or write when remote resets', async () => { + const sendCloseReadSpy = Sinon.spy(stream, 'sendCloseRead') + const sendCloseWriteSpy = Sinon.spy(stream, 'sendCloseWrite') + + stream.reset() + + await delay(100) + + expect(sendCloseReadSpy.called).to.be.false() + expect(sendCloseWriteSpy.called).to.be.false() + }) +}) diff --git a/packages/libp2p/src/connection/index.ts b/packages/libp2p/src/connection/index.ts index 8509b8d70a..5726d1cc54 100644 --- a/packages/libp2p/src/connection/index.ts +++ b/packages/libp2p/src/connection/index.ts @@ -158,16 +158,22 @@ export class ConnectionImpl implements Connection { } catch { } try { + log.trace('closing all streams') + // close all streams gracefully - this can throw if we're not multiplexed await Promise.all( this.streams.map(async s => s.close(options)) ) - // Close raw connection + log.trace('closing underlying transport') + + // close raw connection await this._close(options) - this.timeline.close = Date.now() + log.trace('updating timeline with close time') + this.status = 'closed' + this.timeline.close = Date.now() } catch (err: any) { log.error('error encountered during graceful close of connection to %a', this.remoteAddr, err) this.abort(err) diff --git a/packages/libp2p/src/dcutr/dcutr.ts b/packages/libp2p/src/dcutr/dcutr.ts index 190766bfa2..e6852b6d03 100644 --- a/packages/libp2p/src/dcutr/dcutr.ts +++ b/packages/libp2p/src/dcutr/dcutr.ts @@ -262,7 +262,10 @@ export class DefaultDCUtRService implements Startable { } log('unilateral connection upgrade to %p succeeded via %a, closing relayed connection', relayedConnection.remotePeer, connection.remoteAddr) - await relayedConnection.close() + + await relayedConnection.close({ + signal + }) return true } catch (err) { diff --git a/packages/libp2p/src/upgrader.ts b/packages/libp2p/src/upgrader.ts index c7a6829175..18934e3ad0 100644 --- a/packages/libp2p/src/upgrader.ts +++ b/packages/libp2p/src/upgrader.ts @@ -545,11 +545,16 @@ export class DefaultUpgrader implements Upgrader { newStream: newStream ?? errConnectionNotMultiplexed, getStreams: () => { if (muxer != null) { return muxer.streams } else { return [] } }, close: async (options?: AbortOptions) => { - await maConn.close(options) // Ensure remaining streams are closed gracefully if (muxer != null) { + log.trace('close muxer') await muxer.close(options) } + + log.trace('close maconn') + // close the underlying transport + await maConn.close(options) + log.trace('closed maconn') }, abort: (err) => { maConn.abort(err) diff --git a/packages/transport-webrtc/.aegir.js b/packages/transport-webrtc/.aegir.js index 491576df87..3b38600b7a 100644 --- a/packages/transport-webrtc/.aegir.js +++ b/packages/transport-webrtc/.aegir.js @@ -8,7 +8,6 @@ export default { before: async () => { const { createLibp2p } = await import('libp2p') const { circuitRelayServer } = await import('libp2p/circuit-relay') - const { identifyService } = await import('libp2p/identify') const { webSockets } = await import('@libp2p/websockets') const { noise } = await import('@chainsafe/libp2p-noise') const { yamux } = await import('@chainsafe/libp2p-yamux') @@ -34,11 +33,11 @@ export default { reservations: { maxReservations: Infinity } - }), - identify: identifyService() + }) }, connectionManager: { - minConnections: 0 + minConnections: 0, + inboundConnectionThreshold: Infinity } }) diff --git a/packages/transport-webrtc/package.json b/packages/transport-webrtc/package.json index ad308fea0f..6211a33571 100644 --- a/packages/transport-webrtc/package.json +++ b/packages/transport-webrtc/package.json @@ -54,6 +54,7 @@ "@multiformats/multiaddr": "^12.1.5", "@multiformats/multiaddr-matcher": "^1.0.1", "abortable-iterator": "^5.0.1", + "any-signal": "^4.1.1", "detect-browser": "^5.3.0", "it-length-prefixed": "^9.0.1", "it-pipe": "^3.0.1", @@ -63,10 +64,12 @@ "it-to-buffer": "^4.0.2", "multiformats": "^12.0.1", "multihashes": "^4.0.3", - "node-datachannel": "^0.4.3", + "node-datachannel": "^0.5.0-dev", "p-defer": "^4.0.0", "p-event": "^6.0.0", + "p-timeout": "^6.1.2", "protons-runtime": "^5.0.0", + "race-signal": "^1.0.0", "uint8arraylist": "^2.4.3", "uint8arrays": "^4.0.6" }, @@ -78,10 +81,12 @@ "@types/sinon": "^10.0.15", "aegir": "^40.0.8", "delay": "^6.0.0", + "it-drain": "^3.0.3", "it-length": "^3.0.2", "it-map": "^3.0.3", "it-pair": "^2.0.6", "libp2p": "^0.46.12", + "p-retry": "^6.1.0", "protons": "^7.0.2", "sinon": "^16.0.0", "sinon-ts": "^1.0.0" diff --git a/packages/transport-webrtc/src/index.ts b/packages/transport-webrtc/src/index.ts index 0245aefcc9..a8bbf37e75 100644 --- a/packages/transport-webrtc/src/index.ts +++ b/packages/transport-webrtc/src/index.ts @@ -3,6 +3,40 @@ import { WebRTCDirectTransport, type WebRTCTransportDirectInit, type WebRTCDirec import type { WebRTCTransportComponents, WebRTCTransportInit } from './private-to-private/transport.js' import type { Transport } from '@libp2p/interface/transport' +export interface DataChannelOptions { + /** + * The maximum message size sendable over the channel in bytes (default 16KB) + */ + maxMessageSize?: number + + /** + * If the channel's `bufferedAmount` grows over this amount in bytes, wait + * for it to drain before sending more data (default: 16MB) + */ + maxBufferedAmount?: number + + /** + * When `bufferedAmount` is above `maxBufferedAmount`, we pause sending until + * the `bufferedAmountLow` event fires - this controls how long we wait for + * that event in ms (default: 30s) + */ + bufferedAmountLowEventTimeout?: number + + /** + * When closing a stream, we wait for `bufferedAmount` to become 0 before + * closing the underlying RTCDataChannel - this controls how long we wait + * in ms (default: 30s) + */ + drainTimeout?: number + + /** + * When closing a stream we first send a FIN flag to the remote and wait + * for a FIN_ACK reply before closing the underlying RTCDataChannel - this + * controls how long we wait for the acknowledgement in ms (default: 5s) + */ + closeTimeout?: number +} + /** * @param {WebRTCTransportDirectInit} init - WebRTC direct transport configuration * @param init.dataChannel - DataChannel configurations diff --git a/packages/transport-webrtc/src/maconn.ts b/packages/transport-webrtc/src/maconn.ts index c32d88760c..04318aa549 100644 --- a/packages/transport-webrtc/src/maconn.ts +++ b/packages/transport-webrtc/src/maconn.ts @@ -5,7 +5,7 @@ import type { CounterGroup } from '@libp2p/interface/metrics' import type { AbortOptions, Multiaddr } from '@multiformats/multiaddr' import type { Source, Sink } from 'it-stream-types' -const log = logger('libp2p:webrtc:connection') +const log = logger('libp2p:webrtc:maconn') interface WebRTCMultiaddrConnectionInit { /** @@ -65,8 +65,13 @@ export class WebRTCMultiaddrConnection implements MultiaddrConnection { this.timeline = init.timeline this.peerConnection = init.peerConnection + const initialState = this.peerConnection.connectionState + this.peerConnection.onconnectionstatechange = () => { - if (this.peerConnection.connectionState === 'closed' || this.peerConnection.connectionState === 'disconnected' || this.peerConnection.connectionState === 'failed') { + log.trace('peer connection state change', this.peerConnection.connectionState, 'initial state', initialState) + + if (this.peerConnection.connectionState === 'disconnected' || this.peerConnection.connectionState === 'failed' || this.peerConnection.connectionState === 'closed') { + // nothing else to do but close the connection this.timeline.close = Date.now() } } diff --git a/packages/transport-webrtc/src/muxer.ts b/packages/transport-webrtc/src/muxer.ts index 991d130230..fae5f15011 100644 --- a/packages/transport-webrtc/src/muxer.ts +++ b/packages/transport-webrtc/src/muxer.ts @@ -1,6 +1,7 @@ +import { logger } from '@libp2p/logger' import { createStream } from './stream.js' -import { nopSink, nopSource } from './util.js' -import type { DataChannelOpts } from './stream.js' +import { drainAndClose, nopSink, nopSource } from './util.js' +import type { DataChannelOptions } from './index.js' import type { Stream } from '@libp2p/interface/connection' import type { CounterGroup } from '@libp2p/interface/metrics' import type { StreamMuxer, StreamMuxerFactory, StreamMuxerInit } from '@libp2p/interface/stream-muxer' @@ -8,6 +9,8 @@ import type { AbortOptions } from '@multiformats/multiaddr' import type { Source, Sink } from 'it-stream-types' import type { Uint8ArrayList } from 'uint8arraylist' +const log = logger('libp2p:webrtc:muxer') + const PROTOCOL = '/webrtc' export interface DataChannelMuxerFactoryInit { @@ -17,19 +20,16 @@ export interface DataChannelMuxerFactoryInit { peerConnection: RTCPeerConnection /** - * Optional metrics for this data channel muxer + * The protocol to use */ - metrics?: CounterGroup + protocol?: string /** - * Data channel options + * Optional metrics for this data channel muxer */ - dataChannelOptions?: Partial + metrics?: CounterGroup - /** - * The protocol to use - */ - protocol?: string + dataChannelOptions?: DataChannelOptions } export class DataChannelMuxerFactory implements StreamMuxerFactory { @@ -41,23 +41,23 @@ export class DataChannelMuxerFactory implements StreamMuxerFactory { private readonly peerConnection: RTCPeerConnection private streamBuffer: Stream[] = [] private readonly metrics?: CounterGroup - private readonly dataChannelOptions?: Partial + private readonly dataChannelOptions?: DataChannelOptions constructor (init: DataChannelMuxerFactoryInit) { this.peerConnection = init.peerConnection this.metrics = init.metrics this.protocol = init.protocol ?? PROTOCOL - this.dataChannelOptions = init.dataChannelOptions + this.dataChannelOptions = init.dataChannelOptions ?? {} // store any datachannels opened before upgrade has been completed this.peerConnection.ondatachannel = ({ channel }) => { const stream = createStream({ channel, direction: 'inbound', - dataChannelOptions: init.dataChannelOptions, onEnd: () => { this.streamBuffer = this.streamBuffer.filter(s => s.id !== stream.id) - } + }, + ...this.dataChannelOptions }) this.streamBuffer.push(stream) } @@ -90,34 +90,15 @@ export class DataChannelMuxer implements StreamMuxer { public protocol: string private readonly peerConnection: RTCPeerConnection - private readonly dataChannelOptions?: DataChannelOpts + private readonly dataChannelOptions: DataChannelOptions private readonly metrics?: CounterGroup - /** - * Gracefully close all tracked streams and stop the muxer - */ - close: (options?: AbortOptions) => Promise = async () => { } - - /** - * Abort all tracked streams and stop the muxer - */ - abort: (err: Error) => void = () => { } - - /** - * The stream source, a no-op as the transport natively supports multiplexing - */ - source: AsyncGenerator = nopSource() - - /** - * The stream destination, a no-op as the transport natively supports multiplexing - */ - sink: Sink, Promise> = nopSink - constructor (readonly init: DataChannelMuxerInit) { this.streams = init.streams this.peerConnection = init.peerConnection this.protocol = init.protocol ?? PROTOCOL this.metrics = init.metrics + this.dataChannelOptions = init.dataChannelOptions ?? {} /** * Fired when a data channel has been added to the connection has been @@ -129,19 +110,19 @@ export class DataChannelMuxer implements StreamMuxer { const stream = createStream({ channel, direction: 'inbound', - dataChannelOptions: this.dataChannelOptions, onEnd: () => { + log.trace('stream %s %s %s onEnd', stream.direction, stream.id, stream.protocol) + drainAndClose(channel, `inbound ${stream.id} ${stream.protocol}`, this.dataChannelOptions.drainTimeout) this.streams = this.streams.filter(s => s.id !== stream.id) this.metrics?.increment({ stream_end: true }) init?.onStreamEnd?.(stream) - } + }, + ...this.dataChannelOptions }) this.streams.push(stream) - if ((init?.onIncomingStream) != null) { - this.metrics?.increment({ incoming_stream: true }) - init.onIncomingStream(stream) - } + this.metrics?.increment({ incoming_stream: true }) + init?.onIncomingStream?.(stream) } const onIncomingStream = init?.onIncomingStream @@ -150,19 +131,52 @@ export class DataChannelMuxer implements StreamMuxer { } } + /** + * Gracefully close all tracked streams and stop the muxer + */ + async close (options?: AbortOptions): Promise { + try { + await Promise.all( + this.streams.map(async stream => stream.close(options)) + ) + } catch (err: any) { + this.abort(err) + } + } + + /** + * Abort all tracked streams and stop the muxer + */ + abort (err: Error): void { + for (const stream of this.streams) { + stream.abort(err) + } + } + + /** + * The stream source, a no-op as the transport natively supports multiplexing + */ + source: AsyncGenerator = nopSource() + + /** + * The stream destination, a no-op as the transport natively supports multiplexing + */ + sink: Sink, Promise> = nopSink + newStream (): Stream { // The spec says the label SHOULD be an empty string: https://github.com/libp2p/specs/blob/master/webrtc/README.md#rtcdatachannel-label const channel = this.peerConnection.createDataChannel('') const stream = createStream({ channel, direction: 'outbound', - dataChannelOptions: this.dataChannelOptions, onEnd: () => { - channel.close() // Stream initiator is responsible for closing the channel + log.trace('stream %s %s %s onEnd', stream.direction, stream.id, stream.protocol) + drainAndClose(channel, `outbound ${stream.id} ${stream.protocol}`, this.dataChannelOptions.drainTimeout) this.streams = this.streams.filter(s => s.id !== stream.id) this.metrics?.increment({ stream_end: true }) this.init?.onStreamEnd?.(stream) - } + }, + ...this.dataChannelOptions }) this.streams.push(stream) this.metrics?.increment({ outgoing_stream: true }) diff --git a/packages/transport-webrtc/src/pb/message.proto b/packages/transport-webrtc/src/pb/message.proto index 9301bd802b..ea1ae55b99 100644 --- a/packages/transport-webrtc/src/pb/message.proto +++ b/packages/transport-webrtc/src/pb/message.proto @@ -2,7 +2,8 @@ syntax = "proto3"; message Message { enum Flag { - // The sender will no longer send messages on the stream. + // The sender will no longer send messages on the stream. The recipient + // should send a FIN_ACK back to the sender. FIN = 0; // The sender will no longer read messages on the stream. Incoming data is @@ -12,6 +13,10 @@ message Message { // The sender abruptly terminates the sending part of the stream. The // receiver can discard any data that it already received on that stream. RESET = 2; + + // The sender previously received a FIN. + // Workaround for https://bugs.chromium.org/p/chromium/issues/detail?id=1484907 + FIN_ACK = 3; } optional Flag flag = 1; diff --git a/packages/transport-webrtc/src/pb/message.ts b/packages/transport-webrtc/src/pb/message.ts index a74ca6dd06..f8abb7a4a9 100644 --- a/packages/transport-webrtc/src/pb/message.ts +++ b/packages/transport-webrtc/src/pb/message.ts @@ -17,13 +17,15 @@ export namespace Message { export enum Flag { FIN = 'FIN', STOP_SENDING = 'STOP_SENDING', - RESET = 'RESET' + RESET = 'RESET', + FIN_ACK = 'FIN_ACK' } enum __FlagValues { FIN = 0, STOP_SENDING = 1, - RESET = 2 + RESET = 2, + FIN_ACK = 3 } export namespace Flag { diff --git a/packages/transport-webrtc/src/private-to-private/handler.ts b/packages/transport-webrtc/src/private-to-private/handler.ts deleted file mode 100644 index 8564fc84d2..0000000000 --- a/packages/transport-webrtc/src/private-to-private/handler.ts +++ /dev/null @@ -1,177 +0,0 @@ -import { CodeError } from '@libp2p/interface/errors' -import { logger } from '@libp2p/logger' -import { abortableDuplex } from 'abortable-iterator' -import { pbStream } from 'it-protobuf-stream' -import pDefer, { type DeferredPromise } from 'p-defer' -import { DataChannelMuxerFactory } from '../muxer.js' -import { RTCPeerConnection, RTCSessionDescription } from '../webrtc/index.js' -import { Message } from './pb/message.js' -import { readCandidatesUntilConnected, resolveOnConnected } from './util.js' -import type { DataChannelOpts } from '../stream.js' -import type { Stream } from '@libp2p/interface/connection' -import type { StreamMuxerFactory } from '@libp2p/interface/stream-muxer' -import type { IncomingStreamData } from '@libp2p/interface-internal/registrar' - -const DEFAULT_TIMEOUT = 30 * 1000 - -const log = logger('libp2p:webrtc:peer') - -export type IncomingStreamOpts = { rtcConfiguration?: RTCConfiguration, dataChannelOptions?: Partial } & IncomingStreamData - -export async function handleIncomingStream ({ rtcConfiguration, dataChannelOptions, stream: rawStream }: IncomingStreamOpts): Promise<{ pc: RTCPeerConnection, muxerFactory: StreamMuxerFactory, remoteAddress: string }> { - const signal = AbortSignal.timeout(DEFAULT_TIMEOUT) - const stream = pbStream(abortableDuplex(rawStream, signal)).pb(Message) - const pc = new RTCPeerConnection(rtcConfiguration) - - try { - const muxerFactory = new DataChannelMuxerFactory({ peerConnection: pc, dataChannelOptions }) - const connectedPromise: DeferredPromise = pDefer() - const answerSentPromise: DeferredPromise = pDefer() - - signal.onabort = () => { - connectedPromise.reject(new CodeError('Timed out while trying to connect', 'ERR_TIMEOUT')) - } - // candidate callbacks - pc.onicecandidate = ({ candidate }) => { - answerSentPromise.promise.then( - async () => { - await stream.write({ - type: Message.Type.ICE_CANDIDATE, - data: (candidate != null) ? JSON.stringify(candidate.toJSON()) : '' - }) - }, - (err) => { - log.error('cannot set candidate since sending answer failed', err) - connectedPromise.reject(err) - } - ) - } - - resolveOnConnected(pc, connectedPromise) - - // read an SDP offer - const pbOffer = await stream.read() - if (pbOffer.type !== Message.Type.SDP_OFFER) { - throw new Error(`expected message type SDP_OFFER, received: ${pbOffer.type ?? 'undefined'} `) - } - const offer = new RTCSessionDescription({ - type: 'offer', - sdp: pbOffer.data - }) - - await pc.setRemoteDescription(offer).catch(err => { - log.error('could not execute setRemoteDescription', err) - throw new Error('Failed to set remoteDescription') - }) - - // create and write an SDP answer - const answer = await pc.createAnswer().catch(err => { - log.error('could not execute createAnswer', err) - answerSentPromise.reject(err) - throw new Error('Failed to create answer') - }) - // write the answer to the remote - await stream.write({ type: Message.Type.SDP_ANSWER, data: answer.sdp }) - - await pc.setLocalDescription(answer).catch(err => { - log.error('could not execute setLocalDescription', err) - answerSentPromise.reject(err) - throw new Error('Failed to set localDescription') - }) - - answerSentPromise.resolve() - - // wait until candidates are connected - await readCandidatesUntilConnected(connectedPromise, pc, stream) - - const remoteAddress = parseRemoteAddress(pc.currentRemoteDescription?.sdp ?? '') - - return { pc, muxerFactory, remoteAddress } - } catch (err) { - pc.close() - throw err - } -} - -export interface ConnectOptions { - stream: Stream - signal: AbortSignal - rtcConfiguration?: RTCConfiguration - dataChannelOptions?: Partial -} - -export async function initiateConnection ({ rtcConfiguration, dataChannelOptions, signal, stream: rawStream }: ConnectOptions): Promise<{ pc: RTCPeerConnection, muxerFactory: StreamMuxerFactory, remoteAddress: string }> { - const stream = pbStream(abortableDuplex(rawStream, signal)).pb(Message) - // setup peer connection - const pc = new RTCPeerConnection(rtcConfiguration) - - try { - const muxerFactory = new DataChannelMuxerFactory({ peerConnection: pc, dataChannelOptions }) - - const connectedPromise: DeferredPromise = pDefer() - resolveOnConnected(pc, connectedPromise) - - // reject the connectedPromise if the signal aborts - signal.onabort = connectedPromise.reject - // we create the channel so that the peerconnection has a component for which - // to collect candidates. The label is not relevant to connection initiation - // but can be useful for debugging - const channel = pc.createDataChannel('init') - // setup callback to write ICE candidates to the remote - // peer - pc.onicecandidate = ({ candidate }) => { - void stream.write({ - type: Message.Type.ICE_CANDIDATE, - data: (candidate != null) ? JSON.stringify(candidate.toJSON()) : '' - }) - .catch(err => { - log.error('error sending ICE candidate', err) - }) - } - - // create an offer - const offerSdp = await pc.createOffer() - // write the offer to the stream - await stream.write({ type: Message.Type.SDP_OFFER, data: offerSdp.sdp }) - // set offer as local description - await pc.setLocalDescription(offerSdp).catch(err => { - log.error('could not execute setLocalDescription', err) - throw new Error('Failed to set localDescription') - }) - - // read answer - const answerMessage = await stream.read() - if (answerMessage.type !== Message.Type.SDP_ANSWER) { - throw new Error('remote should send an SDP answer') - } - - const answerSdp = new RTCSessionDescription({ type: 'answer', sdp: answerMessage.data }) - await pc.setRemoteDescription(answerSdp).catch(err => { - log.error('could not execute setRemoteDescription', err) - throw new Error('Failed to set remoteDescription') - }) - - await readCandidatesUntilConnected(connectedPromise, pc, stream) - channel.close() - - const remoteAddress = parseRemoteAddress(pc.currentRemoteDescription?.sdp ?? '') - - return { pc, muxerFactory, remoteAddress } - } catch (err) { - pc.close() - throw err - } -} - -function parseRemoteAddress (sdp: string): string { - // 'a=candidate:1746876089 1 udp 2113937151 0614fbad-b...ocal 54882 typ host generation 0 network-cost 999' - const candidateLine = sdp.split('\r\n').filter(line => line.startsWith('a=candidate')).pop() - const candidateParts = candidateLine?.split(' ') - - if (candidateLine == null || candidateParts == null || candidateParts.length < 5) { - log('could not parse remote address from', candidateLine) - return '/webrtc' - } - - return `/dnsaddr/${candidateParts[4]}/${candidateParts[2].toLowerCase()}/${candidateParts[5]}/webrtc` -} diff --git a/packages/transport-webrtc/src/private-to-private/initiate-connection.ts b/packages/transport-webrtc/src/private-to-private/initiate-connection.ts new file mode 100644 index 0000000000..408bef0ac1 --- /dev/null +++ b/packages/transport-webrtc/src/private-to-private/initiate-connection.ts @@ -0,0 +1,191 @@ +import { CodeError } from '@libp2p/interface/errors' +import { logger } from '@libp2p/logger' +import { peerIdFromString } from '@libp2p/peer-id' +import { multiaddr, type Multiaddr } from '@multiformats/multiaddr' +import { pbStream } from 'it-protobuf-stream' +import pDefer, { type DeferredPromise } from 'p-defer' +import { type RTCPeerConnection, RTCSessionDescription } from '../webrtc/index.js' +import { Message } from './pb/message.js' +import { SIGNALING_PROTO_ID, splitAddr, type WebRTCTransportMetrics } from './transport.js' +import { parseRemoteAddress, readCandidatesUntilConnected, resolveOnConnected } from './util.js' +import type { DataChannelOptions } from '../index.js' +import type { Connection } from '@libp2p/interface/connection' +import type { ConnectionManager } from '@libp2p/interface-internal/connection-manager' +import type { IncomingStreamData } from '@libp2p/interface-internal/registrar' +import type { TransportManager } from '@libp2p/interface-internal/transport-manager' + +const log = logger('libp2p:webrtc:initiate-connection') + +export interface IncomingStreamOpts extends IncomingStreamData { + rtcConfiguration?: RTCConfiguration + dataChannelOptions?: Partial + signal: AbortSignal +} + +export interface ConnectOptions { + peerConnection: RTCPeerConnection + multiaddr: Multiaddr + connectionManager: ConnectionManager + transportManager: TransportManager + dataChannelOptions?: Partial + signal?: AbortSignal + metrics?: WebRTCTransportMetrics +} + +export async function initiateConnection ({ peerConnection, signal, metrics, multiaddr: ma, connectionManager, transportManager }: ConnectOptions): Promise<{ remoteAddress: Multiaddr }> { + const { baseAddr, peerId } = splitAddr(ma) + + metrics?.dialerEvents.increment({ open: true }) + + log.trace('dialing base address: %a', baseAddr) + + const relayPeer = baseAddr.getPeerId() + + if (relayPeer == null) { + throw new CodeError('Relay peer was missing', 'ERR_INVALID_ADDRESS') + } + + const connections = connectionManager.getConnections(peerIdFromString(relayPeer)) + let connection: Connection + let shouldCloseConnection = false + + if (connections.length === 0) { + // use the transport manager to open a connection. Initiating a WebRTC + // connection takes place in the context of a dial - if we use the + // connection manager instead we can end up joining our own dial context + connection = await transportManager.dial(baseAddr, { + signal + }) + // this connection is unmanaged by the connection manager so we should + // close it when we are done + shouldCloseConnection = true + } else { + connection = connections[0] + } + + try { + const stream = await connection.newStream(SIGNALING_PROTO_ID, { + signal, + runOnTransientConnection: true + }) + + const messageStream = pbStream(stream).pb(Message) + const connectedPromise: DeferredPromise = pDefer() + const sdpAbortedListener = (): void => { + connectedPromise.reject(new CodeError('SDP handshake aborted', 'ERR_SDP_HANDSHAKE_ABORTED')) + } + + try { + resolveOnConnected(peerConnection, connectedPromise) + + // reject the connectedPromise if the signal aborts + signal?.addEventListener('abort', sdpAbortedListener) + + // we create the channel so that the RTCPeerConnection has a component for + // which to collect candidates. The label is not relevant to connection + // initiation but can be useful for debugging + const channel = peerConnection.createDataChannel('init') + + // setup callback to write ICE candidates to the remote peer + peerConnection.onicecandidate = ({ candidate }) => { + // a null candidate means end-of-candidates, an empty string candidate + // means end-of-candidates for this generation, otherwise this should + // be a valid candidate object + // see - https://www.w3.org/TR/webrtc/#rtcpeerconnectioniceevent + const data = JSON.stringify(candidate?.toJSON() ?? null) + + log.trace('initiator sending ICE candidate %s', data) + + void messageStream.write({ + type: Message.Type.ICE_CANDIDATE, + data + }, { + signal + }) + .catch(err => { + log.error('error sending ICE candidate', err) + }) + } + peerConnection.onicecandidateerror = (event) => { + log('initiator ICE candidate error', event) + } + + // create an offer + const offerSdp = await peerConnection.createOffer() + + log.trace('initiator send SDP offer %s', offerSdp.sdp) + + // write the offer to the stream + await messageStream.write({ type: Message.Type.SDP_OFFER, data: offerSdp.sdp }, { + signal + }) + + // set offer as local description + await peerConnection.setLocalDescription(offerSdp).catch(err => { + log.error('could not execute setLocalDescription', err) + throw new CodeError('Failed to set localDescription', 'ERR_SDP_HANDSHAKE_FAILED') + }) + + // read answer + const answerMessage = await messageStream.read({ + signal + }) + + if (answerMessage.type !== Message.Type.SDP_ANSWER) { + throw new CodeError('remote should send an SDP answer', 'ERR_SDP_HANDSHAKE_FAILED') + } + + log.trace('initiator receive SDP answer %s', answerMessage.data) + + const answerSdp = new RTCSessionDescription({ type: 'answer', sdp: answerMessage.data }) + await peerConnection.setRemoteDescription(answerSdp).catch(err => { + log.error('could not execute setRemoteDescription', err) + throw new CodeError('Failed to set remoteDescription', 'ERR_SDP_HANDSHAKE_FAILED') + }) + + log.trace('initiator read candidates until connected') + + await readCandidatesUntilConnected(connectedPromise, peerConnection, messageStream, { + direction: 'initiator', + signal + }) + + log.trace('initiator connected, closing init channel') + channel.close() + + log.trace('initiator closing signalling stream') + await messageStream.unwrap().unwrap().close({ + signal + }) + + const remoteAddress = parseRemoteAddress(peerConnection.currentRemoteDescription?.sdp ?? '') + + log.trace('initiator connected to remote address %s', remoteAddress) + + return { + remoteAddress: multiaddr(remoteAddress).encapsulate(`/p2p/${peerId.toString()}`) + } + } catch (err: any) { + peerConnection.close() + stream.abort(err) + throw err + } finally { + // remove event listeners + signal?.removeEventListener('abort', sdpAbortedListener) + peerConnection.onicecandidate = null + peerConnection.onicecandidateerror = null + } + } finally { + // if we had to open a connection to perform the SDP handshake + // close it because it's not tracked by the connection manager + if (shouldCloseConnection) { + try { + await connection.close({ + signal + }) + } catch (err: any) { + connection.abort(err) + } + } + } +} diff --git a/packages/transport-webrtc/src/private-to-private/listener.ts b/packages/transport-webrtc/src/private-to-private/listener.ts index 1dccac6e25..53a3d299c6 100644 --- a/packages/transport-webrtc/src/private-to-private/listener.ts +++ b/packages/transport-webrtc/src/private-to-private/listener.ts @@ -5,20 +5,27 @@ import type { ListenerEvents, Listener } from '@libp2p/interface/transport' import type { TransportManager } from '@libp2p/interface-internal/transport-manager' import type { Multiaddr } from '@multiformats/multiaddr' -export interface ListenerOptions { +export interface WebRTCPeerListenerComponents { peerId: PeerId transportManager: TransportManager } +export interface WebRTCPeerListenerInit { + shutdownController: AbortController +} + export class WebRTCPeerListener extends EventEmitter implements Listener { private readonly peerId: PeerId private readonly transportManager: TransportManager + private readonly shutdownController: AbortController - constructor (opts: ListenerOptions) { + constructor (components: WebRTCPeerListenerComponents, init: WebRTCPeerListenerInit) { super() - this.peerId = opts.peerId - this.transportManager = opts.transportManager + this.peerId = components.peerId + this.transportManager = components.transportManager + + this.shutdownController = init.shutdownController } async listen (): Promise { @@ -39,6 +46,7 @@ export class WebRTCPeerListener extends EventEmitter implements } async close (): Promise { + this.shutdownController.abort() this.safeDispatchEvent('close', {}) } } diff --git a/packages/transport-webrtc/src/private-to-private/signaling-stream-handler.ts b/packages/transport-webrtc/src/private-to-private/signaling-stream-handler.ts new file mode 100644 index 0000000000..69db46361c --- /dev/null +++ b/packages/transport-webrtc/src/private-to-private/signaling-stream-handler.ts @@ -0,0 +1,129 @@ +import { CodeError } from '@libp2p/interface/errors' +import { logger } from '@libp2p/logger' +import { pbStream } from 'it-protobuf-stream' +import pDefer, { type DeferredPromise } from 'p-defer' +import { type RTCPeerConnection, RTCSessionDescription } from '../webrtc/index.js' +import { Message } from './pb/message.js' +import { parseRemoteAddress, readCandidatesUntilConnected, resolveOnConnected } from './util.js' +import type { IncomingStreamData } from '@libp2p/interface-internal/registrar' + +const log = logger('libp2p:webrtc:signaling-stream-handler') + +export interface IncomingStreamOpts extends IncomingStreamData { + peerConnection: RTCPeerConnection + signal: AbortSignal +} + +export async function handleIncomingStream ({ peerConnection, stream, signal, connection }: IncomingStreamOpts): Promise<{ remoteAddress: string }> { + log.trace('new inbound signaling stream') + + const messageStream = pbStream(stream).pb(Message) + + try { + const connectedPromise: DeferredPromise = pDefer() + const answerSentPromise: DeferredPromise = pDefer() + + signal.onabort = () => { + connectedPromise.reject(new CodeError('Timed out while trying to connect', 'ERR_TIMEOUT')) + } + + // candidate callbacks + peerConnection.onicecandidate = ({ candidate }) => { + answerSentPromise.promise.then( + async () => { + // a null candidate means end-of-candidates, an empty string candidate + // means end-of-candidates for this generation, otherwise this should + // be a valid candidate object + // see - https://www.w3.org/TR/webrtc/#rtcpeerconnectioniceevent + const data = JSON.stringify(candidate?.toJSON() ?? null) + + log.trace('recipient sending ICE candidate %s', data) + + await messageStream.write({ + type: Message.Type.ICE_CANDIDATE, + data + }, { + signal + }) + }, + (err) => { + log.error('cannot set candidate since sending answer failed', err) + connectedPromise.reject(err) + } + ) + } + + resolveOnConnected(peerConnection, connectedPromise) + + // read an SDP offer + const pbOffer = await messageStream.read({ + signal + }) + + if (pbOffer.type !== Message.Type.SDP_OFFER) { + throw new CodeError(`expected message type SDP_OFFER, received: ${pbOffer.type ?? 'undefined'} `, 'ERR_SDP_HANDSHAKE_FAILED') + } + + log.trace('recipient receive SDP offer %s', pbOffer.data) + + const offer = new RTCSessionDescription({ + type: 'offer', + sdp: pbOffer.data + }) + + await peerConnection.setRemoteDescription(offer).catch(err => { + log.error('could not execute setRemoteDescription', err) + throw new CodeError('Failed to set remoteDescription', 'ERR_SDP_HANDSHAKE_FAILED') + }) + + // create and write an SDP answer + const answer = await peerConnection.createAnswer().catch(err => { + log.error('could not execute createAnswer', err) + answerSentPromise.reject(err) + throw new CodeError('Failed to create answer', 'ERR_SDP_HANDSHAKE_FAILED') + }) + + log.trace('recipient send SDP answer %s', answer.sdp) + + // write the answer to the remote + await messageStream.write({ type: Message.Type.SDP_ANSWER, data: answer.sdp }, { + signal + }) + + await peerConnection.setLocalDescription(answer).catch(err => { + log.error('could not execute setLocalDescription', err) + answerSentPromise.reject(err) + throw new CodeError('Failed to set localDescription', 'ERR_SDP_HANDSHAKE_FAILED') + }) + + answerSentPromise.resolve() + + log.trace('recipient read candidates until connected') + + // wait until candidates are connected + await readCandidatesUntilConnected(connectedPromise, peerConnection, messageStream, { + direction: 'recipient', + signal + }) + + log.trace('recipient connected, closing signaling stream') + await messageStream.unwrap().unwrap().close({ + signal + }) + } catch (err: any) { + if (peerConnection.connectionState !== 'connected') { + log.error('error while handling signaling stream from peer %a', connection.remoteAddr, err) + + peerConnection.close() + throw err + } else { + log('error while handling signaling stream from peer %a, ignoring as the RTCPeerConnection is already connected', connection.remoteAddr, err) + } + } + + const remoteAddress = parseRemoteAddress(peerConnection.currentRemoteDescription?.sdp ?? '') + + log.trace('recipient connected to remote address %s', remoteAddress) + + return { remoteAddress } +} diff --git a/packages/transport-webrtc/src/private-to-private/transport.ts b/packages/transport-webrtc/src/private-to-private/transport.ts index 9ab8de1f1e..af5c3ea284 100644 --- a/packages/transport-webrtc/src/private-to-private/transport.ts +++ b/packages/transport-webrtc/src/private-to-private/transport.ts @@ -6,26 +6,36 @@ import { multiaddr, type Multiaddr } from '@multiformats/multiaddr' import { WebRTC } from '@multiformats/multiaddr-matcher' import { codes } from '../error.js' import { WebRTCMultiaddrConnection } from '../maconn.js' -import { cleanup } from '../webrtc/index.js' -import { initiateConnection, handleIncomingStream } from './handler.js' +import { DataChannelMuxerFactory } from '../muxer.js' +import { cleanup, RTCPeerConnection } from '../webrtc/index.js' +import { initiateConnection } from './initiate-connection.js' import { WebRTCPeerListener } from './listener.js' -import type { DataChannelOpts } from '../stream.js' +import { handleIncomingStream } from './signaling-stream-handler.js' +import type { DataChannelOptions } from '../index.js' import type { Connection } from '@libp2p/interface/connection' import type { PeerId } from '@libp2p/interface/peer-id' import type { CounterGroup, Metrics } from '@libp2p/interface/src/metrics/index.js' import type { Startable } from '@libp2p/interface/startable' import type { IncomingStreamData, Registrar } from '@libp2p/interface-internal/registrar' +import type { ConnectionManager } from '@libp2p/interface-internal/src/connection-manager/index.js' import type { TransportManager } from '@libp2p/interface-internal/transport-manager' const log = logger('libp2p:webrtc:peer') const WEBRTC_TRANSPORT = '/webrtc' const CIRCUIT_RELAY_TRANSPORT = '/p2p-circuit' -const SIGNALING_PROTO_ID = '/webrtc-signaling/0.0.1' +export const SIGNALING_PROTO_ID = '/webrtc-signaling/0.0.1' +const INBOUND_CONNECTION_TIMEOUT = 30 * 1000 export interface WebRTCTransportInit { rtcConfiguration?: RTCConfiguration - dataChannel?: Partial + dataChannel?: DataChannelOptions + + /** + * Inbound connections must complete the upgrade within this many ms + * (default: 30s) + */ + inboundConnectionTimeout?: number } export interface WebRTCTransportComponents { @@ -33,6 +43,7 @@ export interface WebRTCTransportComponents { registrar: Registrar upgrader: Upgrader transportManager: TransportManager + connectionManager: ConnectionManager metrics?: Metrics } @@ -44,11 +55,14 @@ export interface WebRTCTransportMetrics { export class WebRTCTransport implements Transport, Startable { private _started = false private readonly metrics?: WebRTCTransportMetrics + private readonly shutdownController: AbortController constructor ( private readonly components: WebRTCTransportComponents, private readonly init: WebRTCTransportInit = {} ) { + this.shutdownController = new AbortController() + if (components.metrics != null) { this.metrics = { dialerEvents: components.metrics.registerCounterGroup('libp2p_webrtc_dialer_events_total', { @@ -83,7 +97,9 @@ export class WebRTCTransport implements Transport, Startable { } createListener (options: CreateListenerOptions): Listener { - return new WebRTCPeerListener(this.components) + return new WebRTCPeerListener(this.components, { + shutdownController: this.shutdownController + }) } readonly [Symbol.toStringTag] = '@libp2p/webrtc' @@ -102,84 +118,96 @@ export class WebRTCTransport implements Transport, Startable { * /p2p//p2p-circuit/webrtc/p2p/ */ async dial (ma: Multiaddr, options: DialOptions): Promise { - log.trace('dialing address: ', ma) - const { baseAddr, peerId } = splitAddr(ma) + log.trace('dialing address: %a', ma) - if (options.signal == null) { - const controller = new AbortController() - options.signal = controller.signal - } + const peerConnection = new RTCPeerConnection(this.init.rtcConfiguration) + const muxerFactory = new DataChannelMuxerFactory({ + peerConnection, + dataChannelOptions: this.init.dataChannel + }) - this.metrics?.dialerEvents.increment({ open: true }) - const connection = await this.components.transportManager.dial(baseAddr, options) - const signalingStream = await connection.newStream(SIGNALING_PROTO_ID, { - ...options, - runOnTransientConnection: true + const { remoteAddress } = await initiateConnection({ + peerConnection, + multiaddr: ma, + dataChannelOptions: this.init.dataChannel, + signal: options.signal, + connectionManager: this.components.connectionManager, + transportManager: this.components.transportManager }) - try { - const { pc, muxerFactory, remoteAddress } = await initiateConnection({ - stream: signalingStream, - rtcConfiguration: this.init.rtcConfiguration, - dataChannelOptions: this.init.dataChannel, - signal: options.signal - }) + const webRTCConn = new WebRTCMultiaddrConnection({ + peerConnection, + timeline: { open: Date.now() }, + remoteAddr: remoteAddress, + metrics: this.metrics?.dialerEvents + }) - const result = await options.upgrader.upgradeOutbound( - new WebRTCMultiaddrConnection({ - peerConnection: pc, - timeline: { open: Date.now() }, - remoteAddr: multiaddr(remoteAddress).encapsulate(`/p2p/${peerId.toString()}`), - metrics: this.metrics?.dialerEvents - }), - { - skipProtection: true, - skipEncryption: true, - muxerFactory - } - ) - - // close the stream if SDP has been exchanged successfully - await signalingStream.close() - return result - } catch (err: any) { - this.metrics?.dialerEvents.increment({ error: true }) - // reset the stream in case of any error - signalingStream.abort(err) - throw err - } finally { - // Close the signaling connection - await connection.close() - } + const connection = await options.upgrader.upgradeOutbound(webRTCConn, { + skipProtection: true, + skipEncryption: true, + muxerFactory + }) + + // close the connection on shut down + this._closeOnShutdown(peerConnection, webRTCConn) + + return connection } async _onProtocol ({ connection, stream }: IncomingStreamData): Promise { + const signal = AbortSignal.timeout(this.init.inboundConnectionTimeout ?? INBOUND_CONNECTION_TIMEOUT) + const peerConnection = new RTCPeerConnection(this.init.rtcConfiguration) + const muxerFactory = new DataChannelMuxerFactory({ peerConnection, dataChannelOptions: this.init.dataChannel }) + try { - const { pc, muxerFactory, remoteAddress } = await handleIncomingStream({ - rtcConfiguration: this.init.rtcConfiguration, + const { remoteAddress } = await handleIncomingStream({ + peerConnection, connection, stream, - dataChannelOptions: this.init.dataChannel + signal }) - await this.components.upgrader.upgradeInbound(new WebRTCMultiaddrConnection({ - peerConnection: pc, + const webRTCConn = new WebRTCMultiaddrConnection({ + peerConnection, timeline: { open: (new Date()).getTime() }, remoteAddr: multiaddr(remoteAddress).encapsulate(`/p2p/${connection.remotePeer.toString()}`), metrics: this.metrics?.listenerEvents - }), { + }) + + // close the connection on shut down + this._closeOnShutdown(peerConnection, webRTCConn) + + await this.components.upgrader.upgradeInbound(webRTCConn, { skipEncryption: true, skipProtection: true, muxerFactory }) + + // close the stream if SDP messages have been exchanged successfully + await stream.close({ + signal + }) } catch (err: any) { stream.abort(err) throw err - } finally { - // Close the signaling connection - await connection.close() } } + + private _closeOnShutdown (pc: RTCPeerConnection, webRTCConn: WebRTCMultiaddrConnection): void { + // close the connection on shut down + const shutDownListener = (): void => { + webRTCConn.close() + .catch(err => { + log.error('could not close WebRTCMultiaddrConnection', err) + }) + } + + this.shutdownController.signal.addEventListener('abort', shutDownListener) + + pc.addEventListener('close', () => { + this.shutdownController.signal.removeEventListener('abort', shutDownListener) + }) + } } export function splitAddr (ma: Multiaddr): { baseAddr: Multiaddr, peerId: PeerId } { diff --git a/packages/transport-webrtc/src/private-to-private/util.ts b/packages/transport-webrtc/src/private-to-private/util.ts index 6d2b97898d..b892e1d4b6 100644 --- a/packages/transport-webrtc/src/private-to-private/util.ts +++ b/packages/transport-webrtc/src/private-to-private/util.ts @@ -1,49 +1,101 @@ +import { CodeError } from '@libp2p/interface/errors' import { logger } from '@libp2p/logger' +import { abortableSource } from 'abortable-iterator' +import { anySignal } from 'any-signal' +import * as lp from 'it-length-prefixed' +import { AbortError, raceSignal } from 'race-signal' import { isFirefox } from '../util.js' import { RTCIceCandidate } from '../webrtc/index.js' import { Message } from './pb/message.js' +import type { Stream } from '@libp2p/interface/connection' +import type { AbortOptions, MessageStream } from 'it-protobuf-stream' import type { DeferredPromise } from 'p-defer' -interface MessageStream { - read: () => Promise - write: (d: Message) => void | Promise +const log = logger('libp2p:webrtc:peer:util') + +export interface ReadCandidatesOptions extends AbortOptions { + direction: string } -const log = logger('libp2p:webrtc:peer:util') +export const readCandidatesUntilConnected = async (connectedPromise: DeferredPromise, pc: RTCPeerConnection, stream: MessageStream, options: ReadCandidatesOptions): Promise => { + // if we connect, stop trying to read from the stream + const controller = new AbortController() + connectedPromise.promise.then(() => { + controller.abort() + }, () => { + controller.abort() + }) + + const signal = anySignal([ + controller.signal, + options.signal + ]) + + const source = abortableSource(stream.unwrap().unwrap().source, signal, { + returnOnAbort: true + }) + + try { + // read candidates until we are connected or we reach the end of the stream + for await (const buf of lp.decode(source)) { + const message = Message.decode(buf) -export const readCandidatesUntilConnected = async (connectedPromise: DeferredPromise, pc: RTCPeerConnection, stream: MessageStream): Promise => { - while (true) { - const readResult = await Promise.race([connectedPromise.promise, stream.read()]) - // check if readResult is a message - if (readResult instanceof Object) { - const message = readResult if (message.type !== Message.Type.ICE_CANDIDATE) { - throw new Error('expected only ice candidates') + throw new CodeError('ICE candidate message expected', 'ERR_NOT_ICE_CANDIDATE') } - // end of candidates has been signalled - if (message.data == null || message.data === '') { + + let candidateInit = JSON.parse(message.data ?? 'null') + + if (candidateInit === '') { + log.trace('end-of-candidates for this generation received') + candidateInit = { + candidate: '', + sdpMid: '0', + sdpMLineIndex: 0 + } + } + + if (candidateInit === null) { log.trace('end-of-candidates received') - break + candidateInit = { + candidate: null, + sdpMid: '0', + sdpMLineIndex: 0 + } } - log.trace('received new ICE candidate: %s', message.data) + // a null candidate means end-of-candidates + // see - https://www.w3.org/TR/webrtc/#rtcpeerconnectioniceevent + const candidate = new RTCIceCandidate(candidateInit) + + log.trace('%s received new ICE candidate', options.direction, candidate) + try { - await pc.addIceCandidate(new RTCIceCandidate(JSON.parse(message.data))) + await pc.addIceCandidate(candidate) } catch (err) { - log.error('bad candidate received: ', err) - throw new Error('bad candidate received') + log.error('%s bad candidate received', options.direction, err) } - } else { - // connected promise resolved - break } + } catch (err) { + log.error('%s error parsing ICE candidate', options.direction, err) + } finally { + signal.clear() } - await connectedPromise.promise + + if (options.signal?.aborted === true) { + throw new AbortError('Aborted while reading ICE candidates', 'ERR_ICE_CANDIDATES_READ_ABORTED') + } + + // read all available ICE candidates, wait for connection state change + await raceSignal(connectedPromise.promise, options.signal, { + errorMessage: 'Aborted before connected', + errorCode: 'ERR_ABORTED_BEFORE_CONNECTED' + }) } export function resolveOnConnected (pc: RTCPeerConnection, promise: DeferredPromise): void { pc[isFirefox ? 'oniceconnectionstatechange' : 'onconnectionstatechange'] = (_) => { - log.trace('receiver peerConnectionState state: ', pc.connectionState) + log.trace('receiver peerConnectionState state change: %s', pc.connectionState) switch (isFirefox ? pc.iceConnectionState : pc.connectionState) { case 'connected': promise.resolve() @@ -51,10 +103,23 @@ export function resolveOnConnected (pc: RTCPeerConnection, promise: DeferredProm case 'failed': case 'disconnected': case 'closed': - promise.reject(new Error('RTCPeerConnection was closed')) + promise.reject(new CodeError('RTCPeerConnection was closed', 'ERR_CONNECTION_CLOSED_BEFORE_CONNECTED')) break default: break } } } + +export function parseRemoteAddress (sdp: string): string { + // 'a=candidate:1746876089 1 udp 2113937151 0614fbad-b...ocal 54882 typ host generation 0 network-cost 999' + const candidateLine = sdp.split('\r\n').filter(line => line.startsWith('a=candidate')).pop() + const candidateParts = candidateLine?.split(' ') + + if (candidateLine == null || candidateParts == null || candidateParts.length < 5) { + log('could not parse remote address from', candidateLine) + return '/webrtc' + } + + return `/dnsaddr/${candidateParts[4]}/${candidateParts[2].toLowerCase()}/${candidateParts[5]}/webrtc` +} diff --git a/packages/transport-webrtc/src/private-to-public/transport.ts b/packages/transport-webrtc/src/private-to-public/transport.ts index 23bb5d1994..a82f84a2b5 100644 --- a/packages/transport-webrtc/src/private-to-public/transport.ts +++ b/packages/transport-webrtc/src/private-to-public/transport.ts @@ -16,7 +16,7 @@ import { RTCPeerConnection } from '../webrtc/index.js' import * as sdp from './sdp.js' import { genUfrag } from './util.js' import type { WebRTCDialOptions } from './options.js' -import type { DataChannelOpts } from '../stream.js' +import type { DataChannelOptions } from '../index.js' import type { Connection } from '@libp2p/interface/connection' import type { CounterGroup, Metrics } from '@libp2p/interface/metrics' import type { PeerId } from '@libp2p/interface/peer-id' @@ -56,7 +56,7 @@ export interface WebRTCMetrics { } export interface WebRTCTransportDirectInit { - dataChannel?: Partial + dataChannel?: DataChannelOptions } export class WebRTCDirectTransport implements Transport { @@ -81,7 +81,7 @@ export class WebRTCDirectTransport implements Transport { */ async dial (ma: Multiaddr, options: WebRTCDialOptions): Promise { const rawConn = await this._connect(ma, options) - log(`dialing address - ${ma.toString()}`) + log('dialing address: %a', ma) return rawConn } @@ -194,7 +194,7 @@ export class WebRTCDirectTransport implements Transport { // we pass in undefined for these parameters. const noise = Noise({ prologueBytes: fingerprintsPrologue })() - const wrappedChannel = createStream({ channel: handshakeDataChannel, direction: 'inbound', dataChannelOptions: this.init.dataChannel }) + const wrappedChannel = createStream({ channel: handshakeDataChannel, direction: 'inbound', ...(this.init.dataChannel ?? {}) }) const wrappedDuplex = { ...wrappedChannel, sink: wrappedChannel.sink.bind(wrappedChannel), diff --git a/packages/transport-webrtc/src/stream.ts b/packages/transport-webrtc/src/stream.ts index 1e9bb0c3c7..d4c37cc857 100644 --- a/packages/transport-webrtc/src/stream.ts +++ b/packages/transport-webrtc/src/stream.ts @@ -3,18 +3,18 @@ import { AbstractStream, type AbstractStreamInit } from '@libp2p/interface/strea import { logger } from '@libp2p/logger' import * as lengthPrefixed from 'it-length-prefixed' import { type Pushable, pushable } from 'it-pushable' +import pDefer from 'p-defer' import { pEvent, TimeoutError } from 'p-event' +import pTimeout from 'p-timeout' +import { raceSignal } from 'race-signal' import { Uint8ArrayList } from 'uint8arraylist' import { Message } from './pb/message.js' +import type { DataChannelOptions } from './index.js' +import type { AbortOptions } from '@libp2p/interface' import type { Direction } from '@libp2p/interface/connection' +import type { DeferredPromise } from 'p-defer' -export interface DataChannelOpts { - maxMessageSize: number - maxBufferedAmount: number - bufferedAmountLowEventTimeout: number -} - -export interface WebRTCStreamInit extends AbstractStreamInit { +export interface WebRTCStreamInit extends AbstractStreamInit, DataChannelOptions { /** * The network channel used for bidirectional peer-to-peer transfers of * arbitrary data @@ -22,38 +22,46 @@ export interface WebRTCStreamInit extends AbstractStreamInit { * {@link https://developer.mozilla.org/en-US/docs/Web/API/RTCDataChannel} */ channel: RTCDataChannel - - dataChannelOptions?: Partial - - maxDataSize: number } -// Max message size that can be sent to the DataChannel -export const MAX_MESSAGE_SIZE = 16 * 1024 - -// How much can be buffered to the DataChannel at once +/** + * How much can be buffered to the DataChannel at once + */ export const MAX_BUFFERED_AMOUNT = 16 * 1024 * 1024 -// How long time we wait for the 'bufferedamountlow' event to be emitted +/** + * How long time we wait for the 'bufferedamountlow' event to be emitted + */ export const BUFFERED_AMOUNT_LOW_TIMEOUT = 30 * 1000 -// protobuf field definition overhead +/** + * protobuf field definition overhead + */ export const PROTOBUF_OVERHEAD = 5 -// Length of varint, in bytes. +/** + * Length of varint, in bytes + */ export const VARINT_LENGTH = 2 +/** + * Max message size that can be sent to the DataChannel + */ +export const MAX_MESSAGE_SIZE = 16 * 1024 + +/** + * When closing streams we send a FIN then wait for the remote to + * reply with a FIN_ACK. If that does not happen within this timeout + * we close the stream anyway. + */ +export const FIN_ACK_TIMEOUT = 5000 + export class WebRTCStream extends AbstractStream { /** * The data channel used to send and receive data */ private readonly channel: RTCDataChannel - /** - * Data channel options - */ - private readonly dataChannelOptions: DataChannelOpts - /** * push data from the underlying datachannel to the length prefix decoder * and then the protobuf decoder. @@ -62,24 +70,65 @@ export class WebRTCStream extends AbstractStream { private messageQueue?: Uint8ArrayList + private readonly maxBufferedAmount: number + + private readonly bufferedAmountLowEventTimeout: number + /** * The maximum size of a message in bytes */ - private readonly maxDataSize: number + private readonly maxMessageSize: number + + /** + * When this promise is resolved, the remote has sent us a FIN flag + */ + private readonly receiveFinAck: DeferredPromise + private readonly finAckTimeout: number + // private sentFinAck: boolean constructor (init: WebRTCStreamInit) { + // override onEnd to send/receive FIN_ACK before closing the stream + const originalOnEnd = init.onEnd + init.onEnd = (err?: Error): void => { + this.log.trace('readable and writeable ends closed', this.status) + + void Promise.resolve(async () => { + if (this.timeline.abort != null || this.timeline.reset !== null) { + return + } + + // wait for FIN_ACK if we haven't received it already + try { + await pTimeout(this.receiveFinAck.promise, { + milliseconds: this.finAckTimeout + }) + } catch (err) { + this.log.error('error receiving FIN_ACK', err) + } + }) + .then(() => { + // stop processing incoming messages + this.incomingData.end() + + // final cleanup + originalOnEnd?.(err) + }) + .catch(err => { + this.log.error('error ending stream', err) + }) + } + super(init) this.channel = init.channel this.channel.binaryType = 'arraybuffer' this.incomingData = pushable() this.messageQueue = new Uint8ArrayList() - this.dataChannelOptions = { - bufferedAmountLowEventTimeout: init.dataChannelOptions?.bufferedAmountLowEventTimeout ?? BUFFERED_AMOUNT_LOW_TIMEOUT, - maxBufferedAmount: init.dataChannelOptions?.maxBufferedAmount ?? MAX_BUFFERED_AMOUNT, - maxMessageSize: init.dataChannelOptions?.maxMessageSize ?? init.maxDataSize - } - this.maxDataSize = init.maxDataSize + this.bufferedAmountLowEventTimeout = init.bufferedAmountLowEventTimeout ?? BUFFERED_AMOUNT_LOW_TIMEOUT + this.maxBufferedAmount = init.maxBufferedAmount ?? MAX_BUFFERED_AMOUNT + this.maxMessageSize = (init.maxMessageSize ?? MAX_MESSAGE_SIZE) - PROTOBUF_OVERHEAD - VARINT_LENGTH + this.receiveFinAck = pDefer() + this.finAckTimeout = init.closeTimeout ?? FIN_ACK_TIMEOUT // set up initial state switch (this.channel.readyState) { @@ -105,17 +154,25 @@ export class WebRTCStream extends AbstractStream { this.channel.onopen = (_evt) => { this.timeline.open = new Date().getTime() - if (this.messageQueue != null) { + if (this.messageQueue != null && this.messageQueue.byteLength > 0) { + this.log.trace('dataChannel opened, sending queued messages', this.messageQueue.byteLength, this.channel.readyState) + // send any queued messages this._sendMessage(this.messageQueue) .catch(err => { + this.log.error('error sending queued messages', err) this.abort(err) }) - this.messageQueue = undefined } + + this.messageQueue = undefined } this.channel.onclose = (_evt) => { + // if the channel has closed we'll never receive a FIN_ACK so resolve the + // promise so we don't try to wait later + this.receiveFinAck.resolve() + void this.close().catch(err => { this.log.error('error closing stream after channel closed', err) }) @@ -126,8 +183,6 @@ export class WebRTCStream extends AbstractStream { this.abort(err) } - const self = this - this.channel.onmessage = async (event: MessageEvent) => { const { data } = event @@ -138,11 +193,13 @@ export class WebRTCStream extends AbstractStream { this.incomingData.push(new Uint8Array(data, 0, data.byteLength)) } + const self = this + // pipe framed protobuf messages through a length prefixed decoder, and // surface data from the `Message.message` field through a source. Promise.resolve().then(async () => { for await (const buf of lengthPrefixed.decode(this.incomingData)) { - const message = self.processIncomingProtobuf(buf.subarray()) + const message = self.processIncomingProtobuf(buf) if (message != null) { self.sourcePush(new Uint8ArrayList(message)) @@ -159,12 +216,12 @@ export class WebRTCStream extends AbstractStream { } async _sendMessage (data: Uint8ArrayList, checkBuffer: boolean = true): Promise { - if (checkBuffer && this.channel.bufferedAmount > this.dataChannelOptions.maxBufferedAmount) { + if (checkBuffer && this.channel.bufferedAmount > this.maxBufferedAmount) { try { - await pEvent(this.channel, 'bufferedamountlow', { timeout: this.dataChannelOptions.bufferedAmountLowEventTimeout }) + await pEvent(this.channel, 'bufferedamountlow', { timeout: this.bufferedAmountLowEventTimeout }) } catch (err: any) { if (err instanceof TimeoutError) { - throw new Error('Timed out waiting for DataChannel buffer to clear') + throw new CodeError(`Timed out waiting for DataChannel buffer to clear after ${this.bufferedAmountLowEventTimeout}ms`, 'ERR_BUFFER_CLEAR_TIMEOUT') } throw err @@ -172,7 +229,7 @@ export class WebRTCStream extends AbstractStream { } if (this.channel.readyState === 'closed' || this.channel.readyState === 'closing') { - throw new CodeError('Invalid datachannel state - closed or closing', 'ERR_INVALID_STATE') + throw new CodeError(`Invalid datachannel state - ${this.channel.readyState}`, 'ERR_INVALID_STATE') } if (this.channel.readyState === 'open') { @@ -194,10 +251,12 @@ export class WebRTCStream extends AbstractStream { } async sendData (data: Uint8ArrayList): Promise { + // sending messages is an async operation so use a copy of the list as it + // may be changed beneath us data = data.sublist() while (data.byteLength > 0) { - const toSend = Math.min(data.byteLength, this.maxDataSize) + const toSend = Math.min(data.byteLength, this.maxMessageSize) const buf = data.subarray(0, toSend) const msgbuf = Message.encode({ message: buf }) const sendbuf = lengthPrefixed.encode.single(msgbuf) @@ -211,8 +270,25 @@ export class WebRTCStream extends AbstractStream { await this._sendFlag(Message.Flag.RESET) } - async sendCloseWrite (): Promise { - await this._sendFlag(Message.Flag.FIN) + async sendCloseWrite (options: AbortOptions): Promise { + const sent = await this._sendFlag(Message.Flag.FIN) + + if (sent) { + this.log.trace('awaiting FIN_ACK') + try { + await raceSignal(this.receiveFinAck.promise, options?.signal, { + errorMessage: 'sending close-write was aborted before FIN_ACK was received', + errorCode: 'ERR_FIN_ACK_NOT_RECEIVED' + }) + } catch (err) { + this.log.error('failed to await FIN_ACK', err) + } + } else { + this.log.trace('sending FIN failed, not awaiting FIN_ACK') + } + + // if we've attempted to receive a FIN_ACK, do not try again + this.receiveFinAck.resolve() } async sendCloseRead (): Promise { @@ -222,14 +298,21 @@ export class WebRTCStream extends AbstractStream { /** * Handle incoming */ - private processIncomingProtobuf (buffer: Uint8Array): Uint8Array | undefined { + private processIncomingProtobuf (buffer: Uint8ArrayList): Uint8Array | undefined { const message = Message.decode(buffer) if (message.flag !== undefined) { + this.log.trace('incoming flag %s, write status "%s", read status "%s"', message.flag, this.writeStatus, this.readStatus) + if (message.flag === Message.Flag.FIN) { // We should expect no more data from the remote, stop reading - this.incomingData.end() this.remoteCloseWrite() + + this.log.trace('sending FIN_ACK') + void this._sendFlag(Message.Flag.FIN_ACK) + .catch(err => { + this.log.error('error sending FIN_ACK immediately', err) + }) } if (message.flag === Message.Flag.RESET) { @@ -241,21 +324,45 @@ export class WebRTCStream extends AbstractStream { // The remote has stopped reading this.remoteCloseRead() } + + if (message.flag === Message.Flag.FIN_ACK) { + this.log.trace('received FIN_ACK') + this.receiveFinAck.resolve() + } } - return message.message + // ignore data messages if we've closed the readable end already + if (this.readStatus === 'ready') { + return message.message + } } - private async _sendFlag (flag: Message.Flag): Promise { - this.log.trace('Sending flag: %s', flag.toString()) + private async _sendFlag (flag: Message.Flag): Promise { + if (this.channel.readyState !== 'open') { + // flags can be sent while we or the remote are closing the datachannel so + // if the channel isn't open, don't try to send it but return false to let + // the caller know and act if they need to + this.log.trace('not sending flag %s because channel is not open', flag.toString()) + return false + } + + this.log.trace('sending flag %s', flag.toString()) const msgbuf = Message.encode({ flag }) const prefixedBuf = lengthPrefixed.encode.single(msgbuf) - await this._sendMessage(prefixedBuf, false) + try { + await this._sendMessage(prefixedBuf, false) + + return true + } catch (err: any) { + this.log.error('could not send flag %s', flag.toString(), err) + } + + return false } } -export interface WebRTCStreamOptions { +export interface WebRTCStreamOptions extends DataChannelOptions { /** * The network channel used for bidirectional peer-to-peer transfers of * arbitrary data @@ -269,23 +376,18 @@ export interface WebRTCStreamOptions { */ direction: Direction - dataChannelOptions?: Partial - - maxMsgSize?: number - + /** + * A callback invoked when the channel ends + */ onEnd?: (err?: Error | undefined) => void } export function createStream (options: WebRTCStreamOptions): WebRTCStream { - const { channel, direction, onEnd, dataChannelOptions } = options + const { channel, direction } = options return new WebRTCStream({ id: direction === 'inbound' ? (`i${channel.id}`) : `r${channel.id}`, - direction, - maxDataSize: (dataChannelOptions?.maxMessageSize ?? MAX_MESSAGE_SIZE) - PROTOBUF_OVERHEAD - VARINT_LENGTH, - dataChannelOptions, - onEnd, - channel, - log: logger(`libp2p:webrtc:stream:${direction}:${channel.id}`) + log: logger(`libp2p:webrtc:stream:${direction}:${channel.id}`), + ...options }) } diff --git a/packages/transport-webrtc/src/util.ts b/packages/transport-webrtc/src/util.ts index e26e64dd5f..e35b90ca43 100644 --- a/packages/transport-webrtc/src/util.ts +++ b/packages/transport-webrtc/src/util.ts @@ -1,4 +1,9 @@ +import { logger } from '@libp2p/logger' import { detect } from 'detect-browser' +import pDefer from 'p-defer' +import pTimeout from 'p-timeout' + +const log = logger('libp2p:webrtc:utils') const browser = detect() export const isFirefox = ((browser != null) && browser.name === 'firefox') @@ -6,3 +11,58 @@ export const isFirefox = ((browser != null) && browser.name === 'firefox') export const nopSource = async function * nop (): AsyncGenerator {} export const nopSink = async (_: any): Promise => {} + +export const DATA_CHANNEL_DRAIN_TIMEOUT = 30 * 1000 + +export function drainAndClose (channel: RTCDataChannel, direction: string, drainTimeout: number = DATA_CHANNEL_DRAIN_TIMEOUT): void { + if (channel.readyState !== 'open') { + return + } + + void Promise.resolve() + .then(async () => { + // wait for bufferedAmount to become zero + if (channel.bufferedAmount > 0) { + log('%s drain channel with %d buffered bytes', direction, channel.bufferedAmount) + const deferred = pDefer() + let drained = false + + channel.bufferedAmountLowThreshold = 0 + + const closeListener = (): void => { + if (!drained) { + log('%s drain channel closed before drain', direction) + deferred.resolve() + } + } + + channel.addEventListener('close', closeListener, { + once: true + }) + + channel.addEventListener('bufferedamountlow', () => { + drained = true + channel.removeEventListener('close', closeListener) + deferred.resolve() + }) + + await pTimeout(deferred.promise, { + milliseconds: drainTimeout + }) + } + }) + .then(async () => { + // only close if the channel is still open + if (channel.readyState === 'open') { + channel.close() + } + }) + .catch(err => { + log.error('error closing outbound stream', err) + }) +} + +export interface AbortPromiseOptions { + signal?: AbortSignal + message?: string +} diff --git a/packages/transport-webrtc/test/basics.spec.ts b/packages/transport-webrtc/test/basics.spec.ts index 03d89a1c8b..f7170c6126 100644 --- a/packages/transport-webrtc/test/basics.spec.ts +++ b/packages/transport-webrtc/test/basics.spec.ts @@ -7,15 +7,19 @@ import * as filter from '@libp2p/websockets/filters' import { WebRTC } from '@multiformats/mafmt' import { multiaddr } from '@multiformats/multiaddr' import { expect } from 'aegir/chai' +import drain from 'it-drain' import map from 'it-map' import { pipe } from 'it-pipe' +import { pushable } from 'it-pushable' import toBuffer from 'it-to-buffer' import { createLibp2p } from 'libp2p' import { circuitRelayTransport } from 'libp2p/circuit-relay' -import { identifyService } from 'libp2p/identify' +import pDefer from 'p-defer' +import pRetry from 'p-retry' import { webRTC } from '../src/index.js' import type { Libp2p } from '@libp2p/interface' -import type { Connection } from '@libp2p/interface/connection' +import type { Connection, Stream } from '@libp2p/interface/connection' +import type { StreamHandler } from '@libp2p/interface/stream-handler' async function createNode (): Promise { return createLibp2p({ @@ -38,9 +42,6 @@ async function createNode (): Promise { streamMuxers: [ yamux() ], - services: { - identify: identifyService() - }, connectionGater: { denyDialMultiaddr: () => false }, @@ -55,6 +56,7 @@ describe('basics', () => { let localNode: Libp2p let remoteNode: Libp2p + let streamHandler: StreamHandler async function connectNodes (): Promise { const remoteAddr = remoteNode.getMultiaddrs() @@ -64,11 +66,8 @@ describe('basics', () => { throw new Error('Remote peer could not listen on relay') } - await remoteNode.handle(echo, ({ stream }) => { - void pipe( - stream, - stream - ) + await remoteNode.handle(echo, (info) => { + streamHandler(info) }, { runOnTransientConnection: true }) @@ -83,6 +82,13 @@ describe('basics', () => { } beforeEach(async () => { + streamHandler = ({ stream }) => { + void pipe( + stream, + stream + ) + } + localNode = await createNode() remoteNode = await createNode() }) @@ -101,9 +107,7 @@ describe('basics', () => { const connection = await connectNodes() // open a stream on the echo protocol - const stream = await connection.newStream(echo, { - runOnTransientConnection: true - }) + const stream = await connection.newStream(echo) // send and receive some data const input = new Array(5).fill(0).map(() => new Uint8Array(10)) @@ -138,4 +142,204 @@ describe('basics', () => { // asset that we got the right data expect(output).to.equalBytes(toBuffer(input)) }) + + it('can close local stream for reading but send a large file', async () => { + let output: Uint8Array = new Uint8Array(0) + const streamClosed = pDefer() + + streamHandler = ({ stream }) => { + void Promise.resolve().then(async () => { + output = await toBuffer(map(stream.source, (buf) => buf.subarray())) + await stream.close() + streamClosed.resolve() + }) + } + + const connection = await connectNodes() + + // open a stream on the echo protocol + const stream = await connection.newStream(echo, { + runOnTransientConnection: true + }) + + // close for reading + await stream.closeRead() + + // send some data + const input = new Array(5).fill(0).map(() => new Uint8Array(1024 * 1024)) + + await stream.sink(input) + await stream.close() + + // wait for remote to receive all data + await streamClosed.promise + + // asset that we got the right data + expect(output).to.equalBytes(toBuffer(input)) + }) + + it('can close local stream for writing but receive a large file', async () => { + const input = new Array(5).fill(0).map(() => new Uint8Array(1024 * 1024)) + + streamHandler = ({ stream }) => { + void Promise.resolve().then(async () => { + // send some data + await stream.sink(input) + await stream.close() + }) + } + + const connection = await connectNodes() + + // open a stream on the echo protocol + const stream = await connection.newStream(echo, { + runOnTransientConnection: true + }) + + // close for reading + await stream.closeWrite() + + // receive some data + const output = await toBuffer(map(stream.source, (buf) => buf.subarray())) + + await stream.close() + + // asset that we got the right data + expect(output).to.equalBytes(toBuffer(input)) + }) + + it('can close local stream for writing and reading while a remote stream is writing', async () => { + /** + * NodeA NodeB + * | <--- STOP_SENDING | + * | FIN ---> | + * | <--- FIN | + * | FIN_ACK ---> | + * | <--- FIN_ACK | + */ + + const getRemoteStream = pDefer() + + streamHandler = ({ stream }) => { + void Promise.resolve().then(async () => { + getRemoteStream.resolve(stream) + }) + } + + const connection = await connectNodes() + + // open a stream on the echo protocol + const stream = await connection.newStream(echo, { + runOnTransientConnection: true + }) + + const remoteStream = await getRemoteStream.promise + // close the readable end of the remote stream + await remoteStream.closeRead() + + // keep the remote write end open, this should delay the FIN_ACK reply to the local stream + const remoteInputStream = pushable() + void remoteStream.sink(remoteInputStream) + + const p = stream.closeWrite() + + // wait for remote to receive local close-write + await pRetry(() => { + if (remoteStream.readStatus !== 'closed') { + throw new Error('Remote stream read status ' + remoteStream.readStatus) + } + }, { + minTimeout: 100 + }) + + // remote closes write + remoteInputStream.end() + + // wait to receive FIN_ACK + await p + + // wait for remote to notice closure + await pRetry(() => { + if (remoteStream.status !== 'closed') { + throw new Error('Remote stream not closed') + } + }) + + assertStreamClosed(stream) + assertStreamClosed(remoteStream) + }) + + it('can close local stream for writing and reading while a remote stream is writing using source/sink', async () => { + /** + * NodeA NodeB + * | FIN ---> | + * | <--- FIN | + * | FIN_ACK ---> | + * | <--- FIN_ACK | + */ + + const getRemoteStream = pDefer() + + streamHandler = ({ stream }) => { + void Promise.resolve().then(async () => { + getRemoteStream.resolve(stream) + }) + } + + const connection = await connectNodes() + + // open a stream on the echo protocol + const stream = await connection.newStream(echo, { + runOnTransientConnection: true + }) + + const remoteStream = await getRemoteStream.promise + // close the readable end of the remote stream + await remoteStream.closeRead() + // readable end should finish + await drain(remoteStream.source) + + // keep the remote write end open, this should delay the FIN_ACK reply to the local stream + const p = stream.sink([]) + + // wait for remote to receive local close-write + await pRetry(() => { + if (remoteStream.readStatus !== 'closed') { + throw new Error('Remote stream read status ' + remoteStream.readStatus) + } + }, { + minTimeout: 100 + }) + + // remote closes write + await remoteStream.sink([]) + + // wait to receive FIN_ACK + await p + + // close read end of stream + await stream.closeRead() + // readable end should finish + await drain(stream.source) + + // wait for remote to notice closure + await pRetry(() => { + if (remoteStream.status !== 'closed') { + throw new Error('Remote stream not closed') + } + }) + + assertStreamClosed(stream) + assertStreamClosed(remoteStream) + }) }) + +function assertStreamClosed (stream: Stream): void { + expect(stream.status).to.equal('closed') + expect(stream.readStatus).to.equal('closed') + expect(stream.writeStatus).to.equal('closed') + + expect(stream.timeline.close).to.be.a('number') + expect(stream.timeline.closeRead).to.be.a('number') + expect(stream.timeline.closeWrite).to.be.a('number') +} diff --git a/packages/transport-webrtc/test/listener.spec.ts b/packages/transport-webrtc/test/listener.spec.ts index 34feedb859..036e727d7c 100644 --- a/packages/transport-webrtc/test/listener.spec.ts +++ b/packages/transport-webrtc/test/listener.spec.ts @@ -16,6 +16,8 @@ describe('webrtc private-to-private listener', () => { const listener = new WebRTCPeerListener({ peerId, transportManager + }, { + shutdownController: new AbortController() }) const otherListener = stubInterface({ diff --git a/packages/transport-webrtc/test/peer.browser.spec.ts b/packages/transport-webrtc/test/peer.browser.spec.ts index 623a8a8542..5e98c1078a 100644 --- a/packages/transport-webrtc/test/peer.browser.spec.ts +++ b/packages/transport-webrtc/test/peer.browser.spec.ts @@ -1,56 +1,119 @@ -import { mockConnection, mockMultiaddrConnection, mockRegistrar, mockStream, mockUpgrader } from '@libp2p/interface-compliance-tests/mocks' +import { mockRegistrar, mockUpgrader, streamPair } from '@libp2p/interface-compliance-tests/mocks' import { createEd25519PeerId } from '@libp2p/peer-id-factory' -import { multiaddr } from '@multiformats/multiaddr' +import { multiaddr, type Multiaddr } from '@multiformats/multiaddr' import { expect } from 'aegir/chai' import { detect } from 'detect-browser' -import { pair } from 'it-pair' import { duplexPair } from 'it-pair/duplex' import { pbStream } from 'it-protobuf-stream' import Sinon from 'sinon' -import { initiateConnection, handleIncomingStream } from '../src/private-to-private/handler.js' +import { stubInterface, type StubbedInstance } from 'sinon-ts' +import { initiateConnection } from '../src/private-to-private/initiate-connection.js' import { Message } from '../src/private-to-private/pb/message.js' -import { WebRTCTransport, splitAddr } from '../src/private-to-private/transport.js' +import { handleIncomingStream } from '../src/private-to-private/signaling-stream-handler.js' +import { SIGNALING_PROTO_ID, WebRTCTransport, splitAddr } from '../src/private-to-private/transport.js' import { RTCPeerConnection, RTCSessionDescription } from '../src/webrtc/index.js' +import type { Connection, Stream } from '@libp2p/interface/connection' +import type { ConnectionManager } from '@libp2p/interface-internal/connection-manager' +import type { TransportManager } from '@libp2p/interface-internal/transport-manager' const browser = detect() +interface PrivateToPrivateComponents { + initiator: { + multiaddr: Multiaddr + peerConnection: RTCPeerConnection + connectionManager: StubbedInstance + transportManager: StubbedInstance + connection: StubbedInstance + stream: Stream + } + recipient: { + peerConnection: RTCPeerConnection + connection: StubbedInstance + abortController: AbortController + signal: AbortSignal + stream: Stream + } +} + +async function getComponents (): Promise { + const relayPeerId = await createEd25519PeerId() + const receiverPeerId = await createEd25519PeerId() + const receiverMultiaddr = multiaddr(`/ip4/123.123.123.123/tcp/123/p2p/${relayPeerId}/p2p-circuit/webrtc/p2p/${receiverPeerId}`) + const [initiatorToReceiver, receiverToInitiator] = duplexPair() + const [initiatorStream, receiverStream] = streamPair({ + duplex: initiatorToReceiver, + init: { + protocol: SIGNALING_PROTO_ID + } + }, { + duplex: receiverToInitiator, + init: { + protocol: SIGNALING_PROTO_ID + } + }) + + const recipientAbortController = new AbortController() + + return { + initiator: { + multiaddr: receiverMultiaddr, + peerConnection: new RTCPeerConnection(), + connectionManager: stubInterface(), + transportManager: stubInterface(), + connection: stubInterface(), + stream: initiatorStream + }, + recipient: { + peerConnection: new RTCPeerConnection(), + connection: stubInterface(), + abortController: recipientAbortController, + signal: recipientAbortController.signal, + stream: receiverStream + } + } +} + describe('webrtc basic', () => { const isFirefox = ((browser != null) && browser.name === 'firefox') it('should connect', async () => { - const [receiver, initiator] = duplexPair() - const dstPeerId = await createEd25519PeerId() - const connection = mockConnection( - mockMultiaddrConnection(pair(), dstPeerId) - ) - const controller = new AbortController() - const initiatorPeerConnectionPromise = initiateConnection({ stream: mockStream(initiator), signal: controller.signal }) - const receiverPeerConnectionPromise = handleIncomingStream({ stream: mockStream(receiver), connection }) - await expect(initiatorPeerConnectionPromise).to.be.fulfilled() - await expect(receiverPeerConnectionPromise).to.be.fulfilled() - const [{ pc: pc0 }, { pc: pc1 }] = await Promise.all([initiatorPeerConnectionPromise, receiverPeerConnectionPromise]) + const { initiator, recipient } = await getComponents() + + // no existing connection + initiator.connectionManager.getConnections.returns([]) + + // transport manager dials recipient + initiator.transportManager.dial.resolves(initiator.connection) + + // signalling stream opens successfully + initiator.connection.newStream.withArgs(SIGNALING_PROTO_ID).resolves(initiator.stream) + + await expect( + Promise.all([ + initiateConnection(initiator), + handleIncomingStream(recipient) + ]) + ).to.eventually.be.fulfilled() + if (isFirefox) { - expect(pc0.iceConnectionState).eq('connected') - expect(pc1.iceConnectionState).eq('connected') + expect(initiator.peerConnection.iceConnectionState).eq('connected') + expect(recipient.peerConnection.iceConnectionState).eq('connected') return } - expect(pc0.connectionState).eq('connected') - expect(pc1.connectionState).eq('connected') + expect(initiator.peerConnection.connectionState).eq('connected') + expect(recipient.peerConnection.connectionState).eq('connected') - pc0.close() - pc1.close() + initiator.peerConnection.close() + recipient.peerConnection.close() }) }) describe('webrtc receiver', () => { it('should fail receiving on invalid sdp offer', async () => { - const [receiver, initiator] = duplexPair() - const dstPeerId = await createEd25519PeerId() - const connection = mockConnection( - mockMultiaddrConnection(pair(), dstPeerId) - ) - const receiverPeerConnectionPromise = handleIncomingStream({ stream: mockStream(receiver), connection }) - const stream = pbStream(initiator).pb(Message) + const { initiator, recipient } = await getComponents() + const receiverPeerConnectionPromise = handleIncomingStream(recipient) + const stream = pbStream(initiator.stream).pb(Message) await stream.write({ type: Message.Type.SDP_OFFER, data: 'bad' }) await expect(receiverPeerConnectionPromise).to.be.rejectedWith(/Failed to set remoteDescription/) @@ -59,10 +122,18 @@ describe('webrtc receiver', () => { describe('webrtc dialer', () => { it('should fail receiving on invalid sdp answer', async () => { - const [receiver, initiator] = duplexPair() - const controller = new AbortController() - const initiatorPeerConnectionPromise = initiateConnection({ signal: controller.signal, stream: mockStream(initiator) }) - const stream = pbStream(receiver).pb(Message) + const { initiator, recipient } = await getComponents() + + // existing connection already exists + initiator.connectionManager.getConnections.returns([ + initiator.connection + ]) + + // signalling stream opens successfully + initiator.connection.newStream.withArgs(SIGNALING_PROTO_ID).resolves(initiator.stream) + + const initiatorPeerConnectionPromise = initiateConnection(initiator) + const stream = pbStream(recipient.stream).pb(Message) const offerMessage = await stream.read() expect(offerMessage.type).to.eq(Message.Type.SDP_OFFER) @@ -72,10 +143,19 @@ describe('webrtc dialer', () => { }) it('should fail on receiving a candidate before an answer', async () => { - const [receiver, initiator] = duplexPair() - const controller = new AbortController() - const initiatorPeerConnectionPromise = initiateConnection({ signal: controller.signal, stream: mockStream(initiator) }) - const stream = pbStream(receiver).pb(Message) + const { initiator, recipient } = await getComponents() + + // existing connection already exists + initiator.connectionManager.getConnections.returns([ + initiator.connection + ]) + + // signalling stream opens successfully + initiator.connection.newStream.withArgs(SIGNALING_PROTO_ID).resolves(initiator.stream) + + const initiatorPeerConnectionPromise = initiateConnection(initiator) + + const stream = pbStream(recipient.stream).pb(Message) const pc = new RTCPeerConnection() pc.onicecandidate = ({ candidate }) => { @@ -99,7 +179,8 @@ describe('webrtc dialer', () => { describe('webrtc filter', () => { it('can filter multiaddrs to dial', async () => { const transport = new WebRTCTransport({ - transportManager: Sinon.stub() as any, + transportManager: stubInterface(), + connectionManager: stubInterface(), peerId: Sinon.stub() as any, registrar: mockRegistrar(), upgrader: mockUpgrader({}) diff --git a/packages/transport-webrtc/test/stream.browser.spec.ts b/packages/transport-webrtc/test/stream.browser.spec.ts index 457f95317d..3e72b51e63 100644 --- a/packages/transport-webrtc/test/stream.browser.spec.ts +++ b/packages/transport-webrtc/test/stream.browser.spec.ts @@ -5,13 +5,14 @@ import { bytes } from 'multiformats' import { Message } from '../src/pb/message.js' import { createStream, type WebRTCStream } from '../src/stream.js' import { RTCPeerConnection } from '../src/webrtc/index.js' +import { receiveFinAck } from './util.js' import type { Stream } from '@libp2p/interface/connection' const TEST_MESSAGE = 'test_message' function setup (): { peerConnection: RTCPeerConnection, dataChannel: RTCDataChannel, stream: WebRTCStream } { const peerConnection = new RTCPeerConnection() const dataChannel = peerConnection.createDataChannel('whatever', { negotiated: true, id: 91 }) - const stream = createStream({ channel: dataChannel, direction: 'outbound' }) + const stream = createStream({ channel: dataChannel, direction: 'outbound', closeTimeout: 1 }) return { peerConnection, dataChannel, stream } } @@ -28,9 +29,10 @@ function generatePbByFlag (flag?: Message.Flag): Uint8Array { describe('Stream Stats', () => { let stream: WebRTCStream let peerConnection: RTCPeerConnection + let dataChannel: RTCDataChannel beforeEach(async () => { - ({ stream, peerConnection } = setup()) + ({ stream, peerConnection, dataChannel } = setup()) }) afterEach(() => { @@ -45,7 +47,10 @@ describe('Stream Stats', () => { it('close marks it closed', async () => { expect(stream.timeline.close).to.not.exist() + + receiveFinAck(dataChannel) await stream.close() + expect(stream.timeline.close).to.be.a('number') }) @@ -58,15 +63,23 @@ describe('Stream Stats', () => { it('closeWrite marks it write-closed only', async () => { expect(stream.timeline.close).to.not.exist() + + receiveFinAck(dataChannel) await stream.closeWrite() + expect(stream.timeline.close).to.not.exist() expect(stream.timeline.closeWrite).to.be.greaterThanOrEqual(stream.timeline.open) }) it('closeWrite AND closeRead = close', async () => { expect(stream.timeline.close).to.not.exist() - await stream.closeWrite() - await stream.closeRead() + + receiveFinAck(dataChannel) + await Promise.all([ + stream.closeRead(), + stream.closeWrite() + ]) + expect(stream.timeline.close).to.be.a('number') expect(stream.timeline.closeWrite).to.be.greaterThanOrEqual(stream.timeline.open) expect(stream.timeline.closeRead).to.be.greaterThanOrEqual(stream.timeline.open) diff --git a/packages/transport-webrtc/test/stream.spec.ts b/packages/transport-webrtc/test/stream.spec.ts index 500cbb02de..6a80366373 100644 --- a/packages/transport-webrtc/test/stream.spec.ts +++ b/packages/transport-webrtc/test/stream.spec.ts @@ -4,40 +4,35 @@ import { expect } from 'aegir/chai' import length from 'it-length' import * as lengthPrefixed from 'it-length-prefixed' import { pushable } from 'it-pushable' +import pDefer from 'p-defer' import { Uint8ArrayList } from 'uint8arraylist' import { Message } from '../src/pb/message.js' import { MAX_BUFFERED_AMOUNT, MAX_MESSAGE_SIZE, PROTOBUF_OVERHEAD, createStream } from '../src/stream.js' - -const mockDataChannel = (opts: { send: (bytes: Uint8Array) => void, bufferedAmount?: number }): RTCDataChannel => { - return { - readyState: 'open', - close: () => { }, - addEventListener: (_type: string, _listener: () => void) => { }, - removeEventListener: (_type: string, _listener: () => void) => { }, - ...opts - } as RTCDataChannel -} +import { mockDataChannel, receiveFinAck } from './util.js' describe('Max message size', () => { it(`sends messages smaller or equal to ${MAX_MESSAGE_SIZE} bytes in one`, async () => { const sent: Uint8ArrayList = new Uint8ArrayList() const data = new Uint8Array(MAX_MESSAGE_SIZE - PROTOBUF_OVERHEAD) const p = pushable() + const channel = mockDataChannel({ + send: (bytes) => { + sent.append(bytes) + } + }) // Make sure that the data that ought to be sent will result in a message with exactly MAX_MESSAGE_SIZE const messageLengthEncoded = lengthPrefixed.encode.single(Message.encode({ message: data })) expect(messageLengthEncoded.length).eq(MAX_MESSAGE_SIZE) const webrtcStream = createStream({ - channel: mockDataChannel({ - send: (bytes) => { - sent.append(bytes) - } - }), - direction: 'outbound' + channel, + direction: 'outbound', + closeTimeout: 1 }) p.push(data) p.end() + receiveFinAck(channel) await webrtcStream.sink(p) expect(length(sent)).to.equal(6) @@ -51,22 +46,24 @@ describe('Max message size', () => { const sent: Uint8ArrayList = new Uint8ArrayList() const data = new Uint8Array(MAX_MESSAGE_SIZE) const p = pushable() + const channel = mockDataChannel({ + send: (bytes) => { + sent.append(bytes) + } + }) // Make sure that the data that ought to be sent will result in a message with exactly MAX_MESSAGE_SIZE + 1 // const messageLengthEncoded = lengthPrefixed.encode.single(Message.encode({ message: data })).subarray() // expect(messageLengthEncoded.length).eq(MAX_MESSAGE_SIZE + 1) const webrtcStream = createStream({ - channel: mockDataChannel({ - send: (bytes) => { - sent.append(bytes) - } - }), + channel, direction: 'outbound' }) p.push(data) p.end() + receiveFinAck(channel) await webrtcStream.sink(p) expect(length(sent)).to.equal(6) @@ -78,31 +75,31 @@ describe('Max message size', () => { it('closes the stream if bufferamountlow timeout', async () => { const timeout = 100 - let closed = false - const webrtcStream = createStream({ - dataChannelOptions: { - bufferedAmountLowEventTimeout: timeout + const closed = pDefer() + const channel = mockDataChannel({ + send: () => { + throw new Error('Expected to not send') }, - channel: mockDataChannel({ - send: () => { - throw new Error('Expected to not send') - }, - bufferedAmount: MAX_BUFFERED_AMOUNT + 1 - }), + bufferedAmount: MAX_BUFFERED_AMOUNT + 1 + }) + const webrtcStream = createStream({ + bufferedAmountLowEventTimeout: timeout, + closeTimeout: 1, + channel, direction: 'outbound', onEnd: () => { - closed = true + closed.resolve() } }) const t0 = Date.now() await expect(webrtcStream.sink([new Uint8Array(1)])).to.eventually.be.rejected - .with.property('message', 'Timed out waiting for DataChannel buffer to clear') + .with.property('code', 'ERR_BUFFER_CLEAR_TIMEOUT') const t1 = Date.now() expect(t1 - t0).greaterThan(timeout) expect(t1 - t0).lessThan(timeout + 1000) // Some upper bound - expect(closed).true() + await closed.promise expect(webrtcStream.timeline.close).to.be.greaterThan(webrtcStream.timeline.open) expect(webrtcStream.timeline.abort).to.be.greaterThan(webrtcStream.timeline.open) }) diff --git a/packages/transport-webrtc/test/util.ts b/packages/transport-webrtc/test/util.ts index 70c492b6a7..2ee0cf36dd 100644 --- a/packages/transport-webrtc/test/util.ts +++ b/packages/transport-webrtc/test/util.ts @@ -1,4 +1,6 @@ import { expect } from 'aegir/chai' +import * as lengthPrefixed from 'it-length-prefixed' +import { Message } from '../src/pb/message.js' export const expectError = (error: unknown, message: string): void => { if (error instanceof Error) { @@ -7,3 +9,28 @@ export const expectError = (error: unknown, message: string): void => { expect('Did not throw error:').to.equal(message) } } + +/** + * simulates receiving a FIN_ACK on the passed datachannel + */ +export function receiveFinAck (channel: RTCDataChannel): void { + const msgbuf = Message.encode({ flag: Message.Flag.FIN_ACK }) + const data = lengthPrefixed.encode.single(msgbuf).subarray() + channel.onmessage?.(new MessageEvent('message', { data })) +} + +let mockDataChannelId = 0 + +export const mockDataChannel = (opts: { send: (bytes: Uint8Array) => void, bufferedAmount?: number }): RTCDataChannel => { + // @ts-expect-error incomplete implementation + const channel: RTCDataChannel = { + readyState: 'open', + close: () => { }, + addEventListener: (_type: string, _listener: (_: any) => void) => { }, + removeEventListener: (_type: string, _listener: (_: any) => void) => { }, + id: mockDataChannelId++, + ...opts + } + + return channel +} diff --git a/packages/transport-websockets/test/node.ts b/packages/transport-websockets/test/node.ts index 3526129a71..32aa1b02ae 100644 --- a/packages/transport-websockets/test/node.ts +++ b/packages/transport-websockets/test/node.ts @@ -332,6 +332,10 @@ describe('dial', () => { return !isLoopbackAddr(address) }) + if (addrs.length === 0) { + return + } + // Dial first no loopback address const conn = await ws.dial(addrs[0], { upgrader }) const s = goodbye({ source: [uint8ArrayFromString('hey')], sink: all })