diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index b0276a2..07989cd 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -71,3 +71,7 @@ jobs: nim --version nimble --version nimble test + + - name: Build examples + run: | + nimble build_examples diff --git a/examples/ping.nim b/examples/ping.nim new file mode 100644 index 0000000..bee1971 --- /dev/null +++ b/examples/ping.nim @@ -0,0 +1,21 @@ +import chronos, stew/byteutils +import ../webrtc/udp_transport +import ../webrtc/stun/stun_transport +import ../webrtc/dtls/dtls_transport +import ../webrtc/sctp/[sctp_transport, sctp_connection] + +proc main() {.async.} = + let laddr = initTAddress("127.0.0.1:4244") + let udp = UdpTransport.new(laddr) + let stun = Stun.new(udp) + let dtls = Dtls.new(stun) + let sctp = Sctp.new(dtls) + + let conn = await sctp.connect(initTAddress("127.0.0.1:4242"), sctpPort = 13) + while true: + await conn.write("ping".toBytes) + let msg = await conn.read() + echo "Received: ", string.fromBytes(msg.data) + await sleepAsync(1.seconds) + +waitFor(main()) diff --git a/examples/pong.nim b/examples/pong.nim new file mode 100644 index 0000000..79c018c --- /dev/null +++ b/examples/pong.nim @@ -0,0 +1,27 @@ +import chronos, stew/byteutils +import ../webrtc/udp_transport +import ../webrtc/stun/stun_transport +import ../webrtc/dtls/dtls_transport +import ../webrtc/sctp/[sctp_transport, sctp_connection] + +proc sendPong(conn: SctpConn) {.async.} = + var i = 0 + while true: + let msg = await conn.read() + echo "Received: ", string.fromBytes(msg.data) + await conn.write(("pong " & $i).toBytes) + i.inc() + +proc main() {.async.} = + let laddr = initTAddress("127.0.0.1:4242") + let udp = UdpTransport.new(laddr) + let stun = Stun.new(udp) + let dtls = Dtls.new(stun) + let sctp = Sctp.new(dtls) + + sctp.listen(13) + while true: + let conn = await sctp.accept() + asyncSpawn conn.sendPong() + +waitFor(main()) diff --git a/tests/runalltests.nim b/tests/runalltests.nim index 3a49806..9862719 100644 --- a/tests/runalltests.nim +++ b/tests/runalltests.nim @@ -9,5 +9,8 @@ {.used.} +{.passc: "-DSCTP_DEBUG".} + import teststun import testdtls +import testsctp diff --git a/tests/testsctp.nim b/tests/testsctp.nim new file mode 100644 index 0000000..448714f --- /dev/null +++ b/tests/testsctp.nim @@ -0,0 +1,102 @@ +# Nim-WebRTC +# Copyright (c) 2024 Status Research & Development GmbH +# Licensed under either of +# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE)) +# * MIT license ([LICENSE-MIT](LICENSE-MIT)) +# at your option. +# This file may not be copied, modified, or distributed except according to +# those terms. + +{.used.} + +import chronos +import ../webrtc/udp_transport +import ../webrtc/stun/stun_transport +import ../webrtc/dtls/dtls_transport +import ../webrtc/sctp/sctp_transport +import ../webrtc/sctp/sctp_connection +import ./asyncunit + +suite "SCTP": + teardown: + checkLeaks() + + type + SctpStackForTest = object + localAddress: TransportAddress + udp: UdpTransport + stun: Stun + dtls: Dtls + sctp: Sctp + + proc initSctpStack(la: TransportAddress): SctpStackForTest = + result.udp = UdpTransport.new(la) + result.localAddress = result.udp.localAddress() + result.stun = Stun.new(result.udp) + result.dtls = Dtls.new(result.stun) + result.sctp = Sctp.new(result.dtls) + result.sctp.listen() + + proc closeSctpStack(self: SctpStackForTest) {.async: (raises: [CancelledError]).} = + await self.sctp.stop() + await self.dtls.stop() + await self.stun.stop() + await self.udp.close() + + asyncTest "Two SCTP nodes connecting to each other, then sending/receiving data": + var + sctpServer = initSctpStack(initTAddress("127.0.0.1:0")) + sctpClient = initSctpStack(initTAddress("127.0.0.1:0")) + let + serverConnFut = sctpServer.sctp.accept() + clientConn = await sctpClient.sctp.connect(sctpServer.localAddress) + serverConn = await serverConnFut + + await clientConn.write(@[1'u8, 2, 3, 4]) + check (await serverConn.read()).data == @[1'u8, 2, 3, 4] + + await serverConn.write(@[5'u8, 6, 7, 8]) + check (await clientConn.read()).data == @[5'u8, 6, 7, 8] + + await clientConn.write(@[10'u8, 11, 12, 13]) + await serverConn.write(@[14'u8, 15, 16, 17]) + check (await clientConn.read()).data == @[14'u8, 15, 16, 17] + check (await serverConn.read()).data == @[10'u8, 11, 12, 13] + + await allFutures(clientConn.close(), serverConn.close()) + await allFutures(sctpClient.closeSctpStack(), sctpServer.closeSctpStack()) + + asyncTest "Two DTLS nodes connecting to the same DTLS server, sending/receiving data": + var + sctpServer = initSctpStack(initTAddress("127.0.0.1:0")) + sctpClient1 = initSctpStack(initTAddress("127.0.0.1:0")) + sctpClient2 = initSctpStack(initTAddress("127.0.0.1:0")) + let + serverConn1Fut = sctpServer.sctp.accept() + serverConn2Fut = sctpServer.sctp.accept() + clientConn1 = await sctpClient1.sctp.connect(sctpServer.localAddress) + clientConn2 = await sctpClient2.sctp.connect(sctpServer.localAddress) + serverConn1 = await serverConn1Fut + serverConn2 = await serverConn2Fut + + await serverConn1.write(@[1'u8, 2, 3, 4]) + await serverConn2.write(@[5'u8, 6, 7, 8]) + await clientConn1.write(@[9'u8, 10, 11, 12]) + await clientConn2.write(@[13'u8, 14, 15, 16]) + check: + (await clientConn1.read()).data == @[1'u8, 2, 3, 4] + (await clientConn2.read()).data == @[5'u8, 6, 7, 8] + (await serverConn1.read()).data == @[9'u8, 10, 11, 12] + (await serverConn2.read()).data == @[13'u8, 14, 15, 16] + await allFutures(clientConn1.close(), serverConn1.close()) + + await serverConn2.write(@[5'u8, 6, 7, 8]) + await clientConn2.write(@[13'u8, 14, 15, 16]) + check: + (await clientConn2.read()).data == @[5'u8, 6, 7, 8] + (await serverConn2.read()).data == @[13'u8, 14, 15, 16] + await allFutures(clientConn2.close(), serverConn2.close()) + + await allFutures(sctpClient1.closeSctpStack(), + sctpClient2.closeSctpStack(), + sctpServer.closeSctpStack()) diff --git a/webrtc.nimble b/webrtc.nimble index da9e66f..91134e9 100644 --- a/webrtc.nimble +++ b/webrtc.nimble @@ -24,10 +24,20 @@ var cfg = " --threads:on --opt:speed" when defined(windows): - cfg = cfg & " --clib:ws2_32" + # ws2_32 is required by MbedTLS and usrsctp + # iphlpapi is required by usrsctp + cfg = cfg & " --clib:ws2_32 --clib:iphlpapi" import hashes +proc buildExample(filename: string, run = false, extraFlags = "") = + var excstr = nimc & " " & lang & " " & flags & " -p:. " & extraFlags + excstr.add(" examples/" & filename) + exec excstr + if run: + exec "./examples/" & filename.toExe + rmFile "examples/" & filename.toExe + proc runTest(filename: string) = var excstr = nimc & " " & lang & " -d:debug " & cfg & " " & flags excstr.add(" -d:nimOldCaseObjects") # TODO: fix this in binary-serialization @@ -36,5 +46,9 @@ proc runTest(filename: string) = exec excstr & " -r " & " tests/" & filename rmFile "tests/" & filename.toExe -task test, "Run test": +task test, "Run the test suite": runTest("runalltests") + +task build_examples, "Build the examples": + buildExample("ping") + buildExample("pong") diff --git a/webrtc/sctp/sctp_connection.nim b/webrtc/sctp/sctp_connection.nim new file mode 100644 index 0000000..4c4b887 --- /dev/null +++ b/webrtc/sctp/sctp_connection.nim @@ -0,0 +1,329 @@ +# Nim-WebRTC +# Copyright (c) 2024 Status Research & Development GmbH +# Licensed under either of +# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE)) +# * MIT license ([LICENSE-MIT](LICENSE-MIT)) +# at your option. +# This file may not be copied, modified, or distributed except according to +# those terms. + +import nativesockets, bitops, sequtils +import usrsctp, chronos, chronicles, stew/[ptrops, endians2, byteutils] +import ./sctp_utils, ../errors, ../dtls/dtls_connection + +logScope: + topics = "webrtc sctp_connection" + +const + SctpConnTracker* = "webrtc.sctp.conn" + IPPROTO_SCTP = 132 # Official IANA number + +type + SctpConnOnClose* = proc() {.raises: [], gcsafe.} + + SctpState* = enum + SctpConnecting + SctpConnected + SctpClosed + + SctpMessageParameters* = object + # This object is used to help manage messages exchanged over SCTP + # within the DataChannel stack. + protocolId*: uint32 + # protocolId is used to distinguish different protocols within + # SCTP stream. In WebRTC, this is used to define the type of application + # data being transferred (text data, binary data...). + streamId*: uint16 + # streamId identifies the specific SCTP stream. In WebRTC, each + # DataChannel corresponds to a different stream, so the streamId is + # used to map the message to the appropriate DataChannel. + endOfRecord*: bool + # endOfRecord indicates whether the current SCTP message is the + # final part of a record or not. This is related to the + # fragmentation and reassembly of messages. + unordered*: bool + # The unordered flag determines whether the message should be + # delivered in order or not. SCTP allows for both ordered and + # unordered delivery of messages. + + SctpMessage* = ref object + data*: seq[byte] + info*: sctp_recvv_rn + params*: SctpMessageParameters + + SctpConn* = ref object + conn: DtlsConn # Underlying DTLS Connection + sctpSocket*: ptr socket # Current usrsctp socket + + state*: SctpState # Current Sctp State + onClose: seq[SctpConnOnClose] # List of procedure to run while closing a connection + + connectEvent*: AsyncEvent # Event fired when the connection is connected + acceptEvent*: AsyncEvent # Event fired when the connection is accepted + + # Infinite loop reading on the underlying DTLS Connection. + readLoop: Future[void].Raising([CancelledError, WebRtcError]) + + dataRecv: AsyncQueue[SctpMessage] # Queue of messages to be read + sendQueue: seq[byte] # Queue of messages to be sent + +proc remoteAddress*(self: SctpConn): TransportAddress = + if self.conn.isNil(): + raise newException(WebRtcError, "SCTP - Connection not set") + return self.conn.remoteAddress() + +template usrsctpAwait(self: SctpConn, body: untyped): untyped = + # usrsctpAwait is template which set `sendQueue` to @[] then calls + # an usrsctp function. If during the synchronous run of the usrsctp function + # `sendQueue` is set, it is sent at the end of the function. + proc trySend(conn: SctpConn) {.async: (raises: [CancelledError]).} = + try: + trace "Send To", address = conn.remoteAddress() + await conn.conn.write(self.sendQueue) + except CatchableError as exc: + trace "Send Failed", exceptionMsg = exc.msg + + self.sendQueue = @[] + when type(body) is void: + (body) + if self.sendQueue.len() > 0: + await self.trySend() + else: + let res = (body) + if self.sendQueue.len() > 0: + await self.trySend() + res + +# -- usrsctp send and receive callback -- + +proc recvCallback*(sock: ptr socket, data: pointer, flags: cint) {.cdecl.} = + # Callback procedure called when we receive data after a connection + # has been established. + let + conn = cast[SctpConn](data) + events = usrsctp_get_events(sock) + + trace "Receive callback", events + if bitand(events, SCTP_EVENT_READ) != 0: + var + message = SctpMessage(data: newSeq[byte](4096)) + address: Sockaddr_storage + rn: sctp_recvv_rn + addressLen = sizeof(Sockaddr_storage).SockLen + rnLen = sizeof(sctp_recvv_rn).SockLen + infotype: uint + flags: int + let n = sock.usrsctp_recvv( + cast[pointer](addr message.data[0]), + message.data.len.uint, + cast[ptr SockAddr](addr address), + cast[ptr SockLen](addr addressLen), + cast[pointer](addr message.info), + cast[ptr SockLen](addr rnLen), + cast[ptr cuint](addr infotype), + cast[ptr cint](addr flags), + ) + if n < 0: + warn "usrsctp_recvv", error = sctpStrerror() + return + elif n > 0: + message.data.delete(n ..< message.data.len()) + trace "message info from handle upcall", msginfo = message.info + message.params = SctpMessageParameters( + protocolId: message.info.recvv_rcvinfo.rcv_ppid.swapBytes(), + streamId: message.info.recvv_rcvinfo.rcv_sid, + ) + if bitand(flags, MSG_NOTIFICATION) != 0: + let notif = cast[ptr sctp_notification](data) + trace "Notification received", notifType = notif.sn_header.sn_type + else: + try: + conn.dataRecv.addLastNoWait(message) + except AsyncQueueFullError: + trace "Queue full, dropping packet" + elif bitand(events, SCTP_EVENT_WRITE) != 0: + trace "sctp event write in the upcall" + else: + warn "Handle Upcall unexpected event", events + +proc sendCallback*( + ctx: pointer, buffer: pointer, length: uint, tos: uint8, set_df: uint8 +): cint {.cdecl.} = + # This proc is called by usrsctp everytime usrsctp tries to send data. + let + conn = cast[SctpConn](ctx) + buf = @(buffer.makeOpenArray(byte, int(length))) + trace "sendCallback", sctpPacket = $(buf.getSctpPacket()) + proc testSend() {.async: (raises: [CancelledError]).} = + try: + trace "Send To", address = conn.remoteAddress() + await conn.conn.write(buf) + except CatchableError as exc: + trace "Send Failed", message = exc.msg + + conn.sendQueue = buf + +proc addOnClose*(self: SctpConn, onCloseProc: SctpConnOnClose) = + ## Adds a proc to be called when SctpConn is closed + ## + self.onClose.add(onCloseProc) + +proc readLoopProc(self: SctpConn) {.async: (raises: [CancelledError, WebRtcError]).} = + while true: + var msg = await self.conn.read() + if msg == @[]: + trace "Sctp read loop stopped, DTLS connection closed" + return + trace "Receive data", + remoteAddress = self.conn.remoteAddress(), sctPacket = $(msg.getSctpPacket()) + self.usrsctpAwait: + usrsctp_conninput(cast[pointer](self), addr msg[0], uint(msg.len), 0) + +proc new*(T: typedesc[SctpConn], conn: DtlsConn): T = + result = T( + conn: conn, + state: SctpConnecting, + connectEvent: AsyncEvent(), + acceptEvent: AsyncEvent(), + dataRecv: newAsyncQueue[SctpMessage](), + ) + result.readLoop = result.readLoopProc() + usrsctp_register_address(cast[pointer](result)) + +proc connect*(self: SctpConn, sctpPort: uint16) {.async: (raises: [CancelledError, WebRtcError]).} = + var sconn: Sockaddr_conn + when compiles(sconn.sconn_len): + sconn.sconn_len = sizeof(sconn).uint8 + sconn.sconn_family = AF_CONN + sconn.sconn_port = htons(sctpPort) + sconn.sconn_addr = cast[pointer](self) + let connErr = self.usrsctpAwait: self.sctpSocket.usrsctp_connect( + cast[ptr SockAddr](addr sconn), SockLen(sizeof(sconn)) + ) + if connErr != 0 and errno != SctpEINPROGRESS: + raise + newException(WebRtcError, "SCTP - Connection failed: " & sctpStrerror()) + +proc read*(self: SctpConn): Future[SctpMessage] {.async: (raises: [CancelledError, WebRtcError]).} = + # Used by DataChannel, returns SctpMessage in order to get the stream + # and protocol ids + if self.state == SctpClosed: + raise newException(WebRtcError, "Try to read on an already closed SctpConn") + return await self.dataRecv.popFirst() + +proc toFlags(params: SctpMessageParameters): uint16 = + if params.endOfRecord: + result = result or SCTP_EOR + if params.unordered: + result = result or SCTP_UNORDERED + +proc write*( + self: SctpConn, buf: seq[byte], sendParams = default(SctpMessageParameters) +) {.async: (raises: [CancelledError, WebRtcError]).} = + # Used by DataChannel, writes buf on the Dtls connection. + if self.state == SctpClosed: + raise newException(WebRtcError, "Try to write on an already closed SctpConn") + var cpy = buf + let sendvErr = + if sendParams == default(SctpMessageParameters): + # If writes is called by DataChannel, sendParams should never + # be the default value. This split is useful for testing. + self.usrsctpAwait: + self.sctpSocket.usrsctp_sendv( + cast[pointer](addr cpy[0]), + cpy.len().uint, + nil, + 0, + nil, + 0, + SCTP_SENDV_NOINFO.cuint, + 0, + ) + else: + var sendInfo = sctp_sndinfo( + snd_sid: sendParams.streamId, + snd_ppid: sendParams.protocolId.swapBytes(), + snd_flags: sendParams.toFlags(), + ) + self.usrsctpAwait: + self.sctpSocket.usrsctp_sendv( + cast[pointer](addr cpy[0]), + cpy.len().uint, + nil, + 0, + cast[pointer](addr sendInfo), + sizeof(sendInfo).SockLen, + SCTP_SENDV_SNDINFO.cuint, + 0, + ) + if sendvErr < 0: + raise newException(WebRtcError, "SCTP - " & sctpStrerror()) + +proc write*( + self: SctpConn, s: string +) {.async: (raises: [CancelledError, WebRtcError]).} = + await self.write(s.toBytes()) + +type + # This object is a workaround, srs_stream_list in usrsctp is an + # UncheckedArray, and they're not assignable. + sctp_reset_streams_workaround = object + srs_assoc_id: sctp_assoc_t + srs_flags: uint16 + srs_number_streams: uint16 + srs_stream_list: array[1, uint16] + +proc closeChannel*(self: SctpConn, streamId: uint16) = + ## Resets a specific outgoing SCTP stream identified by + ## streamId to close the associated DataChannel. + var srs: sctp_reset_streams_workaround + let len = sizeof(srs) + + srs.srs_flags = SCTP_STREAM_RESET_OUTGOING + srs.srs_number_streams = 1 + srs.srs_stream_list[0] = streamId + let ret = usrsctp_setsockopt( + self.sctpSocket, + IPPROTO_SCTP, + SCTP_RESET_STREAMS, + addr srs, + len.SockLen + ) + if ret < 0: + raise newException(WebRtcError, "SCTP - Close channel failed: " & sctpStrerror()) + +proc closeAllChannels*(self: SctpConn) = + ## Resets all outgoing SCTP streams, effectively closing all + ## open DataChannels for the current SCTP connection. + var srs: sctp_reset_streams_workaround + let len = sizeof(srs) - sizeof(srs.srs_stream_list) + + srs.srs_flags = SCTP_STREAM_RESET_OUTGOING + srs.srs_number_streams = 0 # 0 means all channels + let ret = usrsctp_setsockopt( + self.sctpSocket, + IPPROTO_SCTP, + SCTP_RESET_STREAMS, + addr srs, + len.SockLen + ) + if ret < 0: + raise newException(WebRtcError, "SCTP - Close all channels failed: " & sctpStrerror()) + +proc close*(self: SctpConn) {.async: (raises: [CancelledError, WebRtcError]).} = + ## Closes the entire SCTP connection by resetting all channels, + ## deregistering the address, stopping the read loop, and cleaning up resources. + if self.state == SctpClosed: + debug "Try to close SctpConn twice" + return + self.closeAllChannels() + usrsctp_deregister_address(cast[pointer](self)) + self.usrsctpAwait: + self.sctpSocket.usrsctp_close() + await self.readLoop.cancelAndWait() + self.state = SctpClosed + untrackCounter(SctpConnTracker) + await self.conn.close() + for onCloseProc in self.onClose: + onCloseProc() + self.onClose = @[] diff --git a/webrtc/sctp/sctp_logutils.nim b/webrtc/sctp/sctp_logutils.nim new file mode 100644 index 0000000..f75452c --- /dev/null +++ b/webrtc/sctp/sctp_logutils.nim @@ -0,0 +1,69 @@ +# Nim-WebRTC +# Copyright (c) 2024 Status Research & Development GmbH +# Licensed under either of +# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE)) +# * MIT license ([LICENSE-MIT](LICENSE-MIT)) +# at your option. +# This file may not be copied, modified, or distributed except according to +# those terms. + +import binary_serialization + +# This file defines custom objects and procedures to improve the +# readability and accuracy of logging SCTP messages. The default +# usrsctp logger may not provide sufficient detail or clarity for +# SCTP message analysis, so this implementation creates more structured +# and informative logs. By parsing and formatting SCTP packet headers +# and chunks into human-readable strings, it provides clearer insights +# into the data being transmitted. This aids debugging by offering a +# more descriptive view of SCTP traffic than what is available +# by default. + +type + SctpChunk* = object + chunkType*: uint8 + flag*: uint8 + length* {.bin_value: it.data.len() + 4.}: uint16 + data* {.bin_len: it.length - 4.}: seq[byte] + + SctpPacketHeader* = object + srcPort*: uint16 + dstPort*: uint16 + verifTag*: uint32 + checksum*: uint32 + + SctpPacketStructure* = object + header*: SctpPacketHeader + chunks*: seq[SctpChunk] + +proc dataToString(data: seq[byte]): string = + if data.len() < 8: + return $data + result = "@[" + result &= $data[0] & ", " & $data[1] & ", " & $data[2] & ", " & $data[3] & " ... " + result &= $data[^4] & ", " & $data[^3] & ", " & $data[^2] & ", " & $data[^1] & "]" + +proc `$`*(packet: SctpPacketStructure): string = + result = "{header: {srcPort: " + result &= $(packet.header.srcPort) & ", dstPort: " + result &= $(packet.header.dstPort) & "}, chunks: @[" + var counter = 0 + for chunk in packet.chunks: + result &= "{type: " & $(chunk.chunkType) & ", len: " + result &= $(chunk.length) & ", data: " + result &= chunk.data.dataToString() + counter += 1 + if counter < packet.chunks.len(): + result &= ", " + result &= "]}" + +proc getSctpPacket*(buffer: seq[byte]): SctpPacketStructure = + result.header = Binary.decode(buffer, SctpPacketHeader) + var size = sizeof(SctpPacketHeader) + while size < buffer.len: + let chunk = Binary.decode(buffer[size ..^ 1], SctpChunk) + result.chunks.add(chunk) + size.inc(chunk.length.int) + while size mod 4 != 0: + # padding; could use `size.inc(-size %% 4)` instead but it lacks clarity + size.inc(1) diff --git a/webrtc/sctp/sctp_transport.nim b/webrtc/sctp/sctp_transport.nim new file mode 100644 index 0000000..45ec122 --- /dev/null +++ b/webrtc/sctp/sctp_transport.nim @@ -0,0 +1,242 @@ +# Nim-WebRTC +# Copyright (c) 2024 Status Research & Development GmbH +# Licensed under either of +# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE)) +# * MIT license ([LICENSE-MIT](LICENSE-MIT)) +# at your option. +# This file may not be copied, modified, or distributed except according to +# those terms. + +import tables, bitops, nativesockets, strutils, sequtils +import usrsctp, chronos, chronicles +import + ./[sctp_connection, sctp_utils], ../errors, ../dtls/dtls_transport + +export chronicles + +const + SctpTransportTracker* = "webrtc.sctp.transport" + IPPROTO_SCTP = 132 + +logScope: + topics = "webrtc sctp" + +# Implementation of an Sctp client and server using the usrsctp library. +# Usrsctp is usable with a single thread but this is not the intended +# way to use it. As a result, there are many callbacks that calls each +# other synchronously. + +proc printf( + format: cstring +) {.cdecl, importc: "printf", varargs, header: "", gcsafe.} + +type Sctp* = ref object + dtls: Dtls # Underlying Dtls Transport + connections: Table[TransportAddress, SctpConn] # List of all the Sctp connections + isServer: bool + sockServer: ptr socket # usrsctp "server" socket to accept new connections + +# -- usrsctp accept and connect callbacks -- + +proc handleAccept(sock: ptr socket, data: pointer, flags: cint) {.cdecl.} = + # Callback procedure called when a connection is about to be accepted. + var + sconn: Sockaddr_conn + slen: SockLen = sizeof(Sockaddr_conn).uint32 + let + sctp = cast[Sctp](data) + sctpSocket = + usrsctp_accept(sctp.sockServer, cast[ptr SockAddr](addr sconn), addr slen) + conn = cast[SctpConn](sconn.sconn_addr) + + if sctpSocket.isNil(): + warn "usrsctp_accept failed", error = sctpStrerror() + conn.state = SctpState.SctpClosed + else: + trace "Scpt connection accepted", remoteAddress = conn.remoteAddress() + conn.sctpSocket = sctpSocket + conn.state = SctpState.SctpConnected + conn.acceptEvent.fire() + +proc handleConnect(sock: ptr socket, data: pointer, flags: cint) {.cdecl.} = + # Callback procedure called during usrsctp_connect + let + conn = cast[SctpConn](data) + events = usrsctp_get_events(sock) + + if conn.state == SctpState.SctpConnecting: + if bitand(events, SCTP_EVENT_ERROR) != 0: + warn "Cannot connect", remoteAddress = conn.remoteAddress() + conn.state = SctpState.SctpClosed + elif bitand(events, SCTP_EVENT_WRITE) != 0: + conn.state = SctpState.SctpConnected + if usrsctp_set_upcall(conn.sctpSocket, recvCallback, data) != 0: + warn "usrsctp_set_upcall fails while connecting", error = sctpStrerror() + trace "Sctp connection connected", remoteAddress = conn.remoteAddress() + conn.connectEvent.fire() + else: + warn "Should never happen", currentState = conn.state + +proc stopServer*(self: Sctp) = + ## Sctp Transport stop acting like a server + ## + if not self.isServer: + trace "Try to close a client" + return + self.isServer = false + self.sockServer.usrsctp_close() + +proc serverSetup(self: Sctp, sctpPort: uint16): bool = + # This procedure setup usrsctp to be in "server mode" and + # creates an sctp socket on which we will listen + if usrsctp_sysctl_set_sctp_blackhole(2) != 0: + warn "usrsctp_sysctl_set_sctp_blackhole failed", error = sctpStrerror() + return false + + if usrsctp_sysctl_set_sctp_no_csum_on_loopback(0) != 0: + warn "usrsctp_sysctl_set_sctp_no_csum_on_loopback failed", error = sctpStrerror() + return false + + if usrsctp_sysctl_set_sctp_delayed_sack_time_default(0) != 0: + warn "usrsctp_sysctl_set_sctp_delayed_sack_time_default failed", error = sctpStrerror() + return false + + let sock = usrsctp_socket(AF_CONN, SOCK_STREAM.toInt(), IPPROTO_SCTP, nil, nil, 0, nil) + if usrsctp_set_non_blocking(sock, 1) != 0: + warn "usrsctp_set_non_blocking failed", error = sctpStrerror() + return false + + var sin: Sockaddr_in + sin.sin_family = type(sin.sin_family)(SctpAF_INET) + sin.sin_port = htons(sctpPort) + sin.sin_addr.s_addr = htonl(INADDR_ANY) + if usrsctp_bind(sock, cast[ptr SockAddr](addr sin), SockLen(sizeof(Sockaddr_in))) != 0: + warn "usrsctp_bind failed", error = sctpStrerror() + return false + + if usrsctp_listen(sock, 1) < 0: + warn "usrsctp_listen failed", error = sctpStrerror() + return false + + if sock.usrsctp_set_upcall(handleAccept, cast[pointer](self)) != 0: + warn "usrsctp_set_upcall failed", error = sctpStrerror() + return false + + self.sockServer = sock + return true + +proc listen*(self: Sctp, sctpPort: uint16 = 5000) = + ## `listen` marks the Sctp Transport as a transport that will be used to accept + ## incoming connection requests using accept. + ## + if self.isServer: + trace "Try to start the server twice" + return + self.isServer = true + trace "Sctp listening", sctpPort + if not self.serverSetup(sctpPort): + raise newException(WebRtcError, "SCTP - Fails to listen") + +proc new*(T: type Sctp, dtls: Dtls): T = + ## Creates a new Sctp Transport + ## + var self = T() + self.dtls = dtls + + when defined(windows): + usrsctp_init_nothreads(dtls.localAddress.port.uint16, sendCallback, nil) + else: + usrsctp_init_nothreads(dtls.localAddress.port.uint16, sendCallback, printf) + if usrsctp_sysctl_set_sctp_debug_on(SCTP_DEBUG_ALL.uint32) != 0: + # Enabling debug is not critical, doesn't matter if it fails + trace "usrsctp_sysctl_set_sctp_debug_on failed", error = sctpStrerror() + if usrsctp_sysctl_set_sctp_ecn_enable(1) != 0: + # In the same way, enabling explicit congestion notification isn't required + trace "usrsctp_sysctl_set_sctp_ecn_enable failed", error = sctpStrerror() + trackCounter(SctpTransportTracker) + return self + +proc stop*(self: Sctp) {.async: (raises: [CancelledError]).} = + ## Stops the Sctp Transport + ## + if self.isServer: + self.stopServer() + untrackCounter(SctpTransportTracker) + let connections = toSeq(self.connections.values()) + await allFutures(connections.mapIt(it.close())) + if usrsctp_finish() != 0: + warn "usrsct_finish failed", error = sctpStrerror() + +proc socketSetup( + conn: SctpConn, callback: proc(a1: ptr socket, a2: pointer, a3: cint) {.cdecl.} +): bool = + # This procedure setup SctpConn. It should be in `sctp_connection.nim` file but I + # prefer not to expose it. + if conn.sctpSocket.usrsctp_set_non_blocking(1) != 0: + warn "usrsctp_set_non_blocking failed", error = sctpStrerror() + return false + + if conn.sctpSocket.usrsctp_set_upcall(callback, cast[pointer](conn)) != 0: + warn "usrsctp_set_upcall failed", error = sctpStrerror() + return false + + var nodelay: uint32 = 1 + if conn.sctpSocket.usrsctp_setsockopt(IPPROTO_SCTP, SCTP_NODELAY, addr nodelay, sizeof(nodelay).SockLen) != 0: + warn "usrsctp_setsockopt nodelay failed", error = sctpStrerror() + return false + + var recvinfo: uint32 = 1 + if conn.sctpSocket.usrsctp_setsockopt(IPPROTO_SCTP, SCTP_RECVRCVINFO, addr recvinfo, sizeof(recvinfo).SockLen) != 0: + warn "usrsctp_setsockopt recvinfo failed", error = sctpStrerror() + return false + return true + +proc addConnToTable(self: Sctp, conn: SctpConn) = + let remoteAddress = conn.remoteAddress() + proc cleanup() = + self.connections.del(remoteAddress) + self.connections[remoteAddress] = conn + conn.addOnClose(cleanup) + +proc accept*( + self: Sctp +): Future[SctpConn] {.async: (raises: [CancelledError, WebRtcError]).} = + ## Accept an Sctp Connection + ## + if not self.isServer: + raise newException(WebRtcError, "SCTP - Not a server") + var conn: SctpConn + while true: + conn = SctpConn.new(await self.dtls.accept()) + conn.acceptEvent.clear() + await conn.acceptEvent.wait() + if conn.state == SctpState.SctpConnected and conn.socketSetup(recvCallback): + break + await conn.close() + + self.addConnToTable(conn) + trackCounter(SctpConnTracker) + return conn + +proc connect*( + self: Sctp, raddr: TransportAddress, sctpPort: uint16 = 5000 +): Future[SctpConn] {.async: (raises: [CancelledError, WebRtcError]).} = + ## Connect to a remote address and returns an Sctp Connection + ## + let conn = SctpConn.new(await self.dtls.connect(raddr)) + conn.state = SctpState.SctpConnecting + conn.sctpSocket = + usrsctp_socket(AF_CONN, SOCK_STREAM.toInt(), IPPROTO_SCTP, nil, nil, 0, nil) + + if not conn.socketSetup(handleConnect): + raise newException(WebRtcError, "SCTP - Socket setup failed while connecting") + + await conn.connect(sctpPort) + + conn.connectEvent.clear() + await conn.connectEvent.wait() + if conn.state == SctpState.SctpClosed: + raise newException(WebRtcError, "SCTP - Connection failed") + self.addConnToTable(conn) + trackCounter(SctpConnTracker) + return conn diff --git a/webrtc/sctp/sctp_utils.nim b/webrtc/sctp/sctp_utils.nim new file mode 100644 index 0000000..17a5fca --- /dev/null +++ b/webrtc/sctp/sctp_utils.nim @@ -0,0 +1,29 @@ +# Nim-WebRTC +# Copyright (c) 2024 Status Research & Development GmbH +# Licensed under either of +# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE)) +# * MIT license ([LICENSE-MIT](LICENSE-MIT)) +# at your option. +# This file may not be copied, modified, or distributed except according to +# those terms. + +import nativesockets +import chronos + +var errno* {.importc, header: "".}: cint ## error variable + +when defined(windows): + import winlean + const + SctpAF_INET* = winlean.AF_INET + SctpEINPROGRESS* = winlean.WSAEINPROGRESS.cint +else: + const + SctpAF_INET* = nativesockets.AF_INET + SctpEINPROGRESS* = chronos.EINPROGRESS.cint + +proc sctpStrerror*(): string = + proc strerror( + error: int + ): cstring {.importc: "strerror", cdecl, header: "".} + return $(strerror(errno)) diff --git a/webrtc/stun/stun_connection.nim b/webrtc/stun/stun_connection.nim index 9d661a2..d732041 100644 --- a/webrtc/stun/stun_connection.nim +++ b/webrtc/stun/stun_connection.nim @@ -201,7 +201,7 @@ proc new*( ## var self = T( udp: udp, - laddr: udp.laddr, + laddr: udp.localAddress(), raddr: raddr, closed: false, dataRecv: newAsyncQueue[seq[byte]](StunMaxQueuingMessages), diff --git a/webrtc/stun/stun_transport.nim b/webrtc/stun/stun_transport.nim index becc593..3d52f53 100644 --- a/webrtc/stun/stun_transport.nim +++ b/webrtc/stun/stun_transport.nim @@ -108,7 +108,7 @@ proc new*( ## var self = T( udp: udp, - laddr: udp.laddr, + laddr: udp.localAddress(), usernameProvider: usernameProvider, usernameChecker: usernameChecker, passwordProvider: passwordProvider, diff --git a/webrtc/udp_transport.nim b/webrtc/udp_transport.nim index e458a5e..1f1625c 100644 --- a/webrtc/udp_transport.nim +++ b/webrtc/udp_transport.nim @@ -23,7 +23,7 @@ type raddr: TransportAddress UdpTransport* = ref object - laddr*: TransportAddress + laddr: TransportAddress udp: DatagramTransport dataRecv: AsyncQueue[UdpPacketInfo] closed: bool @@ -33,7 +33,7 @@ const UdpTransportTrackerName* = "webrtc.udp.transport" proc new*(T: type UdpTransport, laddr: TransportAddress): T = ## Initialize an Udp Transport ## - var self = T(laddr: laddr, closed: false) + var self = T(closed: false) proc onReceive( udp: DatagramTransport, @@ -49,6 +49,7 @@ proc new*(T: type UdpTransport, laddr: TransportAddress): T = self.dataRecv = newAsyncQueue[UdpPacketInfo]() self.udp = newDatagramTransport(onReceive, local = laddr) + self.laddr = self.udp.localAddress() trackCounter(UdpTransportTrackerName) return self @@ -87,3 +88,6 @@ proc read*(self: UdpTransport): Future[UdpPacketInfo] {.async: (raises: [Cancell return trace "UDP read" return await self.dataRecv.popFirst() + +proc localAddress*(self: UdpTransport): TransportAddress = + self.laddr