diff --git a/webrtc/dtls/dtls.nim b/webrtc/dtls/dtls.nim new file mode 100644 index 0000000..1bc4cbf --- /dev/null +++ b/webrtc/dtls/dtls.nim @@ -0,0 +1,377 @@ +# Nim-WebRTC +# Copyright (c) 2024 Status Research & Development GmbH +# Licensed under either of +# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE)) +# * MIT license ([LICENSE-MIT](LICENSE-MIT)) +# at your option. +# This file may not be copied, modified, or distributed except according to +# those terms. + +import times, deques, tables, sequtils +import chronos, chronicles +import ./utils, ../stun/stun_connection + +import mbedtls/ssl +import mbedtls/ssl_cookie +import mbedtls/ssl_cache +import mbedtls/pk +import mbedtls/md +import mbedtls/entropy +import mbedtls/ctr_drbg +import mbedtls/rsa +import mbedtls/x509 +import mbedtls/x509_crt +import mbedtls/bignum +import mbedtls/error +import mbedtls/net_sockets +import mbedtls/timing + +logScope: + topics = "webrtc dtls" + +# Implementation of a DTLS client and a DTLS Server by using the mbedtls library. +# Multiple things here are unintuitive partly because of the callbacks +# used by mbedtls and that those callbacks cannot be async. +# +# TODO: +# - Check the viability of the add/pop first/last of the asyncqueue with the limit. +# There might be some errors (or crashes) with some edge cases with the no wait option +# - Not critical - Check how to make a better use of MBEDTLS_ERR_SSL_WANT_WRITE +# - Not critical - May be interesting to split Dtls and DtlsConn into two files + +# This limit is arbitrary, it could be interesting to make it configurable. +const PendingHandshakeLimit = 1024 + +# -- DtlsConn -- +# A Dtls connection to a specific IP address recovered by the receiving part of +# the Udp "connection" + +type + DtlsError* = object of CatchableError + DtlsConn* = ref object + conn: StunConn + laddr: TransportAddress + raddr*: TransportAddress + dataRecv: AsyncQueue[seq[byte]] + sendFuture: Future[void] + closed: bool + closeEvent: AsyncEvent + + timer: mbedtls_timing_delay_context + + ssl: mbedtls_ssl_context + config: mbedtls_ssl_config + cookie: mbedtls_ssl_cookie_ctx + cache: mbedtls_ssl_cache_context + + ctr_drbg: mbedtls_ctr_drbg_context + entropy: mbedtls_entropy_context + + localCert: seq[byte] + remoteCert: seq[byte] + +proc init(self: DtlsConn, conn: StunConn, laddr: TransportAddress) = + self.conn = conn + self.laddr = laddr + self.dataRecv = newAsyncQueue[seq[byte]]() + self.closed = false + self.closeEvent = newAsyncEvent() + +proc join(self: DtlsConn) {.async.} = + await self.closeEvent.wait() + +proc dtlsHandshake(self: DtlsConn, isServer: bool) {.async.} = + var shouldRead = isServer + while self.ssl.private_state != MBEDTLS_SSL_HANDSHAKE_OVER: + if shouldRead: + if isServer: + case self.raddr.family + of AddressFamily.IPv4: + mb_ssl_set_client_transport_id(self.ssl, self.raddr.address_v4) + of AddressFamily.IPv6: + mb_ssl_set_client_transport_id(self.ssl, self.raddr.address_v6) + else: + raise newException(DtlsError, "Remote address isn't an IP address") + let tmp = await self.dataRecv.popFirst() + self.dataRecv.addFirstNoWait(tmp) + self.sendFuture = nil + let res = mb_ssl_handshake_step(self.ssl) + if not self.sendFuture.isNil(): + await self.sendFuture + shouldRead = false + if res == MBEDTLS_ERR_SSL_WANT_WRITE: + continue + elif res == MBEDTLS_ERR_SSL_WANT_READ: + shouldRead = true + continue + elif res == MBEDTLS_ERR_SSL_HELLO_VERIFY_REQUIRED: + mb_ssl_session_reset(self.ssl) + shouldRead = isServer + continue + elif res != 0: + raise newException(DtlsError, $(res.mbedtls_high_level_strerr())) + +proc close*(self: DtlsConn) {.async.} = + if self.closed: + debug "Try to close DtlsConn twice" + return + + self.closed = true + self.sendFuture = nil + # TODO: proc mbedtls_ssl_close_notify => template mb_ssl_close_notify in nim-mbedtls + let x = mbedtls_ssl_close_notify(addr self.ssl) + if not self.sendFuture.isNil(): + await self.sendFuture + self.closeEvent.fire() + +proc write*(self: DtlsConn, msg: seq[byte]) {.async.} = + if self.closed: + debug "Try to write on an already closed DtlsConn" + return + var buf = msg + try: + let sendFuture = newFuture[void]("DtlsConn write") + self.sendFuture = nil + let write = mb_ssl_write(self.ssl, buf) + if not self.sendFuture.isNil(): + await self.sendFuture + trace "Dtls write", msgLen = msg.len(), actuallyWrote = write + except MbedTLSError as exc: + trace "Dtls write error", errorMsg = exc.msg + raise exc + +proc read*(self: DtlsConn): Future[seq[byte]] {.async.} = + if self.closed: + debug "Try to read on an already closed DtlsConn" + return + var res = newSeq[byte](8192) + while true: + let tmp = await self.dataRecv.popFirst() + self.dataRecv.addFirstNoWait(tmp) + # TODO: Find a clear way to use the template `mb_ssl_read` without + # messing up things with exception + let length = mbedtls_ssl_read(addr self.ssl, cast[ptr byte](addr res[0]), res.len().uint) + if length == MBEDTLS_ERR_SSL_WANT_READ: + continue + if length < 0: + raise newException(DtlsError, $(length.cint.mbedtls_high_level_strerr())) + res.setLen(length) + return res + +# -- Dtls -- +# The Dtls object read every messages from the UdpConn/StunConn and, if the address +# is not yet stored in the Table `Connection`, adds it to the `pendingHandshake` queue +# to be accepted later, if the address is stored, add the message received to the +# corresponding DtlsConn `dataRecv` queue. + +type + Dtls* = ref object of RootObj + connections: Table[TransportAddress, DtlsConn] + pendingHandshakes: AsyncQueue[(TransportAddress, seq[byte])] + conn: StunConn + laddr: TransportAddress + started: bool + readLoop: Future[void] + ctr_drbg: mbedtls_ctr_drbg_context + entropy: mbedtls_entropy_context + + serverPrivKey: mbedtls_pk_context + serverCert: mbedtls_x509_crt + localCert: seq[byte] + +proc updateOrAdd(aq: AsyncQueue[(TransportAddress, seq[byte])], + raddr: TransportAddress, buf: seq[byte]) = + for kv in aq.mitems(): + if kv[0] == raddr: + kv[1] = buf + return + aq.addLastNoWait((raddr, buf)) + +proc init*(self: Dtls, conn: StunConn, laddr: TransportAddress) = + if self.started: + warn "Already started" + return + + proc readLoop() {.async.} = + while true: + let (buf, raddr) = await self.conn.read() + if self.connections.hasKey(raddr): + self.connections[raddr].dataRecv.addLastNoWait(buf) + else: + self.pendingHandshakes.updateOrAdd(raddr, buf) + + self.connections = initTable[TransportAddress, DtlsConn]() + self.pendingHandshakes = newAsyncQueue[(TransportAddress, seq[byte])](PendingHandshakeLimit) + self.conn = conn + self.laddr = laddr + self.started = true + self.readLoop = readLoop() + + mb_ctr_drbg_init(self.ctr_drbg) + mb_entropy_init(self.entropy) + mb_ctr_drbg_seed(self.ctr_drbg, mbedtls_entropy_func, self.entropy, nil, 0) + + self.serverPrivKey = self.ctr_drbg.generateKey() + self.serverCert = self.ctr_drbg.generateCertificate(self.serverPrivKey) + self.localCert = newSeq[byte](self.serverCert.raw.len) + copyMem(addr self.localCert[0], self.serverCert.raw.p, self.serverCert.raw.len) + +proc stop*(self: Dtls) {.async.} = + if not self.started: + warn "Already stopped" + return + + await allFutures(toSeq(self.connections.values()).mapIt(it.close())) + self.readLoop.cancel() + self.started = false + +# -- Remote / Local certificate getter -- + +proc remoteCertificate*(conn: DtlsConn): seq[byte] = + conn.remoteCert + +proc localCertificate*(conn: DtlsConn): seq[byte] = + conn.localCert + +proc localCertificate*(self: Dtls): seq[byte] = + self.localCert + +# -- MbedTLS Callbacks -- + +proc verify(ctx: pointer, pcert: ptr mbedtls_x509_crt, + state: cint, pflags: ptr uint32): cint {.cdecl.} = + # verify is the procedure called by mbedtls when receiving the remote + # certificate. It's usually used to verify the validity of the certificate. + # We use this procedure to store the remote certificate as it's mandatory + # to have it for the Prologue of the Noise protocol, aswell as the localCertificate. + var self = cast[DtlsConn](ctx) + let cert = pcert[] + + self.remoteCert = newSeq[byte](cert.raw.len) + copyMem(addr self.remoteCert[0], cert.raw.p, cert.raw.len) + return 0 + +proc dtlsSend(ctx: pointer, buf: ptr byte, len: uint): cint {.cdecl.} = + # dtlsSend is the procedure called by mbedtls when data needs to be sent. + # As the StunConn's write proc is asynchronous and dtlsSend cannot be async, + # we store the future of this write and await it after the end of the + # function (see write or dtlsHanshake for example). + var self = cast[DtlsConn](ctx) + var toWrite = newSeq[byte](len) + if len > 0: + copyMem(addr toWrite[0], buf, len) + trace "dtls send", len + self.sendFuture = self.conn.write(self.raddr, toWrite) + result = len.cint + +proc dtlsRecv(ctx: pointer, buf: ptr byte, len: uint): cint {.cdecl.} = + # dtlsRecv is the procedure called by mbedtls when data needs to be received. + # As we cannot asynchronously await for data to be received, we use a data received + # queue. If this queue is empty, we return `MBEDTLS_ERR_SSL_WANT_READ` for us to await + # when the mbedtls proc resumed (see read or dtlsHandshake for example) + let self = cast[DtlsConn](ctx) + if self.dataRecv.len() == 0: + return MBEDTLS_ERR_SSL_WANT_READ + + var dataRecv = self.dataRecv.popFirstNoWait() + copyMem(buf, addr dataRecv[0], dataRecv.len()) + result = dataRecv.len().cint + trace "dtls receive", len, result + +# -- Dtls Accept / Connect procedures -- + +proc removeConnection(self: Dtls, conn: DtlsConn, raddr: TransportAddress) {.async.} = + await conn.join() + self.connections.del(raddr) + +proc accept*(self: Dtls): Future[DtlsConn] {.async.} = + var res = DtlsConn() + + res.init(self.conn, self.laddr) + mb_ssl_init(res.ssl) + mb_ssl_config_init(res.config) + mb_ssl_cookie_init(res.cookie) + mb_ssl_cache_init(res.cache) + + res.ctr_drbg = self.ctr_drbg + res.entropy = self.entropy + + var pkey = self.serverPrivKey + var srvcert = self.serverCert + res.localCert = self.localCert + + mb_ssl_config_defaults(res.config, + MBEDTLS_SSL_IS_SERVER, + MBEDTLS_SSL_TRANSPORT_DATAGRAM, + MBEDTLS_SSL_PRESET_DEFAULT) + mb_ssl_conf_rng(res.config, mbedtls_ctr_drbg_random, res.ctr_drbg) + mb_ssl_conf_read_timeout(res.config, 10000) # in milliseconds + mb_ssl_conf_ca_chain(res.config, srvcert.next, nil) + mb_ssl_conf_own_cert(res.config, srvcert, pkey) + mb_ssl_cookie_setup(res.cookie, mbedtls_ctr_drbg_random, res.ctr_drbg) + mb_ssl_conf_dtls_cookies(res.config, res.cookie) + mb_ssl_set_timer_cb(res.ssl, res.timer) + mb_ssl_setup(res.ssl, res.config) + mb_ssl_session_reset(res.ssl) + mb_ssl_set_verify(res.ssl, verify, res) + mb_ssl_conf_authmode(res.config, MBEDTLS_SSL_VERIFY_OPTIONAL) + mb_ssl_set_bio(res.ssl, cast[pointer](res), dtlsSend, dtlsRecv, nil) + while true: + let (raddr, buf) = await self.pendingHandshakes.popFirst() + try: + res.raddr = raddr + res.dataRecv.addLastNoWait(buf) + self.connections[raddr] = res + await res.dtlsHandshake(true) + asyncSpawn self.removeConnection(res, raddr) + break + except CatchableError as exc: + trace "Handshake fail", remoteAddress = raddr, error = exc.msg + self.connections.del(raddr) + continue + return res + +proc connect*(self: Dtls, raddr: TransportAddress): Future[DtlsConn] {.async.} = + var res = DtlsConn() + + res.init(self.conn, self.laddr) + mb_ssl_init(res.ssl) + mb_ssl_config_init(res.config) + + res.ctr_drbg = self.ctr_drbg + res.entropy = self.entropy + + var pkey = res.ctr_drbg.generateKey() + var srvcert = res.ctr_drbg.generateCertificate(pkey) + res.localCert = newSeq[byte](srvcert.raw.len) + copyMem(addr res.localCert[0], srvcert.raw.p, srvcert.raw.len) + + mb_ctr_drbg_init(res.ctr_drbg) + mb_entropy_init(res.entropy) + mb_ctr_drbg_seed(res.ctr_drbg, mbedtls_entropy_func, res.entropy, nil, 0) + + mb_ssl_config_defaults(res.config, + MBEDTLS_SSL_IS_CLIENT, + MBEDTLS_SSL_TRANSPORT_DATAGRAM, + MBEDTLS_SSL_PRESET_DEFAULT) + mb_ssl_conf_rng(res.config, mbedtls_ctr_drbg_random, res.ctr_drbg) + mb_ssl_conf_read_timeout(res.config, 10000) # in milliseconds + mb_ssl_conf_ca_chain(res.config, srvcert.next, nil) + mb_ssl_set_timer_cb(res.ssl, res.timer) + mb_ssl_setup(res.ssl, res.config) + mb_ssl_set_verify(res.ssl, verify, res) + mb_ssl_conf_authmode(res.config, MBEDTLS_SSL_VERIFY_OPTIONAL) + mb_ssl_set_bio(res.ssl, cast[pointer](res), dtlsSend, dtlsRecv, nil) + + res.raddr = raddr + self.connections[raddr] = res + + try: + await res.dtlsHandshake(false) + asyncSpawn self.removeConnection(res, raddr) + except CatchableError as exc: + trace "Handshake fail", remoteAddress = raddr, error = exc.msg + self.connections.del(raddr) + raise exc + + return res diff --git a/webrtc/dtls/utils.nim b/webrtc/dtls/utils.nim new file mode 100644 index 0000000..06fb990 --- /dev/null +++ b/webrtc/dtls/utils.nim @@ -0,0 +1,96 @@ +# Nim-WebRTC +# Copyright (c) 2024 Status Research & Development GmbH +# Licensed under either of +# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE)) +# * MIT license ([LICENSE-MIT](LICENSE-MIT)) +# at your option. +# This file may not be copied, modified, or distributed except according to +# those terms. + +import std/times + +import stew/byteutils + +import mbedtls/pk +import mbedtls/rsa +import mbedtls/ctr_drbg +import mbedtls/x509_crt +import mbedtls/bignum +import mbedtls/md + +import chronicles + +# This sequence is used for debugging. +const mb_ssl_states* = @[ + "MBEDTLS_SSL_HELLO_REQUEST", + "MBEDTLS_SSL_CLIENT_HELLO", + "MBEDTLS_SSL_SERVER_HELLO", + "MBEDTLS_SSL_SERVER_CERTIFICATE", + "MBEDTLS_SSL_SERVER_KEY_EXCHANGE", + "MBEDTLS_SSL_CERTIFICATE_REQUEST", + "MBEDTLS_SSL_SERVER_HELLO_DONE", + "MBEDTLS_SSL_CLIENT_CERTIFICATE", + "MBEDTLS_SSL_CLIENT_KEY_EXCHANGE", + "MBEDTLS_SSL_CERTIFICATE_VERIFY", + "MBEDTLS_SSL_CLIENT_CHANGE_CIPHER_SPEC", + "MBEDTLS_SSL_CLIENT_FINISHED", + "MBEDTLS_SSL_SERVER_CHANGE_CIPHER_SPEC", + "MBEDTLS_SSL_SERVER_FINISHED", + "MBEDTLS_SSL_FLUSH_BUFFERS", + "MBEDTLS_SSL_HANDSHAKE_WRAPUP", + "MBEDTLS_SSL_NEW_SESSION_TICKET", + "MBEDTLS_SSL_SERVER_HELLO_VERIFY_REQUEST_SENT", + "MBEDTLS_SSL_HELLO_RETRY_REQUEST", + "MBEDTLS_SSL_ENCRYPTED_EXTENSIONS", + "MBEDTLS_SSL_END_OF_EARLY_DATA", + "MBEDTLS_SSL_CLIENT_CERTIFICATE_VERIFY", + "MBEDTLS_SSL_CLIENT_CCS_AFTER_SERVER_FINISHED", + "MBEDTLS_SSL_CLIENT_CCS_BEFORE_2ND_CLIENT_HELLO", + "MBEDTLS_SSL_SERVER_CCS_AFTER_SERVER_HELLO", + "MBEDTLS_SSL_CLIENT_CCS_AFTER_CLIENT_HELLO", + "MBEDTLS_SSL_SERVER_CCS_AFTER_HELLO_RETRY_REQUEST", + "MBEDTLS_SSL_HANDSHAKE_OVER", + "MBEDTLS_SSL_TLS1_3_NEW_SESSION_TICKET", + "MBEDTLS_SSL_TLS1_3_NEW_SESSION_TICKET_FLUSH" +] + +template generateKey*(random: mbedtls_ctr_drbg_context): mbedtls_pk_context = + var res: mbedtls_pk_context + mb_pk_init(res) + discard mbedtls_pk_setup(addr res, mbedtls_pk_info_from_type(MBEDTLS_PK_RSA)) + mb_rsa_gen_key(mb_pk_rsa(res), mbedtls_ctr_drbg_random, random, 2048, 65537) + let x = mb_pk_rsa(res) + res + +template generateCertificate*(random: mbedtls_ctr_drbg_context, + issuer_key: mbedtls_pk_context): mbedtls_x509_crt = + let + # To be honest, I have no clue what to put here as a name + name = "C=FR,O=Status,CN=webrtc" + time_format = initTimeFormat("YYYYMMddHHmmss") + time_from = times.now().format(time_format) + time_to = (times.now() + times.years(1)).format(time_format) + + var write_cert: mbedtls_x509write_cert + var serial_mpi: mbedtls_mpi + mb_x509write_crt_init(write_cert) + mb_x509write_crt_set_md_alg(write_cert, MBEDTLS_MD_SHA256); + mb_x509write_crt_set_subject_key(write_cert, issuer_key) + mb_x509write_crt_set_issuer_key(write_cert, issuer_key) + mb_x509write_crt_set_subject_name(write_cert, name) + mb_x509write_crt_set_issuer_name(write_cert, name) + mb_x509write_crt_set_validity(write_cert, time_from, time_to) + mb_x509write_crt_set_basic_constraints(write_cert, 0, -1) + mb_x509write_crt_set_subject_key_identifier(write_cert) + mb_x509write_crt_set_authority_key_identifier(write_cert) + mb_mpi_init(serial_mpi) + let serial_hex = mb_mpi_read_string(serial_mpi, 16) + mb_x509write_crt_set_serial(write_cert, serial_mpi) + let buf = + try: + mb_x509write_crt_pem(write_cert, 2048, mbedtls_ctr_drbg_random, random) + except MbedTLSError as e: + raise e + var res: mbedtls_x509_crt + mb_x509_crt_parse(res, buf) + res