diff options
Diffstat (limited to 'rust/kernel')
145 files changed, 44453 insertions, 456 deletions
diff --git a/rust/kernel/.gitignore b/rust/kernel/.gitignore new file mode 100644 index 000000000000..f636ad95aaf3 --- /dev/null +++ b/rust/kernel/.gitignore @@ -0,0 +1,5 @@ +# SPDX-License-Identifier: GPL-2.0 + +/generated_arch_static_branch_asm.rs +/generated_arch_warn_asm.rs +/generated_arch_reachable_asm.rs diff --git a/rust/kernel/acpi.rs b/rust/kernel/acpi.rs new file mode 100644 index 000000000000..9b8efa623130 --- /dev/null +++ b/rust/kernel/acpi.rs @@ -0,0 +1,65 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! Advanced Configuration and Power Interface abstractions. + +use crate::{ + bindings, + device_id::{RawDeviceId, RawDeviceIdIndex}, + prelude::*, +}; + +/// IdTable type for ACPI drivers. +pub type IdTable<T> = &'static dyn kernel::device_id::IdTable<DeviceId, T>; + +/// An ACPI device id. +#[repr(transparent)] +#[derive(Clone, Copy)] +pub struct DeviceId(bindings::acpi_device_id); + +// SAFETY: `DeviceId` is a `#[repr(transparent)]` wrapper of `acpi_device_id` and does not add +// additional invariants, so it's safe to transmute to `RawType`. +unsafe impl RawDeviceId for DeviceId { + type RawType = bindings::acpi_device_id; +} + +// SAFETY: `DRIVER_DATA_OFFSET` is the offset to the `driver_data` field. +unsafe impl RawDeviceIdIndex for DeviceId { + const DRIVER_DATA_OFFSET: usize = core::mem::offset_of!(bindings::acpi_device_id, driver_data); + + fn index(&self) -> usize { + self.0.driver_data + } +} + +impl DeviceId { + const ACPI_ID_LEN: usize = 16; + + /// Create a new device id from an ACPI 'id' string. + #[inline(always)] + pub const fn new(id: &'static CStr) -> Self { + let src = id.to_bytes_with_nul(); + build_assert!(src.len() <= Self::ACPI_ID_LEN, "ID exceeds 16 bytes"); + let mut acpi: bindings::acpi_device_id = pin_init::zeroed(); + let mut i = 0; + while i < src.len() { + acpi.id[i] = src[i]; + i += 1; + } + + Self(acpi) + } +} + +/// Create an ACPI `IdTable` with an "alias" for modpost. +#[macro_export] +macro_rules! acpi_device_table { + ($table_name:ident, $module_table_name:ident, $id_info_type: ty, $table_data: expr) => { + const $table_name: $crate::device_id::IdArray< + $crate::acpi::DeviceId, + $id_info_type, + { $table_data.len() }, + > = $crate::device_id::IdArray::new($table_data); + + $crate::module_device_table!("acpi", $module_table_name, $table_name); + }; +} diff --git a/rust/kernel/alloc.rs b/rust/kernel/alloc.rs new file mode 100644 index 000000000000..e38720349dcf --- /dev/null +++ b/rust/kernel/alloc.rs @@ -0,0 +1,269 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! Implementation of the kernel's memory allocation infrastructure. + +pub mod allocator; +pub mod kbox; +pub mod kvec; +pub mod layout; + +pub use self::kbox::Box; +pub use self::kbox::KBox; +pub use self::kbox::KVBox; +pub use self::kbox::VBox; + +pub use self::kvec::IntoIter; +pub use self::kvec::KVVec; +pub use self::kvec::KVec; +pub use self::kvec::VVec; +pub use self::kvec::Vec; + +/// Indicates an allocation error. +#[derive(Copy, Clone, PartialEq, Eq, Debug)] +pub struct AllocError; + +use crate::error::{code::EINVAL, Result}; +use core::{alloc::Layout, ptr::NonNull}; + +/// Flags to be used when allocating memory. +/// +/// They can be combined with the operators `|`, `&`, and `!`. +/// +/// Values can be used from the [`flags`] module. +#[derive(Clone, Copy, PartialEq)] +pub struct Flags(u32); + +impl Flags { + /// Get the raw representation of this flag. + pub(crate) fn as_raw(self) -> u32 { + self.0 + } + + /// Check whether `flags` is contained in `self`. + pub fn contains(self, flags: Flags) -> bool { + (self & flags) == flags + } +} + +impl core::ops::BitOr for Flags { + type Output = Self; + fn bitor(self, rhs: Self) -> Self::Output { + Self(self.0 | rhs.0) + } +} + +impl core::ops::BitAnd for Flags { + type Output = Self; + fn bitand(self, rhs: Self) -> Self::Output { + Self(self.0 & rhs.0) + } +} + +impl core::ops::Not for Flags { + type Output = Self; + fn not(self) -> Self::Output { + Self(!self.0) + } +} + +/// Allocation flags. +/// +/// These are meant to be used in functions that can allocate memory. +pub mod flags { + use super::Flags; + + /// Zeroes out the allocated memory. + /// + /// This is normally or'd with other flags. + pub const __GFP_ZERO: Flags = Flags(bindings::__GFP_ZERO); + + /// Allow the allocation to be in high memory. + /// + /// Allocations in high memory may not be mapped into the kernel's address space, so this can't + /// be used with `kmalloc` and other similar methods. + /// + /// This is normally or'd with other flags. + pub const __GFP_HIGHMEM: Flags = Flags(bindings::__GFP_HIGHMEM); + + /// Users can not sleep and need the allocation to succeed. + /// + /// A lower watermark is applied to allow access to "atomic reserves". The current + /// implementation doesn't support NMI and few other strict non-preemptive contexts (e.g. + /// `raw_spin_lock`). The same applies to [`GFP_NOWAIT`]. + pub const GFP_ATOMIC: Flags = Flags(bindings::GFP_ATOMIC); + + /// Typical for kernel-internal allocations. The caller requires `ZONE_NORMAL` or a lower zone + /// for direct access but can direct reclaim. + pub const GFP_KERNEL: Flags = Flags(bindings::GFP_KERNEL); + + /// The same as [`GFP_KERNEL`], except the allocation is accounted to kmemcg. + pub const GFP_KERNEL_ACCOUNT: Flags = Flags(bindings::GFP_KERNEL_ACCOUNT); + + /// For kernel allocations that should not stall for direct reclaim, start physical IO or + /// use any filesystem callback. It is very likely to fail to allocate memory, even for very + /// small allocations. + pub const GFP_NOWAIT: Flags = Flags(bindings::GFP_NOWAIT); + + /// Suppresses allocation failure reports. + /// + /// This is normally or'd with other flags. + pub const __GFP_NOWARN: Flags = Flags(bindings::__GFP_NOWARN); +} + +/// Non Uniform Memory Access (NUMA) node identifier. +#[derive(Clone, Copy, PartialEq)] +pub struct NumaNode(i32); + +impl NumaNode { + /// Create a new NUMA node identifier (non-negative integer). + /// + /// Returns [`EINVAL`] if a negative id or an id exceeding [`bindings::MAX_NUMNODES`] is + /// specified. + pub fn new(node: i32) -> Result<Self> { + // MAX_NUMNODES never exceeds 2**10 because NODES_SHIFT is 0..10. + if node < 0 || node >= bindings::MAX_NUMNODES as i32 { + return Err(EINVAL); + } + Ok(Self(node)) + } +} + +/// Specify necessary constant to pass the information to Allocator that the caller doesn't care +/// about the NUMA node to allocate memory from. +impl NumaNode { + /// No node preference. + pub const NO_NODE: NumaNode = NumaNode(bindings::NUMA_NO_NODE); +} + +/// The kernel's [`Allocator`] trait. +/// +/// An implementation of [`Allocator`] can allocate, re-allocate and free memory buffers described +/// via [`Layout`]. +/// +/// [`Allocator`] is designed to be implemented as a ZST; [`Allocator`] functions do not operate on +/// an object instance. +/// +/// In order to be able to support `#[derive(CoercePointee)]` later on, we need to avoid a design +/// that requires an `Allocator` to be instantiated, hence its functions must not contain any kind +/// of `self` parameter. +/// +/// # Safety +/// +/// - A memory allocation returned from an allocator must remain valid until it is explicitly freed. +/// +/// - Any pointer to a valid memory allocation must be valid to be passed to any other [`Allocator`] +/// function of the same type. +/// +/// - Implementers must ensure that all trait functions abide by the guarantees documented in the +/// `# Guarantees` sections. +pub unsafe trait Allocator { + /// The minimum alignment satisfied by all allocations from this allocator. + /// + /// # Guarantees + /// + /// Any pointer allocated by this allocator is guaranteed to be aligned to `MIN_ALIGN` even if + /// the requested layout has a smaller alignment. + const MIN_ALIGN: usize; + + /// Allocate memory based on `layout`, `flags` and `nid`. + /// + /// On success, returns a buffer represented as `NonNull<[u8]>` that satisfies the layout + /// constraints (i.e. minimum size and alignment as specified by `layout`). + /// + /// This function is equivalent to `realloc` when called with `None`. + /// + /// # Guarantees + /// + /// When the return value is `Ok(ptr)`, then `ptr` is + /// - valid for reads and writes for `layout.size()` bytes, until it is passed to + /// [`Allocator::free`] or [`Allocator::realloc`], + /// - aligned to `layout.align()`, + /// + /// Additionally, `Flags` are honored as documented in + /// <https://docs.kernel.org/core-api/mm-api.html#mm-api-gfp-flags>. + fn alloc(layout: Layout, flags: Flags, nid: NumaNode) -> Result<NonNull<[u8]>, AllocError> { + // SAFETY: Passing `None` to `realloc` is valid by its safety requirements and asks for a + // new memory allocation. + unsafe { Self::realloc(None, layout, Layout::new::<()>(), flags, nid) } + } + + /// Re-allocate an existing memory allocation to satisfy the requested `layout` and + /// a specific NUMA node request to allocate the memory for. + /// + /// Systems employing a Non Uniform Memory Access (NUMA) architecture contain collections of + /// hardware resources including processors, memory, and I/O buses, that comprise what is + /// commonly known as a NUMA node. + /// + /// `nid` stands for NUMA id, i. e. NUMA node identifier, which is a non-negative integer + /// if a node needs to be specified, or [`NumaNode::NO_NODE`] if the caller doesn't care. + /// + /// If the requested size is zero, `realloc` behaves equivalent to `free`. + /// + /// If the requested size is larger than the size of the existing allocation, a successful call + /// to `realloc` guarantees that the new or grown buffer has at least `Layout::size` bytes, but + /// may also be larger. + /// + /// If the requested size is smaller than the size of the existing allocation, `realloc` may or + /// may not shrink the buffer; this is implementation specific to the allocator. + /// + /// On allocation failure, the existing buffer, if any, remains valid. + /// + /// The buffer is represented as `NonNull<[u8]>`. + /// + /// # Safety + /// + /// - If `ptr == Some(p)`, then `p` must point to an existing and valid memory allocation + /// created by this [`Allocator`]; if `old_layout` is zero-sized `p` does not need to be a + /// pointer returned by this [`Allocator`]. + /// - `ptr` is allowed to be `None`; in this case a new memory allocation is created and + /// `old_layout` is ignored. + /// - `old_layout` must match the `Layout` the allocation has been created with. + /// + /// # Guarantees + /// + /// This function has the same guarantees as [`Allocator::alloc`]. When `ptr == Some(p)`, then + /// it additionally guarantees that: + /// - the contents of the memory pointed to by `p` are preserved up to the lesser of the new + /// and old size, i.e. `ret_ptr[0..min(layout.size(), old_layout.size())] == + /// p[0..min(layout.size(), old_layout.size())]`. + /// - when the return value is `Err(AllocError)`, then `ptr` is still valid. + unsafe fn realloc( + ptr: Option<NonNull<u8>>, + layout: Layout, + old_layout: Layout, + flags: Flags, + nid: NumaNode, + ) -> Result<NonNull<[u8]>, AllocError>; + + /// Free an existing memory allocation. + /// + /// # Safety + /// + /// - `ptr` must point to an existing and valid memory allocation created by this [`Allocator`]; + /// if `old_layout` is zero-sized `p` does not need to be a pointer returned by this + /// [`Allocator`]. + /// - `layout` must match the `Layout` the allocation has been created with. + /// - The memory allocation at `ptr` must never again be read from or written to. + unsafe fn free(ptr: NonNull<u8>, layout: Layout) { + // SAFETY: The caller guarantees that `ptr` points at a valid allocation created by this + // allocator. We are passing a `Layout` with the smallest possible alignment, so it is + // smaller than or equal to the alignment previously used with this allocation. + let _ = unsafe { + Self::realloc( + Some(ptr), + Layout::new::<()>(), + layout, + Flags(0), + NumaNode::NO_NODE, + ) + }; + } +} + +/// Returns a properly aligned dangling pointer from the given `layout`. +pub(crate) fn dangling_from_layout(layout: Layout) -> NonNull<u8> { + let ptr = layout.align() as *mut u8; + + // SAFETY: `layout.align()` (and hence `ptr`) is guaranteed to be non-zero. + unsafe { NonNull::new_unchecked(ptr) } +} diff --git a/rust/kernel/alloc/allocator.rs b/rust/kernel/alloc/allocator.rs new file mode 100644 index 000000000000..63bfb91b3671 --- /dev/null +++ b/rust/kernel/alloc/allocator.rs @@ -0,0 +1,308 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! Allocator support. +//! +//! Documentation for the kernel's memory allocators can found in the "Memory Allocation Guide" +//! linked below. For instance, this includes the concept of "get free page" (GFP) flags and the +//! typical application of the different kernel allocators. +//! +//! Reference: <https://docs.kernel.org/core-api/memory-allocation.html> + +use super::Flags; +use core::alloc::Layout; +use core::ptr; +use core::ptr::NonNull; + +use crate::alloc::{AllocError, Allocator, NumaNode}; +use crate::bindings; +use crate::page; + +const ARCH_KMALLOC_MINALIGN: usize = bindings::ARCH_KMALLOC_MINALIGN; + +mod iter; +pub use self::iter::VmallocPageIter; + +/// The contiguous kernel allocator. +/// +/// `Kmalloc` is typically used for physically contiguous allocations up to page size, but also +/// supports larger allocations up to `bindings::KMALLOC_MAX_SIZE`, which is hardware specific. +/// +/// For more details see [self]. +pub struct Kmalloc; + +/// The virtually contiguous kernel allocator. +/// +/// `Vmalloc` allocates pages from the page level allocator and maps them into the contiguous kernel +/// virtual space. It is typically used for large allocations. The memory allocated with this +/// allocator is not physically contiguous. +/// +/// For more details see [self]. +pub struct Vmalloc; + +/// The kvmalloc kernel allocator. +/// +/// `KVmalloc` attempts to allocate memory with `Kmalloc` first, but falls back to `Vmalloc` upon +/// failure. This allocator is typically used when the size for the requested allocation is not +/// known and may exceed the capabilities of `Kmalloc`. +/// +/// For more details see [self]. +pub struct KVmalloc; + +/// # Invariants +/// +/// One of the following: `krealloc_node_align`, `vrealloc_node_align`, `kvrealloc_node_align`. +struct ReallocFunc( + unsafe extern "C" fn( + *const crate::ffi::c_void, + usize, + crate::ffi::c_ulong, + u32, + crate::ffi::c_int, + ) -> *mut crate::ffi::c_void, +); + +impl ReallocFunc { + // INVARIANT: `krealloc_node_align` satisfies the type invariants. + const KREALLOC: Self = Self(bindings::krealloc_node_align); + + // INVARIANT: `vrealloc_node_align` satisfies the type invariants. + const VREALLOC: Self = Self(bindings::vrealloc_node_align); + + // INVARIANT: `kvrealloc_node_align` satisfies the type invariants. + const KVREALLOC: Self = Self(bindings::kvrealloc_node_align); + + /// # Safety + /// + /// This method has the same safety requirements as [`Allocator::realloc`]. + /// + /// # Guarantees + /// + /// This method has the same guarantees as `Allocator::realloc`. Additionally + /// - it accepts any pointer to a valid memory allocation allocated by this function. + /// - memory allocated by this function remains valid until it is passed to this function. + #[inline] + unsafe fn call( + &self, + ptr: Option<NonNull<u8>>, + layout: Layout, + old_layout: Layout, + flags: Flags, + nid: NumaNode, + ) -> Result<NonNull<[u8]>, AllocError> { + let size = layout.size(); + let ptr = match ptr { + Some(ptr) => { + if old_layout.size() == 0 { + ptr::null() + } else { + ptr.as_ptr() + } + } + None => ptr::null(), + }; + + // SAFETY: + // - `self.0` is one of `krealloc`, `vrealloc`, `kvrealloc` and thus only requires that + // `ptr` is NULL or valid. + // - `ptr` is either NULL or valid by the safety requirements of this function. + // + // GUARANTEE: + // - `self.0` is one of `krealloc`, `vrealloc`, `kvrealloc`. + // - Those functions provide the guarantees of this function. + let raw_ptr = unsafe { + // If `size == 0` and `ptr != NULL` the memory behind the pointer is freed. + self.0(ptr.cast(), size, layout.align(), flags.0, nid.0).cast() + }; + + let ptr = if size == 0 { + crate::alloc::dangling_from_layout(layout) + } else { + NonNull::new(raw_ptr).ok_or(AllocError)? + }; + + Ok(NonNull::slice_from_raw_parts(ptr, size)) + } +} + +impl Kmalloc { + /// Returns a [`Layout`] that makes [`Kmalloc`] fulfill the requested size and alignment of + /// `layout`. + pub fn aligned_layout(layout: Layout) -> Layout { + // Note that `layout.size()` (after padding) is guaranteed to be a multiple of + // `layout.align()` which together with the slab guarantees means that `Kmalloc` will return + // a properly aligned object (see comments in `kmalloc()` for more information). + layout.pad_to_align() + } +} + +// SAFETY: `realloc` delegates to `ReallocFunc::call`, which guarantees that +// - memory remains valid until it is explicitly freed, +// - passing a pointer to a valid memory allocation is OK, +// - `realloc` satisfies the guarantees, since `ReallocFunc::call` has the same. +unsafe impl Allocator for Kmalloc { + const MIN_ALIGN: usize = ARCH_KMALLOC_MINALIGN; + + #[inline] + unsafe fn realloc( + ptr: Option<NonNull<u8>>, + layout: Layout, + old_layout: Layout, + flags: Flags, + nid: NumaNode, + ) -> Result<NonNull<[u8]>, AllocError> { + let layout = Kmalloc::aligned_layout(layout); + + // SAFETY: `ReallocFunc::call` has the same safety requirements as `Allocator::realloc`. + unsafe { ReallocFunc::KREALLOC.call(ptr, layout, old_layout, flags, nid) } + } +} + +impl Vmalloc { + /// Convert a pointer to a [`Vmalloc`] allocation to a [`page::BorrowedPage`]. + /// + /// # Examples + /// + /// ``` + /// # use core::ptr::{NonNull, from_mut}; + /// # use kernel::{page, prelude::*}; + /// use kernel::alloc::allocator::Vmalloc; + /// + /// let mut vbox = VBox::<[u8; page::PAGE_SIZE]>::new_uninit(GFP_KERNEL)?; + /// + /// { + /// // SAFETY: By the type invariant of `Box` the inner pointer of `vbox` is non-null. + /// let ptr = unsafe { NonNull::new_unchecked(from_mut(&mut *vbox)) }; + /// + /// // SAFETY: + /// // `ptr` is a valid pointer to a `Vmalloc` allocation. + /// // `ptr` is valid for the entire lifetime of `page`. + /// let page = unsafe { Vmalloc::to_page(ptr.cast()) }; + /// + /// // SAFETY: There is no concurrent read or write to the same page. + /// unsafe { page.fill_zero_raw(0, page::PAGE_SIZE)? }; + /// } + /// # Ok::<(), Error>(()) + /// ``` + /// + /// # Safety + /// + /// - `ptr` must be a valid pointer to a [`Vmalloc`] allocation. + /// - `ptr` must remain valid for the entire duration of `'a`. + pub unsafe fn to_page<'a>(ptr: NonNull<u8>) -> page::BorrowedPage<'a> { + // SAFETY: `ptr` is a valid pointer to `Vmalloc` memory. + let page = unsafe { bindings::vmalloc_to_page(ptr.as_ptr().cast()) }; + + // SAFETY: `vmalloc_to_page` returns a valid pointer to a `struct page` for a valid pointer + // to `Vmalloc` memory. + let page = unsafe { NonNull::new_unchecked(page) }; + + // SAFETY: + // - `page` is a valid pointer to a `struct page`, given that by the safety requirements of + // this function `ptr` is a valid pointer to a `Vmalloc` allocation. + // - By the safety requirements of this function `ptr` is valid for the entire lifetime of + // `'a`. + unsafe { page::BorrowedPage::from_raw(page) } + } +} + +// SAFETY: `realloc` delegates to `ReallocFunc::call`, which guarantees that +// - memory remains valid until it is explicitly freed, +// - passing a pointer to a valid memory allocation is OK, +// - `realloc` satisfies the guarantees, since `ReallocFunc::call` has the same. +unsafe impl Allocator for Vmalloc { + const MIN_ALIGN: usize = kernel::page::PAGE_SIZE; + + #[inline] + unsafe fn realloc( + ptr: Option<NonNull<u8>>, + layout: Layout, + old_layout: Layout, + flags: Flags, + nid: NumaNode, + ) -> Result<NonNull<[u8]>, AllocError> { + // SAFETY: If not `None`, `ptr` is guaranteed to point to valid memory, which was previously + // allocated with this `Allocator`. + unsafe { ReallocFunc::VREALLOC.call(ptr, layout, old_layout, flags, nid) } + } +} + +// SAFETY: `realloc` delegates to `ReallocFunc::call`, which guarantees that +// - memory remains valid until it is explicitly freed, +// - passing a pointer to a valid memory allocation is OK, +// - `realloc` satisfies the guarantees, since `ReallocFunc::call` has the same. +unsafe impl Allocator for KVmalloc { + const MIN_ALIGN: usize = ARCH_KMALLOC_MINALIGN; + + #[inline] + unsafe fn realloc( + ptr: Option<NonNull<u8>>, + layout: Layout, + old_layout: Layout, + flags: Flags, + nid: NumaNode, + ) -> Result<NonNull<[u8]>, AllocError> { + // `KVmalloc` may use the `Kmalloc` backend, hence we have to enforce a `Kmalloc` + // compatible layout. + let layout = Kmalloc::aligned_layout(layout); + + // SAFETY: If not `None`, `ptr` is guaranteed to point to valid memory, which was previously + // allocated with this `Allocator`. + unsafe { ReallocFunc::KVREALLOC.call(ptr, layout, old_layout, flags, nid) } + } +} + +#[macros::kunit_tests(rust_allocator)] +mod tests { + use super::*; + use core::mem::MaybeUninit; + use kernel::prelude::*; + + #[test] + fn test_alignment() -> Result { + const TEST_SIZE: usize = 1024; + const TEST_LARGE_ALIGN_SIZE: usize = kernel::page::PAGE_SIZE * 4; + + // These two structs are used to test allocating aligned memory. + // they don't need to be accessed, so they're marked as dead_code. + #[expect(dead_code)] + #[repr(align(128))] + struct Blob([u8; TEST_SIZE]); + #[expect(dead_code)] + #[repr(align(8192))] + struct LargeAlignBlob([u8; TEST_LARGE_ALIGN_SIZE]); + + struct TestAlign<T, A: Allocator>(Box<MaybeUninit<T>, A>); + impl<T, A: Allocator> TestAlign<T, A> { + fn new() -> Result<Self> { + Ok(Self(Box::<_, A>::new_uninit(GFP_KERNEL)?)) + } + + fn is_aligned_to(&self, align: usize) -> bool { + assert!(align.is_power_of_two()); + + let addr = self.0.as_ptr() as usize; + addr & (align - 1) == 0 + } + } + + let ta = TestAlign::<Blob, Kmalloc>::new()?; + assert!(ta.is_aligned_to(128)); + + let ta = TestAlign::<LargeAlignBlob, Kmalloc>::new()?; + assert!(ta.is_aligned_to(8192)); + + let ta = TestAlign::<Blob, Vmalloc>::new()?; + assert!(ta.is_aligned_to(128)); + + let ta = TestAlign::<LargeAlignBlob, Vmalloc>::new()?; + assert!(ta.is_aligned_to(8192)); + + let ta = TestAlign::<Blob, KVmalloc>::new()?; + assert!(ta.is_aligned_to(128)); + + let ta = TestAlign::<LargeAlignBlob, KVmalloc>::new()?; + assert!(ta.is_aligned_to(8192)); + + Ok(()) + } +} diff --git a/rust/kernel/alloc/allocator/iter.rs b/rust/kernel/alloc/allocator/iter.rs new file mode 100644 index 000000000000..5759f86029b7 --- /dev/null +++ b/rust/kernel/alloc/allocator/iter.rs @@ -0,0 +1,102 @@ +// SPDX-License-Identifier: GPL-2.0 + +use super::Vmalloc; +use crate::page; +use core::marker::PhantomData; +use core::ptr::NonNull; + +/// An [`Iterator`] of [`page::BorrowedPage`] items owned by a [`Vmalloc`] allocation. +/// +/// # Guarantees +/// +/// The pages iterated by the [`Iterator`] appear in the order as they are mapped in the CPU's +/// virtual address space ascendingly. +/// +/// # Invariants +/// +/// - `buf` is a valid and [`page::PAGE_SIZE`] aligned pointer into a [`Vmalloc`] allocation. +/// - `size` is the number of bytes from `buf` until the end of the [`Vmalloc`] allocation `buf` +/// points to. +pub struct VmallocPageIter<'a> { + /// The base address of the [`Vmalloc`] buffer. + buf: NonNull<u8>, + /// The size of the buffer pointed to by `buf` in bytes. + size: usize, + /// The current page index of the [`Iterator`]. + index: usize, + _p: PhantomData<page::BorrowedPage<'a>>, +} + +impl<'a> Iterator for VmallocPageIter<'a> { + type Item = page::BorrowedPage<'a>; + + fn next(&mut self) -> Option<Self::Item> { + let offset = self.index.checked_mul(page::PAGE_SIZE)?; + + // Even though `self.size()` may be smaller than `Self::page_count() * page::PAGE_SIZE`, it + // is always a number between `(Self::page_count() - 1) * page::PAGE_SIZE` and + // `Self::page_count() * page::PAGE_SIZE`, hence the check below is sufficient. + if offset < self.size() { + self.index += 1; + } else { + return None; + } + + // TODO: Use `NonNull::add()` instead, once the minimum supported compiler version is + // bumped to 1.80 or later. + // + // SAFETY: `offset` is in the interval `[0, (self.page_count() - 1) * page::PAGE_SIZE]`, + // hence the resulting pointer is guaranteed to be within the same allocation. + let ptr = unsafe { self.buf.as_ptr().add(offset) }; + + // SAFETY: `ptr` is guaranteed to be non-null given that it is derived from `self.buf`. + let ptr = unsafe { NonNull::new_unchecked(ptr) }; + + // SAFETY: + // - `ptr` is a valid pointer to a `Vmalloc` allocation. + // - `ptr` is valid for the duration of `'a`. + Some(unsafe { Vmalloc::to_page(ptr) }) + } + + fn size_hint(&self) -> (usize, Option<usize>) { + let remaining = self.page_count().saturating_sub(self.index); + + (remaining, Some(remaining)) + } +} + +impl<'a> VmallocPageIter<'a> { + /// Creates a new [`VmallocPageIter`] instance. + /// + /// # Safety + /// + /// - `buf` must be a [`page::PAGE_SIZE`] aligned pointer into a [`Vmalloc`] allocation. + /// - `buf` must be valid for at least the lifetime of `'a`. + /// - `size` must be the number of bytes from `buf` until the end of the [`Vmalloc`] allocation + /// `buf` points to. + pub unsafe fn new(buf: NonNull<u8>, size: usize) -> Self { + // INVARIANT: By the safety requirements, `buf` is a valid and `page::PAGE_SIZE` aligned + // pointer into a [`Vmalloc`] allocation. + Self { + buf, + size, + index: 0, + _p: PhantomData, + } + } + + /// Returns the size of the backing [`Vmalloc`] allocation in bytes. + /// + /// Note that this is the size the [`Vmalloc`] allocation has been allocated with. Hence, this + /// number may be smaller than `[`Self::page_count`] * [`page::PAGE_SIZE`]`. + #[inline] + pub fn size(&self) -> usize { + self.size + } + + /// Returns the number of pages owned by the backing [`Vmalloc`] allocation. + #[inline] + pub fn page_count(&self) -> usize { + self.size().div_ceil(page::PAGE_SIZE) + } +} diff --git a/rust/kernel/alloc/kbox.rs b/rust/kernel/alloc/kbox.rs new file mode 100644 index 000000000000..622b3529edfc --- /dev/null +++ b/rust/kernel/alloc/kbox.rs @@ -0,0 +1,720 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! Implementation of [`Box`]. + +#[allow(unused_imports)] // Used in doc comments. +use super::allocator::{KVmalloc, Kmalloc, Vmalloc, VmallocPageIter}; +use super::{AllocError, Allocator, Flags, NumaNode}; +use core::alloc::Layout; +use core::borrow::{Borrow, BorrowMut}; +use core::marker::PhantomData; +use core::mem::ManuallyDrop; +use core::mem::MaybeUninit; +use core::ops::{Deref, DerefMut}; +use core::pin::Pin; +use core::ptr::NonNull; +use core::result::Result; + +use crate::ffi::c_void; +use crate::fmt; +use crate::init::InPlaceInit; +use crate::page::AsPageIter; +use crate::types::ForeignOwnable; +use pin_init::{InPlaceWrite, Init, PinInit, ZeroableOption}; + +/// The kernel's [`Box`] type -- a heap allocation for a single value of type `T`. +/// +/// This is the kernel's version of the Rust stdlib's `Box`. There are several differences, +/// for example no `noalias` attribute is emitted and partially moving out of a `Box` is not +/// supported. There are also several API differences, e.g. `Box` always requires an [`Allocator`] +/// implementation to be passed as generic, page [`Flags`] when allocating memory and all functions +/// that may allocate memory are fallible. +/// +/// `Box` works with any of the kernel's allocators, e.g. [`Kmalloc`], [`Vmalloc`] or [`KVmalloc`]. +/// There are aliases for `Box` with these allocators ([`KBox`], [`VBox`], [`KVBox`]). +/// +/// When dropping a [`Box`], the value is also dropped and the heap memory is automatically freed. +/// +/// # Examples +/// +/// ``` +/// let b = KBox::<u64>::new(24_u64, GFP_KERNEL)?; +/// +/// assert_eq!(*b, 24_u64); +/// # Ok::<(), Error>(()) +/// ``` +/// +/// ``` +/// # use kernel::bindings; +/// const SIZE: usize = bindings::KMALLOC_MAX_SIZE as usize + 1; +/// struct Huge([u8; SIZE]); +/// +/// assert!(KBox::<Huge>::new_uninit(GFP_KERNEL | __GFP_NOWARN).is_err()); +/// ``` +/// +/// ``` +/// # use kernel::bindings; +/// const SIZE: usize = bindings::KMALLOC_MAX_SIZE as usize + 1; +/// struct Huge([u8; SIZE]); +/// +/// assert!(KVBox::<Huge>::new_uninit(GFP_KERNEL).is_ok()); +/// ``` +/// +/// [`Box`]es can also be used to store trait objects by coercing their type: +/// +/// ``` +/// trait FooTrait {} +/// +/// struct FooStruct; +/// impl FooTrait for FooStruct {} +/// +/// let _ = KBox::new(FooStruct, GFP_KERNEL)? as KBox<dyn FooTrait>; +/// # Ok::<(), Error>(()) +/// ``` +/// +/// # Invariants +/// +/// `self.0` is always properly aligned and either points to memory allocated with `A` or, for +/// zero-sized types, is a dangling, well aligned pointer. +#[repr(transparent)] +#[cfg_attr(CONFIG_RUSTC_HAS_COERCE_POINTEE, derive(core::marker::CoercePointee))] +pub struct Box<#[cfg_attr(CONFIG_RUSTC_HAS_COERCE_POINTEE, pointee)] T: ?Sized, A: Allocator>( + NonNull<T>, + PhantomData<A>, +); + +// This is to allow coercion from `Box<T, A>` to `Box<U, A>` if `T` can be converted to the +// dynamically-sized type (DST) `U`. +#[cfg(not(CONFIG_RUSTC_HAS_COERCE_POINTEE))] +impl<T, U, A> core::ops::CoerceUnsized<Box<U, A>> for Box<T, A> +where + T: ?Sized + core::marker::Unsize<U>, + U: ?Sized, + A: Allocator, +{ +} + +// This is to allow `Box<U, A>` to be dispatched on when `Box<T, A>` can be coerced into `Box<U, +// A>`. +#[cfg(not(CONFIG_RUSTC_HAS_COERCE_POINTEE))] +impl<T, U, A> core::ops::DispatchFromDyn<Box<U, A>> for Box<T, A> +where + T: ?Sized + core::marker::Unsize<U>, + U: ?Sized, + A: Allocator, +{ +} + +/// Type alias for [`Box`] with a [`Kmalloc`] allocator. +/// +/// # Examples +/// +/// ``` +/// let b = KBox::new(24_u64, GFP_KERNEL)?; +/// +/// assert_eq!(*b, 24_u64); +/// # Ok::<(), Error>(()) +/// ``` +pub type KBox<T> = Box<T, super::allocator::Kmalloc>; + +/// Type alias for [`Box`] with a [`Vmalloc`] allocator. +/// +/// # Examples +/// +/// ``` +/// let b = VBox::new(24_u64, GFP_KERNEL)?; +/// +/// assert_eq!(*b, 24_u64); +/// # Ok::<(), Error>(()) +/// ``` +pub type VBox<T> = Box<T, super::allocator::Vmalloc>; + +/// Type alias for [`Box`] with a [`KVmalloc`] allocator. +/// +/// # Examples +/// +/// ``` +/// let b = KVBox::new(24_u64, GFP_KERNEL)?; +/// +/// assert_eq!(*b, 24_u64); +/// # Ok::<(), Error>(()) +/// ``` +pub type KVBox<T> = Box<T, super::allocator::KVmalloc>; + +// SAFETY: All zeros is equivalent to `None` (option layout optimization guarantee: +// <https://doc.rust-lang.org/stable/std/option/index.html#representation>). +unsafe impl<T, A: Allocator> ZeroableOption for Box<T, A> {} + +// SAFETY: `Box` is `Send` if `T` is `Send` because the `Box` owns a `T`. +unsafe impl<T, A> Send for Box<T, A> +where + T: Send + ?Sized, + A: Allocator, +{ +} + +// SAFETY: `Box` is `Sync` if `T` is `Sync` because the `Box` owns a `T`. +unsafe impl<T, A> Sync for Box<T, A> +where + T: Sync + ?Sized, + A: Allocator, +{ +} + +impl<T, A> Box<T, A> +where + T: ?Sized, + A: Allocator, +{ + /// Creates a new `Box<T, A>` from a raw pointer. + /// + /// # Safety + /// + /// For non-ZSTs, `raw` must point at an allocation allocated with `A` that is sufficiently + /// aligned for and holds a valid `T`. The caller passes ownership of the allocation to the + /// `Box`. + /// + /// For ZSTs, `raw` must be a dangling, well aligned pointer. + #[inline] + pub const unsafe fn from_raw(raw: *mut T) -> Self { + // INVARIANT: Validity of `raw` is guaranteed by the safety preconditions of this function. + // SAFETY: By the safety preconditions of this function, `raw` is not a NULL pointer. + Self(unsafe { NonNull::new_unchecked(raw) }, PhantomData) + } + + /// Consumes the `Box<T, A>` and returns a raw pointer. + /// + /// This will not run the destructor of `T` and for non-ZSTs the allocation will stay alive + /// indefinitely. Use [`Box::from_raw`] to recover the [`Box`], drop the value and free the + /// allocation, if any. + /// + /// # Examples + /// + /// ``` + /// let x = KBox::new(24, GFP_KERNEL)?; + /// let ptr = KBox::into_raw(x); + /// // SAFETY: `ptr` comes from a previous call to `KBox::into_raw`. + /// let x = unsafe { KBox::from_raw(ptr) }; + /// + /// assert_eq!(*x, 24); + /// # Ok::<(), Error>(()) + /// ``` + #[inline] + pub fn into_raw(b: Self) -> *mut T { + ManuallyDrop::new(b).0.as_ptr() + } + + /// Consumes and leaks the `Box<T, A>` and returns a mutable reference. + /// + /// See [`Box::into_raw`] for more details. + #[inline] + pub fn leak<'a>(b: Self) -> &'a mut T { + // SAFETY: `Box::into_raw` always returns a properly aligned and dereferenceable pointer + // which points to an initialized instance of `T`. + unsafe { &mut *Box::into_raw(b) } + } +} + +impl<T, A> Box<MaybeUninit<T>, A> +where + A: Allocator, +{ + /// Converts a `Box<MaybeUninit<T>, A>` to a `Box<T, A>`. + /// + /// It is undefined behavior to call this function while the value inside of `b` is not yet + /// fully initialized. + /// + /// # Safety + /// + /// Callers must ensure that the value inside of `b` is in an initialized state. + pub unsafe fn assume_init(self) -> Box<T, A> { + let raw = Self::into_raw(self); + + // SAFETY: `raw` comes from a previous call to `Box::into_raw`. By the safety requirements + // of this function, the value inside the `Box` is in an initialized state. Hence, it is + // safe to reconstruct the `Box` as `Box<T, A>`. + unsafe { Box::from_raw(raw.cast()) } + } + + /// Writes the value and converts to `Box<T, A>`. + pub fn write(mut self, value: T) -> Box<T, A> { + (*self).write(value); + + // SAFETY: We've just initialized `b`'s value. + unsafe { self.assume_init() } + } +} + +impl<T, A> Box<T, A> +where + A: Allocator, +{ + /// Creates a new `Box<T, A>` and initializes its contents with `x`. + /// + /// New memory is allocated with `A`. The allocation may fail, in which case an error is + /// returned. For ZSTs no memory is allocated. + pub fn new(x: T, flags: Flags) -> Result<Self, AllocError> { + let b = Self::new_uninit(flags)?; + Ok(Box::write(b, x)) + } + + /// Creates a new `Box<T, A>` with uninitialized contents. + /// + /// New memory is allocated with `A`. The allocation may fail, in which case an error is + /// returned. For ZSTs no memory is allocated. + /// + /// # Examples + /// + /// ``` + /// let b = KBox::<u64>::new_uninit(GFP_KERNEL)?; + /// let b = KBox::write(b, 24); + /// + /// assert_eq!(*b, 24_u64); + /// # Ok::<(), Error>(()) + /// ``` + pub fn new_uninit(flags: Flags) -> Result<Box<MaybeUninit<T>, A>, AllocError> { + let layout = Layout::new::<MaybeUninit<T>>(); + let ptr = A::alloc(layout, flags, NumaNode::NO_NODE)?; + + // INVARIANT: `ptr` is either a dangling pointer or points to memory allocated with `A`, + // which is sufficient in size and alignment for storing a `T`. + Ok(Box(ptr.cast(), PhantomData)) + } + + /// Constructs a new `Pin<Box<T, A>>`. If `T` does not implement [`Unpin`], then `x` will be + /// pinned in memory and can't be moved. + #[inline] + pub fn pin(x: T, flags: Flags) -> Result<Pin<Box<T, A>>, AllocError> + where + A: 'static, + { + Ok(Self::new(x, flags)?.into()) + } + + /// Construct a pinned slice of elements `Pin<Box<[T], A>>`. + /// + /// This is a convenient means for creation of e.g. slices of structrures containing spinlocks + /// or mutexes. + /// + /// # Examples + /// + /// ``` + /// use kernel::sync::{new_spinlock, SpinLock}; + /// + /// struct Inner { + /// a: u32, + /// b: u32, + /// } + /// + /// #[pin_data] + /// struct Example { + /// c: u32, + /// #[pin] + /// d: SpinLock<Inner>, + /// } + /// + /// impl Example { + /// fn new() -> impl PinInit<Self, Error> { + /// try_pin_init!(Self { + /// c: 10, + /// d <- new_spinlock!(Inner { a: 20, b: 30 }), + /// }) + /// } + /// } + /// + /// // Allocate a boxed slice of 10 `Example`s. + /// let s = KBox::pin_slice( + /// | _i | Example::new(), + /// 10, + /// GFP_KERNEL + /// )?; + /// + /// assert_eq!(s[5].c, 10); + /// assert_eq!(s[3].d.lock().a, 20); + /// # Ok::<(), Error>(()) + /// ``` + pub fn pin_slice<Func, Item, E>( + mut init: Func, + len: usize, + flags: Flags, + ) -> Result<Pin<Box<[T], A>>, E> + where + Func: FnMut(usize) -> Item, + Item: PinInit<T, E>, + E: From<AllocError>, + { + let mut buffer = super::Vec::<T, A>::with_capacity(len, flags)?; + for i in 0..len { + let ptr = buffer.spare_capacity_mut().as_mut_ptr().cast(); + // SAFETY: + // - `ptr` is a valid pointer to uninitialized memory. + // - `ptr` is not used if an error is returned. + // - `ptr` won't be moved until it is dropped, i.e. it is pinned. + unsafe { init(i).__pinned_init(ptr)? }; + + // SAFETY: + // - `i + 1 <= len`, hence we don't exceed the capacity, due to the call to + // `with_capacity()` above. + // - The new value at index buffer.len() + 1 is the only element being added here, and + // it has been initialized above by `init(i).__pinned_init(ptr)`. + unsafe { buffer.inc_len(1) }; + } + + let (ptr, _, _) = buffer.into_raw_parts(); + let slice = core::ptr::slice_from_raw_parts_mut(ptr, len); + + // SAFETY: `slice` points to an allocation allocated with `A` (`buffer`) and holds a valid + // `[T]`. + Ok(Pin::from(unsafe { Box::from_raw(slice) })) + } + + /// Convert a [`Box<T,A>`] to a [`Pin<Box<T,A>>`]. If `T` does not implement + /// [`Unpin`], then `x` will be pinned in memory and can't be moved. + pub fn into_pin(this: Self) -> Pin<Self> { + this.into() + } + + /// Forgets the contents (does not run the destructor), but keeps the allocation. + fn forget_contents(this: Self) -> Box<MaybeUninit<T>, A> { + let ptr = Self::into_raw(this); + + // SAFETY: `ptr` is valid, because it came from `Box::into_raw`. + unsafe { Box::from_raw(ptr.cast()) } + } + + /// Drops the contents, but keeps the allocation. + /// + /// # Examples + /// + /// ``` + /// let value = KBox::new([0; 32], GFP_KERNEL)?; + /// assert_eq!(*value, [0; 32]); + /// let value = KBox::drop_contents(value); + /// // Now we can re-use `value`: + /// let value = KBox::write(value, [1; 32]); + /// assert_eq!(*value, [1; 32]); + /// # Ok::<(), Error>(()) + /// ``` + pub fn drop_contents(this: Self) -> Box<MaybeUninit<T>, A> { + let ptr = this.0.as_ptr(); + + // SAFETY: `ptr` is valid, because it came from `this`. After this call we never access the + // value stored in `this` again. + unsafe { core::ptr::drop_in_place(ptr) }; + + Self::forget_contents(this) + } + + /// Moves the `Box`'s value out of the `Box` and consumes the `Box`. + pub fn into_inner(b: Self) -> T { + // SAFETY: By the type invariant `&*b` is valid for `read`. + let value = unsafe { core::ptr::read(&*b) }; + let _ = Self::forget_contents(b); + value + } +} + +impl<T, A> From<Box<T, A>> for Pin<Box<T, A>> +where + T: ?Sized, + A: Allocator, +{ + /// Converts a `Box<T, A>` into a `Pin<Box<T, A>>`. If `T` does not implement [`Unpin`], then + /// `*b` will be pinned in memory and can't be moved. + /// + /// This moves `b` into `Pin` without moving `*b` or allocating and copying any memory. + fn from(b: Box<T, A>) -> Self { + // SAFETY: The value wrapped inside a `Pin<Box<T, A>>` cannot be moved or replaced as long + // as `T` does not implement `Unpin`. + unsafe { Pin::new_unchecked(b) } + } +} + +impl<T, A> InPlaceWrite<T> for Box<MaybeUninit<T>, A> +where + A: Allocator + 'static, +{ + type Initialized = Box<T, A>; + + fn write_init<E>(mut self, init: impl Init<T, E>) -> Result<Self::Initialized, E> { + let slot = self.as_mut_ptr(); + // SAFETY: When init errors/panics, slot will get deallocated but not dropped, + // slot is valid. + unsafe { init.__init(slot)? }; + // SAFETY: All fields have been initialized. + Ok(unsafe { Box::assume_init(self) }) + } + + fn write_pin_init<E>(mut self, init: impl PinInit<T, E>) -> Result<Pin<Self::Initialized>, E> { + let slot = self.as_mut_ptr(); + // SAFETY: When init errors/panics, slot will get deallocated but not dropped, + // slot is valid and will not be moved, because we pin it later. + unsafe { init.__pinned_init(slot)? }; + // SAFETY: All fields have been initialized. + Ok(unsafe { Box::assume_init(self) }.into()) + } +} + +impl<T, A> InPlaceInit<T> for Box<T, A> +where + A: Allocator + 'static, +{ + type PinnedSelf = Pin<Self>; + + #[inline] + fn try_pin_init<E>(init: impl PinInit<T, E>, flags: Flags) -> Result<Pin<Self>, E> + where + E: From<AllocError>, + { + Box::<_, A>::new_uninit(flags)?.write_pin_init(init) + } + + #[inline] + fn try_init<E>(init: impl Init<T, E>, flags: Flags) -> Result<Self, E> + where + E: From<AllocError>, + { + Box::<_, A>::new_uninit(flags)?.write_init(init) + } +} + +// SAFETY: The pointer returned by `into_foreign` comes from a well aligned +// pointer to `T` allocated by `A`. +unsafe impl<T: 'static, A> ForeignOwnable for Box<T, A> +where + A: Allocator, +{ + const FOREIGN_ALIGN: usize = if core::mem::align_of::<T>() < A::MIN_ALIGN { + A::MIN_ALIGN + } else { + core::mem::align_of::<T>() + }; + + type Borrowed<'a> = &'a T; + type BorrowedMut<'a> = &'a mut T; + + fn into_foreign(self) -> *mut c_void { + Box::into_raw(self).cast() + } + + unsafe fn from_foreign(ptr: *mut c_void) -> Self { + // SAFETY: The safety requirements of this function ensure that `ptr` comes from a previous + // call to `Self::into_foreign`. + unsafe { Box::from_raw(ptr.cast()) } + } + + unsafe fn borrow<'a>(ptr: *mut c_void) -> &'a T { + // SAFETY: The safety requirements of this method ensure that the object remains alive and + // immutable for the duration of 'a. + unsafe { &*ptr.cast() } + } + + unsafe fn borrow_mut<'a>(ptr: *mut c_void) -> &'a mut T { + let ptr = ptr.cast(); + // SAFETY: The safety requirements of this method ensure that the pointer is valid and that + // nothing else will access the value for the duration of 'a. + unsafe { &mut *ptr } + } +} + +// SAFETY: The pointer returned by `into_foreign` comes from a well aligned +// pointer to `T` allocated by `A`. +unsafe impl<T: 'static, A> ForeignOwnable for Pin<Box<T, A>> +where + A: Allocator, +{ + const FOREIGN_ALIGN: usize = <Box<T, A> as ForeignOwnable>::FOREIGN_ALIGN; + type Borrowed<'a> = Pin<&'a T>; + type BorrowedMut<'a> = Pin<&'a mut T>; + + fn into_foreign(self) -> *mut c_void { + // SAFETY: We are still treating the box as pinned. + Box::into_raw(unsafe { Pin::into_inner_unchecked(self) }).cast() + } + + unsafe fn from_foreign(ptr: *mut c_void) -> Self { + // SAFETY: The safety requirements of this function ensure that `ptr` comes from a previous + // call to `Self::into_foreign`. + unsafe { Pin::new_unchecked(Box::from_raw(ptr.cast())) } + } + + unsafe fn borrow<'a>(ptr: *mut c_void) -> Pin<&'a T> { + // SAFETY: The safety requirements for this function ensure that the object is still alive, + // so it is safe to dereference the raw pointer. + // The safety requirements of `from_foreign` also ensure that the object remains alive for + // the lifetime of the returned value. + let r = unsafe { &*ptr.cast() }; + + // SAFETY: This pointer originates from a `Pin<Box<T>>`. + unsafe { Pin::new_unchecked(r) } + } + + unsafe fn borrow_mut<'a>(ptr: *mut c_void) -> Pin<&'a mut T> { + let ptr = ptr.cast(); + // SAFETY: The safety requirements for this function ensure that the object is still alive, + // so it is safe to dereference the raw pointer. + // The safety requirements of `from_foreign` also ensure that the object remains alive for + // the lifetime of the returned value. + let r = unsafe { &mut *ptr }; + + // SAFETY: This pointer originates from a `Pin<Box<T>>`. + unsafe { Pin::new_unchecked(r) } + } +} + +impl<T, A> Deref for Box<T, A> +where + T: ?Sized, + A: Allocator, +{ + type Target = T; + + fn deref(&self) -> &T { + // SAFETY: `self.0` is always properly aligned, dereferenceable and points to an initialized + // instance of `T`. + unsafe { self.0.as_ref() } + } +} + +impl<T, A> DerefMut for Box<T, A> +where + T: ?Sized, + A: Allocator, +{ + fn deref_mut(&mut self) -> &mut T { + // SAFETY: `self.0` is always properly aligned, dereferenceable and points to an initialized + // instance of `T`. + unsafe { self.0.as_mut() } + } +} + +/// # Examples +/// +/// ``` +/// # use core::borrow::Borrow; +/// # use kernel::alloc::KBox; +/// struct Foo<B: Borrow<u32>>(B); +/// +/// // Owned instance. +/// let owned = Foo(1); +/// +/// // Owned instance using `KBox`. +/// let owned_kbox = Foo(KBox::new(1, GFP_KERNEL)?); +/// +/// let i = 1; +/// // Borrowed from `i`. +/// let borrowed = Foo(&i); +/// # Ok::<(), Error>(()) +/// ``` +impl<T, A> Borrow<T> for Box<T, A> +where + T: ?Sized, + A: Allocator, +{ + fn borrow(&self) -> &T { + self.deref() + } +} + +/// # Examples +/// +/// ``` +/// # use core::borrow::BorrowMut; +/// # use kernel::alloc::KBox; +/// struct Foo<B: BorrowMut<u32>>(B); +/// +/// // Owned instance. +/// let owned = Foo(1); +/// +/// // Owned instance using `KBox`. +/// let owned_kbox = Foo(KBox::new(1, GFP_KERNEL)?); +/// +/// let mut i = 1; +/// // Borrowed from `i`. +/// let borrowed = Foo(&mut i); +/// # Ok::<(), Error>(()) +/// ``` +impl<T, A> BorrowMut<T> for Box<T, A> +where + T: ?Sized, + A: Allocator, +{ + fn borrow_mut(&mut self) -> &mut T { + self.deref_mut() + } +} + +impl<T, A> fmt::Display for Box<T, A> +where + T: ?Sized + fmt::Display, + A: Allocator, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + <T as fmt::Display>::fmt(&**self, f) + } +} + +impl<T, A> fmt::Debug for Box<T, A> +where + T: ?Sized + fmt::Debug, + A: Allocator, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + <T as fmt::Debug>::fmt(&**self, f) + } +} + +impl<T, A> Drop for Box<T, A> +where + T: ?Sized, + A: Allocator, +{ + fn drop(&mut self) { + let layout = Layout::for_value::<T>(self); + + // SAFETY: The pointer in `self.0` is guaranteed to be valid by the type invariant. + unsafe { core::ptr::drop_in_place::<T>(self.deref_mut()) }; + + // SAFETY: + // - `self.0` was previously allocated with `A`. + // - `layout` is equal to the `Layout´ `self.0` was allocated with. + unsafe { A::free(self.0.cast(), layout) }; + } +} + +/// # Examples +/// +/// ``` +/// # use kernel::prelude::*; +/// use kernel::alloc::allocator::VmallocPageIter; +/// use kernel::page::{AsPageIter, PAGE_SIZE}; +/// +/// let mut vbox = VBox::new((), GFP_KERNEL)?; +/// +/// assert!(vbox.page_iter().next().is_none()); +/// +/// let mut vbox = VBox::<[u8; PAGE_SIZE]>::new_uninit(GFP_KERNEL)?; +/// +/// let page = vbox.page_iter().next().expect("At least one page should be available.\n"); +/// +/// // SAFETY: There is no concurrent read or write to the same page. +/// unsafe { page.fill_zero_raw(0, PAGE_SIZE)? }; +/// # Ok::<(), Error>(()) +/// ``` +impl<T> AsPageIter for VBox<T> { + type Iter<'a> + = VmallocPageIter<'a> + where + T: 'a; + + fn page_iter(&mut self) -> Self::Iter<'_> { + let ptr = self.0.cast(); + let size = core::mem::size_of::<T>(); + + // SAFETY: + // - `ptr` is a valid pointer to the beginning of a `Vmalloc` allocation. + // - `ptr` is guaranteed to be valid for the lifetime of `'a`. + // - `size` is the size of the `Vmalloc` allocation `ptr` points to. + unsafe { VmallocPageIter::new(ptr, size) } + } +} diff --git a/rust/kernel/alloc/kvec.rs b/rust/kernel/alloc/kvec.rs new file mode 100644 index 000000000000..ac8d6f763ae8 --- /dev/null +++ b/rust/kernel/alloc/kvec.rs @@ -0,0 +1,1401 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! Implementation of [`Vec`]. + +use super::{ + allocator::{KVmalloc, Kmalloc, Vmalloc, VmallocPageIter}, + layout::ArrayLayout, + AllocError, Allocator, Box, Flags, NumaNode, +}; +use crate::{ + fmt, + page::AsPageIter, // +}; +use core::{ + borrow::{Borrow, BorrowMut}, + marker::PhantomData, + mem::{ManuallyDrop, MaybeUninit}, + ops::Deref, + ops::DerefMut, + ops::Index, + ops::IndexMut, + ptr, + ptr::NonNull, + slice, + slice::SliceIndex, +}; + +mod errors; +pub use self::errors::{InsertError, PushError, RemoveError}; + +/// Create a [`KVec`] containing the arguments. +/// +/// New memory is allocated with `GFP_KERNEL`. +/// +/// # Examples +/// +/// ``` +/// let mut v = kernel::kvec![]; +/// v.push(1, GFP_KERNEL)?; +/// assert_eq!(v, [1]); +/// +/// let mut v = kernel::kvec![1; 3]?; +/// v.push(4, GFP_KERNEL)?; +/// assert_eq!(v, [1, 1, 1, 4]); +/// +/// let mut v = kernel::kvec![1, 2, 3]?; +/// v.push(4, GFP_KERNEL)?; +/// assert_eq!(v, [1, 2, 3, 4]); +/// +/// # Ok::<(), Error>(()) +/// ``` +#[macro_export] +macro_rules! kvec { + () => ( + $crate::alloc::KVec::new() + ); + ($elem:expr; $n:expr) => ( + $crate::alloc::KVec::from_elem($elem, $n, GFP_KERNEL) + ); + ($($x:expr),+ $(,)?) => ( + match $crate::alloc::KBox::new_uninit(GFP_KERNEL) { + Ok(b) => Ok($crate::alloc::KVec::from($crate::alloc::KBox::write(b, [$($x),+]))), + Err(e) => Err(e), + } + ); +} + +/// The kernel's [`Vec`] type. +/// +/// A contiguous growable array type with contents allocated with the kernel's allocators (e.g. +/// [`Kmalloc`], [`Vmalloc`] or [`KVmalloc`]), written `Vec<T, A>`. +/// +/// For non-zero-sized values, a [`Vec`] will use the given allocator `A` for its allocation. For +/// the most common allocators the type aliases [`KVec`], [`VVec`] and [`KVVec`] exist. +/// +/// For zero-sized types the [`Vec`]'s pointer must be `dangling_mut::<T>`; no memory is allocated. +/// +/// Generally, [`Vec`] consists of a pointer that represents the vector's backing buffer, the +/// capacity of the vector (the number of elements that currently fit into the vector), its length +/// (the number of elements that are currently stored in the vector) and the `Allocator` type used +/// to allocate (and free) the backing buffer. +/// +/// A [`Vec`] can be deconstructed into and (re-)constructed from its previously named raw parts +/// and manually modified. +/// +/// [`Vec`]'s backing buffer gets, if required, automatically increased (re-allocated) when elements +/// are added to the vector. +/// +/// # Invariants +/// +/// - `self.ptr` is always properly aligned and either points to memory allocated with `A` or, for +/// zero-sized types, is a dangling, well aligned pointer. +/// +/// - `self.len` always represents the exact number of elements stored in the vector. +/// +/// - `self.layout` represents the absolute number of elements that can be stored within the vector +/// without re-allocation. For ZSTs `self.layout`'s capacity is zero. However, it is legal for the +/// backing buffer to be larger than `layout`. +/// +/// - `self.len()` is always less than or equal to `self.capacity()`. +/// +/// - The `Allocator` type `A` of the vector is the exact same `Allocator` type the backing buffer +/// was allocated with (and must be freed with). +pub struct Vec<T, A: Allocator> { + ptr: NonNull<T>, + /// Represents the actual buffer size as `cap` times `size_of::<T>` bytes. + /// + /// Note: This isn't quite the same as `Self::capacity`, which in contrast returns the number of + /// elements we can still store without reallocating. + layout: ArrayLayout<T>, + len: usize, + _p: PhantomData<A>, +} + +/// Type alias for [`Vec`] with a [`Kmalloc`] allocator. +/// +/// # Examples +/// +/// ``` +/// let mut v = KVec::new(); +/// v.push(1, GFP_KERNEL)?; +/// assert_eq!(&v, &[1]); +/// +/// # Ok::<(), Error>(()) +/// ``` +pub type KVec<T> = Vec<T, Kmalloc>; + +/// Type alias for [`Vec`] with a [`Vmalloc`] allocator. +/// +/// # Examples +/// +/// ``` +/// let mut v = VVec::new(); +/// v.push(1, GFP_KERNEL)?; +/// assert_eq!(&v, &[1]); +/// +/// # Ok::<(), Error>(()) +/// ``` +pub type VVec<T> = Vec<T, Vmalloc>; + +/// Type alias for [`Vec`] with a [`KVmalloc`] allocator. +/// +/// # Examples +/// +/// ``` +/// let mut v = KVVec::new(); +/// v.push(1, GFP_KERNEL)?; +/// assert_eq!(&v, &[1]); +/// +/// # Ok::<(), Error>(()) +/// ``` +pub type KVVec<T> = Vec<T, KVmalloc>; + +// SAFETY: `Vec` is `Send` if `T` is `Send` because `Vec` owns its elements. +unsafe impl<T, A> Send for Vec<T, A> +where + T: Send, + A: Allocator, +{ +} + +// SAFETY: `Vec` is `Sync` if `T` is `Sync` because `Vec` owns its elements. +unsafe impl<T, A> Sync for Vec<T, A> +where + T: Sync, + A: Allocator, +{ +} + +impl<T, A> Vec<T, A> +where + A: Allocator, +{ + #[inline] + const fn is_zst() -> bool { + core::mem::size_of::<T>() == 0 + } + + /// Returns the number of elements that can be stored within the vector without allocating + /// additional memory. + pub const fn capacity(&self) -> usize { + if const { Self::is_zst() } { + usize::MAX + } else { + self.layout.len() + } + } + + /// Returns the number of elements stored within the vector. + #[inline] + pub const fn len(&self) -> usize { + self.len + } + + /// Increments `self.len` by `additional`. + /// + /// # Safety + /// + /// - `additional` must be less than or equal to `self.capacity - self.len`. + /// - All elements within the interval [`self.len`,`self.len + additional`) must be initialized. + #[inline] + pub const unsafe fn inc_len(&mut self, additional: usize) { + // Guaranteed by the type invariant to never underflow. + debug_assert!(additional <= self.capacity() - self.len()); + // INVARIANT: By the safety requirements of this method this represents the exact number of + // elements stored within `self`. + self.len += additional; + } + + /// Decreases `self.len` by `count`. + /// + /// Returns a mutable slice to the elements forgotten by the vector. It is the caller's + /// responsibility to drop these elements if necessary. + /// + /// # Safety + /// + /// - `count` must be less than or equal to `self.len`. + unsafe fn dec_len(&mut self, count: usize) -> &mut [T] { + debug_assert!(count <= self.len()); + // INVARIANT: We relinquish ownership of the elements within the range `[self.len - count, + // self.len)`, hence the updated value of `set.len` represents the exact number of elements + // stored within `self`. + self.len -= count; + // SAFETY: The memory after `self.len()` is guaranteed to contain `count` initialized + // elements of type `T`. + unsafe { slice::from_raw_parts_mut(self.as_mut_ptr().add(self.len), count) } + } + + /// Returns a slice of the entire vector. + /// + /// # Examples + /// + /// ``` + /// let mut v = KVec::new(); + /// v.push(1, GFP_KERNEL)?; + /// v.push(2, GFP_KERNEL)?; + /// assert_eq!(v.as_slice(), &[1, 2]); + /// # Ok::<(), Error>(()) + /// ``` + #[inline] + pub fn as_slice(&self) -> &[T] { + self + } + + /// Returns a mutable slice of the entire vector. + #[inline] + pub fn as_mut_slice(&mut self) -> &mut [T] { + self + } + + /// Returns a mutable raw pointer to the vector's backing buffer, or, if `T` is a ZST, a + /// dangling raw pointer. + #[inline] + pub fn as_mut_ptr(&mut self) -> *mut T { + self.ptr.as_ptr() + } + + /// Returns a raw pointer to the vector's backing buffer, or, if `T` is a ZST, a dangling raw + /// pointer. + #[inline] + pub const fn as_ptr(&self) -> *const T { + self.ptr.as_ptr() + } + + /// Returns `true` if the vector contains no elements, `false` otherwise. + /// + /// # Examples + /// + /// ``` + /// let mut v = KVec::new(); + /// assert!(v.is_empty()); + /// + /// v.push(1, GFP_KERNEL); + /// assert!(!v.is_empty()); + /// ``` + #[inline] + pub const fn is_empty(&self) -> bool { + self.len() == 0 + } + + /// Creates a new, empty `Vec<T, A>`. + /// + /// This method does not allocate by itself. + #[inline] + pub const fn new() -> Self { + // INVARIANT: Since this is a new, empty `Vec` with no backing memory yet, + // - `ptr` is a properly aligned dangling pointer for type `T`, + // - `layout` is an empty `ArrayLayout` (zero capacity) + // - `len` is zero, since no elements can be or have been stored, + // - `A` is always valid. + Self { + ptr: NonNull::dangling(), + layout: ArrayLayout::empty(), + len: 0, + _p: PhantomData::<A>, + } + } + + /// Returns a slice of `MaybeUninit<T>` for the remaining spare capacity of the vector. + pub fn spare_capacity_mut(&mut self) -> &mut [MaybeUninit<T>] { + // SAFETY: + // - `self.len` is smaller than `self.capacity` by the type invariant and hence, the + // resulting pointer is guaranteed to be part of the same allocated object. + // - `self.len` can not overflow `isize`. + let ptr = unsafe { self.as_mut_ptr().add(self.len) }.cast::<MaybeUninit<T>>(); + + // SAFETY: The memory between `self.len` and `self.capacity` is guaranteed to be allocated + // and valid, but uninitialized. + unsafe { slice::from_raw_parts_mut(ptr, self.capacity() - self.len) } + } + + /// Appends an element to the back of the [`Vec`] instance. + /// + /// # Examples + /// + /// ``` + /// let mut v = KVec::new(); + /// v.push(1, GFP_KERNEL)?; + /// assert_eq!(&v, &[1]); + /// + /// v.push(2, GFP_KERNEL)?; + /// assert_eq!(&v, &[1, 2]); + /// # Ok::<(), Error>(()) + /// ``` + pub fn push(&mut self, v: T, flags: Flags) -> Result<(), AllocError> { + self.reserve(1, flags)?; + // SAFETY: The call to `reserve` was successful, so the capacity is at least one greater + // than the length. + unsafe { self.push_within_capacity_unchecked(v) }; + Ok(()) + } + + /// Appends an element to the back of the [`Vec`] instance without reallocating. + /// + /// Fails if the vector does not have capacity for the new element. + /// + /// # Examples + /// + /// ``` + /// let mut v = KVec::with_capacity(10, GFP_KERNEL)?; + /// for i in 0..10 { + /// v.push_within_capacity(i)?; + /// } + /// + /// assert!(v.push_within_capacity(10).is_err()); + /// # Ok::<(), Error>(()) + /// ``` + pub fn push_within_capacity(&mut self, v: T) -> Result<(), PushError<T>> { + if self.len() < self.capacity() { + // SAFETY: The length is less than the capacity. + unsafe { self.push_within_capacity_unchecked(v) }; + Ok(()) + } else { + Err(PushError(v)) + } + } + + /// Appends an element to the back of the [`Vec`] instance without reallocating. + /// + /// # Safety + /// + /// The length must be less than the capacity. + unsafe fn push_within_capacity_unchecked(&mut self, v: T) { + let spare = self.spare_capacity_mut(); + + // SAFETY: By the safety requirements, `spare` is non-empty. + unsafe { spare.get_unchecked_mut(0) }.write(v); + + // SAFETY: We just initialised the first spare entry, so it is safe to increase the length + // by 1. We also know that the new length is <= capacity because the caller guarantees that + // the length is less than the capacity at the beginning of this function. + unsafe { self.inc_len(1) }; + } + + /// Inserts an element at the given index in the [`Vec`] instance. + /// + /// Fails if the vector does not have capacity for the new element. Panics if the index is out + /// of bounds. + /// + /// # Examples + /// + /// ``` + /// use kernel::alloc::kvec::InsertError; + /// + /// let mut v = KVec::with_capacity(5, GFP_KERNEL)?; + /// for i in 0..5 { + /// v.insert_within_capacity(0, i)?; + /// } + /// + /// assert!(matches!(v.insert_within_capacity(0, 5), Err(InsertError::OutOfCapacity(_)))); + /// assert!(matches!(v.insert_within_capacity(1000, 5), Err(InsertError::IndexOutOfBounds(_)))); + /// assert_eq!(v, [4, 3, 2, 1, 0]); + /// # Ok::<(), Error>(()) + /// ``` + pub fn insert_within_capacity( + &mut self, + index: usize, + element: T, + ) -> Result<(), InsertError<T>> { + let len = self.len(); + if index > len { + return Err(InsertError::IndexOutOfBounds(element)); + } + + if len >= self.capacity() { + return Err(InsertError::OutOfCapacity(element)); + } + + // SAFETY: This is in bounds since `index <= len < capacity`. + let p = unsafe { self.as_mut_ptr().add(index) }; + // INVARIANT: This breaks the Vec invariants by making `index` contain an invalid element, + // but we restore the invariants below. + // SAFETY: Both the src and dst ranges end no later than one element after the length. + // Since the length is less than the capacity, both ranges are in bounds of the allocation. + unsafe { ptr::copy(p, p.add(1), len - index) }; + // INVARIANT: This restores the Vec invariants. + // SAFETY: The pointer is in-bounds of the allocation. + unsafe { ptr::write(p, element) }; + // SAFETY: Index `len` contains a valid element due to the above copy and write. + unsafe { self.inc_len(1) }; + Ok(()) + } + + /// Removes the last element from a vector and returns it, or `None` if it is empty. + /// + /// # Examples + /// + /// ``` + /// let mut v = KVec::new(); + /// v.push(1, GFP_KERNEL)?; + /// v.push(2, GFP_KERNEL)?; + /// assert_eq!(&v, &[1, 2]); + /// + /// assert_eq!(v.pop(), Some(2)); + /// assert_eq!(v.pop(), Some(1)); + /// assert_eq!(v.pop(), None); + /// # Ok::<(), Error>(()) + /// ``` + pub fn pop(&mut self) -> Option<T> { + if self.is_empty() { + return None; + } + + let removed: *mut T = { + // SAFETY: We just checked that the length is at least one. + let slice = unsafe { self.dec_len(1) }; + // SAFETY: The argument to `dec_len` was 1 so this returns a slice of length 1. + unsafe { slice.get_unchecked_mut(0) } + }; + + // SAFETY: The guarantees of `dec_len` allow us to take ownership of this value. + Some(unsafe { removed.read() }) + } + + /// Removes the element at the given index. + /// + /// # Examples + /// + /// ``` + /// let mut v = kernel::kvec![1, 2, 3]?; + /// assert_eq!(v.remove(1)?, 2); + /// assert_eq!(v, [1, 3]); + /// # Ok::<(), Error>(()) + /// ``` + pub fn remove(&mut self, i: usize) -> Result<T, RemoveError> { + let value = { + let value_ref = self.get(i).ok_or(RemoveError)?; + // INVARIANT: This breaks the invariants by invalidating the value at index `i`, but we + // restore the invariants below. + // SAFETY: The value at index `i` is valid, because otherwise we would have already + // failed with `RemoveError`. + unsafe { ptr::read(value_ref) } + }; + + // SAFETY: We checked that `i` is in-bounds. + let p = unsafe { self.as_mut_ptr().add(i) }; + + // INVARIANT: After this call, the invalid value is at the last slot, so the Vec invariants + // are restored after the below call to `dec_len(1)`. + // SAFETY: `p.add(1).add(self.len - i - 1)` is `i+1+len-i-1 == len` elements after the + // beginning of the vector, so this is in-bounds of the vector's allocation. + unsafe { ptr::copy(p.add(1), p, self.len - i - 1) }; + + // SAFETY: Since the check at the beginning of this call did not fail with `RemoveError`, + // the length is at least one. + unsafe { self.dec_len(1) }; + + Ok(value) + } + + /// Creates a new [`Vec`] instance with at least the given capacity. + /// + /// # Examples + /// + /// ``` + /// let v = KVec::<u32>::with_capacity(20, GFP_KERNEL)?; + /// + /// assert!(v.capacity() >= 20); + /// # Ok::<(), Error>(()) + /// ``` + pub fn with_capacity(capacity: usize, flags: Flags) -> Result<Self, AllocError> { + let mut v = Vec::new(); + + v.reserve(capacity, flags)?; + + Ok(v) + } + + /// Creates a `Vec<T, A>` from a pointer, a length and a capacity using the allocator `A`. + /// + /// # Examples + /// + /// ``` + /// let mut v = kernel::kvec![1, 2, 3]?; + /// v.reserve(1, GFP_KERNEL)?; + /// + /// let (mut ptr, mut len, cap) = v.into_raw_parts(); + /// + /// // SAFETY: We've just reserved memory for another element. + /// unsafe { ptr.add(len).write(4) }; + /// len += 1; + /// + /// // SAFETY: We only wrote an additional element at the end of the `KVec`'s buffer and + /// // correspondingly increased the length of the `KVec` by one. Otherwise, we construct it + /// // from the exact same raw parts. + /// let v = unsafe { KVec::from_raw_parts(ptr, len, cap) }; + /// + /// assert_eq!(v, [1, 2, 3, 4]); + /// + /// # Ok::<(), Error>(()) + /// ``` + /// + /// # Safety + /// + /// If `T` is a ZST: + /// + /// - `ptr` must be a dangling, well aligned pointer. + /// + /// Otherwise: + /// + /// - `ptr` must have been allocated with the allocator `A`. + /// - `ptr` must satisfy or exceed the alignment requirements of `T`. + /// - `ptr` must point to memory with a size of at least `size_of::<T>() * capacity` bytes. + /// - The allocated size in bytes must not be larger than `isize::MAX`. + /// - `length` must be less than or equal to `capacity`. + /// - The first `length` elements must be initialized values of type `T`. + /// + /// It is also valid to create an empty `Vec` passing a dangling pointer for `ptr` and zero for + /// `cap` and `len`. + pub unsafe fn from_raw_parts(ptr: *mut T, length: usize, capacity: usize) -> Self { + let layout = if Self::is_zst() { + ArrayLayout::empty() + } else { + // SAFETY: By the safety requirements of this function, `capacity * size_of::<T>()` is + // smaller than `isize::MAX`. + unsafe { ArrayLayout::new_unchecked(capacity) } + }; + + // INVARIANT: For ZSTs, we store an empty `ArrayLayout`, all other type invariants are + // covered by the safety requirements of this function. + Self { + // SAFETY: By the safety requirements, `ptr` is either dangling or pointing to a valid + // memory allocation, allocated with `A`. + ptr: unsafe { NonNull::new_unchecked(ptr) }, + layout, + len: length, + _p: PhantomData::<A>, + } + } + + /// Consumes the `Vec<T, A>` and returns its raw components `pointer`, `length` and `capacity`. + /// + /// This will not run the destructor of the contained elements and for non-ZSTs the allocation + /// will stay alive indefinitely. Use [`Vec::from_raw_parts`] to recover the [`Vec`], drop the + /// elements and free the allocation, if any. + pub fn into_raw_parts(self) -> (*mut T, usize, usize) { + let mut me = ManuallyDrop::new(self); + let len = me.len(); + let capacity = me.capacity(); + let ptr = me.as_mut_ptr(); + (ptr, len, capacity) + } + + /// Clears the vector, removing all values. + /// + /// Note that this method has no effect on the allocated capacity + /// of the vector. + /// + /// # Examples + /// + /// ``` + /// let mut v = kernel::kvec![1, 2, 3]?; + /// + /// v.clear(); + /// + /// assert!(v.is_empty()); + /// # Ok::<(), Error>(()) + /// ``` + #[inline] + pub fn clear(&mut self) { + self.truncate(0); + } + + /// Ensures that the capacity exceeds the length by at least `additional` elements. + /// + /// # Examples + /// + /// ``` + /// let mut v = KVec::new(); + /// v.push(1, GFP_KERNEL)?; + /// + /// v.reserve(10, GFP_KERNEL)?; + /// let cap = v.capacity(); + /// assert!(cap >= 10); + /// + /// v.reserve(10, GFP_KERNEL)?; + /// let new_cap = v.capacity(); + /// assert_eq!(new_cap, cap); + /// + /// # Ok::<(), Error>(()) + /// ``` + pub fn reserve(&mut self, additional: usize, flags: Flags) -> Result<(), AllocError> { + let len = self.len(); + let cap = self.capacity(); + + if cap - len >= additional { + return Ok(()); + } + + if Self::is_zst() { + // The capacity is already `usize::MAX` for ZSTs, we can't go higher. + return Err(AllocError); + } + + // We know that `cap <= isize::MAX` because of the type invariants of `Self`. So the + // multiplication by two won't overflow. + let new_cap = core::cmp::max(cap * 2, len.checked_add(additional).ok_or(AllocError)?); + let layout = ArrayLayout::new(new_cap).map_err(|_| AllocError)?; + + // SAFETY: + // - `ptr` is valid because it's either `None` or comes from a previous call to + // `A::realloc`. + // - `self.layout` matches the `ArrayLayout` of the preceding allocation. + let ptr = unsafe { + A::realloc( + Some(self.ptr.cast()), + layout.into(), + self.layout.into(), + flags, + NumaNode::NO_NODE, + )? + }; + + // INVARIANT: + // - `layout` is some `ArrayLayout::<T>`, + // - `ptr` has been created by `A::realloc` from `layout`. + self.ptr = ptr.cast(); + self.layout = layout; + + Ok(()) + } + + /// Shortens the vector, setting the length to `len` and drops the removed values. + /// If `len` is greater than or equal to the current length, this does nothing. + /// + /// This has no effect on the capacity and will not allocate. + /// + /// # Examples + /// + /// ``` + /// let mut v = kernel::kvec![1, 2, 3]?; + /// v.truncate(1); + /// assert_eq!(v.len(), 1); + /// assert_eq!(&v, &[1]); + /// + /// # Ok::<(), Error>(()) + /// ``` + pub fn truncate(&mut self, len: usize) { + if let Some(count) = self.len().checked_sub(len) { + // SAFETY: `count` is `self.len() - len` so it is guaranteed to be less than or + // equal to `self.len()`. + let ptr: *mut [T] = unsafe { self.dec_len(count) }; + + // SAFETY: the contract of `dec_len` guarantees that the elements in `ptr` are + // valid elements whose ownership has been transferred to the caller. + unsafe { ptr::drop_in_place(ptr) }; + } + } + + /// Takes ownership of all items in this vector without consuming the allocation. + /// + /// # Examples + /// + /// ``` + /// let mut v = kernel::kvec![0, 1, 2, 3]?; + /// + /// for (i, j) in v.drain_all().enumerate() { + /// assert_eq!(i, j); + /// } + /// + /// assert!(v.capacity() >= 4); + /// # Ok::<(), Error>(()) + /// ``` + pub fn drain_all(&mut self) -> DrainAll<'_, T> { + // SAFETY: This does not underflow the length. + let elems = unsafe { self.dec_len(self.len()) }; + // INVARIANT: The first `len` elements of the spare capacity are valid values, and as we + // just set the length to zero, we may transfer ownership to the `DrainAll` object. + DrainAll { + elements: elems.iter_mut(), + } + } + + /// Removes all elements that don't match the provided closure. + /// + /// # Examples + /// + /// ``` + /// let mut v = kernel::kvec![1, 2, 3, 4]?; + /// v.retain(|i| *i % 2 == 0); + /// assert_eq!(v, [2, 4]); + /// # Ok::<(), Error>(()) + /// ``` + pub fn retain(&mut self, mut f: impl FnMut(&mut T) -> bool) { + let mut num_kept = 0; + let mut next_to_check = 0; + while let Some(to_check) = self.get_mut(next_to_check) { + if f(to_check) { + self.swap(num_kept, next_to_check); + num_kept += 1; + } + next_to_check += 1; + } + self.truncate(num_kept); + } +} + +impl<T: Clone, A: Allocator> Vec<T, A> { + /// Extend the vector by `n` clones of `value`. + pub fn extend_with(&mut self, n: usize, value: T, flags: Flags) -> Result<(), AllocError> { + if n == 0 { + return Ok(()); + } + + self.reserve(n, flags)?; + + let spare = self.spare_capacity_mut(); + + for item in spare.iter_mut().take(n - 1) { + item.write(value.clone()); + } + + // We can write the last element directly without cloning needlessly. + spare[n - 1].write(value); + + // SAFETY: + // - `self.len() + n < self.capacity()` due to the call to reserve above, + // - the loop and the line above initialized the next `n` elements. + unsafe { self.inc_len(n) }; + + Ok(()) + } + + /// Pushes clones of the elements of slice into the [`Vec`] instance. + /// + /// # Examples + /// + /// ``` + /// let mut v = KVec::new(); + /// v.push(1, GFP_KERNEL)?; + /// + /// v.extend_from_slice(&[20, 30, 40], GFP_KERNEL)?; + /// assert_eq!(&v, &[1, 20, 30, 40]); + /// + /// v.extend_from_slice(&[50, 60], GFP_KERNEL)?; + /// assert_eq!(&v, &[1, 20, 30, 40, 50, 60]); + /// # Ok::<(), Error>(()) + /// ``` + pub fn extend_from_slice(&mut self, other: &[T], flags: Flags) -> Result<(), AllocError> { + self.reserve(other.len(), flags)?; + for (slot, item) in core::iter::zip(self.spare_capacity_mut(), other) { + slot.write(item.clone()); + } + + // SAFETY: + // - `other.len()` spare entries have just been initialized, so it is safe to increase + // the length by the same number. + // - `self.len() + other.len() <= self.capacity()` is guaranteed by the preceding `reserve` + // call. + unsafe { self.inc_len(other.len()) }; + Ok(()) + } + + /// Create a new `Vec<T, A>` and extend it by `n` clones of `value`. + pub fn from_elem(value: T, n: usize, flags: Flags) -> Result<Self, AllocError> { + let mut v = Self::with_capacity(n, flags)?; + + v.extend_with(n, value, flags)?; + + Ok(v) + } + + /// Resizes the [`Vec`] so that `len` is equal to `new_len`. + /// + /// If `new_len` is smaller than `len`, the `Vec` is [`Vec::truncate`]d. + /// If `new_len` is larger, each new slot is filled with clones of `value`. + /// + /// # Examples + /// + /// ``` + /// let mut v = kernel::kvec![1, 2, 3]?; + /// v.resize(1, 42, GFP_KERNEL)?; + /// assert_eq!(&v, &[1]); + /// + /// v.resize(3, 42, GFP_KERNEL)?; + /// assert_eq!(&v, &[1, 42, 42]); + /// + /// # Ok::<(), Error>(()) + /// ``` + pub fn resize(&mut self, new_len: usize, value: T, flags: Flags) -> Result<(), AllocError> { + match new_len.checked_sub(self.len()) { + Some(n) => self.extend_with(n, value, flags), + None => { + self.truncate(new_len); + Ok(()) + } + } + } +} + +impl<T, A> Drop for Vec<T, A> +where + A: Allocator, +{ + fn drop(&mut self) { + // SAFETY: `self.as_mut_ptr` is guaranteed to be valid by the type invariant. + unsafe { + ptr::drop_in_place(core::ptr::slice_from_raw_parts_mut( + self.as_mut_ptr(), + self.len, + )) + }; + + // SAFETY: + // - `self.ptr` was previously allocated with `A`. + // - `self.layout` matches the `ArrayLayout` of the preceding allocation. + unsafe { A::free(self.ptr.cast(), self.layout.into()) }; + } +} + +impl<T, A, const N: usize> From<Box<[T; N], A>> for Vec<T, A> +where + A: Allocator, +{ + fn from(b: Box<[T; N], A>) -> Vec<T, A> { + let len = b.len(); + let ptr = Box::into_raw(b); + + // SAFETY: + // - `b` has been allocated with `A`, + // - `ptr` fulfills the alignment requirements for `T`, + // - `ptr` points to memory with at least a size of `size_of::<T>() * len`, + // - all elements within `b` are initialized values of `T`, + // - `len` does not exceed `isize::MAX`. + unsafe { Vec::from_raw_parts(ptr.cast(), len, len) } + } +} + +impl<T, A: Allocator> Default for Vec<T, A> { + #[inline] + fn default() -> Self { + Self::new() + } +} + +impl<T: fmt::Debug, A: Allocator> fmt::Debug for Vec<T, A> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Debug::fmt(&**self, f) + } +} + +impl<T, A> Deref for Vec<T, A> +where + A: Allocator, +{ + type Target = [T]; + + #[inline] + fn deref(&self) -> &[T] { + // SAFETY: The memory behind `self.as_ptr()` is guaranteed to contain `self.len` + // initialized elements of type `T`. + unsafe { slice::from_raw_parts(self.as_ptr(), self.len) } + } +} + +impl<T, A> DerefMut for Vec<T, A> +where + A: Allocator, +{ + #[inline] + fn deref_mut(&mut self) -> &mut [T] { + // SAFETY: The memory behind `self.as_ptr()` is guaranteed to contain `self.len` + // initialized elements of type `T`. + unsafe { slice::from_raw_parts_mut(self.as_mut_ptr(), self.len) } + } +} + +/// # Examples +/// +/// ``` +/// # use core::borrow::Borrow; +/// struct Foo<B: Borrow<[u32]>>(B); +/// +/// // Owned array. +/// let owned_array = Foo([1, 2, 3]); +/// +/// // Owned vector. +/// let owned_vec = Foo(KVec::from_elem(0, 3, GFP_KERNEL)?); +/// +/// let arr = [1, 2, 3]; +/// // Borrowed slice from `arr`. +/// let borrowed_slice = Foo(&arr[..]); +/// # Ok::<(), Error>(()) +/// ``` +impl<T, A> Borrow<[T]> for Vec<T, A> +where + A: Allocator, +{ + fn borrow(&self) -> &[T] { + self.as_slice() + } +} + +/// # Examples +/// +/// ``` +/// # use core::borrow::BorrowMut; +/// struct Foo<B: BorrowMut<[u32]>>(B); +/// +/// // Owned array. +/// let owned_array = Foo([1, 2, 3]); +/// +/// // Owned vector. +/// let owned_vec = Foo(KVec::from_elem(0, 3, GFP_KERNEL)?); +/// +/// let mut arr = [1, 2, 3]; +/// // Borrowed slice from `arr`. +/// let borrowed_slice = Foo(&mut arr[..]); +/// # Ok::<(), Error>(()) +/// ``` +impl<T, A> BorrowMut<[T]> for Vec<T, A> +where + A: Allocator, +{ + fn borrow_mut(&mut self) -> &mut [T] { + self.as_mut_slice() + } +} + +impl<T: Eq, A> Eq for Vec<T, A> where A: Allocator {} + +impl<T, I: SliceIndex<[T]>, A> Index<I> for Vec<T, A> +where + A: Allocator, +{ + type Output = I::Output; + + #[inline] + fn index(&self, index: I) -> &Self::Output { + Index::index(&**self, index) + } +} + +impl<T, I: SliceIndex<[T]>, A> IndexMut<I> for Vec<T, A> +where + A: Allocator, +{ + #[inline] + fn index_mut(&mut self, index: I) -> &mut Self::Output { + IndexMut::index_mut(&mut **self, index) + } +} + +macro_rules! impl_slice_eq { + ($([$($vars:tt)*] $lhs:ty, $rhs:ty,)*) => { + $( + impl<T, U, $($vars)*> PartialEq<$rhs> for $lhs + where + T: PartialEq<U>, + { + #[inline] + fn eq(&self, other: &$rhs) -> bool { self[..] == other[..] } + } + )* + } +} + +impl_slice_eq! { + [A1: Allocator, A2: Allocator] Vec<T, A1>, Vec<U, A2>, + [A: Allocator] Vec<T, A>, &[U], + [A: Allocator] Vec<T, A>, &mut [U], + [A: Allocator] &[T], Vec<U, A>, + [A: Allocator] &mut [T], Vec<U, A>, + [A: Allocator] Vec<T, A>, [U], + [A: Allocator] [T], Vec<U, A>, + [A: Allocator, const N: usize] Vec<T, A>, [U; N], + [A: Allocator, const N: usize] Vec<T, A>, &[U; N], +} + +impl<'a, T, A> IntoIterator for &'a Vec<T, A> +where + A: Allocator, +{ + type Item = &'a T; + type IntoIter = slice::Iter<'a, T>; + + fn into_iter(self) -> Self::IntoIter { + self.iter() + } +} + +impl<'a, T, A: Allocator> IntoIterator for &'a mut Vec<T, A> +where + A: Allocator, +{ + type Item = &'a mut T; + type IntoIter = slice::IterMut<'a, T>; + + fn into_iter(self) -> Self::IntoIter { + self.iter_mut() + } +} + +/// # Examples +/// +/// ``` +/// # use kernel::prelude::*; +/// use kernel::alloc::allocator::VmallocPageIter; +/// use kernel::page::{AsPageIter, PAGE_SIZE}; +/// +/// let mut vec = VVec::<u8>::new(); +/// +/// assert!(vec.page_iter().next().is_none()); +/// +/// vec.reserve(PAGE_SIZE, GFP_KERNEL)?; +/// +/// let page = vec.page_iter().next().expect("At least one page should be available.\n"); +/// +/// // SAFETY: There is no concurrent read or write to the same page. +/// unsafe { page.fill_zero_raw(0, PAGE_SIZE)? }; +/// # Ok::<(), Error>(()) +/// ``` +impl<T> AsPageIter for VVec<T> { + type Iter<'a> + = VmallocPageIter<'a> + where + T: 'a; + + fn page_iter(&mut self) -> Self::Iter<'_> { + let ptr = self.ptr.cast(); + let size = self.layout.size(); + + // SAFETY: + // - `ptr` is a valid pointer to the beginning of a `Vmalloc` allocation. + // - `ptr` is guaranteed to be valid for the lifetime of `'a`. + // - `size` is the size of the `Vmalloc` allocation `ptr` points to. + unsafe { VmallocPageIter::new(ptr, size) } + } +} + +/// An [`Iterator`] implementation for [`Vec`] that moves elements out of a vector. +/// +/// This structure is created by the [`Vec::into_iter`] method on [`Vec`] (provided by the +/// [`IntoIterator`] trait). +/// +/// # Examples +/// +/// ``` +/// let v = kernel::kvec![0, 1, 2]?; +/// let iter = v.into_iter(); +/// +/// # Ok::<(), Error>(()) +/// ``` +pub struct IntoIter<T, A: Allocator> { + ptr: *mut T, + buf: NonNull<T>, + len: usize, + layout: ArrayLayout<T>, + _p: PhantomData<A>, +} + +impl<T, A> IntoIter<T, A> +where + A: Allocator, +{ + fn into_raw_parts(self) -> (*mut T, NonNull<T>, usize, usize) { + let me = ManuallyDrop::new(self); + let ptr = me.ptr; + let buf = me.buf; + let len = me.len; + let cap = me.layout.len(); + (ptr, buf, len, cap) + } + + /// Same as `Iterator::collect` but specialized for `Vec`'s `IntoIter`. + /// + /// # Examples + /// + /// ``` + /// let v = kernel::kvec![1, 2, 3]?; + /// let mut it = v.into_iter(); + /// + /// assert_eq!(it.next(), Some(1)); + /// + /// let v = it.collect(GFP_KERNEL); + /// assert_eq!(v, [2, 3]); + /// + /// # Ok::<(), Error>(()) + /// ``` + /// + /// # Implementation details + /// + /// Currently, we can't implement `FromIterator`. There are a couple of issues with this trait + /// in the kernel, namely: + /// + /// - Rust's specialization feature is unstable. This prevents us to optimize for the special + /// case where `I::IntoIter` equals `Vec`'s `IntoIter` type. + /// - We also can't use `I::IntoIter`'s type ID either to work around this, since `FromIterator` + /// doesn't require this type to be `'static`. + /// - `FromIterator::from_iter` does return `Self` instead of `Result<Self, AllocError>`, hence + /// we can't properly handle allocation failures. + /// - Neither `Iterator::collect` nor `FromIterator::from_iter` can handle additional allocation + /// flags. + /// + /// Instead, provide `IntoIter::collect`, such that we can at least convert a `IntoIter` into a + /// `Vec` again. + /// + /// Note that `IntoIter::collect` doesn't require `Flags`, since it re-uses the existing backing + /// buffer. However, this backing buffer may be shrunk to the actual count of elements. + pub fn collect(self, flags: Flags) -> Vec<T, A> { + let old_layout = self.layout; + let (mut ptr, buf, len, mut cap) = self.into_raw_parts(); + let has_advanced = ptr != buf.as_ptr(); + + if has_advanced { + // Copy the contents we have advanced to at the beginning of the buffer. + // + // SAFETY: + // - `ptr` is valid for reads of `len * size_of::<T>()` bytes, + // - `buf.as_ptr()` is valid for writes of `len * size_of::<T>()` bytes, + // - `ptr` and `buf.as_ptr()` are not be subject to aliasing restrictions relative to + // each other, + // - both `ptr` and `buf.ptr()` are properly aligned. + unsafe { ptr::copy(ptr, buf.as_ptr(), len) }; + ptr = buf.as_ptr(); + + // SAFETY: `len` is guaranteed to be smaller than `self.layout.len()` by the type + // invariant. + let layout = unsafe { ArrayLayout::<T>::new_unchecked(len) }; + + // SAFETY: `buf` points to the start of the backing buffer and `len` is guaranteed by + // the type invariant to be smaller than `cap`. Depending on `realloc` this operation + // may shrink the buffer or leave it as it is. + ptr = match unsafe { + A::realloc( + Some(buf.cast()), + layout.into(), + old_layout.into(), + flags, + NumaNode::NO_NODE, + ) + } { + // If we fail to shrink, which likely can't even happen, continue with the existing + // buffer. + Err(_) => ptr, + Ok(ptr) => { + cap = len; + ptr.as_ptr().cast() + } + }; + } + + // SAFETY: If the iterator has been advanced, the advanced elements have been copied to + // the beginning of the buffer and `len` has been adjusted accordingly. + // + // - `ptr` is guaranteed to point to the start of the backing buffer. + // - `cap` is either the original capacity or, after shrinking the buffer, equal to `len`. + // - `alloc` is guaranteed to be unchanged since `into_iter` has been called on the original + // `Vec`. + unsafe { Vec::from_raw_parts(ptr, len, cap) } + } +} + +impl<T, A> Iterator for IntoIter<T, A> +where + A: Allocator, +{ + type Item = T; + + /// # Examples + /// + /// ``` + /// let v = kernel::kvec![1, 2, 3]?; + /// let mut it = v.into_iter(); + /// + /// assert_eq!(it.next(), Some(1)); + /// assert_eq!(it.next(), Some(2)); + /// assert_eq!(it.next(), Some(3)); + /// assert_eq!(it.next(), None); + /// + /// # Ok::<(), Error>(()) + /// ``` + fn next(&mut self) -> Option<T> { + if self.len == 0 { + return None; + } + + let current = self.ptr; + + // SAFETY: We can't overflow; decreasing `self.len` by one every time we advance `self.ptr` + // by one guarantees that. + unsafe { self.ptr = self.ptr.add(1) }; + + self.len -= 1; + + // SAFETY: `current` is guaranteed to point at a valid element within the buffer. + Some(unsafe { current.read() }) + } + + /// # Examples + /// + /// ``` + /// let v: KVec<u32> = kernel::kvec![1, 2, 3]?; + /// let mut iter = v.into_iter(); + /// let size = iter.size_hint().0; + /// + /// iter.next(); + /// assert_eq!(iter.size_hint().0, size - 1); + /// + /// iter.next(); + /// assert_eq!(iter.size_hint().0, size - 2); + /// + /// iter.next(); + /// assert_eq!(iter.size_hint().0, size - 3); + /// + /// # Ok::<(), Error>(()) + /// ``` + fn size_hint(&self) -> (usize, Option<usize>) { + (self.len, Some(self.len)) + } +} + +impl<T, A> Drop for IntoIter<T, A> +where + A: Allocator, +{ + fn drop(&mut self) { + // SAFETY: `self.ptr` is guaranteed to be valid by the type invariant. + unsafe { ptr::drop_in_place(ptr::slice_from_raw_parts_mut(self.ptr, self.len)) }; + + // SAFETY: + // - `self.buf` was previously allocated with `A`. + // - `self.layout` matches the `ArrayLayout` of the preceding allocation. + unsafe { A::free(self.buf.cast(), self.layout.into()) }; + } +} + +impl<T, A> IntoIterator for Vec<T, A> +where + A: Allocator, +{ + type Item = T; + type IntoIter = IntoIter<T, A>; + + /// Consumes the `Vec<T, A>` and creates an `Iterator`, which moves each value out of the + /// vector (from start to end). + /// + /// # Examples + /// + /// ``` + /// let v = kernel::kvec![1, 2]?; + /// let mut v_iter = v.into_iter(); + /// + /// let first_element: Option<u32> = v_iter.next(); + /// + /// assert_eq!(first_element, Some(1)); + /// assert_eq!(v_iter.next(), Some(2)); + /// assert_eq!(v_iter.next(), None); + /// + /// # Ok::<(), Error>(()) + /// ``` + /// + /// ``` + /// let v = kernel::kvec![]; + /// let mut v_iter = v.into_iter(); + /// + /// let first_element: Option<u32> = v_iter.next(); + /// + /// assert_eq!(first_element, None); + /// + /// # Ok::<(), Error>(()) + /// ``` + #[inline] + fn into_iter(self) -> Self::IntoIter { + let buf = self.ptr; + let layout = self.layout; + let (ptr, len, _) = self.into_raw_parts(); + + IntoIter { + ptr, + buf, + len, + layout, + _p: PhantomData::<A>, + } + } +} + +/// An iterator that owns all items in a vector, but does not own its allocation. +/// +/// # Invariants +/// +/// Every `&mut T` returned by the iterator references a `T` that the iterator may take ownership +/// of. +pub struct DrainAll<'vec, T> { + elements: slice::IterMut<'vec, T>, +} + +impl<'vec, T> Iterator for DrainAll<'vec, T> { + type Item = T; + + fn next(&mut self) -> Option<T> { + let elem: *mut T = self.elements.next()?; + // SAFETY: By the type invariants, we may take ownership of this value. + Some(unsafe { elem.read() }) + } + + fn size_hint(&self) -> (usize, Option<usize>) { + self.elements.size_hint() + } +} + +impl<'vec, T> Drop for DrainAll<'vec, T> { + fn drop(&mut self) { + if core::mem::needs_drop::<T>() { + let iter = core::mem::take(&mut self.elements); + let ptr: *mut [T] = iter.into_slice(); + // SAFETY: By the type invariants, we own these values so we may destroy them. + unsafe { ptr::drop_in_place(ptr) }; + } + } +} + +#[macros::kunit_tests(rust_kvec)] +mod tests { + use super::*; + use crate::prelude::*; + + #[test] + fn test_kvec_retain() { + /// Verify correctness for one specific function. + #[expect(clippy::needless_range_loop)] + fn verify(c: &[bool]) { + let mut vec1: KVec<usize> = KVec::with_capacity(c.len(), GFP_KERNEL).unwrap(); + let mut vec2: KVec<usize> = KVec::with_capacity(c.len(), GFP_KERNEL).unwrap(); + + for i in 0..c.len() { + vec1.push_within_capacity(i).unwrap(); + if c[i] { + vec2.push_within_capacity(i).unwrap(); + } + } + + vec1.retain(|i| c[*i]); + + assert_eq!(vec1, vec2); + } + + /// Add one to a binary integer represented as a boolean array. + fn add(value: &mut [bool]) { + let mut carry = true; + for v in value { + let new_v = carry != *v; + carry = carry && *v; + *v = new_v; + } + } + + // This boolean array represents a function from index to boolean. We check that `retain` + // behaves correctly for all possible boolean arrays of every possible length less than + // ten. + let mut func = KVec::with_capacity(10, GFP_KERNEL).unwrap(); + for len in 0..10 { + for _ in 0u32..1u32 << len { + verify(&func); + add(&mut func); + } + func.push_within_capacity(false).unwrap(); + } + } +} diff --git a/rust/kernel/alloc/kvec/errors.rs b/rust/kernel/alloc/kvec/errors.rs new file mode 100644 index 000000000000..e7de5049ee47 --- /dev/null +++ b/rust/kernel/alloc/kvec/errors.rs @@ -0,0 +1,61 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! Errors for the [`Vec`] type. + +use kernel::fmt; +use kernel::prelude::*; + +/// Error type for [`Vec::push_within_capacity`]. +pub struct PushError<T>(pub T); + +impl<T> fmt::Debug for PushError<T> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "Not enough capacity") + } +} + +impl<T> From<PushError<T>> for Error { + fn from(_: PushError<T>) -> Error { + // Returning ENOMEM isn't appropriate because the system is not out of memory. The vector + // is just full and we are refusing to resize it. + EINVAL + } +} + +/// Error type for [`Vec::remove`]. +pub struct RemoveError; + +impl fmt::Debug for RemoveError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "Index out of bounds") + } +} + +impl From<RemoveError> for Error { + fn from(_: RemoveError) -> Error { + EINVAL + } +} + +/// Error type for [`Vec::insert_within_capacity`]. +pub enum InsertError<T> { + /// The value could not be inserted because the index is out of bounds. + IndexOutOfBounds(T), + /// The value could not be inserted because the vector is out of capacity. + OutOfCapacity(T), +} + +impl<T> fmt::Debug for InsertError<T> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + InsertError::IndexOutOfBounds(_) => write!(f, "Index out of bounds"), + InsertError::OutOfCapacity(_) => write!(f, "Not enough capacity"), + } + } +} + +impl<T> From<InsertError<T>> for Error { + fn from(_: InsertError<T>) -> Error { + EINVAL + } +} diff --git a/rust/kernel/alloc/layout.rs b/rust/kernel/alloc/layout.rs new file mode 100644 index 000000000000..9f8be72feb7a --- /dev/null +++ b/rust/kernel/alloc/layout.rs @@ -0,0 +1,115 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! Memory layout. +//! +//! Custom layout types extending or improving [`Layout`]. + +use core::{alloc::Layout, marker::PhantomData}; + +/// Error when constructing an [`ArrayLayout`]. +pub struct LayoutError; + +/// A layout for an array `[T; n]`. +/// +/// # Invariants +/// +/// - `len * size_of::<T>() <= isize::MAX`. +pub struct ArrayLayout<T> { + len: usize, + _phantom: PhantomData<fn() -> T>, +} + +impl<T> Clone for ArrayLayout<T> { + fn clone(&self) -> Self { + *self + } +} +impl<T> Copy for ArrayLayout<T> {} + +const ISIZE_MAX: usize = isize::MAX as usize; + +impl<T> ArrayLayout<T> { + /// Creates a new layout for `[T; 0]`. + pub const fn empty() -> Self { + // INVARIANT: `0 * size_of::<T>() <= isize::MAX`. + Self { + len: 0, + _phantom: PhantomData, + } + } + + /// Creates a new layout for `[T; len]`. + /// + /// # Errors + /// + /// When `len * size_of::<T>()` overflows or when `len * size_of::<T>() > isize::MAX`. + /// + /// # Examples + /// + /// ``` + /// # use kernel::alloc::layout::{ArrayLayout, LayoutError}; + /// let layout = ArrayLayout::<i32>::new(15)?; + /// assert_eq!(layout.len(), 15); + /// + /// // Errors because `len * size_of::<T>()` overflows. + /// let layout = ArrayLayout::<i32>::new(isize::MAX as usize); + /// assert!(layout.is_err()); + /// + /// // Errors because `len * size_of::<i32>() > isize::MAX`, + /// // even though `len < isize::MAX`. + /// let layout = ArrayLayout::<i32>::new(isize::MAX as usize / 2); + /// assert!(layout.is_err()); + /// + /// # Ok::<(), Error>(()) + /// ``` + pub const fn new(len: usize) -> Result<Self, LayoutError> { + match len.checked_mul(core::mem::size_of::<T>()) { + Some(size) if size <= ISIZE_MAX => { + // INVARIANT: We checked above that `len * size_of::<T>() <= isize::MAX`. + Ok(Self { + len, + _phantom: PhantomData, + }) + } + _ => Err(LayoutError), + } + } + + /// Creates a new layout for `[T; len]`. + /// + /// # Safety + /// + /// `len` must be a value, for which `len * size_of::<T>() <= isize::MAX` is true. + pub const unsafe fn new_unchecked(len: usize) -> Self { + // INVARIANT: By the safety requirements of this function + // `len * size_of::<T>() <= isize::MAX`. + Self { + len, + _phantom: PhantomData, + } + } + + /// Returns the number of array elements represented by this layout. + pub const fn len(&self) -> usize { + self.len + } + + /// Returns `true` when no array elements are represented by this layout. + pub const fn is_empty(&self) -> bool { + self.len == 0 + } + + /// Returns the size of the [`ArrayLayout`] in bytes. + pub const fn size(&self) -> usize { + self.len() * core::mem::size_of::<T>() + } +} + +impl<T> From<ArrayLayout<T>> for Layout { + fn from(value: ArrayLayout<T>) -> Self { + let res = Layout::array::<T>(value.len); + // SAFETY: By the type invariant of `ArrayLayout` we have + // `len * size_of::<T>() <= isize::MAX` and thus the result must be `Ok`. + unsafe { res.unwrap_unchecked() } + } +} diff --git a/rust/kernel/allocator.rs b/rust/kernel/allocator.rs deleted file mode 100644 index 397a3dd57a9b..000000000000 --- a/rust/kernel/allocator.rs +++ /dev/null @@ -1,64 +0,0 @@ -// SPDX-License-Identifier: GPL-2.0 - -//! Allocator support. - -use core::alloc::{GlobalAlloc, Layout}; -use core::ptr; - -use crate::bindings; - -struct KernelAllocator; - -unsafe impl GlobalAlloc for KernelAllocator { - unsafe fn alloc(&self, layout: Layout) -> *mut u8 { - // `krealloc()` is used instead of `kmalloc()` because the latter is - // an inline function and cannot be bound to as a result. - unsafe { bindings::krealloc(ptr::null(), layout.size(), bindings::GFP_KERNEL) as *mut u8 } - } - - unsafe fn dealloc(&self, ptr: *mut u8, _layout: Layout) { - unsafe { - bindings::kfree(ptr as *const core::ffi::c_void); - } - } -} - -#[global_allocator] -static ALLOCATOR: KernelAllocator = KernelAllocator; - -// `rustc` only generates these for some crate types. Even then, we would need -// to extract the object file that has them from the archive. For the moment, -// let's generate them ourselves instead. -// -// Note that `#[no_mangle]` implies exported too, nowadays. -#[no_mangle] -fn __rust_alloc(size: usize, _align: usize) -> *mut u8 { - unsafe { bindings::krealloc(core::ptr::null(), size, bindings::GFP_KERNEL) as *mut u8 } -} - -#[no_mangle] -fn __rust_dealloc(ptr: *mut u8, _size: usize, _align: usize) { - unsafe { bindings::kfree(ptr as *const core::ffi::c_void) }; -} - -#[no_mangle] -fn __rust_realloc(ptr: *mut u8, _old_size: usize, _align: usize, new_size: usize) -> *mut u8 { - unsafe { - bindings::krealloc( - ptr as *const core::ffi::c_void, - new_size, - bindings::GFP_KERNEL, - ) as *mut u8 - } -} - -#[no_mangle] -fn __rust_alloc_zeroed(size: usize, _align: usize) -> *mut u8 { - unsafe { - bindings::krealloc( - core::ptr::null(), - size, - bindings::GFP_KERNEL | bindings::__GFP_ZERO, - ) as *mut u8 - } -} diff --git a/rust/kernel/auxiliary.rs b/rust/kernel/auxiliary.rs new file mode 100644 index 000000000000..56f3c180e8f6 --- /dev/null +++ b/rust/kernel/auxiliary.rs @@ -0,0 +1,381 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! Abstractions for the auxiliary bus. +//! +//! C header: [`include/linux/auxiliary_bus.h`](srctree/include/linux/auxiliary_bus.h) + +use crate::{ + bindings, container_of, device, + device_id::{RawDeviceId, RawDeviceIdIndex}, + devres::Devres, + driver, + error::{from_result, to_result, Result}, + prelude::*, + types::Opaque, + ThisModule, +}; +use core::{ + marker::PhantomData, + mem::offset_of, + ptr::{addr_of_mut, NonNull}, +}; + +/// An adapter for the registration of auxiliary drivers. +pub struct Adapter<T: Driver>(T); + +// SAFETY: A call to `unregister` for a given instance of `RegType` is guaranteed to be valid if +// a preceding call to `register` has been successful. +unsafe impl<T: Driver + 'static> driver::RegistrationOps for Adapter<T> { + type RegType = bindings::auxiliary_driver; + + unsafe fn register( + adrv: &Opaque<Self::RegType>, + name: &'static CStr, + module: &'static ThisModule, + ) -> Result { + // SAFETY: It's safe to set the fields of `struct auxiliary_driver` on initialization. + unsafe { + (*adrv.get()).name = name.as_char_ptr(); + (*adrv.get()).probe = Some(Self::probe_callback); + (*adrv.get()).remove = Some(Self::remove_callback); + (*adrv.get()).id_table = T::ID_TABLE.as_ptr(); + } + + // SAFETY: `adrv` is guaranteed to be a valid `RegType`. + to_result(unsafe { + bindings::__auxiliary_driver_register(adrv.get(), module.0, name.as_char_ptr()) + }) + } + + unsafe fn unregister(adrv: &Opaque<Self::RegType>) { + // SAFETY: `adrv` is guaranteed to be a valid `RegType`. + unsafe { bindings::auxiliary_driver_unregister(adrv.get()) } + } +} + +impl<T: Driver + 'static> Adapter<T> { + extern "C" fn probe_callback( + adev: *mut bindings::auxiliary_device, + id: *const bindings::auxiliary_device_id, + ) -> c_int { + // SAFETY: The auxiliary bus only ever calls the probe callback with a valid pointer to a + // `struct auxiliary_device`. + // + // INVARIANT: `adev` is valid for the duration of `probe_callback()`. + let adev = unsafe { &*adev.cast::<Device<device::CoreInternal>>() }; + + // SAFETY: `DeviceId` is a `#[repr(transparent)`] wrapper of `struct auxiliary_device_id` + // and does not add additional invariants, so it's safe to transmute. + let id = unsafe { &*id.cast::<DeviceId>() }; + let info = T::ID_TABLE.info(id.index()); + + from_result(|| { + let data = T::probe(adev, info); + + adev.as_ref().set_drvdata(data)?; + Ok(0) + }) + } + + extern "C" fn remove_callback(adev: *mut bindings::auxiliary_device) { + // SAFETY: The auxiliary bus only ever calls the probe callback with a valid pointer to a + // `struct auxiliary_device`. + // + // INVARIANT: `adev` is valid for the duration of `probe_callback()`. + let adev = unsafe { &*adev.cast::<Device<device::CoreInternal>>() }; + + // SAFETY: `remove_callback` is only ever called after a successful call to + // `probe_callback`, hence it's guaranteed that `Device::set_drvdata()` has been called + // and stored a `Pin<KBox<T>>`. + drop(unsafe { adev.as_ref().drvdata_obtain::<T>() }); + } +} + +/// Declares a kernel module that exposes a single auxiliary driver. +#[macro_export] +macro_rules! module_auxiliary_driver { + ($($f:tt)*) => { + $crate::module_driver!(<T>, $crate::auxiliary::Adapter<T>, { $($f)* }); + }; +} + +/// Abstraction for `bindings::auxiliary_device_id`. +#[repr(transparent)] +#[derive(Clone, Copy)] +pub struct DeviceId(bindings::auxiliary_device_id); + +impl DeviceId { + /// Create a new [`DeviceId`] from name. + pub const fn new(modname: &'static CStr, name: &'static CStr) -> Self { + let name = name.to_bytes_with_nul(); + let modname = modname.to_bytes_with_nul(); + + // TODO: Replace with `bindings::auxiliary_device_id::default()` once stabilized for + // `const`. + // + // SAFETY: FFI type is valid to be zero-initialized. + let mut id: bindings::auxiliary_device_id = unsafe { core::mem::zeroed() }; + + let mut i = 0; + while i < modname.len() { + id.name[i] = modname[i]; + i += 1; + } + + // Reuse the space of the NULL terminator. + id.name[i - 1] = b'.'; + + let mut j = 0; + while j < name.len() { + id.name[i] = name[j]; + i += 1; + j += 1; + } + + Self(id) + } +} + +// SAFETY: `DeviceId` is a `#[repr(transparent)]` wrapper of `auxiliary_device_id` and does not add +// additional invariants, so it's safe to transmute to `RawType`. +unsafe impl RawDeviceId for DeviceId { + type RawType = bindings::auxiliary_device_id; +} + +// SAFETY: `DRIVER_DATA_OFFSET` is the offset to the `driver_data` field. +unsafe impl RawDeviceIdIndex for DeviceId { + const DRIVER_DATA_OFFSET: usize = + core::mem::offset_of!(bindings::auxiliary_device_id, driver_data); + + fn index(&self) -> usize { + self.0.driver_data + } +} + +/// IdTable type for auxiliary drivers. +pub type IdTable<T> = &'static dyn kernel::device_id::IdTable<DeviceId, T>; + +/// Create a auxiliary `IdTable` with its alias for modpost. +#[macro_export] +macro_rules! auxiliary_device_table { + ($table_name:ident, $module_table_name:ident, $id_info_type: ty, $table_data: expr) => { + const $table_name: $crate::device_id::IdArray< + $crate::auxiliary::DeviceId, + $id_info_type, + { $table_data.len() }, + > = $crate::device_id::IdArray::new($table_data); + + $crate::module_device_table!("auxiliary", $module_table_name, $table_name); + }; +} + +/// The auxiliary driver trait. +/// +/// Drivers must implement this trait in order to get an auxiliary driver registered. +pub trait Driver { + /// The type holding information about each device id supported by the driver. + /// + /// TODO: Use associated_type_defaults once stabilized: + /// + /// type IdInfo: 'static = (); + type IdInfo: 'static; + + /// The table of device ids supported by the driver. + const ID_TABLE: IdTable<Self::IdInfo>; + + /// Auxiliary driver probe. + /// + /// Called when an auxiliary device is matches a corresponding driver. + fn probe(dev: &Device<device::Core>, id_info: &Self::IdInfo) -> impl PinInit<Self, Error>; +} + +/// The auxiliary device representation. +/// +/// This structure represents the Rust abstraction for a C `struct auxiliary_device`. The +/// implementation abstracts the usage of an already existing C `struct auxiliary_device` within +/// Rust code that we get passed from the C side. +/// +/// # Invariants +/// +/// A [`Device`] instance represents a valid `struct auxiliary_device` created by the C portion of +/// the kernel. +#[repr(transparent)] +pub struct Device<Ctx: device::DeviceContext = device::Normal>( + Opaque<bindings::auxiliary_device>, + PhantomData<Ctx>, +); + +impl<Ctx: device::DeviceContext> Device<Ctx> { + fn as_raw(&self) -> *mut bindings::auxiliary_device { + self.0.get() + } + + /// Returns the auxiliary device' id. + pub fn id(&self) -> u32 { + // SAFETY: By the type invariant `self.as_raw()` is a valid pointer to a + // `struct auxiliary_device`. + unsafe { (*self.as_raw()).id } + } +} + +impl Device<device::Bound> { + /// Returns a bound reference to the parent [`device::Device`]. + pub fn parent(&self) -> &device::Device<device::Bound> { + let parent = (**self).parent(); + + // SAFETY: A bound auxiliary device always has a bound parent device. + unsafe { parent.as_bound() } + } +} + +impl Device { + /// Returns a reference to the parent [`device::Device`]. + pub fn parent(&self) -> &device::Device { + // SAFETY: A `struct auxiliary_device` always has a parent. + unsafe { self.as_ref().parent().unwrap_unchecked() } + } + + extern "C" fn release(dev: *mut bindings::device) { + // SAFETY: By the type invariant `self.0.as_raw` is a pointer to the `struct device` + // embedded in `struct auxiliary_device`. + let adev = unsafe { container_of!(dev, bindings::auxiliary_device, dev) }; + + // SAFETY: `adev` points to the memory that has been allocated in `Registration::new`, via + // `KBox::new(Opaque::<bindings::auxiliary_device>::zeroed(), GFP_KERNEL)`. + let _ = unsafe { KBox::<Opaque<bindings::auxiliary_device>>::from_raw(adev.cast()) }; + } +} + +// SAFETY: `auxiliary::Device` is a transparent wrapper of `struct auxiliary_device`. +// The offset is guaranteed to point to a valid device field inside `auxiliary::Device`. +unsafe impl<Ctx: device::DeviceContext> device::AsBusDevice<Ctx> for Device<Ctx> { + const OFFSET: usize = offset_of!(bindings::auxiliary_device, dev); +} + +// SAFETY: `Device` is a transparent wrapper of a type that doesn't depend on `Device`'s generic +// argument. +kernel::impl_device_context_deref!(unsafe { Device }); +kernel::impl_device_context_into_aref!(Device); + +// SAFETY: Instances of `Device` are always reference-counted. +unsafe impl crate::sync::aref::AlwaysRefCounted for Device { + fn inc_ref(&self) { + // SAFETY: The existence of a shared reference guarantees that the refcount is non-zero. + unsafe { bindings::get_device(self.as_ref().as_raw()) }; + } + + unsafe fn dec_ref(obj: NonNull<Self>) { + // CAST: `Self` a transparent wrapper of `bindings::auxiliary_device`. + let adev: *mut bindings::auxiliary_device = obj.cast().as_ptr(); + + // SAFETY: By the type invariant of `Self`, `adev` is a pointer to a valid + // `struct auxiliary_device`. + let dev = unsafe { addr_of_mut!((*adev).dev) }; + + // SAFETY: The safety requirements guarantee that the refcount is non-zero. + unsafe { bindings::put_device(dev) } + } +} + +impl<Ctx: device::DeviceContext> AsRef<device::Device<Ctx>> for Device<Ctx> { + fn as_ref(&self) -> &device::Device<Ctx> { + // SAFETY: By the type invariant of `Self`, `self.as_raw()` is a pointer to a valid + // `struct auxiliary_device`. + let dev = unsafe { addr_of_mut!((*self.as_raw()).dev) }; + + // SAFETY: `dev` points to a valid `struct device`. + unsafe { device::Device::from_raw(dev) } + } +} + +// SAFETY: A `Device` is always reference-counted and can be released from any thread. +unsafe impl Send for Device {} + +// SAFETY: `Device` can be shared among threads because all methods of `Device` +// (i.e. `Device<Normal>) are thread safe. +unsafe impl Sync for Device {} + +/// The registration of an auxiliary device. +/// +/// This type represents the registration of a [`struct auxiliary_device`]. When its parent device +/// is unbound, the corresponding auxiliary device will be unregistered from the system. +/// +/// # Invariants +/// +/// `self.0` always holds a valid pointer to an initialized and registered +/// [`struct auxiliary_device`]. +pub struct Registration(NonNull<bindings::auxiliary_device>); + +impl Registration { + /// Create and register a new auxiliary device. + pub fn new<'a>( + parent: &'a device::Device<device::Bound>, + name: &'a CStr, + id: u32, + modname: &'a CStr, + ) -> impl PinInit<Devres<Self>, Error> + 'a { + pin_init::pin_init_scope(move || { + let boxed = KBox::new(Opaque::<bindings::auxiliary_device>::zeroed(), GFP_KERNEL)?; + let adev = boxed.get(); + + // SAFETY: It's safe to set the fields of `struct auxiliary_device` on initialization. + unsafe { + (*adev).dev.parent = parent.as_raw(); + (*adev).dev.release = Some(Device::release); + (*adev).name = name.as_char_ptr(); + (*adev).id = id; + } + + // SAFETY: `adev` is guaranteed to be a valid pointer to a `struct auxiliary_device`, + // which has not been initialized yet. + unsafe { bindings::auxiliary_device_init(adev) }; + + // Now that `adev` is initialized, leak the `Box`; the corresponding memory will be + // freed by `Device::release` when the last reference to the `struct auxiliary_device` + // is dropped. + let _ = KBox::into_raw(boxed); + + // SAFETY: + // - `adev` is guaranteed to be a valid pointer to a `struct auxiliary_device`, which + // has been initialized, + // - `modname.as_char_ptr()` is a NULL terminated string. + let ret = unsafe { bindings::__auxiliary_device_add(adev, modname.as_char_ptr()) }; + if ret != 0 { + // SAFETY: `adev` is guaranteed to be a valid pointer to a + // `struct auxiliary_device`, which has been initialized. + unsafe { bindings::auxiliary_device_uninit(adev) }; + + return Err(Error::from_errno(ret)); + } + + // INVARIANT: The device will remain registered until `auxiliary_device_delete()` is + // called, which happens in `Self::drop()`. + Ok(Devres::new( + parent, + // SAFETY: `adev` is guaranteed to be non-null, since the `KBox` was allocated + // successfully. + Self(unsafe { NonNull::new_unchecked(adev) }), + )) + }) + } +} + +impl Drop for Registration { + fn drop(&mut self) { + // SAFETY: By the type invariant of `Self`, `self.0.as_ptr()` is a valid registered + // `struct auxiliary_device`. + unsafe { bindings::auxiliary_device_delete(self.0.as_ptr()) }; + + // This drops the reference we acquired through `auxiliary_device_init()`. + // + // SAFETY: By the type invariant of `Self`, `self.0.as_ptr()` is a valid registered + // `struct auxiliary_device`. + unsafe { bindings::auxiliary_device_uninit(self.0.as_ptr()) }; + } +} + +// SAFETY: A `Registration` of a `struct auxiliary_device` can be released from any thread. +unsafe impl Send for Registration {} + +// SAFETY: `Registration` does not expose any methods or fields that need synchronization. +unsafe impl Sync for Registration {} diff --git a/rust/kernel/bitmap.rs b/rust/kernel/bitmap.rs new file mode 100644 index 000000000000..83d7dea99137 --- /dev/null +++ b/rust/kernel/bitmap.rs @@ -0,0 +1,621 @@ +// SPDX-License-Identifier: GPL-2.0 + +// Copyright (C) 2025 Google LLC. + +//! Rust API for bitmap. +//! +//! C headers: [`include/linux/bitmap.h`](srctree/include/linux/bitmap.h). + +use crate::alloc::{AllocError, Flags}; +use crate::bindings; +#[cfg(not(CONFIG_RUST_BITMAP_HARDENED))] +use crate::pr_err; +use core::ptr::NonNull; + +/// Represents a C bitmap. Wraps underlying C bitmap API. +/// +/// # Invariants +/// +/// Must reference a `[c_ulong]` long enough to fit `data.len()` bits. +#[cfg_attr(CONFIG_64BIT, repr(align(8)))] +#[cfg_attr(not(CONFIG_64BIT), repr(align(4)))] +pub struct Bitmap { + data: [()], +} + +impl Bitmap { + /// Borrows a C bitmap. + /// + /// # Safety + /// + /// * `ptr` holds a non-null address of an initialized array of `unsigned long` + /// that is large enough to hold `nbits` bits. + /// * the array must not be freed for the lifetime of this [`Bitmap`] + /// * concurrent access only happens through atomic operations + pub unsafe fn from_raw<'a>(ptr: *const usize, nbits: usize) -> &'a Bitmap { + let data: *const [()] = core::ptr::slice_from_raw_parts(ptr.cast(), nbits); + // INVARIANT: `data` references an initialized array that can hold `nbits` bits. + // SAFETY: + // The caller guarantees that `data` (derived from `ptr` and `nbits`) + // points to a valid, initialized, and appropriately sized memory region + // that will not be freed for the lifetime 'a. + // We are casting `*const [()]` to `*const Bitmap`. The `Bitmap` + // struct is a ZST with a `data: [()]` field. This means its layout + // is compatible with a slice of `()`, and effectively it's a "thin pointer" + // (its size is 0 and alignment is 1). The `slice_from_raw_parts` + // function correctly encodes the length (number of bits, not elements) + // into the metadata of the fat pointer. Therefore, dereferencing this + // pointer as `&Bitmap` is safe given the caller's guarantees. + unsafe { &*(data as *const Bitmap) } + } + + /// Borrows a C bitmap exclusively. + /// + /// # Safety + /// + /// * `ptr` holds a non-null address of an initialized array of `unsigned long` + /// that is large enough to hold `nbits` bits. + /// * the array must not be freed for the lifetime of this [`Bitmap`] + /// * no concurrent access may happen. + pub unsafe fn from_raw_mut<'a>(ptr: *mut usize, nbits: usize) -> &'a mut Bitmap { + let data: *mut [()] = core::ptr::slice_from_raw_parts_mut(ptr.cast(), nbits); + // INVARIANT: `data` references an initialized array that can hold `nbits` bits. + // SAFETY: + // The caller guarantees that `data` (derived from `ptr` and `nbits`) + // points to a valid, initialized, and appropriately sized memory region + // that will not be freed for the lifetime 'a. + // Furthermore, the caller guarantees no concurrent access will happen, + // which upholds the exclusivity requirement for a mutable reference. + // Similar to `from_raw`, casting `*mut [()]` to `*mut Bitmap` is + // safe because `Bitmap` is a ZST with a `data: [()]` field, + // making its layout compatible with a slice of `()`. + unsafe { &mut *(data as *mut Bitmap) } + } + + /// Returns a raw pointer to the backing [`Bitmap`]. + pub fn as_ptr(&self) -> *const usize { + core::ptr::from_ref::<Bitmap>(self).cast::<usize>() + } + + /// Returns a mutable raw pointer to the backing [`Bitmap`]. + pub fn as_mut_ptr(&mut self) -> *mut usize { + core::ptr::from_mut::<Bitmap>(self).cast::<usize>() + } + + /// Returns length of this [`Bitmap`]. + #[expect(clippy::len_without_is_empty)] + pub fn len(&self) -> usize { + self.data.len() + } +} + +/// Holds either a pointer to array of `unsigned long` or a small bitmap. +#[repr(C)] +union BitmapRepr { + bitmap: usize, + ptr: NonNull<usize>, +} + +macro_rules! bitmap_assert { + ($cond:expr, $($arg:tt)+) => { + #[cfg(CONFIG_RUST_BITMAP_HARDENED)] + assert!($cond, $($arg)*); + } +} + +macro_rules! bitmap_assert_return { + ($cond:expr, $($arg:tt)+) => { + #[cfg(CONFIG_RUST_BITMAP_HARDENED)] + assert!($cond, $($arg)*); + + #[cfg(not(CONFIG_RUST_BITMAP_HARDENED))] + if !($cond) { + pr_err!($($arg)*); + return + } + } +} + +/// Represents an owned bitmap. +/// +/// Wraps underlying C bitmap API. See [`Bitmap`] for available +/// methods. +/// +/// # Examples +/// +/// Basic usage +/// +/// ``` +/// use kernel::alloc::flags::GFP_KERNEL; +/// use kernel::bitmap::BitmapVec; +/// +/// let mut b = BitmapVec::new(16, GFP_KERNEL)?; +/// +/// assert_eq!(16, b.len()); +/// for i in 0..16 { +/// if i % 4 == 0 { +/// b.set_bit(i); +/// } +/// } +/// assert_eq!(Some(0), b.next_bit(0)); +/// assert_eq!(Some(1), b.next_zero_bit(0)); +/// assert_eq!(Some(4), b.next_bit(1)); +/// assert_eq!(Some(5), b.next_zero_bit(4)); +/// assert_eq!(Some(12), b.last_bit()); +/// # Ok::<(), Error>(()) +/// ``` +/// +/// # Invariants +/// +/// * `nbits` is `<= MAX_LEN`. +/// * if `nbits <= MAX_INLINE_LEN`, then `repr` is a `usize`. +/// * otherwise, `repr` holds a non-null pointer to an initialized +/// array of `unsigned long` that is large enough to hold `nbits` bits. +pub struct BitmapVec { + /// Representation of bitmap. + repr: BitmapRepr, + /// Length of this bitmap. Must be `<= MAX_LEN`. + nbits: usize, +} + +impl core::ops::Deref for BitmapVec { + type Target = Bitmap; + + fn deref(&self) -> &Bitmap { + let ptr = if self.nbits <= BitmapVec::MAX_INLINE_LEN { + // SAFETY: Bitmap is represented inline. + #[allow(unused_unsafe, reason = "Safe since Rust 1.92.0")] + unsafe { + core::ptr::addr_of!(self.repr.bitmap) + } + } else { + // SAFETY: Bitmap is represented as array of `unsigned long`. + unsafe { self.repr.ptr.as_ptr() } + }; + + // SAFETY: We got the right pointer and invariants of [`Bitmap`] hold. + // An inline bitmap is treated like an array with single element. + unsafe { Bitmap::from_raw(ptr, self.nbits) } + } +} + +impl core::ops::DerefMut for BitmapVec { + fn deref_mut(&mut self) -> &mut Bitmap { + let ptr = if self.nbits <= BitmapVec::MAX_INLINE_LEN { + // SAFETY: Bitmap is represented inline. + #[allow(unused_unsafe, reason = "Safe since Rust 1.92.0")] + unsafe { + core::ptr::addr_of_mut!(self.repr.bitmap) + } + } else { + // SAFETY: Bitmap is represented as array of `unsigned long`. + unsafe { self.repr.ptr.as_ptr() } + }; + + // SAFETY: We got the right pointer and invariants of [`BitmapVec`] hold. + // An inline bitmap is treated like an array with single element. + unsafe { Bitmap::from_raw_mut(ptr, self.nbits) } + } +} + +/// Enable ownership transfer to other threads. +/// +/// SAFETY: We own the underlying bitmap representation. +unsafe impl Send for BitmapVec {} + +/// Enable unsynchronized concurrent access to [`BitmapVec`] through shared references. +/// +/// SAFETY: `deref()` will return a reference to a [`Bitmap`]. Its methods +/// take immutable references are either atomic or read-only. +unsafe impl Sync for BitmapVec {} + +impl Drop for BitmapVec { + fn drop(&mut self) { + if self.nbits <= BitmapVec::MAX_INLINE_LEN { + return; + } + // SAFETY: `self.ptr` was returned by the C `bitmap_zalloc`. + // + // INVARIANT: there is no other use of the `self.ptr` after this + // call and the value is being dropped so the broken invariant is + // not observable on function exit. + unsafe { bindings::bitmap_free(self.repr.ptr.as_ptr()) }; + } +} + +impl BitmapVec { + /// The maximum possible length of a `BitmapVec`. + pub const MAX_LEN: usize = i32::MAX as usize; + + /// The maximum length that uses the inline representation. + pub const MAX_INLINE_LEN: usize = usize::BITS as usize; + + /// Construct a longest possible inline [`BitmapVec`]. + #[inline] + pub fn new_inline() -> Self { + // INVARIANT: `nbits <= MAX_INLINE_LEN`, so an inline bitmap is the right repr. + BitmapVec { + repr: BitmapRepr { bitmap: 0 }, + nbits: BitmapVec::MAX_INLINE_LEN, + } + } + + /// Constructs a new [`BitmapVec`]. + /// + /// Fails with [`AllocError`] when the [`BitmapVec`] could not be allocated. This + /// includes the case when `nbits` is greater than `MAX_LEN`. + #[inline] + pub fn new(nbits: usize, flags: Flags) -> Result<Self, AllocError> { + if nbits <= BitmapVec::MAX_INLINE_LEN { + return Ok(BitmapVec { + repr: BitmapRepr { bitmap: 0 }, + nbits, + }); + } + if nbits > Self::MAX_LEN { + return Err(AllocError); + } + let nbits_u32 = u32::try_from(nbits).unwrap(); + // SAFETY: `MAX_INLINE_LEN < nbits` and `nbits <= MAX_LEN`. + let ptr = unsafe { bindings::bitmap_zalloc(nbits_u32, flags.as_raw()) }; + let ptr = NonNull::new(ptr).ok_or(AllocError)?; + // INVARIANT: `ptr` returned by C `bitmap_zalloc` and `nbits` checked. + Ok(BitmapVec { + repr: BitmapRepr { ptr }, + nbits, + }) + } + + /// Returns length of this [`Bitmap`]. + #[allow(clippy::len_without_is_empty)] + #[inline] + pub fn len(&self) -> usize { + self.nbits + } + + /// Fills this `Bitmap` with random bits. + #[cfg(CONFIG_FIND_BIT_BENCHMARK_RUST)] + pub fn fill_random(&mut self) { + // SAFETY: `self.as_mut_ptr` points to either an array of the + // appropriate length or one usize. + unsafe { + bindings::get_random_bytes( + self.as_mut_ptr().cast::<ffi::c_void>(), + usize::div_ceil(self.nbits, bindings::BITS_PER_LONG as usize) + * bindings::BITS_PER_LONG as usize + / 8, + ); + } + } +} + +impl Bitmap { + /// Set bit with index `index`. + /// + /// ATTENTION: `set_bit` is non-atomic, which differs from the naming + /// convention in C code. The corresponding C function is `__set_bit`. + /// + /// If CONFIG_RUST_BITMAP_HARDENED is not enabled and `index` is greater than + /// or equal to `self.nbits`, does nothing. + /// + /// # Panics + /// + /// Panics if CONFIG_RUST_BITMAP_HARDENED is enabled and `index` is greater than + /// or equal to `self.nbits`. + #[inline] + pub fn set_bit(&mut self, index: usize) { + bitmap_assert_return!( + index < self.len(), + "Bit `index` must be < {}, was {}", + self.len(), + index + ); + // SAFETY: Bit `index` is within bounds. + unsafe { bindings::__set_bit(index, self.as_mut_ptr()) }; + } + + /// Set bit with index `index`, atomically. + /// + /// This is a relaxed atomic operation (no implied memory barriers). + /// + /// ATTENTION: The naming convention differs from C, where the corresponding + /// function is called `set_bit`. + /// + /// If CONFIG_RUST_BITMAP_HARDENED is not enabled and `index` is greater than + /// or equal to `self.len()`, does nothing. + /// + /// # Panics + /// + /// Panics if CONFIG_RUST_BITMAP_HARDENED is enabled and `index` is greater than + /// or equal to `self.len()`. + #[inline] + pub fn set_bit_atomic(&self, index: usize) { + bitmap_assert_return!( + index < self.len(), + "Bit `index` must be < {}, was {}", + self.len(), + index + ); + // SAFETY: `index` is within bounds and the caller has ensured that + // there is no mix of non-atomic and atomic operations. + unsafe { bindings::set_bit(index, self.as_ptr().cast_mut()) }; + } + + /// Clear `index` bit. + /// + /// ATTENTION: `clear_bit` is non-atomic, which differs from the naming + /// convention in C code. The corresponding C function is `__clear_bit`. + /// + /// If CONFIG_RUST_BITMAP_HARDENED is not enabled and `index` is greater than + /// or equal to `self.len()`, does nothing. + /// + /// # Panics + /// + /// Panics if CONFIG_RUST_BITMAP_HARDENED is enabled and `index` is greater than + /// or equal to `self.len()`. + #[inline] + pub fn clear_bit(&mut self, index: usize) { + bitmap_assert_return!( + index < self.len(), + "Bit `index` must be < {}, was {}", + self.len(), + index + ); + // SAFETY: `index` is within bounds. + unsafe { bindings::__clear_bit(index, self.as_mut_ptr()) }; + } + + /// Clear `index` bit, atomically. + /// + /// This is a relaxed atomic operation (no implied memory barriers). + /// + /// ATTENTION: The naming convention differs from C, where the corresponding + /// function is called `clear_bit`. + /// + /// If CONFIG_RUST_BITMAP_HARDENED is not enabled and `index` is greater than + /// or equal to `self.len()`, does nothing. + /// + /// # Panics + /// + /// Panics if CONFIG_RUST_BITMAP_HARDENED is enabled and `index` is greater than + /// or equal to `self.len()`. + #[inline] + pub fn clear_bit_atomic(&self, index: usize) { + bitmap_assert_return!( + index < self.len(), + "Bit `index` must be < {}, was {}", + self.len(), + index + ); + // SAFETY: `index` is within bounds and the caller has ensured that + // there is no mix of non-atomic and atomic operations. + unsafe { bindings::clear_bit(index, self.as_ptr().cast_mut()) }; + } + + /// Copy `src` into this [`Bitmap`] and set any remaining bits to zero. + /// + /// # Examples + /// + /// ``` + /// use kernel::alloc::{AllocError, flags::GFP_KERNEL}; + /// use kernel::bitmap::BitmapVec; + /// + /// let mut long_bitmap = BitmapVec::new(256, GFP_KERNEL)?; + /// + /// assert_eq!(None, long_bitmap.last_bit()); + /// + /// let mut short_bitmap = BitmapVec::new(16, GFP_KERNEL)?; + /// + /// short_bitmap.set_bit(7); + /// long_bitmap.copy_and_extend(&short_bitmap); + /// assert_eq!(Some(7), long_bitmap.last_bit()); + /// + /// # Ok::<(), AllocError>(()) + /// ``` + #[inline] + pub fn copy_and_extend(&mut self, src: &Bitmap) { + let len = core::cmp::min(src.len(), self.len()); + // SAFETY: access to `self` and `src` is within bounds. + unsafe { + bindings::bitmap_copy_and_extend( + self.as_mut_ptr(), + src.as_ptr(), + len as u32, + self.len() as u32, + ) + }; + } + + /// Finds last set bit. + /// + /// # Examples + /// + /// ``` + /// use kernel::alloc::{AllocError, flags::GFP_KERNEL}; + /// use kernel::bitmap::BitmapVec; + /// + /// let bitmap = BitmapVec::new(64, GFP_KERNEL)?; + /// + /// match bitmap.last_bit() { + /// Some(idx) => { + /// pr_info!("The last bit has index {idx}.\n"); + /// } + /// None => { + /// pr_info!("All bits in this bitmap are 0.\n"); + /// } + /// } + /// # Ok::<(), AllocError>(()) + /// ``` + #[inline] + pub fn last_bit(&self) -> Option<usize> { + // SAFETY: `_find_next_bit` access is within bounds due to invariant. + let index = unsafe { bindings::_find_last_bit(self.as_ptr(), self.len()) }; + if index >= self.len() { + None + } else { + Some(index) + } + } + + /// Finds next set bit, starting from `start`. + /// + /// Returns `None` if `start` is greater or equal to `self.nbits`. + #[inline] + pub fn next_bit(&self, start: usize) -> Option<usize> { + bitmap_assert!( + start < self.len(), + "`start` must be < {} was {}", + self.len(), + start + ); + // SAFETY: `_find_next_bit` tolerates out-of-bounds arguments and returns a + // value larger than or equal to `self.len()` in that case. + let index = unsafe { bindings::_find_next_bit(self.as_ptr(), self.len(), start) }; + if index >= self.len() { + None + } else { + Some(index) + } + } + + /// Finds next zero bit, starting from `start`. + /// Returns `None` if `start` is greater than or equal to `self.len()`. + #[inline] + pub fn next_zero_bit(&self, start: usize) -> Option<usize> { + bitmap_assert!( + start < self.len(), + "`start` must be < {} was {}", + self.len(), + start + ); + // SAFETY: `_find_next_zero_bit` tolerates out-of-bounds arguments and returns a + // value larger than or equal to `self.len()` in that case. + let index = unsafe { bindings::_find_next_zero_bit(self.as_ptr(), self.len(), start) }; + if index >= self.len() { + None + } else { + Some(index) + } + } +} + +use macros::kunit_tests; + +#[kunit_tests(rust_kernel_bitmap)] +mod tests { + use super::*; + use kernel::alloc::flags::GFP_KERNEL; + + #[test] + fn bitmap_borrow() { + let fake_bitmap: [usize; 2] = [0, 0]; + let fake_bitmap_len = 2 * usize::BITS as usize; + // SAFETY: `fake_c_bitmap` is an array of expected length. + let b = unsafe { Bitmap::from_raw(fake_bitmap.as_ptr(), fake_bitmap_len) }; + assert_eq!(fake_bitmap_len, b.len()); + assert_eq!(None, b.next_bit(0)); + } + + #[test] + fn bitmap_copy() { + let fake_bitmap: usize = 0xFF; + // SAFETY: `fake_c_bitmap` can be used as one-element array of expected length. + let b = unsafe { Bitmap::from_raw(core::ptr::addr_of!(fake_bitmap), 8) }; + assert_eq!(8, b.len()); + assert_eq!(None, b.next_zero_bit(0)); + } + + #[test] + fn bitmap_vec_new() -> Result<(), AllocError> { + let b = BitmapVec::new(0, GFP_KERNEL)?; + assert_eq!(0, b.len()); + + let b = BitmapVec::new(3, GFP_KERNEL)?; + assert_eq!(3, b.len()); + + let b = BitmapVec::new(1024, GFP_KERNEL)?; + assert_eq!(1024, b.len()); + + // Requesting too large values results in [`AllocError`]. + let res = BitmapVec::new(1 << 31, GFP_KERNEL); + assert!(res.is_err()); + Ok(()) + } + + #[test] + fn bitmap_set_clear_find() -> Result<(), AllocError> { + let mut b = BitmapVec::new(128, GFP_KERNEL)?; + + // Zero-initialized + assert_eq!(None, b.next_bit(0)); + assert_eq!(Some(0), b.next_zero_bit(0)); + assert_eq!(None, b.last_bit()); + + b.set_bit(17); + + assert_eq!(Some(17), b.next_bit(0)); + assert_eq!(Some(17), b.next_bit(17)); + assert_eq!(None, b.next_bit(18)); + assert_eq!(Some(17), b.last_bit()); + + b.set_bit(107); + + assert_eq!(Some(17), b.next_bit(0)); + assert_eq!(Some(17), b.next_bit(17)); + assert_eq!(Some(107), b.next_bit(18)); + assert_eq!(Some(107), b.last_bit()); + + b.clear_bit(17); + + assert_eq!(Some(107), b.next_bit(0)); + assert_eq!(Some(107), b.last_bit()); + Ok(()) + } + + #[test] + fn owned_bitmap_out_of_bounds() -> Result<(), AllocError> { + // TODO: Kunit #[test]s do not support `cfg` yet, + // so we add it here in the body. + #[cfg(not(CONFIG_RUST_BITMAP_HARDENED))] + { + let mut b = BitmapVec::new(128, GFP_KERNEL)?; + b.set_bit(2048); + b.set_bit_atomic(2048); + b.clear_bit(2048); + b.clear_bit_atomic(2048); + assert_eq!(None, b.next_bit(2048)); + assert_eq!(None, b.next_zero_bit(2048)); + assert_eq!(None, b.last_bit()); + } + Ok(()) + } + + // TODO: uncomment once kunit supports [should_panic] and `cfg`. + // #[cfg(CONFIG_RUST_BITMAP_HARDENED)] + // #[test] + // #[should_panic] + // fn owned_bitmap_out_of_bounds() -> Result<(), AllocError> { + // let mut b = BitmapVec::new(128, GFP_KERNEL)?; + // + // b.set_bit(2048); + // } + + #[test] + fn bitmap_copy_and_extend() -> Result<(), AllocError> { + let mut long_bitmap = BitmapVec::new(256, GFP_KERNEL)?; + + long_bitmap.set_bit(3); + long_bitmap.set_bit(200); + + let mut short_bitmap = BitmapVec::new(32, GFP_KERNEL)?; + + short_bitmap.set_bit(17); + + long_bitmap.copy_and_extend(&short_bitmap); + + // Previous bits have been cleared. + assert_eq!(Some(17), long_bitmap.next_bit(0)); + assert_eq!(Some(17), long_bitmap.last_bit()); + Ok(()) + } +} diff --git a/rust/kernel/bits.rs b/rust/kernel/bits.rs new file mode 100644 index 000000000000..553d50265883 --- /dev/null +++ b/rust/kernel/bits.rs @@ -0,0 +1,203 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! Bit manipulation macros. +//! +//! C header: [`include/linux/bits.h`](srctree/include/linux/bits.h) + +use crate::prelude::*; +use core::ops::RangeInclusive; +use macros::paste; + +macro_rules! impl_bit_fn { + ( + $ty:ty + ) => { + paste! { + /// Computes `1 << n` if `n` is in bounds, i.e.: if `n` is smaller than + /// the maximum number of bits supported by the type. + /// + /// Returns [`None`] otherwise. + #[inline] + pub fn [<checked_bit_ $ty>](n: u32) -> Option<$ty> { + (1 as $ty).checked_shl(n) + } + + /// Computes `1 << n` by performing a compile-time assertion that `n` is + /// in bounds. + /// + /// This version is the default and should be used if `n` is known at + /// compile time. + #[inline] + pub const fn [<bit_ $ty>](n: u32) -> $ty { + build_assert!(n < <$ty>::BITS); + (1 as $ty) << n + } + } + }; +} + +impl_bit_fn!(u64); +impl_bit_fn!(u32); +impl_bit_fn!(u16); +impl_bit_fn!(u8); + +macro_rules! impl_genmask_fn { + ( + $ty:ty, + $(#[$genmask_checked_ex:meta])*, + $(#[$genmask_ex:meta])* + ) => { + paste! { + /// Creates a contiguous bitmask for the given range by validating + /// the range at runtime. + /// + /// Returns [`None`] if the range is invalid, i.e.: if the start is + /// greater than the end or if the range is outside of the + /// representable range for the type. + $(#[$genmask_checked_ex])* + #[inline] + pub fn [<genmask_checked_ $ty>](range: RangeInclusive<u32>) -> Option<$ty> { + let start = *range.start(); + let end = *range.end(); + + if start > end { + return None; + } + + let high = [<checked_bit_ $ty>](end)?; + let low = [<checked_bit_ $ty>](start)?; + Some((high | (high - 1)) & !(low - 1)) + } + + /// Creates a compile-time contiguous bitmask for the given range by + /// performing a compile-time assertion that the range is valid. + /// + /// This version is the default and should be used if the range is known + /// at compile time. + $(#[$genmask_ex])* + #[inline] + pub const fn [<genmask_ $ty>](range: RangeInclusive<u32>) -> $ty { + let start = *range.start(); + let end = *range.end(); + + build_assert!(start <= end); + + let high = [<bit_ $ty>](end); + let low = [<bit_ $ty>](start); + (high | (high - 1)) & !(low - 1) + } + } + }; +} + +impl_genmask_fn!( + u64, + /// # Examples + /// + /// ``` + /// # #![expect(clippy::reversed_empty_ranges)] + /// # use kernel::bits::genmask_checked_u64; + /// assert_eq!(genmask_checked_u64(0..=0), Some(0b1)); + /// assert_eq!(genmask_checked_u64(0..=63), Some(u64::MAX)); + /// assert_eq!(genmask_checked_u64(21..=39), Some(0x0000_00ff_ffe0_0000)); + /// + /// // `80` is out of the supported bit range. + /// assert_eq!(genmask_checked_u64(21..=80), None); + /// + /// // Invalid range where the start is bigger than the end. + /// assert_eq!(genmask_checked_u64(15..=8), None); + /// ``` + , + /// # Examples + /// + /// ``` + /// # use kernel::bits::genmask_u64; + /// assert_eq!(genmask_u64(21..=39), 0x0000_00ff_ffe0_0000); + /// assert_eq!(genmask_u64(0..=0), 0b1); + /// assert_eq!(genmask_u64(0..=63), u64::MAX); + /// ``` +); + +impl_genmask_fn!( + u32, + /// # Examples + /// + /// ``` + /// # #![expect(clippy::reversed_empty_ranges)] + /// # use kernel::bits::genmask_checked_u32; + /// assert_eq!(genmask_checked_u32(0..=0), Some(0b1)); + /// assert_eq!(genmask_checked_u32(0..=31), Some(u32::MAX)); + /// assert_eq!(genmask_checked_u32(21..=31), Some(0xffe0_0000)); + /// + /// // `40` is out of the supported bit range. + /// assert_eq!(genmask_checked_u32(21..=40), None); + /// + /// // Invalid range where the start is bigger than the end. + /// assert_eq!(genmask_checked_u32(15..=8), None); + /// ``` + , + /// # Examples + /// + /// ``` + /// # use kernel::bits::genmask_u32; + /// assert_eq!(genmask_u32(21..=31), 0xffe0_0000); + /// assert_eq!(genmask_u32(0..=0), 0b1); + /// assert_eq!(genmask_u32(0..=31), u32::MAX); + /// ``` +); + +impl_genmask_fn!( + u16, + /// # Examples + /// + /// ``` + /// # #![expect(clippy::reversed_empty_ranges)] + /// # use kernel::bits::genmask_checked_u16; + /// assert_eq!(genmask_checked_u16(0..=0), Some(0b1)); + /// assert_eq!(genmask_checked_u16(0..=15), Some(u16::MAX)); + /// assert_eq!(genmask_checked_u16(6..=15), Some(0xffc0)); + /// + /// // `20` is out of the supported bit range. + /// assert_eq!(genmask_checked_u16(6..=20), None); + /// + /// // Invalid range where the start is bigger than the end. + /// assert_eq!(genmask_checked_u16(10..=5), None); + /// ``` + , + /// # Examples + /// + /// ``` + /// # use kernel::bits::genmask_u16; + /// assert_eq!(genmask_u16(6..=15), 0xffc0); + /// assert_eq!(genmask_u16(0..=0), 0b1); + /// assert_eq!(genmask_u16(0..=15), u16::MAX); + /// ``` +); + +impl_genmask_fn!( + u8, + /// # Examples + /// + /// ``` + /// # #![expect(clippy::reversed_empty_ranges)] + /// # use kernel::bits::genmask_checked_u8; + /// assert_eq!(genmask_checked_u8(0..=0), Some(0b1)); + /// assert_eq!(genmask_checked_u8(0..=7), Some(u8::MAX)); + /// assert_eq!(genmask_checked_u8(6..=7), Some(0xc0)); + /// + /// // `10` is out of the supported bit range. + /// assert_eq!(genmask_checked_u8(6..=10), None); + /// + /// // Invalid range where the start is bigger than the end. + /// assert_eq!(genmask_checked_u8(5..=2), None); + /// ``` + , + /// # Examples + /// + /// ``` + /// # use kernel::bits::genmask_u8; + /// assert_eq!(genmask_u8(6..=7), 0xc0); + /// assert_eq!(genmask_u8(0..=0), 0b1); + /// assert_eq!(genmask_u8(0..=7), u8::MAX); + /// ``` +); diff --git a/rust/kernel/block.rs b/rust/kernel/block.rs new file mode 100644 index 000000000000..32c8d865afb6 --- /dev/null +++ b/rust/kernel/block.rs @@ -0,0 +1,18 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! Types for working with the block layer. + +pub mod mq; + +/// Bit mask for masking out [`SECTOR_SIZE`]. +pub const SECTOR_MASK: u32 = bindings::SECTOR_MASK; + +/// Sectors are size `1 << SECTOR_SHIFT`. +pub const SECTOR_SHIFT: u32 = bindings::SECTOR_SHIFT; + +/// Size of a sector. +pub const SECTOR_SIZE: u32 = bindings::SECTOR_SIZE; + +/// The difference between the size of a page and the size of a sector, +/// expressed as a power of two. +pub const PAGE_SECTORS_SHIFT: u32 = bindings::PAGE_SECTORS_SHIFT; diff --git a/rust/kernel/block/mq.rs b/rust/kernel/block/mq.rs new file mode 100644 index 000000000000..1fd0d54dd549 --- /dev/null +++ b/rust/kernel/block/mq.rs @@ -0,0 +1,103 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! This module provides types for implementing block drivers that interface the +//! blk-mq subsystem. +//! +//! To implement a block device driver, a Rust module must do the following: +//! +//! - Implement [`Operations`] for a type `T`. +//! - Create a [`TagSet<T>`]. +//! - Create a [`GenDisk<T>`], via the [`GenDiskBuilder`]. +//! - Add the disk to the system by calling [`GenDiskBuilder::build`] passing in +//! the `TagSet` reference. +//! +//! The types available in this module that have direct C counterparts are: +//! +//! - The [`TagSet`] type that abstracts the C type `struct tag_set`. +//! - The [`GenDisk`] type that abstracts the C type `struct gendisk`. +//! - The [`Request`] type that abstracts the C type `struct request`. +//! +//! The kernel will interface with the block device driver by calling the method +//! implementations of the `Operations` trait. +//! +//! IO requests are passed to the driver as [`kernel::sync::aref::ARef<Request>`] +//! instances. The `Request` type is a wrapper around the C `struct request`. +//! The driver must mark end of processing by calling one of the +//! `Request::end`, methods. Failure to do so can lead to deadlock or timeout +//! errors. Please note that the C function `blk_mq_start_request` is implicitly +//! called when the request is queued with the driver. +//! +//! The `TagSet` is responsible for creating and maintaining a mapping between +//! `Request`s and integer ids as well as carrying a pointer to the vtable +//! generated by `Operations`. This mapping is useful for associating +//! completions from hardware with the correct `Request` instance. The `TagSet` +//! determines the maximum queue depth by setting the number of `Request` +//! instances available to the driver, and it determines the number of queues to +//! instantiate for the driver. If possible, a driver should allocate one queue +//! per core, to keep queue data local to a core. +//! +//! One `TagSet` instance can be shared between multiple `GenDisk` instances. +//! This can be useful when implementing drivers where one piece of hardware +//! with one set of IO resources are represented to the user as multiple disks. +//! +//! One significant difference between block device drivers implemented with +//! these Rust abstractions and drivers implemented in C, is that the Rust +//! drivers have to own a reference count on the `Request` type when the IO is +//! in flight. This is to ensure that the C `struct request` instances backing +//! the Rust `Request` instances are live while the Rust driver holds a +//! reference to the `Request`. In addition, the conversion of an integer tag to +//! a `Request` via the `TagSet` would not be sound without this bookkeeping. +//! +//! [`GenDisk`]: gen_disk::GenDisk +//! [`GenDisk<T>`]: gen_disk::GenDisk +//! [`GenDiskBuilder`]: gen_disk::GenDiskBuilder +//! [`GenDiskBuilder::build`]: gen_disk::GenDiskBuilder::build +//! +//! # Examples +//! +//! ```rust +//! use kernel::{ +//! alloc::flags, +//! block::mq::*, +//! new_mutex, +//! prelude::*, +//! sync::{aref::ARef, Arc, Mutex}, +//! }; +//! +//! struct MyBlkDevice; +//! +//! #[vtable] +//! impl Operations for MyBlkDevice { +//! type QueueData = (); +//! +//! fn queue_rq(_queue_data: (), rq: ARef<Request<Self>>, _is_last: bool) -> Result { +//! Request::end_ok(rq); +//! Ok(()) +//! } +//! +//! fn commit_rqs(_queue_data: ()) {} +//! +//! fn complete(rq: ARef<Request<Self>>) { +//! Request::end_ok(rq) +//! .map_err(|_e| kernel::error::code::EIO) +//! .expect("Fatal error - expected to be able to end request"); +//! } +//! } +//! +//! let tagset: Arc<TagSet<MyBlkDevice>> = +//! Arc::pin_init(TagSet::new(1, 256, 1), flags::GFP_KERNEL)?; +//! let mut disk = gen_disk::GenDiskBuilder::new() +//! .capacity_sectors(4096) +//! .build(fmt!("myblk"), tagset, ())?; +//! +//! # Ok::<(), kernel::error::Error>(()) +//! ``` + +pub mod gen_disk; +mod operations; +mod request; +mod tag_set; + +pub use operations::Operations; +pub use request::Request; +pub use tag_set::TagSet; diff --git a/rust/kernel/block/mq/gen_disk.rs b/rust/kernel/block/mq/gen_disk.rs new file mode 100644 index 000000000000..1ce815c8cdab --- /dev/null +++ b/rust/kernel/block/mq/gen_disk.rs @@ -0,0 +1,226 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! Generic disk abstraction. +//! +//! C header: [`include/linux/blkdev.h`](srctree/include/linux/blkdev.h) +//! C header: [`include/linux/blk-mq.h`](srctree/include/linux/blk-mq.h) + +use crate::{ + bindings, + block::mq::{Operations, TagSet}, + error::{self, from_err_ptr, Result}, + fmt::{self, Write}, + prelude::*, + static_lock_class, + str::NullTerminatedFormatter, + sync::Arc, + types::{ForeignOwnable, ScopeGuard}, +}; + +/// A builder for [`GenDisk`]. +/// +/// Use this struct to configure and add new [`GenDisk`] to the VFS. +pub struct GenDiskBuilder { + rotational: bool, + logical_block_size: u32, + physical_block_size: u32, + capacity_sectors: u64, +} + +impl Default for GenDiskBuilder { + fn default() -> Self { + Self { + rotational: false, + logical_block_size: bindings::PAGE_SIZE as u32, + physical_block_size: bindings::PAGE_SIZE as u32, + capacity_sectors: 0, + } + } +} + +impl GenDiskBuilder { + /// Create a new instance. + pub fn new() -> Self { + Self::default() + } + + /// Set the rotational media attribute for the device to be built. + pub fn rotational(mut self, rotational: bool) -> Self { + self.rotational = rotational; + self + } + + /// Validate block size by verifying that it is between 512 and `PAGE_SIZE`, + /// and that it is a power of two. + pub fn validate_block_size(size: u32) -> Result { + if !(512..=bindings::PAGE_SIZE as u32).contains(&size) || !size.is_power_of_two() { + Err(error::code::EINVAL) + } else { + Ok(()) + } + } + + /// Set the logical block size of the device to be built. + /// + /// This method will check that block size is a power of two and between 512 + /// and 4096. If not, an error is returned and the block size is not set. + /// + /// This is the smallest unit the storage device can address. It is + /// typically 4096 bytes. + pub fn logical_block_size(mut self, block_size: u32) -> Result<Self> { + Self::validate_block_size(block_size)?; + self.logical_block_size = block_size; + Ok(self) + } + + /// Set the physical block size of the device to be built. + /// + /// This method will check that block size is a power of two and between 512 + /// and 4096. If not, an error is returned and the block size is not set. + /// + /// This is the smallest unit a physical storage device can write + /// atomically. It is usually the same as the logical block size but may be + /// bigger. One example is SATA drives with 4096 byte physical block size + /// that expose a 512 byte logical block size to the operating system. + pub fn physical_block_size(mut self, block_size: u32) -> Result<Self> { + Self::validate_block_size(block_size)?; + self.physical_block_size = block_size; + Ok(self) + } + + /// Set the capacity of the device to be built, in sectors (512 bytes). + pub fn capacity_sectors(mut self, capacity: u64) -> Self { + self.capacity_sectors = capacity; + self + } + + /// Build a new `GenDisk` and add it to the VFS. + pub fn build<T: Operations>( + self, + name: fmt::Arguments<'_>, + tagset: Arc<TagSet<T>>, + queue_data: T::QueueData, + ) -> Result<GenDisk<T>> { + let data = queue_data.into_foreign(); + let recover_data = ScopeGuard::new(|| { + // SAFETY: T::QueueData was created by the call to `into_foreign()` above + drop(unsafe { T::QueueData::from_foreign(data) }); + }); + + // SAFETY: `bindings::queue_limits` contain only fields that are valid when zeroed. + let mut lim: bindings::queue_limits = unsafe { core::mem::zeroed() }; + + lim.logical_block_size = self.logical_block_size; + lim.physical_block_size = self.physical_block_size; + if self.rotational { + lim.features = bindings::BLK_FEAT_ROTATIONAL; + } + + // SAFETY: `tagset.raw_tag_set()` points to a valid and initialized tag set + let gendisk = from_err_ptr(unsafe { + bindings::__blk_mq_alloc_disk( + tagset.raw_tag_set(), + &mut lim, + data, + static_lock_class!().as_ptr(), + ) + })?; + + const TABLE: bindings::block_device_operations = bindings::block_device_operations { + submit_bio: None, + open: None, + release: None, + ioctl: None, + compat_ioctl: None, + check_events: None, + unlock_native_capacity: None, + getgeo: None, + set_read_only: None, + swap_slot_free_notify: None, + report_zones: None, + devnode: None, + alternative_gpt_sector: None, + get_unique_id: None, + // TODO: Set to THIS_MODULE. Waiting for const_refs_to_static feature to + // be merged (unstable in rustc 1.78 which is staged for linux 6.10) + // <https://github.com/rust-lang/rust/issues/119618> + owner: core::ptr::null_mut(), + pr_ops: core::ptr::null_mut(), + free_disk: None, + poll_bio: None, + }; + + // SAFETY: `gendisk` is a valid pointer as we initialized it above + unsafe { (*gendisk).fops = &TABLE }; + + let mut writer = NullTerminatedFormatter::new( + // SAFETY: `gendisk` points to a valid and initialized instance. We + // have exclusive access, since the disk is not added to the VFS + // yet. + unsafe { &mut (*gendisk).disk_name }, + ) + .ok_or(EINVAL)?; + writer.write_fmt(name)?; + + // SAFETY: `gendisk` points to a valid and initialized instance of + // `struct gendisk`. `set_capacity` takes a lock to synchronize this + // operation, so we will not race. + unsafe { bindings::set_capacity(gendisk, self.capacity_sectors) }; + + crate::error::to_result( + // SAFETY: `gendisk` points to a valid and initialized instance of + // `struct gendisk`. + unsafe { + bindings::device_add_disk(core::ptr::null_mut(), gendisk, core::ptr::null_mut()) + }, + )?; + + recover_data.dismiss(); + + // INVARIANT: `gendisk` was initialized above. + // INVARIANT: `gendisk` was added to the VFS via `device_add_disk` above. + // INVARIANT: `gendisk.queue.queue_data` is set to `data` in the call to + // `__blk_mq_alloc_disk` above. + Ok(GenDisk { + _tagset: tagset, + gendisk, + }) + } +} + +/// A generic block device. +/// +/// # Invariants +/// +/// - `gendisk` must always point to an initialized and valid `struct gendisk`. +/// - `gendisk` was added to the VFS through a call to +/// `bindings::device_add_disk`. +/// - `self.gendisk.queue.queuedata` is initialized by a call to `ForeignOwnable::into_foreign`. +pub struct GenDisk<T: Operations> { + _tagset: Arc<TagSet<T>>, + gendisk: *mut bindings::gendisk, +} + +// SAFETY: `GenDisk` is an owned pointer to a `struct gendisk` and an `Arc` to a +// `TagSet` It is safe to send this to other threads as long as T is Send. +unsafe impl<T: Operations + Send> Send for GenDisk<T> {} + +impl<T: Operations> Drop for GenDisk<T> { + fn drop(&mut self) { + // SAFETY: By type invariant of `Self`, `self.gendisk` points to a valid + // and initialized instance of `struct gendisk`, and, `queuedata` was + // initialized with the result of a call to + // `ForeignOwnable::into_foreign`. + let queue_data = unsafe { (*(*self.gendisk).queue).queuedata }; + + // SAFETY: By type invariant, `self.gendisk` points to a valid and + // initialized instance of `struct gendisk`, and it was previously added + // to the VFS. + unsafe { bindings::del_gendisk(self.gendisk) }; + + // SAFETY: `queue.queuedata` was created by `GenDiskBuilder::build` with + // a call to `ForeignOwnable::into_foreign` to create `queuedata`. + // `ForeignOwnable::from_foreign` is only called here. + drop(unsafe { T::QueueData::from_foreign(queue_data) }); + } +} diff --git a/rust/kernel/block/mq/operations.rs b/rust/kernel/block/mq/operations.rs new file mode 100644 index 000000000000..8ad46129a52c --- /dev/null +++ b/rust/kernel/block/mq/operations.rs @@ -0,0 +1,286 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! This module provides an interface for blk-mq drivers to implement. +//! +//! C header: [`include/linux/blk-mq.h`](srctree/include/linux/blk-mq.h) + +use crate::{ + bindings, + block::mq::{request::RequestDataWrapper, Request}, + error::{from_result, Result}, + prelude::*, + sync::{aref::ARef, Refcount}, + types::ForeignOwnable, +}; +use core::marker::PhantomData; + +type ForeignBorrowed<'a, T> = <T as ForeignOwnable>::Borrowed<'a>; + +/// Implement this trait to interface blk-mq as block devices. +/// +/// To implement a block device driver, implement this trait as described in the +/// [module level documentation]. The kernel will use the implementation of the +/// functions defined in this trait to interface a block device driver. Note: +/// There is no need for an exit_request() implementation, because the `drop` +/// implementation of the [`Request`] type will be invoked by automatically by +/// the C/Rust glue logic. +/// +/// [module level documentation]: kernel::block::mq +#[macros::vtable] +pub trait Operations: Sized { + /// Data associated with the `struct request_queue` that is allocated for + /// the `GenDisk` associated with this `Operations` implementation. + type QueueData: ForeignOwnable; + + /// Called by the kernel to queue a request with the driver. If `is_last` is + /// `false`, the driver is allowed to defer committing the request. + fn queue_rq( + queue_data: ForeignBorrowed<'_, Self::QueueData>, + rq: ARef<Request<Self>>, + is_last: bool, + ) -> Result; + + /// Called by the kernel to indicate that queued requests should be submitted. + fn commit_rqs(queue_data: ForeignBorrowed<'_, Self::QueueData>); + + /// Called by the kernel when the request is completed. + fn complete(rq: ARef<Request<Self>>); + + /// Called by the kernel to poll the device for completed requests. Only + /// used for poll queues. + fn poll() -> bool { + build_error!(crate::error::VTABLE_DEFAULT_ERROR) + } +} + +/// A vtable for blk-mq to interact with a block device driver. +/// +/// A `bindings::blk_mq_ops` vtable is constructed from pointers to the `extern +/// "C"` functions of this struct, exposed through the `OperationsVTable::VTABLE`. +/// +/// For general documentation of these methods, see the kernel source +/// documentation related to `struct blk_mq_operations` in +/// [`include/linux/blk-mq.h`]. +/// +/// [`include/linux/blk-mq.h`]: srctree/include/linux/blk-mq.h +pub(crate) struct OperationsVTable<T: Operations>(PhantomData<T>); + +impl<T: Operations> OperationsVTable<T> { + /// This function is called by the C kernel. A pointer to this function is + /// installed in the `blk_mq_ops` vtable for the driver. + /// + /// # Safety + /// + /// - The caller of this function must ensure that the pointee of `bd` is + /// valid for reads for the duration of this function. + /// - This function must be called for an initialized and live `hctx`. That + /// is, `Self::init_hctx_callback` was called and + /// `Self::exit_hctx_callback()` was not yet called. + /// - `(*bd).rq` must point to an initialized and live `bindings:request`. + /// That is, `Self::init_request_callback` was called but + /// `Self::exit_request_callback` was not yet called for the request. + /// - `(*bd).rq` must be owned by the driver. That is, the block layer must + /// promise to not access the request until the driver calls + /// `bindings::blk_mq_end_request` for the request. + unsafe extern "C" fn queue_rq_callback( + hctx: *mut bindings::blk_mq_hw_ctx, + bd: *const bindings::blk_mq_queue_data, + ) -> bindings::blk_status_t { + // SAFETY: `bd.rq` is valid as required by the safety requirement for + // this function. + let request = unsafe { &*(*bd).rq.cast::<Request<T>>() }; + + // One refcount for the ARef, one for being in flight + request.wrapper_ref().refcount().set(2); + + // SAFETY: + // - We own a refcount that we took above. We pass that to `ARef`. + // - By the safety requirements of this function, `request` is a valid + // `struct request` and the private data is properly initialized. + // - `rq` will be alive until `blk_mq_end_request` is called and is + // reference counted by `ARef` until then. + let rq = unsafe { Request::aref_from_raw((*bd).rq) }; + + // SAFETY: `hctx` is valid as required by this function. + let queue_data = unsafe { (*(*hctx).queue).queuedata }; + + // SAFETY: `queue.queuedata` was created by `GenDiskBuilder::build` with + // a call to `ForeignOwnable::into_foreign` to create `queuedata`. + // `ForeignOwnable::from_foreign` is only called when the tagset is + // dropped, which happens after we are dropped. + let queue_data = unsafe { T::QueueData::borrow(queue_data) }; + + // SAFETY: We have exclusive access and we just set the refcount above. + unsafe { Request::start_unchecked(&rq) }; + + let ret = T::queue_rq( + queue_data, + rq, + // SAFETY: `bd` is valid as required by the safety requirement for + // this function. + unsafe { (*bd).last }, + ); + + if let Err(e) = ret { + e.to_blk_status() + } else { + bindings::BLK_STS_OK as bindings::blk_status_t + } + } + + /// This function is called by the C kernel. A pointer to this function is + /// installed in the `blk_mq_ops` vtable for the driver. + /// + /// # Safety + /// + /// This function may only be called by blk-mq C infrastructure. The caller + /// must ensure that `hctx` is valid. + unsafe extern "C" fn commit_rqs_callback(hctx: *mut bindings::blk_mq_hw_ctx) { + // SAFETY: `hctx` is valid as required by this function. + let queue_data = unsafe { (*(*hctx).queue).queuedata }; + + // SAFETY: `queue.queuedata` was created by `GenDisk::try_new()` with a + // call to `ForeignOwnable::into_foreign()` to create `queuedata`. + // `ForeignOwnable::from_foreign()` is only called when the tagset is + // dropped, which happens after we are dropped. + let queue_data = unsafe { T::QueueData::borrow(queue_data) }; + T::commit_rqs(queue_data) + } + + /// This function is called by the C kernel. A pointer to this function is + /// installed in the `blk_mq_ops` vtable for the driver. + /// + /// # Safety + /// + /// This function may only be called by blk-mq C infrastructure. `rq` must + /// point to a valid request that has been marked as completed. The pointee + /// of `rq` must be valid for write for the duration of this function. + unsafe extern "C" fn complete_callback(rq: *mut bindings::request) { + // SAFETY: This function can only be dispatched through + // `Request::complete`. We leaked a refcount then which we pick back up + // now. + let aref = unsafe { Request::aref_from_raw(rq) }; + T::complete(aref); + } + + /// This function is called by the C kernel. A pointer to this function is + /// installed in the `blk_mq_ops` vtable for the driver. + /// + /// # Safety + /// + /// This function may only be called by blk-mq C infrastructure. + unsafe extern "C" fn poll_callback( + _hctx: *mut bindings::blk_mq_hw_ctx, + _iob: *mut bindings::io_comp_batch, + ) -> crate::ffi::c_int { + T::poll().into() + } + + /// This function is called by the C kernel. A pointer to this function is + /// installed in the `blk_mq_ops` vtable for the driver. + /// + /// # Safety + /// + /// This function may only be called by blk-mq C infrastructure. This + /// function may only be called once before `exit_hctx_callback` is called + /// for the same context. + unsafe extern "C" fn init_hctx_callback( + _hctx: *mut bindings::blk_mq_hw_ctx, + _tagset_data: *mut crate::ffi::c_void, + _hctx_idx: crate::ffi::c_uint, + ) -> crate::ffi::c_int { + from_result(|| Ok(0)) + } + + /// This function is called by the C kernel. A pointer to this function is + /// installed in the `blk_mq_ops` vtable for the driver. + /// + /// # Safety + /// + /// This function may only be called by blk-mq C infrastructure. + unsafe extern "C" fn exit_hctx_callback( + _hctx: *mut bindings::blk_mq_hw_ctx, + _hctx_idx: crate::ffi::c_uint, + ) { + } + + /// This function is called by the C kernel. A pointer to this function is + /// installed in the `blk_mq_ops` vtable for the driver. + /// + /// # Safety + /// + /// - This function may only be called by blk-mq C infrastructure. + /// - `_set` must point to an initialized `TagSet<T>`. + /// - `rq` must point to an initialized `bindings::request`. + /// - The allocation pointed to by `rq` must be at the size of `Request` + /// plus the size of `RequestDataWrapper`. + unsafe extern "C" fn init_request_callback( + _set: *mut bindings::blk_mq_tag_set, + rq: *mut bindings::request, + _hctx_idx: crate::ffi::c_uint, + _numa_node: crate::ffi::c_uint, + ) -> crate::ffi::c_int { + from_result(|| { + // SAFETY: By the safety requirements of this function, `rq` points + // to a valid allocation. + let pdu = unsafe { Request::wrapper_ptr(rq.cast::<Request<T>>()) }; + + // SAFETY: The refcount field is allocated but not initialized, so + // it is valid for writes. + unsafe { RequestDataWrapper::refcount_ptr(pdu.as_ptr()).write(Refcount::new(0)) }; + + Ok(0) + }) + } + + /// This function is called by the C kernel. A pointer to this function is + /// installed in the `blk_mq_ops` vtable for the driver. + /// + /// # Safety + /// + /// - This function may only be called by blk-mq C infrastructure. + /// - `_set` must point to an initialized `TagSet<T>`. + /// - `rq` must point to an initialized and valid `Request`. + unsafe extern "C" fn exit_request_callback( + _set: *mut bindings::blk_mq_tag_set, + rq: *mut bindings::request, + _hctx_idx: crate::ffi::c_uint, + ) { + // SAFETY: The tagset invariants guarantee that all requests are allocated with extra memory + // for the request data. + let pdu = unsafe { bindings::blk_mq_rq_to_pdu(rq) }.cast::<RequestDataWrapper>(); + + // SAFETY: `pdu` is valid for read and write and is properly initialised. + unsafe { core::ptr::drop_in_place(pdu) }; + } + + const VTABLE: bindings::blk_mq_ops = bindings::blk_mq_ops { + queue_rq: Some(Self::queue_rq_callback), + queue_rqs: None, + commit_rqs: Some(Self::commit_rqs_callback), + get_budget: None, + put_budget: None, + set_rq_budget_token: None, + get_rq_budget_token: None, + timeout: None, + poll: if T::HAS_POLL { + Some(Self::poll_callback) + } else { + None + }, + complete: Some(Self::complete_callback), + init_hctx: Some(Self::init_hctx_callback), + exit_hctx: Some(Self::exit_hctx_callback), + init_request: Some(Self::init_request_callback), + exit_request: Some(Self::exit_request_callback), + cleanup_rq: None, + busy: None, + map_queues: None, + #[cfg(CONFIG_BLK_DEBUG_FS)] + show_rq: None, + }; + + pub(crate) const fn build() -> &'static bindings::blk_mq_ops { + &Self::VTABLE + } +} diff --git a/rust/kernel/block/mq/request.rs b/rust/kernel/block/mq/request.rs new file mode 100644 index 000000000000..ce3e30c81cb5 --- /dev/null +++ b/rust/kernel/block/mq/request.rs @@ -0,0 +1,257 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! This module provides a wrapper for the C `struct request` type. +//! +//! C header: [`include/linux/blk-mq.h`](srctree/include/linux/blk-mq.h) + +use crate::{ + bindings, + block::mq::Operations, + error::Result, + sync::{ + aref::{ARef, AlwaysRefCounted}, + atomic::Relaxed, + Refcount, + }, + types::Opaque, +}; +use core::{marker::PhantomData, ptr::NonNull}; + +/// A wrapper around a blk-mq [`struct request`]. This represents an IO request. +/// +/// # Implementation details +/// +/// There are four states for a request that the Rust bindings care about: +/// +/// 1. Request is owned by block layer (refcount 0). +/// 2. Request is owned by driver but with zero [`ARef`]s in existence +/// (refcount 1). +/// 3. Request is owned by driver with exactly one [`ARef`] in existence +/// (refcount 2). +/// 4. Request is owned by driver with more than one [`ARef`] in existence +/// (refcount > 2). +/// +/// +/// We need to track 1 and 2 to ensure we fail tag to request conversions for +/// requests that are not owned by the driver. +/// +/// We need to track 3 and 4 to ensure that it is safe to end the request and hand +/// back ownership to the block layer. +/// +/// Note that the driver can still obtain new `ARef` even if there is no `ARef`s in existence by +/// using `tag_to_rq`, hence the need to distinguish B and C. +/// +/// The states are tracked through the private `refcount` field of +/// `RequestDataWrapper`. This structure lives in the private data area of the C +/// [`struct request`]. +/// +/// # Invariants +/// +/// * `self.0` is a valid [`struct request`] created by the C portion of the +/// kernel. +/// * The private data area associated with this request must be an initialized +/// and valid `RequestDataWrapper<T>`. +/// * `self` is reference counted by atomic modification of +/// `self.wrapper_ref().refcount()`. +/// +/// [`struct request`]: srctree/include/linux/blk-mq.h +/// +#[repr(transparent)] +pub struct Request<T>(Opaque<bindings::request>, PhantomData<T>); + +impl<T: Operations> Request<T> { + /// Create an [`ARef<Request>`] from a [`struct request`] pointer. + /// + /// # Safety + /// + /// * The caller must own a refcount on `ptr` that is transferred to the + /// returned [`ARef`]. + /// * The type invariants for [`Request`] must hold for the pointee of `ptr`. + /// + /// [`struct request`]: srctree/include/linux/blk-mq.h + pub(crate) unsafe fn aref_from_raw(ptr: *mut bindings::request) -> ARef<Self> { + // INVARIANT: By the safety requirements of this function, invariants are upheld. + // SAFETY: By the safety requirement of this function, we own a + // reference count that we can pass to `ARef`. + unsafe { ARef::from_raw(NonNull::new_unchecked(ptr.cast())) } + } + + /// Notify the block layer that a request is going to be processed now. + /// + /// The block layer uses this hook to do proper initializations such as + /// starting the timeout timer. It is a requirement that block device + /// drivers call this function when starting to process a request. + /// + /// # Safety + /// + /// The caller must have exclusive ownership of `self`, that is + /// `self.wrapper_ref().refcount() == 2`. + pub(crate) unsafe fn start_unchecked(this: &ARef<Self>) { + // SAFETY: By type invariant, `self.0` is a valid `struct request` and + // we have exclusive access. + unsafe { bindings::blk_mq_start_request(this.0.get()) }; + } + + /// Try to take exclusive ownership of `this` by dropping the refcount to 0. + /// This fails if `this` is not the only [`ARef`] pointing to the underlying + /// [`Request`]. + /// + /// If the operation is successful, [`Ok`] is returned with a pointer to the + /// C [`struct request`]. If the operation fails, `this` is returned in the + /// [`Err`] variant. + /// + /// [`struct request`]: srctree/include/linux/blk-mq.h + fn try_set_end(this: ARef<Self>) -> Result<*mut bindings::request, ARef<Self>> { + // To hand back the ownership, we need the current refcount to be 2. + // Since we can race with `TagSet::tag_to_rq`, this needs to atomically reduce + // refcount to 0. `Refcount` does not provide a way to do this, so use the underlying + // atomics directly. + if let Err(_old) = this + .wrapper_ref() + .refcount() + .as_atomic() + .cmpxchg(2, 0, Relaxed) + { + return Err(this); + } + + let request_ptr = this.0.get(); + core::mem::forget(this); + + Ok(request_ptr) + } + + /// Notify the block layer that the request has been completed without errors. + /// + /// This function will return [`Err`] if `this` is not the only [`ARef`] + /// referencing the request. + pub fn end_ok(this: ARef<Self>) -> Result<(), ARef<Self>> { + let request_ptr = Self::try_set_end(this)?; + + // SAFETY: By type invariant, `this.0` was a valid `struct request`. The + // success of the call to `try_set_end` guarantees that there are no + // `ARef`s pointing to this request. Therefore it is safe to hand it + // back to the block layer. + unsafe { + bindings::blk_mq_end_request( + request_ptr, + bindings::BLK_STS_OK as bindings::blk_status_t, + ) + }; + + Ok(()) + } + + /// Complete the request by scheduling `Operations::complete` for + /// execution. + /// + /// The function may be scheduled locally, via SoftIRQ or remotely via IPMI. + /// See `blk_mq_complete_request_remote` in [`blk-mq.c`] for details. + /// + /// [`blk-mq.c`]: srctree/block/blk-mq.c + pub fn complete(this: ARef<Self>) { + let ptr = ARef::into_raw(this).cast::<bindings::request>().as_ptr(); + // SAFETY: By type invariant, `self.0` is a valid `struct request` + if !unsafe { bindings::blk_mq_complete_request_remote(ptr) } { + // SAFETY: We released a refcount above that we can reclaim here. + let this = unsafe { Request::aref_from_raw(ptr) }; + T::complete(this); + } + } + + /// Return a pointer to the [`RequestDataWrapper`] stored in the private area + /// of the request structure. + /// + /// # Safety + /// + /// - `this` must point to a valid allocation of size at least size of + /// [`Self`] plus size of [`RequestDataWrapper`]. + pub(crate) unsafe fn wrapper_ptr(this: *mut Self) -> NonNull<RequestDataWrapper> { + let request_ptr = this.cast::<bindings::request>(); + // SAFETY: By safety requirements for this function, `this` is a + // valid allocation. + let wrapper_ptr = + unsafe { bindings::blk_mq_rq_to_pdu(request_ptr).cast::<RequestDataWrapper>() }; + // SAFETY: By C API contract, `wrapper_ptr` points to a valid allocation + // and is not null. + unsafe { NonNull::new_unchecked(wrapper_ptr) } + } + + /// Return a reference to the [`RequestDataWrapper`] stored in the private + /// area of the request structure. + pub(crate) fn wrapper_ref(&self) -> &RequestDataWrapper { + // SAFETY: By type invariant, `self.0` is a valid allocation. Further, + // the private data associated with this request is initialized and + // valid. The existence of `&self` guarantees that the private data is + // valid as a shared reference. + unsafe { Self::wrapper_ptr(core::ptr::from_ref(self).cast_mut()).as_ref() } + } +} + +/// A wrapper around data stored in the private area of the C [`struct request`]. +/// +/// [`struct request`]: srctree/include/linux/blk-mq.h +pub(crate) struct RequestDataWrapper { + /// The Rust request refcount has the following states: + /// + /// - 0: The request is owned by C block layer. + /// - 1: The request is owned by Rust abstractions but there are no [`ARef`] references to it. + /// - 2+: There are [`ARef`] references to the request. + refcount: Refcount, +} + +impl RequestDataWrapper { + /// Return a reference to the refcount of the request that is embedding + /// `self`. + pub(crate) fn refcount(&self) -> &Refcount { + &self.refcount + } + + /// Return a pointer to the refcount of the request that is embedding the + /// pointee of `this`. + /// + /// # Safety + /// + /// - `this` must point to a live allocation of at least the size of `Self`. + pub(crate) unsafe fn refcount_ptr(this: *mut Self) -> *mut Refcount { + // SAFETY: Because of the safety requirements of this function, the + // field projection is safe. + unsafe { &raw mut (*this).refcount } + } +} + +// SAFETY: Exclusive access is thread-safe for `Request`. `Request` has no `&mut +// self` methods and `&self` methods that mutate `self` are internally +// synchronized. +unsafe impl<T: Operations> Send for Request<T> {} + +// SAFETY: Shared access is thread-safe for `Request`. `&self` methods that +// mutate `self` are internally synchronized` +unsafe impl<T: Operations> Sync for Request<T> {} + +// SAFETY: All instances of `Request<T>` are reference counted. This +// implementation of `AlwaysRefCounted` ensure that increments to the ref count +// keeps the object alive in memory at least until a matching reference count +// decrement is executed. +unsafe impl<T: Operations> AlwaysRefCounted for Request<T> { + fn inc_ref(&self) { + self.wrapper_ref().refcount().inc(); + } + + unsafe fn dec_ref(obj: core::ptr::NonNull<Self>) { + // SAFETY: The type invariants of `ARef` guarantee that `obj` is valid + // for read. + let wrapper_ptr = unsafe { Self::wrapper_ptr(obj.as_ptr()).as_ptr() }; + // SAFETY: The type invariant of `Request` guarantees that the private + // data area is initialized and valid. + let refcount = unsafe { &*RequestDataWrapper::refcount_ptr(wrapper_ptr) }; + + #[cfg_attr(not(CONFIG_DEBUG_MISC), allow(unused_variables))] + let is_zero = refcount.dec_and_test(); + + #[cfg(CONFIG_DEBUG_MISC)] + if is_zero { + panic!("Request reached refcount zero in Rust abstractions"); + } + } +} diff --git a/rust/kernel/block/mq/tag_set.rs b/rust/kernel/block/mq/tag_set.rs new file mode 100644 index 000000000000..c3cf56d52bee --- /dev/null +++ b/rust/kernel/block/mq/tag_set.rs @@ -0,0 +1,87 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! This module provides the `TagSet` struct to wrap the C `struct blk_mq_tag_set`. +//! +//! C header: [`include/linux/blk-mq.h`](srctree/include/linux/blk-mq.h) + +use core::pin::Pin; + +use crate::{ + bindings, + block::mq::{operations::OperationsVTable, request::RequestDataWrapper, Operations}, + error::{self, Result}, + prelude::try_pin_init, + types::Opaque, +}; +use core::{convert::TryInto, marker::PhantomData}; +use pin_init::{pin_data, pinned_drop, PinInit}; + +/// A wrapper for the C `struct blk_mq_tag_set`. +/// +/// `struct blk_mq_tag_set` contains a `struct list_head` and so must be pinned. +/// +/// # Invariants +/// +/// - `inner` is initialized and valid. +#[pin_data(PinnedDrop)] +#[repr(transparent)] +pub struct TagSet<T: Operations> { + #[pin] + inner: Opaque<bindings::blk_mq_tag_set>, + _p: PhantomData<T>, +} + +impl<T: Operations> TagSet<T> { + /// Try to create a new tag set + pub fn new( + nr_hw_queues: u32, + num_tags: u32, + num_maps: u32, + ) -> impl PinInit<Self, error::Error> { + // SAFETY: `blk_mq_tag_set` only contains integers and pointers, which + // all are allowed to be 0. + let tag_set: bindings::blk_mq_tag_set = unsafe { core::mem::zeroed() }; + let tag_set: Result<_> = core::mem::size_of::<RequestDataWrapper>() + .try_into() + .map(|cmd_size| { + bindings::blk_mq_tag_set { + ops: OperationsVTable::<T>::build(), + nr_hw_queues, + timeout: 0, // 0 means default which is 30Hz in C + numa_node: bindings::NUMA_NO_NODE, + queue_depth: num_tags, + cmd_size, + flags: 0, + driver_data: core::ptr::null_mut::<crate::ffi::c_void>(), + nr_maps: num_maps, + ..tag_set + } + }) + .map(Opaque::new) + .map_err(|e| e.into()); + + try_pin_init!(TagSet { + inner <- tag_set.pin_chain(|tag_set| { + // SAFETY: we do not move out of `tag_set`. + let tag_set: &mut Opaque<_> = unsafe { Pin::get_unchecked_mut(tag_set) }; + // SAFETY: `tag_set` is a reference to an initialized `blk_mq_tag_set`. + error::to_result( unsafe { bindings::blk_mq_alloc_tag_set(tag_set.get())}) + }), + _p: PhantomData, + }) + } + + /// Return the pointer to the wrapped `struct blk_mq_tag_set` + pub(crate) fn raw_tag_set(&self) -> *mut bindings::blk_mq_tag_set { + self.inner.get() + } +} + +#[pinned_drop] +impl<T: Operations> PinnedDrop for TagSet<T> { + fn drop(self: Pin<&mut Self>) { + // SAFETY: By type invariant `inner` is valid and has been properly + // initialized during construction. + unsafe { bindings::blk_mq_free_tag_set(self.inner.get()) }; + } +} diff --git a/rust/kernel/bug.rs b/rust/kernel/bug.rs new file mode 100644 index 000000000000..36aef43e5ebe --- /dev/null +++ b/rust/kernel/bug.rs @@ -0,0 +1,126 @@ +// SPDX-License-Identifier: GPL-2.0 + +// Copyright (C) 2024, 2025 FUJITA Tomonori <fujita.tomonori@gmail.com> + +//! Support for BUG and WARN functionality. +//! +//! C header: [`include/asm-generic/bug.h`](srctree/include/asm-generic/bug.h) + +#[macro_export] +#[doc(hidden)] +#[cfg(all(CONFIG_BUG, not(CONFIG_UML), not(CONFIG_LOONGARCH), not(CONFIG_ARM)))] +#[cfg(CONFIG_DEBUG_BUGVERBOSE)] +macro_rules! warn_flags { + ($flags:expr) => { + const FLAGS: u32 = $crate::bindings::BUGFLAG_WARNING | $flags; + const _FILE: &[u8] = file!().as_bytes(); + // Plus one for null-terminator. + static FILE: [u8; _FILE.len() + 1] = { + let mut bytes = [0; _FILE.len() + 1]; + let mut i = 0; + while i < _FILE.len() { + bytes[i] = _FILE[i]; + i += 1; + } + bytes + }; + + // SAFETY: + // - `file`, `line`, `flags`, and `size` are all compile-time constants or + // symbols, preventing any invalid memory access. + // - The asm block has no side effects and does not modify any registers + // or memory. It is purely for embedding metadata into the ELF section. + unsafe { + $crate::asm!( + concat!( + "/* {size} */", + include!(concat!(env!("OBJTREE"), "/rust/kernel/generated_arch_warn_asm.rs")), + include!(concat!(env!("OBJTREE"), "/rust/kernel/generated_arch_reachable_asm.rs"))); + file = sym FILE, + line = const line!(), + flags = const FLAGS, + size = const ::core::mem::size_of::<$crate::bindings::bug_entry>(), + ); + } + } +} + +#[macro_export] +#[doc(hidden)] +#[cfg(all(CONFIG_BUG, not(CONFIG_UML), not(CONFIG_LOONGARCH), not(CONFIG_ARM)))] +#[cfg(not(CONFIG_DEBUG_BUGVERBOSE))] +macro_rules! warn_flags { + ($flags:expr) => { + const FLAGS: u32 = $crate::bindings::BUGFLAG_WARNING | $flags; + + // SAFETY: + // - `flags` and `size` are all compile-time constants, preventing + // any invalid memory access. + // - The asm block has no side effects and does not modify any registers + // or memory. It is purely for embedding metadata into the ELF section. + unsafe { + $crate::asm!( + concat!( + "/* {size} */", + include!(concat!(env!("OBJTREE"), "/rust/kernel/generated_arch_warn_asm.rs")), + include!(concat!(env!("OBJTREE"), "/rust/kernel/generated_arch_reachable_asm.rs"))); + flags = const FLAGS, + size = const ::core::mem::size_of::<$crate::bindings::bug_entry>(), + ); + } + } +} + +#[macro_export] +#[doc(hidden)] +#[cfg(all(CONFIG_BUG, CONFIG_UML))] +macro_rules! warn_flags { + ($flags:expr) => { + // SAFETY: It is always safe to call `warn_slowpath_fmt()` + // with a valid null-terminated string. + unsafe { + $crate::bindings::warn_slowpath_fmt( + $crate::c_str!(::core::file!()).as_char_ptr(), + line!() as $crate::ffi::c_int, + $flags as $crate::ffi::c_uint, + ::core::ptr::null(), + ); + } + }; +} + +#[macro_export] +#[doc(hidden)] +#[cfg(all(CONFIG_BUG, any(CONFIG_LOONGARCH, CONFIG_ARM)))] +macro_rules! warn_flags { + ($flags:expr) => { + // SAFETY: It is always safe to call `WARN_ON()`. + unsafe { $crate::bindings::WARN_ON(true) } + }; +} + +#[macro_export] +#[doc(hidden)] +#[cfg(not(CONFIG_BUG))] +macro_rules! warn_flags { + ($flags:expr) => {}; +} + +#[doc(hidden)] +pub const fn bugflag_taint(value: u32) -> u32 { + value << 8 +} + +/// Report a warning if `cond` is true and return the condition's evaluation result. +#[macro_export] +macro_rules! warn_on { + ($cond:expr) => {{ + let cond = $cond; + if cond { + const WARN_ON_FLAGS: u32 = $crate::bug::bugflag_taint($crate::bindings::TAINT_WARN); + + $crate::warn_flags!(WARN_ON_FLAGS); + } + cond + }}; +} diff --git a/rust/kernel/build_assert.rs b/rust/kernel/build_assert.rs index 659542393c09..6331b15d7c4d 100644 --- a/rust/kernel/build_assert.rs +++ b/rust/kernel/build_assert.rs @@ -2,6 +2,9 @@ //! Build-time assert. +#[doc(hidden)] +pub use build_error::build_error; + /// Fails the build if the code path calling `build_error!` can possibly be executed. /// /// If the macro is executed in const context, `build_error!` will panic. @@ -11,7 +14,6 @@ /// # Examples /// /// ``` -/// # use kernel::build_error; /// #[inline] /// fn foo(a: usize) -> usize { /// a.checked_add(1).unwrap_or_else(|| build_error!("overflow")) @@ -23,10 +25,10 @@ #[macro_export] macro_rules! build_error { () => {{ - $crate::build_error("") + $crate::build_assert::build_error("") }}; ($msg:expr) => {{ - $crate::build_error($msg) + $crate::build_assert::build_error($msg) }}; } @@ -67,16 +69,18 @@ macro_rules! build_error { /// assert!(n > 1); // Run-time check /// } /// ``` +/// +/// [`static_assert!`]: crate::static_assert! #[macro_export] macro_rules! build_assert { ($cond:expr $(,)?) => {{ if !$cond { - $crate::build_error(concat!("assertion failed: ", stringify!($cond))); + $crate::build_assert::build_error(concat!("assertion failed: ", stringify!($cond))); } }}; ($cond:expr, $msg:expr) => {{ if !$cond { - $crate::build_error($msg); + $crate::build_assert::build_error($msg); } }}; } diff --git a/rust/kernel/clk.rs b/rust/kernel/clk.rs new file mode 100644 index 000000000000..c1cfaeaa36a2 --- /dev/null +++ b/rust/kernel/clk.rs @@ -0,0 +1,330 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! Clock abstractions. +//! +//! C header: [`include/linux/clk.h`](srctree/include/linux/clk.h) +//! +//! Reference: <https://docs.kernel.org/driver-api/clk.html> + +use crate::ffi::c_ulong; + +/// The frequency unit. +/// +/// Represents a frequency in hertz, wrapping a [`c_ulong`] value. +/// +/// # Examples +/// +/// ``` +/// use kernel::clk::Hertz; +/// +/// let hz = 1_000_000_000; +/// let rate = Hertz(hz); +/// +/// assert_eq!(rate.as_hz(), hz); +/// assert_eq!(rate, Hertz(hz)); +/// assert_eq!(rate, Hertz::from_khz(hz / 1_000)); +/// assert_eq!(rate, Hertz::from_mhz(hz / 1_000_000)); +/// assert_eq!(rate, Hertz::from_ghz(hz / 1_000_000_000)); +/// ``` +#[derive(Copy, Clone, PartialEq, Eq, Debug)] +pub struct Hertz(pub c_ulong); + +impl Hertz { + const KHZ_TO_HZ: c_ulong = 1_000; + const MHZ_TO_HZ: c_ulong = 1_000_000; + const GHZ_TO_HZ: c_ulong = 1_000_000_000; + + /// Create a new instance from kilohertz (kHz) + pub const fn from_khz(khz: c_ulong) -> Self { + Self(khz * Self::KHZ_TO_HZ) + } + + /// Create a new instance from megahertz (MHz) + pub const fn from_mhz(mhz: c_ulong) -> Self { + Self(mhz * Self::MHZ_TO_HZ) + } + + /// Create a new instance from gigahertz (GHz) + pub const fn from_ghz(ghz: c_ulong) -> Self { + Self(ghz * Self::GHZ_TO_HZ) + } + + /// Get the frequency in hertz + pub const fn as_hz(&self) -> c_ulong { + self.0 + } + + /// Get the frequency in kilohertz + pub const fn as_khz(&self) -> c_ulong { + self.0 / Self::KHZ_TO_HZ + } + + /// Get the frequency in megahertz + pub const fn as_mhz(&self) -> c_ulong { + self.0 / Self::MHZ_TO_HZ + } + + /// Get the frequency in gigahertz + pub const fn as_ghz(&self) -> c_ulong { + self.0 / Self::GHZ_TO_HZ + } +} + +impl From<Hertz> for c_ulong { + fn from(freq: Hertz) -> Self { + freq.0 + } +} + +#[cfg(CONFIG_COMMON_CLK)] +mod common_clk { + use super::Hertz; + use crate::{ + device::Device, + error::{from_err_ptr, to_result, Result}, + prelude::*, + }; + + use core::{ops::Deref, ptr}; + + /// A reference-counted clock. + /// + /// Rust abstraction for the C [`struct clk`]. + /// + /// # Invariants + /// + /// A [`Clk`] instance holds either a pointer to a valid [`struct clk`] created by the C + /// portion of the kernel or a NULL pointer. + /// + /// Instances of this type are reference-counted. Calling [`Clk::get`] ensures that the + /// allocation remains valid for the lifetime of the [`Clk`]. + /// + /// # Examples + /// + /// The following example demonstrates how to obtain and configure a clock for a device. + /// + /// ``` + /// use kernel::c_str; + /// use kernel::clk::{Clk, Hertz}; + /// use kernel::device::Device; + /// use kernel::error::Result; + /// + /// fn configure_clk(dev: &Device) -> Result { + /// let clk = Clk::get(dev, Some(c_str!("apb_clk")))?; + /// + /// clk.prepare_enable()?; + /// + /// let expected_rate = Hertz::from_ghz(1); + /// + /// if clk.rate() != expected_rate { + /// clk.set_rate(expected_rate)?; + /// } + /// + /// clk.disable_unprepare(); + /// Ok(()) + /// } + /// ``` + /// + /// [`struct clk`]: https://docs.kernel.org/driver-api/clk.html + #[repr(transparent)] + pub struct Clk(*mut bindings::clk); + + impl Clk { + /// Gets [`Clk`] corresponding to a [`Device`] and a connection id. + /// + /// Equivalent to the kernel's [`clk_get`] API. + /// + /// [`clk_get`]: https://docs.kernel.org/core-api/kernel-api.html#c.clk_get + pub fn get(dev: &Device, name: Option<&CStr>) -> Result<Self> { + let con_id = name.map_or(ptr::null(), |n| n.as_char_ptr()); + + // SAFETY: It is safe to call [`clk_get`] for a valid device pointer. + // + // INVARIANT: The reference-count is decremented when [`Clk`] goes out of scope. + Ok(Self(from_err_ptr(unsafe { + bindings::clk_get(dev.as_raw(), con_id) + })?)) + } + + /// Obtain the raw [`struct clk`] pointer. + #[inline] + pub fn as_raw(&self) -> *mut bindings::clk { + self.0 + } + + /// Enable the clock. + /// + /// Equivalent to the kernel's [`clk_enable`] API. + /// + /// [`clk_enable`]: https://docs.kernel.org/core-api/kernel-api.html#c.clk_enable + #[inline] + pub fn enable(&self) -> Result { + // SAFETY: By the type invariants, self.as_raw() is a valid argument for + // [`clk_enable`]. + to_result(unsafe { bindings::clk_enable(self.as_raw()) }) + } + + /// Disable the clock. + /// + /// Equivalent to the kernel's [`clk_disable`] API. + /// + /// [`clk_disable`]: https://docs.kernel.org/core-api/kernel-api.html#c.clk_disable + #[inline] + pub fn disable(&self) { + // SAFETY: By the type invariants, self.as_raw() is a valid argument for + // [`clk_disable`]. + unsafe { bindings::clk_disable(self.as_raw()) }; + } + + /// Prepare the clock. + /// + /// Equivalent to the kernel's [`clk_prepare`] API. + /// + /// [`clk_prepare`]: https://docs.kernel.org/core-api/kernel-api.html#c.clk_prepare + #[inline] + pub fn prepare(&self) -> Result { + // SAFETY: By the type invariants, self.as_raw() is a valid argument for + // [`clk_prepare`]. + to_result(unsafe { bindings::clk_prepare(self.as_raw()) }) + } + + /// Unprepare the clock. + /// + /// Equivalent to the kernel's [`clk_unprepare`] API. + /// + /// [`clk_unprepare`]: https://docs.kernel.org/core-api/kernel-api.html#c.clk_unprepare + #[inline] + pub fn unprepare(&self) { + // SAFETY: By the type invariants, self.as_raw() is a valid argument for + // [`clk_unprepare`]. + unsafe { bindings::clk_unprepare(self.as_raw()) }; + } + + /// Prepare and enable the clock. + /// + /// Equivalent to calling [`Clk::prepare`] followed by [`Clk::enable`]. + #[inline] + pub fn prepare_enable(&self) -> Result { + // SAFETY: By the type invariants, self.as_raw() is a valid argument for + // [`clk_prepare_enable`]. + to_result(unsafe { bindings::clk_prepare_enable(self.as_raw()) }) + } + + /// Disable and unprepare the clock. + /// + /// Equivalent to calling [`Clk::disable`] followed by [`Clk::unprepare`]. + #[inline] + pub fn disable_unprepare(&self) { + // SAFETY: By the type invariants, self.as_raw() is a valid argument for + // [`clk_disable_unprepare`]. + unsafe { bindings::clk_disable_unprepare(self.as_raw()) }; + } + + /// Get clock's rate. + /// + /// Equivalent to the kernel's [`clk_get_rate`] API. + /// + /// [`clk_get_rate`]: https://docs.kernel.org/core-api/kernel-api.html#c.clk_get_rate + #[inline] + pub fn rate(&self) -> Hertz { + // SAFETY: By the type invariants, self.as_raw() is a valid argument for + // [`clk_get_rate`]. + Hertz(unsafe { bindings::clk_get_rate(self.as_raw()) }) + } + + /// Set clock's rate. + /// + /// Equivalent to the kernel's [`clk_set_rate`] API. + /// + /// [`clk_set_rate`]: https://docs.kernel.org/core-api/kernel-api.html#c.clk_set_rate + #[inline] + pub fn set_rate(&self, rate: Hertz) -> Result { + // SAFETY: By the type invariants, self.as_raw() is a valid argument for + // [`clk_set_rate`]. + to_result(unsafe { bindings::clk_set_rate(self.as_raw(), rate.as_hz()) }) + } + } + + impl Drop for Clk { + fn drop(&mut self) { + // SAFETY: By the type invariants, self.as_raw() is a valid argument for [`clk_put`]. + unsafe { bindings::clk_put(self.as_raw()) }; + } + } + + /// A reference-counted optional clock. + /// + /// A lightweight wrapper around an optional [`Clk`]. An [`OptionalClk`] represents a [`Clk`] + /// that a driver can function without but may improve performance or enable additional + /// features when available. + /// + /// # Invariants + /// + /// An [`OptionalClk`] instance encapsulates a [`Clk`] with either a valid [`struct clk`] or + /// `NULL` pointer. + /// + /// Instances of this type are reference-counted. Calling [`OptionalClk::get`] ensures that the + /// allocation remains valid for the lifetime of the [`OptionalClk`]. + /// + /// # Examples + /// + /// The following example demonstrates how to obtain and configure an optional clock for a + /// device. The code functions correctly whether or not the clock is available. + /// + /// ``` + /// use kernel::c_str; + /// use kernel::clk::{OptionalClk, Hertz}; + /// use kernel::device::Device; + /// use kernel::error::Result; + /// + /// fn configure_clk(dev: &Device) -> Result { + /// let clk = OptionalClk::get(dev, Some(c_str!("apb_clk")))?; + /// + /// clk.prepare_enable()?; + /// + /// let expected_rate = Hertz::from_ghz(1); + /// + /// if clk.rate() != expected_rate { + /// clk.set_rate(expected_rate)?; + /// } + /// + /// clk.disable_unprepare(); + /// Ok(()) + /// } + /// ``` + /// + /// [`struct clk`]: https://docs.kernel.org/driver-api/clk.html + pub struct OptionalClk(Clk); + + impl OptionalClk { + /// Gets [`OptionalClk`] corresponding to a [`Device`] and a connection id. + /// + /// Equivalent to the kernel's [`clk_get_optional`] API. + /// + /// [`clk_get_optional`]: + /// https://docs.kernel.org/core-api/kernel-api.html#c.clk_get_optional + pub fn get(dev: &Device, name: Option<&CStr>) -> Result<Self> { + let con_id = name.map_or(ptr::null(), |n| n.as_char_ptr()); + + // SAFETY: It is safe to call [`clk_get_optional`] for a valid device pointer. + // + // INVARIANT: The reference-count is decremented when [`OptionalClk`] goes out of + // scope. + Ok(Self(Clk(from_err_ptr(unsafe { + bindings::clk_get_optional(dev.as_raw(), con_id) + })?))) + } + } + + // Make [`OptionalClk`] behave like [`Clk`]. + impl Deref for OptionalClk { + type Target = Clk; + + fn deref(&self) -> &Clk { + &self.0 + } + } +} + +#[cfg(CONFIG_COMMON_CLK)] +pub use common_clk::*; diff --git a/rust/kernel/configfs.rs b/rust/kernel/configfs.rs new file mode 100644 index 000000000000..466fb7f40762 --- /dev/null +++ b/rust/kernel/configfs.rs @@ -0,0 +1,1043 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! configfs interface: Userspace-driven Kernel Object Configuration +//! +//! configfs is an in-memory pseudo file system for configuration of kernel +//! modules. Please see the [C documentation] for details and intended use of +//! configfs. +//! +//! This module does not support the following configfs features: +//! +//! - Items. All group children are groups. +//! - Symlink support. +//! - `disconnect_notify` hook. +//! - Default groups. +//! +//! See the [`rust_configfs.rs`] sample for a full example use of this module. +//! +//! C header: [`include/linux/configfs.h`](srctree/include/linux/configfs.h) +//! +//! # Examples +//! +//! ```ignore +//! use kernel::alloc::flags; +//! use kernel::c_str; +//! use kernel::configfs_attrs; +//! use kernel::configfs; +//! use kernel::new_mutex; +//! use kernel::page::PAGE_SIZE; +//! use kernel::sync::Mutex; +//! use kernel::ThisModule; +//! +//! #[pin_data] +//! struct RustConfigfs { +//! #[pin] +//! config: configfs::Subsystem<Configuration>, +//! } +//! +//! impl kernel::InPlaceModule for RustConfigfs { +//! fn init(_module: &'static ThisModule) -> impl PinInit<Self, Error> { +//! pr_info!("Rust configfs sample (init)\n"); +//! +//! let item_type = configfs_attrs! { +//! container: configfs::Subsystem<Configuration>, +//! data: Configuration, +//! attributes: [ +//! message: 0, +//! bar: 1, +//! ], +//! }; +//! +//! try_pin_init!(Self { +//! config <- configfs::Subsystem::new( +//! c_str!("rust_configfs"), item_type, Configuration::new() +//! ), +//! }) +//! } +//! } +//! +//! #[pin_data] +//! struct Configuration { +//! message: &'static CStr, +//! #[pin] +//! bar: Mutex<(KBox<[u8; PAGE_SIZE]>, usize)>, +//! } +//! +//! impl Configuration { +//! fn new() -> impl PinInit<Self, Error> { +//! try_pin_init!(Self { +//! message: c_str!("Hello World\n"), +//! bar <- new_mutex!((KBox::new([0; PAGE_SIZE], flags::GFP_KERNEL)?, 0)), +//! }) +//! } +//! } +//! +//! #[vtable] +//! impl configfs::AttributeOperations<0> for Configuration { +//! type Data = Configuration; +//! +//! fn show(container: &Configuration, page: &mut [u8; PAGE_SIZE]) -> Result<usize> { +//! pr_info!("Show message\n"); +//! let data = container.message; +//! page[0..data.len()].copy_from_slice(data); +//! Ok(data.len()) +//! } +//! } +//! +//! #[vtable] +//! impl configfs::AttributeOperations<1> for Configuration { +//! type Data = Configuration; +//! +//! fn show(container: &Configuration, page: &mut [u8; PAGE_SIZE]) -> Result<usize> { +//! pr_info!("Show bar\n"); +//! let guard = container.bar.lock(); +//! let data = guard.0.as_slice(); +//! let len = guard.1; +//! page[0..len].copy_from_slice(&data[0..len]); +//! Ok(len) +//! } +//! +//! fn store(container: &Configuration, page: &[u8]) -> Result { +//! pr_info!("Store bar\n"); +//! let mut guard = container.bar.lock(); +//! guard.0[0..page.len()].copy_from_slice(page); +//! guard.1 = page.len(); +//! Ok(()) +//! } +//! } +//! ``` +//! +//! [C documentation]: srctree/Documentation/filesystems/configfs.rst +//! [`rust_configfs.rs`]: srctree/samples/rust/rust_configfs.rs + +use crate::alloc::flags; +use crate::container_of; +use crate::page::PAGE_SIZE; +use crate::prelude::*; +use crate::str::CString; +use crate::sync::Arc; +use crate::sync::ArcBorrow; +use crate::types::Opaque; +use core::cell::UnsafeCell; +use core::marker::PhantomData; + +/// A configfs subsystem. +/// +/// This is the top level entrypoint for a configfs hierarchy. To register +/// with configfs, embed a field of this type into your kernel module struct. +#[pin_data(PinnedDrop)] +pub struct Subsystem<Data> { + #[pin] + subsystem: Opaque<bindings::configfs_subsystem>, + #[pin] + data: Data, +} + +// SAFETY: We do not provide any operations on `Subsystem`. +unsafe impl<Data> Sync for Subsystem<Data> {} + +// SAFETY: Ownership of `Subsystem` can safely be transferred to other threads. +unsafe impl<Data> Send for Subsystem<Data> {} + +impl<Data> Subsystem<Data> { + /// Create an initializer for a [`Subsystem`]. + /// + /// The subsystem will appear in configfs as a directory name given by + /// `name`. The attributes available in directory are specified by + /// `item_type`. + pub fn new( + name: &'static CStr, + item_type: &'static ItemType<Subsystem<Data>, Data>, + data: impl PinInit<Data, Error>, + ) -> impl PinInit<Self, Error> { + try_pin_init!(Self { + subsystem <- pin_init::init_zeroed().chain( + |place: &mut Opaque<bindings::configfs_subsystem>| { + // SAFETY: We initialized the required fields of `place.group` above. + unsafe { + bindings::config_group_init_type_name( + &mut (*place.get()).su_group, + name.as_char_ptr(), + item_type.as_ptr(), + ) + }; + + // SAFETY: `place.su_mutex` is valid for use as a mutex. + unsafe { + bindings::__mutex_init( + &mut (*place.get()).su_mutex, + kernel::optional_name!().as_char_ptr(), + kernel::static_lock_class!().as_ptr(), + ) + } + Ok(()) + } + ), + data <- data, + }) + .pin_chain(|this| { + crate::error::to_result( + // SAFETY: We initialized `this.subsystem` according to C API contract above. + unsafe { bindings::configfs_register_subsystem(this.subsystem.get()) }, + ) + }) + } +} + +#[pinned_drop] +impl<Data> PinnedDrop for Subsystem<Data> { + fn drop(self: Pin<&mut Self>) { + // SAFETY: We registered `self.subsystem` in the initializer returned by `Self::new`. + unsafe { bindings::configfs_unregister_subsystem(self.subsystem.get()) }; + // SAFETY: We initialized the mutex in `Subsystem::new`. + unsafe { bindings::mutex_destroy(&raw mut (*self.subsystem.get()).su_mutex) }; + } +} + +/// Trait that allows offset calculations for structs that embed a +/// `bindings::config_group`. +/// +/// Users of the configfs API should not need to implement this trait. +/// +/// # Safety +/// +/// - Implementers of this trait must embed a `bindings::config_group`. +/// - Methods must be implemented according to method documentation. +pub unsafe trait HasGroup<Data> { + /// Return the address of the `bindings::config_group` embedded in [`Self`]. + /// + /// # Safety + /// + /// - `this` must be a valid allocation of at least the size of [`Self`]. + unsafe fn group(this: *const Self) -> *const bindings::config_group; + + /// Return the address of the [`Self`] that `group` is embedded in. + /// + /// # Safety + /// + /// - `group` must point to the `bindings::config_group` that is embedded in + /// [`Self`]. + unsafe fn container_of(group: *const bindings::config_group) -> *const Self; +} + +// SAFETY: `Subsystem<Data>` embeds a field of type `bindings::config_group` +// within the `subsystem` field. +unsafe impl<Data> HasGroup<Data> for Subsystem<Data> { + unsafe fn group(this: *const Self) -> *const bindings::config_group { + // SAFETY: By impl and function safety requirement this projection is in bounds. + unsafe { &raw const (*(*this).subsystem.get()).su_group } + } + + unsafe fn container_of(group: *const bindings::config_group) -> *const Self { + // SAFETY: By impl and function safety requirement this projection is in bounds. + let c_subsys_ptr = unsafe { container_of!(group, bindings::configfs_subsystem, su_group) }; + let opaque_ptr = c_subsys_ptr.cast::<Opaque<bindings::configfs_subsystem>>(); + // SAFETY: By impl and function safety requirement, `opaque_ptr` and the + // pointer it returns, are within the same allocation. + unsafe { container_of!(opaque_ptr, Subsystem<Data>, subsystem) } + } +} + +/// A configfs group. +/// +/// To add a subgroup to configfs, pass this type as `ctype` to +/// [`crate::configfs_attrs`] when creating a group in [`GroupOperations::make_group`]. +#[pin_data] +pub struct Group<Data> { + #[pin] + group: Opaque<bindings::config_group>, + #[pin] + data: Data, +} + +impl<Data> Group<Data> { + /// Create an initializer for a new group. + /// + /// When instantiated, the group will appear as a directory with the name + /// given by `name` and it will contain attributes specified by `item_type`. + pub fn new( + name: CString, + item_type: &'static ItemType<Group<Data>, Data>, + data: impl PinInit<Data, Error>, + ) -> impl PinInit<Self, Error> { + try_pin_init!(Self { + group <- pin_init::init_zeroed().chain(|v: &mut Opaque<bindings::config_group>| { + let place = v.get(); + let name = name.to_bytes_with_nul().as_ptr(); + // SAFETY: It is safe to initialize a group once it has been zeroed. + unsafe { + bindings::config_group_init_type_name(place, name.cast(), item_type.as_ptr()) + }; + Ok(()) + }), + data <- data, + }) + } +} + +// SAFETY: `Group<Data>` embeds a field of type `bindings::config_group` +// within the `group` field. +unsafe impl<Data> HasGroup<Data> for Group<Data> { + unsafe fn group(this: *const Self) -> *const bindings::config_group { + Opaque::cast_into( + // SAFETY: By impl and function safety requirements this field + // projection is within bounds of the allocation. + unsafe { &raw const (*this).group }, + ) + } + + unsafe fn container_of(group: *const bindings::config_group) -> *const Self { + let opaque_ptr = group.cast::<Opaque<bindings::config_group>>(); + // SAFETY: By impl and function safety requirement, `opaque_ptr` and + // pointer it returns will be in the same allocation. + unsafe { container_of!(opaque_ptr, Self, group) } + } +} + +/// # Safety +/// +/// `this` must be a valid pointer. +/// +/// If `this` does not represent the root group of a configfs subsystem, +/// `this` must be a pointer to a `bindings::config_group` embedded in a +/// `Group<Parent>`. +/// +/// Otherwise, `this` must be a pointer to a `bindings::config_group` that +/// is embedded in a `bindings::configfs_subsystem` that is embedded in a +/// `Subsystem<Parent>`. +unsafe fn get_group_data<'a, Parent>(this: *mut bindings::config_group) -> &'a Parent { + // SAFETY: `this` is a valid pointer. + let is_root = unsafe { (*this).cg_subsys.is_null() }; + + if !is_root { + // SAFETY: By C API contact,`this` was returned from a call to + // `make_group`. The pointer is known to be embedded within a + // `Group<Parent>`. + unsafe { &(*Group::<Parent>::container_of(this)).data } + } else { + // SAFETY: By C API contract, `this` is a pointer to the + // `bindings::config_group` field within a `Subsystem<Parent>`. + unsafe { &(*Subsystem::container_of(this)).data } + } +} + +struct GroupOperationsVTable<Parent, Child>(PhantomData<(Parent, Child)>); + +impl<Parent, Child> GroupOperationsVTable<Parent, Child> +where + Parent: GroupOperations<Child = Child>, + Child: 'static, +{ + /// # Safety + /// + /// `this` must be a valid pointer. + /// + /// If `this` does not represent the root group of a configfs subsystem, + /// `this` must be a pointer to a `bindings::config_group` embedded in a + /// `Group<Parent>`. + /// + /// Otherwise, `this` must be a pointer to a `bindings::config_group` that + /// is embedded in a `bindings::configfs_subsystem` that is embedded in a + /// `Subsystem<Parent>`. + /// + /// `name` must point to a null terminated string. + unsafe extern "C" fn make_group( + this: *mut bindings::config_group, + name: *const kernel::ffi::c_char, + ) -> *mut bindings::config_group { + // SAFETY: By function safety requirements of this function, this call + // is safe. + let parent_data = unsafe { get_group_data(this) }; + + let group_init = match Parent::make_group( + parent_data, + // SAFETY: By function safety requirements, name points to a null + // terminated string. + unsafe { CStr::from_char_ptr(name) }, + ) { + Ok(init) => init, + Err(e) => return e.to_ptr(), + }; + + let child_group = <Arc<Group<Child>> as InPlaceInit<Group<Child>>>::try_pin_init( + group_init, + flags::GFP_KERNEL, + ); + + match child_group { + Ok(child_group) => { + let child_group_ptr = child_group.into_raw(); + // SAFETY: We allocated the pointee of `child_ptr` above as a + // `Group<Child>`. + unsafe { Group::<Child>::group(child_group_ptr) }.cast_mut() + } + Err(e) => e.to_ptr(), + } + } + + /// # Safety + /// + /// If `this` does not represent the root group of a configfs subsystem, + /// `this` must be a pointer to a `bindings::config_group` embedded in a + /// `Group<Parent>`. + /// + /// Otherwise, `this` must be a pointer to a `bindings::config_group` that + /// is embedded in a `bindings::configfs_subsystem` that is embedded in a + /// `Subsystem<Parent>`. + /// + /// `item` must point to a `bindings::config_item` within a + /// `bindings::config_group` within a `Group<Child>`. + unsafe extern "C" fn drop_item( + this: *mut bindings::config_group, + item: *mut bindings::config_item, + ) { + // SAFETY: By function safety requirements of this function, this call + // is safe. + let parent_data = unsafe { get_group_data(this) }; + + // SAFETY: By function safety requirements, `item` is embedded in a + // `config_group`. + let c_child_group_ptr = unsafe { container_of!(item, bindings::config_group, cg_item) }; + // SAFETY: By function safety requirements, `c_child_group_ptr` is + // embedded within a `Group<Child>`. + let r_child_group_ptr = unsafe { Group::<Child>::container_of(c_child_group_ptr) }; + + if Parent::HAS_DROP_ITEM { + // SAFETY: We called `into_raw` to produce `r_child_group_ptr` in + // `make_group`. + let arc: Arc<Group<Child>> = unsafe { Arc::from_raw(r_child_group_ptr.cast_mut()) }; + + Parent::drop_item(parent_data, arc.as_arc_borrow()); + arc.into_raw(); + } + + // SAFETY: By C API contract, we are required to drop a refcount on + // `item`. + unsafe { bindings::config_item_put(item) }; + } + + const VTABLE: bindings::configfs_group_operations = bindings::configfs_group_operations { + make_item: None, + make_group: Some(Self::make_group), + disconnect_notify: None, + drop_item: Some(Self::drop_item), + is_visible: None, + is_bin_visible: None, + }; + + const fn vtable_ptr() -> *const bindings::configfs_group_operations { + &Self::VTABLE + } +} + +struct ItemOperationsVTable<Container, Data>(PhantomData<(Container, Data)>); + +impl<Data> ItemOperationsVTable<Group<Data>, Data> +where + Data: 'static, +{ + /// # Safety + /// + /// `this` must be a pointer to a `bindings::config_group` embedded in a + /// `Group<Parent>`. + /// + /// This function will destroy the pointee of `this`. The pointee of `this` + /// must not be accessed after the function returns. + unsafe extern "C" fn release(this: *mut bindings::config_item) { + // SAFETY: By function safety requirements, `this` is embedded in a + // `config_group`. + let c_group_ptr = unsafe { kernel::container_of!(this, bindings::config_group, cg_item) }; + // SAFETY: By function safety requirements, `c_group_ptr` is + // embedded within a `Group<Data>`. + let r_group_ptr = unsafe { Group::<Data>::container_of(c_group_ptr) }; + + // SAFETY: We called `into_raw` on `r_group_ptr` in + // `make_group`. + let pin_self: Arc<Group<Data>> = unsafe { Arc::from_raw(r_group_ptr.cast_mut()) }; + drop(pin_self); + } + + const VTABLE: bindings::configfs_item_operations = bindings::configfs_item_operations { + release: Some(Self::release), + allow_link: None, + drop_link: None, + }; + + const fn vtable_ptr() -> *const bindings::configfs_item_operations { + &Self::VTABLE + } +} + +impl<Data> ItemOperationsVTable<Subsystem<Data>, Data> { + const VTABLE: bindings::configfs_item_operations = bindings::configfs_item_operations { + release: None, + allow_link: None, + drop_link: None, + }; + + const fn vtable_ptr() -> *const bindings::configfs_item_operations { + &Self::VTABLE + } +} + +/// Operations implemented by configfs groups that can create subgroups. +/// +/// Implement this trait on structs that embed a [`Subsystem`] or a [`Group`]. +#[vtable] +pub trait GroupOperations { + /// The child data object type. + /// + /// This group will create subgroups (subdirectories) backed by this kind of + /// object. + type Child: 'static; + + /// Creates a new subgroup. + /// + /// The kernel will call this method in response to `mkdir(2)` in the + /// directory representing `this`. + /// + /// To accept the request to create a group, implementations should + /// return an initializer of a `Group<Self::Child>`. To prevent creation, + /// return a suitable error. + fn make_group(&self, name: &CStr) -> Result<impl PinInit<Group<Self::Child>, Error>>; + + /// Prepares the group for removal from configfs. + /// + /// The kernel will call this method before the directory representing `_child` is removed from + /// configfs. + /// + /// Implementations can use this method to do house keeping before configfs drops its + /// reference to `Child`. + /// + /// NOTE: "drop" in the name of this function is not related to the Rust drop term. Rather, the + /// name is inherited from the callback name in the underlying C code. + fn drop_item(&self, _child: ArcBorrow<'_, Group<Self::Child>>) { + kernel::build_error!(kernel::error::VTABLE_DEFAULT_ERROR) + } +} + +/// A configfs attribute. +/// +/// An attribute appears as a file in configfs, inside a folder that represent +/// the group that the attribute belongs to. +#[repr(transparent)] +pub struct Attribute<const ID: u64, O, Data> { + attribute: Opaque<bindings::configfs_attribute>, + _p: PhantomData<(O, Data)>, +} + +// SAFETY: We do not provide any operations on `Attribute`. +unsafe impl<const ID: u64, O, Data> Sync for Attribute<ID, O, Data> {} + +// SAFETY: Ownership of `Attribute` can safely be transferred to other threads. +unsafe impl<const ID: u64, O, Data> Send for Attribute<ID, O, Data> {} + +impl<const ID: u64, O, Data> Attribute<ID, O, Data> +where + O: AttributeOperations<ID, Data = Data>, +{ + /// # Safety + /// + /// `item` must be embedded in a `bindings::config_group`. + /// + /// If `item` does not represent the root group of a configfs subsystem, + /// the group must be embedded in a `Group<Data>`. + /// + /// Otherwise, the group must be a embedded in a + /// `bindings::configfs_subsystem` that is embedded in a `Subsystem<Data>`. + /// + /// `page` must point to a writable buffer of size at least [`PAGE_SIZE`]. + unsafe extern "C" fn show( + item: *mut bindings::config_item, + page: *mut kernel::ffi::c_char, + ) -> isize { + let c_group: *mut bindings::config_group = + // SAFETY: By function safety requirements, `item` is embedded in a + // `config_group`. + unsafe { container_of!(item, bindings::config_group, cg_item) }; + + // SAFETY: The function safety requirements for this function satisfy + // the conditions for this call. + let data: &Data = unsafe { get_group_data(c_group) }; + + // SAFETY: By function safety requirements, `page` is writable for `PAGE_SIZE`. + let ret = O::show(data, unsafe { &mut *(page.cast::<[u8; PAGE_SIZE]>()) }); + + match ret { + Ok(size) => size as isize, + Err(err) => err.to_errno() as isize, + } + } + + /// # Safety + /// + /// `item` must be embedded in a `bindings::config_group`. + /// + /// If `item` does not represent the root group of a configfs subsystem, + /// the group must be embedded in a `Group<Data>`. + /// + /// Otherwise, the group must be a embedded in a + /// `bindings::configfs_subsystem` that is embedded in a `Subsystem<Data>`. + /// + /// `page` must point to a readable buffer of size at least `size`. + unsafe extern "C" fn store( + item: *mut bindings::config_item, + page: *const kernel::ffi::c_char, + size: usize, + ) -> isize { + let c_group: *mut bindings::config_group = + // SAFETY: By function safety requirements, `item` is embedded in a + // `config_group`. + unsafe { container_of!(item, bindings::config_group, cg_item) }; + + // SAFETY: The function safety requirements for this function satisfy + // the conditions for this call. + let data: &Data = unsafe { get_group_data(c_group) }; + + let ret = O::store( + data, + // SAFETY: By function safety requirements, `page` is readable + // for at least `size`. + unsafe { core::slice::from_raw_parts(page.cast(), size) }, + ); + + match ret { + Ok(()) => size as isize, + Err(err) => err.to_errno() as isize, + } + } + + /// Create a new attribute. + /// + /// The attribute will appear as a file with name given by `name`. + pub const fn new(name: &'static CStr) -> Self { + Self { + attribute: Opaque::new(bindings::configfs_attribute { + ca_name: crate::str::as_char_ptr_in_const_context(name), + ca_owner: core::ptr::null_mut(), + ca_mode: 0o660, + show: Some(Self::show), + store: if O::HAS_STORE { + Some(Self::store) + } else { + None + }, + }), + _p: PhantomData, + } + } +} + +/// Operations supported by an attribute. +/// +/// Implement this trait on type and pass that type as generic parameter when +/// creating an [`Attribute`]. The type carrying the implementation serve no +/// purpose other than specifying the attribute operations. +/// +/// This trait must be implemented on the `Data` type of for types that +/// implement `HasGroup<Data>`. The trait must be implemented once for each +/// attribute of the group. The constant type parameter `ID` maps the +/// implementation to a specific `Attribute`. `ID` must be passed when declaring +/// attributes via the [`kernel::configfs_attrs`] macro, to tie +/// `AttributeOperations` implementations to concrete named attributes. +#[vtable] +pub trait AttributeOperations<const ID: u64 = 0> { + /// The type of the object that contains the field that is backing the + /// attribute for this operation. + type Data; + + /// Renders the value of an attribute. + /// + /// This function is called by the kernel to read the value of an attribute. + /// + /// Implementations should write the rendering of the attribute to `page` + /// and return the number of bytes written. + fn show(data: &Self::Data, page: &mut [u8; PAGE_SIZE]) -> Result<usize>; + + /// Stores the value of an attribute. + /// + /// This function is called by the kernel to update the value of an attribute. + /// + /// Implementations should parse the value from `page` and update internal + /// state to reflect the parsed value. + fn store(_data: &Self::Data, _page: &[u8]) -> Result { + kernel::build_error!(kernel::error::VTABLE_DEFAULT_ERROR) + } +} + +/// A list of attributes. +/// +/// This type is used to construct a new [`ItemType`]. It represents a list of +/// [`Attribute`] that will appear in the directory representing a [`Group`]. +/// Users should not directly instantiate this type, rather they should use the +/// [`kernel::configfs_attrs`] macro to declare a static set of attributes for a +/// group. +/// +/// # Note +/// +/// Instances of this type are constructed statically at compile by the +/// [`kernel::configfs_attrs`] macro. +#[repr(transparent)] +pub struct AttributeList<const N: usize, Data>( + /// Null terminated Array of pointers to [`Attribute`]. The type is [`c_void`] + /// to conform to the C API. + UnsafeCell<[*mut kernel::ffi::c_void; N]>, + PhantomData<Data>, +); + +// SAFETY: Ownership of `AttributeList` can safely be transferred to other threads. +unsafe impl<const N: usize, Data> Send for AttributeList<N, Data> {} + +// SAFETY: We do not provide any operations on `AttributeList` that need synchronization. +unsafe impl<const N: usize, Data> Sync for AttributeList<N, Data> {} + +impl<const N: usize, Data> AttributeList<N, Data> { + /// # Safety + /// + /// This function must only be called by the [`kernel::configfs_attrs`] + /// macro. + #[doc(hidden)] + pub const unsafe fn new() -> Self { + Self(UnsafeCell::new([core::ptr::null_mut(); N]), PhantomData) + } + + /// # Safety + /// + /// The caller must ensure that there are no other concurrent accesses to + /// `self`. That is, the caller has exclusive access to `self.` + #[doc(hidden)] + pub const unsafe fn add<const I: usize, const ID: u64, O>( + &'static self, + attribute: &'static Attribute<ID, O, Data>, + ) where + O: AttributeOperations<ID, Data = Data>, + { + // We need a space at the end of our list for a null terminator. + const { assert!(I < N - 1, "Invalid attribute index") }; + + // SAFETY: By function safety requirements, we have exclusive access to + // `self` and the reference created below will be exclusive. + unsafe { (&mut *self.0.get())[I] = core::ptr::from_ref(attribute).cast_mut().cast() }; + } +} + +/// A representation of the attributes that will appear in a [`Group`] or +/// [`Subsystem`]. +/// +/// Users should not directly instantiate objects of this type. Rather, they +/// should use the [`kernel::configfs_attrs`] macro to statically declare the +/// shape of a [`Group`] or [`Subsystem`]. +#[pin_data] +pub struct ItemType<Container, Data> { + #[pin] + item_type: Opaque<bindings::config_item_type>, + _p: PhantomData<(Container, Data)>, +} + +// SAFETY: We do not provide any operations on `ItemType` that need synchronization. +unsafe impl<Container, Data> Sync for ItemType<Container, Data> {} + +// SAFETY: Ownership of `ItemType` can safely be transferred to other threads. +unsafe impl<Container, Data> Send for ItemType<Container, Data> {} + +macro_rules! impl_item_type { + ($tpe:ty) => { + impl<Data> ItemType<$tpe, Data> { + #[doc(hidden)] + pub const fn new_with_child_ctor<const N: usize, Child>( + owner: &'static ThisModule, + attributes: &'static AttributeList<N, Data>, + ) -> Self + where + Data: GroupOperations<Child = Child>, + Child: 'static, + { + Self { + item_type: Opaque::new(bindings::config_item_type { + ct_owner: owner.as_ptr(), + ct_group_ops: GroupOperationsVTable::<Data, Child>::vtable_ptr().cast_mut(), + ct_item_ops: ItemOperationsVTable::<$tpe, Data>::vtable_ptr().cast_mut(), + ct_attrs: core::ptr::from_ref(attributes).cast_mut().cast(), + ct_bin_attrs: core::ptr::null_mut(), + }), + _p: PhantomData, + } + } + + #[doc(hidden)] + pub const fn new<const N: usize>( + owner: &'static ThisModule, + attributes: &'static AttributeList<N, Data>, + ) -> Self { + Self { + item_type: Opaque::new(bindings::config_item_type { + ct_owner: owner.as_ptr(), + ct_group_ops: core::ptr::null_mut(), + ct_item_ops: ItemOperationsVTable::<$tpe, Data>::vtable_ptr().cast_mut(), + ct_attrs: core::ptr::from_ref(attributes).cast_mut().cast(), + ct_bin_attrs: core::ptr::null_mut(), + }), + _p: PhantomData, + } + } + } + }; +} + +impl_item_type!(Subsystem<Data>); +impl_item_type!(Group<Data>); + +impl<Container, Data> ItemType<Container, Data> { + fn as_ptr(&self) -> *const bindings::config_item_type { + self.item_type.get() + } +} + +/// Define a list of configfs attributes statically. +/// +/// Invoking the macro in the following manner: +/// +/// ```ignore +/// let item_type = configfs_attrs! { +/// container: configfs::Subsystem<Configuration>, +/// data: Configuration, +/// child: Child, +/// attributes: [ +/// message: 0, +/// bar: 1, +/// ], +/// }; +/// ``` +/// +/// Expands the following output: +/// +/// ```ignore +/// let item_type = { +/// static CONFIGURATION_MESSAGE_ATTR: kernel::configfs::Attribute< +/// 0, +/// Configuration, +/// Configuration, +/// > = unsafe { +/// kernel::configfs::Attribute::new({ +/// const S: &str = "message\u{0}"; +/// const C: &kernel::str::CStr = match kernel::str::CStr::from_bytes_with_nul( +/// S.as_bytes() +/// ) { +/// Ok(v) => v, +/// Err(_) => { +/// core::panicking::panic_fmt(core::const_format_args!( +/// "string contains interior NUL" +/// )); +/// } +/// }; +/// C +/// }) +/// }; +/// +/// static CONFIGURATION_BAR_ATTR: kernel::configfs::Attribute< +/// 1, +/// Configuration, +/// Configuration +/// > = unsafe { +/// kernel::configfs::Attribute::new({ +/// const S: &str = "bar\u{0}"; +/// const C: &kernel::str::CStr = match kernel::str::CStr::from_bytes_with_nul( +/// S.as_bytes() +/// ) { +/// Ok(v) => v, +/// Err(_) => { +/// core::panicking::panic_fmt(core::const_format_args!( +/// "string contains interior NUL" +/// )); +/// } +/// }; +/// C +/// }) +/// }; +/// +/// const N: usize = (1usize + (1usize + 0usize)) + 1usize; +/// +/// static CONFIGURATION_ATTRS: kernel::configfs::AttributeList<N, Configuration> = +/// unsafe { kernel::configfs::AttributeList::new() }; +/// +/// { +/// const N: usize = 0usize; +/// unsafe { CONFIGURATION_ATTRS.add::<N, 0, _>(&CONFIGURATION_MESSAGE_ATTR) }; +/// } +/// +/// { +/// const N: usize = (1usize + 0usize); +/// unsafe { CONFIGURATION_ATTRS.add::<N, 1, _>(&CONFIGURATION_BAR_ATTR) }; +/// } +/// +/// static CONFIGURATION_TPE: +/// kernel::configfs::ItemType<configfs::Subsystem<Configuration> ,Configuration> +/// = kernel::configfs::ItemType::< +/// configfs::Subsystem<Configuration>, +/// Configuration +/// >::new_with_child_ctor::<N,Child>( +/// &THIS_MODULE, +/// &CONFIGURATION_ATTRS +/// ); +/// +/// &CONFIGURATION_TPE +/// } +/// ``` +#[macro_export] +macro_rules! configfs_attrs { + ( + container: $container:ty, + data: $data:ty, + attributes: [ + $($name:ident: $attr:literal),* $(,)? + ] $(,)? + ) => { + $crate::configfs_attrs!( + count: + @container($container), + @data($data), + @child(), + @no_child(x), + @attrs($($name $attr)*), + @eat($($name $attr,)*), + @assign(), + @cnt(0usize), + ) + }; + ( + container: $container:ty, + data: $data:ty, + child: $child:ty, + attributes: [ + $($name:ident: $attr:literal),* $(,)? + ] $(,)? + ) => { + $crate::configfs_attrs!( + count: + @container($container), + @data($data), + @child($child), + @no_child(), + @attrs($($name $attr)*), + @eat($($name $attr,)*), + @assign(), + @cnt(0usize), + ) + }; + (count: + @container($container:ty), + @data($data:ty), + @child($($child:ty)?), + @no_child($($no_child:ident)?), + @attrs($($aname:ident $aattr:literal)*), + @eat($name:ident $attr:literal, $($rname:ident $rattr:literal,)*), + @assign($($assign:block)*), + @cnt($cnt:expr), + ) => { + $crate::configfs_attrs!( + count: + @container($container), + @data($data), + @child($($child)?), + @no_child($($no_child)?), + @attrs($($aname $aattr)*), + @eat($($rname $rattr,)*), + @assign($($assign)* { + const N: usize = $cnt; + // The following macro text expands to a call to `Attribute::add`. + + // SAFETY: By design of this macro, the name of the variable we + // invoke the `add` method on below, is not visible outside of + // the macro expansion. The macro does not operate concurrently + // on this variable, and thus we have exclusive access to the + // variable. + unsafe { + $crate::macros::paste!( + [< $data:upper _ATTRS >] + .add::<N, $attr, _>(&[< $data:upper _ $name:upper _ATTR >]) + ) + }; + }), + @cnt(1usize + $cnt), + ) + }; + (count: + @container($container:ty), + @data($data:ty), + @child($($child:ty)?), + @no_child($($no_child:ident)?), + @attrs($($aname:ident $aattr:literal)*), + @eat(), + @assign($($assign:block)*), + @cnt($cnt:expr), + ) => + { + $crate::configfs_attrs!( + final: + @container($container), + @data($data), + @child($($child)?), + @no_child($($no_child)?), + @attrs($($aname $aattr)*), + @assign($($assign)*), + @cnt($cnt), + ) + }; + (final: + @container($container:ty), + @data($data:ty), + @child($($child:ty)?), + @no_child($($no_child:ident)?), + @attrs($($name:ident $attr:literal)*), + @assign($($assign:block)*), + @cnt($cnt:expr), + ) => + { + $crate::macros::paste!{ + { + $( + // SAFETY: We are expanding `configfs_attrs`. + static [< $data:upper _ $name:upper _ATTR >]: + $crate::configfs::Attribute<$attr, $data, $data> = + unsafe { + $crate::configfs::Attribute::new(c_str!(::core::stringify!($name))) + }; + )* + + + // We need space for a null terminator. + const N: usize = $cnt + 1usize; + + // SAFETY: We are expanding `configfs_attrs`. + static [< $data:upper _ATTRS >]: + $crate::configfs::AttributeList<N, $data> = + unsafe { $crate::configfs::AttributeList::new() }; + + $($assign)* + + $( + const [<$no_child:upper>]: bool = true; + + static [< $data:upper _TPE >] : $crate::configfs::ItemType<$container, $data> = + $crate::configfs::ItemType::<$container, $data>::new::<N>( + &THIS_MODULE, &[<$ data:upper _ATTRS >] + ); + )? + + $( + static [< $data:upper _TPE >]: + $crate::configfs::ItemType<$container, $data> = + $crate::configfs::ItemType::<$container, $data>:: + new_with_child_ctor::<N, $child>( + &THIS_MODULE, &[<$ data:upper _ATTRS >] + ); + )? + + & [< $data:upper _TPE >] + } + } + }; + +} + +pub use crate::configfs_attrs; diff --git a/rust/kernel/cpu.rs b/rust/kernel/cpu.rs new file mode 100644 index 000000000000..cb6c0338ef5a --- /dev/null +++ b/rust/kernel/cpu.rs @@ -0,0 +1,152 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! Generic CPU definitions. +//! +//! C header: [`include/linux/cpu.h`](srctree/include/linux/cpu.h) + +use crate::{bindings, device::Device, error::Result, prelude::ENODEV}; + +/// Returns the maximum number of possible CPUs in the current system configuration. +#[inline] +pub fn nr_cpu_ids() -> u32 { + #[cfg(any(NR_CPUS_1, CONFIG_FORCE_NR_CPUS))] + { + bindings::NR_CPUS + } + + #[cfg(not(any(NR_CPUS_1, CONFIG_FORCE_NR_CPUS)))] + // SAFETY: `nr_cpu_ids` is a valid global provided by the kernel. + unsafe { + bindings::nr_cpu_ids + } +} + +/// The CPU ID. +/// +/// Represents a CPU identifier as a wrapper around an [`u32`]. +/// +/// # Invariants +/// +/// The CPU ID lies within the range `[0, nr_cpu_ids())`. +/// +/// # Examples +/// +/// ``` +/// use kernel::cpu::CpuId; +/// +/// let cpu = 0; +/// +/// // SAFETY: 0 is always a valid CPU number. +/// let id = unsafe { CpuId::from_u32_unchecked(cpu) }; +/// +/// assert_eq!(id.as_u32(), cpu); +/// assert!(CpuId::from_i32(0).is_some()); +/// assert!(CpuId::from_i32(-1).is_none()); +/// ``` +#[derive(Copy, Clone, PartialEq, Eq, Debug)] +pub struct CpuId(u32); + +impl CpuId { + /// Creates a new [`CpuId`] from the given `id` without checking bounds. + /// + /// # Safety + /// + /// The caller must ensure that `id` is a valid CPU ID (i.e., `0 <= id < nr_cpu_ids()`). + #[inline] + pub unsafe fn from_i32_unchecked(id: i32) -> Self { + debug_assert!(id >= 0); + debug_assert!((id as u32) < nr_cpu_ids()); + + // INVARIANT: The function safety guarantees `id` is a valid CPU id. + Self(id as u32) + } + + /// Creates a new [`CpuId`] from the given `id`, checking that it is valid. + pub fn from_i32(id: i32) -> Option<Self> { + if id < 0 || id as u32 >= nr_cpu_ids() { + None + } else { + // INVARIANT: `id` has just been checked as a valid CPU ID. + Some(Self(id as u32)) + } + } + + /// Creates a new [`CpuId`] from the given `id` without checking bounds. + /// + /// # Safety + /// + /// The caller must ensure that `id` is a valid CPU ID (i.e., `0 <= id < nr_cpu_ids()`). + #[inline] + pub unsafe fn from_u32_unchecked(id: u32) -> Self { + debug_assert!(id < nr_cpu_ids()); + + // Ensure the `id` fits in an [`i32`] as it's also representable that way. + debug_assert!(id <= i32::MAX as u32); + + // INVARIANT: The function safety guarantees `id` is a valid CPU id. + Self(id) + } + + /// Creates a new [`CpuId`] from the given `id`, checking that it is valid. + pub fn from_u32(id: u32) -> Option<Self> { + if id >= nr_cpu_ids() { + None + } else { + // INVARIANT: `id` has just been checked as a valid CPU ID. + Some(Self(id)) + } + } + + /// Returns CPU number. + #[inline] + pub fn as_u32(&self) -> u32 { + self.0 + } + + /// Returns the ID of the CPU the code is currently running on. + /// + /// The returned value is considered unstable because it may change + /// unexpectedly due to preemption or CPU migration. It should only be + /// used when the context ensures that the task remains on the same CPU + /// or the users could use a stale (yet valid) CPU ID. + #[inline] + pub fn current() -> Self { + // SAFETY: raw_smp_processor_id() always returns a valid CPU ID. + unsafe { Self::from_u32_unchecked(bindings::raw_smp_processor_id()) } + } +} + +impl From<CpuId> for u32 { + fn from(id: CpuId) -> Self { + id.as_u32() + } +} + +impl From<CpuId> for i32 { + fn from(id: CpuId) -> Self { + id.as_u32() as i32 + } +} + +/// Creates a new instance of CPU's device. +/// +/// # Safety +/// +/// Reference counting is not implemented for the CPU device in the C code. When a CPU is +/// hot-unplugged, the corresponding CPU device is unregistered, but its associated memory +/// is not freed. +/// +/// Callers must ensure that the CPU device is not used after it has been unregistered. +/// This can be achieved, for example, by registering a CPU hotplug notifier and removing +/// any references to the CPU device within the notifier's callback. +pub unsafe fn from_cpu(cpu: CpuId) -> Result<&'static Device> { + // SAFETY: It is safe to call `get_cpu_device()` for any CPU. + let ptr = unsafe { bindings::get_cpu_device(u32::from(cpu)) }; + if ptr.is_null() { + return Err(ENODEV); + } + + // SAFETY: The pointer returned by `get_cpu_device()`, if not `NULL`, is a valid pointer to + // a `struct device` and is never freed by the C code. + Ok(unsafe { Device::from_raw(ptr) }) +} diff --git a/rust/kernel/cpufreq.rs b/rust/kernel/cpufreq.rs new file mode 100644 index 000000000000..f968fbd22890 --- /dev/null +++ b/rust/kernel/cpufreq.rs @@ -0,0 +1,1397 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! CPU frequency scaling. +//! +//! This module provides rust abstractions for interacting with the cpufreq subsystem. +//! +//! C header: [`include/linux/cpufreq.h`](srctree/include/linux/cpufreq.h) +//! +//! Reference: <https://docs.kernel.org/admin-guide/pm/cpufreq.html> + +use crate::{ + clk::Hertz, + cpu::CpuId, + cpumask, + device::{Bound, Device}, + devres, + error::{code::*, from_err_ptr, from_result, to_result, Result, VTABLE_DEFAULT_ERROR}, + ffi::{c_char, c_ulong}, + prelude::*, + types::ForeignOwnable, + types::Opaque, +}; + +#[cfg(CONFIG_COMMON_CLK)] +use crate::clk::Clk; + +use core::{ + cell::UnsafeCell, + marker::PhantomData, + ops::{Deref, DerefMut}, + pin::Pin, + ptr, +}; + +use macros::vtable; + +/// Maximum length of CPU frequency driver's name. +const CPUFREQ_NAME_LEN: usize = bindings::CPUFREQ_NAME_LEN as usize; + +/// Default transition latency value in nanoseconds. +pub const DEFAULT_TRANSITION_LATENCY_NS: u32 = bindings::CPUFREQ_DEFAULT_TRANSITION_LATENCY_NS; + +/// CPU frequency driver flags. +pub mod flags { + /// Driver needs to update internal limits even if frequency remains unchanged. + pub const NEED_UPDATE_LIMITS: u16 = 1 << 0; + + /// Platform where constants like `loops_per_jiffy` are unaffected by frequency changes. + pub const CONST_LOOPS: u16 = 1 << 1; + + /// Register driver as a thermal cooling device automatically. + pub const IS_COOLING_DEV: u16 = 1 << 2; + + /// Supports multiple clock domains with per-policy governors in `cpu/cpuN/cpufreq/`. + pub const HAVE_GOVERNOR_PER_POLICY: u16 = 1 << 3; + + /// Allows post-change notifications outside of the `target()` routine. + pub const ASYNC_NOTIFICATION: u16 = 1 << 4; + + /// Ensure CPU starts at a valid frequency from the driver's freq-table. + pub const NEED_INITIAL_FREQ_CHECK: u16 = 1 << 5; + + /// Disallow governors with `dynamic_switching` capability. + pub const NO_AUTO_DYNAMIC_SWITCHING: u16 = 1 << 6; +} + +/// Relations from the C code. +const CPUFREQ_RELATION_L: u32 = 0; +const CPUFREQ_RELATION_H: u32 = 1; +const CPUFREQ_RELATION_C: u32 = 2; + +/// Can be used with any of the above values. +const CPUFREQ_RELATION_E: u32 = 1 << 2; + +/// CPU frequency selection relations. +/// +/// CPU frequency selection relations, each optionally marked as "efficient". +#[derive(Copy, Clone, Debug, Eq, PartialEq)] +pub enum Relation { + /// Select the lowest frequency at or above target. + Low(bool), + /// Select the highest frequency below or at target. + High(bool), + /// Select the closest frequency to the target. + Close(bool), +} + +impl Relation { + // Construct from a C-compatible `u32` value. + fn new(val: u32) -> Result<Self> { + let efficient = val & CPUFREQ_RELATION_E != 0; + + Ok(match val & !CPUFREQ_RELATION_E { + CPUFREQ_RELATION_L => Self::Low(efficient), + CPUFREQ_RELATION_H => Self::High(efficient), + CPUFREQ_RELATION_C => Self::Close(efficient), + _ => return Err(EINVAL), + }) + } +} + +impl From<Relation> for u32 { + // Convert to a C-compatible `u32` value. + fn from(rel: Relation) -> Self { + let (mut val, efficient) = match rel { + Relation::Low(e) => (CPUFREQ_RELATION_L, e), + Relation::High(e) => (CPUFREQ_RELATION_H, e), + Relation::Close(e) => (CPUFREQ_RELATION_C, e), + }; + + if efficient { + val |= CPUFREQ_RELATION_E; + } + + val + } +} + +/// Policy data. +/// +/// Rust abstraction for the C `struct cpufreq_policy_data`. +/// +/// # Invariants +/// +/// A [`PolicyData`] instance always corresponds to a valid C `struct cpufreq_policy_data`. +/// +/// The callers must ensure that the `struct cpufreq_policy_data` is valid for access and remains +/// valid for the lifetime of the returned reference. +#[repr(transparent)] +pub struct PolicyData(Opaque<bindings::cpufreq_policy_data>); + +impl PolicyData { + /// Creates a mutable reference to an existing `struct cpufreq_policy_data` pointer. + /// + /// # Safety + /// + /// The caller must ensure that `ptr` is valid for writing and remains valid for the lifetime + /// of the returned reference. + #[inline] + pub unsafe fn from_raw_mut<'a>(ptr: *mut bindings::cpufreq_policy_data) -> &'a mut Self { + // SAFETY: Guaranteed by the safety requirements of the function. + // + // INVARIANT: The caller ensures that `ptr` is valid for writing and remains valid for the + // lifetime of the returned reference. + unsafe { &mut *ptr.cast() } + } + + /// Returns a raw pointer to the underlying C `cpufreq_policy_data`. + #[inline] + pub fn as_raw(&self) -> *mut bindings::cpufreq_policy_data { + let this: *const Self = self; + this.cast_mut().cast() + } + + /// Wrapper for `cpufreq_generic_frequency_table_verify`. + #[inline] + pub fn generic_verify(&self) -> Result { + // SAFETY: By the type invariant, the pointer stored in `self` is valid. + to_result(unsafe { bindings::cpufreq_generic_frequency_table_verify(self.as_raw()) }) + } +} + +/// The frequency table index. +/// +/// Represents index with a frequency table. +/// +/// # Invariants +/// +/// The index must correspond to a valid entry in the [`Table`] it is used for. +#[derive(Copy, Clone, PartialEq, Eq, Debug)] +pub struct TableIndex(usize); + +impl TableIndex { + /// Creates an instance of [`TableIndex`]. + /// + /// # Safety + /// + /// The caller must ensure that `index` correspond to a valid entry in the [`Table`] it is used + /// for. + pub unsafe fn new(index: usize) -> Self { + // INVARIANT: The caller ensures that `index` correspond to a valid entry in the [`Table`]. + Self(index) + } +} + +impl From<TableIndex> for usize { + #[inline] + fn from(index: TableIndex) -> Self { + index.0 + } +} + +/// CPU frequency table. +/// +/// Rust abstraction for the C `struct cpufreq_frequency_table`. +/// +/// # Invariants +/// +/// A [`Table`] instance always corresponds to a valid C `struct cpufreq_frequency_table`. +/// +/// The callers must ensure that the `struct cpufreq_frequency_table` is valid for access and +/// remains valid for the lifetime of the returned reference. +/// +/// # Examples +/// +/// The following example demonstrates how to read a frequency value from [`Table`]. +/// +/// ``` +/// use kernel::cpufreq::{Policy, TableIndex}; +/// +/// fn show_freq(policy: &Policy) -> Result { +/// let table = policy.freq_table()?; +/// +/// // SAFETY: Index is a valid entry in the table. +/// let index = unsafe { TableIndex::new(0) }; +/// +/// pr_info!("The frequency at index 0 is: {:?}\n", table.freq(index)?); +/// pr_info!("The flags at index 0 is: {}\n", table.flags(index)); +/// pr_info!("The data at index 0 is: {}\n", table.data(index)); +/// Ok(()) +/// } +/// ``` +#[repr(transparent)] +pub struct Table(Opaque<bindings::cpufreq_frequency_table>); + +impl Table { + /// Creates a reference to an existing C `struct cpufreq_frequency_table` pointer. + /// + /// # Safety + /// + /// The caller must ensure that `ptr` is valid for reading and remains valid for the lifetime + /// of the returned reference. + #[inline] + pub unsafe fn from_raw<'a>(ptr: *const bindings::cpufreq_frequency_table) -> &'a Self { + // SAFETY: Guaranteed by the safety requirements of the function. + // + // INVARIANT: The caller ensures that `ptr` is valid for reading and remains valid for the + // lifetime of the returned reference. + unsafe { &*ptr.cast() } + } + + /// Returns the raw mutable pointer to the C `struct cpufreq_frequency_table`. + #[inline] + pub fn as_raw(&self) -> *mut bindings::cpufreq_frequency_table { + let this: *const Self = self; + this.cast_mut().cast() + } + + /// Returns frequency at `index` in the [`Table`]. + #[inline] + pub fn freq(&self, index: TableIndex) -> Result<Hertz> { + // SAFETY: By the type invariant, the pointer stored in `self` is valid and `index` is + // guaranteed to be valid by its safety requirements. + Ok(Hertz::from_khz(unsafe { + (*self.as_raw().add(index.into())).frequency.try_into()? + })) + } + + /// Returns flags at `index` in the [`Table`]. + #[inline] + pub fn flags(&self, index: TableIndex) -> u32 { + // SAFETY: By the type invariant, the pointer stored in `self` is valid and `index` is + // guaranteed to be valid by its safety requirements. + unsafe { (*self.as_raw().add(index.into())).flags } + } + + /// Returns data at `index` in the [`Table`]. + #[inline] + pub fn data(&self, index: TableIndex) -> u32 { + // SAFETY: By the type invariant, the pointer stored in `self` is valid and `index` is + // guaranteed to be valid by its safety requirements. + unsafe { (*self.as_raw().add(index.into())).driver_data } + } +} + +/// CPU frequency table owned and pinned in memory, created from a [`TableBuilder`]. +pub struct TableBox { + entries: Pin<KVec<bindings::cpufreq_frequency_table>>, +} + +impl TableBox { + /// Constructs a new [`TableBox`] from a [`KVec`] of entries. + /// + /// # Errors + /// + /// Returns `EINVAL` if the entries list is empty. + #[inline] + fn new(entries: KVec<bindings::cpufreq_frequency_table>) -> Result<Self> { + if entries.is_empty() { + return Err(EINVAL); + } + + Ok(Self { + // Pin the entries to memory, since we are passing its pointer to the C code. + entries: Pin::new(entries), + }) + } + + /// Returns a raw pointer to the underlying C `cpufreq_frequency_table`. + #[inline] + fn as_raw(&self) -> *const bindings::cpufreq_frequency_table { + // The pointer is valid until the table gets dropped. + self.entries.as_ptr() + } +} + +impl Deref for TableBox { + type Target = Table; + + fn deref(&self) -> &Self::Target { + // SAFETY: The caller owns TableBox, it is safe to deref. + unsafe { Self::Target::from_raw(self.as_raw()) } + } +} + +/// CPU frequency table builder. +/// +/// This is used by the CPU frequency drivers to build a frequency table dynamically. +/// +/// # Examples +/// +/// The following example demonstrates how to create a CPU frequency table. +/// +/// ``` +/// use kernel::cpufreq::{TableBuilder, TableIndex}; +/// use kernel::clk::Hertz; +/// +/// let mut builder = TableBuilder::new(); +/// +/// // Adds few entries to the table. +/// builder.add(Hertz::from_mhz(700), 0, 1).unwrap(); +/// builder.add(Hertz::from_mhz(800), 2, 3).unwrap(); +/// builder.add(Hertz::from_mhz(900), 4, 5).unwrap(); +/// builder.add(Hertz::from_ghz(1), 6, 7).unwrap(); +/// +/// let table = builder.to_table().unwrap(); +/// +/// // SAFETY: Index values correspond to valid entries in the table. +/// let (index0, index2) = unsafe { (TableIndex::new(0), TableIndex::new(2)) }; +/// +/// assert_eq!(table.freq(index0), Ok(Hertz::from_mhz(700))); +/// assert_eq!(table.flags(index0), 0); +/// assert_eq!(table.data(index0), 1); +/// +/// assert_eq!(table.freq(index2), Ok(Hertz::from_mhz(900))); +/// assert_eq!(table.flags(index2), 4); +/// assert_eq!(table.data(index2), 5); +/// ``` +#[derive(Default)] +#[repr(transparent)] +pub struct TableBuilder { + entries: KVec<bindings::cpufreq_frequency_table>, +} + +impl TableBuilder { + /// Creates a new instance of [`TableBuilder`]. + #[inline] + pub fn new() -> Self { + Self { + entries: KVec::new(), + } + } + + /// Adds a new entry to the table. + pub fn add(&mut self, freq: Hertz, flags: u32, driver_data: u32) -> Result { + // Adds the new entry at the end of the vector. + Ok(self.entries.push( + bindings::cpufreq_frequency_table { + flags, + driver_data, + frequency: freq.as_khz() as u32, + }, + GFP_KERNEL, + )?) + } + + /// Consumes the [`TableBuilder`] and returns [`TableBox`]. + pub fn to_table(mut self) -> Result<TableBox> { + // Add last entry to the table. + self.add(Hertz(c_ulong::MAX), 0, 0)?; + + TableBox::new(self.entries) + } +} + +/// CPU frequency policy. +/// +/// Rust abstraction for the C `struct cpufreq_policy`. +/// +/// # Invariants +/// +/// A [`Policy`] instance always corresponds to a valid C `struct cpufreq_policy`. +/// +/// The callers must ensure that the `struct cpufreq_policy` is valid for access and remains valid +/// for the lifetime of the returned reference. +/// +/// # Examples +/// +/// The following example demonstrates how to create a CPU frequency table. +/// +/// ``` +/// use kernel::cpufreq::{DEFAULT_TRANSITION_LATENCY_NS, Policy}; +/// +/// fn update_policy(policy: &mut Policy) { +/// policy +/// .set_dvfs_possible_from_any_cpu(true) +/// .set_fast_switch_possible(true) +/// .set_transition_latency_ns(DEFAULT_TRANSITION_LATENCY_NS); +/// +/// pr_info!("The policy details are: {:?}\n", (policy.cpu(), policy.cur())); +/// } +/// ``` +#[repr(transparent)] +pub struct Policy(Opaque<bindings::cpufreq_policy>); + +impl Policy { + /// Creates a reference to an existing `struct cpufreq_policy` pointer. + /// + /// # Safety + /// + /// The caller must ensure that `ptr` is valid for reading and remains valid for the lifetime + /// of the returned reference. + #[inline] + pub unsafe fn from_raw<'a>(ptr: *const bindings::cpufreq_policy) -> &'a Self { + // SAFETY: Guaranteed by the safety requirements of the function. + // + // INVARIANT: The caller ensures that `ptr` is valid for reading and remains valid for the + // lifetime of the returned reference. + unsafe { &*ptr.cast() } + } + + /// Creates a mutable reference to an existing `struct cpufreq_policy` pointer. + /// + /// # Safety + /// + /// The caller must ensure that `ptr` is valid for writing and remains valid for the lifetime + /// of the returned reference. + #[inline] + pub unsafe fn from_raw_mut<'a>(ptr: *mut bindings::cpufreq_policy) -> &'a mut Self { + // SAFETY: Guaranteed by the safety requirements of the function. + // + // INVARIANT: The caller ensures that `ptr` is valid for writing and remains valid for the + // lifetime of the returned reference. + unsafe { &mut *ptr.cast() } + } + + /// Returns a raw mutable pointer to the C `struct cpufreq_policy`. + #[inline] + fn as_raw(&self) -> *mut bindings::cpufreq_policy { + let this: *const Self = self; + this.cast_mut().cast() + } + + #[inline] + fn as_ref(&self) -> &bindings::cpufreq_policy { + // SAFETY: By the type invariant, the pointer stored in `self` is valid. + unsafe { &*self.as_raw() } + } + + #[inline] + fn as_mut_ref(&mut self) -> &mut bindings::cpufreq_policy { + // SAFETY: By the type invariant, the pointer stored in `self` is valid. + unsafe { &mut *self.as_raw() } + } + + /// Returns the primary CPU for the [`Policy`]. + #[inline] + pub fn cpu(&self) -> CpuId { + // SAFETY: The C API guarantees that `cpu` refers to a valid CPU number. + unsafe { CpuId::from_u32_unchecked(self.as_ref().cpu) } + } + + /// Returns the minimum frequency for the [`Policy`]. + #[inline] + pub fn min(&self) -> Hertz { + Hertz::from_khz(self.as_ref().min as usize) + } + + /// Set the minimum frequency for the [`Policy`]. + #[inline] + pub fn set_min(&mut self, min: Hertz) -> &mut Self { + self.as_mut_ref().min = min.as_khz() as u32; + self + } + + /// Returns the maximum frequency for the [`Policy`]. + #[inline] + pub fn max(&self) -> Hertz { + Hertz::from_khz(self.as_ref().max as usize) + } + + /// Set the maximum frequency for the [`Policy`]. + #[inline] + pub fn set_max(&mut self, max: Hertz) -> &mut Self { + self.as_mut_ref().max = max.as_khz() as u32; + self + } + + /// Returns the current frequency for the [`Policy`]. + #[inline] + pub fn cur(&self) -> Hertz { + Hertz::from_khz(self.as_ref().cur as usize) + } + + /// Returns the suspend frequency for the [`Policy`]. + #[inline] + pub fn suspend_freq(&self) -> Hertz { + Hertz::from_khz(self.as_ref().suspend_freq as usize) + } + + /// Sets the suspend frequency for the [`Policy`]. + #[inline] + pub fn set_suspend_freq(&mut self, freq: Hertz) -> &mut Self { + self.as_mut_ref().suspend_freq = freq.as_khz() as u32; + self + } + + /// Provides a wrapper to the generic suspend routine. + #[inline] + pub fn generic_suspend(&mut self) -> Result { + // SAFETY: By the type invariant, the pointer stored in `self` is valid. + to_result(unsafe { bindings::cpufreq_generic_suspend(self.as_mut_ref()) }) + } + + /// Provides a wrapper to the generic get routine. + #[inline] + pub fn generic_get(&self) -> Result<u32> { + // SAFETY: By the type invariant, the pointer stored in `self` is valid. + Ok(unsafe { bindings::cpufreq_generic_get(u32::from(self.cpu())) }) + } + + /// Provides a wrapper to the register with energy model using the OPP core. + #[cfg(CONFIG_PM_OPP)] + #[inline] + pub fn register_em_opp(&mut self) { + // SAFETY: By the type invariant, the pointer stored in `self` is valid. + unsafe { bindings::cpufreq_register_em_with_opp(self.as_mut_ref()) }; + } + + /// Gets [`cpumask::Cpumask`] for a cpufreq [`Policy`]. + #[inline] + pub fn cpus(&mut self) -> &mut cpumask::Cpumask { + // SAFETY: The pointer to `cpus` is valid for writing and remains valid for the lifetime of + // the returned reference. + unsafe { cpumask::CpumaskVar::from_raw_mut(&mut self.as_mut_ref().cpus) } + } + + /// Sets clock for the [`Policy`]. + /// + /// # Safety + /// + /// The caller must guarantee that the returned [`Clk`] is not dropped while it is getting used + /// by the C code. + #[cfg(CONFIG_COMMON_CLK)] + pub unsafe fn set_clk(&mut self, dev: &Device, name: Option<&CStr>) -> Result<Clk> { + let clk = Clk::get(dev, name)?; + self.as_mut_ref().clk = clk.as_raw(); + Ok(clk) + } + + /// Allows / disallows frequency switching code to run on any CPU. + #[inline] + pub fn set_dvfs_possible_from_any_cpu(&mut self, val: bool) -> &mut Self { + self.as_mut_ref().dvfs_possible_from_any_cpu = val; + self + } + + /// Returns if fast switching of frequencies is possible or not. + #[inline] + pub fn fast_switch_possible(&self) -> bool { + self.as_ref().fast_switch_possible + } + + /// Enables / disables fast frequency switching. + #[inline] + pub fn set_fast_switch_possible(&mut self, val: bool) -> &mut Self { + self.as_mut_ref().fast_switch_possible = val; + self + } + + /// Sets transition latency (in nanoseconds) for the [`Policy`]. + #[inline] + pub fn set_transition_latency_ns(&mut self, latency_ns: u32) -> &mut Self { + self.as_mut_ref().cpuinfo.transition_latency = latency_ns; + self + } + + /// Sets cpuinfo `min_freq`. + #[inline] + pub fn set_cpuinfo_min_freq(&mut self, min_freq: Hertz) -> &mut Self { + self.as_mut_ref().cpuinfo.min_freq = min_freq.as_khz() as u32; + self + } + + /// Sets cpuinfo `max_freq`. + #[inline] + pub fn set_cpuinfo_max_freq(&mut self, max_freq: Hertz) -> &mut Self { + self.as_mut_ref().cpuinfo.max_freq = max_freq.as_khz() as u32; + self + } + + /// Set `transition_delay_us`, i.e. the minimum time between successive frequency change + /// requests. + #[inline] + pub fn set_transition_delay_us(&mut self, transition_delay_us: u32) -> &mut Self { + self.as_mut_ref().transition_delay_us = transition_delay_us; + self + } + + /// Returns reference to the CPU frequency [`Table`] for the [`Policy`]. + pub fn freq_table(&self) -> Result<&Table> { + if self.as_ref().freq_table.is_null() { + return Err(EINVAL); + } + + // SAFETY: The `freq_table` is guaranteed to be valid for reading and remains valid for the + // lifetime of the returned reference. + Ok(unsafe { Table::from_raw(self.as_ref().freq_table) }) + } + + /// Sets the CPU frequency [`Table`] for the [`Policy`]. + /// + /// # Safety + /// + /// The caller must guarantee that the [`Table`] is not dropped while it is getting used by the + /// C code. + #[inline] + pub unsafe fn set_freq_table(&mut self, table: &Table) -> &mut Self { + self.as_mut_ref().freq_table = table.as_raw(); + self + } + + /// Returns the [`Policy`]'s private data. + pub fn data<T: ForeignOwnable>(&mut self) -> Option<<T>::Borrowed<'_>> { + if self.as_ref().driver_data.is_null() { + None + } else { + // SAFETY: The data is earlier set from [`set_data`]. + Some(unsafe { T::borrow(self.as_ref().driver_data.cast()) }) + } + } + + /// Sets the private data of the [`Policy`] using a foreign-ownable wrapper. + /// + /// # Errors + /// + /// Returns `EBUSY` if private data is already set. + fn set_data<T: ForeignOwnable>(&mut self, data: T) -> Result { + if self.as_ref().driver_data.is_null() { + // Transfer the ownership of the data to the foreign interface. + self.as_mut_ref().driver_data = <T as ForeignOwnable>::into_foreign(data).cast(); + Ok(()) + } else { + Err(EBUSY) + } + } + + /// Clears and returns ownership of the private data. + fn clear_data<T: ForeignOwnable>(&mut self) -> Option<T> { + if self.as_ref().driver_data.is_null() { + None + } else { + let data = Some( + // SAFETY: The data is earlier set by us from [`set_data`]. It is safe to take + // back the ownership of the data from the foreign interface. + unsafe { <T as ForeignOwnable>::from_foreign(self.as_ref().driver_data.cast()) }, + ); + self.as_mut_ref().driver_data = ptr::null_mut(); + data + } + } +} + +/// CPU frequency policy created from a CPU number. +/// +/// This struct represents the CPU frequency policy obtained for a specific CPU, providing safe +/// access to the underlying `cpufreq_policy` and ensuring proper cleanup when the `PolicyCpu` is +/// dropped. +struct PolicyCpu<'a>(&'a mut Policy); + +impl<'a> PolicyCpu<'a> { + fn from_cpu(cpu: CpuId) -> Result<Self> { + // SAFETY: It is safe to call `cpufreq_cpu_get` for any valid CPU. + let ptr = from_err_ptr(unsafe { bindings::cpufreq_cpu_get(u32::from(cpu)) })?; + + Ok(Self( + // SAFETY: The `ptr` is guaranteed to be valid and remains valid for the lifetime of + // the returned reference. + unsafe { Policy::from_raw_mut(ptr) }, + )) + } +} + +impl<'a> Deref for PolicyCpu<'a> { + type Target = Policy; + + fn deref(&self) -> &Self::Target { + self.0 + } +} + +impl<'a> DerefMut for PolicyCpu<'a> { + fn deref_mut(&mut self) -> &mut Policy { + self.0 + } +} + +impl<'a> Drop for PolicyCpu<'a> { + fn drop(&mut self) { + // SAFETY: The underlying pointer is guaranteed to be valid for the lifetime of `self`. + unsafe { bindings::cpufreq_cpu_put(self.0.as_raw()) }; + } +} + +/// CPU frequency driver. +/// +/// Implement this trait to provide a CPU frequency driver and its callbacks. +/// +/// Reference: <https://docs.kernel.org/cpu-freq/cpu-drivers.html> +#[vtable] +pub trait Driver { + /// Driver's name. + const NAME: &'static CStr; + + /// Driver's flags. + const FLAGS: u16; + + /// Boost support. + const BOOST_ENABLED: bool; + + /// Policy specific data. + /// + /// Require that `PData` implements `ForeignOwnable`. We guarantee to never move the underlying + /// wrapped data structure. + type PData: ForeignOwnable; + + /// Driver's `init` callback. + fn init(policy: &mut Policy) -> Result<Self::PData>; + + /// Driver's `exit` callback. + fn exit(_policy: &mut Policy, _data: Option<Self::PData>) -> Result { + build_error!(VTABLE_DEFAULT_ERROR) + } + + /// Driver's `online` callback. + fn online(_policy: &mut Policy) -> Result { + build_error!(VTABLE_DEFAULT_ERROR) + } + + /// Driver's `offline` callback. + fn offline(_policy: &mut Policy) -> Result { + build_error!(VTABLE_DEFAULT_ERROR) + } + + /// Driver's `suspend` callback. + fn suspend(_policy: &mut Policy) -> Result { + build_error!(VTABLE_DEFAULT_ERROR) + } + + /// Driver's `resume` callback. + fn resume(_policy: &mut Policy) -> Result { + build_error!(VTABLE_DEFAULT_ERROR) + } + + /// Driver's `ready` callback. + fn ready(_policy: &mut Policy) { + build_error!(VTABLE_DEFAULT_ERROR) + } + + /// Driver's `verify` callback. + fn verify(data: &mut PolicyData) -> Result; + + /// Driver's `setpolicy` callback. + fn setpolicy(_policy: &mut Policy) -> Result { + build_error!(VTABLE_DEFAULT_ERROR) + } + + /// Driver's `target` callback. + fn target(_policy: &mut Policy, _target_freq: u32, _relation: Relation) -> Result { + build_error!(VTABLE_DEFAULT_ERROR) + } + + /// Driver's `target_index` callback. + fn target_index(_policy: &mut Policy, _index: TableIndex) -> Result { + build_error!(VTABLE_DEFAULT_ERROR) + } + + /// Driver's `fast_switch` callback. + fn fast_switch(_policy: &mut Policy, _target_freq: u32) -> u32 { + build_error!(VTABLE_DEFAULT_ERROR) + } + + /// Driver's `adjust_perf` callback. + fn adjust_perf(_policy: &mut Policy, _min_perf: usize, _target_perf: usize, _capacity: usize) { + build_error!(VTABLE_DEFAULT_ERROR) + } + + /// Driver's `get_intermediate` callback. + fn get_intermediate(_policy: &mut Policy, _index: TableIndex) -> u32 { + build_error!(VTABLE_DEFAULT_ERROR) + } + + /// Driver's `target_intermediate` callback. + fn target_intermediate(_policy: &mut Policy, _index: TableIndex) -> Result { + build_error!(VTABLE_DEFAULT_ERROR) + } + + /// Driver's `get` callback. + fn get(_policy: &mut Policy) -> Result<u32> { + build_error!(VTABLE_DEFAULT_ERROR) + } + + /// Driver's `update_limits` callback. + fn update_limits(_policy: &mut Policy) { + build_error!(VTABLE_DEFAULT_ERROR) + } + + /// Driver's `bios_limit` callback. + fn bios_limit(_policy: &mut Policy, _limit: &mut u32) -> Result { + build_error!(VTABLE_DEFAULT_ERROR) + } + + /// Driver's `set_boost` callback. + fn set_boost(_policy: &mut Policy, _state: i32) -> Result { + build_error!(VTABLE_DEFAULT_ERROR) + } + + /// Driver's `register_em` callback. + fn register_em(_policy: &mut Policy) { + build_error!(VTABLE_DEFAULT_ERROR) + } +} + +/// CPU frequency driver Registration. +/// +/// # Examples +/// +/// The following example demonstrates how to register a cpufreq driver. +/// +/// ``` +/// use kernel::{ +/// cpufreq, +/// c_str, +/// device::{Core, Device}, +/// macros::vtable, +/// of, platform, +/// sync::Arc, +/// }; +/// struct SampleDevice; +/// +/// #[derive(Default)] +/// struct SampleDriver; +/// +/// #[vtable] +/// impl cpufreq::Driver for SampleDriver { +/// const NAME: &'static CStr = c_str!("cpufreq-sample"); +/// const FLAGS: u16 = cpufreq::flags::NEED_INITIAL_FREQ_CHECK | cpufreq::flags::IS_COOLING_DEV; +/// const BOOST_ENABLED: bool = true; +/// +/// type PData = Arc<SampleDevice>; +/// +/// fn init(policy: &mut cpufreq::Policy) -> Result<Self::PData> { +/// // Initialize here +/// Ok(Arc::new(SampleDevice, GFP_KERNEL)?) +/// } +/// +/// fn exit(_policy: &mut cpufreq::Policy, _data: Option<Self::PData>) -> Result { +/// Ok(()) +/// } +/// +/// fn suspend(policy: &mut cpufreq::Policy) -> Result { +/// policy.generic_suspend() +/// } +/// +/// fn verify(data: &mut cpufreq::PolicyData) -> Result { +/// data.generic_verify() +/// } +/// +/// fn target_index(policy: &mut cpufreq::Policy, index: cpufreq::TableIndex) -> Result { +/// // Update CPU frequency +/// Ok(()) +/// } +/// +/// fn get(policy: &mut cpufreq::Policy) -> Result<u32> { +/// policy.generic_get() +/// } +/// } +/// +/// impl platform::Driver for SampleDriver { +/// type IdInfo = (); +/// const OF_ID_TABLE: Option<of::IdTable<Self::IdInfo>> = None; +/// +/// fn probe( +/// pdev: &platform::Device<Core>, +/// _id_info: Option<&Self::IdInfo>, +/// ) -> impl PinInit<Self, Error> { +/// cpufreq::Registration::<SampleDriver>::new_foreign_owned(pdev.as_ref())?; +/// Ok(Self {}) +/// } +/// } +/// ``` +#[repr(transparent)] +pub struct Registration<T: Driver>(KBox<UnsafeCell<bindings::cpufreq_driver>>, PhantomData<T>); + +/// SAFETY: `Registration` doesn't offer any methods or access to fields when shared between threads +/// or CPUs, so it is safe to share it. +unsafe impl<T: Driver> Sync for Registration<T> {} + +#[allow(clippy::non_send_fields_in_send_ty)] +/// SAFETY: Registration with and unregistration from the cpufreq subsystem can happen from any +/// thread. +unsafe impl<T: Driver> Send for Registration<T> {} + +impl<T: Driver> Registration<T> { + const VTABLE: bindings::cpufreq_driver = bindings::cpufreq_driver { + name: Self::copy_name(T::NAME), + boost_enabled: T::BOOST_ENABLED, + flags: T::FLAGS, + + // Initialize mandatory callbacks. + init: Some(Self::init_callback), + verify: Some(Self::verify_callback), + + // Initialize optional callbacks based on the traits of `T`. + setpolicy: if T::HAS_SETPOLICY { + Some(Self::setpolicy_callback) + } else { + None + }, + target: if T::HAS_TARGET { + Some(Self::target_callback) + } else { + None + }, + target_index: if T::HAS_TARGET_INDEX { + Some(Self::target_index_callback) + } else { + None + }, + fast_switch: if T::HAS_FAST_SWITCH { + Some(Self::fast_switch_callback) + } else { + None + }, + adjust_perf: if T::HAS_ADJUST_PERF { + Some(Self::adjust_perf_callback) + } else { + None + }, + get_intermediate: if T::HAS_GET_INTERMEDIATE { + Some(Self::get_intermediate_callback) + } else { + None + }, + target_intermediate: if T::HAS_TARGET_INTERMEDIATE { + Some(Self::target_intermediate_callback) + } else { + None + }, + get: if T::HAS_GET { + Some(Self::get_callback) + } else { + None + }, + update_limits: if T::HAS_UPDATE_LIMITS { + Some(Self::update_limits_callback) + } else { + None + }, + bios_limit: if T::HAS_BIOS_LIMIT { + Some(Self::bios_limit_callback) + } else { + None + }, + online: if T::HAS_ONLINE { + Some(Self::online_callback) + } else { + None + }, + offline: if T::HAS_OFFLINE { + Some(Self::offline_callback) + } else { + None + }, + exit: if T::HAS_EXIT { + Some(Self::exit_callback) + } else { + None + }, + suspend: if T::HAS_SUSPEND { + Some(Self::suspend_callback) + } else { + None + }, + resume: if T::HAS_RESUME { + Some(Self::resume_callback) + } else { + None + }, + ready: if T::HAS_READY { + Some(Self::ready_callback) + } else { + None + }, + set_boost: if T::HAS_SET_BOOST { + Some(Self::set_boost_callback) + } else { + None + }, + register_em: if T::HAS_REGISTER_EM { + Some(Self::register_em_callback) + } else { + None + }, + ..pin_init::zeroed() + }; + + const fn copy_name(name: &'static CStr) -> [c_char; CPUFREQ_NAME_LEN] { + let src = name.to_bytes_with_nul(); + let mut dst = [0; CPUFREQ_NAME_LEN]; + + build_assert!(src.len() <= CPUFREQ_NAME_LEN); + + let mut i = 0; + while i < src.len() { + dst[i] = src[i]; + i += 1; + } + + dst + } + + /// Registers a CPU frequency driver with the cpufreq core. + pub fn new() -> Result<Self> { + // We can't use `&Self::VTABLE` directly because the cpufreq core modifies some fields in + // the C `struct cpufreq_driver`, which requires a mutable reference. + let mut drv = KBox::new(UnsafeCell::new(Self::VTABLE), GFP_KERNEL)?; + + // SAFETY: `drv` is guaranteed to be valid for the lifetime of `Registration`. + to_result(unsafe { bindings::cpufreq_register_driver(drv.get_mut()) })?; + + Ok(Self(drv, PhantomData)) + } + + /// Same as [`Registration::new`], but does not return a [`Registration`] instance. + /// + /// Instead the [`Registration`] is owned by [`devres::register`] and will be dropped, once the + /// device is detached. + pub fn new_foreign_owned(dev: &Device<Bound>) -> Result + where + T: 'static, + { + devres::register(dev, Self::new()?, GFP_KERNEL) + } +} + +/// CPU frequency driver callbacks. +impl<T: Driver> Registration<T> { + /// Driver's `init` callback. + /// + /// # Safety + /// + /// - This function may only be called from the cpufreq C infrastructure. + /// - The pointer arguments must be valid pointers. + unsafe extern "C" fn init_callback(ptr: *mut bindings::cpufreq_policy) -> c_int { + from_result(|| { + // SAFETY: The `ptr` is guaranteed to be valid by the contract with the C code for the + // lifetime of `policy`. + let policy = unsafe { Policy::from_raw_mut(ptr) }; + + let data = T::init(policy)?; + policy.set_data(data)?; + Ok(0) + }) + } + + /// Driver's `exit` callback. + /// + /// # Safety + /// + /// - This function may only be called from the cpufreq C infrastructure. + /// - The pointer arguments must be valid pointers. + unsafe extern "C" fn exit_callback(ptr: *mut bindings::cpufreq_policy) { + // SAFETY: The `ptr` is guaranteed to be valid by the contract with the C code for the + // lifetime of `policy`. + let policy = unsafe { Policy::from_raw_mut(ptr) }; + + let data = policy.clear_data(); + let _ = T::exit(policy, data); + } + + /// Driver's `online` callback. + /// + /// # Safety + /// + /// - This function may only be called from the cpufreq C infrastructure. + /// - The pointer arguments must be valid pointers. + unsafe extern "C" fn online_callback(ptr: *mut bindings::cpufreq_policy) -> c_int { + from_result(|| { + // SAFETY: The `ptr` is guaranteed to be valid by the contract with the C code for the + // lifetime of `policy`. + let policy = unsafe { Policy::from_raw_mut(ptr) }; + T::online(policy).map(|()| 0) + }) + } + + /// Driver's `offline` callback. + /// + /// # Safety + /// + /// - This function may only be called from the cpufreq C infrastructure. + /// - The pointer arguments must be valid pointers. + unsafe extern "C" fn offline_callback(ptr: *mut bindings::cpufreq_policy) -> c_int { + from_result(|| { + // SAFETY: The `ptr` is guaranteed to be valid by the contract with the C code for the + // lifetime of `policy`. + let policy = unsafe { Policy::from_raw_mut(ptr) }; + T::offline(policy).map(|()| 0) + }) + } + + /// Driver's `suspend` callback. + /// + /// # Safety + /// + /// - This function may only be called from the cpufreq C infrastructure. + /// - The pointer arguments must be valid pointers. + unsafe extern "C" fn suspend_callback(ptr: *mut bindings::cpufreq_policy) -> c_int { + from_result(|| { + // SAFETY: The `ptr` is guaranteed to be valid by the contract with the C code for the + // lifetime of `policy`. + let policy = unsafe { Policy::from_raw_mut(ptr) }; + T::suspend(policy).map(|()| 0) + }) + } + + /// Driver's `resume` callback. + /// + /// # Safety + /// + /// - This function may only be called from the cpufreq C infrastructure. + /// - The pointer arguments must be valid pointers. + unsafe extern "C" fn resume_callback(ptr: *mut bindings::cpufreq_policy) -> c_int { + from_result(|| { + // SAFETY: The `ptr` is guaranteed to be valid by the contract with the C code for the + // lifetime of `policy`. + let policy = unsafe { Policy::from_raw_mut(ptr) }; + T::resume(policy).map(|()| 0) + }) + } + + /// Driver's `ready` callback. + /// + /// # Safety + /// + /// - This function may only be called from the cpufreq C infrastructure. + /// - The pointer arguments must be valid pointers. + unsafe extern "C" fn ready_callback(ptr: *mut bindings::cpufreq_policy) { + // SAFETY: The `ptr` is guaranteed to be valid by the contract with the C code for the + // lifetime of `policy`. + let policy = unsafe { Policy::from_raw_mut(ptr) }; + T::ready(policy); + } + + /// Driver's `verify` callback. + /// + /// # Safety + /// + /// - This function may only be called from the cpufreq C infrastructure. + /// - The pointer arguments must be valid pointers. + unsafe extern "C" fn verify_callback(ptr: *mut bindings::cpufreq_policy_data) -> c_int { + from_result(|| { + // SAFETY: The `ptr` is guaranteed to be valid by the contract with the C code for the + // lifetime of `policy`. + let data = unsafe { PolicyData::from_raw_mut(ptr) }; + T::verify(data).map(|()| 0) + }) + } + + /// Driver's `setpolicy` callback. + /// + /// # Safety + /// + /// - This function may only be called from the cpufreq C infrastructure. + /// - The pointer arguments must be valid pointers. + unsafe extern "C" fn setpolicy_callback(ptr: *mut bindings::cpufreq_policy) -> c_int { + from_result(|| { + // SAFETY: The `ptr` is guaranteed to be valid by the contract with the C code for the + // lifetime of `policy`. + let policy = unsafe { Policy::from_raw_mut(ptr) }; + T::setpolicy(policy).map(|()| 0) + }) + } + + /// Driver's `target` callback. + /// + /// # Safety + /// + /// - This function may only be called from the cpufreq C infrastructure. + /// - The pointer arguments must be valid pointers. + unsafe extern "C" fn target_callback( + ptr: *mut bindings::cpufreq_policy, + target_freq: c_uint, + relation: c_uint, + ) -> c_int { + from_result(|| { + // SAFETY: The `ptr` is guaranteed to be valid by the contract with the C code for the + // lifetime of `policy`. + let policy = unsafe { Policy::from_raw_mut(ptr) }; + T::target(policy, target_freq, Relation::new(relation)?).map(|()| 0) + }) + } + + /// Driver's `target_index` callback. + /// + /// # Safety + /// + /// - This function may only be called from the cpufreq C infrastructure. + /// - The pointer arguments must be valid pointers. + unsafe extern "C" fn target_index_callback( + ptr: *mut bindings::cpufreq_policy, + index: c_uint, + ) -> c_int { + from_result(|| { + // SAFETY: The `ptr` is guaranteed to be valid by the contract with the C code for the + // lifetime of `policy`. + let policy = unsafe { Policy::from_raw_mut(ptr) }; + + // SAFETY: The C code guarantees that `index` corresponds to a valid entry in the + // frequency table. + let index = unsafe { TableIndex::new(index as usize) }; + + T::target_index(policy, index).map(|()| 0) + }) + } + + /// Driver's `fast_switch` callback. + /// + /// # Safety + /// + /// - This function may only be called from the cpufreq C infrastructure. + /// - The pointer arguments must be valid pointers. + unsafe extern "C" fn fast_switch_callback( + ptr: *mut bindings::cpufreq_policy, + target_freq: c_uint, + ) -> c_uint { + // SAFETY: The `ptr` is guaranteed to be valid by the contract with the C code for the + // lifetime of `policy`. + let policy = unsafe { Policy::from_raw_mut(ptr) }; + T::fast_switch(policy, target_freq) + } + + /// Driver's `adjust_perf` callback. + /// + /// # Safety + /// + /// - This function may only be called from the cpufreq C infrastructure. + unsafe extern "C" fn adjust_perf_callback( + cpu: c_uint, + min_perf: c_ulong, + target_perf: c_ulong, + capacity: c_ulong, + ) { + // SAFETY: The C API guarantees that `cpu` refers to a valid CPU number. + let cpu_id = unsafe { CpuId::from_u32_unchecked(cpu) }; + + if let Ok(mut policy) = PolicyCpu::from_cpu(cpu_id) { + T::adjust_perf(&mut policy, min_perf, target_perf, capacity); + } + } + + /// Driver's `get_intermediate` callback. + /// + /// # Safety + /// + /// - This function may only be called from the cpufreq C infrastructure. + /// - The pointer arguments must be valid pointers. + unsafe extern "C" fn get_intermediate_callback( + ptr: *mut bindings::cpufreq_policy, + index: c_uint, + ) -> c_uint { + // SAFETY: The `ptr` is guaranteed to be valid by the contract with the C code for the + // lifetime of `policy`. + let policy = unsafe { Policy::from_raw_mut(ptr) }; + + // SAFETY: The C code guarantees that `index` corresponds to a valid entry in the + // frequency table. + let index = unsafe { TableIndex::new(index as usize) }; + + T::get_intermediate(policy, index) + } + + /// Driver's `target_intermediate` callback. + /// + /// # Safety + /// + /// - This function may only be called from the cpufreq C infrastructure. + /// - The pointer arguments must be valid pointers. + unsafe extern "C" fn target_intermediate_callback( + ptr: *mut bindings::cpufreq_policy, + index: c_uint, + ) -> c_int { + from_result(|| { + // SAFETY: The `ptr` is guaranteed to be valid by the contract with the C code for the + // lifetime of `policy`. + let policy = unsafe { Policy::from_raw_mut(ptr) }; + + // SAFETY: The C code guarantees that `index` corresponds to a valid entry in the + // frequency table. + let index = unsafe { TableIndex::new(index as usize) }; + + T::target_intermediate(policy, index).map(|()| 0) + }) + } + + /// Driver's `get` callback. + /// + /// # Safety + /// + /// - This function may only be called from the cpufreq C infrastructure. + unsafe extern "C" fn get_callback(cpu: c_uint) -> c_uint { + // SAFETY: The C API guarantees that `cpu` refers to a valid CPU number. + let cpu_id = unsafe { CpuId::from_u32_unchecked(cpu) }; + + PolicyCpu::from_cpu(cpu_id).map_or(0, |mut policy| T::get(&mut policy).map_or(0, |f| f)) + } + + /// Driver's `update_limit` callback. + /// + /// # Safety + /// + /// - This function may only be called from the cpufreq C infrastructure. + /// - The pointer arguments must be valid pointers. + unsafe extern "C" fn update_limits_callback(ptr: *mut bindings::cpufreq_policy) { + // SAFETY: The `ptr` is guaranteed to be valid by the contract with the C code for the + // lifetime of `policy`. + let policy = unsafe { Policy::from_raw_mut(ptr) }; + T::update_limits(policy); + } + + /// Driver's `bios_limit` callback. + /// + /// # Safety + /// + /// - This function may only be called from the cpufreq C infrastructure. + /// - The pointer arguments must be valid pointers. + unsafe extern "C" fn bios_limit_callback(cpu: c_int, limit: *mut c_uint) -> c_int { + // SAFETY: The C API guarantees that `cpu` refers to a valid CPU number. + let cpu_id = unsafe { CpuId::from_i32_unchecked(cpu) }; + + from_result(|| { + let mut policy = PolicyCpu::from_cpu(cpu_id)?; + + // SAFETY: `limit` is guaranteed by the C code to be valid. + T::bios_limit(&mut policy, &mut (unsafe { *limit })).map(|()| 0) + }) + } + + /// Driver's `set_boost` callback. + /// + /// # Safety + /// + /// - This function may only be called from the cpufreq C infrastructure. + /// - The pointer arguments must be valid pointers. + unsafe extern "C" fn set_boost_callback( + ptr: *mut bindings::cpufreq_policy, + state: c_int, + ) -> c_int { + from_result(|| { + // SAFETY: The `ptr` is guaranteed to be valid by the contract with the C code for the + // lifetime of `policy`. + let policy = unsafe { Policy::from_raw_mut(ptr) }; + T::set_boost(policy, state).map(|()| 0) + }) + } + + /// Driver's `register_em` callback. + /// + /// # Safety + /// + /// - This function may only be called from the cpufreq C infrastructure. + /// - The pointer arguments must be valid pointers. + unsafe extern "C" fn register_em_callback(ptr: *mut bindings::cpufreq_policy) { + // SAFETY: The `ptr` is guaranteed to be valid by the contract with the C code for the + // lifetime of `policy`. + let policy = unsafe { Policy::from_raw_mut(ptr) }; + T::register_em(policy); + } +} + +impl<T: Driver> Drop for Registration<T> { + /// Unregisters with the cpufreq core. + fn drop(&mut self) { + // SAFETY: `self.0` is guaranteed to be valid for the lifetime of `Registration`. + unsafe { bindings::cpufreq_unregister_driver(self.0.get_mut()) }; + } +} diff --git a/rust/kernel/cpumask.rs b/rust/kernel/cpumask.rs new file mode 100644 index 000000000000..c1d17826ae7b --- /dev/null +++ b/rust/kernel/cpumask.rs @@ -0,0 +1,343 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! CPU Mask abstractions. +//! +//! C header: [`include/linux/cpumask.h`](srctree/include/linux/cpumask.h) + +use crate::{ + alloc::{AllocError, Flags}, + cpu::CpuId, + prelude::*, + types::Opaque, +}; + +#[cfg(CONFIG_CPUMASK_OFFSTACK)] +use core::ptr::{self, NonNull}; + +use core::ops::{Deref, DerefMut}; + +/// A CPU Mask. +/// +/// Rust abstraction for the C `struct cpumask`. +/// +/// # Invariants +/// +/// A [`Cpumask`] instance always corresponds to a valid C `struct cpumask`. +/// +/// The callers must ensure that the `struct cpumask` is valid for access and +/// remains valid for the lifetime of the returned reference. +/// +/// # Examples +/// +/// The following example demonstrates how to update a [`Cpumask`]. +/// +/// ``` +/// use kernel::bindings; +/// use kernel::cpu::CpuId; +/// use kernel::cpumask::Cpumask; +/// +/// fn set_clear_cpu(ptr: *mut bindings::cpumask, set_cpu: CpuId, clear_cpu: CpuId) { +/// // SAFETY: The `ptr` is valid for writing and remains valid for the lifetime of the +/// // returned reference. +/// let mask = unsafe { Cpumask::as_mut_ref(ptr) }; +/// +/// mask.set(set_cpu); +/// mask.clear(clear_cpu); +/// } +/// ``` +#[repr(transparent)] +pub struct Cpumask(Opaque<bindings::cpumask>); + +impl Cpumask { + /// Creates a mutable reference to an existing `struct cpumask` pointer. + /// + /// # Safety + /// + /// The caller must ensure that `ptr` is valid for writing and remains valid for the lifetime + /// of the returned reference. + pub unsafe fn as_mut_ref<'a>(ptr: *mut bindings::cpumask) -> &'a mut Self { + // SAFETY: Guaranteed by the safety requirements of the function. + // + // INVARIANT: The caller ensures that `ptr` is valid for writing and remains valid for the + // lifetime of the returned reference. + unsafe { &mut *ptr.cast() } + } + + /// Creates a reference to an existing `struct cpumask` pointer. + /// + /// # Safety + /// + /// The caller must ensure that `ptr` is valid for reading and remains valid for the lifetime + /// of the returned reference. + pub unsafe fn as_ref<'a>(ptr: *const bindings::cpumask) -> &'a Self { + // SAFETY: Guaranteed by the safety requirements of the function. + // + // INVARIANT: The caller ensures that `ptr` is valid for reading and remains valid for the + // lifetime of the returned reference. + unsafe { &*ptr.cast() } + } + + /// Obtain the raw `struct cpumask` pointer. + pub fn as_raw(&self) -> *mut bindings::cpumask { + let this: *const Self = self; + this.cast_mut().cast() + } + + /// Set `cpu` in the cpumask. + /// + /// ATTENTION: Contrary to C, this Rust `set()` method is non-atomic. + /// This mismatches kernel naming convention and corresponds to the C + /// function `__cpumask_set_cpu()`. + #[inline] + pub fn set(&mut self, cpu: CpuId) { + // SAFETY: By the type invariant, `self.as_raw` is a valid argument to `__cpumask_set_cpu`. + unsafe { bindings::__cpumask_set_cpu(u32::from(cpu), self.as_raw()) }; + } + + /// Clear `cpu` in the cpumask. + /// + /// ATTENTION: Contrary to C, this Rust `clear()` method is non-atomic. + /// This mismatches kernel naming convention and corresponds to the C + /// function `__cpumask_clear_cpu()`. + #[inline] + pub fn clear(&mut self, cpu: CpuId) { + // SAFETY: By the type invariant, `self.as_raw` is a valid argument to + // `__cpumask_clear_cpu`. + unsafe { bindings::__cpumask_clear_cpu(i32::from(cpu), self.as_raw()) }; + } + + /// Test `cpu` in the cpumask. + /// + /// Equivalent to the kernel's `cpumask_test_cpu` API. + #[inline] + pub fn test(&self, cpu: CpuId) -> bool { + // SAFETY: By the type invariant, `self.as_raw` is a valid argument to `cpumask_test_cpu`. + unsafe { bindings::cpumask_test_cpu(i32::from(cpu), self.as_raw()) } + } + + /// Set all CPUs in the cpumask. + /// + /// Equivalent to the kernel's `cpumask_setall` API. + #[inline] + pub fn setall(&mut self) { + // SAFETY: By the type invariant, `self.as_raw` is a valid argument to `cpumask_setall`. + unsafe { bindings::cpumask_setall(self.as_raw()) }; + } + + /// Checks if cpumask is empty. + /// + /// Equivalent to the kernel's `cpumask_empty` API. + #[inline] + pub fn empty(&self) -> bool { + // SAFETY: By the type invariant, `self.as_raw` is a valid argument to `cpumask_empty`. + unsafe { bindings::cpumask_empty(self.as_raw()) } + } + + /// Checks if cpumask is full. + /// + /// Equivalent to the kernel's `cpumask_full` API. + #[inline] + pub fn full(&self) -> bool { + // SAFETY: By the type invariant, `self.as_raw` is a valid argument to `cpumask_full`. + unsafe { bindings::cpumask_full(self.as_raw()) } + } + + /// Get weight of the cpumask. + /// + /// Equivalent to the kernel's `cpumask_weight` API. + #[inline] + pub fn weight(&self) -> u32 { + // SAFETY: By the type invariant, `self.as_raw` is a valid argument to `cpumask_weight`. + unsafe { bindings::cpumask_weight(self.as_raw()) } + } + + /// Copy cpumask. + /// + /// Equivalent to the kernel's `cpumask_copy` API. + #[inline] + pub fn copy(&self, dstp: &mut Self) { + // SAFETY: By the type invariant, `Self::as_raw` is a valid argument to `cpumask_copy`. + unsafe { bindings::cpumask_copy(dstp.as_raw(), self.as_raw()) }; + } +} + +/// A CPU Mask pointer. +/// +/// Rust abstraction for the C `struct cpumask_var_t`. +/// +/// # Invariants +/// +/// A [`CpumaskVar`] instance always corresponds to a valid C `struct cpumask_var_t`. +/// +/// The callers must ensure that the `struct cpumask_var_t` is valid for access and remains valid +/// for the lifetime of [`CpumaskVar`]. +/// +/// # Examples +/// +/// The following example demonstrates how to create and update a [`CpumaskVar`]. +/// +/// ``` +/// use kernel::cpu::CpuId; +/// use kernel::cpumask::CpumaskVar; +/// +/// let mut mask = CpumaskVar::new_zero(GFP_KERNEL).unwrap(); +/// +/// assert!(mask.empty()); +/// let mut count = 0; +/// +/// let cpu2 = CpuId::from_u32(2); +/// if let Some(cpu) = cpu2 { +/// mask.set(cpu); +/// assert!(mask.test(cpu)); +/// count += 1; +/// } +/// +/// let cpu3 = CpuId::from_u32(3); +/// if let Some(cpu) = cpu3 { +/// mask.set(cpu); +/// assert!(mask.test(cpu)); +/// count += 1; +/// } +/// +/// assert_eq!(mask.weight(), count); +/// +/// let mask2 = CpumaskVar::try_clone(&mask).unwrap(); +/// +/// if let Some(cpu) = cpu2 { +/// assert!(mask2.test(cpu)); +/// } +/// +/// if let Some(cpu) = cpu3 { +/// assert!(mask2.test(cpu)); +/// } +/// assert_eq!(mask2.weight(), count); +/// ``` +#[repr(transparent)] +pub struct CpumaskVar { + #[cfg(CONFIG_CPUMASK_OFFSTACK)] + ptr: NonNull<Cpumask>, + #[cfg(not(CONFIG_CPUMASK_OFFSTACK))] + mask: Cpumask, +} + +impl CpumaskVar { + /// Creates a zero-initialized instance of the [`CpumaskVar`]. + pub fn new_zero(_flags: Flags) -> Result<Self, AllocError> { + Ok(Self { + #[cfg(CONFIG_CPUMASK_OFFSTACK)] + ptr: { + let mut ptr: *mut bindings::cpumask = ptr::null_mut(); + + // SAFETY: It is safe to call this method as the reference to `ptr` is valid. + // + // INVARIANT: The associated memory is freed when the `CpumaskVar` goes out of + // scope. + unsafe { bindings::zalloc_cpumask_var(&mut ptr, _flags.as_raw()) }; + NonNull::new(ptr.cast()).ok_or(AllocError)? + }, + + #[cfg(not(CONFIG_CPUMASK_OFFSTACK))] + mask: Cpumask(Opaque::zeroed()), + }) + } + + /// Creates an instance of the [`CpumaskVar`]. + /// + /// # Safety + /// + /// The caller must ensure that the returned [`CpumaskVar`] is properly initialized before + /// getting used. + pub unsafe fn new(_flags: Flags) -> Result<Self, AllocError> { + Ok(Self { + #[cfg(CONFIG_CPUMASK_OFFSTACK)] + ptr: { + let mut ptr: *mut bindings::cpumask = ptr::null_mut(); + + // SAFETY: It is safe to call this method as the reference to `ptr` is valid. + // + // INVARIANT: The associated memory is freed when the `CpumaskVar` goes out of + // scope. + unsafe { bindings::alloc_cpumask_var(&mut ptr, _flags.as_raw()) }; + NonNull::new(ptr.cast()).ok_or(AllocError)? + }, + #[cfg(not(CONFIG_CPUMASK_OFFSTACK))] + mask: Cpumask(Opaque::uninit()), + }) + } + + /// Creates a mutable reference to an existing `struct cpumask_var_t` pointer. + /// + /// # Safety + /// + /// The caller must ensure that `ptr` is valid for writing and remains valid for the lifetime + /// of the returned reference. + pub unsafe fn from_raw_mut<'a>(ptr: *mut bindings::cpumask_var_t) -> &'a mut Self { + // SAFETY: Guaranteed by the safety requirements of the function. + // + // INVARIANT: The caller ensures that `ptr` is valid for writing and remains valid for the + // lifetime of the returned reference. + unsafe { &mut *ptr.cast() } + } + + /// Creates a reference to an existing `struct cpumask_var_t` pointer. + /// + /// # Safety + /// + /// The caller must ensure that `ptr` is valid for reading and remains valid for the lifetime + /// of the returned reference. + pub unsafe fn from_raw<'a>(ptr: *const bindings::cpumask_var_t) -> &'a Self { + // SAFETY: Guaranteed by the safety requirements of the function. + // + // INVARIANT: The caller ensures that `ptr` is valid for reading and remains valid for the + // lifetime of the returned reference. + unsafe { &*ptr.cast() } + } + + /// Clones cpumask. + pub fn try_clone(cpumask: &Cpumask) -> Result<Self> { + // SAFETY: The returned cpumask_var is initialized right after this call. + let mut cpumask_var = unsafe { Self::new(GFP_KERNEL) }?; + + cpumask.copy(&mut cpumask_var); + Ok(cpumask_var) + } +} + +// Make [`CpumaskVar`] behave like a pointer to [`Cpumask`]. +impl Deref for CpumaskVar { + type Target = Cpumask; + + #[cfg(CONFIG_CPUMASK_OFFSTACK)] + fn deref(&self) -> &Self::Target { + // SAFETY: The caller owns CpumaskVar, so it is safe to deref the cpumask. + unsafe { &*self.ptr.as_ptr() } + } + + #[cfg(not(CONFIG_CPUMASK_OFFSTACK))] + fn deref(&self) -> &Self::Target { + &self.mask + } +} + +impl DerefMut for CpumaskVar { + #[cfg(CONFIG_CPUMASK_OFFSTACK)] + fn deref_mut(&mut self) -> &mut Cpumask { + // SAFETY: The caller owns CpumaskVar, so it is safe to deref the cpumask. + unsafe { self.ptr.as_mut() } + } + + #[cfg(not(CONFIG_CPUMASK_OFFSTACK))] + fn deref_mut(&mut self) -> &mut Cpumask { + &mut self.mask + } +} + +impl Drop for CpumaskVar { + fn drop(&mut self) { + #[cfg(CONFIG_CPUMASK_OFFSTACK)] + // SAFETY: By the type invariant, `self.as_raw` is a valid argument to `free_cpumask_var`. + unsafe { + bindings::free_cpumask_var(self.as_raw()) + }; + } +} diff --git a/rust/kernel/cred.rs b/rust/kernel/cred.rs new file mode 100644 index 000000000000..ffa156b9df37 --- /dev/null +++ b/rust/kernel/cred.rs @@ -0,0 +1,92 @@ +// SPDX-License-Identifier: GPL-2.0 + +// Copyright (C) 2024 Google LLC. + +//! Credentials management. +//! +//! C header: [`include/linux/cred.h`](srctree/include/linux/cred.h). +//! +//! Reference: <https://www.kernel.org/doc/html/latest/security/credentials.html> + +use crate::{bindings, sync::aref::AlwaysRefCounted, task::Kuid, types::Opaque}; + +/// Wraps the kernel's `struct cred`. +/// +/// Credentials are used for various security checks in the kernel. +/// +/// Most fields of credentials are immutable. When things have their credentials changed, that +/// happens by replacing the credential instead of changing an existing credential. See the [kernel +/// documentation][ref] for more info on this. +/// +/// # Invariants +/// +/// Instances of this type are always ref-counted, that is, a call to `get_cred` ensures that the +/// allocation remains valid at least until the matching call to `put_cred`. +/// +/// [ref]: https://www.kernel.org/doc/html/latest/security/credentials.html +#[repr(transparent)] +pub struct Credential(Opaque<bindings::cred>); + +// SAFETY: +// - `Credential::dec_ref` can be called from any thread. +// - It is okay to send ownership of `Credential` across thread boundaries. +unsafe impl Send for Credential {} + +// SAFETY: It's OK to access `Credential` through shared references from other threads because +// we're either accessing properties that don't change or that are properly synchronised by C code. +unsafe impl Sync for Credential {} + +impl Credential { + /// Creates a reference to a [`Credential`] from a valid pointer. + /// + /// # Safety + /// + /// The caller must ensure that `ptr` is valid and remains valid for the lifetime of the + /// returned [`Credential`] reference. + #[inline] + pub unsafe fn from_ptr<'a>(ptr: *const bindings::cred) -> &'a Credential { + // SAFETY: The safety requirements guarantee the validity of the dereference, while the + // `Credential` type being transparent makes the cast ok. + unsafe { &*ptr.cast() } + } + + /// Returns a raw pointer to the inner credential. + #[inline] + pub fn as_ptr(&self) -> *const bindings::cred { + self.0.get() + } + + /// Get the id for this security context. + #[inline] + pub fn get_secid(&self) -> u32 { + let mut secid = 0; + // SAFETY: The invariants of this type ensures that the pointer is valid. + unsafe { bindings::security_cred_getsecid(self.0.get(), &mut secid) }; + secid + } + + /// Returns the effective UID of the given credential. + #[inline] + pub fn euid(&self) -> Kuid { + // SAFETY: By the type invariant, we know that `self.0` is valid. Furthermore, the `euid` + // field of a credential is never changed after initialization, so there is no potential + // for data races. + Kuid::from_raw(unsafe { (*self.0.get()).euid }) + } +} + +// SAFETY: The type invariants guarantee that `Credential` is always ref-counted. +unsafe impl AlwaysRefCounted for Credential { + #[inline] + fn inc_ref(&self) { + // SAFETY: The existence of a shared reference means that the refcount is nonzero. + unsafe { bindings::get_cred(self.0.get()) }; + } + + #[inline] + unsafe fn dec_ref(obj: core::ptr::NonNull<Credential>) { + // SAFETY: The safety requirements guarantee that the refcount is nonzero. The cast is okay + // because `Credential` has the same representation as `struct cred`. + unsafe { bindings::put_cred(obj.cast().as_ptr()) }; + } +} diff --git a/rust/kernel/debugfs.rs b/rust/kernel/debugfs.rs new file mode 100644 index 000000000000..facad81e8290 --- /dev/null +++ b/rust/kernel/debugfs.rs @@ -0,0 +1,700 @@ +// SPDX-License-Identifier: GPL-2.0 +// Copyright (C) 2025 Google LLC. + +//! DebugFS Abstraction +//! +//! C header: [`include/linux/debugfs.h`](srctree/include/linux/debugfs.h) + +// When DebugFS is disabled, many parameters are dead. Linting for this isn't helpful. +#![cfg_attr(not(CONFIG_DEBUG_FS), allow(unused_variables))] + +use crate::fmt; +use crate::prelude::*; +use crate::str::CStr; +#[cfg(CONFIG_DEBUG_FS)] +use crate::sync::Arc; +use crate::uaccess::UserSliceReader; +use core::marker::PhantomData; +use core::marker::PhantomPinned; +#[cfg(CONFIG_DEBUG_FS)] +use core::mem::ManuallyDrop; +use core::ops::Deref; + +mod traits; +pub use traits::{BinaryReader, BinaryReaderMut, BinaryWriter, Reader, Writer}; + +mod callback_adapters; +use callback_adapters::{FormatAdapter, NoWriter, WritableAdapter}; +mod file_ops; +use file_ops::{ + BinaryReadFile, BinaryReadWriteFile, BinaryWriteFile, FileOps, ReadFile, ReadWriteFile, + WriteFile, +}; +#[cfg(CONFIG_DEBUG_FS)] +mod entry; +#[cfg(CONFIG_DEBUG_FS)] +use entry::Entry; + +/// Owning handle to a DebugFS directory. +/// +/// The directory in the filesystem represented by [`Dir`] will be removed when handle has been +/// dropped *and* all children have been removed. +// If we have a parent, we hold a reference to it in the `Entry`. This prevents the `dentry` +// we point to from being cleaned up if our parent `Dir`/`Entry` is dropped before us. +// +// The `None` option indicates that the `Arc` could not be allocated, so our children would not be +// able to refer to us. In this case, we need to silently fail. All future child directories/files +// will silently fail as well. +#[derive(Clone)] +pub struct Dir(#[cfg(CONFIG_DEBUG_FS)] Option<Arc<Entry<'static>>>); + +impl Dir { + /// Create a new directory in DebugFS. If `parent` is [`None`], it will be created at the root. + fn create(name: &CStr, parent: Option<&Dir>) -> Self { + #[cfg(CONFIG_DEBUG_FS)] + { + let parent_entry = match parent { + // If the parent couldn't be allocated, just early-return + Some(Dir(None)) => return Self(None), + Some(Dir(Some(entry))) => Some(entry.clone()), + None => None, + }; + Self( + // If Arc creation fails, the `Entry` will be dropped, so the directory will be + // cleaned up. + Arc::new(Entry::dynamic_dir(name, parent_entry), GFP_KERNEL).ok(), + ) + } + #[cfg(not(CONFIG_DEBUG_FS))] + Self() + } + + /// Creates a DebugFS file which will own the data produced by the initializer provided in + /// `data`. + fn create_file<'a, T, E: 'a>( + &'a self, + name: &'a CStr, + data: impl PinInit<T, E> + 'a, + file_ops: &'static FileOps<T>, + ) -> impl PinInit<File<T>, E> + 'a + where + T: Sync + 'static, + { + let scope = Scope::<T>::new(data, move |data| { + #[cfg(CONFIG_DEBUG_FS)] + if let Some(parent) = &self.0 { + // SAFETY: Because data derives from a scope, and our entry will be dropped before + // the data is dropped, it is guaranteed to outlive the entry we return. + unsafe { Entry::dynamic_file(name, parent.clone(), data, file_ops) } + } else { + Entry::empty() + } + }); + try_pin_init! { + File { + scope <- scope + } ? E + } + } + + /// Create a new directory in DebugFS at the root. + /// + /// # Examples + /// + /// ``` + /// # use kernel::c_str; + /// # use kernel::debugfs::Dir; + /// let debugfs = Dir::new(c_str!("parent")); + /// ``` + pub fn new(name: &CStr) -> Self { + Dir::create(name, None) + } + + /// Creates a subdirectory within this directory. + /// + /// # Examples + /// + /// ``` + /// # use kernel::c_str; + /// # use kernel::debugfs::Dir; + /// let parent = Dir::new(c_str!("parent")); + /// let child = parent.subdir(c_str!("child")); + /// ``` + pub fn subdir(&self, name: &CStr) -> Self { + Dir::create(name, Some(self)) + } + + /// Creates a read-only file in this directory. + /// + /// The file's contents are produced by invoking [`Writer::write`] on the value initialized by + /// `data`. + /// + /// # Examples + /// + /// ``` + /// # use kernel::c_str; + /// # use kernel::debugfs::Dir; + /// # use kernel::prelude::*; + /// # let dir = Dir::new(c_str!("my_debugfs_dir")); + /// let file = KBox::pin_init(dir.read_only_file(c_str!("foo"), 200), GFP_KERNEL)?; + /// // "my_debugfs_dir/foo" now contains the number 200. + /// // The file is removed when `file` is dropped. + /// # Ok::<(), Error>(()) + /// ``` + pub fn read_only_file<'a, T, E: 'a>( + &'a self, + name: &'a CStr, + data: impl PinInit<T, E> + 'a, + ) -> impl PinInit<File<T>, E> + 'a + where + T: Writer + Send + Sync + 'static, + { + let file_ops = &<T as ReadFile<_>>::FILE_OPS; + self.create_file(name, data, file_ops) + } + + /// Creates a read-only binary file in this directory. + /// + /// The file's contents are produced by invoking [`BinaryWriter::write_to_slice`] on the value + /// initialized by `data`. + /// + /// # Examples + /// + /// ``` + /// # use kernel::c_str; + /// # use kernel::debugfs::Dir; + /// # use kernel::prelude::*; + /// # let dir = Dir::new(c_str!("my_debugfs_dir")); + /// let file = KBox::pin_init(dir.read_binary_file(c_str!("foo"), [0x1, 0x2]), GFP_KERNEL)?; + /// # Ok::<(), Error>(()) + /// ``` + pub fn read_binary_file<'a, T, E: 'a>( + &'a self, + name: &'a CStr, + data: impl PinInit<T, E> + 'a, + ) -> impl PinInit<File<T>, E> + 'a + where + T: BinaryWriter + Send + Sync + 'static, + { + self.create_file(name, data, &T::FILE_OPS) + } + + /// Creates a read-only file in this directory, with contents from a callback. + /// + /// `f` must be a function item or a non-capturing closure. + /// This is statically asserted and not a safety requirement. + /// + /// # Examples + /// + /// ``` + /// # use core::sync::atomic::{AtomicU32, Ordering}; + /// # use kernel::c_str; + /// # use kernel::debugfs::Dir; + /// # use kernel::prelude::*; + /// # let dir = Dir::new(c_str!("foo")); + /// let file = KBox::pin_init( + /// dir.read_callback_file(c_str!("bar"), + /// AtomicU32::new(3), + /// &|val, f| { + /// let out = val.load(Ordering::Relaxed); + /// writeln!(f, "{out:#010x}") + /// }), + /// GFP_KERNEL)?; + /// // Reading "foo/bar" will show "0x00000003". + /// file.store(10, Ordering::Relaxed); + /// // Reading "foo/bar" will now show "0x0000000a". + /// # Ok::<(), Error>(()) + /// ``` + pub fn read_callback_file<'a, T, E: 'a, F>( + &'a self, + name: &'a CStr, + data: impl PinInit<T, E> + 'a, + _f: &'static F, + ) -> impl PinInit<File<T>, E> + 'a + where + T: Send + Sync + 'static, + F: Fn(&T, &mut fmt::Formatter<'_>) -> fmt::Result + Send + Sync, + { + let file_ops = <FormatAdapter<T, F>>::FILE_OPS.adapt(); + self.create_file(name, data, file_ops) + } + + /// Creates a read-write file in this directory. + /// + /// Reading the file uses the [`Writer`] implementation. + /// Writing to the file uses the [`Reader`] implementation. + pub fn read_write_file<'a, T, E: 'a>( + &'a self, + name: &'a CStr, + data: impl PinInit<T, E> + 'a, + ) -> impl PinInit<File<T>, E> + 'a + where + T: Writer + Reader + Send + Sync + 'static, + { + let file_ops = &<T as ReadWriteFile<_>>::FILE_OPS; + self.create_file(name, data, file_ops) + } + + /// Creates a read-write binary file in this directory. + /// + /// Reading the file uses the [`BinaryWriter`] implementation. + /// Writing to the file uses the [`BinaryReader`] implementation. + pub fn read_write_binary_file<'a, T, E: 'a>( + &'a self, + name: &'a CStr, + data: impl PinInit<T, E> + 'a, + ) -> impl PinInit<File<T>, E> + 'a + where + T: BinaryWriter + BinaryReader + Send + Sync + 'static, + { + let file_ops = &<T as BinaryReadWriteFile<_>>::FILE_OPS; + self.create_file(name, data, file_ops) + } + + /// Creates a read-write file in this directory, with logic from callbacks. + /// + /// Reading from the file is handled by `f`. Writing to the file is handled by `w`. + /// + /// `f` and `w` must be function items or non-capturing closures. + /// This is statically asserted and not a safety requirement. + pub fn read_write_callback_file<'a, T, E: 'a, F, W>( + &'a self, + name: &'a CStr, + data: impl PinInit<T, E> + 'a, + _f: &'static F, + _w: &'static W, + ) -> impl PinInit<File<T>, E> + 'a + where + T: Send + Sync + 'static, + F: Fn(&T, &mut fmt::Formatter<'_>) -> fmt::Result + Send + Sync, + W: Fn(&T, &mut UserSliceReader) -> Result + Send + Sync, + { + let file_ops = + <WritableAdapter<FormatAdapter<T, F>, W> as file_ops::ReadWriteFile<_>>::FILE_OPS + .adapt() + .adapt(); + self.create_file(name, data, file_ops) + } + + /// Creates a write-only file in this directory. + /// + /// The file owns its backing data. Writing to the file uses the [`Reader`] + /// implementation. + /// + /// The file is removed when the returned [`File`] is dropped. + pub fn write_only_file<'a, T, E: 'a>( + &'a self, + name: &'a CStr, + data: impl PinInit<T, E> + 'a, + ) -> impl PinInit<File<T>, E> + 'a + where + T: Reader + Send + Sync + 'static, + { + self.create_file(name, data, &T::FILE_OPS) + } + + /// Creates a write-only binary file in this directory. + /// + /// The file owns its backing data. Writing to the file uses the [`BinaryReader`] + /// implementation. + /// + /// The file is removed when the returned [`File`] is dropped. + pub fn write_binary_file<'a, T, E: 'a>( + &'a self, + name: &'a CStr, + data: impl PinInit<T, E> + 'a, + ) -> impl PinInit<File<T>, E> + 'a + where + T: BinaryReader + Send + Sync + 'static, + { + self.create_file(name, data, &T::FILE_OPS) + } + + /// Creates a write-only file in this directory, with write logic from a callback. + /// + /// `w` must be a function item or a non-capturing closure. + /// This is statically asserted and not a safety requirement. + pub fn write_callback_file<'a, T, E: 'a, W>( + &'a self, + name: &'a CStr, + data: impl PinInit<T, E> + 'a, + _w: &'static W, + ) -> impl PinInit<File<T>, E> + 'a + where + T: Send + Sync + 'static, + W: Fn(&T, &mut UserSliceReader) -> Result + Send + Sync, + { + let file_ops = <WritableAdapter<NoWriter<T>, W> as WriteFile<_>>::FILE_OPS + .adapt() + .adapt(); + self.create_file(name, data, file_ops) + } + + // While this function is safe, it is intentionally not public because it's a bit of a + // footgun. + // + // Unless you also extract the `entry` later and schedule it for `Drop` at the appropriate + // time, a `ScopedDir` with a `Dir` parent will never be deleted. + fn scoped_dir<'data>(&self, name: &CStr) -> ScopedDir<'data, 'static> { + #[cfg(CONFIG_DEBUG_FS)] + { + let parent_entry = match &self.0 { + None => return ScopedDir::empty(), + Some(entry) => entry.clone(), + }; + ScopedDir { + entry: ManuallyDrop::new(Entry::dynamic_dir(name, Some(parent_entry))), + _phantom: PhantomData, + } + } + #[cfg(not(CONFIG_DEBUG_FS))] + ScopedDir::empty() + } + + /// Creates a new scope, which is a directory associated with some data `T`. + /// + /// The created directory will be a subdirectory of `self`. The `init` closure is called to + /// populate the directory with files and subdirectories. These files can reference the data + /// stored in the scope. + /// + /// The entire directory tree created within the scope will be removed when the returned + /// `Scope` handle is dropped. + pub fn scope<'a, T: 'a, E: 'a, F>( + &'a self, + data: impl PinInit<T, E> + 'a, + name: &'a CStr, + init: F, + ) -> impl PinInit<Scope<T>, E> + 'a + where + F: for<'data, 'dir> FnOnce(&'data T, &'dir ScopedDir<'data, 'dir>) + 'a, + { + Scope::new(data, |data| { + let scoped = self.scoped_dir(name); + init(data, &scoped); + scoped.into_entry() + }) + } +} + +#[pin_data] +/// Handle to a DebugFS scope, which ensures that attached `data` will outlive the DebugFS entry +/// without moving. +/// +/// This is internally used to back [`File`], and used in the API to represent the attachment +/// of a directory lifetime to a data structure which may be jointly accessed by a number of +/// different files. +/// +/// When dropped, a `Scope` will remove all directories and files in the filesystem backed by the +/// attached data structure prior to releasing the attached data. +pub struct Scope<T> { + // This order is load-bearing for drops - `_entry` must be dropped before `data`. + #[cfg(CONFIG_DEBUG_FS)] + _entry: Entry<'static>, + #[pin] + data: T, + // Even if `T` is `Unpin`, we still can't allow it to be moved. + #[pin] + _pin: PhantomPinned, +} + +#[pin_data] +/// Handle to a DebugFS file, owning its backing data. +/// +/// When dropped, the DebugFS file will be removed and the attached data will be dropped. +pub struct File<T> { + #[pin] + scope: Scope<T>, +} + +#[cfg(not(CONFIG_DEBUG_FS))] +impl<'b, T: 'b> Scope<T> { + fn new<E: 'b, F>(data: impl PinInit<T, E> + 'b, init: F) -> impl PinInit<Self, E> + 'b + where + F: for<'a> FnOnce(&'a T) + 'b, + { + try_pin_init! { + Self { + data <- data, + _pin: PhantomPinned + } ? E + } + .pin_chain(|scope| { + init(&scope.data); + Ok(()) + }) + } +} + +#[cfg(CONFIG_DEBUG_FS)] +impl<'b, T: 'b> Scope<T> { + fn entry_mut(self: Pin<&mut Self>) -> &mut Entry<'static> { + // SAFETY: _entry is not structurally pinned. + unsafe { &mut Pin::into_inner_unchecked(self)._entry } + } + + fn new<E: 'b, F>(data: impl PinInit<T, E> + 'b, init: F) -> impl PinInit<Self, E> + 'b + where + F: for<'a> FnOnce(&'a T) -> Entry<'static> + 'b, + { + try_pin_init! { + Self { + _entry: Entry::empty(), + data <- data, + _pin: PhantomPinned + } ? E + } + .pin_chain(|scope| { + *scope.entry_mut() = init(&scope.data); + Ok(()) + }) + } +} + +impl<'a, T: 'a> Scope<T> { + /// Creates a new scope, which is a directory at the root of the debugfs filesystem, + /// associated with some data `T`. + /// + /// The `init` closure is called to populate the directory with files and subdirectories. These + /// files can reference the data stored in the scope. + /// + /// The entire directory tree created within the scope will be removed when the returned + /// `Scope` handle is dropped. + pub fn dir<E: 'a, F>( + data: impl PinInit<T, E> + 'a, + name: &'a CStr, + init: F, + ) -> impl PinInit<Self, E> + 'a + where + F: for<'data, 'dir> FnOnce(&'data T, &'dir ScopedDir<'data, 'dir>) + 'a, + { + Scope::new(data, |data| { + let scoped = ScopedDir::new(name); + init(data, &scoped); + scoped.into_entry() + }) + } +} + +impl<T> Deref for Scope<T> { + type Target = T; + fn deref(&self) -> &T { + &self.data + } +} + +impl<T> Deref for File<T> { + type Target = T; + fn deref(&self) -> &T { + &self.scope + } +} + +/// A handle to a directory which will live at most `'dir`, accessing data that will live for at +/// least `'data`. +/// +/// Dropping a ScopedDir will not delete or clean it up, this is expected to occur through dropping +/// the `Scope` that created it. +pub struct ScopedDir<'data, 'dir> { + #[cfg(CONFIG_DEBUG_FS)] + entry: ManuallyDrop<Entry<'dir>>, + _phantom: PhantomData<fn(&'data ()) -> &'dir ()>, +} + +impl<'data, 'dir> ScopedDir<'data, 'dir> { + /// Creates a subdirectory inside this `ScopedDir`. + /// + /// The returned directory handle cannot outlive this one. + pub fn dir<'dir2>(&'dir2 self, name: &CStr) -> ScopedDir<'data, 'dir2> { + #[cfg(not(CONFIG_DEBUG_FS))] + let _ = name; + ScopedDir { + #[cfg(CONFIG_DEBUG_FS)] + entry: ManuallyDrop::new(Entry::dir(name, Some(&*self.entry))), + _phantom: PhantomData, + } + } + + fn create_file<T: Sync>(&self, name: &CStr, data: &'data T, vtable: &'static FileOps<T>) { + #[cfg(CONFIG_DEBUG_FS)] + core::mem::forget(Entry::file(name, &self.entry, data, vtable)); + } + + /// Creates a read-only file in this directory. + /// + /// The file's contents are produced by invoking [`Writer::write`]. + /// + /// This function does not produce an owning handle to the file. The created + /// file is removed when the [`Scope`] that this directory belongs + /// to is dropped. + pub fn read_only_file<T: Writer + Send + Sync + 'static>(&self, name: &CStr, data: &'data T) { + self.create_file(name, data, &T::FILE_OPS) + } + + /// Creates a read-only binary file in this directory. + /// + /// The file's contents are produced by invoking [`BinaryWriter::write_to_slice`]. + /// + /// This function does not produce an owning handle to the file. The created file is removed + /// when the [`Scope`] that this directory belongs to is dropped. + pub fn read_binary_file<T: BinaryWriter + Send + Sync + 'static>( + &self, + name: &CStr, + data: &'data T, + ) { + self.create_file(name, data, &T::FILE_OPS) + } + + /// Creates a read-only file in this directory, with contents from a callback. + /// + /// The file contents are generated by calling `f` with `data`. + /// + /// + /// `f` must be a function item or a non-capturing closure. + /// This is statically asserted and not a safety requirement. + /// + /// This function does not produce an owning handle to the file. The created + /// file is removed when the [`Scope`] that this directory belongs + /// to is dropped. + pub fn read_callback_file<T, F>(&self, name: &CStr, data: &'data T, _f: &'static F) + where + T: Send + Sync + 'static, + F: Fn(&T, &mut fmt::Formatter<'_>) -> fmt::Result + Send + Sync, + { + let vtable = <FormatAdapter<T, F> as ReadFile<_>>::FILE_OPS.adapt(); + self.create_file(name, data, vtable) + } + + /// Creates a read-write file in this directory. + /// + /// Reading the file uses the [`Writer`] implementation on `data`. Writing to the file uses + /// the [`Reader`] implementation on `data`. + /// + /// This function does not produce an owning handle to the file. The created + /// file is removed when the [`Scope`] that this directory belongs + /// to is dropped. + pub fn read_write_file<T: Writer + Reader + Send + Sync + 'static>( + &self, + name: &CStr, + data: &'data T, + ) { + let vtable = &<T as ReadWriteFile<_>>::FILE_OPS; + self.create_file(name, data, vtable) + } + + /// Creates a read-write binary file in this directory. + /// + /// Reading the file uses the [`BinaryWriter`] implementation on `data`. Writing to the file + /// uses the [`BinaryReader`] implementation on `data`. + /// + /// This function does not produce an owning handle to the file. The created file is removed + /// when the [`Scope`] that this directory belongs to is dropped. + pub fn read_write_binary_file<T: BinaryWriter + BinaryReader + Send + Sync + 'static>( + &self, + name: &CStr, + data: &'data T, + ) { + let vtable = &<T as BinaryReadWriteFile<_>>::FILE_OPS; + self.create_file(name, data, vtable) + } + + /// Creates a read-write file in this directory, with logic from callbacks. + /// + /// Reading from the file is handled by `f`. Writing to the file is handled by `w`. + /// + /// `f` and `w` must be function items or non-capturing closures. + /// This is statically asserted and not a safety requirement. + /// + /// This function does not produce an owning handle to the file. The created + /// file is removed when the [`Scope`] that this directory belongs + /// to is dropped. + pub fn read_write_callback_file<T, F, W>( + &self, + name: &CStr, + data: &'data T, + _f: &'static F, + _w: &'static W, + ) where + T: Send + Sync + 'static, + F: Fn(&T, &mut fmt::Formatter<'_>) -> fmt::Result + Send + Sync, + W: Fn(&T, &mut UserSliceReader) -> Result + Send + Sync, + { + let vtable = <WritableAdapter<FormatAdapter<T, F>, W> as ReadWriteFile<_>>::FILE_OPS + .adapt() + .adapt(); + self.create_file(name, data, vtable) + } + + /// Creates a write-only file in this directory. + /// + /// Writing to the file uses the [`Reader`] implementation on `data`. + /// + /// This function does not produce an owning handle to the file. The created + /// file is removed when the [`Scope`] that this directory belongs + /// to is dropped. + pub fn write_only_file<T: Reader + Send + Sync + 'static>(&self, name: &CStr, data: &'data T) { + let vtable = &<T as WriteFile<_>>::FILE_OPS; + self.create_file(name, data, vtable) + } + + /// Creates a write-only binary file in this directory. + /// + /// Writing to the file uses the [`BinaryReader`] implementation on `data`. + /// + /// This function does not produce an owning handle to the file. The created file is removed + /// when the [`Scope`] that this directory belongs to is dropped. + pub fn write_binary_file<T: BinaryReader + Send + Sync + 'static>( + &self, + name: &CStr, + data: &'data T, + ) { + self.create_file(name, data, &T::FILE_OPS) + } + + /// Creates a write-only file in this directory, with write logic from a callback. + /// + /// Writing to the file is handled by `w`. + /// + /// `w` must be a function item or a non-capturing closure. + /// This is statically asserted and not a safety requirement. + /// + /// This function does not produce an owning handle to the file. The created + /// file is removed when the [`Scope`] that this directory belongs + /// to is dropped. + pub fn write_only_callback_file<T, W>(&self, name: &CStr, data: &'data T, _w: &'static W) + where + T: Send + Sync + 'static, + W: Fn(&T, &mut UserSliceReader) -> Result + Send + Sync, + { + let vtable = &<WritableAdapter<NoWriter<T>, W> as WriteFile<_>>::FILE_OPS + .adapt() + .adapt(); + self.create_file(name, data, vtable) + } + + fn empty() -> Self { + ScopedDir { + #[cfg(CONFIG_DEBUG_FS)] + entry: ManuallyDrop::new(Entry::empty()), + _phantom: PhantomData, + } + } + #[cfg(CONFIG_DEBUG_FS)] + fn into_entry(self) -> Entry<'dir> { + ManuallyDrop::into_inner(self.entry) + } + #[cfg(not(CONFIG_DEBUG_FS))] + fn into_entry(self) {} +} + +impl<'data> ScopedDir<'data, 'static> { + // This is safe, but intentionally not exported due to footgun status. A ScopedDir with no + // parent will never be released by default, and needs to have its entry extracted and used + // somewhere. + fn new(name: &CStr) -> ScopedDir<'data, 'static> { + ScopedDir { + #[cfg(CONFIG_DEBUG_FS)] + entry: ManuallyDrop::new(Entry::dir(name, None)), + _phantom: PhantomData, + } + } +} diff --git a/rust/kernel/debugfs/callback_adapters.rs b/rust/kernel/debugfs/callback_adapters.rs new file mode 100644 index 000000000000..a260d8dee051 --- /dev/null +++ b/rust/kernel/debugfs/callback_adapters.rs @@ -0,0 +1,121 @@ +// SPDX-License-Identifier: GPL-2.0 +// Copyright (C) 2025 Google LLC. + +//! Adapters which allow the user to supply a write or read implementation as a value rather +//! than a trait implementation. If provided, it will override the trait implementation. + +use super::{Reader, Writer}; +use crate::fmt; +use crate::prelude::*; +use crate::uaccess::UserSliceReader; +use core::marker::PhantomData; +use core::ops::Deref; + +/// # Safety +/// +/// To implement this trait, it must be safe to cast a `&Self` to a `&Inner`. +/// It is intended for use in unstacking adapters out of `FileOps` backings. +pub(crate) unsafe trait Adapter { + type Inner; +} + +/// Adapter to implement `Reader` via a callback with the same representation as `T`. +/// +/// * Layer it on top of `WriterAdapter` if you want to add a custom callback for `write`. +/// * Layer it on top of `NoWriter` to pass through any support present on the underlying type. +/// +/// # Invariants +/// +/// If an instance for `WritableAdapter<_, W>` is constructed, `W` is inhabited. +#[repr(transparent)] +pub(crate) struct WritableAdapter<D, W> { + inner: D, + _writer: PhantomData<W>, +} + +// SAFETY: Stripping off the adapter only removes constraints +unsafe impl<D, W> Adapter for WritableAdapter<D, W> { + type Inner = D; +} + +impl<D: Writer, W> Writer for WritableAdapter<D, W> { + fn write(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + self.inner.write(fmt) + } +} + +impl<D: Deref, W> Reader for WritableAdapter<D, W> +where + W: Fn(&D::Target, &mut UserSliceReader) -> Result + Send + Sync + 'static, +{ + fn read_from_slice(&self, reader: &mut UserSliceReader) -> Result { + // SAFETY: WritableAdapter<_, W> can only be constructed if W is inhabited + let w: &W = unsafe { materialize_zst() }; + w(self.inner.deref(), reader) + } +} + +/// Adapter to implement `Writer` via a callback with the same representation as `T`. +/// +/// # Invariants +/// +/// If an instance for `FormatAdapter<_, F>` is constructed, `F` is inhabited. +#[repr(transparent)] +pub(crate) struct FormatAdapter<D, F> { + inner: D, + _formatter: PhantomData<F>, +} + +impl<D, F> Deref for FormatAdapter<D, F> { + type Target = D; + fn deref(&self) -> &D { + &self.inner + } +} + +impl<D, F> Writer for FormatAdapter<D, F> +where + F: Fn(&D, &mut fmt::Formatter<'_>) -> fmt::Result + 'static, +{ + fn write(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + // SAFETY: FormatAdapter<_, F> can only be constructed if F is inhabited + let f: &F = unsafe { materialize_zst() }; + f(&self.inner, fmt) + } +} + +// SAFETY: Stripping off the adapter only removes constraints +unsafe impl<D, F> Adapter for FormatAdapter<D, F> { + type Inner = D; +} + +#[repr(transparent)] +pub(crate) struct NoWriter<D> { + inner: D, +} + +// SAFETY: Stripping off the adapter only removes constraints +unsafe impl<D> Adapter for NoWriter<D> { + type Inner = D; +} + +impl<D> Deref for NoWriter<D> { + type Target = D; + fn deref(&self) -> &D { + &self.inner + } +} + +/// For types with a unique value, produce a static reference to it. +/// +/// # Safety +/// +/// The caller asserts that F is inhabited +unsafe fn materialize_zst<F>() -> &'static F { + const { assert!(core::mem::size_of::<F>() == 0) }; + let zst_dangle: core::ptr::NonNull<F> = core::ptr::NonNull::dangling(); + // SAFETY: While the pointer is dangling, it is a dangling pointer to a ZST, based on the + // assertion above. The type is also inhabited, by the caller's assertion. This means + // we can materialize it. + unsafe { zst_dangle.as_ref() } +} diff --git a/rust/kernel/debugfs/entry.rs b/rust/kernel/debugfs/entry.rs new file mode 100644 index 000000000000..706cb7f73d6c --- /dev/null +++ b/rust/kernel/debugfs/entry.rs @@ -0,0 +1,164 @@ +// SPDX-License-Identifier: GPL-2.0 +// Copyright (C) 2025 Google LLC. + +use crate::debugfs::file_ops::FileOps; +use crate::ffi::c_void; +use crate::str::{CStr, CStrExt as _}; +use crate::sync::Arc; +use core::marker::PhantomData; + +/// Owning handle to a DebugFS entry. +/// +/// # Invariants +/// +/// The wrapped pointer will always be `NULL`, an error, or an owned DebugFS `dentry`. +pub(crate) struct Entry<'a> { + entry: *mut bindings::dentry, + // If we were created with an owning parent, this is the keep-alive + _parent: Option<Arc<Entry<'static>>>, + // If we were created with a non-owning parent, this prevents us from outliving it + _phantom: PhantomData<&'a ()>, +} + +// SAFETY: [`Entry`] is just a `dentry` under the hood, which the API promises can be transferred +// between threads. +unsafe impl Send for Entry<'_> {} + +// SAFETY: All the C functions we call on the `dentry` pointer are threadsafe. +unsafe impl Sync for Entry<'_> {} + +impl Entry<'static> { + pub(crate) fn dynamic_dir(name: &CStr, parent: Option<Arc<Self>>) -> Self { + let parent_ptr = match &parent { + Some(entry) => entry.as_ptr(), + None => core::ptr::null_mut(), + }; + // SAFETY: The invariants of this function's arguments ensure the safety of this call. + // * `name` is a valid C string by the invariants of `&CStr`. + // * `parent_ptr` is either `NULL` (if `parent` is `None`), or a pointer to a valid + // `dentry` by our invariant. `debugfs_create_dir` handles `NULL` pointers correctly. + let entry = unsafe { bindings::debugfs_create_dir(name.as_char_ptr(), parent_ptr) }; + + Entry { + entry, + _parent: parent, + _phantom: PhantomData, + } + } + + /// # Safety + /// + /// * `data` must outlive the returned `Entry`. + pub(crate) unsafe fn dynamic_file<T>( + name: &CStr, + parent: Arc<Self>, + data: &T, + file_ops: &'static FileOps<T>, + ) -> Self { + // SAFETY: The invariants of this function's arguments ensure the safety of this call. + // * `name` is a valid C string by the invariants of `&CStr`. + // * `parent.as_ptr()` is a pointer to a valid `dentry` by invariant. + // * The caller guarantees that `data` will outlive the returned `Entry`. + // * The guarantees on `FileOps` assert the vtable will be compatible with the data we have + // provided. + let entry = unsafe { + bindings::debugfs_create_file_full( + name.as_char_ptr(), + file_ops.mode(), + parent.as_ptr(), + core::ptr::from_ref(data) as *mut c_void, + core::ptr::null(), + &**file_ops, + ) + }; + + Entry { + entry, + _parent: Some(parent), + _phantom: PhantomData, + } + } +} + +impl<'a> Entry<'a> { + pub(crate) fn dir(name: &CStr, parent: Option<&'a Entry<'_>>) -> Self { + let parent_ptr = match &parent { + Some(entry) => entry.as_ptr(), + None => core::ptr::null_mut(), + }; + // SAFETY: The invariants of this function's arguments ensure the safety of this call. + // * `name` is a valid C string by the invariants of `&CStr`. + // * `parent_ptr` is either `NULL` (if `parent` is `None`), or a pointer to a valid + // `dentry` (because `parent` is a valid reference to an `Entry`). The lifetime `'a` + // ensures that the parent outlives this entry. + let entry = unsafe { bindings::debugfs_create_dir(name.as_char_ptr(), parent_ptr) }; + + Entry { + entry, + _parent: None, + _phantom: PhantomData, + } + } + + pub(crate) fn file<T>( + name: &CStr, + parent: &'a Entry<'_>, + data: &'a T, + file_ops: &FileOps<T>, + ) -> Self { + // SAFETY: The invariants of this function's arguments ensure the safety of this call. + // * `name` is a valid C string by the invariants of `&CStr`. + // * `parent.as_ptr()` is a pointer to a valid `dentry` because we have `&'a Entry`. + // * `data` is a valid pointer to `T` for lifetime `'a`. + // * The returned `Entry` has lifetime `'a`, so it cannot outlive `parent` or `data`. + // * The caller guarantees that `vtable` is compatible with `data`. + // * The guarantees on `FileOps` assert the vtable will be compatible with the data we have + // provided. + let entry = unsafe { + bindings::debugfs_create_file_full( + name.as_char_ptr(), + file_ops.mode(), + parent.as_ptr(), + core::ptr::from_ref(data) as *mut c_void, + core::ptr::null(), + &**file_ops, + ) + }; + + Entry { + entry, + _parent: None, + _phantom: PhantomData, + } + } +} + +impl Entry<'_> { + /// Constructs a placeholder DebugFS [`Entry`]. + pub(crate) fn empty() -> Self { + Self { + entry: core::ptr::null_mut(), + _parent: None, + _phantom: PhantomData, + } + } + + /// Returns the pointer representation of the DebugFS directory. + /// + /// # Guarantees + /// + /// Due to the type invariant, the value returned from this function will always be an error + /// code, NULL, or a live DebugFS directory. If it is live, it will remain live at least as + /// long as this entry lives. + pub(crate) fn as_ptr(&self) -> *mut bindings::dentry { + self.entry + } +} + +impl Drop for Entry<'_> { + fn drop(&mut self) { + // SAFETY: `debugfs_remove` can take `NULL`, error values, and legal DebugFS dentries. + // `as_ptr` guarantees that the pointer is of this form. + unsafe { bindings::debugfs_remove(self.as_ptr()) } + } +} diff --git a/rust/kernel/debugfs/file_ops.rs b/rust/kernel/debugfs/file_ops.rs new file mode 100644 index 000000000000..8a0442d6dd7a --- /dev/null +++ b/rust/kernel/debugfs/file_ops.rs @@ -0,0 +1,385 @@ +// SPDX-License-Identifier: GPL-2.0 +// Copyright (C) 2025 Google LLC. + +use super::{BinaryReader, BinaryWriter, Reader, Writer}; +use crate::debugfs::callback_adapters::Adapter; +use crate::fmt; +use crate::fs::file; +use crate::prelude::*; +use crate::seq_file::SeqFile; +use crate::seq_print; +use crate::uaccess::UserSlice; +use core::marker::PhantomData; + +#[cfg(CONFIG_DEBUG_FS)] +use core::ops::Deref; + +/// # Invariant +/// +/// `FileOps<T>` will always contain an `operations` which is safe to use for a file backed +/// off an inode which has a pointer to a `T` in its private data that is safe to convert +/// into a reference. +pub(super) struct FileOps<T> { + #[cfg(CONFIG_DEBUG_FS)] + operations: bindings::file_operations, + #[cfg(CONFIG_DEBUG_FS)] + mode: u16, + _phantom: PhantomData<T>, +} + +impl<T> FileOps<T> { + /// # Safety + /// + /// The caller asserts that the provided `operations` is safe to use for a file whose + /// inode has a pointer to `T` in its private data that is safe to convert into a reference. + const unsafe fn new(operations: bindings::file_operations, mode: u16) -> Self { + Self { + #[cfg(CONFIG_DEBUG_FS)] + operations, + #[cfg(CONFIG_DEBUG_FS)] + mode, + _phantom: PhantomData, + } + } + + #[cfg(CONFIG_DEBUG_FS)] + pub(crate) const fn mode(&self) -> u16 { + self.mode + } +} + +impl<T: Adapter> FileOps<T> { + pub(super) const fn adapt(&self) -> &FileOps<T::Inner> { + // SAFETY: `Adapter` asserts that `T` can be legally cast to `T::Inner`. + unsafe { core::mem::transmute(self) } + } +} + +#[cfg(CONFIG_DEBUG_FS)] +impl<T> Deref for FileOps<T> { + type Target = bindings::file_operations; + + fn deref(&self) -> &Self::Target { + &self.operations + } +} + +struct WriterAdapter<T>(T); + +impl<'a, T: Writer> fmt::Display for WriterAdapter<&'a T> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.0.write(f) + } +} + +/// Implements `open` for `file_operations` via `single_open` to fill out a `seq_file`. +/// +/// # Safety +/// +/// * `inode`'s private pointer must point to a value of type `T` which will outlive the `inode` +/// and will not have any unique references alias it during the call. +/// * `file` must point to a live, not-yet-initialized file object. +unsafe extern "C" fn writer_open<T: Writer + Sync>( + inode: *mut bindings::inode, + file: *mut bindings::file, +) -> c_int { + // SAFETY: The caller ensures that `inode` is a valid pointer. + let data = unsafe { (*inode).i_private }; + // SAFETY: + // * `file` is acceptable by caller precondition. + // * `print_act` will be called on a `seq_file` with private data set to the third argument, + // so we meet its safety requirements. + // * The `data` pointer passed in the third argument is a valid `T` pointer that outlives + // this call by caller preconditions. + unsafe { bindings::single_open(file, Some(writer_act::<T>), data) } +} + +/// Prints private data stashed in a seq_file to that seq file. +/// +/// # Safety +/// +/// `seq` must point to a live `seq_file` whose private data is a valid pointer to a `T` which may +/// not have any unique references alias it during the call. +unsafe extern "C" fn writer_act<T: Writer + Sync>( + seq: *mut bindings::seq_file, + _: *mut c_void, +) -> c_int { + // SAFETY: By caller precondition, this pointer is valid pointer to a `T`, and + // there are not and will not be any unique references until we are done. + let data = unsafe { &*((*seq).private.cast::<T>()) }; + // SAFETY: By caller precondition, `seq_file` points to a live `seq_file`, so we can lift + // it. + let seq_file = unsafe { SeqFile::from_raw(seq) }; + seq_print!(seq_file, "{}", WriterAdapter(data)); + 0 +} + +// Work around lack of generic const items. +pub(crate) trait ReadFile<T> { + const FILE_OPS: FileOps<T>; +} + +impl<T: Writer + Sync> ReadFile<T> for T { + const FILE_OPS: FileOps<T> = { + let operations = bindings::file_operations { + read: Some(bindings::seq_read), + llseek: Some(bindings::seq_lseek), + release: Some(bindings::single_release), + open: Some(writer_open::<Self>), + // SAFETY: `file_operations` supports zeroes in all fields. + ..unsafe { core::mem::zeroed() } + }; + // SAFETY: `operations` is all stock `seq_file` implementations except for `writer_open`. + // `open`'s only requirement beyond what is provided to all open functions is that the + // inode's data pointer must point to a `T` that will outlive it, which matches the + // `FileOps` requirements. + unsafe { FileOps::new(operations, 0o400) } + }; +} + +fn read<T: Reader + Sync>(data: &T, buf: *const c_char, count: usize) -> isize { + let mut reader = UserSlice::new(UserPtr::from_ptr(buf as *mut c_void), count).reader(); + + if let Err(e) = data.read_from_slice(&mut reader) { + return e.to_errno() as isize; + } + + count as isize +} + +/// # Safety +/// +/// `file` must be a valid pointer to a `file` struct. +/// The `private_data` of the file must contain a valid pointer to a `seq_file` whose +/// `private` data in turn points to a `T` that implements `Reader`. +/// `buf` must be a valid user-space buffer. +pub(crate) unsafe extern "C" fn write<T: Reader + Sync>( + file: *mut bindings::file, + buf: *const c_char, + count: usize, + _ppos: *mut bindings::loff_t, +) -> isize { + // SAFETY: The file was opened with `single_open`, which sets `private_data` to a `seq_file`. + let seq = unsafe { &mut *((*file).private_data.cast::<bindings::seq_file>()) }; + // SAFETY: By caller precondition, this pointer is live and points to a value of type `T`. + let data = unsafe { &*(seq.private as *const T) }; + read(data, buf, count) +} + +// A trait to get the file operations for a type. +pub(crate) trait ReadWriteFile<T> { + const FILE_OPS: FileOps<T>; +} + +impl<T: Writer + Reader + Sync> ReadWriteFile<T> for T { + const FILE_OPS: FileOps<T> = { + let operations = bindings::file_operations { + open: Some(writer_open::<T>), + read: Some(bindings::seq_read), + write: Some(write::<T>), + llseek: Some(bindings::seq_lseek), + release: Some(bindings::single_release), + // SAFETY: `file_operations` supports zeroes in all fields. + ..unsafe { core::mem::zeroed() } + }; + // SAFETY: `operations` is all stock `seq_file` implementations except for `writer_open` + // and `write`. + // `writer_open`'s only requirement beyond what is provided to all open functions is that + // the inode's data pointer must point to a `T` that will outlive it, which matches the + // `FileOps` requirements. + // `write` only requires that the file's private data pointer points to `seq_file` + // which points to a `T` that will outlive it, which matches what `writer_open` + // provides. + unsafe { FileOps::new(operations, 0o600) } + }; +} + +/// # Safety +/// +/// `inode` must be a valid pointer to an `inode` struct. +/// `file` must be a valid pointer to a `file` struct. +unsafe extern "C" fn write_only_open( + inode: *mut bindings::inode, + file: *mut bindings::file, +) -> c_int { + // SAFETY: The caller ensures that `inode` and `file` are valid pointers. + unsafe { (*file).private_data = (*inode).i_private }; + 0 +} + +/// # Safety +/// +/// * `file` must be a valid pointer to a `file` struct. +/// * The `private_data` of the file must contain a valid pointer to a `T` that implements +/// `Reader`. +/// * `buf` must be a valid user-space buffer. +pub(crate) unsafe extern "C" fn write_only_write<T: Reader + Sync>( + file: *mut bindings::file, + buf: *const c_char, + count: usize, + _ppos: *mut bindings::loff_t, +) -> isize { + // SAFETY: The caller ensures that `file` is a valid pointer and that `private_data` holds a + // valid pointer to `T`. + let data = unsafe { &*((*file).private_data as *const T) }; + read(data, buf, count) +} + +pub(crate) trait WriteFile<T> { + const FILE_OPS: FileOps<T>; +} + +impl<T: Reader + Sync> WriteFile<T> for T { + const FILE_OPS: FileOps<T> = { + let operations = bindings::file_operations { + open: Some(write_only_open), + write: Some(write_only_write::<T>), + llseek: Some(bindings::noop_llseek), + // SAFETY: `file_operations` supports zeroes in all fields. + ..unsafe { core::mem::zeroed() } + }; + // SAFETY: + // * `write_only_open` populates the file private data with the inode private data + // * `write_only_write`'s only requirement is that the private data of the file point to + // a `T` and be legal to convert to a shared reference, which `write_only_open` + // satisfies. + unsafe { FileOps::new(operations, 0o200) } + }; +} + +extern "C" fn blob_read<T: BinaryWriter>( + file: *mut bindings::file, + buf: *mut c_char, + count: usize, + ppos: *mut bindings::loff_t, +) -> isize { + // SAFETY: + // - `file` is a valid pointer to a `struct file`. + // - The type invariant of `FileOps` guarantees that `private_data` points to a valid `T`. + let this = unsafe { &*((*file).private_data.cast::<T>()) }; + + // SAFETY: + // - `ppos` is a valid `file::Offset` pointer. + // - We have exclusive access to `ppos`. + let pos: &mut file::Offset = unsafe { &mut *ppos }; + + let mut writer = UserSlice::new(UserPtr::from_ptr(buf.cast()), count).writer(); + + let ret = || -> Result<isize> { + let written = this.write_to_slice(&mut writer, pos)?; + + Ok(written.try_into()?) + }(); + + match ret { + Ok(n) => n, + Err(e) => e.to_errno() as isize, + } +} + +/// Representation of [`FileOps`] for read only binary files. +pub(crate) trait BinaryReadFile<T> { + const FILE_OPS: FileOps<T>; +} + +impl<T: BinaryWriter + Sync> BinaryReadFile<T> for T { + const FILE_OPS: FileOps<T> = { + let operations = bindings::file_operations { + read: Some(blob_read::<T>), + llseek: Some(bindings::default_llseek), + open: Some(bindings::simple_open), + // SAFETY: `file_operations` supports zeroes in all fields. + ..unsafe { core::mem::zeroed() } + }; + + // SAFETY: + // - The private data of `struct inode` does always contain a pointer to a valid `T`. + // - `simple_open()` stores the `struct inode`'s private data in the private data of the + // corresponding `struct file`. + // - `blob_read()` re-creates a reference to `T` from the `struct file`'s private data. + // - `default_llseek()` does not access the `struct file`'s private data. + unsafe { FileOps::new(operations, 0o400) } + }; +} + +extern "C" fn blob_write<T: BinaryReader>( + file: *mut bindings::file, + buf: *const c_char, + count: usize, + ppos: *mut bindings::loff_t, +) -> isize { + // SAFETY: + // - `file` is a valid pointer to a `struct file`. + // - The type invariant of `FileOps` guarantees that `private_data` points to a valid `T`. + let this = unsafe { &*((*file).private_data.cast::<T>()) }; + + // SAFETY: + // - `ppos` is a valid `file::Offset` pointer. + // - We have exclusive access to `ppos`. + let pos: &mut file::Offset = unsafe { &mut *ppos }; + + let mut reader = UserSlice::new(UserPtr::from_ptr(buf.cast_mut().cast()), count).reader(); + + let ret = || -> Result<isize> { + let read = this.read_from_slice(&mut reader, pos)?; + + Ok(read.try_into()?) + }(); + + match ret { + Ok(n) => n, + Err(e) => e.to_errno() as isize, + } +} + +/// Representation of [`FileOps`] for write only binary files. +pub(crate) trait BinaryWriteFile<T> { + const FILE_OPS: FileOps<T>; +} + +impl<T: BinaryReader + Sync> BinaryWriteFile<T> for T { + const FILE_OPS: FileOps<T> = { + let operations = bindings::file_operations { + write: Some(blob_write::<T>), + llseek: Some(bindings::default_llseek), + open: Some(bindings::simple_open), + // SAFETY: `file_operations` supports zeroes in all fields. + ..unsafe { core::mem::zeroed() } + }; + + // SAFETY: + // - The private data of `struct inode` does always contain a pointer to a valid `T`. + // - `simple_open()` stores the `struct inode`'s private data in the private data of the + // corresponding `struct file`. + // - `blob_write()` re-creates a reference to `T` from the `struct file`'s private data. + // - `default_llseek()` does not access the `struct file`'s private data. + unsafe { FileOps::new(operations, 0o200) } + }; +} + +/// Representation of [`FileOps`] for read/write binary files. +pub(crate) trait BinaryReadWriteFile<T> { + const FILE_OPS: FileOps<T>; +} + +impl<T: BinaryWriter + BinaryReader + Sync> BinaryReadWriteFile<T> for T { + const FILE_OPS: FileOps<T> = { + let operations = bindings::file_operations { + read: Some(blob_read::<T>), + write: Some(blob_write::<T>), + llseek: Some(bindings::default_llseek), + open: Some(bindings::simple_open), + // SAFETY: `file_operations` supports zeroes in all fields. + ..unsafe { core::mem::zeroed() } + }; + + // SAFETY: + // - The private data of `struct inode` does always contain a pointer to a valid `T`. + // - `simple_open()` stores the `struct inode`'s private data in the private data of the + // corresponding `struct file`. + // - `blob_read()` re-creates a reference to `T` from the `struct file`'s private data. + // - `blob_write()` re-creates a reference to `T` from the `struct file`'s private data. + // - `default_llseek()` does not access the `struct file`'s private data. + unsafe { FileOps::new(operations, 0o600) } + }; +} diff --git a/rust/kernel/debugfs/traits.rs b/rust/kernel/debugfs/traits.rs new file mode 100644 index 000000000000..3eee60463fd5 --- /dev/null +++ b/rust/kernel/debugfs/traits.rs @@ -0,0 +1,319 @@ +// SPDX-License-Identifier: GPL-2.0 +// Copyright (C) 2025 Google LLC. + +//! Traits for rendering or updating values exported to DebugFS. + +use crate::alloc::Allocator; +use crate::fmt; +use crate::fs::file; +use crate::prelude::*; +use crate::sync::atomic::{Atomic, AtomicBasicOps, AtomicType, Relaxed}; +use crate::sync::Arc; +use crate::sync::Mutex; +use crate::transmute::{AsBytes, FromBytes}; +use crate::uaccess::{UserSliceReader, UserSliceWriter}; +use core::ops::{Deref, DerefMut}; +use core::str::FromStr; + +/// A trait for types that can be written into a string. +/// +/// This works very similarly to `Debug`, and is automatically implemented if `Debug` is +/// implemented for a type. It is also implemented for any writable type inside a `Mutex`. +/// +/// The derived implementation of `Debug` [may +/// change](https://doc.rust-lang.org/std/fmt/trait.Debug.html#stability) +/// between Rust versions, so if stability is key for your use case, please implement `Writer` +/// explicitly instead. +pub trait Writer { + /// Formats the value using the given formatter. + fn write(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result; +} + +impl<T: Writer> Writer for Mutex<T> { + fn write(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.lock().write(f) + } +} + +impl<T: fmt::Debug> Writer for T { + fn write(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + writeln!(f, "{self:?}") + } +} + +/// Trait for types that can be written out as binary. +pub trait BinaryWriter { + /// Writes the binary form of `self` into `writer`. + /// + /// `offset` is the requested offset into the binary representation of `self`. + /// + /// On success, returns the number of bytes written in to `writer`. + fn write_to_slice( + &self, + writer: &mut UserSliceWriter, + offset: &mut file::Offset, + ) -> Result<usize>; +} + +// Base implementation for any `T: AsBytes`. +impl<T: AsBytes> BinaryWriter for T { + fn write_to_slice( + &self, + writer: &mut UserSliceWriter, + offset: &mut file::Offset, + ) -> Result<usize> { + writer.write_slice_file(self.as_bytes(), offset) + } +} + +// Delegate for `Mutex<T>`: Support a `T` with an outer mutex. +impl<T: BinaryWriter> BinaryWriter for Mutex<T> { + fn write_to_slice( + &self, + writer: &mut UserSliceWriter, + offset: &mut file::Offset, + ) -> Result<usize> { + let guard = self.lock(); + + guard.write_to_slice(writer, offset) + } +} + +// Delegate for `Box<T, A>`: Support a `Box<T, A>` with no lock or an inner lock. +impl<T, A> BinaryWriter for Box<T, A> +where + T: BinaryWriter, + A: Allocator, +{ + fn write_to_slice( + &self, + writer: &mut UserSliceWriter, + offset: &mut file::Offset, + ) -> Result<usize> { + self.deref().write_to_slice(writer, offset) + } +} + +// Delegate for `Pin<Box<T, A>>`: Support a `Pin<Box<T, A>>` with no lock or an inner lock. +impl<T, A> BinaryWriter for Pin<Box<T, A>> +where + T: BinaryWriter, + A: Allocator, +{ + fn write_to_slice( + &self, + writer: &mut UserSliceWriter, + offset: &mut file::Offset, + ) -> Result<usize> { + self.deref().write_to_slice(writer, offset) + } +} + +// Delegate for `Arc<T>`: Support a `Arc<T>` with no lock or an inner lock. +impl<T> BinaryWriter for Arc<T> +where + T: BinaryWriter, +{ + fn write_to_slice( + &self, + writer: &mut UserSliceWriter, + offset: &mut file::Offset, + ) -> Result<usize> { + self.deref().write_to_slice(writer, offset) + } +} + +// Delegate for `Vec<T, A>`. +impl<T, A> BinaryWriter for Vec<T, A> +where + T: AsBytes, + A: Allocator, +{ + fn write_to_slice( + &self, + writer: &mut UserSliceWriter, + offset: &mut file::Offset, + ) -> Result<usize> { + let slice = self.as_slice(); + + // SAFETY: `T: AsBytes` allows us to treat `&[T]` as `&[u8]`. + let buffer = unsafe { + core::slice::from_raw_parts(slice.as_ptr().cast(), core::mem::size_of_val(slice)) + }; + + writer.write_slice_file(buffer, offset) + } +} + +/// A trait for types that can be updated from a user slice. +/// +/// This works similarly to `FromStr`, but operates on a `UserSliceReader` rather than a &str. +/// +/// It is automatically implemented for all atomic integers, or any type that implements `FromStr` +/// wrapped in a `Mutex`. +pub trait Reader { + /// Updates the value from the given user slice. + fn read_from_slice(&self, reader: &mut UserSliceReader) -> Result; +} + +impl<T: FromStr + Unpin> Reader for Mutex<T> { + fn read_from_slice(&self, reader: &mut UserSliceReader) -> Result { + let mut buf = [0u8; 128]; + if reader.len() > buf.len() { + return Err(EINVAL); + } + let n = reader.len(); + reader.read_slice(&mut buf[..n])?; + + let s = core::str::from_utf8(&buf[..n]).map_err(|_| EINVAL)?; + let val = s.trim().parse::<T>().map_err(|_| EINVAL)?; + *self.lock() = val; + Ok(()) + } +} + +impl<T: AtomicType + FromStr> Reader for Atomic<T> +where + T::Repr: AtomicBasicOps, +{ + fn read_from_slice(&self, reader: &mut UserSliceReader) -> Result { + let mut buf = [0u8; 21]; // Enough for a 64-bit number. + if reader.len() > buf.len() { + return Err(EINVAL); + } + let n = reader.len(); + reader.read_slice(&mut buf[..n])?; + + let s = core::str::from_utf8(&buf[..n]).map_err(|_| EINVAL)?; + let val = s.trim().parse::<T>().map_err(|_| EINVAL)?; + self.store(val, Relaxed); + Ok(()) + } +} + +/// Trait for types that can be constructed from a binary representation. +/// +/// See also [`BinaryReader`] for interior mutability. +pub trait BinaryReaderMut { + /// Reads the binary form of `self` from `reader`. + /// + /// Same as [`BinaryReader::read_from_slice`], but takes a mutable reference. + /// + /// `offset` is the requested offset into the binary representation of `self`. + /// + /// On success, returns the number of bytes read from `reader`. + fn read_from_slice_mut( + &mut self, + reader: &mut UserSliceReader, + offset: &mut file::Offset, + ) -> Result<usize>; +} + +// Base implementation for any `T: AsBytes + FromBytes`. +impl<T: AsBytes + FromBytes> BinaryReaderMut for T { + fn read_from_slice_mut( + &mut self, + reader: &mut UserSliceReader, + offset: &mut file::Offset, + ) -> Result<usize> { + reader.read_slice_file(self.as_bytes_mut(), offset) + } +} + +// Delegate for `Box<T, A>`: Support a `Box<T, A>` with an outer lock. +impl<T: ?Sized + BinaryReaderMut, A: Allocator> BinaryReaderMut for Box<T, A> { + fn read_from_slice_mut( + &mut self, + reader: &mut UserSliceReader, + offset: &mut file::Offset, + ) -> Result<usize> { + self.deref_mut().read_from_slice_mut(reader, offset) + } +} + +// Delegate for `Vec<T, A>`: Support a `Vec<T, A>` with an outer lock. +impl<T, A> BinaryReaderMut for Vec<T, A> +where + T: AsBytes + FromBytes, + A: Allocator, +{ + fn read_from_slice_mut( + &mut self, + reader: &mut UserSliceReader, + offset: &mut file::Offset, + ) -> Result<usize> { + let slice = self.as_mut_slice(); + + // SAFETY: `T: AsBytes + FromBytes` allows us to treat `&mut [T]` as `&mut [u8]`. + let buffer = unsafe { + core::slice::from_raw_parts_mut( + slice.as_mut_ptr().cast(), + core::mem::size_of_val(slice), + ) + }; + + reader.read_slice_file(buffer, offset) + } +} + +/// Trait for types that can be constructed from a binary representation. +/// +/// See also [`BinaryReaderMut`] for the mutable version. +pub trait BinaryReader { + /// Reads the binary form of `self` from `reader`. + /// + /// `offset` is the requested offset into the binary representation of `self`. + /// + /// On success, returns the number of bytes read from `reader`. + fn read_from_slice( + &self, + reader: &mut UserSliceReader, + offset: &mut file::Offset, + ) -> Result<usize>; +} + +// Delegate for `Mutex<T>`: Support a `T` with an outer `Mutex`. +impl<T: BinaryReaderMut + Unpin> BinaryReader for Mutex<T> { + fn read_from_slice( + &self, + reader: &mut UserSliceReader, + offset: &mut file::Offset, + ) -> Result<usize> { + let mut this = self.lock(); + + this.read_from_slice_mut(reader, offset) + } +} + +// Delegate for `Box<T, A>`: Support a `Box<T, A>` with an inner lock. +impl<T: ?Sized + BinaryReader, A: Allocator> BinaryReader for Box<T, A> { + fn read_from_slice( + &self, + reader: &mut UserSliceReader, + offset: &mut file::Offset, + ) -> Result<usize> { + self.deref().read_from_slice(reader, offset) + } +} + +// Delegate for `Pin<Box<T, A>>`: Support a `Pin<Box<T, A>>` with an inner lock. +impl<T: ?Sized + BinaryReader, A: Allocator> BinaryReader for Pin<Box<T, A>> { + fn read_from_slice( + &self, + reader: &mut UserSliceReader, + offset: &mut file::Offset, + ) -> Result<usize> { + self.deref().read_from_slice(reader, offset) + } +} + +// Delegate for `Arc<T>`: Support an `Arc<T>` with an inner lock. +impl<T: ?Sized + BinaryReader> BinaryReader for Arc<T> { + fn read_from_slice( + &self, + reader: &mut UserSliceReader, + offset: &mut file::Offset, + ) -> Result<usize> { + self.deref().read_from_slice(reader, offset) + } +} diff --git a/rust/kernel/device.rs b/rust/kernel/device.rs new file mode 100644 index 000000000000..c79be2e2bfe3 --- /dev/null +++ b/rust/kernel/device.rs @@ -0,0 +1,928 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! Generic devices that are part of the kernel's driver model. +//! +//! C header: [`include/linux/device.h`](srctree/include/linux/device.h) + +use crate::{ + bindings, fmt, + prelude::*, + sync::aref::ARef, + types::{ForeignOwnable, Opaque}, +}; +use core::{any::TypeId, marker::PhantomData, ptr}; + +#[cfg(CONFIG_PRINTK)] +use crate::c_str; +use crate::str::CStrExt as _; + +pub mod property; + +// Assert that we can `read()` / `write()` a `TypeId` instance from / into `struct driver_type`. +static_assert!(core::mem::size_of::<bindings::driver_type>() >= core::mem::size_of::<TypeId>()); + +/// The core representation of a device in the kernel's driver model. +/// +/// This structure represents the Rust abstraction for a C `struct device`. A [`Device`] can either +/// exist as temporary reference (see also [`Device::from_raw`]), which is only valid within a +/// certain scope or as [`ARef<Device>`], owning a dedicated reference count. +/// +/// # Device Types +/// +/// A [`Device`] can represent either a bus device or a class device. +/// +/// ## Bus Devices +/// +/// A bus device is a [`Device`] that is associated with a physical or virtual bus. Examples of +/// buses include PCI, USB, I2C, and SPI. Devices attached to a bus are registered with a specific +/// bus type, which facilitates matching devices with appropriate drivers based on IDs or other +/// identifying information. Bus devices are visible in sysfs under `/sys/bus/<bus-name>/devices/`. +/// +/// ## Class Devices +/// +/// A class device is a [`Device`] that is associated with a logical category of functionality +/// rather than a physical bus. Examples of classes include block devices, network interfaces, sound +/// cards, and input devices. Class devices are grouped under a common class and exposed to +/// userspace via entries in `/sys/class/<class-name>/`. +/// +/// # Device Context +/// +/// [`Device`] references are generic over a [`DeviceContext`], which represents the type state of +/// a [`Device`]. +/// +/// As the name indicates, this type state represents the context of the scope the [`Device`] +/// reference is valid in. For instance, the [`Bound`] context guarantees that the [`Device`] is +/// bound to a driver for the entire duration of the existence of a [`Device<Bound>`] reference. +/// +/// Other [`DeviceContext`] types besides [`Bound`] are [`Normal`], [`Core`] and [`CoreInternal`]. +/// +/// Unless selected otherwise [`Device`] defaults to the [`Normal`] [`DeviceContext`], which by +/// itself has no additional requirements. +/// +/// It is always up to the caller of [`Device::from_raw`] to select the correct [`DeviceContext`] +/// type for the corresponding scope the [`Device`] reference is created in. +/// +/// All [`DeviceContext`] types other than [`Normal`] are intended to be used with +/// [bus devices](#bus-devices) only. +/// +/// # Implementing Bus Devices +/// +/// This section provides a guideline to implement bus specific devices, such as [`pci::Device`] or +/// [`platform::Device`]. +/// +/// A bus specific device should be defined as follows. +/// +/// ```ignore +/// #[repr(transparent)] +/// pub struct Device<Ctx: device::DeviceContext = device::Normal>( +/// Opaque<bindings::bus_device_type>, +/// PhantomData<Ctx>, +/// ); +/// ``` +/// +/// Since devices are reference counted, [`AlwaysRefCounted`] should be implemented for `Device` +/// (i.e. `Device<Normal>`). Note that [`AlwaysRefCounted`] must not be implemented for any other +/// [`DeviceContext`], since all other device context types are only valid within a certain scope. +/// +/// In order to be able to implement the [`DeviceContext`] dereference hierarchy, bus device +/// implementations should call the [`impl_device_context_deref`] macro as shown below. +/// +/// ```ignore +/// // SAFETY: `Device` is a transparent wrapper of a type that doesn't depend on `Device`'s +/// // generic argument. +/// kernel::impl_device_context_deref!(unsafe { Device }); +/// ``` +/// +/// In order to convert from a any [`Device<Ctx>`] to [`ARef<Device>`], bus devices can implement +/// the following macro call. +/// +/// ```ignore +/// kernel::impl_device_context_into_aref!(Device); +/// ``` +/// +/// Bus devices should also implement the following [`AsRef`] implementation, such that users can +/// easily derive a generic [`Device`] reference. +/// +/// ```ignore +/// impl<Ctx: device::DeviceContext> AsRef<device::Device<Ctx>> for Device<Ctx> { +/// fn as_ref(&self) -> &device::Device<Ctx> { +/// ... +/// } +/// } +/// ``` +/// +/// # Implementing Class Devices +/// +/// Class device implementations require less infrastructure and depend slightly more on the +/// specific subsystem. +/// +/// An example implementation for a class device could look like this. +/// +/// ```ignore +/// #[repr(C)] +/// pub struct Device<T: class::Driver> { +/// dev: Opaque<bindings::class_device_type>, +/// data: T::Data, +/// } +/// ``` +/// +/// This class device uses the sub-classing pattern to embed the driver's private data within the +/// allocation of the class device. For this to be possible the class device is generic over the +/// class specific `Driver` trait implementation. +/// +/// Just like any device, class devices are reference counted and should hence implement +/// [`AlwaysRefCounted`] for `Device`. +/// +/// Class devices should also implement the following [`AsRef`] implementation, such that users can +/// easily derive a generic [`Device`] reference. +/// +/// ```ignore +/// impl<T: class::Driver> AsRef<device::Device> for Device<T> { +/// fn as_ref(&self) -> &device::Device { +/// ... +/// } +/// } +/// ``` +/// +/// An example for a class device implementation is +#[cfg_attr(CONFIG_DRM = "y", doc = "[`drm::Device`](kernel::drm::Device).")] +#[cfg_attr(not(CONFIG_DRM = "y"), doc = "`drm::Device`.")] +/// +/// # Invariants +/// +/// A `Device` instance represents a valid `struct device` created by the C portion of the kernel. +/// +/// Instances of this type are always reference-counted, that is, a call to `get_device` ensures +/// that the allocation remains valid at least until the matching call to `put_device`. +/// +/// `bindings::device::release` is valid to be called from any thread, hence `ARef<Device>` can be +/// dropped from any thread. +/// +/// [`AlwaysRefCounted`]: kernel::types::AlwaysRefCounted +/// [`impl_device_context_deref`]: kernel::impl_device_context_deref +/// [`pci::Device`]: kernel::pci::Device +/// [`platform::Device`]: kernel::platform::Device +#[repr(transparent)] +pub struct Device<Ctx: DeviceContext = Normal>(Opaque<bindings::device>, PhantomData<Ctx>); + +impl Device { + /// Creates a new reference-counted abstraction instance of an existing `struct device` pointer. + /// + /// # Safety + /// + /// Callers must ensure that `ptr` is valid, non-null, and has a non-zero reference count, + /// i.e. it must be ensured that the reference count of the C `struct device` `ptr` points to + /// can't drop to zero, for the duration of this function call. + /// + /// It must also be ensured that `bindings::device::release` can be called from any thread. + /// While not officially documented, this should be the case for any `struct device`. + pub unsafe fn get_device(ptr: *mut bindings::device) -> ARef<Self> { + // SAFETY: By the safety requirements ptr is valid + unsafe { Self::from_raw(ptr) }.into() + } + + /// Convert a [`&Device`](Device) into a [`&Device<Bound>`](Device<Bound>). + /// + /// # Safety + /// + /// The caller is responsible to ensure that the returned [`&Device<Bound>`](Device<Bound>) + /// only lives as long as it can be guaranteed that the [`Device`] is actually bound. + pub unsafe fn as_bound(&self) -> &Device<Bound> { + let ptr = core::ptr::from_ref(self); + + // CAST: By the safety requirements the caller is responsible to guarantee that the + // returned reference only lives as long as the device is actually bound. + let ptr = ptr.cast(); + + // SAFETY: + // - `ptr` comes from `from_ref(self)` above, hence it's guaranteed to be valid. + // - Any valid `Device` pointer is also a valid pointer for `Device<Bound>`. + unsafe { &*ptr } + } +} + +impl Device<CoreInternal> { + fn set_type_id<T: 'static>(&self) { + // SAFETY: By the type invariants, `self.as_raw()` is a valid pointer to a `struct device`. + let private = unsafe { (*self.as_raw()).p }; + + // SAFETY: For a bound device (implied by the `CoreInternal` device context), `private` is + // guaranteed to be a valid pointer to a `struct device_private`. + let driver_type = unsafe { &raw mut (*private).driver_type }; + + // SAFETY: `driver_type` is valid for (unaligned) writes of a `TypeId`. + unsafe { + driver_type + .cast::<TypeId>() + .write_unaligned(TypeId::of::<T>()) + }; + } + + /// Store a pointer to the bound driver's private data. + pub fn set_drvdata<T: 'static>(&self, data: impl PinInit<T, Error>) -> Result { + let data = KBox::pin_init(data, GFP_KERNEL)?; + + // SAFETY: By the type invariants, `self.as_raw()` is a valid pointer to a `struct device`. + unsafe { bindings::dev_set_drvdata(self.as_raw(), data.into_foreign().cast()) }; + self.set_type_id::<T>(); + + Ok(()) + } + + /// Take ownership of the private data stored in this [`Device`]. + /// + /// # Safety + /// + /// - Must only be called once after a preceding call to [`Device::set_drvdata`]. + /// - The type `T` must match the type of the `ForeignOwnable` previously stored by + /// [`Device::set_drvdata`]. + pub unsafe fn drvdata_obtain<T: 'static>(&self) -> Pin<KBox<T>> { + // SAFETY: By the type invariants, `self.as_raw()` is a valid pointer to a `struct device`. + let ptr = unsafe { bindings::dev_get_drvdata(self.as_raw()) }; + + // SAFETY: By the type invariants, `self.as_raw()` is a valid pointer to a `struct device`. + unsafe { bindings::dev_set_drvdata(self.as_raw(), core::ptr::null_mut()) }; + + // SAFETY: + // - By the safety requirements of this function, `ptr` comes from a previous call to + // `into_foreign()`. + // - `dev_get_drvdata()` guarantees to return the same pointer given to `dev_set_drvdata()` + // in `into_foreign()`. + unsafe { Pin::<KBox<T>>::from_foreign(ptr.cast()) } + } + + /// Borrow the driver's private data bound to this [`Device`]. + /// + /// # Safety + /// + /// - Must only be called after a preceding call to [`Device::set_drvdata`] and before + /// [`Device::drvdata_obtain`]. + /// - The type `T` must match the type of the `ForeignOwnable` previously stored by + /// [`Device::set_drvdata`]. + pub unsafe fn drvdata_borrow<T: 'static>(&self) -> Pin<&T> { + // SAFETY: `drvdata_unchecked()` has the exact same safety requirements as the ones + // required by this method. + unsafe { self.drvdata_unchecked() } + } +} + +impl Device<Bound> { + /// Borrow the driver's private data bound to this [`Device`]. + /// + /// # Safety + /// + /// - Must only be called after a preceding call to [`Device::set_drvdata`] and before + /// [`Device::drvdata_obtain`]. + /// - The type `T` must match the type of the `ForeignOwnable` previously stored by + /// [`Device::set_drvdata`]. + unsafe fn drvdata_unchecked<T: 'static>(&self) -> Pin<&T> { + // SAFETY: By the type invariants, `self.as_raw()` is a valid pointer to a `struct device`. + let ptr = unsafe { bindings::dev_get_drvdata(self.as_raw()) }; + + // SAFETY: + // - By the safety requirements of this function, `ptr` comes from a previous call to + // `into_foreign()`. + // - `dev_get_drvdata()` guarantees to return the same pointer given to `dev_set_drvdata()` + // in `into_foreign()`. + unsafe { Pin::<KBox<T>>::borrow(ptr.cast()) } + } + + fn match_type_id<T: 'static>(&self) -> Result { + // SAFETY: By the type invariants, `self.as_raw()` is a valid pointer to a `struct device`. + let private = unsafe { (*self.as_raw()).p }; + + // SAFETY: For a bound device, `private` is guaranteed to be a valid pointer to a + // `struct device_private`. + let driver_type = unsafe { &raw mut (*private).driver_type }; + + // SAFETY: + // - `driver_type` is valid for (unaligned) reads of a `TypeId`. + // - A bound device guarantees that `driver_type` contains a valid `TypeId` value. + let type_id = unsafe { driver_type.cast::<TypeId>().read_unaligned() }; + + if type_id != TypeId::of::<T>() { + return Err(EINVAL); + } + + Ok(()) + } + + /// Access a driver's private data. + /// + /// Returns a pinned reference to the driver's private data or [`EINVAL`] if it doesn't match + /// the asserted type `T`. + pub fn drvdata<T: 'static>(&self) -> Result<Pin<&T>> { + // SAFETY: By the type invariants, `self.as_raw()` is a valid pointer to a `struct device`. + if unsafe { bindings::dev_get_drvdata(self.as_raw()) }.is_null() { + return Err(ENOENT); + } + + self.match_type_id::<T>()?; + + // SAFETY: + // - The above check of `dev_get_drvdata()` guarantees that we are called after + // `set_drvdata()` and before `drvdata_obtain()`. + // - We've just checked that the type of the driver's private data is in fact `T`. + Ok(unsafe { self.drvdata_unchecked() }) + } +} + +impl<Ctx: DeviceContext> Device<Ctx> { + /// Obtain the raw `struct device *`. + pub(crate) fn as_raw(&self) -> *mut bindings::device { + self.0.get() + } + + /// Returns a reference to the parent device, if any. + #[cfg_attr(not(CONFIG_AUXILIARY_BUS), expect(dead_code))] + pub(crate) fn parent(&self) -> Option<&Device> { + // SAFETY: + // - By the type invariant `self.as_raw()` is always valid. + // - The parent device is only ever set at device creation. + let parent = unsafe { (*self.as_raw()).parent }; + + if parent.is_null() { + None + } else { + // SAFETY: + // - Since `parent` is not NULL, it must be a valid pointer to a `struct device`. + // - `parent` is valid for the lifetime of `self`, since a `struct device` holds a + // reference count of its parent. + Some(unsafe { Device::from_raw(parent) }) + } + } + + /// Convert a raw C `struct device` pointer to a `&'a Device`. + /// + /// # Safety + /// + /// Callers must ensure that `ptr` is valid, non-null, and has a non-zero reference count, + /// i.e. it must be ensured that the reference count of the C `struct device` `ptr` points to + /// can't drop to zero, for the duration of this function call and the entire duration when the + /// returned reference exists. + pub unsafe fn from_raw<'a>(ptr: *mut bindings::device) -> &'a Self { + // SAFETY: Guaranteed by the safety requirements of the function. + unsafe { &*ptr.cast() } + } + + /// Prints an emergency-level message (level 0) prefixed with device information. + /// + /// More details are available from [`dev_emerg`]. + /// + /// [`dev_emerg`]: crate::dev_emerg + pub fn pr_emerg(&self, args: fmt::Arguments<'_>) { + // SAFETY: `klevel` is null-terminated, uses one of the kernel constants. + unsafe { self.printk(bindings::KERN_EMERG, args) }; + } + + /// Prints an alert-level message (level 1) prefixed with device information. + /// + /// More details are available from [`dev_alert`]. + /// + /// [`dev_alert`]: crate::dev_alert + pub fn pr_alert(&self, args: fmt::Arguments<'_>) { + // SAFETY: `klevel` is null-terminated, uses one of the kernel constants. + unsafe { self.printk(bindings::KERN_ALERT, args) }; + } + + /// Prints a critical-level message (level 2) prefixed with device information. + /// + /// More details are available from [`dev_crit`]. + /// + /// [`dev_crit`]: crate::dev_crit + pub fn pr_crit(&self, args: fmt::Arguments<'_>) { + // SAFETY: `klevel` is null-terminated, uses one of the kernel constants. + unsafe { self.printk(bindings::KERN_CRIT, args) }; + } + + /// Prints an error-level message (level 3) prefixed with device information. + /// + /// More details are available from [`dev_err`]. + /// + /// [`dev_err`]: crate::dev_err + pub fn pr_err(&self, args: fmt::Arguments<'_>) { + // SAFETY: `klevel` is null-terminated, uses one of the kernel constants. + unsafe { self.printk(bindings::KERN_ERR, args) }; + } + + /// Prints a warning-level message (level 4) prefixed with device information. + /// + /// More details are available from [`dev_warn`]. + /// + /// [`dev_warn`]: crate::dev_warn + pub fn pr_warn(&self, args: fmt::Arguments<'_>) { + // SAFETY: `klevel` is null-terminated, uses one of the kernel constants. + unsafe { self.printk(bindings::KERN_WARNING, args) }; + } + + /// Prints a notice-level message (level 5) prefixed with device information. + /// + /// More details are available from [`dev_notice`]. + /// + /// [`dev_notice`]: crate::dev_notice + pub fn pr_notice(&self, args: fmt::Arguments<'_>) { + // SAFETY: `klevel` is null-terminated, uses one of the kernel constants. + unsafe { self.printk(bindings::KERN_NOTICE, args) }; + } + + /// Prints an info-level message (level 6) prefixed with device information. + /// + /// More details are available from [`dev_info`]. + /// + /// [`dev_info`]: crate::dev_info + pub fn pr_info(&self, args: fmt::Arguments<'_>) { + // SAFETY: `klevel` is null-terminated, uses one of the kernel constants. + unsafe { self.printk(bindings::KERN_INFO, args) }; + } + + /// Prints a debug-level message (level 7) prefixed with device information. + /// + /// More details are available from [`dev_dbg`]. + /// + /// [`dev_dbg`]: crate::dev_dbg + pub fn pr_dbg(&self, args: fmt::Arguments<'_>) { + if cfg!(debug_assertions) { + // SAFETY: `klevel` is null-terminated, uses one of the kernel constants. + unsafe { self.printk(bindings::KERN_DEBUG, args) }; + } + } + + /// Prints the provided message to the console. + /// + /// # Safety + /// + /// Callers must ensure that `klevel` is null-terminated; in particular, one of the + /// `KERN_*`constants, for example, `KERN_CRIT`, `KERN_ALERT`, etc. + #[cfg_attr(not(CONFIG_PRINTK), allow(unused_variables))] + unsafe fn printk(&self, klevel: &[u8], msg: fmt::Arguments<'_>) { + // SAFETY: `klevel` is null-terminated and one of the kernel constants. `self.as_raw` + // is valid because `self` is valid. The "%pA" format string expects a pointer to + // `fmt::Arguments`, which is what we're passing as the last argument. + #[cfg(CONFIG_PRINTK)] + unsafe { + bindings::_dev_printk( + klevel.as_ptr().cast::<crate::ffi::c_char>(), + self.as_raw(), + c_str!("%pA").as_char_ptr(), + core::ptr::from_ref(&msg).cast::<crate::ffi::c_void>(), + ) + }; + } + + /// Obtain the [`FwNode`](property::FwNode) corresponding to this [`Device`]. + pub fn fwnode(&self) -> Option<&property::FwNode> { + // SAFETY: `self` is valid. + let fwnode_handle = unsafe { bindings::__dev_fwnode(self.as_raw()) }; + if fwnode_handle.is_null() { + return None; + } + // SAFETY: `fwnode_handle` is valid. Its lifetime is tied to `&self`. We + // return a reference instead of an `ARef<FwNode>` because `dev_fwnode()` + // doesn't increment the refcount. It is safe to cast from a + // `struct fwnode_handle*` to a `*const FwNode` because `FwNode` is + // defined as a `#[repr(transparent)]` wrapper around `fwnode_handle`. + Some(unsafe { &*fwnode_handle.cast() }) + } +} + +// SAFETY: `Device` is a transparent wrapper of a type that doesn't depend on `Device`'s generic +// argument. +kernel::impl_device_context_deref!(unsafe { Device }); +kernel::impl_device_context_into_aref!(Device); + +// SAFETY: Instances of `Device` are always reference-counted. +unsafe impl crate::sync::aref::AlwaysRefCounted for Device { + fn inc_ref(&self) { + // SAFETY: The existence of a shared reference guarantees that the refcount is non-zero. + unsafe { bindings::get_device(self.as_raw()) }; + } + + unsafe fn dec_ref(obj: ptr::NonNull<Self>) { + // SAFETY: The safety requirements guarantee that the refcount is non-zero. + unsafe { bindings::put_device(obj.cast().as_ptr()) } + } +} + +// SAFETY: As by the type invariant `Device` can be sent to any thread. +unsafe impl Send for Device {} + +// SAFETY: `Device` can be shared among threads because all immutable methods are protected by the +// synchronization in `struct device`. +unsafe impl Sync for Device {} + +/// Marker trait for the context or scope of a bus specific device. +/// +/// [`DeviceContext`] is a marker trait for types representing the context of a bus specific +/// [`Device`]. +/// +/// The specific device context types are: [`CoreInternal`], [`Core`], [`Bound`] and [`Normal`]. +/// +/// [`DeviceContext`] types are hierarchical, which means that there is a strict hierarchy that +/// defines which [`DeviceContext`] type can be derived from another. For instance, any +/// [`Device<Core>`] can dereference to a [`Device<Bound>`]. +/// +/// The following enumeration illustrates the dereference hierarchy of [`DeviceContext`] types. +/// +/// - [`CoreInternal`] => [`Core`] => [`Bound`] => [`Normal`] +/// +/// Bus devices can automatically implement the dereference hierarchy by using +/// [`impl_device_context_deref`]. +/// +/// Note that the guarantee for a [`Device`] reference to have a certain [`DeviceContext`] comes +/// from the specific scope the [`Device`] reference is valid in. +/// +/// [`impl_device_context_deref`]: kernel::impl_device_context_deref +pub trait DeviceContext: private::Sealed {} + +/// The [`Normal`] context is the default [`DeviceContext`] of any [`Device`]. +/// +/// The normal context does not indicate any specific context. Any `Device<Ctx>` is also a valid +/// [`Device<Normal>`]. It is the only [`DeviceContext`] for which it is valid to implement +/// [`AlwaysRefCounted`] for. +/// +/// [`AlwaysRefCounted`]: kernel::types::AlwaysRefCounted +pub struct Normal; + +/// The [`Core`] context is the context of a bus specific device when it appears as argument of +/// any bus specific callback, such as `probe()`. +/// +/// The core context indicates that the [`Device<Core>`] reference's scope is limited to the bus +/// callback it appears in. It is intended to be used for synchronization purposes. Bus device +/// implementations can implement methods for [`Device<Core>`], such that they can only be called +/// from bus callbacks. +pub struct Core; + +/// Semantically the same as [`Core`], but reserved for internal usage of the corresponding bus +/// abstraction. +/// +/// The internal core context is intended to be used in exactly the same way as the [`Core`] +/// context, with the difference that this [`DeviceContext`] is internal to the corresponding bus +/// abstraction. +/// +/// This context mainly exists to share generic [`Device`] infrastructure that should only be called +/// from bus callbacks with bus abstractions, but without making them accessible for drivers. +pub struct CoreInternal; + +/// The [`Bound`] context is the [`DeviceContext`] of a bus specific device when it is guaranteed to +/// be bound to a driver. +/// +/// The bound context indicates that for the entire duration of the lifetime of a [`Device<Bound>`] +/// reference, the [`Device`] is guaranteed to be bound to a driver. +/// +/// Some APIs, such as [`dma::CoherentAllocation`] or [`Devres`] rely on the [`Device`] to be bound, +/// which can be proven with the [`Bound`] device context. +/// +/// Any abstraction that can guarantee a scope where the corresponding bus device is bound, should +/// provide a [`Device<Bound>`] reference to its users for this scope. This allows users to benefit +/// from optimizations for accessing device resources, see also [`Devres::access`]. +/// +/// [`Devres`]: kernel::devres::Devres +/// [`Devres::access`]: kernel::devres::Devres::access +/// [`dma::CoherentAllocation`]: kernel::dma::CoherentAllocation +pub struct Bound; + +mod private { + pub trait Sealed {} + + impl Sealed for super::Bound {} + impl Sealed for super::Core {} + impl Sealed for super::CoreInternal {} + impl Sealed for super::Normal {} +} + +impl DeviceContext for Bound {} +impl DeviceContext for Core {} +impl DeviceContext for CoreInternal {} +impl DeviceContext for Normal {} + +/// Convert device references to bus device references. +/// +/// Bus devices can implement this trait to allow abstractions to provide the bus device in +/// class device callbacks. +/// +/// This must not be used by drivers and is intended for bus and class device abstractions only. +/// +/// # Safety +/// +/// `AsBusDevice::OFFSET` must be the offset of the embedded base `struct device` field within a +/// bus device structure. +pub unsafe trait AsBusDevice<Ctx: DeviceContext>: AsRef<Device<Ctx>> { + /// The relative offset to the device field. + /// + /// Use `offset_of!(bindings, field)` macro to avoid breakage. + const OFFSET: usize; + + /// Convert a reference to [`Device`] into `Self`. + /// + /// # Safety + /// + /// `dev` must be contained in `Self`. + unsafe fn from_device(dev: &Device<Ctx>) -> &Self + where + Self: Sized, + { + let raw = dev.as_raw(); + // SAFETY: `raw - Self::OFFSET` is guaranteed by the safety requirements + // to be a valid pointer to `Self`. + unsafe { &*raw.byte_sub(Self::OFFSET).cast::<Self>() } + } +} + +/// # Safety +/// +/// The type given as `$device` must be a transparent wrapper of a type that doesn't depend on the +/// generic argument of `$device`. +#[doc(hidden)] +#[macro_export] +macro_rules! __impl_device_context_deref { + (unsafe { $device:ident, $src:ty => $dst:ty }) => { + impl ::core::ops::Deref for $device<$src> { + type Target = $device<$dst>; + + fn deref(&self) -> &Self::Target { + let ptr: *const Self = self; + + // CAST: `$device<$src>` and `$device<$dst>` transparently wrap the same type by the + // safety requirement of the macro. + let ptr = ptr.cast::<Self::Target>(); + + // SAFETY: `ptr` was derived from `&self`. + unsafe { &*ptr } + } + } + }; +} + +/// Implement [`core::ops::Deref`] traits for allowed [`DeviceContext`] conversions of a (bus +/// specific) device. +/// +/// # Safety +/// +/// The type given as `$device` must be a transparent wrapper of a type that doesn't depend on the +/// generic argument of `$device`. +#[macro_export] +macro_rules! impl_device_context_deref { + (unsafe { $device:ident }) => { + // SAFETY: This macro has the exact same safety requirement as + // `__impl_device_context_deref!`. + ::kernel::__impl_device_context_deref!(unsafe { + $device, + $crate::device::CoreInternal => $crate::device::Core + }); + + // SAFETY: This macro has the exact same safety requirement as + // `__impl_device_context_deref!`. + ::kernel::__impl_device_context_deref!(unsafe { + $device, + $crate::device::Core => $crate::device::Bound + }); + + // SAFETY: This macro has the exact same safety requirement as + // `__impl_device_context_deref!`. + ::kernel::__impl_device_context_deref!(unsafe { + $device, + $crate::device::Bound => $crate::device::Normal + }); + }; +} + +#[doc(hidden)] +#[macro_export] +macro_rules! __impl_device_context_into_aref { + ($src:ty, $device:tt) => { + impl ::core::convert::From<&$device<$src>> for $crate::sync::aref::ARef<$device> { + fn from(dev: &$device<$src>) -> Self { + (&**dev).into() + } + } + }; +} + +/// Implement [`core::convert::From`], such that all `&Device<Ctx>` can be converted to an +/// `ARef<Device>`. +#[macro_export] +macro_rules! impl_device_context_into_aref { + ($device:tt) => { + ::kernel::__impl_device_context_into_aref!($crate::device::CoreInternal, $device); + ::kernel::__impl_device_context_into_aref!($crate::device::Core, $device); + ::kernel::__impl_device_context_into_aref!($crate::device::Bound, $device); + }; +} + +#[doc(hidden)] +#[macro_export] +macro_rules! dev_printk { + ($method:ident, $dev:expr, $($f:tt)*) => { + { + ($dev).$method($crate::prelude::fmt!($($f)*)); + } + } +} + +/// Prints an emergency-level message (level 0) prefixed with device information. +/// +/// This level should be used if the system is unusable. +/// +/// Equivalent to the kernel's `dev_emerg` macro. +/// +/// Mimics the interface of [`std::print!`]. More information about the syntax is available from +/// [`core::fmt`] and [`std::format!`]. +/// +/// [`std::print!`]: https://doc.rust-lang.org/std/macro.print.html +/// [`std::format!`]: https://doc.rust-lang.org/std/macro.format.html +/// +/// # Examples +/// +/// ``` +/// # use kernel::device::Device; +/// +/// fn example(dev: &Device) { +/// dev_emerg!(dev, "hello {}\n", "there"); +/// } +/// ``` +#[macro_export] +macro_rules! dev_emerg { + ($($f:tt)*) => { $crate::dev_printk!(pr_emerg, $($f)*); } +} + +/// Prints an alert-level message (level 1) prefixed with device information. +/// +/// This level should be used if action must be taken immediately. +/// +/// Equivalent to the kernel's `dev_alert` macro. +/// +/// Mimics the interface of [`std::print!`]. More information about the syntax is available from +/// [`core::fmt`] and [`std::format!`]. +/// +/// [`std::print!`]: https://doc.rust-lang.org/std/macro.print.html +/// [`std::format!`]: https://doc.rust-lang.org/std/macro.format.html +/// +/// # Examples +/// +/// ``` +/// # use kernel::device::Device; +/// +/// fn example(dev: &Device) { +/// dev_alert!(dev, "hello {}\n", "there"); +/// } +/// ``` +#[macro_export] +macro_rules! dev_alert { + ($($f:tt)*) => { $crate::dev_printk!(pr_alert, $($f)*); } +} + +/// Prints a critical-level message (level 2) prefixed with device information. +/// +/// This level should be used in critical conditions. +/// +/// Equivalent to the kernel's `dev_crit` macro. +/// +/// Mimics the interface of [`std::print!`]. More information about the syntax is available from +/// [`core::fmt`] and [`std::format!`]. +/// +/// [`std::print!`]: https://doc.rust-lang.org/std/macro.print.html +/// [`std::format!`]: https://doc.rust-lang.org/std/macro.format.html +/// +/// # Examples +/// +/// ``` +/// # use kernel::device::Device; +/// +/// fn example(dev: &Device) { +/// dev_crit!(dev, "hello {}\n", "there"); +/// } +/// ``` +#[macro_export] +macro_rules! dev_crit { + ($($f:tt)*) => { $crate::dev_printk!(pr_crit, $($f)*); } +} + +/// Prints an error-level message (level 3) prefixed with device information. +/// +/// This level should be used in error conditions. +/// +/// Equivalent to the kernel's `dev_err` macro. +/// +/// Mimics the interface of [`std::print!`]. More information about the syntax is available from +/// [`core::fmt`] and [`std::format!`]. +/// +/// [`std::print!`]: https://doc.rust-lang.org/std/macro.print.html +/// [`std::format!`]: https://doc.rust-lang.org/std/macro.format.html +/// +/// # Examples +/// +/// ``` +/// # use kernel::device::Device; +/// +/// fn example(dev: &Device) { +/// dev_err!(dev, "hello {}\n", "there"); +/// } +/// ``` +#[macro_export] +macro_rules! dev_err { + ($($f:tt)*) => { $crate::dev_printk!(pr_err, $($f)*); } +} + +/// Prints a warning-level message (level 4) prefixed with device information. +/// +/// This level should be used in warning conditions. +/// +/// Equivalent to the kernel's `dev_warn` macro. +/// +/// Mimics the interface of [`std::print!`]. More information about the syntax is available from +/// [`core::fmt`] and [`std::format!`]. +/// +/// [`std::print!`]: https://doc.rust-lang.org/std/macro.print.html +/// [`std::format!`]: https://doc.rust-lang.org/std/macro.format.html +/// +/// # Examples +/// +/// ``` +/// # use kernel::device::Device; +/// +/// fn example(dev: &Device) { +/// dev_warn!(dev, "hello {}\n", "there"); +/// } +/// ``` +#[macro_export] +macro_rules! dev_warn { + ($($f:tt)*) => { $crate::dev_printk!(pr_warn, $($f)*); } +} + +/// Prints a notice-level message (level 5) prefixed with device information. +/// +/// This level should be used in normal but significant conditions. +/// +/// Equivalent to the kernel's `dev_notice` macro. +/// +/// Mimics the interface of [`std::print!`]. More information about the syntax is available from +/// [`core::fmt`] and [`std::format!`]. +/// +/// [`std::print!`]: https://doc.rust-lang.org/std/macro.print.html +/// [`std::format!`]: https://doc.rust-lang.org/std/macro.format.html +/// +/// # Examples +/// +/// ``` +/// # use kernel::device::Device; +/// +/// fn example(dev: &Device) { +/// dev_notice!(dev, "hello {}\n", "there"); +/// } +/// ``` +#[macro_export] +macro_rules! dev_notice { + ($($f:tt)*) => { $crate::dev_printk!(pr_notice, $($f)*); } +} + +/// Prints an info-level message (level 6) prefixed with device information. +/// +/// This level should be used for informational messages. +/// +/// Equivalent to the kernel's `dev_info` macro. +/// +/// Mimics the interface of [`std::print!`]. More information about the syntax is available from +/// [`core::fmt`] and [`std::format!`]. +/// +/// [`std::print!`]: https://doc.rust-lang.org/std/macro.print.html +/// [`std::format!`]: https://doc.rust-lang.org/std/macro.format.html +/// +/// # Examples +/// +/// ``` +/// # use kernel::device::Device; +/// +/// fn example(dev: &Device) { +/// dev_info!(dev, "hello {}\n", "there"); +/// } +/// ``` +#[macro_export] +macro_rules! dev_info { + ($($f:tt)*) => { $crate::dev_printk!(pr_info, $($f)*); } +} + +/// Prints a debug-level message (level 7) prefixed with device information. +/// +/// This level should be used for debug messages. +/// +/// Equivalent to the kernel's `dev_dbg` macro, except that it doesn't support dynamic debug yet. +/// +/// Mimics the interface of [`std::print!`]. More information about the syntax is available from +/// [`core::fmt`] and [`std::format!`]. +/// +/// [`std::print!`]: https://doc.rust-lang.org/std/macro.print.html +/// [`std::format!`]: https://doc.rust-lang.org/std/macro.format.html +/// +/// # Examples +/// +/// ``` +/// # use kernel::device::Device; +/// +/// fn example(dev: &Device) { +/// dev_dbg!(dev, "hello {}\n", "there"); +/// } +/// ``` +#[macro_export] +macro_rules! dev_dbg { + ($($f:tt)*) => { $crate::dev_printk!(pr_dbg, $($f)*); } +} diff --git a/rust/kernel/device/property.rs b/rust/kernel/device/property.rs new file mode 100644 index 000000000000..3a332a8c53a9 --- /dev/null +++ b/rust/kernel/device/property.rs @@ -0,0 +1,632 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! Unified device property interface. +//! +//! C header: [`include/linux/property.h`](srctree/include/linux/property.h) + +use core::{mem::MaybeUninit, ptr}; + +use super::private::Sealed; +use crate::{ + alloc::KVec, + bindings, + error::{to_result, Result}, + fmt, + prelude::*, + str::{CStr, CString}, + types::{ARef, Opaque}, +}; + +/// A reference-counted fwnode_handle. +/// +/// This structure represents the Rust abstraction for a +/// C `struct fwnode_handle`. This implementation abstracts the usage of an +/// already existing C `struct fwnode_handle` within Rust code that we get +/// passed from the C side. +/// +/// # Invariants +/// +/// A `FwNode` instance represents a valid `struct fwnode_handle` created by the +/// C portion of the kernel. +/// +/// Instances of this type are always reference-counted, that is, a call to +/// `fwnode_handle_get` ensures that the allocation remains valid at least until +/// the matching call to `fwnode_handle_put`. +#[repr(transparent)] +pub struct FwNode(Opaque<bindings::fwnode_handle>); + +impl FwNode { + /// # Safety + /// + /// Callers must ensure that: + /// - The reference count was incremented at least once. + /// - They relinquish that increment. That is, if there is only one + /// increment, callers must not use the underlying object anymore -- it is + /// only safe to do so via the newly created `ARef<FwNode>`. + unsafe fn from_raw(raw: *mut bindings::fwnode_handle) -> ARef<Self> { + // SAFETY: As per the safety requirements of this function: + // - `NonNull::new_unchecked`: + // - `raw` is not null. + // - `ARef::from_raw`: + // - `raw` has an incremented refcount. + // - that increment is relinquished, i.e. it won't be decremented + // elsewhere. + // CAST: It is safe to cast from a `*mut fwnode_handle` to + // `*mut FwNode`, because `FwNode` is defined as a + // `#[repr(transparent)]` wrapper around `fwnode_handle`. + unsafe { ARef::from_raw(ptr::NonNull::new_unchecked(raw.cast())) } + } + + /// Obtain the raw `struct fwnode_handle *`. + pub(crate) fn as_raw(&self) -> *mut bindings::fwnode_handle { + self.0.get() + } + + /// Returns `true` if `&self` is an OF node, `false` otherwise. + pub fn is_of_node(&self) -> bool { + // SAFETY: The type invariant of `Self` guarantees that `self.as_raw() is a pointer to a + // valid `struct fwnode_handle`. + unsafe { bindings::is_of_node(self.as_raw()) } + } + + /// Returns an object that implements [`Display`](fmt::Display) for + /// printing the name of a node. + /// + /// This is an alternative to the default `Display` implementation, which + /// prints the full path. + pub fn display_name(&self) -> impl fmt::Display + '_ { + struct FwNodeDisplayName<'a>(&'a FwNode); + + impl fmt::Display for FwNodeDisplayName<'_> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + // SAFETY: `self` is valid by its type invariant. + let name = unsafe { bindings::fwnode_get_name(self.0.as_raw()) }; + if name.is_null() { + return Ok(()); + } + // SAFETY: + // - `fwnode_get_name` returns null or a valid C string. + // - `name` was checked to be non-null. + let name = unsafe { CStr::from_char_ptr(name) }; + fmt::Display::fmt(name, f) + } + } + + FwNodeDisplayName(self) + } + + /// Checks if property is present or not. + pub fn property_present(&self, name: &CStr) -> bool { + // SAFETY: By the invariant of `CStr`, `name` is null-terminated. + unsafe { bindings::fwnode_property_present(self.as_raw().cast_const(), name.as_char_ptr()) } + } + + /// Returns firmware property `name` boolean value. + pub fn property_read_bool(&self, name: &CStr) -> bool { + // SAFETY: + // - `name` is non-null and null-terminated. + // - `self.as_raw()` is valid because `self` is valid. + unsafe { bindings::fwnode_property_read_bool(self.as_raw(), name.as_char_ptr()) } + } + + /// Returns the index of matching string `match_str` for firmware string + /// property `name`. + pub fn property_match_string(&self, name: &CStr, match_str: &CStr) -> Result<usize> { + // SAFETY: + // - `name` and `match_str` are non-null and null-terminated. + // - `self.as_raw` is valid because `self` is valid. + let ret = unsafe { + bindings::fwnode_property_match_string( + self.as_raw(), + name.as_char_ptr(), + match_str.as_char_ptr(), + ) + }; + to_result(ret)?; + Ok(ret as usize) + } + + /// Returns firmware property `name` integer array values in a [`KVec`]. + pub fn property_read_array_vec<'fwnode, 'name, T: PropertyInt>( + &'fwnode self, + name: &'name CStr, + len: usize, + ) -> Result<PropertyGuard<'fwnode, 'name, KVec<T>>> { + let mut val: KVec<T> = KVec::with_capacity(len, GFP_KERNEL)?; + + let res = T::read_array_from_fwnode_property(self, name, val.spare_capacity_mut()); + let res = match res { + Ok(_) => { + // SAFETY: + // - `len` is equal to `val.capacity - val.len`, because + // `val.capacity` is `len` and `val.len` is zero. + // - All elements within the interval [`0`, `len`) were initialized + // by `read_array_from_fwnode_property`. + unsafe { val.inc_len(len) } + Ok(val) + } + Err(e) => Err(e), + }; + Ok(PropertyGuard { + inner: res, + fwnode: self, + name, + }) + } + + /// Returns integer array length for firmware property `name`. + pub fn property_count_elem<T: PropertyInt>(&self, name: &CStr) -> Result<usize> { + T::read_array_len_from_fwnode_property(self, name) + } + + /// Returns the value of firmware property `name`. + /// + /// This method is generic over the type of value to read. The types that + /// can be read are strings, integers and arrays of integers. + /// + /// Reading a [`KVec`] of integers is done with the separate + /// method [`Self::property_read_array_vec`], because it takes an + /// additional `len` argument. + /// + /// Reading a boolean is done with the separate method + /// [`Self::property_read_bool`], because this operation is infallible. + /// + /// For more precise documentation about what types can be read, see + /// the [implementors of Property][Property#implementors] and [its + /// implementations on foreign types][Property#foreign-impls]. + /// + /// # Examples + /// + /// ``` + /// # use kernel::{c_str, device::{Device, property::FwNode}, str::CString}; + /// fn examples(dev: &Device) -> Result { + /// let fwnode = dev.fwnode().ok_or(ENOENT)?; + /// let b: u32 = fwnode.property_read(c_str!("some-number")).required_by(dev)?; + /// if let Some(s) = fwnode.property_read::<CString>(c_str!("some-str")).optional() { + /// // ... + /// } + /// Ok(()) + /// } + /// ``` + pub fn property_read<'fwnode, 'name, T: Property>( + &'fwnode self, + name: &'name CStr, + ) -> PropertyGuard<'fwnode, 'name, T> { + PropertyGuard { + inner: T::read_from_fwnode_property(self, name), + fwnode: self, + name, + } + } + + /// Returns first matching named child node handle. + pub fn get_child_by_name(&self, name: &CStr) -> Option<ARef<Self>> { + // SAFETY: `self` and `name` are valid by their type invariants. + let child = + unsafe { bindings::fwnode_get_named_child_node(self.as_raw(), name.as_char_ptr()) }; + if child.is_null() { + return None; + } + // SAFETY: + // - `fwnode_get_named_child_node` returns a pointer with its refcount + // incremented. + // - That increment is relinquished, i.e. the underlying object is not + // used anymore except via the newly created `ARef`. + Some(unsafe { Self::from_raw(child) }) + } + + /// Returns an iterator over a node's children. + pub fn children<'a>(&'a self) -> impl Iterator<Item = ARef<FwNode>> + 'a { + let mut prev: Option<ARef<FwNode>> = None; + + core::iter::from_fn(move || { + let prev_ptr = match prev.take() { + None => ptr::null_mut(), + Some(prev) => { + // We will pass `prev` to `fwnode_get_next_child_node`, + // which decrements its refcount, so we use + // `ARef::into_raw` to avoid decrementing the refcount + // twice. + let prev = ARef::into_raw(prev); + prev.as_ptr().cast() + } + }; + // SAFETY: + // - `self.as_raw()` is valid by its type invariant. + // - `prev_ptr` may be null, which is allowed and corresponds to + // getting the first child. Otherwise, `prev_ptr` is valid, as it + // is the stored return value from the previous invocation. + // - `prev_ptr` has its refount incremented. + // - The increment of `prev_ptr` is relinquished, i.e. the + // underlying object won't be used anymore. + let next = unsafe { bindings::fwnode_get_next_child_node(self.as_raw(), prev_ptr) }; + if next.is_null() { + return None; + } + // SAFETY: + // - `next` is valid because `fwnode_get_next_child_node` returns a + // pointer with its refcount incremented. + // - That increment is relinquished, i.e. the underlying object + // won't be used anymore, except via the newly created + // `ARef<Self>`. + let next = unsafe { FwNode::from_raw(next) }; + prev = Some(next.clone()); + Some(next) + }) + } + + /// Finds a reference with arguments. + pub fn property_get_reference_args( + &self, + prop: &CStr, + nargs: NArgs<'_>, + index: u32, + ) -> Result<FwNodeReferenceArgs> { + let mut out_args = FwNodeReferenceArgs::default(); + + let (nargs_prop, nargs) = match nargs { + NArgs::Prop(nargs_prop) => (nargs_prop.as_char_ptr(), 0), + NArgs::N(nargs) => (ptr::null(), nargs), + }; + + // SAFETY: + // - `self.0.get()` is valid. + // - `prop.as_char_ptr()` is valid and zero-terminated. + // - `nargs_prop` is valid and zero-terminated if `nargs` + // is zero, otherwise it is allowed to be a null-pointer. + // - The function upholds the type invariants of `out_args`, + // namely: + // - It may fill the field `fwnode` with a valid pointer, + // in which case its refcount is incremented. + // - It may modify the field `nargs`, in which case it + // initializes at least as many elements in `args`. + let ret = unsafe { + bindings::fwnode_property_get_reference_args( + self.0.get(), + prop.as_char_ptr(), + nargs_prop, + nargs, + index, + &mut out_args.0, + ) + }; + to_result(ret)?; + + Ok(out_args) + } +} + +/// The number of arguments to request [`FwNodeReferenceArgs`]. +pub enum NArgs<'a> { + /// The name of the property of the reference indicating the number of + /// arguments. + Prop(&'a CStr), + /// The known number of arguments. + N(u32), +} + +/// The return value of [`FwNode::property_get_reference_args`]. +/// +/// This structure represents the Rust abstraction for a C +/// `struct fwnode_reference_args` which was initialized by the C side. +/// +/// # Invariants +/// +/// If the field `fwnode` is valid, it owns an increment of its refcount. +/// +/// The field `args` contains at least as many initialized elements as indicated +/// by the field `nargs`. +#[repr(transparent)] +#[derive(Default)] +pub struct FwNodeReferenceArgs(bindings::fwnode_reference_args); + +impl Drop for FwNodeReferenceArgs { + fn drop(&mut self) { + if !self.0.fwnode.is_null() { + // SAFETY: + // - By the type invariants of `FwNodeReferenceArgs`, its field + // `fwnode` owns an increment of its refcount. + // - That increment is relinquished. The underlying object won't be + // used anymore because we are dropping it. + let _ = unsafe { FwNode::from_raw(self.0.fwnode) }; + } + } +} + +impl FwNodeReferenceArgs { + /// Returns the slice of reference arguments. + pub fn as_slice(&self) -> &[u64] { + // SAFETY: As per the safety invariant of `FwNodeReferenceArgs`, `nargs` + // is the minimum number of elements in `args` that is valid. + unsafe { core::slice::from_raw_parts(self.0.args.as_ptr(), self.0.nargs as usize) } + } + + /// Returns the number of reference arguments. + pub fn len(&self) -> usize { + self.0.nargs as usize + } + + /// Returns `true` if there are no reference arguments. + pub fn is_empty(&self) -> bool { + self.0.nargs == 0 + } +} + +impl fmt::Debug for FwNodeReferenceArgs { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{:?}", self.as_slice()) + } +} + +// SAFETY: Instances of `FwNode` are always reference-counted. +unsafe impl crate::types::AlwaysRefCounted for FwNode { + fn inc_ref(&self) { + // SAFETY: The existence of a shared reference guarantees that the + // refcount is non-zero. + unsafe { bindings::fwnode_handle_get(self.as_raw()) }; + } + + unsafe fn dec_ref(obj: ptr::NonNull<Self>) { + // SAFETY: The safety requirements guarantee that the refcount is + // non-zero. + unsafe { bindings::fwnode_handle_put(obj.cast().as_ptr()) } + } +} + +enum Node<'a> { + Borrowed(&'a FwNode), + Owned(ARef<FwNode>), +} + +impl fmt::Display for FwNode { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + // The logic here is the same as the one in lib/vsprintf.c + // (fwnode_full_name_string). + + // SAFETY: `self.as_raw()` is valid by its type invariant. + let num_parents = unsafe { bindings::fwnode_count_parents(self.as_raw()) }; + + for depth in (0..=num_parents).rev() { + let fwnode = if depth == 0 { + Node::Borrowed(self) + } else { + // SAFETY: `self.as_raw()` is valid. + let ptr = unsafe { bindings::fwnode_get_nth_parent(self.as_raw(), depth) }; + // SAFETY: + // - The depth passed to `fwnode_get_nth_parent` is + // within the valid range, so the returned pointer is + // not null. + // - The reference count was incremented by + // `fwnode_get_nth_parent`. + // - That increment is relinquished to + // `FwNode::from_raw`. + Node::Owned(unsafe { FwNode::from_raw(ptr) }) + }; + // Take a reference to the owned or borrowed `FwNode`. + let fwnode: &FwNode = match &fwnode { + Node::Borrowed(f) => f, + Node::Owned(f) => f, + }; + + // SAFETY: `fwnode` is valid by its type invariant. + let prefix = unsafe { bindings::fwnode_get_name_prefix(fwnode.as_raw()) }; + if !prefix.is_null() { + // SAFETY: `fwnode_get_name_prefix` returns null or a + // valid C string. + let prefix = unsafe { CStr::from_char_ptr(prefix) }; + fmt::Display::fmt(prefix, f)?; + } + fmt::Display::fmt(&fwnode.display_name(), f)?; + } + + Ok(()) + } +} + +/// Implemented for types that can be read as properties. +/// +/// This is implemented for strings, integers and arrays of integers. It's used +/// to make [`FwNode::property_read`] generic over the type of property being +/// read. There are also two dedicated methods to read other types, because they +/// require more specialized function signatures: +/// - [`property_read_bool`](FwNode::property_read_bool) +/// - [`property_read_array_vec`](FwNode::property_read_array_vec) +/// +/// It must be public, because it appears in the signatures of other public +/// functions, but its methods shouldn't be used outside the kernel crate. +pub trait Property: Sized + Sealed { + /// Used to make [`FwNode::property_read`] generic. + fn read_from_fwnode_property(fwnode: &FwNode, name: &CStr) -> Result<Self>; +} + +impl Sealed for CString {} + +impl Property for CString { + fn read_from_fwnode_property(fwnode: &FwNode, name: &CStr) -> Result<Self> { + let mut str: *mut u8 = ptr::null_mut(); + let pstr: *mut _ = &mut str; + + // SAFETY: + // - `name` is non-null and null-terminated. + // - `fwnode.as_raw` is valid because `fwnode` is valid. + let ret = unsafe { + bindings::fwnode_property_read_string(fwnode.as_raw(), name.as_char_ptr(), pstr.cast()) + }; + to_result(ret)?; + + // SAFETY: + // - `pstr` is a valid pointer to a NUL-terminated C string. + // - It is valid for at least as long as `fwnode`, but it's only used + // within the current function. + // - The memory it points to is not mutated during that time. + let str = unsafe { CStr::from_char_ptr(*pstr) }; + Ok(str.try_into()?) + } +} + +/// Implemented for all integers that can be read as properties. +/// +/// This helper trait is needed on top of the existing [`Property`] +/// trait to associate the integer types of various sizes with their +/// corresponding `fwnode_property_read_*_array` functions. +/// +/// It must be public, because it appears in the signatures of other public +/// functions, but its methods shouldn't be used outside the kernel crate. +pub trait PropertyInt: Copy + Sealed { + /// Reads a property array. + fn read_array_from_fwnode_property<'a>( + fwnode: &FwNode, + name: &CStr, + out: &'a mut [MaybeUninit<Self>], + ) -> Result<&'a mut [Self]>; + + /// Reads the length of a property array. + fn read_array_len_from_fwnode_property(fwnode: &FwNode, name: &CStr) -> Result<usize>; +} +// This macro generates implementations of the traits `Property` and +// `PropertyInt` for integers of various sizes. Its input is a list +// of pairs separated by commas. The first element of the pair is the +// type of the integer, the second one is the name of its corresponding +// `fwnode_property_read_*_array` function. +macro_rules! impl_property_for_int { + ($($int:ty: $f:ident),* $(,)?) => { $( + impl Sealed for $int {} + impl<const N: usize> Sealed for [$int; N] {} + + impl PropertyInt for $int { + fn read_array_from_fwnode_property<'a>( + fwnode: &FwNode, + name: &CStr, + out: &'a mut [MaybeUninit<Self>], + ) -> Result<&'a mut [Self]> { + // SAFETY: + // - `fwnode`, `name` and `out` are all valid by their type + // invariants. + // - `out.len()` is a valid bound for the memory pointed to by + // `out.as_mut_ptr()`. + // CAST: It's ok to cast from `*mut MaybeUninit<$int>` to a + // `*mut $int` because they have the same memory layout. + let ret = unsafe { + bindings::$f( + fwnode.as_raw(), + name.as_char_ptr(), + out.as_mut_ptr().cast(), + out.len(), + ) + }; + to_result(ret)?; + // SAFETY: Transmuting from `&'a mut [MaybeUninit<Self>]` to + // `&'a mut [Self]` is sound, because the previous call to a + // `fwnode_property_read_*_array` function (which didn't fail) + // fully initialized the slice. + Ok(unsafe { core::mem::transmute::<&mut [MaybeUninit<Self>], &mut [Self]>(out) }) + } + + fn read_array_len_from_fwnode_property(fwnode: &FwNode, name: &CStr) -> Result<usize> { + // SAFETY: + // - `fwnode` and `name` are valid by their type invariants. + // - It's ok to pass a null pointer to the + // `fwnode_property_read_*_array` functions if `nval` is zero. + // This will return the length of the array. + let ret = unsafe { + bindings::$f( + fwnode.as_raw(), + name.as_char_ptr(), + ptr::null_mut(), + 0, + ) + }; + to_result(ret)?; + Ok(ret as usize) + } + } + + impl Property for $int { + fn read_from_fwnode_property(fwnode: &FwNode, name: &CStr) -> Result<Self> { + let val: [_; 1] = <[$int; 1]>::read_from_fwnode_property(fwnode, name)?; + Ok(val[0]) + } + } + + impl<const N: usize> Property for [$int; N] { + fn read_from_fwnode_property(fwnode: &FwNode, name: &CStr) -> Result<Self> { + let mut val: [MaybeUninit<$int>; N] = [const { MaybeUninit::uninit() }; N]; + + <$int>::read_array_from_fwnode_property(fwnode, name, &mut val)?; + + // SAFETY: `val` is always initialized when + // `fwnode_property_read_*_array` is successful. + Ok(val.map(|v| unsafe { v.assume_init() })) + } + } + )* }; +} +impl_property_for_int! { + u8: fwnode_property_read_u8_array, + u16: fwnode_property_read_u16_array, + u32: fwnode_property_read_u32_array, + u64: fwnode_property_read_u64_array, + i8: fwnode_property_read_u8_array, + i16: fwnode_property_read_u16_array, + i32: fwnode_property_read_u32_array, + i64: fwnode_property_read_u64_array, +} + +/// A helper for reading device properties. +/// +/// Use [`Self::required_by`] if a missing property is considered a bug and +/// [`Self::optional`] otherwise. +/// +/// For convenience, [`Self::or`] and [`Self::or_default`] are provided. +pub struct PropertyGuard<'fwnode, 'name, T> { + /// The result of reading the property. + inner: Result<T>, + /// The fwnode of the property, used for logging in the "required" case. + fwnode: &'fwnode FwNode, + /// The name of the property, used for logging in the "required" case. + name: &'name CStr, +} + +impl<T> PropertyGuard<'_, '_, T> { + /// Access the property, indicating it is required. + /// + /// If the property is not present, the error is automatically logged. If a + /// missing property is not an error, use [`Self::optional`] instead. The + /// device is required to associate the log with it. + pub fn required_by(self, dev: &super::Device) -> Result<T> { + if self.inner.is_err() { + dev_err!( + dev, + "{}: property '{}' is missing\n", + self.fwnode, + self.name + ); + } + self.inner + } + + /// Access the property, indicating it is optional. + /// + /// In contrast to [`Self::required_by`], no error message is logged if + /// the property is not present. + pub fn optional(self) -> Option<T> { + self.inner.ok() + } + + /// Access the property or the specified default value. + /// + /// Do not pass a sentinel value as default to detect a missing property. + /// Use [`Self::required_by`] or [`Self::optional`] instead. + pub fn or(self, default: T) -> T { + self.inner.unwrap_or(default) + } +} + +impl<T: Default> PropertyGuard<'_, '_, T> { + /// Access the property or a default value. + /// + /// Use [`Self::or`] to specify a custom default value. + pub fn or_default(self) -> T { + self.inner.unwrap_or_default() + } +} diff --git a/rust/kernel/device_id.rs b/rust/kernel/device_id.rs new file mode 100644 index 000000000000..62c42da12e9d --- /dev/null +++ b/rust/kernel/device_id.rs @@ -0,0 +1,206 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! Generic implementation of device IDs. +//! +//! Each bus / subsystem that matches device and driver through a bus / subsystem specific ID is +//! expected to implement [`RawDeviceId`]. + +use core::mem::MaybeUninit; + +/// Marker trait to indicate a Rust device ID type represents a corresponding C device ID type. +/// +/// This is meant to be implemented by buses/subsystems so that they can use [`IdTable`] to +/// guarantee (at compile-time) zero-termination of device id tables provided by drivers. +/// +/// # Safety +/// +/// Implementers must ensure that `Self` is layout-compatible with [`RawDeviceId::RawType`]; +/// i.e. it's safe to transmute to `RawDeviceId`. +/// +/// This requirement is needed so `IdArray::new` can convert `Self` to `RawType` when building +/// the ID table. +/// +/// Ideally, this should be achieved using a const function that does conversion instead of +/// transmute; however, const trait functions relies on `const_trait_impl` unstable feature, +/// which is broken/gone in Rust 1.73. +pub unsafe trait RawDeviceId { + /// The raw type that holds the device id. + /// + /// Id tables created from [`Self`] are going to hold this type in its zero-terminated array. + type RawType: Copy; +} + +/// Extension trait for [`RawDeviceId`] for devices that embed an index or context value. +/// +/// This is typically used when the device ID struct includes a field like `driver_data` +/// that is used to store a pointer-sized value (e.g., an index or context pointer). +/// +/// # Safety +/// +/// Implementers must ensure that `DRIVER_DATA_OFFSET` is the correct offset (in bytes) to +/// the context/data field (e.g., the `driver_data` field) within the raw device ID structure. +/// This field must be correctly sized to hold a `usize`. +/// +/// Ideally, the data should be added during `Self` to `RawType` conversion, +/// but there's currently no way to do it when using traits in const. +pub unsafe trait RawDeviceIdIndex: RawDeviceId { + /// The offset (in bytes) to the context/data field in the raw device ID. + const DRIVER_DATA_OFFSET: usize; + + /// The index stored at `DRIVER_DATA_OFFSET` of the implementor of the [`RawDeviceIdIndex`] + /// trait. + fn index(&self) -> usize; +} + +/// A zero-terminated device id array. +#[repr(C)] +pub struct RawIdArray<T: RawDeviceId, const N: usize> { + ids: [T::RawType; N], + sentinel: MaybeUninit<T::RawType>, +} + +impl<T: RawDeviceId, const N: usize> RawIdArray<T, N> { + #[doc(hidden)] + pub const fn size(&self) -> usize { + core::mem::size_of::<Self>() + } +} + +/// A zero-terminated device id array, followed by context data. +#[repr(C)] +pub struct IdArray<T: RawDeviceId, U, const N: usize> { + raw_ids: RawIdArray<T, N>, + id_infos: [U; N], +} + +impl<T: RawDeviceId, U, const N: usize> IdArray<T, U, N> { + /// Creates a new instance of the array. + /// + /// The contents are derived from the given identifiers and context information. + /// + /// # Safety + /// + /// `data_offset` as `None` is always safe. + /// If `data_offset` is `Some(data_offset)`, then: + /// - `data_offset` must be the correct offset (in bytes) to the context/data field + /// (e.g., the `driver_data` field) within the raw device ID structure. + /// - The field at `data_offset` must be correctly sized to hold a `usize`. + const unsafe fn build(ids: [(T, U); N], data_offset: Option<usize>) -> Self { + let mut raw_ids = [const { MaybeUninit::<T::RawType>::uninit() }; N]; + let mut infos = [const { MaybeUninit::uninit() }; N]; + + let mut i = 0usize; + while i < N { + // SAFETY: by the safety requirement of `RawDeviceId`, we're guaranteed that `T` is + // layout-wise compatible with `RawType`. + raw_ids[i] = unsafe { core::mem::transmute_copy(&ids[i].0) }; + if let Some(data_offset) = data_offset { + // SAFETY: by the safety requirement of this function, this would be effectively + // `raw_ids[i].driver_data = i;`. + unsafe { + raw_ids[i] + .as_mut_ptr() + .byte_add(data_offset) + .cast::<usize>() + .write(i); + } + } + + // SAFETY: this is effectively a move: `infos[i] = ids[i].1`. We make a copy here but + // later forget `ids`. + infos[i] = MaybeUninit::new(unsafe { core::ptr::read(&ids[i].1) }); + i += 1; + } + + core::mem::forget(ids); + + Self { + raw_ids: RawIdArray { + // SAFETY: this is effectively `array_assume_init`, which is unstable, so we use + // `transmute_copy` instead. We have initialized all elements of `raw_ids` so this + // `array_assume_init` is safe. + ids: unsafe { core::mem::transmute_copy(&raw_ids) }, + sentinel: MaybeUninit::zeroed(), + }, + // SAFETY: We have initialized all elements of `infos` so this `array_assume_init` is + // safe. + id_infos: unsafe { core::mem::transmute_copy(&infos) }, + } + } + + /// Creates a new instance of the array without writing index values. + /// + /// The contents are derived from the given identifiers and context information. + /// If the device implements [`RawDeviceIdIndex`], consider using [`IdArray::new`] instead. + pub const fn new_without_index(ids: [(T, U); N]) -> Self { + // SAFETY: Calling `Self::build` with `offset = None` is always safe, + // because no raw memory writes are performed in this case. + unsafe { Self::build(ids, None) } + } + + /// Reference to the contained [`RawIdArray`]. + pub const fn raw_ids(&self) -> &RawIdArray<T, N> { + &self.raw_ids + } +} + +impl<T: RawDeviceId + RawDeviceIdIndex, U, const N: usize> IdArray<T, U, N> { + /// Creates a new instance of the array. + /// + /// The contents are derived from the given identifiers and context information. + pub const fn new(ids: [(T, U); N]) -> Self { + // SAFETY: by the safety requirement of `RawDeviceIdIndex`, + // `T::DRIVER_DATA_OFFSET` is guaranteed to be the correct offset (in bytes) to + // a field within `T::RawType`. + unsafe { Self::build(ids, Some(T::DRIVER_DATA_OFFSET)) } + } +} + +/// A device id table. +/// +/// This trait is only implemented by `IdArray`. +/// +/// The purpose of this trait is to allow `&'static dyn IdArray<T, U>` to be in context when `N` in +/// `IdArray` doesn't matter. +pub trait IdTable<T: RawDeviceId, U> { + /// Obtain the pointer to the ID table. + fn as_ptr(&self) -> *const T::RawType; + + /// Obtain the pointer to the bus specific device ID from an index. + fn id(&self, index: usize) -> &T::RawType; + + /// Obtain the pointer to the driver-specific information from an index. + fn info(&self, index: usize) -> &U; +} + +impl<T: RawDeviceId, U, const N: usize> IdTable<T, U> for IdArray<T, U, N> { + fn as_ptr(&self) -> *const T::RawType { + // This cannot be `self.ids.as_ptr()`, as the return pointer must have correct provenance + // to access the sentinel. + core::ptr::from_ref(self).cast() + } + + fn id(&self, index: usize) -> &T::RawType { + &self.raw_ids.ids[index] + } + + fn info(&self, index: usize) -> &U { + &self.id_infos[index] + } +} + +/// Create device table alias for modpost. +#[macro_export] +macro_rules! module_device_table { + ($table_type: literal, $module_table_name:ident, $table_name:ident) => { + #[rustfmt::skip] + #[export_name = + concat!("__mod_device_table__", line!(), + "__kmod_", module_path!(), + "__", $table_type, + "__", stringify!($table_name)) + ] + static $module_table_name: [::core::mem::MaybeUninit<u8>; $table_name.raw_ids().size()] = + unsafe { ::core::mem::transmute_copy($table_name.raw_ids()) }; + }; +} diff --git a/rust/kernel/devres.rs b/rust/kernel/devres.rs new file mode 100644 index 000000000000..835d9c11948e --- /dev/null +++ b/rust/kernel/devres.rs @@ -0,0 +1,379 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! Devres abstraction +//! +//! [`Devres`] represents an abstraction for the kernel devres (device resource management) +//! implementation. + +use crate::{ + alloc::Flags, + bindings, + device::{Bound, Device}, + error::{to_result, Error, Result}, + ffi::c_void, + prelude::*, + revocable::{Revocable, RevocableGuard}, + sync::{aref::ARef, rcu, Completion}, + types::{ForeignOwnable, Opaque, ScopeGuard}, +}; + +use pin_init::Wrapper; + +/// [`Devres`] inner data accessed from [`Devres::callback`]. +#[pin_data] +struct Inner<T: Send> { + #[pin] + data: Revocable<T>, + /// Tracks whether [`Devres::callback`] has been completed. + #[pin] + devm: Completion, + /// Tracks whether revoking [`Self::data`] has been completed. + #[pin] + revoke: Completion, +} + +/// This abstraction is meant to be used by subsystems to containerize [`Device`] bound resources to +/// manage their lifetime. +/// +/// [`Device`] bound resources should be freed when either the resource goes out of scope or the +/// [`Device`] is unbound respectively, depending on what happens first. In any case, it is always +/// guaranteed that revoking the device resource is completed before the corresponding [`Device`] +/// is unbound. +/// +/// To achieve that [`Devres`] registers a devres callback on creation, which is called once the +/// [`Device`] is unbound, revoking access to the encapsulated resource (see also [`Revocable`]). +/// +/// After the [`Devres`] has been unbound it is not possible to access the encapsulated resource +/// anymore. +/// +/// [`Devres`] users should make sure to simply free the corresponding backing resource in `T`'s +/// [`Drop`] implementation. +/// +/// # Examples +/// +/// ```no_run +/// use kernel::{ +/// bindings, +/// device::{ +/// Bound, +/// Device, +/// }, +/// devres::Devres, +/// io::{ +/// Io, +/// IoRaw, +/// PhysAddr, +/// }, +/// }; +/// use core::ops::Deref; +/// +/// // See also [`pci::Bar`] for a real example. +/// struct IoMem<const SIZE: usize>(IoRaw<SIZE>); +/// +/// impl<const SIZE: usize> IoMem<SIZE> { +/// /// # Safety +/// /// +/// /// [`paddr`, `paddr` + `SIZE`) must be a valid MMIO region that is mappable into the CPUs +/// /// virtual address space. +/// unsafe fn new(paddr: usize) -> Result<Self>{ +/// // SAFETY: By the safety requirements of this function [`paddr`, `paddr` + `SIZE`) is +/// // valid for `ioremap`. +/// let addr = unsafe { bindings::ioremap(paddr as PhysAddr, SIZE) }; +/// if addr.is_null() { +/// return Err(ENOMEM); +/// } +/// +/// Ok(IoMem(IoRaw::new(addr as usize, SIZE)?)) +/// } +/// } +/// +/// impl<const SIZE: usize> Drop for IoMem<SIZE> { +/// fn drop(&mut self) { +/// // SAFETY: `self.0.addr()` is guaranteed to be properly mapped by `Self::new`. +/// unsafe { bindings::iounmap(self.0.addr() as *mut c_void); }; +/// } +/// } +/// +/// impl<const SIZE: usize> Deref for IoMem<SIZE> { +/// type Target = Io<SIZE>; +/// +/// fn deref(&self) -> &Self::Target { +/// // SAFETY: The memory range stored in `self` has been properly mapped in `Self::new`. +/// unsafe { Io::from_raw(&self.0) } +/// } +/// } +/// # fn no_run(dev: &Device<Bound>) -> Result<(), Error> { +/// // SAFETY: Invalid usage for example purposes. +/// let iomem = unsafe { IoMem::<{ core::mem::size_of::<u32>() }>::new(0xBAAAAAAD)? }; +/// let devres = KBox::pin_init(Devres::new(dev, iomem), GFP_KERNEL)?; +/// +/// let res = devres.try_access().ok_or(ENXIO)?; +/// res.write8(0x42, 0x0); +/// # Ok(()) +/// # } +/// ``` +/// +/// # Invariants +/// +/// `Self::inner` is guaranteed to be initialized and is always accessed read-only. +#[pin_data(PinnedDrop)] +pub struct Devres<T: Send> { + dev: ARef<Device>, + /// Pointer to [`Self::devres_callback`]. + /// + /// Has to be stored, since Rust does not guarantee to always return the same address for a + /// function. However, the C API uses the address as a key. + callback: unsafe extern "C" fn(*mut c_void), + /// Contains all the fields shared with [`Self::callback`]. + // TODO: Replace with `UnsafePinned`, once available. + // + // Subsequently, the `drop_in_place()` in `Devres::drop` and `Devres::new` as well as the + // explicit `Send` and `Sync' impls can be removed. + #[pin] + inner: Opaque<Inner<T>>, + _add_action: (), +} + +impl<T: Send> Devres<T> { + /// Creates a new [`Devres`] instance of the given `data`. + /// + /// The `data` encapsulated within the returned `Devres` instance' `data` will be + /// (revoked)[`Revocable`] once the device is detached. + pub fn new<'a, E>( + dev: &'a Device<Bound>, + data: impl PinInit<T, E> + 'a, + ) -> impl PinInit<Self, Error> + 'a + where + T: 'a, + Error: From<E>, + { + try_pin_init!(&this in Self { + dev: dev.into(), + callback: Self::devres_callback, + // INVARIANT: `inner` is properly initialized. + inner <- Opaque::pin_init(try_pin_init!(Inner { + devm <- Completion::new(), + revoke <- Completion::new(), + data <- Revocable::new(data), + })), + // TODO: Replace with "initializer code blocks" [1] once available. + // + // [1] https://github.com/Rust-for-Linux/pin-init/pull/69 + _add_action: { + // SAFETY: `this` is a valid pointer to uninitialized memory. + let inner = unsafe { &raw mut (*this.as_ptr()).inner }; + + // SAFETY: + // - `dev.as_raw()` is a pointer to a valid bound device. + // - `inner` is guaranteed to be a valid for the duration of the lifetime of `Self`. + // - `devm_add_action()` is guaranteed not to call `callback` until `this` has been + // properly initialized, because we require `dev` (i.e. the *bound* device) to + // live at least as long as the returned `impl PinInit<Self, Error>`. + to_result(unsafe { + bindings::devm_add_action(dev.as_raw(), Some(*callback), inner.cast()) + }).inspect_err(|_| { + let inner = Opaque::cast_into(inner); + + // SAFETY: `inner` is a valid pointer to an `Inner<T>` and valid for both reads + // and writes. + unsafe { core::ptr::drop_in_place(inner) }; + })?; + }, + }) + } + + fn inner(&self) -> &Inner<T> { + // SAFETY: By the type invairants of `Self`, `inner` is properly initialized and always + // accessed read-only. + unsafe { &*self.inner.get() } + } + + fn data(&self) -> &Revocable<T> { + &self.inner().data + } + + #[allow(clippy::missing_safety_doc)] + unsafe extern "C" fn devres_callback(ptr: *mut kernel::ffi::c_void) { + // SAFETY: In `Self::new` we've passed a valid pointer to `Inner` to `devm_add_action()`, + // hence `ptr` must be a valid pointer to `Inner`. + let inner = unsafe { &*ptr.cast::<Inner<T>>() }; + + // Ensure that `inner` can't be used anymore after we signal completion of this callback. + let inner = ScopeGuard::new_with_data(inner, |inner| inner.devm.complete_all()); + + if !inner.data.revoke() { + // If `revoke()` returns false, it means that `Devres::drop` already started revoking + // `data` for us. Hence we have to wait until `Devres::drop` signals that it + // completed revoking `data`. + inner.revoke.wait_for_completion(); + } + } + + fn remove_action(&self) -> bool { + // SAFETY: + // - `self.dev` is a valid `Device`, + // - the `action` and `data` pointers are the exact same ones as given to + // `devm_add_action()` previously, + (unsafe { + bindings::devm_remove_action_nowarn( + self.dev.as_raw(), + Some(self.callback), + core::ptr::from_ref(self.inner()).cast_mut().cast(), + ) + } == 0) + } + + /// Return a reference of the [`Device`] this [`Devres`] instance has been created with. + pub fn device(&self) -> &Device { + &self.dev + } + + /// Obtain `&'a T`, bypassing the [`Revocable`]. + /// + /// This method allows to directly obtain a `&'a T`, bypassing the [`Revocable`], by presenting + /// a `&'a Device<Bound>` of the same [`Device`] this [`Devres`] instance has been created with. + /// + /// # Errors + /// + /// An error is returned if `dev` does not match the same [`Device`] this [`Devres`] instance + /// has been created with. + /// + /// # Examples + /// + /// ```no_run + /// # #![cfg(CONFIG_PCI)] + /// # use kernel::{device::Core, devres::Devres, pci}; + /// + /// fn from_core(dev: &pci::Device<Core>, devres: Devres<pci::Bar<0x4>>) -> Result { + /// let bar = devres.access(dev.as_ref())?; + /// + /// let _ = bar.read32(0x0); + /// + /// // might_sleep() + /// + /// bar.write32(0x42, 0x0); + /// + /// Ok(()) + /// } + /// ``` + pub fn access<'a>(&'a self, dev: &'a Device<Bound>) -> Result<&'a T> { + if self.dev.as_raw() != dev.as_raw() { + return Err(EINVAL); + } + + // SAFETY: `dev` being the same device as the device this `Devres` has been created for + // proves that `self.data` hasn't been revoked and is guaranteed to not be revoked as long + // as `dev` lives; `dev` lives at least as long as `self`. + Ok(unsafe { self.data().access() }) + } + + /// [`Devres`] accessor for [`Revocable::try_access`]. + pub fn try_access(&self) -> Option<RevocableGuard<'_, T>> { + self.data().try_access() + } + + /// [`Devres`] accessor for [`Revocable::try_access_with`]. + pub fn try_access_with<R, F: FnOnce(&T) -> R>(&self, f: F) -> Option<R> { + self.data().try_access_with(f) + } + + /// [`Devres`] accessor for [`Revocable::try_access_with_guard`]. + pub fn try_access_with_guard<'a>(&'a self, guard: &'a rcu::Guard) -> Option<&'a T> { + self.data().try_access_with_guard(guard) + } +} + +// SAFETY: `Devres` can be send to any task, if `T: Send`. +unsafe impl<T: Send> Send for Devres<T> {} + +// SAFETY: `Devres` can be shared with any task, if `T: Sync`. +unsafe impl<T: Send + Sync> Sync for Devres<T> {} + +#[pinned_drop] +impl<T: Send> PinnedDrop for Devres<T> { + fn drop(self: Pin<&mut Self>) { + // SAFETY: When `drop` runs, it is guaranteed that nobody is accessing the revocable data + // anymore, hence it is safe not to wait for the grace period to finish. + if unsafe { self.data().revoke_nosync() } { + // We revoked `self.data` before the devres action did, hence try to remove it. + if !self.remove_action() { + // We could not remove the devres action, which means that it now runs concurrently, + // hence signal that `self.data` has been revoked by us successfully. + self.inner().revoke.complete_all(); + + // Wait for `Self::devres_callback` to be done using this object. + self.inner().devm.wait_for_completion(); + } + } else { + // `Self::devres_callback` revokes `self.data` for us, hence wait for it to be done + // using this object. + self.inner().devm.wait_for_completion(); + } + + // INVARIANT: At this point it is guaranteed that `inner` can't be accessed any more. + // + // SAFETY: `inner` is valid for dropping. + unsafe { core::ptr::drop_in_place(self.inner.get()) }; + } +} + +/// Consume `data` and [`Drop::drop`] `data` once `dev` is unbound. +fn register_foreign<P>(dev: &Device<Bound>, data: P) -> Result +where + P: ForeignOwnable + Send + 'static, +{ + let ptr = data.into_foreign(); + + #[allow(clippy::missing_safety_doc)] + unsafe extern "C" fn callback<P: ForeignOwnable>(ptr: *mut kernel::ffi::c_void) { + // SAFETY: `ptr` is the pointer to the `ForeignOwnable` leaked above and hence valid. + drop(unsafe { P::from_foreign(ptr.cast()) }); + } + + // SAFETY: + // - `dev.as_raw()` is a pointer to a valid and bound device. + // - `ptr` is a valid pointer the `ForeignOwnable` devres takes ownership of. + to_result(unsafe { + // `devm_add_action_or_reset()` also calls `callback` on failure, such that the + // `ForeignOwnable` is released eventually. + bindings::devm_add_action_or_reset(dev.as_raw(), Some(callback::<P>), ptr.cast()) + }) +} + +/// Encapsulate `data` in a [`KBox`] and [`Drop::drop`] `data` once `dev` is unbound. +/// +/// # Examples +/// +/// ```no_run +/// use kernel::{device::{Bound, Device}, devres}; +/// +/// /// Registration of e.g. a class device, IRQ, etc. +/// struct Registration; +/// +/// impl Registration { +/// fn new() -> Self { +/// // register +/// +/// Self +/// } +/// } +/// +/// impl Drop for Registration { +/// fn drop(&mut self) { +/// // unregister +/// } +/// } +/// +/// fn from_bound_context(dev: &Device<Bound>) -> Result { +/// devres::register(dev, Registration::new(), GFP_KERNEL) +/// } +/// ``` +pub fn register<T, E>(dev: &Device<Bound>, data: impl PinInit<T, E>, flags: Flags) -> Result +where + T: Send + 'static, + Error: From<E>, +{ + let data = KBox::pin_init(data, flags)?; + + register_foreign(dev, data) +} diff --git a/rust/kernel/dma.rs b/rust/kernel/dma.rs new file mode 100644 index 000000000000..84d3c67269e8 --- /dev/null +++ b/rust/kernel/dma.rs @@ -0,0 +1,748 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! Direct memory access (DMA). +//! +//! C header: [`include/linux/dma-mapping.h`](srctree/include/linux/dma-mapping.h) + +use crate::{ + bindings, build_assert, device, + device::{Bound, Core}, + error::{to_result, Result}, + prelude::*, + sync::aref::ARef, + transmute::{AsBytes, FromBytes}, +}; +use core::ptr::NonNull; + +/// DMA address type. +/// +/// Represents a bus address used for Direct Memory Access (DMA) operations. +/// +/// This is an alias of the kernel's `dma_addr_t`, which may be `u32` or `u64` depending on +/// `CONFIG_ARCH_DMA_ADDR_T_64BIT`. +/// +/// Note that this may be `u64` even on 32-bit architectures. +pub type DmaAddress = bindings::dma_addr_t; + +/// Trait to be implemented by DMA capable bus devices. +/// +/// The [`dma::Device`](Device) trait should be implemented by bus specific device representations, +/// where the underlying bus is DMA capable, such as [`pci::Device`](::kernel::pci::Device) or +/// [`platform::Device`](::kernel::platform::Device). +pub trait Device: AsRef<device::Device<Core>> { + /// Set up the device's DMA streaming addressing capabilities. + /// + /// This method is usually called once from `probe()` as soon as the device capabilities are + /// known. + /// + /// # Safety + /// + /// This method must not be called concurrently with any DMA allocation or mapping primitives, + /// such as [`CoherentAllocation::alloc_attrs`]. + unsafe fn dma_set_mask(&self, mask: DmaMask) -> Result { + // SAFETY: + // - By the type invariant of `device::Device`, `self.as_ref().as_raw()` is valid. + // - The safety requirement of this function guarantees that there are no concurrent calls + // to DMA allocation and mapping primitives using this mask. + to_result(unsafe { bindings::dma_set_mask(self.as_ref().as_raw(), mask.value()) }) + } + + /// Set up the device's DMA coherent addressing capabilities. + /// + /// This method is usually called once from `probe()` as soon as the device capabilities are + /// known. + /// + /// # Safety + /// + /// This method must not be called concurrently with any DMA allocation or mapping primitives, + /// such as [`CoherentAllocation::alloc_attrs`]. + unsafe fn dma_set_coherent_mask(&self, mask: DmaMask) -> Result { + // SAFETY: + // - By the type invariant of `device::Device`, `self.as_ref().as_raw()` is valid. + // - The safety requirement of this function guarantees that there are no concurrent calls + // to DMA allocation and mapping primitives using this mask. + to_result(unsafe { bindings::dma_set_coherent_mask(self.as_ref().as_raw(), mask.value()) }) + } + + /// Set up the device's DMA addressing capabilities. + /// + /// This is a combination of [`Device::dma_set_mask`] and [`Device::dma_set_coherent_mask`]. + /// + /// This method is usually called once from `probe()` as soon as the device capabilities are + /// known. + /// + /// # Safety + /// + /// This method must not be called concurrently with any DMA allocation or mapping primitives, + /// such as [`CoherentAllocation::alloc_attrs`]. + unsafe fn dma_set_mask_and_coherent(&self, mask: DmaMask) -> Result { + // SAFETY: + // - By the type invariant of `device::Device`, `self.as_ref().as_raw()` is valid. + // - The safety requirement of this function guarantees that there are no concurrent calls + // to DMA allocation and mapping primitives using this mask. + to_result(unsafe { + bindings::dma_set_mask_and_coherent(self.as_ref().as_raw(), mask.value()) + }) + } +} + +/// A DMA mask that holds a bitmask with the lowest `n` bits set. +/// +/// Use [`DmaMask::new`] or [`DmaMask::try_new`] to construct a value. Values +/// are guaranteed to never exceed the bit width of `u64`. +/// +/// This is the Rust equivalent of the C macro `DMA_BIT_MASK()`. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct DmaMask(u64); + +impl DmaMask { + /// Constructs a `DmaMask` with the lowest `n` bits set to `1`. + /// + /// For `n <= 64`, sets exactly the lowest `n` bits. + /// For `n > 64`, results in a build error. + /// + /// # Examples + /// + /// ``` + /// use kernel::dma::DmaMask; + /// + /// let mask0 = DmaMask::new::<0>(); + /// assert_eq!(mask0.value(), 0); + /// + /// let mask1 = DmaMask::new::<1>(); + /// assert_eq!(mask1.value(), 0b1); + /// + /// let mask64 = DmaMask::new::<64>(); + /// assert_eq!(mask64.value(), u64::MAX); + /// + /// // Build failure. + /// // let mask_overflow = DmaMask::new::<100>(); + /// ``` + #[inline] + pub const fn new<const N: u32>() -> Self { + let Ok(mask) = Self::try_new(N) else { + build_error!("Invalid DMA Mask."); + }; + + mask + } + + /// Constructs a `DmaMask` with the lowest `n` bits set to `1`. + /// + /// For `n <= 64`, sets exactly the lowest `n` bits. + /// For `n > 64`, returns [`EINVAL`]. + /// + /// # Examples + /// + /// ``` + /// use kernel::dma::DmaMask; + /// + /// let mask0 = DmaMask::try_new(0)?; + /// assert_eq!(mask0.value(), 0); + /// + /// let mask1 = DmaMask::try_new(1)?; + /// assert_eq!(mask1.value(), 0b1); + /// + /// let mask64 = DmaMask::try_new(64)?; + /// assert_eq!(mask64.value(), u64::MAX); + /// + /// let mask_overflow = DmaMask::try_new(100); + /// assert!(mask_overflow.is_err()); + /// # Ok::<(), Error>(()) + /// ``` + #[inline] + pub const fn try_new(n: u32) -> Result<Self> { + Ok(Self(match n { + 0 => 0, + 1..=64 => u64::MAX >> (64 - n), + _ => return Err(EINVAL), + })) + } + + /// Returns the underlying `u64` bitmask value. + #[inline] + pub const fn value(&self) -> u64 { + self.0 + } +} + +/// Possible attributes associated with a DMA mapping. +/// +/// They can be combined with the operators `|`, `&`, and `!`. +/// +/// Values can be used from the [`attrs`] module. +/// +/// # Examples +/// +/// ``` +/// # use kernel::device::{Bound, Device}; +/// use kernel::dma::{attrs::*, CoherentAllocation}; +/// +/// # fn test(dev: &Device<Bound>) -> Result { +/// let attribs = DMA_ATTR_FORCE_CONTIGUOUS | DMA_ATTR_NO_WARN; +/// let c: CoherentAllocation<u64> = +/// CoherentAllocation::alloc_attrs(dev, 4, GFP_KERNEL, attribs)?; +/// # Ok::<(), Error>(()) } +/// ``` +#[derive(Clone, Copy, PartialEq)] +#[repr(transparent)] +pub struct Attrs(u32); + +impl Attrs { + /// Get the raw representation of this attribute. + pub(crate) fn as_raw(self) -> crate::ffi::c_ulong { + self.0 as crate::ffi::c_ulong + } + + /// Check whether `flags` is contained in `self`. + pub fn contains(self, flags: Attrs) -> bool { + (self & flags) == flags + } +} + +impl core::ops::BitOr for Attrs { + type Output = Self; + fn bitor(self, rhs: Self) -> Self::Output { + Self(self.0 | rhs.0) + } +} + +impl core::ops::BitAnd for Attrs { + type Output = Self; + fn bitand(self, rhs: Self) -> Self::Output { + Self(self.0 & rhs.0) + } +} + +impl core::ops::Not for Attrs { + type Output = Self; + fn not(self) -> Self::Output { + Self(!self.0) + } +} + +/// DMA mapping attributes. +pub mod attrs { + use super::Attrs; + + /// Specifies that reads and writes to the mapping may be weakly ordered, that is that reads + /// and writes may pass each other. + pub const DMA_ATTR_WEAK_ORDERING: Attrs = Attrs(bindings::DMA_ATTR_WEAK_ORDERING); + + /// Specifies that writes to the mapping may be buffered to improve performance. + pub const DMA_ATTR_WRITE_COMBINE: Attrs = Attrs(bindings::DMA_ATTR_WRITE_COMBINE); + + /// Lets the platform to avoid creating a kernel virtual mapping for the allocated buffer. + pub const DMA_ATTR_NO_KERNEL_MAPPING: Attrs = Attrs(bindings::DMA_ATTR_NO_KERNEL_MAPPING); + + /// Allows platform code to skip synchronization of the CPU cache for the given buffer assuming + /// that it has been already transferred to 'device' domain. + pub const DMA_ATTR_SKIP_CPU_SYNC: Attrs = Attrs(bindings::DMA_ATTR_SKIP_CPU_SYNC); + + /// Forces contiguous allocation of the buffer in physical memory. + pub const DMA_ATTR_FORCE_CONTIGUOUS: Attrs = Attrs(bindings::DMA_ATTR_FORCE_CONTIGUOUS); + + /// Hints DMA-mapping subsystem that it's probably not worth the time to try + /// to allocate memory to in a way that gives better TLB efficiency. + pub const DMA_ATTR_ALLOC_SINGLE_PAGES: Attrs = Attrs(bindings::DMA_ATTR_ALLOC_SINGLE_PAGES); + + /// This tells the DMA-mapping subsystem to suppress allocation failure reports (similarly to + /// `__GFP_NOWARN`). + pub const DMA_ATTR_NO_WARN: Attrs = Attrs(bindings::DMA_ATTR_NO_WARN); + + /// Indicates that the buffer is fully accessible at an elevated privilege level (and + /// ideally inaccessible or at least read-only at lesser-privileged levels). + pub const DMA_ATTR_PRIVILEGED: Attrs = Attrs(bindings::DMA_ATTR_PRIVILEGED); + + /// Indicates that the buffer is MMIO memory. + pub const DMA_ATTR_MMIO: Attrs = Attrs(bindings::DMA_ATTR_MMIO); +} + +/// DMA data direction. +/// +/// Corresponds to the C [`enum dma_data_direction`]. +/// +/// [`enum dma_data_direction`]: srctree/include/linux/dma-direction.h +#[derive(Copy, Clone, PartialEq, Eq, Debug)] +#[repr(u32)] +pub enum DataDirection { + /// The DMA mapping is for bidirectional data transfer. + /// + /// This is used when the buffer can be both read from and written to by the device. + /// The cache for the corresponding memory region is both flushed and invalidated. + Bidirectional = Self::const_cast(bindings::dma_data_direction_DMA_BIDIRECTIONAL), + + /// The DMA mapping is for data transfer from memory to the device (write). + /// + /// The CPU has prepared data in the buffer, and the device will read it. + /// The cache for the corresponding memory region is flushed before device access. + ToDevice = Self::const_cast(bindings::dma_data_direction_DMA_TO_DEVICE), + + /// The DMA mapping is for data transfer from the device to memory (read). + /// + /// The device will write data into the buffer for the CPU to read. + /// The cache for the corresponding memory region is invalidated before CPU access. + FromDevice = Self::const_cast(bindings::dma_data_direction_DMA_FROM_DEVICE), + + /// The DMA mapping is not for data transfer. + /// + /// This is primarily for debugging purposes. With this direction, the DMA mapping API + /// will not perform any cache coherency operations. + None = Self::const_cast(bindings::dma_data_direction_DMA_NONE), +} + +impl DataDirection { + /// Casts the bindgen-generated enum type to a `u32` at compile time. + /// + /// This function will cause a compile-time error if the underlying value of the + /// C enum is out of bounds for `u32`. + const fn const_cast(val: bindings::dma_data_direction) -> u32 { + // CAST: The C standard allows compilers to choose different integer types for enums. + // To safely check the value, we cast it to a wide signed integer type (`i128`) + // which can hold any standard C integer enum type without truncation. + let wide_val = val as i128; + + // Check if the value is outside the valid range for the target type `u32`. + // CAST: `u32::MAX` is cast to `i128` to match the type of `wide_val` for the comparison. + if wide_val < 0 || wide_val > u32::MAX as i128 { + // Trigger a compile-time error in a const context. + build_error!("C enum value is out of bounds for the target type `u32`."); + } + + // CAST: This cast is valid because the check above guarantees that `wide_val` + // is within the representable range of `u32`. + wide_val as u32 + } +} + +impl From<DataDirection> for bindings::dma_data_direction { + /// Returns the raw representation of [`enum dma_data_direction`]. + fn from(direction: DataDirection) -> Self { + // CAST: `direction as u32` gets the underlying representation of our `#[repr(u32)]` enum. + // The subsequent cast to `Self` (the bindgen type) assumes the C enum is compatible + // with the enum variants of `DataDirection`, which is a valid assumption given our + // compile-time checks. + direction as u32 as Self + } +} + +/// An abstraction of the `dma_alloc_coherent` API. +/// +/// This is an abstraction around the `dma_alloc_coherent` API which is used to allocate and map +/// large coherent DMA regions. +/// +/// A [`CoherentAllocation`] instance contains a pointer to the allocated region (in the +/// processor's virtual address space) and the device address which can be given to the device +/// as the DMA address base of the region. The region is released once [`CoherentAllocation`] +/// is dropped. +/// +/// # Invariants +/// +/// - For the lifetime of an instance of [`CoherentAllocation`], the `cpu_addr` is a valid pointer +/// to an allocated region of coherent memory and `dma_handle` is the DMA address base of the +/// region. +/// - The size in bytes of the allocation is equal to `size_of::<T> * count`. +/// - `size_of::<T> * count` fits into a `usize`. +// TODO +// +// DMA allocations potentially carry device resources (e.g.IOMMU mappings), hence for soundness +// reasons DMA allocation would need to be embedded in a `Devres` container, in order to ensure +// that device resources can never survive device unbind. +// +// However, it is neither desirable nor necessary to protect the allocated memory of the DMA +// allocation from surviving device unbind; it would require RCU read side critical sections to +// access the memory, which may require subsequent unnecessary copies. +// +// Hence, find a way to revoke the device resources of a `CoherentAllocation`, but not the +// entire `CoherentAllocation` including the allocated memory itself. +pub struct CoherentAllocation<T: AsBytes + FromBytes> { + dev: ARef<device::Device>, + dma_handle: DmaAddress, + count: usize, + cpu_addr: NonNull<T>, + dma_attrs: Attrs, +} + +impl<T: AsBytes + FromBytes> CoherentAllocation<T> { + /// Allocates a region of `size_of::<T> * count` of coherent memory. + /// + /// # Examples + /// + /// ``` + /// # use kernel::device::{Bound, Device}; + /// use kernel::dma::{attrs::*, CoherentAllocation}; + /// + /// # fn test(dev: &Device<Bound>) -> Result { + /// let c: CoherentAllocation<u64> = + /// CoherentAllocation::alloc_attrs(dev, 4, GFP_KERNEL, DMA_ATTR_NO_WARN)?; + /// # Ok::<(), Error>(()) } + /// ``` + pub fn alloc_attrs( + dev: &device::Device<Bound>, + count: usize, + gfp_flags: kernel::alloc::Flags, + dma_attrs: Attrs, + ) -> Result<CoherentAllocation<T>> { + build_assert!( + core::mem::size_of::<T>() > 0, + "It doesn't make sense for the allocated type to be a ZST" + ); + + let size = count + .checked_mul(core::mem::size_of::<T>()) + .ok_or(EOVERFLOW)?; + let mut dma_handle = 0; + // SAFETY: Device pointer is guaranteed as valid by the type invariant on `Device`. + let addr = unsafe { + bindings::dma_alloc_attrs( + dev.as_raw(), + size, + &mut dma_handle, + gfp_flags.as_raw(), + dma_attrs.as_raw(), + ) + }; + let addr = NonNull::new(addr).ok_or(ENOMEM)?; + // INVARIANT: + // - We just successfully allocated a coherent region which is accessible for + // `count` elements, hence the cpu address is valid. We also hold a refcounted reference + // to the device. + // - The allocated `size` is equal to `size_of::<T> * count`. + // - The allocated `size` fits into a `usize`. + Ok(Self { + dev: dev.into(), + dma_handle, + count, + cpu_addr: addr.cast(), + dma_attrs, + }) + } + + /// Performs the same functionality as [`CoherentAllocation::alloc_attrs`], except the + /// `dma_attrs` is 0 by default. + pub fn alloc_coherent( + dev: &device::Device<Bound>, + count: usize, + gfp_flags: kernel::alloc::Flags, + ) -> Result<CoherentAllocation<T>> { + CoherentAllocation::alloc_attrs(dev, count, gfp_flags, Attrs(0)) + } + + /// Returns the number of elements `T` in this allocation. + /// + /// Note that this is not the size of the allocation in bytes, which is provided by + /// [`Self::size`]. + pub fn count(&self) -> usize { + self.count + } + + /// Returns the size in bytes of this allocation. + pub fn size(&self) -> usize { + // INVARIANT: The type invariant of `Self` guarantees that `size_of::<T> * count` fits into + // a `usize`. + self.count * core::mem::size_of::<T>() + } + + /// Returns the base address to the allocated region in the CPU's virtual address space. + pub fn start_ptr(&self) -> *const T { + self.cpu_addr.as_ptr() + } + + /// Returns the base address to the allocated region in the CPU's virtual address space as + /// a mutable pointer. + pub fn start_ptr_mut(&mut self) -> *mut T { + self.cpu_addr.as_ptr() + } + + /// Returns a DMA handle which may be given to the device as the DMA address base of + /// the region. + pub fn dma_handle(&self) -> DmaAddress { + self.dma_handle + } + + /// Returns a DMA handle starting at `offset` (in units of `T`) which may be given to the + /// device as the DMA address base of the region. + /// + /// Returns `EINVAL` if `offset` is not within the bounds of the allocation. + pub fn dma_handle_with_offset(&self, offset: usize) -> Result<DmaAddress> { + if offset >= self.count { + Err(EINVAL) + } else { + // INVARIANT: The type invariant of `Self` guarantees that `size_of::<T> * count` fits + // into a `usize`, and `offset` is inferior to `count`. + Ok(self.dma_handle + (offset * core::mem::size_of::<T>()) as DmaAddress) + } + } + + /// Common helper to validate a range applied from the allocated region in the CPU's virtual + /// address space. + fn validate_range(&self, offset: usize, count: usize) -> Result { + if offset.checked_add(count).ok_or(EOVERFLOW)? > self.count { + return Err(EINVAL); + } + Ok(()) + } + + /// Returns the data from the region starting from `offset` as a slice. + /// `offset` and `count` are in units of `T`, not the number of bytes. + /// + /// For ringbuffer type of r/w access or use-cases where the pointer to the live data is needed, + /// [`CoherentAllocation::start_ptr`] or [`CoherentAllocation::start_ptr_mut`] could be used + /// instead. + /// + /// # Safety + /// + /// * Callers must ensure that the device does not read/write to/from memory while the returned + /// slice is live. + /// * Callers must ensure that this call does not race with a write to the same region while + /// the returned slice is live. + pub unsafe fn as_slice(&self, offset: usize, count: usize) -> Result<&[T]> { + self.validate_range(offset, count)?; + // SAFETY: + // - The pointer is valid due to type invariant on `CoherentAllocation`, + // we've just checked that the range and index is within bounds. The immutability of the + // data is also guaranteed by the safety requirements of the function. + // - `offset + count` can't overflow since it is smaller than `self.count` and we've checked + // that `self.count` won't overflow early in the constructor. + Ok(unsafe { core::slice::from_raw_parts(self.start_ptr().add(offset), count) }) + } + + /// Performs the same functionality as [`CoherentAllocation::as_slice`], except that a mutable + /// slice is returned. + /// + /// # Safety + /// + /// * Callers must ensure that the device does not read/write to/from memory while the returned + /// slice is live. + /// * Callers must ensure that this call does not race with a read or write to the same region + /// while the returned slice is live. + pub unsafe fn as_slice_mut(&mut self, offset: usize, count: usize) -> Result<&mut [T]> { + self.validate_range(offset, count)?; + // SAFETY: + // - The pointer is valid due to type invariant on `CoherentAllocation`, + // we've just checked that the range and index is within bounds. The immutability of the + // data is also guaranteed by the safety requirements of the function. + // - `offset + count` can't overflow since it is smaller than `self.count` and we've checked + // that `self.count` won't overflow early in the constructor. + Ok(unsafe { core::slice::from_raw_parts_mut(self.start_ptr_mut().add(offset), count) }) + } + + /// Writes data to the region starting from `offset`. `offset` is in units of `T`, not the + /// number of bytes. + /// + /// # Safety + /// + /// * Callers must ensure that the device does not read/write to/from memory while the returned + /// slice is live. + /// * Callers must ensure that this call does not race with a read or write to the same region + /// that overlaps with this write. + /// + /// # Examples + /// + /// ``` + /// # fn test(alloc: &mut kernel::dma::CoherentAllocation<u8>) -> Result { + /// let somedata: [u8; 4] = [0xf; 4]; + /// let buf: &[u8] = &somedata; + /// // SAFETY: There is no concurrent HW operation on the device and no other R/W access to the + /// // region. + /// unsafe { alloc.write(buf, 0)?; } + /// # Ok::<(), Error>(()) } + /// ``` + pub unsafe fn write(&mut self, src: &[T], offset: usize) -> Result { + self.validate_range(offset, src.len())?; + // SAFETY: + // - The pointer is valid due to type invariant on `CoherentAllocation` + // and we've just checked that the range and index is within bounds. + // - `offset + count` can't overflow since it is smaller than `self.count` and we've checked + // that `self.count` won't overflow early in the constructor. + unsafe { + core::ptr::copy_nonoverlapping( + src.as_ptr(), + self.start_ptr_mut().add(offset), + src.len(), + ) + }; + Ok(()) + } + + /// Returns a pointer to an element from the region with bounds checking. `offset` is in + /// units of `T`, not the number of bytes. + /// + /// Public but hidden since it should only be used from [`dma_read`] and [`dma_write`] macros. + #[doc(hidden)] + pub fn item_from_index(&self, offset: usize) -> Result<*mut T> { + if offset >= self.count { + return Err(EINVAL); + } + // SAFETY: + // - The pointer is valid due to type invariant on `CoherentAllocation` + // and we've just checked that the range and index is within bounds. + // - `offset` can't overflow since it is smaller than `self.count` and we've checked + // that `self.count` won't overflow early in the constructor. + Ok(unsafe { self.cpu_addr.as_ptr().add(offset) }) + } + + /// Reads the value of `field` and ensures that its type is [`FromBytes`]. + /// + /// # Safety + /// + /// This must be called from the [`dma_read`] macro which ensures that the `field` pointer is + /// validated beforehand. + /// + /// Public but hidden since it should only be used from [`dma_read`] macro. + #[doc(hidden)] + pub unsafe fn field_read<F: FromBytes>(&self, field: *const F) -> F { + // SAFETY: + // - By the safety requirements field is valid. + // - Using read_volatile() here is not sound as per the usual rules, the usage here is + // a special exception with the following notes in place. When dealing with a potential + // race from a hardware or code outside kernel (e.g. user-space program), we need that + // read on a valid memory is not UB. Currently read_volatile() is used for this, and the + // rationale behind is that it should generate the same code as READ_ONCE() which the + // kernel already relies on to avoid UB on data races. Note that the usage of + // read_volatile() is limited to this particular case, it cannot be used to prevent + // the UB caused by racing between two kernel functions nor do they provide atomicity. + unsafe { field.read_volatile() } + } + + /// Writes a value to `field` and ensures that its type is [`AsBytes`]. + /// + /// # Safety + /// + /// This must be called from the [`dma_write`] macro which ensures that the `field` pointer is + /// validated beforehand. + /// + /// Public but hidden since it should only be used from [`dma_write`] macro. + #[doc(hidden)] + pub unsafe fn field_write<F: AsBytes>(&self, field: *mut F, val: F) { + // SAFETY: + // - By the safety requirements field is valid. + // - Using write_volatile() here is not sound as per the usual rules, the usage here is + // a special exception with the following notes in place. When dealing with a potential + // race from a hardware or code outside kernel (e.g. user-space program), we need that + // write on a valid memory is not UB. Currently write_volatile() is used for this, and the + // rationale behind is that it should generate the same code as WRITE_ONCE() which the + // kernel already relies on to avoid UB on data races. Note that the usage of + // write_volatile() is limited to this particular case, it cannot be used to prevent + // the UB caused by racing between two kernel functions nor do they provide atomicity. + unsafe { field.write_volatile(val) } + } +} + +/// Note that the device configured to do DMA must be halted before this object is dropped. +impl<T: AsBytes + FromBytes> Drop for CoherentAllocation<T> { + fn drop(&mut self) { + let size = self.count * core::mem::size_of::<T>(); + // SAFETY: Device pointer is guaranteed as valid by the type invariant on `Device`. + // The cpu address, and the dma handle are valid due to the type invariants on + // `CoherentAllocation`. + unsafe { + bindings::dma_free_attrs( + self.dev.as_raw(), + size, + self.start_ptr_mut().cast(), + self.dma_handle, + self.dma_attrs.as_raw(), + ) + } + } +} + +// SAFETY: It is safe to send a `CoherentAllocation` to another thread if `T` +// can be sent to another thread. +unsafe impl<T: AsBytes + FromBytes + Send> Send for CoherentAllocation<T> {} + +/// Reads a field of an item from an allocated region of structs. +/// +/// # Examples +/// +/// ``` +/// use kernel::device::Device; +/// use kernel::dma::{attrs::*, CoherentAllocation}; +/// +/// struct MyStruct { field: u32, } +/// +/// // SAFETY: All bit patterns are acceptable values for `MyStruct`. +/// unsafe impl kernel::transmute::FromBytes for MyStruct{}; +/// // SAFETY: Instances of `MyStruct` have no uninitialized portions. +/// unsafe impl kernel::transmute::AsBytes for MyStruct{}; +/// +/// # fn test(alloc: &kernel::dma::CoherentAllocation<MyStruct>) -> Result { +/// let whole = kernel::dma_read!(alloc[2]); +/// let field = kernel::dma_read!(alloc[1].field); +/// # Ok::<(), Error>(()) } +/// ``` +#[macro_export] +macro_rules! dma_read { + ($dma:expr, $idx: expr, $($field:tt)*) => {{ + (|| -> ::core::result::Result<_, $crate::error::Error> { + let item = $crate::dma::CoherentAllocation::item_from_index(&$dma, $idx)?; + // SAFETY: `item_from_index` ensures that `item` is always a valid pointer and can be + // dereferenced. The compiler also further validates the expression on whether `field` + // is a member of `item` when expanded by the macro. + unsafe { + let ptr_field = ::core::ptr::addr_of!((*item) $($field)*); + ::core::result::Result::Ok( + $crate::dma::CoherentAllocation::field_read(&$dma, ptr_field) + ) + } + })() + }}; + ($dma:ident [ $idx:expr ] $($field:tt)* ) => { + $crate::dma_read!($dma, $idx, $($field)*) + }; + ($($dma:ident).* [ $idx:expr ] $($field:tt)* ) => { + $crate::dma_read!($($dma).*, $idx, $($field)*) + }; +} + +/// Writes to a field of an item from an allocated region of structs. +/// +/// # Examples +/// +/// ``` +/// use kernel::device::Device; +/// use kernel::dma::{attrs::*, CoherentAllocation}; +/// +/// struct MyStruct { member: u32, } +/// +/// // SAFETY: All bit patterns are acceptable values for `MyStruct`. +/// unsafe impl kernel::transmute::FromBytes for MyStruct{}; +/// // SAFETY: Instances of `MyStruct` have no uninitialized portions. +/// unsafe impl kernel::transmute::AsBytes for MyStruct{}; +/// +/// # fn test(alloc: &kernel::dma::CoherentAllocation<MyStruct>) -> Result { +/// kernel::dma_write!(alloc[2].member = 0xf); +/// kernel::dma_write!(alloc[1] = MyStruct { member: 0xf }); +/// # Ok::<(), Error>(()) } +/// ``` +#[macro_export] +macro_rules! dma_write { + ($dma:ident [ $idx:expr ] $($field:tt)*) => {{ + $crate::dma_write!($dma, $idx, $($field)*) + }}; + ($($dma:ident).* [ $idx:expr ] $($field:tt)* ) => {{ + $crate::dma_write!($($dma).*, $idx, $($field)*) + }}; + ($dma:expr, $idx: expr, = $val:expr) => { + (|| -> ::core::result::Result<_, $crate::error::Error> { + let item = $crate::dma::CoherentAllocation::item_from_index(&$dma, $idx)?; + // SAFETY: `item_from_index` ensures that `item` is always a valid item. + unsafe { $crate::dma::CoherentAllocation::field_write(&$dma, item, $val) } + ::core::result::Result::Ok(()) + })() + }; + ($dma:expr, $idx: expr, $(.$field:ident)* = $val:expr) => { + (|| -> ::core::result::Result<_, $crate::error::Error> { + let item = $crate::dma::CoherentAllocation::item_from_index(&$dma, $idx)?; + // SAFETY: `item_from_index` ensures that `item` is always a valid pointer and can be + // dereferenced. The compiler also further validates the expression on whether `field` + // is a member of `item` when expanded by the macro. + unsafe { + let ptr_field = ::core::ptr::addr_of_mut!((*item) $(.$field)*); + $crate::dma::CoherentAllocation::field_write(&$dma, ptr_field, $val) + } + ::core::result::Result::Ok(()) + })() + }; +} diff --git a/rust/kernel/driver.rs b/rust/kernel/driver.rs new file mode 100644 index 000000000000..9beae2e3d57e --- /dev/null +++ b/rust/kernel/driver.rs @@ -0,0 +1,318 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! Generic support for drivers of different buses (e.g., PCI, Platform, Amba, etc.). +//! +//! This documentation describes how to implement a bus specific driver API and how to align it with +//! the design of (bus specific) devices. +//! +//! Note: Readers are expected to know the content of the documentation of [`Device`] and +//! [`DeviceContext`]. +//! +//! # Driver Trait +//! +//! The main driver interface is defined by a bus specific driver trait. For instance: +//! +//! ```ignore +//! pub trait Driver: Send { +//! /// The type holding information about each device ID supported by the driver. +//! type IdInfo: 'static; +//! +//! /// The table of OF device ids supported by the driver. +//! const OF_ID_TABLE: Option<of::IdTable<Self::IdInfo>> = None; +//! +//! /// The table of ACPI device ids supported by the driver. +//! const ACPI_ID_TABLE: Option<acpi::IdTable<Self::IdInfo>> = None; +//! +//! /// Driver probe. +//! fn probe(dev: &Device<device::Core>, id_info: &Self::IdInfo) -> impl PinInit<Self, Error>; +//! +//! /// Driver unbind (optional). +//! fn unbind(dev: &Device<device::Core>, this: Pin<&Self>) { +//! let _ = (dev, this); +//! } +//! } +//! ``` +//! +//! For specific examples see [`auxiliary::Driver`], [`pci::Driver`] and [`platform::Driver`]. +//! +//! The `probe()` callback should return a `impl PinInit<Self, Error>`, i.e. the driver's private +//! data. The bus abstraction should store the pointer in the corresponding bus device. The generic +//! [`Device`] infrastructure provides common helpers for this purpose on its +//! [`Device<CoreInternal>`] implementation. +//! +//! All driver callbacks should provide a reference to the driver's private data. Once the driver +//! is unbound from the device, the bus abstraction should take back the ownership of the driver's +//! private data from the corresponding [`Device`] and [`drop`] it. +//! +//! All driver callbacks should provide a [`Device<Core>`] reference (see also [`device::Core`]). +//! +//! # Adapter +//! +//! The adapter implementation of a bus represents the abstraction layer between the C bus +//! callbacks and the Rust bus callbacks. It therefore has to be generic over an implementation of +//! the [driver trait](#driver-trait). +//! +//! ```ignore +//! pub struct Adapter<T: Driver>; +//! ``` +//! +//! There's a common [`Adapter`] trait that can be implemented to inherit common driver +//! infrastructure, such as finding the ID info from an [`of::IdTable`] or [`acpi::IdTable`]. +//! +//! # Driver Registration +//! +//! In order to register C driver types (such as `struct platform_driver`) the [adapter](#adapter) +//! should implement the [`RegistrationOps`] trait. +//! +//! This trait implementation can be used to create the actual registration with the common +//! [`Registration`] type. +//! +//! Typically, bus abstractions want to provide a bus specific `module_bus_driver!` macro, which +//! creates a kernel module with exactly one [`Registration`] for the bus specific adapter. +//! +//! The generic driver infrastructure provides a helper for this with the [`module_driver`] macro. +//! +//! # Device IDs +//! +//! Besides the common device ID types, such as [`of::DeviceId`] and [`acpi::DeviceId`], most buses +//! may need to implement their own device ID types. +//! +//! For this purpose the generic infrastructure in [`device_id`] should be used. +//! +//! [`auxiliary::Driver`]: kernel::auxiliary::Driver +//! [`Core`]: device::Core +//! [`Device`]: device::Device +//! [`Device<Core>`]: device::Device<device::Core> +//! [`Device<CoreInternal>`]: device::Device<device::CoreInternal> +//! [`DeviceContext`]: device::DeviceContext +//! [`device_id`]: kernel::device_id +//! [`module_driver`]: kernel::module_driver +//! [`pci::Driver`]: kernel::pci::Driver +//! [`platform::Driver`]: kernel::platform::Driver + +use crate::error::{Error, Result}; +use crate::{acpi, device, of, str::CStr, try_pin_init, types::Opaque, ThisModule}; +use core::pin::Pin; +use pin_init::{pin_data, pinned_drop, PinInit}; + +/// The [`RegistrationOps`] trait serves as generic interface for subsystems (e.g., PCI, Platform, +/// Amba, etc.) to provide the corresponding subsystem specific implementation to register / +/// unregister a driver of the particular type (`RegType`). +/// +/// For instance, the PCI subsystem would set `RegType` to `bindings::pci_driver` and call +/// `bindings::__pci_register_driver` from `RegistrationOps::register` and +/// `bindings::pci_unregister_driver` from `RegistrationOps::unregister`. +/// +/// # Safety +/// +/// A call to [`RegistrationOps::unregister`] for a given instance of `RegType` is only valid if a +/// preceding call to [`RegistrationOps::register`] has been successful. +pub unsafe trait RegistrationOps { + /// The type that holds information about the registration. This is typically a struct defined + /// by the C portion of the kernel. + type RegType: Default; + + /// Registers a driver. + /// + /// # Safety + /// + /// On success, `reg` must remain pinned and valid until the matching call to + /// [`RegistrationOps::unregister`]. + unsafe fn register( + reg: &Opaque<Self::RegType>, + name: &'static CStr, + module: &'static ThisModule, + ) -> Result; + + /// Unregisters a driver previously registered with [`RegistrationOps::register`]. + /// + /// # Safety + /// + /// Must only be called after a preceding successful call to [`RegistrationOps::register`] for + /// the same `reg`. + unsafe fn unregister(reg: &Opaque<Self::RegType>); +} + +/// A [`Registration`] is a generic type that represents the registration of some driver type (e.g. +/// `bindings::pci_driver`). Therefore a [`Registration`] must be initialized with a type that +/// implements the [`RegistrationOps`] trait, such that the generic `T::register` and +/// `T::unregister` calls result in the subsystem specific registration calls. +/// +///Once the `Registration` structure is dropped, the driver is unregistered. +#[pin_data(PinnedDrop)] +pub struct Registration<T: RegistrationOps> { + #[pin] + reg: Opaque<T::RegType>, +} + +// SAFETY: `Registration` has no fields or methods accessible via `&Registration`, so it is safe to +// share references to it with multiple threads as nothing can be done. +unsafe impl<T: RegistrationOps> Sync for Registration<T> {} + +// SAFETY: Both registration and unregistration are implemented in C and safe to be performed from +// any thread, so `Registration` is `Send`. +unsafe impl<T: RegistrationOps> Send for Registration<T> {} + +impl<T: RegistrationOps> Registration<T> { + /// Creates a new instance of the registration object. + pub fn new(name: &'static CStr, module: &'static ThisModule) -> impl PinInit<Self, Error> { + try_pin_init!(Self { + reg <- Opaque::try_ffi_init(|ptr: *mut T::RegType| { + // SAFETY: `try_ffi_init` guarantees that `ptr` is valid for write. + unsafe { ptr.write(T::RegType::default()) }; + + // SAFETY: `try_ffi_init` guarantees that `ptr` is valid for write, and it has + // just been initialised above, so it's also valid for read. + let drv = unsafe { &*(ptr as *const Opaque<T::RegType>) }; + + // SAFETY: `drv` is guaranteed to be pinned until `T::unregister`. + unsafe { T::register(drv, name, module) } + }), + }) + } +} + +#[pinned_drop] +impl<T: RegistrationOps> PinnedDrop for Registration<T> { + fn drop(self: Pin<&mut Self>) { + // SAFETY: The existence of `self` guarantees that `self.reg` has previously been + // successfully registered with `T::register` + unsafe { T::unregister(&self.reg) }; + } +} + +/// Declares a kernel module that exposes a single driver. +/// +/// It is meant to be used as a helper by other subsystems so they can more easily expose their own +/// macros. +#[macro_export] +macro_rules! module_driver { + (<$gen_type:ident>, $driver_ops:ty, { type: $type:ty, $($f:tt)* }) => { + type Ops<$gen_type> = $driver_ops; + + #[$crate::prelude::pin_data] + struct DriverModule { + #[pin] + _driver: $crate::driver::Registration<Ops<$type>>, + } + + impl $crate::InPlaceModule for DriverModule { + fn init( + module: &'static $crate::ThisModule + ) -> impl ::pin_init::PinInit<Self, $crate::error::Error> { + $crate::try_pin_init!(Self { + _driver <- $crate::driver::Registration::new( + <Self as $crate::ModuleMetadata>::NAME, + module, + ), + }) + } + } + + $crate::prelude::module! { + type: DriverModule, + $($f)* + } + } +} + +/// The bus independent adapter to match a drivers and a devices. +/// +/// This trait should be implemented by the bus specific adapter, which represents the connection +/// of a device and a driver. +/// +/// It provides bus independent functions for device / driver interactions. +pub trait Adapter { + /// The type holding driver private data about each device id supported by the driver. + type IdInfo: 'static; + + /// The [`acpi::IdTable`] of the corresponding driver + fn acpi_id_table() -> Option<acpi::IdTable<Self::IdInfo>>; + + /// Returns the driver's private data from the matching entry in the [`acpi::IdTable`], if any. + /// + /// If this returns `None`, it means there is no match with an entry in the [`acpi::IdTable`]. + fn acpi_id_info(dev: &device::Device) -> Option<&'static Self::IdInfo> { + #[cfg(not(CONFIG_ACPI))] + { + let _ = dev; + None + } + + #[cfg(CONFIG_ACPI)] + { + let table = Self::acpi_id_table()?; + + // SAFETY: + // - `table` has static lifetime, hence it's valid for read, + // - `dev` is guaranteed to be valid while it's alive, and so is `dev.as_raw()`. + let raw_id = unsafe { bindings::acpi_match_device(table.as_ptr(), dev.as_raw()) }; + + if raw_id.is_null() { + None + } else { + // SAFETY: `DeviceId` is a `#[repr(transparent)]` wrapper of `struct acpi_device_id` + // and does not add additional invariants, so it's safe to transmute. + let id = unsafe { &*raw_id.cast::<acpi::DeviceId>() }; + + Some(table.info(<acpi::DeviceId as crate::device_id::RawDeviceIdIndex>::index(id))) + } + } + } + + /// The [`of::IdTable`] of the corresponding driver. + fn of_id_table() -> Option<of::IdTable<Self::IdInfo>>; + + /// Returns the driver's private data from the matching entry in the [`of::IdTable`], if any. + /// + /// If this returns `None`, it means there is no match with an entry in the [`of::IdTable`]. + fn of_id_info(dev: &device::Device) -> Option<&'static Self::IdInfo> { + #[cfg(not(CONFIG_OF))] + { + let _ = dev; + None + } + + #[cfg(CONFIG_OF)] + { + let table = Self::of_id_table()?; + + // SAFETY: + // - `table` has static lifetime, hence it's valid for read, + // - `dev` is guaranteed to be valid while it's alive, and so is `dev.as_raw()`. + let raw_id = unsafe { bindings::of_match_device(table.as_ptr(), dev.as_raw()) }; + + if raw_id.is_null() { + None + } else { + // SAFETY: `DeviceId` is a `#[repr(transparent)]` wrapper of `struct of_device_id` + // and does not add additional invariants, so it's safe to transmute. + let id = unsafe { &*raw_id.cast::<of::DeviceId>() }; + + Some( + table.info(<of::DeviceId as crate::device_id::RawDeviceIdIndex>::index( + id, + )), + ) + } + } + } + + /// Returns the driver's private data from the matching entry of any of the ID tables, if any. + /// + /// If this returns `None`, it means that there is no match in any of the ID tables directly + /// associated with a [`device::Device`]. + fn id_info(dev: &device::Device) -> Option<&'static Self::IdInfo> { + let id = Self::acpi_id_info(dev); + if id.is_some() { + return id; + } + + let id = Self::of_id_info(dev); + if id.is_some() { + return id; + } + + None + } +} diff --git a/rust/kernel/drm/device.rs b/rust/kernel/drm/device.rs new file mode 100644 index 000000000000..3ce8f62a0056 --- /dev/null +++ b/rust/kernel/drm/device.rs @@ -0,0 +1,229 @@ +// SPDX-License-Identifier: GPL-2.0 OR MIT + +//! DRM device. +//! +//! C header: [`include/drm/drm_device.h`](srctree/include/drm/drm_device.h) + +use crate::{ + alloc::allocator::Kmalloc, + bindings, device, drm, + drm::driver::AllocImpl, + error::from_err_ptr, + error::Result, + prelude::*, + sync::aref::{ARef, AlwaysRefCounted}, + types::Opaque, +}; +use core::{alloc::Layout, mem, ops::Deref, ptr, ptr::NonNull}; + +#[cfg(CONFIG_DRM_LEGACY)] +macro_rules! drm_legacy_fields { + ( $($field:ident: $val:expr),* $(,)? ) => { + bindings::drm_driver { + $( $field: $val ),*, + firstopen: None, + preclose: None, + dma_ioctl: None, + dma_quiescent: None, + context_dtor: None, + irq_handler: None, + irq_preinstall: None, + irq_postinstall: None, + irq_uninstall: None, + get_vblank_counter: None, + enable_vblank: None, + disable_vblank: None, + dev_priv_size: 0, + } + } +} + +#[cfg(not(CONFIG_DRM_LEGACY))] +macro_rules! drm_legacy_fields { + ( $($field:ident: $val:expr),* $(,)? ) => { + bindings::drm_driver { + $( $field: $val ),* + } + } +} + +/// A typed DRM device with a specific `drm::Driver` implementation. +/// +/// The device is always reference-counted. +/// +/// # Invariants +/// +/// `self.dev` is a valid instance of a `struct device`. +#[repr(C)] +pub struct Device<T: drm::Driver> { + dev: Opaque<bindings::drm_device>, + data: T::Data, +} + +impl<T: drm::Driver> Device<T> { + const VTABLE: bindings::drm_driver = drm_legacy_fields! { + load: None, + open: Some(drm::File::<T::File>::open_callback), + postclose: Some(drm::File::<T::File>::postclose_callback), + unload: None, + release: Some(Self::release), + master_set: None, + master_drop: None, + debugfs_init: None, + gem_create_object: T::Object::ALLOC_OPS.gem_create_object, + prime_handle_to_fd: T::Object::ALLOC_OPS.prime_handle_to_fd, + prime_fd_to_handle: T::Object::ALLOC_OPS.prime_fd_to_handle, + gem_prime_import: T::Object::ALLOC_OPS.gem_prime_import, + gem_prime_import_sg_table: T::Object::ALLOC_OPS.gem_prime_import_sg_table, + dumb_create: T::Object::ALLOC_OPS.dumb_create, + dumb_map_offset: T::Object::ALLOC_OPS.dumb_map_offset, + show_fdinfo: None, + fbdev_probe: None, + + major: T::INFO.major, + minor: T::INFO.minor, + patchlevel: T::INFO.patchlevel, + name: crate::str::as_char_ptr_in_const_context(T::INFO.name).cast_mut(), + desc: crate::str::as_char_ptr_in_const_context(T::INFO.desc).cast_mut(), + + driver_features: drm::driver::FEAT_GEM, + ioctls: T::IOCTLS.as_ptr(), + num_ioctls: T::IOCTLS.len() as i32, + fops: &Self::GEM_FOPS, + }; + + const GEM_FOPS: bindings::file_operations = drm::gem::create_fops(); + + /// Create a new `drm::Device` for a `drm::Driver`. + pub fn new(dev: &device::Device, data: impl PinInit<T::Data, Error>) -> Result<ARef<Self>> { + // `__drm_dev_alloc` uses `kmalloc()` to allocate memory, hence ensure a `kmalloc()` + // compatible `Layout`. + let layout = Kmalloc::aligned_layout(Layout::new::<Self>()); + + // SAFETY: + // - `VTABLE`, as a `const` is pinned to the read-only section of the compilation, + // - `dev` is valid by its type invarants, + let raw_drm: *mut Self = unsafe { + bindings::__drm_dev_alloc( + dev.as_raw(), + &Self::VTABLE, + layout.size(), + mem::offset_of!(Self, dev), + ) + } + .cast(); + let raw_drm = NonNull::new(from_err_ptr(raw_drm)?).ok_or(ENOMEM)?; + + // SAFETY: `raw_drm` is a valid pointer to `Self`. + let raw_data = unsafe { ptr::addr_of_mut!((*raw_drm.as_ptr()).data) }; + + // SAFETY: + // - `raw_data` is a valid pointer to uninitialized memory. + // - `raw_data` will not move until it is dropped. + unsafe { data.__pinned_init(raw_data) }.inspect_err(|_| { + // SAFETY: `raw_drm` is a valid pointer to `Self`, given that `__drm_dev_alloc` was + // successful. + let drm_dev = unsafe { Self::into_drm_device(raw_drm) }; + + // SAFETY: `__drm_dev_alloc()` was successful, hence `drm_dev` must be valid and the + // refcount must be non-zero. + unsafe { bindings::drm_dev_put(drm_dev) }; + })?; + + // SAFETY: The reference count is one, and now we take ownership of that reference as a + // `drm::Device`. + Ok(unsafe { ARef::from_raw(raw_drm) }) + } + + pub(crate) fn as_raw(&self) -> *mut bindings::drm_device { + self.dev.get() + } + + /// # Safety + /// + /// `ptr` must be a valid pointer to a `struct device` embedded in `Self`. + unsafe fn from_drm_device(ptr: *const bindings::drm_device) -> *mut Self { + // SAFETY: By the safety requirements of this function `ptr` is a valid pointer to a + // `struct drm_device` embedded in `Self`. + unsafe { crate::container_of!(Opaque::cast_from(ptr), Self, dev) }.cast_mut() + } + + /// # Safety + /// + /// `ptr` must be a valid pointer to `Self`. + unsafe fn into_drm_device(ptr: NonNull<Self>) -> *mut bindings::drm_device { + // SAFETY: By the safety requirements of this function, `ptr` is a valid pointer to `Self`. + unsafe { &raw mut (*ptr.as_ptr()).dev }.cast() + } + + /// Not intended to be called externally, except via declare_drm_ioctls!() + /// + /// # Safety + /// + /// Callers must ensure that `ptr` is valid, non-null, and has a non-zero reference count, + /// i.e. it must be ensured that the reference count of the C `struct drm_device` `ptr` points + /// to can't drop to zero, for the duration of this function call and the entire duration when + /// the returned reference exists. + /// + /// Additionally, callers must ensure that the `struct device`, `ptr` is pointing to, is + /// embedded in `Self`. + #[doc(hidden)] + pub unsafe fn from_raw<'a>(ptr: *const bindings::drm_device) -> &'a Self { + // SAFETY: By the safety requirements of this function `ptr` is a valid pointer to a + // `struct drm_device` embedded in `Self`. + let ptr = unsafe { Self::from_drm_device(ptr) }; + + // SAFETY: `ptr` is valid by the safety requirements of this function. + unsafe { &*ptr.cast() } + } + + extern "C" fn release(ptr: *mut bindings::drm_device) { + // SAFETY: `ptr` is a valid pointer to a `struct drm_device` and embedded in `Self`. + let this = unsafe { Self::from_drm_device(ptr) }; + + // SAFETY: + // - When `release` runs it is guaranteed that there is no further access to `this`. + // - `this` is valid for dropping. + unsafe { core::ptr::drop_in_place(this) }; + } +} + +impl<T: drm::Driver> Deref for Device<T> { + type Target = T::Data; + + fn deref(&self) -> &Self::Target { + &self.data + } +} + +// SAFETY: DRM device objects are always reference counted and the get/put functions +// satisfy the requirements. +unsafe impl<T: drm::Driver> AlwaysRefCounted for Device<T> { + fn inc_ref(&self) { + // SAFETY: The existence of a shared reference guarantees that the refcount is non-zero. + unsafe { bindings::drm_dev_get(self.as_raw()) }; + } + + unsafe fn dec_ref(obj: NonNull<Self>) { + // SAFETY: `obj` is a valid pointer to `Self`. + let drm_dev = unsafe { Self::into_drm_device(obj) }; + + // SAFETY: The safety requirements guarantee that the refcount is non-zero. + unsafe { bindings::drm_dev_put(drm_dev) }; + } +} + +impl<T: drm::Driver> AsRef<device::Device> for Device<T> { + fn as_ref(&self) -> &device::Device { + // SAFETY: `bindings::drm_device::dev` is valid as long as the DRM device itself is valid, + // which is guaranteed by the type invariant. + unsafe { device::Device::from_raw((*self.as_raw()).dev) } + } +} + +// SAFETY: A `drm::Device` can be released from any thread. +unsafe impl<T: drm::Driver> Send for Device<T> {} + +// SAFETY: A `drm::Device` can be shared among threads because all immutable methods are protected +// by the synchronization in `struct drm_device`. +unsafe impl<T: drm::Driver> Sync for Device<T> {} diff --git a/rust/kernel/drm/driver.rs b/rust/kernel/drm/driver.rs new file mode 100644 index 000000000000..f30ee4c6245c --- /dev/null +++ b/rust/kernel/drm/driver.rs @@ -0,0 +1,170 @@ +// SPDX-License-Identifier: GPL-2.0 OR MIT + +//! DRM driver core. +//! +//! C header: [`include/drm/drm_drv.h`](srctree/include/drm/drm_drv.h) + +use crate::{ + bindings, device, devres, drm, + error::{to_result, Result}, + prelude::*, + sync::aref::ARef, +}; +use macros::vtable; + +/// Driver use the GEM memory manager. This should be set for all modern drivers. +pub(crate) const FEAT_GEM: u32 = bindings::drm_driver_feature_DRIVER_GEM; + +/// Information data for a DRM Driver. +pub struct DriverInfo { + /// Driver major version. + pub major: i32, + /// Driver minor version. + pub minor: i32, + /// Driver patchlevel version. + pub patchlevel: i32, + /// Driver name. + pub name: &'static CStr, + /// Driver description. + pub desc: &'static CStr, +} + +/// Internal memory management operation set, normally created by memory managers (e.g. GEM). +pub struct AllocOps { + pub(crate) gem_create_object: Option< + unsafe extern "C" fn( + dev: *mut bindings::drm_device, + size: usize, + ) -> *mut bindings::drm_gem_object, + >, + pub(crate) prime_handle_to_fd: Option< + unsafe extern "C" fn( + dev: *mut bindings::drm_device, + file_priv: *mut bindings::drm_file, + handle: u32, + flags: u32, + prime_fd: *mut core::ffi::c_int, + ) -> core::ffi::c_int, + >, + pub(crate) prime_fd_to_handle: Option< + unsafe extern "C" fn( + dev: *mut bindings::drm_device, + file_priv: *mut bindings::drm_file, + prime_fd: core::ffi::c_int, + handle: *mut u32, + ) -> core::ffi::c_int, + >, + pub(crate) gem_prime_import: Option< + unsafe extern "C" fn( + dev: *mut bindings::drm_device, + dma_buf: *mut bindings::dma_buf, + ) -> *mut bindings::drm_gem_object, + >, + pub(crate) gem_prime_import_sg_table: Option< + unsafe extern "C" fn( + dev: *mut bindings::drm_device, + attach: *mut bindings::dma_buf_attachment, + sgt: *mut bindings::sg_table, + ) -> *mut bindings::drm_gem_object, + >, + pub(crate) dumb_create: Option< + unsafe extern "C" fn( + file_priv: *mut bindings::drm_file, + dev: *mut bindings::drm_device, + args: *mut bindings::drm_mode_create_dumb, + ) -> core::ffi::c_int, + >, + pub(crate) dumb_map_offset: Option< + unsafe extern "C" fn( + file_priv: *mut bindings::drm_file, + dev: *mut bindings::drm_device, + handle: u32, + offset: *mut u64, + ) -> core::ffi::c_int, + >, +} + +/// Trait for memory manager implementations. Implemented internally. +pub trait AllocImpl: super::private::Sealed + drm::gem::IntoGEMObject { + /// The [`Driver`] implementation for this [`AllocImpl`]. + type Driver: drm::Driver; + + /// The C callback operations for this memory manager. + const ALLOC_OPS: AllocOps; +} + +/// The DRM `Driver` trait. +/// +/// This trait must be implemented by drivers in order to create a `struct drm_device` and `struct +/// drm_driver` to be registered in the DRM subsystem. +#[vtable] +pub trait Driver { + /// Context data associated with the DRM driver + type Data: Sync + Send; + + /// The type used to manage memory for this driver. + type Object: AllocImpl; + + /// The type used to represent a DRM File (client) + type File: drm::file::DriverFile; + + /// Driver metadata + const INFO: DriverInfo; + + /// IOCTL list. See `kernel::drm::ioctl::declare_drm_ioctls!{}`. + const IOCTLS: &'static [drm::ioctl::DrmIoctlDescriptor]; +} + +/// The registration type of a `drm::Device`. +/// +/// Once the `Registration` structure is dropped, the device is unregistered. +pub struct Registration<T: Driver>(ARef<drm::Device<T>>); + +impl<T: Driver> Registration<T> { + /// Creates a new [`Registration`] and registers it. + fn new(drm: &drm::Device<T>, flags: usize) -> Result<Self> { + // SAFETY: `drm.as_raw()` is valid by the invariants of `drm::Device`. + to_result(unsafe { bindings::drm_dev_register(drm.as_raw(), flags) })?; + + Ok(Self(drm.into())) + } + + /// Same as [`Registration::new`}, but transfers ownership of the [`Registration`] to + /// [`devres::register`]. + pub fn new_foreign_owned( + drm: &drm::Device<T>, + dev: &device::Device<device::Bound>, + flags: usize, + ) -> Result + where + T: 'static, + { + if drm.as_ref().as_raw() != dev.as_raw() { + return Err(EINVAL); + } + + let reg = Registration::<T>::new(drm, flags)?; + + devres::register(dev, reg, GFP_KERNEL) + } + + /// Returns a reference to the `Device` instance for this registration. + pub fn device(&self) -> &drm::Device<T> { + &self.0 + } +} + +// SAFETY: `Registration` doesn't offer any methods or access to fields when shared between +// threads, hence it's safe to share it. +unsafe impl<T: Driver> Sync for Registration<T> {} + +// SAFETY: Registration with and unregistration from the DRM subsystem can happen from any thread. +unsafe impl<T: Driver> Send for Registration<T> {} + +impl<T: Driver> Drop for Registration<T> { + fn drop(&mut self) { + // SAFETY: Safe by the invariant of `ARef<drm::Device<T>>`. The existence of this + // `Registration` also guarantees the this `drm::Device` is actually registered. + unsafe { bindings::drm_dev_unregister(self.0.as_raw()) }; + } +} diff --git a/rust/kernel/drm/file.rs b/rust/kernel/drm/file.rs new file mode 100644 index 000000000000..8c46f8d51951 --- /dev/null +++ b/rust/kernel/drm/file.rs @@ -0,0 +1,99 @@ +// SPDX-License-Identifier: GPL-2.0 OR MIT + +//! DRM File objects. +//! +//! C header: [`include/drm/drm_file.h`](srctree/include/drm/drm_file.h) + +use crate::{bindings, drm, error::Result, prelude::*, types::Opaque}; +use core::marker::PhantomData; +use core::pin::Pin; + +/// Trait that must be implemented by DRM drivers to represent a DRM File (a client instance). +pub trait DriverFile { + /// The parent `Driver` implementation for this `DriverFile`. + type Driver: drm::Driver; + + /// Open a new file (called when a client opens the DRM device). + fn open(device: &drm::Device<Self::Driver>) -> Result<Pin<KBox<Self>>>; +} + +/// An open DRM File. +/// +/// # Invariants +/// +/// `self.0` is a valid instance of a `struct drm_file`. +#[repr(transparent)] +pub struct File<T: DriverFile>(Opaque<bindings::drm_file>, PhantomData<T>); + +impl<T: DriverFile> File<T> { + #[doc(hidden)] + /// Not intended to be called externally, except via declare_drm_ioctls!() + /// + /// # Safety + /// + /// `raw_file` must be a valid pointer to an open `struct drm_file`, opened through `T::open`. + pub unsafe fn from_raw<'a>(ptr: *mut bindings::drm_file) -> &'a File<T> { + // SAFETY: `raw_file` is valid by the safety requirements of this function. + unsafe { &*ptr.cast() } + } + + pub(super) fn as_raw(&self) -> *mut bindings::drm_file { + self.0.get() + } + + fn driver_priv(&self) -> *mut T { + // SAFETY: By the type invariants of `Self`, `self.as_raw()` is always valid. + unsafe { (*self.as_raw()).driver_priv }.cast() + } + + /// Return a pinned reference to the driver file structure. + pub fn inner(&self) -> Pin<&T> { + // SAFETY: By the type invariant the pointer `self.as_raw()` points to a valid and opened + // `struct drm_file`, hence `driver_priv` has been properly initialized by `open_callback`. + unsafe { Pin::new_unchecked(&*(self.driver_priv())) } + } + + /// The open callback of a `struct drm_file`. + pub(crate) extern "C" fn open_callback( + raw_dev: *mut bindings::drm_device, + raw_file: *mut bindings::drm_file, + ) -> core::ffi::c_int { + // SAFETY: A callback from `struct drm_driver::open` guarantees that + // - `raw_dev` is valid pointer to a `struct drm_device`, + // - the corresponding `struct drm_device` has been registered. + let drm = unsafe { drm::Device::from_raw(raw_dev) }; + + // SAFETY: `raw_file` is a valid pointer to a `struct drm_file`. + let file = unsafe { File::<T>::from_raw(raw_file) }; + + let inner = match T::open(drm) { + Err(e) => { + return e.to_errno(); + } + Ok(i) => i, + }; + + // SAFETY: This pointer is treated as pinned, and the Drop guarantee is upheld in + // `postclose_callback()`. + let driver_priv = KBox::into_raw(unsafe { Pin::into_inner_unchecked(inner) }); + + // SAFETY: By the type invariants of `Self`, `self.as_raw()` is always valid. + unsafe { (*file.as_raw()).driver_priv = driver_priv.cast() }; + + 0 + } + + /// The postclose callback of a `struct drm_file`. + pub(crate) extern "C" fn postclose_callback( + _raw_dev: *mut bindings::drm_device, + raw_file: *mut bindings::drm_file, + ) { + // SAFETY: This reference won't escape this function + let file = unsafe { File::<T>::from_raw(raw_file) }; + + // SAFETY: `file.driver_priv` has been created in `open_callback` through `KBox::into_raw`. + let _ = unsafe { KBox::from_raw(file.driver_priv()) }; + } +} + +impl<T: DriverFile> super::private::Sealed for File<T> {} diff --git a/rust/kernel/drm/gem/mod.rs b/rust/kernel/drm/gem/mod.rs new file mode 100644 index 000000000000..a7f682e95c01 --- /dev/null +++ b/rust/kernel/drm/gem/mod.rs @@ -0,0 +1,315 @@ +// SPDX-License-Identifier: GPL-2.0 OR MIT + +//! DRM GEM API +//! +//! C header: [`include/drm/drm_gem.h`](srctree/include/drm/drm_gem.h) + +use crate::{ + alloc::flags::*, + bindings, drm, + drm::driver::{AllocImpl, AllocOps}, + error::{to_result, Result}, + prelude::*, + sync::aref::{ARef, AlwaysRefCounted}, + types::Opaque, +}; +use core::{ops::Deref, ptr::NonNull}; + +/// A type alias for retrieving a [`Driver`]s [`DriverFile`] implementation from its +/// [`DriverObject`] implementation. +/// +/// [`Driver`]: drm::Driver +/// [`DriverFile`]: drm::file::DriverFile +pub type DriverFile<T> = drm::File<<<T as DriverObject>::Driver as drm::Driver>::File>; + +/// GEM object functions, which must be implemented by drivers. +pub trait DriverObject: Sync + Send + Sized { + /// Parent `Driver` for this object. + type Driver: drm::Driver; + + /// Create a new driver data object for a GEM object of a given size. + fn new(dev: &drm::Device<Self::Driver>, size: usize) -> impl PinInit<Self, Error>; + + /// Open a new handle to an existing object, associated with a File. + fn open(_obj: &<Self::Driver as drm::Driver>::Object, _file: &DriverFile<Self>) -> Result { + Ok(()) + } + + /// Close a handle to an existing object, associated with a File. + fn close(_obj: &<Self::Driver as drm::Driver>::Object, _file: &DriverFile<Self>) {} +} + +/// Trait that represents a GEM object subtype +pub trait IntoGEMObject: Sized + super::private::Sealed + AlwaysRefCounted { + /// Returns a reference to the raw `drm_gem_object` structure, which must be valid as long as + /// this owning object is valid. + fn as_raw(&self) -> *mut bindings::drm_gem_object; + + /// Converts a pointer to a `struct drm_gem_object` into a reference to `Self`. + /// + /// # Safety + /// + /// - `self_ptr` must be a valid pointer to `Self`. + /// - The caller promises that holding the immutable reference returned by this function does + /// not violate rust's data aliasing rules and remains valid throughout the lifetime of `'a`. + unsafe fn from_raw<'a>(self_ptr: *mut bindings::drm_gem_object) -> &'a Self; +} + +extern "C" fn open_callback<T: DriverObject>( + raw_obj: *mut bindings::drm_gem_object, + raw_file: *mut bindings::drm_file, +) -> core::ffi::c_int { + // SAFETY: `open_callback` is only ever called with a valid pointer to a `struct drm_file`. + let file = unsafe { DriverFile::<T>::from_raw(raw_file) }; + + // SAFETY: `open_callback` is specified in the AllocOps structure for `DriverObject<T>`, + // ensuring that `raw_obj` is contained within a `DriverObject<T>` + let obj = unsafe { <<T::Driver as drm::Driver>::Object as IntoGEMObject>::from_raw(raw_obj) }; + + match T::open(obj, file) { + Err(e) => e.to_errno(), + Ok(()) => 0, + } +} + +extern "C" fn close_callback<T: DriverObject>( + raw_obj: *mut bindings::drm_gem_object, + raw_file: *mut bindings::drm_file, +) { + // SAFETY: `open_callback` is only ever called with a valid pointer to a `struct drm_file`. + let file = unsafe { DriverFile::<T>::from_raw(raw_file) }; + + // SAFETY: `close_callback` is specified in the AllocOps structure for `Object<T>`, ensuring + // that `raw_obj` is indeed contained within a `Object<T>`. + let obj = unsafe { <<T::Driver as drm::Driver>::Object as IntoGEMObject>::from_raw(raw_obj) }; + + T::close(obj, file); +} + +impl<T: DriverObject> IntoGEMObject for Object<T> { + fn as_raw(&self) -> *mut bindings::drm_gem_object { + self.obj.get() + } + + unsafe fn from_raw<'a>(self_ptr: *mut bindings::drm_gem_object) -> &'a Self { + // SAFETY: `obj` is guaranteed to be in an `Object<T>` via the safety contract of this + // function + unsafe { &*crate::container_of!(Opaque::cast_from(self_ptr), Object<T>, obj) } + } +} + +/// Base operations shared by all GEM object classes +pub trait BaseObject: IntoGEMObject { + /// Returns the size of the object in bytes. + fn size(&self) -> usize { + // SAFETY: `self.as_raw()` is guaranteed to be a pointer to a valid `struct drm_gem_object`. + unsafe { (*self.as_raw()).size } + } + + /// Creates a new handle for the object associated with a given `File` + /// (or returns an existing one). + fn create_handle<D, F>(&self, file: &drm::File<F>) -> Result<u32> + where + Self: AllocImpl<Driver = D>, + D: drm::Driver<Object = Self, File = F>, + F: drm::file::DriverFile<Driver = D>, + { + let mut handle: u32 = 0; + // SAFETY: The arguments are all valid per the type invariants. + to_result(unsafe { + bindings::drm_gem_handle_create(file.as_raw().cast(), self.as_raw(), &mut handle) + })?; + Ok(handle) + } + + /// Looks up an object by its handle for a given `File`. + fn lookup_handle<D, F>(file: &drm::File<F>, handle: u32) -> Result<ARef<Self>> + where + Self: AllocImpl<Driver = D>, + D: drm::Driver<Object = Self, File = F>, + F: drm::file::DriverFile<Driver = D>, + { + // SAFETY: The arguments are all valid per the type invariants. + let ptr = unsafe { bindings::drm_gem_object_lookup(file.as_raw().cast(), handle) }; + if ptr.is_null() { + return Err(ENOENT); + } + + // SAFETY: + // - A `drm::Driver` can only have a single `File` implementation. + // - `file` uses the same `drm::Driver` as `Self`. + // - Therefore, we're guaranteed that `ptr` must be a gem object embedded within `Self`. + // - And we check if the pointer is null befoe calling from_raw(), ensuring that `ptr` is a + // valid pointer to an initialized `Self`. + let obj = unsafe { Self::from_raw(ptr) }; + + // SAFETY: + // - We take ownership of the reference of `drm_gem_object_lookup()`. + // - Our `NonNull` comes from an immutable reference, thus ensuring it is a valid pointer to + // `Self`. + Ok(unsafe { ARef::from_raw(obj.into()) }) + } + + /// Creates an mmap offset to map the object from userspace. + fn create_mmap_offset(&self) -> Result<u64> { + // SAFETY: The arguments are valid per the type invariant. + to_result(unsafe { bindings::drm_gem_create_mmap_offset(self.as_raw()) })?; + + // SAFETY: The arguments are valid per the type invariant. + Ok(unsafe { bindings::drm_vma_node_offset_addr(&raw mut (*self.as_raw()).vma_node) }) + } +} + +impl<T: IntoGEMObject> BaseObject for T {} + +/// A base GEM object. +/// +/// # Invariants +/// +/// - `self.obj` is a valid instance of a `struct drm_gem_object`. +#[repr(C)] +#[pin_data] +pub struct Object<T: DriverObject + Send + Sync> { + obj: Opaque<bindings::drm_gem_object>, + #[pin] + data: T, +} + +impl<T: DriverObject> Object<T> { + const OBJECT_FUNCS: bindings::drm_gem_object_funcs = bindings::drm_gem_object_funcs { + free: Some(Self::free_callback), + open: Some(open_callback::<T>), + close: Some(close_callback::<T>), + print_info: None, + export: None, + pin: None, + unpin: None, + get_sg_table: None, + vmap: None, + vunmap: None, + mmap: None, + status: None, + vm_ops: core::ptr::null_mut(), + evict: None, + rss: None, + }; + + /// Create a new GEM object. + pub fn new(dev: &drm::Device<T::Driver>, size: usize) -> Result<ARef<Self>> { + let obj: Pin<KBox<Self>> = KBox::pin_init( + try_pin_init!(Self { + obj: Opaque::new(bindings::drm_gem_object::default()), + data <- T::new(dev, size), + }), + GFP_KERNEL, + )?; + + // SAFETY: `obj.as_raw()` is guaranteed to be valid by the initialization above. + unsafe { (*obj.as_raw()).funcs = &Self::OBJECT_FUNCS }; + + // SAFETY: The arguments are all valid per the type invariants. + to_result(unsafe { bindings::drm_gem_object_init(dev.as_raw(), obj.obj.get(), size) })?; + + // SAFETY: We never move out of `Self`. + let ptr = KBox::into_raw(unsafe { Pin::into_inner_unchecked(obj) }); + + // SAFETY: `ptr` comes from `KBox::into_raw` and hence can't be NULL. + let ptr = unsafe { NonNull::new_unchecked(ptr) }; + + // SAFETY: We take over the initial reference count from `drm_gem_object_init()`. + Ok(unsafe { ARef::from_raw(ptr) }) + } + + /// Returns the `Device` that owns this GEM object. + pub fn dev(&self) -> &drm::Device<T::Driver> { + // SAFETY: + // - `struct drm_gem_object.dev` is initialized and valid for as long as the GEM + // object lives. + // - The device we used for creating the gem object is passed as &drm::Device<T::Driver> to + // Object::<T>::new(), so we know that `T::Driver` is the right generic parameter to use + // here. + unsafe { drm::Device::from_raw((*self.as_raw()).dev) } + } + + fn as_raw(&self) -> *mut bindings::drm_gem_object { + self.obj.get() + } + + extern "C" fn free_callback(obj: *mut bindings::drm_gem_object) { + let ptr: *mut Opaque<bindings::drm_gem_object> = obj.cast(); + + // SAFETY: All of our objects are of type `Object<T>`. + let this = unsafe { crate::container_of!(ptr, Self, obj) }; + + // SAFETY: The C code only ever calls this callback with a valid pointer to a `struct + // drm_gem_object`. + unsafe { bindings::drm_gem_object_release(obj) }; + + // SAFETY: All of our objects are allocated via `KBox`, and we're in the + // free callback which guarantees this object has zero remaining references, + // so we can drop it. + let _ = unsafe { KBox::from_raw(this) }; + } +} + +// SAFETY: Instances of `Object<T>` are always reference-counted. +unsafe impl<T: DriverObject> crate::types::AlwaysRefCounted for Object<T> { + fn inc_ref(&self) { + // SAFETY: The existence of a shared reference guarantees that the refcount is non-zero. + unsafe { bindings::drm_gem_object_get(self.as_raw()) }; + } + + unsafe fn dec_ref(obj: NonNull<Self>) { + // SAFETY: `obj` is a valid pointer to an `Object<T>`. + let obj = unsafe { obj.as_ref() }; + + // SAFETY: The safety requirements guarantee that the refcount is non-zero. + unsafe { bindings::drm_gem_object_put(obj.as_raw()) } + } +} + +impl<T: DriverObject> super::private::Sealed for Object<T> {} + +impl<T: DriverObject> Deref for Object<T> { + type Target = T; + + fn deref(&self) -> &Self::Target { + &self.data + } +} + +impl<T: DriverObject> AllocImpl for Object<T> { + type Driver = T::Driver; + + const ALLOC_OPS: AllocOps = AllocOps { + gem_create_object: None, + prime_handle_to_fd: None, + prime_fd_to_handle: None, + gem_prime_import: None, + gem_prime_import_sg_table: None, + dumb_create: None, + dumb_map_offset: None, + }; +} + +pub(super) const fn create_fops() -> bindings::file_operations { + // SAFETY: As by the type invariant, it is safe to initialize `bindings::file_operations` + // zeroed. + let mut fops: bindings::file_operations = unsafe { core::mem::zeroed() }; + + fops.owner = core::ptr::null_mut(); + fops.open = Some(bindings::drm_open); + fops.release = Some(bindings::drm_release); + fops.unlocked_ioctl = Some(bindings::drm_ioctl); + #[cfg(CONFIG_COMPAT)] + { + fops.compat_ioctl = Some(bindings::drm_compat_ioctl); + } + fops.poll = Some(bindings::drm_poll); + fops.read = Some(bindings::drm_read); + fops.llseek = Some(bindings::noop_llseek); + fops.mmap = Some(bindings::drm_gem_mmap); + fops.fop_flags = bindings::FOP_UNSIGNED_OFFSET; + + fops +} diff --git a/rust/kernel/drm/ioctl.rs b/rust/kernel/drm/ioctl.rs new file mode 100644 index 000000000000..cf328101dde4 --- /dev/null +++ b/rust/kernel/drm/ioctl.rs @@ -0,0 +1,167 @@ +// SPDX-License-Identifier: GPL-2.0 OR MIT + +//! DRM IOCTL definitions. +//! +//! C header: [`include/drm/drm_ioctl.h`](srctree/include/drm/drm_ioctl.h) + +use crate::ioctl; + +const BASE: u32 = uapi::DRM_IOCTL_BASE as u32; + +/// Construct a DRM ioctl number with no argument. +#[allow(non_snake_case)] +#[inline(always)] +pub const fn IO(nr: u32) -> u32 { + ioctl::_IO(BASE, nr) +} + +/// Construct a DRM ioctl number with a read-only argument. +#[allow(non_snake_case)] +#[inline(always)] +pub const fn IOR<T>(nr: u32) -> u32 { + ioctl::_IOR::<T>(BASE, nr) +} + +/// Construct a DRM ioctl number with a write-only argument. +#[allow(non_snake_case)] +#[inline(always)] +pub const fn IOW<T>(nr: u32) -> u32 { + ioctl::_IOW::<T>(BASE, nr) +} + +/// Construct a DRM ioctl number with a read-write argument. +#[allow(non_snake_case)] +#[inline(always)] +pub const fn IOWR<T>(nr: u32) -> u32 { + ioctl::_IOWR::<T>(BASE, nr) +} + +/// Descriptor type for DRM ioctls. Use the `declare_drm_ioctls!{}` macro to construct them. +pub type DrmIoctlDescriptor = bindings::drm_ioctl_desc; + +/// This is for ioctl which are used for rendering, and require that the file descriptor is either +/// for a render node, or if it’s a legacy/primary node, then it must be authenticated. +pub const AUTH: u32 = bindings::drm_ioctl_flags_DRM_AUTH; + +/// This must be set for any ioctl which can change the modeset or display state. Userspace must +/// call the ioctl through a primary node, while it is the active master. +/// +/// Note that read-only modeset ioctl can also be called by unauthenticated clients, or when a +/// master is not the currently active one. +pub const MASTER: u32 = bindings::drm_ioctl_flags_DRM_MASTER; + +/// Anything that could potentially wreak a master file descriptor needs to have this flag set. +/// +/// Current that’s only for the SETMASTER and DROPMASTER ioctl, which e.g. logind can call to +/// force a non-behaving master (display compositor) into compliance. +/// +/// This is equivalent to callers with the SYSADMIN capability. +pub const ROOT_ONLY: u32 = bindings::drm_ioctl_flags_DRM_ROOT_ONLY; + +/// This is used for all ioctl needed for rendering only, for drivers which support render nodes. +/// This should be all new render drivers, and hence it should be always set for any ioctl with +/// `AUTH` set. Note though that read-only query ioctl might have this set, but have not set +/// DRM_AUTH because they do not require authentication. +pub const RENDER_ALLOW: u32 = bindings::drm_ioctl_flags_DRM_RENDER_ALLOW; + +/// Internal structures used by the `declare_drm_ioctls!{}` macro. Do not use directly. +#[doc(hidden)] +pub mod internal { + pub use bindings::drm_device; + pub use bindings::drm_file; + pub use bindings::drm_ioctl_desc; +} + +/// Declare the DRM ioctls for a driver. +/// +/// Each entry in the list should have the form: +/// +/// `(ioctl_number, argument_type, flags, user_callback),` +/// +/// `argument_type` is the type name within the `bindings` crate. +/// `user_callback` should have the following prototype: +/// +/// ```ignore +/// fn foo(device: &kernel::drm::Device<Self>, +/// data: &mut uapi::argument_type, +/// file: &kernel::drm::File<Self::File>, +/// ) -> Result<u32> +/// ``` +/// where `Self` is the drm::drv::Driver implementation these ioctls are being declared within. +/// +/// # Examples +/// +/// ```ignore +/// kernel::declare_drm_ioctls! { +/// (FOO_GET_PARAM, drm_foo_get_param, ioctl::RENDER_ALLOW, my_get_param_handler), +/// } +/// ``` +/// +#[macro_export] +macro_rules! declare_drm_ioctls { + ( $(($cmd:ident, $struct:ident, $flags:expr, $func:expr)),* $(,)? ) => { + const IOCTLS: &'static [$crate::drm::ioctl::DrmIoctlDescriptor] = { + use $crate::uapi::*; + const _:() = { + let i: u32 = $crate::uapi::DRM_COMMAND_BASE; + // Assert that all the IOCTLs are in the right order and there are no gaps, + // and that the size of the specified type is correct. + $( + let cmd: u32 = $crate::macros::concat_idents!(DRM_IOCTL_, $cmd); + ::core::assert!(i == $crate::ioctl::_IOC_NR(cmd)); + ::core::assert!(core::mem::size_of::<$crate::uapi::$struct>() == + $crate::ioctl::_IOC_SIZE(cmd)); + let i: u32 = i + 1; + )* + }; + + let ioctls = &[$( + $crate::drm::ioctl::internal::drm_ioctl_desc { + cmd: $crate::macros::concat_idents!(DRM_IOCTL_, $cmd) as u32, + func: { + #[allow(non_snake_case)] + unsafe extern "C" fn $cmd( + raw_dev: *mut $crate::drm::ioctl::internal::drm_device, + raw_data: *mut ::core::ffi::c_void, + raw_file: *mut $crate::drm::ioctl::internal::drm_file, + ) -> core::ffi::c_int { + // SAFETY: + // - The DRM core ensures the device lives while callbacks are being + // called. + // - The DRM device must have been registered when we're called through + // an IOCTL. + // + // FIXME: Currently there is nothing enforcing that the types of the + // dev/file match the current driver these ioctls are being declared + // for, and it's not clear how to enforce this within the type system. + let dev = $crate::drm::device::Device::from_raw(raw_dev); + // SAFETY: The ioctl argument has size `_IOC_SIZE(cmd)`, which we + // asserted above matches the size of this type, and all bit patterns of + // UAPI structs must be valid. + // The `ioctl` argument is exclusively owned by the handler + // and guaranteed by the C implementation (`drm_ioctl()`) to remain + // valid for the entire lifetime of the reference taken here. + // There is no concurrent access or aliasing; no other references + // to this object exist during this call. + let data = unsafe { &mut *(raw_data.cast::<$crate::uapi::$struct>()) }; + // SAFETY: This is just the DRM file structure + let file = unsafe { $crate::drm::File::from_raw(raw_file) }; + + match $func(dev, data, file) { + Err(e) => e.to_errno(), + Ok(i) => i.try_into() + .unwrap_or($crate::error::code::ERANGE.to_errno()), + } + } + Some($cmd) + }, + flags: $flags, + name: $crate::str::as_char_ptr_in_const_context( + $crate::c_str!(::core::stringify!($cmd)), + ), + } + ),*]; + ioctls + }; + }; +} diff --git a/rust/kernel/drm/mod.rs b/rust/kernel/drm/mod.rs new file mode 100644 index 000000000000..1b82b6945edf --- /dev/null +++ b/rust/kernel/drm/mod.rs @@ -0,0 +1,19 @@ +// SPDX-License-Identifier: GPL-2.0 OR MIT + +//! DRM subsystem abstractions. + +pub mod device; +pub mod driver; +pub mod file; +pub mod gem; +pub mod ioctl; + +pub use self::device::Device; +pub use self::driver::Driver; +pub use self::driver::DriverInfo; +pub use self::driver::Registration; +pub use self::file::File; + +pub(crate) mod private { + pub trait Sealed {} +} diff --git a/rust/kernel/error.rs b/rust/kernel/error.rs index 5b9751d7ff1d..258b12afdcba 100644 --- a/rust/kernel/error.rs +++ b/rust/kernel/error.rs @@ -2,25 +2,33 @@ //! Kernel errors. //! -//! C header: [`include/uapi/asm-generic/errno-base.h`](../../../include/uapi/asm-generic/errno-base.h) +//! C header: [`include/uapi/asm-generic/errno-base.h`](srctree/include/uapi/asm-generic/errno-base.h)\ +//! C header: [`include/uapi/asm-generic/errno.h`](srctree/include/uapi/asm-generic/errno.h)\ +//! C header: [`include/linux/errno.h`](srctree/include/linux/errno.h) -use alloc::{ - alloc::{AllocError, LayoutError}, - collections::TryReserveError, +use crate::{ + alloc::{layout::LayoutError, AllocError}, + fmt, + str::CStr, }; -use core::convert::From; +use core::num::NonZeroI32; use core::num::TryFromIntError; use core::str::Utf8Error; /// Contains the C-compatible error codes. +#[rustfmt::skip] pub mod code { macro_rules! declare_err { ($err:tt $(,)? $($doc:expr),+) => { $( #[doc = $doc] )* - pub const $err: super::Error = super::Error(-(crate::bindings::$err as i32)); + pub const $err: super::Error = + match super::Error::try_from_errno(-(crate::bindings::$err as i32)) { + Some(err) => err, + None => panic!("Invalid errno in `declare_err!`"), + }; }; } @@ -33,7 +41,7 @@ pub mod code { declare_err!(E2BIG, "Argument list too long."); declare_err!(ENOEXEC, "Exec format error."); declare_err!(EBADF, "Bad file number."); - declare_err!(ECHILD, "Exec format error."); + declare_err!(ECHILD, "No child processes."); declare_err!(EAGAIN, "Try again."); declare_err!(ENOMEM, "Out of memory."); declare_err!(EACCES, "Permission denied."); @@ -58,6 +66,27 @@ pub mod code { declare_err!(EPIPE, "Broken pipe."); declare_err!(EDOM, "Math argument out of domain of func."); declare_err!(ERANGE, "Math result not representable."); + declare_err!(EOVERFLOW, "Value too large for defined data type."); + declare_err!(ETIMEDOUT, "Connection timed out."); + declare_err!(ERESTARTSYS, "Restart the system call."); + declare_err!(ERESTARTNOINTR, "System call was interrupted by a signal and will be restarted."); + declare_err!(ERESTARTNOHAND, "Restart if no handler."); + declare_err!(ENOIOCTLCMD, "No ioctl command."); + declare_err!(ERESTART_RESTARTBLOCK, "Restart by calling sys_restart_syscall."); + declare_err!(EPROBE_DEFER, "Driver requests probe retry."); + declare_err!(EOPENSTALE, "Open found a stale dentry."); + declare_err!(ENOPARAM, "Parameter not supported."); + declare_err!(EBADHANDLE, "Illegal NFS file handle."); + declare_err!(ENOTSYNC, "Update synchronization mismatch."); + declare_err!(EBADCOOKIE, "Cookie is stale."); + declare_err!(ENOTSUPP, "Operation is not supported."); + declare_err!(ETOOSMALL, "Buffer or request is too small."); + declare_err!(ESERVERFAULT, "An untranslatable error occurred."); + declare_err!(EBADTYPE, "Type not supported by server."); + declare_err!(EJUKEBOX, "Request initiated, but will not complete before timeout."); + declare_err!(EIOCBQUEUED, "iocb queued, will get completion event."); + declare_err!(ERECALLCONFLICT, "Conflict with recalled state."); + declare_err!(ENOGRACE, "NFS file lock reclaim refused."); } /// Generic integer kernel error. @@ -69,12 +98,120 @@ pub mod code { /// /// The value is a valid `errno` (i.e. `>= -MAX_ERRNO && < 0`). #[derive(Clone, Copy, PartialEq, Eq)] -pub struct Error(core::ffi::c_int); +pub struct Error(NonZeroI32); impl Error { + /// Creates an [`Error`] from a kernel error code. + /// + /// `errno` must be within error code range (i.e. `>= -MAX_ERRNO && < 0`). + /// + /// It is a bug to pass an out-of-range `errno`. [`code::EINVAL`] is returned in such a case. + /// + /// # Examples + /// + /// ``` + /// assert_eq!(Error::from_errno(-1), EPERM); + /// assert_eq!(Error::from_errno(-2), ENOENT); + /// ``` + /// + /// The following calls are considered a bug: + /// + /// ``` + /// assert_eq!(Error::from_errno(0), EINVAL); + /// assert_eq!(Error::from_errno(-1000000), EINVAL); + /// ``` + pub fn from_errno(errno: crate::ffi::c_int) -> Error { + if let Some(error) = Self::try_from_errno(errno) { + error + } else { + // TODO: Make it a `WARN_ONCE` once available. + crate::pr_warn!( + "attempted to create `Error` with out of range `errno`: {}\n", + errno + ); + code::EINVAL + } + } + + /// Creates an [`Error`] from a kernel error code. + /// + /// Returns [`None`] if `errno` is out-of-range. + const fn try_from_errno(errno: crate::ffi::c_int) -> Option<Error> { + if errno < -(bindings::MAX_ERRNO as i32) || errno >= 0 { + return None; + } + + // SAFETY: `errno` is checked above to be in a valid range. + Some(unsafe { Error::from_errno_unchecked(errno) }) + } + + /// Creates an [`Error`] from a kernel error code. + /// + /// # Safety + /// + /// `errno` must be within error code range (i.e. `>= -MAX_ERRNO && < 0`). + const unsafe fn from_errno_unchecked(errno: crate::ffi::c_int) -> Error { + // INVARIANT: The contract ensures the type invariant + // will hold. + // SAFETY: The caller guarantees `errno` is non-zero. + Error(unsafe { NonZeroI32::new_unchecked(errno) }) + } + /// Returns the kernel error code. - pub fn to_kernel_errno(self) -> core::ffi::c_int { - self.0 + pub fn to_errno(self) -> crate::ffi::c_int { + self.0.get() + } + + #[cfg(CONFIG_BLOCK)] + pub(crate) fn to_blk_status(self) -> bindings::blk_status_t { + // SAFETY: `self.0` is a valid error due to its invariant. + unsafe { bindings::errno_to_blk_status(self.0.get()) } + } + + /// Returns the error encoded as a pointer. + pub fn to_ptr<T>(self) -> *mut T { + // SAFETY: `self.0` is a valid error due to its invariant. + unsafe { bindings::ERR_PTR(self.0.get() as crate::ffi::c_long).cast() } + } + + /// Returns a string representing the error, if one exists. + #[cfg(not(testlib))] + pub fn name(&self) -> Option<&'static CStr> { + // SAFETY: Just an FFI call, there are no extra safety requirements. + let ptr = unsafe { bindings::errname(-self.0.get()) }; + if ptr.is_null() { + None + } else { + use crate::str::CStrExt as _; + + // SAFETY: The string returned by `errname` is static and `NUL`-terminated. + Some(unsafe { CStr::from_char_ptr(ptr) }) + } + } + + /// Returns a string representing the error, if one exists. + /// + /// When `testlib` is configured, this always returns `None` to avoid the dependency on a + /// kernel function so that tests that use this (e.g., by calling [`Result::unwrap`]) can still + /// run in userspace. + #[cfg(testlib)] + pub fn name(&self) -> Option<&'static CStr> { + None + } +} + +impl fmt::Debug for Error { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self.name() { + // Print out number if no name can be found. + None => f.debug_tuple("Error").field(&-self.0).finish(), + Some(name) => f + .debug_tuple( + // SAFETY: These strings are ASCII-only. + unsafe { core::str::from_utf8_unchecked(name.to_bytes()) }, + ) + .finish(), + } } } @@ -96,20 +233,14 @@ impl From<Utf8Error> for Error { } } -impl From<TryReserveError> for Error { - fn from(_: TryReserveError) -> Error { - code::ENOMEM - } -} - impl From<LayoutError> for Error { fn from(_: LayoutError) -> Error { code::ENOMEM } } -impl From<core::fmt::Error> for Error { - fn from(_: core::fmt::Error) -> Error { +impl From<fmt::Error> for Error { + fn from(_: fmt::Error) -> Error { code::EINVAL } } @@ -138,6 +269,257 @@ impl From<core::convert::Infallible> for Error { /// [`Error`] as its error type. /// /// Note that even if a function does not return anything when it succeeds, -/// it should still be modeled as returning a `Result` rather than +/// it should still be modeled as returning a [`Result`] rather than /// just an [`Error`]. -pub type Result<T = ()> = core::result::Result<T, Error>; +/// +/// Calling a function that returns [`Result`] forces the caller to handle +/// the returned [`Result`]. +/// +/// This can be done "manually" by using [`match`]. Using [`match`] to decode +/// the [`Result`] is similar to C where all the return value decoding and the +/// error handling is done explicitly by writing handling code for each +/// error to cover. Using [`match`] the error and success handling can be +/// implemented in all detail as required. For example (inspired by +/// [`samples/rust/rust_minimal.rs`]): +/// +/// ``` +/// # #[allow(clippy::single_match)] +/// fn example() -> Result { +/// let mut numbers = KVec::new(); +/// +/// match numbers.push(72, GFP_KERNEL) { +/// Err(e) => { +/// pr_err!("Error pushing 72: {e:?}"); +/// return Err(e.into()); +/// } +/// // Do nothing, continue. +/// Ok(()) => (), +/// } +/// +/// match numbers.push(108, GFP_KERNEL) { +/// Err(e) => { +/// pr_err!("Error pushing 108: {e:?}"); +/// return Err(e.into()); +/// } +/// // Do nothing, continue. +/// Ok(()) => (), +/// } +/// +/// match numbers.push(200, GFP_KERNEL) { +/// Err(e) => { +/// pr_err!("Error pushing 200: {e:?}"); +/// return Err(e.into()); +/// } +/// // Do nothing, continue. +/// Ok(()) => (), +/// } +/// +/// Ok(()) +/// } +/// # example()?; +/// # Ok::<(), Error>(()) +/// ``` +/// +/// An alternative to be more concise is the [`if let`] syntax: +/// +/// ``` +/// fn example() -> Result { +/// let mut numbers = KVec::new(); +/// +/// if let Err(e) = numbers.push(72, GFP_KERNEL) { +/// pr_err!("Error pushing 72: {e:?}"); +/// return Err(e.into()); +/// } +/// +/// if let Err(e) = numbers.push(108, GFP_KERNEL) { +/// pr_err!("Error pushing 108: {e:?}"); +/// return Err(e.into()); +/// } +/// +/// if let Err(e) = numbers.push(200, GFP_KERNEL) { +/// pr_err!("Error pushing 200: {e:?}"); +/// return Err(e.into()); +/// } +/// +/// Ok(()) +/// } +/// # example()?; +/// # Ok::<(), Error>(()) +/// ``` +/// +/// Instead of these verbose [`match`]/[`if let`], the [`?`] operator can +/// be used to handle the [`Result`]. Using the [`?`] operator is often +/// the best choice to handle [`Result`] in a non-verbose way as done in +/// [`samples/rust/rust_minimal.rs`]: +/// +/// ``` +/// fn example() -> Result { +/// let mut numbers = KVec::new(); +/// +/// numbers.push(72, GFP_KERNEL)?; +/// numbers.push(108, GFP_KERNEL)?; +/// numbers.push(200, GFP_KERNEL)?; +/// +/// Ok(()) +/// } +/// # example()?; +/// # Ok::<(), Error>(()) +/// ``` +/// +/// Another possibility is to call [`unwrap()`](Result::unwrap) or +/// [`expect()`](Result::expect). However, use of these functions is +/// *heavily discouraged* in the kernel because they trigger a Rust +/// [`panic!`] if an error happens, which may destabilize the system or +/// entirely break it as a result -- just like the C [`BUG()`] macro. +/// Please see the documentation for the C macro [`BUG()`] for guidance +/// on when to use these functions. +/// +/// Alternatively, depending on the use case, using [`unwrap_or()`], +/// [`unwrap_or_else()`], [`unwrap_or_default()`] or [`unwrap_unchecked()`] +/// might be an option, as well. +/// +/// For even more details, please see the [Rust documentation]. +/// +/// [`match`]: https://doc.rust-lang.org/reference/expressions/match-expr.html +/// [`samples/rust/rust_minimal.rs`]: srctree/samples/rust/rust_minimal.rs +/// [`if let`]: https://doc.rust-lang.org/reference/expressions/if-expr.html#if-let-expressions +/// [`?`]: https://doc.rust-lang.org/reference/expressions/operator-expr.html#the-question-mark-operator +/// [`unwrap()`]: Result::unwrap +/// [`expect()`]: Result::expect +/// [`BUG()`]: https://docs.kernel.org/process/deprecated.html#bug-and-bug-on +/// [`unwrap_or()`]: Result::unwrap_or +/// [`unwrap_or_else()`]: Result::unwrap_or_else +/// [`unwrap_or_default()`]: Result::unwrap_or_default +/// [`unwrap_unchecked()`]: Result::unwrap_unchecked +/// [Rust documentation]: https://doc.rust-lang.org/book/ch09-02-recoverable-errors-with-result.html +pub type Result<T = (), E = Error> = core::result::Result<T, E>; + +/// Converts an integer as returned by a C kernel function to a [`Result`]. +/// +/// If the integer is negative, an [`Err`] with an [`Error`] as given by [`Error::from_errno`] is +/// returned. This means the integer must be `>= -MAX_ERRNO`. +/// +/// Otherwise, it returns [`Ok`]. +/// +/// It is a bug to pass an out-of-range negative integer. `Err(EINVAL)` is returned in such a case. +/// +/// # Examples +/// +/// This function may be used to easily perform early returns with the [`?`] operator when working +/// with C APIs within Rust abstractions: +/// +/// ``` +/// # use kernel::error::to_result; +/// # mod bindings { +/// # #![expect(clippy::missing_safety_doc)] +/// # use kernel::prelude::*; +/// # pub(super) unsafe fn f1() -> c_int { 0 } +/// # pub(super) unsafe fn f2() -> c_int { EINVAL.to_errno() } +/// # } +/// fn f() -> Result { +/// // SAFETY: ... +/// to_result(unsafe { bindings::f1() })?; +/// +/// // SAFETY: ... +/// to_result(unsafe { bindings::f2() })?; +/// +/// // ... +/// +/// Ok(()) +/// } +/// # assert_eq!(f(), Err(EINVAL)); +/// ``` +/// +/// [`?`]: https://doc.rust-lang.org/reference/expressions/operator-expr.html#the-question-mark-operator +pub fn to_result(err: crate::ffi::c_int) -> Result { + if err < 0 { + Err(Error::from_errno(err)) + } else { + Ok(()) + } +} + +/// Transform a kernel "error pointer" to a normal pointer. +/// +/// Some kernel C API functions return an "error pointer" which optionally +/// embeds an `errno`. Callers are supposed to check the returned pointer +/// for errors. This function performs the check and converts the "error pointer" +/// to a normal pointer in an idiomatic fashion. +/// +/// # Examples +/// +/// ```ignore +/// # use kernel::from_err_ptr; +/// # use kernel::bindings; +/// fn devm_platform_ioremap_resource( +/// pdev: &mut PlatformDevice, +/// index: u32, +/// ) -> Result<*mut kernel::ffi::c_void> { +/// // SAFETY: `pdev` points to a valid platform device. There are no safety requirements +/// // on `index`. +/// from_err_ptr(unsafe { bindings::devm_platform_ioremap_resource(pdev.to_ptr(), index) }) +/// } +/// ``` +pub fn from_err_ptr<T>(ptr: *mut T) -> Result<*mut T> { + // CAST: Casting a pointer to `*const crate::ffi::c_void` is always valid. + let const_ptr: *const crate::ffi::c_void = ptr.cast(); + // SAFETY: The FFI function does not deref the pointer. + if unsafe { bindings::IS_ERR(const_ptr) } { + // SAFETY: The FFI function does not deref the pointer. + let err = unsafe { bindings::PTR_ERR(const_ptr) }; + + #[allow(clippy::unnecessary_cast)] + // CAST: If `IS_ERR()` returns `true`, + // then `PTR_ERR()` is guaranteed to return a + // negative value greater-or-equal to `-bindings::MAX_ERRNO`, + // which always fits in an `i16`, as per the invariant above. + // And an `i16` always fits in an `i32`. So casting `err` to + // an `i32` can never overflow, and is always valid. + // + // SAFETY: `IS_ERR()` ensures `err` is a + // negative value greater-or-equal to `-bindings::MAX_ERRNO`. + return Err(unsafe { Error::from_errno_unchecked(err as crate::ffi::c_int) }); + } + Ok(ptr) +} + +/// Calls a closure returning a [`crate::error::Result<T>`] and converts the result to +/// a C integer result. +/// +/// This is useful when calling Rust functions that return [`crate::error::Result<T>`] +/// from inside `extern "C"` functions that need to return an integer error result. +/// +/// `T` should be convertible from an `i16` via `From<i16>`. +/// +/// # Examples +/// +/// ```ignore +/// # use kernel::from_result; +/// # use kernel::bindings; +/// unsafe extern "C" fn probe_callback( +/// pdev: *mut bindings::platform_device, +/// ) -> kernel::ffi::c_int { +/// from_result(|| { +/// let ptr = devm_alloc(pdev)?; +/// bindings::platform_set_drvdata(pdev, ptr); +/// Ok(0) +/// }) +/// } +/// ``` +pub fn from_result<T, F>(f: F) -> T +where + T: From<i16>, + F: FnOnce() -> Result<T>, +{ + match f() { + Ok(v) => v, + // NO-OVERFLOW: negative `errno`s are no smaller than `-bindings::MAX_ERRNO`, + // `-bindings::MAX_ERRNO` fits in an `i16` as per invariant above, + // therefore a negative `errno` always fits in an `i16` and will not overflow. + Err(e) => T::from(e.to_errno() as i16), + } +} + +/// Error message for calling a default function of a [`#[vtable]`](macros::vtable) trait. +pub const VTABLE_DEFAULT_ERROR: &str = + "This function must not be called, see the #[vtable] documentation."; diff --git a/rust/kernel/faux.rs b/rust/kernel/faux.rs new file mode 100644 index 000000000000..7fe2dd197e37 --- /dev/null +++ b/rust/kernel/faux.rs @@ -0,0 +1,77 @@ +// SPDX-License-Identifier: GPL-2.0-only + +//! Abstractions for the faux bus. +//! +//! This module provides bindings for working with faux devices in kernel modules. +//! +//! C header: [`include/linux/device/faux.h`](srctree/include/linux/device/faux.h) + +use crate::{bindings, device, error::code::*, prelude::*}; +use core::ptr::{addr_of_mut, null, null_mut, NonNull}; + +/// The registration of a faux device. +/// +/// This type represents the registration of a [`struct faux_device`]. When an instance of this type +/// is dropped, its respective faux device will be unregistered from the system. +/// +/// # Invariants +/// +/// `self.0` always holds a valid pointer to an initialized and registered [`struct faux_device`]. +/// +/// [`struct faux_device`]: srctree/include/linux/device/faux.h +pub struct Registration(NonNull<bindings::faux_device>); + +impl Registration { + /// Create and register a new faux device with the given name. + #[inline] + pub fn new(name: &CStr, parent: Option<&device::Device>) -> Result<Self> { + // SAFETY: + // - `name` is copied by this function into its own storage + // - `faux_ops` is safe to leave NULL according to the C API + // - `parent` can be either NULL or a pointer to a `struct device`, and `faux_device_create` + // will take a reference to `parent` using `device_add` - ensuring that it remains valid + // for the lifetime of the faux device. + let dev = unsafe { + bindings::faux_device_create( + name.as_char_ptr(), + parent.map_or(null_mut(), |p| p.as_raw()), + null(), + ) + }; + + // The above function will return either a valid device, or NULL on failure + // INVARIANT: The device will remain registered until faux_device_destroy() is called, which + // happens in our Drop implementation. + Ok(Self(NonNull::new(dev).ok_or(ENODEV)?)) + } + + fn as_raw(&self) -> *mut bindings::faux_device { + self.0.as_ptr() + } +} + +impl AsRef<device::Device> for Registration { + fn as_ref(&self) -> &device::Device { + // SAFETY: The underlying `device` in `faux_device` is guaranteed by the C API to be + // a valid initialized `device`. + unsafe { device::Device::from_raw(addr_of_mut!((*self.as_raw()).dev)) } + } +} + +impl Drop for Registration { + #[inline] + fn drop(&mut self) { + // SAFETY: `self.0` is a valid registered faux_device via our type invariants. + unsafe { bindings::faux_device_destroy(self.as_raw()) } + } +} + +// SAFETY: The faux device API is thread-safe as guaranteed by the device core, as long as +// faux_device_destroy() is guaranteed to only be called once - which is guaranteed by our type not +// having Copy/Clone. +unsafe impl Send for Registration {} + +// SAFETY: The faux device API is thread-safe as guaranteed by the device core, as long as +// faux_device_destroy() is guaranteed to only be called once - which is guaranteed by our type not +// having Copy/Clone. +unsafe impl Sync for Registration {} diff --git a/rust/kernel/firmware.rs b/rust/kernel/firmware.rs new file mode 100644 index 000000000000..71168d8004e2 --- /dev/null +++ b/rust/kernel/firmware.rs @@ -0,0 +1,345 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! Firmware abstraction +//! +//! C header: [`include/linux/firmware.h`](srctree/include/linux/firmware.h) + +use crate::{ + bindings, + device::Device, + error::Error, + error::Result, + ffi, + str::{CStr, CStrExt as _}, +}; +use core::ptr::NonNull; + +/// # Invariants +/// +/// One of the following: `bindings::request_firmware`, `bindings::firmware_request_nowarn`, +/// `bindings::firmware_request_platform`, `bindings::request_firmware_direct`. +struct FwFunc( + unsafe extern "C" fn( + *mut *const bindings::firmware, + *const ffi::c_char, + *mut bindings::device, + ) -> i32, +); + +impl FwFunc { + fn request() -> Self { + Self(bindings::request_firmware) + } + + fn request_nowarn() -> Self { + Self(bindings::firmware_request_nowarn) + } +} + +/// Abstraction around a C `struct firmware`. +/// +/// This is a simple abstraction around the C firmware API. Just like with the C API, firmware can +/// be requested. Once requested the abstraction provides direct access to the firmware buffer as +/// `&[u8]`. The firmware is released once [`Firmware`] is dropped. +/// +/// # Invariants +/// +/// The pointer is valid, and has ownership over the instance of `struct firmware`. +/// +/// The `Firmware`'s backing buffer is not modified. +/// +/// # Examples +/// +/// ```no_run +/// # use kernel::{device::Device, firmware::Firmware}; +/// +/// # fn no_run() -> Result<(), Error> { +/// # // SAFETY: *NOT* safe, just for the example to get an `ARef<Device>` instance +/// # let dev = unsafe { Device::get_device(core::ptr::null_mut()) }; +/// +/// let fw = Firmware::request(c"path/to/firmware.bin", &dev)?; +/// let blob = fw.data(); +/// +/// # Ok(()) +/// # } +/// ``` +pub struct Firmware(NonNull<bindings::firmware>); + +impl Firmware { + fn request_internal(name: &CStr, dev: &Device, func: FwFunc) -> Result<Self> { + let mut fw: *mut bindings::firmware = core::ptr::null_mut(); + let pfw: *mut *mut bindings::firmware = &mut fw; + let pfw: *mut *const bindings::firmware = pfw.cast(); + + // SAFETY: `pfw` is a valid pointer to a NULL initialized `bindings::firmware` pointer. + // `name` and `dev` are valid as by their type invariants. + let ret = unsafe { func.0(pfw, name.as_char_ptr(), dev.as_raw()) }; + if ret != 0 { + return Err(Error::from_errno(ret)); + } + + // SAFETY: `func` not bailing out with a non-zero error code, guarantees that `fw` is a + // valid pointer to `bindings::firmware`. + Ok(Firmware(unsafe { NonNull::new_unchecked(fw) })) + } + + /// Send a firmware request and wait for it. See also `bindings::request_firmware`. + pub fn request(name: &CStr, dev: &Device) -> Result<Self> { + Self::request_internal(name, dev, FwFunc::request()) + } + + /// Send a request for an optional firmware module. See also + /// `bindings::firmware_request_nowarn`. + pub fn request_nowarn(name: &CStr, dev: &Device) -> Result<Self> { + Self::request_internal(name, dev, FwFunc::request_nowarn()) + } + + fn as_raw(&self) -> *mut bindings::firmware { + self.0.as_ptr() + } + + /// Returns the size of the requested firmware in bytes. + pub fn size(&self) -> usize { + // SAFETY: `self.as_raw()` is valid by the type invariant. + unsafe { (*self.as_raw()).size } + } + + /// Returns the requested firmware as `&[u8]`. + pub fn data(&self) -> &[u8] { + // SAFETY: `self.as_raw()` is valid by the type invariant. Additionally, + // `bindings::firmware` guarantees, if successfully requested, that + // `bindings::firmware::data` has a size of `bindings::firmware::size` bytes. + unsafe { core::slice::from_raw_parts((*self.as_raw()).data, self.size()) } + } +} + +impl Drop for Firmware { + fn drop(&mut self) { + // SAFETY: `self.as_raw()` is valid by the type invariant. + unsafe { bindings::release_firmware(self.as_raw()) }; + } +} + +// SAFETY: `Firmware` only holds a pointer to a C `struct firmware`, which is safe to be used from +// any thread. +unsafe impl Send for Firmware {} + +// SAFETY: `Firmware` only holds a pointer to a C `struct firmware`, references to which are safe to +// be used from any thread. +unsafe impl Sync for Firmware {} + +/// Create firmware .modinfo entries. +/// +/// This macro is the counterpart of the C macro `MODULE_FIRMWARE()`, but instead of taking a +/// simple string literals, which is already covered by the `firmware` field of +/// [`crate::prelude::module!`], it allows the caller to pass a builder type, based on the +/// [`ModInfoBuilder`], which can create the firmware modinfo strings in a more flexible way. +/// +/// Drivers should extend the [`ModInfoBuilder`] with their own driver specific builder type. +/// +/// The `builder` argument must be a type which implements the following function. +/// +/// `const fn create(module_name: &'static CStr) -> ModInfoBuilder` +/// +/// `create` should pass the `module_name` to the [`ModInfoBuilder`] and, with the help of +/// it construct the corresponding firmware modinfo. +/// +/// Typically, such contracts would be enforced by a trait, however traits do not (yet) support +/// const functions. +/// +/// # Examples +/// +/// ``` +/// # mod module_firmware_test { +/// # use kernel::firmware; +/// # use kernel::prelude::*; +/// # +/// # struct MyModule; +/// # +/// # impl kernel::Module for MyModule { +/// # fn init(_module: &'static ThisModule) -> Result<Self> { +/// # Ok(Self) +/// # } +/// # } +/// # +/// # +/// struct Builder<const N: usize>; +/// +/// impl<const N: usize> Builder<N> { +/// const DIR: &'static str = "vendor/chip/"; +/// const FILES: [&'static str; 3] = [ "foo", "bar", "baz" ]; +/// +/// const fn create(module_name: &'static kernel::str::CStr) -> firmware::ModInfoBuilder<N> { +/// let mut builder = firmware::ModInfoBuilder::new(module_name); +/// +/// let mut i = 0; +/// while i < Self::FILES.len() { +/// builder = builder.new_entry() +/// .push(Self::DIR) +/// .push(Self::FILES[i]) +/// .push(".bin"); +/// +/// i += 1; +/// } +/// +/// builder +/// } +/// } +/// +/// module! { +/// type: MyModule, +/// name: "module_firmware_test", +/// authors: ["Rust for Linux"], +/// description: "module_firmware! test module", +/// license: "GPL", +/// } +/// +/// kernel::module_firmware!(Builder); +/// # } +/// ``` +#[macro_export] +macro_rules! module_firmware { + // The argument is the builder type without the const generic, since it's deferred from within + // this macro. Hence, we can neither use `expr` nor `ty`. + ($($builder:tt)*) => { + const _: () = { + const __MODULE_FIRMWARE_PREFIX: &'static $crate::str::CStr = if cfg!(MODULE) { + c"" + } else { + <LocalModule as $crate::ModuleMetadata>::NAME + }; + + #[link_section = ".modinfo"] + #[used(compiler)] + static __MODULE_FIRMWARE: [u8; $($builder)*::create(__MODULE_FIRMWARE_PREFIX) + .build_length()] = $($builder)*::create(__MODULE_FIRMWARE_PREFIX).build(); + }; + }; +} + +/// Builder for firmware module info. +/// +/// [`ModInfoBuilder`] is a helper component to flexibly compose firmware paths strings for the +/// .modinfo section in const context. +/// +/// Therefore the [`ModInfoBuilder`] provides the methods [`ModInfoBuilder::new_entry`] and +/// [`ModInfoBuilder::push`], where the latter is used to push path components and the former to +/// mark the beginning of a new path string. +/// +/// [`ModInfoBuilder`] is meant to be used in combination with [`kernel::module_firmware!`]. +/// +/// The const generic `N` as well as the `module_name` parameter of [`ModInfoBuilder::new`] is an +/// internal implementation detail and supplied through the above macro. +pub struct ModInfoBuilder<const N: usize> { + buf: [u8; N], + n: usize, + module_name: &'static CStr, +} + +impl<const N: usize> ModInfoBuilder<N> { + /// Create an empty builder instance. + pub const fn new(module_name: &'static CStr) -> Self { + Self { + buf: [0; N], + n: 0, + module_name, + } + } + + const fn push_internal(mut self, bytes: &[u8]) -> Self { + let mut j = 0; + + if N == 0 { + self.n += bytes.len(); + return self; + } + + while j < bytes.len() { + if self.n < N { + self.buf[self.n] = bytes[j]; + } + self.n += 1; + j += 1; + } + self + } + + /// Push an additional path component. + /// + /// Append path components to the [`ModInfoBuilder`] instance. Paths need to be separated + /// with [`ModInfoBuilder::new_entry`]. + /// + /// # Examples + /// + /// ``` + /// use kernel::firmware::ModInfoBuilder; + /// + /// # const DIR: &str = "vendor/chip/"; + /// # const fn no_run<const N: usize>(builder: ModInfoBuilder<N>) { + /// let builder = builder.new_entry() + /// .push(DIR) + /// .push("foo.bin") + /// .new_entry() + /// .push(DIR) + /// .push("bar.bin"); + /// # } + /// ``` + pub const fn push(self, s: &str) -> Self { + // Check whether there has been an initial call to `next_entry()`. + if N != 0 && self.n == 0 { + crate::build_error!("Must call next_entry() before push()."); + } + + self.push_internal(s.as_bytes()) + } + + const fn push_module_name(self) -> Self { + let mut this = self; + let module_name = this.module_name; + + if !this.module_name.is_empty() { + this = this.push_internal(module_name.to_bytes_with_nul()); + + if N != 0 { + // Re-use the space taken by the NULL terminator and swap it with the '.' separator. + this.buf[this.n - 1] = b'.'; + } + } + + this + } + + /// Prepare the [`ModInfoBuilder`] for the next entry. + /// + /// This method acts as a separator between module firmware path entries. + /// + /// Must be called before constructing a new entry with subsequent calls to + /// [`ModInfoBuilder::push`]. + /// + /// See [`ModInfoBuilder::push`] for an example. + pub const fn new_entry(self) -> Self { + self.push_internal(b"\0") + .push_module_name() + .push_internal(b"firmware=") + } + + /// Build the byte array. + pub const fn build(self) -> [u8; N] { + // Add the final NULL terminator. + let this = self.push_internal(b"\0"); + + if this.n == N { + this.buf + } else { + crate::build_error!("Length mismatch."); + } + } +} + +impl ModInfoBuilder<0> { + /// Return the length of the byte array to build. + pub const fn build_length(self) -> usize { + // Compensate for the NULL terminator added by `build`. + self.n + 1 + } +} diff --git a/rust/kernel/fmt.rs b/rust/kernel/fmt.rs new file mode 100644 index 000000000000..84d634201d90 --- /dev/null +++ b/rust/kernel/fmt.rs @@ -0,0 +1,92 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! Formatting utilities. +//! +//! This module is intended to be used in place of `core::fmt` in kernel code. + +pub use core::fmt::{Arguments, Debug, Error, Formatter, Result, Write}; + +/// Internal adapter used to route allow implementations of formatting traits for foreign types. +/// +/// It is inserted automatically by the [`fmt!`] macro and is not meant to be used directly. +/// +/// [`fmt!`]: crate::prelude::fmt! +#[doc(hidden)] +pub struct Adapter<T>(pub T); + +macro_rules! impl_fmt_adapter_forward { + ($($trait:ident),* $(,)?) => { + $( + impl<T: $trait> $trait for Adapter<T> { + fn fmt(&self, f: &mut Formatter<'_>) -> Result { + let Self(t) = self; + $trait::fmt(t, f) + } + } + )* + }; +} + +use core::fmt::{Binary, LowerExp, LowerHex, Octal, Pointer, UpperExp, UpperHex}; +impl_fmt_adapter_forward!(Debug, LowerHex, UpperHex, Octal, Binary, Pointer, LowerExp, UpperExp); + +/// A copy of [`core::fmt::Display`] that allows us to implement it for foreign types. +/// +/// Types should implement this trait rather than [`core::fmt::Display`]. Together with the +/// [`Adapter`] type and [`fmt!`] macro, it allows for formatting foreign types (e.g. types from +/// core) which do not implement [`core::fmt::Display`] directly. +/// +/// [`fmt!`]: crate::prelude::fmt! +pub trait Display { + /// Same as [`core::fmt::Display::fmt`]. + fn fmt(&self, f: &mut Formatter<'_>) -> Result; +} + +impl<T: ?Sized + Display> Display for &T { + fn fmt(&self, f: &mut Formatter<'_>) -> Result { + Display::fmt(*self, f) + } +} + +impl<T: ?Sized + Display> core::fmt::Display for Adapter<&T> { + fn fmt(&self, f: &mut Formatter<'_>) -> Result { + let Self(t) = self; + Display::fmt(t, f) + } +} + +macro_rules! impl_display_forward { + ($( + $( { $($generics:tt)* } )? $ty:ty $( { where $($where:tt)* } )? + ),* $(,)?) => { + $( + impl$($($generics)*)? Display for $ty $(where $($where)*)? { + fn fmt(&self, f: &mut Formatter<'_>) -> Result { + core::fmt::Display::fmt(self, f) + } + } + )* + }; +} + +impl_display_forward!( + bool, + char, + core::panic::PanicInfo<'_>, + Arguments<'_>, + i128, + i16, + i32, + i64, + i8, + isize, + str, + u128, + u16, + u32, + u64, + u8, + usize, + {<T: ?Sized>} crate::sync::Arc<T> {where crate::sync::Arc<T>: core::fmt::Display}, + {<T: ?Sized>} crate::sync::UniqueArc<T> {where crate::sync::UniqueArc<T>: core::fmt::Display}, +); diff --git a/rust/kernel/fs.rs b/rust/kernel/fs.rs new file mode 100644 index 000000000000..6ba6bdf143cb --- /dev/null +++ b/rust/kernel/fs.rs @@ -0,0 +1,11 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! Kernel file systems. +//! +//! C headers: [`include/linux/fs.h`](srctree/include/linux/fs.h) + +pub mod file; +pub use self::file::{File, LocalFile}; + +mod kiocb; +pub use self::kiocb::Kiocb; diff --git a/rust/kernel/fs/file.rs b/rust/kernel/fs/file.rs new file mode 100644 index 000000000000..23ee689bd240 --- /dev/null +++ b/rust/kernel/fs/file.rs @@ -0,0 +1,473 @@ +// SPDX-License-Identifier: GPL-2.0 + +// Copyright (C) 2024 Google LLC. + +//! Files and file descriptors. +//! +//! C headers: [`include/linux/fs.h`](srctree/include/linux/fs.h) and +//! [`include/linux/file.h`](srctree/include/linux/file.h) + +use crate::{ + bindings, + cred::Credential, + error::{code::*, to_result, Error, Result}, + fmt, + sync::aref::{ARef, AlwaysRefCounted}, + types::{NotThreadSafe, Opaque}, +}; +use core::ptr; + +/// Primitive type representing the offset within a [`File`]. +/// +/// Type alias for `bindings::loff_t`. +pub type Offset = bindings::loff_t; + +/// Flags associated with a [`File`]. +pub mod flags { + /// File is opened in append mode. + pub const O_APPEND: u32 = bindings::O_APPEND; + + /// Signal-driven I/O is enabled. + pub const O_ASYNC: u32 = bindings::FASYNC; + + /// Close-on-exec flag is set. + pub const O_CLOEXEC: u32 = bindings::O_CLOEXEC; + + /// File was created if it didn't already exist. + pub const O_CREAT: u32 = bindings::O_CREAT; + + /// Direct I/O is enabled for this file. + pub const O_DIRECT: u32 = bindings::O_DIRECT; + + /// File must be a directory. + pub const O_DIRECTORY: u32 = bindings::O_DIRECTORY; + + /// Like [`O_SYNC`] except metadata is not synced. + pub const O_DSYNC: u32 = bindings::O_DSYNC; + + /// Ensure that this file is created with the `open(2)` call. + pub const O_EXCL: u32 = bindings::O_EXCL; + + /// Large file size enabled (`off64_t` over `off_t`). + pub const O_LARGEFILE: u32 = bindings::O_LARGEFILE; + + /// Do not update the file last access time. + pub const O_NOATIME: u32 = bindings::O_NOATIME; + + /// File should not be used as process's controlling terminal. + pub const O_NOCTTY: u32 = bindings::O_NOCTTY; + + /// If basename of path is a symbolic link, fail open. + pub const O_NOFOLLOW: u32 = bindings::O_NOFOLLOW; + + /// File is using nonblocking I/O. + pub const O_NONBLOCK: u32 = bindings::O_NONBLOCK; + + /// File is using nonblocking I/O. + /// + /// This is effectively the same flag as [`O_NONBLOCK`] on all architectures + /// except SPARC64. + pub const O_NDELAY: u32 = bindings::O_NDELAY; + + /// Used to obtain a path file descriptor. + pub const O_PATH: u32 = bindings::O_PATH; + + /// Write operations on this file will flush data and metadata. + pub const O_SYNC: u32 = bindings::O_SYNC; + + /// This file is an unnamed temporary regular file. + pub const O_TMPFILE: u32 = bindings::O_TMPFILE; + + /// File should be truncated to length 0. + pub const O_TRUNC: u32 = bindings::O_TRUNC; + + /// Bitmask for access mode flags. + /// + /// # Examples + /// + /// ``` + /// use kernel::fs::file; + /// # fn do_something() {} + /// # let flags = 0; + /// if (flags & file::flags::O_ACCMODE) == file::flags::O_RDONLY { + /// do_something(); + /// } + /// ``` + pub const O_ACCMODE: u32 = bindings::O_ACCMODE; + + /// File is read only. + pub const O_RDONLY: u32 = bindings::O_RDONLY; + + /// File is write only. + pub const O_WRONLY: u32 = bindings::O_WRONLY; + + /// File can be both read and written. + pub const O_RDWR: u32 = bindings::O_RDWR; +} + +/// Wraps the kernel's `struct file`. Thread safe. +/// +/// This represents an open file rather than a file on a filesystem. Processes generally reference +/// open files using file descriptors. However, file descriptors are not the same as files. A file +/// descriptor is just an integer that corresponds to a file, and a single file may be referenced +/// by multiple file descriptors. +/// +/// # Refcounting +/// +/// Instances of this type are reference-counted. The reference count is incremented by the +/// `fget`/`get_file` functions and decremented by `fput`. The Rust type `ARef<File>` represents a +/// pointer that owns a reference count on the file. +/// +/// Whenever a process opens a file descriptor (fd), it stores a pointer to the file in its fd +/// table (`struct files_struct`). This pointer owns a reference count to the file, ensuring the +/// file isn't prematurely deleted while the file descriptor is open. In Rust terminology, the +/// pointers in `struct files_struct` are `ARef<File>` pointers. +/// +/// ## Light refcounts +/// +/// Whenever a process has an fd to a file, it may use something called a "light refcount" as a +/// performance optimization. Light refcounts are acquired by calling `fdget` and released with +/// `fdput`. The idea behind light refcounts is that if the fd is not closed between the calls to +/// `fdget` and `fdput`, then the refcount cannot hit zero during that time, as the `struct +/// files_struct` holds a reference until the fd is closed. This means that it's safe to access the +/// file even if `fdget` does not increment the refcount. +/// +/// The requirement that the fd is not closed during a light refcount applies globally across all +/// threads - not just on the thread using the light refcount. For this reason, light refcounts are +/// only used when the `struct files_struct` is not shared with other threads, since this ensures +/// that other unrelated threads cannot suddenly start using the fd and close it. Therefore, +/// calling `fdget` on a shared `struct files_struct` creates a normal refcount instead of a light +/// refcount. +/// +/// Light reference counts must be released with `fdput` before the system call returns to +/// userspace. This means that if you wait until the current system call returns to userspace, then +/// all light refcounts that existed at the time have gone away. +/// +/// ### The file position +/// +/// Each `struct file` has a position integer, which is protected by the `f_pos_lock` mutex. +/// However, if the `struct file` is not shared, then the kernel may avoid taking the lock as a +/// performance optimization. +/// +/// The condition for avoiding the `f_pos_lock` mutex is different from the condition for using +/// `fdget`. With `fdget`, you may avoid incrementing the refcount as long as the current fd table +/// is not shared; it is okay if there are other fd tables that also reference the same `struct +/// file`. However, `fdget_pos` can only avoid taking the `f_pos_lock` if the entire `struct file` +/// is not shared, as different processes with an fd to the same `struct file` share the same +/// position. +/// +/// To represent files that are not thread safe due to this optimization, the [`LocalFile`] type is +/// used. +/// +/// ## Rust references +/// +/// The reference type `&File` is similar to light refcounts: +/// +/// * `&File` references don't own a reference count. They can only exist as long as the reference +/// count stays positive, and can only be created when there is some mechanism in place to ensure +/// this. +/// +/// * The Rust borrow-checker normally ensures this by enforcing that the `ARef<File>` from which +/// a `&File` is created outlives the `&File`. +/// +/// * Using the unsafe [`File::from_raw_file`] means that it is up to the caller to ensure that the +/// `&File` only exists while the reference count is positive. +/// +/// * You can think of `fdget` as using an fd to look up an `ARef<File>` in the `struct +/// files_struct` and create an `&File` from it. The "fd cannot be closed" rule is like the Rust +/// rule "the `ARef<File>` must outlive the `&File`". +/// +/// # Invariants +/// +/// * All instances of this type are refcounted using the `f_count` field. +/// * There must not be any active calls to `fdget_pos` on this file that did not take the +/// `f_pos_lock` mutex. +#[repr(transparent)] +pub struct File { + inner: Opaque<bindings::file>, +} + +// SAFETY: This file is known to not have any active `fdget_pos` calls that did not take the +// `f_pos_lock` mutex, so it is safe to transfer it between threads. +unsafe impl Send for File {} + +// SAFETY: This file is known to not have any active `fdget_pos` calls that did not take the +// `f_pos_lock` mutex, so it is safe to access its methods from several threads in parallel. +unsafe impl Sync for File {} + +// SAFETY: The type invariants guarantee that `File` is always ref-counted. This implementation +// makes `ARef<File>` own a normal refcount. +unsafe impl AlwaysRefCounted for File { + #[inline] + fn inc_ref(&self) { + // SAFETY: The existence of a shared reference means that the refcount is nonzero. + unsafe { bindings::get_file(self.as_ptr()) }; + } + + #[inline] + unsafe fn dec_ref(obj: ptr::NonNull<File>) { + // SAFETY: To call this method, the caller passes us ownership of a normal refcount, so we + // may drop it. The cast is okay since `File` has the same representation as `struct file`. + unsafe { bindings::fput(obj.cast().as_ptr()) } + } +} + +/// Wraps the kernel's `struct file`. Not thread safe. +/// +/// This type represents a file that is not known to be safe to transfer across thread boundaries. +/// To obtain a thread-safe [`File`], use the [`assume_no_fdget_pos`] conversion. +/// +/// See the documentation for [`File`] for more information. +/// +/// # Invariants +/// +/// * All instances of this type are refcounted using the `f_count` field. +/// * If there is an active call to `fdget_pos` that did not take the `f_pos_lock` mutex, then it +/// must be on the same thread as this file. +/// +/// [`assume_no_fdget_pos`]: LocalFile::assume_no_fdget_pos +#[repr(transparent)] +pub struct LocalFile { + inner: Opaque<bindings::file>, +} + +// SAFETY: The type invariants guarantee that `LocalFile` is always ref-counted. This implementation +// makes `ARef<LocalFile>` own a normal refcount. +unsafe impl AlwaysRefCounted for LocalFile { + #[inline] + fn inc_ref(&self) { + // SAFETY: The existence of a shared reference means that the refcount is nonzero. + unsafe { bindings::get_file(self.as_ptr()) }; + } + + #[inline] + unsafe fn dec_ref(obj: ptr::NonNull<LocalFile>) { + // SAFETY: To call this method, the caller passes us ownership of a normal refcount, so we + // may drop it. The cast is okay since `LocalFile` has the same representation as + // `struct file`. + unsafe { bindings::fput(obj.cast().as_ptr()) } + } +} + +impl LocalFile { + /// Constructs a new `struct file` wrapper from a file descriptor. + /// + /// The file descriptor belongs to the current process, and there might be active local calls + /// to `fdget_pos` on the same file. + /// + /// To obtain an `ARef<File>`, use the [`assume_no_fdget_pos`] function to convert. + /// + /// [`assume_no_fdget_pos`]: LocalFile::assume_no_fdget_pos + #[inline] + pub fn fget(fd: u32) -> Result<ARef<LocalFile>, BadFdError> { + // SAFETY: FFI call, there are no requirements on `fd`. + let ptr = ptr::NonNull::new(unsafe { bindings::fget(fd) }).ok_or(BadFdError)?; + + // SAFETY: `bindings::fget` created a refcount, and we pass ownership of it to the `ARef`. + // + // INVARIANT: This file is in the fd table on this thread, so either all `fdget_pos` calls + // are on this thread, or the file is shared, in which case `fdget_pos` calls took the + // `f_pos_lock` mutex. + Ok(unsafe { ARef::from_raw(ptr.cast()) }) + } + + /// Creates a reference to a [`LocalFile`] from a valid pointer. + /// + /// # Safety + /// + /// * The caller must ensure that `ptr` points at a valid file and that the file's refcount is + /// positive for the duration of `'a`. + /// * The caller must ensure that if there is an active call to `fdget_pos` that did not take + /// the `f_pos_lock` mutex, then that call is on the current thread. + #[inline] + pub unsafe fn from_raw_file<'a>(ptr: *const bindings::file) -> &'a LocalFile { + // SAFETY: The caller guarantees that the pointer is not dangling and stays valid for the + // duration of `'a`. The cast is okay because `LocalFile` is `repr(transparent)`. + // + // INVARIANT: The caller guarantees that there are no problematic `fdget_pos` calls. + unsafe { &*ptr.cast() } + } + + /// Assume that there are no active `fdget_pos` calls that prevent us from sharing this file. + /// + /// This makes it safe to transfer this file to other threads. No checks are performed, and + /// using it incorrectly may lead to a data race on the file position if the file is shared + /// with another thread. + /// + /// This method is intended to be used together with [`LocalFile::fget`] when the caller knows + /// statically that there are no `fdget_pos` calls on the current thread. For example, you + /// might use it when calling `fget` from an ioctl, since ioctls usually do not touch the file + /// position. + /// + /// # Safety + /// + /// There must not be any active `fdget_pos` calls on the current thread. + #[inline] + pub unsafe fn assume_no_fdget_pos(me: ARef<LocalFile>) -> ARef<File> { + // INVARIANT: There are no `fdget_pos` calls on the current thread, and by the type + // invariants, if there is a `fdget_pos` call on another thread, then it took the + // `f_pos_lock` mutex. + // + // SAFETY: `LocalFile` and `File` have the same layout. + unsafe { ARef::from_raw(ARef::into_raw(me).cast()) } + } + + /// Returns a raw pointer to the inner C struct. + #[inline] + pub fn as_ptr(&self) -> *mut bindings::file { + self.inner.get() + } + + /// Returns the credentials of the task that originally opened the file. + pub fn cred(&self) -> &Credential { + // SAFETY: It's okay to read the `f_cred` field without synchronization because `f_cred` is + // never changed after initialization of the file. + let ptr = unsafe { (*self.as_ptr()).f_cred }; + + // SAFETY: The signature of this function ensures that the caller will only access the + // returned credential while the file is still valid, and the C side ensures that the + // credential stays valid at least as long as the file. + unsafe { Credential::from_ptr(ptr) } + } + + /// Returns the flags associated with the file. + /// + /// The flags are a combination of the constants in [`flags`]. + #[inline] + pub fn flags(&self) -> u32 { + // This `read_volatile` is intended to correspond to a READ_ONCE call. + // + // SAFETY: The file is valid because the shared reference guarantees a nonzero refcount. + // + // FIXME(read_once): Replace with `read_once` when available on the Rust side. + unsafe { core::ptr::addr_of!((*self.as_ptr()).f_flags).read_volatile() } + } +} + +impl File { + /// Creates a reference to a [`File`] from a valid pointer. + /// + /// # Safety + /// + /// * The caller must ensure that `ptr` points at a valid file and that the file's refcount is + /// positive for the duration of `'a`. + /// * The caller must ensure that if there are active `fdget_pos` calls on this file, then they + /// took the `f_pos_lock` mutex. + #[inline] + pub unsafe fn from_raw_file<'a>(ptr: *const bindings::file) -> &'a File { + // SAFETY: The caller guarantees that the pointer is not dangling and stays valid for the + // duration of `'a`. The cast is okay because `File` is `repr(transparent)`. + // + // INVARIANT: The caller guarantees that there are no problematic `fdget_pos` calls. + unsafe { &*ptr.cast() } + } +} + +// Make LocalFile methods available on File. +impl core::ops::Deref for File { + type Target = LocalFile; + #[inline] + fn deref(&self) -> &LocalFile { + // SAFETY: The caller provides a `&File`, and since it is a reference, it must point at a + // valid file for the desired duration. + // + // By the type invariants, there are no `fdget_pos` calls that did not take the + // `f_pos_lock` mutex. + unsafe { LocalFile::from_raw_file(core::ptr::from_ref(self).cast()) } + } +} + +/// A file descriptor reservation. +/// +/// This allows the creation of a file descriptor in two steps: first, we reserve a slot for it, +/// then we commit or drop the reservation. The first step may fail (e.g., the current process ran +/// out of available slots), but commit and drop never fail (and are mutually exclusive). +/// +/// Dropping the reservation happens in the destructor of this type. +/// +/// # Invariants +/// +/// The fd stored in this struct must correspond to a reserved file descriptor of the current task. +pub struct FileDescriptorReservation { + fd: u32, + /// Prevent values of this type from being moved to a different task. + /// + /// The `fd_install` and `put_unused_fd` functions assume that the value of `current` is + /// unchanged since the call to `get_unused_fd_flags`. By adding this marker to this type, we + /// prevent it from being moved across task boundaries, which ensures that `current` does not + /// change while this value exists. + _not_send: NotThreadSafe, +} + +impl FileDescriptorReservation { + /// Creates a new file descriptor reservation. + #[inline] + pub fn get_unused_fd_flags(flags: u32) -> Result<Self> { + // SAFETY: FFI call, there are no safety requirements on `flags`. + let fd: i32 = unsafe { bindings::get_unused_fd_flags(flags) }; + to_result(fd)?; + + Ok(Self { + fd: fd as u32, + _not_send: NotThreadSafe, + }) + } + + /// Returns the file descriptor number that was reserved. + #[inline] + pub fn reserved_fd(&self) -> u32 { + self.fd + } + + /// Commits the reservation. + /// + /// The previously reserved file descriptor is bound to `file`. This method consumes the + /// [`FileDescriptorReservation`], so it will not be usable after this call. + #[inline] + pub fn fd_install(self, file: ARef<File>) { + // SAFETY: `self.fd` was previously returned by `get_unused_fd_flags`. We have not yet used + // the fd, so it is still valid, and `current` still refers to the same task, as this type + // cannot be moved across task boundaries. + // + // Furthermore, the file pointer is guaranteed to own a refcount by its type invariants, + // and we take ownership of that refcount by not running the destructor below. + // Additionally, the file is known to not have any non-shared `fdget_pos` calls, so even if + // this process starts using the file position, this will not result in a data race on the + // file position. + unsafe { bindings::fd_install(self.fd, file.as_ptr()) }; + + // `fd_install` consumes both the file descriptor and the file reference, so we cannot run + // the destructors. + core::mem::forget(self); + core::mem::forget(file); + } +} + +impl Drop for FileDescriptorReservation { + #[inline] + fn drop(&mut self) { + // SAFETY: By the type invariants of this type, `self.fd` was previously returned by + // `get_unused_fd_flags`. We have not yet used the fd, so it is still valid, and `current` + // still refers to the same task, as this type cannot be moved across task boundaries. + unsafe { bindings::put_unused_fd(self.fd) }; + } +} + +/// Represents the [`EBADF`] error code. +/// +/// Used for methods that can only fail with [`EBADF`]. +#[derive(Copy, Clone, Eq, PartialEq)] +pub struct BadFdError; + +impl From<BadFdError> for Error { + #[inline] + fn from(_: BadFdError) -> Error { + EBADF + } +} + +impl fmt::Debug for BadFdError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.pad("EBADF") + } +} diff --git a/rust/kernel/fs/kiocb.rs b/rust/kernel/fs/kiocb.rs new file mode 100644 index 000000000000..84c936cd69b0 --- /dev/null +++ b/rust/kernel/fs/kiocb.rs @@ -0,0 +1,68 @@ +// SPDX-License-Identifier: GPL-2.0 + +// Copyright (C) 2024 Google LLC. + +//! Kernel IO callbacks. +//! +//! C headers: [`include/linux/fs.h`](srctree/include/linux/fs.h) + +use core::marker::PhantomData; +use core::ptr::NonNull; +use kernel::types::ForeignOwnable; + +/// Wrapper for the kernel's `struct kiocb`. +/// +/// Currently this abstractions is incomplete and is essentially just a tuple containing a +/// reference to a file and a file position. +/// +/// The type `T` represents the filesystem or driver specific data associated with the file. +/// +/// # Invariants +/// +/// `inner` points at a valid `struct kiocb` whose file has the type `T` as its private data. +pub struct Kiocb<'a, T> { + inner: NonNull<bindings::kiocb>, + _phantom: PhantomData<&'a T>, +} + +impl<'a, T: ForeignOwnable> Kiocb<'a, T> { + /// Create a `Kiocb` from a raw pointer. + /// + /// # Safety + /// + /// The pointer must reference a valid `struct kiocb` for the duration of `'a`. The private + /// data of the file must be `T`. + pub unsafe fn from_raw(kiocb: *mut bindings::kiocb) -> Self { + Self { + // SAFETY: If a pointer is valid it is not null. + inner: unsafe { NonNull::new_unchecked(kiocb) }, + _phantom: PhantomData, + } + } + + /// Access the underlying `struct kiocb` directly. + pub fn as_raw(&self) -> *mut bindings::kiocb { + self.inner.as_ptr() + } + + /// Get the filesystem or driver specific data associated with the file. + pub fn file(&self) -> <T as ForeignOwnable>::Borrowed<'a> { + // SAFETY: We have shared access to this kiocb and hence the underlying file, so we can + // read the file's private data. + let private = unsafe { (*(*self.as_raw()).ki_filp).private_data }; + // SAFETY: The kiocb has shared access to the private data. + unsafe { <T as ForeignOwnable>::borrow(private) } + } + + /// Gets the current value of `ki_pos`. + pub fn ki_pos(&self) -> i64 { + // SAFETY: We have shared access to the kiocb, so we can read its `ki_pos` field. + unsafe { (*self.as_raw()).ki_pos } + } + + /// Gets a mutable reference to the `ki_pos` field. + pub fn ki_pos_mut(&mut self) -> &mut i64 { + // SAFETY: We have exclusive access to the kiocb, so we can write to `ki_pos`. + unsafe { &mut (*self.as_raw()).ki_pos } + } +} diff --git a/rust/kernel/generated_arch_reachable_asm.rs.S b/rust/kernel/generated_arch_reachable_asm.rs.S new file mode 100644 index 000000000000..3886a9ad3a99 --- /dev/null +++ b/rust/kernel/generated_arch_reachable_asm.rs.S @@ -0,0 +1,7 @@ +/* SPDX-License-Identifier: GPL-2.0 */ + +#include <linux/bug.h> + +// Cut here. + +::kernel::concat_literals!(ARCH_WARN_REACHABLE) diff --git a/rust/kernel/generated_arch_static_branch_asm.rs.S b/rust/kernel/generated_arch_static_branch_asm.rs.S new file mode 100644 index 000000000000..2afb638708db --- /dev/null +++ b/rust/kernel/generated_arch_static_branch_asm.rs.S @@ -0,0 +1,7 @@ +/* SPDX-License-Identifier: GPL-2.0 */ + +#include <linux/jump_label.h> + +// Cut here. + +::kernel::concat_literals!(ARCH_STATIC_BRANCH_ASM("{symb} + {off} + {branch}", "{l_yes}")) diff --git a/rust/kernel/generated_arch_warn_asm.rs.S b/rust/kernel/generated_arch_warn_asm.rs.S new file mode 100644 index 000000000000..409eb4c2d3a1 --- /dev/null +++ b/rust/kernel/generated_arch_warn_asm.rs.S @@ -0,0 +1,7 @@ +/* SPDX-License-Identifier: GPL-2.0 */ + +#include <linux/bug.h> + +// Cut here. + +::kernel::concat_literals!(ARCH_WARN_ASM("{file}", "{line}", "{flags}", "{size}")) diff --git a/rust/kernel/i2c.rs b/rust/kernel/i2c.rs new file mode 100644 index 000000000000..491e6cc25cf4 --- /dev/null +++ b/rust/kernel/i2c.rs @@ -0,0 +1,586 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! I2C Driver subsystem + +// I2C Driver abstractions. +use crate::{ + acpi, + container_of, + device, + device_id::{ + RawDeviceId, + RawDeviceIdIndex, // + }, + devres::Devres, + driver, + error::*, + of, + prelude::*, + types::{ + AlwaysRefCounted, + Opaque, // + }, // +}; + +use core::{ + marker::PhantomData, + mem::offset_of, + ptr::{ + from_ref, + NonNull, // + }, // +}; + +use kernel::types::ARef; + +/// An I2C device id table. +#[repr(transparent)] +#[derive(Clone, Copy)] +pub struct DeviceId(bindings::i2c_device_id); + +impl DeviceId { + const I2C_NAME_SIZE: usize = 20; + + /// Create a new device id from an I2C 'id' string. + #[inline(always)] + pub const fn new(id: &'static CStr) -> Self { + let src = id.to_bytes_with_nul(); + build_assert!(src.len() <= Self::I2C_NAME_SIZE, "ID exceeds 20 bytes"); + let mut i2c: bindings::i2c_device_id = pin_init::zeroed(); + let mut i = 0; + while i < src.len() { + i2c.name[i] = src[i]; + i += 1; + } + + Self(i2c) + } +} + +// SAFETY: `DeviceId` is a `#[repr(transparent)]` wrapper of `i2c_device_id` and does not add +// additional invariants, so it's safe to transmute to `RawType`. +unsafe impl RawDeviceId for DeviceId { + type RawType = bindings::i2c_device_id; +} + +// SAFETY: `DRIVER_DATA_OFFSET` is the offset to the `driver_data` field. +unsafe impl RawDeviceIdIndex for DeviceId { + const DRIVER_DATA_OFFSET: usize = core::mem::offset_of!(bindings::i2c_device_id, driver_data); + + fn index(&self) -> usize { + self.0.driver_data + } +} + +/// IdTable type for I2C +pub type IdTable<T> = &'static dyn kernel::device_id::IdTable<DeviceId, T>; + +/// Create a I2C `IdTable` with its alias for modpost. +#[macro_export] +macro_rules! i2c_device_table { + ($table_name:ident, $module_table_name:ident, $id_info_type: ty, $table_data: expr) => { + const $table_name: $crate::device_id::IdArray< + $crate::i2c::DeviceId, + $id_info_type, + { $table_data.len() }, + > = $crate::device_id::IdArray::new($table_data); + + $crate::module_device_table!("i2c", $module_table_name, $table_name); + }; +} + +/// An adapter for the registration of I2C drivers. +pub struct Adapter<T: Driver>(T); + +// SAFETY: A call to `unregister` for a given instance of `RegType` is guaranteed to be valid if +// a preceding call to `register` has been successful. +unsafe impl<T: Driver + 'static> driver::RegistrationOps for Adapter<T> { + type RegType = bindings::i2c_driver; + + unsafe fn register( + idrv: &Opaque<Self::RegType>, + name: &'static CStr, + module: &'static ThisModule, + ) -> Result { + build_assert!( + T::ACPI_ID_TABLE.is_some() || T::OF_ID_TABLE.is_some() || T::I2C_ID_TABLE.is_some(), + "At least one of ACPI/OF/Legacy tables must be present when registering an i2c driver" + ); + + let i2c_table = match T::I2C_ID_TABLE { + Some(table) => table.as_ptr(), + None => core::ptr::null(), + }; + + let of_table = match T::OF_ID_TABLE { + Some(table) => table.as_ptr(), + None => core::ptr::null(), + }; + + let acpi_table = match T::ACPI_ID_TABLE { + Some(table) => table.as_ptr(), + None => core::ptr::null(), + }; + + // SAFETY: It's safe to set the fields of `struct i2c_client` on initialization. + unsafe { + (*idrv.get()).driver.name = name.as_char_ptr(); + (*idrv.get()).probe = Some(Self::probe_callback); + (*idrv.get()).remove = Some(Self::remove_callback); + (*idrv.get()).shutdown = Some(Self::shutdown_callback); + (*idrv.get()).id_table = i2c_table; + (*idrv.get()).driver.of_match_table = of_table; + (*idrv.get()).driver.acpi_match_table = acpi_table; + } + + // SAFETY: `idrv` is guaranteed to be a valid `RegType`. + to_result(unsafe { bindings::i2c_register_driver(module.0, idrv.get()) }) + } + + unsafe fn unregister(idrv: &Opaque<Self::RegType>) { + // SAFETY: `idrv` is guaranteed to be a valid `RegType`. + unsafe { bindings::i2c_del_driver(idrv.get()) } + } +} + +impl<T: Driver + 'static> Adapter<T> { + extern "C" fn probe_callback(idev: *mut bindings::i2c_client) -> kernel::ffi::c_int { + // SAFETY: The I2C bus only ever calls the probe callback with a valid pointer to a + // `struct i2c_client`. + // + // INVARIANT: `idev` is valid for the duration of `probe_callback()`. + let idev = unsafe { &*idev.cast::<I2cClient<device::CoreInternal>>() }; + + let info = + Self::i2c_id_info(idev).or_else(|| <Self as driver::Adapter>::id_info(idev.as_ref())); + + from_result(|| { + let data = T::probe(idev, info); + + idev.as_ref().set_drvdata(data)?; + Ok(0) + }) + } + + extern "C" fn remove_callback(idev: *mut bindings::i2c_client) { + // SAFETY: `idev` is a valid pointer to a `struct i2c_client`. + let idev = unsafe { &*idev.cast::<I2cClient<device::CoreInternal>>() }; + + // SAFETY: `remove_callback` is only ever called after a successful call to + // `probe_callback`, hence it's guaranteed that `I2cClient::set_drvdata()` has been called + // and stored a `Pin<KBox<T>>`. + let data = unsafe { idev.as_ref().drvdata_obtain::<T>() }; + + T::unbind(idev, data.as_ref()); + } + + extern "C" fn shutdown_callback(idev: *mut bindings::i2c_client) { + // SAFETY: `shutdown_callback` is only ever called for a valid `idev` + let idev = unsafe { &*idev.cast::<I2cClient<device::CoreInternal>>() }; + + // SAFETY: `shutdown_callback` is only ever called after a successful call to + // `probe_callback`, hence it's guaranteed that `Device::set_drvdata()` has been called + // and stored a `Pin<KBox<T>>`. + let data = unsafe { idev.as_ref().drvdata_obtain::<T>() }; + + T::shutdown(idev, data.as_ref()); + } + + /// The [`i2c::IdTable`] of the corresponding driver. + fn i2c_id_table() -> Option<IdTable<<Self as driver::Adapter>::IdInfo>> { + T::I2C_ID_TABLE + } + + /// Returns the driver's private data from the matching entry in the [`i2c::IdTable`], if any. + /// + /// If this returns `None`, it means there is no match with an entry in the [`i2c::IdTable`]. + fn i2c_id_info(dev: &I2cClient) -> Option<&'static <Self as driver::Adapter>::IdInfo> { + let table = Self::i2c_id_table()?; + + // SAFETY: + // - `table` has static lifetime, hence it's valid for reads + // - `dev` is guaranteed to be valid while it's alive, and so is `dev.as_raw()`. + let raw_id = unsafe { bindings::i2c_match_id(table.as_ptr(), dev.as_raw()) }; + + if raw_id.is_null() { + return None; + } + + // SAFETY: `DeviceId` is a `#[repr(transparent)` wrapper of `struct i2c_device_id` and + // does not add additional invariants, so it's safe to transmute. + let id = unsafe { &*raw_id.cast::<DeviceId>() }; + + Some(table.info(<DeviceId as RawDeviceIdIndex>::index(id))) + } +} + +impl<T: Driver + 'static> driver::Adapter for Adapter<T> { + type IdInfo = T::IdInfo; + + fn of_id_table() -> Option<of::IdTable<Self::IdInfo>> { + T::OF_ID_TABLE + } + + fn acpi_id_table() -> Option<acpi::IdTable<Self::IdInfo>> { + T::ACPI_ID_TABLE + } +} + +/// Declares a kernel module that exposes a single i2c driver. +/// +/// # Examples +/// +/// ```ignore +/// kernel::module_i2c_driver! { +/// type: MyDriver, +/// name: "Module name", +/// authors: ["Author name"], +/// description: "Description", +/// license: "GPL v2", +/// } +/// ``` +#[macro_export] +macro_rules! module_i2c_driver { + ($($f:tt)*) => { + $crate::module_driver!(<T>, $crate::i2c::Adapter<T>, { $($f)* }); + }; +} + +/// The i2c driver trait. +/// +/// Drivers must implement this trait in order to get a i2c driver registered. +/// +/// # Example +/// +///``` +/// # use kernel::{acpi, bindings, c_str, device::Core, i2c, of}; +/// +/// struct MyDriver; +/// +/// kernel::acpi_device_table!( +/// ACPI_TABLE, +/// MODULE_ACPI_TABLE, +/// <MyDriver as i2c::Driver>::IdInfo, +/// [ +/// (acpi::DeviceId::new(c_str!("LNUXBEEF")), ()) +/// ] +/// ); +/// +/// kernel::i2c_device_table!( +/// I2C_TABLE, +/// MODULE_I2C_TABLE, +/// <MyDriver as i2c::Driver>::IdInfo, +/// [ +/// (i2c::DeviceId::new(c_str!("rust_driver_i2c")), ()) +/// ] +/// ); +/// +/// kernel::of_device_table!( +/// OF_TABLE, +/// MODULE_OF_TABLE, +/// <MyDriver as i2c::Driver>::IdInfo, +/// [ +/// (of::DeviceId::new(c_str!("test,device")), ()) +/// ] +/// ); +/// +/// impl i2c::Driver for MyDriver { +/// type IdInfo = (); +/// const I2C_ID_TABLE: Option<i2c::IdTable<Self::IdInfo>> = Some(&I2C_TABLE); +/// const OF_ID_TABLE: Option<of::IdTable<Self::IdInfo>> = Some(&OF_TABLE); +/// const ACPI_ID_TABLE: Option<acpi::IdTable<Self::IdInfo>> = Some(&ACPI_TABLE); +/// +/// fn probe( +/// _idev: &i2c::I2cClient<Core>, +/// _id_info: Option<&Self::IdInfo>, +/// ) -> impl PinInit<Self, Error> { +/// Err(ENODEV) +/// } +/// +/// fn shutdown(_idev: &i2c::I2cClient<Core>, this: Pin<&Self>) { +/// } +/// } +///``` +pub trait Driver: Send { + /// The type holding information about each device id supported by the driver. + // TODO: Use `associated_type_defaults` once stabilized: + // + // ``` + // type IdInfo: 'static = (); + // ``` + type IdInfo: 'static; + + /// The table of device ids supported by the driver. + const I2C_ID_TABLE: Option<IdTable<Self::IdInfo>> = None; + + /// The table of OF device ids supported by the driver. + const OF_ID_TABLE: Option<of::IdTable<Self::IdInfo>> = None; + + /// The table of ACPI device ids supported by the driver. + const ACPI_ID_TABLE: Option<acpi::IdTable<Self::IdInfo>> = None; + + /// I2C driver probe. + /// + /// Called when a new i2c client is added or discovered. + /// Implementers should attempt to initialize the client here. + fn probe( + dev: &I2cClient<device::Core>, + id_info: Option<&Self::IdInfo>, + ) -> impl PinInit<Self, Error>; + + /// I2C driver shutdown. + /// + /// Called by the kernel during system reboot or power-off to allow the [`Driver`] to bring the + /// [`I2cClient`] into a safe state. Implementing this callback is optional. + /// + /// Typical actions include stopping transfers, disabling interrupts, or resetting the hardware + /// to prevent undesired behavior during shutdown. + /// + /// This callback is distinct from final resource cleanup, as the driver instance remains valid + /// after it returns. Any deallocation or teardown of driver-owned resources should instead be + /// handled in `Self::drop`. + fn shutdown(dev: &I2cClient<device::Core>, this: Pin<&Self>) { + let _ = (dev, this); + } + + /// I2C driver unbind. + /// + /// Called when the [`I2cClient`] is unbound from its bound [`Driver`]. Implementing this + /// callback is optional. + /// + /// This callback serves as a place for drivers to perform teardown operations that require a + /// `&Device<Core>` or `&Device<Bound>` reference. For instance, drivers may try to perform I/O + /// operations to gracefully tear down the device. + /// + /// Otherwise, release operations for driver resources should be performed in `Self::drop`. + fn unbind(dev: &I2cClient<device::Core>, this: Pin<&Self>) { + let _ = (dev, this); + } +} + +/// The i2c adapter representation. +/// +/// This structure represents the Rust abstraction for a C `struct i2c_adapter`. The +/// implementation abstracts the usage of an existing C `struct i2c_adapter` that +/// gets passed from the C side +/// +/// # Invariants +/// +/// A [`I2cAdapter`] instance represents a valid `struct i2c_adapter` created by the C portion of +/// the kernel. +#[repr(transparent)] +pub struct I2cAdapter<Ctx: device::DeviceContext = device::Normal>( + Opaque<bindings::i2c_adapter>, + PhantomData<Ctx>, +); + +impl<Ctx: device::DeviceContext> I2cAdapter<Ctx> { + fn as_raw(&self) -> *mut bindings::i2c_adapter { + self.0.get() + } +} + +impl I2cAdapter { + /// Returns the I2C Adapter index. + #[inline] + pub fn index(&self) -> i32 { + // SAFETY: `self.as_raw` is a valid pointer to a `struct i2c_adapter`. + unsafe { (*self.as_raw()).nr } + } + + /// Gets pointer to an `i2c_adapter` by index. + pub fn get(index: i32) -> Result<ARef<Self>> { + // SAFETY: `index` must refer to a valid I2C adapter; the kernel + // guarantees that `i2c_get_adapter(index)` returns either a valid + // pointer or NULL. `NonNull::new` guarantees the correct check. + let adapter = NonNull::new(unsafe { bindings::i2c_get_adapter(index) }).ok_or(ENODEV)?; + + // SAFETY: `adapter` is non-null and points to a live `i2c_adapter`. + // `I2cAdapter` is #[repr(transparent)], so this cast is valid. + Ok(unsafe { (&*adapter.as_ptr().cast::<I2cAdapter<device::Normal>>()).into() }) + } +} + +// SAFETY: `I2cAdapter` is a transparent wrapper of a type that doesn't depend on +// `I2cAdapter`'s generic argument. +kernel::impl_device_context_deref!(unsafe { I2cAdapter }); +kernel::impl_device_context_into_aref!(I2cAdapter); + +// SAFETY: Instances of `I2cAdapter` are always reference-counted. +unsafe impl crate::types::AlwaysRefCounted for I2cAdapter { + fn inc_ref(&self) { + // SAFETY: The existence of a shared reference guarantees that the refcount is non-zero. + unsafe { bindings::i2c_get_adapter(self.index()) }; + } + + unsafe fn dec_ref(obj: NonNull<Self>) { + // SAFETY: The safety requirements guarantee that the refcount is non-zero. + unsafe { bindings::i2c_put_adapter(obj.as_ref().as_raw()) } + } +} + +/// The i2c board info representation +/// +/// This structure represents the Rust abstraction for a C `struct i2c_board_info` structure, +/// which is used for manual I2C client creation. +#[repr(transparent)] +pub struct I2cBoardInfo(bindings::i2c_board_info); + +impl I2cBoardInfo { + const I2C_TYPE_SIZE: usize = 20; + /// Create a new [`I2cBoardInfo`] for a kernel driver. + #[inline(always)] + pub const fn new(type_: &'static CStr, addr: u16) -> Self { + let src = type_.to_bytes_with_nul(); + build_assert!(src.len() <= Self::I2C_TYPE_SIZE, "Type exceeds 20 bytes"); + let mut i2c_board_info: bindings::i2c_board_info = pin_init::zeroed(); + let mut i: usize = 0; + while i < src.len() { + i2c_board_info.type_[i] = src[i]; + i += 1; + } + + i2c_board_info.addr = addr; + Self(i2c_board_info) + } + + fn as_raw(&self) -> *const bindings::i2c_board_info { + from_ref(&self.0) + } +} + +/// The i2c client representation. +/// +/// This structure represents the Rust abstraction for a C `struct i2c_client`. The +/// implementation abstracts the usage of an existing C `struct i2c_client` that +/// gets passed from the C side +/// +/// # Invariants +/// +/// A [`I2cClient`] instance represents a valid `struct i2c_client` created by the C portion of +/// the kernel. +#[repr(transparent)] +pub struct I2cClient<Ctx: device::DeviceContext = device::Normal>( + Opaque<bindings::i2c_client>, + PhantomData<Ctx>, +); + +impl<Ctx: device::DeviceContext> I2cClient<Ctx> { + fn as_raw(&self) -> *mut bindings::i2c_client { + self.0.get() + } +} + +// SAFETY: `I2cClient` is a transparent wrapper of `struct i2c_client`. +// The offset is guaranteed to point to a valid device field inside `I2cClient`. +unsafe impl<Ctx: device::DeviceContext> device::AsBusDevice<Ctx> for I2cClient<Ctx> { + const OFFSET: usize = offset_of!(bindings::i2c_client, dev); +} + +// SAFETY: `I2cClient` is a transparent wrapper of a type that doesn't depend on +// `I2cClient`'s generic argument. +kernel::impl_device_context_deref!(unsafe { I2cClient }); +kernel::impl_device_context_into_aref!(I2cClient); + +// SAFETY: Instances of `I2cClient` are always reference-counted. +unsafe impl AlwaysRefCounted for I2cClient { + fn inc_ref(&self) { + // SAFETY: The existence of a shared reference guarantees that the refcount is non-zero. + unsafe { bindings::get_device(self.as_ref().as_raw()) }; + } + + unsafe fn dec_ref(obj: NonNull<Self>) { + // SAFETY: The safety requirements guarantee that the refcount is non-zero. + unsafe { bindings::put_device(&raw mut (*obj.as_ref().as_raw()).dev) } + } +} + +impl<Ctx: device::DeviceContext> AsRef<device::Device<Ctx>> for I2cClient<Ctx> { + fn as_ref(&self) -> &device::Device<Ctx> { + let raw = self.as_raw(); + // SAFETY: By the type invariant of `Self`, `self.as_raw()` is a pointer to a valid + // `struct i2c_client`. + let dev = unsafe { &raw mut (*raw).dev }; + + // SAFETY: `dev` points to a valid `struct device`. + unsafe { device::Device::from_raw(dev) } + } +} + +impl<Ctx: device::DeviceContext> TryFrom<&device::Device<Ctx>> for &I2cClient<Ctx> { + type Error = kernel::error::Error; + + fn try_from(dev: &device::Device<Ctx>) -> Result<Self, Self::Error> { + // SAFETY: By the type invariant of `Device`, `dev.as_raw()` is a valid pointer to a + // `struct device`. + if unsafe { bindings::i2c_verify_client(dev.as_raw()).is_null() } { + return Err(EINVAL); + } + + // SAFETY: We've just verified that the type of `dev` equals to + // `bindings::i2c_client_type`, hence `dev` must be embedded in a valid + // `struct i2c_client` as guaranteed by the corresponding C code. + let idev = unsafe { container_of!(dev.as_raw(), bindings::i2c_client, dev) }; + + // SAFETY: `idev` is a valid pointer to a `struct i2c_client`. + Ok(unsafe { &*idev.cast() }) + } +} + +// SAFETY: A `I2cClient` is always reference-counted and can be released from any thread. +unsafe impl Send for I2cClient {} + +// SAFETY: `I2cClient` can be shared among threads because all methods of `I2cClient` +// (i.e. `I2cClient<Normal>) are thread safe. +unsafe impl Sync for I2cClient {} + +/// The registration of an i2c client device. +/// +/// This type represents the registration of a [`struct i2c_client`]. When an instance of this +/// type is dropped, its respective i2c client device will be unregistered from the system. +/// +/// # Invariants +/// +/// `self.0` always holds a valid pointer to an initialized and registered +/// [`struct i2c_client`]. +#[repr(transparent)] +pub struct Registration(NonNull<bindings::i2c_client>); + +impl Registration { + /// The C `i2c_new_client_device` function wrapper for manual I2C client creation. + pub fn new<'a>( + i2c_adapter: &I2cAdapter, + i2c_board_info: &I2cBoardInfo, + parent_dev: &'a device::Device<device::Bound>, + ) -> impl PinInit<Devres<Self>, Error> + 'a { + Devres::new(parent_dev, Self::try_new(i2c_adapter, i2c_board_info)) + } + + fn try_new(i2c_adapter: &I2cAdapter, i2c_board_info: &I2cBoardInfo) -> Result<Self> { + // SAFETY: the kernel guarantees that `i2c_new_client_device()` returns either a valid + // pointer or NULL. `from_err_ptr` separates errors. Following `NonNull::new` + // checks for NULL. + let raw_dev = from_err_ptr(unsafe { + bindings::i2c_new_client_device(i2c_adapter.as_raw(), i2c_board_info.as_raw()) + })?; + + let dev_ptr = NonNull::new(raw_dev).ok_or(ENODEV)?; + + Ok(Self(dev_ptr)) + } +} + +impl Drop for Registration { + fn drop(&mut self) { + // SAFETY: `Drop` is only called for a valid `Registration`, which by invariant + // always contains a non-null pointer to an `i2c_client`. + unsafe { bindings::i2c_unregister_device(self.0.as_ptr()) } + } +} + +// SAFETY: A `Registration` of a `struct i2c_client` can be released from any thread. +unsafe impl Send for Registration {} + +// SAFETY: `Registration` offers no interior mutability (no mutation through &self +// and no mutable access is exposed) +unsafe impl Sync for Registration {} diff --git a/rust/kernel/id_pool.rs b/rust/kernel/id_pool.rs new file mode 100644 index 000000000000..384753fe0e44 --- /dev/null +++ b/rust/kernel/id_pool.rs @@ -0,0 +1,295 @@ +// SPDX-License-Identifier: GPL-2.0 + +// Copyright (C) 2025 Google LLC. + +//! Rust API for an ID pool backed by a [`BitmapVec`]. + +use crate::alloc::{AllocError, Flags}; +use crate::bitmap::BitmapVec; + +/// Represents a dynamic ID pool backed by a [`BitmapVec`]. +/// +/// Clients acquire and release IDs from unset bits in a bitmap. +/// +/// The capacity of the ID pool may be adjusted by users as +/// needed. The API supports the scenario where users need precise control +/// over the time of allocation of a new backing bitmap, which may require +/// release of spinlock. +/// Due to concurrent updates, all operations are re-verified to determine +/// if the grow or shrink is sill valid. +/// +/// # Examples +/// +/// Basic usage +/// +/// ``` +/// use kernel::alloc::AllocError; +/// use kernel::id_pool::{IdPool, UnusedId}; +/// +/// let mut pool = IdPool::with_capacity(64, GFP_KERNEL)?; +/// for i in 0..64 { +/// assert_eq!(i, pool.find_unused_id(i).ok_or(ENOSPC)?.acquire()); +/// } +/// +/// pool.release_id(23); +/// assert_eq!(23, pool.find_unused_id(0).ok_or(ENOSPC)?.acquire()); +/// +/// assert!(pool.find_unused_id(0).is_none()); // time to realloc. +/// let resizer = pool.grow_request().ok_or(ENOSPC)?.realloc(GFP_KERNEL)?; +/// pool.grow(resizer); +/// +/// assert_eq!(pool.find_unused_id(0).ok_or(ENOSPC)?.acquire(), 64); +/// # Ok::<(), Error>(()) +/// ``` +/// +/// Releasing spinlock to grow the pool +/// +/// ```no_run +/// use kernel::alloc::{AllocError, flags::GFP_KERNEL}; +/// use kernel::sync::{new_spinlock, SpinLock}; +/// use kernel::id_pool::IdPool; +/// +/// fn get_id_maybe_realloc(guarded_pool: &SpinLock<IdPool>) -> Result<usize, AllocError> { +/// let mut pool = guarded_pool.lock(); +/// loop { +/// match pool.find_unused_id(0) { +/// Some(index) => return Ok(index.acquire()), +/// None => { +/// let alloc_request = pool.grow_request(); +/// drop(pool); +/// let resizer = alloc_request.ok_or(AllocError)?.realloc(GFP_KERNEL)?; +/// pool = guarded_pool.lock(); +/// pool.grow(resizer) +/// } +/// } +/// } +/// } +/// ``` +pub struct IdPool { + map: BitmapVec, +} + +/// Indicates that an [`IdPool`] should change to a new target size. +pub struct ReallocRequest { + num_ids: usize, +} + +/// Contains a [`BitmapVec`] of a size suitable for reallocating [`IdPool`]. +pub struct PoolResizer { + new: BitmapVec, +} + +impl ReallocRequest { + /// Allocates a new backing [`BitmapVec`] for [`IdPool`]. + /// + /// This method only prepares reallocation and does not complete it. + /// Reallocation will complete after passing the [`PoolResizer`] to the + /// [`IdPool::grow`] or [`IdPool::shrink`] operation, which will check + /// that reallocation still makes sense. + pub fn realloc(&self, flags: Flags) -> Result<PoolResizer, AllocError> { + let new = BitmapVec::new(self.num_ids, flags)?; + Ok(PoolResizer { new }) + } +} + +impl IdPool { + /// Constructs a new [`IdPool`]. + /// + /// The pool will have a capacity of [`MAX_INLINE_LEN`]. + /// + /// [`MAX_INLINE_LEN`]: BitmapVec::MAX_INLINE_LEN + #[inline] + pub fn new() -> Self { + Self { + map: BitmapVec::new_inline(), + } + } + + /// Constructs a new [`IdPool`] with space for a specific number of bits. + /// + /// A capacity below [`MAX_INLINE_LEN`] is adjusted to [`MAX_INLINE_LEN`]. + /// + /// [`MAX_INLINE_LEN`]: BitmapVec::MAX_INLINE_LEN + #[inline] + pub fn with_capacity(num_ids: usize, flags: Flags) -> Result<Self, AllocError> { + let num_ids = usize::max(num_ids, BitmapVec::MAX_INLINE_LEN); + let map = BitmapVec::new(num_ids, flags)?; + Ok(Self { map }) + } + + /// Returns how many IDs this pool can currently have. + #[inline] + pub fn capacity(&self) -> usize { + self.map.len() + } + + /// Returns a [`ReallocRequest`] if the [`IdPool`] can be shrunk, [`None`] otherwise. + /// + /// The capacity of an [`IdPool`] cannot be shrunk below [`MAX_INLINE_LEN`]. + /// + /// [`MAX_INLINE_LEN`]: BitmapVec::MAX_INLINE_LEN + /// + /// # Examples + /// + /// ``` + /// use kernel::{ + /// alloc::AllocError, + /// bitmap::BitmapVec, + /// id_pool::{ + /// IdPool, + /// ReallocRequest, + /// }, + /// }; + /// + /// let mut pool = IdPool::with_capacity(1024, GFP_KERNEL)?; + /// let alloc_request = pool.shrink_request().ok_or(AllocError)?; + /// let resizer = alloc_request.realloc(GFP_KERNEL)?; + /// pool.shrink(resizer); + /// assert_eq!(pool.capacity(), BitmapVec::MAX_INLINE_LEN); + /// # Ok::<(), AllocError>(()) + /// ``` + #[inline] + pub fn shrink_request(&self) -> Option<ReallocRequest> { + let cap = self.capacity(); + // Shrinking below `MAX_INLINE_LEN` is never possible. + if cap <= BitmapVec::MAX_INLINE_LEN { + return None; + } + // Determine if the bitmap can shrink based on the position of + // its last set bit. If the bit is within the first quarter of + // the bitmap then shrinking is possible. In this case, the + // bitmap should shrink to half its current size. + let Some(bit) = self.map.last_bit() else { + return Some(ReallocRequest { + num_ids: BitmapVec::MAX_INLINE_LEN, + }); + }; + if bit >= (cap / 4) { + return None; + } + let num_ids = usize::max(BitmapVec::MAX_INLINE_LEN, cap / 2); + Some(ReallocRequest { num_ids }) + } + + /// Shrinks pool by using a new [`BitmapVec`], if still possible. + #[inline] + pub fn shrink(&mut self, mut resizer: PoolResizer) { + // Between request to shrink that led to allocation of `resizer` and now, + // bits may have changed. + // Verify that shrinking is still possible. In case shrinking to + // the size of `resizer` is no longer possible, do nothing, + // drop `resizer` and move on. + let Some(updated) = self.shrink_request() else { + return; + }; + if updated.num_ids > resizer.new.len() { + return; + } + + resizer.new.copy_and_extend(&self.map); + self.map = resizer.new; + } + + /// Returns a [`ReallocRequest`] for growing this [`IdPool`], if possible. + /// + /// The capacity of an [`IdPool`] cannot be grown above [`MAX_LEN`]. + /// + /// [`MAX_LEN`]: BitmapVec::MAX_LEN + #[inline] + pub fn grow_request(&self) -> Option<ReallocRequest> { + let num_ids = self.capacity() * 2; + if num_ids > BitmapVec::MAX_LEN { + return None; + } + Some(ReallocRequest { num_ids }) + } + + /// Grows pool by using a new [`BitmapVec`], if still necessary. + /// + /// The `resizer` arguments has to be obtained by calling [`Self::grow_request`] + /// on this object and performing a [`ReallocRequest::realloc`]. + #[inline] + pub fn grow(&mut self, mut resizer: PoolResizer) { + // Between request to grow that led to allocation of `resizer` and now, + // another thread may have already grown the capacity. + // In this case, do nothing, drop `resizer` and move on. + if resizer.new.len() <= self.capacity() { + return; + } + + resizer.new.copy_and_extend(&self.map); + self.map = resizer.new; + } + + /// Finds an unused ID in the bitmap. + /// + /// Upon success, returns its index. Otherwise, returns [`None`] + /// to indicate that a [`Self::grow_request`] is needed. + #[inline] + #[must_use] + pub fn find_unused_id(&mut self, offset: usize) -> Option<UnusedId<'_>> { + // INVARIANT: `next_zero_bit()` returns None or an integer less than `map.len()` + Some(UnusedId { + id: self.map.next_zero_bit(offset)?, + pool: self, + }) + } + + /// Releases an ID. + #[inline] + pub fn release_id(&mut self, id: usize) { + self.map.clear_bit(id); + } +} + +/// Represents an unused id in an [`IdPool`]. +/// +/// # Invariants +/// +/// The value of `id` is less than `pool.map.len()`. +pub struct UnusedId<'pool> { + id: usize, + pool: &'pool mut IdPool, +} + +impl<'pool> UnusedId<'pool> { + /// Get the unused id as an usize. + /// + /// Be aware that the id has not yet been acquired in the pool. The + /// [`acquire`] method must be called to prevent others from taking the id. + /// + /// [`acquire`]: UnusedId::acquire() + #[inline] + #[must_use] + pub fn as_usize(&self) -> usize { + self.id + } + + /// Get the unused id as an u32. + /// + /// Be aware that the id has not yet been acquired in the pool. The + /// [`acquire`] method must be called to prevent others from taking the id. + /// + /// [`acquire`]: UnusedId::acquire() + #[inline] + #[must_use] + pub fn as_u32(&self) -> u32 { + // CAST: By the type invariants: + // `self.id < pool.map.len() <= BitmapVec::MAX_LEN = i32::MAX`. + self.id as u32 + } + + /// Acquire the unused id. + #[inline] + pub fn acquire(self) -> usize { + self.pool.map.set_bit(self.id); + self.id + } +} + +impl Default for IdPool { + #[inline] + fn default() -> Self { + Self::new() + } +} diff --git a/rust/kernel/init.rs b/rust/kernel/init.rs new file mode 100644 index 000000000000..899b9a962762 --- /dev/null +++ b/rust/kernel/init.rs @@ -0,0 +1,296 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! Extensions to the [`pin-init`] crate. +//! +//! Most `struct`s from the [`sync`] module need to be pinned, because they contain self-referential +//! `struct`s from C. [Pinning][pinning] is Rust's way of ensuring data does not move. +//! +//! The [`pin-init`] crate is the way such structs are initialized on the Rust side. Please refer +//! to its documentation to better understand how to use it. Additionally, there are many examples +//! throughout the kernel, such as the types from the [`sync`] module. And the ones presented +//! below. +//! +//! [`sync`]: crate::sync +//! [pinning]: https://doc.rust-lang.org/std/pin/index.html +//! [`pin-init`]: https://rust.docs.kernel.org/pin_init/ +//! +//! # [`Opaque<T>`] +//! +//! For the special case where initializing a field is a single FFI-function call that cannot fail, +//! there exist the helper function [`Opaque::ffi_init`]. This function initialize a single +//! [`Opaque<T>`] field by just delegating to the supplied closure. You can use these in +//! combination with [`pin_init!`]. +//! +//! [`Opaque<T>`]: crate::types::Opaque +//! [`Opaque::ffi_init`]: crate::types::Opaque::ffi_init +//! [`pin_init!`]: pin_init::pin_init +//! +//! # Examples +//! +//! ## General Examples +//! +//! ```rust +//! # #![expect(clippy::undocumented_unsafe_blocks)] +//! use kernel::types::Opaque; +//! use pin_init::pin_init_from_closure; +//! +//! // assume we have some `raw_foo` type in C: +//! #[repr(C)] +//! struct RawFoo([u8; 16]); +//! extern "C" { +//! fn init_foo(_: *mut RawFoo); +//! } +//! +//! #[pin_data] +//! struct Foo { +//! #[pin] +//! raw: Opaque<RawFoo>, +//! } +//! +//! impl Foo { +//! fn setup(self: Pin<&mut Self>) { +//! pr_info!("Setting up foo\n"); +//! } +//! } +//! +//! let foo = pin_init!(Foo { +//! raw <- unsafe { +//! Opaque::ffi_init(|s| { +//! // note that this cannot fail. +//! init_foo(s); +//! }) +//! }, +//! }).pin_chain(|foo| { +//! foo.setup(); +//! Ok(()) +//! }); +//! ``` +//! +//! ```rust +//! use kernel::{prelude::*, types::Opaque}; +//! use core::{ptr::addr_of_mut, marker::PhantomPinned, pin::Pin}; +//! # mod bindings { +//! # #![expect(non_camel_case_types, clippy::missing_safety_doc)] +//! # pub struct foo; +//! # pub unsafe fn init_foo(_ptr: *mut foo) {} +//! # pub unsafe fn destroy_foo(_ptr: *mut foo) {} +//! # pub unsafe fn enable_foo(_ptr: *mut foo, _flags: u32) -> i32 { 0 } +//! # } +//! /// # Invariants +//! /// +//! /// `foo` is always initialized +//! #[pin_data(PinnedDrop)] +//! pub struct RawFoo { +//! #[pin] +//! foo: Opaque<bindings::foo>, +//! #[pin] +//! _p: PhantomPinned, +//! } +//! +//! impl RawFoo { +//! pub fn new(flags: u32) -> impl PinInit<Self, Error> { +//! // SAFETY: +//! // - when the closure returns `Ok(())`, then it has successfully initialized and +//! // enabled `foo`, +//! // - when it returns `Err(e)`, then it has cleaned up before +//! unsafe { +//! pin_init::pin_init_from_closure(move |slot: *mut Self| { +//! // `slot` contains uninit memory, avoid creating a reference. +//! let foo = addr_of_mut!((*slot).foo); +//! +//! // Initialize the `foo` +//! bindings::init_foo(Opaque::cast_into(foo)); +//! +//! // Try to enable it. +//! let err = bindings::enable_foo(Opaque::cast_into(foo), flags); +//! if err != 0 { +//! // Enabling has failed, first clean up the foo and then return the error. +//! bindings::destroy_foo(Opaque::cast_into(foo)); +//! return Err(Error::from_errno(err)); +//! } +//! +//! // All fields of `RawFoo` have been initialized, since `_p` is a ZST. +//! Ok(()) +//! }) +//! } +//! } +//! } +//! +//! #[pinned_drop] +//! impl PinnedDrop for RawFoo { +//! fn drop(self: Pin<&mut Self>) { +//! // SAFETY: Since `foo` is initialized, destroying is safe. +//! unsafe { bindings::destroy_foo(self.foo.get()) }; +//! } +//! } +//! ``` + +use crate::{ + alloc::{AllocError, Flags}, + error::{self, Error}, +}; +use pin_init::{init_from_closure, pin_init_from_closure, Init, PinInit}; + +/// Smart pointer that can initialize memory in-place. +pub trait InPlaceInit<T>: Sized { + /// Pinned version of `Self`. + /// + /// If a type already implicitly pins its pointee, `Pin<Self>` is unnecessary. In this case use + /// `Self`, otherwise just use `Pin<Self>`. + type PinnedSelf; + + /// Use the given pin-initializer to pin-initialize a `T` inside of a new smart pointer of this + /// type. + /// + /// If `T: !Unpin` it will not be able to move afterwards. + fn try_pin_init<E>(init: impl PinInit<T, E>, flags: Flags) -> Result<Self::PinnedSelf, E> + where + E: From<AllocError>; + + /// Use the given pin-initializer to pin-initialize a `T` inside of a new smart pointer of this + /// type. + /// + /// If `T: !Unpin` it will not be able to move afterwards. + fn pin_init<E>(init: impl PinInit<T, E>, flags: Flags) -> error::Result<Self::PinnedSelf> + where + Error: From<E>, + { + // SAFETY: We delegate to `init` and only change the error type. + let init = unsafe { + pin_init_from_closure(|slot| init.__pinned_init(slot).map_err(|e| Error::from(e))) + }; + Self::try_pin_init(init, flags) + } + + /// Use the given initializer to in-place initialize a `T`. + fn try_init<E>(init: impl Init<T, E>, flags: Flags) -> Result<Self, E> + where + E: From<AllocError>; + + /// Use the given initializer to in-place initialize a `T`. + fn init<E>(init: impl Init<T, E>, flags: Flags) -> error::Result<Self> + where + Error: From<E>, + { + // SAFETY: We delegate to `init` and only change the error type. + let init = unsafe { + init_from_closure(|slot| init.__pinned_init(slot).map_err(|e| Error::from(e))) + }; + Self::try_init(init, flags) + } +} + +/// Construct an in-place fallible initializer for `struct`s. +/// +/// This macro defaults the error to [`Error`]. If you need [`Infallible`], then use +/// [`init!`]. +/// +/// The syntax is identical to [`try_pin_init!`]. If you want to specify a custom error, +/// append `? $type` after the `struct` initializer. +/// The safety caveats from [`try_pin_init!`] also apply: +/// - `unsafe` code must guarantee either full initialization or return an error and allow +/// deallocation of the memory. +/// - the fields are initialized in the order given in the initializer. +/// - no references to fields are allowed to be created inside of the initializer. +/// +/// # Examples +/// +/// ```rust +/// use kernel::error::Error; +/// use pin_init::init_zeroed; +/// struct BigBuf { +/// big: KBox<[u8; 1024 * 1024 * 1024]>, +/// small: [u8; 1024 * 1024], +/// } +/// +/// impl BigBuf { +/// fn new() -> impl Init<Self, Error> { +/// try_init!(Self { +/// big: KBox::init(init_zeroed(), GFP_KERNEL)?, +/// small: [0; 1024 * 1024], +/// }? Error) +/// } +/// } +/// ``` +/// +/// [`Infallible`]: core::convert::Infallible +/// [`init!`]: pin_init::init +/// [`try_pin_init!`]: crate::try_pin_init! +/// [`Error`]: crate::error::Error +#[macro_export] +macro_rules! try_init { + ($(&$this:ident in)? $t:ident $(::<$($generics:ty),* $(,)?>)? { + $($fields:tt)* + }) => { + ::pin_init::try_init!($(&$this in)? $t $(::<$($generics),*>)? { + $($fields)* + }? $crate::error::Error) + }; + ($(&$this:ident in)? $t:ident $(::<$($generics:ty),* $(,)?>)? { + $($fields:tt)* + }? $err:ty) => { + ::pin_init::try_init!($(&$this in)? $t $(::<$($generics),*>)? { + $($fields)* + }? $err) + }; +} + +/// Construct an in-place, fallible pinned initializer for `struct`s. +/// +/// If the initialization can complete without error (or [`Infallible`]), then use [`pin_init!`]. +/// +/// You can use the `?` operator or use `return Err(err)` inside the initializer to stop +/// initialization and return the error. +/// +/// IMPORTANT: if you have `unsafe` code inside of the initializer you have to ensure that when +/// initialization fails, the memory can be safely deallocated without any further modifications. +/// +/// This macro defaults the error to [`Error`]. +/// +/// The syntax is identical to [`pin_init!`] with the following exception: you can append `? $type` +/// after the `struct` initializer to specify the error type you want to use. +/// +/// # Examples +/// +/// ```rust +/// # #![feature(new_uninit)] +/// use kernel::error::Error; +/// use pin_init::init_zeroed; +/// #[pin_data] +/// struct BigBuf { +/// big: KBox<[u8; 1024 * 1024 * 1024]>, +/// small: [u8; 1024 * 1024], +/// ptr: *mut u8, +/// } +/// +/// impl BigBuf { +/// fn new() -> impl PinInit<Self, Error> { +/// try_pin_init!(Self { +/// big: KBox::init(init_zeroed(), GFP_KERNEL)?, +/// small: [0; 1024 * 1024], +/// ptr: core::ptr::null_mut(), +/// }? Error) +/// } +/// } +/// ``` +/// +/// [`Infallible`]: core::convert::Infallible +/// [`pin_init!`]: pin_init::pin_init +/// [`Error`]: crate::error::Error +#[macro_export] +macro_rules! try_pin_init { + ($(&$this:ident in)? $t:ident $(::<$($generics:ty),* $(,)?>)? { + $($fields:tt)* + }) => { + ::pin_init::try_pin_init!($(&$this in)? $t $(::<$($generics),*>)? { + $($fields)* + }? $crate::error::Error) + }; + ($(&$this:ident in)? $t:ident $(::<$($generics:ty),* $(,)?>)? { + $($fields:tt)* + }? $err:ty) => { + ::pin_init::try_pin_init!($(&$this in)? $t $(::<$($generics),*>)? { + $($fields)* + }? $err) + }; +} diff --git a/rust/kernel/io.rs b/rust/kernel/io.rs new file mode 100644 index 000000000000..98e8b84e68d1 --- /dev/null +++ b/rust/kernel/io.rs @@ -0,0 +1,288 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! Memory-mapped IO. +//! +//! C header: [`include/asm-generic/io.h`](srctree/include/asm-generic/io.h) + +use crate::{ + bindings, + prelude::*, // +}; + +pub mod mem; +pub mod poll; +pub mod resource; + +pub use resource::Resource; + +/// Physical address type. +/// +/// This is a type alias to either `u32` or `u64` depending on the config option +/// `CONFIG_PHYS_ADDR_T_64BIT`, and it can be a u64 even on 32-bit architectures. +pub type PhysAddr = bindings::phys_addr_t; + +/// Resource Size type. +/// +/// This is a type alias to either `u32` or `u64` depending on the config option +/// `CONFIG_PHYS_ADDR_T_64BIT`, and it can be a u64 even on 32-bit architectures. +pub type ResourceSize = bindings::resource_size_t; + +/// Raw representation of an MMIO region. +/// +/// By itself, the existence of an instance of this structure does not provide any guarantees that +/// the represented MMIO region does exist or is properly mapped. +/// +/// Instead, the bus specific MMIO implementation must convert this raw representation into an `Io` +/// instance providing the actual memory accessors. Only by the conversion into an `Io` structure +/// any guarantees are given. +pub struct IoRaw<const SIZE: usize = 0> { + addr: usize, + maxsize: usize, +} + +impl<const SIZE: usize> IoRaw<SIZE> { + /// Returns a new `IoRaw` instance on success, an error otherwise. + pub fn new(addr: usize, maxsize: usize) -> Result<Self> { + if maxsize < SIZE { + return Err(EINVAL); + } + + Ok(Self { addr, maxsize }) + } + + /// Returns the base address of the MMIO region. + #[inline] + pub fn addr(&self) -> usize { + self.addr + } + + /// Returns the maximum size of the MMIO region. + #[inline] + pub fn maxsize(&self) -> usize { + self.maxsize + } +} + +/// IO-mapped memory region. +/// +/// The creator (usually a subsystem / bus such as PCI) is responsible for creating the +/// mapping, performing an additional region request etc. +/// +/// # Invariant +/// +/// `addr` is the start and `maxsize` the length of valid I/O mapped memory region of size +/// `maxsize`. +/// +/// # Examples +/// +/// ```no_run +/// use kernel::{ +/// bindings, +/// ffi::c_void, +/// io::{ +/// Io, +/// IoRaw, +/// PhysAddr, +/// }, +/// }; +/// use core::ops::Deref; +/// +/// // See also [`pci::Bar`] for a real example. +/// struct IoMem<const SIZE: usize>(IoRaw<SIZE>); +/// +/// impl<const SIZE: usize> IoMem<SIZE> { +/// /// # Safety +/// /// +/// /// [`paddr`, `paddr` + `SIZE`) must be a valid MMIO region that is mappable into the CPUs +/// /// virtual address space. +/// unsafe fn new(paddr: usize) -> Result<Self>{ +/// // SAFETY: By the safety requirements of this function [`paddr`, `paddr` + `SIZE`) is +/// // valid for `ioremap`. +/// let addr = unsafe { bindings::ioremap(paddr as PhysAddr, SIZE) }; +/// if addr.is_null() { +/// return Err(ENOMEM); +/// } +/// +/// Ok(IoMem(IoRaw::new(addr as usize, SIZE)?)) +/// } +/// } +/// +/// impl<const SIZE: usize> Drop for IoMem<SIZE> { +/// fn drop(&mut self) { +/// // SAFETY: `self.0.addr()` is guaranteed to be properly mapped by `Self::new`. +/// unsafe { bindings::iounmap(self.0.addr() as *mut c_void); }; +/// } +/// } +/// +/// impl<const SIZE: usize> Deref for IoMem<SIZE> { +/// type Target = Io<SIZE>; +/// +/// fn deref(&self) -> &Self::Target { +/// // SAFETY: The memory range stored in `self` has been properly mapped in `Self::new`. +/// unsafe { Io::from_raw(&self.0) } +/// } +/// } +/// +///# fn no_run() -> Result<(), Error> { +/// // SAFETY: Invalid usage for example purposes. +/// let iomem = unsafe { IoMem::<{ core::mem::size_of::<u32>() }>::new(0xBAAAAAAD)? }; +/// iomem.write32(0x42, 0x0); +/// assert!(iomem.try_write32(0x42, 0x0).is_ok()); +/// assert!(iomem.try_write32(0x42, 0x4).is_err()); +/// # Ok(()) +/// # } +/// ``` +#[repr(transparent)] +pub struct Io<const SIZE: usize = 0>(IoRaw<SIZE>); + +macro_rules! define_read { + ($(#[$attr:meta])* $name:ident, $try_name:ident, $c_fn:ident -> $type_name:ty) => { + /// Read IO data from a given offset known at compile time. + /// + /// Bound checks are performed on compile time, hence if the offset is not known at compile + /// time, the build will fail. + $(#[$attr])* + #[inline] + pub fn $name(&self, offset: usize) -> $type_name { + let addr = self.io_addr_assert::<$type_name>(offset); + + // SAFETY: By the type invariant `addr` is a valid address for MMIO operations. + unsafe { bindings::$c_fn(addr as *const c_void) } + } + + /// Read IO data from a given offset. + /// + /// Bound checks are performed on runtime, it fails if the offset (plus the type size) is + /// out of bounds. + $(#[$attr])* + pub fn $try_name(&self, offset: usize) -> Result<$type_name> { + let addr = self.io_addr::<$type_name>(offset)?; + + // SAFETY: By the type invariant `addr` is a valid address for MMIO operations. + Ok(unsafe { bindings::$c_fn(addr as *const c_void) }) + } + }; +} + +macro_rules! define_write { + ($(#[$attr:meta])* $name:ident, $try_name:ident, $c_fn:ident <- $type_name:ty) => { + /// Write IO data from a given offset known at compile time. + /// + /// Bound checks are performed on compile time, hence if the offset is not known at compile + /// time, the build will fail. + $(#[$attr])* + #[inline] + pub fn $name(&self, value: $type_name, offset: usize) { + let addr = self.io_addr_assert::<$type_name>(offset); + + // SAFETY: By the type invariant `addr` is a valid address for MMIO operations. + unsafe { bindings::$c_fn(value, addr as *mut c_void) } + } + + /// Write IO data from a given offset. + /// + /// Bound checks are performed on runtime, it fails if the offset (plus the type size) is + /// out of bounds. + $(#[$attr])* + pub fn $try_name(&self, value: $type_name, offset: usize) -> Result { + let addr = self.io_addr::<$type_name>(offset)?; + + // SAFETY: By the type invariant `addr` is a valid address for MMIO operations. + unsafe { bindings::$c_fn(value, addr as *mut c_void) } + Ok(()) + } + }; +} + +impl<const SIZE: usize> Io<SIZE> { + /// Converts an `IoRaw` into an `Io` instance, providing the accessors to the MMIO mapping. + /// + /// # Safety + /// + /// Callers must ensure that `addr` is the start of a valid I/O mapped memory region of size + /// `maxsize`. + pub unsafe fn from_raw(raw: &IoRaw<SIZE>) -> &Self { + // SAFETY: `Io` is a transparent wrapper around `IoRaw`. + unsafe { &*core::ptr::from_ref(raw).cast() } + } + + /// Returns the base address of this mapping. + #[inline] + pub fn addr(&self) -> usize { + self.0.addr() + } + + /// Returns the maximum size of this mapping. + #[inline] + pub fn maxsize(&self) -> usize { + self.0.maxsize() + } + + #[inline] + const fn offset_valid<U>(offset: usize, size: usize) -> bool { + let type_size = core::mem::size_of::<U>(); + if let Some(end) = offset.checked_add(type_size) { + end <= size && offset % type_size == 0 + } else { + false + } + } + + #[inline] + fn io_addr<U>(&self, offset: usize) -> Result<usize> { + if !Self::offset_valid::<U>(offset, self.maxsize()) { + return Err(EINVAL); + } + + // Probably no need to check, since the safety requirements of `Self::new` guarantee that + // this can't overflow. + self.addr().checked_add(offset).ok_or(EINVAL) + } + + #[inline] + fn io_addr_assert<U>(&self, offset: usize) -> usize { + build_assert!(Self::offset_valid::<U>(offset, SIZE)); + + self.addr() + offset + } + + define_read!(read8, try_read8, readb -> u8); + define_read!(read16, try_read16, readw -> u16); + define_read!(read32, try_read32, readl -> u32); + define_read!( + #[cfg(CONFIG_64BIT)] + read64, + try_read64, + readq -> u64 + ); + + define_read!(read8_relaxed, try_read8_relaxed, readb_relaxed -> u8); + define_read!(read16_relaxed, try_read16_relaxed, readw_relaxed -> u16); + define_read!(read32_relaxed, try_read32_relaxed, readl_relaxed -> u32); + define_read!( + #[cfg(CONFIG_64BIT)] + read64_relaxed, + try_read64_relaxed, + readq_relaxed -> u64 + ); + + define_write!(write8, try_write8, writeb <- u8); + define_write!(write16, try_write16, writew <- u16); + define_write!(write32, try_write32, writel <- u32); + define_write!( + #[cfg(CONFIG_64BIT)] + write64, + try_write64, + writeq <- u64 + ); + + define_write!(write8_relaxed, try_write8_relaxed, writeb_relaxed <- u8); + define_write!(write16_relaxed, try_write16_relaxed, writew_relaxed <- u16); + define_write!(write32_relaxed, try_write32_relaxed, writel_relaxed <- u32); + define_write!( + #[cfg(CONFIG_64BIT)] + write64_relaxed, + try_write64_relaxed, + writeq_relaxed <- u64 + ); +} diff --git a/rust/kernel/io/mem.rs b/rust/kernel/io/mem.rs new file mode 100644 index 000000000000..b03b82cd531b --- /dev/null +++ b/rust/kernel/io/mem.rs @@ -0,0 +1,287 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! Generic memory-mapped IO. + +use core::ops::Deref; + +use crate::{ + c_str, + device::{ + Bound, + Device, // + }, + devres::Devres, + io::{ + self, + resource::{ + Region, + Resource, // + }, + Io, + IoRaw, // + }, + prelude::*, +}; + +/// An IO request for a specific device and resource. +pub struct IoRequest<'a> { + device: &'a Device<Bound>, + resource: &'a Resource, +} + +impl<'a> IoRequest<'a> { + /// Creates a new [`IoRequest`] instance. + /// + /// # Safety + /// + /// Callers must ensure that `resource` is valid for `device` during the + /// lifetime `'a`. + pub(crate) unsafe fn new(device: &'a Device<Bound>, resource: &'a Resource) -> Self { + IoRequest { device, resource } + } + + /// Maps an [`IoRequest`] where the size is known at compile time. + /// + /// This uses the [`ioremap()`] C API. + /// + /// [`ioremap()`]: https://docs.kernel.org/driver-api/device-io.html#getting-access-to-the-device + /// + /// # Examples + /// + /// The following example uses a [`kernel::platform::Device`] for + /// illustration purposes. + /// + /// ```no_run + /// use kernel::{bindings, c_str, platform, of, device::Core}; + /// struct SampleDriver; + /// + /// impl platform::Driver for SampleDriver { + /// # type IdInfo = (); + /// + /// fn probe( + /// pdev: &platform::Device<Core>, + /// info: Option<&Self::IdInfo>, + /// ) -> impl PinInit<Self, Error> { + /// let offset = 0; // Some offset. + /// + /// // If the size is known at compile time, use [`Self::iomap_sized`]. + /// // + /// // No runtime checks will apply when reading and writing. + /// let request = pdev.io_request_by_index(0).ok_or(ENODEV)?; + /// let iomem = request.iomap_sized::<42>(); + /// let iomem = KBox::pin_init(iomem, GFP_KERNEL)?; + /// + /// let io = iomem.access(pdev.as_ref())?; + /// + /// // Read and write a 32-bit value at `offset`. + /// let data = io.read32_relaxed(offset); + /// + /// io.write32_relaxed(data, offset); + /// + /// # Ok(SampleDriver) + /// } + /// } + /// ``` + pub fn iomap_sized<const SIZE: usize>(self) -> impl PinInit<Devres<IoMem<SIZE>>, Error> + 'a { + IoMem::new(self) + } + + /// Same as [`Self::iomap_sized`] but with exclusive access to the + /// underlying region. + /// + /// This uses the [`ioremap()`] C API. + /// + /// [`ioremap()`]: https://docs.kernel.org/driver-api/device-io.html#getting-access-to-the-device + pub fn iomap_exclusive_sized<const SIZE: usize>( + self, + ) -> impl PinInit<Devres<ExclusiveIoMem<SIZE>>, Error> + 'a { + ExclusiveIoMem::new(self) + } + + /// Maps an [`IoRequest`] where the size is not known at compile time, + /// + /// This uses the [`ioremap()`] C API. + /// + /// [`ioremap()`]: https://docs.kernel.org/driver-api/device-io.html#getting-access-to-the-device + /// + /// # Examples + /// + /// The following example uses a [`kernel::platform::Device`] for + /// illustration purposes. + /// + /// ```no_run + /// use kernel::{bindings, c_str, platform, of, device::Core}; + /// struct SampleDriver; + /// + /// impl platform::Driver for SampleDriver { + /// # type IdInfo = (); + /// + /// fn probe( + /// pdev: &platform::Device<Core>, + /// info: Option<&Self::IdInfo>, + /// ) -> impl PinInit<Self, Error> { + /// let offset = 0; // Some offset. + /// + /// // Unlike [`Self::iomap_sized`], here the size of the memory region + /// // is not known at compile time, so only the `try_read*` and `try_write*` + /// // family of functions should be used, leading to runtime checks on every + /// // access. + /// let request = pdev.io_request_by_index(0).ok_or(ENODEV)?; + /// let iomem = request.iomap(); + /// let iomem = KBox::pin_init(iomem, GFP_KERNEL)?; + /// + /// let io = iomem.access(pdev.as_ref())?; + /// + /// let data = io.try_read32_relaxed(offset)?; + /// + /// io.try_write32_relaxed(data, offset)?; + /// + /// # Ok(SampleDriver) + /// } + /// } + /// ``` + pub fn iomap(self) -> impl PinInit<Devres<IoMem<0>>, Error> + 'a { + Self::iomap_sized::<0>(self) + } + + /// Same as [`Self::iomap`] but with exclusive access to the underlying + /// region. + pub fn iomap_exclusive(self) -> impl PinInit<Devres<ExclusiveIoMem<0>>, Error> + 'a { + Self::iomap_exclusive_sized::<0>(self) + } +} + +/// An exclusive memory-mapped IO region. +/// +/// # Invariants +/// +/// - [`ExclusiveIoMem`] has exclusive access to the underlying [`IoMem`]. +pub struct ExclusiveIoMem<const SIZE: usize> { + /// The underlying `IoMem` instance. + iomem: IoMem<SIZE>, + + /// The region abstraction. This represents exclusive access to the + /// range represented by the underlying `iomem`. + /// + /// This field is needed for ownership of the region. + _region: Region, +} + +impl<const SIZE: usize> ExclusiveIoMem<SIZE> { + /// Creates a new `ExclusiveIoMem` instance. + fn ioremap(resource: &Resource) -> Result<Self> { + let start = resource.start(); + let size = resource.size(); + let name = resource.name().unwrap_or(c_str!("")); + + let region = resource + .request_region( + start, + size, + name.to_cstring()?, + io::resource::Flags::IORESOURCE_MEM, + ) + .ok_or(EBUSY)?; + + let iomem = IoMem::ioremap(resource)?; + + let iomem = ExclusiveIoMem { + iomem, + _region: region, + }; + + Ok(iomem) + } + + /// Creates a new `ExclusiveIoMem` instance from a previously acquired [`IoRequest`]. + pub fn new<'a>(io_request: IoRequest<'a>) -> impl PinInit<Devres<Self>, Error> + 'a { + let dev = io_request.device; + let res = io_request.resource; + + Devres::new(dev, Self::ioremap(res)) + } +} + +impl<const SIZE: usize> Deref for ExclusiveIoMem<SIZE> { + type Target = Io<SIZE>; + + fn deref(&self) -> &Self::Target { + &self.iomem + } +} + +/// A generic memory-mapped IO region. +/// +/// Accesses to the underlying region is checked either at compile time, if the +/// region's size is known at that point, or at runtime otherwise. +/// +/// # Invariants +/// +/// [`IoMem`] always holds an [`IoRaw`] instance that holds a valid pointer to the +/// start of the I/O memory mapped region. +pub struct IoMem<const SIZE: usize = 0> { + io: IoRaw<SIZE>, +} + +impl<const SIZE: usize> IoMem<SIZE> { + fn ioremap(resource: &Resource) -> Result<Self> { + // Note: Some ioremap() implementations use types that depend on the CPU + // word width rather than the bus address width. + // + // TODO: Properly address this in the C code to avoid this `try_into`. + let size = resource.size().try_into()?; + if size == 0 { + return Err(EINVAL); + } + + let res_start = resource.start(); + + let addr = if resource + .flags() + .contains(io::resource::Flags::IORESOURCE_MEM_NONPOSTED) + { + // SAFETY: + // - `res_start` and `size` are read from a presumably valid `struct resource`. + // - `size` is known not to be zero at this point. + unsafe { bindings::ioremap_np(res_start, size) } + } else { + // SAFETY: + // - `res_start` and `size` are read from a presumably valid `struct resource`. + // - `size` is known not to be zero at this point. + unsafe { bindings::ioremap(res_start, size) } + }; + + if addr.is_null() { + return Err(ENOMEM); + } + + let io = IoRaw::new(addr as usize, size)?; + let io = IoMem { io }; + + Ok(io) + } + + /// Creates a new `IoMem` instance from a previously acquired [`IoRequest`]. + pub fn new<'a>(io_request: IoRequest<'a>) -> impl PinInit<Devres<Self>, Error> + 'a { + let dev = io_request.device; + let res = io_request.resource; + + Devres::new(dev, Self::ioremap(res)) + } +} + +impl<const SIZE: usize> Drop for IoMem<SIZE> { + fn drop(&mut self) { + // SAFETY: Safe as by the invariant of `Io`. + unsafe { bindings::iounmap(self.io.addr() as *mut c_void) } + } +} + +impl<const SIZE: usize> Deref for IoMem<SIZE> { + type Target = Io<SIZE>; + + fn deref(&self) -> &Self::Target { + // SAFETY: Safe as by the invariant of `IoMem`. + unsafe { Io::from_raw(&self.io) } + } +} diff --git a/rust/kernel/io/poll.rs b/rust/kernel/io/poll.rs new file mode 100644 index 000000000000..b1a2570364f4 --- /dev/null +++ b/rust/kernel/io/poll.rs @@ -0,0 +1,173 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! IO polling. +//! +//! C header: [`include/linux/iopoll.h`](srctree/include/linux/iopoll.h). + +use crate::{ + prelude::*, + processor::cpu_relax, + task::might_sleep, + time::{ + delay::{ + fsleep, + udelay, // + }, + Delta, + Instant, + Monotonic, // + }, +}; + +/// Polls periodically until a condition is met, an error occurs, +/// or the timeout is reached. +/// +/// The function repeatedly executes the given operation `op` closure and +/// checks its result using the condition closure `cond`. +/// +/// If `cond` returns `true`, the function returns successfully with +/// the result of `op`. Otherwise, it waits for a duration specified +/// by `sleep_delta` before executing `op` again. +/// +/// This process continues until either `op` returns an error, `cond` +/// returns `true`, or the timeout specified by `timeout_delta` is +/// reached. +/// +/// This function can only be used in a nonatomic context. +/// +/// # Errors +/// +/// If `op` returns an error, then that error is returned directly. +/// +/// If the timeout specified by `timeout_delta` is reached, then +/// `Err(ETIMEDOUT)` is returned. +/// +/// # Examples +/// +/// ```no_run +/// use kernel::io::{Io, poll::read_poll_timeout}; +/// use kernel::time::Delta; +/// +/// const HW_READY: u16 = 0x01; +/// +/// fn wait_for_hardware<const SIZE: usize>(io: &Io<SIZE>) -> Result { +/// read_poll_timeout( +/// // The `op` closure reads the value of a specific status register. +/// || io.try_read16(0x1000), +/// // The `cond` closure takes a reference to the value returned by `op` +/// // and checks whether the hardware is ready. +/// |val: &u16| *val == HW_READY, +/// Delta::from_millis(50), +/// Delta::from_secs(3), +/// )?; +/// Ok(()) +/// } +/// ``` +#[track_caller] +pub fn read_poll_timeout<Op, Cond, T>( + mut op: Op, + mut cond: Cond, + sleep_delta: Delta, + timeout_delta: Delta, +) -> Result<T> +where + Op: FnMut() -> Result<T>, + Cond: FnMut(&T) -> bool, +{ + let start: Instant<Monotonic> = Instant::now(); + + // Unlike the C version, we always call `might_sleep()` unconditionally, + // as conditional calls are error-prone. We clearly separate + // `read_poll_timeout()` and `read_poll_timeout_atomic()` to aid + // tools like klint. + might_sleep(); + + loop { + let val = op()?; + if cond(&val) { + // Unlike the C version, we immediately return. + // We know the condition is met so we don't need to check again. + return Ok(val); + } + + if start.elapsed() > timeout_delta { + // Unlike the C version, we immediately return. + // We have just called `op()` so we don't need to call it again. + return Err(ETIMEDOUT); + } + + if !sleep_delta.is_zero() { + fsleep(sleep_delta); + } + + // `fsleep()` could be a busy-wait loop so we always call `cpu_relax()`. + cpu_relax(); + } +} + +/// Polls periodically until a condition is met, an error occurs, +/// or the attempt limit is reached. +/// +/// The function repeatedly executes the given operation `op` closure and +/// checks its result using the condition closure `cond`. +/// +/// If `cond` returns `true`, the function returns successfully with the result of `op`. +/// Otherwise, it performs a busy wait for a duration specified by `delay_delta` +/// before executing `op` again. +/// +/// This process continues until either `op` returns an error, `cond` +/// returns `true`, or the attempt limit specified by `retry` is reached. +/// +/// # Errors +/// +/// If `op` returns an error, then that error is returned directly. +/// +/// If the attempt limit specified by `retry` is reached, then +/// `Err(ETIMEDOUT)` is returned. +/// +/// # Examples +/// +/// ```no_run +/// use kernel::io::{poll::read_poll_timeout_atomic, Io}; +/// use kernel::time::Delta; +/// +/// const HW_READY: u16 = 0x01; +/// +/// fn wait_for_hardware<const SIZE: usize>(io: &Io<SIZE>) -> Result { +/// read_poll_timeout_atomic( +/// // The `op` closure reads the value of a specific status register. +/// || io.try_read16(0x1000), +/// // The `cond` closure takes a reference to the value returned by `op` +/// // and checks whether the hardware is ready. +/// |val: &u16| *val == HW_READY, +/// Delta::from_micros(50), +/// 1000, +/// )?; +/// Ok(()) +/// } +/// ``` +pub fn read_poll_timeout_atomic<Op, Cond, T>( + mut op: Op, + mut cond: Cond, + delay_delta: Delta, + retry: usize, +) -> Result<T> +where + Op: FnMut() -> Result<T>, + Cond: FnMut(&T) -> bool, +{ + for _ in 0..retry { + let val = op()?; + if cond(&val) { + return Ok(val); + } + + if !delay_delta.is_zero() { + udelay(delay_delta); + } + + cpu_relax(); + } + + Err(ETIMEDOUT) +} diff --git a/rust/kernel/io/resource.rs b/rust/kernel/io/resource.rs new file mode 100644 index 000000000000..56cfde97ce87 --- /dev/null +++ b/rust/kernel/io/resource.rs @@ -0,0 +1,233 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! Abstractions for [system +//! resources](https://docs.kernel.org/core-api/kernel-api.html#resources-management). +//! +//! C header: [`include/linux/ioport.h`](srctree/include/linux/ioport.h) + +use core::{ + ops::Deref, + ptr::NonNull, // +}; + +use crate::{ + prelude::*, + str::CString, + types::Opaque, // +}; + +pub use super::{ + PhysAddr, + ResourceSize, // +}; + +/// A region allocated from a parent [`Resource`]. +/// +/// # Invariants +/// +/// - `self.0` points to a valid `bindings::resource` that was obtained through +/// `bindings::__request_region`. +pub struct Region { + /// The resource returned when the region was requested. + resource: NonNull<bindings::resource>, + /// The name that was passed in when the region was requested. We need to + /// store it for ownership reasons. + _name: CString, +} + +impl Deref for Region { + type Target = Resource; + + fn deref(&self) -> &Self::Target { + // SAFETY: Safe as per the invariant of `Region`. + unsafe { Resource::from_raw(self.resource.as_ptr()) } + } +} + +impl Drop for Region { + fn drop(&mut self) { + let (flags, start, size) = { + let res = &**self; + (res.flags(), res.start(), res.size()) + }; + + let release_fn = if flags.contains(Flags::IORESOURCE_MEM) { + bindings::release_mem_region + } else { + bindings::release_region + }; + + // SAFETY: Safe as per the invariant of `Region`. + unsafe { release_fn(start, size) }; + } +} + +// SAFETY: `Region` only holds a pointer to a C `struct resource`, which is safe to be used from +// any thread. +unsafe impl Send for Region {} + +// SAFETY: `Region` only holds a pointer to a C `struct resource`, references to which are +// safe to be used from any thread. +unsafe impl Sync for Region {} + +/// A resource abstraction. +/// +/// # Invariants +/// +/// [`Resource`] is a transparent wrapper around a valid `bindings::resource`. +#[repr(transparent)] +pub struct Resource(Opaque<bindings::resource>); + +impl Resource { + /// Creates a reference to a [`Resource`] from a valid pointer. + /// + /// # Safety + /// + /// The caller must ensure that for the duration of 'a, the pointer will + /// point at a valid `bindings::resource`. + /// + /// The caller must also ensure that the [`Resource`] is only accessed via the + /// returned reference for the duration of 'a. + pub(crate) const unsafe fn from_raw<'a>(ptr: *mut bindings::resource) -> &'a Self { + // SAFETY: Self is a transparent wrapper around `Opaque<bindings::resource>`. + unsafe { &*ptr.cast() } + } + + /// Requests a resource region. + /// + /// Exclusive access will be given and the region will be marked as busy. + /// Further calls to [`Self::request_region`] will return [`None`] if + /// the region, or a part of it, is already in use. + pub fn request_region( + &self, + start: PhysAddr, + size: ResourceSize, + name: CString, + flags: Flags, + ) -> Option<Region> { + // SAFETY: + // - Safe as per the invariant of `Resource`. + // - `__request_region` will store a reference to the name, but that is + // safe as we own it and it will not be dropped until the `Region` is + // dropped. + let region = unsafe { + bindings::__request_region( + self.0.get(), + start, + size, + name.as_char_ptr(), + flags.0 as c_int, + ) + }; + + Some(Region { + resource: NonNull::new(region)?, + _name: name, + }) + } + + /// Returns the size of the resource. + pub fn size(&self) -> ResourceSize { + let inner = self.0.get(); + // SAFETY: Safe as per the invariants of `Resource`. + unsafe { bindings::resource_size(inner) } + } + + /// Returns the start address of the resource. + pub fn start(&self) -> PhysAddr { + let inner = self.0.get(); + // SAFETY: Safe as per the invariants of `Resource`. + unsafe { (*inner).start } + } + + /// Returns the name of the resource. + pub fn name(&self) -> Option<&CStr> { + let inner = self.0.get(); + + // SAFETY: Safe as per the invariants of `Resource`. + let name = unsafe { (*inner).name }; + + if name.is_null() { + return None; + } + + // SAFETY: In the C code, `resource::name` either contains a null + // pointer or points to a valid NUL-terminated C string, and at this + // point we know it is not null, so we can safely convert it to a + // `CStr`. + Some(unsafe { CStr::from_char_ptr(name) }) + } + + /// Returns the flags associated with the resource. + pub fn flags(&self) -> Flags { + let inner = self.0.get(); + // SAFETY: Safe as per the invariants of `Resource`. + let flags = unsafe { (*inner).flags }; + + Flags(flags) + } +} + +// SAFETY: `Resource` only holds a pointer to a C `struct resource`, which is +// safe to be used from any thread. +unsafe impl Send for Resource {} + +// SAFETY: `Resource` only holds a pointer to a C `struct resource`, references +// to which are safe to be used from any thread. +unsafe impl Sync for Resource {} + +/// Resource flags as stored in the C `struct resource::flags` field. +/// +/// They can be combined with the operators `|`, `&`, and `!`. +/// +/// Values can be used from the associated constants such as +/// [`Flags::IORESOURCE_IO`]. +#[derive(Clone, Copy, PartialEq)] +pub struct Flags(c_ulong); + +impl Flags { + /// Check whether `flags` is contained in `self`. + pub fn contains(self, flags: Flags) -> bool { + (self & flags) == flags + } +} + +impl core::ops::BitOr for Flags { + type Output = Self; + fn bitor(self, rhs: Self) -> Self::Output { + Self(self.0 | rhs.0) + } +} + +impl core::ops::BitAnd for Flags { + type Output = Self; + fn bitand(self, rhs: Self) -> Self::Output { + Self(self.0 & rhs.0) + } +} + +impl core::ops::Not for Flags { + type Output = Self; + fn not(self) -> Self::Output { + Self(!self.0) + } +} + +impl Flags { + /// PCI/ISA I/O ports. + pub const IORESOURCE_IO: Flags = Flags::new(bindings::IORESOURCE_IO); + + /// Resource is software muxed. + pub const IORESOURCE_MUXED: Flags = Flags::new(bindings::IORESOURCE_MUXED); + + /// Resource represents a memory region. + pub const IORESOURCE_MEM: Flags = Flags::new(bindings::IORESOURCE_MEM); + + /// Resource represents a memory region that must be ioremaped using `ioremap_np`. + pub const IORESOURCE_MEM_NONPOSTED: Flags = Flags::new(bindings::IORESOURCE_MEM_NONPOSTED); + + const fn new(value: u32) -> Self { + crate::build_assert!(value as u64 <= c_ulong::MAX as u64); + Flags(value as c_ulong) + } +} diff --git a/rust/kernel/ioctl.rs b/rust/kernel/ioctl.rs new file mode 100644 index 000000000000..2fc7662339e5 --- /dev/null +++ b/rust/kernel/ioctl.rs @@ -0,0 +1,72 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! `ioctl()` number definitions. +//! +//! C header: [`include/asm-generic/ioctl.h`](srctree/include/asm-generic/ioctl.h) + +#![expect(non_snake_case)] + +use crate::build_assert; + +/// Build an ioctl number, analogous to the C macro of the same name. +#[inline(always)] +const fn _IOC(dir: u32, ty: u32, nr: u32, size: usize) -> u32 { + build_assert!(dir <= uapi::_IOC_DIRMASK); + build_assert!(ty <= uapi::_IOC_TYPEMASK); + build_assert!(nr <= uapi::_IOC_NRMASK); + build_assert!(size <= (uapi::_IOC_SIZEMASK as usize)); + + (dir << uapi::_IOC_DIRSHIFT) + | (ty << uapi::_IOC_TYPESHIFT) + | (nr << uapi::_IOC_NRSHIFT) + | ((size as u32) << uapi::_IOC_SIZESHIFT) +} + +/// Build an ioctl number for an argumentless ioctl. +#[inline(always)] +pub const fn _IO(ty: u32, nr: u32) -> u32 { + _IOC(uapi::_IOC_NONE, ty, nr, 0) +} + +/// Build an ioctl number for a read-only ioctl. +#[inline(always)] +pub const fn _IOR<T>(ty: u32, nr: u32) -> u32 { + _IOC(uapi::_IOC_READ, ty, nr, core::mem::size_of::<T>()) +} + +/// Build an ioctl number for a write-only ioctl. +#[inline(always)] +pub const fn _IOW<T>(ty: u32, nr: u32) -> u32 { + _IOC(uapi::_IOC_WRITE, ty, nr, core::mem::size_of::<T>()) +} + +/// Build an ioctl number for a read-write ioctl. +#[inline(always)] +pub const fn _IOWR<T>(ty: u32, nr: u32) -> u32 { + _IOC( + uapi::_IOC_READ | uapi::_IOC_WRITE, + ty, + nr, + core::mem::size_of::<T>(), + ) +} + +/// Get the ioctl direction from an ioctl number. +pub const fn _IOC_DIR(nr: u32) -> u32 { + (nr >> uapi::_IOC_DIRSHIFT) & uapi::_IOC_DIRMASK +} + +/// Get the ioctl type from an ioctl number. +pub const fn _IOC_TYPE(nr: u32) -> u32 { + (nr >> uapi::_IOC_TYPESHIFT) & uapi::_IOC_TYPEMASK +} + +/// Get the ioctl number from an ioctl number. +pub const fn _IOC_NR(nr: u32) -> u32 { + (nr >> uapi::_IOC_NRSHIFT) & uapi::_IOC_NRMASK +} + +/// Get the ioctl size from an ioctl number. +pub const fn _IOC_SIZE(nr: u32) -> usize { + ((nr >> uapi::_IOC_SIZESHIFT) & uapi::_IOC_SIZEMASK) as usize +} diff --git a/rust/kernel/iov.rs b/rust/kernel/iov.rs new file mode 100644 index 000000000000..43bae8923c46 --- /dev/null +++ b/rust/kernel/iov.rs @@ -0,0 +1,314 @@ +// SPDX-License-Identifier: GPL-2.0 + +// Copyright (C) 2025 Google LLC. + +//! IO vectors. +//! +//! C headers: [`include/linux/iov_iter.h`](srctree/include/linux/iov_iter.h), +//! [`include/linux/uio.h`](srctree/include/linux/uio.h) + +use crate::{ + alloc::{Allocator, Flags}, + bindings, + prelude::*, + types::Opaque, +}; +use core::{marker::PhantomData, mem::MaybeUninit, ptr, slice}; + +const ITER_SOURCE: bool = bindings::ITER_SOURCE != 0; +const ITER_DEST: bool = bindings::ITER_DEST != 0; + +// Compile-time assertion for the above constants. +const _: () = { + build_assert!( + ITER_SOURCE != ITER_DEST, + "ITER_DEST and ITER_SOURCE should be different." + ); +}; + +/// An IO vector that acts as a source of data. +/// +/// The data may come from many different sources. This includes both things in kernel-space and +/// reading from userspace. It's not necessarily the case that the data source is immutable, so +/// rewinding the IO vector to read the same data twice is not guaranteed to result in the same +/// bytes. It's also possible that the data source is mapped in a thread-local manner using e.g. +/// `kmap_local_page()`, so this type is not `Send` to ensure that the mapping is read from the +/// right context in that scenario. +/// +/// # Invariants +/// +/// Must hold a valid `struct iov_iter` with `data_source` set to `ITER_SOURCE`. For the duration +/// of `'data`, it must be safe to read from this IO vector using the standard C methods for this +/// purpose. +#[repr(transparent)] +pub struct IovIterSource<'data> { + iov: Opaque<bindings::iov_iter>, + /// Represent to the type system that this value contains a pointer to readable data it does + /// not own. + _source: PhantomData<&'data [u8]>, +} + +impl<'data> IovIterSource<'data> { + /// Obtain an `IovIterSource` from a raw pointer. + /// + /// # Safety + /// + /// * The referenced `struct iov_iter` must be valid and must only be accessed through the + /// returned reference for the duration of `'iov`. + /// * The referenced `struct iov_iter` must have `data_source` set to `ITER_SOURCE`. + /// * For the duration of `'data`, it must be safe to read from this IO vector using the + /// standard C methods for this purpose. + #[track_caller] + #[inline] + pub unsafe fn from_raw<'iov>(ptr: *mut bindings::iov_iter) -> &'iov mut IovIterSource<'data> { + // SAFETY: The caller ensures that `ptr` is valid. + let data_source = unsafe { (*ptr).data_source }; + assert_eq!(data_source, ITER_SOURCE); + + // SAFETY: The caller ensures the type invariants for the right durations, and + // `IovIterSource` is layout compatible with `struct iov_iter`. + unsafe { &mut *ptr.cast::<IovIterSource<'data>>() } + } + + /// Access this as a raw `struct iov_iter`. + #[inline] + pub fn as_raw(&mut self) -> *mut bindings::iov_iter { + self.iov.get() + } + + /// Returns the number of bytes available in this IO vector. + /// + /// Note that this may overestimate the number of bytes. For example, reading from userspace + /// memory could fail with `EFAULT`, which will be treated as the end of the IO vector. + #[inline] + pub fn len(&self) -> usize { + // SAFETY: We have shared access to this IO vector, so we can read its `count` field. + unsafe { + (*self.iov.get()) + .__bindgen_anon_1 + .__bindgen_anon_1 + .as_ref() + .count + } + } + + /// Returns whether there are any bytes left in this IO vector. + /// + /// This may return `true` even if there are no more bytes available. For example, reading from + /// userspace memory could fail with `EFAULT`, which will be treated as the end of the IO vector. + #[inline] + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + /// Advance this IO vector by `bytes` bytes. + /// + /// If `bytes` is larger than the size of this IO vector, it is advanced to the end. + #[inline] + pub fn advance(&mut self, bytes: usize) { + // SAFETY: By the type invariants, `self.iov` is a valid IO vector. + unsafe { bindings::iov_iter_advance(self.as_raw(), bytes) }; + } + + /// Advance this IO vector backwards by `bytes` bytes. + /// + /// # Safety + /// + /// The IO vector must not be reverted to before its beginning. + #[inline] + pub unsafe fn revert(&mut self, bytes: usize) { + // SAFETY: By the type invariants, `self.iov` is a valid IO vector, and the caller + // ensures that `bytes` is in bounds. + unsafe { bindings::iov_iter_revert(self.as_raw(), bytes) }; + } + + /// Read data from this IO vector. + /// + /// Returns the number of bytes that have been copied. + #[inline] + pub fn copy_from_iter(&mut self, out: &mut [u8]) -> usize { + // SAFETY: `Self::copy_from_iter_raw` guarantees that it will not write any uninitialized + // bytes in the provided buffer, so `out` is still a valid `u8` slice after this call. + let out = unsafe { &mut *(ptr::from_mut(out) as *mut [MaybeUninit<u8>]) }; + + self.copy_from_iter_raw(out).len() + } + + /// Read data from this IO vector and append it to a vector. + /// + /// Returns the number of bytes that have been copied. + #[inline] + pub fn copy_from_iter_vec<A: Allocator>( + &mut self, + out: &mut Vec<u8, A>, + flags: Flags, + ) -> Result<usize> { + out.reserve(self.len(), flags)?; + let len = self.copy_from_iter_raw(out.spare_capacity_mut()).len(); + // SAFETY: + // - `len` is the length of a subslice of the spare capacity, so `len` is at most the + // length of the spare capacity. + // - `Self::copy_from_iter_raw` guarantees that the first `len` bytes of the spare capacity + // have been initialized. + unsafe { out.inc_len(len) }; + Ok(len) + } + + /// Read data from this IO vector into potentially uninitialized memory. + /// + /// Returns the sub-slice of the output that has been initialized. If the returned slice is + /// shorter than the input buffer, then the entire IO vector has been read. + /// + /// This will never write uninitialized bytes to the provided buffer. + #[inline] + pub fn copy_from_iter_raw(&mut self, out: &mut [MaybeUninit<u8>]) -> &mut [u8] { + let capacity = out.len(); + let out = out.as_mut_ptr().cast::<u8>(); + + // GUARANTEES: The C API guarantees that it does not write uninitialized bytes to the + // provided buffer. + // SAFETY: + // * By the type invariants, it is still valid to read from this IO vector. + // * `out` is valid for writing for `capacity` bytes because it comes from a slice of + // that length. + let len = unsafe { bindings::_copy_from_iter(out.cast(), capacity, self.as_raw()) }; + + // SAFETY: The underlying C api guarantees that initialized bytes have been written to the + // first `len` bytes of the spare capacity. + unsafe { slice::from_raw_parts_mut(out, len) } + } +} + +/// An IO vector that acts as a destination for data. +/// +/// IO vectors support many different types of destinations. This includes both buffers in +/// kernel-space and writing to userspace. It's possible that the destination buffer is mapped in a +/// thread-local manner using e.g. `kmap_local_page()`, so this type is not `Send` to ensure that +/// the mapping is written to the right context in that scenario. +/// +/// # Invariants +/// +/// Must hold a valid `struct iov_iter` with `data_source` set to `ITER_DEST`. For the duration of +/// `'data`, it must be safe to write to this IO vector using the standard C methods for this +/// purpose. +#[repr(transparent)] +pub struct IovIterDest<'data> { + iov: Opaque<bindings::iov_iter>, + /// Represent to the type system that this value contains a pointer to writable data it does + /// not own. + _source: PhantomData<&'data mut [u8]>, +} + +impl<'data> IovIterDest<'data> { + /// Obtain an `IovIterDest` from a raw pointer. + /// + /// # Safety + /// + /// * The referenced `struct iov_iter` must be valid and must only be accessed through the + /// returned reference for the duration of `'iov`. + /// * The referenced `struct iov_iter` must have `data_source` set to `ITER_DEST`. + /// * For the duration of `'data`, it must be safe to write to this IO vector using the + /// standard C methods for this purpose. + #[track_caller] + #[inline] + pub unsafe fn from_raw<'iov>(ptr: *mut bindings::iov_iter) -> &'iov mut IovIterDest<'data> { + // SAFETY: The caller ensures that `ptr` is valid. + let data_source = unsafe { (*ptr).data_source }; + assert_eq!(data_source, ITER_DEST); + + // SAFETY: The caller ensures the type invariants for the right durations, and + // `IovIterSource` is layout compatible with `struct iov_iter`. + unsafe { &mut *ptr.cast::<IovIterDest<'data>>() } + } + + /// Access this as a raw `struct iov_iter`. + #[inline] + pub fn as_raw(&mut self) -> *mut bindings::iov_iter { + self.iov.get() + } + + /// Returns the number of bytes available in this IO vector. + /// + /// Note that this may overestimate the number of bytes. For example, reading from userspace + /// memory could fail with EFAULT, which will be treated as the end of the IO vector. + #[inline] + pub fn len(&self) -> usize { + // SAFETY: We have shared access to this IO vector, so we can read its `count` field. + unsafe { + (*self.iov.get()) + .__bindgen_anon_1 + .__bindgen_anon_1 + .as_ref() + .count + } + } + + /// Returns whether there are any bytes left in this IO vector. + /// + /// This may return `true` even if there are no more bytes available. For example, reading from + /// userspace memory could fail with EFAULT, which will be treated as the end of the IO vector. + #[inline] + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + /// Advance this IO vector by `bytes` bytes. + /// + /// If `bytes` is larger than the size of this IO vector, it is advanced to the end. + #[inline] + pub fn advance(&mut self, bytes: usize) { + // SAFETY: By the type invariants, `self.iov` is a valid IO vector. + unsafe { bindings::iov_iter_advance(self.as_raw(), bytes) }; + } + + /// Advance this IO vector backwards by `bytes` bytes. + /// + /// # Safety + /// + /// The IO vector must not be reverted to before its beginning. + #[inline] + pub unsafe fn revert(&mut self, bytes: usize) { + // SAFETY: By the type invariants, `self.iov` is a valid IO vector, and the caller + // ensures that `bytes` is in bounds. + unsafe { bindings::iov_iter_revert(self.as_raw(), bytes) }; + } + + /// Write data to this IO vector. + /// + /// Returns the number of bytes that were written. If this is shorter than the provided slice, + /// then no more bytes can be written. + #[inline] + pub fn copy_to_iter(&mut self, input: &[u8]) -> usize { + // SAFETY: + // * By the type invariants, it is still valid to write to this IO vector. + // * `input` is valid for `input.len()` bytes. + unsafe { bindings::_copy_to_iter(input.as_ptr().cast(), input.len(), self.as_raw()) } + } + + /// Utility for implementing `read_iter` given the full contents of the file. + /// + /// The full contents of the file being read from is represented by `contents`. This call will + /// write the appropriate sub-slice of `contents` and update the file position in `ppos` so + /// that the file will appear to contain `contents` even if takes multiple reads to read the + /// entire file. + #[inline] + pub fn simple_read_from_buffer(&mut self, ppos: &mut i64, contents: &[u8]) -> Result<usize> { + if *ppos < 0 { + return Err(EINVAL); + } + let Ok(pos) = usize::try_from(*ppos) else { + return Ok(0); + }; + if pos >= contents.len() { + return Ok(0); + } + + // BOUNDS: We just checked that `pos < contents.len()` above. + let num_written = self.copy_to_iter(&contents[pos..]); + + // OVERFLOW: `pos+num_written <= contents.len() <= isize::MAX <= i64::MAX`. + *ppos = (pos + num_written) as i64; + + Ok(num_written) + } +} diff --git a/rust/kernel/irq.rs b/rust/kernel/irq.rs new file mode 100644 index 000000000000..20abd4056655 --- /dev/null +++ b/rust/kernel/irq.rs @@ -0,0 +1,24 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! IRQ abstractions. +//! +//! An IRQ is an interrupt request from a device. It is used to get the CPU's +//! attention so it can service a hardware event in a timely manner. +//! +//! The current abstractions handle IRQ requests and handlers, i.e.: it allows +//! drivers to register a handler for a given IRQ line. +//! +//! C header: [`include/linux/device.h`](srctree/include/linux/interrupt.h) + +/// Flags to be used when registering IRQ handlers. +mod flags; + +/// IRQ allocation and handling. +mod request; + +pub use flags::Flags; + +pub use request::{ + Handler, IrqRequest, IrqReturn, Registration, ThreadedHandler, ThreadedIrqReturn, + ThreadedRegistration, +}; diff --git a/rust/kernel/irq/flags.rs b/rust/kernel/irq/flags.rs new file mode 100644 index 000000000000..adfde96ec47c --- /dev/null +++ b/rust/kernel/irq/flags.rs @@ -0,0 +1,124 @@ +// SPDX-License-Identifier: GPL-2.0 +// SPDX-FileCopyrightText: Copyright 2025 Collabora ltd. + +use crate::bindings; +use crate::prelude::*; + +/// Flags to be used when registering IRQ handlers. +/// +/// Flags can be used to request specific behaviors when registering an IRQ +/// handler, and can be combined using the `|`, `&`, and `!` operators to +/// further control the system's behavior. +/// +/// A common use case is to register a shared interrupt, as sharing the line +/// between devices is increasingly common in modern systems and is even +/// required for some buses. This requires setting [`Flags::SHARED`] when +/// requesting the interrupt. Other use cases include setting the trigger type +/// through `Flags::TRIGGER_*`, which determines when the interrupt fires, or +/// controlling whether the interrupt is masked after the handler runs by using +/// [`Flags::ONESHOT`]. +/// +/// If an invalid combination of flags is provided, the system will refuse to +/// register the handler, and lower layers will enforce certain flags when +/// necessary. This means, for example, that all the +/// [`crate::irq::Registration`] for a shared interrupt have to agree on +/// [`Flags::SHARED`] and on the same trigger type, if set. +#[derive(Clone, Copy, PartialEq, Eq)] +pub struct Flags(c_ulong); + +impl Flags { + /// Use the interrupt line as already configured. + pub const TRIGGER_NONE: Flags = Flags::new(bindings::IRQF_TRIGGER_NONE); + + /// The interrupt is triggered when the signal goes from low to high. + pub const TRIGGER_RISING: Flags = Flags::new(bindings::IRQF_TRIGGER_RISING); + + /// The interrupt is triggered when the signal goes from high to low. + pub const TRIGGER_FALLING: Flags = Flags::new(bindings::IRQF_TRIGGER_FALLING); + + /// The interrupt is triggered while the signal is held high. + pub const TRIGGER_HIGH: Flags = Flags::new(bindings::IRQF_TRIGGER_HIGH); + + /// The interrupt is triggered while the signal is held low. + pub const TRIGGER_LOW: Flags = Flags::new(bindings::IRQF_TRIGGER_LOW); + + /// Allow sharing the IRQ among several devices. + pub const SHARED: Flags = Flags::new(bindings::IRQF_SHARED); + + /// Set by callers when they expect sharing mismatches to occur. + pub const PROBE_SHARED: Flags = Flags::new(bindings::IRQF_PROBE_SHARED); + + /// Flag to mark this interrupt as timer interrupt. + pub const TIMER: Flags = Flags::new(bindings::IRQF_TIMER); + + /// Interrupt is per CPU. + pub const PERCPU: Flags = Flags::new(bindings::IRQF_PERCPU); + + /// Flag to exclude this interrupt from irq balancing. + pub const NOBALANCING: Flags = Flags::new(bindings::IRQF_NOBALANCING); + + /// Interrupt is used for polling (only the interrupt that is registered + /// first in a shared interrupt is considered for performance reasons). + pub const IRQPOLL: Flags = Flags::new(bindings::IRQF_IRQPOLL); + + /// Interrupt is not re-enabled after the hardirq handler finished. Used by + /// threaded interrupts which need to keep the irq line disabled until the + /// threaded handler has been run. + pub const ONESHOT: Flags = Flags::new(bindings::IRQF_ONESHOT); + + /// Do not disable this IRQ during suspend. Does not guarantee that this + /// interrupt will wake the system from a suspended state. + pub const NO_SUSPEND: Flags = Flags::new(bindings::IRQF_NO_SUSPEND); + + /// Force enable it on resume even if [`Flags::NO_SUSPEND`] is set. + pub const FORCE_RESUME: Flags = Flags::new(bindings::IRQF_FORCE_RESUME); + + /// Interrupt cannot be threaded. + pub const NO_THREAD: Flags = Flags::new(bindings::IRQF_NO_THREAD); + + /// Resume IRQ early during syscore instead of at device resume time. + pub const EARLY_RESUME: Flags = Flags::new(bindings::IRQF_EARLY_RESUME); + + /// If the IRQ is shared with a [`Flags::NO_SUSPEND`] user, execute this + /// interrupt handler after suspending interrupts. For system wakeup devices + /// users need to implement wakeup detection in their interrupt handlers. + pub const COND_SUSPEND: Flags = Flags::new(bindings::IRQF_COND_SUSPEND); + + /// Don't enable IRQ or NMI automatically when users request it. Users will + /// enable it explicitly by `enable_irq` or `enable_nmi` later. + pub const NO_AUTOEN: Flags = Flags::new(bindings::IRQF_NO_AUTOEN); + + /// Exclude from runnaway detection for IPI and similar handlers, depends on + /// `PERCPU`. + pub const NO_DEBUG: Flags = Flags::new(bindings::IRQF_NO_DEBUG); + + pub(crate) fn into_inner(self) -> c_ulong { + self.0 + } + + const fn new(value: u32) -> Self { + build_assert!(value as u64 <= c_ulong::MAX as u64); + Self(value as c_ulong) + } +} + +impl core::ops::BitOr for Flags { + type Output = Self; + fn bitor(self, rhs: Self) -> Self::Output { + Self(self.0 | rhs.0) + } +} + +impl core::ops::BitAnd for Flags { + type Output = Self; + fn bitand(self, rhs: Self) -> Self::Output { + Self(self.0 & rhs.0) + } +} + +impl core::ops::Not for Flags { + type Output = Self; + fn not(self) -> Self::Output { + Self(!self.0) + } +} diff --git a/rust/kernel/irq/request.rs b/rust/kernel/irq/request.rs new file mode 100644 index 000000000000..b150563fdef8 --- /dev/null +++ b/rust/kernel/irq/request.rs @@ -0,0 +1,507 @@ +// SPDX-License-Identifier: GPL-2.0 +// SPDX-FileCopyrightText: Copyright 2025 Collabora ltd. + +//! This module provides types like [`Registration`] and +//! [`ThreadedRegistration`], which allow users to register handlers for a given +//! IRQ line. + +use core::marker::PhantomPinned; + +use crate::alloc::Allocator; +use crate::device::{Bound, Device}; +use crate::devres::Devres; +use crate::error::to_result; +use crate::irq::flags::Flags; +use crate::prelude::*; +use crate::str::CStr; +use crate::sync::Arc; + +/// The value that can be returned from a [`Handler`] or a [`ThreadedHandler`]. +#[repr(u32)] +pub enum IrqReturn { + /// The interrupt was not from this device or was not handled. + None = bindings::irqreturn_IRQ_NONE, + + /// The interrupt was handled by this device. + Handled = bindings::irqreturn_IRQ_HANDLED, +} + +/// Callbacks for an IRQ handler. +pub trait Handler: Sync { + /// The hard IRQ handler. + /// + /// This is executed in interrupt context, hence all corresponding + /// limitations do apply. + /// + /// All work that does not necessarily need to be executed from + /// interrupt context, should be deferred to a threaded handler. + /// See also [`ThreadedRegistration`]. + fn handle(&self, device: &Device<Bound>) -> IrqReturn; +} + +impl<T: ?Sized + Handler + Send> Handler for Arc<T> { + fn handle(&self, device: &Device<Bound>) -> IrqReturn { + T::handle(self, device) + } +} + +impl<T: ?Sized + Handler, A: Allocator> Handler for Box<T, A> { + fn handle(&self, device: &Device<Bound>) -> IrqReturn { + T::handle(self, device) + } +} + +/// # Invariants +/// +/// - `self.irq` is the same as the one passed to `request_{threaded}_irq`. +/// - `cookie` was passed to `request_{threaded}_irq` as the cookie. It is guaranteed to be unique +/// by the type system, since each call to `new` will return a different instance of +/// `Registration`. +#[pin_data(PinnedDrop)] +struct RegistrationInner { + irq: u32, + cookie: *mut c_void, +} + +impl RegistrationInner { + fn synchronize(&self) { + // SAFETY: safe as per the invariants of `RegistrationInner` + unsafe { bindings::synchronize_irq(self.irq) }; + } +} + +#[pinned_drop] +impl PinnedDrop for RegistrationInner { + fn drop(self: Pin<&mut Self>) { + // SAFETY: + // + // Safe as per the invariants of `RegistrationInner` and: + // + // - The containing struct is `!Unpin` and was initialized using + // pin-init, so it occupied the same memory location for the entirety of + // its lifetime. + // + // Notice that this will block until all handlers finish executing, + // i.e.: at no point will &self be invalid while the handler is running. + unsafe { bindings::free_irq(self.irq, self.cookie) }; + } +} + +// SAFETY: We only use `inner` on drop, which called at most once with no +// concurrent access. +unsafe impl Sync for RegistrationInner {} + +// SAFETY: It is safe to send `RegistrationInner` across threads. +unsafe impl Send for RegistrationInner {} + +/// A request for an IRQ line for a given device. +/// +/// # Invariants +/// +/// - `ìrq` is the number of an interrupt source of `dev`. +/// - `irq` has not been registered yet. +pub struct IrqRequest<'a> { + dev: &'a Device<Bound>, + irq: u32, +} + +impl<'a> IrqRequest<'a> { + /// Creates a new IRQ request for the given device and IRQ number. + /// + /// # Safety + /// + /// - `irq` should be a valid IRQ number for `dev`. + pub(crate) unsafe fn new(dev: &'a Device<Bound>, irq: u32) -> Self { + // INVARIANT: `irq` is a valid IRQ number for `dev`. + IrqRequest { dev, irq } + } + + /// Returns the IRQ number of an [`IrqRequest`]. + pub fn irq(&self) -> u32 { + self.irq + } +} + +/// A registration of an IRQ handler for a given IRQ line. +/// +/// # Examples +/// +/// The following is an example of using `Registration`. It uses a +/// [`Completion`] to coordinate between the IRQ +/// handler and process context. [`Completion`] uses interior mutability, so the +/// handler can signal with [`Completion::complete_all()`] and the process +/// context can wait with [`Completion::wait_for_completion()`] even though +/// there is no way to get a mutable reference to the any of the fields in +/// `Data`. +/// +/// [`Completion`]: kernel::sync::Completion +/// [`Completion::complete_all()`]: kernel::sync::Completion::complete_all +/// [`Completion::wait_for_completion()`]: kernel::sync::Completion::wait_for_completion +/// +/// ``` +/// use kernel::c_str; +/// use kernel::device::{Bound, Device}; +/// use kernel::irq::{self, Flags, IrqRequest, IrqReturn, Registration}; +/// use kernel::prelude::*; +/// use kernel::sync::{Arc, Completion}; +/// +/// // Data shared between process and IRQ context. +/// #[pin_data] +/// struct Data { +/// #[pin] +/// completion: Completion, +/// } +/// +/// impl irq::Handler for Data { +/// // Executed in IRQ context. +/// fn handle(&self, _dev: &Device<Bound>) -> IrqReturn { +/// self.completion.complete_all(); +/// IrqReturn::Handled +/// } +/// } +/// +/// // Registers an IRQ handler for the given IrqRequest. +/// // +/// // This runs in process context and assumes `request` was previously acquired from a device. +/// fn register_irq( +/// handler: impl PinInit<Data, Error>, +/// request: IrqRequest<'_>, +/// ) -> Result<Arc<Registration<Data>>> { +/// let registration = Registration::new(request, Flags::SHARED, c_str!("my_device"), handler); +/// +/// let registration = Arc::pin_init(registration, GFP_KERNEL)?; +/// +/// registration.handler().completion.wait_for_completion(); +/// +/// Ok(registration) +/// } +/// # Ok::<(), Error>(()) +/// ``` +/// +/// # Invariants +/// +/// * We own an irq handler whose cookie is a pointer to `Self`. +#[pin_data] +pub struct Registration<T: Handler + 'static> { + #[pin] + inner: Devres<RegistrationInner>, + + #[pin] + handler: T, + + /// Pinned because we need address stability so that we can pass a pointer + /// to the callback. + #[pin] + _pin: PhantomPinned, +} + +impl<T: Handler + 'static> Registration<T> { + /// Registers the IRQ handler with the system for the given IRQ number. + pub fn new<'a>( + request: IrqRequest<'a>, + flags: Flags, + name: &'static CStr, + handler: impl PinInit<T, Error> + 'a, + ) -> impl PinInit<Self, Error> + 'a { + try_pin_init!(&this in Self { + handler <- handler, + inner <- Devres::new( + request.dev, + try_pin_init!(RegistrationInner { + // INVARIANT: `this` is a valid pointer to the `Registration` instance + cookie: this.as_ptr().cast::<c_void>(), + irq: { + // SAFETY: + // - The callbacks are valid for use with request_irq. + // - If this succeeds, the slot is guaranteed to be valid until the + // destructor of Self runs, which will deregister the callbacks + // before the memory location becomes invalid. + // - When request_irq is called, everything that handle_irq_callback will + // touch has already been initialized, so it's safe for the callback to + // be called immediately. + to_result(unsafe { + bindings::request_irq( + request.irq, + Some(handle_irq_callback::<T>), + flags.into_inner(), + name.as_char_ptr(), + this.as_ptr().cast::<c_void>(), + ) + })?; + request.irq + } + }) + ), + _pin: PhantomPinned, + }) + } + + /// Returns a reference to the handler that was registered with the system. + pub fn handler(&self) -> &T { + &self.handler + } + + /// Wait for pending IRQ handlers on other CPUs. + /// + /// This will attempt to access the inner [`Devres`] container. + pub fn try_synchronize(&self) -> Result { + let inner = self.inner.try_access().ok_or(ENODEV)?; + inner.synchronize(); + Ok(()) + } + + /// Wait for pending IRQ handlers on other CPUs. + pub fn synchronize(&self, dev: &Device<Bound>) -> Result { + let inner = self.inner.access(dev)?; + inner.synchronize(); + Ok(()) + } +} + +/// # Safety +/// +/// This function should be only used as the callback in `request_irq`. +unsafe extern "C" fn handle_irq_callback<T: Handler>(_irq: i32, ptr: *mut c_void) -> c_uint { + // SAFETY: `ptr` is a pointer to `Registration<T>` set in `Registration::new` + let registration = unsafe { &*(ptr as *const Registration<T>) }; + // SAFETY: The irq callback is removed before the device is unbound, so the fact that the irq + // callback is running implies that the device has not yet been unbound. + let device = unsafe { registration.inner.device().as_bound() }; + + T::handle(®istration.handler, device) as c_uint +} + +/// The value that can be returned from [`ThreadedHandler::handle`]. +#[repr(u32)] +pub enum ThreadedIrqReturn { + /// The interrupt was not from this device or was not handled. + None = bindings::irqreturn_IRQ_NONE, + + /// The interrupt was handled by this device. + Handled = bindings::irqreturn_IRQ_HANDLED, + + /// The handler wants the handler thread to wake up. + WakeThread = bindings::irqreturn_IRQ_WAKE_THREAD, +} + +/// Callbacks for a threaded IRQ handler. +pub trait ThreadedHandler: Sync { + /// The hard IRQ handler. + /// + /// This is executed in interrupt context, hence all corresponding + /// limitations do apply. All work that does not necessarily need to be + /// executed from interrupt context, should be deferred to the threaded + /// handler, i.e. [`ThreadedHandler::handle_threaded`]. + /// + /// The default implementation returns [`ThreadedIrqReturn::WakeThread`]. + #[expect(unused_variables)] + fn handle(&self, device: &Device<Bound>) -> ThreadedIrqReturn { + ThreadedIrqReturn::WakeThread + } + + /// The threaded IRQ handler. + /// + /// This is executed in process context. The kernel creates a dedicated + /// `kthread` for this purpose. + fn handle_threaded(&self, device: &Device<Bound>) -> IrqReturn; +} + +impl<T: ?Sized + ThreadedHandler + Send> ThreadedHandler for Arc<T> { + fn handle(&self, device: &Device<Bound>) -> ThreadedIrqReturn { + T::handle(self, device) + } + + fn handle_threaded(&self, device: &Device<Bound>) -> IrqReturn { + T::handle_threaded(self, device) + } +} + +impl<T: ?Sized + ThreadedHandler, A: Allocator> ThreadedHandler for Box<T, A> { + fn handle(&self, device: &Device<Bound>) -> ThreadedIrqReturn { + T::handle(self, device) + } + + fn handle_threaded(&self, device: &Device<Bound>) -> IrqReturn { + T::handle_threaded(self, device) + } +} + +/// A registration of a threaded IRQ handler for a given IRQ line. +/// +/// Two callbacks are required: one to handle the IRQ, and one to handle any +/// other work in a separate thread. +/// +/// The thread handler is only called if the IRQ handler returns +/// [`ThreadedIrqReturn::WakeThread`]. +/// +/// # Examples +/// +/// The following is an example of using [`ThreadedRegistration`]. It uses a +/// [`Mutex`](kernel::sync::Mutex) to provide interior mutability. +/// +/// ``` +/// use kernel::c_str; +/// use kernel::device::{Bound, Device}; +/// use kernel::irq::{ +/// self, Flags, IrqRequest, IrqReturn, ThreadedHandler, ThreadedIrqReturn, +/// ThreadedRegistration, +/// }; +/// use kernel::prelude::*; +/// use kernel::sync::{Arc, Mutex}; +/// +/// // Declare a struct that will be passed in when the interrupt fires. The u32 +/// // merely serves as an example of some internal data. +/// // +/// // [`irq::ThreadedHandler::handle`] takes `&self`. This example +/// // illustrates how interior mutability can be used when sharing the data +/// // between process context and IRQ context. +/// #[pin_data] +/// struct Data { +/// #[pin] +/// value: Mutex<u32>, +/// } +/// +/// impl ThreadedHandler for Data { +/// // This will run (in a separate kthread) if and only if +/// // [`ThreadedHandler::handle`] returns [`WakeThread`], which it does by +/// // default. +/// fn handle_threaded(&self, _dev: &Device<Bound>) -> IrqReturn { +/// let mut data = self.value.lock(); +/// *data += 1; +/// IrqReturn::Handled +/// } +/// } +/// +/// // Registers a threaded IRQ handler for the given [`IrqRequest`]. +/// // +/// // This is executing in process context and assumes that `request` was +/// // previously acquired from a device. +/// fn register_threaded_irq( +/// handler: impl PinInit<Data, Error>, +/// request: IrqRequest<'_>, +/// ) -> Result<Arc<ThreadedRegistration<Data>>> { +/// let registration = +/// ThreadedRegistration::new(request, Flags::SHARED, c_str!("my_device"), handler); +/// +/// let registration = Arc::pin_init(registration, GFP_KERNEL)?; +/// +/// { +/// // The data can be accessed from process context too. +/// let mut data = registration.handler().value.lock(); +/// *data += 1; +/// } +/// +/// Ok(registration) +/// } +/// # Ok::<(), Error>(()) +/// ``` +/// +/// # Invariants +/// +/// * We own an irq handler whose cookie is a pointer to `Self`. +#[pin_data] +pub struct ThreadedRegistration<T: ThreadedHandler + 'static> { + #[pin] + inner: Devres<RegistrationInner>, + + #[pin] + handler: T, + + /// Pinned because we need address stability so that we can pass a pointer + /// to the callback. + #[pin] + _pin: PhantomPinned, +} + +impl<T: ThreadedHandler + 'static> ThreadedRegistration<T> { + /// Registers the IRQ handler with the system for the given IRQ number. + pub fn new<'a>( + request: IrqRequest<'a>, + flags: Flags, + name: &'static CStr, + handler: impl PinInit<T, Error> + 'a, + ) -> impl PinInit<Self, Error> + 'a { + try_pin_init!(&this in Self { + handler <- handler, + inner <- Devres::new( + request.dev, + try_pin_init!(RegistrationInner { + // INVARIANT: `this` is a valid pointer to the `ThreadedRegistration` instance. + cookie: this.as_ptr().cast::<c_void>(), + irq: { + // SAFETY: + // - The callbacks are valid for use with request_threaded_irq. + // - If this succeeds, the slot is guaranteed to be valid until the + // destructor of Self runs, which will deregister the callbacks + // before the memory location becomes invalid. + // - When request_threaded_irq is called, everything that the two callbacks + // will touch has already been initialized, so it's safe for the + // callbacks to be called immediately. + to_result(unsafe { + bindings::request_threaded_irq( + request.irq, + Some(handle_threaded_irq_callback::<T>), + Some(thread_fn_callback::<T>), + flags.into_inner(), + name.as_char_ptr(), + this.as_ptr().cast::<c_void>(), + ) + })?; + request.irq + } + }) + ), + _pin: PhantomPinned, + }) + } + + /// Returns a reference to the handler that was registered with the system. + pub fn handler(&self) -> &T { + &self.handler + } + + /// Wait for pending IRQ handlers on other CPUs. + /// + /// This will attempt to access the inner [`Devres`] container. + pub fn try_synchronize(&self) -> Result { + let inner = self.inner.try_access().ok_or(ENODEV)?; + inner.synchronize(); + Ok(()) + } + + /// Wait for pending IRQ handlers on other CPUs. + pub fn synchronize(&self, dev: &Device<Bound>) -> Result { + let inner = self.inner.access(dev)?; + inner.synchronize(); + Ok(()) + } +} + +/// # Safety +/// +/// This function should be only used as the callback in `request_threaded_irq`. +unsafe extern "C" fn handle_threaded_irq_callback<T: ThreadedHandler>( + _irq: i32, + ptr: *mut c_void, +) -> c_uint { + // SAFETY: `ptr` is a pointer to `ThreadedRegistration<T>` set in `ThreadedRegistration::new` + let registration = unsafe { &*(ptr as *const ThreadedRegistration<T>) }; + // SAFETY: The irq callback is removed before the device is unbound, so the fact that the irq + // callback is running implies that the device has not yet been unbound. + let device = unsafe { registration.inner.device().as_bound() }; + + T::handle(®istration.handler, device) as c_uint +} + +/// # Safety +/// +/// This function should be only used as the callback in `request_threaded_irq`. +unsafe extern "C" fn thread_fn_callback<T: ThreadedHandler>(_irq: i32, ptr: *mut c_void) -> c_uint { + // SAFETY: `ptr` is a pointer to `ThreadedRegistration<T>` set in `ThreadedRegistration::new` + let registration = unsafe { &*(ptr as *const ThreadedRegistration<T>) }; + // SAFETY: The irq callback is removed before the device is unbound, so the fact that the irq + // callback is running implies that the device has not yet been unbound. + let device = unsafe { registration.inner.device().as_bound() }; + + T::handle_threaded(®istration.handler, device) as c_uint +} diff --git a/rust/kernel/jump_label.rs b/rust/kernel/jump_label.rs new file mode 100644 index 000000000000..4e974c768dbd --- /dev/null +++ b/rust/kernel/jump_label.rs @@ -0,0 +1,74 @@ +// SPDX-License-Identifier: GPL-2.0 + +// Copyright (C) 2024 Google LLC. + +//! Logic for static keys. +//! +//! C header: [`include/linux/jump_label.h`](srctree/include/linux/jump_label.h). + +/// Branch based on a static key. +/// +/// Takes three arguments: +/// +/// * `key` - the path to the static variable containing the `static_key`. +/// * `keytyp` - the type of `key`. +/// * `field` - the name of the field of `key` that contains the `static_key`. +/// +/// # Safety +/// +/// The macro must be used with a real static key defined by C. +#[macro_export] +macro_rules! static_branch_unlikely { + ($key:path, $keytyp:ty, $field:ident) => {{ + let _key: *const $keytyp = ::core::ptr::addr_of!($key); + let _key: *const $crate::bindings::static_key_false = ::core::ptr::addr_of!((*_key).$field); + let _key: *const $crate::bindings::static_key = _key.cast(); + + #[cfg(not(CONFIG_JUMP_LABEL))] + { + $crate::bindings::static_key_count(_key.cast_mut()) > 0 + } + + #[cfg(CONFIG_JUMP_LABEL)] + $crate::jump_label::arch_static_branch! { $key, $keytyp, $field, false } + }}; +} +pub use static_branch_unlikely; + +/// Assert that the assembly block evaluates to a string literal. +#[cfg(CONFIG_JUMP_LABEL)] +const _: &str = include!(concat!( + env!("OBJTREE"), + "/rust/kernel/generated_arch_static_branch_asm.rs" +)); + +#[macro_export] +#[doc(hidden)] +#[cfg(CONFIG_JUMP_LABEL)] +macro_rules! arch_static_branch { + ($key:path, $keytyp:ty, $field:ident, $branch:expr) => {'my_label: { + $crate::asm!( + include!(concat!(env!("OBJTREE"), "/rust/kernel/generated_arch_static_branch_asm.rs")); + l_yes = label { + break 'my_label true; + }, + symb = sym $key, + off = const ::core::mem::offset_of!($keytyp, $field), + branch = const $crate::jump_label::bool_to_int($branch), + ); + + break 'my_label false; + }}; +} + +#[cfg(CONFIG_JUMP_LABEL)] +pub use arch_static_branch; + +/// A helper used by inline assembly to pass a boolean to as a `const` parameter. +/// +/// Using this function instead of a cast lets you assert that the input is a boolean, and not some +/// other type that can also be cast to an integer. +#[doc(hidden)] +pub const fn bool_to_int(b: bool) -> i32 { + b as i32 +} diff --git a/rust/kernel/kunit.rs b/rust/kernel/kunit.rs new file mode 100644 index 000000000000..79436509dd73 --- /dev/null +++ b/rust/kernel/kunit.rs @@ -0,0 +1,371 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! KUnit-based macros for Rust unit tests. +//! +//! C header: [`include/kunit/test.h`](srctree/include/kunit/test.h) +//! +//! Reference: <https://docs.kernel.org/dev-tools/kunit/index.html> + +use crate::fmt; +use crate::prelude::*; + +#[cfg(CONFIG_PRINTK)] +use crate::c_str; + +/// Prints a KUnit error-level message. +/// +/// Public but hidden since it should only be used from KUnit generated code. +#[doc(hidden)] +pub fn err(args: fmt::Arguments<'_>) { + // SAFETY: The format string is null-terminated and the `%pA` specifier matches the argument we + // are passing. + #[cfg(CONFIG_PRINTK)] + unsafe { + bindings::_printk( + c_str!("\x013%pA").as_char_ptr(), + core::ptr::from_ref(&args).cast::<c_void>(), + ); + } +} + +/// Prints a KUnit info-level message. +/// +/// Public but hidden since it should only be used from KUnit generated code. +#[doc(hidden)] +pub fn info(args: fmt::Arguments<'_>) { + // SAFETY: The format string is null-terminated and the `%pA` specifier matches the argument we + // are passing. + #[cfg(CONFIG_PRINTK)] + unsafe { + bindings::_printk( + c_str!("\x016%pA").as_char_ptr(), + core::ptr::from_ref(&args).cast::<c_void>(), + ); + } +} + +/// Asserts that a boolean expression is `true` at runtime. +/// +/// Public but hidden since it should only be used from generated tests. +/// +/// Unlike the one in `core`, this one does not panic; instead, it is mapped to the KUnit +/// facilities. See [`assert!`] for more details. +#[doc(hidden)] +#[macro_export] +macro_rules! kunit_assert { + ($name:literal, $file:literal, $diff:expr, $condition:expr $(,)?) => { + 'out: { + // Do nothing if the condition is `true`. + if $condition { + break 'out; + } + + static FILE: &'static $crate::str::CStr = $crate::c_str!($file); + static LINE: i32 = ::core::line!() as i32 - $diff; + static CONDITION: &'static $crate::str::CStr = $crate::c_str!(stringify!($condition)); + + // SAFETY: FFI call without safety requirements. + let kunit_test = unsafe { $crate::bindings::kunit_get_current_test() }; + if kunit_test.is_null() { + // The assertion failed but this task is not running a KUnit test, so we cannot call + // KUnit, but at least print an error to the kernel log. This may happen if this + // macro is called from an spawned thread in a test (see + // `scripts/rustdoc_test_gen.rs`) or if some non-test code calls this macro by + // mistake (it is hidden to prevent that). + // + // This mimics KUnit's failed assertion format. + $crate::kunit::err($crate::prelude::fmt!( + " # {}: ASSERTION FAILED at {FILE}:{LINE}\n", + $name + )); + $crate::kunit::err($crate::prelude::fmt!( + " Expected {CONDITION} to be true, but is false\n" + )); + $crate::kunit::err($crate::prelude::fmt!( + " Failure not reported to KUnit since this is a non-KUnit task\n" + )); + break 'out; + } + + #[repr(transparent)] + struct Location($crate::bindings::kunit_loc); + + #[repr(transparent)] + struct UnaryAssert($crate::bindings::kunit_unary_assert); + + // SAFETY: There is only a static instance and in that one the pointer field points to + // an immutable C string. + unsafe impl Sync for Location {} + + // SAFETY: There is only a static instance and in that one the pointer field points to + // an immutable C string. + unsafe impl Sync for UnaryAssert {} + + static LOCATION: Location = Location($crate::bindings::kunit_loc { + file: $crate::str::as_char_ptr_in_const_context(FILE), + line: LINE, + }); + static ASSERTION: UnaryAssert = UnaryAssert($crate::bindings::kunit_unary_assert { + assert: $crate::bindings::kunit_assert {}, + condition: $crate::str::as_char_ptr_in_const_context(CONDITION), + expected_true: true, + }); + + // SAFETY: + // - FFI call. + // - The `kunit_test` pointer is valid because we got it from + // `kunit_get_current_test()` and it was not null. This means we are in a KUnit + // test, and that the pointer can be passed to KUnit functions and assertions. + // - The string pointers (`file` and `condition` above) point to null-terminated + // strings since they are `CStr`s. + // - The function pointer (`format`) points to the proper function. + // - The pointers passed will remain valid since they point to `static`s. + // - The format string is allowed to be null. + // - There are, however, problems with this: first of all, this will end up stopping + // the thread, without running destructors. While that is problematic in itself, + // it is considered UB to have what is effectively a forced foreign unwind + // with `extern "C"` ABI. One could observe the stack that is now gone from + // another thread. We should avoid pinning stack variables to prevent library UB, + // too. For the moment, given that test failures are reported immediately before the + // next test runs, that test failures should be fixed and that KUnit is explicitly + // documented as not suitable for production environments, we feel it is reasonable. + unsafe { + $crate::bindings::__kunit_do_failed_assertion( + kunit_test, + ::core::ptr::addr_of!(LOCATION.0), + $crate::bindings::kunit_assert_type_KUNIT_ASSERTION, + ::core::ptr::addr_of!(ASSERTION.0.assert), + Some($crate::bindings::kunit_unary_assert_format), + ::core::ptr::null(), + ); + } + + // SAFETY: FFI call; the `test` pointer is valid because this hidden macro should only + // be called by the generated documentation tests which forward the test pointer given + // by KUnit. + unsafe { + $crate::bindings::__kunit_abort(kunit_test); + } + } + }; +} + +/// Asserts that two expressions are equal to each other (using [`PartialEq`]). +/// +/// Public but hidden since it should only be used from generated tests. +/// +/// Unlike the one in `core`, this one does not panic; instead, it is mapped to the KUnit +/// facilities. See [`assert!`] for more details. +#[doc(hidden)] +#[macro_export] +macro_rules! kunit_assert_eq { + ($name:literal, $file:literal, $diff:expr, $left:expr, $right:expr $(,)?) => {{ + // For the moment, we just forward to the expression assert because, for binary asserts, + // KUnit supports only a few types (e.g. integers). + $crate::kunit_assert!($name, $file, $diff, $left == $right); + }}; +} + +trait TestResult { + fn is_test_result_ok(&self) -> bool; +} + +impl TestResult for () { + fn is_test_result_ok(&self) -> bool { + true + } +} + +impl<T, E> TestResult for Result<T, E> { + fn is_test_result_ok(&self) -> bool { + self.is_ok() + } +} + +/// Returns whether a test result is to be considered OK. +/// +/// This will be `assert!`ed from the generated tests. +#[doc(hidden)] +#[expect(private_bounds)] +pub fn is_test_result_ok(t: impl TestResult) -> bool { + t.is_test_result_ok() +} + +/// Represents an individual test case. +/// +/// The [`kunit_unsafe_test_suite!`] macro expects a NULL-terminated list of valid test cases. +/// Use [`kunit_case_null`] to generate such a delimiter. +#[doc(hidden)] +pub const fn kunit_case( + name: &'static kernel::str::CStr, + run_case: unsafe extern "C" fn(*mut kernel::bindings::kunit), +) -> kernel::bindings::kunit_case { + kernel::bindings::kunit_case { + run_case: Some(run_case), + name: kernel::str::as_char_ptr_in_const_context(name), + attr: kernel::bindings::kunit_attributes { + speed: kernel::bindings::kunit_speed_KUNIT_SPEED_NORMAL, + }, + generate_params: None, + status: kernel::bindings::kunit_status_KUNIT_SUCCESS, + module_name: core::ptr::null_mut(), + log: core::ptr::null_mut(), + param_init: None, + param_exit: None, + } +} + +/// Represents the NULL test case delimiter. +/// +/// The [`kunit_unsafe_test_suite!`] macro expects a NULL-terminated list of test cases. This +/// function returns such a delimiter. +#[doc(hidden)] +pub const fn kunit_case_null() -> kernel::bindings::kunit_case { + kernel::bindings::kunit_case { + run_case: None, + name: core::ptr::null_mut(), + generate_params: None, + attr: kernel::bindings::kunit_attributes { + speed: kernel::bindings::kunit_speed_KUNIT_SPEED_NORMAL, + }, + status: kernel::bindings::kunit_status_KUNIT_SUCCESS, + module_name: core::ptr::null_mut(), + log: core::ptr::null_mut(), + param_init: None, + param_exit: None, + } +} + +/// Registers a KUnit test suite. +/// +/// # Safety +/// +/// `test_cases` must be a NULL terminated array of valid test cases, +/// whose lifetime is at least that of the test suite (i.e., static). +/// +/// # Examples +/// +/// ```ignore +/// extern "C" fn test_fn(_test: *mut kernel::bindings::kunit) { +/// let actual = 1 + 1; +/// let expected = 2; +/// assert_eq!(actual, expected); +/// } +/// +/// static mut KUNIT_TEST_CASES: [kernel::bindings::kunit_case; 2] = [ +/// kernel::kunit::kunit_case(kernel::c_str!("name"), test_fn), +/// kernel::kunit::kunit_case_null(), +/// ]; +/// kernel::kunit_unsafe_test_suite!(suite_name, KUNIT_TEST_CASES); +/// ``` +#[doc(hidden)] +#[macro_export] +macro_rules! kunit_unsafe_test_suite { + ($name:ident, $test_cases:ident) => { + const _: () = { + const KUNIT_TEST_SUITE_NAME: [::kernel::ffi::c_char; 256] = { + let name_u8 = ::core::stringify!($name).as_bytes(); + let mut ret = [0; 256]; + + if name_u8.len() > 255 { + panic!(concat!( + "The test suite name `", + ::core::stringify!($name), + "` exceeds the maximum length of 255 bytes." + )); + } + + let mut i = 0; + while i < name_u8.len() { + ret[i] = name_u8[i] as ::kernel::ffi::c_char; + i += 1; + } + + ret + }; + + static mut KUNIT_TEST_SUITE: ::kernel::bindings::kunit_suite = + ::kernel::bindings::kunit_suite { + name: KUNIT_TEST_SUITE_NAME, + #[allow(unused_unsafe)] + // SAFETY: `$test_cases` is passed in by the user, and + // (as documented) must be valid for the lifetime of + // the suite (i.e., static). + test_cases: unsafe { + ::core::ptr::addr_of_mut!($test_cases) + .cast::<::kernel::bindings::kunit_case>() + }, + suite_init: None, + suite_exit: None, + init: None, + exit: None, + attr: ::kernel::bindings::kunit_attributes { + speed: ::kernel::bindings::kunit_speed_KUNIT_SPEED_NORMAL, + }, + status_comment: [0; 256usize], + debugfs: ::core::ptr::null_mut(), + log: ::core::ptr::null_mut(), + suite_init_err: 0, + is_init: false, + }; + + #[used(compiler)] + #[allow(unused_unsafe)] + #[cfg_attr(not(target_os = "macos"), link_section = ".kunit_test_suites")] + static mut KUNIT_TEST_SUITE_ENTRY: *const ::kernel::bindings::kunit_suite = + // SAFETY: `KUNIT_TEST_SUITE` is static. + unsafe { ::core::ptr::addr_of_mut!(KUNIT_TEST_SUITE) }; + }; + }; +} + +/// Returns whether we are currently running a KUnit test. +/// +/// In some cases, you need to call test-only code from outside the test case, for example, to +/// create a function mock. This function allows to change behavior depending on whether we are +/// currently running a KUnit test or not. +/// +/// # Examples +/// +/// This example shows how a function can be mocked to return a well-known value while testing: +/// +/// ``` +/// # use kernel::kunit::in_kunit_test; +/// fn fn_mock_example(n: i32) -> i32 { +/// if in_kunit_test() { +/// return 100; +/// } +/// +/// n + 1 +/// } +/// +/// let mock_res = fn_mock_example(5); +/// assert_eq!(mock_res, 100); +/// ``` +pub fn in_kunit_test() -> bool { + // SAFETY: `kunit_get_current_test()` is always safe to call (it has fallbacks for + // when KUnit is not enabled). + !unsafe { bindings::kunit_get_current_test() }.is_null() +} + +#[kunit_tests(rust_kernel_kunit)] +mod tests { + use super::*; + + #[test] + fn rust_test_kunit_example_test() { + assert_eq!(1 + 1, 2); + } + + #[test] + fn rust_test_kunit_in_kunit_test() { + assert!(in_kunit_test()); + } + + #[test] + #[cfg(not(all()))] + fn rust_test_kunit_always_disabled_test() { + // This test should never run because of the `cfg`. + assert!(false); + } +} diff --git a/rust/kernel/lib.rs b/rust/kernel/lib.rs index 53040fa9e897..f812cf120042 100644 --- a/rust/kernel/lib.rs +++ b/rust/kernel/lib.rs @@ -6,39 +6,158 @@ //! usage by Rust code in the kernel and is shared by all of them. //! //! In other words, all the rest of the Rust code in the kernel (e.g. kernel -//! modules written in Rust) depends on [`core`], [`alloc`] and this crate. +//! modules written in Rust) depends on [`core`] and this crate. //! //! If you need a kernel C API that is not ported or wrapped yet here, then //! do so first instead of bypassing this crate. #![no_std] -#![feature(allocator_api)] -#![feature(core_ffi_c)] +// +// Please see https://github.com/Rust-for-Linux/linux/issues/2 for details on +// the unstable features in use. +// +// Stable since Rust 1.79.0. +#![feature(generic_nonzero)] +#![feature(inline_const)] +#![feature(pointer_is_aligned)] +// +// Stable since Rust 1.80.0. +#![feature(slice_flatten)] +// +// Stable since Rust 1.81.0. +#![feature(lint_reasons)] +// +// Stable since Rust 1.82.0. +#![feature(raw_ref_op)] +// +// Stable since Rust 1.83.0. +#![feature(const_maybe_uninit_as_mut_ptr)] +#![feature(const_mut_refs)] +#![feature(const_option)] +#![feature(const_ptr_write)] +#![feature(const_refs_to_cell)] +// +// Expected to become stable. +#![feature(arbitrary_self_types)] +// +// To be determined. +#![feature(used_with_arg)] +// +// `feature(derive_coerce_pointee)` is expected to become stable. Before Rust +// 1.84.0, it did not exist, so enable the predecessor features. +#![cfg_attr(CONFIG_RUSTC_HAS_COERCE_POINTEE, feature(derive_coerce_pointee))] +#![cfg_attr(not(CONFIG_RUSTC_HAS_COERCE_POINTEE), feature(coerce_unsized))] +#![cfg_attr(not(CONFIG_RUSTC_HAS_COERCE_POINTEE), feature(dispatch_from_dyn))] +#![cfg_attr(not(CONFIG_RUSTC_HAS_COERCE_POINTEE), feature(unsize))] +// +// `feature(file_with_nul)` is expected to become stable. Before Rust 1.89.0, it did not exist, so +// enable it conditionally. +#![cfg_attr(CONFIG_RUSTC_HAS_FILE_WITH_NUL, feature(file_with_nul))] // Ensure conditional compilation based on the kernel configuration works; // otherwise we may silently break things like initcall handling. #[cfg(not(CONFIG_RUST))] compile_error!("Missing kernel configuration for conditional compilation"); -#[cfg(not(test))] -#[cfg(not(testlib))] -mod allocator; -mod build_assert; +// Allow proc-macros to refer to `::kernel` inside the `kernel` crate (this crate). +extern crate self as kernel; + +pub use ffi; + +pub mod acpi; +pub mod alloc; +#[cfg(CONFIG_AUXILIARY_BUS)] +pub mod auxiliary; +pub mod bitmap; +pub mod bits; +#[cfg(CONFIG_BLOCK)] +pub mod block; +pub mod bug; +#[doc(hidden)] +pub mod build_assert; +pub mod clk; +#[cfg(CONFIG_CONFIGFS_FS)] +pub mod configfs; +pub mod cpu; +#[cfg(CONFIG_CPU_FREQ)] +pub mod cpufreq; +pub mod cpumask; +pub mod cred; +pub mod debugfs; +pub mod device; +pub mod device_id; +pub mod devres; +pub mod dma; +pub mod driver; +#[cfg(CONFIG_DRM = "y")] +pub mod drm; pub mod error; +pub mod faux; +#[cfg(CONFIG_RUST_FW_LOADER_ABSTRACTIONS)] +pub mod firmware; +pub mod fmt; +pub mod fs; +#[cfg(CONFIG_I2C = "y")] +pub mod i2c; +pub mod id_pool; +pub mod init; +pub mod io; +pub mod ioctl; +pub mod iov; +pub mod irq; +pub mod jump_label; +#[cfg(CONFIG_KUNIT)] +pub mod kunit; +pub mod list; +pub mod maple_tree; +pub mod miscdevice; +pub mod mm; +pub mod module_param; +#[cfg(CONFIG_NET)] +pub mod net; +pub mod num; +pub mod of; +#[cfg(CONFIG_PM_OPP)] +pub mod opp; +pub mod page; +#[cfg(CONFIG_PCI)] +pub mod pci; +pub mod pid_namespace; +pub mod platform; pub mod prelude; pub mod print; +pub mod processor; +pub mod ptr; +#[cfg(CONFIG_RUST_PWM_ABSTRACTIONS)] +pub mod pwm; +pub mod rbtree; +pub mod regulator; +pub mod revocable; +pub mod scatterlist; +pub mod security; +pub mod seq_file; +pub mod sizes; +pub mod slice; mod static_assert; #[doc(hidden)] pub mod std_vendor; pub mod str; +pub mod sync; +pub mod task; +pub mod time; +pub mod tracepoint; +pub mod transmute; pub mod types; +pub mod uaccess; +#[cfg(CONFIG_USB = "y")] +pub mod usb; +pub mod workqueue; +pub mod xarray; #[doc(hidden)] pub use bindings; pub use macros; - -#[doc(hidden)] -pub use build_error::build_error; +pub use uapi; /// Prefix to appear before log messages printed from within the `kernel` crate. const __LOG_PREFIX: &[u8] = b"rust_kernel\0"; @@ -46,7 +165,7 @@ const __LOG_PREFIX: &[u8] = b"rust_kernel\0"; /// The top level entrypoint to implementing a kernel module. /// /// For any teardown or cleanup operations, your type may implement [`Drop`]. -pub trait Module: Sized + Sync { +pub trait Module: Sized + Sync + Send { /// Called at module initialization time. /// /// Use this method to perform whatever setup or registration your module @@ -56,9 +175,38 @@ pub trait Module: Sized + Sync { fn init(module: &'static ThisModule) -> error::Result<Self>; } +/// A module that is pinned and initialised in-place. +pub trait InPlaceModule: Sync + Send { + /// Creates an initialiser for the module. + /// + /// It is called when the module is loaded. + fn init(module: &'static ThisModule) -> impl pin_init::PinInit<Self, error::Error>; +} + +impl<T: Module> InPlaceModule for T { + fn init(module: &'static ThisModule) -> impl pin_init::PinInit<Self, error::Error> { + let initer = move |slot: *mut Self| { + let m = <Self as Module>::init(module)?; + + // SAFETY: `slot` is valid for write per the contract with `pin_init_from_closure`. + unsafe { slot.write(m) }; + Ok(()) + }; + + // SAFETY: On success, `initer` always fully initialises an instance of `Self`. + unsafe { pin_init::pin_init_from_closure(initer) } + } +} + +/// Metadata attached to a [`Module`] or [`InPlaceModule`]. +pub trait ModuleMetadata { + /// The name of the module as specified in the `module!` macro. + const NAME: &'static crate::str::CStr; +} + /// Equivalent to `THIS_MODULE` in the C API. /// -/// C header: `include/linux/export.h` +/// C header: [`include/linux/init.h`](srctree/include/linux/init.h) pub struct ThisModule(*mut bindings::module); // SAFETY: `THIS_MODULE` may be used from all threads within a module. @@ -73,15 +221,148 @@ impl ThisModule { pub const unsafe fn from_ptr(ptr: *mut bindings::module) -> ThisModule { ThisModule(ptr) } + + /// Access the raw pointer for this module. + /// + /// It is up to the user to use it correctly. + pub const fn as_ptr(&self) -> *mut bindings::module { + self.0 + } } -#[cfg(not(any(testlib, test)))] +#[cfg(not(testlib))] #[panic_handler] fn panic(info: &core::panic::PanicInfo<'_>) -> ! { pr_emerg!("{}\n", info); // SAFETY: FFI call. unsafe { bindings::BUG() }; - // Bindgen currently does not recognize `__noreturn` so `BUG` returns `()` - // instead of `!`. See <https://github.com/rust-lang/rust-bindgen/issues/2094>. - loop {} +} + +/// Produces a pointer to an object from a pointer to one of its fields. +/// +/// If you encounter a type mismatch due to the [`Opaque`] type, then use [`Opaque::cast_into`] or +/// [`Opaque::cast_from`] to resolve the mismatch. +/// +/// [`Opaque`]: crate::types::Opaque +/// [`Opaque::cast_into`]: crate::types::Opaque::cast_into +/// [`Opaque::cast_from`]: crate::types::Opaque::cast_from +/// +/// # Safety +/// +/// The pointer passed to this macro, and the pointer returned by this macro, must both be in +/// bounds of the same allocation. +/// +/// # Examples +/// +/// ``` +/// # use kernel::container_of; +/// struct Test { +/// a: u64, +/// b: u32, +/// } +/// +/// let test = Test { a: 10, b: 20 }; +/// let b_ptr: *const _ = &test.b; +/// // SAFETY: The pointer points at the `b` field of a `Test`, so the resulting pointer will be +/// // in-bounds of the same allocation as `b_ptr`. +/// let test_alias = unsafe { container_of!(b_ptr, Test, b) }; +/// assert!(core::ptr::eq(&test, test_alias)); +/// ``` +#[macro_export] +macro_rules! container_of { + ($field_ptr:expr, $Container:ty, $($fields:tt)*) => {{ + let offset: usize = ::core::mem::offset_of!($Container, $($fields)*); + let field_ptr = $field_ptr; + let container_ptr = field_ptr.byte_sub(offset).cast::<$Container>(); + $crate::assert_same_type(field_ptr, (&raw const (*container_ptr).$($fields)*).cast_mut()); + container_ptr + }} +} + +/// Helper for [`container_of!`]. +#[doc(hidden)] +pub fn assert_same_type<T>(_: T, _: T) {} + +/// Helper for `.rs.S` files. +#[doc(hidden)] +#[macro_export] +macro_rules! concat_literals { + ($( $asm:literal )* ) => { + ::core::concat!($($asm),*) + }; +} + +/// Wrapper around `asm!` configured for use in the kernel. +/// +/// Uses a semicolon to avoid parsing ambiguities, even though this does not match native `asm!` +/// syntax. +// For x86, `asm!` uses intel syntax by default, but we want to use at&t syntax in the kernel. +#[cfg(any(target_arch = "x86", target_arch = "x86_64"))] +#[macro_export] +macro_rules! asm { + ($($asm:expr),* ; $($rest:tt)*) => { + ::core::arch::asm!( $($asm)*, options(att_syntax), $($rest)* ) + }; +} + +/// Wrapper around `asm!` configured for use in the kernel. +/// +/// Uses a semicolon to avoid parsing ambiguities, even though this does not match native `asm!` +/// syntax. +// For non-x86 arches we just pass through to `asm!`. +#[cfg(not(any(target_arch = "x86", target_arch = "x86_64")))] +#[macro_export] +macro_rules! asm { + ($($asm:expr),* ; $($rest:tt)*) => { + ::core::arch::asm!( $($asm)*, $($rest)* ) + }; +} + +/// Gets the C string file name of a [`Location`]. +/// +/// If `Location::file_as_c_str()` is not available, returns a string that warns about it. +/// +/// [`Location`]: core::panic::Location +/// +/// # Examples +/// +/// ``` +/// # use kernel::file_from_location; +/// +/// #[track_caller] +/// fn foo() { +/// let caller = core::panic::Location::caller(); +/// +/// // Output: +/// // - A path like "rust/kernel/example.rs" if `file_as_c_str()` is available. +/// // - "<Location::file_as_c_str() not supported>" otherwise. +/// let caller_file = file_from_location(caller); +/// +/// // Prints out the message with caller's file name. +/// pr_info!("foo() called in file {caller_file:?}\n"); +/// +/// # if cfg!(CONFIG_RUSTC_HAS_FILE_WITH_NUL) { +/// # assert_eq!(Ok(caller.file()), caller_file.to_str()); +/// # } +/// } +/// +/// # foo(); +/// ``` +#[inline] +pub fn file_from_location<'a>(loc: &'a core::panic::Location<'a>) -> &'a core::ffi::CStr { + #[cfg(CONFIG_RUSTC_HAS_FILE_AS_C_STR)] + { + loc.file_as_c_str() + } + + #[cfg(all(CONFIG_RUSTC_HAS_FILE_WITH_NUL, not(CONFIG_RUSTC_HAS_FILE_AS_C_STR)))] + { + loc.file_with_nul() + } + + #[cfg(not(CONFIG_RUSTC_HAS_FILE_WITH_NUL))] + { + let _ = loc; + c"<Location::file_as_c_str() not supported>" + } } diff --git a/rust/kernel/list.rs b/rust/kernel/list.rs new file mode 100644 index 000000000000..8349ff32fc37 --- /dev/null +++ b/rust/kernel/list.rs @@ -0,0 +1,1206 @@ +// SPDX-License-Identifier: GPL-2.0 + +// Copyright (C) 2024 Google LLC. + +//! A linked list implementation. + +use crate::sync::ArcBorrow; +use crate::types::Opaque; +use core::iter::{DoubleEndedIterator, FusedIterator}; +use core::marker::PhantomData; +use core::ptr; +use pin_init::PinInit; + +mod impl_list_item_mod; +pub use self::impl_list_item_mod::{ + impl_has_list_links, impl_has_list_links_self_ptr, impl_list_item, HasListLinks, HasSelfPtr, +}; + +mod arc; +pub use self::arc::{impl_list_arc_safe, AtomicTracker, ListArc, ListArcSafe, TryNewListArc}; + +mod arc_field; +pub use self::arc_field::{define_list_arc_field_getter, ListArcField}; + +/// A linked list. +/// +/// All elements in this linked list will be [`ListArc`] references to the value. Since a value can +/// only have one `ListArc` (for each pair of prev/next pointers), this ensures that the same +/// prev/next pointers are not used for several linked lists. +/// +/// # Invariants +/// +/// * If the list is empty, then `first` is null. Otherwise, `first` points at the `ListLinks` +/// field of the first element in the list. +/// * All prev/next pointers in `ListLinks` fields of items in the list are valid and form a cycle. +/// * For every item in the list, the list owns the associated [`ListArc`] reference and has +/// exclusive access to the `ListLinks` field. +/// +/// # Examples +/// +/// Use [`ListLinks`] as the type of the intrusive field. +/// +/// ``` +/// use kernel::list::*; +/// +/// #[pin_data] +/// struct BasicItem { +/// value: i32, +/// #[pin] +/// links: ListLinks, +/// } +/// +/// impl BasicItem { +/// fn new(value: i32) -> Result<ListArc<Self>> { +/// ListArc::pin_init(try_pin_init!(Self { +/// value, +/// links <- ListLinks::new(), +/// }), GFP_KERNEL) +/// } +/// } +/// +/// impl_list_arc_safe! { +/// impl ListArcSafe<0> for BasicItem { untracked; } +/// } +/// impl_list_item! { +/// impl ListItem<0> for BasicItem { using ListLinks { self.links }; } +/// } +/// +/// // Create a new empty list. +/// let mut list = List::new(); +/// { +/// assert!(list.is_empty()); +/// } +/// +/// // Insert 3 elements using `push_back()`. +/// list.push_back(BasicItem::new(15)?); +/// list.push_back(BasicItem::new(10)?); +/// list.push_back(BasicItem::new(30)?); +/// +/// // Iterate over the list to verify the nodes were inserted correctly. +/// // [15, 10, 30] +/// { +/// let mut iter = list.iter(); +/// assert_eq!(iter.next().ok_or(EINVAL)?.value, 15); +/// assert_eq!(iter.next().ok_or(EINVAL)?.value, 10); +/// assert_eq!(iter.next().ok_or(EINVAL)?.value, 30); +/// assert!(iter.next().is_none()); +/// +/// // Verify the length of the list. +/// assert_eq!(list.iter().count(), 3); +/// } +/// +/// // Pop the items from the list using `pop_back()` and verify the content. +/// { +/// assert_eq!(list.pop_back().ok_or(EINVAL)?.value, 30); +/// assert_eq!(list.pop_back().ok_or(EINVAL)?.value, 10); +/// assert_eq!(list.pop_back().ok_or(EINVAL)?.value, 15); +/// } +/// +/// // Insert 3 elements using `push_front()`. +/// list.push_front(BasicItem::new(15)?); +/// list.push_front(BasicItem::new(10)?); +/// list.push_front(BasicItem::new(30)?); +/// +/// // Iterate over the list to verify the nodes were inserted correctly. +/// // [30, 10, 15] +/// { +/// let mut iter = list.iter(); +/// assert_eq!(iter.next().ok_or(EINVAL)?.value, 30); +/// assert_eq!(iter.next().ok_or(EINVAL)?.value, 10); +/// assert_eq!(iter.next().ok_or(EINVAL)?.value, 15); +/// assert!(iter.next().is_none()); +/// +/// // Verify the length of the list. +/// assert_eq!(list.iter().count(), 3); +/// } +/// +/// // Pop the items from the list using `pop_front()` and verify the content. +/// { +/// assert_eq!(list.pop_front().ok_or(EINVAL)?.value, 30); +/// assert_eq!(list.pop_front().ok_or(EINVAL)?.value, 10); +/// } +/// +/// // Push `list2` to `list` through `push_all_back()`. +/// // list: [15] +/// // list2: [25, 35] +/// { +/// let mut list2 = List::new(); +/// list2.push_back(BasicItem::new(25)?); +/// list2.push_back(BasicItem::new(35)?); +/// +/// list.push_all_back(&mut list2); +/// +/// // list: [15, 25, 35] +/// // list2: [] +/// let mut iter = list.iter(); +/// assert_eq!(iter.next().ok_or(EINVAL)?.value, 15); +/// assert_eq!(iter.next().ok_or(EINVAL)?.value, 25); +/// assert_eq!(iter.next().ok_or(EINVAL)?.value, 35); +/// assert!(iter.next().is_none()); +/// assert!(list2.is_empty()); +/// } +/// # Result::<(), Error>::Ok(()) +/// ``` +/// +/// Use [`ListLinksSelfPtr`] as the type of the intrusive field. This allows a list of trait object +/// type. +/// +/// ``` +/// use kernel::list::*; +/// +/// trait Foo { +/// fn foo(&self) -> (&'static str, i32); +/// } +/// +/// #[pin_data] +/// struct DTWrap<T: ?Sized> { +/// #[pin] +/// links: ListLinksSelfPtr<DTWrap<dyn Foo>>, +/// value: T, +/// } +/// +/// impl<T> DTWrap<T> { +/// fn new(value: T) -> Result<ListArc<Self>> { +/// ListArc::pin_init(try_pin_init!(Self { +/// value, +/// links <- ListLinksSelfPtr::new(), +/// }), GFP_KERNEL) +/// } +/// } +/// +/// impl_list_arc_safe! { +/// impl{T: ?Sized} ListArcSafe<0> for DTWrap<T> { untracked; } +/// } +/// impl_list_item! { +/// impl ListItem<0> for DTWrap<dyn Foo> { using ListLinksSelfPtr { self.links }; } +/// } +/// +/// // Create a new empty list. +/// let mut list = List::<DTWrap<dyn Foo>>::new(); +/// { +/// assert!(list.is_empty()); +/// } +/// +/// struct A(i32); +/// // `A` returns the inner value for `foo`. +/// impl Foo for A { fn foo(&self) -> (&'static str, i32) { ("a", self.0) } } +/// +/// struct B; +/// // `B` always returns 42. +/// impl Foo for B { fn foo(&self) -> (&'static str, i32) { ("b", 42) } } +/// +/// // Insert 3 element using `push_back()`. +/// list.push_back(DTWrap::new(A(15))?); +/// list.push_back(DTWrap::new(A(32))?); +/// list.push_back(DTWrap::new(B)?); +/// +/// // Iterate over the list to verify the nodes were inserted correctly. +/// // [A(15), A(32), B] +/// { +/// let mut iter = list.iter(); +/// assert_eq!(iter.next().ok_or(EINVAL)?.value.foo(), ("a", 15)); +/// assert_eq!(iter.next().ok_or(EINVAL)?.value.foo(), ("a", 32)); +/// assert_eq!(iter.next().ok_or(EINVAL)?.value.foo(), ("b", 42)); +/// assert!(iter.next().is_none()); +/// +/// // Verify the length of the list. +/// assert_eq!(list.iter().count(), 3); +/// } +/// +/// // Pop the items from the list using `pop_back()` and verify the content. +/// { +/// assert_eq!(list.pop_back().ok_or(EINVAL)?.value.foo(), ("b", 42)); +/// assert_eq!(list.pop_back().ok_or(EINVAL)?.value.foo(), ("a", 32)); +/// assert_eq!(list.pop_back().ok_or(EINVAL)?.value.foo(), ("a", 15)); +/// } +/// +/// // Insert 3 elements using `push_front()`. +/// list.push_front(DTWrap::new(A(15))?); +/// list.push_front(DTWrap::new(A(32))?); +/// list.push_front(DTWrap::new(B)?); +/// +/// // Iterate over the list to verify the nodes were inserted correctly. +/// // [B, A(32), A(15)] +/// { +/// let mut iter = list.iter(); +/// assert_eq!(iter.next().ok_or(EINVAL)?.value.foo(), ("b", 42)); +/// assert_eq!(iter.next().ok_or(EINVAL)?.value.foo(), ("a", 32)); +/// assert_eq!(iter.next().ok_or(EINVAL)?.value.foo(), ("a", 15)); +/// assert!(iter.next().is_none()); +/// +/// // Verify the length of the list. +/// assert_eq!(list.iter().count(), 3); +/// } +/// +/// // Pop the items from the list using `pop_front()` and verify the content. +/// { +/// assert_eq!(list.pop_back().ok_or(EINVAL)?.value.foo(), ("a", 15)); +/// assert_eq!(list.pop_back().ok_or(EINVAL)?.value.foo(), ("a", 32)); +/// } +/// +/// // Push `list2` to `list` through `push_all_back()`. +/// // list: [B] +/// // list2: [B, A(25)] +/// { +/// let mut list2 = List::<DTWrap<dyn Foo>>::new(); +/// list2.push_back(DTWrap::new(B)?); +/// list2.push_back(DTWrap::new(A(25))?); +/// +/// list.push_all_back(&mut list2); +/// +/// // list: [B, B, A(25)] +/// // list2: [] +/// let mut iter = list.iter(); +/// assert_eq!(iter.next().ok_or(EINVAL)?.value.foo(), ("b", 42)); +/// assert_eq!(iter.next().ok_or(EINVAL)?.value.foo(), ("b", 42)); +/// assert_eq!(iter.next().ok_or(EINVAL)?.value.foo(), ("a", 25)); +/// assert!(iter.next().is_none()); +/// assert!(list2.is_empty()); +/// } +/// # Result::<(), Error>::Ok(()) +/// ``` +pub struct List<T: ?Sized + ListItem<ID>, const ID: u64 = 0> { + first: *mut ListLinksFields, + _ty: PhantomData<ListArc<T, ID>>, +} + +// SAFETY: This is a container of `ListArc<T, ID>`, and access to the container allows the same +// type of access to the `ListArc<T, ID>` elements. +unsafe impl<T, const ID: u64> Send for List<T, ID> +where + ListArc<T, ID>: Send, + T: ?Sized + ListItem<ID>, +{ +} +// SAFETY: This is a container of `ListArc<T, ID>`, and access to the container allows the same +// type of access to the `ListArc<T, ID>` elements. +unsafe impl<T, const ID: u64> Sync for List<T, ID> +where + ListArc<T, ID>: Sync, + T: ?Sized + ListItem<ID>, +{ +} + +/// Implemented by types where a [`ListArc<Self>`] can be inserted into a [`List`]. +/// +/// # Safety +/// +/// Implementers must ensure that they provide the guarantees documented on methods provided by +/// this trait. +/// +/// [`ListArc<Self>`]: ListArc +pub unsafe trait ListItem<const ID: u64 = 0>: ListArcSafe<ID> { + /// Views the [`ListLinks`] for this value. + /// + /// # Guarantees + /// + /// If there is a previous call to `prepare_to_insert` and there is no call to `post_remove` + /// since the most recent such call, then this returns the same pointer as the one returned by + /// the most recent call to `prepare_to_insert`. + /// + /// Otherwise, the returned pointer points at a read-only [`ListLinks`] with two null pointers. + /// + /// # Safety + /// + /// The provided pointer must point at a valid value. (It need not be in an `Arc`.) + unsafe fn view_links(me: *const Self) -> *mut ListLinks<ID>; + + /// View the full value given its [`ListLinks`] field. + /// + /// Can only be used when the value is in a list. + /// + /// # Guarantees + /// + /// * Returns the same pointer as the one passed to the most recent call to `prepare_to_insert`. + /// * The returned pointer is valid until the next call to `post_remove`. + /// + /// # Safety + /// + /// * The provided pointer must originate from the most recent call to `prepare_to_insert`, or + /// from a call to `view_links` that happened after the most recent call to + /// `prepare_to_insert`. + /// * Since the most recent call to `prepare_to_insert`, the `post_remove` method must not have + /// been called. + unsafe fn view_value(me: *mut ListLinks<ID>) -> *const Self; + + /// This is called when an item is inserted into a [`List`]. + /// + /// # Guarantees + /// + /// The caller is granted exclusive access to the returned [`ListLinks`] until `post_remove` is + /// called. + /// + /// # Safety + /// + /// * The provided pointer must point at a valid value in an [`Arc`]. + /// * Calls to `prepare_to_insert` and `post_remove` on the same value must alternate. + /// * The caller must own the [`ListArc`] for this value. + /// * The caller must not give up ownership of the [`ListArc`] unless `post_remove` has been + /// called after this call to `prepare_to_insert`. + /// + /// [`Arc`]: crate::sync::Arc + unsafe fn prepare_to_insert(me: *const Self) -> *mut ListLinks<ID>; + + /// This undoes a previous call to `prepare_to_insert`. + /// + /// # Guarantees + /// + /// The returned pointer is the pointer that was originally passed to `prepare_to_insert`. + /// + /// # Safety + /// + /// The provided pointer must be the pointer returned by the most recent call to + /// `prepare_to_insert`. + unsafe fn post_remove(me: *mut ListLinks<ID>) -> *const Self; +} + +#[repr(C)] +#[derive(Copy, Clone)] +struct ListLinksFields { + next: *mut ListLinksFields, + prev: *mut ListLinksFields, +} + +/// The prev/next pointers for an item in a linked list. +/// +/// # Invariants +/// +/// The fields are null if and only if this item is not in a list. +#[repr(transparent)] +pub struct ListLinks<const ID: u64 = 0> { + // This type is `!Unpin` for aliasing reasons as the pointers are part of an intrusive linked + // list. + inner: Opaque<ListLinksFields>, +} + +// SAFETY: The only way to access/modify the pointers inside of `ListLinks<ID>` is via holding the +// associated `ListArc<T, ID>`. Since that type correctly implements `Send`, it is impossible to +// move this an instance of this type to a different thread if the pointees are `!Send`. +unsafe impl<const ID: u64> Send for ListLinks<ID> {} +// SAFETY: The type is opaque so immutable references to a ListLinks are useless. Therefore, it's +// okay to have immutable access to a ListLinks from several threads at once. +unsafe impl<const ID: u64> Sync for ListLinks<ID> {} + +impl<const ID: u64> ListLinks<ID> { + /// Creates a new initializer for this type. + pub fn new() -> impl PinInit<Self> { + // INVARIANT: Pin-init initializers can't be used on an existing `Arc`, so this value will + // not be constructed in an `Arc` that already has a `ListArc`. + ListLinks { + inner: Opaque::new(ListLinksFields { + prev: ptr::null_mut(), + next: ptr::null_mut(), + }), + } + } + + /// # Safety + /// + /// `me` must be dereferenceable. + #[inline] + unsafe fn fields(me: *mut Self) -> *mut ListLinksFields { + // SAFETY: The caller promises that the pointer is valid. + unsafe { Opaque::cast_into(ptr::addr_of!((*me).inner)) } + } + + /// # Safety + /// + /// `me` must be dereferenceable. + #[inline] + unsafe fn from_fields(me: *mut ListLinksFields) -> *mut Self { + me.cast() + } +} + +/// Similar to [`ListLinks`], but also contains a pointer to the full value. +/// +/// This type can be used instead of [`ListLinks`] to support lists with trait objects. +#[repr(C)] +pub struct ListLinksSelfPtr<T: ?Sized, const ID: u64 = 0> { + /// The `ListLinks` field inside this value. + /// + /// This is public so that it can be used with `impl_has_list_links!`. + pub inner: ListLinks<ID>, + // UnsafeCell is not enough here because we use `Opaque::uninit` as a dummy value, and + // `ptr::null()` doesn't work for `T: ?Sized`. + self_ptr: Opaque<*const T>, +} + +// SAFETY: The fields of a ListLinksSelfPtr can be moved across thread boundaries. +unsafe impl<T: ?Sized + Send, const ID: u64> Send for ListLinksSelfPtr<T, ID> {} +// SAFETY: The type is opaque so immutable references to a ListLinksSelfPtr are useless. Therefore, +// it's okay to have immutable access to a ListLinks from several threads at once. +// +// Note that `inner` being a public field does not prevent this type from being opaque, since +// `inner` is a opaque type. +unsafe impl<T: ?Sized + Sync, const ID: u64> Sync for ListLinksSelfPtr<T, ID> {} + +impl<T: ?Sized, const ID: u64> ListLinksSelfPtr<T, ID> { + /// Creates a new initializer for this type. + pub fn new() -> impl PinInit<Self> { + // INVARIANT: Pin-init initializers can't be used on an existing `Arc`, so this value will + // not be constructed in an `Arc` that already has a `ListArc`. + Self { + inner: ListLinks { + inner: Opaque::new(ListLinksFields { + prev: ptr::null_mut(), + next: ptr::null_mut(), + }), + }, + self_ptr: Opaque::uninit(), + } + } + + /// Returns a pointer to the self pointer. + /// + /// # Safety + /// + /// The provided pointer must point at a valid struct of type `Self`. + pub unsafe fn raw_get_self_ptr(me: *const Self) -> *const Opaque<*const T> { + // SAFETY: The caller promises that the pointer is valid. + unsafe { ptr::addr_of!((*me).self_ptr) } + } +} + +impl<T: ?Sized + ListItem<ID>, const ID: u64> List<T, ID> { + /// Creates a new empty list. + pub const fn new() -> Self { + Self { + first: ptr::null_mut(), + _ty: PhantomData, + } + } + + /// Returns whether this list is empty. + pub fn is_empty(&self) -> bool { + self.first.is_null() + } + + /// Inserts `item` before `next` in the cycle. + /// + /// Returns a pointer to the newly inserted element. Never changes `self.first` unless the list + /// is empty. + /// + /// # Safety + /// + /// * `next` must be an element in this list or null. + /// * if `next` is null, then the list must be empty. + unsafe fn insert_inner( + &mut self, + item: ListArc<T, ID>, + next: *mut ListLinksFields, + ) -> *mut ListLinksFields { + let raw_item = ListArc::into_raw(item); + // SAFETY: + // * We just got `raw_item` from a `ListArc`, so it's in an `Arc`. + // * Since we have ownership of the `ListArc`, `post_remove` must have been called after + // the most recent call to `prepare_to_insert`, if any. + // * We own the `ListArc`. + // * Removing items from this list is always done using `remove_internal_inner`, which + // calls `post_remove` before giving up ownership. + let list_links = unsafe { T::prepare_to_insert(raw_item) }; + // SAFETY: We have not yet called `post_remove`, so `list_links` is still valid. + let item = unsafe { ListLinks::fields(list_links) }; + + // Check if the list is empty. + if next.is_null() { + // SAFETY: The caller just gave us ownership of these fields. + // INVARIANT: A linked list with one item should be cyclic. + unsafe { + (*item).next = item; + (*item).prev = item; + } + self.first = item; + } else { + // SAFETY: By the type invariant, this pointer is valid or null. We just checked that + // it's not null, so it must be valid. + let prev = unsafe { (*next).prev }; + // SAFETY: Pointers in a linked list are never dangling, and the caller just gave us + // ownership of the fields on `item`. + // INVARIANT: This correctly inserts `item` between `prev` and `next`. + unsafe { + (*item).next = next; + (*item).prev = prev; + (*prev).next = item; + (*next).prev = item; + } + } + + item + } + + /// Add the provided item to the back of the list. + pub fn push_back(&mut self, item: ListArc<T, ID>) { + // SAFETY: + // * `self.first` is null or in the list. + // * `self.first` is only null if the list is empty. + unsafe { self.insert_inner(item, self.first) }; + } + + /// Add the provided item to the front of the list. + pub fn push_front(&mut self, item: ListArc<T, ID>) { + // SAFETY: + // * `self.first` is null or in the list. + // * `self.first` is only null if the list is empty. + let new_elem = unsafe { self.insert_inner(item, self.first) }; + + // INVARIANT: `new_elem` is in the list because we just inserted it. + self.first = new_elem; + } + + /// Removes the last item from this list. + pub fn pop_back(&mut self) -> Option<ListArc<T, ID>> { + if self.is_empty() { + return None; + } + + // SAFETY: We just checked that the list is not empty. + let last = unsafe { (*self.first).prev }; + // SAFETY: The last item of this list is in this list. + Some(unsafe { self.remove_internal(last) }) + } + + /// Removes the first item from this list. + pub fn pop_front(&mut self) -> Option<ListArc<T, ID>> { + if self.is_empty() { + return None; + } + + // SAFETY: The first item of this list is in this list. + Some(unsafe { self.remove_internal(self.first) }) + } + + /// Removes the provided item from this list and returns it. + /// + /// This returns `None` if the item is not in the list. (Note that by the safety requirements, + /// this means that the item is not in any list.) + /// + /// When using this method, be careful with using `mem::take` on the same list as that may + /// result in violating the safety requirements of this method. + /// + /// # Safety + /// + /// `item` must not be in a different linked list (with the same id). + pub unsafe fn remove(&mut self, item: &T) -> Option<ListArc<T, ID>> { + // SAFETY: TODO. + let mut item = unsafe { ListLinks::fields(T::view_links(item)) }; + // SAFETY: The user provided a reference, and reference are never dangling. + // + // As for why this is not a data race, there are two cases: + // + // * If `item` is not in any list, then these fields are read-only and null. + // * If `item` is in this list, then we have exclusive access to these fields since we + // have a mutable reference to the list. + // + // In either case, there's no race. + let ListLinksFields { next, prev } = unsafe { *item }; + + debug_assert_eq!(next.is_null(), prev.is_null()); + if !next.is_null() { + // This is really a no-op, but this ensures that `item` is a raw pointer that was + // obtained without going through a pointer->reference->pointer conversion roundtrip. + // This ensures that the list is valid under the more restrictive strict provenance + // ruleset. + // + // SAFETY: We just checked that `next` is not null, and it's not dangling by the + // list invariants. + unsafe { + debug_assert_eq!(item, (*next).prev); + item = (*next).prev; + } + + // SAFETY: We just checked that `item` is in a list, so the caller guarantees that it + // is in this list. The pointers are in the right order. + Some(unsafe { self.remove_internal_inner(item, next, prev) }) + } else { + None + } + } + + /// Removes the provided item from the list. + /// + /// # Safety + /// + /// `item` must point at an item in this list. + unsafe fn remove_internal(&mut self, item: *mut ListLinksFields) -> ListArc<T, ID> { + // SAFETY: The caller promises that this pointer is not dangling, and there's no data race + // since we have a mutable reference to the list containing `item`. + let ListLinksFields { next, prev } = unsafe { *item }; + // SAFETY: The pointers are ok and in the right order. + unsafe { self.remove_internal_inner(item, next, prev) } + } + + /// Removes the provided item from the list. + /// + /// # Safety + /// + /// The `item` pointer must point at an item in this list, and we must have `(*item).next == + /// next` and `(*item).prev == prev`. + unsafe fn remove_internal_inner( + &mut self, + item: *mut ListLinksFields, + next: *mut ListLinksFields, + prev: *mut ListLinksFields, + ) -> ListArc<T, ID> { + // SAFETY: We have exclusive access to the pointers of items in the list, and the prev/next + // pointers are always valid for items in a list. + // + // INVARIANT: There are three cases: + // * If the list has at least three items, then after removing the item, `prev` and `next` + // will be next to each other. + // * If the list has two items, then the remaining item will point at itself. + // * If the list has one item, then `next == prev == item`, so these writes have no + // effect. The list remains unchanged and `item` is still in the list for now. + unsafe { + (*next).prev = prev; + (*prev).next = next; + } + // SAFETY: We have exclusive access to items in the list. + // INVARIANT: `item` is being removed, so the pointers should be null. + unsafe { + (*item).prev = ptr::null_mut(); + (*item).next = ptr::null_mut(); + } + // INVARIANT: There are three cases: + // * If `item` was not the first item, then `self.first` should remain unchanged. + // * If `item` was the first item and there is another item, then we just updated + // `prev->next` to `next`, which is the new first item, and setting `item->next` to null + // did not modify `prev->next`. + // * If `item` was the only item in the list, then `prev == item`, and we just set + // `item->next` to null, so this correctly sets `first` to null now that the list is + // empty. + if self.first == item { + // SAFETY: The `prev` pointer is the value that `item->prev` had when it was in this + // list, so it must be valid. There is no race since `prev` is still in the list and we + // still have exclusive access to the list. + self.first = unsafe { (*prev).next }; + } + + // SAFETY: `item` used to be in the list, so it is dereferenceable by the type invariants + // of `List`. + let list_links = unsafe { ListLinks::from_fields(item) }; + // SAFETY: Any pointer in the list originates from a `prepare_to_insert` call. + let raw_item = unsafe { T::post_remove(list_links) }; + // SAFETY: The above call to `post_remove` guarantees that we can recreate the `ListArc`. + unsafe { ListArc::from_raw(raw_item) } + } + + /// Moves all items from `other` into `self`. + /// + /// The items of `other` are added to the back of `self`, so the last item of `other` becomes + /// the last item of `self`. + pub fn push_all_back(&mut self, other: &mut List<T, ID>) { + // First, we insert the elements into `self`. At the end, we make `other` empty. + if self.is_empty() { + // INVARIANT: All of the elements in `other` become elements of `self`. + self.first = other.first; + } else if !other.is_empty() { + let other_first = other.first; + // SAFETY: The other list is not empty, so this pointer is valid. + let other_last = unsafe { (*other_first).prev }; + let self_first = self.first; + // SAFETY: The self list is not empty, so this pointer is valid. + let self_last = unsafe { (*self_first).prev }; + + // SAFETY: We have exclusive access to both lists, so we can update the pointers. + // INVARIANT: This correctly sets the pointers to merge both lists. We do not need to + // update `self.first` because the first element of `self` does not change. + unsafe { + (*self_first).prev = other_last; + (*other_last).next = self_first; + (*self_last).next = other_first; + (*other_first).prev = self_last; + } + } + + // INVARIANT: The other list is now empty, so update its pointer. + other.first = ptr::null_mut(); + } + + /// Returns a cursor that points before the first element of the list. + pub fn cursor_front(&mut self) -> Cursor<'_, T, ID> { + // INVARIANT: `self.first` is in this list. + Cursor { + next: self.first, + list: self, + } + } + + /// Returns a cursor that points after the last element in the list. + pub fn cursor_back(&mut self) -> Cursor<'_, T, ID> { + // INVARIANT: `next` is allowed to be null. + Cursor { + next: core::ptr::null_mut(), + list: self, + } + } + + /// Creates an iterator over the list. + pub fn iter(&self) -> Iter<'_, T, ID> { + // INVARIANT: If the list is empty, both pointers are null. Otherwise, both pointers point + // at the first element of the same list. + Iter { + current: self.first, + stop: self.first, + _ty: PhantomData, + } + } +} + +impl<T: ?Sized + ListItem<ID>, const ID: u64> Default for List<T, ID> { + fn default() -> Self { + List::new() + } +} + +impl<T: ?Sized + ListItem<ID>, const ID: u64> Drop for List<T, ID> { + fn drop(&mut self) { + while let Some(item) = self.pop_front() { + drop(item); + } + } +} + +/// An iterator over a [`List`]. +/// +/// # Invariants +/// +/// * There must be a [`List`] that is immutably borrowed for the duration of `'a`. +/// * The `current` pointer is null or points at a value in that [`List`]. +/// * The `stop` pointer is equal to the `first` field of that [`List`]. +#[derive(Clone)] +pub struct Iter<'a, T: ?Sized + ListItem<ID>, const ID: u64 = 0> { + current: *mut ListLinksFields, + stop: *mut ListLinksFields, + _ty: PhantomData<&'a ListArc<T, ID>>, +} + +impl<'a, T: ?Sized + ListItem<ID>, const ID: u64> Iterator for Iter<'a, T, ID> { + type Item = ArcBorrow<'a, T>; + + fn next(&mut self) -> Option<ArcBorrow<'a, T>> { + if self.current.is_null() { + return None; + } + + let current = self.current; + + // SAFETY: We just checked that `current` is not null, so it is in a list, and hence not + // dangling. There's no race because the iterator holds an immutable borrow to the list. + let next = unsafe { (*current).next }; + // INVARIANT: If `current` was the last element of the list, then this updates it to null. + // Otherwise, we update it to the next element. + self.current = if next != self.stop { + next + } else { + ptr::null_mut() + }; + + // SAFETY: The `current` pointer points at a value in the list. + let item = unsafe { T::view_value(ListLinks::from_fields(current)) }; + // SAFETY: + // * All values in a list are stored in an `Arc`. + // * The value cannot be removed from the list for the duration of the lifetime annotated + // on the returned `ArcBorrow`, because removing it from the list would require mutable + // access to the list. However, the `ArcBorrow` is annotated with the iterator's + // lifetime, and the list is immutably borrowed for that lifetime. + // * Values in a list never have a `UniqueArc` reference. + Some(unsafe { ArcBorrow::from_raw(item) }) + } +} + +/// A cursor into a [`List`]. +/// +/// A cursor always rests between two elements in the list. This means that a cursor has a previous +/// and next element, but no current element. It also means that it's possible to have a cursor +/// into an empty list. +/// +/// # Examples +/// +/// ``` +/// use kernel::prelude::*; +/// use kernel::list::{List, ListArc, ListLinks}; +/// +/// #[pin_data] +/// struct ListItem { +/// value: u32, +/// #[pin] +/// links: ListLinks, +/// } +/// +/// impl ListItem { +/// fn new(value: u32) -> Result<ListArc<Self>> { +/// ListArc::pin_init(try_pin_init!(Self { +/// value, +/// links <- ListLinks::new(), +/// }), GFP_KERNEL) +/// } +/// } +/// +/// kernel::list::impl_list_arc_safe! { +/// impl ListArcSafe<0> for ListItem { untracked; } +/// } +/// kernel::list::impl_list_item! { +/// impl ListItem<0> for ListItem { using ListLinks { self.links }; } +/// } +/// +/// // Use a cursor to remove the first element with the given value. +/// fn remove_first(list: &mut List<ListItem>, value: u32) -> Option<ListArc<ListItem>> { +/// let mut cursor = list.cursor_front(); +/// while let Some(next) = cursor.peek_next() { +/// if next.value == value { +/// return Some(next.remove()); +/// } +/// cursor.move_next(); +/// } +/// None +/// } +/// +/// // Use a cursor to remove the last element with the given value. +/// fn remove_last(list: &mut List<ListItem>, value: u32) -> Option<ListArc<ListItem>> { +/// let mut cursor = list.cursor_back(); +/// while let Some(prev) = cursor.peek_prev() { +/// if prev.value == value { +/// return Some(prev.remove()); +/// } +/// cursor.move_prev(); +/// } +/// None +/// } +/// +/// // Use a cursor to remove all elements with the given value. The removed elements are moved to +/// // a new list. +/// fn remove_all(list: &mut List<ListItem>, value: u32) -> List<ListItem> { +/// let mut out = List::new(); +/// let mut cursor = list.cursor_front(); +/// while let Some(next) = cursor.peek_next() { +/// if next.value == value { +/// out.push_back(next.remove()); +/// } else { +/// cursor.move_next(); +/// } +/// } +/// out +/// } +/// +/// // Use a cursor to insert a value at a specific index. Returns an error if the index is out of +/// // bounds. +/// fn insert_at(list: &mut List<ListItem>, new: ListArc<ListItem>, idx: usize) -> Result { +/// let mut cursor = list.cursor_front(); +/// for _ in 0..idx { +/// if !cursor.move_next() { +/// return Err(EINVAL); +/// } +/// } +/// cursor.insert_next(new); +/// Ok(()) +/// } +/// +/// // Merge two sorted lists into a single sorted list. +/// fn merge_sorted(list: &mut List<ListItem>, merge: List<ListItem>) { +/// let mut cursor = list.cursor_front(); +/// for to_insert in merge { +/// while let Some(next) = cursor.peek_next() { +/// if to_insert.value < next.value { +/// break; +/// } +/// cursor.move_next(); +/// } +/// cursor.insert_prev(to_insert); +/// } +/// } +/// +/// let mut list = List::new(); +/// list.push_back(ListItem::new(14)?); +/// list.push_back(ListItem::new(12)?); +/// list.push_back(ListItem::new(10)?); +/// list.push_back(ListItem::new(12)?); +/// list.push_back(ListItem::new(15)?); +/// list.push_back(ListItem::new(14)?); +/// assert_eq!(remove_all(&mut list, 12).iter().count(), 2); +/// // [14, 10, 15, 14] +/// assert!(remove_first(&mut list, 14).is_some()); +/// // [10, 15, 14] +/// insert_at(&mut list, ListItem::new(12)?, 2)?; +/// // [10, 15, 12, 14] +/// assert!(remove_last(&mut list, 15).is_some()); +/// // [10, 12, 14] +/// +/// let mut list2 = List::new(); +/// list2.push_back(ListItem::new(11)?); +/// list2.push_back(ListItem::new(13)?); +/// merge_sorted(&mut list, list2); +/// +/// let mut items = list.into_iter(); +/// assert_eq!(items.next().ok_or(EINVAL)?.value, 10); +/// assert_eq!(items.next().ok_or(EINVAL)?.value, 11); +/// assert_eq!(items.next().ok_or(EINVAL)?.value, 12); +/// assert_eq!(items.next().ok_or(EINVAL)?.value, 13); +/// assert_eq!(items.next().ok_or(EINVAL)?.value, 14); +/// assert!(items.next().is_none()); +/// # Result::<(), Error>::Ok(()) +/// ``` +/// +/// # Invariants +/// +/// The `next` pointer is null or points a value in `list`. +pub struct Cursor<'a, T: ?Sized + ListItem<ID>, const ID: u64 = 0> { + list: &'a mut List<T, ID>, + /// Points at the element after this cursor, or null if the cursor is after the last element. + next: *mut ListLinksFields, +} + +impl<'a, T: ?Sized + ListItem<ID>, const ID: u64> Cursor<'a, T, ID> { + /// Returns a pointer to the element before the cursor. + /// + /// Returns null if there is no element before the cursor. + fn prev_ptr(&self) -> *mut ListLinksFields { + let mut next = self.next; + let first = self.list.first; + if next == first { + // We are before the first element. + return core::ptr::null_mut(); + } + + if next.is_null() { + // We are after the last element, so we need a pointer to the last element, which is + // the same as `(*first).prev`. + next = first; + } + + // SAFETY: `next` can't be null, because then `first` must also be null, but in that case + // we would have exited at the `next == first` check. Thus, `next` is an element in the + // list, so we can access its `prev` pointer. + unsafe { (*next).prev } + } + + /// Access the element after this cursor. + pub fn peek_next(&mut self) -> Option<CursorPeek<'_, 'a, T, true, ID>> { + if self.next.is_null() { + return None; + } + + // INVARIANT: + // * We just checked that `self.next` is non-null, so it must be in `self.list`. + // * `ptr` is equal to `self.next`. + Some(CursorPeek { + ptr: self.next, + cursor: self, + }) + } + + /// Access the element before this cursor. + pub fn peek_prev(&mut self) -> Option<CursorPeek<'_, 'a, T, false, ID>> { + let prev = self.prev_ptr(); + + if prev.is_null() { + return None; + } + + // INVARIANT: + // * We just checked that `prev` is non-null, so it must be in `self.list`. + // * `self.prev_ptr()` never returns `self.next`. + Some(CursorPeek { + ptr: prev, + cursor: self, + }) + } + + /// Move the cursor one element forward. + /// + /// If the cursor is after the last element, then this call does nothing. This call returns + /// `true` if the cursor's position was changed. + pub fn move_next(&mut self) -> bool { + if self.next.is_null() { + return false; + } + + // SAFETY: `self.next` is an element in the list and we borrow the list mutably, so we can + // access the `next` field. + let mut next = unsafe { (*self.next).next }; + + if next == self.list.first { + next = core::ptr::null_mut(); + } + + // INVARIANT: `next` is either null or the next element after an element in the list. + self.next = next; + true + } + + /// Move the cursor one element backwards. + /// + /// If the cursor is before the first element, then this call does nothing. This call returns + /// `true` if the cursor's position was changed. + pub fn move_prev(&mut self) -> bool { + if self.next == self.list.first { + return false; + } + + // INVARIANT: `prev_ptr()` always returns a pointer that is null or in the list. + self.next = self.prev_ptr(); + true + } + + /// Inserts an element where the cursor is pointing and get a pointer to the new element. + fn insert_inner(&mut self, item: ListArc<T, ID>) -> *mut ListLinksFields { + let ptr = if self.next.is_null() { + self.list.first + } else { + self.next + }; + // SAFETY: + // * `ptr` is an element in the list or null. + // * if `ptr` is null, then `self.list.first` is null so the list is empty. + let item = unsafe { self.list.insert_inner(item, ptr) }; + if self.next == self.list.first { + // INVARIANT: We just inserted `item`, so it's a member of list. + self.list.first = item; + } + item + } + + /// Insert an element at this cursor's location. + pub fn insert(mut self, item: ListArc<T, ID>) { + // This is identical to `insert_prev`, but consumes the cursor. This is helpful because it + // reduces confusion when the last operation on the cursor is an insertion; in that case, + // you just want to insert the element at the cursor, and it is confusing that the call + // involves the word prev or next. + self.insert_inner(item); + } + + /// Inserts an element after this cursor. + /// + /// After insertion, the new element will be after the cursor. + pub fn insert_next(&mut self, item: ListArc<T, ID>) { + self.next = self.insert_inner(item); + } + + /// Inserts an element before this cursor. + /// + /// After insertion, the new element will be before the cursor. + pub fn insert_prev(&mut self, item: ListArc<T, ID>) { + self.insert_inner(item); + } + + /// Remove the next element from the list. + pub fn remove_next(&mut self) -> Option<ListArc<T, ID>> { + self.peek_next().map(|v| v.remove()) + } + + /// Remove the previous element from the list. + pub fn remove_prev(&mut self) -> Option<ListArc<T, ID>> { + self.peek_prev().map(|v| v.remove()) + } +} + +/// References the element in the list next to the cursor. +/// +/// # Invariants +/// +/// * `ptr` is an element in `self.cursor.list`. +/// * `ISNEXT == (self.ptr == self.cursor.next)`. +pub struct CursorPeek<'a, 'b, T: ?Sized + ListItem<ID>, const ISNEXT: bool, const ID: u64> { + cursor: &'a mut Cursor<'b, T, ID>, + ptr: *mut ListLinksFields, +} + +impl<'a, 'b, T: ?Sized + ListItem<ID>, const ISNEXT: bool, const ID: u64> + CursorPeek<'a, 'b, T, ISNEXT, ID> +{ + /// Remove the element from the list. + pub fn remove(self) -> ListArc<T, ID> { + if ISNEXT { + self.cursor.move_next(); + } + + // INVARIANT: `self.ptr` is not equal to `self.cursor.next` due to the above `move_next` + // call. + // SAFETY: By the type invariants of `Self`, `next` is not null, so `next` is an element of + // `self.cursor.list` by the type invariants of `Cursor`. + unsafe { self.cursor.list.remove_internal(self.ptr) } + } + + /// Access this value as an [`ArcBorrow`]. + pub fn arc(&self) -> ArcBorrow<'_, T> { + // SAFETY: `self.ptr` points at an element in `self.cursor.list`. + let me = unsafe { T::view_value(ListLinks::from_fields(self.ptr)) }; + // SAFETY: + // * All values in a list are stored in an `Arc`. + // * The value cannot be removed from the list for the duration of the lifetime annotated + // on the returned `ArcBorrow`, because removing it from the list would require mutable + // access to the `CursorPeek`, the `Cursor` or the `List`. However, the `ArcBorrow` holds + // an immutable borrow on the `CursorPeek`, which in turn holds a mutable borrow on the + // `Cursor`, which in turn holds a mutable borrow on the `List`, so any such mutable + // access requires first releasing the immutable borrow on the `CursorPeek`. + // * Values in a list never have a `UniqueArc` reference, because the list has a `ListArc` + // reference, and `UniqueArc` references must be unique. + unsafe { ArcBorrow::from_raw(me) } + } +} + +impl<'a, 'b, T: ?Sized + ListItem<ID>, const ISNEXT: bool, const ID: u64> core::ops::Deref + for CursorPeek<'a, 'b, T, ISNEXT, ID> +{ + // If you change the `ptr` field to have type `ArcBorrow<'a, T>`, it might seem like you could + // get rid of the `CursorPeek::arc` method and change the deref target to `ArcBorrow<'a, T>`. + // However, that doesn't work because 'a is too long. You could obtain an `ArcBorrow<'a, T>` + // and then call `CursorPeek::remove` without giving up the `ArcBorrow<'a, T>`, which would be + // unsound. + type Target = T; + + fn deref(&self) -> &T { + // SAFETY: `self.ptr` points at an element in `self.cursor.list`. + let me = unsafe { T::view_value(ListLinks::from_fields(self.ptr)) }; + + // SAFETY: The value cannot be removed from the list for the duration of the lifetime + // annotated on the returned `&T`, because removing it from the list would require mutable + // access to the `CursorPeek`, the `Cursor` or the `List`. However, the `&T` holds an + // immutable borrow on the `CursorPeek`, which in turn holds a mutable borrow on the + // `Cursor`, which in turn holds a mutable borrow on the `List`, so any such mutable access + // requires first releasing the immutable borrow on the `CursorPeek`. + unsafe { &*me } + } +} + +impl<'a, T: ?Sized + ListItem<ID>, const ID: u64> FusedIterator for Iter<'a, T, ID> {} + +impl<'a, T: ?Sized + ListItem<ID>, const ID: u64> IntoIterator for &'a List<T, ID> { + type IntoIter = Iter<'a, T, ID>; + type Item = ArcBorrow<'a, T>; + + fn into_iter(self) -> Iter<'a, T, ID> { + self.iter() + } +} + +/// An owning iterator into a [`List`]. +pub struct IntoIter<T: ?Sized + ListItem<ID>, const ID: u64 = 0> { + list: List<T, ID>, +} + +impl<T: ?Sized + ListItem<ID>, const ID: u64> Iterator for IntoIter<T, ID> { + type Item = ListArc<T, ID>; + + fn next(&mut self) -> Option<ListArc<T, ID>> { + self.list.pop_front() + } +} + +impl<T: ?Sized + ListItem<ID>, const ID: u64> FusedIterator for IntoIter<T, ID> {} + +impl<T: ?Sized + ListItem<ID>, const ID: u64> DoubleEndedIterator for IntoIter<T, ID> { + fn next_back(&mut self) -> Option<ListArc<T, ID>> { + self.list.pop_back() + } +} + +impl<T: ?Sized + ListItem<ID>, const ID: u64> IntoIterator for List<T, ID> { + type IntoIter = IntoIter<T, ID>; + type Item = ListArc<T, ID>; + + fn into_iter(self) -> IntoIter<T, ID> { + IntoIter { list: self } + } +} diff --git a/rust/kernel/list/arc.rs b/rust/kernel/list/arc.rs new file mode 100644 index 000000000000..d92bcf665c89 --- /dev/null +++ b/rust/kernel/list/arc.rs @@ -0,0 +1,521 @@ +// SPDX-License-Identifier: GPL-2.0 + +// Copyright (C) 2024 Google LLC. + +//! A wrapper around `Arc` for linked lists. + +use crate::alloc::{AllocError, Flags}; +use crate::prelude::*; +use crate::sync::{Arc, ArcBorrow, UniqueArc}; +use core::marker::PhantomPinned; +use core::ops::Deref; +use core::pin::Pin; +use core::sync::atomic::{AtomicBool, Ordering}; + +/// Declares that this type has some way to ensure that there is exactly one `ListArc` instance for +/// this id. +/// +/// Types that implement this trait should include some kind of logic for keeping track of whether +/// a [`ListArc`] exists or not. We refer to this logic as "the tracking inside `T`". +/// +/// We allow the case where the tracking inside `T` thinks that a [`ListArc`] exists, but actually, +/// there isn't a [`ListArc`]. However, we do not allow the opposite situation where a [`ListArc`] +/// exists, but the tracking thinks it doesn't. This is because the former can at most result in us +/// failing to create a [`ListArc`] when the operation could succeed, whereas the latter can result +/// in the creation of two [`ListArc`] references. Only the latter situation can lead to memory +/// safety issues. +/// +/// A consequence of the above is that you may implement the tracking inside `T` by not actually +/// keeping track of anything. To do this, you always claim that a [`ListArc`] exists, even if +/// there isn't one. This implementation is allowed by the above rule, but it means that +/// [`ListArc`] references can only be created if you have ownership of *all* references to the +/// refcounted object, as you otherwise have no way of knowing whether a [`ListArc`] exists. +pub trait ListArcSafe<const ID: u64 = 0> { + /// Informs the tracking inside this type that it now has a [`ListArc`] reference. + /// + /// This method may be called even if the tracking inside this type thinks that a `ListArc` + /// reference exists. (But only if that's not actually the case.) + /// + /// # Safety + /// + /// Must not be called if a [`ListArc`] already exist for this value. + unsafe fn on_create_list_arc_from_unique(self: Pin<&mut Self>); + + /// Informs the tracking inside this type that there is no [`ListArc`] reference anymore. + /// + /// # Safety + /// + /// Must only be called if there is no [`ListArc`] reference, but the tracking thinks there is. + unsafe fn on_drop_list_arc(&self); +} + +/// Declares that this type is able to safely attempt to create `ListArc`s at any time. +/// +/// # Safety +/// +/// The guarantees of `try_new_list_arc` must be upheld. +pub unsafe trait TryNewListArc<const ID: u64 = 0>: ListArcSafe<ID> { + /// Attempts to convert an `Arc<Self>` into an `ListArc<Self>`. Returns `true` if the + /// conversion was successful. + /// + /// This method should not be called directly. Use [`ListArc::try_from_arc`] instead. + /// + /// # Guarantees + /// + /// If this call returns `true`, then there is no [`ListArc`] pointing to this value. + /// Additionally, this call will have transitioned the tracking inside `Self` from not thinking + /// that a [`ListArc`] exists, to thinking that a [`ListArc`] exists. + fn try_new_list_arc(&self) -> bool; +} + +/// Declares that this type supports [`ListArc`]. +/// +/// This macro supports a few different strategies for implementing the tracking inside the type: +/// +/// * The `untracked` strategy does not actually keep track of whether a [`ListArc`] exists. When +/// using this strategy, the only way to create a [`ListArc`] is using a [`UniqueArc`]. +/// * The `tracked_by` strategy defers the tracking to a field of the struct. The user must specify +/// which field to defer the tracking to. The field must implement [`ListArcSafe`]. If the field +/// implements [`TryNewListArc`], then the type will also implement [`TryNewListArc`]. +/// +/// The `tracked_by` strategy is usually used by deferring to a field of type +/// [`AtomicTracker`]. However, it is also possible to defer the tracking to another struct +/// using also using this macro. +#[macro_export] +macro_rules! impl_list_arc_safe { + (impl$({$($generics:tt)*})? ListArcSafe<$num:tt> for $t:ty { untracked; } $($rest:tt)*) => { + impl$(<$($generics)*>)? $crate::list::ListArcSafe<$num> for $t { + unsafe fn on_create_list_arc_from_unique(self: ::core::pin::Pin<&mut Self>) {} + unsafe fn on_drop_list_arc(&self) {} + } + $crate::list::impl_list_arc_safe! { $($rest)* } + }; + + (impl$({$($generics:tt)*})? ListArcSafe<$num:tt> for $t:ty { + tracked_by $field:ident : $fty:ty; + } $($rest:tt)*) => { + impl$(<$($generics)*>)? $crate::list::ListArcSafe<$num> for $t { + unsafe fn on_create_list_arc_from_unique(self: ::core::pin::Pin<&mut Self>) { + ::pin_init::assert_pinned!($t, $field, $fty, inline); + + // SAFETY: This field is structurally pinned as per the above assertion. + let field = unsafe { + ::core::pin::Pin::map_unchecked_mut(self, |me| &mut me.$field) + }; + // SAFETY: The caller promises that there is no `ListArc`. + unsafe { + <$fty as $crate::list::ListArcSafe<$num>>::on_create_list_arc_from_unique(field) + }; + } + unsafe fn on_drop_list_arc(&self) { + // SAFETY: The caller promises that there is no `ListArc` reference, and also + // promises that the tracking thinks there is a `ListArc` reference. + unsafe { <$fty as $crate::list::ListArcSafe<$num>>::on_drop_list_arc(&self.$field) }; + } + } + unsafe impl$(<$($generics)*>)? $crate::list::TryNewListArc<$num> for $t + where + $fty: TryNewListArc<$num>, + { + fn try_new_list_arc(&self) -> bool { + <$fty as $crate::list::TryNewListArc<$num>>::try_new_list_arc(&self.$field) + } + } + $crate::list::impl_list_arc_safe! { $($rest)* } + }; + + () => {}; +} +pub use impl_list_arc_safe; + +/// A wrapper around [`Arc`] that's guaranteed unique for the given id. +/// +/// The `ListArc` type can be thought of as a special reference to a refcounted object that owns the +/// permission to manipulate the `next`/`prev` pointers stored in the refcounted object. By ensuring +/// that each object has only one `ListArc` reference, the owner of that reference is assured +/// exclusive access to the `next`/`prev` pointers. When a `ListArc` is inserted into a [`List`], +/// the [`List`] takes ownership of the `ListArc` reference. +/// +/// There are various strategies to ensuring that a value has only one `ListArc` reference. The +/// simplest is to convert a [`UniqueArc`] into a `ListArc`. However, the refcounted object could +/// also keep track of whether a `ListArc` exists using a boolean, which could allow for the +/// creation of new `ListArc` references from an [`Arc`] reference. Whatever strategy is used, the +/// relevant tracking is referred to as "the tracking inside `T`", and the [`ListArcSafe`] trait +/// (and its subtraits) are used to update the tracking when a `ListArc` is created or destroyed. +/// +/// Note that we allow the case where the tracking inside `T` thinks that a `ListArc` exists, but +/// actually, there isn't a `ListArc`. However, we do not allow the opposite situation where a +/// `ListArc` exists, but the tracking thinks it doesn't. This is because the former can at most +/// result in us failing to create a `ListArc` when the operation could succeed, whereas the latter +/// can result in the creation of two `ListArc` references. +/// +/// While this `ListArc` is unique for the given id, there still might exist normal `Arc` +/// references to the object. +/// +/// # Invariants +/// +/// * Each reference counted object has at most one `ListArc` for each value of `ID`. +/// * The tracking inside `T` is aware that a `ListArc` reference exists. +/// +/// [`List`]: crate::list::List +#[repr(transparent)] +#[cfg_attr(CONFIG_RUSTC_HAS_COERCE_POINTEE, derive(core::marker::CoercePointee))] +pub struct ListArc<T, const ID: u64 = 0> +where + T: ListArcSafe<ID> + ?Sized, +{ + arc: Arc<T>, +} + +impl<T: ListArcSafe<ID>, const ID: u64> ListArc<T, ID> { + /// Constructs a new reference counted instance of `T`. + #[inline] + pub fn new(contents: T, flags: Flags) -> Result<Self, AllocError> { + Ok(Self::from(UniqueArc::new(contents, flags)?)) + } + + /// Use the given initializer to in-place initialize a `T`. + /// + /// If `T: !Unpin` it will not be able to move afterwards. + // We don't implement `InPlaceInit` because `ListArc` is implicitly pinned. This is similar to + // what we do for `Arc`. + #[inline] + pub fn pin_init<E>(init: impl PinInit<T, E>, flags: Flags) -> Result<Self, E> + where + E: From<AllocError>, + { + Ok(Self::from(UniqueArc::try_pin_init(init, flags)?)) + } + + /// Use the given initializer to in-place initialize a `T`. + /// + /// This is equivalent to [`ListArc<T>::pin_init`], since a [`ListArc`] is always pinned. + #[inline] + pub fn init<E>(init: impl Init<T, E>, flags: Flags) -> Result<Self, E> + where + E: From<AllocError>, + { + Ok(Self::from(UniqueArc::try_init(init, flags)?)) + } +} + +impl<T, const ID: u64> From<UniqueArc<T>> for ListArc<T, ID> +where + T: ListArcSafe<ID> + ?Sized, +{ + /// Convert a [`UniqueArc`] into a [`ListArc`]. + #[inline] + fn from(unique: UniqueArc<T>) -> Self { + Self::from(Pin::from(unique)) + } +} + +impl<T, const ID: u64> From<Pin<UniqueArc<T>>> for ListArc<T, ID> +where + T: ListArcSafe<ID> + ?Sized, +{ + /// Convert a pinned [`UniqueArc`] into a [`ListArc`]. + #[inline] + fn from(mut unique: Pin<UniqueArc<T>>) -> Self { + // SAFETY: We have a `UniqueArc`, so there is no `ListArc`. + unsafe { T::on_create_list_arc_from_unique(unique.as_mut()) }; + let arc = Arc::from(unique); + // SAFETY: We just called `on_create_list_arc_from_unique` on an arc without a `ListArc`, + // so we can create a `ListArc`. + unsafe { Self::transmute_from_arc(arc) } + } +} + +impl<T, const ID: u64> ListArc<T, ID> +where + T: ListArcSafe<ID> + ?Sized, +{ + /// Creates two `ListArc`s from a [`UniqueArc`]. + /// + /// The two ids must be different. + #[inline] + pub fn pair_from_unique<const ID2: u64>(unique: UniqueArc<T>) -> (Self, ListArc<T, ID2>) + where + T: ListArcSafe<ID2>, + { + Self::pair_from_pin_unique(Pin::from(unique)) + } + + /// Creates two `ListArc`s from a pinned [`UniqueArc`]. + /// + /// The two ids must be different. + #[inline] + pub fn pair_from_pin_unique<const ID2: u64>( + mut unique: Pin<UniqueArc<T>>, + ) -> (Self, ListArc<T, ID2>) + where + T: ListArcSafe<ID2>, + { + build_assert!(ID != ID2); + + // SAFETY: We have a `UniqueArc`, so there is no `ListArc`. + unsafe { <T as ListArcSafe<ID>>::on_create_list_arc_from_unique(unique.as_mut()) }; + // SAFETY: We have a `UniqueArc`, so there is no `ListArc`. + unsafe { <T as ListArcSafe<ID2>>::on_create_list_arc_from_unique(unique.as_mut()) }; + + let arc1 = Arc::from(unique); + let arc2 = Arc::clone(&arc1); + + // SAFETY: We just called `on_create_list_arc_from_unique` on an arc without a `ListArc` + // for both IDs (which are different), so we can create two `ListArc`s. + unsafe { + ( + Self::transmute_from_arc(arc1), + ListArc::transmute_from_arc(arc2), + ) + } + } + + /// Try to create a new `ListArc`. + /// + /// This fails if this value already has a `ListArc`. + pub fn try_from_arc(arc: Arc<T>) -> Result<Self, Arc<T>> + where + T: TryNewListArc<ID>, + { + if arc.try_new_list_arc() { + // SAFETY: The `try_new_list_arc` method returned true, so we made the tracking think + // that a `ListArc` exists. This lets us create a `ListArc`. + Ok(unsafe { Self::transmute_from_arc(arc) }) + } else { + Err(arc) + } + } + + /// Try to create a new `ListArc`. + /// + /// This fails if this value already has a `ListArc`. + pub fn try_from_arc_borrow(arc: ArcBorrow<'_, T>) -> Option<Self> + where + T: TryNewListArc<ID>, + { + if arc.try_new_list_arc() { + // SAFETY: The `try_new_list_arc` method returned true, so we made the tracking think + // that a `ListArc` exists. This lets us create a `ListArc`. + Some(unsafe { Self::transmute_from_arc(Arc::from(arc)) }) + } else { + None + } + } + + /// Try to create a new `ListArc`. + /// + /// If it's not possible to create a new `ListArc`, then the `Arc` is dropped. This will never + /// run the destructor of the value. + pub fn try_from_arc_or_drop(arc: Arc<T>) -> Option<Self> + where + T: TryNewListArc<ID>, + { + match Self::try_from_arc(arc) { + Ok(list_arc) => Some(list_arc), + Err(arc) => Arc::into_unique_or_drop(arc).map(Self::from), + } + } + + /// Transmutes an [`Arc`] into a `ListArc` without updating the tracking inside `T`. + /// + /// # Safety + /// + /// * The value must not already have a `ListArc` reference. + /// * The tracking inside `T` must think that there is a `ListArc` reference. + #[inline] + unsafe fn transmute_from_arc(arc: Arc<T>) -> Self { + // INVARIANT: By the safety requirements, the invariants on `ListArc` are satisfied. + Self { arc } + } + + /// Transmutes a `ListArc` into an [`Arc`] without updating the tracking inside `T`. + /// + /// After this call, the tracking inside `T` will still think that there is a `ListArc` + /// reference. + #[inline] + fn transmute_to_arc(self) -> Arc<T> { + // Use a transmute to skip destructor. + // + // SAFETY: ListArc is repr(transparent). + unsafe { core::mem::transmute(self) } + } + + /// Convert ownership of this `ListArc` into a raw pointer. + /// + /// The returned pointer is indistinguishable from pointers returned by [`Arc::into_raw`]. The + /// tracking inside `T` will still think that a `ListArc` exists after this call. + #[inline] + pub fn into_raw(self) -> *const T { + Arc::into_raw(Self::transmute_to_arc(self)) + } + + /// Take ownership of the `ListArc` from a raw pointer. + /// + /// # Safety + /// + /// * `ptr` must satisfy the safety requirements of [`Arc::from_raw`]. + /// * The value must not already have a `ListArc` reference. + /// * The tracking inside `T` must think that there is a `ListArc` reference. + #[inline] + pub unsafe fn from_raw(ptr: *const T) -> Self { + // SAFETY: The pointer satisfies the safety requirements for `Arc::from_raw`. + let arc = unsafe { Arc::from_raw(ptr) }; + // SAFETY: The value doesn't already have a `ListArc` reference, but the tracking thinks it + // does. + unsafe { Self::transmute_from_arc(arc) } + } + + /// Converts the `ListArc` into an [`Arc`]. + #[inline] + pub fn into_arc(self) -> Arc<T> { + let arc = Self::transmute_to_arc(self); + // SAFETY: There is no longer a `ListArc`, but the tracking thinks there is. + unsafe { T::on_drop_list_arc(&arc) }; + arc + } + + /// Clone a `ListArc` into an [`Arc`]. + #[inline] + pub fn clone_arc(&self) -> Arc<T> { + self.arc.clone() + } + + /// Returns a reference to an [`Arc`] from the given [`ListArc`]. + /// + /// This is useful when the argument of a function call is an [`&Arc`] (e.g., in a method + /// receiver), but we have a [`ListArc`] instead. + /// + /// [`&Arc`]: Arc + #[inline] + pub fn as_arc(&self) -> &Arc<T> { + &self.arc + } + + /// Returns an [`ArcBorrow`] from the given [`ListArc`]. + /// + /// This is useful when the argument of a function call is an [`ArcBorrow`] (e.g., in a method + /// receiver), but we have an [`Arc`] instead. Getting an [`ArcBorrow`] is free when optimised. + #[inline] + pub fn as_arc_borrow(&self) -> ArcBorrow<'_, T> { + self.arc.as_arc_borrow() + } + + /// Compare whether two [`ListArc`] pointers reference the same underlying object. + #[inline] + pub fn ptr_eq(this: &Self, other: &Self) -> bool { + Arc::ptr_eq(&this.arc, &other.arc) + } +} + +impl<T, const ID: u64> Deref for ListArc<T, ID> +where + T: ListArcSafe<ID> + ?Sized, +{ + type Target = T; + + #[inline] + fn deref(&self) -> &Self::Target { + self.arc.deref() + } +} + +impl<T, const ID: u64> Drop for ListArc<T, ID> +where + T: ListArcSafe<ID> + ?Sized, +{ + #[inline] + fn drop(&mut self) { + // SAFETY: There is no longer a `ListArc`, but the tracking thinks there is by the type + // invariants on `Self`. + unsafe { T::on_drop_list_arc(&self.arc) }; + } +} + +impl<T, const ID: u64> AsRef<Arc<T>> for ListArc<T, ID> +where + T: ListArcSafe<ID> + ?Sized, +{ + #[inline] + fn as_ref(&self) -> &Arc<T> { + self.as_arc() + } +} + +// This is to allow coercion from `ListArc<T>` to `ListArc<U>` if `T` can be converted to the +// dynamically-sized type (DST) `U`. +#[cfg(not(CONFIG_RUSTC_HAS_COERCE_POINTEE))] +impl<T, U, const ID: u64> core::ops::CoerceUnsized<ListArc<U, ID>> for ListArc<T, ID> +where + T: ListArcSafe<ID> + core::marker::Unsize<U> + ?Sized, + U: ListArcSafe<ID> + ?Sized, +{ +} + +// This is to allow `ListArc<U>` to be dispatched on when `ListArc<T>` can be coerced into +// `ListArc<U>`. +#[cfg(not(CONFIG_RUSTC_HAS_COERCE_POINTEE))] +impl<T, U, const ID: u64> core::ops::DispatchFromDyn<ListArc<U, ID>> for ListArc<T, ID> +where + T: ListArcSafe<ID> + core::marker::Unsize<U> + ?Sized, + U: ListArcSafe<ID> + ?Sized, +{ +} + +/// A utility for tracking whether a [`ListArc`] exists using an atomic. +/// +/// # Invariants +/// +/// If the boolean is `false`, then there is no [`ListArc`] for this value. +#[repr(transparent)] +pub struct AtomicTracker<const ID: u64 = 0> { + inner: AtomicBool, + // This value needs to be pinned to justify the INVARIANT: comment in `AtomicTracker::new`. + _pin: PhantomPinned, +} + +impl<const ID: u64> AtomicTracker<ID> { + /// Creates a new initializer for this type. + pub fn new() -> impl PinInit<Self> { + // INVARIANT: Pin-init initializers can't be used on an existing `Arc`, so this value will + // not be constructed in an `Arc` that already has a `ListArc`. + Self { + inner: AtomicBool::new(false), + _pin: PhantomPinned, + } + } + + fn project_inner(self: Pin<&mut Self>) -> &mut AtomicBool { + // SAFETY: The `inner` field is not structurally pinned, so we may obtain a mutable + // reference to it even if we only have a pinned reference to `self`. + unsafe { &mut Pin::into_inner_unchecked(self).inner } + } +} + +impl<const ID: u64> ListArcSafe<ID> for AtomicTracker<ID> { + unsafe fn on_create_list_arc_from_unique(self: Pin<&mut Self>) { + // INVARIANT: We just created a ListArc, so the boolean should be true. + *self.project_inner().get_mut() = true; + } + + unsafe fn on_drop_list_arc(&self) { + // INVARIANT: We just dropped a ListArc, so the boolean should be false. + self.inner.store(false, Ordering::Release); + } +} + +// SAFETY: If this method returns `true`, then by the type invariant there is no `ListArc` before +// this call, so it is okay to create a new `ListArc`. +// +// The acquire ordering will synchronize with the release store from the destruction of any +// previous `ListArc`, so if there was a previous `ListArc`, then the destruction of the previous +// `ListArc` happens-before the creation of the new `ListArc`. +unsafe impl<const ID: u64> TryNewListArc<ID> for AtomicTracker<ID> { + fn try_new_list_arc(&self) -> bool { + // INVARIANT: If this method returns true, then the boolean used to be false, and is no + // longer false, so it is okay for the caller to create a new [`ListArc`]. + self.inner + .compare_exchange(false, true, Ordering::Acquire, Ordering::Relaxed) + .is_ok() + } +} diff --git a/rust/kernel/list/arc_field.rs b/rust/kernel/list/arc_field.rs new file mode 100644 index 000000000000..c4b9dd503982 --- /dev/null +++ b/rust/kernel/list/arc_field.rs @@ -0,0 +1,96 @@ +// SPDX-License-Identifier: GPL-2.0 + +// Copyright (C) 2024 Google LLC. + +//! A field that is exclusively owned by a [`ListArc`]. +//! +//! This can be used to have reference counted struct where one of the reference counted pointers +//! has exclusive access to a field of the struct. +//! +//! [`ListArc`]: crate::list::ListArc + +use core::cell::UnsafeCell; + +/// A field owned by a specific [`ListArc`]. +/// +/// [`ListArc`]: crate::list::ListArc +pub struct ListArcField<T, const ID: u64 = 0> { + value: UnsafeCell<T>, +} + +// SAFETY: If the inner type is thread-safe, then it's also okay for `ListArc` to be thread-safe. +unsafe impl<T: Send + Sync, const ID: u64> Send for ListArcField<T, ID> {} +// SAFETY: If the inner type is thread-safe, then it's also okay for `ListArc` to be thread-safe. +unsafe impl<T: Send + Sync, const ID: u64> Sync for ListArcField<T, ID> {} + +impl<T, const ID: u64> ListArcField<T, ID> { + /// Creates a new `ListArcField`. + pub fn new(value: T) -> Self { + Self { + value: UnsafeCell::new(value), + } + } + + /// Access the value when we have exclusive access to the `ListArcField`. + /// + /// This allows access to the field using an `UniqueArc` instead of a `ListArc`. + pub fn get_mut(&mut self) -> &mut T { + self.value.get_mut() + } + + /// Unsafely assert that you have shared access to the `ListArc` for this field. + /// + /// # Safety + /// + /// The caller must have shared access to the `ListArc<ID>` containing the struct with this + /// field for the duration of the returned reference. + pub unsafe fn assert_ref(&self) -> &T { + // SAFETY: The caller has shared access to the `ListArc`, so they also have shared access + // to this field. + unsafe { &*self.value.get() } + } + + /// Unsafely assert that you have mutable access to the `ListArc` for this field. + /// + /// # Safety + /// + /// The caller must have mutable access to the `ListArc<ID>` containing the struct with this + /// field for the duration of the returned reference. + #[expect(clippy::mut_from_ref)] + pub unsafe fn assert_mut(&self) -> &mut T { + // SAFETY: The caller has exclusive access to the `ListArc`, so they also have exclusive + // access to this field. + unsafe { &mut *self.value.get() } + } +} + +/// Defines getters for a [`ListArcField`]. +#[macro_export] +macro_rules! define_list_arc_field_getter { + ($pub:vis fn $name:ident(&self $(<$id:tt>)?) -> &$typ:ty { $field:ident } + $($rest:tt)* + ) => { + $pub fn $name<'a>(self: &'a $crate::list::ListArc<Self $(, $id)?>) -> &'a $typ { + let field = &(&**self).$field; + // SAFETY: We have a shared reference to the `ListArc`. + unsafe { $crate::list::ListArcField::<$typ $(, $id)?>::assert_ref(field) } + } + + $crate::list::define_list_arc_field_getter!($($rest)*); + }; + + ($pub:vis fn $name:ident(&mut self $(<$id:tt>)?) -> &mut $typ:ty { $field:ident } + $($rest:tt)* + ) => { + $pub fn $name<'a>(self: &'a mut $crate::list::ListArc<Self $(, $id)?>) -> &'a mut $typ { + let field = &(&**self).$field; + // SAFETY: We have a mutable reference to the `ListArc`. + unsafe { $crate::list::ListArcField::<$typ $(, $id)?>::assert_mut(field) } + } + + $crate::list::define_list_arc_field_getter!($($rest)*); + }; + + () => {}; +} +pub use define_list_arc_field_getter; diff --git a/rust/kernel/list/impl_list_item_mod.rs b/rust/kernel/list/impl_list_item_mod.rs new file mode 100644 index 000000000000..202bc6f97c13 --- /dev/null +++ b/rust/kernel/list/impl_list_item_mod.rs @@ -0,0 +1,351 @@ +// SPDX-License-Identifier: GPL-2.0 + +// Copyright (C) 2024 Google LLC. + +//! Helpers for implementing list traits safely. + +/// Declares that this type has a [`ListLinks<ID>`] field. +/// +/// This trait is only used to help implement [`ListItem`] safely. If [`ListItem`] is implemented +/// manually, then this trait is not needed. Use the [`impl_has_list_links!`] macro to implement +/// this trait. +/// +/// # Safety +/// +/// The methods on this trait must have exactly the behavior that the definitions given below have. +/// +/// [`ListLinks<ID>`]: crate::list::ListLinks +/// [`ListItem`]: crate::list::ListItem +pub unsafe trait HasListLinks<const ID: u64 = 0> { + /// Returns a pointer to the [`ListLinks<ID>`] field. + /// + /// # Safety + /// + /// The provided pointer must point at a valid struct of type `Self`. + /// + /// [`ListLinks<ID>`]: crate::list::ListLinks + unsafe fn raw_get_list_links(ptr: *mut Self) -> *mut crate::list::ListLinks<ID>; +} + +/// Implements the [`HasListLinks`] trait for the given type. +#[macro_export] +macro_rules! impl_has_list_links { + ($(impl$({$($generics:tt)*})? + HasListLinks$(<$id:tt>)? + for $self:ty + { self$(.$field:ident)* } + )*) => {$( + // SAFETY: The implementation of `raw_get_list_links` only compiles if the field has the + // right type. + unsafe impl$(<$($generics)*>)? $crate::list::HasListLinks$(<$id>)? for $self { + #[inline] + unsafe fn raw_get_list_links(ptr: *mut Self) -> *mut $crate::list::ListLinks$(<$id>)? { + // Statically ensure that `$(.field)*` doesn't follow any pointers. + // + // Cannot be `const` because `$self` may contain generics and E0401 says constants + // "can't use {`Self`,generic parameters} from outer item". + if false { let _: usize = ::core::mem::offset_of!(Self, $($field).*); } + + // SAFETY: The caller promises that the pointer is not dangling. We know that this + // expression doesn't follow any pointers, as the `offset_of!` invocation above + // would otherwise not compile. + unsafe { ::core::ptr::addr_of_mut!((*ptr)$(.$field)*) } + } + } + )*}; +} +pub use impl_has_list_links; + +/// Declares that the [`ListLinks<ID>`] field in this struct is inside a +/// [`ListLinksSelfPtr<T, ID>`]. +/// +/// # Safety +/// +/// The [`ListLinks<ID>`] field of this struct at [`HasListLinks<ID>::raw_get_list_links`] must be +/// inside a [`ListLinksSelfPtr<T, ID>`]. +/// +/// [`ListLinks<ID>`]: crate::list::ListLinks +/// [`ListLinksSelfPtr<T, ID>`]: crate::list::ListLinksSelfPtr +pub unsafe trait HasSelfPtr<T: ?Sized, const ID: u64 = 0> +where + Self: HasListLinks<ID>, +{ +} + +/// Implements the [`HasListLinks`] and [`HasSelfPtr`] traits for the given type. +#[macro_export] +macro_rules! impl_has_list_links_self_ptr { + ($(impl$({$($generics:tt)*})? + HasSelfPtr<$item_type:ty $(, $id:tt)?> + for $self:ty + { self$(.$field:ident)* } + )*) => {$( + // SAFETY: The implementation of `raw_get_list_links` only compiles if the field has the + // right type. + unsafe impl$(<$($generics)*>)? $crate::list::HasSelfPtr<$item_type $(, $id)?> for $self {} + + unsafe impl$(<$($generics)*>)? $crate::list::HasListLinks$(<$id>)? for $self { + #[inline] + unsafe fn raw_get_list_links(ptr: *mut Self) -> *mut $crate::list::ListLinks$(<$id>)? { + // SAFETY: The caller promises that the pointer is not dangling. + let ptr: *mut $crate::list::ListLinksSelfPtr<$item_type $(, $id)?> = + unsafe { ::core::ptr::addr_of_mut!((*ptr)$(.$field)*) }; + ptr.cast() + } + } + )*}; +} +pub use impl_has_list_links_self_ptr; + +/// Implements the [`ListItem`] trait for the given type. +/// +/// Requires that the type implements [`HasListLinks`]. Use the [`impl_has_list_links!`] macro to +/// implement that trait. +/// +/// [`ListItem`]: crate::list::ListItem +/// +/// # Examples +/// +/// ``` +/// #[pin_data] +/// struct SimpleListItem { +/// value: u32, +/// #[pin] +/// links: kernel::list::ListLinks, +/// } +/// +/// kernel::list::impl_list_arc_safe! { +/// impl ListArcSafe<0> for SimpleListItem { untracked; } +/// } +/// +/// kernel::list::impl_list_item! { +/// impl ListItem<0> for SimpleListItem { using ListLinks { self.links }; } +/// } +/// +/// struct ListLinksHolder { +/// inner: kernel::list::ListLinks, +/// } +/// +/// #[pin_data] +/// struct ComplexListItem<T, U> { +/// value: Result<T, U>, +/// #[pin] +/// links: ListLinksHolder, +/// } +/// +/// kernel::list::impl_list_arc_safe! { +/// impl{T, U} ListArcSafe<0> for ComplexListItem<T, U> { untracked; } +/// } +/// +/// kernel::list::impl_list_item! { +/// impl{T, U} ListItem<0> for ComplexListItem<T, U> { using ListLinks { self.links.inner }; } +/// } +/// ``` +/// +/// ``` +/// #[pin_data] +/// struct SimpleListItem { +/// value: u32, +/// #[pin] +/// links: kernel::list::ListLinksSelfPtr<SimpleListItem>, +/// } +/// +/// kernel::list::impl_list_arc_safe! { +/// impl ListArcSafe<0> for SimpleListItem { untracked; } +/// } +/// +/// kernel::list::impl_list_item! { +/// impl ListItem<0> for SimpleListItem { using ListLinksSelfPtr { self.links }; } +/// } +/// +/// struct ListLinksSelfPtrHolder<T, U> { +/// inner: kernel::list::ListLinksSelfPtr<ComplexListItem<T, U>>, +/// } +/// +/// #[pin_data] +/// struct ComplexListItem<T, U> { +/// value: Result<T, U>, +/// #[pin] +/// links: ListLinksSelfPtrHolder<T, U>, +/// } +/// +/// kernel::list::impl_list_arc_safe! { +/// impl{T, U} ListArcSafe<0> for ComplexListItem<T, U> { untracked; } +/// } +/// +/// kernel::list::impl_list_item! { +/// impl{T, U} ListItem<0> for ComplexListItem<T, U> { +/// using ListLinksSelfPtr { self.links.inner }; +/// } +/// } +/// ``` +#[macro_export] +macro_rules! impl_list_item { + ( + $(impl$({$($generics:tt)*})? ListItem<$num:tt> for $self:ty { + using ListLinks { self$(.$field:ident)* }; + })* + ) => {$( + $crate::list::impl_has_list_links! { + impl$({$($generics)*})? HasListLinks<$num> for $self { self$(.$field)* } + } + + // SAFETY: See GUARANTEES comment on each method. + unsafe impl$(<$($generics)*>)? $crate::list::ListItem<$num> for $self { + // GUARANTEES: + // * This returns the same pointer as `prepare_to_insert` because `prepare_to_insert` + // is implemented in terms of `view_links`. + // * By the type invariants of `ListLinks`, the `ListLinks` has two null pointers when + // this value is not in a list. + unsafe fn view_links(me: *const Self) -> *mut $crate::list::ListLinks<$num> { + // SAFETY: The caller guarantees that `me` points at a valid value of type `Self`. + unsafe { + <Self as $crate::list::HasListLinks<$num>>::raw_get_list_links(me.cast_mut()) + } + } + + // GUARANTEES: + // * `me` originates from the most recent call to `prepare_to_insert`, which calls + // `raw_get_list_link`, which is implemented using `addr_of_mut!((*self)$(.$field)*)`. + // This method uses `container_of` to perform the inverse operation, so it returns the + // pointer originally passed to `prepare_to_insert`. + // * The pointer remains valid until the next call to `post_remove` because the caller + // of the most recent call to `prepare_to_insert` promised to retain ownership of the + // `ListArc` containing `Self` until the next call to `post_remove`. The value cannot + // be destroyed while a `ListArc` reference exists. + unsafe fn view_value(me: *mut $crate::list::ListLinks<$num>) -> *const Self { + // SAFETY: `me` originates from the most recent call to `prepare_to_insert`, so it + // points at the field `$field` in a value of type `Self`. Thus, reversing that + // operation is still in-bounds of the allocation. + $crate::container_of!(me, Self, $($field).*) + } + + // GUARANTEES: + // This implementation of `ListItem` will not give out exclusive access to the same + // `ListLinks` several times because calls to `prepare_to_insert` and `post_remove` + // must alternate and exclusive access is given up when `post_remove` is called. + // + // Other invocations of `impl_list_item!` also cannot give out exclusive access to the + // same `ListLinks` because you can only implement `ListItem` once for each value of + // `ID`, and the `ListLinks` fields only work with the specified `ID`. + unsafe fn prepare_to_insert(me: *const Self) -> *mut $crate::list::ListLinks<$num> { + // SAFETY: The caller promises that `me` points at a valid value. + unsafe { <Self as $crate::list::ListItem<$num>>::view_links(me) } + } + + // GUARANTEES: + // * `me` originates from the most recent call to `prepare_to_insert`, which calls + // `raw_get_list_link`, which is implemented using `addr_of_mut!((*self)$(.$field)*)`. + // This method uses `container_of` to perform the inverse operation, so it returns the + // pointer originally passed to `prepare_to_insert`. + unsafe fn post_remove(me: *mut $crate::list::ListLinks<$num>) -> *const Self { + // SAFETY: `me` originates from the most recent call to `prepare_to_insert`, so it + // points at the field `$field` in a value of type `Self`. Thus, reversing that + // operation is still in-bounds of the allocation. + $crate::container_of!(me, Self, $($field).*) + } + } + )*}; + + ( + $(impl$({$($generics:tt)*})? ListItem<$num:tt> for $self:ty { + using ListLinksSelfPtr { self$(.$field:ident)* }; + })* + ) => {$( + $crate::list::impl_has_list_links_self_ptr! { + impl$({$($generics)*})? HasSelfPtr<$self> for $self { self$(.$field)* } + } + + // SAFETY: See GUARANTEES comment on each method. + unsafe impl$(<$($generics)*>)? $crate::list::ListItem<$num> for $self { + // GUARANTEES: + // This implementation of `ListItem` will not give out exclusive access to the same + // `ListLinks` several times because calls to `prepare_to_insert` and `post_remove` + // must alternate and exclusive access is given up when `post_remove` is called. + // + // Other invocations of `impl_list_item!` also cannot give out exclusive access to the + // same `ListLinks` because you can only implement `ListItem` once for each value of + // `ID`, and the `ListLinks` fields only work with the specified `ID`. + unsafe fn prepare_to_insert(me: *const Self) -> *mut $crate::list::ListLinks<$num> { + // SAFETY: The caller promises that `me` points at a valid value of type `Self`. + let links_field = unsafe { <Self as $crate::list::ListItem<$num>>::view_links(me) }; + + let container = $crate::container_of!( + links_field, $crate::list::ListLinksSelfPtr<Self, $num>, inner + ); + + // SAFETY: By the same reasoning above, `links_field` is a valid pointer. + let self_ptr = unsafe { + $crate::list::ListLinksSelfPtr::raw_get_self_ptr(container) + }; + + let cell_inner = $crate::types::Opaque::cast_into(self_ptr); + + // SAFETY: This value is not accessed in any other places than `prepare_to_insert`, + // `post_remove`, or `view_value`. By the safety requirements of those methods, + // none of these three methods may be called in parallel with this call to + // `prepare_to_insert`, so this write will not race with any other access to the + // value. + unsafe { ::core::ptr::write(cell_inner, me) }; + + links_field + } + + // GUARANTEES: + // * This returns the same pointer as `prepare_to_insert` because `prepare_to_insert` + // returns the return value of `view_links`. + // * By the type invariants of `ListLinks`, the `ListLinks` has two null pointers when + // this value is not in a list. + unsafe fn view_links(me: *const Self) -> *mut $crate::list::ListLinks<$num> { + // SAFETY: The caller promises that `me` points at a valid value of type `Self`. + unsafe { + <Self as $crate::list::HasListLinks<$num>>::raw_get_list_links(me.cast_mut()) + } + } + + // This function is also used as the implementation of `post_remove`, so the caller + // may choose to satisfy the safety requirements of `post_remove` instead of the safety + // requirements for `view_value`. + // + // GUARANTEES: (always) + // * This returns the same pointer as the one passed to the most recent call to + // `prepare_to_insert` since that call wrote that pointer to this location. The value + // is only modified in `prepare_to_insert`, so it has not been modified since the + // most recent call. + // + // GUARANTEES: (only when using the `view_value` safety requirements) + // * The pointer remains valid until the next call to `post_remove` because the caller + // of the most recent call to `prepare_to_insert` promised to retain ownership of the + // `ListArc` containing `Self` until the next call to `post_remove`. The value cannot + // be destroyed while a `ListArc` reference exists. + unsafe fn view_value(links_field: *mut $crate::list::ListLinks<$num>) -> *const Self { + let container = $crate::container_of!( + links_field, $crate::list::ListLinksSelfPtr<Self, $num>, inner + ); + + // SAFETY: By the same reasoning above, `links_field` is a valid pointer. + let self_ptr = unsafe { + $crate::list::ListLinksSelfPtr::raw_get_self_ptr(container) + }; + + let cell_inner = $crate::types::Opaque::cast_into(self_ptr); + + // SAFETY: This is not a data race, because the only function that writes to this + // value is `prepare_to_insert`, but by the safety requirements the + // `prepare_to_insert` method may not be called in parallel with `view_value` or + // `post_remove`. + unsafe { ::core::ptr::read(cell_inner) } + } + + // GUARANTEES: + // The first guarantee of `view_value` is exactly what `post_remove` guarantees. + unsafe fn post_remove(me: *mut $crate::list::ListLinks<$num>) -> *const Self { + // SAFETY: This specific implementation of `view_value` allows the caller to + // promise the safety requirements of `post_remove` instead of the safety + // requirements for `view_value`. + unsafe { <Self as $crate::list::ListItem<$num>>::view_value(me) } + } + } + )*}; +} +pub use impl_list_item; diff --git a/rust/kernel/maple_tree.rs b/rust/kernel/maple_tree.rs new file mode 100644 index 000000000000..e72eec56bf57 --- /dev/null +++ b/rust/kernel/maple_tree.rs @@ -0,0 +1,647 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! Maple trees. +//! +//! C header: [`include/linux/maple_tree.h`](srctree/include/linux/maple_tree.h) +//! +//! Reference: <https://docs.kernel.org/core-api/maple_tree.html> + +use core::{ + marker::PhantomData, + ops::{Bound, RangeBounds}, + ptr, +}; + +use kernel::{ + alloc::Flags, + error::to_result, + prelude::*, + types::{ForeignOwnable, Opaque}, +}; + +/// A maple tree optimized for storing non-overlapping ranges. +/// +/// # Invariants +/// +/// Each range in the maple tree owns an instance of `T`. +#[pin_data(PinnedDrop)] +#[repr(transparent)] +pub struct MapleTree<T: ForeignOwnable> { + #[pin] + tree: Opaque<bindings::maple_tree>, + _p: PhantomData<T>, +} + +/// A maple tree with `MT_FLAGS_ALLOC_RANGE` set. +/// +/// All methods on [`MapleTree`] are also accessible on this type. +#[pin_data] +#[repr(transparent)] +pub struct MapleTreeAlloc<T: ForeignOwnable> { + #[pin] + tree: MapleTree<T>, +} + +// Make MapleTree methods usable on MapleTreeAlloc. +impl<T: ForeignOwnable> core::ops::Deref for MapleTreeAlloc<T> { + type Target = MapleTree<T>; + + #[inline] + fn deref(&self) -> &MapleTree<T> { + &self.tree + } +} + +#[inline] +fn to_maple_range(range: impl RangeBounds<usize>) -> Option<(usize, usize)> { + let first = match range.start_bound() { + Bound::Included(start) => *start, + Bound::Excluded(start) => start.checked_add(1)?, + Bound::Unbounded => 0, + }; + + let last = match range.end_bound() { + Bound::Included(end) => *end, + Bound::Excluded(end) => end.checked_sub(1)?, + Bound::Unbounded => usize::MAX, + }; + + if last < first { + return None; + } + + Some((first, last)) +} + +impl<T: ForeignOwnable> MapleTree<T> { + /// Create a new maple tree. + /// + /// The tree will use the regular implementation with a higher branching factor, rather than + /// the allocation tree. + #[inline] + pub fn new() -> impl PinInit<Self> { + pin_init!(MapleTree { + // SAFETY: This initializes a maple tree into a pinned slot. The maple tree will be + // destroyed in Drop before the memory location becomes invalid. + tree <- Opaque::ffi_init(|slot| unsafe { bindings::mt_init_flags(slot, 0) }), + _p: PhantomData, + }) + } + + /// Insert the value at the given index. + /// + /// # Errors + /// + /// If the maple tree already contains a range using the given index, then this call will + /// return an [`InsertErrorKind::Occupied`]. It may also fail if memory allocation fails. + /// + /// # Examples + /// + /// ``` + /// use kernel::maple_tree::{InsertErrorKind, MapleTree}; + /// + /// let tree = KBox::pin_init(MapleTree::<KBox<i32>>::new(), GFP_KERNEL)?; + /// + /// let ten = KBox::new(10, GFP_KERNEL)?; + /// let twenty = KBox::new(20, GFP_KERNEL)?; + /// let the_answer = KBox::new(42, GFP_KERNEL)?; + /// + /// // These calls will succeed. + /// tree.insert(100, ten, GFP_KERNEL)?; + /// tree.insert(101, twenty, GFP_KERNEL)?; + /// + /// // This will fail because the index is already in use. + /// assert_eq!( + /// tree.insert(100, the_answer, GFP_KERNEL).unwrap_err().cause, + /// InsertErrorKind::Occupied, + /// ); + /// # Ok::<_, Error>(()) + /// ``` + #[inline] + pub fn insert(&self, index: usize, value: T, gfp: Flags) -> Result<(), InsertError<T>> { + self.insert_range(index..=index, value, gfp) + } + + /// Insert a value to the specified range, failing on overlap. + /// + /// This accepts the usual types of Rust ranges using the `..` and `..=` syntax for exclusive + /// and inclusive ranges respectively. The range must not be empty, and must not overlap with + /// any existing range. + /// + /// # Errors + /// + /// If the maple tree already contains an overlapping range, then this call will return an + /// [`InsertErrorKind::Occupied`]. It may also fail if memory allocation fails or if the + /// requested range is invalid (e.g. empty). + /// + /// # Examples + /// + /// ``` + /// use kernel::maple_tree::{InsertErrorKind, MapleTree}; + /// + /// let tree = KBox::pin_init(MapleTree::<KBox<i32>>::new(), GFP_KERNEL)?; + /// + /// let ten = KBox::new(10, GFP_KERNEL)?; + /// let twenty = KBox::new(20, GFP_KERNEL)?; + /// let the_answer = KBox::new(42, GFP_KERNEL)?; + /// let hundred = KBox::new(100, GFP_KERNEL)?; + /// + /// // Insert the value 10 at the indices 100 to 499. + /// tree.insert_range(100..500, ten, GFP_KERNEL)?; + /// + /// // Insert the value 20 at the indices 500 to 1000. + /// tree.insert_range(500..=1000, twenty, GFP_KERNEL)?; + /// + /// // This will fail due to overlap with the previous range on index 1000. + /// assert_eq!( + /// tree.insert_range(1000..1200, the_answer, GFP_KERNEL).unwrap_err().cause, + /// InsertErrorKind::Occupied, + /// ); + /// + /// // When using .. to specify the range, you must be careful to ensure that the range is + /// // non-empty. + /// assert_eq!( + /// tree.insert_range(72..72, hundred, GFP_KERNEL).unwrap_err().cause, + /// InsertErrorKind::InvalidRequest, + /// ); + /// # Ok::<_, Error>(()) + /// ``` + pub fn insert_range<R>(&self, range: R, value: T, gfp: Flags) -> Result<(), InsertError<T>> + where + R: RangeBounds<usize>, + { + let Some((first, last)) = to_maple_range(range) else { + return Err(InsertError { + value, + cause: InsertErrorKind::InvalidRequest, + }); + }; + + let ptr = T::into_foreign(value); + + // SAFETY: The tree is valid, and we are passing a pointer to an owned instance of `T`. + let res = to_result(unsafe { + bindings::mtree_insert_range(self.tree.get(), first, last, ptr, gfp.as_raw()) + }); + + if let Err(err) = res { + // SAFETY: As `mtree_insert_range` failed, it is safe to take back ownership. + let value = unsafe { T::from_foreign(ptr) }; + + let cause = if err == ENOMEM { + InsertErrorKind::AllocError(kernel::alloc::AllocError) + } else if err == EEXIST { + InsertErrorKind::Occupied + } else { + InsertErrorKind::InvalidRequest + }; + Err(InsertError { value, cause }) + } else { + Ok(()) + } + } + + /// Erase the range containing the given index. + /// + /// # Examples + /// + /// ``` + /// use kernel::maple_tree::MapleTree; + /// + /// let tree = KBox::pin_init(MapleTree::<KBox<i32>>::new(), GFP_KERNEL)?; + /// + /// let ten = KBox::new(10, GFP_KERNEL)?; + /// let twenty = KBox::new(20, GFP_KERNEL)?; + /// + /// tree.insert_range(100..500, ten, GFP_KERNEL)?; + /// tree.insert(67, twenty, GFP_KERNEL)?; + /// + /// assert_eq!(tree.erase(67).map(|v| *v), Some(20)); + /// assert_eq!(tree.erase(275).map(|v| *v), Some(10)); + /// + /// // The previous call erased the entire range, not just index 275. + /// assert!(tree.erase(127).is_none()); + /// # Ok::<_, Error>(()) + /// ``` + #[inline] + pub fn erase(&self, index: usize) -> Option<T> { + // SAFETY: `self.tree` contains a valid maple tree. + let ret = unsafe { bindings::mtree_erase(self.tree.get(), index) }; + + // SAFETY: If the pointer is not null, then we took ownership of a valid instance of `T` + // from the tree. + unsafe { T::try_from_foreign(ret) } + } + + /// Lock the internal spinlock. + #[inline] + pub fn lock(&self) -> MapleGuard<'_, T> { + // SAFETY: It's safe to lock the spinlock in a maple tree. + unsafe { bindings::spin_lock(self.ma_lock()) }; + + // INVARIANT: We just took the spinlock. + MapleGuard(self) + } + + #[inline] + fn ma_lock(&self) -> *mut bindings::spinlock_t { + // SAFETY: This pointer offset operation stays in-bounds. + let lock_ptr = unsafe { &raw mut (*self.tree.get()).__bindgen_anon_1.ma_lock }; + lock_ptr.cast() + } + + /// Free all `T` instances in this tree. + /// + /// # Safety + /// + /// This frees Rust data referenced by the maple tree without removing it from the maple tree, + /// leaving it in an invalid state. The caller must ensure that this invalid state cannot be + /// observed by the end-user. + unsafe fn free_all_entries(self: Pin<&mut Self>) { + // SAFETY: The caller provides exclusive access to the entire maple tree, so we have + // exclusive access to the entire maple tree despite not holding the lock. + let mut ma_state = unsafe { MaState::new_raw(self.into_ref().get_ref(), 0, usize::MAX) }; + + loop { + // This uses the raw accessor because we're destroying pointers without removing them + // from the maple tree, which is only valid because this is the destructor. + let ptr = ma_state.mas_find_raw(usize::MAX); + if ptr.is_null() { + break; + } + // SAFETY: By the type invariants, this pointer references a valid value of type `T`. + // By the safety requirements, it is okay to free it without removing it from the maple + // tree. + drop(unsafe { T::from_foreign(ptr) }); + } + } +} + +#[pinned_drop] +impl<T: ForeignOwnable> PinnedDrop for MapleTree<T> { + #[inline] + fn drop(mut self: Pin<&mut Self>) { + // We only iterate the tree if the Rust value has a destructor. + if core::mem::needs_drop::<T>() { + // SAFETY: Other than the below `mtree_destroy` call, the tree will not be accessed + // after this call. + unsafe { self.as_mut().free_all_entries() }; + } + + // SAFETY: The tree is valid, and will not be accessed after this call. + unsafe { bindings::mtree_destroy(self.tree.get()) }; + } +} + +/// A reference to a [`MapleTree`] that owns the inner lock. +/// +/// # Invariants +/// +/// This guard owns the inner spinlock. +#[must_use = "if unused, the lock will be immediately unlocked"] +pub struct MapleGuard<'tree, T: ForeignOwnable>(&'tree MapleTree<T>); + +impl<'tree, T: ForeignOwnable> Drop for MapleGuard<'tree, T> { + #[inline] + fn drop(&mut self) { + // SAFETY: By the type invariants, we hold this spinlock. + unsafe { bindings::spin_unlock(self.0.ma_lock()) }; + } +} + +impl<'tree, T: ForeignOwnable> MapleGuard<'tree, T> { + /// Create a [`MaState`] protected by this lock guard. + pub fn ma_state(&mut self, first: usize, end: usize) -> MaState<'_, T> { + // SAFETY: The `MaState` borrows this `MapleGuard`, so it can also borrow the `MapleGuard`s + // read/write permissions to the maple tree. + unsafe { MaState::new_raw(self.0, first, end) } + } + + /// Load the value at the given index. + /// + /// # Examples + /// + /// Read the value while holding the spinlock. + /// + /// ``` + /// use kernel::maple_tree::MapleTree; + /// + /// let tree = KBox::pin_init(MapleTree::<KBox<i32>>::new(), GFP_KERNEL)?; + /// + /// let ten = KBox::new(10, GFP_KERNEL)?; + /// let twenty = KBox::new(20, GFP_KERNEL)?; + /// tree.insert(100, ten, GFP_KERNEL)?; + /// tree.insert(200, twenty, GFP_KERNEL)?; + /// + /// let mut lock = tree.lock(); + /// assert_eq!(lock.load(100).map(|v| *v), Some(10)); + /// assert_eq!(lock.load(200).map(|v| *v), Some(20)); + /// assert_eq!(lock.load(300).map(|v| *v), None); + /// # Ok::<_, Error>(()) + /// ``` + /// + /// Increment refcount under the lock, to keep value alive afterwards. + /// + /// ``` + /// use kernel::maple_tree::MapleTree; + /// use kernel::sync::Arc; + /// + /// let tree = KBox::pin_init(MapleTree::<Arc<i32>>::new(), GFP_KERNEL)?; + /// + /// let ten = Arc::new(10, GFP_KERNEL)?; + /// let twenty = Arc::new(20, GFP_KERNEL)?; + /// tree.insert(100, ten, GFP_KERNEL)?; + /// tree.insert(200, twenty, GFP_KERNEL)?; + /// + /// // Briefly take the lock to increment the refcount. + /// let value = tree.lock().load(100).map(Arc::from); + /// + /// // At this point, another thread might remove the value. + /// tree.erase(100); + /// + /// // But we can still access it because we took a refcount. + /// assert_eq!(value.map(|v| *v), Some(10)); + /// # Ok::<_, Error>(()) + /// ``` + #[inline] + pub fn load(&mut self, index: usize) -> Option<T::BorrowedMut<'_>> { + // SAFETY: `self.tree` contains a valid maple tree. + let ret = unsafe { bindings::mtree_load(self.0.tree.get(), index) }; + if ret.is_null() { + return None; + } + + // SAFETY: If the pointer is not null, then it references a valid instance of `T`. It is + // safe to borrow the instance mutably because the signature of this function enforces that + // the mutable borrow is not used after the spinlock is dropped. + Some(unsafe { T::borrow_mut(ret) }) + } +} + +impl<T: ForeignOwnable> MapleTreeAlloc<T> { + /// Create a new allocation tree. + pub fn new() -> impl PinInit<Self> { + let tree = pin_init!(MapleTree { + // SAFETY: This initializes a maple tree into a pinned slot. The maple tree will be + // destroyed in Drop before the memory location becomes invalid. + tree <- Opaque::ffi_init(|slot| unsafe { + bindings::mt_init_flags(slot, bindings::MT_FLAGS_ALLOC_RANGE) + }), + _p: PhantomData, + }); + + pin_init!(MapleTreeAlloc { tree <- tree }) + } + + /// Insert an entry with the given size somewhere in the given range. + /// + /// The maple tree will search for a location in the given range where there is space to insert + /// the new range. If there is not enough available space, then an error will be returned. + /// + /// The index of the new range is returned. + /// + /// # Examples + /// + /// ``` + /// use kernel::maple_tree::{MapleTreeAlloc, AllocErrorKind}; + /// + /// let tree = KBox::pin_init(MapleTreeAlloc::<KBox<i32>>::new(), GFP_KERNEL)?; + /// + /// let ten = KBox::new(10, GFP_KERNEL)?; + /// let twenty = KBox::new(20, GFP_KERNEL)?; + /// let thirty = KBox::new(30, GFP_KERNEL)?; + /// let hundred = KBox::new(100, GFP_KERNEL)?; + /// + /// // Allocate three ranges. + /// let idx1 = tree.alloc_range(100, ten, ..1000, GFP_KERNEL)?; + /// let idx2 = tree.alloc_range(100, twenty, ..1000, GFP_KERNEL)?; + /// let idx3 = tree.alloc_range(100, thirty, ..1000, GFP_KERNEL)?; + /// + /// assert_eq!(idx1, 0); + /// assert_eq!(idx2, 100); + /// assert_eq!(idx3, 200); + /// + /// // This will fail because the remaining space is too small. + /// assert_eq!( + /// tree.alloc_range(800, hundred, ..1000, GFP_KERNEL).unwrap_err().cause, + /// AllocErrorKind::Busy, + /// ); + /// # Ok::<_, Error>(()) + /// ``` + pub fn alloc_range<R>( + &self, + size: usize, + value: T, + range: R, + gfp: Flags, + ) -> Result<usize, AllocError<T>> + where + R: RangeBounds<usize>, + { + let Some((min, max)) = to_maple_range(range) else { + return Err(AllocError { + value, + cause: AllocErrorKind::InvalidRequest, + }); + }; + + let ptr = T::into_foreign(value); + let mut index = 0; + + // SAFETY: The tree is valid, and we are passing a pointer to an owned instance of `T`. + let res = to_result(unsafe { + bindings::mtree_alloc_range( + self.tree.tree.get(), + &mut index, + ptr, + size, + min, + max, + gfp.as_raw(), + ) + }); + + if let Err(err) = res { + // SAFETY: As `mtree_alloc_range` failed, it is safe to take back ownership. + let value = unsafe { T::from_foreign(ptr) }; + + let cause = if err == ENOMEM { + AllocErrorKind::AllocError(kernel::alloc::AllocError) + } else if err == EBUSY { + AllocErrorKind::Busy + } else { + AllocErrorKind::InvalidRequest + }; + Err(AllocError { value, cause }) + } else { + Ok(index) + } + } +} + +/// A helper type used for navigating a [`MapleTree`]. +/// +/// # Invariants +/// +/// For the duration of `'tree`: +/// +/// * The `ma_state` references a valid `MapleTree<T>`. +/// * The `ma_state` has read/write access to the tree. +pub struct MaState<'tree, T: ForeignOwnable> { + state: bindings::ma_state, + _phantom: PhantomData<&'tree mut MapleTree<T>>, +} + +impl<'tree, T: ForeignOwnable> MaState<'tree, T> { + /// Initialize a new `MaState` with the given tree. + /// + /// # Safety + /// + /// The caller must ensure that this `MaState` has read/write access to the maple tree. + #[inline] + unsafe fn new_raw(mt: &'tree MapleTree<T>, first: usize, end: usize) -> Self { + // INVARIANT: + // * Having a reference ensures that the `MapleTree<T>` is valid for `'tree`. + // * The caller ensures that we have read/write access. + Self { + state: bindings::ma_state { + tree: mt.tree.get(), + index: first, + last: end, + node: ptr::null_mut(), + status: bindings::maple_status_ma_start, + min: 0, + max: usize::MAX, + alloc: ptr::null_mut(), + mas_flags: 0, + store_type: bindings::store_type_wr_invalid, + ..Default::default() + }, + _phantom: PhantomData, + } + } + + #[inline] + fn as_raw(&mut self) -> *mut bindings::ma_state { + &raw mut self.state + } + + #[inline] + fn mas_find_raw(&mut self, max: usize) -> *mut c_void { + // SAFETY: By the type invariants, the `ma_state` is active and we have read/write access + // to the tree. + unsafe { bindings::mas_find(self.as_raw(), max) } + } + + /// Find the next entry in the maple tree. + /// + /// # Examples + /// + /// Iterate the maple tree. + /// + /// ``` + /// use kernel::maple_tree::MapleTree; + /// use kernel::sync::Arc; + /// + /// let tree = KBox::pin_init(MapleTree::<Arc<i32>>::new(), GFP_KERNEL)?; + /// + /// let ten = Arc::new(10, GFP_KERNEL)?; + /// let twenty = Arc::new(20, GFP_KERNEL)?; + /// tree.insert(100, ten, GFP_KERNEL)?; + /// tree.insert(200, twenty, GFP_KERNEL)?; + /// + /// let mut ma_lock = tree.lock(); + /// let mut iter = ma_lock.ma_state(0, usize::MAX); + /// + /// assert_eq!(iter.find(usize::MAX).map(|v| *v), Some(10)); + /// assert_eq!(iter.find(usize::MAX).map(|v| *v), Some(20)); + /// assert!(iter.find(usize::MAX).is_none()); + /// # Ok::<_, Error>(()) + /// ``` + #[inline] + pub fn find(&mut self, max: usize) -> Option<T::BorrowedMut<'_>> { + let ret = self.mas_find_raw(max); + if ret.is_null() { + return None; + } + + // SAFETY: If the pointer is not null, then it references a valid instance of `T`. It's + // safe to access it mutably as the returned reference borrows this `MaState`, and the + // `MaState` has read/write access to the maple tree. + Some(unsafe { T::borrow_mut(ret) }) + } +} + +/// Error type for failure to insert a new value. +pub struct InsertError<T> { + /// The value that could not be inserted. + pub value: T, + /// The reason for the failure to insert. + pub cause: InsertErrorKind, +} + +/// The reason for the failure to insert. +#[derive(PartialEq, Eq, Copy, Clone, Debug)] +pub enum InsertErrorKind { + /// There is already a value in the requested range. + Occupied, + /// Failure to allocate memory. + AllocError(kernel::alloc::AllocError), + /// The insertion request was invalid. + InvalidRequest, +} + +impl From<InsertErrorKind> for Error { + #[inline] + fn from(kind: InsertErrorKind) -> Error { + match kind { + InsertErrorKind::Occupied => EEXIST, + InsertErrorKind::AllocError(kernel::alloc::AllocError) => ENOMEM, + InsertErrorKind::InvalidRequest => EINVAL, + } + } +} + +impl<T> From<InsertError<T>> for Error { + #[inline] + fn from(insert_err: InsertError<T>) -> Error { + Error::from(insert_err.cause) + } +} + +/// Error type for failure to insert a new value. +pub struct AllocError<T> { + /// The value that could not be inserted. + pub value: T, + /// The reason for the failure to insert. + pub cause: AllocErrorKind, +} + +/// The reason for the failure to insert. +#[derive(PartialEq, Eq, Copy, Clone)] +pub enum AllocErrorKind { + /// There is not enough space for the requested allocation. + Busy, + /// Failure to allocate memory. + AllocError(kernel::alloc::AllocError), + /// The insertion request was invalid. + InvalidRequest, +} + +impl From<AllocErrorKind> for Error { + #[inline] + fn from(kind: AllocErrorKind) -> Error { + match kind { + AllocErrorKind::Busy => EBUSY, + AllocErrorKind::AllocError(kernel::alloc::AllocError) => ENOMEM, + AllocErrorKind::InvalidRequest => EINVAL, + } + } +} + +impl<T> From<AllocError<T>> for Error { + #[inline] + fn from(insert_err: AllocError<T>) -> Error { + Error::from(insert_err.cause) + } +} diff --git a/rust/kernel/miscdevice.rs b/rust/kernel/miscdevice.rs new file mode 100644 index 000000000000..d698cddcb4a5 --- /dev/null +++ b/rust/kernel/miscdevice.rs @@ -0,0 +1,430 @@ +// SPDX-License-Identifier: GPL-2.0 + +// Copyright (C) 2024 Google LLC. + +//! Miscdevice support. +//! +//! C headers: [`include/linux/miscdevice.h`](srctree/include/linux/miscdevice.h). +//! +//! Reference: <https://www.kernel.org/doc/html/latest/driver-api/misc_devices.html> + +use crate::{ + bindings, + device::Device, + error::{to_result, Error, Result, VTABLE_DEFAULT_ERROR}, + ffi::{c_int, c_long, c_uint, c_ulong}, + fs::{File, Kiocb}, + iov::{IovIterDest, IovIterSource}, + mm::virt::VmaNew, + prelude::*, + seq_file::SeqFile, + types::{ForeignOwnable, Opaque}, +}; +use core::{marker::PhantomData, mem::MaybeUninit, pin::Pin}; + +/// Options for creating a misc device. +#[derive(Copy, Clone)] +pub struct MiscDeviceOptions { + /// The name of the miscdevice. + pub name: &'static CStr, +} + +impl MiscDeviceOptions { + /// Create a raw `struct miscdev` ready for registration. + pub const fn into_raw<T: MiscDevice>(self) -> bindings::miscdevice { + // SAFETY: All zeros is valid for this C type. + let mut result: bindings::miscdevice = unsafe { MaybeUninit::zeroed().assume_init() }; + result.minor = bindings::MISC_DYNAMIC_MINOR as ffi::c_int; + result.name = crate::str::as_char_ptr_in_const_context(self.name); + result.fops = MiscdeviceVTable::<T>::build(); + result + } +} + +/// A registration of a miscdevice. +/// +/// # Invariants +/// +/// - `inner` contains a `struct miscdevice` that is registered using +/// `misc_register()`. +/// - This registration remains valid for the entire lifetime of the +/// [`MiscDeviceRegistration`] instance. +/// - Deregistration occurs exactly once in [`Drop`] via `misc_deregister()`. +/// - `inner` wraps a valid, pinned `miscdevice` created using +/// [`MiscDeviceOptions::into_raw`]. +#[repr(transparent)] +#[pin_data(PinnedDrop)] +pub struct MiscDeviceRegistration<T> { + #[pin] + inner: Opaque<bindings::miscdevice>, + _t: PhantomData<T>, +} + +// SAFETY: It is allowed to call `misc_deregister` on a different thread from where you called +// `misc_register`. +unsafe impl<T> Send for MiscDeviceRegistration<T> {} +// SAFETY: All `&self` methods on this type are written to ensure that it is safe to call them in +// parallel. +unsafe impl<T> Sync for MiscDeviceRegistration<T> {} + +impl<T: MiscDevice> MiscDeviceRegistration<T> { + /// Register a misc device. + pub fn register(opts: MiscDeviceOptions) -> impl PinInit<Self, Error> { + try_pin_init!(Self { + inner <- Opaque::try_ffi_init(move |slot: *mut bindings::miscdevice| { + // SAFETY: The initializer can write to the provided `slot`. + unsafe { slot.write(opts.into_raw::<T>()) }; + + // SAFETY: We just wrote the misc device options to the slot. The miscdevice will + // get unregistered before `slot` is deallocated because the memory is pinned and + // the destructor of this type deallocates the memory. + // INVARIANT: If this returns `Ok(())`, then the `slot` will contain a registered + // misc device. + to_result(unsafe { bindings::misc_register(slot) }) + }), + _t: PhantomData, + }) + } + + /// Returns a raw pointer to the misc device. + pub fn as_raw(&self) -> *mut bindings::miscdevice { + self.inner.get() + } + + /// Access the `this_device` field. + pub fn device(&self) -> &Device { + // SAFETY: This can only be called after a successful register(), which always + // initialises `this_device` with a valid device. Furthermore, the signature of this + // function tells the borrow-checker that the `&Device` reference must not outlive the + // `&MiscDeviceRegistration<T>` used to obtain it, so the last use of the reference must be + // before the underlying `struct miscdevice` is destroyed. + unsafe { Device::from_raw((*self.as_raw()).this_device) } + } +} + +#[pinned_drop] +impl<T> PinnedDrop for MiscDeviceRegistration<T> { + fn drop(self: Pin<&mut Self>) { + // SAFETY: We know that the device is registered by the type invariants. + unsafe { bindings::misc_deregister(self.inner.get()) }; + } +} + +/// Trait implemented by the private data of an open misc device. +#[vtable] +pub trait MiscDevice: Sized { + /// What kind of pointer should `Self` be wrapped in. + type Ptr: ForeignOwnable + Send + Sync; + + /// Called when the misc device is opened. + /// + /// The returned pointer will be stored as the private data for the file. + fn open(_file: &File, _misc: &MiscDeviceRegistration<Self>) -> Result<Self::Ptr>; + + /// Called when the misc device is released. + fn release(device: Self::Ptr, _file: &File) { + drop(device); + } + + /// Handle for mmap. + /// + /// This function is invoked when a user space process invokes the `mmap` system call on + /// `file`. The function is a callback that is part of the VMA initializer. The kernel will do + /// initial setup of the VMA before calling this function. The function can then interact with + /// the VMA initialization by calling methods of `vma`. If the function does not return an + /// error, the kernel will complete initialization of the VMA according to the properties of + /// `vma`. + fn mmap( + _device: <Self::Ptr as ForeignOwnable>::Borrowed<'_>, + _file: &File, + _vma: &VmaNew, + ) -> Result { + build_error!(VTABLE_DEFAULT_ERROR) + } + + /// Read from this miscdevice. + fn read_iter(_kiocb: Kiocb<'_, Self::Ptr>, _iov: &mut IovIterDest<'_>) -> Result<usize> { + build_error!(VTABLE_DEFAULT_ERROR) + } + + /// Write to this miscdevice. + fn write_iter(_kiocb: Kiocb<'_, Self::Ptr>, _iov: &mut IovIterSource<'_>) -> Result<usize> { + build_error!(VTABLE_DEFAULT_ERROR) + } + + /// Handler for ioctls. + /// + /// The `cmd` argument is usually manipulated using the utilities in [`kernel::ioctl`]. + /// + /// [`kernel::ioctl`]: mod@crate::ioctl + fn ioctl( + _device: <Self::Ptr as ForeignOwnable>::Borrowed<'_>, + _file: &File, + _cmd: u32, + _arg: usize, + ) -> Result<isize> { + build_error!(VTABLE_DEFAULT_ERROR) + } + + /// Handler for ioctls. + /// + /// Used for 32-bit userspace on 64-bit platforms. + /// + /// This method is optional and only needs to be provided if the ioctl relies on structures + /// that have different layout on 32-bit and 64-bit userspace. If no implementation is + /// provided, then `compat_ptr_ioctl` will be used instead. + #[cfg(CONFIG_COMPAT)] + fn compat_ioctl( + _device: <Self::Ptr as ForeignOwnable>::Borrowed<'_>, + _file: &File, + _cmd: u32, + _arg: usize, + ) -> Result<isize> { + build_error!(VTABLE_DEFAULT_ERROR) + } + + /// Show info for this fd. + fn show_fdinfo( + _device: <Self::Ptr as ForeignOwnable>::Borrowed<'_>, + _m: &SeqFile, + _file: &File, + ) { + build_error!(VTABLE_DEFAULT_ERROR) + } +} + +/// A vtable for the file operations of a Rust miscdevice. +struct MiscdeviceVTable<T: MiscDevice>(PhantomData<T>); + +impl<T: MiscDevice> MiscdeviceVTable<T> { + /// # Safety + /// + /// `file` and `inode` must be the file and inode for a file that is undergoing initialization. + /// The file must be associated with a `MiscDeviceRegistration<T>`. + unsafe extern "C" fn open(inode: *mut bindings::inode, raw_file: *mut bindings::file) -> c_int { + // SAFETY: The pointers are valid and for a file being opened. + let ret = unsafe { bindings::generic_file_open(inode, raw_file) }; + if ret != 0 { + return ret; + } + + // SAFETY: The open call of a file can access the private data. + let misc_ptr = unsafe { (*raw_file).private_data }; + + // SAFETY: This is a miscdevice, so `misc_open()` set the private data to a pointer to the + // associated `struct miscdevice` before calling into this method. Furthermore, + // `misc_open()` ensures that the miscdevice can't be unregistered and freed during this + // call to `fops_open`. + let misc = unsafe { &*misc_ptr.cast::<MiscDeviceRegistration<T>>() }; + + // SAFETY: + // * This underlying file is valid for (much longer than) the duration of `T::open`. + // * There is no active fdget_pos region on the file on this thread. + let file = unsafe { File::from_raw_file(raw_file) }; + + let ptr = match T::open(file, misc) { + Ok(ptr) => ptr, + Err(err) => return err.to_errno(), + }; + + // This overwrites the private data with the value specified by the user, changing the type + // of this file's private data. All future accesses to the private data is performed by + // other fops_* methods in this file, which all correctly cast the private data to the new + // type. + // + // SAFETY: The open call of a file can access the private data. + unsafe { (*raw_file).private_data = ptr.into_foreign() }; + + 0 + } + + /// # Safety + /// + /// `file` and `inode` must be the file and inode for a file that is being released. The file + /// must be associated with a `MiscDeviceRegistration<T>`. + unsafe extern "C" fn release(_inode: *mut bindings::inode, file: *mut bindings::file) -> c_int { + // SAFETY: The release call of a file owns the private data. + let private = unsafe { (*file).private_data }; + // SAFETY: The release call of a file owns the private data. + let ptr = unsafe { <T::Ptr as ForeignOwnable>::from_foreign(private) }; + + // SAFETY: + // * The file is valid for the duration of this call. + // * There is no active fdget_pos region on the file on this thread. + T::release(ptr, unsafe { File::from_raw_file(file) }); + + 0 + } + + /// # Safety + /// + /// `kiocb` must be correspond to a valid file that is associated with a + /// `MiscDeviceRegistration<T>`. `iter` must be a valid `struct iov_iter` for writing. + unsafe extern "C" fn read_iter( + kiocb: *mut bindings::kiocb, + iter: *mut bindings::iov_iter, + ) -> isize { + // SAFETY: The caller provides a valid `struct kiocb` associated with a + // `MiscDeviceRegistration<T>` file. + let kiocb = unsafe { Kiocb::from_raw(kiocb) }; + // SAFETY: This is a valid `struct iov_iter` for writing. + let iov = unsafe { IovIterDest::from_raw(iter) }; + + match T::read_iter(kiocb, iov) { + Ok(res) => res as isize, + Err(err) => err.to_errno() as isize, + } + } + + /// # Safety + /// + /// `kiocb` must be correspond to a valid file that is associated with a + /// `MiscDeviceRegistration<T>`. `iter` must be a valid `struct iov_iter` for writing. + unsafe extern "C" fn write_iter( + kiocb: *mut bindings::kiocb, + iter: *mut bindings::iov_iter, + ) -> isize { + // SAFETY: The caller provides a valid `struct kiocb` associated with a + // `MiscDeviceRegistration<T>` file. + let kiocb = unsafe { Kiocb::from_raw(kiocb) }; + // SAFETY: This is a valid `struct iov_iter` for reading. + let iov = unsafe { IovIterSource::from_raw(iter) }; + + match T::write_iter(kiocb, iov) { + Ok(res) => res as isize, + Err(err) => err.to_errno() as isize, + } + } + + /// # Safety + /// + /// `file` must be a valid file that is associated with a `MiscDeviceRegistration<T>`. + /// `vma` must be a vma that is currently being mmap'ed with this file. + unsafe extern "C" fn mmap( + file: *mut bindings::file, + vma: *mut bindings::vm_area_struct, + ) -> c_int { + // SAFETY: The mmap call of a file can access the private data. + let private = unsafe { (*file).private_data }; + // SAFETY: This is a Rust Miscdevice, so we call `into_foreign` in `open` and + // `from_foreign` in `release`, and `fops_mmap` is guaranteed to be called between those + // two operations. + let device = unsafe { <T::Ptr as ForeignOwnable>::borrow(private.cast()) }; + // SAFETY: The caller provides a vma that is undergoing initial VMA setup. + let area = unsafe { VmaNew::from_raw(vma) }; + // SAFETY: + // * The file is valid for the duration of this call. + // * There is no active fdget_pos region on the file on this thread. + let file = unsafe { File::from_raw_file(file) }; + + match T::mmap(device, file, area) { + Ok(()) => 0, + Err(err) => err.to_errno(), + } + } + + /// # Safety + /// + /// `file` must be a valid file that is associated with a `MiscDeviceRegistration<T>`. + unsafe extern "C" fn ioctl(file: *mut bindings::file, cmd: c_uint, arg: c_ulong) -> c_long { + // SAFETY: The ioctl call of a file can access the private data. + let private = unsafe { (*file).private_data }; + // SAFETY: Ioctl calls can borrow the private data of the file. + let device = unsafe { <T::Ptr as ForeignOwnable>::borrow(private) }; + + // SAFETY: + // * The file is valid for the duration of this call. + // * There is no active fdget_pos region on the file on this thread. + let file = unsafe { File::from_raw_file(file) }; + + match T::ioctl(device, file, cmd, arg) { + Ok(ret) => ret as c_long, + Err(err) => err.to_errno() as c_long, + } + } + + /// # Safety + /// + /// `file` must be a valid file that is associated with a `MiscDeviceRegistration<T>`. + #[cfg(CONFIG_COMPAT)] + unsafe extern "C" fn compat_ioctl( + file: *mut bindings::file, + cmd: c_uint, + arg: c_ulong, + ) -> c_long { + // SAFETY: The compat ioctl call of a file can access the private data. + let private = unsafe { (*file).private_data }; + // SAFETY: Ioctl calls can borrow the private data of the file. + let device = unsafe { <T::Ptr as ForeignOwnable>::borrow(private) }; + + // SAFETY: + // * The file is valid for the duration of this call. + // * There is no active fdget_pos region on the file on this thread. + let file = unsafe { File::from_raw_file(file) }; + + match T::compat_ioctl(device, file, cmd, arg) { + Ok(ret) => ret as c_long, + Err(err) => err.to_errno() as c_long, + } + } + + /// # Safety + /// + /// - `file` must be a valid file that is associated with a `MiscDeviceRegistration<T>`. + /// - `seq_file` must be a valid `struct seq_file` that we can write to. + unsafe extern "C" fn show_fdinfo(seq_file: *mut bindings::seq_file, file: *mut bindings::file) { + // SAFETY: The release call of a file owns the private data. + let private = unsafe { (*file).private_data }; + // SAFETY: Ioctl calls can borrow the private data of the file. + let device = unsafe { <T::Ptr as ForeignOwnable>::borrow(private) }; + // SAFETY: + // * The file is valid for the duration of this call. + // * There is no active fdget_pos region on the file on this thread. + let file = unsafe { File::from_raw_file(file) }; + // SAFETY: The caller ensures that the pointer is valid and exclusive for the duration in + // which this method is called. + let m = unsafe { SeqFile::from_raw(seq_file) }; + + T::show_fdinfo(device, m, file); + } + + const VTABLE: bindings::file_operations = bindings::file_operations { + open: Some(Self::open), + release: Some(Self::release), + mmap: if T::HAS_MMAP { Some(Self::mmap) } else { None }, + read_iter: if T::HAS_READ_ITER { + Some(Self::read_iter) + } else { + None + }, + write_iter: if T::HAS_WRITE_ITER { + Some(Self::write_iter) + } else { + None + }, + unlocked_ioctl: if T::HAS_IOCTL { + Some(Self::ioctl) + } else { + None + }, + #[cfg(CONFIG_COMPAT)] + compat_ioctl: if T::HAS_COMPAT_IOCTL { + Some(Self::compat_ioctl) + } else if T::HAS_IOCTL { + Some(bindings::compat_ptr_ioctl) + } else { + None + }, + show_fdinfo: if T::HAS_SHOW_FDINFO { + Some(Self::show_fdinfo) + } else { + None + }, + // SAFETY: All zeros is a valid value for `bindings::file_operations`. + ..unsafe { MaybeUninit::zeroed().assume_init() } + }; + + const fn build() -> &'static bindings::file_operations { + &Self::VTABLE + } +} diff --git a/rust/kernel/mm.rs b/rust/kernel/mm.rs new file mode 100644 index 000000000000..4764d7b68f2a --- /dev/null +++ b/rust/kernel/mm.rs @@ -0,0 +1,297 @@ +// SPDX-License-Identifier: GPL-2.0 + +// Copyright (C) 2024 Google LLC. + +//! Memory management. +//! +//! This module deals with managing the address space of userspace processes. Each process has an +//! instance of [`Mm`], which keeps track of multiple VMAs (virtual memory areas). Each VMA +//! corresponds to a region of memory that the userspace process can access, and the VMA lets you +//! control what happens when userspace reads or writes to that region of memory. +//! +//! C header: [`include/linux/mm.h`](srctree/include/linux/mm.h) + +use crate::{ + bindings, + sync::aref::{ARef, AlwaysRefCounted}, + types::{NotThreadSafe, Opaque}, +}; +use core::{ops::Deref, ptr::NonNull}; + +pub mod virt; +use virt::VmaRef; + +#[cfg(CONFIG_MMU)] +pub use mmput_async::MmWithUserAsync; +mod mmput_async; + +/// A wrapper for the kernel's `struct mm_struct`. +/// +/// This represents the address space of a userspace process, so each process has one `Mm` +/// instance. It may hold many VMAs internally. +/// +/// There is a counter called `mm_users` that counts the users of the address space; this includes +/// the userspace process itself, but can also include kernel threads accessing the address space. +/// Once `mm_users` reaches zero, this indicates that the address space can be destroyed. To access +/// the address space, you must prevent `mm_users` from reaching zero while you are accessing it. +/// The [`MmWithUser`] type represents an address space where this is guaranteed, and you can +/// create one using [`mmget_not_zero`]. +/// +/// The `ARef<Mm>` smart pointer holds an `mmgrab` refcount. Its destructor may sleep. +/// +/// # Invariants +/// +/// Values of this type are always refcounted using `mmgrab`. +/// +/// [`mmget_not_zero`]: Mm::mmget_not_zero +#[repr(transparent)] +pub struct Mm { + mm: Opaque<bindings::mm_struct>, +} + +// SAFETY: It is safe to call `mmdrop` on another thread than where `mmgrab` was called. +unsafe impl Send for Mm {} +// SAFETY: All methods on `Mm` can be called in parallel from several threads. +unsafe impl Sync for Mm {} + +// SAFETY: By the type invariants, this type is always refcounted. +unsafe impl AlwaysRefCounted for Mm { + #[inline] + fn inc_ref(&self) { + // SAFETY: The pointer is valid since self is a reference. + unsafe { bindings::mmgrab(self.as_raw()) }; + } + + #[inline] + unsafe fn dec_ref(obj: NonNull<Self>) { + // SAFETY: The caller is giving up their refcount. + unsafe { bindings::mmdrop(obj.cast().as_ptr()) }; + } +} + +/// A wrapper for the kernel's `struct mm_struct`. +/// +/// This type is like [`Mm`], but with non-zero `mm_users`. It can only be used when `mm_users` can +/// be proven to be non-zero at compile-time, usually because the relevant code holds an `mmget` +/// refcount. It can be used to access the associated address space. +/// +/// The `ARef<MmWithUser>` smart pointer holds an `mmget` refcount. Its destructor may sleep. +/// +/// # Invariants +/// +/// Values of this type are always refcounted using `mmget`. The value of `mm_users` is non-zero. +#[repr(transparent)] +pub struct MmWithUser { + mm: Mm, +} + +// SAFETY: It is safe to call `mmput` on another thread than where `mmget` was called. +unsafe impl Send for MmWithUser {} +// SAFETY: All methods on `MmWithUser` can be called in parallel from several threads. +unsafe impl Sync for MmWithUser {} + +// SAFETY: By the type invariants, this type is always refcounted. +unsafe impl AlwaysRefCounted for MmWithUser { + #[inline] + fn inc_ref(&self) { + // SAFETY: The pointer is valid since self is a reference. + unsafe { bindings::mmget(self.as_raw()) }; + } + + #[inline] + unsafe fn dec_ref(obj: NonNull<Self>) { + // SAFETY: The caller is giving up their refcount. + unsafe { bindings::mmput(obj.cast().as_ptr()) }; + } +} + +// Make all `Mm` methods available on `MmWithUser`. +impl Deref for MmWithUser { + type Target = Mm; + + #[inline] + fn deref(&self) -> &Mm { + &self.mm + } +} + +// These methods are safe to call even if `mm_users` is zero. +impl Mm { + /// Returns a raw pointer to the inner `mm_struct`. + #[inline] + pub fn as_raw(&self) -> *mut bindings::mm_struct { + self.mm.get() + } + + /// Obtain a reference from a raw pointer. + /// + /// # Safety + /// + /// The caller must ensure that `ptr` points at an `mm_struct`, and that it is not deallocated + /// during the lifetime 'a. + #[inline] + pub unsafe fn from_raw<'a>(ptr: *const bindings::mm_struct) -> &'a Mm { + // SAFETY: Caller promises that the pointer is valid for 'a. Layouts are compatible due to + // repr(transparent). + unsafe { &*ptr.cast() } + } + + /// Calls `mmget_not_zero` and returns a handle if it succeeds. + #[inline] + pub fn mmget_not_zero(&self) -> Option<ARef<MmWithUser>> { + // SAFETY: The pointer is valid since self is a reference. + let success = unsafe { bindings::mmget_not_zero(self.as_raw()) }; + + if success { + // SAFETY: We just created an `mmget` refcount. + Some(unsafe { ARef::from_raw(NonNull::new_unchecked(self.as_raw().cast())) }) + } else { + None + } + } +} + +// These methods require `mm_users` to be non-zero. +impl MmWithUser { + /// Obtain a reference from a raw pointer. + /// + /// # Safety + /// + /// The caller must ensure that `ptr` points at an `mm_struct`, and that `mm_users` remains + /// non-zero for the duration of the lifetime 'a. + #[inline] + pub unsafe fn from_raw<'a>(ptr: *const bindings::mm_struct) -> &'a MmWithUser { + // SAFETY: Caller promises that the pointer is valid for 'a. The layout is compatible due + // to repr(transparent). + unsafe { &*ptr.cast() } + } + + /// Attempt to access a vma using the vma read lock. + /// + /// This is an optimistic trylock operation, so it may fail if there is contention. In that + /// case, you should fall back to taking the mmap read lock. + /// + /// When per-vma locks are disabled, this always returns `None`. + #[inline] + pub fn lock_vma_under_rcu(&self, vma_addr: usize) -> Option<VmaReadGuard<'_>> { + #[cfg(CONFIG_PER_VMA_LOCK)] + { + // SAFETY: Calling `bindings::lock_vma_under_rcu` is always okay given an mm where + // `mm_users` is non-zero. + let vma = unsafe { bindings::lock_vma_under_rcu(self.as_raw(), vma_addr) }; + if !vma.is_null() { + return Some(VmaReadGuard { + // SAFETY: If `lock_vma_under_rcu` returns a non-null ptr, then it points at a + // valid vma. The vma is stable for as long as the vma read lock is held. + vma: unsafe { VmaRef::from_raw(vma) }, + _nts: NotThreadSafe, + }); + } + } + + // Silence warnings about unused variables. + #[cfg(not(CONFIG_PER_VMA_LOCK))] + let _ = vma_addr; + + None + } + + /// Lock the mmap read lock. + #[inline] + pub fn mmap_read_lock(&self) -> MmapReadGuard<'_> { + // SAFETY: The pointer is valid since self is a reference. + unsafe { bindings::mmap_read_lock(self.as_raw()) }; + + // INVARIANT: We just acquired the read lock. + MmapReadGuard { + mm: self, + _nts: NotThreadSafe, + } + } + + /// Try to lock the mmap read lock. + #[inline] + pub fn mmap_read_trylock(&self) -> Option<MmapReadGuard<'_>> { + // SAFETY: The pointer is valid since self is a reference. + let success = unsafe { bindings::mmap_read_trylock(self.as_raw()) }; + + if success { + // INVARIANT: We just acquired the read lock. + Some(MmapReadGuard { + mm: self, + _nts: NotThreadSafe, + }) + } else { + None + } + } +} + +/// A guard for the mmap read lock. +/// +/// # Invariants +/// +/// This `MmapReadGuard` guard owns the mmap read lock. +pub struct MmapReadGuard<'a> { + mm: &'a MmWithUser, + // `mmap_read_lock` and `mmap_read_unlock` must be called on the same thread + _nts: NotThreadSafe, +} + +impl<'a> MmapReadGuard<'a> { + /// Look up a vma at the given address. + #[inline] + pub fn vma_lookup(&self, vma_addr: usize) -> Option<&virt::VmaRef> { + // SAFETY: By the type invariants we hold the mmap read guard, so we can safely call this + // method. Any value is okay for `vma_addr`. + let vma = unsafe { bindings::vma_lookup(self.mm.as_raw(), vma_addr) }; + + if vma.is_null() { + None + } else { + // SAFETY: We just checked that a vma was found, so the pointer references a valid vma. + // + // Furthermore, the returned vma is still under the protection of the read lock guard + // and can be used while the mmap read lock is still held. That the vma is not used + // after the MmapReadGuard gets dropped is enforced by the borrow-checker. + unsafe { Some(virt::VmaRef::from_raw(vma)) } + } + } +} + +impl Drop for MmapReadGuard<'_> { + #[inline] + fn drop(&mut self) { + // SAFETY: We hold the read lock by the type invariants. + unsafe { bindings::mmap_read_unlock(self.mm.as_raw()) }; + } +} + +/// A guard for the vma read lock. +/// +/// # Invariants +/// +/// This `VmaReadGuard` guard owns the vma read lock. +pub struct VmaReadGuard<'a> { + vma: &'a VmaRef, + // `vma_end_read` must be called on the same thread as where the lock was taken + _nts: NotThreadSafe, +} + +// Make all `VmaRef` methods available on `VmaReadGuard`. +impl Deref for VmaReadGuard<'_> { + type Target = VmaRef; + + #[inline] + fn deref(&self) -> &VmaRef { + self.vma + } +} + +impl Drop for VmaReadGuard<'_> { + #[inline] + fn drop(&mut self) { + // SAFETY: We hold the read lock by the type invariants. + unsafe { bindings::vma_end_read(self.vma.as_ptr()) }; + } +} diff --git a/rust/kernel/mm/mmput_async.rs b/rust/kernel/mm/mmput_async.rs new file mode 100644 index 000000000000..b8d2f051225c --- /dev/null +++ b/rust/kernel/mm/mmput_async.rs @@ -0,0 +1,68 @@ +// SPDX-License-Identifier: GPL-2.0 + +// Copyright (C) 2024 Google LLC. + +//! Version of `MmWithUser` using `mmput_async`. +//! +//! This is a separate file from `mm.rs` due to the dependency on `CONFIG_MMU=y`. +#![cfg(CONFIG_MMU)] + +use crate::{ + bindings, + mm::MmWithUser, + sync::aref::{ARef, AlwaysRefCounted}, +}; +use core::{ops::Deref, ptr::NonNull}; + +/// A wrapper for the kernel's `struct mm_struct`. +/// +/// This type is identical to `MmWithUser` except that it uses `mmput_async` when dropping a +/// refcount. This means that the destructor of `ARef<MmWithUserAsync>` is safe to call in atomic +/// context. +/// +/// # Invariants +/// +/// Values of this type are always refcounted using `mmget`. The value of `mm_users` is non-zero. +#[repr(transparent)] +pub struct MmWithUserAsync { + mm: MmWithUser, +} + +// SAFETY: It is safe to call `mmput_async` on another thread than where `mmget` was called. +unsafe impl Send for MmWithUserAsync {} +// SAFETY: All methods on `MmWithUserAsync` can be called in parallel from several threads. +unsafe impl Sync for MmWithUserAsync {} + +// SAFETY: By the type invariants, this type is always refcounted. +unsafe impl AlwaysRefCounted for MmWithUserAsync { + #[inline] + fn inc_ref(&self) { + // SAFETY: The pointer is valid since self is a reference. + unsafe { bindings::mmget(self.as_raw()) }; + } + + #[inline] + unsafe fn dec_ref(obj: NonNull<Self>) { + // SAFETY: The caller is giving up their refcount. + unsafe { bindings::mmput_async(obj.cast().as_ptr()) }; + } +} + +// Make all `MmWithUser` methods available on `MmWithUserAsync`. +impl Deref for MmWithUserAsync { + type Target = MmWithUser; + + #[inline] + fn deref(&self) -> &MmWithUser { + &self.mm + } +} + +impl MmWithUser { + /// Use `mmput_async` when dropping this refcount. + #[inline] + pub fn into_mmput_async(me: ARef<MmWithUser>) -> ARef<MmWithUserAsync> { + // SAFETY: The layouts and invariants are compatible. + unsafe { ARef::from_raw(ARef::into_raw(me).cast()) } + } +} diff --git a/rust/kernel/mm/virt.rs b/rust/kernel/mm/virt.rs new file mode 100644 index 000000000000..da21d65ccd20 --- /dev/null +++ b/rust/kernel/mm/virt.rs @@ -0,0 +1,472 @@ +// SPDX-License-Identifier: GPL-2.0 + +// Copyright (C) 2024 Google LLC. + +//! Virtual memory. +//! +//! This module deals with managing a single VMA in the address space of a userspace process. Each +//! VMA corresponds to a region of memory that the userspace process can access, and the VMA lets +//! you control what happens when userspace reads or writes to that region of memory. +//! +//! The module has several different Rust types that all correspond to the C type called +//! `vm_area_struct`. The different structs represent what kind of access you have to the VMA, e.g. +//! [`VmaRef`] is used when you hold the mmap or vma read lock. Using the appropriate struct +//! ensures that you can't, for example, accidentally call a function that requires holding the +//! write lock when you only hold the read lock. + +use crate::{ + bindings, + error::{code::EINVAL, to_result, Result}, + mm::MmWithUser, + page::Page, + types::Opaque, +}; + +use core::ops::Deref; + +/// A wrapper for the kernel's `struct vm_area_struct` with read access. +/// +/// It represents an area of virtual memory. +/// +/// # Invariants +/// +/// The caller must hold the mmap read lock or the vma read lock. +#[repr(transparent)] +pub struct VmaRef { + vma: Opaque<bindings::vm_area_struct>, +} + +// Methods you can call when holding the mmap or vma read lock (or stronger). They must be usable +// no matter what the vma flags are. +impl VmaRef { + /// Access a virtual memory area given a raw pointer. + /// + /// # Safety + /// + /// Callers must ensure that `vma` is valid for the duration of 'a, and that the mmap or vma + /// read lock (or stronger) is held for at least the duration of 'a. + #[inline] + pub unsafe fn from_raw<'a>(vma: *const bindings::vm_area_struct) -> &'a Self { + // SAFETY: The caller ensures that the invariants are satisfied for the duration of 'a. + unsafe { &*vma.cast() } + } + + /// Returns a raw pointer to this area. + #[inline] + pub fn as_ptr(&self) -> *mut bindings::vm_area_struct { + self.vma.get() + } + + /// Access the underlying `mm_struct`. + #[inline] + pub fn mm(&self) -> &MmWithUser { + // SAFETY: By the type invariants, this `vm_area_struct` is valid and we hold the mmap/vma + // read lock or stronger. This implies that the underlying mm has a non-zero value of + // `mm_users`. + unsafe { MmWithUser::from_raw((*self.as_ptr()).vm_mm) } + } + + /// Returns the flags associated with the virtual memory area. + /// + /// The possible flags are a combination of the constants in [`flags`]. + #[inline] + pub fn flags(&self) -> vm_flags_t { + // SAFETY: By the type invariants, the caller holds at least the mmap read lock, so this + // access is not a data race. + unsafe { (*self.as_ptr()).__bindgen_anon_2.vm_flags } + } + + /// Returns the (inclusive) start address of the virtual memory area. + #[inline] + pub fn start(&self) -> usize { + // SAFETY: By the type invariants, the caller holds at least the mmap read lock, so this + // access is not a data race. + unsafe { (*self.as_ptr()).__bindgen_anon_1.__bindgen_anon_1.vm_start } + } + + /// Returns the (exclusive) end address of the virtual memory area. + #[inline] + pub fn end(&self) -> usize { + // SAFETY: By the type invariants, the caller holds at least the mmap read lock, so this + // access is not a data race. + unsafe { (*self.as_ptr()).__bindgen_anon_1.__bindgen_anon_1.vm_end } + } + + /// Zap pages in the given page range. + /// + /// This clears page table mappings for the range at the leaf level, leaving all other page + /// tables intact, and freeing any memory referenced by the VMA in this range. That is, + /// anonymous memory is completely freed, file-backed memory has its reference count on page + /// cache folio's dropped, any dirty data will still be written back to disk as usual. + /// + /// It may seem odd that we clear at the leaf level, this is however a product of the page + /// table structure used to map physical memory into a virtual address space - each virtual + /// address actually consists of a bitmap of array indices into page tables, which form a + /// hierarchical page table level structure. + /// + /// As a result, each page table level maps a multiple of page table levels below, and thus + /// span ever increasing ranges of pages. At the leaf or PTE level, we map the actual physical + /// memory. + /// + /// It is here where a zap operates, as it the only place we can be certain of clearing without + /// impacting any other virtual mappings. It is an implementation detail as to whether the + /// kernel goes further in freeing unused page tables, but for the purposes of this operation + /// we must only assume that the leaf level is cleared. + #[inline] + pub fn zap_page_range_single(&self, address: usize, size: usize) { + let (end, did_overflow) = address.overflowing_add(size); + if did_overflow || address < self.start() || self.end() < end { + // TODO: call WARN_ONCE once Rust version of it is added + return; + } + + // SAFETY: By the type invariants, the caller has read access to this VMA, which is + // sufficient for this method call. This method has no requirements on the vma flags. The + // address range is checked to be within the vma. + unsafe { + bindings::zap_page_range_single(self.as_ptr(), address, size, core::ptr::null_mut()) + }; + } + + /// If the [`VM_MIXEDMAP`] flag is set, returns a [`VmaMixedMap`] to this VMA, otherwise + /// returns `None`. + /// + /// This can be used to access methods that require [`VM_MIXEDMAP`] to be set. + /// + /// [`VM_MIXEDMAP`]: flags::MIXEDMAP + #[inline] + pub fn as_mixedmap_vma(&self) -> Option<&VmaMixedMap> { + if self.flags() & flags::MIXEDMAP != 0 { + // SAFETY: We just checked that `VM_MIXEDMAP` is set. All other requirements are + // satisfied by the type invariants of `VmaRef`. + Some(unsafe { VmaMixedMap::from_raw(self.as_ptr()) }) + } else { + None + } + } +} + +/// A wrapper for the kernel's `struct vm_area_struct` with read access and [`VM_MIXEDMAP`] set. +/// +/// It represents an area of virtual memory. +/// +/// This struct is identical to [`VmaRef`] except that it must only be used when the +/// [`VM_MIXEDMAP`] flag is set on the vma. +/// +/// # Invariants +/// +/// The caller must hold the mmap read lock or the vma read lock. The `VM_MIXEDMAP` flag must be +/// set. +/// +/// [`VM_MIXEDMAP`]: flags::MIXEDMAP +#[repr(transparent)] +pub struct VmaMixedMap { + vma: VmaRef, +} + +// Make all `VmaRef` methods available on `VmaMixedMap`. +impl Deref for VmaMixedMap { + type Target = VmaRef; + + #[inline] + fn deref(&self) -> &VmaRef { + &self.vma + } +} + +impl VmaMixedMap { + /// Access a virtual memory area given a raw pointer. + /// + /// # Safety + /// + /// Callers must ensure that `vma` is valid for the duration of 'a, and that the mmap read lock + /// (or stronger) is held for at least the duration of 'a. The `VM_MIXEDMAP` flag must be set. + #[inline] + pub unsafe fn from_raw<'a>(vma: *const bindings::vm_area_struct) -> &'a Self { + // SAFETY: The caller ensures that the invariants are satisfied for the duration of 'a. + unsafe { &*vma.cast() } + } + + /// Maps a single page at the given address within the virtual memory area. + /// + /// This operation does not take ownership of the page. + #[inline] + pub fn vm_insert_page(&self, address: usize, page: &Page) -> Result { + // SAFETY: By the type invariant of `Self` caller has read access and has verified that + // `VM_MIXEDMAP` is set. By invariant on `Page` the page has order 0. + to_result(unsafe { bindings::vm_insert_page(self.as_ptr(), address, page.as_ptr()) }) + } +} + +/// A configuration object for setting up a VMA in an `f_ops->mmap()` hook. +/// +/// The `f_ops->mmap()` hook is called when a new VMA is being created, and the hook is able to +/// configure the VMA in various ways to fit the driver that owns it. Using `VmaNew` indicates that +/// you are allowed to perform operations on the VMA that can only be performed before the VMA is +/// fully initialized. +/// +/// # Invariants +/// +/// For the duration of 'a, the referenced vma must be undergoing initialization in an +/// `f_ops->mmap()` hook. +#[repr(transparent)] +pub struct VmaNew { + vma: VmaRef, +} + +// Make all `VmaRef` methods available on `VmaNew`. +impl Deref for VmaNew { + type Target = VmaRef; + + #[inline] + fn deref(&self) -> &VmaRef { + &self.vma + } +} + +impl VmaNew { + /// Access a virtual memory area given a raw pointer. + /// + /// # Safety + /// + /// Callers must ensure that `vma` is undergoing initial vma setup for the duration of 'a. + #[inline] + pub unsafe fn from_raw<'a>(vma: *mut bindings::vm_area_struct) -> &'a Self { + // SAFETY: The caller ensures that the invariants are satisfied for the duration of 'a. + unsafe { &*vma.cast() } + } + + /// Internal method for updating the vma flags. + /// + /// # Safety + /// + /// This must not be used to set the flags to an invalid value. + #[inline] + unsafe fn update_flags(&self, set: vm_flags_t, unset: vm_flags_t) { + let mut flags = self.flags(); + flags |= set; + flags &= !unset; + + // SAFETY: This is not a data race: the vma is undergoing initial setup, so it's not yet + // shared. Additionally, `VmaNew` is `!Sync`, so it cannot be used to write in parallel. + // The caller promises that this does not set the flags to an invalid value. + unsafe { (*self.as_ptr()).__bindgen_anon_2.vm_flags = flags }; + } + + /// Set the `VM_MIXEDMAP` flag on this vma. + /// + /// This enables the vma to contain both `struct page` and pure PFN pages. Returns a reference + /// that can be used to call `vm_insert_page` on the vma. + #[inline] + pub fn set_mixedmap(&self) -> &VmaMixedMap { + // SAFETY: We don't yet provide a way to set VM_PFNMAP, so this cannot put the flags in an + // invalid state. + unsafe { self.update_flags(flags::MIXEDMAP, 0) }; + + // SAFETY: We just set `VM_MIXEDMAP` on the vma. + unsafe { VmaMixedMap::from_raw(self.vma.as_ptr()) } + } + + /// Set the `VM_IO` flag on this vma. + /// + /// This is used for memory mapped IO and similar. The flag tells other parts of the kernel to + /// avoid looking at the pages. For memory mapped IO this is useful as accesses to the pages + /// could have side effects. + #[inline] + pub fn set_io(&self) { + // SAFETY: Setting the VM_IO flag is always okay. + unsafe { self.update_flags(flags::IO, 0) }; + } + + /// Set the `VM_DONTEXPAND` flag on this vma. + /// + /// This prevents the vma from being expanded with `mremap()`. + #[inline] + pub fn set_dontexpand(&self) { + // SAFETY: Setting the VM_DONTEXPAND flag is always okay. + unsafe { self.update_flags(flags::DONTEXPAND, 0) }; + } + + /// Set the `VM_DONTCOPY` flag on this vma. + /// + /// This prevents the vma from being copied on fork. This option is only permanent if `VM_IO` + /// is set. + #[inline] + pub fn set_dontcopy(&self) { + // SAFETY: Setting the VM_DONTCOPY flag is always okay. + unsafe { self.update_flags(flags::DONTCOPY, 0) }; + } + + /// Set the `VM_DONTDUMP` flag on this vma. + /// + /// This prevents the vma from being included in core dumps. This option is only permanent if + /// `VM_IO` is set. + #[inline] + pub fn set_dontdump(&self) { + // SAFETY: Setting the VM_DONTDUMP flag is always okay. + unsafe { self.update_flags(flags::DONTDUMP, 0) }; + } + + /// Returns whether `VM_READ` is set. + /// + /// This flag indicates whether userspace is mapping this vma as readable. + #[inline] + pub fn readable(&self) -> bool { + (self.flags() & flags::READ) != 0 + } + + /// Try to clear the `VM_MAYREAD` flag, failing if `VM_READ` is set. + /// + /// This flag indicates whether userspace is allowed to make this vma readable with + /// `mprotect()`. + /// + /// Note that this operation is irreversible. Once `VM_MAYREAD` has been cleared, it can never + /// be set again. + #[inline] + pub fn try_clear_mayread(&self) -> Result { + if self.readable() { + return Err(EINVAL); + } + // SAFETY: Clearing `VM_MAYREAD` is okay when `VM_READ` is not set. + unsafe { self.update_flags(0, flags::MAYREAD) }; + Ok(()) + } + + /// Returns whether `VM_WRITE` is set. + /// + /// This flag indicates whether userspace is mapping this vma as writable. + #[inline] + pub fn writable(&self) -> bool { + (self.flags() & flags::WRITE) != 0 + } + + /// Try to clear the `VM_MAYWRITE` flag, failing if `VM_WRITE` is set. + /// + /// This flag indicates whether userspace is allowed to make this vma writable with + /// `mprotect()`. + /// + /// Note that this operation is irreversible. Once `VM_MAYWRITE` has been cleared, it can never + /// be set again. + #[inline] + pub fn try_clear_maywrite(&self) -> Result { + if self.writable() { + return Err(EINVAL); + } + // SAFETY: Clearing `VM_MAYWRITE` is okay when `VM_WRITE` is not set. + unsafe { self.update_flags(0, flags::MAYWRITE) }; + Ok(()) + } + + /// Returns whether `VM_EXEC` is set. + /// + /// This flag indicates whether userspace is mapping this vma as executable. + #[inline] + pub fn executable(&self) -> bool { + (self.flags() & flags::EXEC) != 0 + } + + /// Try to clear the `VM_MAYEXEC` flag, failing if `VM_EXEC` is set. + /// + /// This flag indicates whether userspace is allowed to make this vma executable with + /// `mprotect()`. + /// + /// Note that this operation is irreversible. Once `VM_MAYEXEC` has been cleared, it can never + /// be set again. + #[inline] + pub fn try_clear_mayexec(&self) -> Result { + if self.executable() { + return Err(EINVAL); + } + // SAFETY: Clearing `VM_MAYEXEC` is okay when `VM_EXEC` is not set. + unsafe { self.update_flags(0, flags::MAYEXEC) }; + Ok(()) + } +} + +/// The integer type used for vma flags. +#[doc(inline)] +pub use bindings::vm_flags_t; + +/// All possible flags for [`VmaRef`]. +pub mod flags { + use super::vm_flags_t; + use crate::bindings; + + /// No flags are set. + pub const NONE: vm_flags_t = bindings::VM_NONE as vm_flags_t; + + /// Mapping allows reads. + pub const READ: vm_flags_t = bindings::VM_READ as vm_flags_t; + + /// Mapping allows writes. + pub const WRITE: vm_flags_t = bindings::VM_WRITE as vm_flags_t; + + /// Mapping allows execution. + pub const EXEC: vm_flags_t = bindings::VM_EXEC as vm_flags_t; + + /// Mapping is shared. + pub const SHARED: vm_flags_t = bindings::VM_SHARED as vm_flags_t; + + /// Mapping may be updated to allow reads. + pub const MAYREAD: vm_flags_t = bindings::VM_MAYREAD as vm_flags_t; + + /// Mapping may be updated to allow writes. + pub const MAYWRITE: vm_flags_t = bindings::VM_MAYWRITE as vm_flags_t; + + /// Mapping may be updated to allow execution. + pub const MAYEXEC: vm_flags_t = bindings::VM_MAYEXEC as vm_flags_t; + + /// Mapping may be updated to be shared. + pub const MAYSHARE: vm_flags_t = bindings::VM_MAYSHARE as vm_flags_t; + + /// Page-ranges managed without `struct page`, just pure PFN. + pub const PFNMAP: vm_flags_t = bindings::VM_PFNMAP as vm_flags_t; + + /// Memory mapped I/O or similar. + pub const IO: vm_flags_t = bindings::VM_IO as vm_flags_t; + + /// Do not copy this vma on fork. + pub const DONTCOPY: vm_flags_t = bindings::VM_DONTCOPY as vm_flags_t; + + /// Cannot expand with mremap(). + pub const DONTEXPAND: vm_flags_t = bindings::VM_DONTEXPAND as vm_flags_t; + + /// Lock the pages covered when they are faulted in. + pub const LOCKONFAULT: vm_flags_t = bindings::VM_LOCKONFAULT as vm_flags_t; + + /// Is a VM accounted object. + pub const ACCOUNT: vm_flags_t = bindings::VM_ACCOUNT as vm_flags_t; + + /// Should the VM suppress accounting. + pub const NORESERVE: vm_flags_t = bindings::VM_NORESERVE as vm_flags_t; + + /// Huge TLB Page VM. + pub const HUGETLB: vm_flags_t = bindings::VM_HUGETLB as vm_flags_t; + + /// Synchronous page faults. (DAX-specific) + pub const SYNC: vm_flags_t = bindings::VM_SYNC as vm_flags_t; + + /// Architecture-specific flag. + pub const ARCH_1: vm_flags_t = bindings::VM_ARCH_1 as vm_flags_t; + + /// Wipe VMA contents in child on fork. + pub const WIPEONFORK: vm_flags_t = bindings::VM_WIPEONFORK as vm_flags_t; + + /// Do not include in the core dump. + pub const DONTDUMP: vm_flags_t = bindings::VM_DONTDUMP as vm_flags_t; + + /// Not soft dirty clean area. + pub const SOFTDIRTY: vm_flags_t = bindings::VM_SOFTDIRTY as vm_flags_t; + + /// Can contain `struct page` and pure PFN pages. + pub const MIXEDMAP: vm_flags_t = bindings::VM_MIXEDMAP as vm_flags_t; + + /// MADV_HUGEPAGE marked this vma. + pub const HUGEPAGE: vm_flags_t = bindings::VM_HUGEPAGE as vm_flags_t; + + /// MADV_NOHUGEPAGE marked this vma. + pub const NOHUGEPAGE: vm_flags_t = bindings::VM_NOHUGEPAGE as vm_flags_t; + + /// KSM may merge identical pages. + pub const MERGEABLE: vm_flags_t = bindings::VM_MERGEABLE as vm_flags_t; +} diff --git a/rust/kernel/module_param.rs b/rust/kernel/module_param.rs new file mode 100644 index 000000000000..6a8a7a875643 --- /dev/null +++ b/rust/kernel/module_param.rs @@ -0,0 +1,182 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! Support for module parameters. +//! +//! C header: [`include/linux/moduleparam.h`](srctree/include/linux/moduleparam.h) + +use crate::prelude::*; +use crate::str::BStr; +use bindings; +use kernel::sync::SetOnce; + +/// Newtype to make `bindings::kernel_param` [`Sync`]. +#[repr(transparent)] +#[doc(hidden)] +pub struct KernelParam(bindings::kernel_param); + +impl KernelParam { + #[doc(hidden)] + pub const fn new(val: bindings::kernel_param) -> Self { + Self(val) + } +} + +// SAFETY: C kernel handles serializing access to this type. We never access it +// from Rust module. +unsafe impl Sync for KernelParam {} + +/// Types that can be used for module parameters. +// NOTE: This trait is `Copy` because drop could produce unsoundness during teardown. +pub trait ModuleParam: Sized + Copy { + /// Parse a parameter argument into the parameter value. + fn try_from_param_arg(arg: &BStr) -> Result<Self>; +} + +/// Set the module parameter from a string. +/// +/// Used to set the parameter value at kernel initialization, when loading +/// the module or when set through `sysfs`. +/// +/// See `struct kernel_param_ops.set`. +/// +/// # Safety +/// +/// - If `val` is non-null then it must point to a valid null-terminated string that must be valid +/// for reads for the duration of the call. +/// - `param` must be a pointer to a `bindings::kernel_param` initialized by the rust module macro. +/// The pointee must be valid for reads for the duration of the call. +/// +/// # Note +/// +/// - The safety requirements are satisfied by C API contract when this function is invoked by the +/// module subsystem C code. +/// - Currently, we only support read-only parameters that are not readable from `sysfs`. Thus, this +/// function is only called at kernel initialization time, or at module load time, and we have +/// exclusive access to the parameter for the duration of the function. +/// +/// [`module!`]: macros::module +unsafe extern "C" fn set_param<T>(val: *const c_char, param: *const bindings::kernel_param) -> c_int +where + T: ModuleParam, +{ + // NOTE: If we start supporting arguments without values, val _is_ allowed + // to be null here. + if val.is_null() { + // TODO: Use pr_warn_once available. + crate::pr_warn!("Null pointer passed to `module_param::set_param`"); + return EINVAL.to_errno(); + } + + // SAFETY: By function safety requirement, val is non-null, null-terminated + // and valid for reads for the duration of this function. + let arg = unsafe { CStr::from_char_ptr(val) }; + let arg: &BStr = arg.as_ref(); + + crate::error::from_result(|| { + let new_value = T::try_from_param_arg(arg)?; + + // SAFETY: By function safety requirements, this access is safe. + let container = unsafe { &*((*param).__bindgen_anon_1.arg.cast::<SetOnce<T>>()) }; + + container + .populate(new_value) + .then_some(0) + .ok_or(kernel::error::code::EEXIST) + }) +} + +macro_rules! impl_int_module_param { + ($ty:ident) => { + impl ModuleParam for $ty { + fn try_from_param_arg(arg: &BStr) -> Result<Self> { + <$ty as crate::str::parse_int::ParseInt>::from_str(arg) + } + } + }; +} + +impl_int_module_param!(i8); +impl_int_module_param!(u8); +impl_int_module_param!(i16); +impl_int_module_param!(u16); +impl_int_module_param!(i32); +impl_int_module_param!(u32); +impl_int_module_param!(i64); +impl_int_module_param!(u64); +impl_int_module_param!(isize); +impl_int_module_param!(usize); + +/// A wrapper for kernel parameters. +/// +/// This type is instantiated by the [`module!`] macro when module parameters are +/// defined. You should never need to instantiate this type directly. +/// +/// Note: This type is `pub` because it is used by module crates to access +/// parameter values. +pub struct ModuleParamAccess<T> { + value: SetOnce<T>, + default: T, +} + +// SAFETY: We only create shared references to the contents of this container, +// so if `T` is `Sync`, so is `ModuleParamAccess`. +unsafe impl<T: Sync> Sync for ModuleParamAccess<T> {} + +impl<T> ModuleParamAccess<T> { + #[doc(hidden)] + pub const fn new(default: T) -> Self { + Self { + value: SetOnce::new(), + default, + } + } + + /// Get a shared reference to the parameter value. + // Note: When sysfs access to parameters are enabled, we have to pass in a + // held lock guard here. + pub fn value(&self) -> &T { + self.value.as_ref().unwrap_or(&self.default) + } + + /// Get a mutable pointer to `self`. + /// + /// NOTE: In most cases it is not safe deref the returned pointer. + pub const fn as_void_ptr(&self) -> *mut c_void { + core::ptr::from_ref(self).cast_mut().cast() + } +} + +#[doc(hidden)] +/// Generate a static [`kernel_param_ops`](srctree/include/linux/moduleparam.h) struct. +/// +/// # Examples +/// +/// ```ignore +/// make_param_ops!( +/// /// Documentation for new param ops. +/// PARAM_OPS_MYTYPE, // Name for the static. +/// MyType // A type which implements [`ModuleParam`]. +/// ); +/// ``` +macro_rules! make_param_ops { + ($ops:ident, $ty:ty) => { + #[doc(hidden)] + pub static $ops: $crate::bindings::kernel_param_ops = $crate::bindings::kernel_param_ops { + flags: 0, + set: Some(set_param::<$ty>), + get: None, + free: None, + }; + }; +} + +make_param_ops!(PARAM_OPS_I8, i8); +make_param_ops!(PARAM_OPS_U8, u8); +make_param_ops!(PARAM_OPS_I16, i16); +make_param_ops!(PARAM_OPS_U16, u16); +make_param_ops!(PARAM_OPS_I32, i32); +make_param_ops!(PARAM_OPS_U32, u32); +make_param_ops!(PARAM_OPS_I64, i64); +make_param_ops!(PARAM_OPS_U64, u64); +make_param_ops!(PARAM_OPS_ISIZE, isize); +make_param_ops!(PARAM_OPS_USIZE, usize); diff --git a/rust/kernel/net.rs b/rust/kernel/net.rs new file mode 100644 index 000000000000..fe415cb369d3 --- /dev/null +++ b/rust/kernel/net.rs @@ -0,0 +1,6 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! Networking. + +#[cfg(CONFIG_RUST_PHYLIB_ABSTRACTIONS)] +pub mod phy; diff --git a/rust/kernel/net/phy.rs b/rust/kernel/net/phy.rs new file mode 100644 index 000000000000..bf6272d87a7b --- /dev/null +++ b/rust/kernel/net/phy.rs @@ -0,0 +1,909 @@ +// SPDX-License-Identifier: GPL-2.0 + +// Copyright (C) 2023 FUJITA Tomonori <fujita.tomonori@gmail.com> + +//! Network PHY device. +//! +//! C headers: [`include/linux/phy.h`](srctree/include/linux/phy.h). + +use crate::{device_id::RawDeviceId, error::*, prelude::*, types::Opaque}; +use core::{marker::PhantomData, ptr::addr_of_mut}; + +pub mod reg; + +/// PHY state machine states. +/// +/// Corresponds to the kernel's [`enum phy_state`]. +/// +/// Some of PHY drivers access to the state of PHY's software state machine. +/// +/// [`enum phy_state`]: srctree/include/linux/phy.h +#[derive(PartialEq, Eq)] +pub enum DeviceState { + /// PHY device and driver are not ready for anything. + Down, + /// PHY is ready to send and receive packets. + Ready, + /// PHY is up, but no polling or interrupts are done. + Halted, + /// PHY is up, but is in an error state. + Error, + /// PHY and attached device are ready to do work. + Up, + /// PHY is currently running. + Running, + /// PHY is up, but not currently plugged in. + NoLink, + /// PHY is performing a cable test. + CableTest, +} + +/// A mode of Ethernet communication. +/// +/// PHY drivers get duplex information from hardware and update the current state. +pub enum DuplexMode { + /// PHY is in full-duplex mode. + Full, + /// PHY is in half-duplex mode. + Half, + /// PHY is in unknown duplex mode. + Unknown, +} + +/// An instance of a PHY device. +/// +/// Wraps the kernel's [`struct phy_device`]. +/// +/// A [`Device`] instance is created when a callback in [`Driver`] is executed. A PHY driver +/// executes [`Driver`]'s methods during the callback. +/// +/// # Invariants +/// +/// - Referencing a `phy_device` using this struct asserts that you are in +/// a context where all methods defined on this struct are safe to call. +/// - This struct always has a valid `self.0.mdio.dev`. +/// +/// [`struct phy_device`]: srctree/include/linux/phy.h +// During the calls to most functions in [`Driver`], the C side (`PHYLIB`) holds a lock that is +// unique for every instance of [`Device`]. `PHYLIB` uses a different serialization technique for +// [`Driver::resume`] and [`Driver::suspend`]: `PHYLIB` updates `phy_device`'s state with +// the lock held, thus guaranteeing that [`Driver::resume`] has exclusive access to the instance. +// [`Driver::resume`] and [`Driver::suspend`] also are called where only one thread can access +// to the instance. +#[repr(transparent)] +pub struct Device(Opaque<bindings::phy_device>); + +impl Device { + /// Creates a new [`Device`] instance from a raw pointer. + /// + /// # Safety + /// + /// For the duration of `'a`, + /// - the pointer must point at a valid `phy_device`, and the caller + /// must be in a context where all methods defined on this struct + /// are safe to call. + /// - `(*ptr).mdio.dev` must be a valid. + unsafe fn from_raw<'a>(ptr: *mut bindings::phy_device) -> &'a mut Self { + // CAST: `Self` is a `repr(transparent)` wrapper around `bindings::phy_device`. + let ptr = ptr.cast::<Self>(); + // SAFETY: by the function requirements the pointer is valid and we have unique access for + // the duration of `'a`. + unsafe { &mut *ptr } + } + + /// Gets the id of the PHY. + pub fn phy_id(&self) -> u32 { + let phydev = self.0.get(); + // SAFETY: The struct invariant ensures that we may access + // this field without additional synchronization. + unsafe { (*phydev).phy_id } + } + + /// Gets the state of PHY state machine states. + pub fn state(&self) -> DeviceState { + let phydev = self.0.get(); + // SAFETY: The struct invariant ensures that we may access + // this field without additional synchronization. + let state = unsafe { (*phydev).state }; + // TODO: this conversion code will be replaced with automatically generated code by bindgen + // when it becomes possible. + match state { + bindings::phy_state_PHY_DOWN => DeviceState::Down, + bindings::phy_state_PHY_READY => DeviceState::Ready, + bindings::phy_state_PHY_HALTED => DeviceState::Halted, + bindings::phy_state_PHY_ERROR => DeviceState::Error, + bindings::phy_state_PHY_UP => DeviceState::Up, + bindings::phy_state_PHY_RUNNING => DeviceState::Running, + bindings::phy_state_PHY_NOLINK => DeviceState::NoLink, + bindings::phy_state_PHY_CABLETEST => DeviceState::CableTest, + _ => DeviceState::Error, + } + } + + /// Gets the current link state. + /// + /// It returns true if the link is up. + pub fn is_link_up(&self) -> bool { + const LINK_IS_UP: u64 = 1; + // TODO: the code to access to the bit field will be replaced with automatically + // generated code by bindgen when it becomes possible. + // SAFETY: The struct invariant ensures that we may access + // this field without additional synchronization. + let bit_field = unsafe { &(*self.0.get())._bitfield_1 }; + bit_field.get(14, 1) == LINK_IS_UP + } + + /// Gets the current auto-negotiation configuration. + /// + /// It returns true if auto-negotiation is enabled. + pub fn is_autoneg_enabled(&self) -> bool { + // TODO: the code to access to the bit field will be replaced with automatically + // generated code by bindgen when it becomes possible. + // SAFETY: The struct invariant ensures that we may access + // this field without additional synchronization. + let bit_field = unsafe { &(*self.0.get())._bitfield_1 }; + bit_field.get(13, 1) == u64::from(bindings::AUTONEG_ENABLE) + } + + /// Gets the current auto-negotiation state. + /// + /// It returns true if auto-negotiation is completed. + pub fn is_autoneg_completed(&self) -> bool { + const AUTONEG_COMPLETED: u64 = 1; + // TODO: the code to access to the bit field will be replaced with automatically + // generated code by bindgen when it becomes possible. + // SAFETY: The struct invariant ensures that we may access + // this field without additional synchronization. + let bit_field = unsafe { &(*self.0.get())._bitfield_1 }; + bit_field.get(15, 1) == AUTONEG_COMPLETED + } + + /// Sets the speed of the PHY. + pub fn set_speed(&mut self, speed: u32) { + let phydev = self.0.get(); + // SAFETY: The struct invariant ensures that we may access + // this field without additional synchronization. + unsafe { (*phydev).speed = speed as c_int }; + } + + /// Sets duplex mode. + pub fn set_duplex(&mut self, mode: DuplexMode) { + let phydev = self.0.get(); + let v = match mode { + DuplexMode::Full => bindings::DUPLEX_FULL, + DuplexMode::Half => bindings::DUPLEX_HALF, + DuplexMode::Unknown => bindings::DUPLEX_UNKNOWN, + }; + // SAFETY: The struct invariant ensures that we may access + // this field without additional synchronization. + unsafe { (*phydev).duplex = v as c_int }; + } + + /// Reads a PHY register. + // This function reads a hardware register and updates the stats so takes `&mut self`. + pub fn read<R: reg::Register>(&mut self, reg: R) -> Result<u16> { + reg.read(self) + } + + /// Writes a PHY register. + pub fn write<R: reg::Register>(&mut self, reg: R, val: u16) -> Result { + reg.write(self, val) + } + + /// Reads a paged register. + pub fn read_paged(&mut self, page: u16, regnum: u16) -> Result<u16> { + let phydev = self.0.get(); + // SAFETY: `phydev` is pointing to a valid object by the type invariant of `Self`. + // So it's just an FFI call. + let ret = unsafe { bindings::phy_read_paged(phydev, page.into(), regnum.into()) }; + + to_result(ret).map(|()| ret as u16) + } + + /// Resolves the advertisements into PHY settings. + pub fn resolve_aneg_linkmode(&mut self) { + let phydev = self.0.get(); + // SAFETY: `phydev` is pointing to a valid object by the type invariant of `Self`. + // So it's just an FFI call. + unsafe { bindings::phy_resolve_aneg_linkmode(phydev) }; + } + + /// Executes software reset the PHY via `BMCR_RESET` bit. + pub fn genphy_soft_reset(&mut self) -> Result { + let phydev = self.0.get(); + // SAFETY: `phydev` is pointing to a valid object by the type invariant of `Self`. + // So it's just an FFI call. + to_result(unsafe { bindings::genphy_soft_reset(phydev) }) + } + + /// Initializes the PHY. + pub fn init_hw(&mut self) -> Result { + let phydev = self.0.get(); + // SAFETY: `phydev` is pointing to a valid object by the type invariant of `Self`. + // So it's just an FFI call. + to_result(unsafe { bindings::phy_init_hw(phydev) }) + } + + /// Starts auto-negotiation. + pub fn start_aneg(&mut self) -> Result { + let phydev = self.0.get(); + // SAFETY: `phydev` is pointing to a valid object by the type invariant of `Self`. + // So it's just an FFI call. + to_result(unsafe { bindings::_phy_start_aneg(phydev) }) + } + + /// Resumes the PHY via `BMCR_PDOWN` bit. + pub fn genphy_resume(&mut self) -> Result { + let phydev = self.0.get(); + // SAFETY: `phydev` is pointing to a valid object by the type invariant of `Self`. + // So it's just an FFI call. + to_result(unsafe { bindings::genphy_resume(phydev) }) + } + + /// Suspends the PHY via `BMCR_PDOWN` bit. + pub fn genphy_suspend(&mut self) -> Result { + let phydev = self.0.get(); + // SAFETY: `phydev` is pointing to a valid object by the type invariant of `Self`. + // So it's just an FFI call. + to_result(unsafe { bindings::genphy_suspend(phydev) }) + } + + /// Checks the link status and updates current link state. + pub fn genphy_read_status<R: reg::Register>(&mut self) -> Result<u16> { + R::read_status(self) + } + + /// Updates the link status. + pub fn genphy_update_link(&mut self) -> Result { + let phydev = self.0.get(); + // SAFETY: `phydev` is pointing to a valid object by the type invariant of `Self`. + // So it's just an FFI call. + to_result(unsafe { bindings::genphy_update_link(phydev) }) + } + + /// Reads link partner ability. + pub fn genphy_read_lpa(&mut self) -> Result { + let phydev = self.0.get(); + // SAFETY: `phydev` is pointing to a valid object by the type invariant of `Self`. + // So it's just an FFI call. + to_result(unsafe { bindings::genphy_read_lpa(phydev) }) + } + + /// Reads PHY abilities. + pub fn genphy_read_abilities(&mut self) -> Result { + let phydev = self.0.get(); + // SAFETY: `phydev` is pointing to a valid object by the type invariant of `Self`. + // So it's just an FFI call. + to_result(unsafe { bindings::genphy_read_abilities(phydev) }) + } +} + +impl AsRef<kernel::device::Device> for Device { + fn as_ref(&self) -> &kernel::device::Device { + let phydev = self.0.get(); + // SAFETY: The struct invariant ensures that `mdio.dev` is valid. + unsafe { kernel::device::Device::from_raw(addr_of_mut!((*phydev).mdio.dev)) } + } +} + +/// Defines certain other features this PHY supports (like interrupts). +/// +/// These flag values are used in [`Driver::FLAGS`]. +pub mod flags { + /// PHY is internal. + pub const IS_INTERNAL: u32 = bindings::PHY_IS_INTERNAL; + /// PHY needs to be reset after the refclk is enabled. + pub const RST_AFTER_CLK_EN: u32 = bindings::PHY_RST_AFTER_CLK_EN; + /// Polling is used to detect PHY status changes. + pub const POLL_CABLE_TEST: u32 = bindings::PHY_POLL_CABLE_TEST; + /// Don't suspend. + pub const ALWAYS_CALL_SUSPEND: u32 = bindings::PHY_ALWAYS_CALL_SUSPEND; +} + +/// An adapter for the registration of a PHY driver. +struct Adapter<T: Driver> { + _p: PhantomData<T>, +} + +impl<T: Driver> Adapter<T> { + /// # Safety + /// + /// `phydev` must be passed by the corresponding callback in `phy_driver`. + unsafe extern "C" fn soft_reset_callback(phydev: *mut bindings::phy_device) -> c_int { + from_result(|| { + // SAFETY: This callback is called only in contexts + // where we hold `phy_device->lock`, so the accessors on + // `Device` are okay to call. + let dev = unsafe { Device::from_raw(phydev) }; + T::soft_reset(dev)?; + Ok(0) + }) + } + + /// # Safety + /// + /// `phydev` must be passed by the corresponding callback in `phy_driver`. + unsafe extern "C" fn probe_callback(phydev: *mut bindings::phy_device) -> c_int { + from_result(|| { + // SAFETY: This callback is called only in contexts + // where we can exclusively access `phy_device` because + // it's not published yet, so the accessors on `Device` are okay + // to call. + let dev = unsafe { Device::from_raw(phydev) }; + T::probe(dev)?; + Ok(0) + }) + } + + /// # Safety + /// + /// `phydev` must be passed by the corresponding callback in `phy_driver`. + unsafe extern "C" fn get_features_callback(phydev: *mut bindings::phy_device) -> c_int { + from_result(|| { + // SAFETY: This callback is called only in contexts + // where we hold `phy_device->lock`, so the accessors on + // `Device` are okay to call. + let dev = unsafe { Device::from_raw(phydev) }; + T::get_features(dev)?; + Ok(0) + }) + } + + /// # Safety + /// + /// `phydev` must be passed by the corresponding callback in `phy_driver`. + unsafe extern "C" fn suspend_callback(phydev: *mut bindings::phy_device) -> c_int { + from_result(|| { + // SAFETY: The C core code ensures that the accessors on + // `Device` are okay to call even though `phy_device->lock` + // might not be held. + let dev = unsafe { Device::from_raw(phydev) }; + T::suspend(dev)?; + Ok(0) + }) + } + + /// # Safety + /// + /// `phydev` must be passed by the corresponding callback in `phy_driver`. + unsafe extern "C" fn resume_callback(phydev: *mut bindings::phy_device) -> c_int { + from_result(|| { + // SAFETY: The C core code ensures that the accessors on + // `Device` are okay to call even though `phy_device->lock` + // might not be held. + let dev = unsafe { Device::from_raw(phydev) }; + T::resume(dev)?; + Ok(0) + }) + } + + /// # Safety + /// + /// `phydev` must be passed by the corresponding callback in `phy_driver`. + unsafe extern "C" fn config_aneg_callback(phydev: *mut bindings::phy_device) -> c_int { + from_result(|| { + // SAFETY: This callback is called only in contexts + // where we hold `phy_device->lock`, so the accessors on + // `Device` are okay to call. + let dev = unsafe { Device::from_raw(phydev) }; + T::config_aneg(dev)?; + Ok(0) + }) + } + + /// # Safety + /// + /// `phydev` must be passed by the corresponding callback in `phy_driver`. + unsafe extern "C" fn read_status_callback(phydev: *mut bindings::phy_device) -> c_int { + from_result(|| { + // SAFETY: This callback is called only in contexts + // where we hold `phy_device->lock`, so the accessors on + // `Device` are okay to call. + let dev = unsafe { Device::from_raw(phydev) }; + T::read_status(dev)?; + Ok(0) + }) + } + + /// # Safety + /// + /// `phydev` must be passed by the corresponding callback in `phy_driver`. + unsafe extern "C" fn match_phy_device_callback( + phydev: *mut bindings::phy_device, + _phydrv: *const bindings::phy_driver, + ) -> c_int { + // SAFETY: This callback is called only in contexts + // where we hold `phy_device->lock`, so the accessors on + // `Device` are okay to call. + let dev = unsafe { Device::from_raw(phydev) }; + T::match_phy_device(dev).into() + } + + /// # Safety + /// + /// `phydev` must be passed by the corresponding callback in `phy_driver`. + unsafe extern "C" fn read_mmd_callback( + phydev: *mut bindings::phy_device, + devnum: i32, + regnum: u16, + ) -> i32 { + from_result(|| { + // SAFETY: This callback is called only in contexts + // where we hold `phy_device->lock`, so the accessors on + // `Device` are okay to call. + let dev = unsafe { Device::from_raw(phydev) }; + // CAST: the C side verifies devnum < 32. + let ret = T::read_mmd(dev, devnum as u8, regnum)?; + Ok(ret.into()) + }) + } + + /// # Safety + /// + /// `phydev` must be passed by the corresponding callback in `phy_driver`. + unsafe extern "C" fn write_mmd_callback( + phydev: *mut bindings::phy_device, + devnum: i32, + regnum: u16, + val: u16, + ) -> i32 { + from_result(|| { + // SAFETY: This callback is called only in contexts + // where we hold `phy_device->lock`, so the accessors on + // `Device` are okay to call. + let dev = unsafe { Device::from_raw(phydev) }; + T::write_mmd(dev, devnum as u8, regnum, val)?; + Ok(0) + }) + } + + /// # Safety + /// + /// `phydev` must be passed by the corresponding callback in `phy_driver`. + unsafe extern "C" fn link_change_notify_callback(phydev: *mut bindings::phy_device) { + // SAFETY: This callback is called only in contexts + // where we hold `phy_device->lock`, so the accessors on + // `Device` are okay to call. + let dev = unsafe { Device::from_raw(phydev) }; + T::link_change_notify(dev); + } +} + +/// Driver structure for a particular PHY type. +/// +/// Wraps the kernel's [`struct phy_driver`]. +/// This is used to register a driver for a particular PHY type with the kernel. +/// +/// # Invariants +/// +/// `self.0` is always in a valid state. +/// +/// [`struct phy_driver`]: srctree/include/linux/phy.h +#[repr(transparent)] +pub struct DriverVTable(Opaque<bindings::phy_driver>); + +// SAFETY: `DriverVTable` doesn't expose any &self method to access internal data, so it's safe to +// share `&DriverVTable` across execution context boundaries. +unsafe impl Sync for DriverVTable {} + +/// Creates a [`DriverVTable`] instance from [`Driver`]. +/// +/// This is used by [`module_phy_driver`] macro to create a static array of `phy_driver`. +/// +/// [`module_phy_driver`]: crate::module_phy_driver +pub const fn create_phy_driver<T: Driver>() -> DriverVTable { + // INVARIANT: All the fields of `struct phy_driver` are initialized properly. + DriverVTable(Opaque::new(bindings::phy_driver { + name: crate::str::as_char_ptr_in_const_context(T::NAME).cast_mut(), + flags: T::FLAGS, + phy_id: T::PHY_DEVICE_ID.id(), + phy_id_mask: T::PHY_DEVICE_ID.mask_as_int(), + soft_reset: if T::HAS_SOFT_RESET { + Some(Adapter::<T>::soft_reset_callback) + } else { + None + }, + probe: if T::HAS_PROBE { + Some(Adapter::<T>::probe_callback) + } else { + None + }, + get_features: if T::HAS_GET_FEATURES { + Some(Adapter::<T>::get_features_callback) + } else { + None + }, + match_phy_device: if T::HAS_MATCH_PHY_DEVICE { + Some(Adapter::<T>::match_phy_device_callback) + } else { + None + }, + suspend: if T::HAS_SUSPEND { + Some(Adapter::<T>::suspend_callback) + } else { + None + }, + resume: if T::HAS_RESUME { + Some(Adapter::<T>::resume_callback) + } else { + None + }, + config_aneg: if T::HAS_CONFIG_ANEG { + Some(Adapter::<T>::config_aneg_callback) + } else { + None + }, + read_status: if T::HAS_READ_STATUS { + Some(Adapter::<T>::read_status_callback) + } else { + None + }, + read_mmd: if T::HAS_READ_MMD { + Some(Adapter::<T>::read_mmd_callback) + } else { + None + }, + write_mmd: if T::HAS_WRITE_MMD { + Some(Adapter::<T>::write_mmd_callback) + } else { + None + }, + link_change_notify: if T::HAS_LINK_CHANGE_NOTIFY { + Some(Adapter::<T>::link_change_notify_callback) + } else { + None + }, + // SAFETY: The rest is zeroed out to initialize `struct phy_driver`, + // sets `Option<&F>` to be `None`. + ..unsafe { core::mem::MaybeUninit::<bindings::phy_driver>::zeroed().assume_init() } + })) +} + +/// Driver implementation for a particular PHY type. +/// +/// This trait is used to create a [`DriverVTable`]. +#[vtable] +pub trait Driver { + /// Defines certain other features this PHY supports. + /// It is a combination of the flags in the [`flags`] module. + const FLAGS: u32 = 0; + + /// The friendly name of this PHY type. + const NAME: &'static CStr; + + /// This driver only works for PHYs with IDs which match this field. + /// The default id and mask are zero. + const PHY_DEVICE_ID: DeviceId = DeviceId::new_with_custom_mask(0, 0); + + /// Issues a PHY software reset. + fn soft_reset(_dev: &mut Device) -> Result { + build_error!(VTABLE_DEFAULT_ERROR) + } + + /// Sets up device-specific structures during discovery. + fn probe(_dev: &mut Device) -> Result { + build_error!(VTABLE_DEFAULT_ERROR) + } + + /// Probes the hardware to determine what abilities it has. + fn get_features(_dev: &mut Device) -> Result { + build_error!(VTABLE_DEFAULT_ERROR) + } + + /// Returns true if this is a suitable driver for the given phydev. + /// If not implemented, matching is based on [`Driver::PHY_DEVICE_ID`]. + fn match_phy_device(_dev: &Device) -> bool { + false + } + + /// Configures the advertisement and resets auto-negotiation + /// if auto-negotiation is enabled. + fn config_aneg(_dev: &mut Device) -> Result { + build_error!(VTABLE_DEFAULT_ERROR) + } + + /// Determines the negotiated speed and duplex. + fn read_status(_dev: &mut Device) -> Result<u16> { + build_error!(VTABLE_DEFAULT_ERROR) + } + + /// Suspends the hardware, saving state if needed. + fn suspend(_dev: &mut Device) -> Result { + build_error!(VTABLE_DEFAULT_ERROR) + } + + /// Resumes the hardware, restoring state if needed. + fn resume(_dev: &mut Device) -> Result { + build_error!(VTABLE_DEFAULT_ERROR) + } + + /// Overrides the default MMD read function for reading a MMD register. + fn read_mmd(_dev: &mut Device, _devnum: u8, _regnum: u16) -> Result<u16> { + build_error!(VTABLE_DEFAULT_ERROR) + } + + /// Overrides the default MMD write function for writing a MMD register. + fn write_mmd(_dev: &mut Device, _devnum: u8, _regnum: u16, _val: u16) -> Result { + build_error!(VTABLE_DEFAULT_ERROR) + } + + /// Callback for notification of link change. + fn link_change_notify(_dev: &mut Device) {} +} + +/// Registration structure for PHY drivers. +/// +/// Registers [`DriverVTable`] instances with the kernel. They will be unregistered when dropped. +/// +/// # Invariants +/// +/// The `drivers` slice are currently registered to the kernel via `phy_drivers_register`. +pub struct Registration { + drivers: Pin<&'static mut [DriverVTable]>, +} + +// SAFETY: The only action allowed in a `Registration` instance is dropping it, which is safe to do +// from any thread because `phy_drivers_unregister` can be called from any thread context. +unsafe impl Send for Registration {} + +impl Registration { + /// Registers a PHY driver. + pub fn register( + module: &'static crate::ThisModule, + drivers: Pin<&'static mut [DriverVTable]>, + ) -> Result<Self> { + if drivers.is_empty() { + return Err(code::EINVAL); + } + // SAFETY: The type invariants of [`DriverVTable`] ensure that all elements of + // the `drivers` slice are initialized properly. `drivers` will not be moved. + // So it's just an FFI call. + to_result(unsafe { + bindings::phy_drivers_register(drivers[0].0.get(), drivers.len().try_into()?, module.0) + })?; + // INVARIANT: The `drivers` slice is successfully registered to the kernel via `phy_drivers_register`. + Ok(Registration { drivers }) + } +} + +impl Drop for Registration { + fn drop(&mut self) { + // SAFETY: The type invariants guarantee that `self.drivers` is valid. + // So it's just an FFI call. + unsafe { + bindings::phy_drivers_unregister(self.drivers[0].0.get(), self.drivers.len() as i32) + }; + } +} + +/// An identifier for PHY devices on an MDIO/MII bus. +/// +/// Represents the kernel's `struct mdio_device_id`. This is used to find an appropriate +/// PHY driver. +#[repr(transparent)] +#[derive(Clone, Copy)] +pub struct DeviceId(bindings::mdio_device_id); + +impl DeviceId { + /// Creates a new instance with the exact match mask. + pub const fn new_with_exact_mask(id: u32) -> Self { + Self(bindings::mdio_device_id { + phy_id: id, + phy_id_mask: DeviceMask::Exact.as_int(), + }) + } + + /// Creates a new instance with the model match mask. + pub const fn new_with_model_mask(id: u32) -> Self { + Self(bindings::mdio_device_id { + phy_id: id, + phy_id_mask: DeviceMask::Model.as_int(), + }) + } + + /// Creates a new instance with the vendor match mask. + pub const fn new_with_vendor_mask(id: u32) -> Self { + Self(bindings::mdio_device_id { + phy_id: id, + phy_id_mask: DeviceMask::Vendor.as_int(), + }) + } + + /// Creates a new instance with a custom match mask. + pub const fn new_with_custom_mask(id: u32, mask: u32) -> Self { + Self(bindings::mdio_device_id { + phy_id: id, + phy_id_mask: DeviceMask::Custom(mask).as_int(), + }) + } + + /// Creates a new instance from [`Driver`]. + pub const fn new_with_driver<T: Driver>() -> Self { + T::PHY_DEVICE_ID + } + + /// Get the MDIO device's PHY ID. + pub const fn id(&self) -> u32 { + self.0.phy_id + } + + /// Get the MDIO device's match mask. + pub const fn mask_as_int(&self) -> u32 { + self.0.phy_id_mask + } + + // macro use only + #[doc(hidden)] + pub const fn mdio_device_id(&self) -> bindings::mdio_device_id { + self.0 + } +} + +// SAFETY: `DeviceId` is a `#[repr(transparent)]` wrapper of `struct mdio_device_id` +// and does not add additional invariants, so it's safe to transmute to `RawType`. +unsafe impl RawDeviceId for DeviceId { + type RawType = bindings::mdio_device_id; +} + +enum DeviceMask { + Exact, + Model, + Vendor, + Custom(u32), +} + +impl DeviceMask { + const MASK_EXACT: u32 = !0; + const MASK_MODEL: u32 = !0 << 4; + const MASK_VENDOR: u32 = !0 << 10; + + const fn as_int(&self) -> u32 { + match self { + DeviceMask::Exact => Self::MASK_EXACT, + DeviceMask::Model => Self::MASK_MODEL, + DeviceMask::Vendor => Self::MASK_VENDOR, + DeviceMask::Custom(mask) => *mask, + } + } +} + +/// Declares a kernel module for PHYs drivers. +/// +/// This creates a static array of kernel's `struct phy_driver` and registers it. +/// This also corresponds to the kernel's `MODULE_DEVICE_TABLE` macro, which embeds the information +/// for module loading into the module binary file. Every driver needs an entry in `device_table`. +/// +/// # Examples +/// +/// ``` +/// # mod module_phy_driver_sample { +/// use kernel::c_str; +/// use kernel::net::phy::{self, DeviceId}; +/// use kernel::prelude::*; +/// +/// kernel::module_phy_driver! { +/// drivers: [PhySample], +/// device_table: [ +/// DeviceId::new_with_driver::<PhySample>() +/// ], +/// name: "rust_sample_phy", +/// authors: ["Rust for Linux Contributors"], +/// description: "Rust sample PHYs driver", +/// license: "GPL", +/// } +/// +/// struct PhySample; +/// +/// #[vtable] +/// impl phy::Driver for PhySample { +/// const NAME: &'static CStr = c_str!("PhySample"); +/// const PHY_DEVICE_ID: phy::DeviceId = phy::DeviceId::new_with_exact_mask(0x00000001); +/// } +/// # } +/// ``` +/// +/// This expands to the following code: +/// +/// ```ignore +/// use kernel::c_str; +/// use kernel::net::phy::{self, DeviceId}; +/// use kernel::prelude::*; +/// +/// struct Module { +/// _reg: ::kernel::net::phy::Registration, +/// } +/// +/// module! { +/// type: Module, +/// name: "rust_sample_phy", +/// authors: ["Rust for Linux Contributors"], +/// description: "Rust sample PHYs driver", +/// license: "GPL", +/// } +/// +/// struct PhySample; +/// +/// #[vtable] +/// impl phy::Driver for PhySample { +/// const NAME: &'static CStr = c_str!("PhySample"); +/// const PHY_DEVICE_ID: phy::DeviceId = phy::DeviceId::new_with_exact_mask(0x00000001); +/// } +/// +/// const _: () = { +/// static mut DRIVERS: [::kernel::net::phy::DriverVTable; 1] = +/// [::kernel::net::phy::create_phy_driver::<PhySample>()]; +/// +/// impl ::kernel::Module for Module { +/// fn init(module: &'static ::kernel::ThisModule) -> Result<Self> { +/// let drivers = unsafe { &mut DRIVERS }; +/// let mut reg = ::kernel::net::phy::Registration::register( +/// module, +/// ::core::pin::Pin::static_mut(drivers), +/// )?; +/// Ok(Module { _reg: reg }) +/// } +/// } +/// }; +/// +/// const N: usize = 1; +/// +/// const TABLE: ::kernel::device_id::IdArray<::kernel::net::phy::DeviceId, (), N> = +/// ::kernel::device_id::IdArray::new_without_index([ +/// ::kernel::net::phy::DeviceId( +/// ::kernel::bindings::mdio_device_id { +/// phy_id: 0x00000001, +/// phy_id_mask: 0xffffffff, +/// }), +/// ]); +/// +/// ::kernel::module_device_table!("mdio", phydev, TABLE); +/// ``` +#[macro_export] +macro_rules! module_phy_driver { + (@replace_expr $_t:tt $sub:expr) => {$sub}; + + (@count_devices $($x:expr),*) => { + 0usize $(+ $crate::module_phy_driver!(@replace_expr $x 1usize))* + }; + + (@device_table [$($dev:expr),+]) => { + const N: usize = $crate::module_phy_driver!(@count_devices $($dev),+); + + const TABLE: $crate::device_id::IdArray<$crate::net::phy::DeviceId, (), N> = + $crate::device_id::IdArray::new_without_index([ $(($dev,())),+, ]); + + $crate::module_device_table!("mdio", phydev, TABLE); + }; + + (drivers: [$($driver:ident),+ $(,)?], device_table: [$($dev:expr),+ $(,)?], $($f:tt)*) => { + struct Module { + _reg: $crate::net::phy::Registration, + } + + $crate::prelude::module! { + type: Module, + $($f)* + } + + const _: () = { + static mut DRIVERS: [$crate::net::phy::DriverVTable; + $crate::module_phy_driver!(@count_devices $($driver),+)] = + [$($crate::net::phy::create_phy_driver::<$driver>()),+]; + + impl $crate::Module for Module { + fn init(module: &'static $crate::ThisModule) -> Result<Self> { + // SAFETY: The anonymous constant guarantees that nobody else can access + // the `DRIVERS` static. The array is used only in the C side. + let drivers = unsafe { &mut DRIVERS }; + let mut reg = $crate::net::phy::Registration::register( + module, + ::core::pin::Pin::static_mut(drivers), + )?; + Ok(Module { _reg: reg }) + } + } + }; + + $crate::module_phy_driver!(@device_table [$($dev),+]); + } +} diff --git a/rust/kernel/net/phy/reg.rs b/rust/kernel/net/phy/reg.rs new file mode 100644 index 000000000000..a7db0064cb7d --- /dev/null +++ b/rust/kernel/net/phy/reg.rs @@ -0,0 +1,224 @@ +// SPDX-License-Identifier: GPL-2.0 + +// Copyright (C) 2024 FUJITA Tomonori <fujita.tomonori@gmail.com> + +//! PHY register interfaces. +//! +//! This module provides support for accessing PHY registers in the +//! Ethernet management interface clauses 22 and 45 register namespaces, as +//! defined in IEEE 802.3. + +use super::Device; +use crate::build_assert; +use crate::error::*; +use crate::uapi; + +mod private { + /// Marker that a trait cannot be implemented outside of this crate + pub trait Sealed {} +} + +/// Accesses PHY registers. +/// +/// This trait is used to implement the unified interface to access +/// C22 and C45 PHY registers. +/// +/// # Examples +/// +/// ```ignore +/// fn link_change_notify(dev: &mut Device) { +/// // read C22 BMCR register +/// dev.read(C22::BMCR); +/// // read C45 PMA/PMD control 1 register +/// dev.read(C45::new(Mmd::PMAPMD, 0)); +/// +/// // Checks the link status as reported by registers in the C22 namespace +/// // and updates current link state. +/// dev.genphy_read_status::<phy::C22>(); +/// // Checks the link status as reported by registers in the C45 namespace +/// // and updates current link state. +/// dev.genphy_read_status::<phy::C45>(); +/// } +/// ``` +pub trait Register: private::Sealed { + /// Reads a PHY register. + fn read(&self, dev: &mut Device) -> Result<u16>; + + /// Writes a PHY register. + fn write(&self, dev: &mut Device, val: u16) -> Result; + + /// Checks the link status and updates current link state. + fn read_status(dev: &mut Device) -> Result<u16>; +} + +/// A single MDIO clause 22 register address (5 bits). +#[derive(Copy, Clone, Debug)] +pub struct C22(u8); + +impl C22 { + /// Basic mode control. + pub const BMCR: Self = C22(0x00); + /// Basic mode status. + pub const BMSR: Self = C22(0x01); + /// PHY identifier 1. + pub const PHYSID1: Self = C22(0x02); + /// PHY identifier 2. + pub const PHYSID2: Self = C22(0x03); + /// Auto-negotiation advertisement. + pub const ADVERTISE: Self = C22(0x04); + /// Auto-negotiation link partner base page ability. + pub const LPA: Self = C22(0x05); + /// Auto-negotiation expansion. + pub const EXPANSION: Self = C22(0x06); + /// Auto-negotiation next page transmit. + pub const NEXT_PAGE_TRANSMIT: Self = C22(0x07); + /// Auto-negotiation link partner received next page. + pub const LP_RECEIVED_NEXT_PAGE: Self = C22(0x08); + /// Master-slave control. + pub const MASTER_SLAVE_CONTROL: Self = C22(0x09); + /// Master-slave status. + pub const MASTER_SLAVE_STATUS: Self = C22(0x0a); + /// PSE Control. + pub const PSE_CONTROL: Self = C22(0x0b); + /// PSE Status. + pub const PSE_STATUS: Self = C22(0x0c); + /// MMD Register control. + pub const MMD_CONTROL: Self = C22(0x0d); + /// MMD Register address data. + pub const MMD_DATA: Self = C22(0x0e); + /// Extended status. + pub const EXTENDED_STATUS: Self = C22(0x0f); + + /// Creates a new instance of `C22` with a vendor specific register. + pub const fn vendor_specific<const N: u8>() -> Self { + build_assert!( + N > 0x0f && N < 0x20, + "Vendor-specific register address must be between 16 and 31" + ); + C22(N) + } +} + +impl private::Sealed for C22 {} + +impl Register for C22 { + fn read(&self, dev: &mut Device) -> Result<u16> { + let phydev = dev.0.get(); + // SAFETY: `phydev` is pointing to a valid object by the type invariant of `Device`. + // So it's just an FFI call, open code of `phy_read()` with a valid `phy_device` pointer + // `phydev`. + let ret = unsafe { + bindings::mdiobus_read((*phydev).mdio.bus, (*phydev).mdio.addr, self.0.into()) + }; + to_result(ret)?; + Ok(ret as u16) + } + + fn write(&self, dev: &mut Device, val: u16) -> Result { + let phydev = dev.0.get(); + // SAFETY: `phydev` is pointing to a valid object by the type invariant of `Device`. + // So it's just an FFI call, open code of `phy_write()` with a valid `phy_device` pointer + // `phydev`. + to_result(unsafe { + bindings::mdiobus_write((*phydev).mdio.bus, (*phydev).mdio.addr, self.0.into(), val) + }) + } + + fn read_status(dev: &mut Device) -> Result<u16> { + let phydev = dev.0.get(); + // SAFETY: `phydev` is pointing to a valid object by the type invariant of `Self`. + // So it's just an FFI call. + let ret = unsafe { bindings::genphy_read_status(phydev) }; + to_result(ret)?; + Ok(ret as u16) + } +} + +/// A single MDIO clause 45 register device and address. +#[derive(Copy, Clone, Debug)] +pub struct Mmd(u8); + +impl Mmd { + /// Physical Medium Attachment/Dependent. + pub const PMAPMD: Self = Mmd(uapi::MDIO_MMD_PMAPMD as u8); + /// WAN interface sublayer. + pub const WIS: Self = Mmd(uapi::MDIO_MMD_WIS as u8); + /// Physical coding sublayer. + pub const PCS: Self = Mmd(uapi::MDIO_MMD_PCS as u8); + /// PHY Extender sublayer. + pub const PHYXS: Self = Mmd(uapi::MDIO_MMD_PHYXS as u8); + /// DTE Extender sublayer. + pub const DTEXS: Self = Mmd(uapi::MDIO_MMD_DTEXS as u8); + /// Transmission convergence. + pub const TC: Self = Mmd(uapi::MDIO_MMD_TC as u8); + /// Auto negotiation. + pub const AN: Self = Mmd(uapi::MDIO_MMD_AN as u8); + /// Separated PMA (1). + pub const SEPARATED_PMA1: Self = Mmd(8); + /// Separated PMA (2). + pub const SEPARATED_PMA2: Self = Mmd(9); + /// Separated PMA (3). + pub const SEPARATED_PMA3: Self = Mmd(10); + /// Separated PMA (4). + pub const SEPARATED_PMA4: Self = Mmd(11); + /// OFDM PMA/PMD. + pub const OFDM_PMAPMD: Self = Mmd(12); + /// Power unit. + pub const POWER_UNIT: Self = Mmd(13); + /// Clause 22 extension. + pub const C22_EXT: Self = Mmd(uapi::MDIO_MMD_C22EXT as u8); + /// Vendor specific 1. + pub const VEND1: Self = Mmd(uapi::MDIO_MMD_VEND1 as u8); + /// Vendor specific 2. + pub const VEND2: Self = Mmd(uapi::MDIO_MMD_VEND2 as u8); +} + +/// A single MDIO clause 45 register device and address. +/// +/// Clause 45 uses a 5-bit device address to access a specific MMD within +/// a port, then a 16-bit register address to access a location within +/// that device. `C45` represents this by storing a [`Mmd`] and +/// a register number. +pub struct C45 { + devad: Mmd, + regnum: u16, +} + +impl C45 { + /// Creates a new instance of `C45`. + pub fn new(devad: Mmd, regnum: u16) -> Self { + Self { devad, regnum } + } +} + +impl private::Sealed for C45 {} + +impl Register for C45 { + fn read(&self, dev: &mut Device) -> Result<u16> { + let phydev = dev.0.get(); + // SAFETY: `phydev` is pointing to a valid object by the type invariant of `Device`. + // So it's just an FFI call. + let ret = + unsafe { bindings::phy_read_mmd(phydev, self.devad.0.into(), self.regnum.into()) }; + to_result(ret)?; + Ok(ret as u16) + } + + fn write(&self, dev: &mut Device, val: u16) -> Result { + let phydev = dev.0.get(); + // SAFETY: `phydev` is pointing to a valid object by the type invariant of `Device`. + // So it's just an FFI call. + to_result(unsafe { + bindings::phy_write_mmd(phydev, self.devad.0.into(), self.regnum.into(), val) + }) + } + + fn read_status(dev: &mut Device) -> Result<u16> { + let phydev = dev.0.get(); + // SAFETY: `phydev` is pointing to a valid object by the type invariant of `Self`. + // So it's just an FFI call. + let ret = unsafe { bindings::genphy_c45_read_status(phydev) }; + to_result(ret)?; + Ok(ret as u16) + } +} diff --git a/rust/kernel/num.rs b/rust/kernel/num.rs new file mode 100644 index 000000000000..8532b511384c --- /dev/null +++ b/rust/kernel/num.rs @@ -0,0 +1,79 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! Additional numerical features for the kernel. + +use core::ops; + +pub mod bounded; +pub use bounded::*; + +/// Designates unsigned primitive types. +pub enum Unsigned {} + +/// Designates signed primitive types. +pub enum Signed {} + +/// Describes core properties of integer types. +pub trait Integer: + Sized + + Copy + + Clone + + PartialEq + + Eq + + PartialOrd + + Ord + + ops::Add<Output = Self> + + ops::AddAssign + + ops::Sub<Output = Self> + + ops::SubAssign + + ops::Mul<Output = Self> + + ops::MulAssign + + ops::Div<Output = Self> + + ops::DivAssign + + ops::Rem<Output = Self> + + ops::RemAssign + + ops::BitAnd<Output = Self> + + ops::BitAndAssign + + ops::BitOr<Output = Self> + + ops::BitOrAssign + + ops::BitXor<Output = Self> + + ops::BitXorAssign + + ops::Shl<u32, Output = Self> + + ops::ShlAssign<u32> + + ops::Shr<u32, Output = Self> + + ops::ShrAssign<u32> + + ops::Not +{ + /// Whether this type is [`Signed`] or [`Unsigned`]. + type Signedness; + + /// Number of bits used for value representation. + const BITS: u32; +} + +macro_rules! impl_integer { + ($($type:ty: $signedness:ty), *) => { + $( + impl Integer for $type { + type Signedness = $signedness; + + const BITS: u32 = <$type>::BITS; + } + )* + }; +} + +impl_integer!( + u8: Unsigned, + u16: Unsigned, + u32: Unsigned, + u64: Unsigned, + u128: Unsigned, + usize: Unsigned, + i8: Signed, + i16: Signed, + i32: Signed, + i64: Signed, + i128: Signed, + isize: Signed +); diff --git a/rust/kernel/num/bounded.rs b/rust/kernel/num/bounded.rs new file mode 100644 index 000000000000..f870080af8ac --- /dev/null +++ b/rust/kernel/num/bounded.rs @@ -0,0 +1,1058 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! Implementation of [`Bounded`], a wrapper around integer types limiting the number of bits +//! usable for value representation. + +use core::{ + cmp, + fmt, + ops::{ + self, + Deref, // + }, //, +}; + +use kernel::{ + num::Integer, + prelude::*, // +}; + +/// Evaluates to `true` if `$value` can be represented using at most `$n` bits in a `$type`. +/// +/// `expr` must be of type `type`, or the result will be incorrect. +/// +/// Can be used in const context. +macro_rules! fits_within { + ($value:expr, $type:ty, $n:expr) => {{ + let shift: u32 = <$type>::BITS - $n; + + // `value` fits within `$n` bits if shifting it left by the number of unused bits, then + // right by the same number, doesn't change it. + // + // This method has the benefit of working for both unsigned and signed values. + ($value << shift) >> shift == $value + }}; +} + +/// Returns `true` if `value` can be represented with at most `N` bits in a `T`. +#[inline(always)] +fn fits_within<T: Integer>(value: T, num_bits: u32) -> bool { + fits_within!(value, T, num_bits) +} + +/// An integer value that requires only the `N` less significant bits of the wrapped type to be +/// encoded. +/// +/// This limits the number of usable bits in the wrapped integer type, and thus the stored value to +/// a narrower range, which provides guarantees that can be useful when working with in e.g. +/// bitfields. +/// +/// # Invariants +/// +/// - `N` is greater than `0`. +/// - `N` is less than or equal to `T::BITS`. +/// - Stored values can be represented with at most `N` bits. +/// +/// # Examples +/// +/// The preferred way to create values is through constants and the [`Bounded::new`] family of +/// constructors, as they trigger a build error if the type invariants cannot be withheld. +/// +/// ``` +/// use kernel::num::Bounded; +/// +/// // An unsigned 8-bit integer, of which only the 4 LSBs are used. +/// // The value `15` is statically validated to fit that constraint at build time. +/// let v = Bounded::<u8, 4>::new::<15>(); +/// assert_eq!(v.get(), 15); +/// +/// // Same using signed values. +/// let v = Bounded::<i8, 4>::new::<-8>(); +/// assert_eq!(v.get(), -8); +/// +/// // This doesn't build: a `u8` is smaller than the requested 9 bits. +/// // let _ = Bounded::<u8, 9>::new::<10>(); +/// +/// // This also doesn't build: the requested value doesn't fit within 4 signed bits. +/// // let _ = Bounded::<i8, 4>::new::<8>(); +/// ``` +/// +/// Values can also be validated at runtime with [`Bounded::try_new`]. +/// +/// ``` +/// use kernel::num::Bounded; +/// +/// // This succeeds because `15` can be represented with 4 unsigned bits. +/// assert!(Bounded::<u8, 4>::try_new(15).is_some()); +/// +/// // This fails because `16` cannot be represented with 4 unsigned bits. +/// assert!(Bounded::<u8, 4>::try_new(16).is_none()); +/// ``` +/// +/// Non-constant expressions can be validated at build-time thanks to compiler optimizations. This +/// should be used with caution, on simple expressions only. +/// +/// ``` +/// use kernel::num::Bounded; +/// # fn some_number() -> u32 { 0xffffffff } +/// +/// // Here the compiler can infer from the mask that the type invariants are not violated, even +/// // though the value returned by `some_number` is not statically known. +/// let v = Bounded::<u32, 4>::from_expr(some_number() & 0xf); +/// ``` +/// +/// Comparison and arithmetic operations are supported on [`Bounded`]s with a compatible backing +/// type, regardless of their number of valid bits. +/// +/// ``` +/// use kernel::num::Bounded; +/// +/// let v1 = Bounded::<u32, 8>::new::<4>(); +/// let v2 = Bounded::<u32, 4>::new::<15>(); +/// +/// assert!(v1 != v2); +/// assert!(v1 < v2); +/// assert_eq!(v1 + v2, 19); +/// assert_eq!(v2 % v1, 3); +/// ``` +/// +/// These operations are also supported between a [`Bounded`] and its backing type. +/// +/// ``` +/// use kernel::num::Bounded; +/// +/// let v = Bounded::<u8, 4>::new::<15>(); +/// +/// assert!(v == 15); +/// assert!(v > 12); +/// assert_eq!(v + 5, 20); +/// assert_eq!(v / 3, 5); +/// ``` +/// +/// A change of backing types is possible using [`Bounded::cast`], and the number of valid bits can +/// be extended or reduced with [`Bounded::extend`] and [`Bounded::try_shrink`]. +/// +/// ``` +/// use kernel::num::Bounded; +/// +/// let v = Bounded::<u32, 12>::new::<127>(); +/// +/// // Changes backing type from `u32` to `u16`. +/// let _: Bounded<u16, 12> = v.cast(); +/// +/// // This does not build, as `u8` is smaller than 12 bits. +/// // let _: Bounded<u8, 12> = v.cast(); +/// +/// // We can safely extend the number of bits... +/// let _ = v.extend::<15>(); +/// +/// // ... to the limits of the backing type. This doesn't build as a `u32` cannot contain 33 bits. +/// // let _ = v.extend::<33>(); +/// +/// // Reducing the number of bits is validated at runtime. This works because `127` can be +/// // represented with 8 bits. +/// assert!(v.try_shrink::<8>().is_some()); +/// +/// // ... but not with 6, so this fails. +/// assert!(v.try_shrink::<6>().is_none()); +/// ``` +/// +/// Infallible conversions from a primitive integer to a large-enough [`Bounded`] are supported. +/// +/// ``` +/// use kernel::num::Bounded; +/// +/// // This unsigned `Bounded` has 8 bits, so it can represent any `u8`. +/// let v = Bounded::<u32, 8>::from(128u8); +/// assert_eq!(v.get(), 128); +/// +/// // This signed `Bounded` has 8 bits, so it can represent any `i8`. +/// let v = Bounded::<i32, 8>::from(-128i8); +/// assert_eq!(v.get(), -128); +/// +/// // This doesn't build, as this 6-bit `Bounded` does not have enough capacity to represent a +/// // `u8` (regardless of the passed value). +/// // let _ = Bounded::<u32, 6>::from(10u8); +/// +/// // Booleans can be converted into single-bit `Bounded`s. +/// +/// let v = Bounded::<u64, 1>::from(false); +/// assert_eq!(v.get(), 0); +/// +/// let v = Bounded::<u64, 1>::from(true); +/// assert_eq!(v.get(), 1); +/// ``` +/// +/// Infallible conversions from a [`Bounded`] to a primitive integer are also supported, and +/// dependent on the number of bits used for value representation, not on the backing type. +/// +/// ``` +/// use kernel::num::Bounded; +/// +/// // Even though its backing type is `u32`, this `Bounded` only uses 6 bits and thus can safely +/// // be converted to a `u8`. +/// let v = Bounded::<u32, 6>::new::<63>(); +/// assert_eq!(u8::from(v), 63); +/// +/// // Same using signed values. +/// let v = Bounded::<i32, 8>::new::<-128>(); +/// assert_eq!(i8::from(v), -128); +/// +/// // This however does not build, as 10 bits won't fit into a `u8` (regardless of the actually +/// // contained value). +/// let _v = Bounded::<u32, 10>::new::<10>(); +/// // assert_eq!(u8::from(_v), 10); +/// +/// // Single-bit `Bounded`s can be converted into a boolean. +/// let v = Bounded::<u8, 1>::new::<1>(); +/// assert_eq!(bool::from(v), true); +/// +/// let v = Bounded::<u8, 1>::new::<0>(); +/// assert_eq!(bool::from(v), false); +/// ``` +/// +/// Fallible conversions from any primitive integer to any [`Bounded`] are also supported using the +/// [`TryIntoBounded`] trait. +/// +/// ``` +/// use kernel::num::{Bounded, TryIntoBounded}; +/// +/// // Succeeds because `128` fits into 8 bits. +/// let v: Option<Bounded<u16, 8>> = 128u32.try_into_bounded(); +/// assert_eq!(v.as_deref().copied(), Some(128)); +/// +/// // Fails because `128` doesn't fits into 6 bits. +/// let v: Option<Bounded<u16, 6>> = 128u32.try_into_bounded(); +/// assert_eq!(v, None); +/// ``` +#[repr(transparent)] +#[derive(Clone, Copy, Debug, Default, Hash)] +pub struct Bounded<T: Integer, const N: u32>(T); + +/// Validating the value as a const expression cannot be done as a regular method, as the +/// arithmetic operations we rely on to check the bounds are not const. Thus, implement +/// [`Bounded::new`] using a macro. +macro_rules! impl_const_new { + ($($type:ty)*) => { + $( + impl<const N: u32> Bounded<$type, N> { + /// Creates a [`Bounded`] for the constant `VALUE`. + /// + /// Fails at build time if `VALUE` cannot be represented with `N` bits. + /// + /// This method should be preferred to [`Self::from_expr`] whenever possible. + /// + /// # Examples + /// + /// ``` + /// use kernel::num::Bounded; + /// + #[doc = ::core::concat!( + "let v = Bounded::<", + ::core::stringify!($type), + ", 4>::new::<7>();")] + /// assert_eq!(v.get(), 7); + /// ``` + pub const fn new<const VALUE: $type>() -> Self { + // Statically assert that `VALUE` fits within the set number of bits. + const { + assert!(fits_within!(VALUE, $type, N)); + } + + // INVARIANT: `fits_within` confirmed that `VALUE` can be represented within + // `N` bits. + Self::__new(VALUE) + } + } + )* + }; +} + +impl_const_new!( + u8 u16 u32 u64 usize + i8 i16 i32 i64 isize +); + +impl<T, const N: u32> Bounded<T, N> +where + T: Integer, +{ + /// Private constructor enforcing the type invariants. + /// + /// All instances of [`Bounded`] must be created through this method as it enforces most of the + /// type invariants. + /// + /// The caller remains responsible for checking, either statically or dynamically, that `value` + /// can be represented as a `T` using at most `N` bits. + const fn __new(value: T) -> Self { + // Enforce the type invariants. + const { + // `N` cannot be zero. + assert!(N != 0); + // The backing type is at least as large as `N` bits. + assert!(N <= T::BITS); + } + + Self(value) + } + + /// Attempts to turn `value` into a `Bounded` using `N` bits. + /// + /// Returns [`None`] if `value` doesn't fit within `N` bits. + /// + /// # Examples + /// + /// ``` + /// use kernel::num::Bounded; + /// + /// let v = Bounded::<u8, 1>::try_new(1); + /// assert_eq!(v.as_deref().copied(), Some(1)); + /// + /// let v = Bounded::<i8, 4>::try_new(-2); + /// assert_eq!(v.as_deref().copied(), Some(-2)); + /// + /// // `0x1ff` doesn't fit into 8 unsigned bits. + /// let v = Bounded::<u32, 8>::try_new(0x1ff); + /// assert_eq!(v, None); + /// + /// // The range of values representable with 4 bits is `[-8..=7]`. The following tests these + /// // limits. + /// let v = Bounded::<i8, 4>::try_new(-8); + /// assert_eq!(v.map(Bounded::get), Some(-8)); + /// let v = Bounded::<i8, 4>::try_new(-9); + /// assert_eq!(v, None); + /// let v = Bounded::<i8, 4>::try_new(7); + /// assert_eq!(v.map(Bounded::get), Some(7)); + /// let v = Bounded::<i8, 4>::try_new(8); + /// assert_eq!(v, None); + /// ``` + pub fn try_new(value: T) -> Option<Self> { + fits_within(value, N).then(|| { + // INVARIANT: `fits_within` confirmed that `value` can be represented within `N` bits. + Self::__new(value) + }) + } + + /// Checks that `expr` is valid for this type at compile-time and build a new value. + /// + /// This relies on [`build_assert!`] and guaranteed optimization to perform validation at + /// compile-time. If `expr` cannot be proved to be within the requested bounds at compile-time, + /// use the fallible [`Self::try_new`] instead. + /// + /// Limit this to simple, easily provable expressions, and prefer one of the [`Self::new`] + /// constructors whenever possible as they statically validate the value instead of relying on + /// compiler optimizations. + /// + /// # Examples + /// + /// ``` + /// use kernel::num::Bounded; + /// # fn some_number() -> u32 { 0xffffffff } + /// + /// // Some undefined number. + /// let v: u32 = some_number(); + /// + /// // Triggers a build error as `v` cannot be asserted to fit within 4 bits... + /// // let _ = Bounded::<u32, 4>::from_expr(v); + /// + /// // ... but this works as the compiler can assert the range from the mask. + /// let _ = Bounded::<u32, 4>::from_expr(v & 0xf); + /// + /// // These expressions are simple enough to be proven correct, but since they are static the + /// // `new` constructor should be preferred. + /// assert_eq!(Bounded::<u8, 1>::from_expr(1).get(), 1); + /// assert_eq!(Bounded::<u16, 8>::from_expr(0xff).get(), 0xff); + /// ``` + #[inline(always)] + pub fn from_expr(expr: T) -> Self { + crate::build_assert!( + fits_within(expr, N), + "Requested value larger than maximal representable value." + ); + + // INVARIANT: `fits_within` confirmed that `expr` can be represented within `N` bits. + Self::__new(expr) + } + + /// Returns the wrapped value as the backing type. + /// + /// # Examples + /// + /// ``` + /// use kernel::num::Bounded; + /// + /// let v = Bounded::<u32, 4>::new::<7>(); + /// assert_eq!(v.get(), 7u32); + /// ``` + pub fn get(self) -> T { + *self.deref() + } + + /// Increases the number of bits usable for `self`. + /// + /// This operation cannot fail. + /// + /// # Examples + /// + /// ``` + /// use kernel::num::Bounded; + /// + /// let v = Bounded::<u32, 4>::new::<7>(); + /// let larger_v = v.extend::<12>(); + /// // The contained values are equal even though `larger_v` has a bigger capacity. + /// assert_eq!(larger_v, v); + /// ``` + pub const fn extend<const M: u32>(self) -> Bounded<T, M> { + const { + assert!( + M >= N, + "Requested number of bits is less than the current representation." + ); + } + + // INVARIANT: The value did fit within `N` bits, so it will all the more fit within + // the larger `M` bits. + Bounded::__new(self.0) + } + + /// Attempts to shrink the number of bits usable for `self`. + /// + /// Returns [`None`] if the value of `self` cannot be represented within `M` bits. + /// + /// # Examples + /// + /// ``` + /// use kernel::num::Bounded; + /// + /// let v = Bounded::<u32, 12>::new::<7>(); + /// + /// // `7` can be represented using 3 unsigned bits... + /// let smaller_v = v.try_shrink::<3>(); + /// assert_eq!(smaller_v.as_deref().copied(), Some(7)); + /// + /// // ... but doesn't fit within `2` bits. + /// assert_eq!(v.try_shrink::<2>(), None); + /// ``` + pub fn try_shrink<const M: u32>(self) -> Option<Bounded<T, M>> { + Bounded::<T, M>::try_new(self.get()) + } + + /// Casts `self` into a [`Bounded`] backed by a different storage type, but using the same + /// number of valid bits. + /// + /// Both `T` and `U` must be of same signedness, and `U` must be at least as large as + /// `N` bits, or a build error will occur. + /// + /// # Examples + /// + /// ``` + /// use kernel::num::Bounded; + /// + /// let v = Bounded::<u32, 12>::new::<127>(); + /// + /// let u16_v: Bounded<u16, 12> = v.cast(); + /// assert_eq!(u16_v.get(), 127); + /// + /// // This won't build: a `u8` is smaller than the required 12 bits. + /// // let _: Bounded<u8, 12> = v.cast(); + /// ``` + pub fn cast<U>(self) -> Bounded<U, N> + where + U: TryFrom<T> + Integer, + T: Integer, + U: Integer<Signedness = T::Signedness>, + { + // SAFETY: The converted value is represented using `N` bits, `U` can contain `N` bits, and + // `U` and `T` have the same sign, hence this conversion cannot fail. + let value = unsafe { U::try_from(self.get()).unwrap_unchecked() }; + + // INVARIANT: Although the backing type has changed, the value is still represented within + // `N` bits, and with the same signedness. + Bounded::__new(value) + } +} + +impl<T, const N: u32> Deref for Bounded<T, N> +where + T: Integer, +{ + type Target = T; + + fn deref(&self) -> &Self::Target { + // Enforce the invariant to inform the compiler of the bounds of the value. + if !fits_within(self.0, N) { + // SAFETY: Per the `Bounded` invariants, `fits_within` can never return `false` on the + // value of a valid instance. + unsafe { core::hint::unreachable_unchecked() } + } + + &self.0 + } +} + +/// Trait similar to [`TryInto`] but for [`Bounded`], to avoid conflicting implementations. +/// +/// # Examples +/// +/// ``` +/// use kernel::num::{Bounded, TryIntoBounded}; +/// +/// // Succeeds because `128` fits into 8 bits. +/// let v: Option<Bounded<u16, 8>> = 128u32.try_into_bounded(); +/// assert_eq!(v.as_deref().copied(), Some(128)); +/// +/// // Fails because `128` doesn't fits into 6 bits. +/// let v: Option<Bounded<u16, 6>> = 128u32.try_into_bounded(); +/// assert_eq!(v, None); +/// ``` +pub trait TryIntoBounded<T: Integer, const N: u32> { + /// Attempts to convert `self` into a [`Bounded`] using `N` bits. + /// + /// Returns [`None`] if `self` does not fit into the target type. + fn try_into_bounded(self) -> Option<Bounded<T, N>>; +} + +/// Any integer value can be attempted to be converted into a [`Bounded`] of any size. +impl<T, U, const N: u32> TryIntoBounded<T, N> for U +where + T: Integer, + U: TryInto<T>, +{ + fn try_into_bounded(self) -> Option<Bounded<T, N>> { + self.try_into().ok().and_then(Bounded::try_new) + } +} + +// Comparisons between `Bounded`s. + +impl<T, U, const N: u32, const M: u32> PartialEq<Bounded<U, M>> for Bounded<T, N> +where + T: Integer, + U: Integer, + T: PartialEq<U>, +{ + fn eq(&self, other: &Bounded<U, M>) -> bool { + self.get() == other.get() + } +} + +impl<T, const N: u32> Eq for Bounded<T, N> where T: Integer {} + +impl<T, U, const N: u32, const M: u32> PartialOrd<Bounded<U, M>> for Bounded<T, N> +where + T: Integer, + U: Integer, + T: PartialOrd<U>, +{ + fn partial_cmp(&self, other: &Bounded<U, M>) -> Option<cmp::Ordering> { + self.get().partial_cmp(&other.get()) + } +} + +impl<T, const N: u32> Ord for Bounded<T, N> +where + T: Integer, + T: Ord, +{ + fn cmp(&self, other: &Self) -> cmp::Ordering { + self.get().cmp(&other.get()) + } +} + +// Comparisons between a `Bounded` and its backing type. + +impl<T, const N: u32> PartialEq<T> for Bounded<T, N> +where + T: Integer, + T: PartialEq, +{ + fn eq(&self, other: &T) -> bool { + self.get() == *other + } +} + +impl<T, const N: u32> PartialOrd<T> for Bounded<T, N> +where + T: Integer, + T: PartialOrd, +{ + fn partial_cmp(&self, other: &T) -> Option<cmp::Ordering> { + self.get().partial_cmp(other) + } +} + +// Implementations of `core::ops` for two `Bounded` with the same backing type. + +impl<T, const N: u32, const M: u32> ops::Add<Bounded<T, M>> for Bounded<T, N> +where + T: Integer, + T: ops::Add<Output = T>, +{ + type Output = T; + + fn add(self, rhs: Bounded<T, M>) -> Self::Output { + self.get() + rhs.get() + } +} + +impl<T, const N: u32, const M: u32> ops::BitAnd<Bounded<T, M>> for Bounded<T, N> +where + T: Integer, + T: ops::BitAnd<Output = T>, +{ + type Output = T; + + fn bitand(self, rhs: Bounded<T, M>) -> Self::Output { + self.get() & rhs.get() + } +} + +impl<T, const N: u32, const M: u32> ops::BitOr<Bounded<T, M>> for Bounded<T, N> +where + T: Integer, + T: ops::BitOr<Output = T>, +{ + type Output = T; + + fn bitor(self, rhs: Bounded<T, M>) -> Self::Output { + self.get() | rhs.get() + } +} + +impl<T, const N: u32, const M: u32> ops::BitXor<Bounded<T, M>> for Bounded<T, N> +where + T: Integer, + T: ops::BitXor<Output = T>, +{ + type Output = T; + + fn bitxor(self, rhs: Bounded<T, M>) -> Self::Output { + self.get() ^ rhs.get() + } +} + +impl<T, const N: u32, const M: u32> ops::Div<Bounded<T, M>> for Bounded<T, N> +where + T: Integer, + T: ops::Div<Output = T>, +{ + type Output = T; + + fn div(self, rhs: Bounded<T, M>) -> Self::Output { + self.get() / rhs.get() + } +} + +impl<T, const N: u32, const M: u32> ops::Mul<Bounded<T, M>> for Bounded<T, N> +where + T: Integer, + T: ops::Mul<Output = T>, +{ + type Output = T; + + fn mul(self, rhs: Bounded<T, M>) -> Self::Output { + self.get() * rhs.get() + } +} + +impl<T, const N: u32, const M: u32> ops::Rem<Bounded<T, M>> for Bounded<T, N> +where + T: Integer, + T: ops::Rem<Output = T>, +{ + type Output = T; + + fn rem(self, rhs: Bounded<T, M>) -> Self::Output { + self.get() % rhs.get() + } +} + +impl<T, const N: u32, const M: u32> ops::Sub<Bounded<T, M>> for Bounded<T, N> +where + T: Integer, + T: ops::Sub<Output = T>, +{ + type Output = T; + + fn sub(self, rhs: Bounded<T, M>) -> Self::Output { + self.get() - rhs.get() + } +} + +// Implementations of `core::ops` between a `Bounded` and its backing type. + +impl<T, const N: u32> ops::Add<T> for Bounded<T, N> +where + T: Integer, + T: ops::Add<Output = T>, +{ + type Output = T; + + fn add(self, rhs: T) -> Self::Output { + self.get() + rhs + } +} + +impl<T, const N: u32> ops::BitAnd<T> for Bounded<T, N> +where + T: Integer, + T: ops::BitAnd<Output = T>, +{ + type Output = T; + + fn bitand(self, rhs: T) -> Self::Output { + self.get() & rhs + } +} + +impl<T, const N: u32> ops::BitOr<T> for Bounded<T, N> +where + T: Integer, + T: ops::BitOr<Output = T>, +{ + type Output = T; + + fn bitor(self, rhs: T) -> Self::Output { + self.get() | rhs + } +} + +impl<T, const N: u32> ops::BitXor<T> for Bounded<T, N> +where + T: Integer, + T: ops::BitXor<Output = T>, +{ + type Output = T; + + fn bitxor(self, rhs: T) -> Self::Output { + self.get() ^ rhs + } +} + +impl<T, const N: u32> ops::Div<T> for Bounded<T, N> +where + T: Integer, + T: ops::Div<Output = T>, +{ + type Output = T; + + fn div(self, rhs: T) -> Self::Output { + self.get() / rhs + } +} + +impl<T, const N: u32> ops::Mul<T> for Bounded<T, N> +where + T: Integer, + T: ops::Mul<Output = T>, +{ + type Output = T; + + fn mul(self, rhs: T) -> Self::Output { + self.get() * rhs + } +} + +impl<T, const N: u32> ops::Neg for Bounded<T, N> +where + T: Integer, + T: ops::Neg<Output = T>, +{ + type Output = T; + + fn neg(self) -> Self::Output { + -self.get() + } +} + +impl<T, const N: u32> ops::Not for Bounded<T, N> +where + T: Integer, + T: ops::Not<Output = T>, +{ + type Output = T; + + fn not(self) -> Self::Output { + !self.get() + } +} + +impl<T, const N: u32> ops::Rem<T> for Bounded<T, N> +where + T: Integer, + T: ops::Rem<Output = T>, +{ + type Output = T; + + fn rem(self, rhs: T) -> Self::Output { + self.get() % rhs + } +} + +impl<T, const N: u32> ops::Sub<T> for Bounded<T, N> +where + T: Integer, + T: ops::Sub<Output = T>, +{ + type Output = T; + + fn sub(self, rhs: T) -> Self::Output { + self.get() - rhs + } +} + +// Proxy implementations of `core::fmt`. + +impl<T, const N: u32> fmt::Display for Bounded<T, N> +where + T: Integer, + T: fmt::Display, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.get().fmt(f) + } +} + +impl<T, const N: u32> fmt::Binary for Bounded<T, N> +where + T: Integer, + T: fmt::Binary, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.get().fmt(f) + } +} + +impl<T, const N: u32> fmt::LowerExp for Bounded<T, N> +where + T: Integer, + T: fmt::LowerExp, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.get().fmt(f) + } +} + +impl<T, const N: u32> fmt::LowerHex for Bounded<T, N> +where + T: Integer, + T: fmt::LowerHex, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.get().fmt(f) + } +} + +impl<T, const N: u32> fmt::Octal for Bounded<T, N> +where + T: Integer, + T: fmt::Octal, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.get().fmt(f) + } +} + +impl<T, const N: u32> fmt::UpperExp for Bounded<T, N> +where + T: Integer, + T: fmt::UpperExp, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.get().fmt(f) + } +} + +impl<T, const N: u32> fmt::UpperHex for Bounded<T, N> +where + T: Integer, + T: fmt::UpperHex, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.get().fmt(f) + } +} + +/// Implements `$trait` for all [`Bounded`] types represented using `$num_bits`. +/// +/// This is used to declare size properties as traits that we can constrain against in impl blocks. +macro_rules! impl_size_rule { + ($trait:ty, $($num_bits:literal)*) => { + $( + impl<T> $trait for Bounded<T, $num_bits> where T: Integer {} + )* + }; +} + +/// Local trait expressing the fact that a given [`Bounded`] has at least `N` bits used for value +/// representation. +trait AtLeastXBits<const N: usize> {} + +/// Implementations for infallibly converting a primitive type into a [`Bounded`] that can contain +/// it. +/// +/// Put into their own module for readability, and to avoid cluttering the rustdoc of the parent +/// module. +mod atleast_impls { + use super::*; + + // Number of bits at least as large as 64. + impl_size_rule!(AtLeastXBits<64>, 64); + + // Anything 64 bits or more is also larger than 32. + impl<T> AtLeastXBits<32> for T where T: AtLeastXBits<64> {} + // Other numbers of bits at least as large as 32. + impl_size_rule!(AtLeastXBits<32>, + 32 33 34 35 36 37 38 39 + 40 41 42 43 44 45 46 47 + 48 49 50 51 52 53 54 55 + 56 57 58 59 60 61 62 63 + ); + + // Anything 32 bits or more is also larger than 16. + impl<T> AtLeastXBits<16> for T where T: AtLeastXBits<32> {} + // Other numbers of bits at least as large as 16. + impl_size_rule!(AtLeastXBits<16>, + 16 17 18 19 20 21 22 23 + 24 25 26 27 28 29 30 31 + ); + + // Anything 16 bits or more is also larger than 8. + impl<T> AtLeastXBits<8> for T where T: AtLeastXBits<16> {} + // Other numbers of bits at least as large as 8. + impl_size_rule!(AtLeastXBits<8>, 8 9 10 11 12 13 14 15); +} + +/// Generates `From` implementations from a primitive type into a [`Bounded`] with +/// enough bits to store any value of that type. +/// +/// Note: The only reason for having this macro is that if we pass `$type` as a generic +/// parameter, we cannot use it in the const context of [`AtLeastXBits`]'s generic parameter. This +/// can be fixed once the `generic_const_exprs` feature is usable, and this macro replaced by a +/// regular `impl` block. +macro_rules! impl_from_primitive { + ($($type:ty)*) => { + $( + #[doc = ::core::concat!( + "Conversion from a [`", + ::core::stringify!($type), + "`] into a [`Bounded`] of same signedness with enough bits to store it.")] + impl<T, const N: u32> From<$type> for Bounded<T, N> + where + $type: Integer, + T: Integer<Signedness = <$type as Integer>::Signedness> + From<$type>, + Self: AtLeastXBits<{ <$type as Integer>::BITS as usize }>, + { + fn from(value: $type) -> Self { + // INVARIANT: The trait bound on `Self` guarantees that `N` bits is + // enough to hold any value of the source type. + Self::__new(T::from(value)) + } + } + )* + } +} + +impl_from_primitive!( + u8 u16 u32 u64 usize + i8 i16 i32 i64 isize +); + +/// Local trait expressing the fact that a given [`Bounded`] fits into a primitive type of `N` bits, +/// provided they have the same signedness. +trait FitsInXBits<const N: usize> {} + +/// Implementations for infallibly converting a [`Bounded`] into a primitive type that can contain +/// it. +/// +/// Put into their own module for readability, and to avoid cluttering the rustdoc of the parent +/// module. +mod fits_impls { + use super::*; + + // Number of bits that fit into a 8-bits primitive. + impl_size_rule!(FitsInXBits<8>, 1 2 3 4 5 6 7 8); + + // Anything that fits into 8 bits also fits into 16. + impl<T> FitsInXBits<16> for T where T: FitsInXBits<8> {} + // Other number of bits that fit into a 16-bits primitive. + impl_size_rule!(FitsInXBits<16>, 9 10 11 12 13 14 15 16); + + // Anything that fits into 16 bits also fits into 32. + impl<T> FitsInXBits<32> for T where T: FitsInXBits<16> {} + // Other number of bits that fit into a 32-bits primitive. + impl_size_rule!(FitsInXBits<32>, + 17 18 19 20 21 22 23 24 + 25 26 27 28 29 30 31 32 + ); + + // Anything that fits into 32 bits also fits into 64. + impl<T> FitsInXBits<64> for T where T: FitsInXBits<32> {} + // Other number of bits that fit into a 64-bits primitive. + impl_size_rule!(FitsInXBits<64>, + 33 34 35 36 37 38 39 40 + 41 42 43 44 45 46 47 48 + 49 50 51 52 53 54 55 56 + 57 58 59 60 61 62 63 64 + ); +} + +/// Generates [`From`] implementations from a [`Bounded`] into a primitive type that is +/// guaranteed to contain it. +/// +/// Note: The only reason for having this macro is that if we pass `$type` as a generic +/// parameter, we cannot use it in the const context of `AtLeastXBits`'s generic parameter. This +/// can be fixed once the `generic_const_exprs` feature is usable, and this macro replaced by a +/// regular `impl` block. +macro_rules! impl_into_primitive { + ($($type:ty)*) => { + $( + #[doc = ::core::concat!( + "Conversion from a [`Bounded`] with no more bits than a [`", + ::core::stringify!($type), + "`] and of same signedness into [`", + ::core::stringify!($type), + "`]")] + impl<T, const N: u32> From<Bounded<T, N>> for $type + where + $type: Integer + TryFrom<T>, + T: Integer<Signedness = <$type as Integer>::Signedness>, + Bounded<T, N>: FitsInXBits<{ <$type as Integer>::BITS as usize }>, + { + fn from(value: Bounded<T, N>) -> $type { + // SAFETY: The trait bound on `Bounded` ensures that any value it holds (which + // is constrained to `N` bits) can fit into the destination type, so this + // conversion cannot fail. + unsafe { <$type>::try_from(value.get()).unwrap_unchecked() } + } + } + )* + } +} + +impl_into_primitive!( + u8 u16 u32 u64 usize + i8 i16 i32 i64 isize +); + +// Single-bit `Bounded`s can be converted from/to a boolean. + +impl<T> From<Bounded<T, 1>> for bool +where + T: Integer + Zeroable, +{ + fn from(value: Bounded<T, 1>) -> Self { + value.get() != Zeroable::zeroed() + } +} + +impl<T, const N: u32> From<bool> for Bounded<T, N> +where + T: Integer + From<bool>, +{ + fn from(value: bool) -> Self { + // INVARIANT: A boolean can be represented using a single bit, and thus fits within any + // integer type for any `N` > 0. + Self::__new(T::from(value)) + } +} diff --git a/rust/kernel/of.rs b/rust/kernel/of.rs new file mode 100644 index 000000000000..58b20c367f99 --- /dev/null +++ b/rust/kernel/of.rs @@ -0,0 +1,65 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! Device Tree / Open Firmware abstractions. + +use crate::{ + bindings, + device_id::{RawDeviceId, RawDeviceIdIndex}, + prelude::*, +}; + +/// IdTable type for OF drivers. +pub type IdTable<T> = &'static dyn kernel::device_id::IdTable<DeviceId, T>; + +/// An open firmware device id. +#[repr(transparent)] +#[derive(Clone, Copy)] +pub struct DeviceId(bindings::of_device_id); + +// SAFETY: `DeviceId` is a `#[repr(transparent)]` wrapper of `struct of_device_id` and +// does not add additional invariants, so it's safe to transmute to `RawType`. +unsafe impl RawDeviceId for DeviceId { + type RawType = bindings::of_device_id; +} + +// SAFETY: `DRIVER_DATA_OFFSET` is the offset to the `data` field. +unsafe impl RawDeviceIdIndex for DeviceId { + const DRIVER_DATA_OFFSET: usize = core::mem::offset_of!(bindings::of_device_id, data); + + fn index(&self) -> usize { + self.0.data as usize + } +} + +impl DeviceId { + /// Create a new device id from an OF 'compatible' string. + pub const fn new(compatible: &'static CStr) -> Self { + let src = compatible.to_bytes_with_nul(); + // Replace with `bindings::of_device_id::default()` once stabilized for `const`. + // SAFETY: FFI type is valid to be zero-initialized. + let mut of: bindings::of_device_id = unsafe { core::mem::zeroed() }; + + // TODO: Use `copy_from_slice` once stabilized for `const`. + let mut i = 0; + while i < src.len() { + of.compatible[i] = src[i]; + i += 1; + } + + Self(of) + } +} + +/// Create an OF `IdTable` with an "alias" for modpost. +#[macro_export] +macro_rules! of_device_table { + ($table_name:ident, $module_table_name:ident, $id_info_type: ty, $table_data: expr) => { + const $table_name: $crate::device_id::IdArray< + $crate::of::DeviceId, + $id_info_type, + { $table_data.len() }, + > = $crate::device_id::IdArray::new($table_data); + + $crate::module_device_table!("of", $module_table_name, $table_name); + }; +} diff --git a/rust/kernel/opp.rs b/rust/kernel/opp.rs new file mode 100644 index 000000000000..a760fac28765 --- /dev/null +++ b/rust/kernel/opp.rs @@ -0,0 +1,1145 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! Operating performance points. +//! +//! This module provides rust abstractions for interacting with the OPP subsystem. +//! +//! C header: [`include/linux/pm_opp.h`](srctree/include/linux/pm_opp.h) +//! +//! Reference: <https://docs.kernel.org/power/opp.html> + +use crate::{ + clk::Hertz, + cpumask::{Cpumask, CpumaskVar}, + device::Device, + error::{code::*, from_err_ptr, from_result, to_result, Result, VTABLE_DEFAULT_ERROR}, + ffi::{c_char, c_ulong}, + prelude::*, + str::CString, + sync::aref::{ARef, AlwaysRefCounted}, + types::Opaque, +}; + +#[cfg(CONFIG_CPU_FREQ)] +/// Frequency table implementation. +mod freq { + use super::*; + use crate::cpufreq; + use core::ops::Deref; + + /// OPP frequency table. + /// + /// A [`cpufreq::Table`] created from [`Table`]. + pub struct FreqTable { + dev: ARef<Device>, + ptr: *mut bindings::cpufreq_frequency_table, + } + + impl FreqTable { + /// Creates a new instance of [`FreqTable`] from [`Table`]. + pub(crate) fn new(table: &Table) -> Result<Self> { + let mut ptr: *mut bindings::cpufreq_frequency_table = ptr::null_mut(); + + // SAFETY: The requirements are satisfied by the existence of [`Device`] and its safety + // requirements. + to_result(unsafe { + bindings::dev_pm_opp_init_cpufreq_table(table.dev.as_raw(), &mut ptr) + })?; + + Ok(Self { + dev: table.dev.clone(), + ptr, + }) + } + + /// Returns a reference to the underlying [`cpufreq::Table`]. + #[inline] + fn table(&self) -> &cpufreq::Table { + // SAFETY: The `ptr` is guaranteed by the C code to be valid. + unsafe { cpufreq::Table::from_raw(self.ptr) } + } + } + + impl Deref for FreqTable { + type Target = cpufreq::Table; + + #[inline] + fn deref(&self) -> &Self::Target { + self.table() + } + } + + impl Drop for FreqTable { + fn drop(&mut self) { + // SAFETY: The pointer was created via `dev_pm_opp_init_cpufreq_table`, and is only + // freed here. + unsafe { + bindings::dev_pm_opp_free_cpufreq_table(self.dev.as_raw(), &mut self.as_raw()) + }; + } + } +} + +#[cfg(CONFIG_CPU_FREQ)] +pub use freq::FreqTable; + +use core::{marker::PhantomData, ptr}; + +use macros::vtable; + +/// Creates a null-terminated slice of pointers to [`CString`]s. +fn to_c_str_array(names: &[CString]) -> Result<KVec<*const c_char>> { + // Allocated a null-terminated vector of pointers. + let mut list = KVec::with_capacity(names.len() + 1, GFP_KERNEL)?; + + for name in names.iter() { + list.push(name.as_char_ptr(), GFP_KERNEL)?; + } + + list.push(ptr::null(), GFP_KERNEL)?; + Ok(list) +} + +/// The voltage unit. +/// +/// Represents voltage in microvolts, wrapping a [`c_ulong`] value. +/// +/// # Examples +/// +/// ``` +/// use kernel::opp::MicroVolt; +/// +/// let raw = 90500; +/// let volt = MicroVolt(raw); +/// +/// assert_eq!(usize::from(volt), raw); +/// assert_eq!(volt, MicroVolt(raw)); +/// ``` +#[derive(Copy, Clone, PartialEq, Eq, Debug)] +pub struct MicroVolt(pub c_ulong); + +impl From<MicroVolt> for c_ulong { + #[inline] + fn from(volt: MicroVolt) -> Self { + volt.0 + } +} + +/// The power unit. +/// +/// Represents power in microwatts, wrapping a [`c_ulong`] value. +/// +/// # Examples +/// +/// ``` +/// use kernel::opp::MicroWatt; +/// +/// let raw = 1000000; +/// let power = MicroWatt(raw); +/// +/// assert_eq!(usize::from(power), raw); +/// assert_eq!(power, MicroWatt(raw)); +/// ``` +#[derive(Copy, Clone, PartialEq, Eq, Debug)] +pub struct MicroWatt(pub c_ulong); + +impl From<MicroWatt> for c_ulong { + #[inline] + fn from(power: MicroWatt) -> Self { + power.0 + } +} + +/// Handle for a dynamically created [`OPP`]. +/// +/// The associated [`OPP`] is automatically removed when the [`Token`] is dropped. +/// +/// # Examples +/// +/// The following example demonstrates how to create an [`OPP`] dynamically. +/// +/// ``` +/// use kernel::clk::Hertz; +/// use kernel::device::Device; +/// use kernel::error::Result; +/// use kernel::opp::{Data, MicroVolt, Token}; +/// use kernel::sync::aref::ARef; +/// +/// fn create_opp(dev: &ARef<Device>, freq: Hertz, volt: MicroVolt, level: u32) -> Result<Token> { +/// let data = Data::new(freq, volt, level, false); +/// +/// // OPP is removed once token goes out of scope. +/// data.add_opp(dev) +/// } +/// ``` +pub struct Token { + dev: ARef<Device>, + freq: Hertz, +} + +impl Token { + /// Dynamically adds an [`OPP`] and returns a [`Token`] that removes it on drop. + fn new(dev: &ARef<Device>, mut data: Data) -> Result<Self> { + // SAFETY: The requirements are satisfied by the existence of [`Device`] and its safety + // requirements. + to_result(unsafe { bindings::dev_pm_opp_add_dynamic(dev.as_raw(), &mut data.0) })?; + Ok(Self { + dev: dev.clone(), + freq: data.freq(), + }) + } +} + +impl Drop for Token { + fn drop(&mut self) { + // SAFETY: The requirements are satisfied by the existence of [`Device`] and its safety + // requirements. + unsafe { bindings::dev_pm_opp_remove(self.dev.as_raw(), self.freq.into()) }; + } +} + +/// OPP data. +/// +/// Rust abstraction for the C `struct dev_pm_opp_data`, used to define operating performance +/// points (OPPs) dynamically. +/// +/// # Examples +/// +/// The following example demonstrates how to create an [`OPP`] with [`Data`]. +/// +/// ``` +/// use kernel::clk::Hertz; +/// use kernel::device::Device; +/// use kernel::error::Result; +/// use kernel::opp::{Data, MicroVolt, Token}; +/// use kernel::sync::aref::ARef; +/// +/// fn create_opp(dev: &ARef<Device>, freq: Hertz, volt: MicroVolt, level: u32) -> Result<Token> { +/// let data = Data::new(freq, volt, level, false); +/// +/// // OPP is removed once token goes out of scope. +/// data.add_opp(dev) +/// } +/// ``` +#[repr(transparent)] +pub struct Data(bindings::dev_pm_opp_data); + +impl Data { + /// Creates a new instance of [`Data`]. + /// + /// This can be used to define a dynamic OPP to be added to a device. + pub fn new(freq: Hertz, volt: MicroVolt, level: u32, turbo: bool) -> Self { + Self(bindings::dev_pm_opp_data { + turbo, + freq: freq.into(), + u_volt: volt.into(), + level, + }) + } + + /// Adds an [`OPP`] dynamically. + /// + /// Returns a [`Token`] that ensures the OPP is automatically removed + /// when it goes out of scope. + #[inline] + pub fn add_opp(self, dev: &ARef<Device>) -> Result<Token> { + Token::new(dev, self) + } + + /// Returns the frequency associated with this OPP data. + #[inline] + fn freq(&self) -> Hertz { + Hertz(self.0.freq) + } +} + +/// [`OPP`] search options. +/// +/// # Examples +/// +/// Defines how to search for an [`OPP`] in a [`Table`] relative to a frequency. +/// +/// ``` +/// use kernel::clk::Hertz; +/// use kernel::error::Result; +/// use kernel::opp::{OPP, SearchType, Table}; +/// use kernel::sync::aref::ARef; +/// +/// fn find_opp(table: &Table, freq: Hertz) -> Result<ARef<OPP>> { +/// let opp = table.opp_from_freq(freq, Some(true), None, SearchType::Exact)?; +/// +/// pr_info!("OPP frequency is: {:?}\n", opp.freq(None)); +/// pr_info!("OPP voltage is: {:?}\n", opp.voltage()); +/// pr_info!("OPP level is: {}\n", opp.level()); +/// pr_info!("OPP power is: {:?}\n", opp.power()); +/// +/// Ok(opp) +/// } +/// ``` +#[derive(Copy, Clone, Debug, Eq, PartialEq)] +pub enum SearchType { + /// Match the exact frequency. + Exact, + /// Find the highest frequency less than or equal to the given value. + Floor, + /// Find the lowest frequency greater than or equal to the given value. + Ceil, +} + +/// OPP configuration callbacks. +/// +/// Implement this trait to customize OPP clock and regulator setup for your device. +#[vtable] +pub trait ConfigOps { + /// This is typically used to scale clocks when transitioning between OPPs. + #[inline] + fn config_clks(_dev: &Device, _table: &Table, _opp: &OPP, _scaling_down: bool) -> Result { + build_error!(VTABLE_DEFAULT_ERROR) + } + + /// This provides access to the old and new OPPs, allowing for safe regulator adjustments. + #[inline] + fn config_regulators( + _dev: &Device, + _opp_old: &OPP, + _opp_new: &OPP, + _data: *mut *mut bindings::regulator, + _count: u32, + ) -> Result { + build_error!(VTABLE_DEFAULT_ERROR) + } +} + +/// OPP configuration token. +/// +/// Returned by the OPP core when configuration is applied to a [`Device`]. The associated +/// configuration is automatically cleared when the token is dropped. +pub struct ConfigToken(i32); + +impl Drop for ConfigToken { + fn drop(&mut self) { + // SAFETY: This is the same token value returned by the C code via `dev_pm_opp_set_config`. + unsafe { bindings::dev_pm_opp_clear_config(self.0) }; + } +} + +/// OPP configurations. +/// +/// Rust abstraction for the C `struct dev_pm_opp_config`. +/// +/// # Examples +/// +/// The following example demonstrates how to set OPP property-name configuration for a [`Device`]. +/// +/// ``` +/// use kernel::device::Device; +/// use kernel::error::Result; +/// use kernel::opp::{Config, ConfigOps, ConfigToken}; +/// use kernel::str::CString; +/// use kernel::sync::aref::ARef; +/// use kernel::macros::vtable; +/// +/// #[derive(Default)] +/// struct Driver; +/// +/// #[vtable] +/// impl ConfigOps for Driver {} +/// +/// fn configure(dev: &ARef<Device>) -> Result<ConfigToken> { +/// let name = CString::try_from_fmt(fmt!("slow"))?; +/// +/// // The OPP configuration is cleared once the [`ConfigToken`] goes out of scope. +/// Config::<Driver>::new() +/// .set_prop_name(name)? +/// .set(dev) +/// } +/// ``` +#[derive(Default)] +pub struct Config<T: ConfigOps> +where + T: Default, +{ + clk_names: Option<KVec<CString>>, + prop_name: Option<CString>, + regulator_names: Option<KVec<CString>>, + supported_hw: Option<KVec<u32>>, + + // Tuple containing (required device, index) + required_dev: Option<(ARef<Device>, u32)>, + _data: PhantomData<T>, +} + +impl<T: ConfigOps + Default> Config<T> { + /// Creates a new instance of [`Config`]. + #[inline] + pub fn new() -> Self { + Self::default() + } + + /// Initializes clock names. + pub fn set_clk_names(mut self, names: KVec<CString>) -> Result<Self> { + if self.clk_names.is_some() { + return Err(EBUSY); + } + + if names.is_empty() { + return Err(EINVAL); + } + + self.clk_names = Some(names); + Ok(self) + } + + /// Initializes property name. + pub fn set_prop_name(mut self, name: CString) -> Result<Self> { + if self.prop_name.is_some() { + return Err(EBUSY); + } + + self.prop_name = Some(name); + Ok(self) + } + + /// Initializes regulator names. + pub fn set_regulator_names(mut self, names: KVec<CString>) -> Result<Self> { + if self.regulator_names.is_some() { + return Err(EBUSY); + } + + if names.is_empty() { + return Err(EINVAL); + } + + self.regulator_names = Some(names); + + Ok(self) + } + + /// Initializes required devices. + pub fn set_required_dev(mut self, dev: ARef<Device>, index: u32) -> Result<Self> { + if self.required_dev.is_some() { + return Err(EBUSY); + } + + self.required_dev = Some((dev, index)); + Ok(self) + } + + /// Initializes supported hardware. + pub fn set_supported_hw(mut self, hw: KVec<u32>) -> Result<Self> { + if self.supported_hw.is_some() { + return Err(EBUSY); + } + + if hw.is_empty() { + return Err(EINVAL); + } + + self.supported_hw = Some(hw); + Ok(self) + } + + /// Sets the configuration with the OPP core. + /// + /// The returned [`ConfigToken`] will remove the configuration when dropped. + pub fn set(self, dev: &Device) -> Result<ConfigToken> { + let clk_names = self.clk_names.as_deref().map(to_c_str_array).transpose()?; + let regulator_names = self + .regulator_names + .as_deref() + .map(to_c_str_array) + .transpose()?; + + let set_config = || { + let clk_names = clk_names.as_ref().map_or(ptr::null(), |c| c.as_ptr()); + let regulator_names = regulator_names.as_ref().map_or(ptr::null(), |c| c.as_ptr()); + + let prop_name = self + .prop_name + .as_ref() + .map_or(ptr::null(), |p| p.as_char_ptr()); + + let (supported_hw, supported_hw_count) = self + .supported_hw + .as_ref() + .map_or((ptr::null(), 0), |hw| (hw.as_ptr(), hw.len() as u32)); + + let (required_dev, required_dev_index) = self + .required_dev + .as_ref() + .map_or((ptr::null_mut(), 0), |(dev, idx)| (dev.as_raw(), *idx)); + + let mut config = bindings::dev_pm_opp_config { + clk_names, + config_clks: if T::HAS_CONFIG_CLKS { + Some(Self::config_clks) + } else { + None + }, + prop_name, + regulator_names, + config_regulators: if T::HAS_CONFIG_REGULATORS { + Some(Self::config_regulators) + } else { + None + }, + supported_hw, + supported_hw_count, + + required_dev, + required_dev_index, + }; + + // SAFETY: The requirements are satisfied by the existence of [`Device`] and its safety + // requirements. The OPP core guarantees not to access fields of [`Config`] after this + // call and so we don't need to save a copy of them for future use. + let ret = unsafe { bindings::dev_pm_opp_set_config(dev.as_raw(), &mut config) }; + + to_result(ret).map(|()| ConfigToken(ret)) + }; + + // Ensure the closure does not accidentally drop owned data; if violated, the compiler + // produces E0525 with e.g.: + // + // ``` + // closure is `FnOnce` because it moves the variable `clk_names` out of its environment + // ``` + let _: &dyn Fn() -> _ = &set_config; + + set_config() + } + + /// Config's clk callback. + /// + /// SAFETY: Called from C. Inputs must be valid pointers. + extern "C" fn config_clks( + dev: *mut bindings::device, + opp_table: *mut bindings::opp_table, + opp: *mut bindings::dev_pm_opp, + _data: *mut c_void, + scaling_down: bool, + ) -> c_int { + from_result(|| { + // SAFETY: 'dev' is guaranteed by the C code to be valid. + let dev = unsafe { Device::get_device(dev) }; + T::config_clks( + &dev, + // SAFETY: 'opp_table' is guaranteed by the C code to be valid. + &unsafe { Table::from_raw_table(opp_table, &dev) }, + // SAFETY: 'opp' is guaranteed by the C code to be valid. + unsafe { OPP::from_raw_opp(opp)? }, + scaling_down, + ) + .map(|()| 0) + }) + } + + /// Config's regulator callback. + /// + /// SAFETY: Called from C. Inputs must be valid pointers. + extern "C" fn config_regulators( + dev: *mut bindings::device, + old_opp: *mut bindings::dev_pm_opp, + new_opp: *mut bindings::dev_pm_opp, + regulators: *mut *mut bindings::regulator, + count: c_uint, + ) -> c_int { + from_result(|| { + // SAFETY: 'dev' is guaranteed by the C code to be valid. + let dev = unsafe { Device::get_device(dev) }; + T::config_regulators( + &dev, + // SAFETY: 'old_opp' is guaranteed by the C code to be valid. + unsafe { OPP::from_raw_opp(old_opp)? }, + // SAFETY: 'new_opp' is guaranteed by the C code to be valid. + unsafe { OPP::from_raw_opp(new_opp)? }, + regulators, + count, + ) + .map(|()| 0) + }) + } +} + +/// A reference-counted OPP table. +/// +/// Rust abstraction for the C `struct opp_table`. +/// +/// # Invariants +/// +/// The pointer stored in `Self` is non-null and valid for the lifetime of the [`Table`]. +/// +/// Instances of this type are reference-counted. +/// +/// # Examples +/// +/// The following example demonstrates how to get OPP [`Table`] for a [`Cpumask`] and set its +/// frequency. +/// +/// ``` +/// # #![cfg(CONFIG_OF)] +/// use kernel::clk::Hertz; +/// use kernel::cpumask::Cpumask; +/// use kernel::device::Device; +/// use kernel::error::Result; +/// use kernel::opp::Table; +/// use kernel::sync::aref::ARef; +/// +/// fn get_table(dev: &ARef<Device>, mask: &mut Cpumask, freq: Hertz) -> Result<Table> { +/// let mut opp_table = Table::from_of_cpumask(dev, mask)?; +/// +/// if opp_table.opp_count()? == 0 { +/// return Err(EINVAL); +/// } +/// +/// pr_info!("Max transition latency is: {} ns\n", opp_table.max_transition_latency_ns()); +/// pr_info!("Suspend frequency is: {:?}\n", opp_table.suspend_freq()); +/// +/// opp_table.set_rate(freq)?; +/// Ok(opp_table) +/// } +/// ``` +pub struct Table { + ptr: *mut bindings::opp_table, + dev: ARef<Device>, + #[allow(dead_code)] + em: bool, + #[allow(dead_code)] + of: bool, + cpus: Option<CpumaskVar>, +} + +/// SAFETY: It is okay to send ownership of [`Table`] across thread boundaries. +unsafe impl Send for Table {} + +/// SAFETY: It is okay to access [`Table`] through shared references from other threads because +/// we're either accessing properties that don't change or that are properly synchronised by C code. +unsafe impl Sync for Table {} + +impl Table { + /// Creates a new reference-counted [`Table`] from a raw pointer. + /// + /// # Safety + /// + /// Callers must ensure that `ptr` is valid and non-null. + unsafe fn from_raw_table(ptr: *mut bindings::opp_table, dev: &ARef<Device>) -> Self { + // SAFETY: By the safety requirements, ptr is valid and its refcount will be incremented. + // + // INVARIANT: The reference-count is decremented when [`Table`] goes out of scope. + unsafe { bindings::dev_pm_opp_get_opp_table_ref(ptr) }; + + Self { + ptr, + dev: dev.clone(), + em: false, + of: false, + cpus: None, + } + } + + /// Creates a new reference-counted [`Table`] instance for a [`Device`]. + pub fn from_dev(dev: &Device) -> Result<Self> { + // SAFETY: The requirements are satisfied by the existence of the [`Device`] and its safety + // requirements. + // + // INVARIANT: The reference-count is incremented by the C code and is decremented when + // [`Table`] goes out of scope. + let ptr = from_err_ptr(unsafe { bindings::dev_pm_opp_get_opp_table(dev.as_raw()) })?; + + Ok(Self { + ptr, + dev: dev.into(), + em: false, + of: false, + cpus: None, + }) + } + + /// Creates a new reference-counted [`Table`] instance for a [`Device`] based on device tree + /// entries. + #[cfg(CONFIG_OF)] + pub fn from_of(dev: &ARef<Device>, index: i32) -> Result<Self> { + // SAFETY: The requirements are satisfied by the existence of the [`Device`] and its safety + // requirements. + // + // INVARIANT: The reference-count is incremented by the C code and is decremented when + // [`Table`] goes out of scope. + to_result(unsafe { bindings::dev_pm_opp_of_add_table_indexed(dev.as_raw(), index) })?; + + // Get the newly created [`Table`]. + let mut table = Self::from_dev(dev)?; + table.of = true; + + Ok(table) + } + + /// Remove device tree based [`Table`]. + #[cfg(CONFIG_OF)] + #[inline] + fn remove_of(&self) { + // SAFETY: The requirements are satisfied by the existence of the [`Device`] and its safety + // requirements. We took the reference from [`from_of`] earlier, it is safe to drop the + // same now. + unsafe { bindings::dev_pm_opp_of_remove_table(self.dev.as_raw()) }; + } + + /// Creates a new reference-counted [`Table`] instance for a [`Cpumask`] based on device tree + /// entries. + #[cfg(CONFIG_OF)] + pub fn from_of_cpumask(dev: &Device, cpumask: &mut Cpumask) -> Result<Self> { + // SAFETY: The cpumask is valid and the returned pointer will be owned by the [`Table`] + // instance. + // + // INVARIANT: The reference-count is incremented by the C code and is decremented when + // [`Table`] goes out of scope. + to_result(unsafe { bindings::dev_pm_opp_of_cpumask_add_table(cpumask.as_raw()) })?; + + // Fetch the newly created table. + let mut table = Self::from_dev(dev)?; + table.cpus = Some(CpumaskVar::try_clone(cpumask)?); + + Ok(table) + } + + /// Remove device tree based [`Table`] for a [`Cpumask`]. + #[cfg(CONFIG_OF)] + #[inline] + fn remove_of_cpumask(&self, cpumask: &Cpumask) { + // SAFETY: The cpumask is valid and we took the reference from [`from_of_cpumask`] earlier, + // it is safe to drop the same now. + unsafe { bindings::dev_pm_opp_of_cpumask_remove_table(cpumask.as_raw()) }; + } + + /// Returns the number of [`OPP`]s in the [`Table`]. + pub fn opp_count(&self) -> Result<u32> { + // SAFETY: The requirements are satisfied by the existence of [`Device`] and its safety + // requirements. + let ret = unsafe { bindings::dev_pm_opp_get_opp_count(self.dev.as_raw()) }; + + to_result(ret).map(|()| ret as u32) + } + + /// Returns max clock latency (in nanoseconds) of the [`OPP`]s in the [`Table`]. + #[inline] + pub fn max_clock_latency_ns(&self) -> usize { + // SAFETY: The requirements are satisfied by the existence of [`Device`] and its safety + // requirements. + unsafe { bindings::dev_pm_opp_get_max_clock_latency(self.dev.as_raw()) } + } + + /// Returns max volt latency (in nanoseconds) of the [`OPP`]s in the [`Table`]. + #[inline] + pub fn max_volt_latency_ns(&self) -> usize { + // SAFETY: The requirements are satisfied by the existence of [`Device`] and its safety + // requirements. + unsafe { bindings::dev_pm_opp_get_max_volt_latency(self.dev.as_raw()) } + } + + /// Returns max transition latency (in nanoseconds) of the [`OPP`]s in the [`Table`]. + #[inline] + pub fn max_transition_latency_ns(&self) -> usize { + // SAFETY: The requirements are satisfied by the existence of [`Device`] and its safety + // requirements. + unsafe { bindings::dev_pm_opp_get_max_transition_latency(self.dev.as_raw()) } + } + + /// Returns the suspend [`OPP`]'s frequency. + #[inline] + pub fn suspend_freq(&self) -> Hertz { + // SAFETY: The requirements are satisfied by the existence of [`Device`] and its safety + // requirements. + Hertz(unsafe { bindings::dev_pm_opp_get_suspend_opp_freq(self.dev.as_raw()) }) + } + + /// Synchronizes regulators used by the [`Table`]. + #[inline] + pub fn sync_regulators(&self) -> Result { + // SAFETY: The requirements are satisfied by the existence of [`Device`] and its safety + // requirements. + to_result(unsafe { bindings::dev_pm_opp_sync_regulators(self.dev.as_raw()) }) + } + + /// Gets sharing CPUs. + #[inline] + pub fn sharing_cpus(dev: &Device, cpumask: &mut Cpumask) -> Result { + // SAFETY: The requirements are satisfied by the existence of [`Device`] and its safety + // requirements. + to_result(unsafe { bindings::dev_pm_opp_get_sharing_cpus(dev.as_raw(), cpumask.as_raw()) }) + } + + /// Sets sharing CPUs. + pub fn set_sharing_cpus(&mut self, cpumask: &mut Cpumask) -> Result { + // SAFETY: The requirements are satisfied by the existence of [`Device`] and its safety + // requirements. + to_result(unsafe { + bindings::dev_pm_opp_set_sharing_cpus(self.dev.as_raw(), cpumask.as_raw()) + })?; + + if let Some(mask) = self.cpus.as_mut() { + // Update the cpumask as this will be used while removing the table. + cpumask.copy(mask); + } + + Ok(()) + } + + /// Gets sharing CPUs from device tree. + #[cfg(CONFIG_OF)] + #[inline] + pub fn of_sharing_cpus(dev: &Device, cpumask: &mut Cpumask) -> Result { + // SAFETY: The requirements are satisfied by the existence of [`Device`] and its safety + // requirements. + to_result(unsafe { + bindings::dev_pm_opp_of_get_sharing_cpus(dev.as_raw(), cpumask.as_raw()) + }) + } + + /// Updates the voltage value for an [`OPP`]. + #[inline] + pub fn adjust_voltage( + &self, + freq: Hertz, + volt: MicroVolt, + volt_min: MicroVolt, + volt_max: MicroVolt, + ) -> Result { + // SAFETY: The requirements are satisfied by the existence of [`Device`] and its safety + // requirements. + to_result(unsafe { + bindings::dev_pm_opp_adjust_voltage( + self.dev.as_raw(), + freq.into(), + volt.into(), + volt_min.into(), + volt_max.into(), + ) + }) + } + + /// Creates [`FreqTable`] from [`Table`]. + #[cfg(CONFIG_CPU_FREQ)] + #[inline] + pub fn cpufreq_table(&mut self) -> Result<FreqTable> { + FreqTable::new(self) + } + + /// Configures device with [`OPP`] matching the frequency value. + #[inline] + pub fn set_rate(&self, freq: Hertz) -> Result { + // SAFETY: The requirements are satisfied by the existence of [`Device`] and its safety + // requirements. + to_result(unsafe { bindings::dev_pm_opp_set_rate(self.dev.as_raw(), freq.into()) }) + } + + /// Configures device with [`OPP`]. + #[inline] + pub fn set_opp(&self, opp: &OPP) -> Result { + // SAFETY: The requirements are satisfied by the existence of [`Device`] and its safety + // requirements. + to_result(unsafe { bindings::dev_pm_opp_set_opp(self.dev.as_raw(), opp.as_raw()) }) + } + + /// Finds [`OPP`] based on frequency. + pub fn opp_from_freq( + &self, + freq: Hertz, + available: Option<bool>, + index: Option<u32>, + stype: SearchType, + ) -> Result<ARef<OPP>> { + let raw_dev = self.dev.as_raw(); + let index = index.unwrap_or(0); + let mut rate = freq.into(); + + let ptr = from_err_ptr(match stype { + SearchType::Exact => { + if let Some(available) = available { + // SAFETY: The requirements are satisfied by the existence of [`Device`] and + // its safety requirements. The returned pointer will be owned by the new + // [`OPP`] instance. + unsafe { + bindings::dev_pm_opp_find_freq_exact_indexed( + raw_dev, rate, index, available, + ) + } + } else { + return Err(EINVAL); + } + } + + // SAFETY: The requirements are satisfied by the existence of [`Device`] and its safety + // requirements. The returned pointer will be owned by the new [`OPP`] instance. + SearchType::Ceil => unsafe { + bindings::dev_pm_opp_find_freq_ceil_indexed(raw_dev, &mut rate, index) + }, + + // SAFETY: The requirements are satisfied by the existence of [`Device`] and its safety + // requirements. The returned pointer will be owned by the new [`OPP`] instance. + SearchType::Floor => unsafe { + bindings::dev_pm_opp_find_freq_floor_indexed(raw_dev, &mut rate, index) + }, + })?; + + // SAFETY: The `ptr` is guaranteed by the C code to be valid. + unsafe { OPP::from_raw_opp_owned(ptr) } + } + + /// Finds [`OPP`] based on level. + pub fn opp_from_level(&self, mut level: u32, stype: SearchType) -> Result<ARef<OPP>> { + let raw_dev = self.dev.as_raw(); + + let ptr = from_err_ptr(match stype { + // SAFETY: The requirements are satisfied by the existence of [`Device`] and its safety + // requirements. The returned pointer will be owned by the new [`OPP`] instance. + SearchType::Exact => unsafe { bindings::dev_pm_opp_find_level_exact(raw_dev, level) }, + + // SAFETY: The requirements are satisfied by the existence of [`Device`] and its safety + // requirements. The returned pointer will be owned by the new [`OPP`] instance. + SearchType::Ceil => unsafe { + bindings::dev_pm_opp_find_level_ceil(raw_dev, &mut level) + }, + + // SAFETY: The requirements are satisfied by the existence of [`Device`] and its safety + // requirements. The returned pointer will be owned by the new [`OPP`] instance. + SearchType::Floor => unsafe { + bindings::dev_pm_opp_find_level_floor(raw_dev, &mut level) + }, + })?; + + // SAFETY: The `ptr` is guaranteed by the C code to be valid. + unsafe { OPP::from_raw_opp_owned(ptr) } + } + + /// Finds [`OPP`] based on bandwidth. + pub fn opp_from_bw(&self, mut bw: u32, index: i32, stype: SearchType) -> Result<ARef<OPP>> { + let raw_dev = self.dev.as_raw(); + + let ptr = from_err_ptr(match stype { + // The OPP core doesn't support this yet. + SearchType::Exact => return Err(EINVAL), + + // SAFETY: The requirements are satisfied by the existence of [`Device`] and its safety + // requirements. The returned pointer will be owned by the new [`OPP`] instance. + SearchType::Ceil => unsafe { + bindings::dev_pm_opp_find_bw_ceil(raw_dev, &mut bw, index) + }, + + // SAFETY: The requirements are satisfied by the existence of [`Device`] and its safety + // requirements. The returned pointer will be owned by the new [`OPP`] instance. + SearchType::Floor => unsafe { + bindings::dev_pm_opp_find_bw_floor(raw_dev, &mut bw, index) + }, + })?; + + // SAFETY: The `ptr` is guaranteed by the C code to be valid. + unsafe { OPP::from_raw_opp_owned(ptr) } + } + + /// Enables the [`OPP`]. + #[inline] + pub fn enable_opp(&self, freq: Hertz) -> Result { + // SAFETY: The requirements are satisfied by the existence of [`Device`] and its safety + // requirements. + to_result(unsafe { bindings::dev_pm_opp_enable(self.dev.as_raw(), freq.into()) }) + } + + /// Disables the [`OPP`]. + #[inline] + pub fn disable_opp(&self, freq: Hertz) -> Result { + // SAFETY: The requirements are satisfied by the existence of [`Device`] and its safety + // requirements. + to_result(unsafe { bindings::dev_pm_opp_disable(self.dev.as_raw(), freq.into()) }) + } + + /// Registers with the Energy model. + #[cfg(CONFIG_OF)] + pub fn of_register_em(&mut self, cpumask: &mut Cpumask) -> Result { + // SAFETY: The requirements are satisfied by the existence of [`Device`] and its safety + // requirements. + to_result(unsafe { + bindings::dev_pm_opp_of_register_em(self.dev.as_raw(), cpumask.as_raw()) + })?; + + self.em = true; + Ok(()) + } + + /// Unregisters with the Energy model. + #[cfg(all(CONFIG_OF, CONFIG_ENERGY_MODEL))] + #[inline] + fn of_unregister_em(&self) { + // SAFETY: The requirements are satisfied by the existence of [`Device`] and its safety + // requirements. We registered with the EM framework earlier, it is safe to unregister now. + unsafe { bindings::em_dev_unregister_perf_domain(self.dev.as_raw()) }; + } +} + +impl Drop for Table { + fn drop(&mut self) { + // SAFETY: By the type invariants, we know that `self` owns a reference, so it is safe + // to relinquish it now. + unsafe { bindings::dev_pm_opp_put_opp_table(self.ptr) }; + + #[cfg(CONFIG_OF)] + { + #[cfg(CONFIG_ENERGY_MODEL)] + if self.em { + self.of_unregister_em(); + } + + if self.of { + self.remove_of(); + } else if let Some(cpumask) = self.cpus.take() { + self.remove_of_cpumask(&cpumask); + } + } + } +} + +/// A reference-counted Operating performance point (OPP). +/// +/// Rust abstraction for the C `struct dev_pm_opp`. +/// +/// # Invariants +/// +/// The pointer stored in `Self` is non-null and valid for the lifetime of the [`OPP`]. +/// +/// Instances of this type are reference-counted. The reference count is incremented by the +/// `dev_pm_opp_get` function and decremented by `dev_pm_opp_put`. The Rust type `ARef<OPP>` +/// represents a pointer that owns a reference count on the [`OPP`]. +/// +/// A reference to the [`OPP`], &[`OPP`], isn't refcounted by the Rust code. +/// +/// # Examples +/// +/// The following example demonstrates how to get [`OPP`] corresponding to a frequency value and +/// configure the device with it. +/// +/// ``` +/// use kernel::clk::Hertz; +/// use kernel::error::Result; +/// use kernel::opp::{SearchType, Table}; +/// +/// fn configure_opp(table: &Table, freq: Hertz) -> Result { +/// let opp = table.opp_from_freq(freq, Some(true), None, SearchType::Exact)?; +/// +/// if opp.freq(None) != freq { +/// return Err(EINVAL); +/// } +/// +/// table.set_opp(&opp) +/// } +/// ``` +#[repr(transparent)] +pub struct OPP(Opaque<bindings::dev_pm_opp>); + +/// SAFETY: It is okay to send the ownership of [`OPP`] across thread boundaries. +unsafe impl Send for OPP {} + +/// SAFETY: It is okay to access [`OPP`] through shared references from other threads because we're +/// either accessing properties that don't change or that are properly synchronised by C code. +unsafe impl Sync for OPP {} + +/// SAFETY: The type invariants guarantee that [`OPP`] is always refcounted. +unsafe impl AlwaysRefCounted for OPP { + fn inc_ref(&self) { + // SAFETY: The existence of a shared reference means that the refcount is nonzero. + unsafe { bindings::dev_pm_opp_get(self.0.get()) }; + } + + unsafe fn dec_ref(obj: ptr::NonNull<Self>) { + // SAFETY: The safety requirements guarantee that the refcount is nonzero. + unsafe { bindings::dev_pm_opp_put(obj.cast().as_ptr()) } + } +} + +impl OPP { + /// Creates an owned reference to a [`OPP`] from a valid pointer. + /// + /// The refcount is incremented by the C code and will be decremented by `dec_ref` when the + /// [`ARef`] object is dropped. + /// + /// # Safety + /// + /// The caller must ensure that `ptr` is valid and the refcount of the [`OPP`] is incremented. + /// The caller must also ensure that it doesn't explicitly drop the refcount of the [`OPP`], as + /// the returned [`ARef`] object takes over the refcount increment on the underlying object and + /// the same will be dropped along with it. + pub unsafe fn from_raw_opp_owned(ptr: *mut bindings::dev_pm_opp) -> Result<ARef<Self>> { + let ptr = ptr::NonNull::new(ptr).ok_or(ENODEV)?; + + // SAFETY: The safety requirements guarantee the validity of the pointer. + // + // INVARIANT: The reference-count is decremented when [`OPP`] goes out of scope. + Ok(unsafe { ARef::from_raw(ptr.cast()) }) + } + + /// Creates a reference to a [`OPP`] from a valid pointer. + /// + /// The refcount is not updated by the Rust API unless the returned reference is converted to + /// an [`ARef`] object. + /// + /// # Safety + /// + /// The caller must ensure that `ptr` is valid and remains valid for the duration of `'a`. + #[inline] + pub unsafe fn from_raw_opp<'a>(ptr: *mut bindings::dev_pm_opp) -> Result<&'a Self> { + // SAFETY: The caller guarantees that the pointer is not dangling and stays valid for the + // duration of 'a. The cast is okay because [`OPP`] is `repr(transparent)`. + Ok(unsafe { &*ptr.cast() }) + } + + #[inline] + fn as_raw(&self) -> *mut bindings::dev_pm_opp { + self.0.get() + } + + /// Returns the frequency of an [`OPP`]. + pub fn freq(&self, index: Option<u32>) -> Hertz { + let index = index.unwrap_or(0); + + // SAFETY: By the type invariants, we know that `self` owns a reference, so it is safe to + // use it. + Hertz(unsafe { bindings::dev_pm_opp_get_freq_indexed(self.as_raw(), index) }) + } + + /// Returns the voltage of an [`OPP`]. + #[inline] + pub fn voltage(&self) -> MicroVolt { + // SAFETY: By the type invariants, we know that `self` owns a reference, so it is safe to + // use it. + MicroVolt(unsafe { bindings::dev_pm_opp_get_voltage(self.as_raw()) }) + } + + /// Returns the level of an [`OPP`]. + #[inline] + pub fn level(&self) -> u32 { + // SAFETY: By the type invariants, we know that `self` owns a reference, so it is safe to + // use it. + unsafe { bindings::dev_pm_opp_get_level(self.as_raw()) } + } + + /// Returns the power of an [`OPP`]. + #[inline] + pub fn power(&self) -> MicroWatt { + // SAFETY: By the type invariants, we know that `self` owns a reference, so it is safe to + // use it. + MicroWatt(unsafe { bindings::dev_pm_opp_get_power(self.as_raw()) }) + } + + /// Returns the required pstate of an [`OPP`]. + #[inline] + pub fn required_pstate(&self, index: u32) -> u32 { + // SAFETY: By the type invariants, we know that `self` owns a reference, so it is safe to + // use it. + unsafe { bindings::dev_pm_opp_get_required_pstate(self.as_raw(), index) } + } + + /// Returns true if the [`OPP`] is turbo. + #[inline] + pub fn is_turbo(&self) -> bool { + // SAFETY: By the type invariants, we know that `self` owns a reference, so it is safe to + // use it. + unsafe { bindings::dev_pm_opp_is_turbo(self.as_raw()) } + } +} diff --git a/rust/kernel/page.rs b/rust/kernel/page.rs new file mode 100644 index 000000000000..432fc0297d4a --- /dev/null +++ b/rust/kernel/page.rs @@ -0,0 +1,351 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! Kernel page allocation and management. + +use crate::{ + alloc::{AllocError, Flags}, + bindings, + error::code::*, + error::Result, + uaccess::UserSliceReader, +}; +use core::{ + marker::PhantomData, + mem::ManuallyDrop, + ops::Deref, + ptr::{self, NonNull}, +}; + +/// A bitwise shift for the page size. +pub const PAGE_SHIFT: usize = bindings::PAGE_SHIFT as usize; + +/// The number of bytes in a page. +pub const PAGE_SIZE: usize = bindings::PAGE_SIZE; + +/// A bitmask that gives the page containing a given address. +pub const PAGE_MASK: usize = !(PAGE_SIZE - 1); + +/// Round up the given number to the next multiple of [`PAGE_SIZE`]. +/// +/// It is incorrect to pass an address where the next multiple of [`PAGE_SIZE`] doesn't fit in a +/// [`usize`]. +pub const fn page_align(addr: usize) -> usize { + // Parentheses around `PAGE_SIZE - 1` to avoid triggering overflow sanitizers in the wrong + // cases. + (addr + (PAGE_SIZE - 1)) & PAGE_MASK +} + +/// Representation of a non-owning reference to a [`Page`]. +/// +/// This type provides a borrowed version of a [`Page`] that is owned by some other entity, e.g. a +/// [`Vmalloc`] allocation such as [`VBox`]. +/// +/// # Example +/// +/// ``` +/// # use kernel::{bindings, prelude::*}; +/// use kernel::page::{BorrowedPage, Page, PAGE_SIZE}; +/// # use core::{mem::MaybeUninit, ptr, ptr::NonNull }; +/// +/// fn borrow_page<'a>(vbox: &'a mut VBox<MaybeUninit<[u8; PAGE_SIZE]>>) -> BorrowedPage<'a> { +/// let ptr = ptr::from_ref(&**vbox); +/// +/// // SAFETY: `ptr` is a valid pointer to `Vmalloc` memory. +/// let page = unsafe { bindings::vmalloc_to_page(ptr.cast()) }; +/// +/// // SAFETY: `vmalloc_to_page` returns a valid pointer to a `struct page` for a valid +/// // pointer to `Vmalloc` memory. +/// let page = unsafe { NonNull::new_unchecked(page) }; +/// +/// // SAFETY: +/// // - `self.0` is a valid pointer to a `struct page`. +/// // - `self.0` is valid for the entire lifetime of `self`. +/// unsafe { BorrowedPage::from_raw(page) } +/// } +/// +/// let mut vbox = VBox::<[u8; PAGE_SIZE]>::new_uninit(GFP_KERNEL)?; +/// let page = borrow_page(&mut vbox); +/// +/// // SAFETY: There is no concurrent read or write to this page. +/// unsafe { page.fill_zero_raw(0, PAGE_SIZE)? }; +/// # Ok::<(), Error>(()) +/// ``` +/// +/// # Invariants +/// +/// The borrowed underlying pointer to a `struct page` is valid for the entire lifetime `'a`. +/// +/// [`VBox`]: kernel::alloc::VBox +/// [`Vmalloc`]: kernel::alloc::allocator::Vmalloc +pub struct BorrowedPage<'a>(ManuallyDrop<Page>, PhantomData<&'a Page>); + +impl<'a> BorrowedPage<'a> { + /// Constructs a [`BorrowedPage`] from a raw pointer to a `struct page`. + /// + /// # Safety + /// + /// - `ptr` must point to a valid `bindings::page`. + /// - `ptr` must remain valid for the entire lifetime `'a`. + pub unsafe fn from_raw(ptr: NonNull<bindings::page>) -> Self { + let page = Page { page: ptr }; + + // INVARIANT: The safety requirements guarantee that `ptr` is valid for the entire lifetime + // `'a`. + Self(ManuallyDrop::new(page), PhantomData) + } +} + +impl<'a> Deref for BorrowedPage<'a> { + type Target = Page; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +/// Trait to be implemented by types which provide an [`Iterator`] implementation of +/// [`BorrowedPage`] items, such as [`VmallocPageIter`](kernel::alloc::allocator::VmallocPageIter). +pub trait AsPageIter { + /// The [`Iterator`] type, e.g. [`VmallocPageIter`](kernel::alloc::allocator::VmallocPageIter). + type Iter<'a>: Iterator<Item = BorrowedPage<'a>> + where + Self: 'a; + + /// Returns an [`Iterator`] of [`BorrowedPage`] items over all pages owned by `self`. + fn page_iter(&mut self) -> Self::Iter<'_>; +} + +/// A pointer to a page that owns the page allocation. +/// +/// # Invariants +/// +/// The pointer is valid, and has ownership over the page. +pub struct Page { + page: NonNull<bindings::page>, +} + +// SAFETY: Pages have no logic that relies on them staying on a given thread, so moving them across +// threads is safe. +unsafe impl Send for Page {} + +// SAFETY: Pages have no logic that relies on them not being accessed concurrently, so accessing +// them concurrently is safe. +unsafe impl Sync for Page {} + +impl Page { + /// Allocates a new page. + /// + /// # Examples + /// + /// Allocate memory for a page. + /// + /// ``` + /// use kernel::page::Page; + /// + /// let page = Page::alloc_page(GFP_KERNEL)?; + /// # Ok::<(), kernel::alloc::AllocError>(()) + /// ``` + /// + /// Allocate memory for a page and zero its contents. + /// + /// ``` + /// use kernel::page::Page; + /// + /// let page = Page::alloc_page(GFP_KERNEL | __GFP_ZERO)?; + /// # Ok::<(), kernel::alloc::AllocError>(()) + /// ``` + #[inline] + pub fn alloc_page(flags: Flags) -> Result<Self, AllocError> { + // SAFETY: Depending on the value of `gfp_flags`, this call may sleep. Other than that, it + // is always safe to call this method. + let page = unsafe { bindings::alloc_pages(flags.as_raw(), 0) }; + let page = NonNull::new(page).ok_or(AllocError)?; + // INVARIANT: We just successfully allocated a page, so we now have ownership of the newly + // allocated page. We transfer that ownership to the new `Page` object. + Ok(Self { page }) + } + + /// Returns a raw pointer to the page. + pub fn as_ptr(&self) -> *mut bindings::page { + self.page.as_ptr() + } + + /// Get the node id containing this page. + pub fn nid(&self) -> i32 { + // SAFETY: Always safe to call with a valid page. + unsafe { bindings::page_to_nid(self.as_ptr()) } + } + + /// Runs a piece of code with this page mapped to an address. + /// + /// The page is unmapped when this call returns. + /// + /// # Using the raw pointer + /// + /// It is up to the caller to use the provided raw pointer correctly. The pointer is valid for + /// `PAGE_SIZE` bytes and for the duration in which the closure is called. The pointer might + /// only be mapped on the current thread, and when that is the case, dereferencing it on other + /// threads is UB. Other than that, the usual rules for dereferencing a raw pointer apply: don't + /// cause data races, the memory may be uninitialized, and so on. + /// + /// If multiple threads map the same page at the same time, then they may reference with + /// different addresses. However, even if the addresses are different, the underlying memory is + /// still the same for these purposes (e.g., it's still a data race if they both write to the + /// same underlying byte at the same time). + fn with_page_mapped<T>(&self, f: impl FnOnce(*mut u8) -> T) -> T { + // SAFETY: `page` is valid due to the type invariants on `Page`. + let mapped_addr = unsafe { bindings::kmap_local_page(self.as_ptr()) }; + + let res = f(mapped_addr.cast()); + + // This unmaps the page mapped above. + // + // SAFETY: Since this API takes the user code as a closure, it can only be used in a manner + // where the pages are unmapped in reverse order. This is as required by `kunmap_local`. + // + // In other words, if this call to `kunmap_local` happens when a different page should be + // unmapped first, then there must necessarily be a call to `kmap_local_page` other than the + // call just above in `with_page_mapped` that made that possible. In this case, it is the + // unsafe block that wraps that other call that is incorrect. + unsafe { bindings::kunmap_local(mapped_addr) }; + + res + } + + /// Runs a piece of code with a raw pointer to a slice of this page, with bounds checking. + /// + /// If `f` is called, then it will be called with a pointer that points at `off` bytes into the + /// page, and the pointer will be valid for at least `len` bytes. The pointer is only valid on + /// this task, as this method uses a local mapping. + /// + /// If `off` and `len` refers to a region outside of this page, then this method returns + /// [`EINVAL`] and does not call `f`. + /// + /// # Using the raw pointer + /// + /// It is up to the caller to use the provided raw pointer correctly. The pointer is valid for + /// `len` bytes and for the duration in which the closure is called. The pointer might only be + /// mapped on the current thread, and when that is the case, dereferencing it on other threads + /// is UB. Other than that, the usual rules for dereferencing a raw pointer apply: don't cause + /// data races, the memory may be uninitialized, and so on. + /// + /// If multiple threads map the same page at the same time, then they may reference with + /// different addresses. However, even if the addresses are different, the underlying memory is + /// still the same for these purposes (e.g., it's still a data race if they both write to the + /// same underlying byte at the same time). + fn with_pointer_into_page<T>( + &self, + off: usize, + len: usize, + f: impl FnOnce(*mut u8) -> Result<T>, + ) -> Result<T> { + let bounds_ok = off <= PAGE_SIZE && len <= PAGE_SIZE && (off + len) <= PAGE_SIZE; + + if bounds_ok { + self.with_page_mapped(move |page_addr| { + // SAFETY: The `off` integer is at most `PAGE_SIZE`, so this pointer offset will + // result in a pointer that is in bounds or one off the end of the page. + f(unsafe { page_addr.add(off) }) + }) + } else { + Err(EINVAL) + } + } + + /// Maps the page and reads from it into the given buffer. + /// + /// This method will perform bounds checks on the page offset. If `offset .. offset+len` goes + /// outside of the page, then this call returns [`EINVAL`]. + /// + /// # Safety + /// + /// * Callers must ensure that `dst` is valid for writing `len` bytes. + /// * Callers must ensure that this call does not race with a write to the same page that + /// overlaps with this read. + pub unsafe fn read_raw(&self, dst: *mut u8, offset: usize, len: usize) -> Result { + self.with_pointer_into_page(offset, len, move |src| { + // SAFETY: If `with_pointer_into_page` calls into this closure, then + // it has performed a bounds check and guarantees that `src` is + // valid for `len` bytes. + // + // There caller guarantees that there is no data race. + unsafe { ptr::copy_nonoverlapping(src, dst, len) }; + Ok(()) + }) + } + + /// Maps the page and writes into it from the given buffer. + /// + /// This method will perform bounds checks on the page offset. If `offset .. offset+len` goes + /// outside of the page, then this call returns [`EINVAL`]. + /// + /// # Safety + /// + /// * Callers must ensure that `src` is valid for reading `len` bytes. + /// * Callers must ensure that this call does not race with a read or write to the same page + /// that overlaps with this write. + pub unsafe fn write_raw(&self, src: *const u8, offset: usize, len: usize) -> Result { + self.with_pointer_into_page(offset, len, move |dst| { + // SAFETY: If `with_pointer_into_page` calls into this closure, then it has performed a + // bounds check and guarantees that `dst` is valid for `len` bytes. + // + // There caller guarantees that there is no data race. + unsafe { ptr::copy_nonoverlapping(src, dst, len) }; + Ok(()) + }) + } + + /// Maps the page and zeroes the given slice. + /// + /// This method will perform bounds checks on the page offset. If `offset .. offset+len` goes + /// outside of the page, then this call returns [`EINVAL`]. + /// + /// # Safety + /// + /// Callers must ensure that this call does not race with a read or write to the same page that + /// overlaps with this write. + pub unsafe fn fill_zero_raw(&self, offset: usize, len: usize) -> Result { + self.with_pointer_into_page(offset, len, move |dst| { + // SAFETY: If `with_pointer_into_page` calls into this closure, then it has performed a + // bounds check and guarantees that `dst` is valid for `len` bytes. + // + // There caller guarantees that there is no data race. + unsafe { ptr::write_bytes(dst, 0u8, len) }; + Ok(()) + }) + } + + /// Copies data from userspace into this page. + /// + /// This method will perform bounds checks on the page offset. If `offset .. offset+len` goes + /// outside of the page, then this call returns [`EINVAL`]. + /// + /// Like the other `UserSliceReader` methods, data races are allowed on the userspace address. + /// However, they are not allowed on the page you are copying into. + /// + /// # Safety + /// + /// Callers must ensure that this call does not race with a read or write to the same page that + /// overlaps with this write. + pub unsafe fn copy_from_user_slice_raw( + &self, + reader: &mut UserSliceReader, + offset: usize, + len: usize, + ) -> Result { + self.with_pointer_into_page(offset, len, move |dst| { + // SAFETY: If `with_pointer_into_page` calls into this closure, then it has performed a + // bounds check and guarantees that `dst` is valid for `len` bytes. Furthermore, we have + // exclusive access to the slice since the caller guarantees that there are no races. + reader.read_raw(unsafe { core::slice::from_raw_parts_mut(dst.cast(), len) }) + }) + } +} + +impl Drop for Page { + #[inline] + fn drop(&mut self) { + // SAFETY: By the type invariants, we have ownership of the page and can free it. + unsafe { bindings::__free_pages(self.page.as_ptr(), 0) }; + } +} diff --git a/rust/kernel/pci.rs b/rust/kernel/pci.rs new file mode 100644 index 000000000000..82e128431f08 --- /dev/null +++ b/rust/kernel/pci.rs @@ -0,0 +1,509 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! Abstractions for the PCI bus. +//! +//! C header: [`include/linux/pci.h`](srctree/include/linux/pci.h) + +use crate::{ + bindings, + container_of, + device, + device_id::{ + RawDeviceId, + RawDeviceIdIndex, // + }, + driver, + error::{ + from_result, + to_result, // + }, + prelude::*, + str::CStr, + types::Opaque, + ThisModule, // +}; +use core::{ + marker::PhantomData, + mem::offset_of, + ptr::{ + addr_of_mut, + NonNull, // + }, +}; + +mod id; +mod io; +mod irq; + +pub use self::id::{ + Class, + ClassMask, + Vendor, // +}; +pub use self::io::Bar; +pub use self::irq::{ + IrqType, + IrqTypes, + IrqVector, // +}; + +/// An adapter for the registration of PCI drivers. +pub struct Adapter<T: Driver>(T); + +// SAFETY: A call to `unregister` for a given instance of `RegType` is guaranteed to be valid if +// a preceding call to `register` has been successful. +unsafe impl<T: Driver + 'static> driver::RegistrationOps for Adapter<T> { + type RegType = bindings::pci_driver; + + unsafe fn register( + pdrv: &Opaque<Self::RegType>, + name: &'static CStr, + module: &'static ThisModule, + ) -> Result { + // SAFETY: It's safe to set the fields of `struct pci_driver` on initialization. + unsafe { + (*pdrv.get()).name = name.as_char_ptr(); + (*pdrv.get()).probe = Some(Self::probe_callback); + (*pdrv.get()).remove = Some(Self::remove_callback); + (*pdrv.get()).id_table = T::ID_TABLE.as_ptr(); + } + + // SAFETY: `pdrv` is guaranteed to be a valid `RegType`. + to_result(unsafe { + bindings::__pci_register_driver(pdrv.get(), module.0, name.as_char_ptr()) + }) + } + + unsafe fn unregister(pdrv: &Opaque<Self::RegType>) { + // SAFETY: `pdrv` is guaranteed to be a valid `RegType`. + unsafe { bindings::pci_unregister_driver(pdrv.get()) } + } +} + +impl<T: Driver + 'static> Adapter<T> { + extern "C" fn probe_callback( + pdev: *mut bindings::pci_dev, + id: *const bindings::pci_device_id, + ) -> c_int { + // SAFETY: The PCI bus only ever calls the probe callback with a valid pointer to a + // `struct pci_dev`. + // + // INVARIANT: `pdev` is valid for the duration of `probe_callback()`. + let pdev = unsafe { &*pdev.cast::<Device<device::CoreInternal>>() }; + + // SAFETY: `DeviceId` is a `#[repr(transparent)]` wrapper of `struct pci_device_id` and + // does not add additional invariants, so it's safe to transmute. + let id = unsafe { &*id.cast::<DeviceId>() }; + let info = T::ID_TABLE.info(id.index()); + + from_result(|| { + let data = T::probe(pdev, info); + + pdev.as_ref().set_drvdata(data)?; + Ok(0) + }) + } + + extern "C" fn remove_callback(pdev: *mut bindings::pci_dev) { + // SAFETY: The PCI bus only ever calls the remove callback with a valid pointer to a + // `struct pci_dev`. + // + // INVARIANT: `pdev` is valid for the duration of `remove_callback()`. + let pdev = unsafe { &*pdev.cast::<Device<device::CoreInternal>>() }; + + // SAFETY: `remove_callback` is only ever called after a successful call to + // `probe_callback`, hence it's guaranteed that `Device::set_drvdata()` has been called + // and stored a `Pin<KBox<T>>`. + let data = unsafe { pdev.as_ref().drvdata_obtain::<T>() }; + + T::unbind(pdev, data.as_ref()); + } +} + +/// Declares a kernel module that exposes a single PCI driver. +/// +/// # Examples +/// +///```ignore +/// kernel::module_pci_driver! { +/// type: MyDriver, +/// name: "Module name", +/// authors: ["Author name"], +/// description: "Description", +/// license: "GPL v2", +/// } +///``` +#[macro_export] +macro_rules! module_pci_driver { +($($f:tt)*) => { + $crate::module_driver!(<T>, $crate::pci::Adapter<T>, { $($f)* }); +}; +} + +/// Abstraction for the PCI device ID structure ([`struct pci_device_id`]). +/// +/// [`struct pci_device_id`]: https://docs.kernel.org/PCI/pci.html#c.pci_device_id +#[repr(transparent)] +#[derive(Clone, Copy)] +pub struct DeviceId(bindings::pci_device_id); + +impl DeviceId { + const PCI_ANY_ID: u32 = !0; + + /// Equivalent to C's `PCI_DEVICE` macro. + /// + /// Create a new `pci::DeviceId` from a vendor and device ID. + #[inline] + pub const fn from_id(vendor: Vendor, device: u32) -> Self { + Self(bindings::pci_device_id { + vendor: vendor.as_raw() as u32, + device, + subvendor: DeviceId::PCI_ANY_ID, + subdevice: DeviceId::PCI_ANY_ID, + class: 0, + class_mask: 0, + driver_data: 0, + override_only: 0, + }) + } + + /// Equivalent to C's `PCI_DEVICE_CLASS` macro. + /// + /// Create a new `pci::DeviceId` from a class number and mask. + #[inline] + pub const fn from_class(class: u32, class_mask: u32) -> Self { + Self(bindings::pci_device_id { + vendor: DeviceId::PCI_ANY_ID, + device: DeviceId::PCI_ANY_ID, + subvendor: DeviceId::PCI_ANY_ID, + subdevice: DeviceId::PCI_ANY_ID, + class, + class_mask, + driver_data: 0, + override_only: 0, + }) + } + + /// Create a new [`DeviceId`] from a class number, mask, and specific vendor. + /// + /// This is more targeted than [`DeviceId::from_class`]: in addition to matching by [`Vendor`], + /// it also matches the PCI [`Class`] (up to the entire 24 bits, depending on the + /// [`ClassMask`]). + #[inline] + pub const fn from_class_and_vendor( + class: Class, + class_mask: ClassMask, + vendor: Vendor, + ) -> Self { + Self(bindings::pci_device_id { + vendor: vendor.as_raw() as u32, + device: DeviceId::PCI_ANY_ID, + subvendor: DeviceId::PCI_ANY_ID, + subdevice: DeviceId::PCI_ANY_ID, + class: class.as_raw(), + class_mask: class_mask.as_raw(), + driver_data: 0, + override_only: 0, + }) + } +} + +// SAFETY: `DeviceId` is a `#[repr(transparent)]` wrapper of `pci_device_id` and does not add +// additional invariants, so it's safe to transmute to `RawType`. +unsafe impl RawDeviceId for DeviceId { + type RawType = bindings::pci_device_id; +} + +// SAFETY: `DRIVER_DATA_OFFSET` is the offset to the `driver_data` field. +unsafe impl RawDeviceIdIndex for DeviceId { + const DRIVER_DATA_OFFSET: usize = core::mem::offset_of!(bindings::pci_device_id, driver_data); + + fn index(&self) -> usize { + self.0.driver_data + } +} + +/// `IdTable` type for PCI. +pub type IdTable<T> = &'static dyn kernel::device_id::IdTable<DeviceId, T>; + +/// Create a PCI `IdTable` with its alias for modpost. +#[macro_export] +macro_rules! pci_device_table { + ($table_name:ident, $module_table_name:ident, $id_info_type: ty, $table_data: expr) => { + const $table_name: $crate::device_id::IdArray< + $crate::pci::DeviceId, + $id_info_type, + { $table_data.len() }, + > = $crate::device_id::IdArray::new($table_data); + + $crate::module_device_table!("pci", $module_table_name, $table_name); + }; +} + +/// The PCI driver trait. +/// +/// # Examples +/// +///``` +/// # use kernel::{bindings, device::Core, pci}; +/// +/// struct MyDriver; +/// +/// kernel::pci_device_table!( +/// PCI_TABLE, +/// MODULE_PCI_TABLE, +/// <MyDriver as pci::Driver>::IdInfo, +/// [ +/// ( +/// pci::DeviceId::from_id(pci::Vendor::REDHAT, bindings::PCI_ANY_ID as u32), +/// (), +/// ) +/// ] +/// ); +/// +/// impl pci::Driver for MyDriver { +/// type IdInfo = (); +/// const ID_TABLE: pci::IdTable<Self::IdInfo> = &PCI_TABLE; +/// +/// fn probe( +/// _pdev: &pci::Device<Core>, +/// _id_info: &Self::IdInfo, +/// ) -> impl PinInit<Self, Error> { +/// Err(ENODEV) +/// } +/// } +///``` +/// Drivers must implement this trait in order to get a PCI driver registered. Please refer to the +/// `Adapter` documentation for an example. +pub trait Driver: Send { + /// The type holding information about each device id supported by the driver. + // TODO: Use `associated_type_defaults` once stabilized: + // + // ``` + // type IdInfo: 'static = (); + // ``` + type IdInfo: 'static; + + /// The table of device ids supported by the driver. + const ID_TABLE: IdTable<Self::IdInfo>; + + /// PCI driver probe. + /// + /// Called when a new pci device is added or discovered. Implementers should + /// attempt to initialize the device here. + fn probe(dev: &Device<device::Core>, id_info: &Self::IdInfo) -> impl PinInit<Self, Error>; + + /// PCI driver unbind. + /// + /// Called when a [`Device`] is unbound from its bound [`Driver`]. Implementing this callback + /// is optional. + /// + /// This callback serves as a place for drivers to perform teardown operations that require a + /// `&Device<Core>` or `&Device<Bound>` reference. For instance, drivers may try to perform I/O + /// operations to gracefully tear down the device. + /// + /// Otherwise, release operations for driver resources should be performed in `Self::drop`. + fn unbind(dev: &Device<device::Core>, this: Pin<&Self>) { + let _ = (dev, this); + } +} + +/// The PCI device representation. +/// +/// This structure represents the Rust abstraction for a C `struct pci_dev`. The implementation +/// abstracts the usage of an already existing C `struct pci_dev` within Rust code that we get +/// passed from the C side. +/// +/// # Invariants +/// +/// A [`Device`] instance represents a valid `struct pci_dev` created by the C portion of the +/// kernel. +#[repr(transparent)] +pub struct Device<Ctx: device::DeviceContext = device::Normal>( + Opaque<bindings::pci_dev>, + PhantomData<Ctx>, +); + +impl<Ctx: device::DeviceContext> Device<Ctx> { + #[inline] + fn as_raw(&self) -> *mut bindings::pci_dev { + self.0.get() + } +} + +impl Device { + /// Returns the PCI vendor ID as [`Vendor`]. + /// + /// # Examples + /// + /// ``` + /// # use kernel::{device::Core, pci::{self, Vendor}, prelude::*}; + /// fn log_device_info(pdev: &pci::Device<Core>) -> Result { + /// // Get an instance of `Vendor`. + /// let vendor = pdev.vendor_id(); + /// dev_info!( + /// pdev.as_ref(), + /// "Device: Vendor={}, Device=0x{:x}\n", + /// vendor, + /// pdev.device_id() + /// ); + /// Ok(()) + /// } + /// ``` + #[inline] + pub fn vendor_id(&self) -> Vendor { + // SAFETY: `self.as_raw` is a valid pointer to a `struct pci_dev`. + let vendor_id = unsafe { (*self.as_raw()).vendor }; + Vendor::from_raw(vendor_id) + } + + /// Returns the PCI device ID. + #[inline] + pub fn device_id(&self) -> u16 { + // SAFETY: By its type invariant `self.as_raw` is always a valid pointer to a + // `struct pci_dev`. + unsafe { (*self.as_raw()).device } + } + + /// Returns the PCI revision ID. + #[inline] + pub fn revision_id(&self) -> u8 { + // SAFETY: By its type invariant `self.as_raw` is always a valid pointer to a + // `struct pci_dev`. + unsafe { (*self.as_raw()).revision } + } + + /// Returns the PCI bus device/function. + #[inline] + pub fn dev_id(&self) -> u16 { + // SAFETY: By its type invariant `self.as_raw` is always a valid pointer to a + // `struct pci_dev`. + unsafe { bindings::pci_dev_id(self.as_raw()) } + } + + /// Returns the PCI subsystem vendor ID. + #[inline] + pub fn subsystem_vendor_id(&self) -> u16 { + // SAFETY: By its type invariant `self.as_raw` is always a valid pointer to a + // `struct pci_dev`. + unsafe { (*self.as_raw()).subsystem_vendor } + } + + /// Returns the PCI subsystem device ID. + #[inline] + pub fn subsystem_device_id(&self) -> u16 { + // SAFETY: By its type invariant `self.as_raw` is always a valid pointer to a + // `struct pci_dev`. + unsafe { (*self.as_raw()).subsystem_device } + } + + /// Returns the start of the given PCI BAR resource. + pub fn resource_start(&self, bar: u32) -> Result<bindings::resource_size_t> { + if !Bar::index_is_valid(bar) { + return Err(EINVAL); + } + + // SAFETY: + // - `bar` is a valid bar number, as guaranteed by the above call to `Bar::index_is_valid`, + // - by its type invariant `self.as_raw` is always a valid pointer to a `struct pci_dev`. + Ok(unsafe { bindings::pci_resource_start(self.as_raw(), bar.try_into()?) }) + } + + /// Returns the size of the given PCI BAR resource. + pub fn resource_len(&self, bar: u32) -> Result<bindings::resource_size_t> { + if !Bar::index_is_valid(bar) { + return Err(EINVAL); + } + + // SAFETY: + // - `bar` is a valid bar number, as guaranteed by the above call to `Bar::index_is_valid`, + // - by its type invariant `self.as_raw` is always a valid pointer to a `struct pci_dev`. + Ok(unsafe { bindings::pci_resource_len(self.as_raw(), bar.try_into()?) }) + } + + /// Returns the PCI class as a `Class` struct. + #[inline] + pub fn pci_class(&self) -> Class { + // SAFETY: `self.as_raw` is a valid pointer to a `struct pci_dev`. + Class::from_raw(unsafe { (*self.as_raw()).class }) + } +} + +impl Device<device::Core> { + /// Enable memory resources for this device. + pub fn enable_device_mem(&self) -> Result { + // SAFETY: `self.as_raw` is guaranteed to be a pointer to a valid `struct pci_dev`. + to_result(unsafe { bindings::pci_enable_device_mem(self.as_raw()) }) + } + + /// Enable bus-mastering for this device. + #[inline] + pub fn set_master(&self) { + // SAFETY: `self.as_raw` is guaranteed to be a pointer to a valid `struct pci_dev`. + unsafe { bindings::pci_set_master(self.as_raw()) }; + } +} + +// SAFETY: `pci::Device` is a transparent wrapper of `struct pci_dev`. +// The offset is guaranteed to point to a valid device field inside `pci::Device`. +unsafe impl<Ctx: device::DeviceContext> device::AsBusDevice<Ctx> for Device<Ctx> { + const OFFSET: usize = offset_of!(bindings::pci_dev, dev); +} + +// SAFETY: `Device` is a transparent wrapper of a type that doesn't depend on `Device`'s generic +// argument. +kernel::impl_device_context_deref!(unsafe { Device }); +kernel::impl_device_context_into_aref!(Device); + +impl crate::dma::Device for Device<device::Core> {} + +// SAFETY: Instances of `Device` are always reference-counted. +unsafe impl crate::sync::aref::AlwaysRefCounted for Device { + fn inc_ref(&self) { + // SAFETY: The existence of a shared reference guarantees that the refcount is non-zero. + unsafe { bindings::pci_dev_get(self.as_raw()) }; + } + + unsafe fn dec_ref(obj: NonNull<Self>) { + // SAFETY: The safety requirements guarantee that the refcount is non-zero. + unsafe { bindings::pci_dev_put(obj.cast().as_ptr()) } + } +} + +impl<Ctx: device::DeviceContext> AsRef<device::Device<Ctx>> for Device<Ctx> { + fn as_ref(&self) -> &device::Device<Ctx> { + // SAFETY: By the type invariant of `Self`, `self.as_raw()` is a pointer to a valid + // `struct pci_dev`. + let dev = unsafe { addr_of_mut!((*self.as_raw()).dev) }; + + // SAFETY: `dev` points to a valid `struct device`. + unsafe { device::Device::from_raw(dev) } + } +} + +impl<Ctx: device::DeviceContext> TryFrom<&device::Device<Ctx>> for &Device<Ctx> { + type Error = kernel::error::Error; + + fn try_from(dev: &device::Device<Ctx>) -> Result<Self, Self::Error> { + // SAFETY: By the type invariant of `Device`, `dev.as_raw()` is a valid pointer to a + // `struct device`. + if !unsafe { bindings::dev_is_pci(dev.as_raw()) } { + return Err(EINVAL); + } + + // SAFETY: We've just verified that the bus type of `dev` equals `bindings::pci_bus_type`, + // hence `dev` must be embedded in a valid `struct pci_dev` as guaranteed by the + // corresponding C code. + let pdev = unsafe { container_of!(dev.as_raw(), bindings::pci_dev, dev) }; + + // SAFETY: `pdev` is a valid pointer to a `struct pci_dev`. + Ok(unsafe { &*pdev.cast() }) + } +} + +// SAFETY: A `Device` is always reference-counted and can be released from any thread. +unsafe impl Send for Device {} + +// SAFETY: `Device` can be shared among threads because all methods of `Device` +// (i.e. `Device<Normal>) are thread safe. +unsafe impl Sync for Device {} diff --git a/rust/kernel/pci/id.rs b/rust/kernel/pci/id.rs new file mode 100644 index 000000000000..c09125946d9e --- /dev/null +++ b/rust/kernel/pci/id.rs @@ -0,0 +1,581 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! PCI device identifiers and related types. +//! +//! This module contains PCI class codes, Vendor IDs, and supporting types. + +use crate::{ + bindings, + fmt, + prelude::*, // +}; + +/// PCI device class codes. +/// +/// Each entry contains the full 24-bit PCI class code (base class in bits +/// 23-16, subclass in bits 15-8, programming interface in bits 7-0). +/// +/// # Examples +/// +/// ``` +/// # use kernel::{device::Core, pci::{self, Class}, prelude::*}; +/// fn probe_device(pdev: &pci::Device<Core>) -> Result { +/// let pci_class = pdev.pci_class(); +/// dev_info!( +/// pdev.as_ref(), +/// "Detected PCI class: {}\n", +/// pci_class +/// ); +/// Ok(()) +/// } +/// ``` +#[derive(Clone, Copy, PartialEq, Eq)] +#[repr(transparent)] +pub struct Class(u32); + +/// PCI class mask constants for matching [`Class`] codes. +#[repr(u32)] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum ClassMask { + /// Match the full 24-bit class code. + Full = 0xffffff, + /// Match the upper 16 bits of the class code (base class and subclass only) + ClassSubclass = 0xffff00, +} + +macro_rules! define_all_pci_classes { + ( + $($variant:ident = $binding:expr,)+ + ) => { + impl Class { + $( + #[allow(missing_docs)] + pub const $variant: Self = Self(Self::to_24bit_class($binding)); + )+ + } + + impl fmt::Display for Class { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + $( + &Self::$variant => write!(f, stringify!($variant)), + )+ + _ => <Self as fmt::Debug>::fmt(self, f), + } + } + } + }; +} + +/// Once constructed, a [`Class`] contains a valid PCI class code. +impl Class { + /// Create a [`Class`] from a raw 24-bit class code. + #[inline] + pub(super) fn from_raw(class_code: u32) -> Self { + Self(class_code) + } + + /// Get the raw 24-bit class code value. + #[inline] + pub const fn as_raw(self) -> u32 { + self.0 + } + + // Converts a PCI class constant to 24-bit format. + // + // Many device drivers use only the upper 16 bits (base class and subclass), + // but some use the full 24 bits. In order to support both cases, store the + // class code as a 24-bit value, where 16-bit values are shifted up 8 bits. + const fn to_24bit_class(val: u32) -> u32 { + if val > 0xFFFF { + val + } else { + val << 8 + } + } +} + +impl fmt::Debug for Class { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "0x{:06x}", self.0) + } +} + +impl ClassMask { + /// Get the raw mask value. + #[inline] + pub const fn as_raw(self) -> u32 { + self as u32 + } +} + +impl TryFrom<u32> for ClassMask { + type Error = Error; + + fn try_from(value: u32) -> Result<Self, Self::Error> { + match value { + 0xffffff => Ok(ClassMask::Full), + 0xffff00 => Ok(ClassMask::ClassSubclass), + _ => Err(EINVAL), + } + } +} + +/// PCI vendor IDs. +/// +/// Each entry contains the 16-bit PCI vendor ID as assigned by the PCI SIG. +#[derive(Clone, Copy, PartialEq, Eq)] +#[repr(transparent)] +pub struct Vendor(u16); + +macro_rules! define_all_pci_vendors { + ( + $($variant:ident = $binding:expr,)+ + ) => { + impl Vendor { + $( + #[allow(missing_docs)] + pub const $variant: Self = Self($binding as u16); + )+ + } + + impl fmt::Display for Vendor { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + $( + &Self::$variant => write!(f, stringify!($variant)), + )+ + _ => <Self as fmt::Debug>::fmt(self, f), + } + } + } + }; +} + +/// Once constructed, a `Vendor` contains a valid PCI Vendor ID. +impl Vendor { + /// Create a Vendor from a raw 16-bit vendor ID. + #[inline] + pub(super) fn from_raw(vendor_id: u16) -> Self { + Self(vendor_id) + } + + /// Get the raw 16-bit vendor ID value. + #[inline] + pub const fn as_raw(self) -> u16 { + self.0 + } +} + +impl fmt::Debug for Vendor { + #[inline] + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "0x{:04x}", self.0) + } +} + +define_all_pci_classes! { + NOT_DEFINED = bindings::PCI_CLASS_NOT_DEFINED, // 0x000000 + NOT_DEFINED_VGA = bindings::PCI_CLASS_NOT_DEFINED_VGA, // 0x000100 + + STORAGE_SCSI = bindings::PCI_CLASS_STORAGE_SCSI, // 0x010000 + STORAGE_IDE = bindings::PCI_CLASS_STORAGE_IDE, // 0x010100 + STORAGE_FLOPPY = bindings::PCI_CLASS_STORAGE_FLOPPY, // 0x010200 + STORAGE_IPI = bindings::PCI_CLASS_STORAGE_IPI, // 0x010300 + STORAGE_RAID = bindings::PCI_CLASS_STORAGE_RAID, // 0x010400 + STORAGE_SATA = bindings::PCI_CLASS_STORAGE_SATA, // 0x010600 + STORAGE_SATA_AHCI = bindings::PCI_CLASS_STORAGE_SATA_AHCI, // 0x010601 + STORAGE_SAS = bindings::PCI_CLASS_STORAGE_SAS, // 0x010700 + STORAGE_EXPRESS = bindings::PCI_CLASS_STORAGE_EXPRESS, // 0x010802 + STORAGE_OTHER = bindings::PCI_CLASS_STORAGE_OTHER, // 0x018000 + + NETWORK_ETHERNET = bindings::PCI_CLASS_NETWORK_ETHERNET, // 0x020000 + NETWORK_TOKEN_RING = bindings::PCI_CLASS_NETWORK_TOKEN_RING, // 0x020100 + NETWORK_FDDI = bindings::PCI_CLASS_NETWORK_FDDI, // 0x020200 + NETWORK_ATM = bindings::PCI_CLASS_NETWORK_ATM, // 0x020300 + NETWORK_OTHER = bindings::PCI_CLASS_NETWORK_OTHER, // 0x028000 + + DISPLAY_VGA = bindings::PCI_CLASS_DISPLAY_VGA, // 0x030000 + DISPLAY_XGA = bindings::PCI_CLASS_DISPLAY_XGA, // 0x030100 + DISPLAY_3D = bindings::PCI_CLASS_DISPLAY_3D, // 0x030200 + DISPLAY_OTHER = bindings::PCI_CLASS_DISPLAY_OTHER, // 0x038000 + + MULTIMEDIA_VIDEO = bindings::PCI_CLASS_MULTIMEDIA_VIDEO, // 0x040000 + MULTIMEDIA_AUDIO = bindings::PCI_CLASS_MULTIMEDIA_AUDIO, // 0x040100 + MULTIMEDIA_PHONE = bindings::PCI_CLASS_MULTIMEDIA_PHONE, // 0x040200 + MULTIMEDIA_HD_AUDIO = bindings::PCI_CLASS_MULTIMEDIA_HD_AUDIO, // 0x040300 + MULTIMEDIA_OTHER = bindings::PCI_CLASS_MULTIMEDIA_OTHER, // 0x048000 + + MEMORY_RAM = bindings::PCI_CLASS_MEMORY_RAM, // 0x050000 + MEMORY_FLASH = bindings::PCI_CLASS_MEMORY_FLASH, // 0x050100 + MEMORY_CXL = bindings::PCI_CLASS_MEMORY_CXL, // 0x050200 + MEMORY_OTHER = bindings::PCI_CLASS_MEMORY_OTHER, // 0x058000 + + BRIDGE_HOST = bindings::PCI_CLASS_BRIDGE_HOST, // 0x060000 + BRIDGE_ISA = bindings::PCI_CLASS_BRIDGE_ISA, // 0x060100 + BRIDGE_EISA = bindings::PCI_CLASS_BRIDGE_EISA, // 0x060200 + BRIDGE_MC = bindings::PCI_CLASS_BRIDGE_MC, // 0x060300 + BRIDGE_PCI_NORMAL = bindings::PCI_CLASS_BRIDGE_PCI_NORMAL, // 0x060400 + BRIDGE_PCI_SUBTRACTIVE = bindings::PCI_CLASS_BRIDGE_PCI_SUBTRACTIVE, // 0x060401 + BRIDGE_PCMCIA = bindings::PCI_CLASS_BRIDGE_PCMCIA, // 0x060500 + BRIDGE_NUBUS = bindings::PCI_CLASS_BRIDGE_NUBUS, // 0x060600 + BRIDGE_CARDBUS = bindings::PCI_CLASS_BRIDGE_CARDBUS, // 0x060700 + BRIDGE_RACEWAY = bindings::PCI_CLASS_BRIDGE_RACEWAY, // 0x060800 + BRIDGE_OTHER = bindings::PCI_CLASS_BRIDGE_OTHER, // 0x068000 + + COMMUNICATION_SERIAL = bindings::PCI_CLASS_COMMUNICATION_SERIAL, // 0x070000 + COMMUNICATION_PARALLEL = bindings::PCI_CLASS_COMMUNICATION_PARALLEL, // 0x070100 + COMMUNICATION_MULTISERIAL = bindings::PCI_CLASS_COMMUNICATION_MULTISERIAL, // 0x070200 + COMMUNICATION_MODEM = bindings::PCI_CLASS_COMMUNICATION_MODEM, // 0x070300 + COMMUNICATION_OTHER = bindings::PCI_CLASS_COMMUNICATION_OTHER, // 0x078000 + + SYSTEM_PIC = bindings::PCI_CLASS_SYSTEM_PIC, // 0x080000 + SYSTEM_PIC_IOAPIC = bindings::PCI_CLASS_SYSTEM_PIC_IOAPIC, // 0x080010 + SYSTEM_PIC_IOXAPIC = bindings::PCI_CLASS_SYSTEM_PIC_IOXAPIC, // 0x080020 + SYSTEM_DMA = bindings::PCI_CLASS_SYSTEM_DMA, // 0x080100 + SYSTEM_TIMER = bindings::PCI_CLASS_SYSTEM_TIMER, // 0x080200 + SYSTEM_RTC = bindings::PCI_CLASS_SYSTEM_RTC, // 0x080300 + SYSTEM_PCI_HOTPLUG = bindings::PCI_CLASS_SYSTEM_PCI_HOTPLUG, // 0x080400 + SYSTEM_SDHCI = bindings::PCI_CLASS_SYSTEM_SDHCI, // 0x080500 + SYSTEM_RCEC = bindings::PCI_CLASS_SYSTEM_RCEC, // 0x080700 + SYSTEM_OTHER = bindings::PCI_CLASS_SYSTEM_OTHER, // 0x088000 + + INPUT_KEYBOARD = bindings::PCI_CLASS_INPUT_KEYBOARD, // 0x090000 + INPUT_PEN = bindings::PCI_CLASS_INPUT_PEN, // 0x090100 + INPUT_MOUSE = bindings::PCI_CLASS_INPUT_MOUSE, // 0x090200 + INPUT_SCANNER = bindings::PCI_CLASS_INPUT_SCANNER, // 0x090300 + INPUT_GAMEPORT = bindings::PCI_CLASS_INPUT_GAMEPORT, // 0x090400 + INPUT_OTHER = bindings::PCI_CLASS_INPUT_OTHER, // 0x098000 + + DOCKING_GENERIC = bindings::PCI_CLASS_DOCKING_GENERIC, // 0x0a0000 + DOCKING_OTHER = bindings::PCI_CLASS_DOCKING_OTHER, // 0x0a8000 + + PROCESSOR_386 = bindings::PCI_CLASS_PROCESSOR_386, // 0x0b0000 + PROCESSOR_486 = bindings::PCI_CLASS_PROCESSOR_486, // 0x0b0100 + PROCESSOR_PENTIUM = bindings::PCI_CLASS_PROCESSOR_PENTIUM, // 0x0b0200 + PROCESSOR_ALPHA = bindings::PCI_CLASS_PROCESSOR_ALPHA, // 0x0b1000 + PROCESSOR_POWERPC = bindings::PCI_CLASS_PROCESSOR_POWERPC, // 0x0b2000 + PROCESSOR_MIPS = bindings::PCI_CLASS_PROCESSOR_MIPS, // 0x0b3000 + PROCESSOR_CO = bindings::PCI_CLASS_PROCESSOR_CO, // 0x0b4000 + + SERIAL_FIREWIRE = bindings::PCI_CLASS_SERIAL_FIREWIRE, // 0x0c0000 + SERIAL_FIREWIRE_OHCI = bindings::PCI_CLASS_SERIAL_FIREWIRE_OHCI, // 0x0c0010 + SERIAL_ACCESS = bindings::PCI_CLASS_SERIAL_ACCESS, // 0x0c0100 + SERIAL_SSA = bindings::PCI_CLASS_SERIAL_SSA, // 0x0c0200 + SERIAL_USB_UHCI = bindings::PCI_CLASS_SERIAL_USB_UHCI, // 0x0c0300 + SERIAL_USB_OHCI = bindings::PCI_CLASS_SERIAL_USB_OHCI, // 0x0c0310 + SERIAL_USB_EHCI = bindings::PCI_CLASS_SERIAL_USB_EHCI, // 0x0c0320 + SERIAL_USB_XHCI = bindings::PCI_CLASS_SERIAL_USB_XHCI, // 0x0c0330 + SERIAL_USB_CDNS = bindings::PCI_CLASS_SERIAL_USB_CDNS, // 0x0c0380 + SERIAL_USB_DEVICE = bindings::PCI_CLASS_SERIAL_USB_DEVICE, // 0x0c03fe + SERIAL_FIBER = bindings::PCI_CLASS_SERIAL_FIBER, // 0x0c0400 + SERIAL_SMBUS = bindings::PCI_CLASS_SERIAL_SMBUS, // 0x0c0500 + SERIAL_IPMI_SMIC = bindings::PCI_CLASS_SERIAL_IPMI_SMIC, // 0x0c0700 + SERIAL_IPMI_KCS = bindings::PCI_CLASS_SERIAL_IPMI_KCS, // 0x0c0701 + SERIAL_IPMI_BT = bindings::PCI_CLASS_SERIAL_IPMI_BT, // 0x0c0702 + + WIRELESS_RF_CONTROLLER = bindings::PCI_CLASS_WIRELESS_RF_CONTROLLER, // 0x0d1000 + WIRELESS_WHCI = bindings::PCI_CLASS_WIRELESS_WHCI, // 0x0d1010 + + INTELLIGENT_I2O = bindings::PCI_CLASS_INTELLIGENT_I2O, // 0x0e0000 + + SATELLITE_TV = bindings::PCI_CLASS_SATELLITE_TV, // 0x0f0000 + SATELLITE_AUDIO = bindings::PCI_CLASS_SATELLITE_AUDIO, // 0x0f0100 + SATELLITE_VOICE = bindings::PCI_CLASS_SATELLITE_VOICE, // 0x0f0300 + SATELLITE_DATA = bindings::PCI_CLASS_SATELLITE_DATA, // 0x0f0400 + + CRYPT_NETWORK = bindings::PCI_CLASS_CRYPT_NETWORK, // 0x100000 + CRYPT_ENTERTAINMENT = bindings::PCI_CLASS_CRYPT_ENTERTAINMENT, // 0x100100 + CRYPT_OTHER = bindings::PCI_CLASS_CRYPT_OTHER, // 0x108000 + + SP_DPIO = bindings::PCI_CLASS_SP_DPIO, // 0x110000 + SP_OTHER = bindings::PCI_CLASS_SP_OTHER, // 0x118000 + + ACCELERATOR_PROCESSING = bindings::PCI_CLASS_ACCELERATOR_PROCESSING, // 0x120000 + + OTHERS = bindings::PCI_CLASS_OTHERS, // 0xff0000 +} + +define_all_pci_vendors! { + PCI_SIG = bindings::PCI_VENDOR_ID_PCI_SIG, // 0x0001 + LOONGSON = bindings::PCI_VENDOR_ID_LOONGSON, // 0x0014 + SOLIDIGM = bindings::PCI_VENDOR_ID_SOLIDIGM, // 0x025e + TTTECH = bindings::PCI_VENDOR_ID_TTTECH, // 0x0357 + DYNALINK = bindings::PCI_VENDOR_ID_DYNALINK, // 0x0675 + UBIQUITI = bindings::PCI_VENDOR_ID_UBIQUITI, // 0x0777 + BERKOM = bindings::PCI_VENDOR_ID_BERKOM, // 0x0871 + ITTIM = bindings::PCI_VENDOR_ID_ITTIM, // 0x0b48 + COMPAQ = bindings::PCI_VENDOR_ID_COMPAQ, // 0x0e11 + LSI_LOGIC = bindings::PCI_VENDOR_ID_LSI_LOGIC, // 0x1000 + ATI = bindings::PCI_VENDOR_ID_ATI, // 0x1002 + VLSI = bindings::PCI_VENDOR_ID_VLSI, // 0x1004 + ADL = bindings::PCI_VENDOR_ID_ADL, // 0x1005 + NS = bindings::PCI_VENDOR_ID_NS, // 0x100b + TSENG = bindings::PCI_VENDOR_ID_TSENG, // 0x100c + WEITEK = bindings::PCI_VENDOR_ID_WEITEK, // 0x100e + DEC = bindings::PCI_VENDOR_ID_DEC, // 0x1011 + CIRRUS = bindings::PCI_VENDOR_ID_CIRRUS, // 0x1013 + IBM = bindings::PCI_VENDOR_ID_IBM, // 0x1014 + UNISYS = bindings::PCI_VENDOR_ID_UNISYS, // 0x1018 + COMPEX2 = bindings::PCI_VENDOR_ID_COMPEX2, // 0x101a + WD = bindings::PCI_VENDOR_ID_WD, // 0x101c + AMI = bindings::PCI_VENDOR_ID_AMI, // 0x101e + AMD = bindings::PCI_VENDOR_ID_AMD, // 0x1022 + TRIDENT = bindings::PCI_VENDOR_ID_TRIDENT, // 0x1023 + AI = bindings::PCI_VENDOR_ID_AI, // 0x1025 + DELL = bindings::PCI_VENDOR_ID_DELL, // 0x1028 + MATROX = bindings::PCI_VENDOR_ID_MATROX, // 0x102B + MOBILITY_ELECTRONICS = bindings::PCI_VENDOR_ID_MOBILITY_ELECTRONICS, // 0x14f2 + CT = bindings::PCI_VENDOR_ID_CT, // 0x102c + MIRO = bindings::PCI_VENDOR_ID_MIRO, // 0x1031 + NEC = bindings::PCI_VENDOR_ID_NEC, // 0x1033 + FD = bindings::PCI_VENDOR_ID_FD, // 0x1036 + SI = bindings::PCI_VENDOR_ID_SI, // 0x1039 + HP = bindings::PCI_VENDOR_ID_HP, // 0x103c + HP_3PAR = bindings::PCI_VENDOR_ID_HP_3PAR, // 0x1590 + PCTECH = bindings::PCI_VENDOR_ID_PCTECH, // 0x1042 + ASUSTEK = bindings::PCI_VENDOR_ID_ASUSTEK, // 0x1043 + DPT = bindings::PCI_VENDOR_ID_DPT, // 0x1044 + OPTI = bindings::PCI_VENDOR_ID_OPTI, // 0x1045 + ELSA = bindings::PCI_VENDOR_ID_ELSA, // 0x1048 + STMICRO = bindings::PCI_VENDOR_ID_STMICRO, // 0x104A + BUSLOGIC = bindings::PCI_VENDOR_ID_BUSLOGIC, // 0x104B + TI = bindings::PCI_VENDOR_ID_TI, // 0x104c + SONY = bindings::PCI_VENDOR_ID_SONY, // 0x104d + WINBOND2 = bindings::PCI_VENDOR_ID_WINBOND2, // 0x1050 + ANIGMA = bindings::PCI_VENDOR_ID_ANIGMA, // 0x1051 + EFAR = bindings::PCI_VENDOR_ID_EFAR, // 0x1055 + MOTOROLA = bindings::PCI_VENDOR_ID_MOTOROLA, // 0x1057 + PROMISE = bindings::PCI_VENDOR_ID_PROMISE, // 0x105a + FOXCONN = bindings::PCI_VENDOR_ID_FOXCONN, // 0x105b + UMC = bindings::PCI_VENDOR_ID_UMC, // 0x1060 + PICOPOWER = bindings::PCI_VENDOR_ID_PICOPOWER, // 0x1066 + MYLEX = bindings::PCI_VENDOR_ID_MYLEX, // 0x1069 + APPLE = bindings::PCI_VENDOR_ID_APPLE, // 0x106b + YAMAHA = bindings::PCI_VENDOR_ID_YAMAHA, // 0x1073 + QLOGIC = bindings::PCI_VENDOR_ID_QLOGIC, // 0x1077 + CYRIX = bindings::PCI_VENDOR_ID_CYRIX, // 0x1078 + CONTAQ = bindings::PCI_VENDOR_ID_CONTAQ, // 0x1080 + OLICOM = bindings::PCI_VENDOR_ID_OLICOM, // 0x108d + SUN = bindings::PCI_VENDOR_ID_SUN, // 0x108e + NI = bindings::PCI_VENDOR_ID_NI, // 0x1093 + CMD = bindings::PCI_VENDOR_ID_CMD, // 0x1095 + BROOKTREE = bindings::PCI_VENDOR_ID_BROOKTREE, // 0x109e + SGI = bindings::PCI_VENDOR_ID_SGI, // 0x10a9 + WINBOND = bindings::PCI_VENDOR_ID_WINBOND, // 0x10ad + PLX = bindings::PCI_VENDOR_ID_PLX, // 0x10b5 + MADGE = bindings::PCI_VENDOR_ID_MADGE, // 0x10b6 + THREECOM = bindings::PCI_VENDOR_ID_3COM, // 0x10b7 + AL = bindings::PCI_VENDOR_ID_AL, // 0x10b9 + NEOMAGIC = bindings::PCI_VENDOR_ID_NEOMAGIC, // 0x10c8 + TCONRAD = bindings::PCI_VENDOR_ID_TCONRAD, // 0x10da + ROHM = bindings::PCI_VENDOR_ID_ROHM, // 0x10db + NVIDIA = bindings::PCI_VENDOR_ID_NVIDIA, // 0x10de + IMS = bindings::PCI_VENDOR_ID_IMS, // 0x10e0 + AMCC = bindings::PCI_VENDOR_ID_AMCC, // 0x10e8 + AMPERE = bindings::PCI_VENDOR_ID_AMPERE, // 0x1def + INTERG = bindings::PCI_VENDOR_ID_INTERG, // 0x10ea + REALTEK = bindings::PCI_VENDOR_ID_REALTEK, // 0x10ec + XILINX = bindings::PCI_VENDOR_ID_XILINX, // 0x10ee + INIT = bindings::PCI_VENDOR_ID_INIT, // 0x1101 + CREATIVE = bindings::PCI_VENDOR_ID_CREATIVE, // 0x1102 + TTI = bindings::PCI_VENDOR_ID_TTI, // 0x1103 + SIGMA = bindings::PCI_VENDOR_ID_SIGMA, // 0x1105 + VIA = bindings::PCI_VENDOR_ID_VIA, // 0x1106 + SIEMENS = bindings::PCI_VENDOR_ID_SIEMENS, // 0x110A + VORTEX = bindings::PCI_VENDOR_ID_VORTEX, // 0x1119 + EF = bindings::PCI_VENDOR_ID_EF, // 0x111a + IDT = bindings::PCI_VENDOR_ID_IDT, // 0x111d + FORE = bindings::PCI_VENDOR_ID_FORE, // 0x1127 + PHILIPS = bindings::PCI_VENDOR_ID_PHILIPS, // 0x1131 + EICON = bindings::PCI_VENDOR_ID_EICON, // 0x1133 + CISCO = bindings::PCI_VENDOR_ID_CISCO, // 0x1137 + ZIATECH = bindings::PCI_VENDOR_ID_ZIATECH, // 0x1138 + SYSKONNECT = bindings::PCI_VENDOR_ID_SYSKONNECT, // 0x1148 + DIGI = bindings::PCI_VENDOR_ID_DIGI, // 0x114f + XIRCOM = bindings::PCI_VENDOR_ID_XIRCOM, // 0x115d + SERVERWORKS = bindings::PCI_VENDOR_ID_SERVERWORKS, // 0x1166 + ALTERA = bindings::PCI_VENDOR_ID_ALTERA, // 0x1172 + SBE = bindings::PCI_VENDOR_ID_SBE, // 0x1176 + TOSHIBA = bindings::PCI_VENDOR_ID_TOSHIBA, // 0x1179 + TOSHIBA_2 = bindings::PCI_VENDOR_ID_TOSHIBA_2, // 0x102f + ATTO = bindings::PCI_VENDOR_ID_ATTO, // 0x117c + RICOH = bindings::PCI_VENDOR_ID_RICOH, // 0x1180 + DLINK = bindings::PCI_VENDOR_ID_DLINK, // 0x1186 + ARTOP = bindings::PCI_VENDOR_ID_ARTOP, // 0x1191 + ZEITNET = bindings::PCI_VENDOR_ID_ZEITNET, // 0x1193 + FUJITSU_ME = bindings::PCI_VENDOR_ID_FUJITSU_ME, // 0x119e + MARVELL = bindings::PCI_VENDOR_ID_MARVELL, // 0x11ab + MARVELL_EXT = bindings::PCI_VENDOR_ID_MARVELL_EXT, // 0x1b4b + V3 = bindings::PCI_VENDOR_ID_V3, // 0x11b0 + ATT = bindings::PCI_VENDOR_ID_ATT, // 0x11c1 + SPECIALIX = bindings::PCI_VENDOR_ID_SPECIALIX, // 0x11cb + ANALOG_DEVICES = bindings::PCI_VENDOR_ID_ANALOG_DEVICES, // 0x11d4 + ZORAN = bindings::PCI_VENDOR_ID_ZORAN, // 0x11de + COMPEX = bindings::PCI_VENDOR_ID_COMPEX, // 0x11f6 + MICROSEMI = bindings::PCI_VENDOR_ID_MICROSEMI, // 0x11f8 + RP = bindings::PCI_VENDOR_ID_RP, // 0x11fe + CYCLADES = bindings::PCI_VENDOR_ID_CYCLADES, // 0x120e + ESSENTIAL = bindings::PCI_VENDOR_ID_ESSENTIAL, // 0x120f + O2 = bindings::PCI_VENDOR_ID_O2, // 0x1217 + THREEDX = bindings::PCI_VENDOR_ID_3DFX, // 0x121a + AVM = bindings::PCI_VENDOR_ID_AVM, // 0x1244 + STALLION = bindings::PCI_VENDOR_ID_STALLION, // 0x124d + AT = bindings::PCI_VENDOR_ID_AT, // 0x1259 + ASIX = bindings::PCI_VENDOR_ID_ASIX, // 0x125b + ESS = bindings::PCI_VENDOR_ID_ESS, // 0x125d + SATSAGEM = bindings::PCI_VENDOR_ID_SATSAGEM, // 0x1267 + ENSONIQ = bindings::PCI_VENDOR_ID_ENSONIQ, // 0x1274 + TRANSMETA = bindings::PCI_VENDOR_ID_TRANSMETA, // 0x1279 + ROCKWELL = bindings::PCI_VENDOR_ID_ROCKWELL, // 0x127A + ITE = bindings::PCI_VENDOR_ID_ITE, // 0x1283 + ALTEON = bindings::PCI_VENDOR_ID_ALTEON, // 0x12ae + NVIDIA_SGS = bindings::PCI_VENDOR_ID_NVIDIA_SGS, // 0x12d2 + PERICOM = bindings::PCI_VENDOR_ID_PERICOM, // 0x12D8 + AUREAL = bindings::PCI_VENDOR_ID_AUREAL, // 0x12eb + ELECTRONICDESIGNGMBH = bindings::PCI_VENDOR_ID_ELECTRONICDESIGNGMBH, // 0x12f8 + ESDGMBH = bindings::PCI_VENDOR_ID_ESDGMBH, // 0x12fe + CB = bindings::PCI_VENDOR_ID_CB, // 0x1307 + SIIG = bindings::PCI_VENDOR_ID_SIIG, // 0x131f + RADISYS = bindings::PCI_VENDOR_ID_RADISYS, // 0x1331 + MICRO_MEMORY = bindings::PCI_VENDOR_ID_MICRO_MEMORY, // 0x1332 + DOMEX = bindings::PCI_VENDOR_ID_DOMEX, // 0x134a + INTASHIELD = bindings::PCI_VENDOR_ID_INTASHIELD, // 0x135a + QUATECH = bindings::PCI_VENDOR_ID_QUATECH, // 0x135C + SEALEVEL = bindings::PCI_VENDOR_ID_SEALEVEL, // 0x135e + HYPERCOPE = bindings::PCI_VENDOR_ID_HYPERCOPE, // 0x1365 + DIGIGRAM = bindings::PCI_VENDOR_ID_DIGIGRAM, // 0x1369 + KAWASAKI = bindings::PCI_VENDOR_ID_KAWASAKI, // 0x136b + CNET = bindings::PCI_VENDOR_ID_CNET, // 0x1371 + LMC = bindings::PCI_VENDOR_ID_LMC, // 0x1376 + NETGEAR = bindings::PCI_VENDOR_ID_NETGEAR, // 0x1385 + APPLICOM = bindings::PCI_VENDOR_ID_APPLICOM, // 0x1389 + MOXA = bindings::PCI_VENDOR_ID_MOXA, // 0x1393 + CCD = bindings::PCI_VENDOR_ID_CCD, // 0x1397 + EXAR = bindings::PCI_VENDOR_ID_EXAR, // 0x13a8 + MICROGATE = bindings::PCI_VENDOR_ID_MICROGATE, // 0x13c0 + THREEWARE = bindings::PCI_VENDOR_ID_3WARE, // 0x13C1 + IOMEGA = bindings::PCI_VENDOR_ID_IOMEGA, // 0x13ca + ABOCOM = bindings::PCI_VENDOR_ID_ABOCOM, // 0x13D1 + SUNDANCE = bindings::PCI_VENDOR_ID_SUNDANCE, // 0x13f0 + CMEDIA = bindings::PCI_VENDOR_ID_CMEDIA, // 0x13f6 + ADVANTECH = bindings::PCI_VENDOR_ID_ADVANTECH, // 0x13fe + MEILHAUS = bindings::PCI_VENDOR_ID_MEILHAUS, // 0x1402 + LAVA = bindings::PCI_VENDOR_ID_LAVA, // 0x1407 + TIMEDIA = bindings::PCI_VENDOR_ID_TIMEDIA, // 0x1409 + ICE = bindings::PCI_VENDOR_ID_ICE, // 0x1412 + MICROSOFT = bindings::PCI_VENDOR_ID_MICROSOFT, // 0x1414 + OXSEMI = bindings::PCI_VENDOR_ID_OXSEMI, // 0x1415 + CHELSIO = bindings::PCI_VENDOR_ID_CHELSIO, // 0x1425 + EDIMAX = bindings::PCI_VENDOR_ID_EDIMAX, // 0x1432 + ADLINK = bindings::PCI_VENDOR_ID_ADLINK, // 0x144a + SAMSUNG = bindings::PCI_VENDOR_ID_SAMSUNG, // 0x144d + GIGABYTE = bindings::PCI_VENDOR_ID_GIGABYTE, // 0x1458 + AMBIT = bindings::PCI_VENDOR_ID_AMBIT, // 0x1468 + MYRICOM = bindings::PCI_VENDOR_ID_MYRICOM, // 0x14c1 + MEDIATEK = bindings::PCI_VENDOR_ID_MEDIATEK, // 0x14c3 + TITAN = bindings::PCI_VENDOR_ID_TITAN, // 0x14D2 + PANACOM = bindings::PCI_VENDOR_ID_PANACOM, // 0x14d4 + SIPACKETS = bindings::PCI_VENDOR_ID_SIPACKETS, // 0x14d9 + AFAVLAB = bindings::PCI_VENDOR_ID_AFAVLAB, // 0x14db + AMPLICON = bindings::PCI_VENDOR_ID_AMPLICON, // 0x14dc + BCM_GVC = bindings::PCI_VENDOR_ID_BCM_GVC, // 0x14a4 + BROADCOM = bindings::PCI_VENDOR_ID_BROADCOM, // 0x14e4 + TOPIC = bindings::PCI_VENDOR_ID_TOPIC, // 0x151f + MAINPINE = bindings::PCI_VENDOR_ID_MAINPINE, // 0x1522 + ENE = bindings::PCI_VENDOR_ID_ENE, // 0x1524 + SYBA = bindings::PCI_VENDOR_ID_SYBA, // 0x1592 + MORETON = bindings::PCI_VENDOR_ID_MORETON, // 0x15aa + VMWARE = bindings::PCI_VENDOR_ID_VMWARE, // 0x15ad + ZOLTRIX = bindings::PCI_VENDOR_ID_ZOLTRIX, // 0x15b0 + MELLANOX = bindings::PCI_VENDOR_ID_MELLANOX, // 0x15b3 + DFI = bindings::PCI_VENDOR_ID_DFI, // 0x15bd + QUICKNET = bindings::PCI_VENDOR_ID_QUICKNET, // 0x15e2 + ADDIDATA = bindings::PCI_VENDOR_ID_ADDIDATA, // 0x15B8 + PDC = bindings::PCI_VENDOR_ID_PDC, // 0x15e9 + FARSITE = bindings::PCI_VENDOR_ID_FARSITE, // 0x1619 + ARIMA = bindings::PCI_VENDOR_ID_ARIMA, // 0x161f + BROCADE = bindings::PCI_VENDOR_ID_BROCADE, // 0x1657 + SIBYTE = bindings::PCI_VENDOR_ID_SIBYTE, // 0x166d + ATHEROS = bindings::PCI_VENDOR_ID_ATHEROS, // 0x168c + NETCELL = bindings::PCI_VENDOR_ID_NETCELL, // 0x169c + CENATEK = bindings::PCI_VENDOR_ID_CENATEK, // 0x16CA + SYNOPSYS = bindings::PCI_VENDOR_ID_SYNOPSYS, // 0x16c3 + USR = bindings::PCI_VENDOR_ID_USR, // 0x16ec + VITESSE = bindings::PCI_VENDOR_ID_VITESSE, // 0x1725 + LINKSYS = bindings::PCI_VENDOR_ID_LINKSYS, // 0x1737 + ALTIMA = bindings::PCI_VENDOR_ID_ALTIMA, // 0x173b + CAVIUM = bindings::PCI_VENDOR_ID_CAVIUM, // 0x177d + TECHWELL = bindings::PCI_VENDOR_ID_TECHWELL, // 0x1797 + BELKIN = bindings::PCI_VENDOR_ID_BELKIN, // 0x1799 + RDC = bindings::PCI_VENDOR_ID_RDC, // 0x17f3 + GLI = bindings::PCI_VENDOR_ID_GLI, // 0x17a0 + LENOVO = bindings::PCI_VENDOR_ID_LENOVO, // 0x17aa + QCOM = bindings::PCI_VENDOR_ID_QCOM, // 0x17cb + CDNS = bindings::PCI_VENDOR_ID_CDNS, // 0x17cd + ARECA = bindings::PCI_VENDOR_ID_ARECA, // 0x17d3 + S2IO = bindings::PCI_VENDOR_ID_S2IO, // 0x17d5 + SITECOM = bindings::PCI_VENDOR_ID_SITECOM, // 0x182d + TOPSPIN = bindings::PCI_VENDOR_ID_TOPSPIN, // 0x1867 + COMMTECH = bindings::PCI_VENDOR_ID_COMMTECH, // 0x18f7 + SILAN = bindings::PCI_VENDOR_ID_SILAN, // 0x1904 + RENESAS = bindings::PCI_VENDOR_ID_RENESAS, // 0x1912 + SOLARFLARE = bindings::PCI_VENDOR_ID_SOLARFLARE, // 0x1924 + TDI = bindings::PCI_VENDOR_ID_TDI, // 0x192E + NXP = bindings::PCI_VENDOR_ID_NXP, // 0x1957 + PASEMI = bindings::PCI_VENDOR_ID_PASEMI, // 0x1959 + ATTANSIC = bindings::PCI_VENDOR_ID_ATTANSIC, // 0x1969 + JMICRON = bindings::PCI_VENDOR_ID_JMICRON, // 0x197B + KORENIX = bindings::PCI_VENDOR_ID_KORENIX, // 0x1982 + HUAWEI = bindings::PCI_VENDOR_ID_HUAWEI, // 0x19e5 + NETRONOME = bindings::PCI_VENDOR_ID_NETRONOME, // 0x19ee + QMI = bindings::PCI_VENDOR_ID_QMI, // 0x1a32 + AZWAVE = bindings::PCI_VENDOR_ID_AZWAVE, // 0x1a3b + REDHAT_QUMRANET = bindings::PCI_VENDOR_ID_REDHAT_QUMRANET, // 0x1af4 + ASMEDIA = bindings::PCI_VENDOR_ID_ASMEDIA, // 0x1b21 + REDHAT = bindings::PCI_VENDOR_ID_REDHAT, // 0x1b36 + WCHIC = bindings::PCI_VENDOR_ID_WCHIC, // 0x1c00 + SILICOM_DENMARK = bindings::PCI_VENDOR_ID_SILICOM_DENMARK, // 0x1c2c + AMAZON_ANNAPURNA_LABS = bindings::PCI_VENDOR_ID_AMAZON_ANNAPURNA_LABS, // 0x1c36 + CIRCUITCO = bindings::PCI_VENDOR_ID_CIRCUITCO, // 0x1cc8 + AMAZON = bindings::PCI_VENDOR_ID_AMAZON, // 0x1d0f + ZHAOXIN = bindings::PCI_VENDOR_ID_ZHAOXIN, // 0x1d17 + ROCKCHIP = bindings::PCI_VENDOR_ID_ROCKCHIP, // 0x1d87 + HYGON = bindings::PCI_VENDOR_ID_HYGON, // 0x1d94 + META = bindings::PCI_VENDOR_ID_META, // 0x1d9b + FUNGIBLE = bindings::PCI_VENDOR_ID_FUNGIBLE, // 0x1dad + HXT = bindings::PCI_VENDOR_ID_HXT, // 0x1dbf + TEKRAM = bindings::PCI_VENDOR_ID_TEKRAM, // 0x1de1 + RPI = bindings::PCI_VENDOR_ID_RPI, // 0x1de4 + ALIBABA = bindings::PCI_VENDOR_ID_ALIBABA, // 0x1ded + CXL = bindings::PCI_VENDOR_ID_CXL, // 0x1e98 + TEHUTI = bindings::PCI_VENDOR_ID_TEHUTI, // 0x1fc9 + SUNIX = bindings::PCI_VENDOR_ID_SUNIX, // 0x1fd4 + HINT = bindings::PCI_VENDOR_ID_HINT, // 0x3388 + THREEDLABS = bindings::PCI_VENDOR_ID_3DLABS, // 0x3d3d + NETXEN = bindings::PCI_VENDOR_ID_NETXEN, // 0x4040 + AKS = bindings::PCI_VENDOR_ID_AKS, // 0x416c + WCHCN = bindings::PCI_VENDOR_ID_WCHCN, // 0x4348 + ACCESSIO = bindings::PCI_VENDOR_ID_ACCESSIO, // 0x494f + S3 = bindings::PCI_VENDOR_ID_S3, // 0x5333 + DUNORD = bindings::PCI_VENDOR_ID_DUNORD, // 0x5544 + DCI = bindings::PCI_VENDOR_ID_DCI, // 0x6666 + GLENFLY = bindings::PCI_VENDOR_ID_GLENFLY, // 0x6766 + INTEL = bindings::PCI_VENDOR_ID_INTEL, // 0x8086 + WANGXUN = bindings::PCI_VENDOR_ID_WANGXUN, // 0x8088 + SCALEMP = bindings::PCI_VENDOR_ID_SCALEMP, // 0x8686 + COMPUTONE = bindings::PCI_VENDOR_ID_COMPUTONE, // 0x8e0e + KTI = bindings::PCI_VENDOR_ID_KTI, // 0x8e2e + ADAPTEC = bindings::PCI_VENDOR_ID_ADAPTEC, // 0x9004 + ADAPTEC2 = bindings::PCI_VENDOR_ID_ADAPTEC2, // 0x9005 + HOLTEK = bindings::PCI_VENDOR_ID_HOLTEK, // 0x9412 + NETMOS = bindings::PCI_VENDOR_ID_NETMOS, // 0x9710 + THREECOM_2 = bindings::PCI_VENDOR_ID_3COM_2, // 0xa727 + SOLIDRUN = bindings::PCI_VENDOR_ID_SOLIDRUN, // 0xd063 + DIGIUM = bindings::PCI_VENDOR_ID_DIGIUM, // 0xd161 + TIGERJET = bindings::PCI_VENDOR_ID_TIGERJET, // 0xe159 + XILINX_RME = bindings::PCI_VENDOR_ID_XILINX_RME, // 0xea60 + XEN = bindings::PCI_VENDOR_ID_XEN, // 0x5853 + OCZ = bindings::PCI_VENDOR_ID_OCZ, // 0x1b85 + NCUBE = bindings::PCI_VENDOR_ID_NCUBE, // 0x10ff +} diff --git a/rust/kernel/pci/io.rs b/rust/kernel/pci/io.rs new file mode 100644 index 000000000000..0d55c3139b6f --- /dev/null +++ b/rust/kernel/pci/io.rs @@ -0,0 +1,144 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! PCI memory-mapped I/O infrastructure. + +use super::Device; +use crate::{ + bindings, + device, + devres::Devres, + io::{ + Io, + IoRaw, // + }, + prelude::*, + sync::aref::ARef, // +}; +use core::ops::Deref; + +/// A PCI BAR to perform I/O-Operations on. +/// +/// # Invariants +/// +/// `Bar` always holds an `IoRaw` inststance that holds a valid pointer to the start of the I/O +/// memory mapped PCI BAR and its size. +pub struct Bar<const SIZE: usize = 0> { + pdev: ARef<Device>, + io: IoRaw<SIZE>, + num: i32, +} + +impl<const SIZE: usize> Bar<SIZE> { + pub(super) fn new(pdev: &Device, num: u32, name: &CStr) -> Result<Self> { + let len = pdev.resource_len(num)?; + if len == 0 { + return Err(ENOMEM); + } + + // Convert to `i32`, since that's what all the C bindings use. + let num = i32::try_from(num)?; + + // SAFETY: + // `pdev` is valid by the invariants of `Device`. + // `num` is checked for validity by a previous call to `Device::resource_len`. + // `name` is always valid. + let ret = unsafe { bindings::pci_request_region(pdev.as_raw(), num, name.as_char_ptr()) }; + if ret != 0 { + return Err(EBUSY); + } + + // SAFETY: + // `pdev` is valid by the invariants of `Device`. + // `num` is checked for validity by a previous call to `Device::resource_len`. + // `name` is always valid. + let ioptr: usize = unsafe { bindings::pci_iomap(pdev.as_raw(), num, 0) } as usize; + if ioptr == 0 { + // SAFETY: + // `pdev` valid by the invariants of `Device`. + // `num` is checked for validity by a previous call to `Device::resource_len`. + unsafe { bindings::pci_release_region(pdev.as_raw(), num) }; + return Err(ENOMEM); + } + + let io = match IoRaw::new(ioptr, len as usize) { + Ok(io) => io, + Err(err) => { + // SAFETY: + // `pdev` is valid by the invariants of `Device`. + // `ioptr` is guaranteed to be the start of a valid I/O mapped memory region. + // `num` is checked for validity by a previous call to `Device::resource_len`. + unsafe { Self::do_release(pdev, ioptr, num) }; + return Err(err); + } + }; + + Ok(Bar { + pdev: pdev.into(), + io, + num, + }) + } + + /// # Safety + /// + /// `ioptr` must be a valid pointer to the memory mapped PCI BAR number `num`. + unsafe fn do_release(pdev: &Device, ioptr: usize, num: i32) { + // SAFETY: + // `pdev` is valid by the invariants of `Device`. + // `ioptr` is valid by the safety requirements. + // `num` is valid by the safety requirements. + unsafe { + bindings::pci_iounmap(pdev.as_raw(), ioptr as *mut c_void); + bindings::pci_release_region(pdev.as_raw(), num); + } + } + + fn release(&self) { + // SAFETY: The safety requirements are guaranteed by the type invariant of `self.pdev`. + unsafe { Self::do_release(&self.pdev, self.io.addr(), self.num) }; + } +} + +impl Bar { + #[inline] + pub(super) fn index_is_valid(index: u32) -> bool { + // A `struct pci_dev` owns an array of resources with at most `PCI_NUM_RESOURCES` entries. + index < bindings::PCI_NUM_RESOURCES + } +} + +impl<const SIZE: usize> Drop for Bar<SIZE> { + fn drop(&mut self) { + self.release(); + } +} + +impl<const SIZE: usize> Deref for Bar<SIZE> { + type Target = Io<SIZE>; + + fn deref(&self) -> &Self::Target { + // SAFETY: By the type invariant of `Self`, the MMIO range in `self.io` is properly mapped. + unsafe { Io::from_raw(&self.io) } + } +} + +impl Device<device::Bound> { + /// Maps an entire PCI BAR after performing a region-request on it. I/O operation bound checks + /// can be performed on compile time for offsets (plus the requested type size) < SIZE. + pub fn iomap_region_sized<'a, const SIZE: usize>( + &'a self, + bar: u32, + name: &'a CStr, + ) -> impl PinInit<Devres<Bar<SIZE>>, Error> + 'a { + Devres::new(self.as_ref(), Bar::<SIZE>::new(self, bar, name)) + } + + /// Maps an entire PCI BAR after performing a region-request on it. + pub fn iomap_region<'a>( + &'a self, + bar: u32, + name: &'a CStr, + ) -> impl PinInit<Devres<Bar>, Error> + 'a { + self.iomap_region_sized::<0>(bar, name) + } +} diff --git a/rust/kernel/pci/irq.rs b/rust/kernel/pci/irq.rs new file mode 100644 index 000000000000..d9230e105541 --- /dev/null +++ b/rust/kernel/pci/irq.rs @@ -0,0 +1,252 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! PCI interrupt infrastructure. + +use super::Device; +use crate::{ + bindings, + device, + device::Bound, + devres, + error::to_result, + irq::{ + self, + IrqRequest, // + }, + prelude::*, + str::CStr, + sync::aref::ARef, // +}; +use core::ops::RangeInclusive; + +/// IRQ type flags for PCI interrupt allocation. +#[derive(Debug, Clone, Copy)] +pub enum IrqType { + /// INTx interrupts. + Intx, + /// Message Signaled Interrupts (MSI). + Msi, + /// Extended Message Signaled Interrupts (MSI-X). + MsiX, +} + +impl IrqType { + /// Convert to the corresponding kernel flags. + const fn as_raw(self) -> u32 { + match self { + IrqType::Intx => bindings::PCI_IRQ_INTX, + IrqType::Msi => bindings::PCI_IRQ_MSI, + IrqType::MsiX => bindings::PCI_IRQ_MSIX, + } + } +} + +/// Set of IRQ types that can be used for PCI interrupt allocation. +#[derive(Debug, Clone, Copy, Default)] +pub struct IrqTypes(u32); + +impl IrqTypes { + /// Create a set containing all IRQ types (MSI-X, MSI, and INTx). + pub const fn all() -> Self { + Self(bindings::PCI_IRQ_ALL_TYPES) + } + + /// Build a set of IRQ types. + /// + /// # Examples + /// + /// ```ignore + /// // Create a set with only MSI and MSI-X (no INTx interrupts). + /// let msi_only = IrqTypes::default() + /// .with(IrqType::Msi) + /// .with(IrqType::MsiX); + /// ``` + pub const fn with(self, irq_type: IrqType) -> Self { + Self(self.0 | irq_type.as_raw()) + } + + /// Get the raw flags value. + const fn as_raw(self) -> u32 { + self.0 + } +} + +/// Represents an allocated IRQ vector for a specific PCI device. +/// +/// This type ties an IRQ vector to the device it was allocated for, +/// ensuring the vector is only used with the correct device. +#[derive(Clone, Copy)] +pub struct IrqVector<'a> { + dev: &'a Device<Bound>, + index: u32, +} + +impl<'a> IrqVector<'a> { + /// Creates a new [`IrqVector`] for the given device and index. + /// + /// # Safety + /// + /// - `index` must be a valid IRQ vector index for `dev`. + /// - `dev` must point to a [`Device`] that has successfully allocated IRQ vectors. + unsafe fn new(dev: &'a Device<Bound>, index: u32) -> Self { + Self { dev, index } + } + + /// Returns the raw vector index. + fn index(&self) -> u32 { + self.index + } +} + +impl<'a> TryInto<IrqRequest<'a>> for IrqVector<'a> { + type Error = Error; + + fn try_into(self) -> Result<IrqRequest<'a>> { + // SAFETY: `self.as_raw` returns a valid pointer to a `struct pci_dev`. + let irq = unsafe { bindings::pci_irq_vector(self.dev.as_raw(), self.index()) }; + if irq < 0 { + return Err(crate::error::Error::from_errno(irq)); + } + // SAFETY: `irq` is guaranteed to be a valid IRQ number for `&self`. + Ok(unsafe { IrqRequest::new(self.dev.as_ref(), irq as u32) }) + } +} + +/// Represents an IRQ vector allocation for a PCI device. +/// +/// This type ensures that IRQ vectors are properly allocated and freed by +/// tying the allocation to the lifetime of this registration object. +/// +/// # Invariants +/// +/// The [`Device`] has successfully allocated IRQ vectors. +struct IrqVectorRegistration { + dev: ARef<Device>, +} + +impl IrqVectorRegistration { + /// Allocate and register IRQ vectors for the given PCI device. + /// + /// Allocates IRQ vectors and registers them with devres for automatic cleanup. + /// Returns a range of valid IRQ vectors. + fn register<'a>( + dev: &'a Device<Bound>, + min_vecs: u32, + max_vecs: u32, + irq_types: IrqTypes, + ) -> Result<RangeInclusive<IrqVector<'a>>> { + // SAFETY: + // - `dev.as_raw()` is guaranteed to be a valid pointer to a `struct pci_dev` + // by the type invariant of `Device`. + // - `pci_alloc_irq_vectors` internally validates all other parameters + // and returns error codes. + let ret = unsafe { + bindings::pci_alloc_irq_vectors(dev.as_raw(), min_vecs, max_vecs, irq_types.as_raw()) + }; + + to_result(ret)?; + let count = ret as u32; + + // SAFETY: + // - `pci_alloc_irq_vectors` returns the number of allocated vectors on success. + // - Vectors are 0-based, so valid indices are [0, count-1]. + // - `pci_alloc_irq_vectors` guarantees `count >= min_vecs > 0`, so both `0` and + // `count - 1` are valid IRQ vector indices for `dev`. + let range = unsafe { IrqVector::new(dev, 0)..=IrqVector::new(dev, count - 1) }; + + // INVARIANT: The IRQ vector allocation for `dev` above was successful. + let irq_vecs = Self { dev: dev.into() }; + devres::register(dev.as_ref(), irq_vecs, GFP_KERNEL)?; + + Ok(range) + } +} + +impl Drop for IrqVectorRegistration { + fn drop(&mut self) { + // SAFETY: + // - By the type invariant, `self.dev.as_raw()` is a valid pointer to a `struct pci_dev`. + // - `self.dev` has successfully allocated IRQ vectors. + unsafe { bindings::pci_free_irq_vectors(self.dev.as_raw()) }; + } +} + +impl Device<device::Bound> { + /// Returns a [`kernel::irq::Registration`] for the given IRQ vector. + pub fn request_irq<'a, T: crate::irq::Handler + 'static>( + &'a self, + vector: IrqVector<'a>, + flags: irq::Flags, + name: &'static CStr, + handler: impl PinInit<T, Error> + 'a, + ) -> impl PinInit<irq::Registration<T>, Error> + 'a { + pin_init::pin_init_scope(move || { + let request = vector.try_into()?; + + Ok(irq::Registration::<T>::new(request, flags, name, handler)) + }) + } + + /// Returns a [`kernel::irq::ThreadedRegistration`] for the given IRQ vector. + pub fn request_threaded_irq<'a, T: crate::irq::ThreadedHandler + 'static>( + &'a self, + vector: IrqVector<'a>, + flags: irq::Flags, + name: &'static CStr, + handler: impl PinInit<T, Error> + 'a, + ) -> impl PinInit<irq::ThreadedRegistration<T>, Error> + 'a { + pin_init::pin_init_scope(move || { + let request = vector.try_into()?; + + Ok(irq::ThreadedRegistration::<T>::new( + request, flags, name, handler, + )) + }) + } + + /// Allocate IRQ vectors for this PCI device with automatic cleanup. + /// + /// Allocates between `min_vecs` and `max_vecs` interrupt vectors for the device. + /// The allocation will use MSI-X, MSI, or INTx interrupts based on the `irq_types` + /// parameter and hardware capabilities. When multiple types are specified, the kernel + /// will try them in order of preference: MSI-X first, then MSI, then INTx interrupts. + /// + /// The allocated vectors are automatically freed when the device is unbound, using the + /// devres (device resource management) system. + /// + /// # Arguments + /// + /// * `min_vecs` - Minimum number of vectors required. + /// * `max_vecs` - Maximum number of vectors to allocate. + /// * `irq_types` - Types of interrupts that can be used. + /// + /// # Returns + /// + /// Returns a range of IRQ vectors that were successfully allocated, or an error if the + /// allocation fails or cannot meet the minimum requirement. + /// + /// # Examples + /// + /// ``` + /// # use kernel::{ device::Bound, pci}; + /// # fn no_run(dev: &pci::Device<Bound>) -> Result { + /// // Allocate using any available interrupt type in the order mentioned above. + /// let vectors = dev.alloc_irq_vectors(1, 32, pci::IrqTypes::all())?; + /// + /// // Allocate MSI or MSI-X only (no INTx interrupts). + /// let msi_only = pci::IrqTypes::default() + /// .with(pci::IrqType::Msi) + /// .with(pci::IrqType::MsiX); + /// let vectors = dev.alloc_irq_vectors(4, 16, msi_only)?; + /// # Ok(()) + /// # } + /// ``` + pub fn alloc_irq_vectors( + &self, + min_vecs: u32, + max_vecs: u32, + irq_types: IrqTypes, + ) -> Result<RangeInclusive<IrqVector<'_>>> { + IrqVectorRegistration::register(self, min_vecs, max_vecs, irq_types) + } +} diff --git a/rust/kernel/pid_namespace.rs b/rust/kernel/pid_namespace.rs new file mode 100644 index 000000000000..979a9718f153 --- /dev/null +++ b/rust/kernel/pid_namespace.rs @@ -0,0 +1,65 @@ +// SPDX-License-Identifier: GPL-2.0 + +// Copyright (c) 2024 Christian Brauner <brauner@kernel.org> + +//! Pid namespaces. +//! +//! C header: [`include/linux/pid_namespace.h`](srctree/include/linux/pid_namespace.h) and +//! [`include/linux/pid.h`](srctree/include/linux/pid.h) + +use crate::{bindings, sync::aref::AlwaysRefCounted, types::Opaque}; +use core::ptr; + +/// Wraps the kernel's `struct pid_namespace`. Thread safe. +/// +/// This structure represents the Rust abstraction for a C `struct pid_namespace`. This +/// implementation abstracts the usage of an already existing C `struct pid_namespace` within Rust +/// code that we get passed from the C side. +#[repr(transparent)] +pub struct PidNamespace { + inner: Opaque<bindings::pid_namespace>, +} + +impl PidNamespace { + /// Returns a raw pointer to the inner C struct. + #[inline] + pub fn as_ptr(&self) -> *mut bindings::pid_namespace { + self.inner.get() + } + + /// Creates a reference to a [`PidNamespace`] from a valid pointer. + /// + /// # Safety + /// + /// The caller must ensure that `ptr` is valid and remains valid for the lifetime of the + /// returned [`PidNamespace`] reference. + pub unsafe fn from_ptr<'a>(ptr: *const bindings::pid_namespace) -> &'a Self { + // SAFETY: The safety requirements guarantee the validity of the dereference, while the + // `PidNamespace` type being transparent makes the cast ok. + unsafe { &*ptr.cast() } + } +} + +// SAFETY: Instances of `PidNamespace` are always reference-counted. +unsafe impl AlwaysRefCounted for PidNamespace { + #[inline] + fn inc_ref(&self) { + // SAFETY: The existence of a shared reference means that the refcount is nonzero. + unsafe { bindings::get_pid_ns(self.as_ptr()) }; + } + + #[inline] + unsafe fn dec_ref(obj: ptr::NonNull<PidNamespace>) { + // SAFETY: The safety requirements guarantee that the refcount is non-zero. + unsafe { bindings::put_pid_ns(obj.cast().as_ptr()) } + } +} + +// SAFETY: +// - `PidNamespace::dec_ref` can be called from any thread. +// - It is okay to send ownership of `PidNamespace` across thread boundaries. +unsafe impl Send for PidNamespace {} + +// SAFETY: It's OK to access `PidNamespace` through shared references from other threads because +// we're either accessing properties that don't change or that are properly synchronised by C code. +unsafe impl Sync for PidNamespace {} diff --git a/rust/kernel/platform.rs b/rust/kernel/platform.rs new file mode 100644 index 000000000000..ed889f079cab --- /dev/null +++ b/rust/kernel/platform.rs @@ -0,0 +1,532 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! Abstractions for the platform bus. +//! +//! C header: [`include/linux/platform_device.h`](srctree/include/linux/platform_device.h) + +use crate::{ + acpi, bindings, container_of, + device::{self, Bound}, + driver, + error::{from_result, to_result, Result}, + io::{mem::IoRequest, Resource}, + irq::{self, IrqRequest}, + of, + prelude::*, + types::Opaque, + ThisModule, +}; + +use core::{ + marker::PhantomData, + mem::offset_of, + ptr::{addr_of_mut, NonNull}, +}; + +/// An adapter for the registration of platform drivers. +pub struct Adapter<T: Driver>(T); + +// SAFETY: A call to `unregister` for a given instance of `RegType` is guaranteed to be valid if +// a preceding call to `register` has been successful. +unsafe impl<T: Driver + 'static> driver::RegistrationOps for Adapter<T> { + type RegType = bindings::platform_driver; + + unsafe fn register( + pdrv: &Opaque<Self::RegType>, + name: &'static CStr, + module: &'static ThisModule, + ) -> Result { + let of_table = match T::OF_ID_TABLE { + Some(table) => table.as_ptr(), + None => core::ptr::null(), + }; + + let acpi_table = match T::ACPI_ID_TABLE { + Some(table) => table.as_ptr(), + None => core::ptr::null(), + }; + + // SAFETY: It's safe to set the fields of `struct platform_driver` on initialization. + unsafe { + (*pdrv.get()).driver.name = name.as_char_ptr(); + (*pdrv.get()).probe = Some(Self::probe_callback); + (*pdrv.get()).remove = Some(Self::remove_callback); + (*pdrv.get()).driver.of_match_table = of_table; + (*pdrv.get()).driver.acpi_match_table = acpi_table; + } + + // SAFETY: `pdrv` is guaranteed to be a valid `RegType`. + to_result(unsafe { bindings::__platform_driver_register(pdrv.get(), module.0) }) + } + + unsafe fn unregister(pdrv: &Opaque<Self::RegType>) { + // SAFETY: `pdrv` is guaranteed to be a valid `RegType`. + unsafe { bindings::platform_driver_unregister(pdrv.get()) }; + } +} + +impl<T: Driver + 'static> Adapter<T> { + extern "C" fn probe_callback(pdev: *mut bindings::platform_device) -> kernel::ffi::c_int { + // SAFETY: The platform bus only ever calls the probe callback with a valid pointer to a + // `struct platform_device`. + // + // INVARIANT: `pdev` is valid for the duration of `probe_callback()`. + let pdev = unsafe { &*pdev.cast::<Device<device::CoreInternal>>() }; + let info = <Self as driver::Adapter>::id_info(pdev.as_ref()); + + from_result(|| { + let data = T::probe(pdev, info); + + pdev.as_ref().set_drvdata(data)?; + Ok(0) + }) + } + + extern "C" fn remove_callback(pdev: *mut bindings::platform_device) { + // SAFETY: The platform bus only ever calls the remove callback with a valid pointer to a + // `struct platform_device`. + // + // INVARIANT: `pdev` is valid for the duration of `probe_callback()`. + let pdev = unsafe { &*pdev.cast::<Device<device::CoreInternal>>() }; + + // SAFETY: `remove_callback` is only ever called after a successful call to + // `probe_callback`, hence it's guaranteed that `Device::set_drvdata()` has been called + // and stored a `Pin<KBox<T>>`. + let data = unsafe { pdev.as_ref().drvdata_obtain::<T>() }; + + T::unbind(pdev, data.as_ref()); + } +} + +impl<T: Driver + 'static> driver::Adapter for Adapter<T> { + type IdInfo = T::IdInfo; + + fn of_id_table() -> Option<of::IdTable<Self::IdInfo>> { + T::OF_ID_TABLE + } + + fn acpi_id_table() -> Option<acpi::IdTable<Self::IdInfo>> { + T::ACPI_ID_TABLE + } +} + +/// Declares a kernel module that exposes a single platform driver. +/// +/// # Examples +/// +/// ```ignore +/// kernel::module_platform_driver! { +/// type: MyDriver, +/// name: "Module name", +/// authors: ["Author name"], +/// description: "Description", +/// license: "GPL v2", +/// } +/// ``` +#[macro_export] +macro_rules! module_platform_driver { + ($($f:tt)*) => { + $crate::module_driver!(<T>, $crate::platform::Adapter<T>, { $($f)* }); + }; +} + +/// The platform driver trait. +/// +/// Drivers must implement this trait in order to get a platform driver registered. +/// +/// # Examples +/// +///``` +/// # use kernel::{acpi, bindings, c_str, device::Core, of, platform}; +/// +/// struct MyDriver; +/// +/// kernel::of_device_table!( +/// OF_TABLE, +/// MODULE_OF_TABLE, +/// <MyDriver as platform::Driver>::IdInfo, +/// [ +/// (of::DeviceId::new(c_str!("test,device")), ()) +/// ] +/// ); +/// +/// kernel::acpi_device_table!( +/// ACPI_TABLE, +/// MODULE_ACPI_TABLE, +/// <MyDriver as platform::Driver>::IdInfo, +/// [ +/// (acpi::DeviceId::new(c_str!("LNUXBEEF")), ()) +/// ] +/// ); +/// +/// impl platform::Driver for MyDriver { +/// type IdInfo = (); +/// const OF_ID_TABLE: Option<of::IdTable<Self::IdInfo>> = Some(&OF_TABLE); +/// const ACPI_ID_TABLE: Option<acpi::IdTable<Self::IdInfo>> = Some(&ACPI_TABLE); +/// +/// fn probe( +/// _pdev: &platform::Device<Core>, +/// _id_info: Option<&Self::IdInfo>, +/// ) -> impl PinInit<Self, Error> { +/// Err(ENODEV) +/// } +/// } +///``` +pub trait Driver: Send { + /// The type holding driver private data about each device id supported by the driver. + // TODO: Use associated_type_defaults once stabilized: + // + // ``` + // type IdInfo: 'static = (); + // ``` + type IdInfo: 'static; + + /// The table of OF device ids supported by the driver. + const OF_ID_TABLE: Option<of::IdTable<Self::IdInfo>> = None; + + /// The table of ACPI device ids supported by the driver. + const ACPI_ID_TABLE: Option<acpi::IdTable<Self::IdInfo>> = None; + + /// Platform driver probe. + /// + /// Called when a new platform device is added or discovered. + /// Implementers should attempt to initialize the device here. + fn probe( + dev: &Device<device::Core>, + id_info: Option<&Self::IdInfo>, + ) -> impl PinInit<Self, Error>; + + /// Platform driver unbind. + /// + /// Called when a [`Device`] is unbound from its bound [`Driver`]. Implementing this callback + /// is optional. + /// + /// This callback serves as a place for drivers to perform teardown operations that require a + /// `&Device<Core>` or `&Device<Bound>` reference. For instance, drivers may try to perform I/O + /// operations to gracefully tear down the device. + /// + /// Otherwise, release operations for driver resources should be performed in `Self::drop`. + fn unbind(dev: &Device<device::Core>, this: Pin<&Self>) { + let _ = (dev, this); + } +} + +/// The platform device representation. +/// +/// This structure represents the Rust abstraction for a C `struct platform_device`. The +/// implementation abstracts the usage of an already existing C `struct platform_device` within Rust +/// code that we get passed from the C side. +/// +/// # Invariants +/// +/// A [`Device`] instance represents a valid `struct platform_device` created by the C portion of +/// the kernel. +#[repr(transparent)] +pub struct Device<Ctx: device::DeviceContext = device::Normal>( + Opaque<bindings::platform_device>, + PhantomData<Ctx>, +); + +impl<Ctx: device::DeviceContext> Device<Ctx> { + fn as_raw(&self) -> *mut bindings::platform_device { + self.0.get() + } + + /// Returns the resource at `index`, if any. + pub fn resource_by_index(&self, index: u32) -> Option<&Resource> { + // SAFETY: `self.as_raw()` returns a valid pointer to a `struct platform_device`. + let resource = unsafe { + bindings::platform_get_resource(self.as_raw(), bindings::IORESOURCE_MEM, index) + }; + + if resource.is_null() { + return None; + } + + // SAFETY: `resource` is a valid pointer to a `struct resource` as + // returned by `platform_get_resource`. + Some(unsafe { Resource::from_raw(resource) }) + } + + /// Returns the resource with a given `name`, if any. + pub fn resource_by_name(&self, name: &CStr) -> Option<&Resource> { + // SAFETY: `self.as_raw()` returns a valid pointer to a `struct + // platform_device` and `name` points to a valid C string. + let resource = unsafe { + bindings::platform_get_resource_byname( + self.as_raw(), + bindings::IORESOURCE_MEM, + name.as_char_ptr(), + ) + }; + + if resource.is_null() { + return None; + } + + // SAFETY: `resource` is a valid pointer to a `struct resource` as + // returned by `platform_get_resource`. + Some(unsafe { Resource::from_raw(resource) }) + } +} + +impl Device<Bound> { + /// Returns an `IoRequest` for the resource at `index`, if any. + pub fn io_request_by_index(&self, index: u32) -> Option<IoRequest<'_>> { + self.resource_by_index(index) + // SAFETY: `resource` is a valid resource for `&self` during the + // lifetime of the `IoRequest`. + .map(|resource| unsafe { IoRequest::new(self.as_ref(), resource) }) + } + + /// Returns an `IoRequest` for the resource with a given `name`, if any. + pub fn io_request_by_name(&self, name: &CStr) -> Option<IoRequest<'_>> { + self.resource_by_name(name) + // SAFETY: `resource` is a valid resource for `&self` during the + // lifetime of the `IoRequest`. + .map(|resource| unsafe { IoRequest::new(self.as_ref(), resource) }) + } +} + +// SAFETY: `platform::Device` is a transparent wrapper of `struct platform_device`. +// The offset is guaranteed to point to a valid device field inside `platform::Device`. +unsafe impl<Ctx: device::DeviceContext> device::AsBusDevice<Ctx> for Device<Ctx> { + const OFFSET: usize = offset_of!(bindings::platform_device, dev); +} + +macro_rules! define_irq_accessor_by_index { + ( + $(#[$meta:meta])* $fn_name:ident, + $request_fn:ident, + $reg_type:ident, + $handler_trait:ident + ) => { + $(#[$meta])* + pub fn $fn_name<'a, T: irq::$handler_trait + 'static>( + &'a self, + flags: irq::Flags, + index: u32, + name: &'static CStr, + handler: impl PinInit<T, Error> + 'a, + ) -> impl PinInit<irq::$reg_type<T>, Error> + 'a { + pin_init::pin_init_scope(move || { + let request = self.$request_fn(index)?; + + Ok(irq::$reg_type::<T>::new( + request, + flags, + name, + handler, + )) + }) + } + }; +} + +macro_rules! define_irq_accessor_by_name { + ( + $(#[$meta:meta])* $fn_name:ident, + $request_fn:ident, + $reg_type:ident, + $handler_trait:ident + ) => { + $(#[$meta])* + pub fn $fn_name<'a, T: irq::$handler_trait + 'static>( + &'a self, + flags: irq::Flags, + irq_name: &'a CStr, + name: &'static CStr, + handler: impl PinInit<T, Error> + 'a, + ) -> impl PinInit<irq::$reg_type<T>, Error> + 'a { + pin_init::pin_init_scope(move || { + let request = self.$request_fn(irq_name)?; + + Ok(irq::$reg_type::<T>::new( + request, + flags, + name, + handler, + )) + }) + } + }; +} + +impl Device<Bound> { + /// Returns an [`IrqRequest`] for the IRQ at the given index, if any. + pub fn irq_by_index(&self, index: u32) -> Result<IrqRequest<'_>> { + // SAFETY: `self.as_raw` returns a valid pointer to a `struct platform_device`. + let irq = unsafe { bindings::platform_get_irq(self.as_raw(), index) }; + + if irq < 0 { + return Err(Error::from_errno(irq)); + } + + // SAFETY: `irq` is guaranteed to be a valid IRQ number for `&self`. + Ok(unsafe { IrqRequest::new(self.as_ref(), irq as u32) }) + } + + /// Returns an [`IrqRequest`] for the IRQ at the given index, but does not + /// print an error if the IRQ cannot be obtained. + pub fn optional_irq_by_index(&self, index: u32) -> Result<IrqRequest<'_>> { + // SAFETY: `self.as_raw` returns a valid pointer to a `struct platform_device`. + let irq = unsafe { bindings::platform_get_irq_optional(self.as_raw(), index) }; + + if irq < 0 { + return Err(Error::from_errno(irq)); + } + + // SAFETY: `irq` is guaranteed to be a valid IRQ number for `&self`. + Ok(unsafe { IrqRequest::new(self.as_ref(), irq as u32) }) + } + + /// Returns an [`IrqRequest`] for the IRQ with the given name, if any. + pub fn irq_by_name(&self, name: &CStr) -> Result<IrqRequest<'_>> { + // SAFETY: `self.as_raw` returns a valid pointer to a `struct platform_device`. + let irq = unsafe { bindings::platform_get_irq_byname(self.as_raw(), name.as_char_ptr()) }; + + if irq < 0 { + return Err(Error::from_errno(irq)); + } + + // SAFETY: `irq` is guaranteed to be a valid IRQ number for `&self`. + Ok(unsafe { IrqRequest::new(self.as_ref(), irq as u32) }) + } + + /// Returns an [`IrqRequest`] for the IRQ with the given name, but does not + /// print an error if the IRQ cannot be obtained. + pub fn optional_irq_by_name(&self, name: &CStr) -> Result<IrqRequest<'_>> { + // SAFETY: `self.as_raw` returns a valid pointer to a `struct platform_device`. + let irq = unsafe { + bindings::platform_get_irq_byname_optional(self.as_raw(), name.as_char_ptr()) + }; + + if irq < 0 { + return Err(Error::from_errno(irq)); + } + + // SAFETY: `irq` is guaranteed to be a valid IRQ number for `&self`. + Ok(unsafe { IrqRequest::new(self.as_ref(), irq as u32) }) + } + + define_irq_accessor_by_index!( + /// Returns a [`irq::Registration`] for the IRQ at the given index. + request_irq_by_index, + irq_by_index, + Registration, + Handler + ); + define_irq_accessor_by_name!( + /// Returns a [`irq::Registration`] for the IRQ with the given name. + request_irq_by_name, + irq_by_name, + Registration, + Handler + ); + define_irq_accessor_by_index!( + /// Does the same as [`Self::request_irq_by_index`], except that it does + /// not print an error message if the IRQ cannot be obtained. + request_optional_irq_by_index, + optional_irq_by_index, + Registration, + Handler + ); + define_irq_accessor_by_name!( + /// Does the same as [`Self::request_irq_by_name`], except that it does + /// not print an error message if the IRQ cannot be obtained. + request_optional_irq_by_name, + optional_irq_by_name, + Registration, + Handler + ); + + define_irq_accessor_by_index!( + /// Returns a [`irq::ThreadedRegistration`] for the IRQ at the given index. + request_threaded_irq_by_index, + irq_by_index, + ThreadedRegistration, + ThreadedHandler + ); + define_irq_accessor_by_name!( + /// Returns a [`irq::ThreadedRegistration`] for the IRQ with the given name. + request_threaded_irq_by_name, + irq_by_name, + ThreadedRegistration, + ThreadedHandler + ); + define_irq_accessor_by_index!( + /// Does the same as [`Self::request_threaded_irq_by_index`], except + /// that it does not print an error message if the IRQ cannot be + /// obtained. + request_optional_threaded_irq_by_index, + optional_irq_by_index, + ThreadedRegistration, + ThreadedHandler + ); + define_irq_accessor_by_name!( + /// Does the same as [`Self::request_threaded_irq_by_name`], except that + /// it does not print an error message if the IRQ cannot be obtained. + request_optional_threaded_irq_by_name, + optional_irq_by_name, + ThreadedRegistration, + ThreadedHandler + ); +} + +// SAFETY: `Device` is a transparent wrapper of a type that doesn't depend on `Device`'s generic +// argument. +kernel::impl_device_context_deref!(unsafe { Device }); +kernel::impl_device_context_into_aref!(Device); + +impl crate::dma::Device for Device<device::Core> {} + +// SAFETY: Instances of `Device` are always reference-counted. +unsafe impl crate::sync::aref::AlwaysRefCounted for Device { + fn inc_ref(&self) { + // SAFETY: The existence of a shared reference guarantees that the refcount is non-zero. + unsafe { bindings::get_device(self.as_ref().as_raw()) }; + } + + unsafe fn dec_ref(obj: NonNull<Self>) { + // SAFETY: The safety requirements guarantee that the refcount is non-zero. + unsafe { bindings::platform_device_put(obj.cast().as_ptr()) } + } +} + +impl<Ctx: device::DeviceContext> AsRef<device::Device<Ctx>> for Device<Ctx> { + fn as_ref(&self) -> &device::Device<Ctx> { + // SAFETY: By the type invariant of `Self`, `self.as_raw()` is a pointer to a valid + // `struct platform_device`. + let dev = unsafe { addr_of_mut!((*self.as_raw()).dev) }; + + // SAFETY: `dev` points to a valid `struct device`. + unsafe { device::Device::from_raw(dev) } + } +} + +impl<Ctx: device::DeviceContext> TryFrom<&device::Device<Ctx>> for &Device<Ctx> { + type Error = kernel::error::Error; + + fn try_from(dev: &device::Device<Ctx>) -> Result<Self, Self::Error> { + // SAFETY: By the type invariant of `Device`, `dev.as_raw()` is a valid pointer to a + // `struct device`. + if !unsafe { bindings::dev_is_platform(dev.as_raw()) } { + return Err(EINVAL); + } + + // SAFETY: We've just verified that the bus type of `dev` equals + // `bindings::platform_bus_type`, hence `dev` must be embedded in a valid + // `struct platform_device` as guaranteed by the corresponding C code. + let pdev = unsafe { container_of!(dev.as_raw(), bindings::platform_device, dev) }; + + // SAFETY: `pdev` is a valid pointer to a `struct platform_device`. + Ok(unsafe { &*pdev.cast() }) + } +} + +// SAFETY: A `Device` is always reference-counted and can be released from any thread. +unsafe impl Send for Device {} + +// SAFETY: `Device` can be shared among threads because all methods of `Device` +// (i.e. `Device<Normal>) are thread safe. +unsafe impl Sync for Device {} diff --git a/rust/kernel/prelude.rs b/rust/kernel/prelude.rs index 7a90249ee9b9..2877e3f7b6d3 100644 --- a/rust/kernel/prelude.rs +++ b/rust/kernel/prelude.rs @@ -11,18 +11,45 @@ //! use kernel::prelude::*; //! ``` -pub use core::pin::Pin; +#[doc(no_inline)] +pub use core::{ + mem::{align_of, align_of_val, size_of, size_of_val}, + pin::Pin, +}; -pub use alloc::{boxed::Box, vec::Vec}; +pub use ::ffi::{ + c_char, c_int, c_long, c_longlong, c_schar, c_short, c_uchar, c_uint, c_ulong, c_ulonglong, + c_ushort, c_void, CStr, +}; -pub use macros::{module, vtable}; +pub use crate::alloc::{flags::*, Box, KBox, KVBox, KVVec, KVec, VBox, VVec, Vec}; -pub use super::build_assert; +#[doc(no_inline)] +pub use macros::{export, fmt, kunit_tests, module, vtable}; -pub use super::{dbg, pr_alert, pr_crit, pr_debug, pr_emerg, pr_err, pr_info, pr_notice, pr_warn}; +pub use pin_init::{init, pin_data, pin_init, pinned_drop, InPlaceWrite, Init, PinInit, Zeroable}; + +pub use super::{build_assert, build_error}; + +// `super::std_vendor` is hidden, which makes the macro inline for some reason. +#[doc(no_inline)] +pub use super::dbg; +pub use super::{dev_alert, dev_crit, dev_dbg, dev_emerg, dev_err, dev_info, dev_notice, dev_warn}; +pub use super::{pr_alert, pr_crit, pr_debug, pr_emerg, pr_err, pr_info, pr_notice, pr_warn}; + +pub use super::{try_init, try_pin_init}; pub use super::static_assert; pub use super::error::{code::*, Error, Result}; -pub use super::{str::CStr, ThisModule}; +pub use super::{str::CStrExt as _, ThisModule}; + +pub use super::init::InPlaceInit; + +pub use super::current; + +pub use super::uaccess::UserPtr; + +#[cfg(not(CONFIG_RUSTC_HAS_SLICE_AS_FLATTENED))] +pub use super::slice::AsFlattened; diff --git a/rust/kernel/print.rs b/rust/kernel/print.rs index 30103325696d..2d743d78d220 100644 --- a/rust/kernel/print.rs +++ b/rust/kernel/print.rs @@ -2,27 +2,30 @@ //! Printing facilities. //! -//! C header: [`include/linux/printk.h`](../../../../include/linux/printk.h) +//! C header: [`include/linux/printk.h`](srctree/include/linux/printk.h) //! -//! Reference: <https://www.kernel.org/doc/html/latest/core-api/printk-basics.html> +//! Reference: <https://docs.kernel.org/core-api/printk-basics.html> -use core::{ +use crate::{ ffi::{c_char, c_void}, fmt, + prelude::*, + str::RawFormatter, }; -use crate::str::RawFormatter; - -#[cfg(CONFIG_PRINTK)] -use crate::bindings; - // Called from `vsprintf` with format specifier `%pA`. -#[no_mangle] -unsafe fn rust_fmt_argument(buf: *mut c_char, end: *mut c_char, ptr: *const c_void) -> *mut c_char { +#[expect(clippy::missing_safety_doc)] +#[export] +unsafe extern "C" fn rust_fmt_argument( + buf: *mut c_char, + end: *mut c_char, + ptr: *const c_void, +) -> *mut c_char { use fmt::Write; // SAFETY: The C contract guarantees that `buf` is valid if it's less than `end`. let mut w = unsafe { RawFormatter::from_ptrs(buf.cast(), end.cast()) }; - let _ = w.write_fmt(unsafe { *(ptr as *const fmt::Arguments<'_>) }); + // SAFETY: TODO. + let _ = w.write_fmt(unsafe { *ptr.cast::<fmt::Arguments<'_>>() }); w.pos().cast() } @@ -31,8 +34,6 @@ unsafe fn rust_fmt_argument(buf: *mut c_char, end: *mut c_char, ptr: *const c_vo /// Public but hidden since it should only be used from public macros. #[doc(hidden)] pub mod format_strings { - use crate::bindings; - /// The length we copy from the `KERN_*` kernel prefixes. const LENGTH_PREFIX: usize = 2; @@ -44,7 +45,7 @@ pub mod format_strings { /// The format string is always the same for a given level, i.e. for a /// given `prefix`, which are the kernel's `KERN_*` constants. /// - /// [`_printk`]: ../../../../include/linux/printk.h + /// [`_printk`]: srctree/include/linux/printk.h const fn generate(is_cont: bool, prefix: &[u8; 3]) -> [u8; LENGTH] { // Ensure the `KERN_*` macros are what we expect. assert!(prefix[0] == b'\x01'); @@ -93,7 +94,7 @@ pub mod format_strings { /// The format string must be one of the ones in [`format_strings`], and /// the module name must be null-terminated. /// -/// [`_printk`]: ../../../../include/linux/_printk.h +/// [`_printk`]: srctree/include/linux/_printk.h #[doc(hidden)] #[cfg_attr(not(CONFIG_PRINTK), allow(unused_variables))] pub unsafe fn call_printk( @@ -103,11 +104,12 @@ pub unsafe fn call_printk( ) { // `_printk` does not seem to fail in any path. #[cfg(CONFIG_PRINTK)] + // SAFETY: TODO. unsafe { bindings::_printk( - format_string.as_ptr() as _, + format_string.as_ptr(), module_name.as_ptr(), - &args as *const _ as *const c_void, + core::ptr::from_ref(&args).cast::<c_void>(), ); } } @@ -116,7 +118,7 @@ pub unsafe fn call_printk( /// /// Public but hidden since it should only be used from public macros. /// -/// [`_printk`]: ../../../../include/linux/printk.h +/// [`_printk`]: srctree/include/linux/printk.h #[doc(hidden)] #[cfg_attr(not(CONFIG_PRINTK), allow(unused_variables))] pub fn call_printk_cont(args: fmt::Arguments<'_>) { @@ -126,8 +128,8 @@ pub fn call_printk_cont(args: fmt::Arguments<'_>) { #[cfg(CONFIG_PRINTK)] unsafe { bindings::_printk( - format_strings::CONT.as_ptr() as _, - &args as *const _ as *const c_void, + format_strings::CONT.as_ptr(), + core::ptr::from_ref(&args).cast::<c_void>(), ); } } @@ -138,7 +140,7 @@ pub fn call_printk_cont(args: fmt::Arguments<'_>) { #[doc(hidden)] #[cfg(not(testlib))] #[macro_export] -#[allow(clippy::crate_in_macro_def)] +#[expect(clippy::crate_in_macro_def)] macro_rules! print_macro ( // The non-continuation cases (most of them, e.g. `INFO`). ($format_string:path, false, $($arg:tt)+) => ( @@ -147,7 +149,7 @@ macro_rules! print_macro ( // takes borrows on the arguments, but does not extend the scope of temporaries. // Therefore, a `match` expression is used to keep them around, since // the scrutinee is kept until the end of the `match`. - match format_args!($($arg)+) { + match $crate::prelude::fmt!($($arg)+) { // SAFETY: This hidden macro should only be called by the documented // printing macros which ensure the format string is one of the fixed // ones. All `__LOG_PREFIX`s are null-terminated as they are generated @@ -166,7 +168,7 @@ macro_rules! print_macro ( // The `CONT` case. ($format_string:path, true, $($arg:tt)+) => ( $crate::print::call_printk_cont( - format_args!($($arg)+), + $crate::prelude::fmt!($($arg)+), ); ); ); @@ -196,10 +198,11 @@ macro_rules! print_macro ( /// Equivalent to the kernel's [`pr_emerg`] macro. /// /// Mimics the interface of [`std::print!`]. See [`core::fmt`] and -/// `alloc::format!` for information about the formatting syntax. +/// [`std::format!`] for information about the formatting syntax. /// -/// [`pr_emerg`]: https://www.kernel.org/doc/html/latest/core-api/printk-basics.html#c.pr_emerg +/// [`pr_emerg`]: https://docs.kernel.org/core-api/printk-basics.html#c.pr_emerg /// [`std::print!`]: https://doc.rust-lang.org/std/macro.print.html +/// [`std::format!`]: https://doc.rust-lang.org/std/macro.format.html /// /// # Examples /// @@ -220,10 +223,11 @@ macro_rules! pr_emerg ( /// Equivalent to the kernel's [`pr_alert`] macro. /// /// Mimics the interface of [`std::print!`]. See [`core::fmt`] and -/// `alloc::format!` for information about the formatting syntax. +/// [`std::format!`] for information about the formatting syntax. /// -/// [`pr_alert`]: https://www.kernel.org/doc/html/latest/core-api/printk-basics.html#c.pr_alert +/// [`pr_alert`]: https://docs.kernel.org/core-api/printk-basics.html#c.pr_alert /// [`std::print!`]: https://doc.rust-lang.org/std/macro.print.html +/// [`std::format!`]: https://doc.rust-lang.org/std/macro.format.html /// /// # Examples /// @@ -244,10 +248,11 @@ macro_rules! pr_alert ( /// Equivalent to the kernel's [`pr_crit`] macro. /// /// Mimics the interface of [`std::print!`]. See [`core::fmt`] and -/// `alloc::format!` for information about the formatting syntax. +/// [`std::format!`] for information about the formatting syntax. /// -/// [`pr_crit`]: https://www.kernel.org/doc/html/latest/core-api/printk-basics.html#c.pr_crit +/// [`pr_crit`]: https://docs.kernel.org/core-api/printk-basics.html#c.pr_crit /// [`std::print!`]: https://doc.rust-lang.org/std/macro.print.html +/// [`std::format!`]: https://doc.rust-lang.org/std/macro.format.html /// /// # Examples /// @@ -268,10 +273,11 @@ macro_rules! pr_crit ( /// Equivalent to the kernel's [`pr_err`] macro. /// /// Mimics the interface of [`std::print!`]. See [`core::fmt`] and -/// `alloc::format!` for information about the formatting syntax. +/// [`std::format!`] for information about the formatting syntax. /// -/// [`pr_err`]: https://www.kernel.org/doc/html/latest/core-api/printk-basics.html#c.pr_err +/// [`pr_err`]: https://docs.kernel.org/core-api/printk-basics.html#c.pr_err /// [`std::print!`]: https://doc.rust-lang.org/std/macro.print.html +/// [`std::format!`]: https://doc.rust-lang.org/std/macro.format.html /// /// # Examples /// @@ -292,10 +298,11 @@ macro_rules! pr_err ( /// Equivalent to the kernel's [`pr_warn`] macro. /// /// Mimics the interface of [`std::print!`]. See [`core::fmt`] and -/// `alloc::format!` for information about the formatting syntax. +/// [`std::format!`] for information about the formatting syntax. /// -/// [`pr_warn`]: https://www.kernel.org/doc/html/latest/core-api/printk-basics.html#c.pr_warn +/// [`pr_warn`]: https://docs.kernel.org/core-api/printk-basics.html#c.pr_warn /// [`std::print!`]: https://doc.rust-lang.org/std/macro.print.html +/// [`std::format!`]: https://doc.rust-lang.org/std/macro.format.html /// /// # Examples /// @@ -316,10 +323,11 @@ macro_rules! pr_warn ( /// Equivalent to the kernel's [`pr_notice`] macro. /// /// Mimics the interface of [`std::print!`]. See [`core::fmt`] and -/// `alloc::format!` for information about the formatting syntax. +/// [`std::format!`] for information about the formatting syntax. /// -/// [`pr_notice`]: https://www.kernel.org/doc/html/latest/core-api/printk-basics.html#c.pr_notice +/// [`pr_notice`]: https://docs.kernel.org/core-api/printk-basics.html#c.pr_notice /// [`std::print!`]: https://doc.rust-lang.org/std/macro.print.html +/// [`std::format!`]: https://doc.rust-lang.org/std/macro.format.html /// /// # Examples /// @@ -340,10 +348,11 @@ macro_rules! pr_notice ( /// Equivalent to the kernel's [`pr_info`] macro. /// /// Mimics the interface of [`std::print!`]. See [`core::fmt`] and -/// `alloc::format!` for information about the formatting syntax. +/// [`std::format!`] for information about the formatting syntax. /// -/// [`pr_info`]: https://www.kernel.org/doc/html/latest/core-api/printk-basics.html#c.pr_info +/// [`pr_info`]: https://docs.kernel.org/core-api/printk-basics.html#c.pr_info /// [`std::print!`]: https://doc.rust-lang.org/std/macro.print.html +/// [`std::format!`]: https://doc.rust-lang.org/std/macro.format.html /// /// # Examples /// @@ -366,10 +375,11 @@ macro_rules! pr_info ( /// yet. /// /// Mimics the interface of [`std::print!`]. See [`core::fmt`] and -/// `alloc::format!` for information about the formatting syntax. +/// [`std::format!`] for information about the formatting syntax. /// -/// [`pr_debug`]: https://www.kernel.org/doc/html/latest/core-api/printk-basics.html#c.pr_debug +/// [`pr_debug`]: https://docs.kernel.org/core-api/printk-basics.html#c.pr_debug /// [`std::print!`]: https://doc.rust-lang.org/std/macro.print.html +/// [`std::format!`]: https://doc.rust-lang.org/std/macro.format.html /// /// # Examples /// @@ -393,10 +403,12 @@ macro_rules! pr_debug ( /// Equivalent to the kernel's [`pr_cont`] macro. /// /// Mimics the interface of [`std::print!`]. See [`core::fmt`] and -/// `alloc::format!` for information about the formatting syntax. +/// [`std::format!`] for information about the formatting syntax. /// -/// [`pr_cont`]: https://www.kernel.org/doc/html/latest/core-api/printk-basics.html#c.pr_cont +/// [`pr_info!`]: crate::pr_info! +/// [`pr_cont`]: https://docs.kernel.org/core-api/printk-basics.html#c.pr_cont /// [`std::print!`]: https://doc.rust-lang.org/std/macro.print.html +/// [`std::format!`]: https://doc.rust-lang.org/std/macro.format.html /// /// # Examples /// diff --git a/rust/kernel/processor.rs b/rust/kernel/processor.rs new file mode 100644 index 000000000000..85b49b3614dd --- /dev/null +++ b/rust/kernel/processor.rs @@ -0,0 +1,14 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! Processor related primitives. +//! +//! C header: [`include/linux/processor.h`](srctree/include/linux/processor.h) + +/// Lower CPU power consumption or yield to a hyperthreaded twin processor. +/// +/// It also happens to serve as a compiler barrier. +#[inline] +pub fn cpu_relax() { + // SAFETY: Always safe to call. + unsafe { bindings::cpu_relax() } +} diff --git a/rust/kernel/ptr.rs b/rust/kernel/ptr.rs new file mode 100644 index 000000000000..e3893ed04049 --- /dev/null +++ b/rust/kernel/ptr.rs @@ -0,0 +1,227 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! Types and functions to work with pointers and addresses. + +use core::mem::align_of; +use core::num::NonZero; + +use crate::build_assert; + +/// Type representing an alignment, which is always a power of two. +/// +/// It is used to validate that a given value is a valid alignment, and to perform masking and +/// alignment operations. +/// +/// This is a temporary substitute for the [`Alignment`] nightly type from the standard library, +/// and to be eventually replaced by it. +/// +/// [`Alignment`]: https://github.com/rust-lang/rust/issues/102070 +/// +/// # Invariants +/// +/// An alignment is always a power of two. +#[repr(transparent)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub struct Alignment(NonZero<usize>); + +impl Alignment { + /// Validates that `ALIGN` is a power of two at build-time, and returns an [`Alignment`] of the + /// same value. + /// + /// A build error is triggered if `ALIGN` is not a power of two. + /// + /// # Examples + /// + /// ``` + /// use kernel::ptr::Alignment; + /// + /// let v = Alignment::new::<16>(); + /// assert_eq!(v.as_usize(), 16); + /// ``` + #[inline(always)] + pub const fn new<const ALIGN: usize>() -> Self { + build_assert!( + ALIGN.is_power_of_two(), + "Provided alignment is not a power of two." + ); + + // INVARIANT: `align` is a power of two. + // SAFETY: `align` is a power of two, and thus non-zero. + Self(unsafe { NonZero::new_unchecked(ALIGN) }) + } + + /// Validates that `align` is a power of two at runtime, and returns an + /// [`Alignment`] of the same value. + /// + /// Returns [`None`] if `align` is not a power of two. + /// + /// # Examples + /// + /// ``` + /// use kernel::ptr::Alignment; + /// + /// assert_eq!(Alignment::new_checked(16), Some(Alignment::new::<16>())); + /// assert_eq!(Alignment::new_checked(15), None); + /// assert_eq!(Alignment::new_checked(1), Some(Alignment::new::<1>())); + /// assert_eq!(Alignment::new_checked(0), None); + /// ``` + #[inline(always)] + pub const fn new_checked(align: usize) -> Option<Self> { + if align.is_power_of_two() { + // INVARIANT: `align` is a power of two. + // SAFETY: `align` is a power of two, and thus non-zero. + Some(Self(unsafe { NonZero::new_unchecked(align) })) + } else { + None + } + } + + /// Returns the alignment of `T`. + /// + /// This is equivalent to [`align_of`], but with the return value provided as an [`Alignment`]. + #[inline(always)] + pub const fn of<T>() -> Self { + #![allow(clippy::incompatible_msrv)] + // This cannot panic since alignments are always powers of two. + // + // We unfortunately cannot use `new` as it would require the `generic_const_exprs` feature. + const { Alignment::new_checked(align_of::<T>()).unwrap() } + } + + /// Returns this alignment as a [`usize`]. + /// + /// It is guaranteed to be a power of two. + /// + /// # Examples + /// + /// ``` + /// use kernel::ptr::Alignment; + /// + /// assert_eq!(Alignment::new::<16>().as_usize(), 16); + /// ``` + #[inline(always)] + pub const fn as_usize(self) -> usize { + self.as_nonzero().get() + } + + /// Returns this alignment as a [`NonZero`]. + /// + /// It is guaranteed to be a power of two. + /// + /// # Examples + /// + /// ``` + /// use kernel::ptr::Alignment; + /// + /// assert_eq!(Alignment::new::<16>().as_nonzero().get(), 16); + /// ``` + #[inline(always)] + pub const fn as_nonzero(self) -> NonZero<usize> { + // Allow the compiler to know that the value is indeed a power of two. This can help + // optimize some operations down the line, like e.g. replacing divisions by bit shifts. + if !self.0.is_power_of_two() { + // SAFETY: Per the invariants, `self.0` is always a power of two so this block will + // never be reached. + unsafe { core::hint::unreachable_unchecked() } + } + self.0 + } + + /// Returns the base-2 logarithm of the alignment. + /// + /// # Examples + /// + /// ``` + /// use kernel::ptr::Alignment; + /// + /// assert_eq!(Alignment::of::<u8>().log2(), 0); + /// assert_eq!(Alignment::new::<16>().log2(), 4); + /// ``` + #[inline(always)] + pub const fn log2(self) -> u32 { + self.0.ilog2() + } + + /// Returns the mask for this alignment. + /// + /// This is equivalent to `!(self.as_usize() - 1)`. + /// + /// # Examples + /// + /// ``` + /// use kernel::ptr::Alignment; + /// + /// assert_eq!(Alignment::new::<0x10>().mask(), !0xf); + /// ``` + #[inline(always)] + pub const fn mask(self) -> usize { + // No underflow can occur as the alignment is guaranteed to be a power of two, and thus is + // non-zero. + !(self.as_usize() - 1) + } +} + +/// Trait for items that can be aligned against an [`Alignment`]. +pub trait Alignable: Sized { + /// Aligns `self` down to `alignment`. + /// + /// # Examples + /// + /// ``` + /// use kernel::ptr::{Alignable, Alignment}; + /// + /// assert_eq!(0x2f_usize.align_down(Alignment::new::<0x10>()), 0x20); + /// assert_eq!(0x30usize.align_down(Alignment::new::<0x10>()), 0x30); + /// assert_eq!(0xf0u8.align_down(Alignment::new::<0x1000>()), 0x0); + /// ``` + fn align_down(self, alignment: Alignment) -> Self; + + /// Aligns `self` up to `alignment`, returning `None` if aligning would result in an overflow. + /// + /// # Examples + /// + /// ``` + /// use kernel::ptr::{Alignable, Alignment}; + /// + /// assert_eq!(0x4fusize.align_up(Alignment::new::<0x10>()), Some(0x50)); + /// assert_eq!(0x40usize.align_up(Alignment::new::<0x10>()), Some(0x40)); + /// assert_eq!(0x0usize.align_up(Alignment::new::<0x10>()), Some(0x0)); + /// assert_eq!(u8::MAX.align_up(Alignment::new::<0x10>()), None); + /// assert_eq!(0x10u8.align_up(Alignment::new::<0x100>()), None); + /// assert_eq!(0x0u8.align_up(Alignment::new::<0x100>()), Some(0x0)); + /// ``` + fn align_up(self, alignment: Alignment) -> Option<Self>; +} + +/// Implement [`Alignable`] for unsigned integer types. +macro_rules! impl_alignable_uint { + ($($t:ty),*) => { + $( + impl Alignable for $t { + #[inline(always)] + fn align_down(self, alignment: Alignment) -> Self { + // The operands of `&` need to be of the same type so convert the alignment to + // `Self`. This means we need to compute the mask ourselves. + ::core::num::NonZero::<Self>::try_from(alignment.as_nonzero()) + .map(|align| self & !(align.get() - 1)) + // An alignment larger than `Self` always aligns down to `0`. + .unwrap_or(0) + } + + #[inline(always)] + fn align_up(self, alignment: Alignment) -> Option<Self> { + let aligned_down = self.align_down(alignment); + if self == aligned_down { + Some(aligned_down) + } else { + Self::try_from(alignment.as_usize()) + .ok() + .and_then(|align| aligned_down.checked_add(align)) + } + } + } + )* + }; +} + +impl_alignable_uint!(u8, u16, u32, u64, usize); diff --git a/rust/kernel/pwm.rs b/rust/kernel/pwm.rs new file mode 100644 index 000000000000..cb00f8a8765c --- /dev/null +++ b/rust/kernel/pwm.rs @@ -0,0 +1,735 @@ +// SPDX-License-Identifier: GPL-2.0 +// Copyright (c) 2025 Samsung Electronics Co., Ltd. +// Author: Michal Wilczynski <m.wilczynski@samsung.com> + +//! PWM subsystem abstractions. +//! +//! C header: [`include/linux/pwm.h`](srctree/include/linux/pwm.h). + +use crate::{ + bindings, + container_of, + device::{self, Bound}, + devres, + error::{self, to_result}, + prelude::*, + types::{ARef, AlwaysRefCounted, Opaque}, // +}; +use core::{marker::PhantomData, ptr::NonNull}; + +/// Represents a PWM waveform configuration. +/// Mirrors struct [`struct pwm_waveform`](srctree/include/linux/pwm.h). +#[derive(Copy, Clone, Debug, Default, PartialEq, Eq)] +pub struct Waveform { + /// Total duration of one complete PWM cycle, in nanoseconds. + pub period_length_ns: u64, + + /// Duty-cycle active time, in nanoseconds. + /// + /// For a typical normal polarity configuration (active-high) this is the + /// high time of the signal. + pub duty_length_ns: u64, + + /// Duty-cycle start offset, in nanoseconds. + /// + /// Delay from the beginning of the period to the first active edge. + /// In most simple PWM setups this is `0`, so the duty cycle starts + /// immediately at each period’s start. + pub duty_offset_ns: u64, +} + +impl From<bindings::pwm_waveform> for Waveform { + fn from(wf: bindings::pwm_waveform) -> Self { + Waveform { + period_length_ns: wf.period_length_ns, + duty_length_ns: wf.duty_length_ns, + duty_offset_ns: wf.duty_offset_ns, + } + } +} + +impl From<Waveform> for bindings::pwm_waveform { + fn from(wf: Waveform) -> Self { + bindings::pwm_waveform { + period_length_ns: wf.period_length_ns, + duty_length_ns: wf.duty_length_ns, + duty_offset_ns: wf.duty_offset_ns, + } + } +} + +/// Describes the outcome of a `round_waveform` operation. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum RoundingOutcome { + /// The requested waveform was achievable exactly or by rounding values down. + ExactOrRoundedDown, + + /// The requested waveform could only be achieved by rounding up. + RoundedUp, +} + +/// Wrapper for a PWM device [`struct pwm_device`](srctree/include/linux/pwm.h). +#[repr(transparent)] +pub struct Device(Opaque<bindings::pwm_device>); + +impl Device { + /// Creates a reference to a [`Device`] from a valid C pointer. + /// + /// # Safety + /// + /// The caller must ensure that `ptr` is valid and remains valid for the lifetime of the + /// returned [`Device`] reference. + pub(crate) unsafe fn from_raw<'a>(ptr: *mut bindings::pwm_device) -> &'a Self { + // SAFETY: The safety requirements guarantee the validity of the dereference, while the + // `Device` type being transparent makes the cast ok. + unsafe { &*ptr.cast::<Self>() } + } + + /// Returns a raw pointer to the underlying `pwm_device`. + fn as_raw(&self) -> *mut bindings::pwm_device { + self.0.get() + } + + /// Gets the hardware PWM index for this device within its chip. + pub fn hwpwm(&self) -> u32 { + // SAFETY: `self.as_raw()` provides a valid pointer for `self`'s lifetime. + unsafe { (*self.as_raw()).hwpwm } + } + + /// Gets a reference to the parent `Chip` that this device belongs to. + pub fn chip<T: PwmOps>(&self) -> &Chip<T> { + // SAFETY: `self.as_raw()` provides a valid pointer. (*self.as_raw()).chip + // is assumed to be a valid pointer to `pwm_chip` managed by the kernel. + // Chip::from_raw's safety conditions must be met. + unsafe { Chip::<T>::from_raw((*self.as_raw()).chip) } + } + + /// Gets the label for this PWM device, if any. + pub fn label(&self) -> Option<&CStr> { + // SAFETY: self.as_raw() provides a valid pointer. + let label_ptr = unsafe { (*self.as_raw()).label }; + if label_ptr.is_null() { + return None; + } + + // SAFETY: label_ptr is non-null and points to a C string + // managed by the kernel, valid for the lifetime of the PWM device. + Some(unsafe { CStr::from_char_ptr(label_ptr) }) + } + + /// Sets the PWM waveform configuration and enables the PWM signal. + pub fn set_waveform(&self, wf: &Waveform, exact: bool) -> Result { + let c_wf = bindings::pwm_waveform::from(*wf); + + // SAFETY: `self.as_raw()` provides a valid `*mut pwm_device` pointer. + // `&c_wf` is a valid pointer to a `pwm_waveform` struct. The C function + // handles all necessary internal locking. + let ret = unsafe { bindings::pwm_set_waveform_might_sleep(self.as_raw(), &c_wf, exact) }; + to_result(ret) + } + + /// Queries the hardware for the configuration it would apply for a given + /// request. + pub fn round_waveform(&self, wf: &mut Waveform) -> Result<RoundingOutcome> { + let mut c_wf = bindings::pwm_waveform::from(*wf); + + // SAFETY: `self.as_raw()` provides a valid `*mut pwm_device` pointer. + // `&mut c_wf` is a valid pointer to a mutable `pwm_waveform` struct that + // the C function will update. + let ret = unsafe { bindings::pwm_round_waveform_might_sleep(self.as_raw(), &mut c_wf) }; + + to_result(ret)?; + + *wf = Waveform::from(c_wf); + + if ret == 1 { + Ok(RoundingOutcome::RoundedUp) + } else { + Ok(RoundingOutcome::ExactOrRoundedDown) + } + } + + /// Reads the current waveform configuration directly from the hardware. + pub fn get_waveform(&self) -> Result<Waveform> { + let mut c_wf = bindings::pwm_waveform::default(); + + // SAFETY: `self.as_raw()` is a valid pointer. We provide a valid pointer + // to a stack-allocated `pwm_waveform` struct for the kernel to fill. + let ret = unsafe { bindings::pwm_get_waveform_might_sleep(self.as_raw(), &mut c_wf) }; + + to_result(ret)?; + + Ok(Waveform::from(c_wf)) + } +} + +/// The result of a `round_waveform_tohw` operation. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct RoundedWaveform<WfHw> { + /// A status code, 0 for success or 1 if values were rounded up. + pub status: c_int, + /// The driver-specific hardware representation of the waveform. + pub hardware_waveform: WfHw, +} + +/// Trait defining the operations for a PWM driver. +pub trait PwmOps: 'static + Sized { + /// The driver-specific hardware representation of a waveform. + /// + /// This type must be [`Copy`], [`Default`], and fit within `PWM_WFHWSIZE`. + type WfHw: Copy + Default; + + /// Optional hook for when a PWM device is requested. + fn request(_chip: &Chip<Self>, _pwm: &Device, _parent_dev: &device::Device<Bound>) -> Result { + Ok(()) + } + + /// Optional hook for capturing a PWM signal. + fn capture( + _chip: &Chip<Self>, + _pwm: &Device, + _result: &mut bindings::pwm_capture, + _timeout: usize, + _parent_dev: &device::Device<Bound>, + ) -> Result { + Err(ENOTSUPP) + } + + /// Convert a generic waveform to the hardware-specific representation. + /// This is typically a pure calculation and does not perform I/O. + fn round_waveform_tohw( + _chip: &Chip<Self>, + _pwm: &Device, + _wf: &Waveform, + ) -> Result<RoundedWaveform<Self::WfHw>> { + Err(ENOTSUPP) + } + + /// Convert a hardware-specific representation back to a generic waveform. + /// This is typically a pure calculation and does not perform I/O. + fn round_waveform_fromhw( + _chip: &Chip<Self>, + _pwm: &Device, + _wfhw: &Self::WfHw, + _wf: &mut Waveform, + ) -> Result { + Err(ENOTSUPP) + } + + /// Read the current hardware configuration into the hardware-specific representation. + fn read_waveform( + _chip: &Chip<Self>, + _pwm: &Device, + _parent_dev: &device::Device<Bound>, + ) -> Result<Self::WfHw> { + Err(ENOTSUPP) + } + + /// Write a hardware-specific waveform configuration to the hardware. + fn write_waveform( + _chip: &Chip<Self>, + _pwm: &Device, + _wfhw: &Self::WfHw, + _parent_dev: &device::Device<Bound>, + ) -> Result { + Err(ENOTSUPP) + } +} + +/// Bridges Rust `PwmOps` to the C `pwm_ops` vtable. +struct Adapter<T: PwmOps> { + _p: PhantomData<T>, +} + +impl<T: PwmOps> Adapter<T> { + const VTABLE: PwmOpsVTable = create_pwm_ops::<T>(); + + /// # Safety + /// + /// `wfhw_ptr` must be valid for writes of `size_of::<T::WfHw>()` bytes. + unsafe fn serialize_wfhw(wfhw: &T::WfHw, wfhw_ptr: *mut c_void) -> Result { + let size = core::mem::size_of::<T::WfHw>(); + + build_assert!(size <= bindings::PWM_WFHWSIZE as usize); + + // SAFETY: The caller ensures `wfhw_ptr` is valid for `size` bytes. + unsafe { + core::ptr::copy_nonoverlapping( + core::ptr::from_ref::<T::WfHw>(wfhw).cast::<u8>(), + wfhw_ptr.cast::<u8>(), + size, + ); + } + + Ok(()) + } + + /// # Safety + /// + /// `wfhw_ptr` must be valid for reads of `size_of::<T::WfHw>()` bytes. + unsafe fn deserialize_wfhw(wfhw_ptr: *const c_void) -> Result<T::WfHw> { + let size = core::mem::size_of::<T::WfHw>(); + + build_assert!(size <= bindings::PWM_WFHWSIZE as usize); + + let mut wfhw = T::WfHw::default(); + // SAFETY: The caller ensures `wfhw_ptr` is valid for `size` bytes. + unsafe { + core::ptr::copy_nonoverlapping( + wfhw_ptr.cast::<u8>(), + core::ptr::from_mut::<T::WfHw>(&mut wfhw).cast::<u8>(), + size, + ); + } + + Ok(wfhw) + } + + /// # Safety + /// + /// `dev` must be a valid pointer to a `bindings::device` embedded within a + /// `bindings::pwm_chip`. This function is called by the device core when the + /// last reference to the device is dropped. + unsafe extern "C" fn release_callback(dev: *mut bindings::device) { + // SAFETY: The function's contract guarantees that `dev` points to a `device` + // field embedded within a valid `pwm_chip`. `container_of!` can therefore + // safely calculate the address of the containing struct. + let c_chip_ptr = unsafe { container_of!(dev, bindings::pwm_chip, dev) }; + + // SAFETY: `c_chip_ptr` is a valid pointer to a `pwm_chip` as established + // above. Calling this FFI function is safe. + let drvdata_ptr = unsafe { bindings::pwmchip_get_drvdata(c_chip_ptr) }; + + // SAFETY: The driver data was initialized in `new`. We run its destructor here. + unsafe { core::ptr::drop_in_place(drvdata_ptr.cast::<T>()) }; + + // Now, call the original release function to free the `pwm_chip` itself. + // SAFETY: `dev` is the valid pointer passed into this callback, which is + // the expected argument for `pwmchip_release`. + unsafe { + bindings::pwmchip_release(dev); + } + } + + /// # Safety + /// + /// Pointers from C must be valid. + unsafe extern "C" fn request_callback( + chip_ptr: *mut bindings::pwm_chip, + pwm_ptr: *mut bindings::pwm_device, + ) -> c_int { + // SAFETY: PWM core guarentees `chip_ptr` and `pwm_ptr` are valid pointers. + let (chip, pwm) = unsafe { (Chip::<T>::from_raw(chip_ptr), Device::from_raw(pwm_ptr)) }; + + // SAFETY: The PWM core guarantees the parent device exists and is bound during callbacks. + let bound_parent = unsafe { chip.bound_parent_device() }; + match T::request(chip, pwm, bound_parent) { + Ok(()) => 0, + Err(e) => e.to_errno(), + } + } + + /// # Safety + /// + /// Pointers from C must be valid. + unsafe extern "C" fn capture_callback( + chip_ptr: *mut bindings::pwm_chip, + pwm_ptr: *mut bindings::pwm_device, + res: *mut bindings::pwm_capture, + timeout: usize, + ) -> c_int { + // SAFETY: Relies on the function's contract that `chip_ptr` and `pwm_ptr` are valid + // pointers. + let (chip, pwm, result) = unsafe { + ( + Chip::<T>::from_raw(chip_ptr), + Device::from_raw(pwm_ptr), + &mut *res, + ) + }; + + // SAFETY: The PWM core guarantees the parent device exists and is bound during callbacks. + let bound_parent = unsafe { chip.bound_parent_device() }; + match T::capture(chip, pwm, result, timeout, bound_parent) { + Ok(()) => 0, + Err(e) => e.to_errno(), + } + } + + /// # Safety + /// + /// Pointers from C must be valid. + unsafe extern "C" fn round_waveform_tohw_callback( + chip_ptr: *mut bindings::pwm_chip, + pwm_ptr: *mut bindings::pwm_device, + wf_ptr: *const bindings::pwm_waveform, + wfhw_ptr: *mut c_void, + ) -> c_int { + // SAFETY: Relies on the function's contract that `chip_ptr` and `pwm_ptr` are valid + // pointers. + let (chip, pwm, wf) = unsafe { + ( + Chip::<T>::from_raw(chip_ptr), + Device::from_raw(pwm_ptr), + Waveform::from(*wf_ptr), + ) + }; + match T::round_waveform_tohw(chip, pwm, &wf) { + Ok(rounded) => { + // SAFETY: `wfhw_ptr` is valid per this function's safety contract. + if unsafe { Self::serialize_wfhw(&rounded.hardware_waveform, wfhw_ptr) }.is_err() { + return EINVAL.to_errno(); + } + rounded.status + } + Err(e) => e.to_errno(), + } + } + + /// # Safety + /// + /// Pointers from C must be valid. + unsafe extern "C" fn round_waveform_fromhw_callback( + chip_ptr: *mut bindings::pwm_chip, + pwm_ptr: *mut bindings::pwm_device, + wfhw_ptr: *const c_void, + wf_ptr: *mut bindings::pwm_waveform, + ) -> c_int { + // SAFETY: Relies on the function's contract that `chip_ptr` and `pwm_ptr` are valid + // pointers. + let (chip, pwm) = unsafe { (Chip::<T>::from_raw(chip_ptr), Device::from_raw(pwm_ptr)) }; + // SAFETY: `deserialize_wfhw`'s safety contract is met by this function's contract. + let wfhw = match unsafe { Self::deserialize_wfhw(wfhw_ptr) } { + Ok(v) => v, + Err(e) => return e.to_errno(), + }; + + let mut rust_wf = Waveform::default(); + match T::round_waveform_fromhw(chip, pwm, &wfhw, &mut rust_wf) { + Ok(()) => { + // SAFETY: `wf_ptr` is guaranteed valid by the C caller. + unsafe { + *wf_ptr = rust_wf.into(); + }; + 0 + } + Err(e) => e.to_errno(), + } + } + + /// # Safety + /// + /// Pointers from C must be valid. + unsafe extern "C" fn read_waveform_callback( + chip_ptr: *mut bindings::pwm_chip, + pwm_ptr: *mut bindings::pwm_device, + wfhw_ptr: *mut c_void, + ) -> c_int { + // SAFETY: Relies on the function's contract that `chip_ptr` and `pwm_ptr` are valid + // pointers. + let (chip, pwm) = unsafe { (Chip::<T>::from_raw(chip_ptr), Device::from_raw(pwm_ptr)) }; + + // SAFETY: The PWM core guarantees the parent device exists and is bound during callbacks. + let bound_parent = unsafe { chip.bound_parent_device() }; + match T::read_waveform(chip, pwm, bound_parent) { + // SAFETY: `wfhw_ptr` is valid per this function's safety contract. + Ok(wfhw) => match unsafe { Self::serialize_wfhw(&wfhw, wfhw_ptr) } { + Ok(()) => 0, + Err(e) => e.to_errno(), + }, + Err(e) => e.to_errno(), + } + } + + /// # Safety + /// + /// Pointers from C must be valid. + unsafe extern "C" fn write_waveform_callback( + chip_ptr: *mut bindings::pwm_chip, + pwm_ptr: *mut bindings::pwm_device, + wfhw_ptr: *const c_void, + ) -> c_int { + // SAFETY: Relies on the function's contract that `chip_ptr` and `pwm_ptr` are valid + // pointers. + let (chip, pwm) = unsafe { (Chip::<T>::from_raw(chip_ptr), Device::from_raw(pwm_ptr)) }; + + // SAFETY: The PWM core guarantees the parent device exists and is bound during callbacks. + let bound_parent = unsafe { chip.bound_parent_device() }; + + // SAFETY: `wfhw_ptr` is valid per this function's safety contract. + let wfhw = match unsafe { Self::deserialize_wfhw(wfhw_ptr) } { + Ok(v) => v, + Err(e) => return e.to_errno(), + }; + match T::write_waveform(chip, pwm, &wfhw, bound_parent) { + Ok(()) => 0, + Err(e) => e.to_errno(), + } + } +} + +/// VTable structure wrapper for PWM operations. +/// Mirrors [`struct pwm_ops`](srctree/include/linux/pwm.h). +#[repr(transparent)] +pub struct PwmOpsVTable(bindings::pwm_ops); + +// SAFETY: PwmOpsVTable is Send. The vtable contains only function pointers +// and a size, which are simple data types that can be safely moved across +// threads. The thread-safety of calling these functions is handled by the +// kernel's locking mechanisms. +unsafe impl Send for PwmOpsVTable {} + +// SAFETY: PwmOpsVTable is Sync. The vtable is immutable after it is created, +// so it can be safely referenced and accessed concurrently by multiple threads +// e.g. to read the function pointers. +unsafe impl Sync for PwmOpsVTable {} + +impl PwmOpsVTable { + /// Returns a raw pointer to the underlying `pwm_ops` struct. + pub(crate) fn as_raw(&self) -> *const bindings::pwm_ops { + &self.0 + } +} + +/// Creates a PWM operations vtable for a type `T` that implements `PwmOps`. +/// +/// This is used to bridge Rust trait implementations to the C `struct pwm_ops` +/// expected by the kernel. +pub const fn create_pwm_ops<T: PwmOps>() -> PwmOpsVTable { + // SAFETY: `core::mem::zeroed()` is unsafe. For `pwm_ops`, all fields are + // `Option<extern "C" fn(...)>` or data, so a zeroed pattern (None/0) is valid initially. + let mut ops: bindings::pwm_ops = unsafe { core::mem::zeroed() }; + + ops.request = Some(Adapter::<T>::request_callback); + ops.capture = Some(Adapter::<T>::capture_callback); + + ops.round_waveform_tohw = Some(Adapter::<T>::round_waveform_tohw_callback); + ops.round_waveform_fromhw = Some(Adapter::<T>::round_waveform_fromhw_callback); + ops.read_waveform = Some(Adapter::<T>::read_waveform_callback); + ops.write_waveform = Some(Adapter::<T>::write_waveform_callback); + ops.sizeof_wfhw = core::mem::size_of::<T::WfHw>(); + + PwmOpsVTable(ops) +} + +/// Wrapper for a PWM chip/controller ([`struct pwm_chip`](srctree/include/linux/pwm.h)). +#[repr(transparent)] +pub struct Chip<T: PwmOps>(Opaque<bindings::pwm_chip>, PhantomData<T>); + +impl<T: PwmOps> Chip<T> { + /// Creates a reference to a [`Chip`] from a valid pointer. + /// + /// # Safety + /// + /// The caller must ensure that `ptr` is valid and remains valid for the lifetime of the + /// returned [`Chip`] reference. + pub(crate) unsafe fn from_raw<'a>(ptr: *mut bindings::pwm_chip) -> &'a Self { + // SAFETY: The safety requirements guarantee the validity of the dereference, while the + // `Chip` type being transparent makes the cast ok. + unsafe { &*ptr.cast::<Self>() } + } + + /// Returns a raw pointer to the underlying `pwm_chip`. + pub(crate) fn as_raw(&self) -> *mut bindings::pwm_chip { + self.0.get() + } + + /// Gets the number of PWM channels (hardware PWMs) on this chip. + pub fn num_channels(&self) -> u32 { + // SAFETY: `self.as_raw()` provides a valid pointer for `self`'s lifetime. + unsafe { (*self.as_raw()).npwm } + } + + /// Returns `true` if the chip supports atomic operations for configuration. + pub fn is_atomic(&self) -> bool { + // SAFETY: `self.as_raw()` provides a valid pointer for `self`'s lifetime. + unsafe { (*self.as_raw()).atomic } + } + + /// Returns a reference to the embedded `struct device` abstraction. + pub fn device(&self) -> &device::Device { + // SAFETY: + // - `self.as_raw()` provides a valid pointer to `bindings::pwm_chip`. + // - The `dev` field is an instance of `bindings::device` embedded + // within `pwm_chip`. + // - Taking a pointer to this embedded field is valid. + // - `device::Device` is `#[repr(transparent)]`. + // - The lifetime of the returned reference is tied to `self`. + unsafe { device::Device::from_raw(&raw mut (*self.as_raw()).dev) } + } + + /// Gets the typed driver specific data associated with this chip's embedded device. + pub fn drvdata(&self) -> &T { + // SAFETY: `pwmchip_get_drvdata` returns the pointer to the private data area, + // which we know holds our `T`. The pointer is valid for the lifetime of `self`. + unsafe { &*bindings::pwmchip_get_drvdata(self.as_raw()).cast::<T>() } + } + + /// Returns a reference to the parent device of this PWM chip's device. + /// + /// # Safety + /// + /// The caller must guarantee that the parent device exists and is bound. + /// This is guaranteed by the PWM core during `PwmOps` callbacks. + unsafe fn bound_parent_device(&self) -> &device::Device<Bound> { + // SAFETY: Per the function's safety contract, the parent device exists. + let parent = unsafe { self.device().parent().unwrap_unchecked() }; + + // SAFETY: Per the function's safety contract, the parent device is bound. + // This is guaranteed by the PWM core during `PwmOps` callbacks. + unsafe { parent.as_bound() } + } + + /// Allocates and wraps a PWM chip using `bindings::pwmchip_alloc`. + /// + /// Returns an [`ARef<Chip>`] managing the chip's lifetime via refcounting + /// on its embedded `struct device`. + pub fn new( + parent_dev: &device::Device, + num_channels: u32, + data: impl pin_init::PinInit<T, Error>, + ) -> Result<ARef<Self>> { + let sizeof_priv = core::mem::size_of::<T>(); + // SAFETY: `pwmchip_alloc` allocates memory for the C struct and our private data. + let c_chip_ptr_raw = + unsafe { bindings::pwmchip_alloc(parent_dev.as_raw(), num_channels, sizeof_priv) }; + + let c_chip_ptr: *mut bindings::pwm_chip = error::from_err_ptr(c_chip_ptr_raw)?; + + // SAFETY: The `drvdata` pointer is the start of the private area, which is where + // we will construct our `T` object. + let drvdata_ptr = unsafe { bindings::pwmchip_get_drvdata(c_chip_ptr) }; + + // SAFETY: We construct the `T` object in-place in the allocated private memory. + unsafe { data.__pinned_init(drvdata_ptr.cast())? }; + + // SAFETY: `c_chip_ptr` points to a valid chip. + unsafe { + (*c_chip_ptr).dev.release = Some(Adapter::<T>::release_callback); + } + + // SAFETY: `c_chip_ptr` points to a valid chip. + // The `Adapter`'s `VTABLE` has a 'static lifetime, so the pointer + // returned by `as_raw()` is always valid. + unsafe { + (*c_chip_ptr).ops = Adapter::<T>::VTABLE.as_raw(); + } + + // Cast the `*mut bindings::pwm_chip` to `*mut Chip`. This is valid because + // `Chip` is `repr(transparent)` over `Opaque<bindings::pwm_chip>`, and + // `Opaque<T>` is `repr(transparent)` over `T`. + let chip_ptr_as_self = c_chip_ptr.cast::<Self>(); + + // SAFETY: `chip_ptr_as_self` points to a valid `Chip` (layout-compatible with + // `bindings::pwm_chip`) whose embedded device has refcount 1. + // `ARef::from_raw` takes this pointer and manages it via `AlwaysRefCounted`. + Ok(unsafe { ARef::from_raw(NonNull::new_unchecked(chip_ptr_as_self)) }) + } +} + +// SAFETY: Implements refcounting for `Chip` using the embedded `struct device`. +unsafe impl<T: PwmOps> AlwaysRefCounted for Chip<T> { + #[inline] + fn inc_ref(&self) { + // SAFETY: `self.0.get()` points to a valid `pwm_chip` because `self` exists. + // The embedded `dev` is valid. `get_device` increments its refcount. + unsafe { + bindings::get_device(&raw mut (*self.0.get()).dev); + } + } + + #[inline] + unsafe fn dec_ref(obj: NonNull<Chip<T>>) { + let c_chip_ptr = obj.cast::<bindings::pwm_chip>().as_ptr(); + + // SAFETY: `obj` is a valid pointer to a `Chip` (and thus `bindings::pwm_chip`) + // with a non-zero refcount. `put_device` handles decrement and final release. + unsafe { + bindings::put_device(&raw mut (*c_chip_ptr).dev); + } + } +} + +// SAFETY: `Chip` is a wrapper around `*mut bindings::pwm_chip`. The underlying C +// structure's state is managed and synchronized by the kernel's device model +// and PWM core locking mechanisms. Therefore, it is safe to move the `Chip` +// wrapper (and the pointer it contains) across threads. +unsafe impl<T: PwmOps + Send> Send for Chip<T> {} + +// SAFETY: It is safe for multiple threads to have shared access (`&Chip`) because +// the `Chip` data is immutable from the Rust side without holding the appropriate +// kernel locks, which the C core is responsible for. Any interior mutability is +// handled and synchronized by the C kernel code. +unsafe impl<T: PwmOps + Sync> Sync for Chip<T> {} + +/// A resource guard that ensures `pwmchip_remove` is called on drop. +/// +/// This struct is intended to be managed by the `devres` framework by transferring its ownership +/// via [`devres::register`]. This ties the lifetime of the PWM chip registration +/// to the lifetime of the underlying device. +pub struct Registration<T: PwmOps> { + chip: ARef<Chip<T>>, +} + +impl<T: 'static + PwmOps + Send + Sync> Registration<T> { + /// Registers a PWM chip with the PWM subsystem. + /// + /// Transfers its ownership to the `devres` framework, which ties its lifetime + /// to the parent device. + /// On unbind of the parent device, the `devres` entry will be dropped, automatically + /// calling `pwmchip_remove`. This function should be called from the driver's `probe`. + pub fn register(dev: &device::Device<Bound>, chip: ARef<Chip<T>>) -> Result { + let chip_parent = chip.device().parent().ok_or(EINVAL)?; + if dev.as_raw() != chip_parent.as_raw() { + return Err(EINVAL); + } + + let c_chip_ptr = chip.as_raw(); + + // SAFETY: `c_chip_ptr` points to a valid chip with its ops initialized. + // `__pwmchip_add` is the C function to register the chip with the PWM core. + unsafe { + to_result(bindings::__pwmchip_add(c_chip_ptr, core::ptr::null_mut()))?; + } + + let registration = Registration { chip }; + + devres::register(dev, registration, GFP_KERNEL) + } +} + +impl<T: PwmOps> Drop for Registration<T> { + fn drop(&mut self) { + let chip_raw = self.chip.as_raw(); + + // SAFETY: `chip_raw` points to a chip that was successfully registered. + // `bindings::pwmchip_remove` is the correct C function to unregister it. + // This `drop` implementation is called automatically by `devres` on driver unbind. + unsafe { + bindings::pwmchip_remove(chip_raw); + } + } +} + +/// Declares a kernel module that exposes a single PWM driver. +/// +/// # Examples +/// +///```ignore +/// kernel::module_pwm_platform_driver! { +/// type: MyDriver, +/// name: "Module name", +/// authors: ["Author name"], +/// description: "Description", +/// license: "GPL v2", +/// } +///``` +#[macro_export] +macro_rules! module_pwm_platform_driver { + ($($user_args:tt)*) => { + $crate::module_platform_driver! { + $($user_args)* + imports_ns: ["PWM"], + } + }; +} diff --git a/rust/kernel/rbtree.rs b/rust/kernel/rbtree.rs new file mode 100644 index 000000000000..4729eb56827a --- /dev/null +++ b/rust/kernel/rbtree.rs @@ -0,0 +1,1430 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! Red-black trees. +//! +//! C header: [`include/linux/rbtree.h`](srctree/include/linux/rbtree.h) +//! +//! Reference: <https://docs.kernel.org/core-api/rbtree.html> + +use crate::{alloc::Flags, bindings, container_of, error::Result, prelude::*}; +use core::{ + cmp::{Ord, Ordering}, + marker::PhantomData, + mem::MaybeUninit, + ptr::{addr_of_mut, from_mut, NonNull}, +}; + +/// A red-black tree with owned nodes. +/// +/// It is backed by the kernel C red-black trees. +/// +/// # Examples +/// +/// In the example below we do several operations on a tree. We note that insertions may fail if +/// the system is out of memory. +/// +/// ``` +/// use kernel::{alloc::flags, rbtree::{RBTree, RBTreeNode, RBTreeNodeReservation}}; +/// +/// // Create a new tree. +/// let mut tree = RBTree::new(); +/// +/// // Insert three elements. +/// tree.try_create_and_insert(20, 200, flags::GFP_KERNEL)?; +/// tree.try_create_and_insert(10, 100, flags::GFP_KERNEL)?; +/// tree.try_create_and_insert(30, 300, flags::GFP_KERNEL)?; +/// +/// // Check the nodes we just inserted. +/// { +/// assert_eq!(tree.get(&10), Some(&100)); +/// assert_eq!(tree.get(&20), Some(&200)); +/// assert_eq!(tree.get(&30), Some(&300)); +/// } +/// +/// // Iterate over the nodes we just inserted. +/// { +/// let mut iter = tree.iter(); +/// assert_eq!(iter.next(), Some((&10, &100))); +/// assert_eq!(iter.next(), Some((&20, &200))); +/// assert_eq!(iter.next(), Some((&30, &300))); +/// assert!(iter.next().is_none()); +/// } +/// +/// // Print all elements. +/// for (key, value) in &tree { +/// pr_info!("{} = {}\n", key, value); +/// } +/// +/// // Replace one of the elements. +/// tree.try_create_and_insert(10, 1000, flags::GFP_KERNEL)?; +/// +/// // Check that the tree reflects the replacement. +/// { +/// let mut iter = tree.iter(); +/// assert_eq!(iter.next(), Some((&10, &1000))); +/// assert_eq!(iter.next(), Some((&20, &200))); +/// assert_eq!(iter.next(), Some((&30, &300))); +/// assert!(iter.next().is_none()); +/// } +/// +/// // Change the value of one of the elements. +/// *tree.get_mut(&30).unwrap() = 3000; +/// +/// // Check that the tree reflects the update. +/// { +/// let mut iter = tree.iter(); +/// assert_eq!(iter.next(), Some((&10, &1000))); +/// assert_eq!(iter.next(), Some((&20, &200))); +/// assert_eq!(iter.next(), Some((&30, &3000))); +/// assert!(iter.next().is_none()); +/// } +/// +/// // Remove an element. +/// tree.remove(&10); +/// +/// // Check that the tree reflects the removal. +/// { +/// let mut iter = tree.iter(); +/// assert_eq!(iter.next(), Some((&20, &200))); +/// assert_eq!(iter.next(), Some((&30, &3000))); +/// assert!(iter.next().is_none()); +/// } +/// +/// # Ok::<(), Error>(()) +/// ``` +/// +/// In the example below, we first allocate a node, acquire a spinlock, then insert the node into +/// the tree. This is useful when the insertion context does not allow sleeping, for example, when +/// holding a spinlock. +/// +/// ``` +/// use kernel::{alloc::flags, rbtree::{RBTree, RBTreeNode}, sync::SpinLock}; +/// +/// fn insert_test(tree: &SpinLock<RBTree<u32, u32>>) -> Result { +/// // Pre-allocate node. This may fail (as it allocates memory). +/// let node = RBTreeNode::new(10, 100, flags::GFP_KERNEL)?; +/// +/// // Insert node while holding the lock. It is guaranteed to succeed with no allocation +/// // attempts. +/// let mut guard = tree.lock(); +/// guard.insert(node); +/// Ok(()) +/// } +/// ``` +/// +/// In the example below, we reuse an existing node allocation from an element we removed. +/// +/// ``` +/// use kernel::{alloc::flags, rbtree::{RBTree, RBTreeNodeReservation}}; +/// +/// // Create a new tree. +/// let mut tree = RBTree::new(); +/// +/// // Insert three elements. +/// tree.try_create_and_insert(20, 200, flags::GFP_KERNEL)?; +/// tree.try_create_and_insert(10, 100, flags::GFP_KERNEL)?; +/// tree.try_create_and_insert(30, 300, flags::GFP_KERNEL)?; +/// +/// // Check the nodes we just inserted. +/// { +/// let mut iter = tree.iter(); +/// assert_eq!(iter.next(), Some((&10, &100))); +/// assert_eq!(iter.next(), Some((&20, &200))); +/// assert_eq!(iter.next(), Some((&30, &300))); +/// assert!(iter.next().is_none()); +/// } +/// +/// // Remove a node, getting back ownership of it. +/// let existing = tree.remove(&30); +/// +/// // Check that the tree reflects the removal. +/// { +/// let mut iter = tree.iter(); +/// assert_eq!(iter.next(), Some((&10, &100))); +/// assert_eq!(iter.next(), Some((&20, &200))); +/// assert!(iter.next().is_none()); +/// } +/// +/// // Create a preallocated reservation that we can re-use later. +/// let reservation = RBTreeNodeReservation::new(flags::GFP_KERNEL)?; +/// +/// // Insert a new node into the tree, reusing the previous allocation. This is guaranteed to +/// // succeed (no memory allocations). +/// tree.insert(reservation.into_node(15, 150)); +/// +/// // Check that the tree reflect the new insertion. +/// { +/// let mut iter = tree.iter(); +/// assert_eq!(iter.next(), Some((&10, &100))); +/// assert_eq!(iter.next(), Some((&15, &150))); +/// assert_eq!(iter.next(), Some((&20, &200))); +/// assert!(iter.next().is_none()); +/// } +/// +/// # Ok::<(), Error>(()) +/// ``` +/// +/// # Invariants +/// +/// Non-null parent/children pointers stored in instances of the `rb_node` C struct are always +/// valid, and pointing to a field of our internal representation of a node. +pub struct RBTree<K, V> { + root: bindings::rb_root, + _p: PhantomData<Node<K, V>>, +} + +// SAFETY: An [`RBTree`] allows the same kinds of access to its values that a struct allows to its +// fields, so we use the same Send condition as would be used for a struct with K and V fields. +unsafe impl<K: Send, V: Send> Send for RBTree<K, V> {} + +// SAFETY: An [`RBTree`] allows the same kinds of access to its values that a struct allows to its +// fields, so we use the same Sync condition as would be used for a struct with K and V fields. +unsafe impl<K: Sync, V: Sync> Sync for RBTree<K, V> {} + +impl<K, V> RBTree<K, V> { + /// Creates a new and empty tree. + pub fn new() -> Self { + Self { + // INVARIANT: There are no nodes in the tree, so the invariant holds vacuously. + root: bindings::rb_root::default(), + _p: PhantomData, + } + } + + /// Returns true if this tree is empty. + #[inline] + pub fn is_empty(&self) -> bool { + self.root.rb_node.is_null() + } + + /// Returns an iterator over the tree nodes, sorted by key. + pub fn iter(&self) -> Iter<'_, K, V> { + Iter { + _tree: PhantomData, + // INVARIANT: + // - `self.root` is a valid pointer to a tree root. + // - `bindings::rb_first` produces a valid pointer to a node given `root` is valid. + iter_raw: IterRaw { + // SAFETY: by the invariants, all pointers are valid. + next: unsafe { bindings::rb_first(&self.root) }, + _phantom: PhantomData, + }, + } + } + + /// Returns a mutable iterator over the tree nodes, sorted by key. + pub fn iter_mut(&mut self) -> IterMut<'_, K, V> { + IterMut { + _tree: PhantomData, + // INVARIANT: + // - `self.root` is a valid pointer to a tree root. + // - `bindings::rb_first` produces a valid pointer to a node given `root` is valid. + iter_raw: IterRaw { + // SAFETY: by the invariants, all pointers are valid. + next: unsafe { bindings::rb_first(from_mut(&mut self.root)) }, + _phantom: PhantomData, + }, + } + } + + /// Returns an iterator over the keys of the nodes in the tree, in sorted order. + pub fn keys(&self) -> impl Iterator<Item = &'_ K> { + self.iter().map(|(k, _)| k) + } + + /// Returns an iterator over the values of the nodes in the tree, sorted by key. + pub fn values(&self) -> impl Iterator<Item = &'_ V> { + self.iter().map(|(_, v)| v) + } + + /// Returns a mutable iterator over the values of the nodes in the tree, sorted by key. + pub fn values_mut(&mut self) -> impl Iterator<Item = &'_ mut V> { + self.iter_mut().map(|(_, v)| v) + } + + /// Returns a cursor over the tree nodes, starting with the smallest key. + pub fn cursor_front_mut(&mut self) -> Option<CursorMut<'_, K, V>> { + let root = addr_of_mut!(self.root); + // SAFETY: `self.root` is always a valid root node. + let current = unsafe { bindings::rb_first(root) }; + NonNull::new(current).map(|current| { + // INVARIANT: + // - `current` is a valid node in the [`RBTree`] pointed to by `self`. + CursorMut { + current, + tree: self, + } + }) + } + + /// Returns an immutable cursor over the tree nodes, starting with the smallest key. + pub fn cursor_front(&self) -> Option<Cursor<'_, K, V>> { + let root = &raw const self.root; + // SAFETY: `self.root` is always a valid root node. + let current = unsafe { bindings::rb_first(root) }; + NonNull::new(current).map(|current| { + // INVARIANT: + // - `current` is a valid node in the [`RBTree`] pointed to by `self`. + Cursor { + current, + _tree: PhantomData, + } + }) + } + + /// Returns a cursor over the tree nodes, starting with the largest key. + pub fn cursor_back_mut(&mut self) -> Option<CursorMut<'_, K, V>> { + let root = addr_of_mut!(self.root); + // SAFETY: `self.root` is always a valid root node. + let current = unsafe { bindings::rb_last(root) }; + NonNull::new(current).map(|current| { + // INVARIANT: + // - `current` is a valid node in the [`RBTree`] pointed to by `self`. + CursorMut { + current, + tree: self, + } + }) + } + + /// Returns a cursor over the tree nodes, starting with the largest key. + pub fn cursor_back(&self) -> Option<Cursor<'_, K, V>> { + let root = &raw const self.root; + // SAFETY: `self.root` is always a valid root node. + let current = unsafe { bindings::rb_last(root) }; + NonNull::new(current).map(|current| { + // INVARIANT: + // - `current` is a valid node in the [`RBTree`] pointed to by `self`. + Cursor { + current, + _tree: PhantomData, + } + }) + } +} + +impl<K, V> RBTree<K, V> +where + K: Ord, +{ + /// Tries to insert a new value into the tree. + /// + /// It overwrites a node if one already exists with the same key and returns it (containing the + /// key/value pair). Returns [`None`] if a node with the same key didn't already exist. + /// + /// Returns an error if it cannot allocate memory for the new node. + pub fn try_create_and_insert( + &mut self, + key: K, + value: V, + flags: Flags, + ) -> Result<Option<RBTreeNode<K, V>>> { + Ok(self.insert(RBTreeNode::new(key, value, flags)?)) + } + + /// Inserts a new node into the tree. + /// + /// It overwrites a node if one already exists with the same key and returns it (containing the + /// key/value pair). Returns [`None`] if a node with the same key didn't already exist. + /// + /// This function always succeeds. + pub fn insert(&mut self, node: RBTreeNode<K, V>) -> Option<RBTreeNode<K, V>> { + match self.raw_entry(&node.node.key) { + RawEntry::Occupied(entry) => Some(entry.replace(node)), + RawEntry::Vacant(entry) => { + entry.insert(node); + None + } + } + } + + fn raw_entry(&mut self, key: &K) -> RawEntry<'_, K, V> { + let raw_self: *mut RBTree<K, V> = self; + // The returned `RawEntry` is used to call either `rb_link_node` or `rb_replace_node`. + // The parameters of `bindings::rb_link_node` are as follows: + // - `node`: A pointer to an uninitialized node being inserted. + // - `parent`: A pointer to an existing node in the tree. One of its child pointers must be + // null, and `node` will become a child of `parent` by replacing that child pointer + // with a pointer to `node`. + // - `rb_link`: A pointer to either the left-child or right-child field of `parent`. This + // specifies which child of `parent` should hold `node` after this call. The + // value of `*rb_link` must be null before the call to `rb_link_node`. If the + // red/black tree is empty, then it’s also possible for `parent` to be null. In + // this case, `rb_link` is a pointer to the `root` field of the red/black tree. + // + // We will traverse the tree looking for a node that has a null pointer as its child, + // representing an empty subtree where we can insert our new node. We need to make sure + // that we preserve the ordering of the nodes in the tree. In each iteration of the loop + // we store `parent` and `child_field_of_parent`, and the new `node` will go somewhere + // in the subtree of `parent` that `child_field_of_parent` points at. Once + // we find an empty subtree, we can insert the new node using `rb_link_node`. + let mut parent = core::ptr::null_mut(); + let mut child_field_of_parent: &mut *mut bindings::rb_node = + // SAFETY: `raw_self` is a valid pointer to the `RBTree` (created from `self` above). + unsafe { &mut (*raw_self).root.rb_node }; + while !(*child_field_of_parent).is_null() { + let curr = *child_field_of_parent; + // SAFETY: All links fields we create are in a `Node<K, V>`. + let node = unsafe { container_of!(curr, Node<K, V>, links) }; + + // SAFETY: `node` is a non-null node so it is valid by the type invariants. + match key.cmp(unsafe { &(*node).key }) { + // SAFETY: `curr` is a non-null node so it is valid by the type invariants. + Ordering::Less => child_field_of_parent = unsafe { &mut (*curr).rb_left }, + // SAFETY: `curr` is a non-null node so it is valid by the type invariants. + Ordering::Greater => child_field_of_parent = unsafe { &mut (*curr).rb_right }, + Ordering::Equal => { + return RawEntry::Occupied(OccupiedEntry { + rbtree: self, + node_links: curr, + }) + } + } + parent = curr; + } + + RawEntry::Vacant(RawVacantEntry { + rbtree: raw_self, + parent, + child_field_of_parent, + _phantom: PhantomData, + }) + } + + /// Gets the given key's corresponding entry in the map for in-place manipulation. + pub fn entry(&mut self, key: K) -> Entry<'_, K, V> { + match self.raw_entry(&key) { + RawEntry::Occupied(entry) => Entry::Occupied(entry), + RawEntry::Vacant(entry) => Entry::Vacant(VacantEntry { raw: entry, key }), + } + } + + /// Used for accessing the given node, if it exists. + pub fn find_mut(&mut self, key: &K) -> Option<OccupiedEntry<'_, K, V>> { + match self.raw_entry(key) { + RawEntry::Occupied(entry) => Some(entry), + RawEntry::Vacant(_entry) => None, + } + } + + /// Returns a reference to the value corresponding to the key. + pub fn get(&self, key: &K) -> Option<&V> { + let mut node = self.root.rb_node; + while !node.is_null() { + // SAFETY: By the type invariant of `Self`, all non-null `rb_node` pointers stored in `self` + // point to the links field of `Node<K, V>` objects. + let this = unsafe { container_of!(node, Node<K, V>, links) }; + // SAFETY: `this` is a non-null node so it is valid by the type invariants. + node = match key.cmp(unsafe { &(*this).key }) { + // SAFETY: `node` is a non-null node so it is valid by the type invariants. + Ordering::Less => unsafe { (*node).rb_left }, + // SAFETY: `node` is a non-null node so it is valid by the type invariants. + Ordering::Greater => unsafe { (*node).rb_right }, + // SAFETY: `node` is a non-null node so it is valid by the type invariants. + Ordering::Equal => return Some(unsafe { &(*this).value }), + } + } + None + } + + /// Returns a mutable reference to the value corresponding to the key. + pub fn get_mut(&mut self, key: &K) -> Option<&mut V> { + self.find_mut(key).map(|node| node.into_mut()) + } + + /// Removes the node with the given key from the tree. + /// + /// It returns the node that was removed if one exists, or [`None`] otherwise. + pub fn remove_node(&mut self, key: &K) -> Option<RBTreeNode<K, V>> { + self.find_mut(key).map(OccupiedEntry::remove_node) + } + + /// Removes the node with the given key from the tree. + /// + /// It returns the value that was removed if one exists, or [`None`] otherwise. + pub fn remove(&mut self, key: &K) -> Option<V> { + self.find_mut(key).map(OccupiedEntry::remove) + } + + /// Returns a cursor over the tree nodes based on the given key. + /// + /// If the given key exists, the cursor starts there. + /// Otherwise it starts with the first larger key in sort order. + /// If there is no larger key, it returns [`None`]. + pub fn cursor_lower_bound_mut(&mut self, key: &K) -> Option<CursorMut<'_, K, V>> + where + K: Ord, + { + let best = self.find_best_match(key)?; + + NonNull::new(best.as_ptr()).map(|current| { + // INVARIANT: + // - `current` is a valid node in the [`RBTree`] pointed to by `self`. + CursorMut { + current, + tree: self, + } + }) + } + + /// Returns a cursor over the tree nodes based on the given key. + /// + /// If the given key exists, the cursor starts there. + /// Otherwise it starts with the first larger key in sort order. + /// If there is no larger key, it returns [`None`]. + pub fn cursor_lower_bound(&self, key: &K) -> Option<Cursor<'_, K, V>> + where + K: Ord, + { + let best = self.find_best_match(key)?; + + NonNull::new(best.as_ptr()).map(|current| { + // INVARIANT: + // - `current` is a valid node in the [`RBTree`] pointed to by `self`. + Cursor { + current, + _tree: PhantomData, + } + }) + } + + fn find_best_match(&self, key: &K) -> Option<NonNull<bindings::rb_node>> { + let mut node = self.root.rb_node; + let mut best_key: Option<&K> = None; + let mut best_links: Option<NonNull<bindings::rb_node>> = None; + while !node.is_null() { + // SAFETY: By the type invariant of `Self`, all non-null `rb_node` pointers stored in `self` + // point to the links field of `Node<K, V>` objects. + let this = unsafe { container_of!(node, Node<K, V>, links) }; + // SAFETY: `this` is a non-null node so it is valid by the type invariants. + let this_key = unsafe { &(*this).key }; + // SAFETY: `node` is a non-null node so it is valid by the type invariants. + let left_child = unsafe { (*node).rb_left }; + // SAFETY: `node` is a non-null node so it is valid by the type invariants. + let right_child = unsafe { (*node).rb_right }; + match key.cmp(this_key) { + Ordering::Equal => { + // SAFETY: `this` is a non-null node so it is valid by the type invariants. + best_links = Some(unsafe { NonNull::new_unchecked(&mut (*this).links) }); + break; + } + Ordering::Greater => { + node = right_child; + } + Ordering::Less => { + let is_better_match = match best_key { + None => true, + Some(best) => best > this_key, + }; + if is_better_match { + best_key = Some(this_key); + // SAFETY: `this` is a non-null node so it is valid by the type invariants. + best_links = Some(unsafe { NonNull::new_unchecked(&mut (*this).links) }); + } + node = left_child; + } + }; + } + best_links + } +} + +impl<K, V> Default for RBTree<K, V> { + fn default() -> Self { + Self::new() + } +} + +impl<K, V> Drop for RBTree<K, V> { + fn drop(&mut self) { + // SAFETY: `root` is valid as it's embedded in `self` and we have a valid `self`. + let mut next = unsafe { bindings::rb_first_postorder(&self.root) }; + + // INVARIANT: The loop invariant is that all tree nodes from `next` in postorder are valid. + while !next.is_null() { + // SAFETY: All links fields we create are in a `Node<K, V>`. + let this = unsafe { container_of!(next, Node<K, V>, links) }; + + // Find out what the next node is before disposing of the current one. + // SAFETY: `next` and all nodes in postorder are still valid. + next = unsafe { bindings::rb_next_postorder(next) }; + + // INVARIANT: This is the destructor, so we break the type invariant during clean-up, + // but it is not observable. The loop invariant is still maintained. + + // SAFETY: `this` is valid per the loop invariant. + unsafe { drop(KBox::from_raw(this)) }; + } + } +} + +/// A bidirectional mutable cursor over the tree nodes, sorted by key. +/// +/// # Examples +/// +/// In the following example, we obtain a cursor to the first element in the tree. +/// The cursor allows us to iterate bidirectionally over key/value pairs in the tree. +/// +/// ``` +/// use kernel::{alloc::flags, rbtree::RBTree}; +/// +/// // Create a new tree. +/// let mut tree = RBTree::new(); +/// +/// // Insert three elements. +/// tree.try_create_and_insert(10, 100, flags::GFP_KERNEL)?; +/// tree.try_create_and_insert(20, 200, flags::GFP_KERNEL)?; +/// tree.try_create_and_insert(30, 300, flags::GFP_KERNEL)?; +/// +/// // Get a cursor to the first element. +/// let mut cursor = tree.cursor_front_mut().unwrap(); +/// let mut current = cursor.current(); +/// assert_eq!(current, (&10, &100)); +/// +/// // Move the cursor, updating it to the 2nd element. +/// cursor = cursor.move_next().unwrap(); +/// current = cursor.current(); +/// assert_eq!(current, (&20, &200)); +/// +/// // Peek at the next element without impacting the cursor. +/// let next = cursor.peek_next().unwrap(); +/// assert_eq!(next, (&30, &300)); +/// current = cursor.current(); +/// assert_eq!(current, (&20, &200)); +/// +/// // Moving past the last element causes the cursor to return [`None`]. +/// cursor = cursor.move_next().unwrap(); +/// current = cursor.current(); +/// assert_eq!(current, (&30, &300)); +/// let cursor = cursor.move_next(); +/// assert!(cursor.is_none()); +/// +/// # Ok::<(), Error>(()) +/// ``` +/// +/// A cursor can also be obtained at the last element in the tree. +/// +/// ``` +/// use kernel::{alloc::flags, rbtree::RBTree}; +/// +/// // Create a new tree. +/// let mut tree = RBTree::new(); +/// +/// // Insert three elements. +/// tree.try_create_and_insert(10, 100, flags::GFP_KERNEL)?; +/// tree.try_create_and_insert(20, 200, flags::GFP_KERNEL)?; +/// tree.try_create_and_insert(30, 300, flags::GFP_KERNEL)?; +/// +/// let mut cursor = tree.cursor_back_mut().unwrap(); +/// let current = cursor.current(); +/// assert_eq!(current, (&30, &300)); +/// +/// # Ok::<(), Error>(()) +/// ``` +/// +/// Obtaining a cursor returns [`None`] if the tree is empty. +/// +/// ``` +/// use kernel::rbtree::RBTree; +/// +/// let mut tree: RBTree<u16, u16> = RBTree::new(); +/// assert!(tree.cursor_front_mut().is_none()); +/// +/// # Ok::<(), Error>(()) +/// ``` +/// +/// [`RBTree::cursor_lower_bound`] can be used to start at an arbitrary node in the tree. +/// +/// ``` +/// use kernel::{alloc::flags, rbtree::RBTree}; +/// +/// // Create a new tree. +/// let mut tree = RBTree::new(); +/// +/// // Insert five elements. +/// tree.try_create_and_insert(10, 100, flags::GFP_KERNEL)?; +/// tree.try_create_and_insert(20, 200, flags::GFP_KERNEL)?; +/// tree.try_create_and_insert(30, 300, flags::GFP_KERNEL)?; +/// tree.try_create_and_insert(40, 400, flags::GFP_KERNEL)?; +/// tree.try_create_and_insert(50, 500, flags::GFP_KERNEL)?; +/// +/// // If the provided key exists, a cursor to that key is returned. +/// let cursor = tree.cursor_lower_bound(&20).unwrap(); +/// let current = cursor.current(); +/// assert_eq!(current, (&20, &200)); +/// +/// // If the provided key doesn't exist, a cursor to the first larger element in sort order is returned. +/// let cursor = tree.cursor_lower_bound(&25).unwrap(); +/// let current = cursor.current(); +/// assert_eq!(current, (&30, &300)); +/// +/// // If there is no larger key, [`None`] is returned. +/// let cursor = tree.cursor_lower_bound(&55); +/// assert!(cursor.is_none()); +/// +/// # Ok::<(), Error>(()) +/// ``` +/// +/// The cursor allows mutation of values in the tree. +/// +/// ``` +/// use kernel::{alloc::flags, rbtree::RBTree}; +/// +/// // Create a new tree. +/// let mut tree = RBTree::new(); +/// +/// // Insert three elements. +/// tree.try_create_and_insert(10, 100, flags::GFP_KERNEL)?; +/// tree.try_create_and_insert(20, 200, flags::GFP_KERNEL)?; +/// tree.try_create_and_insert(30, 300, flags::GFP_KERNEL)?; +/// +/// // Retrieve a cursor. +/// let mut cursor = tree.cursor_front_mut().unwrap(); +/// +/// // Get a mutable reference to the current value. +/// let (k, v) = cursor.current_mut(); +/// *v = 1000; +/// +/// // The updated value is reflected in the tree. +/// let updated = tree.get(&10).unwrap(); +/// assert_eq!(updated, &1000); +/// +/// # Ok::<(), Error>(()) +/// ``` +/// +/// It also allows node removal. The following examples demonstrate the behavior of removing the current node. +/// +/// ``` +/// use kernel::{alloc::flags, rbtree::RBTree}; +/// +/// // Create a new tree. +/// let mut tree = RBTree::new(); +/// +/// // Insert three elements. +/// tree.try_create_and_insert(10, 100, flags::GFP_KERNEL)?; +/// tree.try_create_and_insert(20, 200, flags::GFP_KERNEL)?; +/// tree.try_create_and_insert(30, 300, flags::GFP_KERNEL)?; +/// +/// // Remove the first element. +/// let mut cursor = tree.cursor_front_mut().unwrap(); +/// let mut current = cursor.current(); +/// assert_eq!(current, (&10, &100)); +/// cursor = cursor.remove_current().0.unwrap(); +/// +/// // If a node exists after the current element, it is returned. +/// current = cursor.current(); +/// assert_eq!(current, (&20, &200)); +/// +/// // Get a cursor to the last element, and remove it. +/// cursor = tree.cursor_back_mut().unwrap(); +/// current = cursor.current(); +/// assert_eq!(current, (&30, &300)); +/// +/// // Since there is no next node, the previous node is returned. +/// cursor = cursor.remove_current().0.unwrap(); +/// current = cursor.current(); +/// assert_eq!(current, (&20, &200)); +/// +/// // Removing the last element in the tree returns [`None`]. +/// assert!(cursor.remove_current().0.is_none()); +/// +/// # Ok::<(), Error>(()) +/// ``` +/// +/// Nodes adjacent to the current node can also be removed. +/// +/// ``` +/// use kernel::{alloc::flags, rbtree::RBTree}; +/// +/// // Create a new tree. +/// let mut tree = RBTree::new(); +/// +/// // Insert three elements. +/// tree.try_create_and_insert(10, 100, flags::GFP_KERNEL)?; +/// tree.try_create_and_insert(20, 200, flags::GFP_KERNEL)?; +/// tree.try_create_and_insert(30, 300, flags::GFP_KERNEL)?; +/// +/// // Get a cursor to the first element. +/// let mut cursor = tree.cursor_front_mut().unwrap(); +/// let mut current = cursor.current(); +/// assert_eq!(current, (&10, &100)); +/// +/// // Calling `remove_prev` from the first element returns [`None`]. +/// assert!(cursor.remove_prev().is_none()); +/// +/// // Get a cursor to the last element. +/// cursor = tree.cursor_back_mut().unwrap(); +/// current = cursor.current(); +/// assert_eq!(current, (&30, &300)); +/// +/// // Calling `remove_prev` removes and returns the middle element. +/// assert_eq!(cursor.remove_prev().unwrap().to_key_value(), (20, 200)); +/// +/// // Calling `remove_next` from the last element returns [`None`]. +/// assert!(cursor.remove_next().is_none()); +/// +/// // Move to the first element +/// cursor = cursor.move_prev().unwrap(); +/// current = cursor.current(); +/// assert_eq!(current, (&10, &100)); +/// +/// // Calling `remove_next` removes and returns the last element. +/// assert_eq!(cursor.remove_next().unwrap().to_key_value(), (30, 300)); +/// +/// # Ok::<(), Error>(()) +/// +/// ``` +/// +/// # Invariants +/// - `current` points to a node that is in the same [`RBTree`] as `tree`. +pub struct CursorMut<'a, K, V> { + tree: &'a mut RBTree<K, V>, + current: NonNull<bindings::rb_node>, +} + +/// A bidirectional immutable cursor over the tree nodes, sorted by key. This is a simpler +/// variant of [`CursorMut`] that is basically providing read only access. +/// +/// # Examples +/// +/// In the following example, we obtain a cursor to the first element in the tree. +/// The cursor allows us to iterate bidirectionally over key/value pairs in the tree. +/// +/// ``` +/// use kernel::{alloc::flags, rbtree::RBTree}; +/// +/// // Create a new tree. +/// let mut tree = RBTree::new(); +/// +/// // Insert three elements. +/// tree.try_create_and_insert(10, 100, flags::GFP_KERNEL)?; +/// tree.try_create_and_insert(20, 200, flags::GFP_KERNEL)?; +/// tree.try_create_and_insert(30, 300, flags::GFP_KERNEL)?; +/// +/// // Get a cursor to the first element. +/// let cursor = tree.cursor_front().unwrap(); +/// let current = cursor.current(); +/// assert_eq!(current, (&10, &100)); +/// +/// # Ok::<(), Error>(()) +/// ``` +pub struct Cursor<'a, K, V> { + _tree: PhantomData<&'a RBTree<K, V>>, + current: NonNull<bindings::rb_node>, +} + +// SAFETY: The immutable cursor gives out shared access to `K` and `V` so if `K` and `V` can be +// shared across threads, then it's safe to share the cursor. +unsafe impl<'a, K: Sync, V: Sync> Send for Cursor<'a, K, V> {} + +// SAFETY: The immutable cursor gives out shared access to `K` and `V` so if `K` and `V` can be +// shared across threads, then it's safe to share the cursor. +unsafe impl<'a, K: Sync, V: Sync> Sync for Cursor<'a, K, V> {} + +impl<'a, K, V> Cursor<'a, K, V> { + /// The current node + pub fn current(&self) -> (&K, &V) { + // SAFETY: + // - `self.current` is a valid node by the type invariants. + // - We have an immutable reference by the function signature. + unsafe { Self::to_key_value(self.current) } + } + + /// # Safety + /// + /// - `node` must be a valid pointer to a node in an [`RBTree`]. + /// - The caller has immutable access to `node` for the duration of `'b`. + unsafe fn to_key_value<'b>(node: NonNull<bindings::rb_node>) -> (&'b K, &'b V) { + // SAFETY: By the type invariant of `Self`, all non-null `rb_node` pointers stored in `self` + // point to the links field of `Node<K, V>` objects. + let this = unsafe { container_of!(node.as_ptr(), Node<K, V>, links) }; + // SAFETY: The passed `node` is the current node or a non-null neighbor, + // thus `this` is valid by the type invariants. + let k = unsafe { &(*this).key }; + // SAFETY: The passed `node` is the current node or a non-null neighbor, + // thus `this` is valid by the type invariants. + let v = unsafe { &(*this).value }; + (k, v) + } + + /// Access the previous node without moving the cursor. + pub fn peek_prev(&self) -> Option<(&K, &V)> { + self.peek(Direction::Prev) + } + + /// Access the next node without moving the cursor. + pub fn peek_next(&self) -> Option<(&K, &V)> { + self.peek(Direction::Next) + } + + fn peek(&self, direction: Direction) -> Option<(&K, &V)> { + self.get_neighbor_raw(direction).map(|neighbor| { + // SAFETY: + // - `neighbor` is a valid tree node. + // - By the function signature, we have an immutable reference to `self`. + unsafe { Self::to_key_value(neighbor) } + }) + } + + fn get_neighbor_raw(&self, direction: Direction) -> Option<NonNull<bindings::rb_node>> { + // SAFETY: `self.current` is valid by the type invariants. + let neighbor = unsafe { + match direction { + Direction::Prev => bindings::rb_prev(self.current.as_ptr()), + Direction::Next => bindings::rb_next(self.current.as_ptr()), + } + }; + + NonNull::new(neighbor) + } +} + +// SAFETY: The [`CursorMut`] has exclusive access to both `K` and `V`, so it is sufficient to +// require them to be `Send`. +// The cursor only gives out immutable references to the keys, but since it has exclusive access to +// those same keys, `Send` is sufficient. `Sync` would be okay, but it is more restrictive to the +// user. +unsafe impl<'a, K: Send, V: Send> Send for CursorMut<'a, K, V> {} + +// SAFETY: The [`CursorMut`] gives out immutable references to `K` and mutable references to `V`, +// so it has the same thread safety requirements as mutable references. +unsafe impl<'a, K: Sync, V: Sync> Sync for CursorMut<'a, K, V> {} + +impl<'a, K, V> CursorMut<'a, K, V> { + /// The current node. + pub fn current(&self) -> (&K, &V) { + // SAFETY: + // - `self.current` is a valid node by the type invariants. + // - We have an immutable reference by the function signature. + unsafe { Self::to_key_value(self.current) } + } + + /// The current node, with a mutable value + pub fn current_mut(&mut self) -> (&K, &mut V) { + // SAFETY: + // - `self.current` is a valid node by the type invariants. + // - We have an mutable reference by the function signature. + unsafe { Self::to_key_value_mut(self.current) } + } + + /// Remove the current node from the tree. + /// + /// Returns a tuple where the first element is a cursor to the next node, if it exists, + /// else the previous node, else [`None`] (if the tree becomes empty). The second element + /// is the removed node. + pub fn remove_current(self) -> (Option<Self>, RBTreeNode<K, V>) { + let prev = self.get_neighbor_raw(Direction::Prev); + let next = self.get_neighbor_raw(Direction::Next); + // SAFETY: By the type invariant of `Self`, all non-null `rb_node` pointers stored in `self` + // point to the links field of `Node<K, V>` objects. + let this = unsafe { container_of!(self.current.as_ptr(), Node<K, V>, links) }; + // SAFETY: `this` is valid by the type invariants as described above. + let node = unsafe { KBox::from_raw(this) }; + let node = RBTreeNode { node }; + // SAFETY: The reference to the tree used to create the cursor outlives the cursor, so + // the tree cannot change. By the tree invariant, all nodes are valid. + unsafe { bindings::rb_erase(&mut (*this).links, addr_of_mut!(self.tree.root)) }; + + // INVARIANT: + // - `current` is a valid node in the [`RBTree`] pointed to by `self.tree`. + let cursor = next.or(prev).map(|current| Self { + current, + tree: self.tree, + }); + + (cursor, node) + } + + /// Remove the previous node, returning it if it exists. + pub fn remove_prev(&mut self) -> Option<RBTreeNode<K, V>> { + self.remove_neighbor(Direction::Prev) + } + + /// Remove the next node, returning it if it exists. + pub fn remove_next(&mut self) -> Option<RBTreeNode<K, V>> { + self.remove_neighbor(Direction::Next) + } + + fn remove_neighbor(&mut self, direction: Direction) -> Option<RBTreeNode<K, V>> { + if let Some(neighbor) = self.get_neighbor_raw(direction) { + let neighbor = neighbor.as_ptr(); + // SAFETY: The reference to the tree used to create the cursor outlives the cursor, so + // the tree cannot change. By the tree invariant, all nodes are valid. + unsafe { bindings::rb_erase(neighbor, addr_of_mut!(self.tree.root)) }; + // SAFETY: By the type invariant of `Self`, all non-null `rb_node` pointers stored in `self` + // point to the links field of `Node<K, V>` objects. + let this = unsafe { container_of!(neighbor, Node<K, V>, links) }; + // SAFETY: `this` is valid by the type invariants as described above. + let node = unsafe { KBox::from_raw(this) }; + return Some(RBTreeNode { node }); + } + None + } + + /// Move the cursor to the previous node, returning [`None`] if it doesn't exist. + pub fn move_prev(self) -> Option<Self> { + self.mv(Direction::Prev) + } + + /// Move the cursor to the next node, returning [`None`] if it doesn't exist. + pub fn move_next(self) -> Option<Self> { + self.mv(Direction::Next) + } + + fn mv(self, direction: Direction) -> Option<Self> { + // INVARIANT: + // - `neighbor` is a valid node in the [`RBTree`] pointed to by `self.tree`. + self.get_neighbor_raw(direction).map(|neighbor| Self { + tree: self.tree, + current: neighbor, + }) + } + + /// Access the previous node without moving the cursor. + pub fn peek_prev(&self) -> Option<(&K, &V)> { + self.peek(Direction::Prev) + } + + /// Access the previous node without moving the cursor. + pub fn peek_next(&self) -> Option<(&K, &V)> { + self.peek(Direction::Next) + } + + fn peek(&self, direction: Direction) -> Option<(&K, &V)> { + self.get_neighbor_raw(direction).map(|neighbor| { + // SAFETY: + // - `neighbor` is a valid tree node. + // - By the function signature, we have an immutable reference to `self`. + unsafe { Self::to_key_value(neighbor) } + }) + } + + /// Access the previous node mutably without moving the cursor. + pub fn peek_prev_mut(&mut self) -> Option<(&K, &mut V)> { + self.peek_mut(Direction::Prev) + } + + /// Access the next node mutably without moving the cursor. + pub fn peek_next_mut(&mut self) -> Option<(&K, &mut V)> { + self.peek_mut(Direction::Next) + } + + fn peek_mut(&mut self, direction: Direction) -> Option<(&K, &mut V)> { + self.get_neighbor_raw(direction).map(|neighbor| { + // SAFETY: + // - `neighbor` is a valid tree node. + // - By the function signature, we have a mutable reference to `self`. + unsafe { Self::to_key_value_mut(neighbor) } + }) + } + + fn get_neighbor_raw(&self, direction: Direction) -> Option<NonNull<bindings::rb_node>> { + // SAFETY: `self.current` is valid by the type invariants. + let neighbor = unsafe { + match direction { + Direction::Prev => bindings::rb_prev(self.current.as_ptr()), + Direction::Next => bindings::rb_next(self.current.as_ptr()), + } + }; + + NonNull::new(neighbor) + } + + /// # Safety + /// + /// - `node` must be a valid pointer to a node in an [`RBTree`]. + /// - The caller has immutable access to `node` for the duration of `'b`. + unsafe fn to_key_value<'b>(node: NonNull<bindings::rb_node>) -> (&'b K, &'b V) { + // SAFETY: the caller guarantees that `node` is a valid pointer in an `RBTree`. + let (k, v) = unsafe { Self::to_key_value_raw(node) }; + // SAFETY: the caller guarantees immutable access to `node`. + (k, unsafe { &*v }) + } + + /// # Safety + /// + /// - `node` must be a valid pointer to a node in an [`RBTree`]. + /// - The caller has mutable access to `node` for the duration of `'b`. + unsafe fn to_key_value_mut<'b>(node: NonNull<bindings::rb_node>) -> (&'b K, &'b mut V) { + // SAFETY: the caller guarantees that `node` is a valid pointer in an `RBTree`. + let (k, v) = unsafe { Self::to_key_value_raw(node) }; + // SAFETY: the caller guarantees mutable access to `node`. + (k, unsafe { &mut *v }) + } + + /// # Safety + /// + /// - `node` must be a valid pointer to a node in an [`RBTree`]. + /// - The caller has immutable access to the key for the duration of `'b`. + unsafe fn to_key_value_raw<'b>(node: NonNull<bindings::rb_node>) -> (&'b K, *mut V) { + // SAFETY: By the type invariant of `Self`, all non-null `rb_node` pointers stored in `self` + // point to the links field of `Node<K, V>` objects. + let this = unsafe { container_of!(node.as_ptr(), Node<K, V>, links) }; + // SAFETY: The passed `node` is the current node or a non-null neighbor, + // thus `this` is valid by the type invariants. + let k = unsafe { &(*this).key }; + // SAFETY: The passed `node` is the current node or a non-null neighbor, + // thus `this` is valid by the type invariants. + let v = unsafe { addr_of_mut!((*this).value) }; + (k, v) + } +} + +/// Direction for [`Cursor`] and [`CursorMut`] operations. +enum Direction { + /// the node immediately before, in sort order + Prev, + /// the node immediately after, in sort order + Next, +} + +impl<'a, K, V> IntoIterator for &'a RBTree<K, V> { + type Item = (&'a K, &'a V); + type IntoIter = Iter<'a, K, V>; + + fn into_iter(self) -> Self::IntoIter { + self.iter() + } +} + +/// An iterator over the nodes of a [`RBTree`]. +/// +/// Instances are created by calling [`RBTree::iter`]. +pub struct Iter<'a, K, V> { + _tree: PhantomData<&'a RBTree<K, V>>, + iter_raw: IterRaw<K, V>, +} + +// SAFETY: The [`Iter`] gives out immutable references to K and V, so it has the same +// thread safety requirements as immutable references. +unsafe impl<'a, K: Sync, V: Sync> Send for Iter<'a, K, V> {} + +// SAFETY: The [`Iter`] gives out immutable references to K and V, so it has the same +// thread safety requirements as immutable references. +unsafe impl<'a, K: Sync, V: Sync> Sync for Iter<'a, K, V> {} + +impl<'a, K, V> Iterator for Iter<'a, K, V> { + type Item = (&'a K, &'a V); + + fn next(&mut self) -> Option<Self::Item> { + // SAFETY: Due to `self._tree`, `k` and `v` are valid for the lifetime of `'a`. + self.iter_raw.next().map(|(k, v)| unsafe { (&*k, &*v) }) + } +} + +impl<'a, K, V> IntoIterator for &'a mut RBTree<K, V> { + type Item = (&'a K, &'a mut V); + type IntoIter = IterMut<'a, K, V>; + + fn into_iter(self) -> Self::IntoIter { + self.iter_mut() + } +} + +/// A mutable iterator over the nodes of a [`RBTree`]. +/// +/// Instances are created by calling [`RBTree::iter_mut`]. +pub struct IterMut<'a, K, V> { + _tree: PhantomData<&'a mut RBTree<K, V>>, + iter_raw: IterRaw<K, V>, +} + +// SAFETY: The [`IterMut`] has exclusive access to both `K` and `V`, so it is sufficient to require them to be `Send`. +// The iterator only gives out immutable references to the keys, but since the iterator has excusive access to those same +// keys, `Send` is sufficient. `Sync` would be okay, but it is more restrictive to the user. +unsafe impl<'a, K: Send, V: Send> Send for IterMut<'a, K, V> {} + +// SAFETY: The [`IterMut`] gives out immutable references to K and mutable references to V, so it has the same +// thread safety requirements as mutable references. +unsafe impl<'a, K: Sync, V: Sync> Sync for IterMut<'a, K, V> {} + +impl<'a, K, V> Iterator for IterMut<'a, K, V> { + type Item = (&'a K, &'a mut V); + + fn next(&mut self) -> Option<Self::Item> { + self.iter_raw.next().map(|(k, v)| + // SAFETY: Due to `&mut self`, we have exclusive access to `k` and `v`, for the lifetime of `'a`. + unsafe { (&*k, &mut *v) }) + } +} + +/// A raw iterator over the nodes of a [`RBTree`]. +/// +/// # Invariants +/// - `self.next` is a valid pointer. +/// - `self.next` points to a node stored inside of a valid `RBTree`. +struct IterRaw<K, V> { + next: *mut bindings::rb_node, + _phantom: PhantomData<fn() -> (K, V)>, +} + +impl<K, V> Iterator for IterRaw<K, V> { + type Item = (*mut K, *mut V); + + fn next(&mut self) -> Option<Self::Item> { + if self.next.is_null() { + return None; + } + + // SAFETY: By the type invariant of `IterRaw`, `self.next` is a valid node in an `RBTree`, + // and by the type invariant of `RBTree`, all nodes point to the links field of `Node<K, V>` objects. + let cur = unsafe { container_of!(self.next, Node<K, V>, links) }; + + // SAFETY: `self.next` is a valid tree node by the type invariants. + self.next = unsafe { bindings::rb_next(self.next) }; + + // SAFETY: By the same reasoning above, it is safe to dereference the node. + Some(unsafe { (addr_of_mut!((*cur).key), addr_of_mut!((*cur).value)) }) + } +} + +/// A memory reservation for a red-black tree node. +/// +/// +/// It contains the memory needed to hold a node that can be inserted into a red-black tree. One +/// can be obtained by directly allocating it ([`RBTreeNodeReservation::new`]). +pub struct RBTreeNodeReservation<K, V> { + node: KBox<MaybeUninit<Node<K, V>>>, +} + +impl<K, V> RBTreeNodeReservation<K, V> { + /// Allocates memory for a node to be eventually initialised and inserted into the tree via a + /// call to [`RBTree::insert`]. + pub fn new(flags: Flags) -> Result<RBTreeNodeReservation<K, V>> { + Ok(RBTreeNodeReservation { + node: KBox::new_uninit(flags)?, + }) + } +} + +// SAFETY: This doesn't actually contain K or V, and is just a memory allocation. Those can always +// be moved across threads. +unsafe impl<K, V> Send for RBTreeNodeReservation<K, V> {} + +// SAFETY: This doesn't actually contain K or V, and is just a memory allocation. +unsafe impl<K, V> Sync for RBTreeNodeReservation<K, V> {} + +impl<K, V> RBTreeNodeReservation<K, V> { + /// Initialises a node reservation. + /// + /// It then becomes an [`RBTreeNode`] that can be inserted into a tree. + pub fn into_node(self, key: K, value: V) -> RBTreeNode<K, V> { + let node = KBox::write( + self.node, + Node { + key, + value, + links: bindings::rb_node::default(), + }, + ); + RBTreeNode { node } + } +} + +/// A red-black tree node. +/// +/// The node is fully initialised (with key and value) and can be inserted into a tree without any +/// extra allocations or failure paths. +pub struct RBTreeNode<K, V> { + node: KBox<Node<K, V>>, +} + +impl<K, V> RBTreeNode<K, V> { + /// Allocates and initialises a node that can be inserted into the tree via + /// [`RBTree::insert`]. + pub fn new(key: K, value: V, flags: Flags) -> Result<RBTreeNode<K, V>> { + Ok(RBTreeNodeReservation::new(flags)?.into_node(key, value)) + } + + /// Get the key and value from inside the node. + pub fn to_key_value(self) -> (K, V) { + let node = KBox::into_inner(self.node); + + (node.key, node.value) + } +} + +// SAFETY: If K and V can be sent across threads, then it's also okay to send [`RBTreeNode`] across +// threads. +unsafe impl<K: Send, V: Send> Send for RBTreeNode<K, V> {} + +// SAFETY: If K and V can be accessed without synchronization, then it's also okay to access +// [`RBTreeNode`] without synchronization. +unsafe impl<K: Sync, V: Sync> Sync for RBTreeNode<K, V> {} + +impl<K, V> RBTreeNode<K, V> { + /// Drop the key and value, but keep the allocation. + /// + /// It then becomes a reservation that can be re-initialised into a different node (i.e., with + /// a different key and/or value). + /// + /// The existing key and value are dropped in-place as part of this operation, that is, memory + /// may be freed (but only for the key/value; memory for the node itself is kept for reuse). + pub fn into_reservation(self) -> RBTreeNodeReservation<K, V> { + RBTreeNodeReservation { + node: KBox::drop_contents(self.node), + } + } +} + +/// A view into a single entry in a map, which may either be vacant or occupied. +/// +/// This enum is constructed from the [`RBTree::entry`]. +/// +/// [`entry`]: fn@RBTree::entry +pub enum Entry<'a, K, V> { + /// This [`RBTree`] does not have a node with this key. + Vacant(VacantEntry<'a, K, V>), + /// This [`RBTree`] already has a node with this key. + Occupied(OccupiedEntry<'a, K, V>), +} + +/// Like [`Entry`], except that it doesn't have ownership of the key. +enum RawEntry<'a, K, V> { + Vacant(RawVacantEntry<'a, K, V>), + Occupied(OccupiedEntry<'a, K, V>), +} + +/// A view into a vacant entry in a [`RBTree`]. It is part of the [`Entry`] enum. +pub struct VacantEntry<'a, K, V> { + key: K, + raw: RawVacantEntry<'a, K, V>, +} + +/// Like [`VacantEntry`], but doesn't hold on to the key. +/// +/// # Invariants +/// - `parent` may be null if the new node becomes the root. +/// - `child_field_of_parent` is a valid pointer to the left-child or right-child of `parent`. If `parent` is +/// null, it is a pointer to the root of the [`RBTree`]. +struct RawVacantEntry<'a, K, V> { + rbtree: *mut RBTree<K, V>, + /// The node that will become the parent of the new node if we insert one. + parent: *mut bindings::rb_node, + /// This points to the left-child or right-child field of `parent`, or `root` if `parent` is + /// null. + child_field_of_parent: *mut *mut bindings::rb_node, + _phantom: PhantomData<&'a mut RBTree<K, V>>, +} + +impl<'a, K, V> RawVacantEntry<'a, K, V> { + /// Inserts the given node into the [`RBTree`] at this entry. + /// + /// The `node` must have a key such that inserting it here does not break the ordering of this + /// [`RBTree`]. + fn insert(self, node: RBTreeNode<K, V>) -> &'a mut V { + let node = KBox::into_raw(node.node); + + // SAFETY: `node` is valid at least until we call `KBox::from_raw`, which only happens when + // the node is removed or replaced. + let node_links = unsafe { addr_of_mut!((*node).links) }; + + // INVARIANT: We are linking in a new node, which is valid. It remains valid because we + // "forgot" it with `KBox::into_raw`. + // SAFETY: The type invariants of `RawVacantEntry` are exactly the safety requirements of `rb_link_node`. + unsafe { bindings::rb_link_node(node_links, self.parent, self.child_field_of_parent) }; + + // SAFETY: All pointers are valid. `node` has just been inserted into the tree. + unsafe { bindings::rb_insert_color(node_links, addr_of_mut!((*self.rbtree).root)) }; + + // SAFETY: The node is valid until we remove it from the tree. + unsafe { &mut (*node).value } + } +} + +impl<'a, K, V> VacantEntry<'a, K, V> { + /// Inserts the given node into the [`RBTree`] at this entry. + pub fn insert(self, value: V, reservation: RBTreeNodeReservation<K, V>) -> &'a mut V { + self.raw.insert(reservation.into_node(self.key, value)) + } +} + +/// A view into an occupied entry in a [`RBTree`]. It is part of the [`Entry`] enum. +/// +/// # Invariants +/// - `node_links` is a valid, non-null pointer to a tree node in `self.rbtree` +pub struct OccupiedEntry<'a, K, V> { + rbtree: &'a mut RBTree<K, V>, + /// The node that this entry corresponds to. + node_links: *mut bindings::rb_node, +} + +impl<'a, K, V> OccupiedEntry<'a, K, V> { + /// Gets a reference to the value in the entry. + pub fn get(&self) -> &V { + // SAFETY: + // - `self.node_links` is a valid pointer to a node in the tree. + // - We have shared access to the underlying tree, and can thus give out a shared reference. + unsafe { &(*container_of!(self.node_links, Node<K, V>, links)).value } + } + + /// Gets a mutable reference to the value in the entry. + pub fn get_mut(&mut self) -> &mut V { + // SAFETY: + // - `self.node_links` is a valid pointer to a node in the tree. + // - We have exclusive access to the underlying tree, and can thus give out a mutable reference. + unsafe { &mut (*(container_of!(self.node_links, Node<K, V>, links))).value } + } + + /// Converts the entry into a mutable reference to its value. + /// + /// If you need multiple references to the `OccupiedEntry`, see [`self#get_mut`]. + pub fn into_mut(self) -> &'a mut V { + // SAFETY: + // - `self.node_links` is a valid pointer to a node in the tree. + // - This consumes the `&'a mut RBTree<K, V>`, therefore it can give out a mutable reference that lives for `'a`. + unsafe { &mut (*(container_of!(self.node_links, Node<K, V>, links))).value } + } + + /// Remove this entry from the [`RBTree`]. + pub fn remove_node(self) -> RBTreeNode<K, V> { + // SAFETY: The node is a node in the tree, so it is valid. + unsafe { bindings::rb_erase(self.node_links, &mut self.rbtree.root) }; + + // INVARIANT: The node is being returned and the caller may free it, however, it was + // removed from the tree. So the invariants still hold. + RBTreeNode { + // SAFETY: The node was a node in the tree, but we removed it, so we can convert it + // back into a box. + node: unsafe { KBox::from_raw(container_of!(self.node_links, Node<K, V>, links)) }, + } + } + + /// Takes the value of the entry out of the map, and returns it. + pub fn remove(self) -> V { + let rb_node = self.remove_node(); + let node = KBox::into_inner(rb_node.node); + + node.value + } + + /// Swap the current node for the provided node. + /// + /// The key of both nodes must be equal. + fn replace(self, node: RBTreeNode<K, V>) -> RBTreeNode<K, V> { + let node = KBox::into_raw(node.node); + + // SAFETY: `node` is valid at least until we call `KBox::from_raw`, which only happens when + // the node is removed or replaced. + let new_node_links = unsafe { addr_of_mut!((*node).links) }; + + // SAFETY: This updates the pointers so that `new_node_links` is in the tree where + // `self.node_links` used to be. + unsafe { + bindings::rb_replace_node(self.node_links, new_node_links, &mut self.rbtree.root) + }; + + // SAFETY: + // - `self.node_ptr` produces a valid pointer to a node in the tree. + // - Now that we removed this entry from the tree, we can convert the node to a box. + let old_node = unsafe { KBox::from_raw(container_of!(self.node_links, Node<K, V>, links)) }; + + RBTreeNode { node: old_node } + } +} + +struct Node<K, V> { + links: bindings::rb_node, + key: K, + value: V, +} diff --git a/rust/kernel/regulator.rs b/rust/kernel/regulator.rs new file mode 100644 index 000000000000..2c44827ad0b7 --- /dev/null +++ b/rust/kernel/regulator.rs @@ -0,0 +1,398 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! Regulator abstractions, providing a standard kernel interface to control +//! voltage and current regulators. +//! +//! The intention is to allow systems to dynamically control regulator power +//! output in order to save power and prolong battery life. This applies to both +//! voltage regulators (where voltage output is controllable) and current sinks +//! (where current limit is controllable). +//! +//! C header: [`include/linux/regulator/consumer.h`](srctree/include/linux/regulator/consumer.h) +//! +//! Regulators are modeled in Rust with a collection of states. Each state may +//! enforce a given invariant, and they may convert between each other where applicable. +//! +//! See [Voltage and current regulator API](https://docs.kernel.org/driver-api/regulator.html) +//! for more information. + +use crate::{ + bindings, + device::{Bound, Device}, + error::{from_err_ptr, to_result, Result}, + prelude::*, +}; + +use core::{marker::PhantomData, mem::ManuallyDrop, ptr::NonNull}; + +mod private { + pub trait Sealed {} + + impl Sealed for super::Enabled {} + impl Sealed for super::Disabled {} +} + +/// A trait representing the different states a [`Regulator`] can be in. +pub trait RegulatorState: private::Sealed + 'static { + /// Whether the regulator should be disabled when dropped. + const DISABLE_ON_DROP: bool; +} + +/// A state where the [`Regulator`] is known to be enabled. +/// +/// The `enable` reference count held by this state is decremented when it is +/// dropped. +pub struct Enabled; + +/// A state where this [`Regulator`] handle has not specifically asked for the +/// underlying regulator to be enabled. This means that this reference does not +/// own an `enable` reference count, but the regulator may still be on. +pub struct Disabled; + +impl RegulatorState for Enabled { + const DISABLE_ON_DROP: bool = true; +} + +impl RegulatorState for Disabled { + const DISABLE_ON_DROP: bool = false; +} + +/// A trait that abstracts the ability to check if a [`Regulator`] is enabled. +pub trait IsEnabled: RegulatorState {} +impl IsEnabled for Disabled {} + +/// An error that can occur when trying to convert a [`Regulator`] between states. +pub struct Error<State: RegulatorState> { + /// The error that occurred. + pub error: kernel::error::Error, + + /// The regulator that caused the error, so that the operation may be retried. + pub regulator: Regulator<State>, +} +/// Obtains and enables a [`devres`]-managed regulator for a device. +/// +/// This calls [`regulator_disable()`] and [`regulator_put()`] automatically on +/// driver detach. +/// +/// This API is identical to `devm_regulator_get_enable()`, and should be +/// preferred over the [`Regulator<T: RegulatorState>`] API if the caller only +/// cares about the regulator being enabled. +/// +/// [`devres`]: https://docs.kernel.org/driver-api/driver-model/devres.html +/// [`regulator_disable()`]: https://docs.kernel.org/driver-api/regulator.html#c.regulator_disable +/// [`regulator_put()`]: https://docs.kernel.org/driver-api/regulator.html#c.regulator_put +pub fn devm_enable(dev: &Device<Bound>, name: &CStr) -> Result { + // SAFETY: `dev` is a valid and bound device, while `name` is a valid C + // string. + to_result(unsafe { bindings::devm_regulator_get_enable(dev.as_raw(), name.as_char_ptr()) }) +} + +/// Same as [`devm_enable`], but calls `devm_regulator_get_enable_optional` +/// instead. +/// +/// This obtains and enables a [`devres`]-managed regulator for a device, but +/// does not print a message nor provides a dummy if the regulator is not found. +/// +/// This calls [`regulator_disable()`] and [`regulator_put()`] automatically on +/// driver detach. +/// +/// [`devres`]: https://docs.kernel.org/driver-api/driver-model/devres.html +/// [`regulator_disable()`]: https://docs.kernel.org/driver-api/regulator.html#c.regulator_disable +/// [`regulator_put()`]: https://docs.kernel.org/driver-api/regulator.html#c.regulator_put +pub fn devm_enable_optional(dev: &Device<Bound>, name: &CStr) -> Result { + // SAFETY: `dev` is a valid and bound device, while `name` is a valid C + // string. + to_result(unsafe { + bindings::devm_regulator_get_enable_optional(dev.as_raw(), name.as_char_ptr()) + }) +} + +/// A `struct regulator` abstraction. +/// +/// # Examples +/// +/// ## Enabling a regulator +/// +/// This example uses [`Regulator<Enabled>`], which is suitable for drivers that +/// enable a regulator at probe time and leave them on until the device is +/// removed or otherwise shutdown. +/// +/// These users can store [`Regulator<Enabled>`] directly in their driver's +/// private data struct. +/// +/// ``` +/// # use kernel::prelude::*; +/// # use kernel::c_str; +/// # use kernel::device::Device; +/// # use kernel::regulator::{Voltage, Regulator, Disabled, Enabled}; +/// fn enable(dev: &Device, min_voltage: Voltage, max_voltage: Voltage) -> Result { +/// // Obtain a reference to a (fictitious) regulator. +/// let regulator: Regulator<Disabled> = Regulator::<Disabled>::get(dev, c_str!("vcc"))?; +/// +/// // The voltage can be set before enabling the regulator if needed, e.g.: +/// regulator.set_voltage(min_voltage, max_voltage)?; +/// +/// // The same applies for `get_voltage()`, i.e.: +/// let voltage: Voltage = regulator.get_voltage()?; +/// +/// // Enables the regulator, consuming the previous value. +/// // +/// // From now on, the regulator is known to be enabled because of the type +/// // `Enabled`. +/// // +/// // If this operation fails, the `Error` will contain the regulator +/// // reference, so that the operation may be retried. +/// let regulator: Regulator<Enabled> = +/// regulator.try_into_enabled().map_err(|error| error.error)?; +/// +/// // The voltage can also be set after enabling the regulator, e.g.: +/// regulator.set_voltage(min_voltage, max_voltage)?; +/// +/// // The same applies for `get_voltage()`, i.e.: +/// let voltage: Voltage = regulator.get_voltage()?; +/// +/// // Dropping an enabled regulator will disable it. The refcount will be +/// // decremented. +/// drop(regulator); +/// +/// // ... +/// +/// Ok(()) +/// } +/// ``` +/// +/// A more concise shortcut is available for enabling a regulator. This is +/// equivalent to `regulator_get_enable()`: +/// +/// ``` +/// # use kernel::prelude::*; +/// # use kernel::c_str; +/// # use kernel::device::Device; +/// # use kernel::regulator::{Voltage, Regulator, Enabled}; +/// fn enable(dev: &Device) -> Result { +/// // Obtain a reference to a (fictitious) regulator and enable it. +/// let regulator: Regulator<Enabled> = Regulator::<Enabled>::get(dev, c_str!("vcc"))?; +/// +/// // Dropping an enabled regulator will disable it. The refcount will be +/// // decremented. +/// drop(regulator); +/// +/// // ... +/// +/// Ok(()) +/// } +/// ``` +/// +/// If a driver only cares about the regulator being on for as long it is bound +/// to a device, then it should use [`devm_enable`] or [`devm_enable_optional`]. +/// This should be the default use-case unless more fine-grained control over +/// the regulator's state is required. +/// +/// [`devm_enable`]: crate::regulator::devm_enable +/// [`devm_optional`]: crate::regulator::devm_enable_optional +/// +/// ``` +/// # use kernel::prelude::*; +/// # use kernel::c_str; +/// # use kernel::device::{Bound, Device}; +/// # use kernel::regulator; +/// fn enable(dev: &Device<Bound>) -> Result { +/// // Obtain a reference to a (fictitious) regulator and enable it. This +/// // call only returns whether the operation succeeded. +/// regulator::devm_enable(dev, c_str!("vcc"))?; +/// +/// // The regulator will be disabled and put when `dev` is unbound. +/// Ok(()) +/// } +/// ``` +/// +/// ## Disabling a regulator +/// +/// ``` +/// # use kernel::prelude::*; +/// # use kernel::device::Device; +/// # use kernel::regulator::{Regulator, Enabled, Disabled}; +/// fn disable(dev: &Device, regulator: Regulator<Enabled>) -> Result { +/// // We can also disable an enabled regulator without reliquinshing our +/// // refcount: +/// // +/// // If this operation fails, the `Error` will contain the regulator +/// // reference, so that the operation may be retried. +/// let regulator: Regulator<Disabled> = +/// regulator.try_into_disabled().map_err(|error| error.error)?; +/// +/// // The refcount will be decremented when `regulator` is dropped. +/// drop(regulator); +/// +/// // ... +/// +/// Ok(()) +/// } +/// ``` +/// +/// # Invariants +/// +/// - `inner` is a non-null wrapper over a pointer to a `struct +/// regulator` obtained from [`regulator_get()`]. +/// +/// [`regulator_get()`]: https://docs.kernel.org/driver-api/regulator.html#c.regulator_get +pub struct Regulator<State> +where + State: RegulatorState, +{ + inner: NonNull<bindings::regulator>, + _phantom: PhantomData<State>, +} + +impl<T: RegulatorState> Regulator<T> { + /// Sets the voltage for the regulator. + /// + /// This can be used to ensure that the device powers up cleanly. + pub fn set_voltage(&self, min_voltage: Voltage, max_voltage: Voltage) -> Result { + // SAFETY: Safe as per the type invariants of `Regulator`. + to_result(unsafe { + bindings::regulator_set_voltage( + self.inner.as_ptr(), + min_voltage.as_microvolts(), + max_voltage.as_microvolts(), + ) + }) + } + + /// Gets the current voltage of the regulator. + pub fn get_voltage(&self) -> Result<Voltage> { + // SAFETY: Safe as per the type invariants of `Regulator`. + let voltage = unsafe { bindings::regulator_get_voltage(self.inner.as_ptr()) }; + + to_result(voltage).map(|()| Voltage::from_microvolts(voltage)) + } + + fn get_internal(dev: &Device, name: &CStr) -> Result<Regulator<T>> { + let inner = + // SAFETY: It is safe to call `regulator_get()`, on a device pointer + // received from the C code. + from_err_ptr(unsafe { bindings::regulator_get(dev.as_raw(), name.as_char_ptr()) })?; + + // SAFETY: We can safely trust `inner` to be a pointer to a valid + // regulator if `ERR_PTR` was not returned. + let inner = unsafe { NonNull::new_unchecked(inner) }; + + Ok(Self { + inner, + _phantom: PhantomData, + }) + } + + fn enable_internal(&self) -> Result { + // SAFETY: Safe as per the type invariants of `Regulator`. + to_result(unsafe { bindings::regulator_enable(self.inner.as_ptr()) }) + } + + fn disable_internal(&self) -> Result { + // SAFETY: Safe as per the type invariants of `Regulator`. + to_result(unsafe { bindings::regulator_disable(self.inner.as_ptr()) }) + } +} + +impl Regulator<Disabled> { + /// Obtains a [`Regulator`] instance from the system. + pub fn get(dev: &Device, name: &CStr) -> Result<Self> { + Regulator::get_internal(dev, name) + } + + /// Attempts to convert the regulator to an enabled state. + pub fn try_into_enabled(self) -> Result<Regulator<Enabled>, Error<Disabled>> { + // We will be transferring the ownership of our `regulator_get()` count to + // `Regulator<Enabled>`. + let regulator = ManuallyDrop::new(self); + + regulator + .enable_internal() + .map(|()| Regulator { + inner: regulator.inner, + _phantom: PhantomData, + }) + .map_err(|error| Error { + error, + regulator: ManuallyDrop::into_inner(regulator), + }) + } +} + +impl Regulator<Enabled> { + /// Obtains a [`Regulator`] instance from the system and enables it. + /// + /// This is equivalent to calling `regulator_get_enable()` in the C API. + pub fn get(dev: &Device, name: &CStr) -> Result<Self> { + Regulator::<Disabled>::get_internal(dev, name)? + .try_into_enabled() + .map_err(|error| error.error) + } + + /// Attempts to convert the regulator to a disabled state. + pub fn try_into_disabled(self) -> Result<Regulator<Disabled>, Error<Enabled>> { + // We will be transferring the ownership of our `regulator_get()` count + // to `Regulator<Disabled>`. + let regulator = ManuallyDrop::new(self); + + regulator + .disable_internal() + .map(|()| Regulator { + inner: regulator.inner, + _phantom: PhantomData, + }) + .map_err(|error| Error { + error, + regulator: ManuallyDrop::into_inner(regulator), + }) + } +} + +impl<T: IsEnabled> Regulator<T> { + /// Checks if the regulator is enabled. + pub fn is_enabled(&self) -> bool { + // SAFETY: Safe as per the type invariants of `Regulator`. + unsafe { bindings::regulator_is_enabled(self.inner.as_ptr()) != 0 } + } +} + +impl<T: RegulatorState> Drop for Regulator<T> { + fn drop(&mut self) { + if T::DISABLE_ON_DROP { + // SAFETY: By the type invariants, we know that `self` owns a + // reference on the enabled refcount, so it is safe to relinquish it + // now. + unsafe { bindings::regulator_disable(self.inner.as_ptr()) }; + } + // SAFETY: By the type invariants, we know that `self` owns a reference, + // so it is safe to relinquish it now. + unsafe { bindings::regulator_put(self.inner.as_ptr()) }; + } +} + +// SAFETY: It is safe to send a `Regulator<T>` across threads. In particular, a +// Regulator<T> can be dropped from any thread. +unsafe impl<T: RegulatorState> Send for Regulator<T> {} + +// SAFETY: It is safe to send a &Regulator<T> across threads because the C side +// handles its own locking. +unsafe impl<T: RegulatorState> Sync for Regulator<T> {} + +/// A voltage. +/// +/// This type represents a voltage value in microvolts. +#[repr(transparent)] +#[derive(Copy, Clone, PartialEq, Eq)] +pub struct Voltage(i32); + +impl Voltage { + /// Creates a new `Voltage` from a value in microvolts. + pub fn from_microvolts(uv: i32) -> Self { + Self(uv) + } + + /// Returns the value of the voltage in microvolts as an [`i32`]. + pub fn as_microvolts(self) -> i32 { + self.0 + } +} diff --git a/rust/kernel/revocable.rs b/rust/kernel/revocable.rs new file mode 100644 index 000000000000..0f4ae673256d --- /dev/null +++ b/rust/kernel/revocable.rs @@ -0,0 +1,263 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! Revocable objects. +//! +//! The [`Revocable`] type wraps other types and allows access to them to be revoked. The existence +//! of a [`RevocableGuard`] ensures that objects remain valid. + +use pin_init::Wrapper; + +use crate::{bindings, prelude::*, sync::rcu, types::Opaque}; +use core::{ + marker::PhantomData, + ops::Deref, + ptr::drop_in_place, + sync::atomic::{AtomicBool, Ordering}, +}; + +/// An object that can become inaccessible at runtime. +/// +/// Once access is revoked and all concurrent users complete (i.e., all existing instances of +/// [`RevocableGuard`] are dropped), the wrapped object is also dropped. +/// +/// # Examples +/// +/// ``` +/// # use kernel::revocable::Revocable; +/// +/// struct Example { +/// a: u32, +/// b: u32, +/// } +/// +/// fn add_two(v: &Revocable<Example>) -> Option<u32> { +/// let guard = v.try_access()?; +/// Some(guard.a + guard.b) +/// } +/// +/// let v = KBox::pin_init(Revocable::new(Example { a: 10, b: 20 }), GFP_KERNEL).unwrap(); +/// assert_eq!(add_two(&v), Some(30)); +/// v.revoke(); +/// assert_eq!(add_two(&v), None); +/// ``` +/// +/// Sample example as above, but explicitly using the rcu read side lock. +/// +/// ``` +/// # use kernel::revocable::Revocable; +/// use kernel::sync::rcu; +/// +/// struct Example { +/// a: u32, +/// b: u32, +/// } +/// +/// fn add_two(v: &Revocable<Example>) -> Option<u32> { +/// let guard = rcu::read_lock(); +/// let e = v.try_access_with_guard(&guard)?; +/// Some(e.a + e.b) +/// } +/// +/// let v = KBox::pin_init(Revocable::new(Example { a: 10, b: 20 }), GFP_KERNEL).unwrap(); +/// assert_eq!(add_two(&v), Some(30)); +/// v.revoke(); +/// assert_eq!(add_two(&v), None); +/// ``` +#[pin_data(PinnedDrop)] +pub struct Revocable<T> { + is_available: AtomicBool, + #[pin] + data: Opaque<T>, +} + +// SAFETY: `Revocable` is `Send` if the wrapped object is also `Send`. This is because while the +// functionality exposed by `Revocable` can be accessed from any thread/CPU, it is possible that +// this isn't supported by the wrapped object. +unsafe impl<T: Send> Send for Revocable<T> {} + +// SAFETY: `Revocable` is `Sync` if the wrapped object is both `Send` and `Sync`. We require `Send` +// from the wrapped object as well because of `Revocable::revoke`, which can trigger the `Drop` +// implementation of the wrapped object from an arbitrary thread. +unsafe impl<T: Sync + Send> Sync for Revocable<T> {} + +impl<T> Revocable<T> { + /// Creates a new revocable instance of the given data. + pub fn new<E>(data: impl PinInit<T, E>) -> impl PinInit<Self, E> { + try_pin_init!(Self { + is_available: AtomicBool::new(true), + data <- Opaque::pin_init(data), + }? E) + } + + /// Tries to access the revocable wrapped object. + /// + /// Returns `None` if the object has been revoked and is therefore no longer accessible. + /// + /// Returns a guard that gives access to the object otherwise; the object is guaranteed to + /// remain accessible while the guard is alive. In such cases, callers are not allowed to sleep + /// because another CPU may be waiting to complete the revocation of this object. + pub fn try_access(&self) -> Option<RevocableGuard<'_, T>> { + let guard = rcu::read_lock(); + if self.is_available.load(Ordering::Relaxed) { + // Since `self.is_available` is true, data is initialised and has to remain valid + // because the RCU read side lock prevents it from being dropped. + Some(RevocableGuard::new(self.data.get(), guard)) + } else { + None + } + } + + /// Tries to access the revocable wrapped object. + /// + /// Returns `None` if the object has been revoked and is therefore no longer accessible. + /// + /// Returns a shared reference to the object otherwise; the object is guaranteed to + /// remain accessible while the rcu read side guard is alive. In such cases, callers are not + /// allowed to sleep because another CPU may be waiting to complete the revocation of this + /// object. + pub fn try_access_with_guard<'a>(&'a self, _guard: &'a rcu::Guard) -> Option<&'a T> { + if self.is_available.load(Ordering::Relaxed) { + // SAFETY: Since `self.is_available` is true, data is initialised and has to remain + // valid because the RCU read side lock prevents it from being dropped. + Some(unsafe { &*self.data.get() }) + } else { + None + } + } + + /// Tries to access the wrapped object and run a closure on it while the guard is held. + /// + /// This is a convenience method to run short non-sleepable code blocks while ensuring the + /// guard is dropped afterwards. [`Self::try_access`] carries the risk that the caller will + /// forget to explicitly drop that returned guard before calling sleepable code; this method + /// adds an extra safety to make sure it doesn't happen. + /// + /// Returns [`None`] if the object has been revoked and is therefore no longer accessible, or + /// the result of the closure wrapped in [`Some`]. If the closure returns a [`Result`] then the + /// return type becomes `Option<Result<>>`, which can be inconvenient. Users are encouraged to + /// define their own macro that turns the [`Option`] into a proper error code and flattens the + /// inner result into it if it makes sense within their subsystem. + pub fn try_access_with<R, F: FnOnce(&T) -> R>(&self, f: F) -> Option<R> { + self.try_access().map(|t| f(&*t)) + } + + /// Directly access the revocable wrapped object. + /// + /// # Safety + /// + /// The caller must ensure this [`Revocable`] instance hasn't been revoked and won't be revoked + /// as long as the returned `&T` lives. + pub unsafe fn access(&self) -> &T { + // SAFETY: By the safety requirement of this function it is guaranteed that + // `self.data.get()` is a valid pointer to an instance of `T`. + unsafe { &*self.data.get() } + } + + /// # Safety + /// + /// Callers must ensure that there are no more concurrent users of the revocable object. + unsafe fn revoke_internal<const SYNC: bool>(&self) -> bool { + let revoke = self.is_available.swap(false, Ordering::Relaxed); + + if revoke { + if SYNC { + // SAFETY: Just an FFI call, there are no further requirements. + unsafe { bindings::synchronize_rcu() }; + } + + // SAFETY: We know `self.data` is valid because only one CPU can succeed the + // `compare_exchange` above that takes `is_available` from `true` to `false`. + unsafe { drop_in_place(self.data.get()) }; + } + + revoke + } + + /// Revokes access to and drops the wrapped object. + /// + /// Access to the object is revoked immediately to new callers of [`Revocable::try_access`], + /// expecting that there are no concurrent users of the object. + /// + /// Returns `true` if `&self` has been revoked with this call, `false` if it was revoked + /// already. + /// + /// # Safety + /// + /// Callers must ensure that there are no more concurrent users of the revocable object. + pub unsafe fn revoke_nosync(&self) -> bool { + // SAFETY: By the safety requirement of this function, the caller ensures that nobody is + // accessing the data anymore and hence we don't have to wait for the grace period to + // finish. + unsafe { self.revoke_internal::<false>() } + } + + /// Revokes access to and drops the wrapped object. + /// + /// Access to the object is revoked immediately to new callers of [`Revocable::try_access`]. + /// + /// If there are concurrent users of the object (i.e., ones that called + /// [`Revocable::try_access`] beforehand and still haven't dropped the returned guard), this + /// function waits for the concurrent access to complete before dropping the wrapped object. + /// + /// Returns `true` if `&self` has been revoked with this call, `false` if it was revoked + /// already. + pub fn revoke(&self) -> bool { + // SAFETY: By passing `true` we ask `revoke_internal` to wait for the grace period to + // finish. + unsafe { self.revoke_internal::<true>() } + } +} + +#[pinned_drop] +impl<T> PinnedDrop for Revocable<T> { + fn drop(self: Pin<&mut Self>) { + // Drop only if the data hasn't been revoked yet (in which case it has already been + // dropped). + // SAFETY: We are not moving out of `p`, only dropping in place + let p = unsafe { self.get_unchecked_mut() }; + if *p.is_available.get_mut() { + // SAFETY: We know `self.data` is valid because no other CPU has changed + // `is_available` to `false` yet, and no other CPU can do it anymore because this CPU + // holds the only reference (mutable) to `self` now. + unsafe { drop_in_place(p.data.get()) }; + } + } +} + +/// A guard that allows access to a revocable object and keeps it alive. +/// +/// CPUs may not sleep while holding on to [`RevocableGuard`] because it's in atomic context +/// holding the RCU read-side lock. +/// +/// # Invariants +/// +/// The RCU read-side lock is held while the guard is alive. +pub struct RevocableGuard<'a, T> { + // This can't use the `&'a T` type because references that appear in function arguments must + // not become dangling during the execution of the function, which can happen if the + // `RevocableGuard` is passed as a function argument and then dropped during execution of the + // function. + data_ref: *const T, + _rcu_guard: rcu::Guard, + _p: PhantomData<&'a ()>, +} + +impl<T> RevocableGuard<'_, T> { + fn new(data_ref: *const T, rcu_guard: rcu::Guard) -> Self { + Self { + data_ref, + _rcu_guard: rcu_guard, + _p: PhantomData, + } + } +} + +impl<T> Deref for RevocableGuard<'_, T> { + type Target = T; + + fn deref(&self) -> &Self::Target { + // SAFETY: By the type invariants, we hold the rcu read-side lock, so the object is + // guaranteed to remain valid. + unsafe { &*self.data_ref } + } +} diff --git a/rust/kernel/scatterlist.rs b/rust/kernel/scatterlist.rs new file mode 100644 index 000000000000..196fdb9a75e7 --- /dev/null +++ b/rust/kernel/scatterlist.rs @@ -0,0 +1,491 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! Abstractions for scatter-gather lists. +//! +//! C header: [`include/linux/scatterlist.h`](srctree/include/linux/scatterlist.h) +//! +//! Scatter-gather (SG) I/O is a memory access technique that allows devices to perform DMA +//! operations on data buffers that are not physically contiguous in memory. It works by creating a +//! "scatter-gather list", an array where each entry specifies the address and length of a +//! physically contiguous memory segment. +//! +//! The device's DMA controller can then read this list and process the segments sequentially as +//! part of one logical I/O request. This avoids the need for a single, large, physically contiguous +//! memory buffer, which can be difficult or impossible to allocate. +//! +//! This module provides safe Rust abstractions over the kernel's `struct scatterlist` and +//! `struct sg_table` types. +//! +//! The main entry point is the [`SGTable`] type, which represents a complete scatter-gather table. +//! It can be either: +//! +//! - An owned table ([`SGTable<Owned<P>>`]), created from a Rust memory buffer (e.g., [`VVec`]). +//! This type manages the allocation of the `struct sg_table`, the DMA mapping of the buffer, and +//! the automatic cleanup of all resources. +//! - A borrowed reference (&[`SGTable`]), which provides safe, read-only access to a table that was +//! allocated by other (e.g., C) code. +//! +//! Individual entries in the table are represented by [`SGEntry`], which can be accessed by +//! iterating over an [`SGTable`]. + +use crate::{ + alloc, + alloc::allocator::VmallocPageIter, + bindings, + device::{Bound, Device}, + devres::Devres, + dma, error, + io::ResourceSize, + page, + prelude::*, + types::{ARef, Opaque}, +}; +use core::{ops::Deref, ptr::NonNull}; + +/// A single entry in a scatter-gather list. +/// +/// An `SGEntry` represents a single, physically contiguous segment of memory that has been mapped +/// for DMA. +/// +/// Instances of this struct are obtained by iterating over an [`SGTable`]. Drivers do not create +/// or own [`SGEntry`] objects directly. +#[repr(transparent)] +pub struct SGEntry(Opaque<bindings::scatterlist>); + +// SAFETY: `SGEntry` can be sent to any task. +unsafe impl Send for SGEntry {} + +// SAFETY: `SGEntry` has no interior mutability and can be accessed concurrently. +unsafe impl Sync for SGEntry {} + +impl SGEntry { + /// Convert a raw `struct scatterlist *` to a `&'a SGEntry`. + /// + /// # Safety + /// + /// Callers must ensure that the `struct scatterlist` pointed to by `ptr` is valid for the + /// lifetime `'a`. + #[inline] + unsafe fn from_raw<'a>(ptr: *mut bindings::scatterlist) -> &'a Self { + // SAFETY: The safety requirements of this function guarantee that `ptr` is a valid pointer + // to a `struct scatterlist` for the duration of `'a`. + unsafe { &*ptr.cast() } + } + + /// Obtain the raw `struct scatterlist *`. + #[inline] + fn as_raw(&self) -> *mut bindings::scatterlist { + self.0.get() + } + + /// Returns the DMA address of this SG entry. + /// + /// This is the address that the device should use to access the memory segment. + #[inline] + pub fn dma_address(&self) -> dma::DmaAddress { + // SAFETY: `self.as_raw()` is a valid pointer to a `struct scatterlist`. + unsafe { bindings::sg_dma_address(self.as_raw()) } + } + + /// Returns the length of this SG entry in bytes. + #[inline] + pub fn dma_len(&self) -> ResourceSize { + #[allow(clippy::useless_conversion)] + // SAFETY: `self.as_raw()` is a valid pointer to a `struct scatterlist`. + unsafe { bindings::sg_dma_len(self.as_raw()) }.into() + } +} + +/// The borrowed generic type of an [`SGTable`], representing a borrowed or externally managed +/// table. +#[repr(transparent)] +pub struct Borrowed(Opaque<bindings::sg_table>); + +// SAFETY: `Borrowed` can be sent to any task. +unsafe impl Send for Borrowed {} + +// SAFETY: `Borrowed` has no interior mutability and can be accessed concurrently. +unsafe impl Sync for Borrowed {} + +/// A scatter-gather table. +/// +/// This struct is a wrapper around the kernel's `struct sg_table`. It manages a list of DMA-mapped +/// memory segments that can be passed to a device for I/O operations. +/// +/// The generic parameter `T` is used as a generic type to distinguish between owned and borrowed +/// tables. +/// +/// - [`SGTable<Owned>`]: An owned table created and managed entirely by Rust code. It handles +/// allocation, DMA mapping, and cleanup of all associated resources. See [`SGTable::new`]. +/// - [`SGTable<Borrowed>`} (or simply [`SGTable`]): Represents a table whose lifetime is managed +/// externally. It can be used safely via a borrowed reference `&'a SGTable`, where `'a` is the +/// external lifetime. +/// +/// All [`SGTable`] variants can be iterated over the individual [`SGEntry`]s. +#[repr(transparent)] +#[pin_data] +pub struct SGTable<T: private::Sealed = Borrowed> { + #[pin] + inner: T, +} + +impl SGTable { + /// Creates a borrowed `&'a SGTable` from a raw `struct sg_table` pointer. + /// + /// This allows safe access to an `sg_table` that is managed elsewhere (for example, in C code). + /// + /// # Safety + /// + /// Callers must ensure that: + /// + /// - the `struct sg_table` pointed to by `ptr` is valid for the entire lifetime of `'a`, + /// - the data behind `ptr` is not modified concurrently for the duration of `'a`. + #[inline] + pub unsafe fn from_raw<'a>(ptr: *mut bindings::sg_table) -> &'a Self { + // SAFETY: The safety requirements of this function guarantee that `ptr` is a valid pointer + // to a `struct sg_table` for the duration of `'a`. + unsafe { &*ptr.cast() } + } + + #[inline] + fn as_raw(&self) -> *mut bindings::sg_table { + self.inner.0.get() + } + + /// Returns an [`SGTableIter`] bound to the lifetime of `self`. + pub fn iter(&self) -> SGTableIter<'_> { + // SAFETY: `self.as_raw()` is a valid pointer to a `struct sg_table`. + let nents = unsafe { (*self.as_raw()).nents }; + + let pos = if nents > 0 { + // SAFETY: `self.as_raw()` is a valid pointer to a `struct sg_table`. + let ptr = unsafe { (*self.as_raw()).sgl }; + + // SAFETY: `ptr` is guaranteed to be a valid pointer to a `struct scatterlist`. + Some(unsafe { SGEntry::from_raw(ptr) }) + } else { + None + }; + + SGTableIter { pos, nents } + } +} + +/// Represents the DMA mapping state of a `struct sg_table`. +/// +/// This is used as an inner type of [`Owned`] to manage the DMA mapping lifecycle. +/// +/// # Invariants +/// +/// - `sgt` is a valid pointer to a `struct sg_table` for the entire lifetime of the +/// [`DmaMappedSgt`]. +/// - `sgt` is always DMA mapped. +struct DmaMappedSgt { + sgt: NonNull<bindings::sg_table>, + dev: ARef<Device>, + dir: dma::DataDirection, +} + +// SAFETY: `DmaMappedSgt` can be sent to any task. +unsafe impl Send for DmaMappedSgt {} + +// SAFETY: `DmaMappedSgt` has no interior mutability and can be accessed concurrently. +unsafe impl Sync for DmaMappedSgt {} + +impl DmaMappedSgt { + /// # Safety + /// + /// - `sgt` must be a valid pointer to a `struct sg_table` for the entire lifetime of the + /// returned [`DmaMappedSgt`]. + /// - The caller must guarantee that `sgt` remains DMA mapped for the entire lifetime of + /// [`DmaMappedSgt`]. + unsafe fn new( + sgt: NonNull<bindings::sg_table>, + dev: &Device<Bound>, + dir: dma::DataDirection, + ) -> Result<Self> { + // SAFETY: + // - `dev.as_raw()` is a valid pointer to a `struct device`, which is guaranteed to be + // bound to a driver for the duration of this call. + // - `sgt` is a valid pointer to a `struct sg_table`. + error::to_result(unsafe { + bindings::dma_map_sgtable(dev.as_raw(), sgt.as_ptr(), dir.into(), 0) + })?; + + // INVARIANT: By the safety requirements of this function it is guaranteed that `sgt` is + // valid for the entire lifetime of this object instance. + Ok(Self { + sgt, + dev: dev.into(), + dir, + }) + } +} + +impl Drop for DmaMappedSgt { + #[inline] + fn drop(&mut self) { + // SAFETY: + // - `self.dev.as_raw()` is a pointer to a valid `struct device`. + // - `self.dev` is the same device the mapping has been created for in `Self::new()`. + // - `self.sgt.as_ptr()` is a valid pointer to a `struct sg_table` by the type invariants + // of `Self`. + // - `self.dir` is the same `dma::DataDirection` the mapping has been created with in + // `Self::new()`. + unsafe { + bindings::dma_unmap_sgtable(self.dev.as_raw(), self.sgt.as_ptr(), self.dir.into(), 0) + }; + } +} + +/// A transparent wrapper around a `struct sg_table`. +/// +/// While we could also create the `struct sg_table` in the constructor of [`Owned`], we can't tear +/// down the `struct sg_table` in [`Owned::drop`]; the drop order in [`Owned`] matters. +#[repr(transparent)] +struct RawSGTable(Opaque<bindings::sg_table>); + +// SAFETY: `RawSGTable` can be sent to any task. +unsafe impl Send for RawSGTable {} + +// SAFETY: `RawSGTable` has no interior mutability and can be accessed concurrently. +unsafe impl Sync for RawSGTable {} + +impl RawSGTable { + /// # Safety + /// + /// - `pages` must be a slice of valid `struct page *`. + /// - The pages pointed to by `pages` must remain valid for the entire lifetime of the returned + /// [`RawSGTable`]. + unsafe fn new( + pages: &mut [*mut bindings::page], + size: usize, + max_segment: u32, + flags: alloc::Flags, + ) -> Result<Self> { + // `sg_alloc_table_from_pages_segment()` expects at least one page, otherwise it + // produces a NPE. + if pages.is_empty() { + return Err(EINVAL); + } + + let sgt = Opaque::zeroed(); + // SAFETY: + // - `sgt.get()` is a valid pointer to uninitialized memory. + // - As by the check above, `pages` is not empty. + error::to_result(unsafe { + bindings::sg_alloc_table_from_pages_segment( + sgt.get(), + pages.as_mut_ptr(), + pages.len().try_into()?, + 0, + size, + max_segment, + flags.as_raw(), + ) + })?; + + Ok(Self(sgt)) + } + + #[inline] + fn as_raw(&self) -> *mut bindings::sg_table { + self.0.get() + } +} + +impl Drop for RawSGTable { + #[inline] + fn drop(&mut self) { + // SAFETY: `sgt` is a valid and initialized `struct sg_table`. + unsafe { bindings::sg_free_table(self.0.get()) }; + } +} + +/// The [`Owned`] generic type of an [`SGTable`]. +/// +/// A [`SGTable<Owned>`] signifies that the [`SGTable`] owns all associated resources: +/// +/// - The backing memory pages. +/// - The `struct sg_table` allocation (`sgt`). +/// - The DMA mapping, managed through a [`Devres`]-managed `DmaMappedSgt`. +/// +/// Users interact with this type through the [`SGTable`] handle and do not need to manage +/// [`Owned`] directly. +#[pin_data] +pub struct Owned<P> { + // Note: The drop order is relevant; we first have to unmap the `struct sg_table`, then free the + // `struct sg_table` and finally free the backing pages. + #[pin] + dma: Devres<DmaMappedSgt>, + sgt: RawSGTable, + _pages: P, +} + +// SAFETY: `Owned` can be sent to any task if `P` can be send to any task. +unsafe impl<P: Send> Send for Owned<P> {} + +// SAFETY: `Owned` has no interior mutability and can be accessed concurrently if `P` can be +// accessed concurrently. +unsafe impl<P: Sync> Sync for Owned<P> {} + +impl<P> Owned<P> +where + for<'a> P: page::AsPageIter<Iter<'a> = VmallocPageIter<'a>> + 'static, +{ + fn new( + dev: &Device<Bound>, + mut pages: P, + dir: dma::DataDirection, + flags: alloc::Flags, + ) -> Result<impl PinInit<Self, Error> + '_> { + let page_iter = pages.page_iter(); + let size = page_iter.size(); + + let mut page_vec: KVec<*mut bindings::page> = + KVec::with_capacity(page_iter.page_count(), flags)?; + + for page in page_iter { + page_vec.push(page.as_ptr(), flags)?; + } + + // `dma_max_mapping_size` returns `size_t`, but `sg_alloc_table_from_pages_segment()` takes + // an `unsigned int`. + // + // SAFETY: `dev.as_raw()` is a valid pointer to a `struct device`. + let max_segment = match unsafe { bindings::dma_max_mapping_size(dev.as_raw()) } { + 0 => u32::MAX, + max_segment => u32::try_from(max_segment).unwrap_or(u32::MAX), + }; + + Ok(try_pin_init!(&this in Self { + // SAFETY: + // - `page_vec` is a `KVec` of valid `struct page *` obtained from `pages`. + // - The pages contained in `pages` remain valid for the entire lifetime of the + // `RawSGTable`. + sgt: unsafe { RawSGTable::new(&mut page_vec, size, max_segment, flags) }?, + dma <- { + // SAFETY: `this` is a valid pointer to uninitialized memory. + let sgt = unsafe { &raw mut (*this.as_ptr()).sgt }.cast(); + + // SAFETY: `sgt` is guaranteed to be non-null. + let sgt = unsafe { NonNull::new_unchecked(sgt) }; + + // SAFETY: + // - It is guaranteed that the object returned by `DmaMappedSgt::new` won't out-live + // `sgt`. + // - `sgt` is never DMA unmapped manually. + Devres::new(dev, unsafe { DmaMappedSgt::new(sgt, dev, dir) }) + }, + _pages: pages, + })) + } +} + +impl<P> SGTable<Owned<P>> +where + for<'a> P: page::AsPageIter<Iter<'a> = VmallocPageIter<'a>> + 'static, +{ + /// Allocates a new scatter-gather table from the given pages and maps it for DMA. + /// + /// This constructor creates a new [`SGTable<Owned>`] that takes ownership of `P`. + /// It allocates a `struct sg_table`, populates it with entries corresponding to the physical + /// pages of `P`, and maps the table for DMA with the specified [`Device`] and + /// [`dma::DataDirection`]. + /// + /// The DMA mapping is managed through [`Devres`], ensuring that the DMA mapping is unmapped + /// once the associated [`Device`] is unbound, or when the [`SGTable<Owned>`] is dropped. + /// + /// # Parameters + /// + /// * `dev`: The [`Device`] that will be performing the DMA. + /// * `pages`: The entity providing the backing pages. It must implement [`page::AsPageIter`]. + /// The ownership of this entity is moved into the new [`SGTable<Owned>`]. + /// * `dir`: The [`dma::DataDirection`] of the DMA transfer. + /// * `flags`: Allocation flags for internal allocations (e.g., [`GFP_KERNEL`]). + /// + /// # Examples + /// + /// ``` + /// use kernel::{ + /// device::{Bound, Device}, + /// dma, page, + /// prelude::*, + /// scatterlist::{SGTable, Owned}, + /// }; + /// + /// fn test(dev: &Device<Bound>) -> Result { + /// let size = 4 * page::PAGE_SIZE; + /// let pages = VVec::<u8>::with_capacity(size, GFP_KERNEL)?; + /// + /// let sgt = KBox::pin_init(SGTable::new( + /// dev, + /// pages, + /// dma::DataDirection::ToDevice, + /// GFP_KERNEL, + /// ), GFP_KERNEL)?; + /// + /// Ok(()) + /// } + /// ``` + pub fn new( + dev: &Device<Bound>, + pages: P, + dir: dma::DataDirection, + flags: alloc::Flags, + ) -> impl PinInit<Self, Error> + '_ { + try_pin_init!(Self { + inner <- Owned::new(dev, pages, dir, flags)? + }) + } +} + +impl<P> Deref for SGTable<Owned<P>> { + type Target = SGTable; + + #[inline] + fn deref(&self) -> &Self::Target { + // SAFETY: + // - `self.inner.sgt.as_raw()` is a valid pointer to a `struct sg_table` for the entire + // lifetime of `self`. + // - The backing `struct sg_table` is not modified for the entire lifetime of `self`. + unsafe { SGTable::from_raw(self.inner.sgt.as_raw()) } + } +} + +mod private { + pub trait Sealed {} + + impl Sealed for super::Borrowed {} + impl<P> Sealed for super::Owned<P> {} +} + +/// An [`Iterator`] over the DMA mapped [`SGEntry`] items of an [`SGTable`]. +/// +/// Note that the existence of an [`SGTableIter`] does not guarantee that the [`SGEntry`] items +/// actually remain DMA mapped; they are prone to be unmapped on device unbind. +pub struct SGTableIter<'a> { + pos: Option<&'a SGEntry>, + /// The number of DMA mapped entries in a `struct sg_table`. + nents: c_uint, +} + +impl<'a> Iterator for SGTableIter<'a> { + type Item = &'a SGEntry; + + fn next(&mut self) -> Option<Self::Item> { + let entry = self.pos?; + self.nents = self.nents.saturating_sub(1); + + // SAFETY: `entry.as_raw()` is a valid pointer to a `struct scatterlist`. + let next = unsafe { bindings::sg_next(entry.as_raw()) }; + + self.pos = (!next.is_null() && self.nents > 0).then(|| { + // SAFETY: If `next` is not NULL, `sg_next()` guarantees to return a valid pointer to + // the next `struct scatterlist`. + unsafe { SGEntry::from_raw(next) } + }); + + Some(entry) + } +} diff --git a/rust/kernel/security.rs b/rust/kernel/security.rs new file mode 100644 index 000000000000..9d271695265f --- /dev/null +++ b/rust/kernel/security.rs @@ -0,0 +1,111 @@ +// SPDX-License-Identifier: GPL-2.0 + +// Copyright (C) 2024 Google LLC. + +//! Linux Security Modules (LSM). +//! +//! C header: [`include/linux/security.h`](srctree/include/linux/security.h). + +use crate::{ + bindings, + cred::Credential, + error::{to_result, Result}, + fs::File, +}; + +/// Calls the security modules to determine if the given task can become the manager of a binder +/// context. +#[inline] +pub fn binder_set_context_mgr(mgr: &Credential) -> Result { + // SAFETY: `mrg.0` is valid because the shared reference guarantees a nonzero refcount. + to_result(unsafe { bindings::security_binder_set_context_mgr(mgr.as_ptr()) }) +} + +/// Calls the security modules to determine if binder transactions are allowed from task `from` to +/// task `to`. +#[inline] +pub fn binder_transaction(from: &Credential, to: &Credential) -> Result { + // SAFETY: `from` and `to` are valid because the shared references guarantee nonzero refcounts. + to_result(unsafe { bindings::security_binder_transaction(from.as_ptr(), to.as_ptr()) }) +} + +/// Calls the security modules to determine if task `from` is allowed to send binder objects +/// (owned by itself or other processes) to task `to` through a binder transaction. +#[inline] +pub fn binder_transfer_binder(from: &Credential, to: &Credential) -> Result { + // SAFETY: `from` and `to` are valid because the shared references guarantee nonzero refcounts. + to_result(unsafe { bindings::security_binder_transfer_binder(from.as_ptr(), to.as_ptr()) }) +} + +/// Calls the security modules to determine if task `from` is allowed to send the given file to +/// task `to` (which would get its own file descriptor) through a binder transaction. +#[inline] +pub fn binder_transfer_file(from: &Credential, to: &Credential, file: &File) -> Result { + // SAFETY: `from`, `to` and `file` are valid because the shared references guarantee nonzero + // refcounts. + to_result(unsafe { + bindings::security_binder_transfer_file(from.as_ptr(), to.as_ptr(), file.as_ptr()) + }) +} + +/// A security context string. +/// +/// # Invariants +/// +/// The `ctx` field corresponds to a valid security context as returned by a successful call to +/// `security_secid_to_secctx`, that has not yet been released by `security_release_secctx`. +pub struct SecurityCtx { + ctx: bindings::lsm_context, +} + +impl SecurityCtx { + /// Get the security context given its id. + #[inline] + pub fn from_secid(secid: u32) -> Result<Self> { + // SAFETY: `struct lsm_context` can be initialized to all zeros. + let mut ctx: bindings::lsm_context = unsafe { core::mem::zeroed() }; + + // SAFETY: Just a C FFI call. The pointer is valid for writes. + to_result(unsafe { bindings::security_secid_to_secctx(secid, &mut ctx) })?; + + // INVARIANT: If the above call did not fail, then we have a valid security context. + Ok(Self { ctx }) + } + + /// Returns whether the security context is empty. + #[inline] + pub fn is_empty(&self) -> bool { + self.ctx.len == 0 + } + + /// Returns the length of this security context. + #[inline] + pub fn len(&self) -> usize { + self.ctx.len as usize + } + + /// Returns the bytes for this security context. + #[inline] + pub fn as_bytes(&self) -> &[u8] { + let ptr = self.ctx.context; + if ptr.is_null() { + debug_assert_eq!(self.len(), 0); + // We can't pass a null pointer to `slice::from_raw_parts` even if the length is zero. + return &[]; + } + + // SAFETY: The call to `security_secid_to_secctx` guarantees that the pointer is valid for + // `self.len()` bytes. Furthermore, if the length is zero, then we have ensured that the + // pointer is not null. + unsafe { core::slice::from_raw_parts(ptr.cast(), self.len()) } + } +} + +impl Drop for SecurityCtx { + #[inline] + fn drop(&mut self) { + // SAFETY: By the invariant of `Self`, this releases an lsm context that came from a + // successful call to `security_secid_to_secctx` and has not yet been released. + unsafe { bindings::security_release_secctx(&mut self.ctx) }; + } +} diff --git a/rust/kernel/seq_file.rs b/rust/kernel/seq_file.rs new file mode 100644 index 000000000000..855e533813a6 --- /dev/null +++ b/rust/kernel/seq_file.rs @@ -0,0 +1,53 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! Seq file bindings. +//! +//! C header: [`include/linux/seq_file.h`](srctree/include/linux/seq_file.h) + +use crate::{bindings, c_str, fmt, str::CStrExt as _, types::NotThreadSafe, types::Opaque}; + +/// A utility for generating the contents of a seq file. +#[repr(transparent)] +pub struct SeqFile { + inner: Opaque<bindings::seq_file>, + _not_send: NotThreadSafe, +} + +impl SeqFile { + /// Creates a new [`SeqFile`] from a raw pointer. + /// + /// # Safety + /// + /// The caller must ensure that for the duration of `'a` the following is satisfied: + /// * The pointer points at a valid `struct seq_file`. + /// * The `struct seq_file` is not accessed from any other thread. + pub unsafe fn from_raw<'a>(ptr: *mut bindings::seq_file) -> &'a SeqFile { + // SAFETY: The caller ensures that the reference is valid for 'a. There's no way to trigger + // a data race by using the `&SeqFile` since this is the only thread accessing the seq_file. + // + // CAST: The layout of `struct seq_file` and `SeqFile` is compatible. + unsafe { &*ptr.cast() } + } + + /// Used by the [`seq_print`] macro. + #[inline] + pub fn call_printf(&self, args: fmt::Arguments<'_>) { + // SAFETY: Passing a void pointer to `Arguments` is valid for `%pA`. + unsafe { + bindings::seq_printf( + self.inner.get(), + c_str!("%pA").as_char_ptr(), + core::ptr::from_ref(&args).cast::<crate::ffi::c_void>(), + ); + } + } +} + +/// Write to a [`SeqFile`] with the ordinary Rust formatting syntax. +#[macro_export] +macro_rules! seq_print { + ($m:expr, $($arg:tt)+) => ( + $m.call_printf($crate::prelude::fmt!($($arg)+)) + ); +} +pub use seq_print; diff --git a/rust/kernel/sizes.rs b/rust/kernel/sizes.rs new file mode 100644 index 000000000000..661e680d9330 --- /dev/null +++ b/rust/kernel/sizes.rs @@ -0,0 +1,50 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! Commonly used sizes. +//! +//! C headers: [`include/linux/sizes.h`](srctree/include/linux/sizes.h). + +/// 0x00000400 +pub const SZ_1K: usize = bindings::SZ_1K as usize; +/// 0x00000800 +pub const SZ_2K: usize = bindings::SZ_2K as usize; +/// 0x00001000 +pub const SZ_4K: usize = bindings::SZ_4K as usize; +/// 0x00002000 +pub const SZ_8K: usize = bindings::SZ_8K as usize; +/// 0x00004000 +pub const SZ_16K: usize = bindings::SZ_16K as usize; +/// 0x00008000 +pub const SZ_32K: usize = bindings::SZ_32K as usize; +/// 0x00010000 +pub const SZ_64K: usize = bindings::SZ_64K as usize; +/// 0x00020000 +pub const SZ_128K: usize = bindings::SZ_128K as usize; +/// 0x00040000 +pub const SZ_256K: usize = bindings::SZ_256K as usize; +/// 0x00080000 +pub const SZ_512K: usize = bindings::SZ_512K as usize; +/// 0x00100000 +pub const SZ_1M: usize = bindings::SZ_1M as usize; +/// 0x00200000 +pub const SZ_2M: usize = bindings::SZ_2M as usize; +/// 0x00400000 +pub const SZ_4M: usize = bindings::SZ_4M as usize; +/// 0x00800000 +pub const SZ_8M: usize = bindings::SZ_8M as usize; +/// 0x01000000 +pub const SZ_16M: usize = bindings::SZ_16M as usize; +/// 0x02000000 +pub const SZ_32M: usize = bindings::SZ_32M as usize; +/// 0x04000000 +pub const SZ_64M: usize = bindings::SZ_64M as usize; +/// 0x08000000 +pub const SZ_128M: usize = bindings::SZ_128M as usize; +/// 0x10000000 +pub const SZ_256M: usize = bindings::SZ_256M as usize; +/// 0x20000000 +pub const SZ_512M: usize = bindings::SZ_512M as usize; +/// 0x40000000 +pub const SZ_1G: usize = bindings::SZ_1G as usize; +/// 0x80000000 +pub const SZ_2G: usize = bindings::SZ_2G as usize; diff --git a/rust/kernel/slice.rs b/rust/kernel/slice.rs new file mode 100644 index 000000000000..ca2cde135061 --- /dev/null +++ b/rust/kernel/slice.rs @@ -0,0 +1,49 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! Additional (and temporary) slice helpers. + +/// Extension trait providing a portable version of [`as_flattened`] and +/// [`as_flattened_mut`]. +/// +/// In Rust 1.80, the previously unstable `slice::flatten` family of methods +/// have been stabilized and renamed from `flatten` to `as_flattened`. +/// +/// This creates an issue for as long as the MSRV is < 1.80, as the same functionality is provided +/// by different methods depending on the compiler version. +/// +/// This extension trait solves this by abstracting `as_flatten` and calling the correct method +/// depending on the Rust version. +/// +/// This trait can be removed once the MSRV passes 1.80. +/// +/// [`as_flattened`]: https://doc.rust-lang.org/std/primitive.slice.html#method.as_flattened +/// [`as_flattened_mut`]: https://doc.rust-lang.org/std/primitive.slice.html#method.as_flattened_mut +#[cfg(not(CONFIG_RUSTC_HAS_SLICE_AS_FLATTENED))] +pub trait AsFlattened<T> { + /// Takes a `&[[T; N]]` and flattens it to a `&[T]`. + /// + /// This is an portable layer on top of [`as_flattened`]; see its documentation for details. + /// + /// [`as_flattened`]: https://doc.rust-lang.org/std/primitive.slice.html#method.as_flattened + fn as_flattened(&self) -> &[T]; + + /// Takes a `&mut [[T; N]]` and flattens it to a `&mut [T]`. + /// + /// This is an portable layer on top of [`as_flattened_mut`]; see its documentation for details. + /// + /// [`as_flattened_mut`]: https://doc.rust-lang.org/std/primitive.slice.html#method.as_flattened_mut + fn as_flattened_mut(&mut self) -> &mut [T]; +} + +#[cfg(not(CONFIG_RUSTC_HAS_SLICE_AS_FLATTENED))] +impl<T, const N: usize> AsFlattened<T> for [[T; N]] { + #[allow(clippy::incompatible_msrv)] + fn as_flattened(&self) -> &[T] { + self.flatten() + } + + #[allow(clippy::incompatible_msrv)] + fn as_flattened_mut(&mut self) -> &mut [T] { + self.flatten_mut() + } +} diff --git a/rust/kernel/static_assert.rs b/rust/kernel/static_assert.rs index 3115ee0ba8e9..a57ba14315a0 100644 --- a/rust/kernel/static_assert.rs +++ b/rust/kernel/static_assert.rs @@ -6,6 +6,10 @@ /// /// Similar to C11 [`_Static_assert`] and C++11 [`static_assert`]. /// +/// An optional panic message can be supplied after the expression. +/// Currently only a string literal without formatting is supported +/// due to constness limitations of the [`assert!`] macro. +/// /// The feature may be added to Rust in the future: see [RFC 2790]. /// /// [`_Static_assert`]: https://en.cppreference.com/w/c/language/_Static_assert @@ -25,10 +29,11 @@ /// x + 2 /// } /// static_assert!(f(40) == 42); +/// static_assert!(f(40) == 42, "f(x) must add 2 to the given input."); /// ``` #[macro_export] macro_rules! static_assert { - ($condition:expr) => { - const _: () = core::assert!($condition); + ($condition:expr $(,$arg:literal)?) => { + const _: () = ::core::assert!($condition $(,$arg)?); }; } diff --git a/rust/kernel/std_vendor.rs b/rust/kernel/std_vendor.rs index b3e68b24a8c6..abbab5050cc5 100644 --- a/rust/kernel/std_vendor.rs +++ b/rust/kernel/std_vendor.rs @@ -1,5 +1,7 @@ // SPDX-License-Identifier: Apache-2.0 OR MIT +//! Rust standard library vendored code. +//! //! The contents of this file come from the Rust standard library, hosted in //! the <https://github.com/rust-lang/rust> repository, licensed under //! "Apache-2.0 OR MIT" and adapted for kernel use. For copyright details, @@ -14,9 +16,9 @@ /// /// ```rust /// let a = 2; -/// # #[allow(clippy::dbg_macro)] +/// # #[expect(clippy::disallowed_macros)] /// let b = dbg!(a * 2) + 1; -/// // ^-- prints: [src/main.rs:2] a * 2 = 4 +/// // ^-- prints: [src/main.rs:3:9] a * 2 = 4 /// assert_eq!(b, 5); /// ``` /// @@ -52,7 +54,7 @@ /// With a method call: /// /// ```rust -/// # #[allow(clippy::dbg_macro)] +/// # #[expect(clippy::disallowed_macros)] /// fn foo(n: usize) { /// if dbg!(n.checked_sub(4)).is_some() { /// // ... @@ -65,14 +67,13 @@ /// This prints to the kernel log: /// /// ```text,ignore -/// [src/main.rs:4] n.checked_sub(4) = None +/// [src/main.rs:3:8] n.checked_sub(4) = None /// ``` /// /// Naive factorial implementation: /// /// ```rust -/// # #[allow(clippy::dbg_macro)] -/// # { +/// # #![expect(clippy::disallowed_macros)] /// fn factorial(n: u32) -> u32 { /// if dbg!(n <= 1) { /// dbg!(1) @@ -82,21 +83,20 @@ /// } /// /// dbg!(factorial(4)); -/// # } /// ``` /// /// This prints to the kernel log: /// /// ```text,ignore -/// [src/main.rs:3] n <= 1 = false -/// [src/main.rs:3] n <= 1 = false -/// [src/main.rs:3] n <= 1 = false -/// [src/main.rs:3] n <= 1 = true -/// [src/main.rs:4] 1 = 1 -/// [src/main.rs:5] n * factorial(n - 1) = 2 -/// [src/main.rs:5] n * factorial(n - 1) = 6 -/// [src/main.rs:5] n * factorial(n - 1) = 24 -/// [src/main.rs:11] factorial(4) = 24 +/// [src/main.rs:3:8] n <= 1 = false +/// [src/main.rs:3:8] n <= 1 = false +/// [src/main.rs:3:8] n <= 1 = false +/// [src/main.rs:3:8] n <= 1 = true +/// [src/main.rs:4:9] 1 = 1 +/// [src/main.rs:5:9] n * factorial(n - 1) = 2 +/// [src/main.rs:5:9] n * factorial(n - 1) = 6 +/// [src/main.rs:5:9] n * factorial(n - 1) = 24 +/// [src/main.rs:11:1] factorial(4) = 24 /// ``` /// /// The `dbg!(..)` macro moves the input: @@ -118,7 +118,7 @@ /// a tuple (and return it, too): /// /// ``` -/// # #[allow(clippy::dbg_macro)] +/// # #![expect(clippy::disallowed_macros)] /// assert_eq!(dbg!(1usize, 2u32), (1, 2)); /// ``` /// @@ -127,16 +127,16 @@ /// invocations. You can use a 1-tuple directly if you need one: /// /// ``` -/// # #[allow(clippy::dbg_macro)] -/// # { +/// # #![expect(clippy::disallowed_macros)] /// assert_eq!(1, dbg!(1u32,)); // trailing comma ignored /// assert_eq!((1,), dbg!((1u32,))); // 1-tuple -/// # } /// ``` /// /// [`std::dbg`]: https://doc.rust-lang.org/std/macro.dbg.html /// [`eprintln`]: https://doc.rust-lang.org/std/macro.eprintln.html -/// [`printk`]: https://www.kernel.org/doc/html/latest/core-api/printk-basics.html +/// [`printk`]: https://docs.kernel.org/core-api/printk-basics.html +/// [`pr_info`]: crate::pr_info! +/// [`pr_debug`]: crate::pr_debug! #[macro_export] macro_rules! dbg { // NOTE: We cannot use `concat!` to make a static string as a format argument @@ -144,15 +144,16 @@ macro_rules! dbg { // `$val` expression could be a block (`{ .. }`), in which case the `pr_info!` // will be malformed. () => { - $crate::pr_info!("[{}:{}]\n", ::core::file!(), ::core::line!()) + $crate::pr_info!("[{}:{}:{}]\n", ::core::file!(), ::core::line!(), ::core::column!()) }; ($val:expr $(,)?) => { // Use of `match` here is intentional because it affects the lifetimes - // of temporaries - https://stackoverflow.com/a/48732525/1063961 + // of temporaries - <https://stackoverflow.com/a/48732525/1063961> match $val { tmp => { - $crate::pr_info!("[{}:{}] {} = {:#?}\n", - ::core::file!(), ::core::line!(), ::core::stringify!($val), &tmp); + $crate::pr_info!("[{}:{}:{}] {} = {:#?}\n", + ::core::file!(), ::core::line!(), ::core::column!(), + ::core::stringify!($val), &tmp); tmp } } diff --git a/rust/kernel/str.rs b/rust/kernel/str.rs index b771310fa4a4..fa87779d2253 100644 --- a/rust/kernel/str.rs +++ b/rust/kernel/str.rs @@ -2,19 +2,166 @@ //! String representations. -use alloc::vec::Vec; -use core::fmt::{self, Write}; -use core::ops::{self, Deref, Index}; - use crate::{ - bindings, - error::{code::*, Error}, + alloc::{flags::*, AllocError, KVec}, + error::{to_result, Result}, + fmt::{self, Write}, + prelude::*, +}; +use core::{ + marker::PhantomData, + ops::{Deref, DerefMut, Index}, }; +pub use crate::prelude::CStr; + +pub mod parse_int; + /// Byte string without UTF-8 validity guarantee. -/// -/// `BStr` is simply an alias to `[u8]`, but has a more evident semantical meaning. -pub type BStr = [u8]; +#[repr(transparent)] +pub struct BStr([u8]); + +impl BStr { + /// Returns the length of this string. + #[inline] + pub const fn len(&self) -> usize { + self.0.len() + } + + /// Returns `true` if the string is empty. + #[inline] + pub const fn is_empty(&self) -> bool { + self.len() == 0 + } + + /// Creates a [`BStr`] from a `[u8]`. + #[inline] + pub const fn from_bytes(bytes: &[u8]) -> &Self { + // SAFETY: `BStr` is transparent to `[u8]`. + unsafe { &*(core::ptr::from_ref(bytes) as *const BStr) } + } + + /// Strip a prefix from `self`. Delegates to [`slice::strip_prefix`]. + /// + /// # Examples + /// + /// ``` + /// # use kernel::b_str; + /// assert_eq!(Some(b_str!("bar")), b_str!("foobar").strip_prefix(b_str!("foo"))); + /// assert_eq!(None, b_str!("foobar").strip_prefix(b_str!("bar"))); + /// assert_eq!(Some(b_str!("foobar")), b_str!("foobar").strip_prefix(b_str!(""))); + /// assert_eq!(Some(b_str!("")), b_str!("foobar").strip_prefix(b_str!("foobar"))); + /// ``` + pub fn strip_prefix(&self, pattern: impl AsRef<Self>) -> Option<&BStr> { + self.deref() + .strip_prefix(pattern.as_ref().deref()) + .map(Self::from_bytes) + } +} + +impl fmt::Display for BStr { + /// Formats printable ASCII characters, escaping the rest. + /// + /// ``` + /// # use kernel::{prelude::fmt, b_str, str::{BStr, CString}}; + /// let ascii = b_str!("Hello, BStr!"); + /// let s = CString::try_from_fmt(fmt!("{ascii}"))?; + /// assert_eq!(s.to_bytes(), "Hello, BStr!".as_bytes()); + /// + /// let non_ascii = b_str!("🦀"); + /// let s = CString::try_from_fmt(fmt!("{non_ascii}"))?; + /// assert_eq!(s.to_bytes(), "\\xf0\\x9f\\xa6\\x80".as_bytes()); + /// # Ok::<(), kernel::error::Error>(()) + /// ``` + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + for &b in &self.0 { + match b { + // Common escape codes. + b'\t' => f.write_str("\\t")?, + b'\n' => f.write_str("\\n")?, + b'\r' => f.write_str("\\r")?, + // Printable characters. + 0x20..=0x7e => f.write_char(b as char)?, + _ => write!(f, "\\x{b:02x}")?, + } + } + Ok(()) + } +} + +impl fmt::Debug for BStr { + /// Formats printable ASCII characters with a double quote on either end, + /// escaping the rest. + /// + /// ``` + /// # use kernel::{prelude::fmt, b_str, str::{BStr, CString}}; + /// // Embedded double quotes are escaped. + /// let ascii = b_str!("Hello, \"BStr\"!"); + /// let s = CString::try_from_fmt(fmt!("{ascii:?}"))?; + /// assert_eq!(s.to_bytes(), "\"Hello, \\\"BStr\\\"!\"".as_bytes()); + /// + /// let non_ascii = b_str!("😺"); + /// let s = CString::try_from_fmt(fmt!("{non_ascii:?}"))?; + /// assert_eq!(s.to_bytes(), "\"\\xf0\\x9f\\x98\\xba\"".as_bytes()); + /// # Ok::<(), kernel::error::Error>(()) + /// ``` + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_char('"')?; + for &b in &self.0 { + match b { + // Common escape codes. + b'\t' => f.write_str("\\t")?, + b'\n' => f.write_str("\\n")?, + b'\r' => f.write_str("\\r")?, + // String escape characters. + b'\"' => f.write_str("\\\"")?, + b'\\' => f.write_str("\\\\")?, + // Printable characters. + 0x20..=0x7e => f.write_char(b as char)?, + _ => write!(f, "\\x{b:02x}")?, + } + } + f.write_char('"') + } +} + +impl Deref for BStr { + type Target = [u8]; + + #[inline] + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl PartialEq for BStr { + fn eq(&self, other: &Self) -> bool { + self.deref().eq(other.deref()) + } +} + +impl<Idx> Index<Idx> for BStr +where + [u8]: Index<Idx, Output = [u8]>, +{ + type Output = Self; + + fn index(&self, index: Idx) -> &Self::Output { + BStr::from_bytes(&self.0[index]) + } +} + +impl AsRef<BStr> for [u8] { + fn as_ref(&self) -> &BStr { + BStr::from_bytes(self) + } +} + +impl AsRef<BStr> for BStr { + fn as_ref(&self) -> &BStr { + self + } +} /// Creates a new [`BStr`] from a string literal. /// @@ -32,60 +179,28 @@ pub type BStr = [u8]; macro_rules! b_str { ($str:literal) => {{ const S: &'static str = $str; - const C: &'static $crate::str::BStr = S.as_bytes(); + const C: &'static $crate::str::BStr = $crate::str::BStr::from_bytes(S.as_bytes()); C }}; } -/// Possible errors when using conversion functions in [`CStr`]. -#[derive(Debug, Clone, Copy)] -pub enum CStrConvertError { - /// Supplied bytes contain an interior `NUL`. - InteriorNul, - - /// Supplied bytes are not terminated by `NUL`. - NotNulTerminated, -} - -impl From<CStrConvertError> for Error { - #[inline] - fn from(_: CStrConvertError) -> Error { - EINVAL - } +/// Returns a C pointer to the string. +// It is a free function rather than a method on an extension trait because: +// +// - error[E0379]: functions in trait impls cannot be declared const +#[inline] +pub const fn as_char_ptr_in_const_context(c_str: &CStr) -> *const c_char { + c_str.as_ptr().cast() } -/// A string that is guaranteed to have exactly one `NUL` byte, which is at the -/// end. -/// -/// Used for interoperability with kernel APIs that take C strings. -#[repr(transparent)] -pub struct CStr([u8]); - -impl CStr { - /// Returns the length of this string excluding `NUL`. - #[inline] - pub const fn len(&self) -> usize { - self.len_with_nul() - 1 - } - - /// Returns the length of this string with `NUL`. - #[inline] - pub const fn len_with_nul(&self) -> usize { - // SAFETY: This is one of the invariant of `CStr`. - // We add a `unreachable_unchecked` here to hint the optimizer that - // the value returned from this function is non-zero. - if self.0.is_empty() { - unsafe { core::hint::unreachable_unchecked() }; - } - self.0.len() - } +mod private { + pub trait Sealed {} - /// Returns `true` if the string only includes `NUL`. - #[inline] - pub const fn is_empty(&self) -> bool { - self.len() == 0 - } + impl Sealed for super::CStr {} +} +/// Extensions to [`CStr`]. +pub trait CStrExt: private::Sealed { /// Wraps a raw C string pointer. /// /// # Safety @@ -93,232 +208,171 @@ impl CStr { /// `ptr` must be a valid pointer to a `NUL`-terminated C string, and it must /// last at least `'a`. When `CStr` is alive, the memory pointed by `ptr` /// must not be mutated. - #[inline] - pub unsafe fn from_char_ptr<'a>(ptr: *const core::ffi::c_char) -> &'a Self { - // SAFETY: The safety precondition guarantees `ptr` is a valid pointer - // to a `NUL`-terminated C string. - let len = unsafe { bindings::strlen(ptr) } + 1; - // SAFETY: Lifetime guaranteed by the safety precondition. - let bytes = unsafe { core::slice::from_raw_parts(ptr as _, len as _) }; - // SAFETY: As `len` is returned by `strlen`, `bytes` does not contain interior `NUL`. - // As we have added 1 to `len`, the last byte is known to be `NUL`. - unsafe { Self::from_bytes_with_nul_unchecked(bytes) } - } - - /// Creates a [`CStr`] from a `[u8]`. - /// - /// The provided slice must be `NUL`-terminated, does not contain any - /// interior `NUL` bytes. - pub const fn from_bytes_with_nul(bytes: &[u8]) -> Result<&Self, CStrConvertError> { - if bytes.is_empty() { - return Err(CStrConvertError::NotNulTerminated); - } - if bytes[bytes.len() - 1] != 0 { - return Err(CStrConvertError::NotNulTerminated); - } - let mut i = 0; - // `i + 1 < bytes.len()` allows LLVM to optimize away bounds checking, - // while it couldn't optimize away bounds checks for `i < bytes.len() - 1`. - while i + 1 < bytes.len() { - if bytes[i] == 0 { - return Err(CStrConvertError::InteriorNul); - } - i += 1; - } - // SAFETY: We just checked that all properties hold. - Ok(unsafe { Self::from_bytes_with_nul_unchecked(bytes) }) - } + // This function exists to paper over the fact that `CStr::from_ptr` takes a `*const + // core::ffi::c_char` rather than a `*const crate::ffi::c_char`. + unsafe fn from_char_ptr<'a>(ptr: *const c_char) -> &'a Self; - /// Creates a [`CStr`] from a `[u8]` without performing any additional - /// checks. + /// Creates a mutable [`CStr`] from a `[u8]` without performing any + /// additional checks. /// /// # Safety /// /// `bytes` *must* end with a `NUL` byte, and should only have a single /// `NUL` byte (or the string will be truncated). - #[inline] - pub const unsafe fn from_bytes_with_nul_unchecked(bytes: &[u8]) -> &CStr { - // SAFETY: Properties of `bytes` guaranteed by the safety precondition. - unsafe { core::mem::transmute(bytes) } - } + unsafe fn from_bytes_with_nul_unchecked_mut(bytes: &mut [u8]) -> &mut Self; /// Returns a C pointer to the string. - #[inline] - pub const fn as_char_ptr(&self) -> *const core::ffi::c_char { - self.0.as_ptr() as _ - } + // This function exists to paper over the fact that `CStr::as_ptr` returns a `*const + // core::ffi::c_char` rather than a `*const crate::ffi::c_char`. + fn as_char_ptr(&self) -> *const c_char; - /// Convert the string to a byte slice without the trailing 0 byte. - #[inline] - pub fn as_bytes(&self) -> &[u8] { - &self.0[..self.len()] - } + /// Convert this [`CStr`] into a [`CString`] by allocating memory and + /// copying over the string data. + fn to_cstring(&self) -> Result<CString, AllocError>; - /// Convert the string to a byte slice containing the trailing 0 byte. - #[inline] - pub const fn as_bytes_with_nul(&self) -> &[u8] { - &self.0 - } + /// Converts this [`CStr`] to its ASCII lower case equivalent in-place. + /// + /// ASCII letters 'A' to 'Z' are mapped to 'a' to 'z', + /// but non-ASCII letters are unchanged. + /// + /// To return a new lowercased value without modifying the existing one, use + /// [`to_ascii_lowercase()`]. + /// + /// [`to_ascii_lowercase()`]: #method.to_ascii_lowercase + fn make_ascii_lowercase(&mut self); - /// Yields a [`&str`] slice if the [`CStr`] contains valid UTF-8. + /// Converts this [`CStr`] to its ASCII upper case equivalent in-place. /// - /// If the contents of the [`CStr`] are valid UTF-8 data, this - /// function will return the corresponding [`&str`] slice. Otherwise, - /// it will return an error with details of where UTF-8 validation failed. + /// ASCII letters 'a' to 'z' are mapped to 'A' to 'Z', + /// but non-ASCII letters are unchanged. /// - /// # Examples + /// To return a new uppercased value without modifying the existing one, use + /// [`to_ascii_uppercase()`]. /// - /// ``` - /// # use kernel::str::CStr; - /// let cstr = CStr::from_bytes_with_nul(b"foo\0").unwrap(); - /// assert_eq!(cstr.to_str(), Ok("foo")); - /// ``` - #[inline] - pub fn to_str(&self) -> Result<&str, core::str::Utf8Error> { - core::str::from_utf8(self.as_bytes()) - } + /// [`to_ascii_uppercase()`]: #method.to_ascii_uppercase + fn make_ascii_uppercase(&mut self); - /// Unsafely convert this [`CStr`] into a [`&str`], without checking for - /// valid UTF-8. + /// Returns a copy of this [`CString`] where each character is mapped to its + /// ASCII lower case equivalent. /// - /// # Safety + /// ASCII letters 'A' to 'Z' are mapped to 'a' to 'z', + /// but non-ASCII letters are unchanged. /// - /// The contents must be valid UTF-8. + /// To lowercase the value in-place, use [`make_ascii_lowercase`]. /// - /// # Examples + /// [`make_ascii_lowercase`]: str::make_ascii_lowercase + fn to_ascii_lowercase(&self) -> Result<CString, AllocError>; + + /// Returns a copy of this [`CString`] where each character is mapped to its + /// ASCII upper case equivalent. /// - /// ``` - /// # use kernel::c_str; - /// # use kernel::str::CStr; - /// // SAFETY: String literals are guaranteed to be valid UTF-8 - /// // by the Rust compiler. - /// let bar = c_str!("ツ"); - /// assert_eq!(unsafe { bar.as_str_unchecked() }, "ツ"); - /// ``` - #[inline] - pub unsafe fn as_str_unchecked(&self) -> &str { - unsafe { core::str::from_utf8_unchecked(self.as_bytes()) } - } + /// ASCII letters 'a' to 'z' are mapped to 'A' to 'Z', + /// but non-ASCII letters are unchanged. + /// + /// To uppercase the value in-place, use [`make_ascii_uppercase`]. + /// + /// [`make_ascii_uppercase`]: str::make_ascii_uppercase + fn to_ascii_uppercase(&self) -> Result<CString, AllocError>; } impl fmt::Display for CStr { /// Formats printable ASCII characters, escaping the rest. /// /// ``` - /// # use kernel::c_str; + /// # use kernel::prelude::fmt; /// # use kernel::str::CStr; /// # use kernel::str::CString; - /// let penguin = c_str!("🐧"); - /// let s = CString::try_from_fmt(fmt!("{}", penguin)).unwrap(); - /// assert_eq!(s.as_bytes_with_nul(), "\\xf0\\x9f\\x90\\xa7\0".as_bytes()); + /// let penguin = c"🐧"; + /// let s = CString::try_from_fmt(fmt!("{penguin}"))?; + /// assert_eq!(s.to_bytes_with_nul(), "\\xf0\\x9f\\x90\\xa7\0".as_bytes()); /// - /// let ascii = c_str!("so \"cool\""); - /// let s = CString::try_from_fmt(fmt!("{}", ascii)).unwrap(); - /// assert_eq!(s.as_bytes_with_nul(), "so \"cool\"\0".as_bytes()); + /// let ascii = c"so \"cool\""; + /// let s = CString::try_from_fmt(fmt!("{ascii}"))?; + /// assert_eq!(s.to_bytes_with_nul(), "so \"cool\"\0".as_bytes()); + /// # Ok::<(), kernel::error::Error>(()) /// ``` fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - for &c in self.as_bytes() { + for &c in self.to_bytes() { if (0x20..0x7f).contains(&c) { // Printable character. f.write_char(c as char)?; } else { - write!(f, "\\x{:02x}", c)?; + write!(f, "\\x{c:02x}")?; } } Ok(()) } } -impl fmt::Debug for CStr { - /// Formats printable ASCII characters with a double quote on either end, escaping the rest. - /// - /// ``` - /// # use kernel::c_str; - /// # use kernel::str::CStr; - /// # use kernel::str::CString; - /// let penguin = c_str!("🐧"); - /// let s = CString::try_from_fmt(fmt!("{:?}", penguin)).unwrap(); - /// assert_eq!(s.as_bytes_with_nul(), "\"\\xf0\\x9f\\x90\\xa7\"\0".as_bytes()); - /// - /// // Embedded double quotes are escaped. - /// let ascii = c_str!("so \"cool\""); - /// let s = CString::try_from_fmt(fmt!("{:?}", ascii)).unwrap(); - /// assert_eq!(s.as_bytes_with_nul(), "\"so \\\"cool\\\"\"\0".as_bytes()); - /// ``` - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.write_str("\"")?; - for &c in self.as_bytes() { - match c { - // Printable characters. - b'\"' => f.write_str("\\\"")?, - 0x20..=0x7e => f.write_char(c as char)?, - _ => write!(f, "\\x{:02x}", c)?, - } - } - f.write_str("\"") - } +/// Converts a mutable C string to a mutable byte slice. +/// +/// # Safety +/// +/// The caller must ensure that the slice ends in a NUL byte and contains no other NUL bytes before +/// the borrow ends and the underlying [`CStr`] is used. +unsafe fn to_bytes_mut(s: &mut CStr) -> &mut [u8] { + // SAFETY: the cast from `&CStr` to `&[u8]` is safe since `CStr` has the same layout as `&[u8]` + // (this is technically not guaranteed, but we rely on it here). The pointer dereference is + // safe since it comes from a mutable reference which is guaranteed to be valid for writes. + unsafe { &mut *(core::ptr::from_mut(s) as *mut [u8]) } } -impl AsRef<BStr> for CStr { +impl CStrExt for CStr { #[inline] - fn as_ref(&self) -> &BStr { - self.as_bytes() + unsafe fn from_char_ptr<'a>(ptr: *const c_char) -> &'a Self { + // SAFETY: The safety preconditions are the same as for `CStr::from_ptr`. + unsafe { CStr::from_ptr(ptr.cast()) } } -} -impl Deref for CStr { - type Target = BStr; + #[inline] + unsafe fn from_bytes_with_nul_unchecked_mut(bytes: &mut [u8]) -> &mut Self { + // SAFETY: the cast from `&[u8]` to `&CStr` is safe since the properties of `bytes` are + // guaranteed by the safety precondition and `CStr` has the same layout as `&[u8]` (this is + // technically not guaranteed, but we rely on it here). The pointer dereference is safe + // since it comes from a mutable reference which is guaranteed to be valid for writes. + unsafe { &mut *(core::ptr::from_mut(bytes) as *mut CStr) } + } #[inline] - fn deref(&self) -> &Self::Target { - self.as_bytes() + fn as_char_ptr(&self) -> *const c_char { + self.as_ptr().cast() + } + + fn to_cstring(&self) -> Result<CString, AllocError> { + CString::try_from(self) } -} -impl Index<ops::RangeFrom<usize>> for CStr { - type Output = CStr; + fn make_ascii_lowercase(&mut self) { + // SAFETY: This doesn't introduce or remove NUL bytes in the C string. + unsafe { to_bytes_mut(self) }.make_ascii_lowercase(); + } - #[inline] - fn index(&self, index: ops::RangeFrom<usize>) -> &Self::Output { - // Delegate bounds checking to slice. - // Assign to _ to mute clippy's unnecessary operation warning. - let _ = &self.as_bytes()[index.start..]; - // SAFETY: We just checked the bounds. - unsafe { Self::from_bytes_with_nul_unchecked(&self.0[index.start..]) } + fn make_ascii_uppercase(&mut self) { + // SAFETY: This doesn't introduce or remove NUL bytes in the C string. + unsafe { to_bytes_mut(self) }.make_ascii_uppercase(); } -} -impl Index<ops::RangeFull> for CStr { - type Output = CStr; + fn to_ascii_lowercase(&self) -> Result<CString, AllocError> { + let mut s = self.to_cstring()?; - #[inline] - fn index(&self, _index: ops::RangeFull) -> &Self::Output { - self + s.make_ascii_lowercase(); + + Ok(s) } -} -mod private { - use core::ops; + fn to_ascii_uppercase(&self) -> Result<CString, AllocError> { + let mut s = self.to_cstring()?; - // Marker trait for index types that can be forward to `BStr`. - pub trait CStrIndex {} + s.make_ascii_uppercase(); - impl CStrIndex for usize {} - impl CStrIndex for ops::Range<usize> {} - impl CStrIndex for ops::RangeInclusive<usize> {} - impl CStrIndex for ops::RangeToInclusive<usize> {} + Ok(s) + } } -impl<Idx> Index<Idx> for CStr -where - Idx: private::CStrIndex, - BStr: Index<Idx>, -{ - type Output = <BStr as Index<Idx>>::Output; - +impl AsRef<BStr> for CStr { #[inline] - fn index(&self, index: Idx) -> &Self::Output { - &self.as_bytes()[index] + fn as_ref(&self) -> &BStr { + BStr::from_bytes(self.to_bytes()) } } @@ -345,32 +399,116 @@ macro_rules! c_str { }}; } -#[cfg(test)] +#[kunit_tests(rust_kernel_str)] mod tests { use super::*; + impl From<core::ffi::FromBytesWithNulError> for Error { + #[inline] + fn from(_: core::ffi::FromBytesWithNulError) -> Error { + EINVAL + } + } + + macro_rules! format { + ($($f:tt)*) => ({ + CString::try_from_fmt(fmt!($($f)*))?.to_str()? + }) + } + + const ALL_ASCII_CHARS: &str = + "\\x01\\x02\\x03\\x04\\x05\\x06\\x07\\x08\\x09\\x0a\\x0b\\x0c\\x0d\\x0e\\x0f\ + \\x10\\x11\\x12\\x13\\x14\\x15\\x16\\x17\\x18\\x19\\x1a\\x1b\\x1c\\x1d\\x1e\\x1f \ + !\"#$%&'()*+,-./0123456789:;<=>?@\ + ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\]^_`abcdefghijklmnopqrstuvwxyz{|}~\\x7f\ + \\x80\\x81\\x82\\x83\\x84\\x85\\x86\\x87\\x88\\x89\\x8a\\x8b\\x8c\\x8d\\x8e\\x8f\ + \\x90\\x91\\x92\\x93\\x94\\x95\\x96\\x97\\x98\\x99\\x9a\\x9b\\x9c\\x9d\\x9e\\x9f\ + \\xa0\\xa1\\xa2\\xa3\\xa4\\xa5\\xa6\\xa7\\xa8\\xa9\\xaa\\xab\\xac\\xad\\xae\\xaf\ + \\xb0\\xb1\\xb2\\xb3\\xb4\\xb5\\xb6\\xb7\\xb8\\xb9\\xba\\xbb\\xbc\\xbd\\xbe\\xbf\ + \\xc0\\xc1\\xc2\\xc3\\xc4\\xc5\\xc6\\xc7\\xc8\\xc9\\xca\\xcb\\xcc\\xcd\\xce\\xcf\ + \\xd0\\xd1\\xd2\\xd3\\xd4\\xd5\\xd6\\xd7\\xd8\\xd9\\xda\\xdb\\xdc\\xdd\\xde\\xdf\ + \\xe0\\xe1\\xe2\\xe3\\xe4\\xe5\\xe6\\xe7\\xe8\\xe9\\xea\\xeb\\xec\\xed\\xee\\xef\ + \\xf0\\xf1\\xf2\\xf3\\xf4\\xf5\\xf6\\xf7\\xf8\\xf9\\xfa\\xfb\\xfc\\xfd\\xfe\\xff"; + #[test] - fn test_cstr_to_str() { - let good_bytes = b"\xf0\x9f\xa6\x80\0"; - let checked_cstr = CStr::from_bytes_with_nul(good_bytes).unwrap(); - let checked_str = checked_cstr.to_str().unwrap(); + fn test_cstr_to_str() -> Result { + let cstr = c"\xf0\x9f\xa6\x80"; + let checked_str = cstr.to_str()?; assert_eq!(checked_str, "🦀"); + Ok(()) + } + + #[test] + fn test_cstr_to_str_invalid_utf8() -> Result { + let cstr = c"\xc3\x28"; + assert!(cstr.to_str().is_err()); + Ok(()) } #[test] - #[should_panic] - fn test_cstr_to_str_panic() { - let bad_bytes = b"\xc3\x28\0"; - let checked_cstr = CStr::from_bytes_with_nul(bad_bytes).unwrap(); - checked_cstr.to_str().unwrap(); + fn test_cstr_display() -> Result { + let hello_world = c"hello, world!"; + assert_eq!(format!("{hello_world}"), "hello, world!"); + let non_printables = c"\x01\x09\x0a"; + assert_eq!(format!("{non_printables}"), "\\x01\\x09\\x0a"); + let non_ascii = c"d\xe9j\xe0 vu"; + assert_eq!(format!("{non_ascii}"), "d\\xe9j\\xe0 vu"); + let good_bytes = c"\xf0\x9f\xa6\x80"; + assert_eq!(format!("{good_bytes}"), "\\xf0\\x9f\\xa6\\x80"); + Ok(()) } #[test] - fn test_cstr_as_str_unchecked() { - let good_bytes = b"\xf0\x9f\x90\xA7\0"; - let checked_cstr = CStr::from_bytes_with_nul(good_bytes).unwrap(); - let unchecked_str = unsafe { checked_cstr.as_str_unchecked() }; - assert_eq!(unchecked_str, "🐧"); + fn test_cstr_display_all_bytes() -> Result { + let mut bytes: [u8; 256] = [0; 256]; + // fill `bytes` with [1..=255] + [0] + for i in u8::MIN..=u8::MAX { + bytes[i as usize] = i.wrapping_add(1); + } + let cstr = CStr::from_bytes_with_nul(&bytes)?; + assert_eq!(format!("{cstr}"), ALL_ASCII_CHARS); + Ok(()) + } + + #[test] + fn test_cstr_debug() -> Result { + let hello_world = c"hello, world!"; + assert_eq!(format!("{hello_world:?}"), "\"hello, world!\""); + let non_printables = c"\x01\x09\x0a"; + assert_eq!(format!("{non_printables:?}"), "\"\\x01\\t\\n\""); + let non_ascii = c"d\xe9j\xe0 vu"; + assert_eq!(format!("{non_ascii:?}"), "\"d\\xe9j\\xe0 vu\""); + Ok(()) + } + + #[test] + fn test_bstr_display() -> Result { + let hello_world = BStr::from_bytes(b"hello, world!"); + assert_eq!(format!("{hello_world}"), "hello, world!"); + let escapes = BStr::from_bytes(b"_\t_\n_\r_\\_\'_\"_"); + assert_eq!(format!("{escapes}"), "_\\t_\\n_\\r_\\_'_\"_"); + let others = BStr::from_bytes(b"\x01"); + assert_eq!(format!("{others}"), "\\x01"); + let non_ascii = BStr::from_bytes(b"d\xe9j\xe0 vu"); + assert_eq!(format!("{non_ascii}"), "d\\xe9j\\xe0 vu"); + let good_bytes = BStr::from_bytes(b"\xf0\x9f\xa6\x80"); + assert_eq!(format!("{good_bytes}"), "\\xf0\\x9f\\xa6\\x80"); + Ok(()) + } + + #[test] + fn test_bstr_debug() -> Result { + let hello_world = BStr::from_bytes(b"hello, world!"); + assert_eq!(format!("{hello_world:?}"), "\"hello, world!\""); + let escapes = BStr::from_bytes(b"_\t_\n_\r_\\_\'_\"_"); + assert_eq!(format!("{escapes:?}"), "\"_\\t_\\n_\\r_\\\\_'_\\\"_\""); + let others = BStr::from_bytes(b"\x01"); + assert_eq!(format!("{others:?}"), "\"\\x01\""); + let non_ascii = BStr::from_bytes(b"d\xe9j\xe0 vu"); + assert_eq!(format!("{non_ascii:?}"), "\"d\\xe9j\\xe0 vu\""); + let good_bytes = BStr::from_bytes(b"\xf0\x9f\xa6\x80"); + assert_eq!(format!("{good_bytes:?}"), "\"\\xf0\\x9f\\xa6\\x80\""); + Ok(()) } } @@ -383,7 +521,7 @@ mod tests { /// /// The memory region between `pos` (inclusive) and `end` (exclusive) is valid for writes if `pos` /// is less than `end`. -pub(crate) struct RawFormatter { +pub struct RawFormatter { // Use `usize` to use `saturating_*` functions. beg: usize, pos: usize, @@ -408,11 +546,11 @@ impl RawFormatter { /// If `pos` is less than `end`, then the region between `pos` (inclusive) and `end` /// (exclusive) must be valid for writes for the lifetime of the returned [`RawFormatter`]. pub(crate) unsafe fn from_ptrs(pos: *mut u8, end: *mut u8) -> Self { - // INVARIANT: The safety requierments guarantee the type invariants. + // INVARIANT: The safety requirements guarantee the type invariants. Self { - beg: pos as _, - pos: pos as _, - end: end as _, + beg: pos as usize, + pos: pos as usize, + end: end as usize, } } @@ -424,7 +562,7 @@ impl RawFormatter { /// for the lifetime of the returned [`RawFormatter`]. pub(crate) unsafe fn from_buffer(buf: *mut u8, len: usize) -> Self { let pos = buf as usize; - // INVARIANT: We ensure that `end` is never less then `buf`, and the safety requirements + // INVARIANT: We ensure that `end` is never less than `buf`, and the safety requirements // guarantees that the memory region is valid for writes. Self { pos, @@ -437,11 +575,11 @@ impl RawFormatter { /// /// N.B. It may point to invalid memory. pub(crate) fn pos(&self) -> *mut u8 { - self.pos as _ + self.pos as *mut u8 } - /// Return the number of bytes written to the formatter. - pub(crate) fn bytes_written(&self) -> usize { + /// Returns the number of bytes written to the formatter. + pub fn bytes_written(&self) -> usize { self.pos - self.beg } } @@ -475,9 +613,9 @@ impl fmt::Write for RawFormatter { /// Allows formatting of [`fmt::Arguments`] into a raw buffer. /// /// Fails if callers attempt to write more than will fit in the buffer. -pub(crate) struct Formatter(RawFormatter); +pub struct Formatter<'a>(RawFormatter, PhantomData<&'a mut ()>); -impl Formatter { +impl Formatter<'_> { /// Creates a new instance of [`Formatter`] with the given buffer. /// /// # Safety @@ -486,11 +624,18 @@ impl Formatter { /// for the lifetime of the returned [`Formatter`]. pub(crate) unsafe fn from_buffer(buf: *mut u8, len: usize) -> Self { // SAFETY: The safety requirements of this function satisfy those of the callee. - Self(unsafe { RawFormatter::from_buffer(buf, len) }) + Self(unsafe { RawFormatter::from_buffer(buf, len) }, PhantomData) + } + + /// Create a new [`Self`] instance. + pub fn new(buffer: &mut [u8]) -> Self { + // SAFETY: `buffer` is valid for writes for the entire length for + // the lifetime of `Self`. + unsafe { Formatter::from_buffer(buffer.as_mut_ptr(), buffer.len()) } } } -impl Deref for Formatter { +impl Deref for Formatter<'_> { type Target = RawFormatter; fn deref(&self) -> &Self::Target { @@ -498,7 +643,7 @@ impl Deref for Formatter { } } -impl fmt::Write for Formatter { +impl fmt::Write for Formatter<'_> { fn write_str(&mut self, s: &str) -> fmt::Result { self.0.write_str(s)?; @@ -511,6 +656,132 @@ impl fmt::Write for Formatter { } } +/// A mutable reference to a byte buffer where a string can be written into. +/// +/// The buffer will be automatically null terminated after the last written character. +/// +/// # Invariants +/// +/// * The first byte of `buffer` is always zero. +/// * The length of `buffer` is at least 1. +pub(crate) struct NullTerminatedFormatter<'a> { + buffer: &'a mut [u8], +} + +impl<'a> NullTerminatedFormatter<'a> { + /// Create a new [`Self`] instance. + pub(crate) fn new(buffer: &'a mut [u8]) -> Option<NullTerminatedFormatter<'a>> { + *(buffer.first_mut()?) = 0; + + // INVARIANT: + // - We wrote zero to the first byte above. + // - If buffer was not at least length 1, `buffer.first_mut()` would return None. + Some(Self { buffer }) + } +} + +impl Write for NullTerminatedFormatter<'_> { + fn write_str(&mut self, s: &str) -> fmt::Result { + let bytes = s.as_bytes(); + let len = bytes.len(); + + // We want space for a zero. By type invariant, buffer length is always at least 1, so no + // underflow. + if len > self.buffer.len() - 1 { + return Err(fmt::Error); + } + + let buffer = core::mem::take(&mut self.buffer); + // We break the zero start invariant for a short while. + buffer[..len].copy_from_slice(bytes); + // INVARIANT: We checked above that buffer will have size at least 1 after this assignment. + self.buffer = &mut buffer[len..]; + + // INVARIANT: We write zero to the first byte of the buffer. + self.buffer[0] = 0; + + Ok(()) + } +} + +/// # Safety +/// +/// - `string` must point to a null terminated string that is valid for read. +unsafe fn kstrtobool_raw(string: *const u8) -> Result<bool> { + let mut result: bool = false; + + // SAFETY: + // - By function safety requirement, `string` is a valid null-terminated string. + // - `result` is a valid `bool` that we own. + to_result(unsafe { bindings::kstrtobool(string, &mut result) })?; + Ok(result) +} + +/// Convert common user inputs into boolean values using the kernel's `kstrtobool` function. +/// +/// This routine returns `Ok(bool)` if the first character is one of 'YyTt1NnFf0', or +/// \[oO\]\[NnFf\] for "on" and "off". Otherwise it will return `Err(EINVAL)`. +/// +/// # Examples +/// +/// ``` +/// # use kernel::str::kstrtobool; +/// +/// // Lowercase +/// assert_eq!(kstrtobool(c"true"), Ok(true)); +/// assert_eq!(kstrtobool(c"tr"), Ok(true)); +/// assert_eq!(kstrtobool(c"t"), Ok(true)); +/// assert_eq!(kstrtobool(c"twrong"), Ok(true)); +/// assert_eq!(kstrtobool(c"false"), Ok(false)); +/// assert_eq!(kstrtobool(c"f"), Ok(false)); +/// assert_eq!(kstrtobool(c"yes"), Ok(true)); +/// assert_eq!(kstrtobool(c"no"), Ok(false)); +/// assert_eq!(kstrtobool(c"on"), Ok(true)); +/// assert_eq!(kstrtobool(c"off"), Ok(false)); +/// +/// // Camel case +/// assert_eq!(kstrtobool(c"True"), Ok(true)); +/// assert_eq!(kstrtobool(c"False"), Ok(false)); +/// assert_eq!(kstrtobool(c"Yes"), Ok(true)); +/// assert_eq!(kstrtobool(c"No"), Ok(false)); +/// assert_eq!(kstrtobool(c"On"), Ok(true)); +/// assert_eq!(kstrtobool(c"Off"), Ok(false)); +/// +/// // All caps +/// assert_eq!(kstrtobool(c"TRUE"), Ok(true)); +/// assert_eq!(kstrtobool(c"FALSE"), Ok(false)); +/// assert_eq!(kstrtobool(c"YES"), Ok(true)); +/// assert_eq!(kstrtobool(c"NO"), Ok(false)); +/// assert_eq!(kstrtobool(c"ON"), Ok(true)); +/// assert_eq!(kstrtobool(c"OFF"), Ok(false)); +/// +/// // Numeric +/// assert_eq!(kstrtobool(c"1"), Ok(true)); +/// assert_eq!(kstrtobool(c"0"), Ok(false)); +/// +/// // Invalid input +/// assert_eq!(kstrtobool(c"invalid"), Err(EINVAL)); +/// assert_eq!(kstrtobool(c"2"), Err(EINVAL)); +/// ``` +pub fn kstrtobool(string: &CStr) -> Result<bool> { + // SAFETY: + // - The pointer returned by `CStr::as_char_ptr` is guaranteed to be + // null terminated. + // - `string` is live and thus the string is valid for read. + unsafe { kstrtobool_raw(string.as_char_ptr()) } +} + +/// Convert `&[u8]` to `bool` by deferring to [`kernel::str::kstrtobool`]. +/// +/// Only considers at most the first two bytes of `bytes`. +pub fn kstrtobool_bytes(bytes: &[u8]) -> Result<bool> { + // `ktostrbool` only considers the first two bytes of the input. + let stack_string = [*bytes.first().unwrap_or(&0), *bytes.get(1).unwrap_or(&0), 0]; + // SAFETY: `stack_string` is null terminated and it is live on the stack so + // it is valid for read. + unsafe { kstrtobool_raw(stack_string.as_ptr()) } +} + /// An owned string that is guaranteed to have exactly one `NUL` byte, which is at the end. /// /// Used for interoperability with kernel APIs that take C strings. @@ -522,21 +793,22 @@ impl fmt::Write for Formatter { /// # Examples /// /// ``` -/// use kernel::str::CString; +/// use kernel::{str::CString, prelude::fmt}; /// -/// let s = CString::try_from_fmt(fmt!("{}{}{}", "abc", 10, 20)).unwrap(); -/// assert_eq!(s.as_bytes_with_nul(), "abc1020\0".as_bytes()); +/// let s = CString::try_from_fmt(fmt!("{}{}{}", "abc", 10, 20))?; +/// assert_eq!(s.to_bytes_with_nul(), "abc1020\0".as_bytes()); /// /// let tmp = "testing"; -/// let s = CString::try_from_fmt(fmt!("{tmp}{}", 123)).unwrap(); -/// assert_eq!(s.as_bytes_with_nul(), "testing123\0".as_bytes()); +/// let s = CString::try_from_fmt(fmt!("{tmp}{}", 123))?; +/// assert_eq!(s.to_bytes_with_nul(), "testing123\0".as_bytes()); /// /// // This fails because it has an embedded `NUL` byte. /// let s = CString::try_from_fmt(fmt!("a\0b{}", 123)); /// assert_eq!(s.is_ok(), false); +/// # Ok::<(), kernel::error::Error>(()) /// ``` pub struct CString { - buf: Vec<u8>, + buf: KVec<u8>, } impl CString { @@ -549,7 +821,7 @@ impl CString { let size = f.bytes_written(); // Allocate a vector with the required number of bytes, and write to it. - let mut buf = Vec::try_with_capacity(size)?; + let mut buf = KVec::with_capacity(size, GFP_KERNEL)?; // SAFETY: The buffer stored in `buf` is at least of size `size` and is valid for writes. let mut f = unsafe { Formatter::from_buffer(buf.as_mut_ptr(), size) }; f.write_fmt(args)?; @@ -557,13 +829,13 @@ impl CString { // SAFETY: The number of bytes that can be written to `f` is bounded by `size`, which is // `buf`'s capacity. The contents of the buffer have been initialised by writes to `f`. - unsafe { buf.set_len(f.bytes_written()) }; + unsafe { buf.inc_len(f.bytes_written()) }; // Check that there are no `NUL` bytes before the end. // SAFETY: The buffer is valid for read because `f.bytes_written()` is bounded by `size` // (which the minimum buffer size) and is non-zero (we wrote at least the `NUL` terminator) // so `f.bytes_written() - 1` doesn't underflow. - let ptr = unsafe { bindings::memchr(buf.as_ptr().cast(), 0, (f.bytes_written() - 1) as _) }; + let ptr = unsafe { bindings::memchr(buf.as_ptr().cast(), 0, f.bytes_written() - 1) }; if !ptr.is_null() { return Err(EINVAL); } @@ -584,8 +856,30 @@ impl Deref for CString { } } -/// A convenience alias for [`core::format_args`]. -#[macro_export] -macro_rules! fmt { - ($($f:tt)*) => ( core::format_args!($($f)*) ) +impl DerefMut for CString { + fn deref_mut(&mut self) -> &mut Self::Target { + // SAFETY: A `CString` is always NUL-terminated and contains no other + // NUL bytes. + unsafe { CStr::from_bytes_with_nul_unchecked_mut(self.buf.as_mut_slice()) } + } +} + +impl<'a> TryFrom<&'a CStr> for CString { + type Error = AllocError; + + fn try_from(cstr: &'a CStr) -> Result<CString, AllocError> { + let mut buf = KVec::new(); + + buf.extend_from_slice(cstr.to_bytes_with_nul(), GFP_KERNEL)?; + + // INVARIANT: The `CStr` and `CString` types have the same invariants for + // the string data, and we copied it over without changes. + Ok(CString { buf }) + } +} + +impl fmt::Debug for CString { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Debug::fmt(&**self, f) + } } diff --git a/rust/kernel/str/parse_int.rs b/rust/kernel/str/parse_int.rs new file mode 100644 index 000000000000..48eb4c202984 --- /dev/null +++ b/rust/kernel/str/parse_int.rs @@ -0,0 +1,148 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! Integer parsing functions. +//! +//! Integer parsing functions for parsing signed and unsigned integers +//! potentially prefixed with `0x`, `0o`, or `0b`. + +use crate::prelude::*; +use crate::str::BStr; +use core::ops::Deref; + +// Make `FromStrRadix` a public type with a private name. This seals +// `ParseInt`, that is, prevents downstream users from implementing the +// trait. +mod private { + use crate::prelude::*; + use crate::str::BStr; + + /// Trait that allows parsing a [`&BStr`] to an integer with a radix. + pub trait FromStrRadix: Sized { + /// Parse `src` to [`Self`] using radix `radix`. + fn from_str_radix(src: &BStr, radix: u32) -> Result<Self>; + + /// Tries to convert `value` into [`Self`] and negates the resulting value. + fn from_u64_negated(value: u64) -> Result<Self>; + } +} + +/// Extract the radix from an integer literal optionally prefixed with +/// one of `0x`, `0X`, `0o`, `0O`, `0b`, `0B`, `0`. +fn strip_radix(src: &BStr) -> (u32, &BStr) { + match src.deref() { + [b'0', b'x' | b'X', rest @ ..] => (16, rest.as_ref()), + [b'0', b'o' | b'O', rest @ ..] => (8, rest.as_ref()), + [b'0', b'b' | b'B', rest @ ..] => (2, rest.as_ref()), + // NOTE: We are including the leading zero to be able to parse + // literal `0` here. If we removed it as a radix prefix, we would + // not be able to parse `0`. + [b'0', ..] => (8, src), + _ => (10, src), + } +} + +/// Trait for parsing string representations of integers. +/// +/// Strings beginning with `0x`, `0o`, or `0b` are parsed as hex, octal, or +/// binary respectively. Strings beginning with `0` otherwise are parsed as +/// octal. Anything else is parsed as decimal. A leading `+` or `-` is also +/// permitted. Any string parsed by [`kstrtol()`] or [`kstrtoul()`] will be +/// successfully parsed. +/// +/// [`kstrtol()`]: https://docs.kernel.org/core-api/kernel-api.html#c.kstrtol +/// [`kstrtoul()`]: https://docs.kernel.org/core-api/kernel-api.html#c.kstrtoul +/// +/// # Examples +/// +/// ``` +/// # use kernel::str::parse_int::ParseInt; +/// # use kernel::b_str; +/// +/// assert_eq!(Ok(0u8), u8::from_str(b_str!("0"))); +/// +/// assert_eq!(Ok(0xa2u8), u8::from_str(b_str!("0xa2"))); +/// assert_eq!(Ok(-0xa2i32), i32::from_str(b_str!("-0xa2"))); +/// +/// assert_eq!(Ok(-0o57i8), i8::from_str(b_str!("-0o57"))); +/// assert_eq!(Ok(0o57i8), i8::from_str(b_str!("057"))); +/// +/// assert_eq!(Ok(0b1001i16), i16::from_str(b_str!("0b1001"))); +/// assert_eq!(Ok(-0b1001i16), i16::from_str(b_str!("-0b1001"))); +/// +/// assert_eq!(Ok(127i8), i8::from_str(b_str!("127"))); +/// assert!(i8::from_str(b_str!("128")).is_err()); +/// assert_eq!(Ok(-128i8), i8::from_str(b_str!("-128"))); +/// assert!(i8::from_str(b_str!("-129")).is_err()); +/// assert_eq!(Ok(255u8), u8::from_str(b_str!("255"))); +/// assert!(u8::from_str(b_str!("256")).is_err()); +/// ``` +pub trait ParseInt: private::FromStrRadix + TryFrom<u64> { + /// Parse a string according to the description in [`Self`]. + fn from_str(src: &BStr) -> Result<Self> { + match src.deref() { + [b'-', rest @ ..] => { + let (radix, digits) = strip_radix(rest.as_ref()); + // 2's complement values range from -2^(b-1) to 2^(b-1)-1. + // So if we want to parse negative numbers as positive and + // later multiply by -1, we have to parse into a larger + // integer. We choose `u64` as sufficiently large. + // + // NOTE: 128 bit integers are not available on all + // platforms, hence the choice of 64 bits. + let val = + u64::from_str_radix(core::str::from_utf8(digits).map_err(|_| EINVAL)?, radix) + .map_err(|_| EINVAL)?; + Self::from_u64_negated(val) + } + _ => { + let (radix, digits) = strip_radix(src); + Self::from_str_radix(digits, radix).map_err(|_| EINVAL) + } + } + } +} + +macro_rules! impl_parse_int { + ($($ty:ty),*) => { + $( + impl private::FromStrRadix for $ty { + fn from_str_radix(src: &BStr, radix: u32) -> Result<Self> { + <$ty>::from_str_radix(core::str::from_utf8(src).map_err(|_| EINVAL)?, radix) + .map_err(|_| EINVAL) + } + + fn from_u64_negated(value: u64) -> Result<Self> { + const ABS_MIN: u64 = { + #[allow(unused_comparisons)] + if <$ty>::MIN < 0 { + 1u64 << (<$ty>::BITS - 1) + } else { + 0 + } + }; + + if value > ABS_MIN { + return Err(EINVAL); + } + + if value == ABS_MIN { + return Ok(<$ty>::MIN); + } + + // SAFETY: The above checks guarantee that `value` fits into `Self`: + // - if `Self` is unsigned, then `ABS_MIN == 0` and thus we have returned above + // (either `EINVAL` or `MIN`). + // - if `Self` is signed, then we have that `0 <= value < ABS_MIN`. And since + // `ABS_MIN - 1` fits into `Self` by construction, `value` also does. + let value: Self = unsafe { value.try_into().unwrap_unchecked() }; + + Ok((!value).wrapping_add(1)) + } + } + + impl ParseInt for $ty {} + )* + }; +} + +impl_parse_int![i8, u8, i16, u16, i32, u32, i64, u64, isize, usize]; diff --git a/rust/kernel/sync.rs b/rust/kernel/sync.rs new file mode 100644 index 000000000000..5df87e2bd212 --- /dev/null +++ b/rust/kernel/sync.rs @@ -0,0 +1,125 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! Synchronisation primitives. +//! +//! This module contains the kernel APIs related to synchronisation that have been ported or +//! wrapped for usage by Rust code in the kernel. + +use crate::prelude::*; +use crate::types::Opaque; +use pin_init; + +mod arc; +pub mod aref; +pub mod atomic; +pub mod barrier; +pub mod completion; +mod condvar; +pub mod lock; +mod locked_by; +pub mod poll; +pub mod rcu; +mod refcount; +mod set_once; + +pub use arc::{Arc, ArcBorrow, UniqueArc}; +pub use completion::Completion; +pub use condvar::{new_condvar, CondVar, CondVarTimeoutResult}; +pub use lock::global::{global_lock, GlobalGuard, GlobalLock, GlobalLockBackend, GlobalLockedBy}; +pub use lock::mutex::{new_mutex, Mutex, MutexGuard}; +pub use lock::spinlock::{new_spinlock, SpinLock, SpinLockGuard}; +pub use locked_by::LockedBy; +pub use refcount::Refcount; +pub use set_once::SetOnce; + +/// Represents a lockdep class. It's a wrapper around C's `lock_class_key`. +#[repr(transparent)] +#[pin_data(PinnedDrop)] +pub struct LockClassKey { + #[pin] + inner: Opaque<bindings::lock_class_key>, +} + +// SAFETY: `bindings::lock_class_key` is designed to be used concurrently from multiple threads and +// provides its own synchronization. +unsafe impl Sync for LockClassKey {} + +impl LockClassKey { + /// Initializes a dynamically allocated lock class key. In the common case of using a + /// statically allocated lock class key, the static_lock_class! macro should be used instead. + /// + /// # Examples + /// ``` + /// # use kernel::alloc::KBox; + /// # use kernel::types::ForeignOwnable; + /// # use kernel::sync::{LockClassKey, SpinLock}; + /// # use pin_init::stack_pin_init; + /// + /// let key = KBox::pin_init(LockClassKey::new_dynamic(), GFP_KERNEL)?; + /// let key_ptr = key.into_foreign(); + /// + /// { + /// stack_pin_init!(let num: SpinLock<u32> = SpinLock::new( + /// 0, + /// c"my_spinlock", + /// // SAFETY: `key_ptr` is returned by the above `into_foreign()`, whose + /// // `from_foreign()` has not yet been called. + /// unsafe { <Pin<KBox<LockClassKey>> as ForeignOwnable>::borrow(key_ptr) } + /// )); + /// } + /// + /// // SAFETY: We dropped `num`, the only use of the key, so the result of the previous + /// // `borrow` has also been dropped. Thus, it's safe to use from_foreign. + /// unsafe { drop(<Pin<KBox<LockClassKey>> as ForeignOwnable>::from_foreign(key_ptr)) }; + /// + /// # Ok::<(), Error>(()) + /// ``` + pub fn new_dynamic() -> impl PinInit<Self> { + pin_init!(Self { + // SAFETY: lockdep_register_key expects an uninitialized block of memory + inner <- Opaque::ffi_init(|slot| unsafe { bindings::lockdep_register_key(slot) }) + }) + } + + pub(crate) fn as_ptr(&self) -> *mut bindings::lock_class_key { + self.inner.get() + } +} + +#[pinned_drop] +impl PinnedDrop for LockClassKey { + fn drop(self: Pin<&mut Self>) { + // SAFETY: self.as_ptr was registered with lockdep and self is pinned, so the address + // hasn't changed. Thus, it's safe to pass to unregister. + unsafe { bindings::lockdep_unregister_key(self.as_ptr()) } + } +} + +/// Defines a new static lock class and returns a pointer to it. +#[doc(hidden)] +#[macro_export] +macro_rules! static_lock_class { + () => {{ + static CLASS: $crate::sync::LockClassKey = + // Lockdep expects uninitialized memory when it's handed a statically allocated `struct + // lock_class_key`. + // + // SAFETY: `LockClassKey` transparently wraps `Opaque` which permits uninitialized + // memory. + unsafe { ::core::mem::MaybeUninit::uninit().assume_init() }; + $crate::prelude::Pin::static_ref(&CLASS) + }}; +} + +/// Returns the given string, if one is provided, otherwise generates one based on the source code +/// location. +#[doc(hidden)] +#[macro_export] +macro_rules! optional_name { + () => { + $crate::c_str!(::core::concat!(::core::file!(), ":", ::core::line!())) + }; + ($name:literal) => { + $crate::c_str!($name) + }; +} diff --git a/rust/kernel/sync/arc.rs b/rust/kernel/sync/arc.rs new file mode 100644 index 000000000000..289f77abf415 --- /dev/null +++ b/rust/kernel/sync/arc.rs @@ -0,0 +1,920 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! A reference-counted pointer. +//! +//! This module implements a way for users to create reference-counted objects and pointers to +//! them. Such a pointer automatically increments and decrements the count, and drops the +//! underlying object when it reaches zero. It is also safe to use concurrently from multiple +//! threads. +//! +//! It is different from the standard library's [`Arc`] in a few ways: +//! 1. It is backed by the kernel's [`Refcount`] type. +//! 2. It does not support weak references, which allows it to be half the size. +//! 3. It saturates the reference count instead of aborting when it goes over a threshold. +//! 4. It does not provide a `get_mut` method, so the ref counted object is pinned. +//! 5. The object in [`Arc`] is pinned implicitly. +//! +//! [`Arc`]: https://doc.rust-lang.org/std/sync/struct.Arc.html + +use crate::{ + alloc::{AllocError, Flags, KBox}, + ffi::c_void, + fmt, + init::InPlaceInit, + sync::Refcount, + try_init, + types::ForeignOwnable, +}; +use core::{ + alloc::Layout, + borrow::{Borrow, BorrowMut}, + marker::PhantomData, + mem::{ManuallyDrop, MaybeUninit}, + ops::{Deref, DerefMut}, + pin::Pin, + ptr::NonNull, +}; +use pin_init::{self, pin_data, InPlaceWrite, Init, PinInit}; + +mod std_vendor; + +/// A reference-counted pointer to an instance of `T`. +/// +/// The reference count is incremented when new instances of [`Arc`] are created, and decremented +/// when they are dropped. When the count reaches zero, the underlying `T` is also dropped. +/// +/// # Invariants +/// +/// The reference count on an instance of [`Arc`] is always non-zero. +/// The object pointed to by [`Arc`] is always pinned. +/// +/// # Examples +/// +/// ``` +/// use kernel::sync::Arc; +/// +/// struct Example { +/// a: u32, +/// b: u32, +/// } +/// +/// // Create a refcounted instance of `Example`. +/// let obj = Arc::new(Example { a: 10, b: 20 }, GFP_KERNEL)?; +/// +/// // Get a new pointer to `obj` and increment the refcount. +/// let cloned = obj.clone(); +/// +/// // Assert that both `obj` and `cloned` point to the same underlying object. +/// assert!(core::ptr::eq(&*obj, &*cloned)); +/// +/// // Destroy `obj` and decrement its refcount. +/// drop(obj); +/// +/// // Check that the values are still accessible through `cloned`. +/// assert_eq!(cloned.a, 10); +/// assert_eq!(cloned.b, 20); +/// +/// // The refcount drops to zero when `cloned` goes out of scope, and the memory is freed. +/// # Ok::<(), Error>(()) +/// ``` +/// +/// Using `Arc<T>` as the type of `self`: +/// +/// ``` +/// use kernel::sync::Arc; +/// +/// struct Example { +/// a: u32, +/// b: u32, +/// } +/// +/// impl Example { +/// fn take_over(self: Arc<Self>) { +/// // ... +/// } +/// +/// fn use_reference(self: &Arc<Self>) { +/// // ... +/// } +/// } +/// +/// let obj = Arc::new(Example { a: 10, b: 20 }, GFP_KERNEL)?; +/// obj.use_reference(); +/// obj.take_over(); +/// # Ok::<(), Error>(()) +/// ``` +/// +/// Coercion from `Arc<Example>` to `Arc<dyn MyTrait>`: +/// +/// ``` +/// use kernel::sync::{Arc, ArcBorrow}; +/// +/// trait MyTrait { +/// // Trait has a function whose `self` type is `Arc<Self>`. +/// fn example1(self: Arc<Self>) {} +/// +/// // Trait has a function whose `self` type is `ArcBorrow<'_, Self>`. +/// fn example2(self: ArcBorrow<'_, Self>) {} +/// } +/// +/// struct Example; +/// impl MyTrait for Example {} +/// +/// // `obj` has type `Arc<Example>`. +/// let obj: Arc<Example> = Arc::new(Example, GFP_KERNEL)?; +/// +/// // `coerced` has type `Arc<dyn MyTrait>`. +/// let coerced: Arc<dyn MyTrait> = obj; +/// # Ok::<(), Error>(()) +/// ``` +#[repr(transparent)] +#[cfg_attr(CONFIG_RUSTC_HAS_COERCE_POINTEE, derive(core::marker::CoercePointee))] +pub struct Arc<T: ?Sized> { + ptr: NonNull<ArcInner<T>>, + // NB: this informs dropck that objects of type `ArcInner<T>` may be used in `<Arc<T> as + // Drop>::drop`. Note that dropck already assumes that objects of type `T` may be used in + // `<Arc<T> as Drop>::drop` and the distinction between `T` and `ArcInner<T>` is not presently + // meaningful with respect to dropck - but this may change in the future so this is left here + // out of an abundance of caution. + // + // See <https://doc.rust-lang.org/nomicon/phantom-data.html#generic-parameters-and-drop-checking> + // for more detail on the semantics of dropck in the presence of `PhantomData`. + _p: PhantomData<ArcInner<T>>, +} + +#[pin_data] +#[repr(C)] +struct ArcInner<T: ?Sized> { + refcount: Refcount, + data: T, +} + +impl<T: ?Sized> ArcInner<T> { + /// Converts a pointer to the contents of an [`Arc`] into a pointer to the [`ArcInner`]. + /// + /// # Safety + /// + /// `ptr` must have been returned by a previous call to [`Arc::into_raw`], and the `Arc` must + /// not yet have been destroyed. + unsafe fn container_of(ptr: *const T) -> NonNull<ArcInner<T>> { + let refcount_layout = Layout::new::<Refcount>(); + // SAFETY: The caller guarantees that the pointer is valid. + let val_layout = Layout::for_value(unsafe { &*ptr }); + // SAFETY: We're computing the layout of a real struct that existed when compiling this + // binary, so its layout is not so large that it can trigger arithmetic overflow. + let val_offset = unsafe { refcount_layout.extend(val_layout).unwrap_unchecked().1 }; + + // Pointer casts leave the metadata unchanged. This is okay because the metadata of `T` and + // `ArcInner<T>` is the same since `ArcInner` is a struct with `T` as its last field. + // + // This is documented at: + // <https://doc.rust-lang.org/std/ptr/trait.Pointee.html>. + let ptr = ptr as *const ArcInner<T>; + + // SAFETY: The pointer is in-bounds of an allocation both before and after offsetting the + // pointer, since it originates from a previous call to `Arc::into_raw` on an `Arc` that is + // still valid. + let ptr = unsafe { ptr.byte_sub(val_offset) }; + + // SAFETY: The pointer can't be null since you can't have an `ArcInner<T>` value at the null + // address. + unsafe { NonNull::new_unchecked(ptr.cast_mut()) } + } +} + +// This is to allow coercion from `Arc<T>` to `Arc<U>` if `T` can be converted to the +// dynamically-sized type (DST) `U`. +#[cfg(not(CONFIG_RUSTC_HAS_COERCE_POINTEE))] +impl<T: ?Sized + core::marker::Unsize<U>, U: ?Sized> core::ops::CoerceUnsized<Arc<U>> for Arc<T> {} + +// This is to allow `Arc<U>` to be dispatched on when `Arc<T>` can be coerced into `Arc<U>`. +#[cfg(not(CONFIG_RUSTC_HAS_COERCE_POINTEE))] +impl<T: ?Sized + core::marker::Unsize<U>, U: ?Sized> core::ops::DispatchFromDyn<Arc<U>> for Arc<T> {} + +// SAFETY: It is safe to send `Arc<T>` to another thread when the underlying `T` is `Sync` because +// it effectively means sharing `&T` (which is safe because `T` is `Sync`); additionally, it needs +// `T` to be `Send` because any thread that has an `Arc<T>` may ultimately access `T` using a +// mutable reference when the reference count reaches zero and `T` is dropped. +unsafe impl<T: ?Sized + Sync + Send> Send for Arc<T> {} + +// SAFETY: It is safe to send `&Arc<T>` to another thread when the underlying `T` is `Sync` +// because it effectively means sharing `&T` (which is safe because `T` is `Sync`); additionally, +// it needs `T` to be `Send` because any thread that has a `&Arc<T>` may clone it and get an +// `Arc<T>` on that thread, so the thread may ultimately access `T` using a mutable reference when +// the reference count reaches zero and `T` is dropped. +unsafe impl<T: ?Sized + Sync + Send> Sync for Arc<T> {} + +impl<T> InPlaceInit<T> for Arc<T> { + type PinnedSelf = Self; + + #[inline] + fn try_pin_init<E>(init: impl PinInit<T, E>, flags: Flags) -> Result<Self::PinnedSelf, E> + where + E: From<AllocError>, + { + UniqueArc::try_pin_init(init, flags).map(|u| u.into()) + } + + #[inline] + fn try_init<E>(init: impl Init<T, E>, flags: Flags) -> Result<Self, E> + where + E: From<AllocError>, + { + UniqueArc::try_init(init, flags).map(|u| u.into()) + } +} + +impl<T> Arc<T> { + /// Constructs a new reference counted instance of `T`. + pub fn new(contents: T, flags: Flags) -> Result<Self, AllocError> { + // INVARIANT: The refcount is initialised to a non-zero value. + let value = ArcInner { + refcount: Refcount::new(1), + data: contents, + }; + + let inner = KBox::new(value, flags)?; + let inner = KBox::leak(inner).into(); + + // SAFETY: We just created `inner` with a reference count of 1, which is owned by the new + // `Arc` object. + Ok(unsafe { Self::from_inner(inner) }) + } +} + +impl<T: ?Sized> Arc<T> { + /// Constructs a new [`Arc`] from an existing [`ArcInner`]. + /// + /// # Safety + /// + /// The caller must ensure that `inner` points to a valid location and has a non-zero reference + /// count, one of which will be owned by the new [`Arc`] instance. + unsafe fn from_inner(inner: NonNull<ArcInner<T>>) -> Self { + // INVARIANT: By the safety requirements, the invariants hold. + Arc { + ptr: inner, + _p: PhantomData, + } + } + + /// Convert the [`Arc`] into a raw pointer. + /// + /// The raw pointer has ownership of the refcount that this Arc object owned. + pub fn into_raw(self) -> *const T { + let ptr = self.ptr.as_ptr(); + core::mem::forget(self); + // SAFETY: The pointer is valid. + unsafe { core::ptr::addr_of!((*ptr).data) } + } + + /// Return a raw pointer to the data in this arc. + pub fn as_ptr(this: &Self) -> *const T { + let ptr = this.ptr.as_ptr(); + + // SAFETY: As `ptr` points to a valid allocation of type `ArcInner`, + // field projection to `data`is within bounds of the allocation. + unsafe { core::ptr::addr_of!((*ptr).data) } + } + + /// Recreates an [`Arc`] instance previously deconstructed via [`Arc::into_raw`]. + /// + /// # Safety + /// + /// `ptr` must have been returned by a previous call to [`Arc::into_raw`]. Additionally, it + /// must not be called more than once for each previous call to [`Arc::into_raw`]. + pub unsafe fn from_raw(ptr: *const T) -> Self { + // SAFETY: The caller promises that this pointer originates from a call to `into_raw` on an + // `Arc` that is still valid. + let ptr = unsafe { ArcInner::container_of(ptr) }; + + // SAFETY: By the safety requirements we know that `ptr` came from `Arc::into_raw`, so the + // reference count held then will be owned by the new `Arc` object. + unsafe { Self::from_inner(ptr) } + } + + /// Returns an [`ArcBorrow`] from the given [`Arc`]. + /// + /// This is useful when the argument of a function call is an [`ArcBorrow`] (e.g., in a method + /// receiver), but we have an [`Arc`] instead. Getting an [`ArcBorrow`] is free when optimised. + #[inline] + pub fn as_arc_borrow(&self) -> ArcBorrow<'_, T> { + // SAFETY: The constraint that the lifetime of the shared reference must outlive that of + // the returned `ArcBorrow` ensures that the object remains alive and that no mutable + // reference can be created. + unsafe { ArcBorrow::new(self.ptr) } + } + + /// Compare whether two [`Arc`] pointers reference the same underlying object. + pub fn ptr_eq(this: &Self, other: &Self) -> bool { + core::ptr::eq(this.ptr.as_ptr(), other.ptr.as_ptr()) + } + + /// Converts this [`Arc`] into a [`UniqueArc`], or destroys it if it is not unique. + /// + /// When this destroys the `Arc`, it does so while properly avoiding races. This means that + /// this method will never call the destructor of the value. + /// + /// # Examples + /// + /// ``` + /// use kernel::sync::{Arc, UniqueArc}; + /// + /// let arc = Arc::new(42, GFP_KERNEL)?; + /// let unique_arc = Arc::into_unique_or_drop(arc); + /// + /// // The above conversion should succeed since refcount of `arc` is 1. + /// assert!(unique_arc.is_some()); + /// + /// assert_eq!(*(unique_arc.unwrap()), 42); + /// + /// # Ok::<(), Error>(()) + /// ``` + /// + /// ``` + /// use kernel::sync::{Arc, UniqueArc}; + /// + /// let arc = Arc::new(42, GFP_KERNEL)?; + /// let another = arc.clone(); + /// + /// let unique_arc = Arc::into_unique_or_drop(arc); + /// + /// // The above conversion should fail since refcount of `arc` is >1. + /// assert!(unique_arc.is_none()); + /// + /// # Ok::<(), Error>(()) + /// ``` + pub fn into_unique_or_drop(this: Self) -> Option<Pin<UniqueArc<T>>> { + // We will manually manage the refcount in this method, so we disable the destructor. + let this = ManuallyDrop::new(this); + // SAFETY: We own a refcount, so the pointer is still valid. + let refcount = unsafe { &this.ptr.as_ref().refcount }; + + // If the refcount reaches a non-zero value, then we have destroyed this `Arc` and will + // return without further touching the `Arc`. If the refcount reaches zero, then there are + // no other arcs, and we can create a `UniqueArc`. + if refcount.dec_and_test() { + refcount.set(1); + + // INVARIANT: We own the only refcount to this arc, so we may create a `UniqueArc`. We + // must pin the `UniqueArc` because the values was previously in an `Arc`, and they pin + // their values. + Some(Pin::from(UniqueArc { + inner: ManuallyDrop::into_inner(this), + })) + } else { + None + } + } +} + +// SAFETY: The pointer returned by `into_foreign` was originally allocated as an +// `KBox<ArcInner<T>>`, so that type is what determines the alignment. +unsafe impl<T: 'static> ForeignOwnable for Arc<T> { + const FOREIGN_ALIGN: usize = <KBox<ArcInner<T>> as ForeignOwnable>::FOREIGN_ALIGN; + + type Borrowed<'a> = ArcBorrow<'a, T>; + type BorrowedMut<'a> = Self::Borrowed<'a>; + + fn into_foreign(self) -> *mut c_void { + ManuallyDrop::new(self).ptr.as_ptr().cast() + } + + unsafe fn from_foreign(ptr: *mut c_void) -> Self { + // SAFETY: The safety requirements of this function ensure that `ptr` comes from a previous + // call to `Self::into_foreign`. + let inner = unsafe { NonNull::new_unchecked(ptr.cast::<ArcInner<T>>()) }; + + // SAFETY: By the safety requirement of this function, we know that `ptr` came from + // a previous call to `Arc::into_foreign`, which guarantees that `ptr` is valid and + // holds a reference count increment that is transferrable to us. + unsafe { Self::from_inner(inner) } + } + + unsafe fn borrow<'a>(ptr: *mut c_void) -> ArcBorrow<'a, T> { + // SAFETY: The safety requirements of this function ensure that `ptr` comes from a previous + // call to `Self::into_foreign`. + let inner = unsafe { NonNull::new_unchecked(ptr.cast::<ArcInner<T>>()) }; + + // SAFETY: The safety requirements of `from_foreign` ensure that the object remains alive + // for the lifetime of the returned value. + unsafe { ArcBorrow::new(inner) } + } + + unsafe fn borrow_mut<'a>(ptr: *mut c_void) -> ArcBorrow<'a, T> { + // SAFETY: The safety requirements for `borrow_mut` are a superset of the safety + // requirements for `borrow`. + unsafe { <Self as ForeignOwnable>::borrow(ptr) } + } +} + +impl<T: ?Sized> Deref for Arc<T> { + type Target = T; + + fn deref(&self) -> &Self::Target { + // SAFETY: By the type invariant, there is necessarily a reference to the object, so it is + // safe to dereference it. + unsafe { &self.ptr.as_ref().data } + } +} + +impl<T: ?Sized> AsRef<T> for Arc<T> { + fn as_ref(&self) -> &T { + self.deref() + } +} + +/// # Examples +/// +/// ``` +/// # use core::borrow::Borrow; +/// # use kernel::sync::Arc; +/// struct Foo<B: Borrow<u32>>(B); +/// +/// // Owned instance. +/// let owned = Foo(1); +/// +/// // Shared instance. +/// let arc = Arc::new(1, GFP_KERNEL)?; +/// let shared = Foo(arc.clone()); +/// +/// let i = 1; +/// // Borrowed from `i`. +/// let borrowed = Foo(&i); +/// # Ok::<(), Error>(()) +/// ``` +impl<T: ?Sized> Borrow<T> for Arc<T> { + fn borrow(&self) -> &T { + self.deref() + } +} + +impl<T: ?Sized> Clone for Arc<T> { + fn clone(&self) -> Self { + // INVARIANT: `Refcount` saturates the refcount, so it cannot overflow to zero. + // SAFETY: By the type invariant, there is necessarily a reference to the object, so it is + // safe to increment the refcount. + unsafe { self.ptr.as_ref() }.refcount.inc(); + + // SAFETY: We just incremented the refcount. This increment is now owned by the new `Arc`. + unsafe { Self::from_inner(self.ptr) } + } +} + +impl<T: ?Sized> Drop for Arc<T> { + fn drop(&mut self) { + // INVARIANT: If the refcount reaches zero, there are no other instances of `Arc`, and + // this instance is being dropped, so the broken invariant is not observable. + // SAFETY: By the type invariant, there is necessarily a reference to the object. + let is_zero = unsafe { self.ptr.as_ref() }.refcount.dec_and_test(); + if is_zero { + // The count reached zero, we must free the memory. + // + // SAFETY: The pointer was initialised from the result of `KBox::leak`. + unsafe { drop(KBox::from_raw(self.ptr.as_ptr())) }; + } + } +} + +impl<T: ?Sized> From<UniqueArc<T>> for Arc<T> { + fn from(item: UniqueArc<T>) -> Self { + item.inner + } +} + +impl<T: ?Sized> From<Pin<UniqueArc<T>>> for Arc<T> { + fn from(item: Pin<UniqueArc<T>>) -> Self { + // SAFETY: The type invariants of `Arc` guarantee that the data is pinned. + unsafe { Pin::into_inner_unchecked(item).inner } + } +} + +/// A borrowed reference to an [`Arc`] instance. +/// +/// For cases when one doesn't ever need to increment the refcount on the allocation, it is simpler +/// to use just `&T`, which we can trivially get from an [`Arc<T>`] instance. +/// +/// However, when one may need to increment the refcount, it is preferable to use an `ArcBorrow<T>` +/// over `&Arc<T>` because the latter results in a double-indirection: a pointer (shared reference) +/// to a pointer ([`Arc<T>`]) to the object (`T`). An [`ArcBorrow`] eliminates this double +/// indirection while still allowing one to increment the refcount and getting an [`Arc<T>`] when/if +/// needed. +/// +/// # Invariants +/// +/// There are no mutable references to the underlying [`Arc`], and it remains valid for the +/// lifetime of the [`ArcBorrow`] instance. +/// +/// # Examples +/// +/// ``` +/// use kernel::sync::{Arc, ArcBorrow}; +/// +/// struct Example; +/// +/// fn do_something(e: ArcBorrow<'_, Example>) -> Arc<Example> { +/// e.into() +/// } +/// +/// let obj = Arc::new(Example, GFP_KERNEL)?; +/// let cloned = do_something(obj.as_arc_borrow()); +/// +/// // Assert that both `obj` and `cloned` point to the same underlying object. +/// assert!(core::ptr::eq(&*obj, &*cloned)); +/// # Ok::<(), Error>(()) +/// ``` +/// +/// Using `ArcBorrow<T>` as the type of `self`: +/// +/// ``` +/// use kernel::sync::{Arc, ArcBorrow}; +/// +/// struct Example { +/// a: u32, +/// b: u32, +/// } +/// +/// impl Example { +/// fn use_reference(self: ArcBorrow<'_, Self>) { +/// // ... +/// } +/// } +/// +/// let obj = Arc::new(Example { a: 10, b: 20 }, GFP_KERNEL)?; +/// obj.as_arc_borrow().use_reference(); +/// # Ok::<(), Error>(()) +/// ``` +#[repr(transparent)] +#[cfg_attr(CONFIG_RUSTC_HAS_COERCE_POINTEE, derive(core::marker::CoercePointee))] +pub struct ArcBorrow<'a, T: ?Sized + 'a> { + inner: NonNull<ArcInner<T>>, + _p: PhantomData<&'a ()>, +} + +// This is to allow `ArcBorrow<U>` to be dispatched on when `ArcBorrow<T>` can be coerced into +// `ArcBorrow<U>`. +#[cfg(not(CONFIG_RUSTC_HAS_COERCE_POINTEE))] +impl<T: ?Sized + core::marker::Unsize<U>, U: ?Sized> core::ops::DispatchFromDyn<ArcBorrow<'_, U>> + for ArcBorrow<'_, T> +{ +} + +impl<T: ?Sized> Clone for ArcBorrow<'_, T> { + fn clone(&self) -> Self { + *self + } +} + +impl<T: ?Sized> Copy for ArcBorrow<'_, T> {} + +impl<T: ?Sized> ArcBorrow<'_, T> { + /// Creates a new [`ArcBorrow`] instance. + /// + /// # Safety + /// + /// Callers must ensure the following for the lifetime of the returned [`ArcBorrow`] instance: + /// 1. That `inner` remains valid; + /// 2. That no mutable references to `inner` are created. + unsafe fn new(inner: NonNull<ArcInner<T>>) -> Self { + // INVARIANT: The safety requirements guarantee the invariants. + Self { + inner, + _p: PhantomData, + } + } + + /// Creates an [`ArcBorrow`] to an [`Arc`] that has previously been deconstructed with + /// [`Arc::into_raw`] or [`Arc::as_ptr`]. + /// + /// # Safety + /// + /// * The provided pointer must originate from a call to [`Arc::into_raw`] or [`Arc::as_ptr`]. + /// * For the duration of the lifetime annotated on this `ArcBorrow`, the reference count must + /// not hit zero. + /// * For the duration of the lifetime annotated on this `ArcBorrow`, there must not be a + /// [`UniqueArc`] reference to this value. + pub unsafe fn from_raw(ptr: *const T) -> Self { + // SAFETY: The caller promises that this pointer originates from a call to `into_raw` on an + // `Arc` that is still valid. + let ptr = unsafe { ArcInner::container_of(ptr) }; + + // SAFETY: The caller promises that the value remains valid since the reference count must + // not hit zero, and no mutable reference will be created since that would involve a + // `UniqueArc`. + unsafe { Self::new(ptr) } + } +} + +impl<T: ?Sized> From<ArcBorrow<'_, T>> for Arc<T> { + fn from(b: ArcBorrow<'_, T>) -> Self { + // SAFETY: The existence of `b` guarantees that the refcount is non-zero. `ManuallyDrop` + // guarantees that `drop` isn't called, so it's ok that the temporary `Arc` doesn't own the + // increment. + ManuallyDrop::new(unsafe { Arc::from_inner(b.inner) }) + .deref() + .clone() + } +} + +impl<T: ?Sized> Deref for ArcBorrow<'_, T> { + type Target = T; + + fn deref(&self) -> &Self::Target { + // SAFETY: By the type invariant, the underlying object is still alive with no mutable + // references to it, so it is safe to create a shared reference. + unsafe { &self.inner.as_ref().data } + } +} + +/// A refcounted object that is known to have a refcount of 1. +/// +/// It is mutable and can be converted to an [`Arc`] so that it can be shared. +/// +/// # Invariants +/// +/// `inner` always has a reference count of 1. +/// +/// # Examples +/// +/// In the following example, we make changes to the inner object before turning it into an +/// `Arc<Test>` object (after which point, it cannot be mutated directly). Note that `x.into()` +/// cannot fail. +/// +/// ``` +/// use kernel::sync::{Arc, UniqueArc}; +/// +/// struct Example { +/// a: u32, +/// b: u32, +/// } +/// +/// fn test() -> Result<Arc<Example>> { +/// let mut x = UniqueArc::new(Example { a: 10, b: 20 }, GFP_KERNEL)?; +/// x.a += 1; +/// x.b += 1; +/// Ok(x.into()) +/// } +/// +/// # test().unwrap(); +/// ``` +/// +/// In the following example we first allocate memory for a refcounted `Example` but we don't +/// initialise it on allocation. We do initialise it later with a call to [`UniqueArc::write`], +/// followed by a conversion to `Arc<Example>`. This is particularly useful when allocation happens +/// in one context (e.g., sleepable) and initialisation in another (e.g., atomic): +/// +/// ``` +/// use kernel::sync::{Arc, UniqueArc}; +/// +/// struct Example { +/// a: u32, +/// b: u32, +/// } +/// +/// fn test() -> Result<Arc<Example>> { +/// let x = UniqueArc::new_uninit(GFP_KERNEL)?; +/// Ok(x.write(Example { a: 10, b: 20 }).into()) +/// } +/// +/// # test().unwrap(); +/// ``` +/// +/// In the last example below, the caller gets a pinned instance of `Example` while converting to +/// `Arc<Example>`; this is useful in scenarios where one needs a pinned reference during +/// initialisation, for example, when initialising fields that are wrapped in locks. +/// +/// ``` +/// use kernel::sync::{Arc, UniqueArc}; +/// +/// struct Example { +/// a: u32, +/// b: u32, +/// } +/// +/// fn test() -> Result<Arc<Example>> { +/// let mut pinned = Pin::from(UniqueArc::new(Example { a: 10, b: 20 }, GFP_KERNEL)?); +/// // We can modify `pinned` because it is `Unpin`. +/// pinned.as_mut().a += 1; +/// Ok(pinned.into()) +/// } +/// +/// # test().unwrap(); +/// ``` +pub struct UniqueArc<T: ?Sized> { + inner: Arc<T>, +} + +impl<T> InPlaceInit<T> for UniqueArc<T> { + type PinnedSelf = Pin<Self>; + + #[inline] + fn try_pin_init<E>(init: impl PinInit<T, E>, flags: Flags) -> Result<Self::PinnedSelf, E> + where + E: From<AllocError>, + { + UniqueArc::new_uninit(flags)?.write_pin_init(init) + } + + #[inline] + fn try_init<E>(init: impl Init<T, E>, flags: Flags) -> Result<Self, E> + where + E: From<AllocError>, + { + UniqueArc::new_uninit(flags)?.write_init(init) + } +} + +impl<T> InPlaceWrite<T> for UniqueArc<MaybeUninit<T>> { + type Initialized = UniqueArc<T>; + + fn write_init<E>(mut self, init: impl Init<T, E>) -> Result<Self::Initialized, E> { + let slot = self.as_mut_ptr(); + // SAFETY: When init errors/panics, slot will get deallocated but not dropped, + // slot is valid. + unsafe { init.__init(slot)? }; + // SAFETY: All fields have been initialized. + Ok(unsafe { self.assume_init() }) + } + + fn write_pin_init<E>(mut self, init: impl PinInit<T, E>) -> Result<Pin<Self::Initialized>, E> { + let slot = self.as_mut_ptr(); + // SAFETY: When init errors/panics, slot will get deallocated but not dropped, + // slot is valid and will not be moved, because we pin it later. + unsafe { init.__pinned_init(slot)? }; + // SAFETY: All fields have been initialized. + Ok(unsafe { self.assume_init() }.into()) + } +} + +impl<T> UniqueArc<T> { + /// Tries to allocate a new [`UniqueArc`] instance. + pub fn new(value: T, flags: Flags) -> Result<Self, AllocError> { + Ok(Self { + // INVARIANT: The newly-created object has a refcount of 1. + inner: Arc::new(value, flags)?, + }) + } + + /// Tries to allocate a new [`UniqueArc`] instance whose contents are not initialised yet. + pub fn new_uninit(flags: Flags) -> Result<UniqueArc<MaybeUninit<T>>, AllocError> { + // INVARIANT: The refcount is initialised to a non-zero value. + let inner = KBox::try_init::<AllocError>( + try_init!(ArcInner { + refcount: Refcount::new(1), + data <- pin_init::uninit::<T, AllocError>(), + }? AllocError), + flags, + )?; + Ok(UniqueArc { + // INVARIANT: The newly-created object has a refcount of 1. + // SAFETY: The pointer from the `KBox` is valid. + inner: unsafe { Arc::from_inner(KBox::leak(inner).into()) }, + }) + } +} + +impl<T> UniqueArc<MaybeUninit<T>> { + /// Converts a `UniqueArc<MaybeUninit<T>>` into a `UniqueArc<T>` by writing a value into it. + pub fn write(mut self, value: T) -> UniqueArc<T> { + self.deref_mut().write(value); + // SAFETY: We just wrote the value to be initialized. + unsafe { self.assume_init() } + } + + /// Unsafely assume that `self` is initialized. + /// + /// # Safety + /// + /// The caller guarantees that the value behind this pointer has been initialized. It is + /// *immediate* UB to call this when the value is not initialized. + pub unsafe fn assume_init(self) -> UniqueArc<T> { + let inner = ManuallyDrop::new(self).inner.ptr; + UniqueArc { + // SAFETY: The new `Arc` is taking over `ptr` from `self.inner` (which won't be + // dropped). The types are compatible because `MaybeUninit<T>` is compatible with `T`. + inner: unsafe { Arc::from_inner(inner.cast()) }, + } + } + + /// Initialize `self` using the given initializer. + pub fn init_with<E>(mut self, init: impl Init<T, E>) -> core::result::Result<UniqueArc<T>, E> { + // SAFETY: The supplied pointer is valid for initialization. + match unsafe { init.__init(self.as_mut_ptr()) } { + // SAFETY: Initialization completed successfully. + Ok(()) => Ok(unsafe { self.assume_init() }), + Err(err) => Err(err), + } + } + + /// Pin-initialize `self` using the given pin-initializer. + pub fn pin_init_with<E>( + mut self, + init: impl PinInit<T, E>, + ) -> core::result::Result<Pin<UniqueArc<T>>, E> { + // SAFETY: The supplied pointer is valid for initialization and we will later pin the value + // to ensure it does not move. + match unsafe { init.__pinned_init(self.as_mut_ptr()) } { + // SAFETY: Initialization completed successfully. + Ok(()) => Ok(unsafe { self.assume_init() }.into()), + Err(err) => Err(err), + } + } +} + +impl<T: ?Sized> From<UniqueArc<T>> for Pin<UniqueArc<T>> { + fn from(obj: UniqueArc<T>) -> Self { + // SAFETY: It is not possible to move/replace `T` inside a `Pin<UniqueArc<T>>` (unless `T` + // is `Unpin`), so it is ok to convert it to `Pin<UniqueArc<T>>`. + unsafe { Pin::new_unchecked(obj) } + } +} + +impl<T: ?Sized> Deref for UniqueArc<T> { + type Target = T; + + fn deref(&self) -> &Self::Target { + self.inner.deref() + } +} + +impl<T: ?Sized> DerefMut for UniqueArc<T> { + fn deref_mut(&mut self) -> &mut Self::Target { + // SAFETY: By the `Arc` type invariant, there is necessarily a reference to the object, so + // it is safe to dereference it. Additionally, we know there is only one reference when + // it's inside a `UniqueArc`, so it is safe to get a mutable reference. + unsafe { &mut self.inner.ptr.as_mut().data } + } +} + +/// # Examples +/// +/// ``` +/// # use core::borrow::Borrow; +/// # use kernel::sync::UniqueArc; +/// struct Foo<B: Borrow<u32>>(B); +/// +/// // Owned instance. +/// let owned = Foo(1); +/// +/// // Owned instance using `UniqueArc`. +/// let arc = UniqueArc::new(1, GFP_KERNEL)?; +/// let shared = Foo(arc); +/// +/// let i = 1; +/// // Borrowed from `i`. +/// let borrowed = Foo(&i); +/// # Ok::<(), Error>(()) +/// ``` +impl<T: ?Sized> Borrow<T> for UniqueArc<T> { + fn borrow(&self) -> &T { + self.deref() + } +} + +/// # Examples +/// +/// ``` +/// # use core::borrow::BorrowMut; +/// # use kernel::sync::UniqueArc; +/// struct Foo<B: BorrowMut<u32>>(B); +/// +/// // Owned instance. +/// let owned = Foo(1); +/// +/// // Owned instance using `UniqueArc`. +/// let arc = UniqueArc::new(1, GFP_KERNEL)?; +/// let shared = Foo(arc); +/// +/// let mut i = 1; +/// // Borrowed from `i`. +/// let borrowed = Foo(&mut i); +/// # Ok::<(), Error>(()) +/// ``` +impl<T: ?Sized> BorrowMut<T> for UniqueArc<T> { + fn borrow_mut(&mut self) -> &mut T { + self.deref_mut() + } +} + +impl<T: fmt::Display + ?Sized> fmt::Display for UniqueArc<T> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Display::fmt(self.deref(), f) + } +} + +impl<T: fmt::Display + ?Sized> fmt::Display for Arc<T> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Display::fmt(self.deref(), f) + } +} + +impl<T: fmt::Debug + ?Sized> fmt::Debug for UniqueArc<T> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Debug::fmt(self.deref(), f) + } +} + +impl<T: fmt::Debug + ?Sized> fmt::Debug for Arc<T> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Debug::fmt(self.deref(), f) + } +} diff --git a/rust/kernel/sync/arc/std_vendor.rs b/rust/kernel/sync/arc/std_vendor.rs new file mode 100644 index 000000000000..11b3f4ecca5f --- /dev/null +++ b/rust/kernel/sync/arc/std_vendor.rs @@ -0,0 +1,30 @@ +// SPDX-License-Identifier: Apache-2.0 OR MIT + +//! Rust standard library vendored code. +//! +//! The contents of this file come from the Rust standard library, hosted in +//! the <https://github.com/rust-lang/rust> repository, licensed under +//! "Apache-2.0 OR MIT" and adapted for kernel use. For copyright details, +//! see <https://github.com/rust-lang/rust/blob/master/COPYRIGHT>. + +use crate::sync::{arc::ArcInner, Arc}; +use core::any::Any; + +impl Arc<dyn Any + Send + Sync> { + /// Attempt to downcast the `Arc<dyn Any + Send + Sync>` to a concrete type. + pub fn downcast<T>(self) -> core::result::Result<Arc<T>, Self> + where + T: Any + Send + Sync, + { + if (*self).is::<T>() { + // SAFETY: We have just checked that the type is correct, so we can cast the pointer. + unsafe { + let ptr = self.ptr.cast::<ArcInner<T>>(); + core::mem::forget(self); + Ok(Arc::from_inner(ptr)) + } + } else { + Err(self) + } + } +} diff --git a/rust/kernel/sync/aref.rs b/rust/kernel/sync/aref.rs new file mode 100644 index 000000000000..0d24a0432015 --- /dev/null +++ b/rust/kernel/sync/aref.rs @@ -0,0 +1,169 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! Internal reference counting support. +//! +//! Many C types already have their own reference counting mechanism (e.g. by storing a +//! `refcount_t`). This module provides support for directly using their internal reference count +//! from Rust; instead of making users have to use an additional Rust-reference count in the form of +//! [`Arc`]. +//! +//! The smart pointer [`ARef<T>`] acts similarly to [`Arc<T>`] in that it holds a refcount on the +//! underlying object, but this refcount is internal to the object. It essentially is a Rust +//! implementation of the `get_` and `put_` pattern used in C for reference counting. +//! +//! To make use of [`ARef<MyType>`], `MyType` needs to implement [`AlwaysRefCounted`]. It is a trait +//! for accessing the internal reference count of an object of the `MyType` type. +//! +//! [`Arc`]: crate::sync::Arc +//! [`Arc<T>`]: crate::sync::Arc + +use core::{marker::PhantomData, mem::ManuallyDrop, ops::Deref, ptr::NonNull}; + +/// Types that are _always_ reference counted. +/// +/// It allows such types to define their own custom ref increment and decrement functions. +/// Additionally, it allows users to convert from a shared reference `&T` to an owned reference +/// [`ARef<T>`]. +/// +/// This is usually implemented by wrappers to existing structures on the C side of the code. For +/// Rust code, the recommendation is to use [`Arc`](crate::sync::Arc) to create reference-counted +/// instances of a type. +/// +/// # Safety +/// +/// Implementers must ensure that increments to the reference count keep the object alive in memory +/// at least until matching decrements are performed. +/// +/// Implementers must also ensure that all instances are reference-counted. (Otherwise they +/// won't be able to honour the requirement that [`AlwaysRefCounted::inc_ref`] keep the object +/// alive.) +pub unsafe trait AlwaysRefCounted { + /// Increments the reference count on the object. + fn inc_ref(&self); + + /// Decrements the reference count on the object. + /// + /// Frees the object when the count reaches zero. + /// + /// # Safety + /// + /// Callers must ensure that there was a previous matching increment to the reference count, + /// and that the object is no longer used after its reference count is decremented (as it may + /// result in the object being freed), unless the caller owns another increment on the refcount + /// (e.g., it calls [`AlwaysRefCounted::inc_ref`] twice, then calls + /// [`AlwaysRefCounted::dec_ref`] once). + unsafe fn dec_ref(obj: NonNull<Self>); +} + +/// An owned reference to an always-reference-counted object. +/// +/// The object's reference count is automatically decremented when an instance of [`ARef`] is +/// dropped. It is also automatically incremented when a new instance is created via +/// [`ARef::clone`]. +/// +/// # Invariants +/// +/// The pointer stored in `ptr` is non-null and valid for the lifetime of the [`ARef`] instance. In +/// particular, the [`ARef`] instance owns an increment on the underlying object's reference count. +pub struct ARef<T: AlwaysRefCounted> { + ptr: NonNull<T>, + _p: PhantomData<T>, +} + +// SAFETY: It is safe to send `ARef<T>` to another thread when the underlying `T` is `Sync` because +// it effectively means sharing `&T` (which is safe because `T` is `Sync`); additionally, it needs +// `T` to be `Send` because any thread that has an `ARef<T>` may ultimately access `T` using a +// mutable reference, for example, when the reference count reaches zero and `T` is dropped. +unsafe impl<T: AlwaysRefCounted + Sync + Send> Send for ARef<T> {} + +// SAFETY: It is safe to send `&ARef<T>` to another thread when the underlying `T` is `Sync` +// because it effectively means sharing `&T` (which is safe because `T` is `Sync`); additionally, +// it needs `T` to be `Send` because any thread that has a `&ARef<T>` may clone it and get an +// `ARef<T>` on that thread, so the thread may ultimately access `T` using a mutable reference, for +// example, when the reference count reaches zero and `T` is dropped. +unsafe impl<T: AlwaysRefCounted + Sync + Send> Sync for ARef<T> {} + +impl<T: AlwaysRefCounted> ARef<T> { + /// Creates a new instance of [`ARef`]. + /// + /// It takes over an increment of the reference count on the underlying object. + /// + /// # Safety + /// + /// Callers must ensure that the reference count was incremented at least once, and that they + /// are properly relinquishing one increment. That is, if there is only one increment, callers + /// must not use the underlying object anymore -- it is only safe to do so via the newly + /// created [`ARef`]. + pub unsafe fn from_raw(ptr: NonNull<T>) -> Self { + // INVARIANT: The safety requirements guarantee that the new instance now owns the + // increment on the refcount. + Self { + ptr, + _p: PhantomData, + } + } + + /// Consumes the `ARef`, returning a raw pointer. + /// + /// This function does not change the refcount. After calling this function, the caller is + /// responsible for the refcount previously managed by the `ARef`. + /// + /// # Examples + /// + /// ``` + /// use core::ptr::NonNull; + /// use kernel::sync::aref::{ARef, AlwaysRefCounted}; + /// + /// struct Empty {} + /// + /// # // SAFETY: TODO. + /// unsafe impl AlwaysRefCounted for Empty { + /// fn inc_ref(&self) {} + /// unsafe fn dec_ref(_obj: NonNull<Self>) {} + /// } + /// + /// let mut data = Empty {}; + /// let ptr = NonNull::<Empty>::new(&mut data).unwrap(); + /// # // SAFETY: TODO. + /// let data_ref: ARef<Empty> = unsafe { ARef::from_raw(ptr) }; + /// let raw_ptr: NonNull<Empty> = ARef::into_raw(data_ref); + /// + /// assert_eq!(ptr, raw_ptr); + /// ``` + pub fn into_raw(me: Self) -> NonNull<T> { + ManuallyDrop::new(me).ptr + } +} + +impl<T: AlwaysRefCounted> Clone for ARef<T> { + fn clone(&self) -> Self { + self.inc_ref(); + // SAFETY: We just incremented the refcount above. + unsafe { Self::from_raw(self.ptr) } + } +} + +impl<T: AlwaysRefCounted> Deref for ARef<T> { + type Target = T; + + fn deref(&self) -> &Self::Target { + // SAFETY: The type invariants guarantee that the object is valid. + unsafe { self.ptr.as_ref() } + } +} + +impl<T: AlwaysRefCounted> From<&T> for ARef<T> { + fn from(b: &T) -> Self { + b.inc_ref(); + // SAFETY: We just incremented the refcount above. + unsafe { Self::from_raw(NonNull::from(b)) } + } +} + +impl<T: AlwaysRefCounted> Drop for ARef<T> { + fn drop(&mut self) { + // SAFETY: The type invariants guarantee that the `ARef` owns the reference we're about to + // decrement. + unsafe { T::dec_ref(self.ptr) }; + } +} diff --git a/rust/kernel/sync/atomic.rs b/rust/kernel/sync/atomic.rs new file mode 100644 index 000000000000..4aebeacb961a --- /dev/null +++ b/rust/kernel/sync/atomic.rs @@ -0,0 +1,562 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! Atomic primitives. +//! +//! These primitives have the same semantics as their C counterparts: and the precise definitions of +//! semantics can be found at [`LKMM`]. Note that Linux Kernel Memory (Consistency) Model is the +//! only model for Rust code in kernel, and Rust's own atomics should be avoided. +//! +//! # Data races +//! +//! [`LKMM`] atomics have different rules regarding data races: +//! +//! - A normal write from C side is treated as an atomic write if +//! CONFIG_KCSAN_ASSUME_PLAIN_WRITES_ATOMIC=y. +//! - Mixed-size atomic accesses don't cause data races. +//! +//! [`LKMM`]: srctree/tools/memory-model/ + +mod internal; +pub mod ordering; +mod predefine; + +pub use internal::AtomicImpl; +pub use ordering::{Acquire, Full, Relaxed, Release}; + +pub(crate) use internal::{AtomicArithmeticOps, AtomicBasicOps, AtomicExchangeOps}; + +use crate::build_error; +use internal::AtomicRepr; +use ordering::OrderingType; + +/// A memory location which can be safely modified from multiple execution contexts. +/// +/// This has the same size, alignment and bit validity as the underlying type `T`. And it disables +/// niche optimization for the same reason as [`UnsafeCell`]. +/// +/// The atomic operations are implemented in a way that is fully compatible with the [Linux Kernel +/// Memory (Consistency) Model][LKMM], hence they should be modeled as the corresponding +/// [`LKMM`][LKMM] atomic primitives. With the help of [`Atomic::from_ptr()`] and +/// [`Atomic::as_ptr()`], this provides a way to interact with [C-side atomic operations] +/// (including those without the `atomic` prefix, e.g. `READ_ONCE()`, `WRITE_ONCE()`, +/// `smp_load_acquire()` and `smp_store_release()`). +/// +/// # Invariants +/// +/// `self.0` is a valid `T`. +/// +/// [`UnsafeCell`]: core::cell::UnsafeCell +/// [LKMM]: srctree/tools/memory-model/ +/// [C-side atomic operations]: srctree/Documentation/atomic_t.txt +#[repr(transparent)] +pub struct Atomic<T: AtomicType>(AtomicRepr<T::Repr>); + +// SAFETY: `Atomic<T>` is safe to share among execution contexts because all accesses are atomic. +unsafe impl<T: AtomicType> Sync for Atomic<T> {} + +/// Types that support basic atomic operations. +/// +/// # Round-trip transmutability +/// +/// `T` is round-trip transmutable to `U` if and only if both of these properties hold: +/// +/// - Any valid bit pattern for `T` is also a valid bit pattern for `U`. +/// - Transmuting (e.g. using [`transmute()`]) a value of type `T` to `U` and then to `T` again +/// yields a value that is in all aspects equivalent to the original value. +/// +/// # Safety +/// +/// - [`Self`] must have the same size and alignment as [`Self::Repr`]. +/// - [`Self`] must be [round-trip transmutable] to [`Self::Repr`]. +/// +/// Note that this is more relaxed than requiring the bi-directional transmutability (i.e. +/// [`transmute()`] is always sound between `U` and `T`) because of the support for atomic +/// variables over unit-only enums, see [Examples]. +/// +/// # Limitations +/// +/// Because C primitives are used to implement the atomic operations, and a C function requires a +/// valid object of a type to operate on (i.e. no `MaybeUninit<_>`), hence at the Rust <-> C +/// surface, only types with all the bits initialized can be passed. As a result, types like `(u8, +/// u16)` (padding bytes are uninitialized) are currently not supported. +/// +/// # Examples +/// +/// A unit-only enum that implements [`AtomicType`]: +/// +/// ``` +/// use kernel::sync::atomic::{AtomicType, Atomic, Relaxed}; +/// +/// #[derive(Clone, Copy, PartialEq, Eq)] +/// #[repr(i32)] +/// enum State { +/// Uninit = 0, +/// Working = 1, +/// Done = 2, +/// }; +/// +/// // SAFETY: `State` and `i32` has the same size and alignment, and it's round-trip +/// // transmutable to `i32`. +/// unsafe impl AtomicType for State { +/// type Repr = i32; +/// } +/// +/// let s = Atomic::new(State::Uninit); +/// +/// assert_eq!(State::Uninit, s.load(Relaxed)); +/// ``` +/// [`transmute()`]: core::mem::transmute +/// [round-trip transmutable]: AtomicType#round-trip-transmutability +/// [Examples]: AtomicType#examples +pub unsafe trait AtomicType: Sized + Send + Copy { + /// The backing atomic implementation type. + type Repr: AtomicImpl; +} + +/// Types that support atomic add operations. +/// +/// # Safety +/// +// TODO: Properly defines `wrapping_add` in the following comment. +/// `wrapping_add` any value of type `Self::Repr::Delta` obtained by [`Self::rhs_into_delta()`] to +/// any value of type `Self::Repr` obtained through transmuting a value of type `Self` to must +/// yield a value with a bit pattern also valid for `Self`. +pub unsafe trait AtomicAdd<Rhs = Self>: AtomicType { + /// Converts `Rhs` into the `Delta` type of the atomic implementation. + fn rhs_into_delta(rhs: Rhs) -> <Self::Repr as AtomicImpl>::Delta; +} + +#[inline(always)] +const fn into_repr<T: AtomicType>(v: T) -> T::Repr { + // SAFETY: Per the safety requirement of `AtomicType`, `T` is round-trip transmutable to + // `T::Repr`, therefore the transmute operation is sound. + unsafe { core::mem::transmute_copy(&v) } +} + +/// # Safety +/// +/// `r` must be a valid bit pattern of `T`. +#[inline(always)] +const unsafe fn from_repr<T: AtomicType>(r: T::Repr) -> T { + // SAFETY: Per the safety requirement of the function, the transmute operation is sound. + unsafe { core::mem::transmute_copy(&r) } +} + +impl<T: AtomicType> Atomic<T> { + /// Creates a new atomic `T`. + pub const fn new(v: T) -> Self { + // INVARIANT: Per the safety requirement of `AtomicType`, `into_repr(v)` is a valid `T`. + Self(AtomicRepr::new(into_repr(v))) + } + + /// Creates a reference to an atomic `T` from a pointer of `T`. + /// + /// This usually is used when communicating with C side or manipulating a C struct, see + /// examples below. + /// + /// # Safety + /// + /// - `ptr` is aligned to `align_of::<T>()`. + /// - `ptr` is valid for reads and writes for `'a`. + /// - For the duration of `'a`, other accesses to `*ptr` must not cause data races (defined + /// by [`LKMM`]) against atomic operations on the returned reference. Note that if all other + /// accesses are atomic, then this safety requirement is trivially fulfilled. + /// + /// [`LKMM`]: srctree/tools/memory-model + /// + /// # Examples + /// + /// Using [`Atomic::from_ptr()`] combined with [`Atomic::load()`] or [`Atomic::store()`] can + /// achieve the same functionality as `READ_ONCE()`/`smp_load_acquire()` or + /// `WRITE_ONCE()`/`smp_store_release()` in C side: + /// + /// ``` + /// # use kernel::types::Opaque; + /// use kernel::sync::atomic::{Atomic, Relaxed, Release}; + /// + /// // Assume there is a C struct `foo`. + /// mod cbindings { + /// #[repr(C)] + /// pub(crate) struct foo { + /// pub(crate) a: i32, + /// pub(crate) b: i32 + /// } + /// } + /// + /// let tmp = Opaque::new(cbindings::foo { a: 1, b: 2 }); + /// + /// // struct foo *foo_ptr = ..; + /// let foo_ptr = tmp.get(); + /// + /// // SAFETY: `foo_ptr` is valid, and `.a` is in bounds. + /// let foo_a_ptr = unsafe { &raw mut (*foo_ptr).a }; + /// + /// // a = READ_ONCE(foo_ptr->a); + /// // + /// // SAFETY: `foo_a_ptr` is valid for read, and all other accesses on it is atomic, so no + /// // data race. + /// let a = unsafe { Atomic::from_ptr(foo_a_ptr) }.load(Relaxed); + /// # assert_eq!(a, 1); + /// + /// // smp_store_release(&foo_ptr->a, 2); + /// // + /// // SAFETY: `foo_a_ptr` is valid for writes, and all other accesses on it is atomic, so + /// // no data race. + /// unsafe { Atomic::from_ptr(foo_a_ptr) }.store(2, Release); + /// ``` + pub unsafe fn from_ptr<'a>(ptr: *mut T) -> &'a Self + where + T: Sync, + { + // CAST: `T` and `Atomic<T>` have the same size, alignment and bit validity. + // SAFETY: Per function safety requirement, `ptr` is a valid pointer and the object will + // live long enough. It's safe to return a `&Atomic<T>` because function safety requirement + // guarantees other accesses won't cause data races. + unsafe { &*ptr.cast::<Self>() } + } + + /// Returns a pointer to the underlying atomic `T`. + /// + /// Note that use of the return pointer must not cause data races defined by [`LKMM`]. + /// + /// # Guarantees + /// + /// The returned pointer is valid and properly aligned (i.e. aligned to [`align_of::<T>()`]). + /// + /// [`LKMM`]: srctree/tools/memory-model + /// [`align_of::<T>()`]: core::mem::align_of + pub const fn as_ptr(&self) -> *mut T { + // GUARANTEE: Per the function guarantee of `AtomicRepr::as_ptr()`, the `self.0.as_ptr()` + // must be a valid and properly aligned pointer for `T::Repr`, and per the safety guarantee + // of `AtomicType`, it's a valid and properly aligned pointer of `T`. + self.0.as_ptr().cast() + } + + /// Returns a mutable reference to the underlying atomic `T`. + /// + /// This is safe because the mutable reference of the atomic `T` guarantees exclusive access. + pub fn get_mut(&mut self) -> &mut T { + // CAST: `T` and `T::Repr` has the same size and alignment per the safety requirement of + // `AtomicType`, and per the type invariants `self.0` is a valid `T`, therefore the casting + // result is a valid pointer of `T`. + // SAFETY: The pointer is valid per the CAST comment above, and the mutable reference + // guarantees exclusive access. + unsafe { &mut *self.0.as_ptr().cast() } + } +} + +impl<T: AtomicType> Atomic<T> +where + T::Repr: AtomicBasicOps, +{ + /// Loads the value from the atomic `T`. + /// + /// # Examples + /// + /// ``` + /// use kernel::sync::atomic::{Atomic, Relaxed}; + /// + /// let x = Atomic::new(42i32); + /// + /// assert_eq!(42, x.load(Relaxed)); + /// + /// let x = Atomic::new(42i64); + /// + /// assert_eq!(42, x.load(Relaxed)); + /// ``` + #[doc(alias("atomic_read", "atomic64_read"))] + #[inline(always)] + pub fn load<Ordering: ordering::AcquireOrRelaxed>(&self, _: Ordering) -> T { + let v = { + match Ordering::TYPE { + OrderingType::Relaxed => T::Repr::atomic_read(&self.0), + OrderingType::Acquire => T::Repr::atomic_read_acquire(&self.0), + _ => build_error!("Wrong ordering"), + } + }; + + // SAFETY: `v` comes from reading `self.0`, which is a valid `T` per the type invariants. + unsafe { from_repr(v) } + } + + /// Stores a value to the atomic `T`. + /// + /// # Examples + /// + /// ``` + /// use kernel::sync::atomic::{Atomic, Relaxed}; + /// + /// let x = Atomic::new(42i32); + /// + /// assert_eq!(42, x.load(Relaxed)); + /// + /// x.store(43, Relaxed); + /// + /// assert_eq!(43, x.load(Relaxed)); + /// ``` + #[doc(alias("atomic_set", "atomic64_set"))] + #[inline(always)] + pub fn store<Ordering: ordering::ReleaseOrRelaxed>(&self, v: T, _: Ordering) { + let v = into_repr(v); + + // INVARIANT: `v` is a valid `T`, and is stored to `self.0` by `atomic_set*()`. + match Ordering::TYPE { + OrderingType::Relaxed => T::Repr::atomic_set(&self.0, v), + OrderingType::Release => T::Repr::atomic_set_release(&self.0, v), + _ => build_error!("Wrong ordering"), + } + } +} + +impl<T: AtomicType + core::fmt::Debug> core::fmt::Debug for Atomic<T> +where + T::Repr: AtomicBasicOps, +{ + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + core::fmt::Debug::fmt(&self.load(Relaxed), f) + } +} + +impl<T: AtomicType> Atomic<T> +where + T::Repr: AtomicExchangeOps, +{ + /// Atomic exchange. + /// + /// Atomically updates `*self` to `v` and returns the old value of `*self`. + /// + /// # Examples + /// + /// ``` + /// use kernel::sync::atomic::{Atomic, Acquire, Relaxed}; + /// + /// let x = Atomic::new(42); + /// + /// assert_eq!(42, x.xchg(52, Acquire)); + /// assert_eq!(52, x.load(Relaxed)); + /// ``` + #[doc(alias("atomic_xchg", "atomic64_xchg", "swap"))] + #[inline(always)] + pub fn xchg<Ordering: ordering::Ordering>(&self, v: T, _: Ordering) -> T { + let v = into_repr(v); + + // INVARIANT: `self.0` is a valid `T` after `atomic_xchg*()` because `v` is transmutable to + // `T`. + let ret = { + match Ordering::TYPE { + OrderingType::Full => T::Repr::atomic_xchg(&self.0, v), + OrderingType::Acquire => T::Repr::atomic_xchg_acquire(&self.0, v), + OrderingType::Release => T::Repr::atomic_xchg_release(&self.0, v), + OrderingType::Relaxed => T::Repr::atomic_xchg_relaxed(&self.0, v), + } + }; + + // SAFETY: `ret` comes from reading `*self`, which is a valid `T` per type invariants. + unsafe { from_repr(ret) } + } + + /// Atomic compare and exchange. + /// + /// If `*self` == `old`, atomically updates `*self` to `new`. Otherwise, `*self` is not + /// modified. + /// + /// Compare: The comparison is done via the byte level comparison between `*self` and `old`. + /// + /// Ordering: When succeeds, provides the corresponding ordering as the `Ordering` type + /// parameter indicates, and a failed one doesn't provide any ordering, the load part of a + /// failed cmpxchg is a [`Relaxed`] load. + /// + /// Returns `Ok(value)` if cmpxchg succeeds, and `value` is guaranteed to be equal to `old`, + /// otherwise returns `Err(value)`, and `value` is the current value of `*self`. + /// + /// # Examples + /// + /// ``` + /// use kernel::sync::atomic::{Atomic, Full, Relaxed}; + /// + /// let x = Atomic::new(42); + /// + /// // Checks whether cmpxchg succeeded. + /// let success = x.cmpxchg(52, 64, Relaxed).is_ok(); + /// # assert!(!success); + /// + /// // Checks whether cmpxchg failed. + /// let failure = x.cmpxchg(52, 64, Relaxed).is_err(); + /// # assert!(failure); + /// + /// // Uses the old value if failed, probably re-try cmpxchg. + /// match x.cmpxchg(52, 64, Relaxed) { + /// Ok(_) => { }, + /// Err(old) => { + /// // do something with `old`. + /// # assert_eq!(old, 42); + /// } + /// } + /// + /// // Uses the latest value regardlessly, same as atomic_cmpxchg() in C. + /// let latest = x.cmpxchg(42, 64, Full).unwrap_or_else(|old| old); + /// # assert_eq!(42, latest); + /// assert_eq!(64, x.load(Relaxed)); + /// ``` + /// + /// [`Relaxed`]: ordering::Relaxed + #[doc(alias( + "atomic_cmpxchg", + "atomic64_cmpxchg", + "atomic_try_cmpxchg", + "atomic64_try_cmpxchg", + "compare_exchange" + ))] + #[inline(always)] + pub fn cmpxchg<Ordering: ordering::Ordering>( + &self, + mut old: T, + new: T, + o: Ordering, + ) -> Result<T, T> { + // Note on code generation: + // + // try_cmpxchg() is used to implement cmpxchg(), and if the helper functions are inlined, + // the compiler is able to figure out that branch is not needed if the users don't care + // about whether the operation succeeds or not. One exception is on x86, due to commit + // 44fe84459faf ("locking/atomic: Fix atomic_try_cmpxchg() semantics"), the + // atomic_try_cmpxchg() on x86 has a branch even if the caller doesn't care about the + // success of cmpxchg and only wants to use the old value. For example, for code like: + // + // let latest = x.cmpxchg(42, 64, Full).unwrap_or_else(|old| old); + // + // It will still generate code: + // + // movl $0x40, %ecx + // movl $0x34, %eax + // lock + // cmpxchgl %ecx, 0x4(%rsp) + // jne 1f + // 2: + // ... + // 1: movl %eax, %ecx + // jmp 2b + // + // This might be "fixed" by introducing a try_cmpxchg_exclusive() that knows the "*old" + // location in the C function is always safe to write. + if self.try_cmpxchg(&mut old, new, o) { + Ok(old) + } else { + Err(old) + } + } + + /// Atomic compare and exchange and returns whether the operation succeeds. + /// + /// If `*self` == `old`, atomically updates `*self` to `new`. Otherwise, `*self` is not + /// modified, `*old` is updated to the current value of `*self`. + /// + /// "Compare" and "Ordering" part are the same as [`Atomic::cmpxchg()`]. + /// + /// Returns `true` means the cmpxchg succeeds otherwise returns `false`. + #[inline(always)] + fn try_cmpxchg<Ordering: ordering::Ordering>(&self, old: &mut T, new: T, _: Ordering) -> bool { + let mut tmp = into_repr(*old); + let new = into_repr(new); + + // INVARIANT: `self.0` is a valid `T` after `atomic_try_cmpxchg*()` because `new` is + // transmutable to `T`. + let ret = { + match Ordering::TYPE { + OrderingType::Full => T::Repr::atomic_try_cmpxchg(&self.0, &mut tmp, new), + OrderingType::Acquire => { + T::Repr::atomic_try_cmpxchg_acquire(&self.0, &mut tmp, new) + } + OrderingType::Release => { + T::Repr::atomic_try_cmpxchg_release(&self.0, &mut tmp, new) + } + OrderingType::Relaxed => { + T::Repr::atomic_try_cmpxchg_relaxed(&self.0, &mut tmp, new) + } + } + }; + + // SAFETY: `tmp` comes from reading `*self`, which is a valid `T` per type invariants. + *old = unsafe { from_repr(tmp) }; + + ret + } +} + +impl<T: AtomicType> Atomic<T> +where + T::Repr: AtomicArithmeticOps, +{ + /// Atomic add. + /// + /// Atomically updates `*self` to `(*self).wrapping_add(v)`. + /// + /// # Examples + /// + /// ``` + /// use kernel::sync::atomic::{Atomic, Relaxed}; + /// + /// let x = Atomic::new(42); + /// + /// assert_eq!(42, x.load(Relaxed)); + /// + /// x.add(12, Relaxed); + /// + /// assert_eq!(54, x.load(Relaxed)); + /// ``` + #[inline(always)] + pub fn add<Rhs>(&self, v: Rhs, _: ordering::Relaxed) + where + T: AtomicAdd<Rhs>, + { + let v = T::rhs_into_delta(v); + + // INVARIANT: `self.0` is a valid `T` after `atomic_add()` due to safety requirement of + // `AtomicAdd`. + T::Repr::atomic_add(&self.0, v); + } + + /// Atomic fetch and add. + /// + /// Atomically updates `*self` to `(*self).wrapping_add(v)`, and returns the value of `*self` + /// before the update. + /// + /// # Examples + /// + /// ``` + /// use kernel::sync::atomic::{Atomic, Acquire, Full, Relaxed}; + /// + /// let x = Atomic::new(42); + /// + /// assert_eq!(42, x.load(Relaxed)); + /// + /// assert_eq!(54, { x.fetch_add(12, Acquire); x.load(Relaxed) }); + /// + /// let x = Atomic::new(42); + /// + /// assert_eq!(42, x.load(Relaxed)); + /// + /// assert_eq!(54, { x.fetch_add(12, Full); x.load(Relaxed) } ); + /// ``` + #[inline(always)] + pub fn fetch_add<Rhs, Ordering: ordering::Ordering>(&self, v: Rhs, _: Ordering) -> T + where + T: AtomicAdd<Rhs>, + { + let v = T::rhs_into_delta(v); + + // INVARIANT: `self.0` is a valid `T` after `atomic_fetch_add*()` due to safety requirement + // of `AtomicAdd`. + let ret = { + match Ordering::TYPE { + OrderingType::Full => T::Repr::atomic_fetch_add(&self.0, v), + OrderingType::Acquire => T::Repr::atomic_fetch_add_acquire(&self.0, v), + OrderingType::Release => T::Repr::atomic_fetch_add_release(&self.0, v), + OrderingType::Relaxed => T::Repr::atomic_fetch_add_relaxed(&self.0, v), + } + }; + + // SAFETY: `ret` comes from reading `self.0`, which is a valid `T` per type invariants. + unsafe { from_repr(ret) } + } +} diff --git a/rust/kernel/sync/atomic/internal.rs b/rust/kernel/sync/atomic/internal.rs new file mode 100644 index 000000000000..6fdd8e59f45b --- /dev/null +++ b/rust/kernel/sync/atomic/internal.rs @@ -0,0 +1,265 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! Atomic internal implementations. +//! +//! Provides 1:1 mapping to the C atomic operations. + +use crate::bindings; +use crate::macros::paste; +use core::cell::UnsafeCell; + +mod private { + /// Sealed trait marker to disable customized impls on atomic implementation traits. + pub trait Sealed {} +} + +// `i32` and `i64` are only supported atomic implementations. +impl private::Sealed for i32 {} +impl private::Sealed for i64 {} + +/// A marker trait for types that implement atomic operations with C side primitives. +/// +/// This trait is sealed, and only types that have directly mapping to the C side atomics should +/// impl this: +/// +/// - `i32` maps to `atomic_t`. +/// - `i64` maps to `atomic64_t`. +pub trait AtomicImpl: Sized + Send + Copy + private::Sealed { + /// The type of the delta in arithmetic or logical operations. + /// + /// For example, in `atomic_add(ptr, v)`, it's the type of `v`. Usually it's the same type of + /// [`Self`], but it may be different for the atomic pointer type. + type Delta; +} + +// `atomic_t` implements atomic operations on `i32`. +impl AtomicImpl for i32 { + type Delta = Self; +} + +// `atomic64_t` implements atomic operations on `i64`. +impl AtomicImpl for i64 { + type Delta = Self; +} + +/// Atomic representation. +#[repr(transparent)] +pub struct AtomicRepr<T: AtomicImpl>(UnsafeCell<T>); + +impl<T: AtomicImpl> AtomicRepr<T> { + /// Creates a new atomic representation `T`. + pub const fn new(v: T) -> Self { + Self(UnsafeCell::new(v)) + } + + /// Returns a pointer to the underlying `T`. + /// + /// # Guarantees + /// + /// The returned pointer is valid and properly aligned (i.e. aligned to [`align_of::<T>()`]). + pub const fn as_ptr(&self) -> *mut T { + // GUARANTEE: `self.0` is an `UnsafeCell<T>`, therefore the pointer returned by `.get()` + // must be valid and properly aligned. + self.0.get() + } +} + +// This macro generates the function signature with given argument list and return type. +macro_rules! declare_atomic_method { + ( + $(#[doc=$doc:expr])* + $func:ident($($arg:ident : $arg_type:ty),*) $(-> $ret:ty)? + ) => { + paste!( + $(#[doc = $doc])* + fn [< atomic_ $func >]($($arg: $arg_type,)*) $(-> $ret)?; + ); + }; + ( + $(#[doc=$doc:expr])* + $func:ident [$variant:ident $($rest:ident)*]($($arg_sig:tt)*) $(-> $ret:ty)? + ) => { + paste!( + declare_atomic_method!( + $(#[doc = $doc])* + [< $func _ $variant >]($($arg_sig)*) $(-> $ret)? + ); + ); + + declare_atomic_method!( + $(#[doc = $doc])* + $func [$($rest)*]($($arg_sig)*) $(-> $ret)? + ); + }; + ( + $(#[doc=$doc:expr])* + $func:ident []($($arg_sig:tt)*) $(-> $ret:ty)? + ) => { + declare_atomic_method!( + $(#[doc = $doc])* + $func($($arg_sig)*) $(-> $ret)? + ); + } +} + +// This macro generates the function implementation with given argument list and return type, and it +// will replace "call(...)" expression with "$ctype _ $func" to call the real C function. +macro_rules! impl_atomic_method { + ( + ($ctype:ident) $func:ident($($arg:ident: $arg_type:ty),*) $(-> $ret:ty)? { + $unsafe:tt { call($($c_arg:expr),*) } + } + ) => { + paste!( + #[inline(always)] + fn [< atomic_ $func >]($($arg: $arg_type,)*) $(-> $ret)? { + // TODO: Ideally we want to use the SAFETY comments written at the macro invocation + // (e.g. in `declare_and_impl_atomic_methods!()`, however, since SAFETY comments + // are just comments, and they are not passed to macros as tokens, therefore we + // cannot use them here. One potential improvement is that if we support using + // attributes as an alternative for SAFETY comments, then we can use that for macro + // generating code. + // + // SAFETY: specified on macro invocation. + $unsafe { bindings::[< $ctype _ $func >]($($c_arg,)*) } + } + ); + }; + ( + ($ctype:ident) $func:ident[$variant:ident $($rest:ident)*]($($arg_sig:tt)*) $(-> $ret:ty)? { + $unsafe:tt { call($($arg:tt)*) } + } + ) => { + paste!( + impl_atomic_method!( + ($ctype) [< $func _ $variant >]($($arg_sig)*) $( -> $ret)? { + $unsafe { call($($arg)*) } + } + ); + ); + impl_atomic_method!( + ($ctype) $func [$($rest)*]($($arg_sig)*) $( -> $ret)? { + $unsafe { call($($arg)*) } + } + ); + }; + ( + ($ctype:ident) $func:ident[]($($arg_sig:tt)*) $( -> $ret:ty)? { + $unsafe:tt { call($($arg:tt)*) } + } + ) => { + impl_atomic_method!( + ($ctype) $func($($arg_sig)*) $(-> $ret)? { + $unsafe { call($($arg)*) } + } + ); + } +} + +// Delcares $ops trait with methods and implements the trait for `i32` and `i64`. +macro_rules! declare_and_impl_atomic_methods { + ($(#[$attr:meta])* $pub:vis trait $ops:ident { + $( + $(#[doc=$doc:expr])* + fn $func:ident [$($variant:ident),*]($($arg_sig:tt)*) $( -> $ret:ty)? { + $unsafe:tt { bindings::#call($($arg:tt)*) } + } + )* + }) => { + $(#[$attr])* + $pub trait $ops: AtomicImpl { + $( + declare_atomic_method!( + $(#[doc=$doc])* + $func[$($variant)*]($($arg_sig)*) $(-> $ret)? + ); + )* + } + + impl $ops for i32 { + $( + impl_atomic_method!( + (atomic) $func[$($variant)*]($($arg_sig)*) $(-> $ret)? { + $unsafe { call($($arg)*) } + } + ); + )* + } + + impl $ops for i64 { + $( + impl_atomic_method!( + (atomic64) $func[$($variant)*]($($arg_sig)*) $(-> $ret)? { + $unsafe { call($($arg)*) } + } + ); + )* + } + } +} + +declare_and_impl_atomic_methods!( + /// Basic atomic operations + pub trait AtomicBasicOps { + /// Atomic read (load). + fn read[acquire](a: &AtomicRepr<Self>) -> Self { + // SAFETY: `a.as_ptr()` is valid and properly aligned. + unsafe { bindings::#call(a.as_ptr().cast()) } + } + + /// Atomic set (store). + fn set[release](a: &AtomicRepr<Self>, v: Self) { + // SAFETY: `a.as_ptr()` is valid and properly aligned. + unsafe { bindings::#call(a.as_ptr().cast(), v) } + } + } +); + +declare_and_impl_atomic_methods!( + /// Exchange and compare-and-exchange atomic operations + pub trait AtomicExchangeOps { + /// Atomic exchange. + /// + /// Atomically updates `*a` to `v` and returns the old value. + fn xchg[acquire, release, relaxed](a: &AtomicRepr<Self>, v: Self) -> Self { + // SAFETY: `a.as_ptr()` is valid and properly aligned. + unsafe { bindings::#call(a.as_ptr().cast(), v) } + } + + /// Atomic compare and exchange. + /// + /// If `*a` == `*old`, atomically updates `*a` to `new`. Otherwise, `*a` is not + /// modified, `*old` is updated to the current value of `*a`. + /// + /// Return `true` if the update of `*a` occurred, `false` otherwise. + fn try_cmpxchg[acquire, release, relaxed]( + a: &AtomicRepr<Self>, old: &mut Self, new: Self + ) -> bool { + // SAFETY: `a.as_ptr()` is valid and properly aligned. `core::ptr::from_mut(old)` + // is valid and properly aligned. + unsafe { bindings::#call(a.as_ptr().cast(), core::ptr::from_mut(old), new) } + } + } +); + +declare_and_impl_atomic_methods!( + /// Atomic arithmetic operations + pub trait AtomicArithmeticOps { + /// Atomic add (wrapping). + /// + /// Atomically updates `*a` to `(*a).wrapping_add(v)`. + fn add[](a: &AtomicRepr<Self>, v: Self::Delta) { + // SAFETY: `a.as_ptr()` is valid and properly aligned. + unsafe { bindings::#call(v, a.as_ptr().cast()) } + } + + /// Atomic fetch and add (wrapping). + /// + /// Atomically updates `*a` to `(*a).wrapping_add(v)`, and returns the value of `*a` + /// before the update. + fn fetch_add[acquire, release, relaxed](a: &AtomicRepr<Self>, v: Self::Delta) -> Self { + // SAFETY: `a.as_ptr()` is valid and properly aligned. + unsafe { bindings::#call(v, a.as_ptr().cast()) } + } + } +); diff --git a/rust/kernel/sync/atomic/ordering.rs b/rust/kernel/sync/atomic/ordering.rs new file mode 100644 index 000000000000..3f103aa8db99 --- /dev/null +++ b/rust/kernel/sync/atomic/ordering.rs @@ -0,0 +1,104 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! Memory orderings. +//! +//! The semantics of these orderings follows the [`LKMM`] definitions and rules. +//! +//! - [`Acquire`] provides ordering between the load part of the annotated operation and all the +//! following memory accesses, and if there is a store part, the store part has the [`Relaxed`] +//! ordering. +//! - [`Release`] provides ordering between all the preceding memory accesses and the store part of +//! the annotated operation, and if there is a load part, the load part has the [`Relaxed`] +//! ordering. +//! - [`Full`] means "fully-ordered", that is: +//! - It provides ordering between all the preceding memory accesses and the annotated operation. +//! - It provides ordering between the annotated operation and all the following memory accesses. +//! - It provides ordering between all the preceding memory accesses and all the following memory +//! accesses. +//! - All the orderings are the same strength as a full memory barrier (i.e. `smp_mb()`). +//! - [`Relaxed`] provides no ordering except the dependency orderings. Dependency orderings are +//! described in "DEPENDENCY RELATIONS" in [`LKMM`]'s [`explanation`]. +//! +//! [`LKMM`]: srctree/tools/memory-model/ +//! [`explanation`]: srctree/tools/memory-model/Documentation/explanation.txt + +/// The annotation type for relaxed memory ordering, for the description of relaxed memory +/// ordering, see [module-level documentation]. +/// +/// [module-level documentation]: crate::sync::atomic::ordering +pub struct Relaxed; + +/// The annotation type for acquire memory ordering, for the description of acquire memory +/// ordering, see [module-level documentation]. +/// +/// [module-level documentation]: crate::sync::atomic::ordering +pub struct Acquire; + +/// The annotation type for release memory ordering, for the description of release memory +/// ordering, see [module-level documentation]. +/// +/// [module-level documentation]: crate::sync::atomic::ordering +pub struct Release; + +/// The annotation type for fully-ordered memory ordering, for the description fully-ordered memory +/// ordering, see [module-level documentation]. +/// +/// [module-level documentation]: crate::sync::atomic::ordering +pub struct Full; + +/// Describes the exact memory ordering. +#[doc(hidden)] +pub enum OrderingType { + /// Relaxed ordering. + Relaxed, + /// Acquire ordering. + Acquire, + /// Release ordering. + Release, + /// Fully-ordered. + Full, +} + +mod internal { + /// Sealed trait, can be only implemented inside atomic mod. + pub trait Sealed {} + + impl Sealed for super::Relaxed {} + impl Sealed for super::Acquire {} + impl Sealed for super::Release {} + impl Sealed for super::Full {} +} + +/// The trait bound for annotating operations that support any ordering. +pub trait Ordering: internal::Sealed { + /// Describes the exact memory ordering. + const TYPE: OrderingType; +} + +impl Ordering for Relaxed { + const TYPE: OrderingType = OrderingType::Relaxed; +} + +impl Ordering for Acquire { + const TYPE: OrderingType = OrderingType::Acquire; +} + +impl Ordering for Release { + const TYPE: OrderingType = OrderingType::Release; +} + +impl Ordering for Full { + const TYPE: OrderingType = OrderingType::Full; +} + +/// The trait bound for operations that only support acquire or relaxed ordering. +pub trait AcquireOrRelaxed: Ordering {} + +impl AcquireOrRelaxed for Acquire {} +impl AcquireOrRelaxed for Relaxed {} + +/// The trait bound for operations that only support release or relaxed ordering. +pub trait ReleaseOrRelaxed: Ordering {} + +impl ReleaseOrRelaxed for Release {} +impl ReleaseOrRelaxed for Relaxed {} diff --git a/rust/kernel/sync/atomic/predefine.rs b/rust/kernel/sync/atomic/predefine.rs new file mode 100644 index 000000000000..45a17985cda4 --- /dev/null +++ b/rust/kernel/sync/atomic/predefine.rs @@ -0,0 +1,169 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! Pre-defined atomic types + +use crate::static_assert; +use core::mem::{align_of, size_of}; + +// SAFETY: `i32` has the same size and alignment with itself, and is round-trip transmutable to +// itself. +unsafe impl super::AtomicType for i32 { + type Repr = i32; +} + +// SAFETY: The wrapping add result of two `i32`s is a valid `i32`. +unsafe impl super::AtomicAdd<i32> for i32 { + fn rhs_into_delta(rhs: i32) -> i32 { + rhs + } +} + +// SAFETY: `i64` has the same size and alignment with itself, and is round-trip transmutable to +// itself. +unsafe impl super::AtomicType for i64 { + type Repr = i64; +} + +// SAFETY: The wrapping add result of two `i64`s is a valid `i64`. +unsafe impl super::AtomicAdd<i64> for i64 { + fn rhs_into_delta(rhs: i64) -> i64 { + rhs + } +} + +// Defines an internal type that always maps to the integer type which has the same size alignment +// as `isize` and `usize`, and `isize` and `usize` are always bi-directional transmutable to +// `isize_atomic_repr`, which also always implements `AtomicImpl`. +#[allow(non_camel_case_types)] +#[cfg(not(CONFIG_64BIT))] +type isize_atomic_repr = i32; +#[allow(non_camel_case_types)] +#[cfg(CONFIG_64BIT)] +type isize_atomic_repr = i64; + +// Ensure size and alignment requirements are checked. +static_assert!(size_of::<isize>() == size_of::<isize_atomic_repr>()); +static_assert!(align_of::<isize>() == align_of::<isize_atomic_repr>()); +static_assert!(size_of::<usize>() == size_of::<isize_atomic_repr>()); +static_assert!(align_of::<usize>() == align_of::<isize_atomic_repr>()); + +// SAFETY: `isize` has the same size and alignment with `isize_atomic_repr`, and is round-trip +// transmutable to `isize_atomic_repr`. +unsafe impl super::AtomicType for isize { + type Repr = isize_atomic_repr; +} + +// SAFETY: The wrapping add result of two `isize_atomic_repr`s is a valid `usize`. +unsafe impl super::AtomicAdd<isize> for isize { + fn rhs_into_delta(rhs: isize) -> isize_atomic_repr { + rhs as isize_atomic_repr + } +} + +// SAFETY: `u32` and `i32` has the same size and alignment, and `u32` is round-trip transmutable to +// `i32`. +unsafe impl super::AtomicType for u32 { + type Repr = i32; +} + +// SAFETY: The wrapping add result of two `i32`s is a valid `u32`. +unsafe impl super::AtomicAdd<u32> for u32 { + fn rhs_into_delta(rhs: u32) -> i32 { + rhs as i32 + } +} + +// SAFETY: `u64` and `i64` has the same size and alignment, and `u64` is round-trip transmutable to +// `i64`. +unsafe impl super::AtomicType for u64 { + type Repr = i64; +} + +// SAFETY: The wrapping add result of two `i64`s is a valid `u64`. +unsafe impl super::AtomicAdd<u64> for u64 { + fn rhs_into_delta(rhs: u64) -> i64 { + rhs as i64 + } +} + +// SAFETY: `usize` has the same size and alignment with `isize_atomic_repr`, and is round-trip +// transmutable to `isize_atomic_repr`. +unsafe impl super::AtomicType for usize { + type Repr = isize_atomic_repr; +} + +// SAFETY: The wrapping add result of two `isize_atomic_repr`s is a valid `usize`. +unsafe impl super::AtomicAdd<usize> for usize { + fn rhs_into_delta(rhs: usize) -> isize_atomic_repr { + rhs as isize_atomic_repr + } +} + +use crate::macros::kunit_tests; + +#[kunit_tests(rust_atomics)] +mod tests { + use super::super::*; + + // Call $fn($val) with each $type of $val. + macro_rules! for_each_type { + ($val:literal in [$($type:ty),*] $fn:expr) => { + $({ + let v: $type = $val; + + $fn(v); + })* + } + } + + #[test] + fn atomic_basic_tests() { + for_each_type!(42 in [i32, i64, u32, u64, isize, usize] |v| { + let x = Atomic::new(v); + + assert_eq!(v, x.load(Relaxed)); + }); + } + + #[test] + fn atomic_xchg_tests() { + for_each_type!(42 in [i32, i64, u32, u64, isize, usize] |v| { + let x = Atomic::new(v); + + let old = v; + let new = v + 1; + + assert_eq!(old, x.xchg(new, Full)); + assert_eq!(new, x.load(Relaxed)); + }); + } + + #[test] + fn atomic_cmpxchg_tests() { + for_each_type!(42 in [i32, i64, u32, u64, isize, usize] |v| { + let x = Atomic::new(v); + + let old = v; + let new = v + 1; + + assert_eq!(Err(old), x.cmpxchg(new, new, Full)); + assert_eq!(old, x.load(Relaxed)); + assert_eq!(Ok(old), x.cmpxchg(old, new, Relaxed)); + assert_eq!(new, x.load(Relaxed)); + }); + } + + #[test] + fn atomic_arithmetic_tests() { + for_each_type!(42 in [i32, i64, u32, u64, isize, usize] |v| { + let x = Atomic::new(v); + + assert_eq!(v, x.fetch_add(12, Full)); + assert_eq!(v + 12, x.load(Relaxed)); + + x.add(13, Relaxed); + + assert_eq!(v + 25, x.load(Relaxed)); + }); + } +} diff --git a/rust/kernel/sync/barrier.rs b/rust/kernel/sync/barrier.rs new file mode 100644 index 000000000000..8f2d435fcd94 --- /dev/null +++ b/rust/kernel/sync/barrier.rs @@ -0,0 +1,61 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! Memory barriers. +//! +//! These primitives have the same semantics as their C counterparts: and the precise definitions +//! of semantics can be found at [`LKMM`]. +//! +//! [`LKMM`]: srctree/tools/memory-model/ + +/// A compiler barrier. +/// +/// A barrier that prevents compiler from reordering memory accesses across the barrier. +#[inline(always)] +pub(crate) fn barrier() { + // By default, Rust inline asms are treated as being able to access any memory or flags, hence + // it suffices as a compiler barrier. + // + // SAFETY: An empty asm block. + unsafe { core::arch::asm!("") }; +} + +/// A full memory barrier. +/// +/// A barrier that prevents compiler and CPU from reordering memory accesses across the barrier. +#[inline(always)] +pub fn smp_mb() { + if cfg!(CONFIG_SMP) { + // SAFETY: `smp_mb()` is safe to call. + unsafe { bindings::smp_mb() }; + } else { + barrier(); + } +} + +/// A write-write memory barrier. +/// +/// A barrier that prevents compiler and CPU from reordering memory write accesses across the +/// barrier. +#[inline(always)] +pub fn smp_wmb() { + if cfg!(CONFIG_SMP) { + // SAFETY: `smp_wmb()` is safe to call. + unsafe { bindings::smp_wmb() }; + } else { + barrier(); + } +} + +/// A read-read memory barrier. +/// +/// A barrier that prevents compiler and CPU from reordering memory read accesses across the +/// barrier. +#[inline(always)] +pub fn smp_rmb() { + if cfg!(CONFIG_SMP) { + // SAFETY: `smp_rmb()` is safe to call. + unsafe { bindings::smp_rmb() }; + } else { + barrier(); + } +} diff --git a/rust/kernel/sync/completion.rs b/rust/kernel/sync/completion.rs new file mode 100644 index 000000000000..c50012a940a3 --- /dev/null +++ b/rust/kernel/sync/completion.rs @@ -0,0 +1,112 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! Completion support. +//! +//! Reference: <https://docs.kernel.org/scheduler/completion.html> +//! +//! C header: [`include/linux/completion.h`](srctree/include/linux/completion.h) + +use crate::{bindings, prelude::*, types::Opaque}; + +/// Synchronization primitive to signal when a certain task has been completed. +/// +/// The [`Completion`] synchronization primitive signals when a certain task has been completed by +/// waking up other tasks that have been queued up to wait for the [`Completion`] to be completed. +/// +/// # Examples +/// +/// ``` +/// use kernel::sync::{Arc, Completion}; +/// use kernel::workqueue::{self, impl_has_work, new_work, Work, WorkItem}; +/// +/// #[pin_data] +/// struct MyTask { +/// #[pin] +/// work: Work<MyTask>, +/// #[pin] +/// done: Completion, +/// } +/// +/// impl_has_work! { +/// impl HasWork<Self> for MyTask { self.work } +/// } +/// +/// impl MyTask { +/// fn new() -> Result<Arc<Self>> { +/// let this = Arc::pin_init(pin_init!(MyTask { +/// work <- new_work!("MyTask::work"), +/// done <- Completion::new(), +/// }), GFP_KERNEL)?; +/// +/// let _ = workqueue::system().enqueue(this.clone()); +/// +/// Ok(this) +/// } +/// +/// fn wait_for_completion(&self) { +/// self.done.wait_for_completion(); +/// +/// pr_info!("Completion: task complete\n"); +/// } +/// } +/// +/// impl WorkItem for MyTask { +/// type Pointer = Arc<MyTask>; +/// +/// fn run(this: Arc<MyTask>) { +/// // process this task +/// this.done.complete_all(); +/// } +/// } +/// +/// let task = MyTask::new()?; +/// task.wait_for_completion(); +/// # Ok::<(), Error>(()) +/// ``` +#[pin_data] +pub struct Completion { + #[pin] + inner: Opaque<bindings::completion>, +} + +// SAFETY: `Completion` is safe to be send to any task. +unsafe impl Send for Completion {} + +// SAFETY: `Completion` is safe to be accessed concurrently. +unsafe impl Sync for Completion {} + +impl Completion { + /// Create an initializer for a new [`Completion`]. + pub fn new() -> impl PinInit<Self> { + pin_init!(Self { + inner <- Opaque::ffi_init(|slot: *mut bindings::completion| { + // SAFETY: `slot` is a valid pointer to an uninitialized `struct completion`. + unsafe { bindings::init_completion(slot) }; + }), + }) + } + + fn as_raw(&self) -> *mut bindings::completion { + self.inner.get() + } + + /// Signal all tasks waiting on this completion. + /// + /// This method wakes up all tasks waiting on this completion; after this operation the + /// completion is permanently done, i.e. signals all current and future waiters. + pub fn complete_all(&self) { + // SAFETY: `self.as_raw()` is a pointer to a valid `struct completion`. + unsafe { bindings::complete_all(self.as_raw()) }; + } + + /// Wait for completion of a task. + /// + /// This method waits for the completion of a task; it is not interruptible and there is no + /// timeout. + /// + /// See also [`Completion::complete_all`]. + pub fn wait_for_completion(&self) { + // SAFETY: `self.as_raw()` is a pointer to a valid `struct completion`. + unsafe { bindings::wait_for_completion(self.as_raw()) }; + } +} diff --git a/rust/kernel/sync/condvar.rs b/rust/kernel/sync/condvar.rs new file mode 100644 index 000000000000..69d58dfbad7b --- /dev/null +++ b/rust/kernel/sync/condvar.rs @@ -0,0 +1,258 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! A condition variable. +//! +//! This module allows Rust code to use the kernel's [`struct wait_queue_head`] as a condition +//! variable. + +use super::{lock::Backend, lock::Guard, LockClassKey}; +use crate::{ + ffi::{c_int, c_long}, + str::{CStr, CStrExt as _}, + task::{ + MAX_SCHEDULE_TIMEOUT, TASK_FREEZABLE, TASK_INTERRUPTIBLE, TASK_NORMAL, TASK_UNINTERRUPTIBLE, + }, + time::Jiffies, + types::Opaque, +}; +use core::{marker::PhantomPinned, pin::Pin, ptr}; +use pin_init::{pin_data, pin_init, PinInit}; + +/// Creates a [`CondVar`] initialiser with the given name and a newly-created lock class. +#[macro_export] +macro_rules! new_condvar { + ($($name:literal)?) => { + $crate::sync::CondVar::new($crate::optional_name!($($name)?), $crate::static_lock_class!()) + }; +} +pub use new_condvar; + +/// A conditional variable. +/// +/// Exposes the kernel's [`struct wait_queue_head`] as a condition variable. It allows the caller to +/// atomically release the given lock and go to sleep. It reacquires the lock when it wakes up. And +/// it wakes up when notified by another thread (via [`CondVar::notify_one`] or +/// [`CondVar::notify_all`]) or because the thread received a signal. It may also wake up +/// spuriously. +/// +/// Instances of [`CondVar`] need a lock class and to be pinned. The recommended way to create such +/// instances is with the [`pin_init`](pin_init::pin_init!) and [`new_condvar`] macros. +/// +/// # Examples +/// +/// The following is an example of using a condvar with a mutex: +/// +/// ``` +/// use kernel::sync::{new_condvar, new_mutex, CondVar, Mutex}; +/// +/// #[pin_data] +/// pub struct Example { +/// #[pin] +/// value: Mutex<u32>, +/// +/// #[pin] +/// value_changed: CondVar, +/// } +/// +/// /// Waits for `e.value` to become `v`. +/// fn wait_for_value(e: &Example, v: u32) { +/// let mut guard = e.value.lock(); +/// while *guard != v { +/// e.value_changed.wait(&mut guard); +/// } +/// } +/// +/// /// Increments `e.value` and notifies all potential waiters. +/// fn increment(e: &Example) { +/// *e.value.lock() += 1; +/// e.value_changed.notify_all(); +/// } +/// +/// /// Allocates a new boxed `Example`. +/// fn new_example() -> Result<Pin<KBox<Example>>> { +/// KBox::pin_init(pin_init!(Example { +/// value <- new_mutex!(0), +/// value_changed <- new_condvar!(), +/// }), GFP_KERNEL) +/// } +/// ``` +/// +/// [`struct wait_queue_head`]: srctree/include/linux/wait.h +#[pin_data] +pub struct CondVar { + #[pin] + pub(crate) wait_queue_head: Opaque<bindings::wait_queue_head>, + + /// A condvar needs to be pinned because it contains a [`struct list_head`] that is + /// self-referential, so it cannot be safely moved once it is initialised. + /// + /// [`struct list_head`]: srctree/include/linux/types.h + #[pin] + _pin: PhantomPinned, +} + +// SAFETY: `CondVar` only uses a `struct wait_queue_head`, which is safe to use on any thread. +unsafe impl Send for CondVar {} + +// SAFETY: `CondVar` only uses a `struct wait_queue_head`, which is safe to use on multiple threads +// concurrently. +unsafe impl Sync for CondVar {} + +impl CondVar { + /// Constructs a new condvar initialiser. + pub fn new(name: &'static CStr, key: Pin<&'static LockClassKey>) -> impl PinInit<Self> { + pin_init!(Self { + _pin: PhantomPinned, + // SAFETY: `slot` is valid while the closure is called and both `name` and `key` have + // static lifetimes so they live indefinitely. + wait_queue_head <- Opaque::ffi_init(|slot| unsafe { + bindings::__init_waitqueue_head(slot, name.as_char_ptr(), key.as_ptr()) + }), + }) + } + + fn wait_internal<T: ?Sized, B: Backend>( + &self, + wait_state: c_int, + guard: &mut Guard<'_, T, B>, + timeout_in_jiffies: c_long, + ) -> c_long { + let wait = Opaque::<bindings::wait_queue_entry>::uninit(); + + // SAFETY: `wait` points to valid memory. + unsafe { bindings::init_wait(wait.get()) }; + + // SAFETY: Both `wait` and `wait_queue_head` point to valid memory. + unsafe { + bindings::prepare_to_wait_exclusive(self.wait_queue_head.get(), wait.get(), wait_state) + }; + + // SAFETY: Switches to another thread. The timeout can be any number. + let ret = guard.do_unlocked(|| unsafe { bindings::schedule_timeout(timeout_in_jiffies) }); + + // SAFETY: Both `wait` and `wait_queue_head` point to valid memory. + unsafe { bindings::finish_wait(self.wait_queue_head.get(), wait.get()) }; + + ret + } + + /// Releases the lock and waits for a notification in uninterruptible mode. + /// + /// Atomically releases the given lock (whose ownership is proven by the guard) and puts the + /// thread to sleep, reacquiring the lock on wake up. It wakes up when notified by + /// [`CondVar::notify_one`] or [`CondVar::notify_all`]. Note that it may also wake up + /// spuriously. + pub fn wait<T: ?Sized, B: Backend>(&self, guard: &mut Guard<'_, T, B>) { + self.wait_internal(TASK_UNINTERRUPTIBLE, guard, MAX_SCHEDULE_TIMEOUT); + } + + /// Releases the lock and waits for a notification in interruptible mode. + /// + /// Similar to [`CondVar::wait`], except that the wait is interruptible. That is, the thread may + /// wake up due to signals. It may also wake up spuriously. + /// + /// Returns whether there is a signal pending. + #[must_use = "wait_interruptible returns if a signal is pending, so the caller must check the return value"] + pub fn wait_interruptible<T: ?Sized, B: Backend>(&self, guard: &mut Guard<'_, T, B>) -> bool { + self.wait_internal(TASK_INTERRUPTIBLE, guard, MAX_SCHEDULE_TIMEOUT); + crate::current!().signal_pending() + } + + /// Releases the lock and waits for a notification in interruptible and freezable mode. + /// + /// The process is allowed to be frozen during this sleep. No lock should be held when calling + /// this function, and there is a lockdep assertion for this. Freezing a task that holds a lock + /// can trivially deadlock vs another task that needs that lock to complete before it too can + /// hit freezable. + #[must_use = "wait_interruptible_freezable returns if a signal is pending, so the caller must check the return value"] + pub fn wait_interruptible_freezable<T: ?Sized, B: Backend>( + &self, + guard: &mut Guard<'_, T, B>, + ) -> bool { + self.wait_internal( + TASK_INTERRUPTIBLE | TASK_FREEZABLE, + guard, + MAX_SCHEDULE_TIMEOUT, + ); + crate::current!().signal_pending() + } + + /// Releases the lock and waits for a notification in interruptible mode. + /// + /// Atomically releases the given lock (whose ownership is proven by the guard) and puts the + /// thread to sleep. It wakes up when notified by [`CondVar::notify_one`] or + /// [`CondVar::notify_all`], or when a timeout occurs, or when the thread receives a signal. + #[must_use = "wait_interruptible_timeout returns if a signal is pending, so the caller must check the return value"] + pub fn wait_interruptible_timeout<T: ?Sized, B: Backend>( + &self, + guard: &mut Guard<'_, T, B>, + jiffies: Jiffies, + ) -> CondVarTimeoutResult { + let jiffies = jiffies.try_into().unwrap_or(MAX_SCHEDULE_TIMEOUT); + let res = self.wait_internal(TASK_INTERRUPTIBLE, guard, jiffies); + + match (res as Jiffies, crate::current!().signal_pending()) { + (jiffies, true) => CondVarTimeoutResult::Signal { jiffies }, + (0, false) => CondVarTimeoutResult::Timeout, + (jiffies, false) => CondVarTimeoutResult::Woken { jiffies }, + } + } + + /// Calls the kernel function to notify the appropriate number of threads. + fn notify(&self, count: c_int) { + // SAFETY: `wait_queue_head` points to valid memory. + unsafe { + bindings::__wake_up( + self.wait_queue_head.get(), + TASK_NORMAL, + count, + ptr::null_mut(), + ) + }; + } + + /// Calls the kernel function to notify one thread synchronously. + /// + /// This method behaves like `notify_one`, except that it hints to the scheduler that the + /// current thread is about to go to sleep, so it should schedule the target thread on the same + /// CPU. + #[inline] + pub fn notify_sync(&self) { + // SAFETY: `wait_queue_head` points to valid memory. + unsafe { bindings::__wake_up_sync(self.wait_queue_head.get(), TASK_NORMAL) }; + } + + /// Wakes a single waiter up, if any. + /// + /// This is not 'sticky' in the sense that if no thread is waiting, the notification is lost + /// completely (as opposed to automatically waking up the next waiter). + #[inline] + pub fn notify_one(&self) { + self.notify(1); + } + + /// Wakes all waiters up, if any. + /// + /// This is not 'sticky' in the sense that if no thread is waiting, the notification is lost + /// completely (as opposed to automatically waking up the next waiter). + #[inline] + pub fn notify_all(&self) { + self.notify(0); + } +} + +/// The return type of `wait_timeout`. +pub enum CondVarTimeoutResult { + /// The timeout was reached. + Timeout, + /// Somebody woke us up. + Woken { + /// Remaining sleep duration. + jiffies: Jiffies, + }, + /// A signal occurred. + Signal { + /// Remaining sleep duration. + jiffies: Jiffies, + }, +} diff --git a/rust/kernel/sync/lock.rs b/rust/kernel/sync/lock.rs new file mode 100644 index 000000000000..46a57d1fc309 --- /dev/null +++ b/rust/kernel/sync/lock.rs @@ -0,0 +1,317 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! Generic kernel lock and guard. +//! +//! It contains a generic Rust lock and guard that allow for different backends (e.g., mutexes, +//! spinlocks, raw spinlocks) to be provided with minimal effort. + +use super::LockClassKey; +use crate::{ + str::{CStr, CStrExt as _}, + types::{NotThreadSafe, Opaque, ScopeGuard}, +}; +use core::{cell::UnsafeCell, marker::PhantomPinned, pin::Pin}; +use pin_init::{pin_data, pin_init, PinInit, Wrapper}; + +pub mod mutex; +pub mod spinlock; + +pub(super) mod global; +pub use global::{GlobalGuard, GlobalLock, GlobalLockBackend, GlobalLockedBy}; + +/// The "backend" of a lock. +/// +/// It is the actual implementation of the lock, without the need to repeat patterns used in all +/// locks. +/// +/// # Safety +/// +/// - Implementers must ensure that only one thread/CPU may access the protected data once the lock +/// is owned, that is, between calls to [`lock`] and [`unlock`]. +/// - Implementers must also ensure that [`relock`] uses the same locking method as the original +/// lock operation. +/// +/// [`lock`]: Backend::lock +/// [`unlock`]: Backend::unlock +/// [`relock`]: Backend::relock +pub unsafe trait Backend { + /// The state required by the lock. + type State; + + /// The state required to be kept between [`lock`] and [`unlock`]. + /// + /// [`lock`]: Backend::lock + /// [`unlock`]: Backend::unlock + type GuardState; + + /// Initialises the lock. + /// + /// # Safety + /// + /// `ptr` must be valid for write for the duration of the call, while `name` and `key` must + /// remain valid for read indefinitely. + unsafe fn init( + ptr: *mut Self::State, + name: *const crate::ffi::c_char, + key: *mut bindings::lock_class_key, + ); + + /// Acquires the lock, making the caller its owner. + /// + /// # Safety + /// + /// Callers must ensure that [`Backend::init`] has been previously called. + #[must_use] + unsafe fn lock(ptr: *mut Self::State) -> Self::GuardState; + + /// Tries to acquire the lock. + /// + /// # Safety + /// + /// Callers must ensure that [`Backend::init`] has been previously called. + unsafe fn try_lock(ptr: *mut Self::State) -> Option<Self::GuardState>; + + /// Releases the lock, giving up its ownership. + /// + /// # Safety + /// + /// It must only be called by the current owner of the lock. + unsafe fn unlock(ptr: *mut Self::State, guard_state: &Self::GuardState); + + /// Reacquires the lock, making the caller its owner. + /// + /// # Safety + /// + /// Callers must ensure that `guard_state` comes from a previous call to [`Backend::lock`] (or + /// variant) that has been unlocked with [`Backend::unlock`] and will be relocked now. + unsafe fn relock(ptr: *mut Self::State, guard_state: &mut Self::GuardState) { + // SAFETY: The safety requirements ensure that the lock is initialised. + *guard_state = unsafe { Self::lock(ptr) }; + } + + /// Asserts that the lock is held using lockdep. + /// + /// # Safety + /// + /// Callers must ensure that [`Backend::init`] has been previously called. + unsafe fn assert_is_held(ptr: *mut Self::State); +} + +/// A mutual exclusion primitive. +/// +/// Exposes one of the kernel locking primitives. Which one is exposed depends on the lock +/// [`Backend`] specified as the generic parameter `B`. +#[repr(C)] +#[pin_data] +pub struct Lock<T: ?Sized, B: Backend> { + /// The kernel lock object. + #[pin] + state: Opaque<B::State>, + + /// Some locks are known to be self-referential (e.g., mutexes), while others are architecture + /// or config defined (e.g., spinlocks). So we conservatively require them to be pinned in case + /// some architecture uses self-references now or in the future. + #[pin] + _pin: PhantomPinned, + + /// The data protected by the lock. + #[pin] + pub(crate) data: UnsafeCell<T>, +} + +// SAFETY: `Lock` can be transferred across thread boundaries iff the data it protects can. +unsafe impl<T: ?Sized + Send, B: Backend> Send for Lock<T, B> {} + +// SAFETY: `Lock` serialises the interior mutability it provides, so it is `Sync` as long as the +// data it protects is `Send`. +unsafe impl<T: ?Sized + Send, B: Backend> Sync for Lock<T, B> {} + +impl<T, B: Backend> Lock<T, B> { + /// Constructs a new lock initialiser. + pub fn new( + t: impl PinInit<T>, + name: &'static CStr, + key: Pin<&'static LockClassKey>, + ) -> impl PinInit<Self> { + pin_init!(Self { + data <- UnsafeCell::pin_init(t), + _pin: PhantomPinned, + // SAFETY: `slot` is valid while the closure is called and both `name` and `key` have + // static lifetimes so they live indefinitely. + state <- Opaque::ffi_init(|slot| unsafe { + B::init(slot, name.as_char_ptr(), key.as_ptr()) + }), + }) + } +} + +impl<B: Backend> Lock<(), B> { + /// Constructs a [`Lock`] from a raw pointer. + /// + /// This can be useful for interacting with a lock which was initialised outside of Rust. + /// + /// # Safety + /// + /// The caller promises that `ptr` points to a valid initialised instance of [`State`] during + /// the whole lifetime of `'a`. + /// + /// [`State`]: Backend::State + pub unsafe fn from_raw<'a>(ptr: *mut B::State) -> &'a Self { + // SAFETY: + // - By the safety contract `ptr` must point to a valid initialised instance of `B::State` + // - Since the lock data type is `()` which is a ZST, `state` is the only non-ZST member of + // the struct + // - Combined with `#[repr(C)]`, this guarantees `Self` has an equivalent data layout to + // `B::State`. + unsafe { &*ptr.cast() } + } +} + +impl<T: ?Sized, B: Backend> Lock<T, B> { + /// Acquires the lock and gives the caller access to the data protected by it. + pub fn lock(&self) -> Guard<'_, T, B> { + // SAFETY: The constructor of the type calls `init`, so the existence of the object proves + // that `init` was called. + let state = unsafe { B::lock(self.state.get()) }; + // SAFETY: The lock was just acquired. + unsafe { Guard::new(self, state) } + } + + /// Tries to acquire the lock. + /// + /// Returns a guard that can be used to access the data protected by the lock if successful. + // `Option<T>` is not `#[must_use]` even if `T` is, thus the attribute is needed here. + #[must_use = "if unused, the lock will be immediately unlocked"] + pub fn try_lock(&self) -> Option<Guard<'_, T, B>> { + // SAFETY: The constructor of the type calls `init`, so the existence of the object proves + // that `init` was called. + unsafe { B::try_lock(self.state.get()).map(|state| Guard::new(self, state)) } + } +} + +/// A lock guard. +/// +/// Allows mutual exclusion primitives that implement the [`Backend`] trait to automatically unlock +/// when a guard goes out of scope. It also provides a safe and convenient way to access the data +/// protected by the lock. +#[must_use = "the lock unlocks immediately when the guard is unused"] +pub struct Guard<'a, T: ?Sized, B: Backend> { + pub(crate) lock: &'a Lock<T, B>, + pub(crate) state: B::GuardState, + _not_send: NotThreadSafe, +} + +// SAFETY: `Guard` is sync when the data protected by the lock is also sync. +unsafe impl<T: Sync + ?Sized, B: Backend> Sync for Guard<'_, T, B> {} + +impl<'a, T: ?Sized, B: Backend> Guard<'a, T, B> { + /// Returns the lock that this guard originates from. + /// + /// # Examples + /// + /// The following example shows how to use [`Guard::lock_ref()`] to assert the corresponding + /// lock is held. + /// + /// ``` + /// # use kernel::{new_spinlock, sync::lock::{Backend, Guard, Lock}}; + /// # use pin_init::stack_pin_init; + /// + /// fn assert_held<T, B: Backend>(guard: &Guard<'_, T, B>, lock: &Lock<T, B>) { + /// // Address-equal means the same lock. + /// assert!(core::ptr::eq(guard.lock_ref(), lock)); + /// } + /// + /// // Creates a new lock on the stack. + /// stack_pin_init!{ + /// let l = new_spinlock!(42) + /// } + /// + /// let g = l.lock(); + /// + /// // `g` originates from `l`. + /// assert_held(&g, &l); + /// ``` + pub fn lock_ref(&self) -> &'a Lock<T, B> { + self.lock + } + + pub(crate) fn do_unlocked<U>(&mut self, cb: impl FnOnce() -> U) -> U { + // SAFETY: The caller owns the lock, so it is safe to unlock it. + unsafe { B::unlock(self.lock.state.get(), &self.state) }; + + let _relock = ScopeGuard::new(|| + // SAFETY: The lock was just unlocked above and is being relocked now. + unsafe { B::relock(self.lock.state.get(), &mut self.state) }); + + cb() + } + + /// Returns a pinned mutable reference to the protected data. + /// + /// The guard implements [`DerefMut`] when `T: Unpin`, so for [`Unpin`] + /// types [`DerefMut`] should be used instead of this function. + /// + /// [`DerefMut`]: core::ops::DerefMut + /// [`Unpin`]: core::marker::Unpin + /// + /// # Examples + /// + /// ``` + /// # use kernel::sync::{Mutex, MutexGuard}; + /// # use core::{pin::Pin, marker::PhantomPinned}; + /// struct Data(PhantomPinned); + /// + /// fn example(mutex: &Mutex<Data>) { + /// let mut data: MutexGuard<'_, Data> = mutex.lock(); + /// let mut data: Pin<&mut Data> = data.as_mut(); + /// } + /// ``` + pub fn as_mut(&mut self) -> Pin<&mut T> { + // SAFETY: `self.lock.data` is structurally pinned. + unsafe { Pin::new_unchecked(&mut *self.lock.data.get()) } + } +} + +impl<T: ?Sized, B: Backend> core::ops::Deref for Guard<'_, T, B> { + type Target = T; + + fn deref(&self) -> &Self::Target { + // SAFETY: The caller owns the lock, so it is safe to deref the protected data. + unsafe { &*self.lock.data.get() } + } +} + +impl<T: ?Sized, B: Backend> core::ops::DerefMut for Guard<'_, T, B> +where + T: Unpin, +{ + fn deref_mut(&mut self) -> &mut Self::Target { + // SAFETY: The caller owns the lock, so it is safe to deref the protected data. + unsafe { &mut *self.lock.data.get() } + } +} + +impl<T: ?Sized, B: Backend> Drop for Guard<'_, T, B> { + fn drop(&mut self) { + // SAFETY: The caller owns the lock, so it is safe to unlock it. + unsafe { B::unlock(self.lock.state.get(), &self.state) }; + } +} + +impl<'a, T: ?Sized, B: Backend> Guard<'a, T, B> { + /// Constructs a new immutable lock guard. + /// + /// # Safety + /// + /// The caller must ensure that it owns the lock. + pub unsafe fn new(lock: &'a Lock<T, B>, state: B::GuardState) -> Self { + // SAFETY: The caller can only hold the lock if `Backend::init` has already been called. + unsafe { B::assert_is_held(lock.state.get()) }; + + Self { + lock, + state, + _not_send: NotThreadSafe, + } + } +} diff --git a/rust/kernel/sync/lock/global.rs b/rust/kernel/sync/lock/global.rs new file mode 100644 index 000000000000..eab48108a4ae --- /dev/null +++ b/rust/kernel/sync/lock/global.rs @@ -0,0 +1,305 @@ +// SPDX-License-Identifier: GPL-2.0 + +// Copyright (C) 2024 Google LLC. + +//! Support for defining statics containing locks. + +use crate::{ + str::{CStr, CStrExt as _}, + sync::lock::{Backend, Guard, Lock}, + sync::{LockClassKey, LockedBy}, + types::Opaque, +}; +use core::{ + cell::UnsafeCell, + marker::{PhantomData, PhantomPinned}, + pin::Pin, +}; + +/// Trait implemented for marker types for global locks. +/// +/// See [`global_lock!`] for examples. +pub trait GlobalLockBackend { + /// The name for this global lock. + const NAME: &'static CStr; + /// Item type stored in this global lock. + type Item: 'static; + /// The backend used for this global lock. + type Backend: Backend + 'static; + /// The class for this global lock. + fn get_lock_class() -> Pin<&'static LockClassKey>; +} + +/// Type used for global locks. +/// +/// See [`global_lock!`] for examples. +pub struct GlobalLock<B: GlobalLockBackend> { + inner: Lock<B::Item, B::Backend>, +} + +impl<B: GlobalLockBackend> GlobalLock<B> { + /// Creates a global lock. + /// + /// # Safety + /// + /// * Before any other method on this lock is called, [`Self::init`] must be called. + /// * The type `B` must not be used with any other lock. + pub const unsafe fn new(data: B::Item) -> Self { + Self { + inner: Lock { + state: Opaque::uninit(), + data: UnsafeCell::new(data), + _pin: PhantomPinned, + }, + } + } + + /// Initializes a global lock. + /// + /// # Safety + /// + /// Must not be called more than once on a given lock. + pub unsafe fn init(&'static self) { + // SAFETY: The pointer to `state` is valid for the duration of this call, and both `name` + // and `key` are valid indefinitely. The `state` is pinned since we have a `'static` + // reference to `self`. + // + // We have exclusive access to the `state` since the caller of `new` promised to call + // `init` before using any other methods. As `init` can only be called once, all other + // uses of this lock must happen after this call. + unsafe { + B::Backend::init( + self.inner.state.get(), + B::NAME.as_char_ptr(), + B::get_lock_class().as_ptr(), + ) + } + } + + /// Lock this global lock. + pub fn lock(&'static self) -> GlobalGuard<B> { + GlobalGuard { + inner: self.inner.lock(), + } + } + + /// Try to lock this global lock. + pub fn try_lock(&'static self) -> Option<GlobalGuard<B>> { + Some(GlobalGuard { + inner: self.inner.try_lock()?, + }) + } +} + +/// A guard for a [`GlobalLock`]. +/// +/// See [`global_lock!`] for examples. +pub struct GlobalGuard<B: GlobalLockBackend> { + inner: Guard<'static, B::Item, B::Backend>, +} + +impl<B: GlobalLockBackend> core::ops::Deref for GlobalGuard<B> { + type Target = B::Item; + + fn deref(&self) -> &Self::Target { + &self.inner + } +} + +impl<B: GlobalLockBackend> core::ops::DerefMut for GlobalGuard<B> +where + B::Item: Unpin, +{ + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.inner + } +} + +/// A version of [`LockedBy`] for a [`GlobalLock`]. +/// +/// See [`global_lock!`] for examples. +pub struct GlobalLockedBy<T: ?Sized, B: GlobalLockBackend> { + _backend: PhantomData<B>, + value: UnsafeCell<T>, +} + +// SAFETY: The same thread-safety rules as `LockedBy` apply to `GlobalLockedBy`. +unsafe impl<T, B> Send for GlobalLockedBy<T, B> +where + T: ?Sized, + B: GlobalLockBackend, + LockedBy<T, B::Item>: Send, +{ +} + +// SAFETY: The same thread-safety rules as `LockedBy` apply to `GlobalLockedBy`. +unsafe impl<T, B> Sync for GlobalLockedBy<T, B> +where + T: ?Sized, + B: GlobalLockBackend, + LockedBy<T, B::Item>: Sync, +{ +} + +impl<T, B: GlobalLockBackend> GlobalLockedBy<T, B> { + /// Create a new [`GlobalLockedBy`]. + /// + /// The provided value will be protected by the global lock indicated by `B`. + pub fn new(val: T) -> Self { + Self { + value: UnsafeCell::new(val), + _backend: PhantomData, + } + } +} + +impl<T: ?Sized, B: GlobalLockBackend> GlobalLockedBy<T, B> { + /// Access the value immutably. + /// + /// The caller must prove shared access to the lock. + pub fn as_ref<'a>(&'a self, _guard: &'a GlobalGuard<B>) -> &'a T { + // SAFETY: The lock is globally unique, so there can only be one guard. + unsafe { &*self.value.get() } + } + + /// Access the value mutably. + /// + /// The caller must prove shared exclusive to the lock. + pub fn as_mut<'a>(&'a self, _guard: &'a mut GlobalGuard<B>) -> &'a mut T { + // SAFETY: The lock is globally unique, so there can only be one guard. + unsafe { &mut *self.value.get() } + } + + /// Access the value mutably directly. + /// + /// The caller has exclusive access to this `GlobalLockedBy`, so they do not need to hold the + /// lock. + pub fn get_mut(&mut self) -> &mut T { + self.value.get_mut() + } +} + +/// Defines a global lock. +/// +/// The global mutex must be initialized before first use. Usually this is done by calling +/// [`GlobalLock::init`] in the module initializer. +/// +/// # Examples +/// +/// A global counter: +/// +/// ``` +/// # mod ex { +/// # use kernel::prelude::*; +/// kernel::sync::global_lock! { +/// // SAFETY: Initialized in module initializer before first use. +/// unsafe(uninit) static MY_COUNTER: Mutex<u32> = 0; +/// } +/// +/// fn increment_counter() -> u32 { +/// let mut guard = MY_COUNTER.lock(); +/// *guard += 1; +/// *guard +/// } +/// +/// impl kernel::Module for MyModule { +/// fn init(_module: &'static ThisModule) -> Result<Self> { +/// // SAFETY: Called exactly once. +/// unsafe { MY_COUNTER.init() }; +/// +/// Ok(MyModule {}) +/// } +/// } +/// # struct MyModule {} +/// # } +/// ``` +/// +/// A global mutex used to protect all instances of a given struct: +/// +/// ``` +/// # mod ex { +/// # use kernel::prelude::*; +/// use kernel::sync::{GlobalGuard, GlobalLockedBy}; +/// +/// kernel::sync::global_lock! { +/// // SAFETY: Initialized in module initializer before first use. +/// unsafe(uninit) static MY_MUTEX: Mutex<()> = (); +/// } +/// +/// /// All instances of this struct are protected by `MY_MUTEX`. +/// struct MyStruct { +/// my_counter: GlobalLockedBy<u32, MY_MUTEX>, +/// } +/// +/// impl MyStruct { +/// /// Increment the counter in this instance. +/// /// +/// /// The caller must hold the `MY_MUTEX` mutex. +/// fn increment(&self, guard: &mut GlobalGuard<MY_MUTEX>) -> u32 { +/// let my_counter = self.my_counter.as_mut(guard); +/// *my_counter += 1; +/// *my_counter +/// } +/// } +/// +/// impl kernel::Module for MyModule { +/// fn init(_module: &'static ThisModule) -> Result<Self> { +/// // SAFETY: Called exactly once. +/// unsafe { MY_MUTEX.init() }; +/// +/// Ok(MyModule {}) +/// } +/// } +/// # struct MyModule {} +/// # } +/// ``` +#[macro_export] +macro_rules! global_lock { + { + $(#[$meta:meta])* $pub:vis + unsafe(uninit) static $name:ident: $kind:ident<$valuety:ty> = $value:expr; + } => { + #[doc = ::core::concat!( + "Backend type used by [`", + ::core::stringify!($name), + "`](static@", + ::core::stringify!($name), + ")." + )] + #[allow(non_camel_case_types, unreachable_pub)] + $pub enum $name {} + + impl $crate::sync::lock::GlobalLockBackend for $name { + const NAME: &'static $crate::str::CStr = $crate::c_str!(::core::stringify!($name)); + type Item = $valuety; + type Backend = $crate::global_lock_inner!(backend $kind); + + fn get_lock_class() -> Pin<&'static $crate::sync::LockClassKey> { + $crate::static_lock_class!() + } + } + + $(#[$meta])* + $pub static $name: $crate::sync::lock::GlobalLock<$name> = { + // Defined here to be outside the unsafe scope. + let init: $valuety = $value; + + // SAFETY: + // * The user of this macro promises to initialize the macro before use. + // * We are only generating one static with this backend type. + unsafe { $crate::sync::lock::GlobalLock::new(init) } + }; + }; +} +pub use global_lock; + +#[doc(hidden)] +#[macro_export] +macro_rules! global_lock_inner { + (backend Mutex) => { + $crate::sync::lock::mutex::MutexBackend + }; + (backend SpinLock) => { + $crate::sync::lock::spinlock::SpinLockBackend + }; +} diff --git a/rust/kernel/sync/lock/mutex.rs b/rust/kernel/sync/lock/mutex.rs new file mode 100644 index 000000000000..581cee7ab842 --- /dev/null +++ b/rust/kernel/sync/lock/mutex.rs @@ -0,0 +1,142 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! A kernel mutex. +//! +//! This module allows Rust code to use the kernel's `struct mutex`. + +/// Creates a [`Mutex`] initialiser with the given name and a newly-created lock class. +/// +/// It uses the name if one is given, otherwise it generates one based on the file name and line +/// number. +#[macro_export] +macro_rules! new_mutex { + ($inner:expr $(, $name:literal)? $(,)?) => { + $crate::sync::Mutex::new( + $inner, $crate::optional_name!($($name)?), $crate::static_lock_class!()) + }; +} +pub use new_mutex; + +/// A mutual exclusion primitive. +/// +/// Exposes the kernel's [`struct mutex`]. When multiple threads attempt to lock the same mutex, +/// only one at a time is allowed to progress, the others will block (sleep) until the mutex is +/// unlocked, at which point another thread will be allowed to wake up and make progress. +/// +/// Since it may block, [`Mutex`] needs to be used with care in atomic contexts. +/// +/// Instances of [`Mutex`] need a lock class and to be pinned. The recommended way to create such +/// instances is with the [`pin_init`](pin_init::pin_init) and [`new_mutex`] macros. +/// +/// # Examples +/// +/// The following example shows how to declare, allocate and initialise a struct (`Example`) that +/// contains an inner struct (`Inner`) that is protected by a mutex. +/// +/// ``` +/// use kernel::sync::{new_mutex, Mutex}; +/// +/// struct Inner { +/// a: u32, +/// b: u32, +/// } +/// +/// #[pin_data] +/// struct Example { +/// c: u32, +/// #[pin] +/// d: Mutex<Inner>, +/// } +/// +/// impl Example { +/// fn new() -> impl PinInit<Self> { +/// pin_init!(Self { +/// c: 10, +/// d <- new_mutex!(Inner { a: 20, b: 30 }), +/// }) +/// } +/// } +/// +/// // Allocate a boxed `Example`. +/// let e = KBox::pin_init(Example::new(), GFP_KERNEL)?; +/// assert_eq!(e.c, 10); +/// assert_eq!(e.d.lock().a, 20); +/// assert_eq!(e.d.lock().b, 30); +/// # Ok::<(), Error>(()) +/// ``` +/// +/// The following example shows how to use interior mutability to modify the contents of a struct +/// protected by a mutex despite only having a shared reference: +/// +/// ``` +/// use kernel::sync::Mutex; +/// +/// struct Example { +/// a: u32, +/// b: u32, +/// } +/// +/// fn example(m: &Mutex<Example>) { +/// let mut guard = m.lock(); +/// guard.a += 10; +/// guard.b += 20; +/// } +/// ``` +/// +/// [`struct mutex`]: srctree/include/linux/mutex.h +pub type Mutex<T> = super::Lock<T, MutexBackend>; + +/// A [`Guard`] acquired from locking a [`Mutex`]. +/// +/// This is simply a type alias for a [`Guard`] returned from locking a [`Mutex`]. It will unlock +/// the [`Mutex`] upon being dropped. +/// +/// [`Guard`]: super::Guard +pub type MutexGuard<'a, T> = super::Guard<'a, T, MutexBackend>; + +/// A kernel `struct mutex` lock backend. +pub struct MutexBackend; + +// SAFETY: The underlying kernel `struct mutex` object ensures mutual exclusion. +unsafe impl super::Backend for MutexBackend { + type State = bindings::mutex; + type GuardState = (); + + unsafe fn init( + ptr: *mut Self::State, + name: *const crate::ffi::c_char, + key: *mut bindings::lock_class_key, + ) { + // SAFETY: The safety requirements ensure that `ptr` is valid for writes, and `name` and + // `key` are valid for read indefinitely. + unsafe { bindings::__mutex_init(ptr, name, key) } + } + + unsafe fn lock(ptr: *mut Self::State) -> Self::GuardState { + // SAFETY: The safety requirements of this function ensure that `ptr` points to valid + // memory, and that it has been initialised before. + unsafe { bindings::mutex_lock(ptr) }; + } + + unsafe fn unlock(ptr: *mut Self::State, _guard_state: &Self::GuardState) { + // SAFETY: The safety requirements of this function ensure that `ptr` is valid and that the + // caller is the owner of the mutex. + unsafe { bindings::mutex_unlock(ptr) }; + } + + unsafe fn try_lock(ptr: *mut Self::State) -> Option<Self::GuardState> { + // SAFETY: The `ptr` pointer is guaranteed to be valid and initialized before use. + let result = unsafe { bindings::mutex_trylock(ptr) }; + + if result != 0 { + Some(()) + } else { + None + } + } + + unsafe fn assert_is_held(ptr: *mut Self::State) { + // SAFETY: The `ptr` pointer is guaranteed to be valid and initialized before use. + unsafe { bindings::mutex_assert_is_held(ptr) } + } +} diff --git a/rust/kernel/sync/lock/spinlock.rs b/rust/kernel/sync/lock/spinlock.rs new file mode 100644 index 000000000000..d7be38ccbdc7 --- /dev/null +++ b/rust/kernel/sync/lock/spinlock.rs @@ -0,0 +1,141 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! A kernel spinlock. +//! +//! This module allows Rust code to use the kernel's `spinlock_t`. + +/// Creates a [`SpinLock`] initialiser with the given name and a newly-created lock class. +/// +/// It uses the name if one is given, otherwise it generates one based on the file name and line +/// number. +#[macro_export] +macro_rules! new_spinlock { + ($inner:expr $(, $name:literal)? $(,)?) => { + $crate::sync::SpinLock::new( + $inner, $crate::optional_name!($($name)?), $crate::static_lock_class!()) + }; +} +pub use new_spinlock; + +/// A spinlock. +/// +/// Exposes the kernel's [`spinlock_t`]. When multiple CPUs attempt to lock the same spinlock, only +/// one at a time is allowed to progress, the others will block (spinning) until the spinlock is +/// unlocked, at which point another CPU will be allowed to make progress. +/// +/// Instances of [`SpinLock`] need a lock class and to be pinned. The recommended way to create such +/// instances is with the [`pin_init`](pin_init::pin_init) and [`new_spinlock`] macros. +/// +/// # Examples +/// +/// The following example shows how to declare, allocate and initialise a struct (`Example`) that +/// contains an inner struct (`Inner`) that is protected by a spinlock. +/// +/// ``` +/// use kernel::sync::{new_spinlock, SpinLock}; +/// +/// struct Inner { +/// a: u32, +/// b: u32, +/// } +/// +/// #[pin_data] +/// struct Example { +/// c: u32, +/// #[pin] +/// d: SpinLock<Inner>, +/// } +/// +/// impl Example { +/// fn new() -> impl PinInit<Self> { +/// pin_init!(Self { +/// c: 10, +/// d <- new_spinlock!(Inner { a: 20, b: 30 }), +/// }) +/// } +/// } +/// +/// // Allocate a boxed `Example`. +/// let e = KBox::pin_init(Example::new(), GFP_KERNEL)?; +/// assert_eq!(e.c, 10); +/// assert_eq!(e.d.lock().a, 20); +/// assert_eq!(e.d.lock().b, 30); +/// # Ok::<(), Error>(()) +/// ``` +/// +/// The following example shows how to use interior mutability to modify the contents of a struct +/// protected by a spinlock despite only having a shared reference: +/// +/// ``` +/// use kernel::sync::SpinLock; +/// +/// struct Example { +/// a: u32, +/// b: u32, +/// } +/// +/// fn example(m: &SpinLock<Example>) { +/// let mut guard = m.lock(); +/// guard.a += 10; +/// guard.b += 20; +/// } +/// ``` +/// +/// [`spinlock_t`]: srctree/include/linux/spinlock.h +pub type SpinLock<T> = super::Lock<T, SpinLockBackend>; + +/// A kernel `spinlock_t` lock backend. +pub struct SpinLockBackend; + +/// A [`Guard`] acquired from locking a [`SpinLock`]. +/// +/// This is simply a type alias for a [`Guard`] returned from locking a [`SpinLock`]. It will unlock +/// the [`SpinLock`] upon being dropped. +/// +/// [`Guard`]: super::Guard +pub type SpinLockGuard<'a, T> = super::Guard<'a, T, SpinLockBackend>; + +// SAFETY: The underlying kernel `spinlock_t` object ensures mutual exclusion. `relock` uses the +// default implementation that always calls the same locking method. +unsafe impl super::Backend for SpinLockBackend { + type State = bindings::spinlock_t; + type GuardState = (); + + unsafe fn init( + ptr: *mut Self::State, + name: *const crate::ffi::c_char, + key: *mut bindings::lock_class_key, + ) { + // SAFETY: The safety requirements ensure that `ptr` is valid for writes, and `name` and + // `key` are valid for read indefinitely. + unsafe { bindings::__spin_lock_init(ptr, name, key) } + } + + unsafe fn lock(ptr: *mut Self::State) -> Self::GuardState { + // SAFETY: The safety requirements of this function ensure that `ptr` points to valid + // memory, and that it has been initialised before. + unsafe { bindings::spin_lock(ptr) } + } + + unsafe fn unlock(ptr: *mut Self::State, _guard_state: &Self::GuardState) { + // SAFETY: The safety requirements of this function ensure that `ptr` is valid and that the + // caller is the owner of the spinlock. + unsafe { bindings::spin_unlock(ptr) } + } + + unsafe fn try_lock(ptr: *mut Self::State) -> Option<Self::GuardState> { + // SAFETY: The `ptr` pointer is guaranteed to be valid and initialized before use. + let result = unsafe { bindings::spin_trylock(ptr) }; + + if result != 0 { + Some(()) + } else { + None + } + } + + unsafe fn assert_is_held(ptr: *mut Self::State) { + // SAFETY: The `ptr` pointer is guaranteed to be valid and initialized before use. + unsafe { bindings::spin_assert_is_held(ptr) } + } +} diff --git a/rust/kernel/sync/locked_by.rs b/rust/kernel/sync/locked_by.rs new file mode 100644 index 000000000000..61f100a45b35 --- /dev/null +++ b/rust/kernel/sync/locked_by.rs @@ -0,0 +1,169 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! A wrapper for data protected by a lock that does not wrap it. + +use super::{lock::Backend, lock::Lock}; +use crate::build_assert; +use core::{cell::UnsafeCell, mem::size_of, ptr}; + +/// Allows access to some data to be serialised by a lock that does not wrap it. +/// +/// In most cases, data protected by a lock is wrapped by the appropriate lock type, e.g., +/// [`Mutex`] or [`SpinLock`]. [`LockedBy`] is meant for cases when this is not possible. +/// For example, if a container has a lock and some data in the contained elements needs +/// to be protected by the same lock. +/// +/// [`LockedBy`] wraps the data in lieu of another locking primitive, and only allows access to it +/// when the caller shows evidence that the 'external' lock is locked. It panics if the evidence +/// refers to the wrong instance of the lock. +/// +/// [`Mutex`]: super::Mutex +/// [`SpinLock`]: super::SpinLock +/// +/// # Examples +/// +/// The following is an example for illustrative purposes: `InnerDirectory::bytes_used` is an +/// aggregate of all `InnerFile::bytes_used` and must be kept consistent; so we wrap `InnerFile` in +/// a `LockedBy` so that it shares a lock with `InnerDirectory`. This allows us to enforce at +/// compile-time that access to `InnerFile` is only granted when an `InnerDirectory` is also +/// locked; we enforce at run time that the right `InnerDirectory` is locked. +/// +/// ``` +/// use kernel::sync::{LockedBy, Mutex}; +/// +/// struct InnerFile { +/// bytes_used: u64, +/// } +/// +/// struct File { +/// _ino: u32, +/// inner: LockedBy<InnerFile, InnerDirectory>, +/// } +/// +/// struct InnerDirectory { +/// /// The sum of the bytes used by all files. +/// bytes_used: u64, +/// _files: KVec<File>, +/// } +/// +/// struct Directory { +/// _ino: u32, +/// inner: Mutex<InnerDirectory>, +/// } +/// +/// /// Prints `bytes_used` from both the directory and file. +/// fn print_bytes_used(dir: &Directory, file: &File) { +/// let guard = dir.inner.lock(); +/// let inner_file = file.inner.access(&guard); +/// pr_info!("{} {}\n", guard.bytes_used, inner_file.bytes_used); +/// } +/// +/// /// Increments `bytes_used` for both the directory and file. +/// fn inc_bytes_used(dir: &Directory, file: &File) { +/// let mut guard = dir.inner.lock(); +/// guard.bytes_used += 10; +/// +/// let file_inner = file.inner.access_mut(&mut guard); +/// file_inner.bytes_used += 10; +/// } +/// +/// /// Creates a new file. +/// fn new_file(ino: u32, dir: &Directory) -> File { +/// File { +/// _ino: ino, +/// inner: LockedBy::new(&dir.inner, InnerFile { bytes_used: 0 }), +/// } +/// } +/// ``` +pub struct LockedBy<T: ?Sized, U: ?Sized> { + owner: *const U, + data: UnsafeCell<T>, +} + +// SAFETY: `LockedBy` can be transferred across thread boundaries iff the data it protects can. +unsafe impl<T: ?Sized + Send, U: ?Sized> Send for LockedBy<T, U> {} + +// SAFETY: If `T` is not `Sync`, then parallel shared access to this `LockedBy` allows you to use +// `access_mut` to hand out `&mut T` on one thread at the time. The requirement that `T: Send` is +// sufficient to allow that. +// +// If `T` is `Sync`, then the `access` method also becomes available, which allows you to obtain +// several `&T` from several threads at once. However, this is okay as `T` is `Sync`. +unsafe impl<T: ?Sized + Send, U: ?Sized> Sync for LockedBy<T, U> {} + +impl<T, U> LockedBy<T, U> { + /// Constructs a new instance of [`LockedBy`]. + /// + /// It stores a raw pointer to the owner that is never dereferenced. It is only used to ensure + /// that the right owner is being used to access the protected data. If the owner is freed, the + /// data becomes inaccessible; if another instance of the owner is allocated *on the same + /// memory location*, the data becomes accessible again: none of this affects memory safety + /// because in any case at most one thread (or CPU) can access the protected data at a time. + pub fn new<B: Backend>(owner: &Lock<U, B>, data: T) -> Self { + build_assert!( + size_of::<Lock<U, B>>() > 0, + "The lock type cannot be a ZST because it may be impossible to distinguish instances" + ); + Self { + owner: owner.data.get(), + data: UnsafeCell::new(data), + } + } +} + +impl<T: ?Sized, U> LockedBy<T, U> { + /// Returns a reference to the protected data when the caller provides evidence (via a + /// reference) that the owner is locked. + /// + /// `U` cannot be a zero-sized type (ZST) because there are ways to get an `&U` that matches + /// the data protected by the lock without actually holding it. + /// + /// # Panics + /// + /// Panics if `owner` is different from the data protected by the lock used in + /// [`new`](LockedBy::new). + pub fn access<'a>(&'a self, owner: &'a U) -> &'a T + where + T: Sync, + { + build_assert!( + size_of::<U>() > 0, + "`U` cannot be a ZST because `owner` wouldn't be unique" + ); + if !ptr::eq(owner, self.owner) { + panic!("mismatched owners"); + } + + // SAFETY: `owner` is evidence that there are only shared references to the owner for the + // duration of 'a, so it's not possible to use `Self::access_mut` to obtain a mutable + // reference to the inner value that aliases with this shared reference. The type is `Sync` + // so there are no other requirements. + unsafe { &*self.data.get() } + } + + /// Returns a mutable reference to the protected data when the caller provides evidence (via a + /// mutable owner) that the owner is locked mutably. + /// + /// `U` cannot be a zero-sized type (ZST) because there are ways to get an `&mut U` that + /// matches the data protected by the lock without actually holding it. + /// + /// Showing a mutable reference to the owner is sufficient because we know no other references + /// can exist to it. + /// + /// # Panics + /// + /// Panics if `owner` is different from the data protected by the lock used in + /// [`new`](LockedBy::new). + pub fn access_mut<'a>(&'a self, owner: &'a mut U) -> &'a mut T { + build_assert!( + size_of::<U>() > 0, + "`U` cannot be a ZST because `owner` wouldn't be unique" + ); + if !ptr::eq(owner, self.owner) { + panic!("mismatched owners"); + } + + // SAFETY: `owner` is evidence that there is only one reference to the owner. + unsafe { &mut *self.data.get() } + } +} diff --git a/rust/kernel/sync/poll.rs b/rust/kernel/sync/poll.rs new file mode 100644 index 000000000000..0ec985d560c8 --- /dev/null +++ b/rust/kernel/sync/poll.rs @@ -0,0 +1,106 @@ +// SPDX-License-Identifier: GPL-2.0 + +// Copyright (C) 2024 Google LLC. + +//! Utilities for working with `struct poll_table`. + +use crate::{ + bindings, + fs::File, + prelude::*, + sync::{CondVar, LockClassKey}, +}; +use core::{marker::PhantomData, ops::Deref}; + +/// Creates a [`PollCondVar`] initialiser with the given name and a newly-created lock class. +#[macro_export] +macro_rules! new_poll_condvar { + ($($name:literal)?) => { + $crate::sync::poll::PollCondVar::new( + $crate::optional_name!($($name)?), $crate::static_lock_class!() + ) + }; +} + +/// Wraps the kernel's `poll_table`. +/// +/// # Invariants +/// +/// The pointer must be null or reference a valid `poll_table`. +#[repr(transparent)] +pub struct PollTable<'a> { + table: *mut bindings::poll_table, + _lifetime: PhantomData<&'a bindings::poll_table>, +} + +impl<'a> PollTable<'a> { + /// Creates a [`PollTable`] from a valid pointer. + /// + /// # Safety + /// + /// The pointer must be null or reference a valid `poll_table` for the duration of `'a`. + pub unsafe fn from_raw(table: *mut bindings::poll_table) -> Self { + // INVARIANTS: The safety requirements are the same as the struct invariants. + PollTable { + table, + _lifetime: PhantomData, + } + } + + /// Register this [`PollTable`] with the provided [`PollCondVar`], so that it can be notified + /// using the condition variable. + pub fn register_wait(&self, file: &File, cv: &PollCondVar) { + // SAFETY: + // * `file.as_ptr()` references a valid file for the duration of this call. + // * `self.table` is null or references a valid poll_table for the duration of this call. + // * Since `PollCondVar` is pinned, its destructor is guaranteed to run before the memory + // containing `cv.wait_queue_head` is invalidated. Since the destructor clears all + // waiters and then waits for an rcu grace period, it's guaranteed that + // `cv.wait_queue_head` remains valid for at least an rcu grace period after the removal + // of the last waiter. + unsafe { bindings::poll_wait(file.as_ptr(), cv.wait_queue_head.get(), self.table) } + } +} + +/// A wrapper around [`CondVar`] that makes it usable with [`PollTable`]. +/// +/// [`CondVar`]: crate::sync::CondVar +#[pin_data(PinnedDrop)] +pub struct PollCondVar { + #[pin] + inner: CondVar, +} + +impl PollCondVar { + /// Constructs a new condvar initialiser. + pub fn new(name: &'static CStr, key: Pin<&'static LockClassKey>) -> impl PinInit<Self> { + pin_init!(Self { + inner <- CondVar::new(name, key), + }) + } +} + +// Make the `CondVar` methods callable on `PollCondVar`. +impl Deref for PollCondVar { + type Target = CondVar; + + fn deref(&self) -> &CondVar { + &self.inner + } +} + +#[pinned_drop] +impl PinnedDrop for PollCondVar { + #[inline] + fn drop(self: Pin<&mut Self>) { + // Clear anything registered using `register_wait`. + // + // SAFETY: The pointer points at a valid `wait_queue_head`. + unsafe { bindings::__wake_up_pollfree(self.inner.wait_queue_head.get()) }; + + // Wait for epoll items to be properly removed. + // + // SAFETY: Just an FFI call. + unsafe { bindings::synchronize_rcu() }; + } +} diff --git a/rust/kernel/sync/rcu.rs b/rust/kernel/sync/rcu.rs new file mode 100644 index 000000000000..a32bef6e490b --- /dev/null +++ b/rust/kernel/sync/rcu.rs @@ -0,0 +1,52 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! RCU support. +//! +//! C header: [`include/linux/rcupdate.h`](srctree/include/linux/rcupdate.h) + +use crate::{bindings, types::NotThreadSafe}; + +/// Evidence that the RCU read side lock is held on the current thread/CPU. +/// +/// The type is explicitly not `Send` because this property is per-thread/CPU. +/// +/// # Invariants +/// +/// The RCU read side lock is actually held while instances of this guard exist. +pub struct Guard(NotThreadSafe); + +impl Guard { + /// Acquires the RCU read side lock and returns a guard. + #[inline] + pub fn new() -> Self { + // SAFETY: An FFI call with no additional requirements. + unsafe { bindings::rcu_read_lock() }; + // INVARIANT: The RCU read side lock was just acquired above. + Self(NotThreadSafe) + } + + /// Explicitly releases the RCU read side lock. + #[inline] + pub fn unlock(self) {} +} + +impl Default for Guard { + #[inline] + fn default() -> Self { + Self::new() + } +} + +impl Drop for Guard { + #[inline] + fn drop(&mut self) { + // SAFETY: By the type invariants, the RCU read side is locked, so it is ok to unlock it. + unsafe { bindings::rcu_read_unlock() }; + } +} + +/// Acquires the RCU read side lock. +#[inline] +pub fn read_lock() -> Guard { + Guard::new() +} diff --git a/rust/kernel/sync/refcount.rs b/rust/kernel/sync/refcount.rs new file mode 100644 index 000000000000..19236a5bccde --- /dev/null +++ b/rust/kernel/sync/refcount.rs @@ -0,0 +1,113 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! Atomic reference counting. +//! +//! C header: [`include/linux/refcount.h`](srctree/include/linux/refcount.h) + +use crate::build_assert; +use crate::sync::atomic::Atomic; +use crate::types::Opaque; + +/// Atomic reference counter. +/// +/// This type is conceptually an atomic integer, but provides saturation semantics compared to +/// normal atomic integers. Values in the negative range when viewed as a signed integer are +/// saturation (bad) values. For details about the saturation semantics, please refer to top of +/// [`include/linux/refcount.h`](srctree/include/linux/refcount.h). +/// +/// Wraps the kernel's C `refcount_t`. +#[repr(transparent)] +pub struct Refcount(Opaque<bindings::refcount_t>); + +impl Refcount { + /// Construct a new [`Refcount`] from an initial value. + /// + /// The initial value should be non-saturated. + #[inline] + pub fn new(value: i32) -> Self { + build_assert!(value >= 0, "initial value saturated"); + // SAFETY: There are no safety requirements for this FFI call. + Self(Opaque::new(unsafe { bindings::REFCOUNT_INIT(value) })) + } + + #[inline] + fn as_ptr(&self) -> *mut bindings::refcount_t { + self.0.get() + } + + /// Get the underlying atomic counter that backs the refcount. + /// + /// NOTE: Usage of this function is discouraged as it can circumvent the protections offered by + /// `refcount.h`. If there is no way to achieve the result using APIs in `refcount.h`, then + /// this function can be used. Otherwise consider adding a binding for the required API. + #[inline] + pub fn as_atomic(&self) -> &Atomic<i32> { + let ptr = self.0.get().cast(); + // SAFETY: `refcount_t` is a transparent wrapper of `atomic_t`, which is an atomic 32-bit + // integer that is layout-wise compatible with `Atomic<i32>`. All values are valid for + // `refcount_t`, despite some of the values being considered saturated and "bad". + unsafe { &*ptr } + } + + /// Set a refcount's value. + #[inline] + pub fn set(&self, value: i32) { + // SAFETY: `self.as_ptr()` is valid. + unsafe { bindings::refcount_set(self.as_ptr(), value) } + } + + /// Increment a refcount. + /// + /// It will saturate if overflows and `WARN`. It will also `WARN` if the refcount is 0, as this + /// represents a possible use-after-free condition. + /// + /// Provides no memory ordering, it is assumed that caller already has a reference on the + /// object. + #[inline] + pub fn inc(&self) { + // SAFETY: self is valid. + unsafe { bindings::refcount_inc(self.as_ptr()) } + } + + /// Decrement a refcount. + /// + /// It will `WARN` on underflow and fail to decrement when saturated. + /// + /// Provides release memory ordering, such that prior loads and stores are done + /// before. + #[inline] + pub fn dec(&self) { + // SAFETY: `self.as_ptr()` is valid. + unsafe { bindings::refcount_dec(self.as_ptr()) } + } + + /// Decrement a refcount and test if it is 0. + /// + /// It will `WARN` on underflow and fail to decrement when saturated. + /// + /// Provides release memory ordering, such that prior loads and stores are done + /// before, and provides an acquire ordering on success such that memory deallocation + /// must come after. + /// + /// Returns true if the resulting refcount is 0, false otherwise. + /// + /// # Notes + /// + /// A common pattern of using `Refcount` is to free memory when the reference count reaches + /// zero. This means that the reference to `Refcount` could become invalid after calling this + /// function. This is fine as long as the reference to `Refcount` is no longer used when this + /// function returns `false`. It is not necessary to use raw pointers in this scenario, see + /// <https://github.com/rust-lang/rust/issues/55005>. + #[inline] + #[must_use = "use `dec` instead if you do not need to test if it is 0"] + pub fn dec_and_test(&self) -> bool { + // SAFETY: `self.as_ptr()` is valid. + unsafe { bindings::refcount_dec_and_test(self.as_ptr()) } + } +} + +// SAFETY: `refcount_t` is thread-safe. +unsafe impl Send for Refcount {} + +// SAFETY: `refcount_t` is thread-safe. +unsafe impl Sync for Refcount {} diff --git a/rust/kernel/sync/set_once.rs b/rust/kernel/sync/set_once.rs new file mode 100644 index 000000000000..bdba601807d8 --- /dev/null +++ b/rust/kernel/sync/set_once.rs @@ -0,0 +1,125 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! A container that can be initialized at most once. + +use super::atomic::{ + ordering::{Acquire, Relaxed, Release}, + Atomic, +}; +use core::{cell::UnsafeCell, mem::MaybeUninit}; + +/// A container that can be populated at most once. Thread safe. +/// +/// Once the a [`SetOnce`] is populated, it remains populated by the same object for the +/// lifetime `Self`. +/// +/// # Invariants +/// +/// - `init` may only increase in value. +/// - `init` may only assume values in the range `0..=2`. +/// - `init == 0` if and only if `value` is uninitialized. +/// - `init == 1` if and only if there is exactly one thread with exclusive +/// access to `self.value`. +/// - `init == 2` if and only if `value` is initialized and valid for shared +/// access. +/// +/// # Example +/// +/// ``` +/// # use kernel::sync::SetOnce; +/// let value = SetOnce::new(); +/// assert_eq!(None, value.as_ref()); +/// +/// let status = value.populate(42u8); +/// assert_eq!(true, status); +/// assert_eq!(Some(&42u8), value.as_ref()); +/// assert_eq!(Some(42u8), value.copy()); +/// +/// let status = value.populate(101u8); +/// assert_eq!(false, status); +/// assert_eq!(Some(&42u8), value.as_ref()); +/// assert_eq!(Some(42u8), value.copy()); +/// ``` +pub struct SetOnce<T> { + init: Atomic<u32>, + value: UnsafeCell<MaybeUninit<T>>, +} + +impl<T> Default for SetOnce<T> { + fn default() -> Self { + Self::new() + } +} + +impl<T> SetOnce<T> { + /// Create a new [`SetOnce`]. + /// + /// The returned instance will be empty. + pub const fn new() -> Self { + // INVARIANT: The container is empty and we initialize `init` to `0`. + Self { + value: UnsafeCell::new(MaybeUninit::uninit()), + init: Atomic::new(0), + } + } + + /// Get a reference to the contained object. + /// + /// Returns [`None`] if this [`SetOnce`] is empty. + pub fn as_ref(&self) -> Option<&T> { + if self.init.load(Acquire) == 2 { + // SAFETY: By the type invariants of `Self`, `self.init == 2` means that `self.value` + // is initialized and valid for shared access. + Some(unsafe { &*self.value.get().cast() }) + } else { + None + } + } + + /// Populate the [`SetOnce`]. + /// + /// Returns `true` if the [`SetOnce`] was successfully populated. + pub fn populate(&self, value: T) -> bool { + // INVARIANT: If the swap succeeds: + // - We increase `init`. + // - We write the valid value `1` to `init`. + // - Only one thread can succeed in this write, so we have exclusive access after the + // write. + if let Ok(0) = self.init.cmpxchg(0, 1, Relaxed) { + // SAFETY: By the type invariants of `Self`, the fact that we succeeded in writing `1` + // to `self.init` means we obtained exclusive access to `self.value`. + unsafe { core::ptr::write(self.value.get().cast(), value) }; + // INVARIANT: + // - We increase `init`. + // - We write the valid value `2` to `init`. + // - We release our exclusive access to `self.value` and it is now valid for shared + // access. + self.init.store(2, Release); + true + } else { + false + } + } + + /// Get a copy of the contained object. + /// + /// Returns [`None`] if the [`SetOnce`] is empty. + pub fn copy(&self) -> Option<T> + where + T: Copy, + { + self.as_ref().copied() + } +} + +impl<T> Drop for SetOnce<T> { + fn drop(&mut self) { + if *self.init.get_mut() == 2 { + let value = self.value.get_mut(); + // SAFETY: By the type invariants of `Self`, `self.init == 2` means that `self.value` + // contains a valid value. We have exclusive access, as we hold a `mut` reference to + // `self`. + unsafe { value.assume_init_drop() }; + } + } +} diff --git a/rust/kernel/task.rs b/rust/kernel/task.rs new file mode 100644 index 000000000000..49fad6de0674 --- /dev/null +++ b/rust/kernel/task.rs @@ -0,0 +1,427 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! Tasks (threads and processes). +//! +//! C header: [`include/linux/sched.h`](srctree/include/linux/sched.h). + +use crate::{ + bindings, + ffi::{c_int, c_long, c_uint}, + mm::MmWithUser, + pid_namespace::PidNamespace, + sync::aref::ARef, + types::{NotThreadSafe, Opaque}, +}; +use core::{ + cmp::{Eq, PartialEq}, + ops::Deref, + ptr, +}; + +/// A sentinel value used for infinite timeouts. +pub const MAX_SCHEDULE_TIMEOUT: c_long = c_long::MAX; + +/// Bitmask for tasks that are sleeping in an interruptible state. +pub const TASK_INTERRUPTIBLE: c_int = bindings::TASK_INTERRUPTIBLE as c_int; +/// Bitmask for tasks that are sleeping in an uninterruptible state. +pub const TASK_UNINTERRUPTIBLE: c_int = bindings::TASK_UNINTERRUPTIBLE as c_int; +/// Bitmask for tasks that are sleeping in a freezable state. +pub const TASK_FREEZABLE: c_int = bindings::TASK_FREEZABLE as c_int; +/// Convenience constant for waking up tasks regardless of whether they are in interruptible or +/// uninterruptible sleep. +pub const TASK_NORMAL: c_uint = bindings::TASK_NORMAL as c_uint; + +/// Returns the currently running task. +#[macro_export] +macro_rules! current { + () => { + // SAFETY: This expression creates a temporary value that is dropped at the end of the + // caller's scope. The following mechanisms ensure that the resulting `&CurrentTask` cannot + // leave current task context: + // + // * To return to userspace, the caller must leave the current scope. + // * Operations such as `begin_new_exec()` are necessarily unsafe and the caller of + // `begin_new_exec()` is responsible for safety. + // * Rust abstractions for things such as a `kthread_use_mm()` scope must require the + // closure to be `Send`, so the `NotThreadSafe` field of `CurrentTask` ensures that the + // `&CurrentTask` cannot cross the scope in either direction. + unsafe { &*$crate::task::Task::current() } + }; +} + +/// Wraps the kernel's `struct task_struct`. +/// +/// # Invariants +/// +/// All instances are valid tasks created by the C portion of the kernel. +/// +/// Instances of this type are always refcounted, that is, a call to `get_task_struct` ensures +/// that the allocation remains valid at least until the matching call to `put_task_struct`. +/// +/// # Examples +/// +/// The following is an example of getting the PID of the current thread with zero additional cost +/// when compared to the C version: +/// +/// ``` +/// let pid = current!().pid(); +/// ``` +/// +/// Getting the PID of the current process, also zero additional cost: +/// +/// ``` +/// let pid = current!().group_leader().pid(); +/// ``` +/// +/// Getting the current task and storing it in some struct. The reference count is automatically +/// incremented when creating `State` and decremented when it is dropped: +/// +/// ``` +/// use kernel::{task::Task, sync::aref::ARef}; +/// +/// struct State { +/// creator: ARef<Task>, +/// index: u32, +/// } +/// +/// impl State { +/// fn new() -> Self { +/// Self { +/// creator: ARef::from(&**current!()), +/// index: 0, +/// } +/// } +/// } +/// ``` +#[repr(transparent)] +pub struct Task(pub(crate) Opaque<bindings::task_struct>); + +// SAFETY: By design, the only way to access a `Task` is via the `current` function or via an +// `ARef<Task>` obtained through the `AlwaysRefCounted` impl. This means that the only situation in +// which a `Task` can be accessed mutably is when the refcount drops to zero and the destructor +// runs. It is safe for that to happen on any thread, so it is ok for this type to be `Send`. +unsafe impl Send for Task {} + +// SAFETY: It's OK to access `Task` through shared references from other threads because we're +// either accessing properties that don't change (e.g., `pid`, `group_leader`) or that are properly +// synchronised by C code (e.g., `signal_pending`). +unsafe impl Sync for Task {} + +/// Represents the [`Task`] in the `current` global. +/// +/// This type exists to provide more efficient operations that are only valid on the current task. +/// For example, to retrieve the pid-namespace of a task, you must use rcu protection unless it is +/// the current task. +/// +/// # Invariants +/// +/// Each value of this type must only be accessed from the task context it was created within. +/// +/// Of course, every thread is in a different task context, but for the purposes of this invariant, +/// these operations also permanently leave the task context: +/// +/// * Returning to userspace from system call context. +/// * Calling `release_task()`. +/// * Calling `begin_new_exec()` in a binary format loader. +/// +/// Other operations temporarily create a new sub-context: +/// +/// * Calling `kthread_use_mm()` creates a new context, and `kthread_unuse_mm()` returns to the +/// old context. +/// +/// This means that a `CurrentTask` obtained before a `kthread_use_mm()` call may be used again +/// once `kthread_unuse_mm()` is called, but it must not be used between these two calls. +/// Conversely, a `CurrentTask` obtained between a `kthread_use_mm()`/`kthread_unuse_mm()` pair +/// must not be used after `kthread_unuse_mm()`. +#[repr(transparent)] +pub struct CurrentTask(Task, NotThreadSafe); + +// Make all `Task` methods available on `CurrentTask`. +impl Deref for CurrentTask { + type Target = Task; + #[inline] + fn deref(&self) -> &Task { + &self.0 + } +} + +/// The type of process identifiers (PIDs). +pub type Pid = bindings::pid_t; + +/// The type of user identifiers (UIDs). +#[derive(Copy, Clone)] +pub struct Kuid { + kuid: bindings::kuid_t, +} + +impl Task { + /// Returns a raw pointer to the current task. + /// + /// It is up to the user to use the pointer correctly. + #[inline] + pub fn current_raw() -> *mut bindings::task_struct { + // SAFETY: Getting the current pointer is always safe. + unsafe { bindings::get_current() } + } + + /// Returns a task reference for the currently executing task/thread. + /// + /// The recommended way to get the current task/thread is to use the + /// [`current`] macro because it is safe. + /// + /// # Safety + /// + /// Callers must ensure that the returned object is only used to access a [`CurrentTask`] + /// within the task context that was active when this function was called. For more details, + /// see the invariants section for [`CurrentTask`]. + #[inline] + pub unsafe fn current() -> impl Deref<Target = CurrentTask> { + struct TaskRef { + task: *const CurrentTask, + } + + impl Deref for TaskRef { + type Target = CurrentTask; + + fn deref(&self) -> &Self::Target { + // SAFETY: The returned reference borrows from this `TaskRef`, so it cannot outlive + // the `TaskRef`, which the caller of `Task::current()` has promised will not + // outlive the task/thread for which `self.task` is the `current` pointer. Thus, it + // is okay to return a `CurrentTask` reference here. + unsafe { &*self.task } + } + } + + TaskRef { + // CAST: The layout of `struct task_struct` and `CurrentTask` is identical. + task: Task::current_raw().cast(), + } + } + + /// Returns a raw pointer to the task. + #[inline] + pub fn as_ptr(&self) -> *mut bindings::task_struct { + self.0.get() + } + + /// Returns the group leader of the given task. + pub fn group_leader(&self) -> &Task { + // SAFETY: The group leader of a task never changes after initialization, so reading this + // field is not a data race. + let ptr = unsafe { *ptr::addr_of!((*self.as_ptr()).group_leader) }; + + // SAFETY: The lifetime of the returned task reference is tied to the lifetime of `self`, + // and given that a task has a reference to its group leader, we know it must be valid for + // the lifetime of the returned task reference. + unsafe { &*ptr.cast() } + } + + /// Returns the PID of the given task. + pub fn pid(&self) -> Pid { + // SAFETY: The pid of a task never changes after initialization, so reading this field is + // not a data race. + unsafe { *ptr::addr_of!((*self.as_ptr()).pid) } + } + + /// Returns the UID of the given task. + #[inline] + pub fn uid(&self) -> Kuid { + // SAFETY: It's always safe to call `task_uid` on a valid task. + Kuid::from_raw(unsafe { bindings::task_uid(self.as_ptr()) }) + } + + /// Returns the effective UID of the given task. + #[inline] + pub fn euid(&self) -> Kuid { + // SAFETY: It's always safe to call `task_euid` on a valid task. + Kuid::from_raw(unsafe { bindings::task_euid(self.as_ptr()) }) + } + + /// Determines whether the given task has pending signals. + #[inline] + pub fn signal_pending(&self) -> bool { + // SAFETY: It's always safe to call `signal_pending` on a valid task. + unsafe { bindings::signal_pending(self.as_ptr()) != 0 } + } + + /// Returns task's pid namespace with elevated reference count + #[inline] + pub fn get_pid_ns(&self) -> Option<ARef<PidNamespace>> { + // SAFETY: By the type invariant, we know that `self.0` is valid. + let ptr = unsafe { bindings::task_get_pid_ns(self.as_ptr()) }; + if ptr.is_null() { + None + } else { + // SAFETY: `ptr` is valid by the safety requirements of this function. And we own a + // reference count via `task_get_pid_ns()`. + // CAST: `Self` is a `repr(transparent)` wrapper around `bindings::pid_namespace`. + Some(unsafe { ARef::from_raw(ptr::NonNull::new_unchecked(ptr.cast::<PidNamespace>())) }) + } + } + + /// Returns the given task's pid in the provided pid namespace. + #[doc(alias = "task_tgid_nr_ns")] + #[inline] + pub fn tgid_nr_ns(&self, pidns: Option<&PidNamespace>) -> Pid { + let pidns = match pidns { + Some(pidns) => pidns.as_ptr(), + None => core::ptr::null_mut(), + }; + // SAFETY: By the type invariant, we know that `self.0` is valid. We received a valid + // PidNamespace that we can use as a pointer or we received an empty PidNamespace and + // thus pass a null pointer. The underlying C function is safe to be used with NULL + // pointers. + unsafe { bindings::task_tgid_nr_ns(self.as_ptr(), pidns) } + } + + /// Wakes up the task. + #[inline] + pub fn wake_up(&self) { + // SAFETY: It's always safe to call `wake_up_process` on a valid task, even if the task + // running. + unsafe { bindings::wake_up_process(self.as_ptr()) }; + } +} + +impl CurrentTask { + /// Access the address space of the current task. + /// + /// This function does not touch the refcount of the mm. + #[inline] + pub fn mm(&self) -> Option<&MmWithUser> { + // SAFETY: The `mm` field of `current` is not modified from other threads, so reading it is + // not a data race. + let mm = unsafe { (*self.as_ptr()).mm }; + + if mm.is_null() { + return None; + } + + // SAFETY: If `current->mm` is non-null, then it references a valid mm with a non-zero + // value of `mm_users`. Furthermore, the returned `&MmWithUser` borrows from this + // `CurrentTask`, so it cannot escape the scope in which the current pointer was obtained. + // + // This is safe even if `kthread_use_mm()`/`kthread_unuse_mm()` are used. There are two + // relevant cases: + // * If the `&CurrentTask` was created before `kthread_use_mm()`, then it cannot be + // accessed during the `kthread_use_mm()`/`kthread_unuse_mm()` scope due to the + // `NotThreadSafe` field of `CurrentTask`. + // * If the `&CurrentTask` was created within a `kthread_use_mm()`/`kthread_unuse_mm()` + // scope, then the `&CurrentTask` cannot escape that scope, so the returned `&MmWithUser` + // also cannot escape that scope. + // In either case, it's not possible to read `current->mm` and keep using it after the + // scope is ended with `kthread_unuse_mm()`. + Some(unsafe { MmWithUser::from_raw(mm) }) + } + + /// Access the pid namespace of the current task. + /// + /// This function does not touch the refcount of the namespace or use RCU protection. + /// + /// To access the pid namespace of another task, see [`Task::get_pid_ns`]. + #[doc(alias = "task_active_pid_ns")] + #[inline] + pub fn active_pid_ns(&self) -> Option<&PidNamespace> { + // SAFETY: It is safe to call `task_active_pid_ns` without RCU protection when calling it + // on the current task. + let active_ns = unsafe { bindings::task_active_pid_ns(self.as_ptr()) }; + + if active_ns.is_null() { + return None; + } + + // The lifetime of `PidNamespace` is bound to `Task` and `struct pid`. + // + // The `PidNamespace` of a `Task` doesn't ever change once the `Task` is alive. + // + // From system call context retrieving the `PidNamespace` for the current task is always + // safe and requires neither RCU locking nor a reference count to be held. Retrieving the + // `PidNamespace` after `release_task()` for current will return `NULL` but no codepath + // like that is exposed to Rust. + // + // SAFETY: If `current`'s pid ns is non-null, then it references a valid pid ns. + // Furthermore, the returned `&PidNamespace` borrows from this `CurrentTask`, so it cannot + // escape the scope in which the current pointer was obtained, e.g. it cannot live past a + // `release_task()` call. + Some(unsafe { PidNamespace::from_ptr(active_ns) }) + } +} + +// SAFETY: The type invariants guarantee that `Task` is always refcounted. +unsafe impl crate::sync::aref::AlwaysRefCounted for Task { + #[inline] + fn inc_ref(&self) { + // SAFETY: The existence of a shared reference means that the refcount is nonzero. + unsafe { bindings::get_task_struct(self.as_ptr()) }; + } + + #[inline] + unsafe fn dec_ref(obj: ptr::NonNull<Self>) { + // SAFETY: The safety requirements guarantee that the refcount is nonzero. + unsafe { bindings::put_task_struct(obj.cast().as_ptr()) } + } +} + +impl Kuid { + /// Get the current euid. + #[inline] + pub fn current_euid() -> Kuid { + // SAFETY: Just an FFI call. + Self::from_raw(unsafe { bindings::current_euid() }) + } + + /// Create a `Kuid` given the raw C type. + #[inline] + pub fn from_raw(kuid: bindings::kuid_t) -> Self { + Self { kuid } + } + + /// Turn this kuid into the raw C type. + #[inline] + pub fn into_raw(self) -> bindings::kuid_t { + self.kuid + } + + /// Converts this kernel UID into a userspace UID. + /// + /// Uses the namespace of the current task. + #[inline] + pub fn into_uid_in_current_ns(self) -> bindings::uid_t { + // SAFETY: Just an FFI call. + unsafe { bindings::from_kuid(bindings::current_user_ns(), self.kuid) } + } +} + +impl PartialEq for Kuid { + #[inline] + fn eq(&self, other: &Kuid) -> bool { + // SAFETY: Just an FFI call. + unsafe { bindings::uid_eq(self.kuid, other.kuid) } + } +} + +impl Eq for Kuid {} + +/// Annotation for functions that can sleep. +/// +/// Equivalent to the C side [`might_sleep()`], this function serves as +/// a debugging aid and a potential scheduling point. +/// +/// This function can only be used in a nonatomic context. +/// +/// [`might_sleep()`]: https://docs.kernel.org/driver-api/basics.html#c.might_sleep +#[track_caller] +#[inline] +pub fn might_sleep() { + #[cfg(CONFIG_DEBUG_ATOMIC_SLEEP)] + { + let loc = core::panic::Location::caller(); + let file = kernel::file_from_location(loc); + + // SAFETY: `file.as_ptr()` is valid for reading and guaranteed to be nul-terminated. + unsafe { crate::bindings::__might_sleep(file.as_ptr().cast(), loc.line() as i32) } + } + + // SAFETY: Always safe to call. + unsafe { crate::bindings::might_resched() } +} diff --git a/rust/kernel/time.rs b/rust/kernel/time.rs new file mode 100644 index 000000000000..6ea98dfcd027 --- /dev/null +++ b/rust/kernel/time.rs @@ -0,0 +1,476 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! Time related primitives. +//! +//! This module contains the kernel APIs related to time and timers that +//! have been ported or wrapped for usage by Rust code in the kernel. +//! +//! There are two types in this module: +//! +//! - The [`Instant`] type represents a specific point in time. +//! - The [`Delta`] type represents a span of time. +//! +//! Note that the C side uses `ktime_t` type to represent both. However, timestamp +//! and timedelta are different. To avoid confusion, we use two different types. +//! +//! A [`Instant`] object can be created by calling the [`Instant::now()`] function. +//! It represents a point in time at which the object was created. +//! By calling the [`Instant::elapsed()`] method, a [`Delta`] object representing +//! the elapsed time can be created. The [`Delta`] object can also be created +//! by subtracting two [`Instant`] objects. +//! +//! A [`Delta`] type supports methods to retrieve the duration in various units. +//! +//! C header: [`include/linux/jiffies.h`](srctree/include/linux/jiffies.h). +//! C header: [`include/linux/ktime.h`](srctree/include/linux/ktime.h). + +use core::marker::PhantomData; +use core::ops; + +pub mod delay; +pub mod hrtimer; + +/// The number of nanoseconds per microsecond. +pub const NSEC_PER_USEC: i64 = bindings::NSEC_PER_USEC as i64; + +/// The number of nanoseconds per millisecond. +pub const NSEC_PER_MSEC: i64 = bindings::NSEC_PER_MSEC as i64; + +/// The number of nanoseconds per second. +pub const NSEC_PER_SEC: i64 = bindings::NSEC_PER_SEC as i64; + +/// The time unit of Linux kernel. One jiffy equals (1/HZ) second. +pub type Jiffies = crate::ffi::c_ulong; + +/// The millisecond time unit. +pub type Msecs = crate::ffi::c_uint; + +/// Converts milliseconds to jiffies. +#[inline] +pub fn msecs_to_jiffies(msecs: Msecs) -> Jiffies { + // SAFETY: The `__msecs_to_jiffies` function is always safe to call no + // matter what the argument is. + unsafe { bindings::__msecs_to_jiffies(msecs) } +} + +/// Trait for clock sources. +/// +/// Selection of the clock source depends on the use case. In some cases the usage of a +/// particular clock is mandatory, e.g. in network protocols, filesystems. In other +/// cases the user of the clock has to decide which clock is best suited for the +/// purpose. In most scenarios clock [`Monotonic`] is the best choice as it +/// provides a accurate monotonic notion of time (leap second smearing ignored). +pub trait ClockSource { + /// The kernel clock ID associated with this clock source. + /// + /// This constant corresponds to the C side `clockid_t` value. + const ID: bindings::clockid_t; + + /// Get the current time from the clock source. + /// + /// The function must return a value in the range from 0 to `KTIME_MAX`. + fn ktime_get() -> bindings::ktime_t; +} + +/// A monotonically increasing clock. +/// +/// A nonsettable system-wide clock that represents monotonic time since as +/// described by POSIX, "some unspecified point in the past". On Linux, that +/// point corresponds to the number of seconds that the system has been +/// running since it was booted. +/// +/// The CLOCK_MONOTONIC clock is not affected by discontinuous jumps in the +/// CLOCK_REAL (e.g., if the system administrator manually changes the +/// clock), but is affected by frequency adjustments. This clock does not +/// count time that the system is suspended. +pub struct Monotonic; + +impl ClockSource for Monotonic { + const ID: bindings::clockid_t = bindings::CLOCK_MONOTONIC as bindings::clockid_t; + + fn ktime_get() -> bindings::ktime_t { + // SAFETY: It is always safe to call `ktime_get()` outside of NMI context. + unsafe { bindings::ktime_get() } + } +} + +/// A settable system-wide clock that measures real (i.e., wall-clock) time. +/// +/// Setting this clock requires appropriate privileges. This clock is +/// affected by discontinuous jumps in the system time (e.g., if the system +/// administrator manually changes the clock), and by frequency adjustments +/// performed by NTP and similar applications via adjtime(3), adjtimex(2), +/// clock_adjtime(2), and ntp_adjtime(3). This clock normally counts the +/// number of seconds since 1970-01-01 00:00:00 Coordinated Universal Time +/// (UTC) except that it ignores leap seconds; near a leap second it may be +/// adjusted by leap second smearing to stay roughly in sync with UTC. Leap +/// second smearing applies frequency adjustments to the clock to speed up +/// or slow down the clock to account for the leap second without +/// discontinuities in the clock. If leap second smearing is not applied, +/// the clock will experience discontinuity around leap second adjustment. +pub struct RealTime; + +impl ClockSource for RealTime { + const ID: bindings::clockid_t = bindings::CLOCK_REALTIME as bindings::clockid_t; + + fn ktime_get() -> bindings::ktime_t { + // SAFETY: It is always safe to call `ktime_get_real()` outside of NMI context. + unsafe { bindings::ktime_get_real() } + } +} + +/// A monotonic that ticks while system is suspended. +/// +/// A nonsettable system-wide clock that is identical to CLOCK_MONOTONIC, +/// except that it also includes any time that the system is suspended. This +/// allows applications to get a suspend-aware monotonic clock without +/// having to deal with the complications of CLOCK_REALTIME, which may have +/// discontinuities if the time is changed using settimeofday(2) or similar. +pub struct BootTime; + +impl ClockSource for BootTime { + const ID: bindings::clockid_t = bindings::CLOCK_BOOTTIME as bindings::clockid_t; + + fn ktime_get() -> bindings::ktime_t { + // SAFETY: It is always safe to call `ktime_get_boottime()` outside of NMI context. + unsafe { bindings::ktime_get_boottime() } + } +} + +/// International Atomic Time. +/// +/// A system-wide clock derived from wall-clock time but counting leap seconds. +/// +/// This clock is coupled to CLOCK_REALTIME and will be set when CLOCK_REALTIME is +/// set, or when the offset to CLOCK_REALTIME is changed via adjtimex(2). This +/// usually happens during boot and **should** not happen during normal operations. +/// However, if NTP or another application adjusts CLOCK_REALTIME by leap second +/// smearing, this clock will not be precise during leap second smearing. +/// +/// The acronym TAI refers to International Atomic Time. +pub struct Tai; + +impl ClockSource for Tai { + const ID: bindings::clockid_t = bindings::CLOCK_TAI as bindings::clockid_t; + + fn ktime_get() -> bindings::ktime_t { + // SAFETY: It is always safe to call `ktime_get_tai()` outside of NMI context. + unsafe { bindings::ktime_get_clocktai() } + } +} + +/// A specific point in time. +/// +/// # Invariants +/// +/// The `inner` value is in the range from 0 to `KTIME_MAX`. +#[repr(transparent)] +#[derive(PartialEq, PartialOrd, Eq, Ord)] +pub struct Instant<C: ClockSource> { + inner: bindings::ktime_t, + _c: PhantomData<C>, +} + +impl<C: ClockSource> Clone for Instant<C> { + fn clone(&self) -> Self { + *self + } +} + +impl<C: ClockSource> Copy for Instant<C> {} + +impl<C: ClockSource> Instant<C> { + /// Get the current time from the clock source. + #[inline] + pub fn now() -> Self { + // INVARIANT: The `ClockSource::ktime_get()` function returns a value in the range + // from 0 to `KTIME_MAX`. + Self { + inner: C::ktime_get(), + _c: PhantomData, + } + } + + /// Return the amount of time elapsed since the [`Instant`]. + #[inline] + pub fn elapsed(&self) -> Delta { + Self::now() - *self + } + + #[inline] + pub(crate) fn as_nanos(&self) -> i64 { + self.inner + } + + /// Create an [`Instant`] from a `ktime_t` without checking if it is non-negative. + /// + /// # Panics + /// + /// On debug builds, this function will panic if `ktime` is not in the range from 0 to + /// `KTIME_MAX`. + /// + /// # Safety + /// + /// The caller promises that `ktime` is in the range from 0 to `KTIME_MAX`. + #[inline] + pub(crate) unsafe fn from_ktime(ktime: bindings::ktime_t) -> Self { + debug_assert!(ktime >= 0); + + // INVARIANT: Our safety contract ensures that `ktime` is in the range from 0 to + // `KTIME_MAX`. + Self { + inner: ktime, + _c: PhantomData, + } + } +} + +impl<C: ClockSource> ops::Sub for Instant<C> { + type Output = Delta; + + // By the type invariant, it never overflows. + #[inline] + fn sub(self, other: Instant<C>) -> Delta { + Delta { + nanos: self.inner - other.inner, + } + } +} + +impl<T: ClockSource> ops::Add<Delta> for Instant<T> { + type Output = Self; + + #[inline] + fn add(self, rhs: Delta) -> Self::Output { + // INVARIANT: With arithmetic over/underflow checks enabled, this will panic if we overflow + // (e.g. go above `KTIME_MAX`) + let res = self.inner + rhs.nanos; + + // INVARIANT: With overflow checks enabled, we verify here that the value is >= 0 + #[cfg(CONFIG_RUST_OVERFLOW_CHECKS)] + assert!(res >= 0); + + Self { + inner: res, + _c: PhantomData, + } + } +} + +impl<T: ClockSource> ops::Sub<Delta> for Instant<T> { + type Output = Self; + + #[inline] + fn sub(self, rhs: Delta) -> Self::Output { + // INVARIANT: With arithmetic over/underflow checks enabled, this will panic if we overflow + // (e.g. go above `KTIME_MAX`) + let res = self.inner - rhs.nanos; + + // INVARIANT: With overflow checks enabled, we verify here that the value is >= 0 + #[cfg(CONFIG_RUST_OVERFLOW_CHECKS)] + assert!(res >= 0); + + Self { + inner: res, + _c: PhantomData, + } + } +} + +/// A span of time. +/// +/// This struct represents a span of time, with its value stored as nanoseconds. +/// The value can represent any valid i64 value, including negative, zero, and +/// positive numbers. +#[derive(Copy, Clone, PartialEq, PartialOrd, Eq, Ord, Debug)] +pub struct Delta { + nanos: i64, +} + +impl ops::Add for Delta { + type Output = Self; + + #[inline] + fn add(self, rhs: Self) -> Self { + Self { + nanos: self.nanos + rhs.nanos, + } + } +} + +impl ops::AddAssign for Delta { + #[inline] + fn add_assign(&mut self, rhs: Self) { + self.nanos += rhs.nanos; + } +} + +impl ops::Sub for Delta { + type Output = Self; + + #[inline] + fn sub(self, rhs: Self) -> Self::Output { + Self { + nanos: self.nanos - rhs.nanos, + } + } +} + +impl ops::SubAssign for Delta { + #[inline] + fn sub_assign(&mut self, rhs: Self) { + self.nanos -= rhs.nanos; + } +} + +impl ops::Mul<i64> for Delta { + type Output = Self; + + #[inline] + fn mul(self, rhs: i64) -> Self::Output { + Self { + nanos: self.nanos * rhs, + } + } +} + +impl ops::MulAssign<i64> for Delta { + #[inline] + fn mul_assign(&mut self, rhs: i64) { + self.nanos *= rhs; + } +} + +impl ops::Div for Delta { + type Output = i64; + + #[inline] + fn div(self, rhs: Self) -> Self::Output { + #[cfg(CONFIG_64BIT)] + { + self.nanos / rhs.nanos + } + + #[cfg(not(CONFIG_64BIT))] + { + // SAFETY: This function is always safe to call regardless of the input values + unsafe { bindings::div64_s64(self.nanos, rhs.nanos) } + } + } +} + +impl Delta { + /// A span of time equal to zero. + pub const ZERO: Self = Self { nanos: 0 }; + + /// Create a new [`Delta`] from a number of microseconds. + /// + /// The `micros` can range from -9_223_372_036_854_775 to 9_223_372_036_854_775. + /// If `micros` is outside this range, `i64::MIN` is used for negative values, + /// and `i64::MAX` is used for positive values due to saturation. + #[inline] + pub const fn from_micros(micros: i64) -> Self { + Self { + nanos: micros.saturating_mul(NSEC_PER_USEC), + } + } + + /// Create a new [`Delta`] from a number of milliseconds. + /// + /// The `millis` can range from -9_223_372_036_854 to 9_223_372_036_854. + /// If `millis` is outside this range, `i64::MIN` is used for negative values, + /// and `i64::MAX` is used for positive values due to saturation. + #[inline] + pub const fn from_millis(millis: i64) -> Self { + Self { + nanos: millis.saturating_mul(NSEC_PER_MSEC), + } + } + + /// Create a new [`Delta`] from a number of seconds. + /// + /// The `secs` can range from -9_223_372_036 to 9_223_372_036. + /// If `secs` is outside this range, `i64::MIN` is used for negative values, + /// and `i64::MAX` is used for positive values due to saturation. + #[inline] + pub const fn from_secs(secs: i64) -> Self { + Self { + nanos: secs.saturating_mul(NSEC_PER_SEC), + } + } + + /// Return `true` if the [`Delta`] spans no time. + #[inline] + pub fn is_zero(self) -> bool { + self.as_nanos() == 0 + } + + /// Return `true` if the [`Delta`] spans a negative amount of time. + #[inline] + pub fn is_negative(self) -> bool { + self.as_nanos() < 0 + } + + /// Return the number of nanoseconds in the [`Delta`]. + #[inline] + pub const fn as_nanos(self) -> i64 { + self.nanos + } + + /// Return the smallest number of microseconds greater than or equal + /// to the value in the [`Delta`]. + #[inline] + pub fn as_micros_ceil(self) -> i64 { + #[cfg(CONFIG_64BIT)] + { + self.as_nanos().saturating_add(NSEC_PER_USEC - 1) / NSEC_PER_USEC + } + + #[cfg(not(CONFIG_64BIT))] + // SAFETY: It is always safe to call `ktime_to_us()` with any value. + unsafe { + bindings::ktime_to_us(self.as_nanos().saturating_add(NSEC_PER_USEC - 1)) + } + } + + /// Return the number of milliseconds in the [`Delta`]. + #[inline] + pub fn as_millis(self) -> i64 { + #[cfg(CONFIG_64BIT)] + { + self.as_nanos() / NSEC_PER_MSEC + } + + #[cfg(not(CONFIG_64BIT))] + // SAFETY: It is always safe to call `ktime_to_ms()` with any value. + unsafe { + bindings::ktime_to_ms(self.as_nanos()) + } + } + + /// Return `self % dividend` where `dividend` is in nanoseconds. + /// + /// The kernel doesn't have any emulation for `s64 % s64` on 32 bit platforms, so this is + /// limited to 32 bit dividends. + #[inline] + pub fn rem_nanos(self, dividend: i32) -> Self { + #[cfg(CONFIG_64BIT)] + { + Self { + nanos: self.as_nanos() % i64::from(dividend), + } + } + + #[cfg(not(CONFIG_64BIT))] + { + let mut rem = 0; + + // SAFETY: `rem` is in the stack, so we can always provide a valid pointer to it. + unsafe { bindings::div_s64_rem(self.as_nanos(), dividend, &mut rem) }; + + Self { + nanos: i64::from(rem), + } + } + } +} diff --git a/rust/kernel/time/delay.rs b/rust/kernel/time/delay.rs new file mode 100644 index 000000000000..b5b1b42797a0 --- /dev/null +++ b/rust/kernel/time/delay.rs @@ -0,0 +1,86 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! Delay and sleep primitives. +//! +//! This module contains the kernel APIs related to delay and sleep that +//! have been ported or wrapped for usage by Rust code in the kernel. +//! +//! C header: [`include/linux/delay.h`](srctree/include/linux/delay.h). + +use super::Delta; +use crate::prelude::*; + +/// Sleeps for a given duration at least. +/// +/// Equivalent to the C side [`fsleep()`], flexible sleep function, +/// which automatically chooses the best sleep method based on a duration. +/// +/// `delta` must be within `[0, i32::MAX]` microseconds; +/// otherwise, it is erroneous behavior. That is, it is considered a bug +/// to call this function with an out-of-range value, in which case the function +/// will sleep for at least the maximum value in the range and may warn +/// in the future. +/// +/// The behavior above differs from the C side [`fsleep()`] for which out-of-range +/// values mean "infinite timeout" instead. +/// +/// This function can only be used in a nonatomic context. +/// +/// [`fsleep()`]: https://docs.kernel.org/timers/delay_sleep_functions.html#c.fsleep +pub fn fsleep(delta: Delta) { + // The maximum value is set to `i32::MAX` microseconds to prevent integer + // overflow inside fsleep, which could lead to unintentional infinite sleep. + const MAX_DELTA: Delta = Delta::from_micros(i32::MAX as i64); + + let delta = if (Delta::ZERO..=MAX_DELTA).contains(&delta) { + delta + } else { + // TODO: Add WARN_ONCE() when it's supported. + MAX_DELTA + }; + + // SAFETY: It is always safe to call `fsleep()` with any duration. + unsafe { + // Convert the duration to microseconds and round up to preserve + // the guarantee; `fsleep()` sleeps for at least the provided duration, + // but that it may sleep for longer under some circumstances. + bindings::fsleep(delta.as_micros_ceil() as c_ulong) + } +} + +/// Inserts a delay based on microseconds with busy waiting. +/// +/// Equivalent to the C side [`udelay()`], which delays in microseconds. +/// +/// `delta` must be within `[0, MAX_UDELAY_MS]` in milliseconds; +/// otherwise, it is erroneous behavior. That is, it is considered a bug to +/// call this function with an out-of-range value. +/// +/// The behavior above differs from the C side [`udelay()`] for which out-of-range +/// values could lead to an overflow and unexpected behavior. +/// +/// [`udelay()`]: https://docs.kernel.org/timers/delay_sleep_functions.html#c.udelay +pub fn udelay(delta: Delta) { + const MAX_UDELAY_DELTA: Delta = Delta::from_millis(bindings::MAX_UDELAY_MS as i64); + + debug_assert!(delta.as_nanos() >= 0); + debug_assert!(delta <= MAX_UDELAY_DELTA); + + let delta = if (Delta::ZERO..=MAX_UDELAY_DELTA).contains(&delta) { + delta + } else { + MAX_UDELAY_DELTA + }; + + // SAFETY: It is always safe to call `udelay()` with any duration. + // Note that the kernel is compiled with `-fno-strict-overflow` + // so any out-of-range value could lead to unexpected behavior + // but won't lead to undefined behavior. + unsafe { + // Convert the duration to microseconds and round up to preserve + // the guarantee; `udelay()` inserts a delay for at least + // the provided duration, but that it may delay for longer + // under some circumstances. + bindings::udelay(delta.as_micros_ceil() as c_ulong) + } +} diff --git a/rust/kernel/time/hrtimer.rs b/rust/kernel/time/hrtimer.rs new file mode 100644 index 000000000000..856d2d929a00 --- /dev/null +++ b/rust/kernel/time/hrtimer.rs @@ -0,0 +1,784 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! Intrusive high resolution timers. +//! +//! Allows running timer callbacks without doing allocations at the time of +//! starting the timer. For now, only one timer per type is allowed. +//! +//! # Vocabulary +//! +//! States: +//! +//! - Stopped: initialized but not started, or cancelled, or not restarted. +//! - Started: initialized and started or restarted. +//! - Running: executing the callback. +//! +//! Operations: +//! +//! * Start +//! * Cancel +//! * Restart +//! +//! Events: +//! +//! * Expire +//! +//! ## State Diagram +//! +//! ```text +//! Return NoRestart +//! +---------------------------------------------------------------------+ +//! | | +//! | | +//! | | +//! | Return Restart | +//! | +------------------------+ | +//! | | | | +//! | | | | +//! v v | | +//! +-----------------+ Start +------------------+ +--------+-----+--+ +//! | +---------------->| | | | +//! Init | | | | Expire | | +//! --------->| Stopped | | Started +---------->| Running | +//! | | Cancel | | | | +//! | |<----------------+ | | | +//! +-----------------+ +---------------+--+ +-----------------+ +//! ^ | +//! | | +//! +---------+ +//! Restart +//! ``` +//! +//! +//! A timer is initialized in the **stopped** state. A stopped timer can be +//! **started** by the `start` operation, with an **expiry** time. After the +//! `start` operation, the timer is in the **started** state. When the timer +//! **expires**, the timer enters the **running** state and the handler is +//! executed. After the handler has returned, the timer may enter the +//! **started* or **stopped** state, depending on the return value of the +//! handler. A timer in the **started** or **running** state may be **canceled** +//! by the `cancel` operation. A timer that is cancelled enters the **stopped** +//! state. +//! +//! A `cancel` or `restart` operation on a timer in the **running** state takes +//! effect after the handler has returned and the timer has transitioned +//! out of the **running** state. +//! +//! A `restart` operation on a timer in the **stopped** state is equivalent to a +//! `start` operation. + +use super::{ClockSource, Delta, Instant}; +use crate::{prelude::*, types::Opaque}; +use core::{marker::PhantomData, ptr::NonNull}; +use pin_init::PinInit; + +/// A type-alias to refer to the [`Instant<C>`] for a given `T` from [`HrTimer<T>`]. +/// +/// Where `C` is the [`ClockSource`] of the [`HrTimer`]. +pub type HrTimerInstant<T> = Instant<<<T as HasHrTimer<T>>::TimerMode as HrTimerMode>::Clock>; + +/// A timer backed by a C `struct hrtimer`. +/// +/// # Invariants +/// +/// * `self.timer` is initialized by `bindings::hrtimer_setup`. +#[pin_data] +#[repr(C)] +pub struct HrTimer<T> { + #[pin] + timer: Opaque<bindings::hrtimer>, + _t: PhantomData<T>, +} + +// SAFETY: Ownership of an `HrTimer` can be moved to other threads and +// used/dropped from there. +unsafe impl<T> Send for HrTimer<T> {} + +// SAFETY: Timer operations are locked on the C side, so it is safe to operate +// on a timer from multiple threads. +unsafe impl<T> Sync for HrTimer<T> {} + +impl<T> HrTimer<T> { + /// Return an initializer for a new timer instance. + pub fn new() -> impl PinInit<Self> + where + T: HrTimerCallback, + T: HasHrTimer<T>, + { + pin_init!(Self { + // INVARIANT: We initialize `timer` with `hrtimer_setup` below. + timer <- Opaque::ffi_init(move |place: *mut bindings::hrtimer| { + // SAFETY: By design of `pin_init!`, `place` is a pointer to a + // live allocation. hrtimer_setup will initialize `place` and + // does not require `place` to be initialized prior to the call. + unsafe { + bindings::hrtimer_setup( + place, + Some(T::Pointer::run), + <<T as HasHrTimer<T>>::TimerMode as HrTimerMode>::Clock::ID, + <T as HasHrTimer<T>>::TimerMode::C_MODE, + ); + } + }), + _t: PhantomData, + }) + } + + /// Get a pointer to the contained `bindings::hrtimer`. + /// + /// This function is useful to get access to the value without creating + /// intermediate references. + /// + /// # Safety + /// + /// `this` must point to a live allocation of at least the size of `Self`. + unsafe fn raw_get(this: *const Self) -> *mut bindings::hrtimer { + // SAFETY: The field projection to `timer` does not go out of bounds, + // because the caller of this function promises that `this` points to an + // allocation of at least the size of `Self`. + unsafe { Opaque::cast_into(core::ptr::addr_of!((*this).timer)) } + } + + /// Cancel an initialized and potentially running timer. + /// + /// If the timer handler is running, this function will block until the + /// handler returns. + /// + /// Note that the timer might be started by a concurrent start operation. If + /// so, the timer might not be in the **stopped** state when this function + /// returns. + /// + /// Users of the `HrTimer` API would not usually call this method directly. + /// Instead they would use the safe [`HrTimerHandle::cancel`] on the handle + /// returned when the timer was started. + /// + /// This function is useful to get access to the value without creating + /// intermediate references. + /// + /// # Safety + /// + /// `this` must point to a valid `Self`. + pub(crate) unsafe fn raw_cancel(this: *const Self) -> bool { + // SAFETY: `this` points to an allocation of at least `HrTimer` size. + let c_timer_ptr = unsafe { HrTimer::raw_get(this) }; + + // If the handler is running, this will wait for the handler to return + // before returning. + // SAFETY: `c_timer_ptr` is initialized and valid. Synchronization is + // handled on the C side. + unsafe { bindings::hrtimer_cancel(c_timer_ptr) != 0 } + } + + /// Forward the timer expiry for a given timer pointer. + /// + /// # Safety + /// + /// - `self_ptr` must point to a valid `Self`. + /// - The caller must either have exclusive access to the data pointed at by `self_ptr`, or be + /// within the context of the timer callback. + #[inline] + unsafe fn raw_forward(self_ptr: *mut Self, now: HrTimerInstant<T>, interval: Delta) -> u64 + where + T: HasHrTimer<T>, + { + // SAFETY: + // * The C API requirements for this function are fulfilled by our safety contract. + // * `self_ptr` is guaranteed to point to a valid `Self` via our safety contract + unsafe { + bindings::hrtimer_forward(Self::raw_get(self_ptr), now.as_nanos(), interval.as_nanos()) + } + } + + /// Conditionally forward the timer. + /// + /// If the timer expires after `now`, this function does nothing and returns 0. If the timer + /// expired at or before `now`, this function forwards the timer by `interval` until the timer + /// expires after `now` and then returns the number of times the timer was forwarded by + /// `interval`. + /// + /// This function is mainly useful for timer types which can provide exclusive access to the + /// timer when the timer is not running. For forwarding the timer from within the timer callback + /// context, see [`HrTimerCallbackContext::forward()`]. + /// + /// Returns the number of overruns that occurred as a result of the timer expiry change. + pub fn forward(self: Pin<&mut Self>, now: HrTimerInstant<T>, interval: Delta) -> u64 + where + T: HasHrTimer<T>, + { + // SAFETY: `raw_forward` does not move `Self` + let this = unsafe { self.get_unchecked_mut() }; + + // SAFETY: By existence of `Pin<&mut Self>`, the pointer passed to `raw_forward` points to a + // valid `Self` that we have exclusive access to. + unsafe { Self::raw_forward(this, now, interval) } + } + + /// Conditionally forward the timer. + /// + /// This is a variant of [`forward()`](Self::forward) that uses an interval after the current + /// time of the base clock for the [`HrTimer`]. + pub fn forward_now(self: Pin<&mut Self>, interval: Delta) -> u64 + where + T: HasHrTimer<T>, + { + self.forward(HrTimerInstant::<T>::now(), interval) + } + + /// Return the time expiry for this [`HrTimer`]. + /// + /// This value should only be used as a snapshot, as the actual expiry time could change after + /// this function is called. + pub fn expires(&self) -> HrTimerInstant<T> + where + T: HasHrTimer<T>, + { + // SAFETY: `self` is an immutable reference and thus always points to a valid `HrTimer`. + let c_timer_ptr = unsafe { HrTimer::raw_get(self) }; + + // SAFETY: + // - Timers cannot have negative ktime_t values as their expiration time. + // - There's no actual locking here, a racy read is fine and expected + unsafe { + Instant::from_ktime( + // This `read_volatile` is intended to correspond to a READ_ONCE call. + // FIXME(read_once): Replace with `read_once` when available on the Rust side. + core::ptr::read_volatile(&raw const ((*c_timer_ptr).node.expires)), + ) + } + } +} + +/// Implemented by pointer types that point to structs that contain a [`HrTimer`]. +/// +/// `Self` must be [`Sync`] because it is passed to timer callbacks in another +/// thread of execution (hard or soft interrupt context). +/// +/// Starting a timer returns a [`HrTimerHandle`] that can be used to manipulate +/// the timer. Note that it is OK to call the start function repeatedly, and +/// that more than one [`HrTimerHandle`] associated with a [`HrTimerPointer`] may +/// exist. A timer can be manipulated through any of the handles, and a handle +/// may represent a cancelled timer. +pub trait HrTimerPointer: Sync + Sized { + /// The operational mode associated with this timer. + /// + /// This defines how the expiration value is interpreted. + type TimerMode: HrTimerMode; + + /// A handle representing a started or restarted timer. + /// + /// If the timer is running or if the timer callback is executing when the + /// handle is dropped, the drop method of [`HrTimerHandle`] should not return + /// until the timer is stopped and the callback has completed. + /// + /// Note: When implementing this trait, consider that it is not unsafe to + /// leak the handle. + type TimerHandle: HrTimerHandle; + + /// Start the timer with expiry after `expires` time units. If the timer was + /// already running, it is restarted with the new expiry time. + fn start(self, expires: <Self::TimerMode as HrTimerMode>::Expires) -> Self::TimerHandle; +} + +/// Unsafe version of [`HrTimerPointer`] for situations where leaking the +/// [`HrTimerHandle`] returned by `start` would be unsound. This is the case for +/// stack allocated timers. +/// +/// Typical implementers are pinned references such as [`Pin<&T>`]. +/// +/// # Safety +/// +/// Implementers of this trait must ensure that instances of types implementing +/// [`UnsafeHrTimerPointer`] outlives any associated [`HrTimerPointer::TimerHandle`] +/// instances. +pub unsafe trait UnsafeHrTimerPointer: Sync + Sized { + /// The operational mode associated with this timer. + /// + /// This defines how the expiration value is interpreted. + type TimerMode: HrTimerMode; + + /// A handle representing a running timer. + /// + /// # Safety + /// + /// If the timer is running, or if the timer callback is executing when the + /// handle is dropped, the drop method of [`Self::TimerHandle`] must not return + /// until the timer is stopped and the callback has completed. + type TimerHandle: HrTimerHandle; + + /// Start the timer after `expires` time units. If the timer was already + /// running, it is restarted at the new expiry time. + /// + /// # Safety + /// + /// Caller promises keep the timer structure alive until the timer is dead. + /// Caller can ensure this by not leaking the returned [`Self::TimerHandle`]. + unsafe fn start(self, expires: <Self::TimerMode as HrTimerMode>::Expires) -> Self::TimerHandle; +} + +/// A trait for stack allocated timers. +/// +/// # Safety +/// +/// Implementers must ensure that `start_scoped` does not return until the +/// timer is dead and the timer handler is not running. +pub unsafe trait ScopedHrTimerPointer { + /// The operational mode associated with this timer. + /// + /// This defines how the expiration value is interpreted. + type TimerMode: HrTimerMode; + + /// Start the timer to run after `expires` time units and immediately + /// after call `f`. When `f` returns, the timer is cancelled. + fn start_scoped<T, F>(self, expires: <Self::TimerMode as HrTimerMode>::Expires, f: F) -> T + where + F: FnOnce() -> T; +} + +// SAFETY: By the safety requirement of [`UnsafeHrTimerPointer`], dropping the +// handle returned by [`UnsafeHrTimerPointer::start`] ensures that the timer is +// killed. +unsafe impl<T> ScopedHrTimerPointer for T +where + T: UnsafeHrTimerPointer, +{ + type TimerMode = T::TimerMode; + + fn start_scoped<U, F>( + self, + expires: <<T as UnsafeHrTimerPointer>::TimerMode as HrTimerMode>::Expires, + f: F, + ) -> U + where + F: FnOnce() -> U, + { + // SAFETY: We drop the timer handle below before returning. + let handle = unsafe { UnsafeHrTimerPointer::start(self, expires) }; + let t = f(); + drop(handle); + t + } +} + +/// Implemented by [`HrTimerPointer`] implementers to give the C timer callback a +/// function to call. +// This is split from `HrTimerPointer` to make it easier to specify trait bounds. +pub trait RawHrTimerCallback { + /// Type of the parameter passed to [`HrTimerCallback::run`]. It may be + /// [`Self`], or a pointer type derived from [`Self`]. + type CallbackTarget<'a>; + + /// Callback to be called from C when timer fires. + /// + /// # Safety + /// + /// Only to be called by C code in the `hrtimer` subsystem. `this` must point + /// to the `bindings::hrtimer` structure that was used to start the timer. + unsafe extern "C" fn run(this: *mut bindings::hrtimer) -> bindings::hrtimer_restart; +} + +/// Implemented by structs that can be the target of a timer callback. +pub trait HrTimerCallback { + /// The type whose [`RawHrTimerCallback::run`] method will be invoked when + /// the timer expires. + type Pointer<'a>: RawHrTimerCallback; + + /// Called by the timer logic when the timer fires. + fn run( + this: <Self::Pointer<'_> as RawHrTimerCallback>::CallbackTarget<'_>, + ctx: HrTimerCallbackContext<'_, Self>, + ) -> HrTimerRestart + where + Self: Sized, + Self: HasHrTimer<Self>; +} + +/// A handle representing a potentially running timer. +/// +/// More than one handle representing the same timer might exist. +/// +/// # Safety +/// +/// When dropped, the timer represented by this handle must be cancelled, if it +/// is running. If the timer handler is running when the handle is dropped, the +/// drop method must wait for the handler to return before returning. +/// +/// Note: One way to satisfy the safety requirement is to call `Self::cancel` in +/// the drop implementation for `Self.` +pub unsafe trait HrTimerHandle { + /// Cancel the timer. If the timer is in the running state, block till the + /// handler has returned. + /// + /// Note that the timer might be started by a concurrent start operation. If + /// so, the timer might not be in the **stopped** state when this function + /// returns. + /// + /// Returns `true` if the timer was running. + fn cancel(&mut self) -> bool; +} + +/// Implemented by structs that contain timer nodes. +/// +/// Clients of the timer API would usually safely implement this trait by using +/// the [`crate::impl_has_hr_timer`] macro. +/// +/// # Safety +/// +/// Implementers of this trait must ensure that the implementer has a +/// [`HrTimer`] field and that all trait methods are implemented according to +/// their documentation. All the methods of this trait must operate on the same +/// field. +pub unsafe trait HasHrTimer<T> { + /// The operational mode associated with this timer. + /// + /// This defines how the expiration value is interpreted. + type TimerMode: HrTimerMode; + + /// Return a pointer to the [`HrTimer`] within `Self`. + /// + /// This function is useful to get access to the value without creating + /// intermediate references. + /// + /// # Safety + /// + /// `this` must be a valid pointer. + unsafe fn raw_get_timer(this: *const Self) -> *const HrTimer<T>; + + /// Return a pointer to the struct that is containing the [`HrTimer`] pointed + /// to by `ptr`. + /// + /// This function is useful to get access to the value without creating + /// intermediate references. + /// + /// # Safety + /// + /// `ptr` must point to a [`HrTimer<T>`] field in a struct of type `Self`. + unsafe fn timer_container_of(ptr: *mut HrTimer<T>) -> *mut Self + where + Self: Sized; + + /// Get pointer to the contained `bindings::hrtimer` struct. + /// + /// This function is useful to get access to the value without creating + /// intermediate references. + /// + /// # Safety + /// + /// `this` must be a valid pointer. + unsafe fn c_timer_ptr(this: *const Self) -> *const bindings::hrtimer { + // SAFETY: `this` is a valid pointer to a `Self`. + let timer_ptr = unsafe { Self::raw_get_timer(this) }; + + // SAFETY: timer_ptr points to an allocation of at least `HrTimer` size. + unsafe { HrTimer::raw_get(timer_ptr) } + } + + /// Start the timer contained in the `Self` pointed to by `self_ptr`. If + /// it is already running it is removed and inserted. + /// + /// # Safety + /// + /// - `this` must point to a valid `Self`. + /// - Caller must ensure that the pointee of `this` lives until the timer + /// fires or is canceled. + unsafe fn start(this: *const Self, expires: <Self::TimerMode as HrTimerMode>::Expires) { + // SAFETY: By function safety requirement, `this` is a valid `Self`. + unsafe { + bindings::hrtimer_start_range_ns( + Self::c_timer_ptr(this).cast_mut(), + expires.as_nanos(), + 0, + <Self::TimerMode as HrTimerMode>::C_MODE, + ); + } + } +} + +/// Restart policy for timers. +#[derive(Copy, Clone, PartialEq, Eq, Debug)] +#[repr(u32)] +pub enum HrTimerRestart { + /// Timer should not be restarted. + NoRestart = bindings::hrtimer_restart_HRTIMER_NORESTART, + /// Timer should be restarted. + Restart = bindings::hrtimer_restart_HRTIMER_RESTART, +} + +impl HrTimerRestart { + fn into_c(self) -> bindings::hrtimer_restart { + self as bindings::hrtimer_restart + } +} + +/// Time representations that can be used as expiration values in [`HrTimer`]. +pub trait HrTimerExpires { + /// Converts the expiration time into a nanosecond representation. + /// + /// This value corresponds to a raw ktime_t value, suitable for passing to kernel + /// timer functions. The interpretation (absolute vs relative) depends on the + /// associated [HrTimerMode] in use. + fn as_nanos(&self) -> i64; +} + +impl<C: ClockSource> HrTimerExpires for Instant<C> { + #[inline] + fn as_nanos(&self) -> i64 { + Instant::<C>::as_nanos(self) + } +} + +impl HrTimerExpires for Delta { + #[inline] + fn as_nanos(&self) -> i64 { + Delta::as_nanos(*self) + } +} + +mod private { + use crate::time::ClockSource; + + pub trait Sealed {} + + impl<C: ClockSource> Sealed for super::AbsoluteMode<C> {} + impl<C: ClockSource> Sealed for super::RelativeMode<C> {} + impl<C: ClockSource> Sealed for super::AbsolutePinnedMode<C> {} + impl<C: ClockSource> Sealed for super::RelativePinnedMode<C> {} + impl<C: ClockSource> Sealed for super::AbsoluteSoftMode<C> {} + impl<C: ClockSource> Sealed for super::RelativeSoftMode<C> {} + impl<C: ClockSource> Sealed for super::AbsolutePinnedSoftMode<C> {} + impl<C: ClockSource> Sealed for super::RelativePinnedSoftMode<C> {} + impl<C: ClockSource> Sealed for super::AbsoluteHardMode<C> {} + impl<C: ClockSource> Sealed for super::RelativeHardMode<C> {} + impl<C: ClockSource> Sealed for super::AbsolutePinnedHardMode<C> {} + impl<C: ClockSource> Sealed for super::RelativePinnedHardMode<C> {} +} + +/// Operational mode of [`HrTimer`]. +pub trait HrTimerMode: private::Sealed { + /// The C representation of hrtimer mode. + const C_MODE: bindings::hrtimer_mode; + + /// Type representing the clock source. + type Clock: ClockSource; + + /// Type representing the expiration specification (absolute or relative time). + type Expires: HrTimerExpires; +} + +/// Timer that expires at a fixed point in time. +pub struct AbsoluteMode<C: ClockSource>(PhantomData<C>); + +impl<C: ClockSource> HrTimerMode for AbsoluteMode<C> { + const C_MODE: bindings::hrtimer_mode = bindings::hrtimer_mode_HRTIMER_MODE_ABS; + + type Clock = C; + type Expires = Instant<C>; +} + +/// Timer that expires after a delay from now. +pub struct RelativeMode<C: ClockSource>(PhantomData<C>); + +impl<C: ClockSource> HrTimerMode for RelativeMode<C> { + const C_MODE: bindings::hrtimer_mode = bindings::hrtimer_mode_HRTIMER_MODE_REL; + + type Clock = C; + type Expires = Delta; +} + +/// Timer with absolute expiration time, pinned to its current CPU. +pub struct AbsolutePinnedMode<C: ClockSource>(PhantomData<C>); +impl<C: ClockSource> HrTimerMode for AbsolutePinnedMode<C> { + const C_MODE: bindings::hrtimer_mode = bindings::hrtimer_mode_HRTIMER_MODE_ABS_PINNED; + + type Clock = C; + type Expires = Instant<C>; +} + +/// Timer with relative expiration time, pinned to its current CPU. +pub struct RelativePinnedMode<C: ClockSource>(PhantomData<C>); +impl<C: ClockSource> HrTimerMode for RelativePinnedMode<C> { + const C_MODE: bindings::hrtimer_mode = bindings::hrtimer_mode_HRTIMER_MODE_REL_PINNED; + + type Clock = C; + type Expires = Delta; +} + +/// Timer with absolute expiration, handled in soft irq context. +pub struct AbsoluteSoftMode<C: ClockSource>(PhantomData<C>); +impl<C: ClockSource> HrTimerMode for AbsoluteSoftMode<C> { + const C_MODE: bindings::hrtimer_mode = bindings::hrtimer_mode_HRTIMER_MODE_ABS_SOFT; + + type Clock = C; + type Expires = Instant<C>; +} + +/// Timer with relative expiration, handled in soft irq context. +pub struct RelativeSoftMode<C: ClockSource>(PhantomData<C>); +impl<C: ClockSource> HrTimerMode for RelativeSoftMode<C> { + const C_MODE: bindings::hrtimer_mode = bindings::hrtimer_mode_HRTIMER_MODE_REL_SOFT; + + type Clock = C; + type Expires = Delta; +} + +/// Timer with absolute expiration, pinned to CPU and handled in soft irq context. +pub struct AbsolutePinnedSoftMode<C: ClockSource>(PhantomData<C>); +impl<C: ClockSource> HrTimerMode for AbsolutePinnedSoftMode<C> { + const C_MODE: bindings::hrtimer_mode = bindings::hrtimer_mode_HRTIMER_MODE_ABS_PINNED_SOFT; + + type Clock = C; + type Expires = Instant<C>; +} + +/// Timer with absolute expiration, pinned to CPU and handled in soft irq context. +pub struct RelativePinnedSoftMode<C: ClockSource>(PhantomData<C>); +impl<C: ClockSource> HrTimerMode for RelativePinnedSoftMode<C> { + const C_MODE: bindings::hrtimer_mode = bindings::hrtimer_mode_HRTIMER_MODE_REL_PINNED_SOFT; + + type Clock = C; + type Expires = Delta; +} + +/// Timer with absolute expiration, handled in hard irq context. +pub struct AbsoluteHardMode<C: ClockSource>(PhantomData<C>); +impl<C: ClockSource> HrTimerMode for AbsoluteHardMode<C> { + const C_MODE: bindings::hrtimer_mode = bindings::hrtimer_mode_HRTIMER_MODE_ABS_HARD; + + type Clock = C; + type Expires = Instant<C>; +} + +/// Timer with relative expiration, handled in hard irq context. +pub struct RelativeHardMode<C: ClockSource>(PhantomData<C>); +impl<C: ClockSource> HrTimerMode for RelativeHardMode<C> { + const C_MODE: bindings::hrtimer_mode = bindings::hrtimer_mode_HRTIMER_MODE_REL_HARD; + + type Clock = C; + type Expires = Delta; +} + +/// Timer with absolute expiration, pinned to CPU and handled in hard irq context. +pub struct AbsolutePinnedHardMode<C: ClockSource>(PhantomData<C>); +impl<C: ClockSource> HrTimerMode for AbsolutePinnedHardMode<C> { + const C_MODE: bindings::hrtimer_mode = bindings::hrtimer_mode_HRTIMER_MODE_ABS_PINNED_HARD; + + type Clock = C; + type Expires = Instant<C>; +} + +/// Timer with relative expiration, pinned to CPU and handled in hard irq context. +pub struct RelativePinnedHardMode<C: ClockSource>(PhantomData<C>); +impl<C: ClockSource> HrTimerMode for RelativePinnedHardMode<C> { + const C_MODE: bindings::hrtimer_mode = bindings::hrtimer_mode_HRTIMER_MODE_REL_PINNED_HARD; + + type Clock = C; + type Expires = Delta; +} + +/// Privileged smart-pointer for a [`HrTimer`] callback context. +/// +/// Many [`HrTimer`] methods can only be called in two situations: +/// +/// * When the caller has exclusive access to the `HrTimer` and the `HrTimer` is guaranteed not to +/// be running. +/// * From within the context of an `HrTimer`'s callback method. +/// +/// This type provides access to said methods from within a timer callback context. +/// +/// # Invariants +/// +/// * The existence of this type means the caller is currently within the callback for an +/// [`HrTimer`]. +/// * `self.0` always points to a live instance of [`HrTimer<T>`]. +pub struct HrTimerCallbackContext<'a, T: HasHrTimer<T>>(NonNull<HrTimer<T>>, PhantomData<&'a ()>); + +impl<'a, T: HasHrTimer<T>> HrTimerCallbackContext<'a, T> { + /// Create a new [`HrTimerCallbackContext`]. + /// + /// # Safety + /// + /// This function relies on the caller being within the context of a timer callback, so it must + /// not be used anywhere except for within implementations of [`RawHrTimerCallback::run`]. The + /// caller promises that `timer` points to a valid initialized instance of + /// [`bindings::hrtimer`]. + /// + /// The returned `Self` must not outlive the function context of [`RawHrTimerCallback::run`] + /// where this function is called. + pub(crate) unsafe fn from_raw(timer: *mut HrTimer<T>) -> Self { + // SAFETY: The caller guarantees `timer` is a valid pointer to an initialized + // `bindings::hrtimer` + // INVARIANT: Our safety contract ensures that we're within the context of a timer callback + // and that `timer` points to a live instance of `HrTimer<T>`. + Self(unsafe { NonNull::new_unchecked(timer) }, PhantomData) + } + + /// Conditionally forward the timer. + /// + /// This function is identical to [`HrTimer::forward()`] except that it may only be used from + /// within the context of a [`HrTimer`] callback. + pub fn forward(&mut self, now: HrTimerInstant<T>, interval: Delta) -> u64 { + // SAFETY: + // - We are guaranteed to be within the context of a timer callback by our type invariants + // - By our type invariants, `self.0` always points to a valid `HrTimer<T>` + unsafe { HrTimer::<T>::raw_forward(self.0.as_ptr(), now, interval) } + } + + /// Conditionally forward the timer. + /// + /// This is a variant of [`HrTimerCallbackContext::forward()`] that uses an interval after the + /// current time of the base clock for the [`HrTimer`]. + pub fn forward_now(&mut self, duration: Delta) -> u64 { + self.forward(HrTimerInstant::<T>::now(), duration) + } +} + +/// Use to implement the [`HasHrTimer<T>`] trait. +/// +/// See [`module`] documentation for an example. +/// +/// [`module`]: crate::time::hrtimer +#[macro_export] +macro_rules! impl_has_hr_timer { + ( + impl$({$($generics:tt)*})? + HasHrTimer<$timer_type:ty> + for $self:ty + { + mode : $mode:ty, + field : self.$field:ident $(,)? + } + $($rest:tt)* + ) => { + // SAFETY: This implementation of `raw_get_timer` only compiles if the + // field has the right type. + unsafe impl$(<$($generics)*>)? $crate::time::hrtimer::HasHrTimer<$timer_type> for $self { + type TimerMode = $mode; + + #[inline] + unsafe fn raw_get_timer( + this: *const Self, + ) -> *const $crate::time::hrtimer::HrTimer<$timer_type> { + // SAFETY: The caller promises that the pointer is not dangling. + unsafe { ::core::ptr::addr_of!((*this).$field) } + } + + #[inline] + unsafe fn timer_container_of( + ptr: *mut $crate::time::hrtimer::HrTimer<$timer_type>, + ) -> *mut Self { + // SAFETY: As per the safety requirement of this function, `ptr` + // is pointing inside a `$timer_type`. + unsafe { ::kernel::container_of!(ptr, $timer_type, $field) } + } + } + } +} + +mod arc; +pub use arc::ArcHrTimerHandle; +mod pin; +pub use pin::PinHrTimerHandle; +mod pin_mut; +pub use pin_mut::PinMutHrTimerHandle; +// `box` is a reserved keyword, so prefix with `t` for timer +mod tbox; +pub use tbox::BoxHrTimerHandle; diff --git a/rust/kernel/time/hrtimer/arc.rs b/rust/kernel/time/hrtimer/arc.rs new file mode 100644 index 000000000000..7be82bcb352a --- /dev/null +++ b/rust/kernel/time/hrtimer/arc.rs @@ -0,0 +1,111 @@ +// SPDX-License-Identifier: GPL-2.0 + +use super::HasHrTimer; +use super::HrTimer; +use super::HrTimerCallback; +use super::HrTimerCallbackContext; +use super::HrTimerHandle; +use super::HrTimerMode; +use super::HrTimerPointer; +use super::RawHrTimerCallback; +use crate::sync::Arc; +use crate::sync::ArcBorrow; + +/// A handle for an `Arc<HasHrTimer<T>>` returned by a call to +/// [`HrTimerPointer::start`]. +pub struct ArcHrTimerHandle<T> +where + T: HasHrTimer<T>, +{ + pub(crate) inner: Arc<T>, +} + +// SAFETY: We implement drop below, and we cancel the timer in the drop +// implementation. +unsafe impl<T> HrTimerHandle for ArcHrTimerHandle<T> +where + T: HasHrTimer<T>, +{ + fn cancel(&mut self) -> bool { + let self_ptr = Arc::as_ptr(&self.inner); + + // SAFETY: As we obtained `self_ptr` from a valid reference above, it + // must point to a valid `T`. + let timer_ptr = unsafe { <T as HasHrTimer<T>>::raw_get_timer(self_ptr) }; + + // SAFETY: As `timer_ptr` points into `T` and `T` is valid, `timer_ptr` + // must point to a valid `HrTimer` instance. + unsafe { HrTimer::<T>::raw_cancel(timer_ptr) } + } +} + +impl<T> Drop for ArcHrTimerHandle<T> +where + T: HasHrTimer<T>, +{ + fn drop(&mut self) { + self.cancel(); + } +} + +impl<T> HrTimerPointer for Arc<T> +where + T: 'static, + T: Send + Sync, + T: HasHrTimer<T>, + T: for<'a> HrTimerCallback<Pointer<'a> = Self>, +{ + type TimerMode = <T as HasHrTimer<T>>::TimerMode; + type TimerHandle = ArcHrTimerHandle<T>; + + fn start( + self, + expires: <<T as HasHrTimer<T>>::TimerMode as HrTimerMode>::Expires, + ) -> ArcHrTimerHandle<T> { + // SAFETY: + // - We keep `self` alive by wrapping it in a handle below. + // - Since we generate the pointer passed to `start` from a valid + // reference, it is a valid pointer. + unsafe { T::start(Arc::as_ptr(&self), expires) }; + ArcHrTimerHandle { inner: self } + } +} + +impl<T> RawHrTimerCallback for Arc<T> +where + T: 'static, + T: HasHrTimer<T>, + T: for<'a> HrTimerCallback<Pointer<'a> = Self>, +{ + type CallbackTarget<'a> = ArcBorrow<'a, T>; + + unsafe extern "C" fn run(ptr: *mut bindings::hrtimer) -> bindings::hrtimer_restart { + // `HrTimer` is `repr(C)` + let timer_ptr = ptr.cast::<super::HrTimer<T>>(); + + // SAFETY: By C API contract `ptr` is the pointer we passed when + // queuing the timer, so it is a `HrTimer<T>` embedded in a `T`. + let data_ptr = unsafe { T::timer_container_of(timer_ptr) }; + + // SAFETY: + // - `data_ptr` is derived form the pointer to the `T` that was used to + // queue the timer. + // - As per the safety requirements of the trait `HrTimerHandle`, the + // `ArcHrTimerHandle` associated with this timer is guaranteed to + // be alive until this method returns. That handle borrows the `T` + // behind `data_ptr` thus guaranteeing the validity of + // the `ArcBorrow` created below. + // - We own one refcount in the `ArcTimerHandle` associated with this + // timer, so it is not possible to get a `UniqueArc` to this + // allocation from other `Arc` clones. + let receiver = unsafe { ArcBorrow::from_raw(data_ptr) }; + + // SAFETY: + // - By C API contract `timer_ptr` is the pointer that we passed when queuing the timer, so + // it is a valid pointer to a `HrTimer<T>` embedded in a `T`. + // - We are within `RawHrTimerCallback::run` + let context = unsafe { HrTimerCallbackContext::from_raw(timer_ptr) }; + + T::run(receiver, context).into_c() + } +} diff --git a/rust/kernel/time/hrtimer/pin.rs b/rust/kernel/time/hrtimer/pin.rs new file mode 100644 index 000000000000..4d39ef781697 --- /dev/null +++ b/rust/kernel/time/hrtimer/pin.rs @@ -0,0 +1,115 @@ +// SPDX-License-Identifier: GPL-2.0 + +use super::HasHrTimer; +use super::HrTimer; +use super::HrTimerCallback; +use super::HrTimerCallbackContext; +use super::HrTimerHandle; +use super::HrTimerMode; +use super::RawHrTimerCallback; +use super::UnsafeHrTimerPointer; +use core::pin::Pin; + +/// A handle for a `Pin<&HasHrTimer>`. When the handle exists, the timer might be +/// running. +pub struct PinHrTimerHandle<'a, T> +where + T: HasHrTimer<T>, +{ + pub(crate) inner: Pin<&'a T>, +} + +// SAFETY: We cancel the timer when the handle is dropped. The implementation of +// the `cancel` method will block if the timer handler is running. +unsafe impl<'a, T> HrTimerHandle for PinHrTimerHandle<'a, T> +where + T: HasHrTimer<T>, +{ + fn cancel(&mut self) -> bool { + let self_ptr: *const T = self.inner.get_ref(); + + // SAFETY: As we got `self_ptr` from a reference above, it must point to + // a valid `T`. + let timer_ptr = unsafe { <T as HasHrTimer<T>>::raw_get_timer(self_ptr) }; + + // SAFETY: As `timer_ptr` is derived from a reference, it must point to + // a valid and initialized `HrTimer`. + unsafe { HrTimer::<T>::raw_cancel(timer_ptr) } + } +} + +impl<'a, T> Drop for PinHrTimerHandle<'a, T> +where + T: HasHrTimer<T>, +{ + fn drop(&mut self) { + self.cancel(); + } +} + +// SAFETY: We capture the lifetime of `Self` when we create a `PinHrTimerHandle`, +// so `Self` will outlive the handle. +unsafe impl<'a, T> UnsafeHrTimerPointer for Pin<&'a T> +where + T: Send + Sync, + T: HasHrTimer<T>, + T: HrTimerCallback<Pointer<'a> = Self>, +{ + type TimerMode = <T as HasHrTimer<T>>::TimerMode; + type TimerHandle = PinHrTimerHandle<'a, T>; + + unsafe fn start( + self, + expires: <<T as HasHrTimer<T>>::TimerMode as HrTimerMode>::Expires, + ) -> Self::TimerHandle { + // Cast to pointer + let self_ptr: *const T = self.get_ref(); + + // SAFETY: + // - As we derive `self_ptr` from a reference above, it must point to a + // valid `T`. + // - We keep `self` alive by wrapping it in a handle below. + unsafe { T::start(self_ptr, expires) }; + + PinHrTimerHandle { inner: self } + } +} + +impl<'a, T> RawHrTimerCallback for Pin<&'a T> +where + T: HasHrTimer<T>, + T: HrTimerCallback<Pointer<'a> = Self>, +{ + type CallbackTarget<'b> = Self; + + unsafe extern "C" fn run(ptr: *mut bindings::hrtimer) -> bindings::hrtimer_restart { + // `HrTimer` is `repr(C)` + let timer_ptr = ptr.cast::<HrTimer<T>>(); + + // SAFETY: By the safety requirement of this function, `timer_ptr` + // points to a `HrTimer<T>` contained in an `T`. + let receiver_ptr = unsafe { T::timer_container_of(timer_ptr) }; + + // SAFETY: + // - By the safety requirement of this function, `timer_ptr` + // points to a `HrTimer<T>` contained in an `T`. + // - As per the safety requirements of the trait `HrTimerHandle`, the + // `PinHrTimerHandle` associated with this timer is guaranteed to + // be alive until this method returns. That handle borrows the `T` + // behind `receiver_ptr`, thus guaranteeing the validity of + // the reference created below. + let receiver_ref = unsafe { &*receiver_ptr }; + + // SAFETY: `receiver_ref` only exists as pinned, so it is safe to pin it + // here. + let receiver_pin = unsafe { Pin::new_unchecked(receiver_ref) }; + + // SAFETY: + // - By C API contract `timer_ptr` is the pointer that we passed when queuing the timer, so + // it is a valid pointer to a `HrTimer<T>` embedded in a `T`. + // - We are within `RawHrTimerCallback::run` + let context = unsafe { HrTimerCallbackContext::from_raw(timer_ptr) }; + + T::run(receiver_pin, context).into_c() + } +} diff --git a/rust/kernel/time/hrtimer/pin_mut.rs b/rust/kernel/time/hrtimer/pin_mut.rs new file mode 100644 index 000000000000..9d9447d4d57e --- /dev/null +++ b/rust/kernel/time/hrtimer/pin_mut.rs @@ -0,0 +1,118 @@ +// SPDX-License-Identifier: GPL-2.0 + +use super::{ + HasHrTimer, HrTimer, HrTimerCallback, HrTimerCallbackContext, HrTimerHandle, HrTimerMode, + RawHrTimerCallback, UnsafeHrTimerPointer, +}; +use core::{marker::PhantomData, pin::Pin, ptr::NonNull}; + +/// A handle for a `Pin<&mut HasHrTimer>`. When the handle exists, the timer might +/// be running. +pub struct PinMutHrTimerHandle<'a, T> +where + T: HasHrTimer<T>, +{ + pub(crate) inner: NonNull<T>, + _p: PhantomData<&'a mut T>, +} + +// SAFETY: We cancel the timer when the handle is dropped. The implementation of +// the `cancel` method will block if the timer handler is running. +unsafe impl<'a, T> HrTimerHandle for PinMutHrTimerHandle<'a, T> +where + T: HasHrTimer<T>, +{ + fn cancel(&mut self) -> bool { + let self_ptr = self.inner.as_ptr(); + + // SAFETY: As we got `self_ptr` from a reference above, it must point to + // a valid `T`. + let timer_ptr = unsafe { <T as HasHrTimer<T>>::raw_get_timer(self_ptr) }; + + // SAFETY: As `timer_ptr` is derived from a reference, it must point to + // a valid and initialized `HrTimer`. + unsafe { HrTimer::<T>::raw_cancel(timer_ptr) } + } +} + +impl<'a, T> Drop for PinMutHrTimerHandle<'a, T> +where + T: HasHrTimer<T>, +{ + fn drop(&mut self) { + self.cancel(); + } +} + +// SAFETY: We capture the lifetime of `Self` when we create a +// `PinMutHrTimerHandle`, so `Self` will outlive the handle. +unsafe impl<'a, T> UnsafeHrTimerPointer for Pin<&'a mut T> +where + T: Send + Sync, + T: HasHrTimer<T>, + T: HrTimerCallback<Pointer<'a> = Self>, +{ + type TimerMode = <T as HasHrTimer<T>>::TimerMode; + type TimerHandle = PinMutHrTimerHandle<'a, T>; + + unsafe fn start( + mut self, + expires: <<T as HasHrTimer<T>>::TimerMode as HrTimerMode>::Expires, + ) -> Self::TimerHandle { + // SAFETY: + // - We promise not to move out of `self`. We only pass `self` + // back to the caller as a `Pin<&mut self>`. + // - The return value of `get_unchecked_mut` is guaranteed not to be null. + let self_ptr = unsafe { NonNull::new_unchecked(self.as_mut().get_unchecked_mut()) }; + + // SAFETY: + // - As we derive `self_ptr` from a reference above, it must point to a + // valid `T`. + // - We keep `self` alive by wrapping it in a handle below. + unsafe { T::start(self_ptr.as_ptr(), expires) }; + + PinMutHrTimerHandle { + inner: self_ptr, + _p: PhantomData, + } + } +} + +impl<'a, T> RawHrTimerCallback for Pin<&'a mut T> +where + T: HasHrTimer<T>, + T: HrTimerCallback<Pointer<'a> = Self>, +{ + type CallbackTarget<'b> = Self; + + unsafe extern "C" fn run(ptr: *mut bindings::hrtimer) -> bindings::hrtimer_restart { + // `HrTimer` is `repr(C)` + let timer_ptr = ptr.cast::<HrTimer<T>>(); + + // SAFETY: By the safety requirement of this function, `timer_ptr` + // points to a `HrTimer<T>` contained in an `T`. + let receiver_ptr = unsafe { T::timer_container_of(timer_ptr) }; + + // SAFETY: + // - By the safety requirement of this function, `timer_ptr` + // points to a `HrTimer<T>` contained in an `T`. + // - As per the safety requirements of the trait `HrTimerHandle`, the + // `PinMutHrTimerHandle` associated with this timer is guaranteed to + // be alive until this method returns. That handle borrows the `T` + // behind `receiver_ptr` mutably thus guaranteeing the validity of + // the reference created below. + let receiver_ref = unsafe { &mut *receiver_ptr }; + + // SAFETY: `receiver_ref` only exists as pinned, so it is safe to pin it + // here. + let receiver_pin = unsafe { Pin::new_unchecked(receiver_ref) }; + + // SAFETY: + // - By C API contract `timer_ptr` is the pointer that we passed when queuing the timer, so + // it is a valid pointer to a `HrTimer<T>` embedded in a `T`. + // - We are within `RawHrTimerCallback::run` + let context = unsafe { HrTimerCallbackContext::from_raw(timer_ptr) }; + + T::run(receiver_pin, context).into_c() + } +} diff --git a/rust/kernel/time/hrtimer/tbox.rs b/rust/kernel/time/hrtimer/tbox.rs new file mode 100644 index 000000000000..aa1ee31a7195 --- /dev/null +++ b/rust/kernel/time/hrtimer/tbox.rs @@ -0,0 +1,131 @@ +// SPDX-License-Identifier: GPL-2.0 + +use super::HasHrTimer; +use super::HrTimer; +use super::HrTimerCallback; +use super::HrTimerCallbackContext; +use super::HrTimerHandle; +use super::HrTimerMode; +use super::HrTimerPointer; +use super::RawHrTimerCallback; +use crate::prelude::*; +use core::ptr::NonNull; + +/// A handle for a [`Box<HasHrTimer<T>>`] returned by a call to +/// [`HrTimerPointer::start`]. +/// +/// # Invariants +/// +/// - `self.inner` comes from a `Box::into_raw` call. +pub struct BoxHrTimerHandle<T, A> +where + T: HasHrTimer<T>, + A: crate::alloc::Allocator, +{ + pub(crate) inner: NonNull<T>, + _p: core::marker::PhantomData<A>, +} + +// SAFETY: We implement drop below, and we cancel the timer in the drop +// implementation. +unsafe impl<T, A> HrTimerHandle for BoxHrTimerHandle<T, A> +where + T: HasHrTimer<T>, + A: crate::alloc::Allocator, +{ + fn cancel(&mut self) -> bool { + // SAFETY: As we obtained `self.inner` from a valid reference when we + // created `self`, it must point to a valid `T`. + let timer_ptr = unsafe { <T as HasHrTimer<T>>::raw_get_timer(self.inner.as_ptr()) }; + + // SAFETY: As `timer_ptr` points into `T` and `T` is valid, `timer_ptr` + // must point to a valid `HrTimer` instance. + unsafe { HrTimer::<T>::raw_cancel(timer_ptr) } + } +} + +impl<T, A> Drop for BoxHrTimerHandle<T, A> +where + T: HasHrTimer<T>, + A: crate::alloc::Allocator, +{ + fn drop(&mut self) { + self.cancel(); + // SAFETY: By type invariant, `self.inner` came from a `Box::into_raw` + // call. + drop(unsafe { Box::<T, A>::from_raw(self.inner.as_ptr()) }) + } +} + +impl<T, A> HrTimerPointer for Pin<Box<T, A>> +where + T: 'static, + T: Send + Sync, + T: HasHrTimer<T>, + T: for<'a> HrTimerCallback<Pointer<'a> = Pin<Box<T, A>>>, + A: crate::alloc::Allocator, +{ + type TimerMode = <T as HasHrTimer<T>>::TimerMode; + type TimerHandle = BoxHrTimerHandle<T, A>; + + fn start( + self, + expires: <<T as HasHrTimer<T>>::TimerMode as HrTimerMode>::Expires, + ) -> Self::TimerHandle { + // SAFETY: + // - We will not move out of this box during timer callback (we pass an + // immutable reference to the callback). + // - `Box::into_raw` is guaranteed to return a valid pointer. + let inner = + unsafe { NonNull::new_unchecked(Box::into_raw(Pin::into_inner_unchecked(self))) }; + + // SAFETY: + // - We keep `self` alive by wrapping it in a handle below. + // - Since we generate the pointer passed to `start` from a valid + // reference, it is a valid pointer. + unsafe { T::start(inner.as_ptr(), expires) }; + + // INVARIANT: `inner` came from `Box::into_raw` above. + BoxHrTimerHandle { + inner, + _p: core::marker::PhantomData, + } + } +} + +impl<T, A> RawHrTimerCallback for Pin<Box<T, A>> +where + T: 'static, + T: HasHrTimer<T>, + T: for<'a> HrTimerCallback<Pointer<'a> = Pin<Box<T, A>>>, + A: crate::alloc::Allocator, +{ + type CallbackTarget<'a> = Pin<&'a mut T>; + + unsafe extern "C" fn run(ptr: *mut bindings::hrtimer) -> bindings::hrtimer_restart { + // `HrTimer` is `repr(C)` + let timer_ptr = ptr.cast::<super::HrTimer<T>>(); + + // SAFETY: By C API contract `ptr` is the pointer we passed when + // queuing the timer, so it is a `HrTimer<T>` embedded in a `T`. + let data_ptr = unsafe { T::timer_container_of(timer_ptr) }; + + // SAFETY: + // - As per the safety requirements of the trait `HrTimerHandle`, the + // `BoxHrTimerHandle` associated with this timer is guaranteed to + // be alive until this method returns. That handle owns the `T` + // behind `data_ptr` thus guaranteeing the validity of + // the reference created below. + // - As `data_ptr` comes from a `Pin<Box<T>>`, only pinned references to + // `data_ptr` exist. + let data_mut_ref = unsafe { Pin::new_unchecked(&mut *data_ptr) }; + + // SAFETY: + // - By C API contract `timer_ptr` is the pointer that we passed when queuing the timer, so + // it is a valid pointer to a `HrTimer<T>` embedded in a `T`. + // - We are within `RawHrTimerCallback::run` + let context = unsafe { HrTimerCallbackContext::from_raw(timer_ptr) }; + + T::run(data_mut_ref, context).into_c() + } +} diff --git a/rust/kernel/tracepoint.rs b/rust/kernel/tracepoint.rs new file mode 100644 index 000000000000..c6e80aa99e8e --- /dev/null +++ b/rust/kernel/tracepoint.rs @@ -0,0 +1,49 @@ +// SPDX-License-Identifier: GPL-2.0 + +// Copyright (C) 2024 Google LLC. + +//! Logic for tracepoints. + +/// Declare the Rust entry point for a tracepoint. +/// +/// This macro generates an unsafe function that calls into C, and its safety requirements will be +/// whatever the relevant C code requires. To document these safety requirements, you may add +/// doc-comments when invoking the macro. +#[macro_export] +macro_rules! declare_trace { + ($($(#[$attr:meta])* $pub:vis unsafe fn $name:ident($($argname:ident : $argtyp:ty),* $(,)?);)*) => {$( + $( #[$attr] )* + #[inline(always)] + $pub unsafe fn $name($($argname : $argtyp),*) { + #[cfg(CONFIG_TRACEPOINTS)] + { + // SAFETY: It's always okay to query the static key for a tracepoint. + let should_trace = unsafe { + $crate::macros::paste! { + $crate::jump_label::static_branch_unlikely!( + $crate::bindings::[< __tracepoint_ $name >], + $crate::bindings::tracepoint, + key + ) + } + }; + + if should_trace { + $crate::macros::paste! { + // SAFETY: The caller guarantees that it is okay to call this tracepoint. + unsafe { $crate::bindings::[< rust_do_trace_ $name >]($($argname),*) }; + } + } + } + + #[cfg(not(CONFIG_TRACEPOINTS))] + { + // If tracepoints are disabled, insert a trivial use of each argument + // to avoid unused argument warnings. + $( let _unused = $argname; )* + } + } + )*} +} + +pub use declare_trace; diff --git a/rust/kernel/transmute.rs b/rust/kernel/transmute.rs new file mode 100644 index 000000000000..be5dbf3829e2 --- /dev/null +++ b/rust/kernel/transmute.rs @@ -0,0 +1,244 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! Traits for transmuting types. + +use core::mem::size_of; + +/// Types for which any bit pattern is valid. +/// +/// Not all types are valid for all values. For example, a `bool` must be either zero or one, so +/// reading arbitrary bytes into something that contains a `bool` is not okay. +/// +/// It's okay for the type to have padding, as initializing those bytes has no effect. +/// +/// # Examples +/// +/// ``` +/// use kernel::transmute::FromBytes; +/// +/// # fn test() -> Option<()> { +/// let raw = [1, 2, 3, 4]; +/// +/// let result = u32::from_bytes(&raw)?; +/// +/// #[cfg(target_endian = "little")] +/// assert_eq!(*result, 0x4030201); +/// +/// #[cfg(target_endian = "big")] +/// assert_eq!(*result, 0x1020304); +/// +/// # Some(()) } +/// # test().ok_or(EINVAL)?; +/// # Ok::<(), Error>(()) +/// ``` +/// +/// # Safety +/// +/// All bit-patterns must be valid for this type. This type must not have interior mutability. +pub unsafe trait FromBytes { + /// Converts a slice of bytes to a reference to `Self`. + /// + /// Succeeds if the reference is properly aligned, and the size of `bytes` is equal to that of + /// `T` and different from zero. + /// + /// Otherwise, returns [`None`]. + fn from_bytes(bytes: &[u8]) -> Option<&Self> + where + Self: Sized, + { + let slice_ptr = bytes.as_ptr().cast::<Self>(); + let size = size_of::<Self>(); + + #[allow(clippy::incompatible_msrv)] + if bytes.len() == size && slice_ptr.is_aligned() { + // SAFETY: Size and alignment were just checked. + unsafe { Some(&*slice_ptr) } + } else { + None + } + } + + /// Converts the beginning of `bytes` to a reference to `Self`. + /// + /// This method is similar to [`Self::from_bytes`], with the difference that `bytes` does not + /// need to be the same size of `Self` - the appropriate portion is cut from the beginning of + /// `bytes`, and the remainder returned alongside `Self`. + fn from_bytes_prefix(bytes: &[u8]) -> Option<(&Self, &[u8])> + where + Self: Sized, + { + if bytes.len() < size_of::<Self>() { + None + } else { + // PANIC: We checked that `bytes.len() >= size_of::<Self>`, thus `split_at` cannot + // panic. + // TODO: replace with `split_at_checked` once the MSRV is >= 1.80. + let (prefix, remainder) = bytes.split_at(size_of::<Self>()); + + Self::from_bytes(prefix).map(|s| (s, remainder)) + } + } + + /// Converts a mutable slice of bytes to a reference to `Self`. + /// + /// Succeeds if the reference is properly aligned, and the size of `bytes` is equal to that of + /// `T` and different from zero. + /// + /// Otherwise, returns [`None`]. + fn from_bytes_mut(bytes: &mut [u8]) -> Option<&mut Self> + where + Self: AsBytes + Sized, + { + let slice_ptr = bytes.as_mut_ptr().cast::<Self>(); + let size = size_of::<Self>(); + + #[allow(clippy::incompatible_msrv)] + if bytes.len() == size && slice_ptr.is_aligned() { + // SAFETY: Size and alignment were just checked. + unsafe { Some(&mut *slice_ptr) } + } else { + None + } + } + + /// Converts the beginning of `bytes` to a mutable reference to `Self`. + /// + /// This method is similar to [`Self::from_bytes_mut`], with the difference that `bytes` does + /// not need to be the same size of `Self` - the appropriate portion is cut from the beginning + /// of `bytes`, and the remainder returned alongside `Self`. + fn from_bytes_mut_prefix(bytes: &mut [u8]) -> Option<(&mut Self, &mut [u8])> + where + Self: AsBytes + Sized, + { + if bytes.len() < size_of::<Self>() { + None + } else { + // PANIC: We checked that `bytes.len() >= size_of::<Self>`, thus `split_at_mut` cannot + // panic. + // TODO: replace with `split_at_mut_checked` once the MSRV is >= 1.80. + let (prefix, remainder) = bytes.split_at_mut(size_of::<Self>()); + + Self::from_bytes_mut(prefix).map(|s| (s, remainder)) + } + } + + /// Creates an owned instance of `Self` by copying `bytes`. + /// + /// Unlike [`FromBytes::from_bytes`], which requires aligned input, this method can be used on + /// non-aligned data at the cost of a copy. + fn from_bytes_copy(bytes: &[u8]) -> Option<Self> + where + Self: Sized, + { + if bytes.len() == size_of::<Self>() { + // SAFETY: we just verified that `bytes` has the same size as `Self`, and per the + // invariants of `FromBytes`, any byte sequence of the correct length is a valid value + // for `Self`. + Some(unsafe { core::ptr::read_unaligned(bytes.as_ptr().cast::<Self>()) }) + } else { + None + } + } + + /// Creates an owned instance of `Self` from the beginning of `bytes`. + /// + /// This method is similar to [`Self::from_bytes_copy`], with the difference that `bytes` does + /// not need to be the same size of `Self` - the appropriate portion is cut from the beginning + /// of `bytes`, and the remainder returned alongside `Self`. + fn from_bytes_copy_prefix(bytes: &[u8]) -> Option<(Self, &[u8])> + where + Self: Sized, + { + if bytes.len() < size_of::<Self>() { + None + } else { + // PANIC: We checked that `bytes.len() >= size_of::<Self>`, thus `split_at` cannot + // panic. + // TODO: replace with `split_at_checked` once the MSRV is >= 1.80. + let (prefix, remainder) = bytes.split_at(size_of::<Self>()); + + Self::from_bytes_copy(prefix).map(|s| (s, remainder)) + } + } +} + +macro_rules! impl_frombytes { + ($($({$($generics:tt)*})? $t:ty, )*) => { + // SAFETY: Safety comments written in the macro invocation. + $(unsafe impl$($($generics)*)? FromBytes for $t {})* + }; +} + +impl_frombytes! { + // SAFETY: All bit patterns are acceptable values of the types below. + u8, u16, u32, u64, usize, + i8, i16, i32, i64, isize, + + // SAFETY: If all bit patterns are acceptable for individual values in an array, then all bit + // patterns are also acceptable for arrays of that type. + {<T: FromBytes>} [T], + {<T: FromBytes, const N: usize>} [T; N], +} + +/// Types that can be viewed as an immutable slice of initialized bytes. +/// +/// If a struct implements this trait, then it is okay to copy it byte-for-byte to userspace. This +/// means that it should not have any padding, as padding bytes are uninitialized. Reading +/// uninitialized memory is not just undefined behavior, it may even lead to leaking sensitive +/// information on the stack to userspace. +/// +/// The struct should also not hold kernel pointers, as kernel pointer addresses are also considered +/// sensitive. However, leaking kernel pointers is not considered undefined behavior by Rust, so +/// this is a correctness requirement, but not a safety requirement. +/// +/// # Safety +/// +/// Values of this type may not contain any uninitialized bytes. This type must not have interior +/// mutability. +pub unsafe trait AsBytes { + /// Returns `self` as a slice of bytes. + fn as_bytes(&self) -> &[u8] { + // CAST: `Self` implements `AsBytes` thus all bytes of `self` are initialized. + let data = core::ptr::from_ref(self).cast::<u8>(); + let len = core::mem::size_of_val(self); + + // SAFETY: `data` is non-null and valid for reads of `len * sizeof::<u8>()` bytes. + unsafe { core::slice::from_raw_parts(data, len) } + } + + /// Returns `self` as a mutable slice of bytes. + fn as_bytes_mut(&mut self) -> &mut [u8] + where + Self: FromBytes, + { + // CAST: `Self` implements both `AsBytes` and `FromBytes` thus making `Self` + // bi-directionally transmutable to `[u8; size_of_val(self)]`. + let data = core::ptr::from_mut(self).cast::<u8>(); + let len = core::mem::size_of_val(self); + + // SAFETY: `data` is non-null and valid for read and writes of `len * sizeof::<u8>()` + // bytes. + unsafe { core::slice::from_raw_parts_mut(data, len) } + } +} + +macro_rules! impl_asbytes { + ($($({$($generics:tt)*})? $t:ty, )*) => { + // SAFETY: Safety comments written in the macro invocation. + $(unsafe impl$($($generics)*)? AsBytes for $t {})* + }; +} + +impl_asbytes! { + // SAFETY: Instances of the following types have no uninitialized portions. + u8, u16, u32, u64, usize, + i8, i16, i32, i64, isize, + bool, + char, + str, + + // SAFETY: If individual values in an array have no uninitialized portions, then the array + // itself does not have any uninitialized portions either. + {<T: AsBytes>} [T], + {<T: AsBytes, const N: usize>} [T; N], +} diff --git a/rust/kernel/types.rs b/rust/kernel/types.rs index e84e51ec9716..9c5e7dbf1632 100644 --- a/rust/kernel/types.rs +++ b/rust/kernel/types.rs @@ -2,36 +2,443 @@ //! Kernel types. -use core::{cell::UnsafeCell, mem::MaybeUninit}; +use crate::ffi::c_void; +use core::{ + cell::UnsafeCell, + marker::{PhantomData, PhantomPinned}, + mem::MaybeUninit, + ops::{Deref, DerefMut}, +}; +use pin_init::{PinInit, Wrapper, Zeroable}; + +pub use crate::sync::aref::{ARef, AlwaysRefCounted}; + +/// Used to transfer ownership to and from foreign (non-Rust) languages. +/// +/// Ownership is transferred from Rust to a foreign language by calling [`Self::into_foreign`] and +/// later may be transferred back to Rust by calling [`Self::from_foreign`]. +/// +/// This trait is meant to be used in cases when Rust objects are stored in C objects and +/// eventually "freed" back to Rust. +/// +/// # Safety +/// +/// - Implementations must satisfy the guarantees of [`Self::into_foreign`]. +pub unsafe trait ForeignOwnable: Sized { + /// The alignment of pointers returned by `into_foreign`. + const FOREIGN_ALIGN: usize; + + /// Type used to immutably borrow a value that is currently foreign-owned. + type Borrowed<'a>; + + /// Type used to mutably borrow a value that is currently foreign-owned. + type BorrowedMut<'a>; + + /// Converts a Rust-owned object to a foreign-owned one. + /// + /// The foreign representation is a pointer to void. Aside from the guarantees listed below, + /// there are no other guarantees for this pointer. For example, it might be invalid, dangling + /// or pointing to uninitialized memory. Using it in any way except for [`from_foreign`], + /// [`try_from_foreign`], [`borrow`], or [`borrow_mut`] can result in undefined behavior. + /// + /// # Guarantees + /// + /// - Minimum alignment of returned pointer is [`Self::FOREIGN_ALIGN`]. + /// - The returned pointer is not null. + /// + /// [`from_foreign`]: Self::from_foreign + /// [`try_from_foreign`]: Self::try_from_foreign + /// [`borrow`]: Self::borrow + /// [`borrow_mut`]: Self::borrow_mut + fn into_foreign(self) -> *mut c_void; + + /// Converts a foreign-owned object back to a Rust-owned one. + /// + /// # Safety + /// + /// The provided pointer must have been returned by a previous call to [`into_foreign`], and it + /// must not be passed to `from_foreign` more than once. + /// + /// [`into_foreign`]: Self::into_foreign + unsafe fn from_foreign(ptr: *mut c_void) -> Self; + + /// Tries to convert a foreign-owned object back to a Rust-owned one. + /// + /// A convenience wrapper over [`ForeignOwnable::from_foreign`] that returns [`None`] if `ptr` + /// is null. + /// + /// # Safety + /// + /// `ptr` must either be null or satisfy the safety requirements for [`from_foreign`]. + /// + /// [`from_foreign`]: Self::from_foreign + unsafe fn try_from_foreign(ptr: *mut c_void) -> Option<Self> { + if ptr.is_null() { + None + } else { + // SAFETY: Since `ptr` is not null here, then `ptr` satisfies the safety requirements + // of `from_foreign` given the safety requirements of this function. + unsafe { Some(Self::from_foreign(ptr)) } + } + } + + /// Borrows a foreign-owned object immutably. + /// + /// This method provides a way to access a foreign-owned value from Rust immutably. It provides + /// you with exactly the same abilities as an `&Self` when the value is Rust-owned. + /// + /// # Safety + /// + /// The provided pointer must have been returned by a previous call to [`into_foreign`], and if + /// the pointer is ever passed to [`from_foreign`], then that call must happen after the end of + /// the lifetime `'a`. + /// + /// [`into_foreign`]: Self::into_foreign + /// [`from_foreign`]: Self::from_foreign + unsafe fn borrow<'a>(ptr: *mut c_void) -> Self::Borrowed<'a>; + + /// Borrows a foreign-owned object mutably. + /// + /// This method provides a way to access a foreign-owned value from Rust mutably. It provides + /// you with exactly the same abilities as an `&mut Self` when the value is Rust-owned, except + /// that the address of the object must not be changed. + /// + /// Note that for types like [`Arc`], an `&mut Arc<T>` only gives you immutable access to the + /// inner value, so this method also only provides immutable access in that case. + /// + /// In the case of `Box<T>`, this method gives you the ability to modify the inner `T`, but it + /// does not let you change the box itself. That is, you cannot change which allocation the box + /// points at. + /// + /// # Safety + /// + /// The provided pointer must have been returned by a previous call to [`into_foreign`], and if + /// the pointer is ever passed to [`from_foreign`], then that call must happen after the end of + /// the lifetime `'a`. + /// + /// The lifetime `'a` must not overlap with the lifetime of any other call to [`borrow`] or + /// `borrow_mut` on the same object. + /// + /// [`into_foreign`]: Self::into_foreign + /// [`from_foreign`]: Self::from_foreign + /// [`borrow`]: Self::borrow + /// [`Arc`]: crate::sync::Arc + unsafe fn borrow_mut<'a>(ptr: *mut c_void) -> Self::BorrowedMut<'a>; +} + +// SAFETY: The pointer returned by `into_foreign` comes from a well aligned +// pointer to `()`. +unsafe impl ForeignOwnable for () { + const FOREIGN_ALIGN: usize = core::mem::align_of::<()>(); + type Borrowed<'a> = (); + type BorrowedMut<'a> = (); + + fn into_foreign(self) -> *mut c_void { + core::ptr::NonNull::dangling().as_ptr() + } + + unsafe fn from_foreign(_: *mut c_void) -> Self {} + + unsafe fn borrow<'a>(_: *mut c_void) -> Self::Borrowed<'a> {} + unsafe fn borrow_mut<'a>(_: *mut c_void) -> Self::BorrowedMut<'a> {} +} + +/// Runs a cleanup function/closure when dropped. +/// +/// The [`ScopeGuard::dismiss`] function prevents the cleanup function from running. +/// +/// # Examples +/// +/// In the example below, we have multiple exit paths and we want to log regardless of which one is +/// taken: +/// +/// ``` +/// # use kernel::types::ScopeGuard; +/// fn example1(arg: bool) { +/// let _log = ScopeGuard::new(|| pr_info!("example1 completed\n")); +/// +/// if arg { +/// return; +/// } +/// +/// pr_info!("Do something...\n"); +/// } +/// +/// # example1(false); +/// # example1(true); +/// ``` +/// +/// In the example below, we want to log the same message on all early exits but a different one on +/// the main exit path: +/// +/// ``` +/// # use kernel::types::ScopeGuard; +/// fn example2(arg: bool) { +/// let log = ScopeGuard::new(|| pr_info!("example2 returned early\n")); +/// +/// if arg { +/// return; +/// } +/// +/// // (Other early returns...) +/// +/// log.dismiss(); +/// pr_info!("example2 no early return\n"); +/// } +/// +/// # example2(false); +/// # example2(true); +/// ``` +/// +/// In the example below, we need a mutable object (the vector) to be accessible within the log +/// function, so we wrap it in the [`ScopeGuard`]: +/// +/// ``` +/// # use kernel::types::ScopeGuard; +/// fn example3(arg: bool) -> Result { +/// let mut vec = +/// ScopeGuard::new_with_data(KVec::new(), |v| pr_info!("vec had {} elements\n", v.len())); +/// +/// vec.push(10u8, GFP_KERNEL)?; +/// if arg { +/// return Ok(()); +/// } +/// vec.push(20u8, GFP_KERNEL)?; +/// Ok(()) +/// } +/// +/// # assert_eq!(example3(false), Ok(())); +/// # assert_eq!(example3(true), Ok(())); +/// ``` +/// +/// # Invariants +/// +/// The value stored in the struct is nearly always `Some(_)`, except between +/// [`ScopeGuard::dismiss`] and [`ScopeGuard::drop`]: in this case, it will be `None` as the value +/// will have been returned to the caller. Since [`ScopeGuard::dismiss`] consumes the guard, +/// callers won't be able to use it anymore. +pub struct ScopeGuard<T, F: FnOnce(T)>(Option<(T, F)>); + +impl<T, F: FnOnce(T)> ScopeGuard<T, F> { + /// Creates a new guarded object wrapping the given data and with the given cleanup function. + pub fn new_with_data(data: T, cleanup_func: F) -> Self { + // INVARIANT: The struct is being initialised with `Some(_)`. + Self(Some((data, cleanup_func))) + } + + /// Prevents the cleanup function from running and returns the guarded data. + pub fn dismiss(mut self) -> T { + // INVARIANT: This is the exception case in the invariant; it is not visible to callers + // because this function consumes `self`. + self.0.take().unwrap().0 + } +} + +impl ScopeGuard<(), fn(())> { + /// Creates a new guarded object with the given cleanup function. + pub fn new(cleanup: impl FnOnce()) -> ScopeGuard<(), impl FnOnce(())> { + ScopeGuard::new_with_data((), move |()| cleanup()) + } +} + +impl<T, F: FnOnce(T)> Deref for ScopeGuard<T, F> { + type Target = T; + + fn deref(&self) -> &T { + // The type invariants guarantee that `unwrap` will succeed. + &self.0.as_ref().unwrap().0 + } +} + +impl<T, F: FnOnce(T)> DerefMut for ScopeGuard<T, F> { + fn deref_mut(&mut self) -> &mut T { + // The type invariants guarantee that `unwrap` will succeed. + &mut self.0.as_mut().unwrap().0 + } +} + +impl<T, F: FnOnce(T)> Drop for ScopeGuard<T, F> { + fn drop(&mut self) { + // Run the cleanup function if one is still present. + if let Some((data, cleanup)) = self.0.take() { + cleanup(data) + } + } +} /// Stores an opaque value. /// -/// This is meant to be used with FFI objects that are never interpreted by Rust code. +/// [`Opaque<T>`] is meant to be used with FFI objects that are never interpreted by Rust code. +/// +/// It is used to wrap structs from the C side, like for example `Opaque<bindings::mutex>`. +/// It gets rid of all the usual assumptions that Rust has for a value: +/// +/// * The value is allowed to be uninitialized (for example have invalid bit patterns: `3` for a +/// [`bool`]). +/// * The value is allowed to be mutated, when a `&Opaque<T>` exists on the Rust side. +/// * No uniqueness for mutable references: it is fine to have multiple `&mut Opaque<T>` point to +/// the same value. +/// * The value is not allowed to be shared with other threads (i.e. it is `!Sync`). +/// +/// This has to be used for all values that the C side has access to, because it can't be ensured +/// that the C side is adhering to the usual constraints that Rust needs. +/// +/// Using [`Opaque<T>`] allows to continue to use references on the Rust side even for values shared +/// with C. +/// +/// # Examples +/// +/// ``` +/// use kernel::types::Opaque; +/// # // Emulate a C struct binding which is from C, maybe uninitialized or not, only the C side +/// # // knows. +/// # mod bindings { +/// # pub struct Foo { +/// # pub val: u8, +/// # } +/// # } +/// +/// // `foo.val` is assumed to be handled on the C side, so we use `Opaque` to wrap it. +/// pub struct Foo { +/// foo: Opaque<bindings::Foo>, +/// } +/// +/// impl Foo { +/// pub fn get_val(&self) -> u8 { +/// let ptr = Opaque::get(&self.foo); +/// +/// // SAFETY: `Self` is valid from C side. +/// unsafe { (*ptr).val } +/// } +/// } +/// +/// // Create an instance of `Foo` with the `Opaque` wrapper. +/// let foo = Foo { +/// foo: Opaque::new(bindings::Foo { val: 0xdb }), +/// }; +/// +/// assert_eq!(foo.get_val(), 0xdb); +/// ``` #[repr(transparent)] -pub struct Opaque<T>(MaybeUninit<UnsafeCell<T>>); +pub struct Opaque<T> { + value: UnsafeCell<MaybeUninit<T>>, + _pin: PhantomPinned, +} + +// SAFETY: `Opaque<T>` allows the inner value to be any bit pattern, including all zeros. +unsafe impl<T> Zeroable for Opaque<T> {} impl<T> Opaque<T> { /// Creates a new opaque value. pub const fn new(value: T) -> Self { - Self(MaybeUninit::new(UnsafeCell::new(value))) + Self { + value: UnsafeCell::new(MaybeUninit::new(value)), + _pin: PhantomPinned, + } } /// Creates an uninitialised value. pub const fn uninit() -> Self { - Self(MaybeUninit::uninit()) + Self { + value: UnsafeCell::new(MaybeUninit::uninit()), + _pin: PhantomPinned, + } + } + + /// Creates a new zeroed opaque value. + pub const fn zeroed() -> Self { + Self { + value: UnsafeCell::new(MaybeUninit::zeroed()), + _pin: PhantomPinned, + } + } + + /// Creates a pin-initializer from the given initializer closure. + /// + /// The returned initializer calls the given closure with the pointer to the inner `T` of this + /// `Opaque`. Since this memory is uninitialized, the closure is not allowed to read from it. + /// + /// This function is safe, because the `T` inside of an `Opaque` is allowed to be + /// uninitialized. Additionally, access to the inner `T` requires `unsafe`, so the caller needs + /// to verify at that point that the inner value is valid. + pub fn ffi_init(init_func: impl FnOnce(*mut T)) -> impl PinInit<Self> { + // SAFETY: We contain a `MaybeUninit`, so it is OK for the `init_func` to not fully + // initialize the `T`. + unsafe { + pin_init::pin_init_from_closure::<_, ::core::convert::Infallible>(move |slot| { + init_func(Self::cast_into(slot)); + Ok(()) + }) + } + } + + /// Creates a fallible pin-initializer from the given initializer closure. + /// + /// The returned initializer calls the given closure with the pointer to the inner `T` of this + /// `Opaque`. Since this memory is uninitialized, the closure is not allowed to read from it. + /// + /// This function is safe, because the `T` inside of an `Opaque` is allowed to be + /// uninitialized. Additionally, access to the inner `T` requires `unsafe`, so the caller needs + /// to verify at that point that the inner value is valid. + pub fn try_ffi_init<E>( + init_func: impl FnOnce(*mut T) -> Result<(), E>, + ) -> impl PinInit<Self, E> { + // SAFETY: We contain a `MaybeUninit`, so it is OK for the `init_func` to not fully + // initialize the `T`. + unsafe { + pin_init::pin_init_from_closure::<_, E>(move |slot| init_func(Self::cast_into(slot))) + } } /// Returns a raw pointer to the opaque data. - pub fn get(&self) -> *mut T { - UnsafeCell::raw_get(self.0.as_ptr()) + pub const fn get(&self) -> *mut T { + UnsafeCell::get(&self.value).cast::<T>() } -} -/// A sum type that always holds either a value of type `L` or `R`. -pub enum Either<L, R> { - /// Constructs an instance of [`Either`] containing a value of type `L`. - Left(L), + /// Gets the value behind `this`. + /// + /// This function is useful to get access to the value without creating intermediate + /// references. + pub const fn cast_into(this: *const Self) -> *mut T { + UnsafeCell::raw_get(this.cast::<UnsafeCell<MaybeUninit<T>>>()).cast::<T>() + } + + /// The opposite operation of [`Opaque::cast_into`]. + pub const fn cast_from(this: *const T) -> *const Self { + this.cast() + } +} - /// Constructs an instance of [`Either`] containing a value of type `R`. - Right(R), +impl<T> Wrapper<T> for Opaque<T> { + /// Create an opaque pin-initializer from the given pin-initializer. + fn pin_init<E>(slot: impl PinInit<T, E>) -> impl PinInit<Self, E> { + Self::try_ffi_init(|ptr: *mut T| { + // SAFETY: + // - `ptr` is a valid pointer to uninitialized memory, + // - `slot` is not accessed on error, + // - `slot` is pinned in memory. + unsafe { PinInit::<T, E>::__pinned_init(slot, ptr) } + }) + } } + +/// Zero-sized type to mark types not [`Send`]. +/// +/// Add this type as a field to your struct if your type should not be sent to a different task. +/// Since [`Send`] is an auto trait, adding a single field that is `!Send` will ensure that the +/// whole type is `!Send`. +/// +/// If a type is `!Send` it is impossible to give control over an instance of the type to another +/// task. This is useful to include in types that store or reference task-local information. A file +/// descriptor is an example of such task-local information. +/// +/// This type also makes the type `!Sync`, which prevents immutable access to the value from +/// several threads in parallel. +pub type NotThreadSafe = PhantomData<*mut ()>; + +/// Used to construct instances of type [`NotThreadSafe`] similar to how `PhantomData` is +/// constructed. +/// +/// [`NotThreadSafe`]: type@NotThreadSafe +#[allow(non_upper_case_globals)] +pub const NotThreadSafe: NotThreadSafe = PhantomData; diff --git a/rust/kernel/uaccess.rs b/rust/kernel/uaccess.rs new file mode 100644 index 000000000000..f989539a31b4 --- /dev/null +++ b/rust/kernel/uaccess.rs @@ -0,0 +1,593 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! Slices to user space memory regions. +//! +//! C header: [`include/linux/uaccess.h`](srctree/include/linux/uaccess.h) + +use crate::{ + alloc::{Allocator, Flags}, + bindings, + error::Result, + ffi::{c_char, c_void}, + fs::file, + prelude::*, + transmute::{AsBytes, FromBytes}, +}; +use core::mem::{size_of, MaybeUninit}; + +/// A pointer into userspace. +/// +/// This is the Rust equivalent to C pointers tagged with `__user`. +#[repr(transparent)] +#[derive(Copy, Clone)] +pub struct UserPtr(*mut c_void); + +impl UserPtr { + /// Create a `UserPtr` from an integer representing the userspace address. + #[inline] + pub fn from_addr(addr: usize) -> Self { + Self(addr as *mut c_void) + } + + /// Create a `UserPtr` from a pointer representing the userspace address. + #[inline] + pub fn from_ptr(addr: *mut c_void) -> Self { + Self(addr) + } + + /// Cast this userspace pointer to a raw const void pointer. + /// + /// It is up to the caller to use the returned pointer correctly. + #[inline] + pub fn as_const_ptr(self) -> *const c_void { + self.0 + } + + /// Cast this userspace pointer to a raw mutable void pointer. + /// + /// It is up to the caller to use the returned pointer correctly. + #[inline] + pub fn as_mut_ptr(self) -> *mut c_void { + self.0 + } + + /// Increment this user pointer by `add` bytes. + /// + /// This addition is wrapping, so wrapping around the address space does not result in a panic + /// even if `CONFIG_RUST_OVERFLOW_CHECKS` is enabled. + #[inline] + pub fn wrapping_byte_add(self, add: usize) -> UserPtr { + UserPtr(self.0.wrapping_byte_add(add)) + } +} + +/// A pointer to an area in userspace memory, which can be either read-only or read-write. +/// +/// All methods on this struct are safe: attempting to read or write on bad addresses (either out of +/// the bound of the slice or unmapped addresses) will return [`EFAULT`]. Concurrent access, +/// *including data races to/from userspace memory*, is permitted, because fundamentally another +/// userspace thread/process could always be modifying memory at the same time (in the same way that +/// userspace Rust's [`std::io`] permits data races with the contents of files on disk). In the +/// presence of a race, the exact byte values read/written are unspecified but the operation is +/// well-defined. Kernelspace code should validate its copy of data after completing a read, and not +/// expect that multiple reads of the same address will return the same value. +/// +/// These APIs are designed to make it difficult to accidentally write TOCTOU (time-of-check to +/// time-of-use) bugs. Every time a memory location is read, the reader's position is advanced by +/// the read length and the next read will start from there. This helps prevent accidentally reading +/// the same location twice and causing a TOCTOU bug. +/// +/// Creating a [`UserSliceReader`] and/or [`UserSliceWriter`] consumes the `UserSlice`, helping +/// ensure that there aren't multiple readers or writers to the same location. +/// +/// If double-fetching a memory location is necessary for some reason, then that is done by creating +/// multiple readers to the same memory location, e.g. using [`clone_reader`]. +/// +/// # Examples +/// +/// Takes a region of userspace memory from the current process, and modify it by adding one to +/// every byte in the region. +/// +/// ```no_run +/// use kernel::ffi::c_void; +/// use kernel::uaccess::{UserPtr, UserSlice}; +/// +/// fn bytes_add_one(uptr: UserPtr, len: usize) -> Result { +/// let (read, mut write) = UserSlice::new(uptr, len).reader_writer(); +/// +/// let mut buf = KVec::new(); +/// read.read_all(&mut buf, GFP_KERNEL)?; +/// +/// for b in &mut buf { +/// *b = b.wrapping_add(1); +/// } +/// +/// write.write_slice(&buf)?; +/// Ok(()) +/// } +/// ``` +/// +/// Example illustrating a TOCTOU (time-of-check to time-of-use) bug. +/// +/// ```no_run +/// use kernel::ffi::c_void; +/// use kernel::uaccess::{UserPtr, UserSlice}; +/// +/// /// Returns whether the data in this region is valid. +/// fn is_valid(uptr: UserPtr, len: usize) -> Result<bool> { +/// let read = UserSlice::new(uptr, len).reader(); +/// +/// let mut buf = KVec::new(); +/// read.read_all(&mut buf, GFP_KERNEL)?; +/// +/// todo!() +/// } +/// +/// /// Returns the bytes behind this user pointer if they are valid. +/// fn get_bytes_if_valid(uptr: UserPtr, len: usize) -> Result<KVec<u8>> { +/// if !is_valid(uptr, len)? { +/// return Err(EINVAL); +/// } +/// +/// let read = UserSlice::new(uptr, len).reader(); +/// +/// let mut buf = KVec::new(); +/// read.read_all(&mut buf, GFP_KERNEL)?; +/// +/// // THIS IS A BUG! The bytes could have changed since we checked them. +/// // +/// // To avoid this kind of bug, don't call `UserSlice::new` multiple +/// // times with the same address. +/// Ok(buf) +/// } +/// ``` +/// +/// [`std::io`]: https://doc.rust-lang.org/std/io/index.html +/// [`clone_reader`]: UserSliceReader::clone_reader +pub struct UserSlice { + ptr: UserPtr, + length: usize, +} + +impl UserSlice { + /// Constructs a user slice from a raw pointer and a length in bytes. + /// + /// Constructing a [`UserSlice`] performs no checks on the provided address and length, it can + /// safely be constructed inside a kernel thread with no current userspace process. Reads and + /// writes wrap the kernel APIs `copy_from_user` and `copy_to_user`, which check the memory map + /// of the current process and enforce that the address range is within the user range (no + /// additional calls to `access_ok` are needed). Validity of the pointer is checked when you + /// attempt to read or write, not in the call to `UserSlice::new`. + /// + /// Callers must be careful to avoid time-of-check-time-of-use (TOCTOU) issues. The simplest way + /// is to create a single instance of [`UserSlice`] per user memory block as it reads each byte + /// at most once. + pub fn new(ptr: UserPtr, length: usize) -> Self { + UserSlice { ptr, length } + } + + /// Reads the entirety of the user slice, appending it to the end of the provided buffer. + /// + /// Fails with [`EFAULT`] if the read happens on a bad address. + pub fn read_all<A: Allocator>(self, buf: &mut Vec<u8, A>, flags: Flags) -> Result { + self.reader().read_all(buf, flags) + } + + /// Constructs a [`UserSliceReader`]. + pub fn reader(self) -> UserSliceReader { + UserSliceReader { + ptr: self.ptr, + length: self.length, + } + } + + /// Constructs a [`UserSliceWriter`]. + pub fn writer(self) -> UserSliceWriter { + UserSliceWriter { + ptr: self.ptr, + length: self.length, + } + } + + /// Constructs both a [`UserSliceReader`] and a [`UserSliceWriter`]. + /// + /// Usually when this is used, you will first read the data, and then overwrite it afterwards. + pub fn reader_writer(self) -> (UserSliceReader, UserSliceWriter) { + ( + UserSliceReader { + ptr: self.ptr, + length: self.length, + }, + UserSliceWriter { + ptr: self.ptr, + length: self.length, + }, + ) + } +} + +/// A reader for [`UserSlice`]. +/// +/// Used to incrementally read from the user slice. +pub struct UserSliceReader { + ptr: UserPtr, + length: usize, +} + +impl UserSliceReader { + /// Skip the provided number of bytes. + /// + /// Returns an error if skipping more than the length of the buffer. + pub fn skip(&mut self, num_skip: usize) -> Result { + // Update `self.length` first since that's the fallible part of this operation. + self.length = self.length.checked_sub(num_skip).ok_or(EFAULT)?; + self.ptr = self.ptr.wrapping_byte_add(num_skip); + Ok(()) + } + + /// Create a reader that can access the same range of data. + /// + /// Reading from the clone does not advance the current reader. + /// + /// The caller should take care to not introduce TOCTOU issues, as described in the + /// documentation for [`UserSlice`]. + pub fn clone_reader(&self) -> UserSliceReader { + UserSliceReader { + ptr: self.ptr, + length: self.length, + } + } + + /// Returns the number of bytes left to be read from this reader. + /// + /// Note that even reading less than this number of bytes may fail. + pub fn len(&self) -> usize { + self.length + } + + /// Returns `true` if no data is available in the io buffer. + pub fn is_empty(&self) -> bool { + self.length == 0 + } + + /// Reads raw data from the user slice into a kernel buffer. + /// + /// For a version that uses `&mut [u8]`, please see [`UserSliceReader::read_slice`]. + /// + /// Fails with [`EFAULT`] if the read happens on a bad address, or if the read goes out of + /// bounds of this [`UserSliceReader`]. This call may modify `out` even if it returns an error. + /// + /// # Guarantees + /// + /// After a successful call to this method, all bytes in `out` are initialized. + pub fn read_raw(&mut self, out: &mut [MaybeUninit<u8>]) -> Result { + let len = out.len(); + let out_ptr = out.as_mut_ptr().cast::<c_void>(); + if len > self.length { + return Err(EFAULT); + } + // SAFETY: `out_ptr` points into a mutable slice of length `len`, so we may write + // that many bytes to it. + let res = unsafe { bindings::copy_from_user(out_ptr, self.ptr.as_const_ptr(), len) }; + if res != 0 { + return Err(EFAULT); + } + self.ptr = self.ptr.wrapping_byte_add(len); + self.length -= len; + Ok(()) + } + + /// Reads raw data from the user slice into a kernel buffer. + /// + /// Fails with [`EFAULT`] if the read happens on a bad address, or if the read goes out of + /// bounds of this [`UserSliceReader`]. This call may modify `out` even if it returns an error. + pub fn read_slice(&mut self, out: &mut [u8]) -> Result { + // SAFETY: The types are compatible and `read_raw` doesn't write uninitialized bytes to + // `out`. + let out = unsafe { &mut *(core::ptr::from_mut(out) as *mut [MaybeUninit<u8>]) }; + self.read_raw(out) + } + + /// Reads raw data from the user slice into a kernel buffer partially. + /// + /// This is the same as [`Self::read_slice`] but considers the given `offset` into `out` and + /// truncates the read to the boundaries of `self` and `out`. + /// + /// On success, returns the number of bytes read. + pub fn read_slice_partial(&mut self, out: &mut [u8], offset: usize) -> Result<usize> { + let end = offset.saturating_add(self.len()).min(out.len()); + + let Some(dst) = out.get_mut(offset..end) else { + return Ok(0); + }; + + self.read_slice(dst)?; + Ok(dst.len()) + } + + /// Reads raw data from the user slice into a kernel buffer partially. + /// + /// This is the same as [`Self::read_slice_partial`] but updates the given [`file::Offset`] by + /// the number of bytes read. + /// + /// This is equivalent to C's `simple_write_to_buffer()`. + /// + /// On success, returns the number of bytes read. + pub fn read_slice_file(&mut self, out: &mut [u8], offset: &mut file::Offset) -> Result<usize> { + if offset.is_negative() { + return Err(EINVAL); + } + + let Ok(offset_index) = (*offset).try_into() else { + return Ok(0); + }; + + let read = self.read_slice_partial(out, offset_index)?; + + // OVERFLOW: `offset + read <= data.len() <= isize::MAX <= Offset::MAX` + *offset += read as i64; + + Ok(read) + } + + /// Reads a value of the specified type. + /// + /// Fails with [`EFAULT`] if the read happens on a bad address, or if the read goes out of + /// bounds of this [`UserSliceReader`]. + pub fn read<T: FromBytes>(&mut self) -> Result<T> { + let len = size_of::<T>(); + if len > self.length { + return Err(EFAULT); + } + let mut out: MaybeUninit<T> = MaybeUninit::uninit(); + // SAFETY: The local variable `out` is valid for writing `size_of::<T>()` bytes. + // + // By using the _copy_from_user variant, we skip the check_object_size check that verifies + // the kernel pointer. This mirrors the logic on the C side that skips the check when the + // length is a compile-time constant. + let res = unsafe { + bindings::_copy_from_user( + out.as_mut_ptr().cast::<c_void>(), + self.ptr.as_const_ptr(), + len, + ) + }; + if res != 0 { + return Err(EFAULT); + } + self.ptr = self.ptr.wrapping_byte_add(len); + self.length -= len; + // SAFETY: The read above has initialized all bytes in `out`, and since `T` implements + // `FromBytes`, any bit-pattern is a valid value for this type. + Ok(unsafe { out.assume_init() }) + } + + /// Reads the entirety of the user slice, appending it to the end of the provided buffer. + /// + /// Fails with [`EFAULT`] if the read happens on a bad address. + pub fn read_all<A: Allocator>(mut self, buf: &mut Vec<u8, A>, flags: Flags) -> Result { + let len = self.length; + buf.reserve(len, flags)?; + + // The call to `reserve` was successful, so the spare capacity is at least `len` bytes long. + self.read_raw(&mut buf.spare_capacity_mut()[..len])?; + + // SAFETY: Since the call to `read_raw` was successful, so the next `len` bytes of the + // vector have been initialized. + unsafe { buf.inc_len(len) }; + Ok(()) + } + + /// Read a NUL-terminated string from userspace and return it. + /// + /// The string is read into `buf` and a NUL-terminator is added if the end of `buf` is reached. + /// Since there must be space to add a NUL-terminator, the buffer must not be empty. The + /// returned `&CStr` points into `buf`. + /// + /// Fails with [`EFAULT`] if the read happens on a bad address (some data may have been + /// copied). + #[doc(alias = "strncpy_from_user")] + pub fn strcpy_into_buf<'buf>(self, buf: &'buf mut [u8]) -> Result<&'buf CStr> { + if buf.is_empty() { + return Err(EINVAL); + } + + // SAFETY: The types are compatible and `strncpy_from_user` doesn't write uninitialized + // bytes to `buf`. + let mut dst = unsafe { &mut *(core::ptr::from_mut(buf) as *mut [MaybeUninit<u8>]) }; + + // We never read more than `self.length` bytes. + if dst.len() > self.length { + dst = &mut dst[..self.length]; + } + + let mut len = raw_strncpy_from_user(dst, self.ptr)?; + if len < dst.len() { + // Add one to include the NUL-terminator. + len += 1; + } else if len < buf.len() { + // This implies that `len == dst.len() < buf.len()`. + // + // This means that we could not fill the entire buffer, but we had to stop reading + // because we hit the `self.length` limit of this `UserSliceReader`. Since we did not + // fill the buffer, we treat this case as if we tried to read past the `self.length` + // limit and received a page fault, which is consistent with other `UserSliceReader` + // methods that also return page faults when you exceed `self.length`. + return Err(EFAULT); + } else { + // This implies that `len == buf.len()`. + // + // This means that we filled the buffer exactly. In this case, we add a NUL-terminator + // and return it. Unlike the `len < dst.len()` branch, don't modify `len` because it + // already represents the length including the NUL-terminator. + // + // SAFETY: Due to the check at the beginning, the buffer is not empty. + unsafe { *buf.last_mut().unwrap_unchecked() = 0 }; + } + + // This method consumes `self`, so it can only be called once, thus we do not need to + // update `self.length`. This sidesteps concerns such as whether `self.length` should be + // incremented by `len` or `len-1` in the `len == buf.len()` case. + + // SAFETY: There are two cases: + // * If we hit the `len < dst.len()` case, then `raw_strncpy_from_user` guarantees that + // this slice contains exactly one NUL byte at the end of the string. + // * Otherwise, `raw_strncpy_from_user` guarantees that the string contained no NUL bytes, + // and we have since added a NUL byte at the end. + Ok(unsafe { CStr::from_bytes_with_nul_unchecked(&buf[..len]) }) + } +} + +/// A writer for [`UserSlice`]. +/// +/// Used to incrementally write into the user slice. +pub struct UserSliceWriter { + ptr: UserPtr, + length: usize, +} + +impl UserSliceWriter { + /// Returns the amount of space remaining in this buffer. + /// + /// Note that even writing less than this number of bytes may fail. + pub fn len(&self) -> usize { + self.length + } + + /// Returns `true` if no more data can be written to this buffer. + pub fn is_empty(&self) -> bool { + self.length == 0 + } + + /// Writes raw data to this user pointer from a kernel buffer. + /// + /// Fails with [`EFAULT`] if the write happens on a bad address, or if the write goes out of + /// bounds of this [`UserSliceWriter`]. This call may modify the associated userspace slice even + /// if it returns an error. + pub fn write_slice(&mut self, data: &[u8]) -> Result { + let len = data.len(); + let data_ptr = data.as_ptr().cast::<c_void>(); + if len > self.length { + return Err(EFAULT); + } + // SAFETY: `data_ptr` points into an immutable slice of length `len`, so we may read + // that many bytes from it. + let res = unsafe { bindings::copy_to_user(self.ptr.as_mut_ptr(), data_ptr, len) }; + if res != 0 { + return Err(EFAULT); + } + self.ptr = self.ptr.wrapping_byte_add(len); + self.length -= len; + Ok(()) + } + + /// Writes raw data to this user pointer from a kernel buffer partially. + /// + /// This is the same as [`Self::write_slice`] but considers the given `offset` into `data` and + /// truncates the write to the boundaries of `self` and `data`. + /// + /// On success, returns the number of bytes written. + pub fn write_slice_partial(&mut self, data: &[u8], offset: usize) -> Result<usize> { + let end = offset.saturating_add(self.len()).min(data.len()); + + let Some(src) = data.get(offset..end) else { + return Ok(0); + }; + + self.write_slice(src)?; + Ok(src.len()) + } + + /// Writes raw data to this user pointer from a kernel buffer partially. + /// + /// This is the same as [`Self::write_slice_partial`] but updates the given [`file::Offset`] by + /// the number of bytes written. + /// + /// This is equivalent to C's `simple_read_from_buffer()`. + /// + /// On success, returns the number of bytes written. + pub fn write_slice_file(&mut self, data: &[u8], offset: &mut file::Offset) -> Result<usize> { + if offset.is_negative() { + return Err(EINVAL); + } + + let Ok(offset_index) = (*offset).try_into() else { + return Ok(0); + }; + + let written = self.write_slice_partial(data, offset_index)?; + + // OVERFLOW: `offset + written <= data.len() <= isize::MAX <= Offset::MAX` + *offset += written as i64; + + Ok(written) + } + + /// Writes the provided Rust value to this userspace pointer. + /// + /// Fails with [`EFAULT`] if the write happens on a bad address, or if the write goes out of + /// bounds of this [`UserSliceWriter`]. This call may modify the associated userspace slice even + /// if it returns an error. + pub fn write<T: AsBytes>(&mut self, value: &T) -> Result { + let len = size_of::<T>(); + if len > self.length { + return Err(EFAULT); + } + // SAFETY: The reference points to a value of type `T`, so it is valid for reading + // `size_of::<T>()` bytes. + // + // By using the _copy_to_user variant, we skip the check_object_size check that verifies the + // kernel pointer. This mirrors the logic on the C side that skips the check when the length + // is a compile-time constant. + let res = unsafe { + bindings::_copy_to_user( + self.ptr.as_mut_ptr(), + core::ptr::from_ref(value).cast::<c_void>(), + len, + ) + }; + if res != 0 { + return Err(EFAULT); + } + self.ptr = self.ptr.wrapping_byte_add(len); + self.length -= len; + Ok(()) + } +} + +/// Reads a nul-terminated string into `dst` and returns the length. +/// +/// This reads from userspace until a NUL byte is encountered, or until `dst.len()` bytes have been +/// read. Fails with [`EFAULT`] if a read happens on a bad address (some data may have been +/// copied). When the end of the buffer is encountered, no NUL byte is added, so the string is +/// *not* guaranteed to be NUL-terminated when `Ok(dst.len())` is returned. +/// +/// # Guarantees +/// +/// When this function returns `Ok(len)`, it is guaranteed that the first `len` bytes of `dst` are +/// initialized and non-zero. Furthermore, if `len < dst.len()`, then `dst[len]` is a NUL byte. +#[inline] +fn raw_strncpy_from_user(dst: &mut [MaybeUninit<u8>], src: UserPtr) -> Result<usize> { + // CAST: Slice lengths are guaranteed to be `<= isize::MAX`. + let len = dst.len() as isize; + + // SAFETY: `dst` is valid for writing `dst.len()` bytes. + let res = unsafe { + bindings::strncpy_from_user( + dst.as_mut_ptr().cast::<c_char>(), + src.as_const_ptr().cast::<c_char>(), + len, + ) + }; + + if res < 0 { + return Err(Error::from_errno(res as i32)); + } + + #[cfg(CONFIG_RUST_OVERFLOW_CHECKS)] + assert!(res <= len); + + // GUARANTEES: `strncpy_from_user` was successful, so `dst` has contents in accordance with the + // guarantees of this function. + Ok(res as usize) +} diff --git a/rust/kernel/usb.rs b/rust/kernel/usb.rs new file mode 100644 index 000000000000..d10b65e9fb6a --- /dev/null +++ b/rust/kernel/usb.rs @@ -0,0 +1,469 @@ +// SPDX-License-Identifier: GPL-2.0 +// SPDX-FileCopyrightText: Copyright (C) 2025 Collabora Ltd. + +//! Abstractions for the USB bus. +//! +//! C header: [`include/linux/usb.h`](srctree/include/linux/usb.h) + +use crate::{ + bindings, device, + device_id::{RawDeviceId, RawDeviceIdIndex}, + driver, + error::{from_result, to_result, Result}, + prelude::*, + str::CStr, + types::{AlwaysRefCounted, Opaque}, + ThisModule, +}; +use core::{ + marker::PhantomData, + mem::{ + offset_of, + MaybeUninit, // + }, + ptr::NonNull, +}; + +/// An adapter for the registration of USB drivers. +pub struct Adapter<T: Driver>(T); + +// SAFETY: A call to `unregister` for a given instance of `RegType` is guaranteed to be valid if +// a preceding call to `register` has been successful. +unsafe impl<T: Driver + 'static> driver::RegistrationOps for Adapter<T> { + type RegType = bindings::usb_driver; + + unsafe fn register( + udrv: &Opaque<Self::RegType>, + name: &'static CStr, + module: &'static ThisModule, + ) -> Result { + // SAFETY: It's safe to set the fields of `struct usb_driver` on initialization. + unsafe { + (*udrv.get()).name = name.as_char_ptr(); + (*udrv.get()).probe = Some(Self::probe_callback); + (*udrv.get()).disconnect = Some(Self::disconnect_callback); + (*udrv.get()).id_table = T::ID_TABLE.as_ptr(); + } + + // SAFETY: `udrv` is guaranteed to be a valid `RegType`. + to_result(unsafe { + bindings::usb_register_driver(udrv.get(), module.0, name.as_char_ptr()) + }) + } + + unsafe fn unregister(udrv: &Opaque<Self::RegType>) { + // SAFETY: `udrv` is guaranteed to be a valid `RegType`. + unsafe { bindings::usb_deregister(udrv.get()) }; + } +} + +impl<T: Driver + 'static> Adapter<T> { + extern "C" fn probe_callback( + intf: *mut bindings::usb_interface, + id: *const bindings::usb_device_id, + ) -> kernel::ffi::c_int { + // SAFETY: The USB core only ever calls the probe callback with a valid pointer to a + // `struct usb_interface` and `struct usb_device_id`. + // + // INVARIANT: `intf` is valid for the duration of `probe_callback()`. + let intf = unsafe { &*intf.cast::<Interface<device::CoreInternal>>() }; + + from_result(|| { + // SAFETY: `DeviceId` is a `#[repr(transparent)]` wrapper of `struct usb_device_id` and + // does not add additional invariants, so it's safe to transmute. + let id = unsafe { &*id.cast::<DeviceId>() }; + + let info = T::ID_TABLE.info(id.index()); + let data = T::probe(intf, id, info); + + let dev: &device::Device<device::CoreInternal> = intf.as_ref(); + dev.set_drvdata(data)?; + Ok(0) + }) + } + + extern "C" fn disconnect_callback(intf: *mut bindings::usb_interface) { + // SAFETY: The USB core only ever calls the disconnect callback with a valid pointer to a + // `struct usb_interface`. + // + // INVARIANT: `intf` is valid for the duration of `disconnect_callback()`. + let intf = unsafe { &*intf.cast::<Interface<device::CoreInternal>>() }; + + let dev: &device::Device<device::CoreInternal> = intf.as_ref(); + + // SAFETY: `disconnect_callback` is only ever called after a successful call to + // `probe_callback`, hence it's guaranteed that `Device::set_drvdata()` has been called + // and stored a `Pin<KBox<T>>`. + let data = unsafe { dev.drvdata_obtain::<T>() }; + + T::disconnect(intf, data.as_ref()); + } +} + +/// Abstraction for the USB device ID structure, i.e. [`struct usb_device_id`]. +/// +/// [`struct usb_device_id`]: https://docs.kernel.org/driver-api/basics.html#c.usb_device_id +#[repr(transparent)] +#[derive(Clone, Copy)] +pub struct DeviceId(bindings::usb_device_id); + +impl DeviceId { + /// Equivalent to C's `USB_DEVICE` macro. + pub const fn from_id(vendor: u16, product: u16) -> Self { + Self(bindings::usb_device_id { + match_flags: bindings::USB_DEVICE_ID_MATCH_DEVICE as u16, + idVendor: vendor, + idProduct: product, + // SAFETY: It is safe to use all zeroes for the other fields of `usb_device_id`. + ..unsafe { MaybeUninit::zeroed().assume_init() } + }) + } + + /// Equivalent to C's `USB_DEVICE_VER` macro. + pub const fn from_device_ver(vendor: u16, product: u16, bcd_lo: u16, bcd_hi: u16) -> Self { + Self(bindings::usb_device_id { + match_flags: bindings::USB_DEVICE_ID_MATCH_DEVICE_AND_VERSION as u16, + idVendor: vendor, + idProduct: product, + bcdDevice_lo: bcd_lo, + bcdDevice_hi: bcd_hi, + // SAFETY: It is safe to use all zeroes for the other fields of `usb_device_id`. + ..unsafe { MaybeUninit::zeroed().assume_init() } + }) + } + + /// Equivalent to C's `USB_DEVICE_INFO` macro. + pub const fn from_device_info(class: u8, subclass: u8, protocol: u8) -> Self { + Self(bindings::usb_device_id { + match_flags: bindings::USB_DEVICE_ID_MATCH_DEV_INFO as u16, + bDeviceClass: class, + bDeviceSubClass: subclass, + bDeviceProtocol: protocol, + // SAFETY: It is safe to use all zeroes for the other fields of `usb_device_id`. + ..unsafe { MaybeUninit::zeroed().assume_init() } + }) + } + + /// Equivalent to C's `USB_INTERFACE_INFO` macro. + pub const fn from_interface_info(class: u8, subclass: u8, protocol: u8) -> Self { + Self(bindings::usb_device_id { + match_flags: bindings::USB_DEVICE_ID_MATCH_INT_INFO as u16, + bInterfaceClass: class, + bInterfaceSubClass: subclass, + bInterfaceProtocol: protocol, + // SAFETY: It is safe to use all zeroes for the other fields of `usb_device_id`. + ..unsafe { MaybeUninit::zeroed().assume_init() } + }) + } + + /// Equivalent to C's `USB_DEVICE_INTERFACE_CLASS` macro. + pub const fn from_device_interface_class(vendor: u16, product: u16, class: u8) -> Self { + Self(bindings::usb_device_id { + match_flags: (bindings::USB_DEVICE_ID_MATCH_DEVICE + | bindings::USB_DEVICE_ID_MATCH_INT_CLASS) as u16, + idVendor: vendor, + idProduct: product, + bInterfaceClass: class, + // SAFETY: It is safe to use all zeroes for the other fields of `usb_device_id`. + ..unsafe { MaybeUninit::zeroed().assume_init() } + }) + } + + /// Equivalent to C's `USB_DEVICE_INTERFACE_PROTOCOL` macro. + pub const fn from_device_interface_protocol(vendor: u16, product: u16, protocol: u8) -> Self { + Self(bindings::usb_device_id { + match_flags: (bindings::USB_DEVICE_ID_MATCH_DEVICE + | bindings::USB_DEVICE_ID_MATCH_INT_PROTOCOL) as u16, + idVendor: vendor, + idProduct: product, + bInterfaceProtocol: protocol, + // SAFETY: It is safe to use all zeroes for the other fields of `usb_device_id`. + ..unsafe { MaybeUninit::zeroed().assume_init() } + }) + } + + /// Equivalent to C's `USB_DEVICE_INTERFACE_NUMBER` macro. + pub const fn from_device_interface_number(vendor: u16, product: u16, number: u8) -> Self { + Self(bindings::usb_device_id { + match_flags: (bindings::USB_DEVICE_ID_MATCH_DEVICE + | bindings::USB_DEVICE_ID_MATCH_INT_NUMBER) as u16, + idVendor: vendor, + idProduct: product, + bInterfaceNumber: number, + // SAFETY: It is safe to use all zeroes for the other fields of `usb_device_id`. + ..unsafe { MaybeUninit::zeroed().assume_init() } + }) + } + + /// Equivalent to C's `USB_DEVICE_AND_INTERFACE_INFO` macro. + pub const fn from_device_and_interface_info( + vendor: u16, + product: u16, + class: u8, + subclass: u8, + protocol: u8, + ) -> Self { + Self(bindings::usb_device_id { + match_flags: (bindings::USB_DEVICE_ID_MATCH_INT_INFO + | bindings::USB_DEVICE_ID_MATCH_DEVICE) as u16, + idVendor: vendor, + idProduct: product, + bInterfaceClass: class, + bInterfaceSubClass: subclass, + bInterfaceProtocol: protocol, + // SAFETY: It is safe to use all zeroes for the other fields of `usb_device_id`. + ..unsafe { MaybeUninit::zeroed().assume_init() } + }) + } +} + +// SAFETY: `DeviceId` is a `#[repr(transparent)]` wrapper of `usb_device_id` and does not add +// additional invariants, so it's safe to transmute to `RawType`. +unsafe impl RawDeviceId for DeviceId { + type RawType = bindings::usb_device_id; +} + +// SAFETY: `DRIVER_DATA_OFFSET` is the offset to the `driver_info` field. +unsafe impl RawDeviceIdIndex for DeviceId { + const DRIVER_DATA_OFFSET: usize = core::mem::offset_of!(bindings::usb_device_id, driver_info); + + fn index(&self) -> usize { + self.0.driver_info + } +} + +/// [`IdTable`](kernel::device_id::IdTable) type for USB. +pub type IdTable<T> = &'static dyn kernel::device_id::IdTable<DeviceId, T>; + +/// Create a USB `IdTable` with its alias for modpost. +#[macro_export] +macro_rules! usb_device_table { + ($table_name:ident, $module_table_name:ident, $id_info_type: ty, $table_data: expr) => { + const $table_name: $crate::device_id::IdArray< + $crate::usb::DeviceId, + $id_info_type, + { $table_data.len() }, + > = $crate::device_id::IdArray::new($table_data); + + $crate::module_device_table!("usb", $module_table_name, $table_name); + }; +} + +/// The USB driver trait. +/// +/// # Examples +/// +///``` +/// # use kernel::{bindings, device::Core, usb}; +/// use kernel::prelude::*; +/// +/// struct MyDriver; +/// +/// kernel::usb_device_table!( +/// USB_TABLE, +/// MODULE_USB_TABLE, +/// <MyDriver as usb::Driver>::IdInfo, +/// [ +/// (usb::DeviceId::from_id(0x1234, 0x5678), ()), +/// (usb::DeviceId::from_id(0xabcd, 0xef01), ()), +/// ] +/// ); +/// +/// impl usb::Driver for MyDriver { +/// type IdInfo = (); +/// const ID_TABLE: usb::IdTable<Self::IdInfo> = &USB_TABLE; +/// +/// fn probe( +/// _interface: &usb::Interface<Core>, +/// _id: &usb::DeviceId, +/// _info: &Self::IdInfo, +/// ) -> impl PinInit<Self, Error> { +/// Err(ENODEV) +/// } +/// +/// fn disconnect(_interface: &usb::Interface<Core>, _data: Pin<&Self>) {} +/// } +///``` +pub trait Driver { + /// The type holding information about each one of the device ids supported by the driver. + type IdInfo: 'static; + + /// The table of device ids supported by the driver. + const ID_TABLE: IdTable<Self::IdInfo>; + + /// USB driver probe. + /// + /// Called when a new USB interface is bound to this driver. + /// Implementers should attempt to initialize the interface here. + fn probe( + interface: &Interface<device::Core>, + id: &DeviceId, + id_info: &Self::IdInfo, + ) -> impl PinInit<Self, Error>; + + /// USB driver disconnect. + /// + /// Called when the USB interface is about to be unbound from this driver. + fn disconnect(interface: &Interface<device::Core>, data: Pin<&Self>); +} + +/// A USB interface. +/// +/// This structure represents the Rust abstraction for a C [`struct usb_interface`]. +/// The implementation abstracts the usage of a C [`struct usb_interface`] passed +/// in from the C side. +/// +/// # Invariants +/// +/// An [`Interface`] instance represents a valid [`struct usb_interface`] created +/// by the C portion of the kernel. +/// +/// [`struct usb_interface`]: https://www.kernel.org/doc/html/latest/driver-api/usb/usb.html#c.usb_interface +#[repr(transparent)] +pub struct Interface<Ctx: device::DeviceContext = device::Normal>( + Opaque<bindings::usb_interface>, + PhantomData<Ctx>, +); + +impl<Ctx: device::DeviceContext> Interface<Ctx> { + fn as_raw(&self) -> *mut bindings::usb_interface { + self.0.get() + } +} + +// SAFETY: `usb::Interface` is a transparent wrapper of `struct usb_interface`. +// The offset is guaranteed to point to a valid device field inside `usb::Interface`. +unsafe impl<Ctx: device::DeviceContext> device::AsBusDevice<Ctx> for Interface<Ctx> { + const OFFSET: usize = offset_of!(bindings::usb_interface, dev); +} + +// SAFETY: `Interface` is a transparent wrapper of a type that doesn't depend on +// `Interface`'s generic argument. +kernel::impl_device_context_deref!(unsafe { Interface }); +kernel::impl_device_context_into_aref!(Interface); + +impl<Ctx: device::DeviceContext> AsRef<device::Device<Ctx>> for Interface<Ctx> { + fn as_ref(&self) -> &device::Device<Ctx> { + // SAFETY: By the type invariant of `Self`, `self.as_raw()` is a pointer to a valid + // `struct usb_interface`. + let dev = unsafe { &raw mut ((*self.as_raw()).dev) }; + + // SAFETY: `dev` points to a valid `struct device`. + unsafe { device::Device::from_raw(dev) } + } +} + +impl<Ctx: device::DeviceContext> AsRef<Device> for Interface<Ctx> { + fn as_ref(&self) -> &Device { + // SAFETY: `self.as_raw()` is valid by the type invariants. + let usb_dev = unsafe { bindings::interface_to_usbdev(self.as_raw()) }; + + // SAFETY: For a valid `struct usb_interface` pointer, the above call to + // `interface_to_usbdev()` guarantees to return a valid pointer to a `struct usb_device`. + unsafe { &*(usb_dev.cast()) } + } +} + +// SAFETY: Instances of `Interface` are always reference-counted. +unsafe impl AlwaysRefCounted for Interface { + fn inc_ref(&self) { + // SAFETY: The invariants of `Interface` guarantee that `self.as_raw()` + // returns a valid `struct usb_interface` pointer, for which we will + // acquire a new refcount. + unsafe { bindings::usb_get_intf(self.as_raw()) }; + } + + unsafe fn dec_ref(obj: NonNull<Self>) { + // SAFETY: The safety requirements guarantee that the refcount is non-zero. + unsafe { bindings::usb_put_intf(obj.cast().as_ptr()) } + } +} + +// SAFETY: A `Interface` is always reference-counted and can be released from any thread. +unsafe impl Send for Interface {} + +// SAFETY: It is safe to send a &Interface to another thread because we do not +// allow any mutation through a shared reference. +unsafe impl Sync for Interface {} + +/// A USB device. +/// +/// This structure represents the Rust abstraction for a C [`struct usb_device`]. +/// The implementation abstracts the usage of a C [`struct usb_device`] passed in +/// from the C side. +/// +/// # Invariants +/// +/// A [`Device`] instance represents a valid [`struct usb_device`] created by the C portion of the +/// kernel. +/// +/// [`struct usb_device`]: https://www.kernel.org/doc/html/latest/driver-api/usb/usb.html#c.usb_device +#[repr(transparent)] +struct Device<Ctx: device::DeviceContext = device::Normal>( + Opaque<bindings::usb_device>, + PhantomData<Ctx>, +); + +impl<Ctx: device::DeviceContext> Device<Ctx> { + fn as_raw(&self) -> *mut bindings::usb_device { + self.0.get() + } +} + +// SAFETY: `Device` is a transparent wrapper of a type that doesn't depend on `Device`'s generic +// argument. +kernel::impl_device_context_deref!(unsafe { Device }); +kernel::impl_device_context_into_aref!(Device); + +// SAFETY: Instances of `Device` are always reference-counted. +unsafe impl AlwaysRefCounted for Device { + fn inc_ref(&self) { + // SAFETY: The invariants of `Device` guarantee that `self.as_raw()` + // returns a valid `struct usb_device` pointer, for which we will + // acquire a new refcount. + unsafe { bindings::usb_get_dev(self.as_raw()) }; + } + + unsafe fn dec_ref(obj: NonNull<Self>) { + // SAFETY: The safety requirements guarantee that the refcount is non-zero. + unsafe { bindings::usb_put_dev(obj.cast().as_ptr()) } + } +} + +impl<Ctx: device::DeviceContext> AsRef<device::Device<Ctx>> for Device<Ctx> { + fn as_ref(&self) -> &device::Device<Ctx> { + // SAFETY: By the type invariant of `Self`, `self.as_raw()` is a pointer to a valid + // `struct usb_device`. + let dev = unsafe { &raw mut ((*self.as_raw()).dev) }; + + // SAFETY: `dev` points to a valid `struct device`. + unsafe { device::Device::from_raw(dev) } + } +} + +// SAFETY: A `Device` is always reference-counted and can be released from any thread. +unsafe impl Send for Device {} + +// SAFETY: It is safe to send a &Device to another thread because we do not +// allow any mutation through a shared reference. +unsafe impl Sync for Device {} + +/// Declares a kernel module that exposes a single USB driver. +/// +/// # Examples +/// +/// ```ignore +/// module_usb_driver! { +/// type: MyDriver, +/// name: "Module name", +/// author: ["Author name"], +/// description: "Description", +/// license: "GPL v2", +/// } +/// ``` +#[macro_export] +macro_rules! module_usb_driver { + ($($f:tt)*) => { + $crate::module_driver!(<T>, $crate::usb::Adapter<T>, { $($f)* }); + } +} diff --git a/rust/kernel/workqueue.rs b/rust/kernel/workqueue.rs new file mode 100644 index 000000000000..706e833e9702 --- /dev/null +++ b/rust/kernel/workqueue.rs @@ -0,0 +1,1024 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! Work queues. +//! +//! This file has two components: The raw work item API, and the safe work item API. +//! +//! One pattern that is used in both APIs is the `ID` const generic, which exists to allow a single +//! type to define multiple `work_struct` fields. This is done by choosing an id for each field, +//! and using that id to specify which field you wish to use. (The actual value doesn't matter, as +//! long as you use different values for different fields of the same struct.) Since these IDs are +//! generic, they are used only at compile-time, so they shouldn't exist in the final binary. +//! +//! # The raw API +//! +//! The raw API consists of the [`RawWorkItem`] trait, where the work item needs to provide an +//! arbitrary function that knows how to enqueue the work item. It should usually not be used +//! directly, but if you want to, you can use it without using the pieces from the safe API. +//! +//! # The safe API +//! +//! The safe API is used via the [`Work`] struct and [`WorkItem`] traits. Furthermore, it also +//! includes a trait called [`WorkItemPointer`], which is usually not used directly by the user. +//! +//! * The [`Work`] struct is the Rust wrapper for the C `work_struct` type. +//! * The [`WorkItem`] trait is implemented for structs that can be enqueued to a workqueue. +//! * The [`WorkItemPointer`] trait is implemented for the pointer type that points at a something +//! that implements [`WorkItem`]. +//! +//! ## Examples +//! +//! This example defines a struct that holds an integer and can be scheduled on the workqueue. When +//! the struct is executed, it will print the integer. Since there is only one `work_struct` field, +//! we do not need to specify ids for the fields. +//! +//! ``` +//! use kernel::sync::Arc; +//! use kernel::workqueue::{self, impl_has_work, new_work, Work, WorkItem}; +//! +//! #[pin_data] +//! struct MyStruct { +//! value: i32, +//! #[pin] +//! work: Work<MyStruct>, +//! } +//! +//! impl_has_work! { +//! impl HasWork<Self> for MyStruct { self.work } +//! } +//! +//! impl MyStruct { +//! fn new(value: i32) -> Result<Arc<Self>> { +//! Arc::pin_init(pin_init!(MyStruct { +//! value, +//! work <- new_work!("MyStruct::work"), +//! }), GFP_KERNEL) +//! } +//! } +//! +//! impl WorkItem for MyStruct { +//! type Pointer = Arc<MyStruct>; +//! +//! fn run(this: Arc<MyStruct>) { +//! pr_info!("The value is: {}\n", this.value); +//! } +//! } +//! +//! /// This method will enqueue the struct for execution on the system workqueue, where its value +//! /// will be printed. +//! fn print_later(val: Arc<MyStruct>) { +//! let _ = workqueue::system().enqueue(val); +//! } +//! # print_later(MyStruct::new(42).unwrap()); +//! ``` +//! +//! The following example shows how multiple `work_struct` fields can be used: +//! +//! ``` +//! use kernel::sync::Arc; +//! use kernel::workqueue::{self, impl_has_work, new_work, Work, WorkItem}; +//! +//! #[pin_data] +//! struct MyStruct { +//! value_1: i32, +//! value_2: i32, +//! #[pin] +//! work_1: Work<MyStruct, 1>, +//! #[pin] +//! work_2: Work<MyStruct, 2>, +//! } +//! +//! impl_has_work! { +//! impl HasWork<Self, 1> for MyStruct { self.work_1 } +//! impl HasWork<Self, 2> for MyStruct { self.work_2 } +//! } +//! +//! impl MyStruct { +//! fn new(value_1: i32, value_2: i32) -> Result<Arc<Self>> { +//! Arc::pin_init(pin_init!(MyStruct { +//! value_1, +//! value_2, +//! work_1 <- new_work!("MyStruct::work_1"), +//! work_2 <- new_work!("MyStruct::work_2"), +//! }), GFP_KERNEL) +//! } +//! } +//! +//! impl WorkItem<1> for MyStruct { +//! type Pointer = Arc<MyStruct>; +//! +//! fn run(this: Arc<MyStruct>) { +//! pr_info!("The value is: {}\n", this.value_1); +//! } +//! } +//! +//! impl WorkItem<2> for MyStruct { +//! type Pointer = Arc<MyStruct>; +//! +//! fn run(this: Arc<MyStruct>) { +//! pr_info!("The second value is: {}\n", this.value_2); +//! } +//! } +//! +//! fn print_1_later(val: Arc<MyStruct>) { +//! let _ = workqueue::system().enqueue::<Arc<MyStruct>, 1>(val); +//! } +//! +//! fn print_2_later(val: Arc<MyStruct>) { +//! let _ = workqueue::system().enqueue::<Arc<MyStruct>, 2>(val); +//! } +//! # print_1_later(MyStruct::new(24, 25).unwrap()); +//! # print_2_later(MyStruct::new(41, 42).unwrap()); +//! ``` +//! +//! This example shows how you can schedule delayed work items: +//! +//! ``` +//! use kernel::sync::Arc; +//! use kernel::workqueue::{self, impl_has_delayed_work, new_delayed_work, DelayedWork, WorkItem}; +//! +//! #[pin_data] +//! struct MyStruct { +//! value: i32, +//! #[pin] +//! work: DelayedWork<MyStruct>, +//! } +//! +//! impl_has_delayed_work! { +//! impl HasDelayedWork<Self> for MyStruct { self.work } +//! } +//! +//! impl MyStruct { +//! fn new(value: i32) -> Result<Arc<Self>> { +//! Arc::pin_init( +//! pin_init!(MyStruct { +//! value, +//! work <- new_delayed_work!("MyStruct::work"), +//! }), +//! GFP_KERNEL, +//! ) +//! } +//! } +//! +//! impl WorkItem for MyStruct { +//! type Pointer = Arc<MyStruct>; +//! +//! fn run(this: Arc<MyStruct>) { +//! pr_info!("The value is: {}\n", this.value); +//! } +//! } +//! +//! /// This method will enqueue the struct for execution on the system workqueue, where its value +//! /// will be printed 12 jiffies later. +//! fn print_later(val: Arc<MyStruct>) { +//! let _ = workqueue::system().enqueue_delayed(val, 12); +//! } +//! +//! /// It is also possible to use the ordinary `enqueue` method together with `DelayedWork`. This +//! /// is equivalent to calling `enqueue_delayed` with a delay of zero. +//! fn print_now(val: Arc<MyStruct>) { +//! let _ = workqueue::system().enqueue(val); +//! } +//! # print_later(MyStruct::new(42).unwrap()); +//! # print_now(MyStruct::new(42).unwrap()); +//! ``` +//! +//! C header: [`include/linux/workqueue.h`](srctree/include/linux/workqueue.h) + +use crate::{ + alloc::{AllocError, Flags}, + container_of, + prelude::*, + sync::Arc, + sync::LockClassKey, + time::Jiffies, + types::Opaque, +}; +use core::marker::PhantomData; + +/// Creates a [`Work`] initialiser with the given name and a newly-created lock class. +#[macro_export] +macro_rules! new_work { + ($($name:literal)?) => { + $crate::workqueue::Work::new($crate::optional_name!($($name)?), $crate::static_lock_class!()) + }; +} +pub use new_work; + +/// Creates a [`DelayedWork`] initialiser with the given name and a newly-created lock class. +#[macro_export] +macro_rules! new_delayed_work { + () => { + $crate::workqueue::DelayedWork::new( + $crate::optional_name!(), + $crate::static_lock_class!(), + $crate::c_str!(::core::concat!( + ::core::file!(), + ":", + ::core::line!(), + "_timer" + )), + $crate::static_lock_class!(), + ) + }; + ($name:literal) => { + $crate::workqueue::DelayedWork::new( + $crate::c_str!($name), + $crate::static_lock_class!(), + $crate::c_str!(::core::concat!($name, "_timer")), + $crate::static_lock_class!(), + ) + }; +} +pub use new_delayed_work; + +/// A kernel work queue. +/// +/// Wraps the kernel's C `struct workqueue_struct`. +/// +/// It allows work items to be queued to run on thread pools managed by the kernel. Several are +/// always available, for example, `system`, `system_highpri`, `system_long`, etc. +#[repr(transparent)] +pub struct Queue(Opaque<bindings::workqueue_struct>); + +// SAFETY: Accesses to workqueues used by [`Queue`] are thread-safe. +unsafe impl Send for Queue {} +// SAFETY: Accesses to workqueues used by [`Queue`] are thread-safe. +unsafe impl Sync for Queue {} + +impl Queue { + /// Use the provided `struct workqueue_struct` with Rust. + /// + /// # Safety + /// + /// The caller must ensure that the provided raw pointer is not dangling, that it points at a + /// valid workqueue, and that it remains valid until the end of `'a`. + pub unsafe fn from_raw<'a>(ptr: *const bindings::workqueue_struct) -> &'a Queue { + // SAFETY: The `Queue` type is `#[repr(transparent)]`, so the pointer cast is valid. The + // caller promises that the pointer is not dangling. + unsafe { &*ptr.cast::<Queue>() } + } + + /// Enqueues a work item. + /// + /// This may fail if the work item is already enqueued in a workqueue. + /// + /// The work item will be submitted using `WORK_CPU_UNBOUND`. + pub fn enqueue<W, const ID: u64>(&self, w: W) -> W::EnqueueOutput + where + W: RawWorkItem<ID> + Send + 'static, + { + let queue_ptr = self.0.get(); + + // SAFETY: We only return `false` if the `work_struct` is already in a workqueue. The other + // `__enqueue` requirements are not relevant since `W` is `Send` and static. + // + // The call to `bindings::queue_work_on` will dereference the provided raw pointer, which + // is ok because `__enqueue` guarantees that the pointer is valid for the duration of this + // closure. + // + // Furthermore, if the C workqueue code accesses the pointer after this call to + // `__enqueue`, then the work item was successfully enqueued, and `bindings::queue_work_on` + // will have returned true. In this case, `__enqueue` promises that the raw pointer will + // stay valid until we call the function pointer in the `work_struct`, so the access is ok. + unsafe { + w.__enqueue(move |work_ptr| { + bindings::queue_work_on( + bindings::wq_misc_consts_WORK_CPU_UNBOUND as ffi::c_int, + queue_ptr, + work_ptr, + ) + }) + } + } + + /// Enqueues a delayed work item. + /// + /// This may fail if the work item is already enqueued in a workqueue. + /// + /// The work item will be submitted using `WORK_CPU_UNBOUND`. + pub fn enqueue_delayed<W, const ID: u64>(&self, w: W, delay: Jiffies) -> W::EnqueueOutput + where + W: RawDelayedWorkItem<ID> + Send + 'static, + { + let queue_ptr = self.0.get(); + + // SAFETY: We only return `false` if the `work_struct` is already in a workqueue. The other + // `__enqueue` requirements are not relevant since `W` is `Send` and static. + // + // The call to `bindings::queue_delayed_work_on` will dereference the provided raw pointer, + // which is ok because `__enqueue` guarantees that the pointer is valid for the duration of + // this closure, and the safety requirements of `RawDelayedWorkItem` expands this + // requirement to apply to the entire `delayed_work`. + // + // Furthermore, if the C workqueue code accesses the pointer after this call to + // `__enqueue`, then the work item was successfully enqueued, and + // `bindings::queue_delayed_work_on` will have returned true. In this case, `__enqueue` + // promises that the raw pointer will stay valid until we call the function pointer in the + // `work_struct`, so the access is ok. + unsafe { + w.__enqueue(move |work_ptr| { + bindings::queue_delayed_work_on( + bindings::wq_misc_consts_WORK_CPU_UNBOUND as ffi::c_int, + queue_ptr, + container_of!(work_ptr, bindings::delayed_work, work), + delay, + ) + }) + } + } + + /// Tries to spawn the given function or closure as a work item. + /// + /// This method can fail because it allocates memory to store the work item. + pub fn try_spawn<T: 'static + Send + FnOnce()>( + &self, + flags: Flags, + func: T, + ) -> Result<(), AllocError> { + let init = pin_init!(ClosureWork { + work <- new_work!("Queue::try_spawn"), + func: Some(func), + }); + + self.enqueue(KBox::pin_init(init, flags).map_err(|_| AllocError)?); + Ok(()) + } +} + +/// A helper type used in [`try_spawn`]. +/// +/// [`try_spawn`]: Queue::try_spawn +#[pin_data] +struct ClosureWork<T> { + #[pin] + work: Work<ClosureWork<T>>, + func: Option<T>, +} + +impl<T: FnOnce()> WorkItem for ClosureWork<T> { + type Pointer = Pin<KBox<Self>>; + + fn run(mut this: Pin<KBox<Self>>) { + if let Some(func) = this.as_mut().project().func.take() { + (func)() + } + } +} + +/// A raw work item. +/// +/// This is the low-level trait that is designed for being as general as possible. +/// +/// The `ID` parameter to this trait exists so that a single type can provide multiple +/// implementations of this trait. For example, if a struct has multiple `work_struct` fields, then +/// you will implement this trait once for each field, using a different id for each field. The +/// actual value of the id is not important as long as you use different ids for different fields +/// of the same struct. (Fields of different structs need not use different ids.) +/// +/// Note that the id is used only to select the right method to call during compilation. It won't be +/// part of the final executable. +/// +/// # Safety +/// +/// Implementers must ensure that any pointers passed to a `queue_work_on` closure by [`__enqueue`] +/// remain valid for the duration specified in the guarantees section of the documentation for +/// [`__enqueue`]. +/// +/// [`__enqueue`]: RawWorkItem::__enqueue +pub unsafe trait RawWorkItem<const ID: u64> { + /// The return type of [`Queue::enqueue`]. + type EnqueueOutput; + + /// Enqueues this work item on a queue using the provided `queue_work_on` method. + /// + /// # Guarantees + /// + /// If this method calls the provided closure, then the raw pointer is guaranteed to point at a + /// valid `work_struct` for the duration of the call to the closure. If the closure returns + /// true, then it is further guaranteed that the pointer remains valid until someone calls the + /// function pointer stored in the `work_struct`. + /// + /// # Safety + /// + /// The provided closure may only return `false` if the `work_struct` is already in a workqueue. + /// + /// If the work item type is annotated with any lifetimes, then you must not call the function + /// pointer after any such lifetime expires. (Never calling the function pointer is okay.) + /// + /// If the work item type is not [`Send`], then the function pointer must be called on the same + /// thread as the call to `__enqueue`. + unsafe fn __enqueue<F>(self, queue_work_on: F) -> Self::EnqueueOutput + where + F: FnOnce(*mut bindings::work_struct) -> bool; +} + +/// A raw delayed work item. +/// +/// # Safety +/// +/// If the `__enqueue` method in the `RawWorkItem` implementation calls the closure, then the +/// provided pointer must point at the `work` field of a valid `delayed_work`, and the guarantees +/// that `__enqueue` provides about accessing the `work_struct` must also apply to the rest of the +/// `delayed_work` struct. +pub unsafe trait RawDelayedWorkItem<const ID: u64>: RawWorkItem<ID> {} + +/// Defines the method that should be called directly when a work item is executed. +/// +/// This trait is implemented by `Pin<KBox<T>>` and [`Arc<T>`], and is mainly intended to be +/// implemented for smart pointer types. For your own structs, you would implement [`WorkItem`] +/// instead. The [`run`] method on this trait will usually just perform the appropriate +/// `container_of` translation and then call into the [`run`][WorkItem::run] method from the +/// [`WorkItem`] trait. +/// +/// This trait is used when the `work_struct` field is defined using the [`Work`] helper. +/// +/// # Safety +/// +/// Implementers must ensure that [`__enqueue`] uses a `work_struct` initialized with the [`run`] +/// method of this trait as the function pointer. +/// +/// [`__enqueue`]: RawWorkItem::__enqueue +/// [`run`]: WorkItemPointer::run +pub unsafe trait WorkItemPointer<const ID: u64>: RawWorkItem<ID> { + /// Run this work item. + /// + /// # Safety + /// + /// The provided `work_struct` pointer must originate from a previous call to [`__enqueue`] + /// where the `queue_work_on` closure returned true, and the pointer must still be valid. + /// + /// [`__enqueue`]: RawWorkItem::__enqueue + unsafe extern "C" fn run(ptr: *mut bindings::work_struct); +} + +/// Defines the method that should be called when this work item is executed. +/// +/// This trait is used when the `work_struct` field is defined using the [`Work`] helper. +pub trait WorkItem<const ID: u64 = 0> { + /// The pointer type that this struct is wrapped in. This will typically be `Arc<Self>` or + /// `Pin<KBox<Self>>`. + type Pointer: WorkItemPointer<ID>; + + /// The method that should be called when this work item is executed. + fn run(this: Self::Pointer); +} + +/// Links for a work item. +/// +/// This struct contains a function pointer to the [`run`] function from the [`WorkItemPointer`] +/// trait, and defines the linked list pointers necessary to enqueue a work item in a workqueue. +/// +/// Wraps the kernel's C `struct work_struct`. +/// +/// This is a helper type used to associate a `work_struct` with the [`WorkItem`] that uses it. +/// +/// [`run`]: WorkItemPointer::run +#[pin_data] +#[repr(transparent)] +pub struct Work<T: ?Sized, const ID: u64 = 0> { + #[pin] + work: Opaque<bindings::work_struct>, + _inner: PhantomData<T>, +} + +// SAFETY: Kernel work items are usable from any thread. +// +// We do not need to constrain `T` since the work item does not actually contain a `T`. +unsafe impl<T: ?Sized, const ID: u64> Send for Work<T, ID> {} +// SAFETY: Kernel work items are usable from any thread. +// +// We do not need to constrain `T` since the work item does not actually contain a `T`. +unsafe impl<T: ?Sized, const ID: u64> Sync for Work<T, ID> {} + +impl<T: ?Sized, const ID: u64> Work<T, ID> { + /// Creates a new instance of [`Work`]. + #[inline] + pub fn new(name: &'static CStr, key: Pin<&'static LockClassKey>) -> impl PinInit<Self> + where + T: WorkItem<ID>, + { + pin_init!(Self { + work <- Opaque::ffi_init(|slot| { + // SAFETY: The `WorkItemPointer` implementation promises that `run` can be used as + // the work item function. + unsafe { + bindings::init_work_with_key( + slot, + Some(T::Pointer::run), + false, + name.as_char_ptr(), + key.as_ptr(), + ) + } + }), + _inner: PhantomData, + }) + } + + /// Get a pointer to the inner `work_struct`. + /// + /// # Safety + /// + /// The provided pointer must not be dangling and must be properly aligned. (But the memory + /// need not be initialized.) + #[inline] + pub unsafe fn raw_get(ptr: *const Self) -> *mut bindings::work_struct { + // SAFETY: The caller promises that the pointer is aligned and not dangling. + // + // A pointer cast would also be ok due to `#[repr(transparent)]`. We use `addr_of!` so that + // the compiler does not complain that the `work` field is unused. + unsafe { Opaque::cast_into(core::ptr::addr_of!((*ptr).work)) } + } +} + +/// Declares that a type contains a [`Work<T, ID>`]. +/// +/// The intended way of using this trait is via the [`impl_has_work!`] macro. You can use the macro +/// like this: +/// +/// ```no_run +/// use kernel::workqueue::{impl_has_work, Work}; +/// +/// struct MyWorkItem { +/// work_field: Work<MyWorkItem, 1>, +/// } +/// +/// impl_has_work! { +/// impl HasWork<MyWorkItem, 1> for MyWorkItem { self.work_field } +/// } +/// ``` +/// +/// Note that since the [`Work`] type is annotated with an id, you can have several `work_struct` +/// fields by using a different id for each one. +/// +/// # Safety +/// +/// The methods [`raw_get_work`] and [`work_container_of`] must return valid pointers and must be +/// true inverses of each other; that is, they must satisfy the following invariants: +/// - `work_container_of(raw_get_work(ptr)) == ptr` for any `ptr: *mut Self`. +/// - `raw_get_work(work_container_of(ptr)) == ptr` for any `ptr: *mut Work<T, ID>`. +/// +/// [`impl_has_work!`]: crate::impl_has_work +/// [`raw_get_work`]: HasWork::raw_get_work +/// [`work_container_of`]: HasWork::work_container_of +pub unsafe trait HasWork<T, const ID: u64 = 0> { + /// Returns a pointer to the [`Work<T, ID>`] field. + /// + /// # Safety + /// + /// The provided pointer must point at a valid struct of type `Self`. + unsafe fn raw_get_work(ptr: *mut Self) -> *mut Work<T, ID>; + + /// Returns a pointer to the struct containing the [`Work<T, ID>`] field. + /// + /// # Safety + /// + /// The pointer must point at a [`Work<T, ID>`] field in a struct of type `Self`. + unsafe fn work_container_of(ptr: *mut Work<T, ID>) -> *mut Self; +} + +/// Used to safely implement the [`HasWork<T, ID>`] trait. +/// +/// # Examples +/// +/// ``` +/// use kernel::sync::Arc; +/// use kernel::workqueue::{self, impl_has_work, Work}; +/// +/// struct MyStruct<'a, T, const N: usize> { +/// work_field: Work<MyStruct<'a, T, N>, 17>, +/// f: fn(&'a [T; N]), +/// } +/// +/// impl_has_work! { +/// impl{'a, T, const N: usize} HasWork<MyStruct<'a, T, N>, 17> +/// for MyStruct<'a, T, N> { self.work_field } +/// } +/// ``` +#[macro_export] +macro_rules! impl_has_work { + ($(impl$({$($generics:tt)*})? + HasWork<$work_type:ty $(, $id:tt)?> + for $self:ty + { self.$field:ident } + )*) => {$( + // SAFETY: The implementation of `raw_get_work` only compiles if the field has the right + // type. + unsafe impl$(<$($generics)+>)? $crate::workqueue::HasWork<$work_type $(, $id)?> for $self { + #[inline] + unsafe fn raw_get_work(ptr: *mut Self) -> *mut $crate::workqueue::Work<$work_type $(, $id)?> { + // SAFETY: The caller promises that the pointer is not dangling. + unsafe { + ::core::ptr::addr_of_mut!((*ptr).$field) + } + } + + #[inline] + unsafe fn work_container_of( + ptr: *mut $crate::workqueue::Work<$work_type $(, $id)?>, + ) -> *mut Self { + // SAFETY: The caller promises that the pointer points at a field of the right type + // in the right kind of struct. + unsafe { $crate::container_of!(ptr, Self, $field) } + } + } + )*}; +} +pub use impl_has_work; + +impl_has_work! { + impl{T} HasWork<Self> for ClosureWork<T> { self.work } +} + +/// Links for a delayed work item. +/// +/// This struct contains a function pointer to the [`run`] function from the [`WorkItemPointer`] +/// trait, and defines the linked list pointers necessary to enqueue a work item in a workqueue in +/// a delayed manner. +/// +/// Wraps the kernel's C `struct delayed_work`. +/// +/// This is a helper type used to associate a `delayed_work` with the [`WorkItem`] that uses it. +/// +/// [`run`]: WorkItemPointer::run +#[pin_data] +#[repr(transparent)] +pub struct DelayedWork<T: ?Sized, const ID: u64 = 0> { + #[pin] + dwork: Opaque<bindings::delayed_work>, + _inner: PhantomData<T>, +} + +// SAFETY: Kernel work items are usable from any thread. +// +// We do not need to constrain `T` since the work item does not actually contain a `T`. +unsafe impl<T: ?Sized, const ID: u64> Send for DelayedWork<T, ID> {} +// SAFETY: Kernel work items are usable from any thread. +// +// We do not need to constrain `T` since the work item does not actually contain a `T`. +unsafe impl<T: ?Sized, const ID: u64> Sync for DelayedWork<T, ID> {} + +impl<T: ?Sized, const ID: u64> DelayedWork<T, ID> { + /// Creates a new instance of [`DelayedWork`]. + #[inline] + pub fn new( + work_name: &'static CStr, + work_key: Pin<&'static LockClassKey>, + timer_name: &'static CStr, + timer_key: Pin<&'static LockClassKey>, + ) -> impl PinInit<Self> + where + T: WorkItem<ID>, + { + pin_init!(Self { + dwork <- Opaque::ffi_init(|slot: *mut bindings::delayed_work| { + // SAFETY: The `WorkItemPointer` implementation promises that `run` can be used as + // the work item function. + unsafe { + bindings::init_work_with_key( + core::ptr::addr_of_mut!((*slot).work), + Some(T::Pointer::run), + false, + work_name.as_char_ptr(), + work_key.as_ptr(), + ) + } + + // SAFETY: The `delayed_work_timer_fn` function pointer can be used here because + // the timer is embedded in a `struct delayed_work`, and only ever scheduled via + // the core workqueue code, and configured to run in irqsafe context. + unsafe { + bindings::timer_init_key( + core::ptr::addr_of_mut!((*slot).timer), + Some(bindings::delayed_work_timer_fn), + bindings::TIMER_IRQSAFE, + timer_name.as_char_ptr(), + timer_key.as_ptr(), + ) + } + }), + _inner: PhantomData, + }) + } + + /// Get a pointer to the inner `delayed_work`. + /// + /// # Safety + /// + /// The provided pointer must not be dangling and must be properly aligned. (But the memory + /// need not be initialized.) + #[inline] + pub unsafe fn raw_as_work(ptr: *const Self) -> *mut Work<T, ID> { + // SAFETY: The caller promises that the pointer is aligned and not dangling. + let dw: *mut bindings::delayed_work = + unsafe { Opaque::cast_into(core::ptr::addr_of!((*ptr).dwork)) }; + // SAFETY: The caller promises that the pointer is aligned and not dangling. + let wrk: *mut bindings::work_struct = unsafe { core::ptr::addr_of_mut!((*dw).work) }; + // CAST: Work and work_struct have compatible layouts. + wrk.cast() + } +} + +/// Declares that a type contains a [`DelayedWork<T, ID>`]. +/// +/// # Safety +/// +/// The `HasWork<T, ID>` implementation must return a `work_struct` that is stored in the `work` +/// field of a `delayed_work` with the same access rules as the `work_struct`. +pub unsafe trait HasDelayedWork<T, const ID: u64 = 0>: HasWork<T, ID> {} + +/// Used to safely implement the [`HasDelayedWork<T, ID>`] trait. +/// +/// This macro also implements the [`HasWork`] trait, so you do not need to use [`impl_has_work!`] +/// when using this macro. +/// +/// # Examples +/// +/// ``` +/// use kernel::sync::Arc; +/// use kernel::workqueue::{self, impl_has_delayed_work, DelayedWork}; +/// +/// struct MyStruct<'a, T, const N: usize> { +/// work_field: DelayedWork<MyStruct<'a, T, N>, 17>, +/// f: fn(&'a [T; N]), +/// } +/// +/// impl_has_delayed_work! { +/// impl{'a, T, const N: usize} HasDelayedWork<MyStruct<'a, T, N>, 17> +/// for MyStruct<'a, T, N> { self.work_field } +/// } +/// ``` +#[macro_export] +macro_rules! impl_has_delayed_work { + ($(impl$({$($generics:tt)*})? + HasDelayedWork<$work_type:ty $(, $id:tt)?> + for $self:ty + { self.$field:ident } + )*) => {$( + // SAFETY: The implementation of `raw_get_work` only compiles if the field has the right + // type. + unsafe impl$(<$($generics)+>)? + $crate::workqueue::HasDelayedWork<$work_type $(, $id)?> for $self {} + + // SAFETY: The implementation of `raw_get_work` only compiles if the field has the right + // type. + unsafe impl$(<$($generics)+>)? $crate::workqueue::HasWork<$work_type $(, $id)?> for $self { + #[inline] + unsafe fn raw_get_work( + ptr: *mut Self + ) -> *mut $crate::workqueue::Work<$work_type $(, $id)?> { + // SAFETY: The caller promises that the pointer is not dangling. + let ptr: *mut $crate::workqueue::DelayedWork<$work_type $(, $id)?> = unsafe { + ::core::ptr::addr_of_mut!((*ptr).$field) + }; + + // SAFETY: The caller promises that the pointer is not dangling. + unsafe { $crate::workqueue::DelayedWork::raw_as_work(ptr) } + } + + #[inline] + unsafe fn work_container_of( + ptr: *mut $crate::workqueue::Work<$work_type $(, $id)?>, + ) -> *mut Self { + // SAFETY: The caller promises that the pointer points at a field of the right type + // in the right kind of struct. + let ptr = unsafe { $crate::workqueue::Work::raw_get(ptr) }; + + // SAFETY: The caller promises that the pointer points at a field of the right type + // in the right kind of struct. + let delayed_work = unsafe { + $crate::container_of!(ptr, $crate::bindings::delayed_work, work) + }; + + let delayed_work: *mut $crate::workqueue::DelayedWork<$work_type $(, $id)?> = + delayed_work.cast(); + + // SAFETY: The caller promises that the pointer points at a field of the right type + // in the right kind of struct. + unsafe { $crate::container_of!(delayed_work, Self, $field) } + } + } + )*}; +} +pub use impl_has_delayed_work; + +// SAFETY: The `__enqueue` implementation in RawWorkItem uses a `work_struct` initialized with the +// `run` method of this trait as the function pointer because: +// - `__enqueue` gets the `work_struct` from the `Work` field, using `T::raw_get_work`. +// - The only safe way to create a `Work` object is through `Work::new`. +// - `Work::new` makes sure that `T::Pointer::run` is passed to `init_work_with_key`. +// - Finally `Work` and `RawWorkItem` guarantee that the correct `Work` field +// will be used because of the ID const generic bound. This makes sure that `T::raw_get_work` +// uses the correct offset for the `Work` field, and `Work::new` picks the correct +// implementation of `WorkItemPointer` for `Arc<T>`. +unsafe impl<T, const ID: u64> WorkItemPointer<ID> for Arc<T> +where + T: WorkItem<ID, Pointer = Self>, + T: HasWork<T, ID>, +{ + unsafe extern "C" fn run(ptr: *mut bindings::work_struct) { + // The `__enqueue` method always uses a `work_struct` stored in a `Work<T, ID>`. + let ptr = ptr.cast::<Work<T, ID>>(); + // SAFETY: This computes the pointer that `__enqueue` got from `Arc::into_raw`. + let ptr = unsafe { T::work_container_of(ptr) }; + // SAFETY: This pointer comes from `Arc::into_raw` and we've been given back ownership. + let arc = unsafe { Arc::from_raw(ptr) }; + + T::run(arc) + } +} + +// SAFETY: The `work_struct` raw pointer is guaranteed to be valid for the duration of the call to +// the closure because we get it from an `Arc`, which means that the ref count will be at least 1, +// and we don't drop the `Arc` ourselves. If `queue_work_on` returns true, it is further guaranteed +// to be valid until a call to the function pointer in `work_struct` because we leak the memory it +// points to, and only reclaim it if the closure returns false, or in `WorkItemPointer::run`, which +// is what the function pointer in the `work_struct` must be pointing to, according to the safety +// requirements of `WorkItemPointer`. +unsafe impl<T, const ID: u64> RawWorkItem<ID> for Arc<T> +where + T: WorkItem<ID, Pointer = Self>, + T: HasWork<T, ID>, +{ + type EnqueueOutput = Result<(), Self>; + + unsafe fn __enqueue<F>(self, queue_work_on: F) -> Self::EnqueueOutput + where + F: FnOnce(*mut bindings::work_struct) -> bool, + { + // Casting between const and mut is not a problem as long as the pointer is a raw pointer. + let ptr = Arc::into_raw(self).cast_mut(); + + // SAFETY: Pointers into an `Arc` point at a valid value. + let work_ptr = unsafe { T::raw_get_work(ptr) }; + // SAFETY: `raw_get_work` returns a pointer to a valid value. + let work_ptr = unsafe { Work::raw_get(work_ptr) }; + + if queue_work_on(work_ptr) { + Ok(()) + } else { + // SAFETY: The work queue has not taken ownership of the pointer. + Err(unsafe { Arc::from_raw(ptr) }) + } + } +} + +// SAFETY: By the safety requirements of `HasDelayedWork`, the `work_struct` returned by methods in +// `HasWork` provides a `work_struct` that is the `work` field of a `delayed_work`, and the rest of +// the `delayed_work` has the same access rules as its `work` field. +unsafe impl<T, const ID: u64> RawDelayedWorkItem<ID> for Arc<T> +where + T: WorkItem<ID, Pointer = Self>, + T: HasDelayedWork<T, ID>, +{ +} + +// SAFETY: TODO. +unsafe impl<T, const ID: u64> WorkItemPointer<ID> for Pin<KBox<T>> +where + T: WorkItem<ID, Pointer = Self>, + T: HasWork<T, ID>, +{ + unsafe extern "C" fn run(ptr: *mut bindings::work_struct) { + // The `__enqueue` method always uses a `work_struct` stored in a `Work<T, ID>`. + let ptr = ptr.cast::<Work<T, ID>>(); + // SAFETY: This computes the pointer that `__enqueue` got from `Arc::into_raw`. + let ptr = unsafe { T::work_container_of(ptr) }; + // SAFETY: This pointer comes from `Arc::into_raw` and we've been given back ownership. + let boxed = unsafe { KBox::from_raw(ptr) }; + // SAFETY: The box was already pinned when it was enqueued. + let pinned = unsafe { Pin::new_unchecked(boxed) }; + + T::run(pinned) + } +} + +// SAFETY: TODO. +unsafe impl<T, const ID: u64> RawWorkItem<ID> for Pin<KBox<T>> +where + T: WorkItem<ID, Pointer = Self>, + T: HasWork<T, ID>, +{ + type EnqueueOutput = (); + + unsafe fn __enqueue<F>(self, queue_work_on: F) -> Self::EnqueueOutput + where + F: FnOnce(*mut bindings::work_struct) -> bool, + { + // SAFETY: We're not going to move `self` or any of its fields, so its okay to temporarily + // remove the `Pin` wrapper. + let boxed = unsafe { Pin::into_inner_unchecked(self) }; + let ptr = KBox::into_raw(boxed); + + // SAFETY: Pointers into a `KBox` point at a valid value. + let work_ptr = unsafe { T::raw_get_work(ptr) }; + // SAFETY: `raw_get_work` returns a pointer to a valid value. + let work_ptr = unsafe { Work::raw_get(work_ptr) }; + + if !queue_work_on(work_ptr) { + // SAFETY: This method requires exclusive ownership of the box, so it cannot be in a + // workqueue. + unsafe { ::core::hint::unreachable_unchecked() } + } + } +} + +// SAFETY: By the safety requirements of `HasDelayedWork`, the `work_struct` returned by methods in +// `HasWork` provides a `work_struct` that is the `work` field of a `delayed_work`, and the rest of +// the `delayed_work` has the same access rules as its `work` field. +unsafe impl<T, const ID: u64> RawDelayedWorkItem<ID> for Pin<KBox<T>> +where + T: WorkItem<ID, Pointer = Self>, + T: HasDelayedWork<T, ID>, +{ +} + +/// Returns the system work queue (`system_wq`). +/// +/// It is the one used by `schedule[_delayed]_work[_on]()`. Multi-CPU multi-threaded. There are +/// users which expect relatively short queue flush time. +/// +/// Callers shouldn't queue work items which can run for too long. +pub fn system() -> &'static Queue { + // SAFETY: `system_wq` is a C global, always available. + unsafe { Queue::from_raw(bindings::system_wq) } +} + +/// Returns the system high-priority work queue (`system_highpri_wq`). +/// +/// It is similar to the one returned by [`system`] but for work items which require higher +/// scheduling priority. +pub fn system_highpri() -> &'static Queue { + // SAFETY: `system_highpri_wq` is a C global, always available. + unsafe { Queue::from_raw(bindings::system_highpri_wq) } +} + +/// Returns the system work queue for potentially long-running work items (`system_long_wq`). +/// +/// It is similar to the one returned by [`system`] but may host long running work items. Queue +/// flushing might take relatively long. +pub fn system_long() -> &'static Queue { + // SAFETY: `system_long_wq` is a C global, always available. + unsafe { Queue::from_raw(bindings::system_long_wq) } +} + +/// Returns the system unbound work queue (`system_unbound_wq`). +/// +/// Workers are not bound to any specific CPU, not concurrency managed, and all queued work items +/// are executed immediately as long as `max_active` limit is not reached and resources are +/// available. +pub fn system_unbound() -> &'static Queue { + // SAFETY: `system_unbound_wq` is a C global, always available. + unsafe { Queue::from_raw(bindings::system_unbound_wq) } +} + +/// Returns the system freezable work queue (`system_freezable_wq`). +/// +/// It is equivalent to the one returned by [`system`] except that it's freezable. +/// +/// A freezable workqueue participates in the freeze phase of the system suspend operations. Work +/// items on the workqueue are drained and no new work item starts execution until thawed. +pub fn system_freezable() -> &'static Queue { + // SAFETY: `system_freezable_wq` is a C global, always available. + unsafe { Queue::from_raw(bindings::system_freezable_wq) } +} + +/// Returns the system power-efficient work queue (`system_power_efficient_wq`). +/// +/// It is inclined towards saving power and is converted to "unbound" variants if the +/// `workqueue.power_efficient` kernel parameter is specified; otherwise, it is similar to the one +/// returned by [`system`]. +pub fn system_power_efficient() -> &'static Queue { + // SAFETY: `system_power_efficient_wq` is a C global, always available. + unsafe { Queue::from_raw(bindings::system_power_efficient_wq) } +} + +/// Returns the system freezable power-efficient work queue (`system_freezable_power_efficient_wq`). +/// +/// It is similar to the one returned by [`system_power_efficient`] except that is freezable. +/// +/// A freezable workqueue participates in the freeze phase of the system suspend operations. Work +/// items on the workqueue are drained and no new work item starts execution until thawed. +pub fn system_freezable_power_efficient() -> &'static Queue { + // SAFETY: `system_freezable_power_efficient_wq` is a C global, always available. + unsafe { Queue::from_raw(bindings::system_freezable_power_efficient_wq) } +} + +/// Returns the system bottom halves work queue (`system_bh_wq`). +/// +/// It is similar to the one returned by [`system`] but for work items which +/// need to run from a softirq context. +pub fn system_bh() -> &'static Queue { + // SAFETY: `system_bh_wq` is a C global, always available. + unsafe { Queue::from_raw(bindings::system_bh_wq) } +} + +/// Returns the system bottom halves high-priority work queue (`system_bh_highpri_wq`). +/// +/// It is similar to the one returned by [`system_bh`] but for work items which +/// require higher scheduling priority. +pub fn system_bh_highpri() -> &'static Queue { + // SAFETY: `system_bh_highpri_wq` is a C global, always available. + unsafe { Queue::from_raw(bindings::system_bh_highpri_wq) } +} diff --git a/rust/kernel/xarray.rs b/rust/kernel/xarray.rs new file mode 100644 index 000000000000..a49d6db28845 --- /dev/null +++ b/rust/kernel/xarray.rs @@ -0,0 +1,276 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! XArray abstraction. +//! +//! C header: [`include/linux/xarray.h`](srctree/include/linux/xarray.h) + +use crate::{ + alloc, bindings, build_assert, + error::{Error, Result}, + ffi::c_void, + types::{ForeignOwnable, NotThreadSafe, Opaque}, +}; +use core::{iter, marker::PhantomData, pin::Pin, ptr::NonNull}; +use pin_init::{pin_data, pin_init, pinned_drop, PinInit}; + +/// An array which efficiently maps sparse integer indices to owned objects. +/// +/// This is similar to a [`crate::alloc::kvec::Vec<Option<T>>`], but more efficient when there are +/// holes in the index space, and can be efficiently grown. +/// +/// # Invariants +/// +/// `self.xa` is always an initialized and valid [`bindings::xarray`] whose entries are either +/// `XA_ZERO_ENTRY` or came from `T::into_foreign`. +/// +/// # Examples +/// +/// ```rust +/// use kernel::alloc::KBox; +/// use kernel::xarray::{AllocKind, XArray}; +/// +/// let xa = KBox::pin_init(XArray::new(AllocKind::Alloc1), GFP_KERNEL)?; +/// +/// let dead = KBox::new(0xdead, GFP_KERNEL)?; +/// let beef = KBox::new(0xbeef, GFP_KERNEL)?; +/// +/// let mut guard = xa.lock(); +/// +/// assert_eq!(guard.get(0), None); +/// +/// assert_eq!(guard.store(0, dead, GFP_KERNEL)?.as_deref(), None); +/// assert_eq!(guard.get(0).copied(), Some(0xdead)); +/// +/// *guard.get_mut(0).unwrap() = 0xffff; +/// assert_eq!(guard.get(0).copied(), Some(0xffff)); +/// +/// assert_eq!(guard.store(0, beef, GFP_KERNEL)?.as_deref().copied(), Some(0xffff)); +/// assert_eq!(guard.get(0).copied(), Some(0xbeef)); +/// +/// guard.remove(0); +/// assert_eq!(guard.get(0), None); +/// +/// # Ok::<(), Error>(()) +/// ``` +#[pin_data(PinnedDrop)] +pub struct XArray<T: ForeignOwnable> { + #[pin] + xa: Opaque<bindings::xarray>, + _p: PhantomData<T>, +} + +#[pinned_drop] +impl<T: ForeignOwnable> PinnedDrop for XArray<T> { + fn drop(self: Pin<&mut Self>) { + self.iter().for_each(|ptr| { + let ptr = ptr.as_ptr(); + // SAFETY: `ptr` came from `T::into_foreign`. + // + // INVARIANT: we own the only reference to the array which is being dropped so the + // broken invariant is not observable on function exit. + drop(unsafe { T::from_foreign(ptr) }) + }); + + // SAFETY: `self.xa` is always valid by the type invariant. + unsafe { bindings::xa_destroy(self.xa.get()) }; + } +} + +/// Flags passed to [`XArray::new`] to configure the array's allocation tracking behavior. +pub enum AllocKind { + /// Consider the first element to be at index 0. + Alloc, + /// Consider the first element to be at index 1. + Alloc1, +} + +impl<T: ForeignOwnable> XArray<T> { + /// Creates a new initializer for this type. + pub fn new(kind: AllocKind) -> impl PinInit<Self> { + let flags = match kind { + AllocKind::Alloc => bindings::XA_FLAGS_ALLOC, + AllocKind::Alloc1 => bindings::XA_FLAGS_ALLOC1, + }; + pin_init!(Self { + // SAFETY: `xa` is valid while the closure is called. + // + // INVARIANT: `xa` is initialized here to an empty, valid [`bindings::xarray`]. + xa <- Opaque::ffi_init(|xa| unsafe { + bindings::xa_init_flags(xa, flags) + }), + _p: PhantomData, + }) + } + + fn iter(&self) -> impl Iterator<Item = NonNull<c_void>> + '_ { + let mut index = 0; + + // SAFETY: `self.xa` is always valid by the type invariant. + iter::once(unsafe { + bindings::xa_find(self.xa.get(), &mut index, usize::MAX, bindings::XA_PRESENT) + }) + .chain(iter::from_fn(move || { + // SAFETY: `self.xa` is always valid by the type invariant. + Some(unsafe { + bindings::xa_find_after(self.xa.get(), &mut index, usize::MAX, bindings::XA_PRESENT) + }) + })) + .map_while(|ptr| NonNull::new(ptr.cast())) + } + + /// Attempts to lock the [`XArray`] for exclusive access. + pub fn try_lock(&self) -> Option<Guard<'_, T>> { + // SAFETY: `self.xa` is always valid by the type invariant. + if (unsafe { bindings::xa_trylock(self.xa.get()) } != 0) { + Some(Guard { + xa: self, + _not_send: NotThreadSafe, + }) + } else { + None + } + } + + /// Locks the [`XArray`] for exclusive access. + pub fn lock(&self) -> Guard<'_, T> { + // SAFETY: `self.xa` is always valid by the type invariant. + unsafe { bindings::xa_lock(self.xa.get()) }; + + Guard { + xa: self, + _not_send: NotThreadSafe, + } + } +} + +/// A lock guard. +/// +/// The lock is unlocked when the guard goes out of scope. +#[must_use = "the lock unlocks immediately when the guard is unused"] +pub struct Guard<'a, T: ForeignOwnable> { + xa: &'a XArray<T>, + _not_send: NotThreadSafe, +} + +impl<T: ForeignOwnable> Drop for Guard<'_, T> { + fn drop(&mut self) { + // SAFETY: + // - `self.xa.xa` is always valid by the type invariant. + // - The caller holds the lock, so it is safe to unlock it. + unsafe { bindings::xa_unlock(self.xa.xa.get()) }; + } +} + +/// The error returned by [`store`](Guard::store). +/// +/// Contains the underlying error and the value that was not stored. +pub struct StoreError<T> { + /// The error that occurred. + pub error: Error, + /// The value that was not stored. + pub value: T, +} + +impl<T> From<StoreError<T>> for Error { + fn from(value: StoreError<T>) -> Self { + value.error + } +} + +impl<'a, T: ForeignOwnable> Guard<'a, T> { + fn load<F, U>(&self, index: usize, f: F) -> Option<U> + where + F: FnOnce(NonNull<c_void>) -> U, + { + // SAFETY: `self.xa.xa` is always valid by the type invariant. + let ptr = unsafe { bindings::xa_load(self.xa.xa.get(), index) }; + let ptr = NonNull::new(ptr.cast())?; + Some(f(ptr)) + } + + /// Provides a reference to the element at the given index. + pub fn get(&self, index: usize) -> Option<T::Borrowed<'_>> { + self.load(index, |ptr| { + // SAFETY: `ptr` came from `T::into_foreign`. + unsafe { T::borrow(ptr.as_ptr()) } + }) + } + + /// Provides a mutable reference to the element at the given index. + pub fn get_mut(&mut self, index: usize) -> Option<T::BorrowedMut<'_>> { + self.load(index, |ptr| { + // SAFETY: `ptr` came from `T::into_foreign`. + unsafe { T::borrow_mut(ptr.as_ptr()) } + }) + } + + /// Removes and returns the element at the given index. + pub fn remove(&mut self, index: usize) -> Option<T> { + // SAFETY: + // - `self.xa.xa` is always valid by the type invariant. + // - The caller holds the lock. + let ptr = unsafe { bindings::__xa_erase(self.xa.xa.get(), index) }.cast(); + // SAFETY: + // - `ptr` is either NULL or came from `T::into_foreign`. + // - `&mut self` guarantees that the lifetimes of [`T::Borrowed`] and [`T::BorrowedMut`] + // borrowed from `self` have ended. + unsafe { T::try_from_foreign(ptr) } + } + + /// Stores an element at the given index. + /// + /// May drop the lock if needed to allocate memory, and then reacquire it afterwards. + /// + /// On success, returns the element which was previously at the given index. + /// + /// On failure, returns the element which was attempted to be stored. + pub fn store( + &mut self, + index: usize, + value: T, + gfp: alloc::Flags, + ) -> Result<Option<T>, StoreError<T>> { + build_assert!( + T::FOREIGN_ALIGN >= 4, + "pointers stored in XArray must be 4-byte aligned" + ); + let new = value.into_foreign(); + + let old = { + let new = new.cast(); + // SAFETY: + // - `self.xa.xa` is always valid by the type invariant. + // - The caller holds the lock. + // + // INVARIANT: `new` came from `T::into_foreign`. + unsafe { bindings::__xa_store(self.xa.xa.get(), index, new, gfp.as_raw()) } + }; + + // SAFETY: `__xa_store` returns the old entry at this index on success or `xa_err` if an + // error happened. + let errno = unsafe { bindings::xa_err(old) }; + if errno != 0 { + // SAFETY: `new` came from `T::into_foreign` and `__xa_store` does not take + // ownership of the value on error. + let value = unsafe { T::from_foreign(new) }; + Err(StoreError { + value, + error: Error::from_errno(errno), + }) + } else { + let old = old.cast(); + // SAFETY: `ptr` is either NULL or came from `T::into_foreign`. + // + // NB: `XA_ZERO_ENTRY` is never returned by functions belonging to the Normal XArray + // API; such entries present as `NULL`. + Ok(unsafe { T::try_from_foreign(old) }) + } + } +} + +// SAFETY: `XArray<T>` has no shared mutable state so it is `Send` iff `T` is `Send`. +unsafe impl<T: ForeignOwnable + Send> Send for XArray<T> {} + +// SAFETY: `XArray<T>` serialises the interior mutability it provides so it is `Sync` iff `T` is +// `Send`. +unsafe impl<T: ForeignOwnable + Send> Sync for XArray<T> {} |
