summaryrefslogtreecommitdiff
path: root/vendor/rustix/src/net/socket_addr_any.rs
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/rustix/src/net/socket_addr_any.rs')
-rw-r--r--vendor/rustix/src/net/socket_addr_any.rs344
1 files changed, 344 insertions, 0 deletions
diff --git a/vendor/rustix/src/net/socket_addr_any.rs b/vendor/rustix/src/net/socket_addr_any.rs
new file mode 100644
index 00000000..7a953044
--- /dev/null
+++ b/vendor/rustix/src/net/socket_addr_any.rs
@@ -0,0 +1,344 @@
+//! The [`SocketAddrAny`] type and related utilities.
+
+#![allow(unsafe_code)]
+
+use crate::backend::c;
+use crate::backend::net::read_sockaddr;
+use crate::io::Errno;
+use crate::net::addr::{SocketAddrArg, SocketAddrLen, SocketAddrOpaque, SocketAddrStorage};
+#[cfg(unix)]
+use crate::net::SocketAddrUnix;
+use crate::net::{AddressFamily, SocketAddr, SocketAddrV4, SocketAddrV6};
+use core::fmt;
+use core::mem::{size_of, MaybeUninit};
+use core::num::NonZeroU32;
+
+/// Temporary buffer for creating a `SocketAddrAny` from a syscall that writes
+/// to a `sockaddr_t` and `socklen_t`
+///
+/// Unlike `SocketAddrAny`, this does not maintain the invariant that `len`
+/// bytes are initialized.
+pub(crate) struct SocketAddrBuf {
+ pub(crate) len: c::socklen_t,
+ pub(crate) storage: MaybeUninit<SocketAddrStorage>,
+}
+
+impl SocketAddrBuf {
+ #[inline]
+ pub(crate) const fn new() -> Self {
+ Self {
+ len: size_of::<SocketAddrStorage>() as c::socklen_t,
+ storage: MaybeUninit::<SocketAddrStorage>::uninit(),
+ }
+ }
+
+ /// Convert the buffer into [`SocketAddrAny`].
+ ///
+ /// # Safety
+ ///
+ /// A valid address must have been written into `self.storage` and its
+ /// length written into `self.len`.
+ #[inline]
+ pub(crate) unsafe fn into_any(self) -> SocketAddrAny {
+ SocketAddrAny::new(self.storage, bitcast!(self.len))
+ }
+
+ /// Convert the buffer into [`Option<SocketAddrAny>`].
+ ///
+ /// This returns `None` if `len` is zero or other platform-specific
+ /// conditions define the address as empty.
+ ///
+ /// # Safety
+ ///
+ /// Either valid address must have been written into `self.storage` and its
+ /// length written into `self.len`, or `self.len` must have been set to 0.
+ #[inline]
+ pub(crate) unsafe fn into_any_option(self) -> Option<SocketAddrAny> {
+ let len = bitcast!(self.len);
+ if read_sockaddr::sockaddr_nonempty(self.storage.as_ptr().cast(), len) {
+ Some(SocketAddrAny::new(self.storage, len))
+ } else {
+ None
+ }
+ }
+}
+
+/// A type that can hold any kind of socket address, as a safe abstraction for
+/// `sockaddr_storage`.
+///
+/// Socket addresses can be converted to `SocketAddrAny` via the [`From`] and
+/// [`Into`] traits. `SocketAddrAny` can be converted back to a specific socket
+/// address type with [`TryFrom`] and [`TryInto`]. These implementations return
+/// [`Errno::AFNOSUPPORT`] if the address family does not match the requested
+/// type.
+#[derive(Clone)]
+#[doc(alias = "sockaddr_storage")]
+pub struct SocketAddrAny {
+ // Invariants:
+ // - `len` is at least `size_of::<backend::c::sa_family_t>()`
+ // - `len` is at most `size_of::<SocketAddrStorage>()`
+ // - The first `len` bytes of `storage` are initialized.
+ pub(crate) len: NonZeroU32,
+ pub(crate) storage: MaybeUninit<SocketAddrStorage>,
+}
+
+impl SocketAddrAny {
+ /// Creates a socket address from `storage`, which is initialized for `len`
+ /// bytes.
+ ///
+ /// # Panics
+ ///
+ /// if `len` is smaller than the sockaddr header or larger than
+ /// `SocketAddrStorage`.
+ ///
+ /// # Safety
+ ///
+ /// - `storage` must contain a valid socket address.
+ /// - `len` bytes must be initialized.
+ #[inline]
+ pub const unsafe fn new(storage: MaybeUninit<SocketAddrStorage>, len: SocketAddrLen) -> Self {
+ assert!(len as usize >= size_of::<read_sockaddr::sockaddr_header>());
+ assert!(len as usize <= size_of::<SocketAddrStorage>());
+ let len = NonZeroU32::new_unchecked(len);
+ Self { storage, len }
+ }
+
+ /// Creates a socket address from reading from `ptr`, which points at `len`
+ /// initialized bytes.
+ ///
+ /// # Panics
+ ///
+ /// if `len` is smaller than the sockaddr header or larger than
+ /// `SocketAddrStorage`.
+ ///
+ /// # Safety
+ ///
+ /// - `ptr` must be a pointer to memory containing a valid socket address.
+ /// - `len` bytes must be initialized.
+ pub unsafe fn read(ptr: *const SocketAddrStorage, len: SocketAddrLen) -> Self {
+ assert!(len as usize >= size_of::<read_sockaddr::sockaddr_header>());
+ assert!(len as usize <= size_of::<SocketAddrStorage>());
+ let mut storage = MaybeUninit::<SocketAddrStorage>::uninit();
+ core::ptr::copy_nonoverlapping(
+ ptr.cast::<u8>(),
+ storage.as_mut_ptr().cast::<u8>(),
+ len as usize,
+ );
+ let len = NonZeroU32::new_unchecked(len);
+ Self { storage, len }
+ }
+
+ /// Gets the initialized part of the storage as bytes.
+ #[inline]
+ fn bytes(&self) -> &[u8] {
+ let len = self.len.get() as usize;
+ unsafe { core::slice::from_raw_parts(self.storage.as_ptr().cast(), len) }
+ }
+
+ /// Gets the address family of this socket address.
+ #[inline]
+ pub fn address_family(&self) -> AddressFamily {
+ // SAFETY: Our invariants maintain that the `sa_family` field is
+ // initialized.
+ unsafe {
+ AddressFamily::from_raw(crate::backend::net::read_sockaddr::read_sa_family(
+ self.storage.as_ptr().cast(),
+ ))
+ }
+ }
+
+ /// Returns a raw pointer to the sockaddr.
+ #[inline]
+ pub fn as_ptr(&self) -> *const SocketAddrStorage {
+ self.storage.as_ptr()
+ }
+
+ /// Returns a raw mutable pointer to the sockaddr.
+ #[inline]
+ pub fn as_mut_ptr(&mut self) -> *mut SocketAddrStorage {
+ self.storage.as_mut_ptr()
+ }
+
+ /// Returns the length of the encoded sockaddr.
+ #[inline]
+ pub fn addr_len(&self) -> SocketAddrLen {
+ self.len.get()
+ }
+}
+
+impl PartialEq<Self> for SocketAddrAny {
+ fn eq(&self, other: &Self) -> bool {
+ self.bytes() == other.bytes()
+ }
+}
+
+impl Eq for SocketAddrAny {}
+
+// This just forwards to another `partial_cmp`.
+#[allow(clippy::non_canonical_partial_ord_impl)]
+impl PartialOrd<Self> for SocketAddrAny {
+ fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
+ self.bytes().partial_cmp(other.bytes())
+ }
+}
+
+impl Ord for SocketAddrAny {
+ fn cmp(&self, other: &Self) -> core::cmp::Ordering {
+ self.bytes().cmp(other.bytes())
+ }
+}
+
+impl core::hash::Hash for SocketAddrAny {
+ fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
+ self.bytes().hash(state)
+ }
+}
+
+impl fmt::Debug for SocketAddrAny {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ match self.address_family() {
+ AddressFamily::INET => {
+ if let Ok(addr) = SocketAddrV4::try_from(self.clone()) {
+ return addr.fmt(f);
+ }
+ }
+ AddressFamily::INET6 => {
+ if let Ok(addr) = SocketAddrV6::try_from(self.clone()) {
+ return addr.fmt(f);
+ }
+ }
+ #[cfg(unix)]
+ AddressFamily::UNIX => {
+ if let Ok(addr) = SocketAddrUnix::try_from(self.clone()) {
+ return addr.fmt(f);
+ }
+ }
+ #[cfg(target_os = "linux")]
+ AddressFamily::XDP => {
+ if let Ok(addr) = crate::net::xdp::SocketAddrXdp::try_from(self.clone()) {
+ return addr.fmt(f);
+ }
+ }
+ #[cfg(linux_kernel)]
+ AddressFamily::NETLINK => {
+ if let Ok(addr) = crate::net::netlink::SocketAddrNetlink::try_from(self.clone()) {
+ return addr.fmt(f);
+ }
+ }
+ _ => {}
+ }
+
+ f.debug_struct("SocketAddrAny")
+ .field("address_family", &self.address_family())
+ .field("namelen", &self.addr_len())
+ .finish()
+ }
+}
+
+// SAFETY: `with_sockaddr` calls `f` with a pointer to its own storage.
+unsafe impl SocketAddrArg for SocketAddrAny {
+ unsafe fn with_sockaddr<R>(
+ &self,
+ f: impl FnOnce(*const SocketAddrOpaque, SocketAddrLen) -> R,
+ ) -> R {
+ f(self.as_ptr().cast(), self.addr_len())
+ }
+}
+
+impl From<SocketAddr> for SocketAddrAny {
+ #[inline]
+ fn from(from: SocketAddr) -> Self {
+ from.as_any()
+ }
+}
+
+impl TryFrom<SocketAddrAny> for SocketAddr {
+ type Error = Errno;
+
+ /// Convert if the address is an IPv4 or IPv6 address.
+ ///
+ /// Returns `Err(Errno::AFNOSUPPORT)` if the address family is not IPv4 or
+ /// IPv6.
+ #[inline]
+ fn try_from(value: SocketAddrAny) -> Result<Self, Self::Error> {
+ match value.address_family() {
+ AddressFamily::INET => read_sockaddr::read_sockaddr_v4(&value).map(SocketAddr::V4),
+ AddressFamily::INET6 => read_sockaddr::read_sockaddr_v6(&value).map(SocketAddr::V6),
+ _ => Err(Errno::AFNOSUPPORT),
+ }
+ }
+}
+
+impl From<SocketAddrV4> for SocketAddrAny {
+ #[inline]
+ fn from(from: SocketAddrV4) -> Self {
+ from.as_any()
+ }
+}
+
+impl TryFrom<SocketAddrAny> for SocketAddrV4 {
+ type Error = Errno;
+
+ /// Convert if the address is an IPv4 address.
+ ///
+ /// Returns `Err(Errno::AFNOSUPPORT)` if the address family is not IPv4.
+ #[inline]
+ fn try_from(value: SocketAddrAny) -> Result<Self, Self::Error> {
+ read_sockaddr::read_sockaddr_v4(&value)
+ }
+}
+
+impl From<SocketAddrV6> for SocketAddrAny {
+ #[inline]
+ fn from(from: SocketAddrV6) -> Self {
+ from.as_any()
+ }
+}
+
+impl TryFrom<SocketAddrAny> for SocketAddrV6 {
+ type Error = Errno;
+
+ /// Convert if the address is an IPv6 address.
+ ///
+ /// Returns `Err(Errno::AFNOSUPPORT)` if the address family is not IPv6.
+ #[inline]
+ fn try_from(value: SocketAddrAny) -> Result<Self, Self::Error> {
+ read_sockaddr::read_sockaddr_v6(&value)
+ }
+}
+
+#[cfg(unix)]
+impl From<SocketAddrUnix> for SocketAddrAny {
+ #[inline]
+ fn from(from: SocketAddrUnix) -> Self {
+ from.as_any()
+ }
+}
+
+#[cfg(unix)]
+impl TryFrom<SocketAddrAny> for SocketAddrUnix {
+ type Error = Errno;
+
+ /// Convert if the address is a Unix socket address.
+ ///
+ /// Returns `Err(Errno::AFNOSUPPORT)` if the address family is not Unix.
+ #[inline]
+ fn try_from(value: SocketAddrAny) -> Result<Self, Self::Error> {
+ read_sockaddr::read_sockaddr_unix(&value)
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ #[test]
+ fn any_read() {
+ let localhost = std::net::Ipv6Addr::LOCALHOST;
+ let addr = SocketAddrAny::from(SocketAddrV6::new(localhost, 7, 8, 9));
+ unsafe {
+ let same = SocketAddrAny::read(addr.as_ptr(), addr.addr_len());
+ assert_eq!(addr, same);
+ }
+ }
+}