Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Validate extensions and abort on error #108

Merged
merged 5 commits into from
Apr 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/common/decrypted_read_handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,9 @@ impl DecryptedReadHandler<'_> {
}
ServerRecord::ChangeCipherSpec(_) => Err(TlsError::InternalError),
ServerRecord::Handshake(ServerHandshake::NewSessionTicket(_)) => {
// Ignore
// TODO: we should validate extensions and abort. We can do this automatically
// as long as the connection is unsplit, however, split connections must be aborted
// by the user.
Ok(())
}
_ => {
Expand Down
72 changes: 60 additions & 12 deletions src/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -238,18 +238,18 @@ impl<'a> State {
let record = record_reader
.read(transport, key_schedule.read_state())
.await?;
process_server_hello(handshake, key_schedule, record)
let result = process_server_hello(handshake, key_schedule, record);

handle_processing_error(result, transport, key_schedule, tx_buf).await
}
State::ServerVerify => {
/*info!(
"SIZE of server record queue : {}",
core::mem::size_of_val(&records)
);*/
let record = record_reader
.read(transport, key_schedule.read_state())
.await?;

process_server_verify(handshake, key_schedule, config, record)
let result = process_server_verify(handshake, key_schedule, config, record);

handle_processing_error(result, transport, key_schedule, tx_buf).await
}
State::ClientCert => {
let (state, tx) = client_cert(handshake, key_schedule, config, tx_buf)?;
Expand Down Expand Up @@ -296,16 +296,17 @@ impl<'a> State {
}
State::ServerHello => {
let record = record_reader.read_blocking(transport, key_schedule.read_state())?;
process_server_hello(handshake, key_schedule, record)

let result = process_server_hello(handshake, key_schedule, record);

handle_processing_error_blocking(result, transport, key_schedule, tx_buf)
}
State::ServerVerify => {
/*info!(
"SIZE of server record queue : {}",
core::mem::size_of_val(&records)
);*/
let record = record_reader.read_blocking(transport, key_schedule.read_state())?;

process_server_verify(handshake, key_schedule, config, record)
let result = process_server_verify(handshake, key_schedule, config, record);

handle_processing_error_blocking(result, transport, key_schedule, tx_buf)
}
State::ClientCert => {
let (state, tx) = client_cert(handshake, key_schedule, config, tx_buf)?;
Expand All @@ -326,6 +327,29 @@ impl<'a> State {
}
}

fn handle_processing_error_blocking<CipherSuite>(
result: Result<State, TlsError>,
transport: &mut impl BlockingWrite,
key_schedule: &mut KeySchedule<CipherSuite>,
tx_buf: &mut WriteBuffer,
) -> Result<State, TlsError>
where
CipherSuite: TlsCipherSuite,
{
if let Err(TlsError::AbortHandshake(level, description)) = result {
let (write_key_schedule, read_key_schedule) = key_schedule.as_split();
let tx = tx_buf.write_record(
&ClientRecord::Alert(Alert { level, description }, false),
write_key_schedule,
Some(read_key_schedule),
)?;

respond_blocking(tx, transport, key_schedule)?;
}

result
}

fn respond_blocking<CipherSuite>(
tx: &[u8],
transport: &mut impl BlockingWrite,
Expand All @@ -345,6 +369,30 @@ where
Ok(())
}

#[cfg(feature = "async")]
async fn handle_processing_error<'a, CipherSuite>(
result: Result<State, TlsError>,
transport: &mut impl AsyncWrite,
key_schedule: &mut KeySchedule<CipherSuite>,
tx_buf: &mut WriteBuffer<'a>,
) -> Result<State, TlsError>
where
CipherSuite: TlsCipherSuite,
{
if let Err(TlsError::AbortHandshake(level, description)) = result {
let (write_key_schedule, read_key_schedule) = key_schedule.as_split();
let tx = tx_buf.write_record(
&ClientRecord::Alert(Alert { level, description }, false),
write_key_schedule,
Some(read_key_schedule),
)?;

respond(tx, transport, key_schedule).await?;
}

result
}

