rustix/backend/linux_raw/net/
read_sockaddr.rs

1//! The BSD sockets API requires us to read the `ss_family` field before we can
2//! interpret the rest of a `sockaddr` produced by the kernel.
3#![allow(unsafe_code)]
4
5use crate::backend::c;
6use crate::io;
7#[cfg(target_os = "linux")]
8use crate::net::xdp::{SockaddrXdpFlags, SocketAddrXdp};
9use crate::net::{Ipv4Addr, Ipv6Addr, SocketAddrAny, SocketAddrUnix, SocketAddrV4, SocketAddrV6};
10use core::mem::size_of;
11use core::slice;
12
13// This must match the header of `sockaddr`.
14#[repr(C)]
15struct sockaddr_header {
16    ss_family: u16,
17}
18
19/// Read the `ss_family` field from a socket address returned from the OS.
20///
21/// # Safety
22///
23/// `storage` must point to a valid socket address returned from the OS.
24#[inline]
25unsafe fn read_ss_family(storage: *const c::sockaddr) -> u16 {
26    // Assert that we know the layout of `sockaddr`.
27    let _ = c::sockaddr {
28        __storage: c::sockaddr_storage {
29            __bindgen_anon_1: linux_raw_sys::net::__kernel_sockaddr_storage__bindgen_ty_1 {
30                __bindgen_anon_1:
31                    linux_raw_sys::net::__kernel_sockaddr_storage__bindgen_ty_1__bindgen_ty_1 {
32                        ss_family: 0_u16,
33                        __data: [0; 126_usize],
34                    },
35            },
36        },
37    };
38
39    (*storage.cast::<sockaddr_header>()).ss_family
40}
41
42/// Set the `ss_family` field of a socket address to `AF_UNSPEC`, so that we
43/// can test for `AF_UNSPEC` to test whether it was stored to.
44#[inline]
45pub(crate) unsafe fn initialize_family_to_unspec(storage: *mut c::sockaddr) {
46    (*storage.cast::<sockaddr_header>()).ss_family = c::AF_UNSPEC as _;
47}
48
49/// Read a socket address encoded in a platform-specific format.
50///
51/// # Safety
52///
53/// `storage` must point to valid socket address storage.
54pub(crate) unsafe fn read_sockaddr(
55    storage: *const c::sockaddr,
56    len: usize,
57) -> io::Result<SocketAddrAny> {
58    let offsetof_sun_path = super::addr::offsetof_sun_path();
59
60    if len < size_of::<c::sa_family_t>() {
61        return Err(io::Errno::INVAL);
62    }
63    match read_ss_family(storage).into() {
64        c::AF_INET => {
65            if len < size_of::<c::sockaddr_in>() {
66                return Err(io::Errno::INVAL);
67            }
68            let decode = &*storage.cast::<c::sockaddr_in>();
69            Ok(SocketAddrAny::V4(SocketAddrV4::new(
70                Ipv4Addr::from(u32::from_be(decode.sin_addr.s_addr)),
71                u16::from_be(decode.sin_port),
72            )))
73        }
74        c::AF_INET6 => {
75            if len < size_of::<c::sockaddr_in6>() {
76                return Err(io::Errno::INVAL);
77            }
78            let decode = &*storage.cast::<c::sockaddr_in6>();
79            Ok(SocketAddrAny::V6(SocketAddrV6::new(
80                Ipv6Addr::from(decode.sin6_addr.in6_u.u6_addr8),
81                u16::from_be(decode.sin6_port),
82                u32::from_be(decode.sin6_flowinfo),
83                decode.sin6_scope_id,
84            )))
85        }
86        c::AF_UNIX => {
87            if len < offsetof_sun_path {
88                return Err(io::Errno::INVAL);
89            }
90            if len == offsetof_sun_path {
91                Ok(SocketAddrAny::Unix(SocketAddrUnix::new(&[][..])?))
92            } else {
93                let decode = &*storage.cast::<c::sockaddr_un>();
94
95                // On Linux check for Linux's [abstract namespace].
96                //
97                // [abstract namespace]: https://man7.org/linux/man-pages/man7/unix.7.html
98                if decode.sun_path[0] == 0 {
99                    let bytes = &decode.sun_path[1..len - offsetof_sun_path];
100
101                    // SAFETY: Convert `&[c_char]` to `&[u8]`.
102                    let bytes = slice::from_raw_parts(bytes.as_ptr().cast::<u8>(), bytes.len());
103
104                    return SocketAddrUnix::new_abstract_name(bytes).map(SocketAddrAny::Unix);
105                }
106
107                // Otherwise we expect a NUL-terminated filesystem path.
108                let bytes = &decode.sun_path[..len - 1 - offsetof_sun_path];
109
110                // SAFETY: Convert `&[c_char]` to `&[u8]`.
111                let bytes = slice::from_raw_parts(bytes.as_ptr().cast::<u8>(), bytes.len());
112
113                assert_eq!(decode.sun_path[len - 1 - offsetof_sun_path], 0);
114                Ok(SocketAddrAny::Unix(SocketAddrUnix::new(bytes)?))
115            }
116        }
117        #[cfg(target_os = "linux")]
118        c::AF_XDP => {
119            if len < size_of::<c::sockaddr_xdp>() {
120                return Err(io::Errno::INVAL);
121            }
122            let decode = &*storage.cast::<c::sockaddr_xdp>();
123            Ok(SocketAddrAny::Xdp(SocketAddrXdp::new(
124                SockaddrXdpFlags::from_bits_retain(decode.sxdp_flags),
125                u32::from_be(decode.sxdp_ifindex),
126                u32::from_be(decode.sxdp_queue_id),
127                u32::from_be(decode.sxdp_shared_umem_fd),
128            )))
129        }
130        _ => Err(io::Errno::NOTSUP),
131    }
132}
133
134/// Read an optional socket address returned from the OS.
135///
136/// # Safety
137///
138/// `storage` must point to a valid socket address returned from the OS.
139pub(crate) unsafe fn maybe_read_sockaddr_os(
140    storage: *const c::sockaddr,
141    len: usize,
142) -> Option<SocketAddrAny> {
143    if len == 0 {
144        None
145    } else {
146        Some(read_sockaddr_os(storage, len))
147    }
148}
149
150/// Read a socket address returned from the OS.
151///
152/// # Safety
153///
154/// `storage` must point to a valid socket address returned from the OS.
155pub(crate) unsafe fn read_sockaddr_os(storage: *const c::sockaddr, len: usize) -> SocketAddrAny {
156    let offsetof_sun_path = super::addr::offsetof_sun_path();
157
158    assert!(len >= size_of::<c::sa_family_t>());
159    match read_ss_family(storage).into() {
160        c::AF_INET => {
161            assert!(len >= size_of::<c::sockaddr_in>());
162            let decode = &*storage.cast::<c::sockaddr_in>();
163            SocketAddrAny::V4(SocketAddrV4::new(
164                Ipv4Addr::from(u32::from_be(decode.sin_addr.s_addr)),
165                u16::from_be(decode.sin_port),
166            ))
167        }
168        c::AF_INET6 => {
169            assert!(len >= size_of::<c::sockaddr_in6>());
170            let decode = &*storage.cast::<c::sockaddr_in6>();
171            SocketAddrAny::V6(SocketAddrV6::new(
172                Ipv6Addr::from(decode.sin6_addr.in6_u.u6_addr8),
173                u16::from_be(decode.sin6_port),
174                u32::from_be(decode.sin6_flowinfo),
175                decode.sin6_scope_id,
176            ))
177        }
178        c::AF_UNIX => {
179            assert!(len >= offsetof_sun_path);
180            if len == offsetof_sun_path {
181                SocketAddrAny::Unix(SocketAddrUnix::new(&[][..]).unwrap())
182            } else {
183                let decode = &*storage.cast::<c::sockaddr_un>();
184
185                // On Linux check for Linux's [abstract namespace].
186                //
187                // [abstract namespace]: https://man7.org/linux/man-pages/man7/unix.7.html
188                if decode.sun_path[0] == 0 {
189                    let bytes = &decode.sun_path[1..len - offsetof_sun_path];
190
191                    // SAFETY: Convert `&[c_char]` to `&[u8]`.
192                    let bytes = slice::from_raw_parts(bytes.as_ptr().cast::<u8>(), bytes.len());
193
194                    return SocketAddrAny::Unix(SocketAddrUnix::new_abstract_name(bytes).unwrap());
195                }
196
197                // Otherwise we expect a NUL-terminated filesystem path.
198                assert_eq!(decode.sun_path[len - 1 - offsetof_sun_path], 0);
199
200                let bytes = &decode.sun_path[..len - 1 - offsetof_sun_path];
201
202                // SAFETY: Convert `&[c_char]` to `&[u8]`.
203                let bytes = slice::from_raw_parts(bytes.as_ptr().cast::<u8>(), bytes.len());
204
205                SocketAddrAny::Unix(SocketAddrUnix::new(bytes).unwrap())
206            }
207        }
208        #[cfg(target_os = "linux")]
209        c::AF_XDP => {
210            assert!(len >= size_of::<c::sockaddr_xdp>());
211            let decode = &*storage.cast::<c::sockaddr_xdp>();
212            SocketAddrAny::Xdp(SocketAddrXdp::new(
213                SockaddrXdpFlags::from_bits_retain(decode.sxdp_flags),
214                u32::from_be(decode.sxdp_ifindex),
215                u32::from_be(decode.sxdp_queue_id),
216                u32::from_be(decode.sxdp_shared_umem_fd),
217            ))
218        }
219        other => unimplemented!("{:?}", other),
220    }
221}