diff --git a/Cargo.lock b/Cargo.lock index e0e72696c0..0a44f998cb 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -197,6 +197,15 @@ dependencies = [ "syn 2.0.106", ] +[[package]] +name = "atomic-polyfill" +version = "1.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8cf2bce30dfe09ef0bfaef228b9d414faaf7e563035494d7fe092dba54b300f4" +dependencies = [ + "critical-section", +] + [[package]] name = "atomic-waker" version = "1.1.2" @@ -768,6 +777,7 @@ version = "0.0.0" dependencies = [ "futures", "getrandom 0.3.3", + "heapless", "inspect", "inspect_counters", "libc", @@ -776,8 +786,10 @@ dependencies = [ "smoltcp", "socket2", "thiserror 2.0.16", + "tracelimit", "tracing", "windows-sys 0.61.0", + "zerocopy 0.8.27", ] [[package]] @@ -878,6 +890,12 @@ dependencies = [ "itertools 0.13.0", ] +[[package]] +name = "critical-section" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "790eea4361631c5e7d22598ecd5723ff611904e3344ce8720784c93e3d83d40b" + [[package]] name = "crossbeam-deque" version = "0.8.6" @@ -995,6 +1013,47 @@ dependencies = [ "vmsocket", ] +[[package]] +name = "defmt" +version = "0.3.100" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0963443817029b2024136fc4dd07a5107eb8f977eaf18fcd1fdeb11306b64ad" +dependencies = [ + "defmt 1.0.1", +] + +[[package]] +name = "defmt" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "548d977b6da32fa1d1fda2876453da1e7df63ad0304c8b3dae4dbe7b96f39b78" +dependencies = [ + "bitflags 1.3.2", + "defmt-macros", +] + +[[package]] +name = "defmt-macros" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3d4fc12a85bcf441cfe44344c4b72d58493178ce635338a3f3b78943aceb258e" +dependencies = [ + "defmt-parser", + "proc-macro-error2", + "proc-macro2", + "quote", + "syn 2.0.106", +] + +[[package]] +name = "defmt-parser" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "10d60334b3b2e7c9d91ef8150abfb6fa4c1c39ebbcf4a81c2e346aad939fee3e" +dependencies = [ + "thiserror 2.0.16", +] + [[package]] name = "der" version = "0.7.10" @@ -2796,6 +2855,15 @@ dependencies = [ "crunchy", ] +[[package]] +name = "hash32" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b0c35f58762feb77d74ebe43bdbc3210f09be9fe6742234d573bacc26ed92b67" +dependencies = [ + "byteorder", +] + [[package]] name = "hashbrown" version = "0.15.5" @@ -2885,6 +2953,19 @@ dependencies = [ name = "headervec" version = "0.0.0" +[[package]] +name = "heapless" +version = "0.7.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cdc6457c0eb62c71aac4bc17216026d8410337c4126773b9c5daba343f17964f" +dependencies = [ + "atomic-polyfill", + "hash32", + "rustc_version", + "spin 0.9.8", + "stable_deref_trait", +] + [[package]] name = "heck" version = "0.4.1" @@ -5979,6 +6060,28 @@ dependencies = [ "syn 1.0.109", ] +[[package]] +name = "proc-macro-error-attr2" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "96de42df36bb9bba5542fe9f1a054b8cc87e172759a1868aa05c1f3acc89dfc5" +dependencies = [ + "proc-macro2", + "quote", +] + +[[package]] +name = "proc-macro-error2" +version = "2.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "11ec05c52be0a07b08061f7dd003e7d7092e0472bc731b4af7bb1ef876109802" +dependencies = [ + "proc-macro-error-attr2", + "proc-macro2", + "quote", + "syn 2.0.106", +] + [[package]] name = "proc-macro2" version = "1.0.101" @@ -6855,12 +6958,15 @@ checksum = "67b1b7a3b5fe4f1376887184045fcf45c69e92af734b7aaddc05fb777b6fbd03" [[package]] name = "smoltcp" -version = "0.8.2" +version = "0.9.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ee34c1e1bfc7e9206cc0fb8030a90129b4e319ab53856249bb27642cab914fb3" +checksum = "7e9786ac45091b96f946693e05bfa4d8ca93e2d3341237d97a380107a6b38dea" dependencies = [ "bitflags 1.3.2", "byteorder", + "cfg-if", + "defmt 0.3.100", + "heapless", "managed", ] @@ -6893,6 +6999,9 @@ name = "spin" version = "0.9.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67" +dependencies = [ + "lock_api", +] [[package]] name = "spin" @@ -6922,6 +7031,12 @@ dependencies = [ "der", ] +[[package]] +name = "stable_deref_trait" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6ce2be8dc25455e1f91df71bfa12ad37d7af1092ae736f3a6cd0e37bc7810596" + [[package]] name = "stackfuture" version = "0.3.0" diff --git a/Cargo.toml b/Cargo.toml index 70550ee604..d325bef7cc 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -451,6 +451,7 @@ gptman = "2.0" grep-regex = "0.1" grep-searcher = "0.1" h2 = "0.4" +heapless = "0.7.16" heck = "0.5" hex = "0.4" http = "1" @@ -524,7 +525,7 @@ signal-hook = { version = "0.3", default-features = false } slab = "0.4" smallbox = "0.8" smallvec = "1.8" -smoltcp = { version = "0.8", default-features = false } +smoltcp = { version = "0.9", default-features = false } socket2 = "0.6" spin = "0.10.0" stackfuture = "0.3" diff --git a/vm/devices/net/net_consomme/consomme/Cargo.toml b/vm/devices/net/net_consomme/consomme/Cargo.toml index 3af40b4a53..6dc66b5461 100644 --- a/vm/devices/net/net_consomme/consomme/Cargo.toml +++ b/vm/devices/net/net_consomme/consomme/Cargo.toml @@ -7,16 +7,19 @@ edition.workspace = true rust-version.workspace = true [dependencies] +heapless.workspace = true inspect.workspace = true inspect_counters.workspace = true pal_async.workspace = true futures.workspace = true getrandom.workspace = true -smoltcp = { workspace = true, features = [ "proto-ipv4", "medium-ethernet", "socket-raw", "std", "proto-dhcpv4" ] } +smoltcp = { workspace = true, features = [ "proto-ipv4", "proto-ipv6", "medium-ethernet", "socket-raw", "std", "proto-dhcpv4" ] } socket2.workspace = true thiserror.workspace = true +tracelimit.workspace = true tracing.workspace = true +zerocopy.workspace = true [target.'cfg(unix)'.dependencies] libc.workspace = true diff --git a/vm/devices/net/net_consomme/consomme/src/dhcp.rs b/vm/devices/net/net_consomme/consomme/src/dhcp.rs index 1fa73f1bb4..07044a62d0 100644 --- a/vm/devices/net/net_consomme/consomme/src/dhcp.rs +++ b/vm/devices/net/net_consomme/consomme/src/dhcp.rs @@ -6,6 +6,7 @@ use super::Client; use super::DropReason; use crate::ChecksumState; use crate::MIN_MTU; +use heapless::Vec as HeaplessVec; use smoltcp::phy::ChecksumCapabilities; use smoltcp::wire::DHCP_MAX_DNS_SERVER_COUNT; use smoltcp::wire::DhcpMessageType; @@ -48,25 +49,29 @@ impl Access<'_, T> { } let dns_servers = if self.inner.state.params.nameservers.is_empty() { - None + let dns_servers: HeaplessVec = + HeaplessVec::new(); + Some(dns_servers) } else { - let mut dns_servers = [None; DHCP_MAX_DNS_SERVER_COUNT]; - for (&s, d) in self + let dns_servers: HeaplessVec = self .inner .state .params .nameservers .iter() - .zip(&mut dns_servers) - { - *d = Some(s); - } + .filter_map(|ip| match ip { + IpAddress::Ipv4(addr) => Some(*addr), + _ => None, + }) + .take(DHCP_MAX_DNS_SERVER_COUNT) + .collect::>(); Some(dns_servers) }; let resp_dhcp = if let Some(your_ip) = your_ip { DhcpRepr { message_type, + secs: 0, transaction_id: dhcp_req.transaction_id, client_hardware_address: dhcp_req.client_hardware_address, client_ip: Ipv4Address::UNSPECIFIED, @@ -83,10 +88,14 @@ impl Access<'_, T> { dns_servers, max_size: None, lease_duration: Some(86400), + rebind_duration: None, + renew_duration: None, + additional_options: &[], } } else { DhcpRepr { message_type: DhcpMessageType::Nak, + secs: 0, transaction_id: dhcp_req.transaction_id, client_hardware_address: dhcp_req.client_hardware_address, client_ip: Ipv4Address::UNSPECIFIED, @@ -103,6 +112,9 @@ impl Access<'_, T> { dns_servers: None, max_size: None, lease_duration: None, + rebind_duration: None, + renew_duration: None, + additional_options: &[], } }; @@ -113,7 +125,7 @@ impl Access<'_, T> { let resp_ipv4 = Ipv4Repr { src_addr: self.inner.state.params.gateway_ip, dst_addr: Ipv4Address::BROADCAST, - protocol: IpProtocol::Udp, + next_header: IpProtocol::Udp, payload_len: resp_udp.header_len() + resp_dhcp.buffer_len(), hop_limit: 64, }; diff --git a/vm/devices/net/net_consomme/consomme/src/dhcpv6.rs b/vm/devices/net/net_consomme/consomme/src/dhcpv6.rs new file mode 100644 index 0000000000..3686d38c34 --- /dev/null +++ b/vm/devices/net/net_consomme/consomme/src/dhcpv6.rs @@ -0,0 +1,442 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +use super::Access; +use super::Client; +use super::DropReason; +use crate::ChecksumState; +use crate::MIN_MTU; +use smoltcp::phy::ChecksumCapabilities; +use smoltcp::wire::EthernetFrame; +use smoltcp::wire::EthernetProtocol; +use smoltcp::wire::EthernetRepr; +use smoltcp::wire::IpAddress; +use smoltcp::wire::IpProtocol; +use smoltcp::wire::Ipv6Address; +use smoltcp::wire::Ipv6Packet; +use smoltcp::wire::Ipv6Repr; +use smoltcp::wire::UdpPacket; +use smoltcp::wire::UdpRepr; +use std::collections::HashMap; +use std::mem::size_of; +use thiserror::Error; +use zerocopy::FromBytes; +use zerocopy::Immutable; +use zerocopy::IntoBytes; +use zerocopy::KnownLayout; +use zerocopy::Ref; +use zerocopy::big_endian::U16; + +const DHCPV6_ALL_AGENTS_MULTICAST: Ipv6Address = + Ipv6Address([0xff, 0x02, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 2]); + +// DHCPv6 ports +pub const DHCPV6_SERVER: u16 = 547; +pub const DHCPV6_CLIENT: u16 = 546; + +/// DHCPv6 message types (RFC 8415) +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[repr(u8)] +pub enum MessageType { + InformationRequest = 11, + Reply = 7, + Unknown(u8), +} + +impl MessageType { + fn from_u8(value: u8) -> Self { + match value { + 11 => MessageType::InformationRequest, + 7 => MessageType::Reply, + other => MessageType::Unknown(other), + } + } + + fn to_u8(self) -> u8 { + match self { + MessageType::InformationRequest => 11, + MessageType::Reply => 7, + MessageType::Unknown(v) => v, + } + } +} + +/// DHCPv6 option codes (RFC 8415) +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +#[repr(u16)] +enum OptionCode { + ClientId = 1, + ServerId = 2, + DnsServers = 23, +} + +impl OptionCode { + fn from_u16(value: u16) -> Option { + match value { + 1 => Some(OptionCode::ClientId), + 2 => Some(OptionCode::ServerId), + 23 => Some(OptionCode::DnsServers), + _ => None, + } + } +} + +/// DHCPv6 option +#[derive(Debug, Clone)] +enum DhcpOption { + ClientId(Vec), + ServerId(Vec), + DnsServers(Vec), +} + +/// DHCPv6 message +struct Message { + msg_type: MessageType, + transaction_id: [u8; 3], + options: HashMap, +} + +#[derive(Debug, Error)] +enum DhcpV6Error { + #[error("message too short: {0:#x}")] + MessageTooShort(usize), + #[error("malformed option at offset {0:#x}")] + MalformedOption(usize), + #[error("invalid DNS Server option length {0:#x}")] + InvalidDnsServerOption(usize), +} + +#[repr(C)] +#[derive(FromBytes, IntoBytes, Immutable, KnownLayout)] +struct DhcpV6Header { + msg_type: u8, + transaction_id: [u8; 3], +} + +#[repr(C)] +#[derive(FromBytes, IntoBytes, Immutable, KnownLayout)] +struct DhcpV6Option { + code: U16, + len: U16, +} + +impl Message { + fn new(msg_type: MessageType) -> Self { + Self { + msg_type, + transaction_id: [0; 3], + options: HashMap::new(), + } + } + + fn decode(data: &[u8]) -> Result { + // Parse header using zerocopy + let (header, mut remaining) = Ref::<_, DhcpV6Header>::from_prefix(data) + .map_err(|_| DhcpV6Error::MessageTooShort(data.len()))?; + + let msg_type = MessageType::from_u8(header.msg_type); + let transaction_id = header.transaction_id; + + let mut options = HashMap::new(); + + // Parse options using zerocopy + while remaining.len() >= size_of::() { + let (option_header, rest) = Ref::<_, DhcpV6Option>::from_prefix(remaining) + .map_err(|_| DhcpV6Error::MalformedOption(data.len() - remaining.len()))?; + + let option_code = option_header.code.get(); + let option_len = option_header.len.get() as usize; + + if option_len > rest.len() { + return Err(DhcpV6Error::MalformedOption(data.len() - remaining.len())); + } + + let option_data = &rest[..option_len]; + remaining = &rest[option_len..]; + + if let Some(code) = OptionCode::from_u16(option_code) { + match code { + OptionCode::ClientId => { + options.insert(code, DhcpOption::ClientId(option_data.to_vec())); + } + OptionCode::ServerId => { + options.insert(code, DhcpOption::ServerId(option_data.to_vec())); + } + OptionCode::DnsServers => { + // DNS servers option contains a list of IPv6 addresses (16 bytes each) + if !option_len.is_multiple_of(16) { + return Err(DhcpV6Error::InvalidDnsServerOption(option_len)); + } + let mut dns_servers = Vec::new(); + for i in (0..option_len).step_by(16) { + let mut addr_bytes = [0u8; 16]; + addr_bytes.copy_from_slice(&option_data[i..i + 16]); + dns_servers.push(std::net::Ipv6Addr::from(addr_bytes)); + } + options.insert(code, DhcpOption::DnsServers(dns_servers)); + } + } + } + // Skip unknown options + } + + Ok(Self { + msg_type, + transaction_id, + options, + }) + } + + fn encode(&self) -> Vec { + let mut buffer = Vec::new(); + + // Message type (1 byte) + transaction ID (3 bytes) + buffer.push(self.msg_type.to_u8()); + buffer.extend_from_slice(&self.transaction_id); + + // Encode options + for (code, option) in &self.options { + let code_bytes = (*code as u16).to_be_bytes(); + buffer.extend_from_slice(&code_bytes); + + match option { + DhcpOption::ClientId(data) | DhcpOption::ServerId(data) => { + let len_bytes = (data.len() as u16).to_be_bytes(); + buffer.extend_from_slice(&len_bytes); + buffer.extend_from_slice(data); + } + DhcpOption::DnsServers(servers) => { + let len = (servers.len() * 16) as u16; + let len_bytes = len.to_be_bytes(); + buffer.extend_from_slice(&len_bytes); + for server in servers { + buffer.extend_from_slice(&server.octets()); + } + } + } + } + + buffer + } + + fn set_transaction_id(&mut self, xid: [u8; 3]) { + self.transaction_id = xid; + } + + fn insert_option(&mut self, option: DhcpOption) { + let code = match &option { + DhcpOption::ClientId(_) => OptionCode::ClientId, + DhcpOption::ServerId(_) => OptionCode::ServerId, + DhcpOption::DnsServers(_) => OptionCode::DnsServers, + }; + self.options.insert(code, option); + } + + fn get_option(&self, code: OptionCode) -> Option<&DhcpOption> { + self.options.get(&code) + } +} + +impl Access<'_, T> { + pub(crate) fn handle_dhcpv6( + &mut self, + payload: &[u8], + client_ip: Option, + ) -> Result<(), DropReason> { + // Parse the DHCPv6 message + let msg = Message::decode(payload).map_err(|e| { + tracing::info!(error = %e, "failed to decode DHCPv6 message"); + DropReason::MalformedPacket + })?; + + match msg.msg_type { + MessageType::InformationRequest => { + // Build DHCPv6 Reply response + let mut reply = Message::new(MessageType::Reply); + reply.set_transaction_id(msg.transaction_id); + + // Add Client Identifier option (echo back from the InformationRequest) + if let Some(DhcpOption::ClientId(client_id)) = msg.get_option(OptionCode::ClientId) + { + reply.insert_option(DhcpOption::ClientId(client_id.clone())); + } + + // Add Server Identifier option + // Use DUID-LL (type 3: Link-layer address) + let gateway_mac = self.inner.state.params.gateway_mac_ipv6.0; + let mut duid_bytes = vec![0x00, 0x03, 0x00, 0x01]; // Type 3 (LL), Hardware type 1 (Ethernet) + duid_bytes.extend_from_slice(&gateway_mac); + reply.insert_option(DhcpOption::ServerId(duid_bytes)); + + // Add DNS Name Server option if we have nameservers + let dns_servers: Vec = self + .inner + .state + .params + .nameservers + .iter() + .filter_map(|ip| match ip { + IpAddress::Ipv6(addr) => Some(*addr), + _ => None, + }) + .filter(|addr| { + !(addr.is_unspecified() + || addr.is_loopback() + || addr.is_link_local() + || addr.is_multicast() + || matches!(addr.0[0], 0xfc | 0xfd) // Is unique local address + || addr.0.starts_with(&[0xfe, 0xc0])) // Is synthetic DNS server + }) + .map(|addr| addr.into()) + .collect(); + + if !dns_servers.is_empty() { + reply.insert_option(DhcpOption::DnsServers(dns_servers)); + } + + let dhcpv6_buffer = reply.encode(); + + let resp_udp = UdpRepr { + src_port: DHCPV6_SERVER, + dst_port: DHCPV6_CLIENT, + }; + + let client_link_local = client_ip.unwrap_or(DHCPV6_ALL_AGENTS_MULTICAST); + let resp_ipv6 = Ipv6Repr { + src_addr: self.inner.state.params.gateway_link_local_ipv6, + dst_addr: client_link_local, + next_header: IpProtocol::Udp, + payload_len: resp_udp.header_len() + dhcpv6_buffer.len(), + hop_limit: 64, + }; + let resp_eth = EthernetRepr { + src_addr: self.inner.state.params.gateway_mac_ipv6, + dst_addr: self.inner.state.params.client_mac, + ethertype: EthernetProtocol::Ipv6, + }; + + // Construct the complete packet + let mut buffer = [0; MIN_MTU]; + let mut eth_frame = EthernetFrame::new_unchecked(&mut buffer); + resp_eth.emit(&mut eth_frame); + + let mut ipv6_packet = Ipv6Packet::new_unchecked(eth_frame.payload_mut()); + resp_ipv6.emit(&mut ipv6_packet); + + let mut udp_packet = UdpPacket::new_unchecked(ipv6_packet.payload_mut()); + resp_udp.emit( + &mut udp_packet, + &IpAddress::Ipv6(resp_ipv6.src_addr), + &IpAddress::Ipv6(resp_ipv6.dst_addr), + dhcpv6_buffer.len(), + |udp_payload| { + udp_payload[..dhcpv6_buffer.len()].copy_from_slice(&dhcpv6_buffer); + }, + &ChecksumCapabilities::default(), + ); + + let total_len = resp_eth.buffer_len() + + resp_ipv6.buffer_len() + + resp_udp.header_len() + + dhcpv6_buffer.len(); + + self.client.recv(&buffer[..total_len], &ChecksumState::NONE); + } + _ => return Err(DropReason::UnsupportedDhcpv6(msg.msg_type)), + } + + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + /// Helper function to convert a hex string to bytes + fn hex_to_bytes(hex: &str) -> Vec { + (0..hex.len()) + .step_by(2) + .map(|i| u8::from_str_radix(&hex[i..i + 2], 16).unwrap()) + .collect() + } + + /// Helper function to convert bytes to hex string + fn bytes_to_hex(bytes: &[u8]) -> String { + bytes + .iter() + .map(|b| format!("{:02x}", b)) + .collect::>() + .join("") + } + + /// Helper function to create IPv6 address from hex string + fn hex_to_ipv6(hex: &str) -> std::net::Ipv6Addr { + std::net::Ipv6Addr::from(<[u8; 16]>::try_from(hex_to_bytes(hex).as_slice()).unwrap()) + } + + #[test] + fn test_message_decode() { + // This is a sample DHCPv6 InformationRequest message that was captured from a VM. + let input_hex = "0b1c57ca0008000200000001000e0001000130adec9800155d300e150010000e0000013700084d53465420352e30000600080011001700180020"; + let input_bytes = hex_to_bytes(input_hex); + let msg = Message::decode(&input_bytes).expect("Failed to decode message"); + + assert_eq!(msg.msg_type, MessageType::InformationRequest); + assert_eq!(msg.transaction_id, [0x1c, 0x57, 0xca]); + let client_id = "0001000130adec9800155d300e15"; + if let Some(DhcpOption::ClientId(data)) = msg.get_option(OptionCode::ClientId) { + assert_eq!(bytes_to_hex(data), client_id); + } else { + panic!("ClientId option not found"); + } + } + + #[test] + fn test_message_encode() { + const CLIENT_ID_HEX: &str = "0001000130adec9800155d300e15"; + const SERVER_ID_HEX: &str = "0003000152550a000102"; + const DNS1_HEX: &str = "20014898000000000000000010501050"; + const DNS2_HEX: &str = "20014898000000000000000010505050"; + const TRANSACTION_ID: [u8; 3] = [0x1c, 0x57, 0xca]; + + // Create a message with all option types + let mut msg = Message::new(MessageType::Reply); + msg.set_transaction_id(TRANSACTION_ID); + msg.insert_option(DhcpOption::ClientId(hex_to_bytes(CLIENT_ID_HEX))); + msg.insert_option(DhcpOption::ServerId(hex_to_bytes(SERVER_ID_HEX))); + + let dns_servers = vec![hex_to_ipv6(DNS1_HEX), hex_to_ipv6(DNS2_HEX)]; + msg.insert_option(DhcpOption::DnsServers(dns_servers.clone())); + + // Encode and decode to verify round-trip + let decoded = Message::decode(&msg.encode()).expect("Failed to decode encoded message"); + + assert_eq!(decoded.msg_type, MessageType::Reply); + assert_eq!(decoded.transaction_id, TRANSACTION_ID); + + let DhcpOption::ClientId(data) = decoded + .get_option(OptionCode::ClientId) + .expect("ClientId not found") + else { + panic!("Wrong option type for ClientId"); + }; + assert_eq!(bytes_to_hex(data), CLIENT_ID_HEX); + + let DhcpOption::ServerId(data) = decoded + .get_option(OptionCode::ServerId) + .expect("ServerId not found") + else { + panic!("Wrong option type for ServerId"); + }; + assert_eq!(bytes_to_hex(data), SERVER_ID_HEX); + + let DhcpOption::DnsServers(servers) = decoded + .get_option(OptionCode::DnsServers) + .expect("DnsServers not found") + else { + panic!("Wrong option type for DnsServers"); + }; + assert_eq!(servers, &dns_servers); + } +} diff --git a/vm/devices/net/net_consomme/consomme/src/dns_unix.rs b/vm/devices/net/net_consomme/consomme/src/dns_unix.rs index 14b0320e89..3cc0126285 100644 --- a/vm/devices/net/net_consomme/consomme/src/dns_unix.rs +++ b/vm/devices/net/net_consomme/consomme/src/dns_unix.rs @@ -2,7 +2,9 @@ // Licensed under the MIT License. use resolv_conf::ScopedIp; +use smoltcp::wire::IpAddress; use smoltcp::wire::Ipv4Address; +use smoltcp::wire::Ipv6Address; use thiserror::Error; #[derive(Debug, Error)] @@ -13,15 +15,16 @@ pub enum Error { Parse(#[from] resolv_conf::ParseError), } -pub fn nameservers() -> Result, Error> { +pub fn nameservers() -> Result, Error> { let contents = std::fs::read("/etc/resolv.conf")?; let config = resolv_conf::Config::parse(contents)?; Ok(config .nameservers .iter() .filter_map(|ns| match ns { - ScopedIp::V4(addr) => Some(Ipv4Address::from(*addr)), - ScopedIp::V6(_, _) => None, + ScopedIp::V4(addr) => Some(IpAddress::Ipv4(Ipv4Address::from(*addr))), + ScopedIp::V6(addr, None) => Some(IpAddress::Ipv6(Ipv6Address::from(*addr))), + ScopedIp::V6(_, Some(_)) => None, }) .collect()) } diff --git a/vm/devices/net/net_consomme/consomme/src/dns_windows.rs b/vm/devices/net/net_consomme/consomme/src/dns_windows.rs index 1ef3fae50c..3e6494b496 100644 --- a/vm/devices/net/net_consomme/consomme/src/dns_windows.rs +++ b/vm/devices/net/net_consomme/consomme/src/dns_windows.rs @@ -4,10 +4,11 @@ // UNSAFETY: Calling Win32 APIs to get DNS server information. #![expect(unsafe_code)] -use smoltcp::wire::Ipv4Address; +use smoltcp::wire::IpAddress; use std::alloc::Layout; use std::io; use std::net::Ipv4Addr; +use std::net::Ipv6Addr; use std::ptr::NonNull; use std::ptr::null_mut; use thiserror::Error; @@ -21,7 +22,10 @@ use windows_sys::Win32::NetworkManagement::IpHelper::GAA_FLAG_SKIP_UNICAST; use windows_sys::Win32::NetworkManagement::IpHelper::GetAdaptersAddresses; use windows_sys::Win32::NetworkManagement::IpHelper::IP_ADAPTER_ADDRESSES_LH; use windows_sys::Win32::Networking::WinSock::AF_INET; +use windows_sys::Win32::Networking::WinSock::AF_INET6; +use windows_sys::Win32::Networking::WinSock::AF_UNSPEC; use windows_sys::Win32::Networking::WinSock::SOCKADDR_IN; +use windows_sys::Win32::Networking::WinSock::SOCKADDR_IN6; #[derive(Debug, Error)] pub enum Error { @@ -29,7 +33,7 @@ pub enum Error { AdapterAddresses(#[source] io::Error), } -pub fn nameservers() -> Result, Error> { +pub fn nameservers() -> Result, Error> { let flags = GAA_FLAG_SKIP_UNICAST | GAA_FLAG_SKIP_ANYCAST | GAA_FLAG_SKIP_MULTICAST @@ -45,8 +49,13 @@ pub fn nameservers() -> Result, Error> { let mut addrs = Addresses::new(0); loop { let mut size = addrs.size(); - let r = - GetAdaptersAddresses(AF_INET.into(), flags, null_mut(), addrs.as_ptr(), &mut size); + let r = GetAdaptersAddresses( + AF_UNSPEC.into(), + flags, + null_mut(), + addrs.as_ptr(), + &mut size, + ); match r { ERROR_SUCCESS => break, ERROR_BUFFER_OVERFLOW => {} @@ -72,8 +81,12 @@ pub fn nameservers() -> Result, Error> { let dns_addr = &*dns.Address.lpSockaddr; if dns_addr.sa_family == AF_INET { let dns_addr = &*dns.Address.lpSockaddr.cast::(); - dns_servers - .push(Ipv4Addr::from(u32::from_be(dns_addr.sin_addr.S_un.S_addr)).into()); + let ipv4_addr = Ipv4Addr::from(u32::from_be(dns_addr.sin_addr.S_un.S_addr)); + dns_servers.push(ipv4_addr.into()); + } else if dns_addr.sa_family == AF_INET6 { + let dns_addr = &*dns.Address.lpSockaddr.cast::(); + let ipv6_addr = Ipv6Addr::from(u128::from_be_bytes(dns_addr.sin6_addr.u.Byte)); + dns_servers.push(ipv6_addr.into()); } dns_p = dns.Next; } diff --git a/vm/devices/net/net_consomme/consomme/src/icmp.rs b/vm/devices/net/net_consomme/consomme/src/icmp.rs index afe1d05f4f..1dead5d936 100644 --- a/vm/devices/net/net_consomme/consomme/src/icmp.rs +++ b/vm/devices/net/net_consomme/consomme/src/icmp.rs @@ -8,7 +8,6 @@ use super::Access; use super::Client; use super::ConsommeState; use super::DropReason; -use super::SocketAddress; use crate::ChecksumState; use crate::Ipv4Addresses; @@ -36,13 +35,14 @@ use std::mem::MaybeUninit; use std::net::IpAddr; use std::net::Ipv4Addr; use std::net::SocketAddr; +use std::net::SocketAddrV4; use std::task::Context; use std::task::Poll; const ICMPV4_HEADER_LEN: usize = 8; pub(crate) struct Icmp { - connections: HashMap, + connections: HashMap, } impl Icmp { @@ -57,7 +57,7 @@ impl Inspect for Icmp { fn inspect(&self, req: inspect::Request<'_>) { let mut resp = req.respond(); for (addr, conn) in &self.connections { - resp.field(&format!("{}:{}", addr.ip, addr.port), conn); + resp.field(&format!("{}:{}", addr.ip(), addr.port()), conn); } } } @@ -83,7 +83,7 @@ impl IcmpConnection { fn poll_conn( &mut self, cx: &mut Context<'_>, - dst_addr: &SocketAddress, + dst_addr: &SocketAddrV4, state: &mut ConsommeState, client: &mut impl Client, ) { @@ -106,7 +106,7 @@ impl IcmpConnection { eth.set_src_addr(state.params.gateway_mac); eth.set_dst_addr(self.guest_mac); let mut ipv4 = Ipv4Packet::new_unchecked(eth.payload_mut()); - ipv4.set_dst_addr(dst_addr.ip); + ipv4.set_dst_addr((*dst_addr.ip()).into()); ipv4.fill_checksum(); let len = ETHERNET_HEADER_LEN + n; client.recv(ð.as_ref()[..len], &ChecksumState::IPV4_ONLY); @@ -160,10 +160,7 @@ impl Access<'_, T> { hop_limit: u8, ) -> Result<(), DropReason> { let icmp_packet = smoltcp::wire::Icmpv4Packet::new_unchecked(payload); - let guest_addr = SocketAddress { - ip: addresses.src_addr, - port: 0, - }; + let guest_addr = SocketAddrV4::new(addresses.src_addr.into(), 0); let entry = self.inner.icmp.connections.entry(guest_addr); let conn = match entry { diff --git a/vm/devices/net/net_consomme/consomme/src/lib.rs b/vm/devices/net/net_consomme/consomme/src/lib.rs index 37f6e39949..41e9fef8d0 100644 --- a/vm/devices/net/net_consomme/consomme/src/lib.rs +++ b/vm/devices/net/net_consomme/consomme/src/lib.rs @@ -16,10 +16,12 @@ mod arp; mod dhcp; +mod dhcpv6; #[cfg_attr(unix, path = "dns_unix.rs")] #[cfg_attr(windows, path = "dns_windows.rs")] mod dns; mod icmp; +mod ndp; mod tcp; mod udp; mod windows; @@ -35,10 +37,13 @@ use smoltcp::wire::EthernetFrame; use smoltcp::wire::EthernetProtocol; use smoltcp::wire::EthernetRepr; use smoltcp::wire::IPV4_HEADER_LEN; +use smoltcp::wire::Icmpv6Packet; +use smoltcp::wire::IpAddress; use smoltcp::wire::IpProtocol; use smoltcp::wire::Ipv4Address; use smoltcp::wire::Ipv4Packet; -use std::net::SocketAddrV4; +use smoltcp::wire::Ipv6Address; +use smoltcp::wire::Ipv6Packet; use std::task::Context; use thiserror::Error; @@ -79,7 +84,26 @@ pub struct ConsommeParams { pub client_mac: EthernetAddress, /// Current list of DNS resolvers. #[inspect(with = "|x| inspect::iter_by_index(x).map_value(inspect::AsDisplay)")] - pub nameservers: Vec, + pub nameservers: Vec, + /// Current IPv6 network mask (if any). + #[inspect(display)] + pub prefix_len_ipv6: u8, + /// Current IPv6 gateway MAC address (if any). + #[inspect(display)] + pub gateway_mac_ipv6: EthernetAddress, + /// Gateway's link-local IPv6 address (derived from gateway_mac_ipv6). + /// + /// This is the address used as the source for NDP Router Advertisements + /// and as the target for Neighbor Solicitations. + #[inspect(display)] + pub gateway_link_local_ipv6: Ipv6Address, + /// Current IPv6 address learned from guest via SLAAC (if any). + /// + /// With SLAAC (Stateless Address Autoconfiguration), the guest generates + /// its own IPv6 address using the advertised prefix and its interface identifier. + /// This field is learned from incoming IPv6 traffic from the guest. + #[inspect(with = "Option::is_some")] + pub client_ip_ipv6: Option, } /// An error indicating that the CIDR is invalid. @@ -91,9 +115,12 @@ impl ConsommeParams { /// Create default dynamic network state. The default state is /// IP address: 10.0.0.2 / 24 /// gateway: 10.0.0.1 with MAC address 52-55-10-0-0-1 - /// no DNS resolvers + /// IPv6 address: is not assigned by us, we expect the guest to assign it via SLAAC + /// gateway ipv6 link local address: fe80::5055:aff:fe00:102 with MAC address 52-55-0A-00-01-02 pub fn new() -> Result { let nameservers = dns::nameservers()?; + let gateway_mac_ipv6 = EthernetAddress([0x52, 0x55, 0x0A, 0x00, 0x01, 0x02]); + Ok(Self { gateway_ip: Ipv4Address::new(10, 0, 0, 1), gateway_mac: EthernetAddress([0x52, 0x55, 10, 0, 0, 1]), @@ -101,6 +128,10 @@ impl ConsommeParams { client_mac: EthernetAddress([0x0, 0x0, 0x0, 0x0, 0x1, 0x0]), net_mask: Ipv4Address::new(255, 255, 255, 0), nameservers, + prefix_len_ipv6: 64, + gateway_mac_ipv6, + gateway_link_local_ipv6: Self::compute_link_local_address(gateway_mac_ipv6), + client_ip_ipv6: None, }) } @@ -118,6 +149,38 @@ impl ConsommeParams { self.net_mask = cidr.netmask(); Ok(()) } + + /// Compute a link-local IPv6 address from a MAC address using EUI-64 format. + /// + /// RFC 4291 Section 2.5.6: Link-local addresses are formed by combining + /// the link-local prefix (fe80::/64) with an interface identifier derived + /// from the MAC address using the EUI-64 format. + /// + /// EUI-64 format (RFC 2464 Section 4): + /// - Insert 0xFFFE in the middle of the 48-bit MAC address + /// - Invert the universal/local bit (bit 6 of the first byte) + pub fn compute_link_local_address(mac: EthernetAddress) -> Ipv6Address { + const LINK_LOCAL_PREFIX: [u8; 8] = [0xfe, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00]; + + let mut addr = [0u8; 16]; + + // Set link-local prefix (fe80::/64) + addr[0..8].copy_from_slice(&LINK_LOCAL_PREFIX); + + // Create EUI-64 interface identifier from MAC address + // MAC: AB:CD:EF:11:22:33 + // EUI-64: AB:CD:EF:FF:FE:11:22:33 with universal/local bit flipped + addr[8] = mac.0[0] ^ 0x02; // Flip the universal/local bit + addr[9] = mac.0[1]; + addr[10] = mac.0[2]; + addr[11] = 0xFF; + addr[12] = 0xFE; + addr[13] = mac.0[3]; + addr[14] = mac.0[4]; + addr[15] = mac.0[5]; + + Ipv6Address(addr) + } } /// An accessor for consomme. @@ -201,6 +264,12 @@ impl ChecksumState { udp: true, tso: None, }; + const TCP6: Self = Self { + ipv4: false, + tcp: true, + udp: false, + tso: None, + }; fn caps(&self) -> ChecksumCapabilities { let mut caps = ChecksumCapabilities::default(); @@ -221,36 +290,12 @@ impl ChecksumState { /// frame). pub const MIN_MTU: usize = 1514; -#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)] -struct SocketAddress { - ip: Ipv4Address, - port: u16, -} - -impl From for SocketAddrV4 { - fn from(addr: SocketAddress) -> Self { - Self::new(addr.ip.into(), addr.port) - } -} - -impl From for socket2::SockAddr { - fn from(addr: SocketAddress) -> Self { - socket2::SockAddr::from(SocketAddrV4::from(addr)) - } -} - -#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)] -struct FourTuple { - dst: SocketAddress, - src: SocketAddress, -} - /// The reason a packet was dropped without being handled. #[derive(Debug, Error)] pub enum DropReason { /// The packet could not be parsed. #[error("packet parsing error")] - Packet(#[from] smoltcp::Error), + Packet(#[from] smoltcp::wire::Error), /// The ethertype is unknown. #[error("unsupported ethertype {0}")] UnsupportedEthertype(EthernetProtocol), @@ -278,6 +323,20 @@ pub enum DropReason { /// Specified port is not bound. #[error("port is not bound")] PortNotBound, + /// The DHCPv6 message type is unsupported. + #[error("unsupported dhcpv6 message type {0:?}")] + UnsupportedDhcpv6(dhcpv6::MessageType), + /// The NDP message type is unsupported. + #[error("unsupported ndp message type {0:?}")] + UnsupportedNdp(ndp::NdpMessageType), + /// An incoming packet was recognized but was self-contradictory. + /// E.g. a TCP packet with both SYN and FIN flags set. + #[error("packet is malformed")] + MalformedPacket, + /// An incoming IP packet has been split into several IP fragments and was dropped, + /// since IP reassembly is not supported. + #[error("packet fragmentation is not supported")] + FragmentedPacket, } /// An error to create a consomme instance. @@ -294,6 +353,34 @@ struct Ipv4Addresses { dst_addr: Ipv4Address, } +#[derive(Debug)] +struct Ipv6Addresses { + src_addr: Ipv6Address, + dst_addr: Ipv6Address, +} + +#[derive(Debug)] +enum IpAddresses { + V4(Ipv4Addresses), + V6(Ipv6Addresses), +} + +impl IpAddresses { + fn src_addr(&self) -> IpAddress { + match self { + IpAddresses::V4(addrs) => IpAddress::Ipv4(addrs.src_addr), + IpAddresses::V6(addrs) => IpAddress::Ipv6(addrs.src_addr), + } + } + + fn dst_addr(&self) -> IpAddress { + match self { + IpAddresses::V4(addrs) => IpAddress::Ipv4(addrs.dst_addr), + IpAddresses::V6(addrs) => IpAddress::Ipv6(addrs.dst_addr), + } + } +} + impl Consomme { /// Creates a new consomme instance with specified state. pub fn new(params: ConsommeParams) -> Self { @@ -385,6 +472,7 @@ impl Access<'_, T> { let frame = EthernetRepr::parse(&frame_packet)?; match frame.ethertype { EthernetProtocol::Ipv4 => self.handle_ipv4(&frame, frame_packet.payload(), checksum)?, + EthernetProtocol::Ipv6 => self.handle_ipv6(&frame, frame_packet.payload(), checksum)?, EthernetProtocol::Arp => self.handle_arp(&frame, frame_packet.payload())?, _ => return Err(DropReason::UnsupportedEthertype(frame.ethertype)), } @@ -403,7 +491,7 @@ impl Access<'_, T> { || payload.len() < ipv4.header_len().into() || payload.len() < ipv4.total_len().into() { - return Err(DropReason::Packet(smoltcp::Error::Malformed)); + return Err(DropReason::MalformedPacket); } let total_len = if checksum.tso.is_some() { @@ -412,11 +500,11 @@ impl Access<'_, T> { ipv4.total_len().into() }; if total_len < ipv4.header_len().into() { - return Err(DropReason::Packet(smoltcp::Error::Malformed)); + return Err(DropReason::MalformedPacket); } if ipv4.more_frags() || ipv4.frag_offset() != 0 { - return Err(DropReason::Packet(smoltcp::Error::Fragmented)); + return Err(DropReason::FragmentedPacket); } if !checksum.ipv4 && !ipv4.verify_checksum() { @@ -430,9 +518,11 @@ impl Access<'_, T> { let inner = &payload[ipv4.header_len().into()..total_len]; - match ipv4.protocol() { - IpProtocol::Tcp => self.handle_tcp(&addresses, inner, checksum)?, - IpProtocol::Udp => self.handle_udp(frame, &addresses, inner, checksum)?, + match ipv4.next_header() { + IpProtocol::Tcp => self.handle_tcp(&IpAddresses::V4(addresses), inner, checksum)?, + IpProtocol::Udp => { + self.handle_udp(frame, &IpAddresses::V4(addresses), inner, checksum)? + } IpProtocol::Icmp => { self.handle_icmp(frame, &addresses, inner, checksum, ipv4.hop_limit())? } @@ -440,4 +530,54 @@ impl Access<'_, T> { }; Ok(()) } + + fn handle_ipv6( + &mut self, + frame: &EthernetRepr, + payload: &[u8], + checksum: &ChecksumState, + ) -> Result<(), DropReason> { + let ipv6 = Ipv6Packet::new_unchecked(payload); + if payload.len() < smoltcp::wire::IPV6_HEADER_LEN || ipv6.version() != 6 { + return Err(DropReason::MalformedPacket); + } + + let required_len = smoltcp::wire::IPV6_HEADER_LEN + ipv6.payload_len() as usize; + if payload.len() < required_len { + return Err(DropReason::MalformedPacket); + } + + //TODO: Walk extension headers. + let next_header = ipv6.next_header(); + let inner = &payload[smoltcp::wire::IPV6_HEADER_LEN..]; + let addresses = Ipv6Addresses { + src_addr: ipv6.src_addr(), + dst_addr: ipv6.dst_addr(), + }; + + match next_header { + IpProtocol::Udp => { + self.handle_udp(frame, &IpAddresses::V6(addresses), inner, checksum)? + } + IpProtocol::Tcp => self.handle_tcp(&IpAddresses::V6(addresses), inner, checksum)?, + IpProtocol::Icmpv6 => { + // Check if this is an NDP packet + let icmpv6_packet = Icmpv6Packet::new_unchecked(inner); + let msg_type = icmpv6_packet.msg_type(); + + if msg_type == smoltcp::wire::Icmpv6Message::NeighborSolicit + || msg_type == smoltcp::wire::Icmpv6Message::NeighborAdvert + || msg_type == smoltcp::wire::Icmpv6Message::RouterSolicit + || msg_type == smoltcp::wire::Icmpv6Message::RouterAdvert + { + self.handle_ndp(frame, inner, ipv6.src_addr())?; + } else { + return Err(DropReason::UnsupportedIpProtocol(next_header)); + } + } + + p => return Err(DropReason::UnsupportedIpProtocol(p)), + }; + Ok(()) + } } diff --git a/vm/devices/net/net_consomme/consomme/src/ndp.rs b/vm/devices/net/net_consomme/consomme/src/ndp.rs new file mode 100644 index 0000000000..047f6bce75 --- /dev/null +++ b/vm/devices/net/net_consomme/consomme/src/ndp.rs @@ -0,0 +1,379 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +//! NDP (Neighbor Discovery Protocol) implementation for IPv6 SLAAC (Stateless Address Autoconfiguration) +//! +//! This module implements RFC 4861 (Neighbor Discovery) and RFC 4862 (IPv6 Stateless Address Autoconfiguration). +//! The implementation is stateless - we advertise prefixes via Router Advertisements and let clients +//! autoconfigure their own addresses using SLAAC. + +use super::Access; +use super::Client; +use super::DropReason; +use crate::ChecksumState; +use crate::MIN_MTU; +use smoltcp::phy::Medium; +use smoltcp::wire::EthernetAddress; +use smoltcp::wire::EthernetFrame; +use smoltcp::wire::EthernetProtocol; +use smoltcp::wire::EthernetRepr; +use smoltcp::wire::HardwareAddress; +use smoltcp::wire::Icmpv6Packet; +use smoltcp::wire::IpAddress; +use smoltcp::wire::IpProtocol; +use smoltcp::wire::Ipv6Address; +use smoltcp::wire::Ipv6Packet; +use smoltcp::wire::Ipv6Repr; +use smoltcp::wire::NdiscNeighborFlags; +use smoltcp::wire::NdiscPrefixInfoFlags; +use smoltcp::wire::NdiscPrefixInformation; +use smoltcp::wire::NdiscRepr; +use smoltcp::wire::NdiscRouterFlags; +use smoltcp::wire::RawHardwareAddress; + +const NETWORK_PREFIX_BASE: Ipv6Address = Ipv6Address::new(0x2001, 0xabcd, 0, 0, 0, 0, 0, 0); + +#[derive(Debug)] +pub enum NdpMessageType { + RouterSolicit, + RouterAdvert, + NeighborSolicit, + NeighborAdvert, + Redirect, +} + +impl Access<'_, T> { + /// Handle NDP messages from the guest + pub(crate) fn handle_ndp( + &mut self, + frame: &EthernetRepr, + payload: &[u8], + ipv6_src_addr: Ipv6Address, + ) -> Result<(), DropReason> { + let icmpv6_packet = Icmpv6Packet::new_unchecked(payload); + let ndp = NdiscRepr::parse(&icmpv6_packet)?; + + match ndp { + NdiscRepr::RouterSolicit { lladdr } => { + self.handle_router_solicit(frame, ipv6_src_addr, lladdr) + } + NdiscRepr::NeighborSolicit { + target_addr, + lladdr: source_lladdr, + } => self.handle_neighbor_solicit(frame, ipv6_src_addr, target_addr, source_lladdr), + NdiscRepr::NeighborAdvert { .. } => { + tracing::trace!("received unsolicited Neighbor Advertisement, ignoring"); + Ok(()) + } + NdiscRepr::RouterAdvert { .. } => { + tracing::trace!("received Router Advertisement, ignoring"); + Ok(()) + } + NdiscRepr::Redirect { .. } => { + tracing::trace!("received Redirect, ignoring"); + Ok(()) + } + } + } + + /// Handle Router Solicitation (RFC 4861 Section 6.2.6) + /// + /// Router Solicitations are sent by hosts to discover routers on the link. + /// We respond with a Router Advertisement containing prefix information for SLAAC. + fn handle_router_solicit( + &mut self, + frame: &EthernetRepr, + ipv6_src_addr: Ipv6Address, + lladdr: Option, + ) -> Result<(), DropReason> { + // RFC 4861 Section 6.1.1: Validate source link-layer address option + // If source is unspecified (::), there must be no source link-layer address option + if ipv6_src_addr.is_unspecified() && lladdr.is_some() { + tracelimit::warn_ratelimited!( + "invalid RS: source is :: but source link-layer address present" + ); + return Err(DropReason::MalformedPacket); + } + + // Verify this is from the expected client MAC (if link-layer address is provided) + if let Some(lladdr) = lladdr { + if let Ok(hw_addr) = lladdr.parse(Medium::Ethernet) { + let HardwareAddress::Ethernet(eth_addr) = hw_addr; + if eth_addr != self.inner.state.params.client_mac { + tracelimit::warn_ratelimited!( + "Router Solicitation from unexpected MAC, ignoring" + ); + return Ok(()); + } + } + } + + // Determine destination address for the reply + // RFC 4861 Section 6.2.6: If RS has source link-layer address option, + // unicast to source. Otherwise, use all-nodes multicast. + let reply_dst_addr = if lladdr.is_some() && !ipv6_src_addr.is_unspecified() { + ipv6_src_addr + } else { + Ipv6Address::LINK_LOCAL_ALL_NODES + }; + + // Determine Ethernet destination + let eth_dst_addr = if reply_dst_addr.is_multicast() { + // Multicast IPv6 to Ethernet address mapping (RFC 2464) + // 33:33:xx:xx:xx:xx where xx:xx:xx:xx are the low-order 32 bits of the IPv6 multicast address + EthernetAddress([ + 0x33, + 0x33, + reply_dst_addr.0[12], + reply_dst_addr.0[13], + reply_dst_addr.0[14], + reply_dst_addr.0[15], + ]) + } else { + frame.src_addr + }; + + self.send_router_advertisement(reply_dst_addr, eth_dst_addr) + } + + /// Send a Router Advertisement (RFC 4861 Section 4.2) + /// + /// Router Advertisements contain prefix information for SLAAC. Clients will use + /// the advertised prefix to generate their own IPv6 addresses. + fn send_router_advertisement( + &mut self, + dst_addr: Ipv6Address, + eth_dst_addr: EthernetAddress, + ) -> Result<(), DropReason> { + // Compute the network prefix from our configured IPv6 parameters + // This is the prefix that clients will use for SLAAC + let prefix = self + .compute_network_prefix(NETWORK_PREFIX_BASE, self.inner.state.params.prefix_len_ipv6); + + // RFC 4861 Section 4.6.2: Router Advertisement with Prefix Information + // We set the ADDRCONF flag to enable SLAAC and ON_LINK flag to indicate + // that addresses with this prefix are on-link. + let ndp_repr = NdiscRepr::RouterAdvert { + hop_limit: 255, + flags: NdiscRouterFlags::empty(), + router_lifetime: smoltcp::time::Duration::from_secs(9000), // https://www.rfc-editor.org/rfc/rfc4861#section-4.2 + reachable_time: smoltcp::time::Duration::from_millis(30000), // https://www.rfc-editor.org/rfc/rfc4861#section-6 + retrans_time: smoltcp::time::Duration::from_millis(1000), // https://www.rfc-editor.org/rfc/rfc4861#section-6 + lladdr: Some(RawHardwareAddress::from( + self.inner.state.params.gateway_mac_ipv6, + )), + mtu: None, + prefix_info: Some(NdiscPrefixInformation { + prefix_len: self.inner.state.params.prefix_len_ipv6, + prefix, + valid_lifetime: smoltcp::time::Duration::from_secs(2592000), // https://www.rfc-editor.org/rfc/rfc4861#section-6.2.1 + preferred_lifetime: smoltcp::time::Duration::from_secs(604800), // https://www.rfc-editor.org/rfc/rfc4861#section-6.2.1 + flags: NdiscPrefixInfoFlags::ON_LINK | NdiscPrefixInfoFlags::ADDRCONF, + }), + }; + + // Build IPv6 header + let ipv6_repr = Ipv6Repr { + src_addr: self.inner.state.params.gateway_link_local_ipv6, + dst_addr, + next_header: IpProtocol::Icmpv6, + payload_len: ndp_repr.buffer_len(), + hop_limit: 255, // Router advertisements must have a hop limit of 255 to indicate the packet was not forwarded by another router. + }; + + let eth_repr = EthernetRepr { + src_addr: self.inner.state.params.gateway_mac_ipv6, + dst_addr: eth_dst_addr, + ethertype: EthernetProtocol::Ipv6, + }; + + let mut buffer = [0; MIN_MTU]; + let mut eth_frame = EthernetFrame::new_unchecked(&mut buffer); + eth_repr.emit(&mut eth_frame); + + let mut ipv6_packet = Ipv6Packet::new_unchecked(eth_frame.payload_mut()); + ipv6_repr.emit(&mut ipv6_packet); + + let mut icmpv6_packet = Icmpv6Packet::new_unchecked(ipv6_packet.payload_mut()); + ndp_repr.emit(&mut icmpv6_packet); + icmpv6_packet.fill_checksum( + &IpAddress::Ipv6(ipv6_repr.src_addr), + &IpAddress::Ipv6(ipv6_repr.dst_addr), + ); + + let total_len = eth_repr.buffer_len() + ipv6_repr.buffer_len() + ndp_repr.buffer_len(); + + self.client.recv(&buffer[..total_len], &ChecksumState::NONE); + Ok(()) + } + + /// Handle Neighbor Solicitation (RFC 4861 Section 7.2.3) + /// + /// Neighbor Solicitations are used for: + /// 1. Address resolution (discovering link-layer address of a neighbor) + /// 2. Duplicate Address Detection (DAD) - verifying address uniqueness + /// 3. Neighbor Unreachability Detection (NUD) + fn handle_neighbor_solicit( + &mut self, + frame: &EthernetRepr, + ipv6_src_addr: Ipv6Address, + target_addr: Ipv6Address, + source_lladdr: Option, + ) -> Result<(), DropReason> { + // RFC 4861 Section 7.1.1: If source is unspecified, there must be no + // source link-layer address option + if ipv6_src_addr.is_unspecified() && source_lladdr.is_some() { + tracelimit::warn_ratelimited!( + "invalid NS: source is :: but source link-layer address present" + ); + return Err(DropReason::MalformedPacket); + } + + // RFC 4862 Section 5.4.3: Handle Duplicate Address Detection (DAD) + // If source is unspecified (::), this is DAD - we should NOT respond + // to avoid interfering with the client's address configuration + if ipv6_src_addr.is_unspecified() { + tracelimit::warn_ratelimited!( + target_addr = %target_addr, + "received DAD Neighbor Solicitation, silently ignoring per RFC 4862" + ); + return Ok(()); + } + + // Verify this is from the expected client MAC + let client_mac_matches = source_lladdr + .and_then(|addr| addr.parse(Medium::Ethernet).ok()) + .map(|hw_addr| match hw_addr { + HardwareAddress::Ethernet(eth_addr) => { + eth_addr == self.inner.state.params.client_mac + } + #[allow(unreachable_patterns)] + _ => false, + }) + .unwrap_or(false); + + if !client_mac_matches { + tracelimit::warn_ratelimited!("Neighbor Solicitation from unexpected MAC, ignoring"); + return Ok(()); + } + + // Learn client IPv6 address from Neighbor Solicitation + // When the client performs address resolution using their SLAAC-configured + // global address, we learn it here. We only learn global unicast addresses + // (not link-local, multicast, or unspecified). + if !ipv6_src_addr.is_link_local() + && !ipv6_src_addr.is_multicast() + && !ipv6_src_addr.is_unspecified() + { + if self.inner.state.params.client_ip_ipv6.is_none() + || self.inner.state.params.client_ip_ipv6 != Some(ipv6_src_addr) + { + tracing::debug!( + client_ipv6 = %ipv6_src_addr, + "learned client IPv6 address from Neighbor Solicitation" + ); + self.inner.state.params.client_ip_ipv6 = Some(ipv6_src_addr); + } + } + + // Only respond if the target is our link-local address + // In a stateless NAT implementation, the gateway only responds for its own + // link-local address, not for global addresses that clients autoconfigure + if target_addr != self.inner.state.params.gateway_link_local_ipv6 { + tracing::debug!( + target_addr = %target_addr, + our_link_local = %self.inner.state.params.gateway_link_local_ipv6, + "NS target is not our link-local address, ignoring" + ); + return Ok(()); + } + + // Send Neighbor Advertisement + self.send_neighbor_advertisement(ipv6_src_addr, frame.src_addr, target_addr, true) + } + + /// Send a Neighbor Advertisement (RFC 4861 Section 7.2.4) + /// + /// Neighbor Advertisements are sent in response to Neighbor Solicitations + /// to provide our link-layer address for address resolution. + fn send_neighbor_advertisement( + &mut self, + dst_addr: Ipv6Address, + eth_dst_addr: EthernetAddress, + target_addr: Ipv6Address, + solicited: bool, + ) -> Result<(), DropReason> { + // RFC 4861 Section 7.2.4: Neighbor Advertisement format + // Solicited flag = 1 (this is a response to a solicitation) + // Override flag = 1 (we're authoritative for this address) + // Router flag = 1 (we are a router) + let mut flags = NdiscNeighborFlags::OVERRIDE; + if solicited { + flags |= NdiscNeighborFlags::SOLICITED; + } + flags |= NdiscNeighborFlags::ROUTER; + + let ndp_repr = NdiscRepr::NeighborAdvert { + flags, + target_addr, + lladdr: Some(RawHardwareAddress::from( + self.inner.state.params.gateway_mac_ipv6, + )), + }; + + // Build IPv6 header - destination is the source of the solicitation + let ipv6_repr = Ipv6Repr { + src_addr: target_addr, // Our address (the one being asked about) + dst_addr, // Respond to the solicitation's source + next_header: IpProtocol::Icmpv6, + payload_len: ndp_repr.buffer_len(), + hop_limit: 255, // RFC 4861: Neighbor Advertisements must have a hop limit of 255 to indicate the packet was not forwarded. + }; + + // Build Ethernet header + let eth_repr = EthernetRepr { + src_addr: self.inner.state.params.gateway_mac_ipv6, + dst_addr: eth_dst_addr, + ethertype: EthernetProtocol::Ipv6, + }; + + // Construct the complete packet + let mut buffer = [0; MIN_MTU]; + let mut eth_frame = EthernetFrame::new_unchecked(&mut buffer); + eth_repr.emit(&mut eth_frame); + + let mut ipv6_packet = Ipv6Packet::new_unchecked(eth_frame.payload_mut()); + ipv6_repr.emit(&mut ipv6_packet); + + let mut icmpv6_packet = Icmpv6Packet::new_unchecked(ipv6_packet.payload_mut()); + ndp_repr.emit(&mut icmpv6_packet); + icmpv6_packet.fill_checksum( + &IpAddress::Ipv6(ipv6_repr.src_addr), + &IpAddress::Ipv6(ipv6_repr.dst_addr), + ); + + let total_len = eth_repr.buffer_len() + ipv6_repr.buffer_len() + ndp_repr.buffer_len(); + + self.client.recv(&buffer[..total_len], &ChecksumState::NONE); + Ok(()) + } + + /// Compute the network prefix from an IPv6 address and prefix length + /// + /// This extracts the network portion of an IPv6 address by applying + /// a mask based on the prefix length. + fn compute_network_prefix(&self, addr: Ipv6Address, prefix_len: u8) -> Ipv6Address { + if prefix_len >= 128 { + return addr; + } + + let addr_u128 = u128::from_be_bytes(addr.0); + let mask = if prefix_len == 0 { + 0u128 + } else { + (!0u128) << (128 - prefix_len) + }; + + Ipv6Address((addr_u128 & mask).to_be_bytes()) + } +} diff --git a/vm/devices/net/net_consomme/consomme/src/tcp.rs b/vm/devices/net/net_consomme/consomme/src/tcp.rs index 816b8ddb09..d7532afb46 100644 --- a/vm/devices/net/net_consomme/consomme/src/tcp.rs +++ b/vm/devices/net/net_consomme/consomme/src/tcp.rs @@ -6,11 +6,9 @@ mod ring; use super::Access; use super::Client; use super::DropReason; -use super::FourTuple; -use super::SocketAddress; use crate::ChecksumState; use crate::ConsommeState; -use crate::Ipv4Addresses; +use crate::IpAddresses; use futures::AsyncRead; use futures::AsyncWrite; use inspect::Inspect; @@ -22,9 +20,12 @@ use smoltcp::wire::ETHERNET_HEADER_LEN; use smoltcp::wire::EthernetFrame; use smoltcp::wire::EthernetProtocol; use smoltcp::wire::IPV4_HEADER_LEN; +use smoltcp::wire::IPV6_HEADER_LEN; +use smoltcp::wire::IpAddress; use smoltcp::wire::IpProtocol; +use smoltcp::wire::IpRepr; use smoltcp::wire::Ipv4Packet; -use smoltcp::wire::Ipv4Repr; +use smoltcp::wire::Ipv6Packet; use smoltcp::wire::TcpControl; use smoltcp::wire::TcpPacket; use smoltcp::wire::TcpRepr; @@ -41,16 +42,32 @@ use std::io; use std::io::ErrorKind; use std::io::IoSlice; use std::io::IoSliceMut; +use std::net::IpAddr; use std::net::Ipv4Addr; +use std::net::Ipv6Addr; use std::net::Shutdown; +use std::net::SocketAddr; use std::net::SocketAddrV4; +use std::net::SocketAddrV6; use std::pin::Pin; use std::task::Context; use std::task::Poll; use thiserror::Error; +trait SupportedAddressFamily {} + +impl SupportedAddressFamily for SocketAddrV4 {} +impl SupportedAddressFamily for SocketAddrV6 {} +impl SupportedAddressFamily for SocketAddr {} + +#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)] +struct FourTuple { + src: T, + dst: T, +} + pub(crate) struct Tcp { - connections: HashMap, + connections: HashMap, TcpConnection>, listeners: HashMap, } @@ -77,7 +94,10 @@ impl Inspect for Tcp { resp.field( &format!( "{}:{}-{}:{}", - addr.src.ip, addr.src.port, addr.dst.ip, addr.dst.port + addr.src.ip(), + addr.src.port(), + addr.dst.ip(), + addr.dst.port() ), conn, ); @@ -208,12 +228,12 @@ impl Access<'_, T> { if let Some((socket, mut other_addr)) = result { // Check for loopback requests and replace the dest port. // This supports a guest owning both the sending and receiving ports. - if other_addr.ip.is_loopback() { + if other_addr.ip().is_loopback() { for (other_ft, connection) in self.inner.tcp.connections.iter() { - if connection.state == TcpState::Connecting && other_ft.dst.port == *port { + if connection.state == TcpState::Connecting && other_ft.dst.port() == *port { if let LoopbackPortInfo::ProxyForGuestPort{sending_port, guest_port} = connection.loopback_port { - if sending_port == other_addr.port { - other_addr.port = guest_port; + if sending_port == other_addr.port() { + other_addr.set_port(guest_port); break; } } @@ -221,10 +241,7 @@ impl Access<'_, T> { } } - let ft = FourTuple { dst: other_addr, src: SocketAddress { - ip: self.inner.state.params.client_ip, - port: *port, - } }; + let ft = FourTuple { dst: SocketAddr::V4(other_addr), src: SocketAddr::V4(SocketAddrV4::new(self.inner.state.params.client_ip.into(), *port)) }; match self.inner.tcp.connections.entry(ft) { hash_map::Entry::Vacant(e) => { @@ -290,35 +307,44 @@ impl Access<'_, T> { false } } - }) + }); } pub(crate) fn handle_tcp( &mut self, - addresses: &Ipv4Addresses, + addresses: &IpAddresses, payload: &[u8], checksum: &ChecksumState, ) -> Result<(), DropReason> { let tcp_packet = TcpPacket::new_checked(payload)?; let tcp = TcpRepr::parse( &tcp_packet, - &addresses.src_addr.into(), - &addresses.dst_addr.into(), + &addresses.src_addr(), + &addresses.dst_addr(), &checksum.caps(), )?; - tracing::trace!(?tcp, "tcp packet"); - - let ft = FourTuple { - dst: SocketAddress { - ip: addresses.dst_addr, - port: tcp.dst_port, + let ft = match addresses { + IpAddresses::V4(addresses) => FourTuple { + dst: SocketAddr::V4(SocketAddrV4::new(addresses.dst_addr.into(), tcp.dst_port)), + src: SocketAddr::V4(SocketAddrV4::new(addresses.src_addr.into(), tcp.src_port)), }, - src: SocketAddress { - ip: addresses.src_addr, - port: tcp.src_port, + IpAddresses::V6(addresses) => FourTuple { + dst: SocketAddr::V6(SocketAddrV6::new( + addresses.dst_addr.into(), + tcp.dst_port, + 0, + 0, + )), + src: SocketAddr::V6(SocketAddrV6::new( + addresses.src_addr.into(), + tcp.src_port, + 0, + 0, + )), }, }; + tracing::trace!(?tcp, "tcp packet"); let mut sender = Sender { ft: &ft, @@ -352,24 +378,25 @@ impl Access<'_, T> { /// Binds to the specified host IP and port for listening for incoming /// connections. - pub fn bind_tcp_port( - &mut self, - ip_addr: Option, - port: u16, - ) -> Result<(), DropReason> { + pub fn bind_tcp_port(&mut self, ip_addr: Option, port: u16) -> Result<(), DropReason> { + let ip_addr = match ip_addr { + Some(IpAddr::V4(ip)) => SocketAddr::V4(SocketAddrV4::new(ip, port)), + Some(IpAddr::V6(ip)) => SocketAddr::V6(SocketAddrV6::new(ip, port, 0, 0)), + None => SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, port)), + }; match self.inner.tcp.listeners.entry(port) { hash_map::Entry::Occupied(_) => { tracing::warn!(port, "Duplicate TCP bind for port"); } hash_map::Entry::Vacant(e) => { - let ft = FourTuple { - dst: SocketAddress { - ip: Ipv4Addr::UNSPECIFIED.into(), - port: 0, + let ft = match ip_addr { + SocketAddr::V4(ip) => FourTuple { + dst: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0)), + src: SocketAddr::V4(ip), }, - src: SocketAddress { - ip: ip_addr.unwrap_or(Ipv4Addr::UNSPECIFIED).into(), - port, + SocketAddr::V6(ip) => FourTuple { + dst: SocketAddr::V6(SocketAddrV6::new(Ipv6Addr::UNSPECIFIED, 0, 0, 0)), + src: SocketAddr::V6(ip), }, }; let mut sender = Sender { @@ -398,7 +425,7 @@ impl Access<'_, T> { } struct Sender<'a, T> { - ft: &'a FourTuple, + ft: &'a FourTuple, client: &'a mut T, state: &'a mut ConsommeState, } @@ -407,39 +434,73 @@ impl Sender<'_, T> { fn send_packet(&mut self, tcp: &TcpRepr<'_>, payload: Option>) { let buffer = &mut self.state.buffer; let mut eth_packet = EthernetFrame::new_unchecked(&mut buffer[..]); - eth_packet.set_ethertype(EthernetProtocol::Ipv4); eth_packet.set_dst_addr(self.state.params.client_mac); eth_packet.set_src_addr(self.state.params.gateway_mac); - let mut ipv4_packet = Ipv4Packet::new_unchecked(eth_packet.payload_mut()); - let ipv4 = Ipv4Repr { - src_addr: self.ft.dst.ip, - dst_addr: self.ft.src.ip, - protocol: IpProtocol::Tcp, - payload_len: tcp.header_len() + payload.as_ref().map_or(0, |p| p.len()), - hop_limit: 64, + let copy_payload_into_buffer = |buf: &mut [u8], payload: Option>| { + if let Some(payload) = payload { + for (b, c) in buf.iter_mut().zip(payload.iter()) { + *b = *c; + } + } + }; + let ip = IpRepr::new( + self.ft.dst.ip().into(), + self.ft.src.ip().into(), + IpProtocol::Tcp, + tcp.header_len() + payload.as_ref().map_or(0, |p| p.len()), + 64, + ); + // Set the ethernet type based on IP version + match ip { + IpRepr::Ipv4(_) => eth_packet.set_ethertype(EthernetProtocol::Ipv4), + IpRepr::Ipv6(_) => eth_packet.set_ethertype(EthernetProtocol::Ipv6), + } + + // Emit IP packet and get the TCP payload buffer (works for both IPv4 and IPv6) + let ip_packet_buf = eth_packet.payload_mut(); + ip.emit(&mut *ip_packet_buf, &ChecksumCapabilities::default()); + + let (tcp_payload_buf, ip_total_len) = match self.ft.dst { + SocketAddr::V4(_) => { + let ipv4_packet = Ipv4Packet::new_unchecked(&*ip_packet_buf); + let total_len = ipv4_packet.total_len() as usize; + let payload_offset = ipv4_packet.header_len() as usize; + (&mut ip_packet_buf[payload_offset..], total_len) + } + SocketAddr::V6(_) => { + let ipv6_packet = Ipv6Packet::new_unchecked(&*ip_packet_buf); + let total_len = ipv6_packet.total_len(); + let payload_offset = IPV6_HEADER_LEN; + (&mut ip_packet_buf[payload_offset..], total_len) + } }; - ipv4.emit(&mut ipv4_packet, &ChecksumCapabilities::default()); - let mut tcp_packet = TcpPacket::new_unchecked(ipv4_packet.payload_mut()); + + let dst_ip_addr: IpAddress = self.ft.dst.ip().into(); + let src_ip_addr: IpAddress = self.ft.src.ip().into(); + let mut tcp_packet = TcpPacket::new_unchecked(tcp_payload_buf); tcp.emit( &mut tcp_packet, - &self.ft.dst.ip.into(), - &self.ft.src.ip.into(), + &dst_ip_addr, + &src_ip_addr, &ChecksumCapabilities::default(), ); - if let Some(payload) = payload { - for (b, c) in tcp_packet.payload_mut().iter_mut().zip(payload.iter()) { - *b = *c; - } - } - tcp_packet.fill_checksum(&self.ft.dst.ip.into(), &self.ft.src.ip.into()); - let n = ETHERNET_HEADER_LEN + ipv4_packet.total_len() as usize; - self.client.recv(&buffer[..n], &ChecksumState::TCP4); + + // Copy payload into TCP packet + copy_payload_into_buffer(tcp_packet.payload_mut(), payload); + tcp_packet.fill_checksum(&self.ft.dst.ip().into(), &self.ft.src.ip().into()); + let n = ETHERNET_HEADER_LEN + ip_total_len; + let checksum_state = match self.ft.dst { + SocketAddr::V4(_) => ChecksumState::TCP4, + SocketAddr::V6(_) => ChecksumState::TCP6, + }; + + self.client.recv(&buffer[..n], &checksum_state); } fn rst(&mut self, seq: TcpSeqNumber, ack: Option) { let tcp = TcpRepr { - src_port: self.ft.dst.port, - dst_port: self.ft.src.port, + src_port: self.ft.dst.port(), + dst_port: self.ft.src.port(), control: TcpControl::Rst, seq_number: seq, ack_number: ack, @@ -505,24 +566,28 @@ impl TcpConnection { let mut this = Self::default(); this.initialize_from_first_client_packet(tcp)?; - let socket = - Socket::new(Domain::IPV4, Type::STREAM, Some(Protocol::TCP)).map_err(DropReason::Io)?; + let socket = Socket::new( + match sender.ft.dst { + SocketAddr::V4(_) => Domain::IPV4, + SocketAddr::V6(_) => Domain::IPV6, + }, + Type::STREAM, + Some(Protocol::TCP), + ) + .map_err(DropReason::Io)?; // On Windows the default behavior for non-existent loopback sockets is // to wait and try again. This is different than the Linux behavior of // immediately failing. Default to the Linux behavior. #[cfg(windows)] - if sender.ft.dst.ip.is_loopback() { + if sender.ft.dst.ip().is_loopback() { if let Err(err) = crate::windows::disable_connection_retries(&socket) { tracing::trace!(err, "Failed to disable loopback retries"); } } let socket = PolledSocket::new(sender.client.driver(), socket).map_err(DropReason::Io)?; - match socket - .get() - .connect(&SockAddr::from(SocketAddrV4::from(sender.ft.dst))) - { + match socket.get().connect(&SockAddr::from(sender.ft.dst)) { Ok(_) => unreachable!(), Err(err) if is_connect_incomplete_error(&err) => (), Err(err) => { @@ -535,12 +600,17 @@ impl TcpConnection { } } if let Ok(addr) = socket.get().local_addr() { - if let Some(addr) = addr.as_socket_ipv4() { - if addr.ip().is_loopback() { - this.loopback_port = LoopbackPortInfo::ProxyForGuestPort { - sending_port: addr.port(), - guest_port: sender.ft.src.port, - }; + match addr.as_socket() { + None => { + tracing::warn!("unable to get local socket address"); + } + Some(addr) => { + if addr.ip().is_loopback() { + this.loopback_port = LoopbackPortInfo::ProxyForGuestPort { + sending_port: addr.port(), + guest_port: sender.ft.src.port(), + }; + } } } } @@ -789,8 +859,8 @@ impl TcpConnection { // to truncate this to its own MTU calculation. let max_seg_size = u16::MAX; let tcp = TcpRepr { - src_port: sender.ft.dst.port, - dst_port: sender.ft.src.port, + src_port: sender.ft.dst.port(), + dst_port: sender.ft.src.port(), control: TcpControl::Syn, seq_number: self.tx_send, ack_number, @@ -821,8 +891,8 @@ impl TcpConnection { } let mut tcp = TcpRepr { - src_port: sender.ft.dst.port, - dst_port: sender.ft.src.port, + src_port: sender.ft.dst.port(), + dst_port: sender.ft.src.port(), control: TcpControl::None, seq_number: self.tx_send, ack_number: Some(self.rx_seq), @@ -843,7 +913,11 @@ impl TcpConnection { // 3. The configured maximum segment size. // 4. The client MTU. let tx_segment_end = { - let header_len = ETHERNET_HEADER_LEN + IPV4_HEADER_LEN + tcp.header_len(); + let ip_header_len = match sender.ft.dst { + SocketAddr::V4(_) => IPV4_HEADER_LEN, + SocketAddr::V6(_) => IPV6_HEADER_LEN, + }; + let header_len = ETHERNET_HEADER_LEN + ip_header_len + tcp.header_len(); let mtu = rx_mtu.min(sender.state.buffer.len()); seq_min([ tx_payload_end, @@ -915,8 +989,8 @@ impl TcpConnection { /// by the peer. fn ack(&self, sender: &mut Sender<'_, impl Client>) { let tcp = TcpRepr { - src_port: sender.ft.dst.port, - dst_port: sender.ft.src.port, + src_port: sender.ft.dst.port(), + dst_port: sender.ft.src.port(), control: TcpControl::None, seq_number: self.tx_send, ack_number: Some(self.rx_seq), @@ -1168,17 +1242,14 @@ impl TcpListener { fn poll_listener( &mut self, cx: &mut Context<'_>, - ) -> Result, DropReason> { + ) -> Result, DropReason> { match self.socket.poll_accept(cx) { Poll::Ready(r) => match r { Ok((socket, address)) => match address.as_socket() { Some(addr) => match address.as_socket_ipv4() { Some(src_address) => Ok(Some(( socket, - SocketAddress { - ip: (*src_address.ip()).into(), - port: addr.port(), - }, + SocketAddrV4::new(*src_address.ip(), addr.port()), ))), None => { tracing::warn!(?address, "Not an IPv4 address from accept"); diff --git a/vm/devices/net/net_consomme/consomme/src/udp.rs b/vm/devices/net/net_consomme/consomme/src/udp.rs index 29b259d444..0ed6c62ee8 100644 --- a/vm/devices/net/net_consomme/consomme/src/udp.rs +++ b/vm/devices/net/net_consomme/consomme/src/udp.rs @@ -4,11 +4,11 @@ use super::Access; use super::Client; use super::DropReason; -use super::SocketAddress; use super::dhcp::DHCP_SERVER; +use super::dhcpv6::DHCPV6_SERVER; use crate::ChecksumState; use crate::ConsommeState; -use crate::Ipv4Addresses; +use crate::IpAddresses; use inspect::Inspect; use inspect::InspectMut; use inspect_counters::Counter; @@ -22,9 +22,13 @@ use smoltcp::wire::EthernetFrame; use smoltcp::wire::EthernetProtocol; use smoltcp::wire::EthernetRepr; use smoltcp::wire::IPV4_HEADER_LEN; +use smoltcp::wire::IPV6_HEADER_LEN; +use smoltcp::wire::IpAddress; use smoltcp::wire::IpProtocol; +use smoltcp::wire::IpRepr; use smoltcp::wire::Ipv4Packet; -use smoltcp::wire::Ipv4Repr; +use smoltcp::wire::Ipv6Address; +use smoltcp::wire::Ipv6Packet; use smoltcp::wire::UDP_HEADER_LEN; use smoltcp::wire::UdpPacket; use smoltcp::wire::UdpRepr; @@ -33,12 +37,16 @@ use std::collections::hash_map; use std::io::ErrorKind; use std::net::IpAddr; use std::net::Ipv4Addr; +use std::net::Ipv6Addr; +use std::net::SocketAddr; +use std::net::SocketAddrV4; +use std::net::SocketAddrV6; use std::net::UdpSocket; use std::task::Context; use std::task::Poll; pub(crate) struct Udp { - connections: HashMap, + connections: HashMap, } impl Udp { @@ -53,7 +61,8 @@ impl InspectMut for Udp { fn inspect_mut(&mut self, req: inspect::Request<'_>) { let mut resp = req.respond(); for (addr, conn) in &mut self.connections { - resp.field_mut(&format!("{}:{}", addr.ip, addr.port), conn); + let key = addr.to_string(); + resp.field_mut(&key, conn); } } } @@ -81,7 +90,7 @@ impl UdpConnection { fn poll_conn( &mut self, cx: &mut Context<'_>, - dst_addr: &SocketAddress, + dst_addr: &SocketAddr, state: &mut ConsommeState, client: &mut impl Client, ) -> bool { @@ -99,6 +108,12 @@ impl UdpConnection { if client.rx_mtu() == 0 { break true; } + + let header_offset = match dst_addr { + SocketAddr::V4(_) => IPV4_HEADER_LEN + UDP_HEADER_LEN, + SocketAddr::V6(_) => IPV6_HEADER_LEN + UDP_HEADER_LEN, + }; + match self.socket.as_mut().unwrap().poll_io( cx, InterestSlot::Read, @@ -106,34 +121,58 @@ impl UdpConnection { |socket| { socket .get() - .recv_from(&mut eth.payload_mut()[IPV4_HEADER_LEN + UDP_HEADER_LEN..]) + .recv_from(&mut eth.payload_mut()[header_offset..]) }, ) { Poll::Ready(Ok((n, src_addr))) => { - let src_ip = if let IpAddr::V4(ip) = src_addr.ip() { - ip - } else { - unreachable!() - }; - eth.set_ethertype(EthernetProtocol::Ipv4); - eth.set_src_addr(state.params.gateway_mac); eth.set_dst_addr(self.guest_mac); - let mut ipv4 = Ipv4Packet::new_unchecked(eth.payload_mut()); - Ipv4Repr { - src_addr: src_ip.into(), - dst_addr: dst_addr.ip, - protocol: IpProtocol::Udp, - payload_len: UDP_HEADER_LEN + n, - hop_limit: 64, + eth.set_src_addr(state.params.gateway_mac); + let ip = IpRepr::new( + src_addr.ip().into(), + dst_addr.ip().into(), + IpProtocol::Udp, + UDP_HEADER_LEN + n, + 64, + ); + + match ip { + IpRepr::Ipv4(_) => eth.set_ethertype(EthernetProtocol::Ipv4), + IpRepr::Ipv6(_) => eth.set_ethertype(EthernetProtocol::Ipv6), } - .emit(&mut ipv4, &ChecksumCapabilities::default()); - let mut udp = UdpPacket::new_unchecked(ipv4.payload_mut()); - udp.set_src_port(src_addr.port()); - udp.set_dst_port(dst_addr.port); - udp.set_len((UDP_HEADER_LEN + n) as u16); - udp.fill_checksum(&src_ip.into(), &dst_addr.ip.into()); - let len = ETHERNET_HEADER_LEN + ipv4.total_len() as usize; - client.recv(ð.as_ref()[..len], &ChecksumState::UDP4); + + let ip_packet_buf = eth.payload_mut(); + ip.emit(&mut *ip_packet_buf, &ChecksumCapabilities::default()); + let (udp_payload_buf, ip_total_len) = match dst_addr { + SocketAddr::V4(_) => { + let ipv4_packet = Ipv4Packet::new_unchecked(&*ip_packet_buf); + let total_len = ipv4_packet.total_len() as usize; + let payload_offset = ipv4_packet.header_len() as usize; + (&mut ip_packet_buf[payload_offset..], total_len) + } + SocketAddr::V6(_) => { + let ipv6_packet = Ipv6Packet::new_unchecked(&*ip_packet_buf); + let total_len = ipv6_packet.total_len(); + let payload_offset = IPV6_HEADER_LEN; + (&mut ip_packet_buf[payload_offset..], total_len) + } + }; + + let dst_ip_addr: IpAddress = dst_addr.ip().into(); + let src_ip_addr: IpAddress = src_addr.ip().into(); + let mut udp_packet = UdpPacket::new_unchecked(udp_payload_buf); + udp_packet.set_src_port(src_addr.port()); + udp_packet.set_dst_port(dst_addr.port()); + udp_packet.set_len((UDP_HEADER_LEN + n) as u16); + udp_packet.fill_checksum(&src_ip_addr, &dst_ip_addr); + + let packet_len = ETHERNET_HEADER_LEN + ip_total_len; + let checksum_state = match dst_addr { + SocketAddr::V4(_) => ChecksumState::UDP4, + SocketAddr::V6(_) => ChecksumState::NONE, + }; + + // Send packet to client + client.recv(ð.as_ref()[..packet_len], &checksum_state); self.stats.rx_packets.increment(); } Poll::Ready(Err(err)) => { @@ -175,36 +214,74 @@ impl Access<'_, T> { pub(crate) fn handle_udp( &mut self, frame: &EthernetRepr, - addresses: &Ipv4Addresses, + addresses: &IpAddresses, payload: &[u8], checksum: &ChecksumState, ) -> Result<(), DropReason> { let udp_packet = UdpPacket::new_checked(payload)?; - let udp = UdpRepr::parse( - &udp_packet, - &addresses.src_addr.into(), - &addresses.dst_addr.into(), - &checksum.caps(), - )?; - - if addresses.dst_addr == self.inner.state.params.gateway_ip - || addresses.dst_addr.is_broadcast() - { - if self.handle_gateway_udp(&udp_packet)? { - return Ok(()); + + // Parse UDP header and check gateway handling + let (guest_addr, dst_sock_addr) = match addresses { + IpAddresses::V4(addrs) => { + let udp = UdpRepr::parse( + &udp_packet, + &addrs.src_addr.into(), + &addrs.dst_addr.into(), + &checksum.caps(), + )?; + + // Check for gateway-destined packets + if addrs.dst_addr == self.inner.state.params.gateway_ip + || addrs.dst_addr.is_broadcast() + { + if self.handle_gateway_udp(&udp_packet)? { + return Ok(()); + } + } + + let guest_addr = + SocketAddr::V4(SocketAddrV4::new(addrs.src_addr.into(), udp.src_port)); + + let dst_sock_addr = + SocketAddr::V4(SocketAddrV4::new(addrs.dst_addr.into(), udp.dst_port)); + + (guest_addr, dst_sock_addr) } - } + IpAddresses::V6(addrs) => { + let udp = UdpRepr::parse( + &udp_packet, + &addrs.src_addr.into(), + &addrs.dst_addr.into(), + &checksum.caps(), + )?; - let guest_addr = SocketAddress { - ip: addresses.src_addr, - port: udp.src_port, + // Check for gateway-destined packets (IPv6 uses multicast instead of broadcast) + if addrs.dst_addr == self.inner.state.params.gateway_link_local_ipv6 + || addrs.dst_addr.0[0..2] == [0xff, 0x02] + { + if self.handle_gateway_udp_v6(&udp_packet, Some(addrs.src_addr))? { + return Ok(()); + } + } + + let guest_addr = + SocketAddr::V6(SocketAddrV6::new(addrs.src_addr.into(), udp.src_port, 0, 0)); + + let dst_sock_addr = + SocketAddr::V6(SocketAddrV6::new(addrs.dst_addr.into(), udp.dst_port, 0, 0)); + + (guest_addr, dst_sock_addr) + } }; - let conn = self.get_or_insert(guest_addr, None, Some(frame.src_addr))?; - match conn.socket.as_mut().unwrap().get().send_to( - udp_packet.payload(), - (Ipv4Addr::from(addresses.dst_addr), udp.dst_port), - ) { + let conn = self.get_or_insert(guest_addr, Some(frame.src_addr))?; + match conn + .socket + .as_mut() + .unwrap() + .get() + .send_to(udp_packet.payload(), dst_sock_addr) + { Ok(_) => { conn.stats.tx_packets.increment(); Ok(()) @@ -222,16 +299,23 @@ impl Access<'_, T> { fn get_or_insert( &mut self, - guest_addr: SocketAddress, - host_addr: Option, + guest_addr: SocketAddr, guest_mac: Option, ) -> Result<&mut UdpConnection, DropReason> { let entry = self.inner.udp.connections.entry(guest_addr); match entry { hash_map::Entry::Occupied(conn) => Ok(conn.into_mut()), hash_map::Entry::Vacant(e) => { - let socket = UdpSocket::bind((host_addr.unwrap_or(Ipv4Addr::UNSPECIFIED), 0)) - .map_err(DropReason::Io)?; + let bind_addr: SocketAddr = match guest_addr { + SocketAddr::V4(_) => { + SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0)) + } + SocketAddr::V6(_) => { + SocketAddr::V6(SocketAddrV6::new(Ipv6Addr::UNSPECIFIED, 0, 0, 0)) + } + }; + + let socket = UdpSocket::bind(bind_addr).map_err(DropReason::Io)?; let socket = PolledSocket::new(self.client.driver(), socket).map_err(DropReason::Io)?; let conn = UdpConnection { @@ -256,30 +340,46 @@ impl Access<'_, T> { } } + fn handle_gateway_udp_v6( + &mut self, + udp: &UdpPacket<&[u8]>, + client_ip: Option, + ) -> Result { + let payload = udp.payload(); + match udp.dst_port() { + DHCPV6_SERVER => { + self.handle_dhcpv6(payload, client_ip)?; + Ok(true) + } + _ => Ok(false), + } + } + /// Binds to the specified host IP and port for forwarding inbound UDP /// packets to the guest. - pub fn bind_udp_port( - &mut self, - ip_addr: Option, - port: u16, - ) -> Result<(), DropReason> { - let guest_addr = SocketAddress { - ip: ip_addr.unwrap_or(Ipv4Addr::UNSPECIFIED).into(), - port, + pub fn bind_udp_port(&mut self, ip_addr: Option, port: u16) -> Result<(), DropReason> { + let guest_addr = match ip_addr { + Some(IpAddr::V4(ip)) => SocketAddr::V4(SocketAddrV4::new(ip, port)), + Some(IpAddr::V6(ip)) => SocketAddr::V6(SocketAddrV6::new(ip, port, 0, 0)), + None => SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, port)), }; - let _ = self.get_or_insert(guest_addr, ip_addr, None)?; + let _ = self.get_or_insert(guest_addr, None)?; Ok(()) } - /// Unbinds from the specified host port. + /// Unbinds from the specified host port for both IPv4 and IPv6. pub fn unbind_udp_port(&mut self, port: u16) -> Result<(), DropReason> { - let guest_addr = SocketAddress { - ip: Ipv4Addr::UNSPECIFIED.into(), - port, - }; - match self.inner.udp.connections.remove(&guest_addr) { - Some(_) => Ok(()), - None => Err(DropReason::PortNotBound), + // Try to remove both IPv4 and IPv6 bindings + let v4_addr = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, port)); + let v6_addr = SocketAddr::V6(SocketAddrV6::new(Ipv6Addr::UNSPECIFIED, port, 0, 0)); + + let v4_removed = self.inner.udp.connections.remove(&v4_addr).is_some(); + let v6_removed = self.inner.udp.connections.remove(&v6_addr).is_some(); + + if v4_removed || v6_removed { + Ok(()) + } else { + Err(DropReason::PortNotBound) } } } diff --git a/vm/devices/net/net_consomme/src/lib.rs b/vm/devices/net/net_consomme/src/lib.rs index 290ca4fe5f..511d9e5753 100644 --- a/vm/devices/net/net_consomme/src/lib.rs +++ b/vm/devices/net/net_consomme/src/lib.rs @@ -31,7 +31,7 @@ use net_backend::TxSegmentType; use pal_async::driver::Driver; use parking_lot::Mutex; use std::collections::VecDeque; -use std::net::Ipv4Addr; +use std::net::IpAddr; use std::sync::Arc; use std::task::Context; use std::task::Poll; @@ -107,7 +107,7 @@ pub enum IpProtocol { struct MessageBindPort { protocol: IpProtocol, - address: Option, + address: Option, port: u16, } @@ -122,7 +122,7 @@ impl ConsommeControl { pub async fn bind_port( &self, protocol: IpProtocol, - ip_addr: Option, + ip_addr: Option, port: u16, ) -> Result<(), ConsommeMessageError> { self.send @@ -352,11 +352,15 @@ impl net_backend::Queue for ConsommeQueue { consomme::DropReason::UnsupportedEthertype(_) | consomme::DropReason::UnsupportedIpProtocol(_) | consomme::DropReason::UnsupportedDhcp(_) - | consomme::DropReason::UnsupportedArp => self.stats.tx_unknown.increment(), + | consomme::DropReason::UnsupportedArp + | consomme::DropReason::UnsupportedDhcpv6(_) + | consomme::DropReason::UnsupportedNdp(_) => self.stats.tx_unknown.increment(), consomme::DropReason::Packet(_) | consomme::DropReason::Ipv4Checksum | consomme::DropReason::Io(_) - | consomme::DropReason::BadTcpState(_) => self.stats.tx_errors.increment(), + | consomme::DropReason::BadTcpState(_) + | consomme::DropReason::FragmentedPacket + | consomme::DropReason::MalformedPacket => self.stats.tx_errors.increment(), consomme::DropReason::PortNotBound => unreachable!(), } }