From 78a90b573e83df2a8f523ec626aaf268a7f5366e Mon Sep 17 00:00:00 2001 From: David Venhoek Date: Thu, 26 Mar 2026 14:57:51 +0100 Subject: [PATCH] Implement custom roots for LAN clients in communication. --- Cargo.lock | 4 +- Cargo.toml | 1 + s2energy-connection/Cargo.toml | 1 + .../examples/communication-client.rs | 4 + .../examples/pairing-client.rs | 4 +- .../examples/pairing-server.rs | 2 +- .../src/communication/client.rs | 49 +++- s2energy-connection/src/communication/mod.rs | 9 +- .../src/communication/transport.rs | 223 ++++++++++++++++++ s2energy-connection/src/lib.rs | 33 +++ s2energy-connection/src/pairing/client.rs | 39 ++- s2energy-connection/src/pairing/mod.rs | 25 +- s2energy-connection/src/pairing/server.rs | 10 +- s2energy-connection/src/pairing/transport.rs | 34 ++- s2energy-connection/src/pairing/wire.rs | 36 ++- 15 files changed, 431 insertions(+), 43 deletions(-) create mode 100644 s2energy-connection/src/communication/transport.rs diff --git a/Cargo.lock b/Cargo.lock index 16c382b..a834e4e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -716,6 +716,7 @@ version = "0.14.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "85649ca51fd72272d7821adaf274ad91c288277713d9c18820d8499a7ff69e9a" dependencies = [ + "serde", "typenum", "version_check", ] @@ -1273,7 +1274,7 @@ version = "0.50.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7957b9740744892f114936ab4a57b3f487491bbeafaf8083688b16841a4240e5" dependencies = [ - "windows-sys 0.60.2", + "windows-sys 0.61.2", ] [[package]] @@ -1680,6 +1681,7 @@ dependencies = [ "axum-server", "base64", "futures-util", + "generic-array", "hmac", "http", "http-body-util", diff --git a/Cargo.toml b/Cargo.toml index 356c093..8c686bc 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,6 +9,7 @@ base64 = "0.22.1" bon = "3.8.0" chrono = { version = "0.4.42", features = ["serde"] } futures-util = "0.3.31" +generic-array = { version = "=0.14.7", features = ["serde"] } hmac = "0.12.1" http = "1.4.0" hyper = "1.8.1" diff --git a/s2energy-connection/Cargo.toml b/s2energy-connection/Cargo.toml index d2f9c95..ea1de77 100644 --- a/s2energy-connection/Cargo.toml +++ b/s2energy-connection/Cargo.toml @@ -8,6 +8,7 @@ axum.workspace = true axum-extra.workspace = true base64.workspace = true futures-util.workspace = true +generic-array.workspace = true hmac.workspace = true http.workspace = true hyper.workspace = true diff --git a/s2energy-connection/examples/communication-client.rs b/s2energy-connection/examples/communication-client.rs index 6d42f5d..2b1e3ce 100644 --- a/s2energy-connection/examples/communication-client.rs +++ b/s2energy-connection/examples/communication-client.rs @@ -35,6 +35,10 @@ impl ClientPairing for &mut MemoryPairing { &self.communication_url } + fn certificate_hash(&self) -> Option { + None + } + async fn set_access_tokens(&mut self, tokens: Vec) -> Result<(), Self::Error> { self.tokens = tokens; Ok(()) diff --git a/s2energy-connection/examples/pairing-client.rs b/s2energy-connection/examples/pairing-client.rs index 3748f17..ae2ecfc 100644 --- a/s2energy-connection/examples/pairing-client.rs +++ b/s2energy-connection/examples/pairing-client.rs @@ -29,7 +29,7 @@ async fn main() { }, vec![MessageVersion("v1".into())], ) - .with_connection_initiate_url("client.example.com".into()) + .with_connection_initiate_url("https://client.example.com".into()) .build() .unwrap(); @@ -61,7 +61,7 @@ async fn main() { let pair_result = rx.await.unwrap(); match pair_result.role { - s2energy_connection::pairing::PairingRole::CommunicationClient { initiate_url } => { + s2energy_connection::pairing::PairingRole::CommunicationClient { initiate_url, .. } => { println!("Paired as client, url: {initiate_url}, token: {}", pair_result.token.0) } s2energy_connection::pairing::PairingRole::CommunicationServer => println!("Paired as server, token: {}", pair_result.token.0), diff --git a/s2energy-connection/examples/pairing-server.rs b/s2energy-connection/examples/pairing-server.rs index a9f8380..5d3c694 100644 --- a/s2energy-connection/examples/pairing-server.rs +++ b/s2energy-connection/examples/pairing-server.rs @@ -35,7 +35,7 @@ async fn main() { }, vec![MessageVersion("v1".into())], ) - .with_connection_initiate_url("test.example.com".into()) + .with_connection_initiate_url("https://test.example.com".into()) .build() .unwrap(); let app = server.get_router(); diff --git a/s2energy-connection/src/communication/client.rs b/s2energy-connection/src/communication/client.rs index e02563a..7edcf11 100644 --- a/s2energy-connection/src/communication/client.rs +++ b/s2energy-connection/src/communication/client.rs @@ -7,10 +7,11 @@ use tokio_tungstenite::{Connector, connect_async_tls_with_config, tungstenite::C use tracing::{debug, trace}; use crate::{ - AccessToken, CommunicationProtocol, EndpointDescription, NodeId, + AccessToken, CertificateHash, CommunicationProtocol, EndpointDescription, NodeId, common::negotiate_version, communication::{ CommunicationResult, ConnectionInfo, Error, ErrorKind, NodeConfig, WebSocketTransport, + transport::hash_checking_http_client, wire::{ CommunicationDetails, CommunicationDetailsErrorMessage, InitiateConnectionRequest, InitiateConnectionResponse, UnpairRequest, }, @@ -52,6 +53,8 @@ pub trait ClientPairing: Send { fn access_tokens(&self) -> impl AsRef<[AccessToken]>; /// The communication url the client can use to contact the server. fn communication_url(&self) -> impl AsRef; + /// Hash of the root certificate the server uses. + fn certificate_hash(&self) -> Option; /// Store a new set of access tokens for the pairing. fn set_access_tokens(&mut self, tokens: Vec) -> impl Future> + Send; @@ -71,14 +74,17 @@ impl Client { /// upon success. #[tracing::instrument(skip_all, fields(client = %pairing.client_id(), server = %pairing.server_id()), level = tracing::Level::ERROR)] pub async fn unpair(&self, pairing: impl ClientPairing) -> CommunicationResult<()> { - let client = reqwest::Client::builder() - .tls_certs_merge( - self.additional_certificates - .iter() - .filter_map(|v| reqwest::Certificate::from_der(v).ok()), - ) - .build() - .map_err(|e| Error::new(ErrorKind::TransportFailed, e))?; + let client = match pairing.certificate_hash() { + Some(hash) => hash_checking_http_client(hash)?, + None => reqwest::Client::builder() + .tls_certs_merge( + self.additional_certificates + .iter() + .filter_map(|v| reqwest::Certificate::from_der(v).ok()), + ) + .build() + .map_err(|e| Error::new(ErrorKind::TransportFailed, e))?, + }; let communication_url = Url::parse(pairing.communication_url().as_ref()).map_err(|e| Error::new(ErrorKind::InvalidUrl, e))?; @@ -301,7 +307,7 @@ mod tests { use tokio::net::TcpListener; use crate::{ - AccessToken, CommunicationProtocol, EndpointDescription, MessageVersion, NodeId, Role, + AccessToken, CertificateHash, CommunicationProtocol, EndpointDescription, MessageVersion, NodeId, Role, common::wire::test::{UUID_A, UUID_B, basic_node_description}, communication::{ self, Client, ClientConfig, ClientPairing, ErrorKind, NodeConfig, PairingLookup, Server, ServerConfig, ServerPairing, @@ -370,6 +376,7 @@ mod tests { server: NodeId, tokens: Arc>>, url: String, + certificate_hash: Option, } impl ClientPairing for &TestPairing { @@ -391,6 +398,10 @@ mod tests { &self.url } + fn certificate_hash(&self) -> Option { + self.certificate_hash.clone() + } + async fn set_access_tokens(&mut self, tokens: Vec) -> Result<(), Self::Error> { *self.tokens.lock().unwrap() = tokens; Ok(()) @@ -451,6 +462,7 @@ mod tests { server: UUID_B.into(), tokens: Arc::new(Mutex::new(vec![AccessToken("testtoken".into())])), url: format!("https://localhost:{}/", addr.port()), + certificate_hash: None, }; assert!(client.unpair(&pairing).await.is_ok()); @@ -483,6 +495,7 @@ mod tests { server: UUID_B.into(), tokens: Arc::new(Mutex::new(vec![AccessToken("invalidtoken".into())])), url: format!("https://localhost:{}/", addr.port()), + certificate_hash: None, }; let error = client.unpair(&pairing).await.unwrap_err(); @@ -511,6 +524,7 @@ mod tests { server: UUID_B.into(), tokens: Arc::new(Mutex::new(vec![AccessToken("testtoken".into())])), url: format!("https://localhost:{}/", addr.port()), + certificate_hash: None, }; let mut client_connection = client.connect(&pairing).await.unwrap(); @@ -577,6 +591,7 @@ mod tests { server: UUID_B.into(), tokens: Arc::new(Mutex::new(vec![AccessToken("testtoken".into())])), url: format!("https://localhost:{}/", addr.port()), + certificate_hash: None, }; let mut client_connection = client.connect(&pairing).await.unwrap(); @@ -637,6 +652,7 @@ mod tests { server: UUID_B.into(), tokens: Arc::new(Mutex::new(vec![AccessToken("testtoken".into())])), url: format!("https://localhost:{}/", addr.port()), + certificate_hash: None, }; let mut client_connection = client.connect(&pairing).await.unwrap(); @@ -692,6 +708,7 @@ mod tests { server: UUID_B.into(), tokens: Arc::new(Mutex::new(vec![AccessToken("testtoken".into())])), url: format!("https://localhost:{}/", addr.port()), + certificate_hash: None, }; let mut client_connection = client.connect(&pairing).await.unwrap(); @@ -740,6 +757,7 @@ mod tests { server: UUID_B.into(), tokens: Arc::new(Mutex::new(vec![AccessToken("testtoken".into())])), url: format!("https://localhost:{}/", addr.port()), + certificate_hash: None, }; let mut client_connection = client.connect(&pairing).await.unwrap(); @@ -792,6 +810,7 @@ mod tests { server: UUID_B.into(), tokens: Arc::new(Mutex::new(vec![AccessToken("testtoken".into())])), url: format!("https://localhost:{}/", addr.port()), + certificate_hash: None, }; assert_eq!(client.connect(&pairing).await.unwrap_err().kind(), ErrorKind::NoSupportedVersion); @@ -830,6 +849,7 @@ mod tests { server: UUID_B.into(), tokens: Arc::new(Mutex::new(vec![AccessToken("testtoken".into())])), url: format!("https://localhost:{}/", addr.port()), + certificate_hash: None, }; assert_eq!(client.connect(&pairing).await.unwrap_err().kind(), ErrorKind::NoSupportedVersion); @@ -876,6 +896,7 @@ mod tests { server: UUID_B.into(), tokens: Arc::new(Mutex::new(vec![AccessToken("testtoken".into())])), url: format!("https://localhost:{}/", addr.port()), + certificate_hash: None, }; assert_eq!(client.connect(&pairing).await.unwrap_err().kind(), ErrorKind::NoSupportedVersion); @@ -922,6 +943,7 @@ mod tests { server: UUID_B.into(), tokens: Arc::new(Mutex::new(vec![AccessToken("testtoken".into())])), url: format!("https://localhost:{}/", addr.port()), + certificate_hash: None, }; assert_eq!(client.connect(&pairing).await.unwrap_err().kind(), ErrorKind::NoSupportedVersion); @@ -973,6 +995,10 @@ mod tests { &self.url } + fn certificate_hash(&self) -> Option { + None + } + async fn set_access_tokens(&mut self, _tokens: Vec) -> Result<(), Self::Error> { Err(std::io::ErrorKind::Other.into()) } @@ -995,6 +1021,7 @@ mod tests { server: UUID_B.into(), tokens: Arc::new(Mutex::new(vec![AccessToken("testtoken".into())])), url: format!("https://localhost:{}/", addr.port()), + certificate_hash: None, }; let mut client_connection = client.connect(&pairing).await.unwrap(); @@ -1042,6 +1069,7 @@ mod tests { server: UUID_B.into(), tokens: Arc::new(Mutex::new(vec![AccessToken("testtoken".into())])), url: format!("https://localhost:{}/", addr.port()), + certificate_hash: None, }; assert_eq!(client.connect(&pairing).await.unwrap_err().kind(), ErrorKind::ProtocolError); @@ -1076,6 +1104,7 @@ mod tests { server: UUID_B.into(), tokens: Arc::new(Mutex::new(vec![AccessToken("testtoken".into())])), url: format!("https://localhost:{}/", addr.port()), + certificate_hash: None, }; assert_eq!(client.connect(&pairing).await.unwrap_err().kind(), ErrorKind::ProtocolError); diff --git a/s2energy-connection/src/communication/mod.rs b/s2energy-connection/src/communication/mod.rs index ad84fa5..6336e44 100644 --- a/s2energy-connection/src/communication/mod.rs +++ b/s2energy-connection/src/communication/mod.rs @@ -44,12 +44,13 @@ //! # use std::sync::Arc; //! # use std::convert::Infallible; //! # use s2energy_connection::communication::{NodeConfig, Client, ClientConfig, ClientPairing}; -//! # use s2energy_connection::{MessageVersion, AccessToken, NodeId}; +//! # use s2energy_connection::{MessageVersion, AccessToken, NodeId, CertificateHash}; //! struct MemoryClientPairing { //! client_id: NodeId, //! server_id: NodeId, //! communication_url: String, //! access_tokens: Vec, +//! certificate_hash: Option, //! } //! //! impl ClientPairing for MemoryClientPairing { @@ -71,6 +72,10 @@ //! &self.access_tokens //! } //! +//! fn certificate_hash(&self) -> Option { +//! self.certificate_hash.clone() +//! } +//! //! async fn set_access_tokens(&mut self, tokens: Vec) -> Result<(), Infallible> { //! self.access_tokens = tokens; //! Ok(()) @@ -84,6 +89,7 @@ //! server_id: NodeId::try_from("67e55044-10b1-426f-9247-bb680e5fe0c6").unwrap(), //! communication_url: "https://example.com".into(), //! access_tokens: vec![AccessToken("some-token-value".into())], +//! certificate_hash: None, //! }); //! ``` //! @@ -210,6 +216,7 @@ use crate::{EndpointDescription, MessageVersion, NodeDescription}; mod client; mod error; mod server; +mod transport; mod websocket; mod wire; diff --git a/s2energy-connection/src/communication/transport.rs b/s2energy-connection/src/communication/transport.rs new file mode 100644 index 0000000..b85d59a --- /dev/null +++ b/s2energy-connection/src/communication/transport.rs @@ -0,0 +1,223 @@ +use std::sync::{Arc, OnceLock}; + +use rustls::{ + RootCertStore, + client::{WebPkiServerVerifier, danger::ServerCertVerifier}, + pki_types::CertificateDer, +}; + +use crate::CertificateHash; + +use super::{CommunicationResult, Error, ErrorKind}; + +#[derive(Debug)] +struct HashedCertificateVerifier { + inner: rustls_platform_verifier::Verifier, + self_signed_state: OnceLock, + root_hash: CertificateHash, +} + +#[derive(Debug)] +struct SelfSignedState { + hash: CertificateHash, + verifier: SelfVerifier, +} + +#[derive(Debug)] +enum SelfVerifier { + WebPki(WebPkiServerVerifier), + None, +} + +impl ServerCertVerifier for SelfVerifier { + fn verify_server_cert( + &self, + end_entity: &CertificateDer<'_>, + intermediates: &[CertificateDer<'_>], + server_name: &rustls::pki_types::ServerName<'_>, + ocsp_response: &[u8], + now: rustls::pki_types::UnixTime, + ) -> Result { + match self { + SelfVerifier::WebPki(web_pki_server_verifier) => { + web_pki_server_verifier.verify_server_cert(end_entity, intermediates, server_name, ocsp_response, now) + } + SelfVerifier::None => Err(rustls::Error::InvalidCertificate(rustls::CertificateError::UnknownIssuer)), + } + } + + fn verify_tls12_signature( + &self, + message: &[u8], + cert: &CertificateDer<'_>, + dss: &rustls::DigitallySignedStruct, + ) -> Result { + match self { + SelfVerifier::WebPki(web_pki_server_verifier) => web_pki_server_verifier.verify_tls12_signature(message, cert, dss), + SelfVerifier::None => Err(rustls::Error::InvalidCertificate(rustls::CertificateError::UnknownIssuer)), + } + } + + fn verify_tls13_signature( + &self, + message: &[u8], + cert: &CertificateDer<'_>, + dss: &rustls::DigitallySignedStruct, + ) -> Result { + match self { + SelfVerifier::WebPki(web_pki_server_verifier) => web_pki_server_verifier.verify_tls13_signature(message, cert, dss), + SelfVerifier::None => Err(rustls::Error::InvalidCertificate(rustls::CertificateError::UnknownIssuer)), + } + } + + fn supported_verify_schemes(&self) -> Vec { + match self { + SelfVerifier::WebPki(web_pki_server_verifier) => web_pki_server_verifier.supported_verify_schemes(), + SelfVerifier::None => vec![], + } + } +} + +impl ServerCertVerifier for HashedCertificateVerifier { + fn verify_server_cert( + &self, + end_entity: &rustls::pki_types::CertificateDer<'_>, + intermediates: &[rustls::pki_types::CertificateDer<'_>], + server_name: &rustls::pki_types::ServerName<'_>, + ocsp_response: &[u8], + now: rustls::pki_types::UnixTime, + ) -> Result { + let state = self.self_signed_state.get_or_init(|| { + let fallback = CertificateDer::from_slice(&[]); + let root_cert = intermediates.last().unwrap_or(&fallback); + let hash = CertificateHash::sha256(root_cert); + let mut root_store = RootCertStore::empty(); + // conciously ignore errors here, we just want to initialize + root_store.add(root_cert.clone()).ok(); + let verifier = match WebPkiServerVerifier::builder(Arc::new(root_store)).build() { + Ok(verifier) => SelfVerifier::WebPki(Arc::try_unwrap(verifier).unwrap()), + Err(_) => SelfVerifier::None, + }; + + SelfSignedState { hash, verifier } + }); + if state.hash != self.root_hash { + return Err(rustls::Error::InvalidCertificate(rustls::CertificateError::UnknownIssuer)); + } + state + .verifier + .verify_server_cert(end_entity, intermediates, server_name, ocsp_response, now) + } + + fn verify_tls12_signature( + &self, + message: &[u8], + cert: &rustls::pki_types::CertificateDer<'_>, + dss: &rustls::DigitallySignedStruct, + ) -> Result { + self.inner.verify_tls12_signature(message, cert, dss) + } + + fn verify_tls13_signature( + &self, + message: &[u8], + cert: &rustls::pki_types::CertificateDer<'_>, + dss: &rustls::DigitallySignedStruct, + ) -> Result { + self.inner.verify_tls13_signature(message, cert, dss) + } + + fn supported_verify_schemes(&self) -> Vec { + self.inner.supported_verify_schemes() + } +} + +pub(crate) fn hash_checking_http_client(root_hash: CertificateHash) -> CommunicationResult { + let rustls_config_builder = rustls::ClientConfig::builder(); + let crypto_provider = rustls_config_builder.crypto_provider().clone(); + let verifier = HashedCertificateVerifier { + inner: rustls_platform_verifier::Verifier::new(crypto_provider).map_err(|e| Error::new(ErrorKind::TransportFailed, e))?, + self_signed_state: OnceLock::new(), + root_hash, + }; + let client_config = rustls_config_builder + .dangerous() + .with_custom_certificate_verifier(Arc::new(verifier)) + .with_no_client_auth(); + + let client = reqwest::Client::builder() + .use_preconfigured_tls(client_config) + .build() + .map_err(|e| Error::new(ErrorKind::TransportFailed, e))?; + + Ok(client) +} + +#[cfg(test)] +mod tests { + use std::net::{Ipv4Addr, SocketAddr}; + + use axum::{Router, routing::get}; + use axum_server::tls_rustls::RustlsConfig; + use rustls::pki_types::{CertificateDer, pem::PemObject}; + + use crate::{CertificateHash, communication::transport::hash_checking_http_client}; + + #[tokio::test] + async fn matching_certificates() { + let rustls_config = RustlsConfig::from_pem( + include_bytes!("../../testdata/localhost.chain.pem").into(), + include_bytes!("../../testdata/localhost.key").into(), + ) + .await + .unwrap(); + let router = Router::new().route("/", get(|| async { "Hello world" })); + let https_server_handle = axum_server::Handle::new(); + let https_server_handle_clone = https_server_handle.clone(); + tokio::spawn(async move { + axum_server::bind_rustls(SocketAddr::new(Ipv4Addr::LOCALHOST.into(), 0), rustls_config) + .handle(https_server_handle_clone) + .serve(router.into_make_service()) + .await + .unwrap(); + }); + let addr = https_server_handle.listening().await.unwrap(); + + let client = hash_checking_http_client(CertificateHash::sha256( + &CertificateDer::from_pem_slice(include_bytes!("../../testdata/root.pem")).unwrap(), + )) + .unwrap(); + assert!(client.get(format!("https://localhost:{}/", addr.port())).send().await.is_ok()); + + https_server_handle.shutdown(); + } + + #[tokio::test] + async fn mismatching_certificates() { + let rustls_config = RustlsConfig::from_pem( + include_bytes!("../../testdata/localhost.chain.pem").into(), + include_bytes!("../../testdata/localhost.key").into(), + ) + .await + .unwrap(); + let router = Router::new().route("/", get(|| async { "Hello world" })); + let https_server_handle = axum_server::Handle::new(); + let https_server_handle_clone = https_server_handle.clone(); + tokio::spawn(async move { + axum_server::bind_rustls(SocketAddr::new(Ipv4Addr::LOCALHOST.into(), 0), rustls_config) + .handle(https_server_handle_clone) + .serve(router.into_make_service()) + .await + .unwrap(); + }); + let addr = https_server_handle.listening().await.unwrap(); + + let client = hash_checking_http_client(CertificateHash::sha256( + &CertificateDer::from_pem_slice(include_bytes!("../../testdata/altroot.pem")).unwrap(), + )) + .unwrap(); + assert!(client.get(format!("https://localhost:{}/", addr.port())).send().await.is_err()); + + https_server_handle.shutdown(); + } +} diff --git a/s2energy-connection/src/lib.rs b/s2energy-connection/src/lib.rs index b38c993..3b4d71c 100644 --- a/s2energy-connection/src/lib.rs +++ b/s2energy-connection/src/lib.rs @@ -18,3 +18,36 @@ pub mod pairing; pub use common::wire::{ AccessToken, CommunicationProtocol, Deployment, EndpointDescription, InvalidNodeId, MessageVersion, NodeDescription, NodeId, Role, }; +use serde::{Deserialize, Serialize}; +use sha2::Digest; + +/// Hash of a TLS certificate. +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq, Hash)] +pub struct CertificateHash(CertificateHashInner); + +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq, Hash)] +enum CertificateHashInner { + Sha256(sha2::digest::generic_array::GenericArray::OutputSize>), +} + +impl CertificateHash { + pub(crate) fn sha256(data: &[u8]) -> Self { + Self(CertificateHashInner::Sha256(sha2::Sha256::digest(data))) + } +} + +impl AsRef for CertificateHash { + fn as_ref(&self) -> &CertificateHash { + self + } +} + +impl std::ops::Deref for CertificateHash { + type Target = [u8]; + + fn deref(&self) -> &Self::Target { + match &self.0 { + CertificateHashInner::Sha256(generic_array) => generic_array, + } + } +} diff --git a/s2energy-connection/src/pairing/client.rs b/s2energy-connection/src/pairing/client.rs index b05999f..e3cf5fa 100644 --- a/s2energy-connection/src/pairing/client.rs +++ b/s2energy-connection/src/pairing/client.rs @@ -8,7 +8,7 @@ use crate::common::negotiate_version; use crate::common::wire::{AccessToken, Deployment, PairingVersion, Role}; use crate::pairing::transport::{HashProvider, hash_providing_https_client}; use crate::pairing::{ConfigError, Error, Pairing, PairingRole}; -use crate::{EndpointDescription, NodeDescription, NodeId}; +use crate::{CertificateHash, EndpointDescription, NodeDescription, NodeId}; use super::NodeConfig; use super::wire::*; @@ -392,7 +392,9 @@ impl Client { } fn prepare_reqwest_client(&self, url: &Url) -> Result<(reqwest::Client, Option), Error> { - let (client, certhash) = if url.domain().map(|v| v.ends_with(".local")).unwrap_or_default() { + let (client, certhash) = if url.domain().map(|v| v.ends_with(".local")).unwrap_or_default() + || url.domain().map(|v| v.ends_with(".local.")).unwrap_or_default() + { let (client, certhash) = hash_providing_https_client()?; (client, Some(certhash)) } else { @@ -476,11 +478,11 @@ impl<'a> V1Session<'a> { let our_deployment = self.endpoint_description.deployment.unwrap_or(local_deployment); let our_role = self.config.node_description.role; - let network = if self.base_url.domain().map(|v| v.ends_with(".local")).unwrap_or_default() { - if let Some(hash) = certhash.as_ref().and_then(HashProvider::hash) { - Network::Lan { - fingerprint: hash.try_into().unwrap(), - } + let network = if self.base_url.domain().map(|v| v.ends_with(".local")).unwrap_or_default() + || self.base_url.domain().map(|v| v.ends_with(".local.")).unwrap_or_default() + { + if let Some(hash) = certhash.as_ref().and_then(HashProvider::leaf_hash) { + Network::Lan { fingerprint: hash.clone() } } else { return Err(ErrorKind::ProtocolError.into()); } @@ -552,6 +554,7 @@ impl<'a> V1Session<'a> { server_hmac_challenge_response, initiate_connection_url.clone(), access_token.clone(), + self.config.root_certificate.as_deref().map(CertificateHash::sha256), ) .await { @@ -573,12 +576,26 @@ impl<'a> V1Session<'a> { return Err(e); } }; + + let initiate_url = + Url::parse(&connection_details.initiate_connection_url).map_err(|e| Error::new(ErrorKind::ProtocolError, e))?; + let root_hash = if initiate_url.domain().map(|v| v.ends_with(".local")).unwrap_or_default() + || initiate_url.domain().map(|v| v.ends_with(".local.")).unwrap_or_default() + { + connection_details + .certificate_fingerprint + .or(certhash.as_ref().and_then(HashProvider::root_hash).cloned()) + } else { + None + }; + Pairing { remote_endpoint_description: request_pairing_response.server_endpoint_description, remote_node_description: request_pairing_response.server_node_description, token: connection_details.access_token, role: PairingRole::CommunicationClient { initiate_url: connection_details.initiate_connection_url, + root_hash, }, } } @@ -644,12 +661,14 @@ impl<'a> V1Session<'a> { server_hmac_challenge_response: HmacChallengeResponse, initiate_connection_url: String, access_token: AccessToken, + certificate_fingerprint: Option, ) -> PairingResult<()> { let request = PostConnectionDetailsRequest { server_hmac_challenge_response, connection_details: ConnectionDetails { initiate_connection_url, access_token, + certificate_fingerprint, }, }; let response = self @@ -888,12 +907,12 @@ mod tests { #[tokio::test] async fn pairing_ok_rm_initiates() { let server_config = NodeConfig::builder(basic_node_description(UUID_A, Role::Cem), vec![MessageVersion("v1".into())]) - .with_connection_initiate_url("test.example.com".into()) + .with_connection_initiate_url("https://test.example.com".into()) .build() .unwrap(); let client_config = NodeConfig::builder(basic_node_description(UUID_B, Role::Rm), vec![MessageVersion("v1".into())]) - .with_connection_initiate_url("client.example.com".into()) + .with_connection_initiate_url("https://client.example.com".into()) .build() .unwrap(); @@ -1502,7 +1521,7 @@ mod tests { #[tokio::test] async fn longpolling() { let server_config = NodeConfig::builder(basic_node_description(UUID_A, Role::Cem), vec![MessageVersion("v1".into())]) - .with_connection_initiate_url("test.example.com".into()) + .with_connection_initiate_url("https://test.example.com".into()) .build() .unwrap(); diff --git a/s2energy-connection/src/pairing/mod.rs b/s2energy-connection/src/pairing/mod.rs index 58ff60a..1a9491d 100644 --- a/s2energy-connection/src/pairing/mod.rs +++ b/s2energy-connection/src/pairing/mod.rs @@ -210,6 +210,7 @@ mod wire; use rand::CryptoRng; +use rustls::pki_types::CertificateDer; use wire::{HmacChallenge, HmacChallengeResponse}; pub use client::{Client, ClientConfig, LongpollHandler, Longpoller, PairingRemote, PrePairing}; @@ -219,7 +220,10 @@ pub use server::{ }; pub use wire::NodeIdAlias; -use crate::{CommunicationProtocol, Deployment, EndpointDescription, MessageVersion, NodeDescription, Role, common::wire::AccessToken}; +use crate::{ + CertificateHash, CommunicationProtocol, Deployment, EndpointDescription, MessageVersion, NodeDescription, Role, + common::wire::AccessToken, +}; /// Full description of an S2 node. #[derive(Debug, Clone)] @@ -228,6 +232,7 @@ pub struct NodeConfig { supported_message_versions: Vec, supported_communication_protocols: Vec, connection_initiate_url: Option, + root_certificate: Option>, } impl NodeConfig { @@ -251,6 +256,11 @@ impl NodeConfig { self.connection_initiate_url.as_deref() } + /// Root certificate used by the node in communication, if known. + pub fn root_certificate(&self) -> Option<&CertificateDer<'static>> { + self.root_certificate.as_ref() + } + /// Create a builder for a new [`NodeConfig`]. /// /// All node configurations must at least contain description of the node and supported message versions. Additional @@ -261,6 +271,7 @@ impl NodeConfig { supported_message_versions, supported_communication_protocols: vec![CommunicationProtocol("WebSocket".into())], connection_initiate_url: None, + root_certificate: None, } } } @@ -271,6 +282,7 @@ pub struct ConfigBuilder { supported_message_versions: Vec, supported_communication_protocols: Vec, connection_initiate_url: Option, + root_certificate: Option>, } impl ConfigBuilder { @@ -288,6 +300,12 @@ impl ConfigBuilder { self } + /// Set the root certificate used in communication by this node. + pub fn with_root_certificate(mut self, root_certificate: CertificateDer<'static>) -> Self { + self.root_certificate = Some(root_certificate); + self + } + /// Create the actual [`NodeConfig`], validating that it is reasonable. pub fn build(self) -> Result { if self.node_description.role == Role::Cem && self.connection_initiate_url.is_none() { @@ -298,6 +316,7 @@ impl ConfigBuilder { supported_message_versions: self.supported_message_versions, supported_communication_protocols: self.supported_communication_protocols, connection_initiate_url: self.connection_initiate_url, + root_certificate: self.root_certificate, }) } } @@ -309,6 +328,8 @@ pub enum PairingRole { CommunicationClient { /// URL to be used for initiating the connection. initiate_url: String, + /// Hash of the root certificate of the communication server + root_hash: Option, }, /// This node gets contacted by the other node to initiate a connection. CommunicationServer, @@ -400,7 +421,7 @@ pub type PairingResult = Result; #[derive(Debug)] enum Network { Wan, - Lan { fingerprint: [u8; 32] }, + Lan { fingerprint: CertificateHash }, } impl Network { diff --git a/s2energy-connection/src/pairing/server.rs b/s2energy-connection/src/pairing/server.rs index 3efeaa0..264a5dc 100644 --- a/s2energy-connection/src/pairing/server.rs +++ b/s2energy-connection/src/pairing/server.rs @@ -28,6 +28,7 @@ use tokio::{ use tracing::{Instrument, info, trace}; use crate::{ + CertificateHash, common::{ AbortingJoinHandle, root, wire::{AccessToken, EndpointDescription, NodeDescription, NodeId, PairingVersion}, @@ -133,7 +134,7 @@ impl Clone for Server { /// Configuration for the S2 pairing server. pub struct ServerConfig { - /// The root certificate of the server, if we are using a self-signed root. + /// The leaf certificate of the server, if we are using a self-signed root. /// Presence of this field indicates we are deployed on LAN. pub leaf_certificate: Option>, /// Endpoint description of the server @@ -212,7 +213,7 @@ impl Server { network: server_config .leaf_certificate .map(|v| Network::Lan { - fingerprint: sha2::Sha256::digest(v).into(), + fingerprint: CertificateHash::sha256(&v), }) .unwrap_or(Network::Wan), advertised_nodes: Mutex::new(server_config.advertised_nodes), @@ -1021,6 +1022,7 @@ async fn v1_request_connection_details( None => return (Err(StatusCode::BAD_REQUEST), None), }, access_token: AccessToken::new(&mut rng), + certificate_fingerprint: state.config.root_certificate.as_deref().map(CertificateHash::sha256), }; trace!("Generated connection details"); @@ -1094,6 +1096,7 @@ async fn v1_post_connection_details( access_token: req.connection_details.access_token, role: PairingRole::CommunicationClient { initiate_url: req.connection_details.initiate_connection_url, + root_hash: req.connection_details.certificate_fingerprint, }, }); @@ -2018,6 +2021,7 @@ mod tests { connection_details: ConnectionDetails { initiate_connection_url: "https://example.com/".into(), access_token: AccessToken::new(&mut rand::rng()), + certificate_fingerprint: None, }, }) .unwrap(), @@ -2073,6 +2077,7 @@ mod tests { connection_details: ConnectionDetails { initiate_connection_url: "https://example.com/".into(), access_token: AccessToken::new(&mut rand::rng()), + certificate_fingerprint: None, }, }) .unwrap(), @@ -2130,6 +2135,7 @@ mod tests { connection_details: ConnectionDetails { initiate_connection_url: "https://example.com/".into(), access_token: AccessToken::new(&mut rand::rng()), + certificate_fingerprint: None, }, }) .unwrap(), diff --git a/s2energy-connection/src/pairing/transport.rs b/s2energy-connection/src/pairing/transport.rs index 06b40cc..2fc122d 100644 --- a/s2energy-connection/src/pairing/transport.rs +++ b/s2energy-connection/src/pairing/transport.rs @@ -5,9 +5,8 @@ use rustls::{ client::{WebPkiServerVerifier, danger::ServerCertVerifier}, pki_types::CertificateDer, }; -use sha2::Digest; -use crate::pairing::Error; +use crate::{CertificateHash, pairing::Error}; use super::{ErrorKind, PairingResult}; @@ -19,7 +18,8 @@ struct HashingCertificateVerifier { #[derive(Debug)] struct SelfSignedState { - hash: CertificateHash, + root_hash: CertificateHash, + leaf_hash: CertificateHash, verifier: SelfVerifier, } @@ -78,8 +78,6 @@ impl ServerCertVerifier for SelfVerifier { } } -type CertificateHash = sha2::digest::generic_array::GenericArray::OutputSize>; - impl ServerCertVerifier for HashingCertificateVerifier { fn verify_server_cert( &self, @@ -98,7 +96,8 @@ impl ServerCertVerifier for HashingCertificateVerifier { let state = self.self_signed_state.get_or_init(|| { let fallback = CertificateDer::from_slice(&[]); let root_cert = intermediates.last().unwrap_or(&fallback); - let hash = sha2::Sha256::digest(end_entity); + let root_hash = CertificateHash::sha256(root_cert); + let leaf_hash = CertificateHash::sha256(end_entity); let mut root_store = RootCertStore::empty(); // conciously ignore errors here, we just want to initialize root_store.add(root_cert.clone()).ok(); @@ -107,7 +106,11 @@ impl ServerCertVerifier for HashingCertificateVerifier { Err(_) => SelfVerifier::None, }; - SelfSignedState { hash, verifier } + SelfSignedState { + root_hash, + leaf_hash, + verifier, + } }); state .verifier @@ -144,9 +147,16 @@ pub(crate) struct HashProvider { } impl HashProvider { - pub(crate) fn hash(&self) -> Option<&[u8]> { + pub(crate) fn leaf_hash(&self) -> Option<&CertificateHash> { + match self.state.get() { + Some(state) => Some(&state.leaf_hash), + None => None, + } + } + + pub(crate) fn root_hash(&self) -> Option<&CertificateHash> { match self.state.get() { - Some(state) => Some(&state.hash), + Some(state) => Some(&state.root_hash), None => None, } } @@ -205,7 +215,7 @@ mod tests { let (client, hash_provider) = hash_providing_https_client().unwrap(); assert!(client.get(format!("https://localhost:{}/", addr.port())).send().await.is_ok()); - assert!(hash_provider.hash().is_some()); + assert!(hash_provider.leaf_hash().is_some()); assert!(client.get(format!("https://localhost:{}/", addr.port())).send().await.is_ok()); https_server_handle.shutdown(); @@ -233,7 +243,7 @@ mod tests { let (client, hash_provider) = hash_providing_https_client().unwrap(); assert!(client.get(format!("https://localhost:{}/", addr.port())).send().await.is_ok()); - assert!(hash_provider.hash().is_some()); + assert!(hash_provider.leaf_hash().is_some()); https_server_handle.shutdown(); @@ -280,7 +290,7 @@ mod tests { let (client, hash_provider) = hash_providing_https_client().unwrap(); assert!(client.get(format!("https://localhost:{}/", addr.port())).send().await.is_ok()); - assert!(hash_provider.hash().is_some()); + assert!(hash_provider.leaf_hash().is_some()); https_server_handle.shutdown(); diff --git a/s2energy-connection/src/pairing/wire.rs b/s2energy-connection/src/pairing/wire.rs index f5bd5f2..73d5d91 100644 --- a/s2energy-connection/src/pairing/wire.rs +++ b/s2energy-connection/src/pairing/wire.rs @@ -2,13 +2,13 @@ use axum::{Json, extract::FromRequestParts, response::IntoResponse}; use axum_extra::{TypedHeader, headers}; use http::StatusCode; use rand::distr::{Alphanumeric, SampleString}; -use serde::*; +use serde::{ser::SerializeMap, *}; use subtle::ConstantTimeEq; use thiserror::Error; use tracing::info; use crate::{ - NodeId, + CertificateHash, CertificateHashInner, NodeId, common::wire::{AccessToken, CommunicationProtocol, EndpointDescription, MessageVersion, NodeDescription}, }; @@ -234,6 +234,38 @@ pub(crate) struct CancelPrePairingRequest { pub(crate) struct ConnectionDetails { pub initiate_connection_url: String, pub access_token: AccessToken, + #[serde( + default, + skip_serializing_if = "Option::is_none", + serialize_with = "serialize_fingerprint", + deserialize_with = "deserialize_fingerprint" + )] + pub certificate_fingerprint: Option, +} + +pub(crate) fn serialize_fingerprint(value: &Option, serializer: S) -> Result { + use base64::{Engine, engine::general_purpose::STANDARD}; + // Unwrap is ok here as we serialize only when not none. + let encoded = STANDARD.encode(value.as_deref().unwrap() as &[u8]); + let mut map = serializer.serialize_map(Some(1))?; + map.serialize_entry("SHA256", &encoded)?; + map.end() +} + +pub(crate) fn deserialize_fingerprint<'de, D: Deserializer<'de>>(deserializer: D) -> Result, D::Error> { + use base64::{Engine, engine::general_purpose::STANDARD}; + use std::{borrow::Cow, collections::HashMap}; + let data = HashMap::, Cow<'de, str>>::deserialize(deserializer)?; + if let Some(hash) = data.get("SHA256") { + let decoded = STANDARD.decode(hash.as_ref()).map_err(de::Error::custom)?; + Ok(Some(CertificateHash(CertificateHashInner::Sha256( + <[u8; 32]>::try_from(decoded) + .map_err(|_| de::Error::custom("Hash is wrong length"))? + .into(), + )))) + } else { + Err(de::Error::custom("Missing SHA256 hash")) + } } #[derive(Serialize, Deserialize)]