#[cfg(feature = "async")]
async fn respond<CipherSuite>(
tx: &[u8],
Expand Down
6 changes: 3 additions & 3 deletions src/extensions/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use crate::supported_versions::ProtocolVersions;
use crate::TlsError;
use heapless::Vec;

#[derive(Debug)]
#[derive(Debug, PartialEq)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub enum ExtensionType {
ServerName = 0,
Expand All @@ -19,7 +19,7 @@ pub enum ExtensionType {
SupportedGroups = 10,
SignatureAlgorithms = 13,
UseSrtp = 14,
Heatbeat = 15,
Heartbeat = 15,
ApplicationLayerProtocolNegotiation = 16,
SignedCertificateTimestamp = 18,
ClientCertificateType = 19,
Expand Down Expand Up @@ -47,7 +47,7 @@ impl ExtensionType {
10 => Some(Self::SupportedGroups),
13 => Some(Self::SignatureAlgorithms),
14 => Some(Self::UseSrtp),
15 => Some(Self::Heatbeat),
15 => Some(Self::Heartbeat),
16 => Some(Self::ApplicationLayerProtocolNegotiation),
18 => Some(Self::SignedCertificateTimestamp),
19 => Some(Self::ClientCertificateType),
Expand Down
171 changes: 107 additions & 64 deletions src/extensions/server.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use crate::alert::{AlertDescription, AlertLevel};
use crate::extensions::common::KeyShareEntry;
use crate::extensions::ExtensionType;
use crate::parse_buffer::{ParseBuffer, ParseError};
Expand All @@ -11,6 +12,9 @@ pub enum ServerExtension<'a> {
SupportedVersion(SupportedVersion),
KeyShare(KeyShare<'a>),
PreSharedKey(u16),

SupportedGroups,
ServerName,
}

#[derive(Debug)]
Expand All @@ -26,6 +30,29 @@ impl SupportedVersion {
}
}

pub struct ServerExtensionParserIterator<'a, 'b> {
buffer: &'b mut ParseBuffer<'a>,
allowed: &'b [ExtensionType],
}

impl<'a, 'b> ServerExtensionParserIterator<'a, 'b> {
pub fn new(buffer: &'b mut ParseBuffer<'a>, allowed: &'b [ExtensionType]) -> Self {
Self { buffer, allowed }
}
}

impl<'a, 'b> Iterator for ServerExtensionParserIterator<'a, 'b> {
type Item = Result<Option<ServerExtension<'a>>, TlsError>;

fn next(&mut self) -> Option<Self::Item> {
if self.buffer.is_empty() {
return None;
}

Some(ServerExtension::parse(&mut self.buffer, &self.allowed))
}
}

#[derive(Debug)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub struct KeyShare<'a>(pub(crate) KeyShareEntry<'a>);
Expand All @@ -37,75 +64,91 @@ impl<'a> KeyShare<'a> {
}

impl<'a> ServerExtension<'a> {
pub fn parse_vector(
pub fn parse(
buf: &mut ParseBuffer<'a>,
) -> Result<Vec<ServerExtension<'a>, 16>, TlsError> {
let mut extensions = Vec::new();
allowed: &[ExtensionType],
) -> Result<Option<ServerExtension<'a>>, TlsError> {
let extension_type =
ExtensionType::of(buf.read_u16().map_err(|_| TlsError::UnknownExtensionType)?)
.ok_or(TlsError::UnknownExtensionType)?;

loop {
if buf.is_empty() {
break;
}
trace!("extension type {:?}", extension_type);

if !allowed.contains(&extension_type) {
warn!(
"{:?} extension is not allowed in this context",
extension_type
);

// Section 4.2. Extensions
// If an implementation receives an extension
// which it recognizes and which is not specified for the message in
// which it appears, it MUST abort the handshake with an
// "illegal_parameter" alert.
return Err(TlsError::AbortHandshake(
AlertLevel::Fatal,
AlertDescription::IllegalParameter,
));
}

let extension_type =
ExtensionType::of(buf.read_u16().map_err(|_| TlsError::UnknownExtensionType)?)
.ok_or(TlsError::UnknownExtensionType)?;

//info!("extension type {:?}", extension_type);

let extension_length = buf
.read_u16()
.map_err(|_| TlsError::InvalidExtensionsLength)?;

//info!("extension length {}", extension_length);

match extension_type {
ExtensionType::SupportedVersions => {
extensions
.push(ServerExtension::SupportedVersion(
SupportedVersion::parse(
&mut buf
.slice(extension_length as usize)
.map_err(|_| TlsError::InvalidExtensionsLength)?,
)
.map_err(|_| TlsError::InvalidSupportedVersions)?,
))
.map_err(|_| TlsError::DecodeError)?;
}
ExtensionType::KeyShare => {
extensions
.push(ServerExtension::KeyShare(
KeyShare::parse(
&mut buf
.slice(extension_length as usize)
.map_err(|_| TlsError::InvalidExtensionsLength)?,
)
.map_err(|_| TlsError::InvalidKeyShare)?,
))
.map_err(|_| TlsError::DecodeError)?;
}
ExtensionType::SupportedGroups => {
let _ = buf.slice(extension_length as usize);
}
ExtensionType::ServerName => {
let _ = buf.slice(extension_length as usize);
}
ExtensionType::PreSharedKey => {
let data = buf
.slice(extension_length as usize)
.map_err(|_| TlsError::DecodeError)?;
let data = data.as_slice();
let value = u16::from_be_bytes([data[0], data[1]]);
extensions
.push(ServerExtension::PreSharedKey(value))
.map_err(|_| TlsError::DecodeError)?;
}
t => {
info!("Unsupported extension type {:?}", t);
return Err(TlsError::Unimplemented);
}
let extension_length = buf
.read_u16()
.map_err(|_| TlsError::InvalidExtensionsLength)?;

trace!("extension length {}", extension_length);

Self::from_type_and_data(extension_type, &mut buf.slice(extension_length as usize)?)
}

pub fn parse_vector<const N: usize>(
buf: &mut ParseBuffer<'a>,
allowed: &[ExtensionType],
) -> Result<Vec<ServerExtension<'a>, N>, TlsError> {
let extensions_len = buf
.read_u16()
.map_err(|_| TlsError::InvalidExtensionsLength)?;

let mut ext_buf = buf.slice(extensions_len as usize)?;

let mut iter = ServerExtensionParserIterator::new(&mut ext_buf, allowed);

let mut extensions = Vec::new();

while let Some(extension) = iter.next() {
if let Some(extension) = extension? {
extensions
.push(extension)
.map_err(|_| TlsError::DecodeError)?;
}
}

Ok(extensions)
}

fn from_type_and_data<'b>(
extension_type: ExtensionType,
data: &mut ParseBuffer<'b>,
) -> Result<Option<ServerExtension<'b>>, TlsError> {
let extension = match extension_type {
ExtensionType::SupportedVersions => ServerExtension::SupportedVersion(
SupportedVersion::parse(data).map_err(|_| TlsError::InvalidSupportedVersions)?,
),
ExtensionType::KeyShare => ServerExtension::KeyShare(
KeyShare::parse(data).map_err(|_| TlsError::InvalidKeyShare)?,
),
ExtensionType::PreSharedKey => {
let value = data.read_u16()?;

ServerExtension::PreSharedKey(value)
}
ExtensionType::SupportedGroups => ServerExtension::SupportedGroups,
ExtensionType::ServerName => ServerExtension::ServerName,
t => {
warn!("Unimplemented extension: {:?}", t);
return Ok(None);
}
};

Ok(Some(extension))
}
}
13 changes: 10 additions & 3 deletions src/handshake/certificate.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
use crate::buffer::CryptoBuffer;
use crate::extensions::server::ServerExtension;
use crate::extensions::ExtensionType;
use crate::parse_buffer::ParseBuffer;
use crate::TlsError;
use heapless::Vec;
Expand Down Expand Up @@ -68,6 +70,12 @@ pub enum CertificateEntryRef<'a> {
}

impl<'a> CertificateEntryRef<'a> {
// Source: https://www.rfc-editor.org/rfc/rfc8446#section-4.2 table, rows marked with CT
const ALLOWED_EXTENSIONS: &[ExtensionType] = &[
ExtensionType::StatusRequest,
ExtensionType::SignedCertificateTimestamp,
];

pub fn parse_vector(
buf: &mut ParseBuffer<'a>,
) -> Result<Vec<CertificateEntryRef<'a>, 16>, TlsError> {
Expand All @@ -88,9 +96,8 @@ impl<'a> CertificateEntryRef<'a> {
.push(CertificateEntryRef::X509(cert.as_slice()))
.map_err(|_| TlsError::DecodeError)?;

let _extensions_len = buf
.read_u16()
.map_err(|_| TlsError::InvalidExtensionsLength)?;
// Validate extensions
ServerExtension::parse_vector::<2>(buf, Self::ALLOWED_EXTENSIONS)?;

if buf.is_empty() {
break;
Expand Down
Loading