From 7f8977a7fe6df9cdfe489c641058ca5534ec73eb Mon Sep 17 00:00:00 2001 From: Benno Lossin Date: Mon, 14 Aug 2023 08:47:48 +0000 Subject: rust: init: add `{pin_}chain` functions to `{Pin}Init` The `{pin_}chain` functions extend an initializer: it not only initializes the value, but also executes a closure taking a reference to the initialized value. This allows to do something with a value directly after initialization. Suggested-by: Asahi Lina Reviewed-by: Martin Rodriguez Reboredo Signed-off-by: Benno Lossin Reviewed-by: Alice Ryhl Link: https://lore.kernel.org/r/20230814084602.25699-13-benno.lossin@proton.me [ Cleaned a few trivial nits. ] Signed-off-by: Miguel Ojeda --- rust/kernel/init.rs | 142 +++++++++++++++++++++++++++++++++++++++++ rust/kernel/init/__internal.rs | 2 +- 2 files changed, 143 insertions(+), 1 deletion(-) (limited to 'rust') diff --git a/rust/kernel/init.rs b/rust/kernel/init.rs index 0f9c8576697e..0071b2834b78 100644 --- a/rust/kernel/init.rs +++ b/rust/kernel/init.rs @@ -767,6 +767,79 @@ pub unsafe trait PinInit: Sized { /// deallocate. /// - `slot` will not move until it is dropped, i.e. it will be pinned. unsafe fn __pinned_init(self, slot: *mut T) -> Result<(), E>; + + /// First initializes the value using `self` then calls the function `f` with the initialized + /// value. + /// + /// If `f` returns an error the value is dropped and the initializer will forward the error. + /// + /// # Examples + /// + /// ```rust + /// # #![allow(clippy::disallowed_names)] + /// use kernel::{types::Opaque, init::pin_init_from_closure}; + /// #[repr(C)] + /// struct RawFoo([u8; 16]); + /// extern { + /// fn init_foo(_: *mut RawFoo); + /// } + /// + /// #[pin_data] + /// struct Foo { + /// #[pin] + /// raw: Opaque, + /// } + /// + /// impl Foo { + /// fn setup(self: Pin<&mut Self>) { + /// pr_info!("Setting up foo"); + /// } + /// } + /// + /// let foo = pin_init!(Foo { + /// raw <- unsafe { + /// Opaque::ffi_init(|s| { + /// init_foo(s); + /// }) + /// }, + /// }).pin_chain(|foo| { + /// foo.setup(); + /// Ok(()) + /// }); + /// ``` + fn pin_chain(self, f: F) -> ChainPinInit + where + F: FnOnce(Pin<&mut T>) -> Result<(), E>, + { + ChainPinInit(self, f, PhantomData) + } +} + +/// An initializer returned by [`PinInit::pin_chain`]. +pub struct ChainPinInit(I, F, __internal::Invariant<(E, Box)>); + +// SAFETY: The `__pinned_init` function is implemented such that it +// - returns `Ok(())` on successful initialization, +// - returns `Err(err)` on error and in this case `slot` will be dropped. +// - considers `slot` pinned. +unsafe impl PinInit for ChainPinInit +where + I: PinInit, + F: FnOnce(Pin<&mut T>) -> Result<(), E>, +{ + unsafe fn __pinned_init(self, slot: *mut T) -> Result<(), E> { + // SAFETY: All requirements fulfilled since this function is `__pinned_init`. + unsafe { self.0.__pinned_init(slot)? }; + // SAFETY: The above call initialized `slot` and we still have unique access. + let val = unsafe { &mut *slot }; + // SAFETY: `slot` is considered pinned. + let val = unsafe { Pin::new_unchecked(val) }; + (self.1)(val).map_err(|e| { + // SAFETY: `slot` was initialized above. + unsafe { core::ptr::drop_in_place(slot) }; + e + }) + } } /// An initializer for `T`. @@ -808,6 +881,75 @@ pub unsafe trait Init: PinInit { /// - the caller does not touch `slot` when `Err` is returned, they are only permitted to /// deallocate. unsafe fn __init(self, slot: *mut T) -> Result<(), E>; + + /// First initializes the value using `self` then calls the function `f` with the initialized + /// value. + /// + /// If `f` returns an error the value is dropped and the initializer will forward the error. + /// + /// # Examples + /// + /// ```rust + /// # #![allow(clippy::disallowed_names)] + /// use kernel::{types::Opaque, init::{self, init_from_closure}}; + /// struct Foo { + /// buf: [u8; 1_000_000], + /// } + /// + /// impl Foo { + /// fn setup(&mut self) { + /// pr_info!("Setting up foo"); + /// } + /// } + /// + /// let foo = init!(Foo { + /// buf <- init::zeroed() + /// }).chain(|foo| { + /// foo.setup(); + /// Ok(()) + /// }); + /// ``` + fn chain(self, f: F) -> ChainInit + where + F: FnOnce(&mut T) -> Result<(), E>, + { + ChainInit(self, f, PhantomData) + } +} + +/// An initializer returned by [`Init::chain`]. +pub struct ChainInit(I, F, __internal::Invariant<(E, Box)>); + +// SAFETY: The `__init` function is implemented such that it +// - returns `Ok(())` on successful initialization, +// - returns `Err(err)` on error and in this case `slot` will be dropped. +unsafe impl Init for ChainInit +where + I: Init, + F: FnOnce(&mut T) -> Result<(), E>, +{ + unsafe fn __init(self, slot: *mut T) -> Result<(), E> { + // SAFETY: All requirements fulfilled since this function is `__init`. + unsafe { self.0.__pinned_init(slot)? }; + // SAFETY: The above call initialized `slot` and we still have unique access. + (self.1)(unsafe { &mut *slot }).map_err(|e| { + // SAFETY: `slot` was initialized above. + unsafe { core::ptr::drop_in_place(slot) }; + e + }) + } +} + +// SAFETY: `__pinned_init` behaves exactly the same as `__init`. +unsafe impl PinInit for ChainInit +where + I: Init, + F: FnOnce(&mut T) -> Result<(), E>, +{ + unsafe fn __pinned_init(self, slot: *mut T) -> Result<(), E> { + // SAFETY: `__init` has less strict requirements compared to `__pinned_init`. + unsafe { self.__init(slot) } + } } /// Creates a new [`PinInit`] from the given closure. diff --git a/rust/kernel/init/__internal.rs b/rust/kernel/init/__internal.rs index 12e195061525..db3372619ecd 100644 --- a/rust/kernel/init/__internal.rs +++ b/rust/kernel/init/__internal.rs @@ -13,7 +13,7 @@ use super::*; /// /// [nomicon]: https://doc.rust-lang.org/nomicon/subtyping.html /// [this table]: https://doc.rust-lang.org/nomicon/phantom-data.html#table-of-phantomdata-patterns -type Invariant = PhantomData *mut T>; +pub(super) type Invariant = PhantomData *mut T>; /// This is the module-internal type implementing `PinInit` and `Init`. It is unsafe to create this /// type, since the closure needs to fulfill the same safety requirement as the -- cgit