Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 132 additions & 2 deletions s2energy-connection/src/pairing/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use crate::common::negotiate_version;
use crate::common::wire::{AccessToken, Deployment, PairingVersion, S2Role};
use crate::pairing::transport::{HashProvider, hash_providing_https_client};
use crate::pairing::{ConfigError, Error, Pairing, PairingRole};
use crate::{S2EndpointDescription, S2NodeId};
use crate::{S2EndpointDescription, S2NodeDescription, S2NodeId};

use super::NodeConfig;
use super::wire::*;
Expand Down Expand Up @@ -241,6 +241,61 @@ impl Client {
})
}

/// Get information about a specific endpoint and its nodes
#[tracing::instrument(skip_all, fields(remote = ?remote), level = tracing::Level::ERROR)]
pub async fn get_endpoint_descriptors(&self, remote: String) -> PairingResult<(S2EndpointDescription, Vec<S2NodeDescription>)> {
trace!("Querying remote for descriptors");

let url = Url::try_from(remote.as_str()).map_err(|e| Error::new(ErrorKind::InvalidUrl, e))?;

let (client, _) = self.prepare_reqwest_client(&url)?;

trace!("Prepared reqwest client.");

let pairing_version = negotiate_version(&client, url.clone()).await?;

match pairing_version {
PairingVersion::V1 => {
let base_url = url.join("v1/").unwrap();

let endpoint_response = client
.get(base_url.join("s2endpoint").unwrap())
.send()
.await
.map_err(|e| Error::new(ErrorKind::TransportFailed, e))?;

if endpoint_response.status() == StatusCode::UNAUTHORIZED {
return Err(ErrorKind::Rejected.into());
}
if endpoint_response.status() != StatusCode::OK {
return Err(ErrorKind::ProtocolError.into());
}

let endpoint: S2EndpointDescription = endpoint_response
.json()
.await
.map_err(|e| Error::new(ErrorKind::ProtocolError, e))?;

let node_response = client
.get(base_url.join("s2nodes").unwrap())
.send()
.await
.map_err(|e| Error::new(ErrorKind::TransportFailed, e))?;

if node_response.status() == StatusCode::UNAUTHORIZED {
return Err(ErrorKind::Rejected.into());
}
if node_response.status() != StatusCode::OK {
return Err(ErrorKind::ProtocolError.into());
}

let nodes: Vec<S2NodeDescription> = node_response.json().await.map_err(|e| Error::new(ErrorKind::ProtocolError, e))?;

Ok((endpoint, nodes))
}
}
}

/// Create a longpoller for a given remote.
pub async fn longpoller(&self, remote: String) -> PairingResult<Longpoller> {
let span = span!(tracing::Level::ERROR, "longpolling", remote);
Expand Down Expand Up @@ -692,7 +747,10 @@ mod tests {
},
};

use axum::{Json, Router, routing::post};
use axum::{
Json, Router,
routing::{get, post},
};
use axum_server::{Handle, tls_rustls::RustlsConfig};
use http::StatusCode;
use rustls::pki_types::{CertificateDer, pem::PemObject};
Expand Down Expand Up @@ -755,6 +813,78 @@ mod tests {
setup_server_with_prepairing(config, NoopPrePairingHandler, overrides).await
}

#[tokio::test]
async fn descriptors() {
let server_config = NodeConfig::builder(basic_node_description(UUID_A, S2Role::Cem), vec![MessageVersion("v1".into())])
.with_connection_initiate_url("test.example.com".into())
.build()
.unwrap();

let (server_handle, _server_pairing, _) = setup_server(
server_config,
Router::new()
.route("/v1/s2endpoint", get(|| async { Json(S2EndpointDescription::default()) }))
.route(
"/v1/s2nodes",
get(|| async {
Json(vec![
basic_node_description(UUID_A, S2Role::Cem),
basic_node_description(UUID_B, S2Role::Rm),
])
}),
),
)
.await;

let addr = server_handle.listening().await.unwrap();

let client = Client::new(ClientConfig {
additional_certificates: vec![CertificateDer::from_pem_slice(include_bytes!("../../testdata/root.pem")).unwrap()],
endpoint_description: S2EndpointDescription::default(),
pairing_deployment: Deployment::Wan,
})
.unwrap();

let (endpoint, nodes) = client
.get_endpoint_descriptors(format!("https://localhost:{}/", addr.port()))
.await
.unwrap();
assert_eq!(endpoint.deployment, None);
assert_eq!(nodes.len(), 2);
assert_eq!(nodes.first().unwrap().id, UUID_A.into());
}

#[tokio::test]
async fn descriptors_forbidden() {
let server_config = NodeConfig::builder(basic_node_description(UUID_A, S2Role::Cem), vec![MessageVersion("v1".into())])
.with_connection_initiate_url("test.example.com".into())
.build()
.unwrap();

let (server_handle, _server_pairing, _) = setup_server(
server_config,
Router::new()
.route("/v1/s2endpoint", get(|| async { StatusCode::UNAUTHORIZED }))
.route("/v1/s2nodes", get(|| async { StatusCode::UNAUTHORIZED })),
)
.await;

let addr = server_handle.listening().await.unwrap();

let client = Client::new(ClientConfig {
additional_certificates: vec![CertificateDer::from_pem_slice(include_bytes!("../../testdata/root.pem")).unwrap()],
endpoint_description: S2EndpointDescription::default(),
pairing_deployment: Deployment::Wan,
})
.unwrap();

let err = client
.get_endpoint_descriptors(format!("https://localhost:{}/", addr.port()))
.await
.unwrap_err();
assert_eq!(err.kind(), ErrorKind::Rejected);
}

#[tokio::test]
async fn pairing_ok_rm_initiates() {
let server_config = NodeConfig::builder(basic_node_description(UUID_A, S2Role::Cem), vec![MessageVersion("v1".into())])
Expand Down
2 changes: 1 addition & 1 deletion s2energy-connection/src/pairing/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ pub enum ErrorKind {
AlreadyPending,
/// Provided token was invalid.
InvalidToken,
/// Remote permanently rejects longpolling
/// Remote permanently rejects longpolling or querying of node information.
Rejected,
/// The pairing or longpolling session was cancelled.
Cancelled,
Expand Down
Loading