diff options
author | Linus Torvalds <torvalds@linux-foundation.org> | 2025-03-30 17:03:26 -0700 |
---|---|---|
committer | Linus Torvalds <torvalds@linux-foundation.org> | 2025-03-30 17:03:26 -0700 |
commit | 4e82c87058f45e79eeaa4d5bcc3b38dd3dce7209 (patch) | |
tree | 122868ae62bfff4d0ed9f13c853c1c9690dbe0f3 /rust/pin-init/examples | |
parent | 01d5b167dc230cf3b6eb9dd7205f6a705026d1ce (diff) | |
parent | e6ea10d5dbe082c54add289b44f08c9fcfe658af (diff) |
Merge tag 'rust-6.15' of git://git.kernel.org/pub/scm/linux/kernel/git/ojeda/linux
Pull Rust updates from Miguel Ojeda:
"Toolchain and infrastructure:
- Extract the 'pin-init' API from the 'kernel' crate and make it into
a standalone crate.
In order to do this, the contents are rearranged so that they can
easily be kept in sync with the version maintained out-of-tree that
other projects have started to use too (or plan to, like QEMU).
This will reduce the maintenance burden for Benno, who will now
have his own sub-tree, and will simplify future expected changes
like the move to use 'syn' to simplify the implementation.
- Add '#[test]'-like support based on KUnit.
We already had doctests support based on KUnit, which takes the
examples in our Rust documentation and runs them under KUnit.
Now, we are adding the beginning of the support for "normal" tests,
similar to those the '#[test]' tests in userspace Rust. For
instance:
#[kunit_tests(my_suite)]
mod tests {
#[test]
fn my_test() {
assert_eq!(1 + 1, 2);
}
}
Unlike with doctests, the 'assert*!'s do not map to the KUnit
assertion APIs yet.
- Check Rust signatures at compile time for functions called from C
by name.
In particular, introduce a new '#[export]' macro that can be placed
in the Rust function definition. It will ensure that the function
declaration on the C side matches the signature on the Rust
function:
#[export]
pub unsafe extern "C" fn my_function(a: u8, b: i32) -> usize {
// ...
}
The macro essentially forces the compiler to compare the types of
the actual Rust function and the 'bindgen'-processed C signature.
These cases are rare so far. In the future, we may consider
introducing another tool, 'cbindgen', to generate C headers
automatically. Even then, having these functions explicitly marked
may be a good idea anyway.
- Enable the 'raw_ref_op' Rust feature: it is already stable, and
allows us to use the new '&raw' syntax, avoiding a couple macros.
After everyone has migrated, we will disallow the macros.
- Pass the correct target to 'bindgen' on Usermode Linux.
- Fix 'rusttest' build in macOS.
'kernel' crate:
- New 'hrtimer' module: add support for setting up intrusive timers
without allocating when starting the timer. Add support for
'Pin<Box<_>>', 'Arc<_>', 'Pin<&_>' and 'Pin<&mut _>' as pointer
types for use with timer callbacks. Add support for setting clock
source and timer mode.
- New 'dma' module: add a simple DMA coherent allocator abstraction
and a test sample driver.
- 'list' module: make the linked list 'Cursor' point between
elements, rather than at an element, which is more convenient to us
and allows for cursors to empty lists; and document it with
examples of how to perform common operations with the provided
methods.
- 'str' module: implement a few traits for 'BStr' as well as the
'strip_prefix()' method.
- 'sync' module: add 'Arc::as_ptr'.
- 'alloc' module: add 'Box::into_pin'.
- 'error' module: extend the 'Result' documentation, including a few
examples on different ways of handling errors, a warning about
using methods that may panic, and links to external documentation.
'macros' crate:
- 'module' macro: add the 'authors' key to support multiple authors.
The original key will be kept until everyone has migrated.
Documentation:
- Add error handling sections.
MAINTAINERS:
- Add Danilo Krummrich as reviewer of the Rust "subsystem".
- Add 'RUST [PIN-INIT]' entry with Benno Lossin as maintainer. It has
its own sub-tree.
- Add sub-tree for 'RUST [ALLOC]'.
- Add 'DMA MAPPING HELPERS DEVICE DRIVER API [RUST]' entry with
Abdiel Janulgue as primary maintainer. It will go through the
sub-tree of the 'RUST [ALLOC]' entry.
- Add 'HIGH-RESOLUTION TIMERS [RUST]' entry with Andreas Hindborg as
maintainer. It has its own sub-tree.
And a few other cleanups and improvements"
* tag 'rust-6.15' of git://git.kernel.org/pub/scm/linux/kernel/git/ojeda/linux: (71 commits)
rust: dma: add `Send` implementation for `CoherentAllocation`
rust: macros: fix `make rusttest` build on macOS
rust: block: refactor to use `&raw mut`
rust: enable `raw_ref_op` feature
rust: uaccess: name the correct function
rust: rbtree: fix comments referring to Box instead of KBox
rust: hrtimer: add maintainer entry
rust: hrtimer: add clocksource selection through `ClockId`
rust: hrtimer: add `HrTimerMode`
rust: hrtimer: implement `HrTimerPointer` for `Pin<Box<T>>`
rust: alloc: add `Box::into_pin`
rust: hrtimer: implement `UnsafeHrTimerPointer` for `Pin<&mut T>`
rust: hrtimer: implement `UnsafeHrTimerPointer` for `Pin<&T>`
rust: hrtimer: add `hrtimer::ScopedHrTimerPointer`
rust: hrtimer: add `UnsafeHrTimerPointer`
rust: hrtimer: allow timer restart from timer handler
rust: str: implement `strip_prefix` for `BStr`
rust: str: implement `AsRef<BStr>` for `[u8]` and `BStr`
rust: str: implement `Index` for `BStr`
rust: str: implement `PartialEq` for `BStr`
...
Diffstat (limited to 'rust/pin-init/examples')
-rw-r--r-- | rust/pin-init/examples/big_struct_in_place.rs | 39 | ||||
-rw-r--r-- | rust/pin-init/examples/error.rs | 27 | ||||
-rw-r--r-- | rust/pin-init/examples/linked_list.rs | 161 | ||||
-rw-r--r-- | rust/pin-init/examples/mutex.rs | 209 | ||||
-rw-r--r-- | rust/pin-init/examples/pthread_mutex.rs | 178 | ||||
-rw-r--r-- | rust/pin-init/examples/static_init.rs | 122 |
6 files changed, 736 insertions, 0 deletions
diff --git a/rust/pin-init/examples/big_struct_in_place.rs b/rust/pin-init/examples/big_struct_in_place.rs new file mode 100644 index 000000000000..30d44a334ffd --- /dev/null +++ b/rust/pin-init/examples/big_struct_in_place.rs @@ -0,0 +1,39 @@ +// SPDX-License-Identifier: Apache-2.0 OR MIT + +use pin_init::*; + +// Struct with size over 1GiB +#[derive(Debug)] +pub struct BigStruct { + buf: [u8; 1024 * 1024 * 1024], + a: u64, + b: u64, + c: u64, + d: u64, + managed_buf: ManagedBuf, +} + +#[derive(Debug)] +pub struct ManagedBuf { + buf: [u8; 1024 * 1024], +} + +impl ManagedBuf { + pub fn new() -> impl Init<Self> { + init!(ManagedBuf { buf <- zeroed() }) + } +} + +fn main() { + // we want to initialize the struct in-place, otherwise we would get a stackoverflow + let buf: Box<BigStruct> = Box::init(init!(BigStruct { + buf <- zeroed(), + a: 7, + b: 186, + c: 7789, + d: 34, + managed_buf <- ManagedBuf::new(), + })) + .unwrap(); + println!("{}", core::mem::size_of_val(&*buf)); +} diff --git a/rust/pin-init/examples/error.rs b/rust/pin-init/examples/error.rs new file mode 100644 index 000000000000..e0cc258746ce --- /dev/null +++ b/rust/pin-init/examples/error.rs @@ -0,0 +1,27 @@ +// SPDX-License-Identifier: Apache-2.0 OR MIT + +#![cfg_attr(feature = "alloc", feature(allocator_api))] + +use core::convert::Infallible; + +#[cfg(feature = "alloc")] +use std::alloc::AllocError; + +#[derive(Debug)] +pub struct Error; + +impl From<Infallible> for Error { + fn from(e: Infallible) -> Self { + match e {} + } +} + +#[cfg(feature = "alloc")] +impl From<AllocError> for Error { + fn from(_: AllocError) -> Self { + Self + } +} + +#[allow(dead_code)] +fn main() {} diff --git a/rust/pin-init/examples/linked_list.rs b/rust/pin-init/examples/linked_list.rs new file mode 100644 index 000000000000..6d7eb0a0ec0d --- /dev/null +++ b/rust/pin-init/examples/linked_list.rs @@ -0,0 +1,161 @@ +// SPDX-License-Identifier: Apache-2.0 OR MIT + +#![allow(clippy::undocumented_unsafe_blocks)] +#![cfg_attr(feature = "alloc", feature(allocator_api))] + +use core::{ + cell::Cell, + convert::Infallible, + marker::PhantomPinned, + pin::Pin, + ptr::{self, NonNull}, +}; + +use pin_init::*; + +#[expect(unused_attributes)] +mod error; +use error::Error; + +#[pin_data(PinnedDrop)] +#[repr(C)] +#[derive(Debug)] +pub struct ListHead { + next: Link, + prev: Link, + #[pin] + pin: PhantomPinned, +} + +impl ListHead { + #[inline] + pub fn new() -> impl PinInit<Self, Infallible> { + try_pin_init!(&this in Self { + next: unsafe { Link::new_unchecked(this) }, + prev: unsafe { Link::new_unchecked(this) }, + pin: PhantomPinned, + }? Infallible) + } + + #[inline] + pub fn insert_next(list: &ListHead) -> impl PinInit<Self, Infallible> + '_ { + try_pin_init!(&this in Self { + prev: list.next.prev().replace(unsafe { Link::new_unchecked(this)}), + next: list.next.replace(unsafe { Link::new_unchecked(this)}), + pin: PhantomPinned, + }? Infallible) + } + + #[inline] + pub fn insert_prev(list: &ListHead) -> impl PinInit<Self, Infallible> + '_ { + try_pin_init!(&this in Self { + next: list.prev.next().replace(unsafe { Link::new_unchecked(this)}), + prev: list.prev.replace(unsafe { Link::new_unchecked(this)}), + pin: PhantomPinned, + }? Infallible) + } + + #[inline] + pub fn next(&self) -> Option<NonNull<Self>> { + if ptr::eq(self.next.as_ptr(), self) { + None + } else { + Some(unsafe { NonNull::new_unchecked(self.next.as_ptr() as *mut Self) }) + } + } + + #[allow(dead_code)] + pub fn size(&self) -> usize { + let mut size = 1; + let mut cur = self.next.clone(); + while !ptr::eq(self, cur.cur()) { + cur = cur.next().clone(); + size += 1; + } + size + } +} + +#[pinned_drop] +impl PinnedDrop for ListHead { + //#[inline] + fn drop(self: Pin<&mut Self>) { + if !ptr::eq(self.next.as_ptr(), &*self) { + let next = unsafe { &*self.next.as_ptr() }; + let prev = unsafe { &*self.prev.as_ptr() }; + next.prev.set(&self.prev); + prev.next.set(&self.next); + } + } +} + +#[repr(transparent)] +#[derive(Clone, Debug)] +struct Link(Cell<NonNull<ListHead>>); + +impl Link { + /// # Safety + /// + /// The contents of the pointer should form a consistent circular + /// linked list; for example, a "next" link should be pointed back + /// by the target `ListHead`'s "prev" link and a "prev" link should be + /// pointed back by the target `ListHead`'s "next" link. + #[inline] + unsafe fn new_unchecked(ptr: NonNull<ListHead>) -> Self { + Self(Cell::new(ptr)) + } + + #[inline] + fn next(&self) -> &Link { + unsafe { &(*self.0.get().as_ptr()).next } + } + + #[inline] + fn prev(&self) -> &Link { + unsafe { &(*self.0.get().as_ptr()).prev } + } + + #[allow(dead_code)] + fn cur(&self) -> &ListHead { + unsafe { &*self.0.get().as_ptr() } + } + + #[inline] + fn replace(&self, other: Link) -> Link { + unsafe { Link::new_unchecked(self.0.replace(other.0.get())) } + } + + #[inline] + fn as_ptr(&self) -> *const ListHead { + self.0.get().as_ptr() + } + + #[inline] + fn set(&self, val: &Link) { + self.0.set(val.0.get()); + } +} + +#[allow(dead_code)] +#[cfg_attr(test, test)] +fn main() -> Result<(), Error> { + let a = Box::pin_init(ListHead::new())?; + stack_pin_init!(let b = ListHead::insert_next(&a)); + stack_pin_init!(let c = ListHead::insert_next(&a)); + stack_pin_init!(let d = ListHead::insert_next(&b)); + let e = Box::pin_init(ListHead::insert_next(&b))?; + println!("a ({a:p}): {a:?}"); + println!("b ({b:p}): {b:?}"); + println!("c ({c:p}): {c:?}"); + println!("d ({d:p}): {d:?}"); + println!("e ({e:p}): {e:?}"); + let mut inspect = &*a; + while let Some(next) = inspect.next() { + println!("({inspect:p}): {inspect:?}"); + inspect = unsafe { &*next.as_ptr() }; + if core::ptr::eq(inspect, &*a) { + break; + } + } + Ok(()) +} diff --git a/rust/pin-init/examples/mutex.rs b/rust/pin-init/examples/mutex.rs new file mode 100644 index 000000000000..073bb79341d1 --- /dev/null +++ b/rust/pin-init/examples/mutex.rs @@ -0,0 +1,209 @@ +// SPDX-License-Identifier: Apache-2.0 OR MIT + +#![allow(clippy::undocumented_unsafe_blocks)] +#![cfg_attr(feature = "alloc", feature(allocator_api))] +#![allow(clippy::missing_safety_doc)] + +use core::{ + cell::{Cell, UnsafeCell}, + marker::PhantomPinned, + ops::{Deref, DerefMut}, + pin::Pin, + sync::atomic::{AtomicBool, Ordering}, +}; +use std::{ + sync::Arc, + thread::{self, park, sleep, Builder, Thread}, + time::Duration, +}; + +use pin_init::*; +#[expect(unused_attributes)] +#[path = "./linked_list.rs"] +pub mod linked_list; +use linked_list::*; + +pub struct SpinLock { + inner: AtomicBool, +} + +impl SpinLock { + #[inline] + pub fn acquire(&self) -> SpinLockGuard<'_> { + while self + .inner + .compare_exchange(false, true, Ordering::Acquire, Ordering::Relaxed) + .is_err() + { + while self.inner.load(Ordering::Relaxed) { + thread::yield_now(); + } + } + SpinLockGuard(self) + } + + #[inline] + #[allow(clippy::new_without_default)] + pub const fn new() -> Self { + Self { + inner: AtomicBool::new(false), + } + } +} + +pub struct SpinLockGuard<'a>(&'a SpinLock); + +impl Drop for SpinLockGuard<'_> { + #[inline] + fn drop(&mut self) { + self.0.inner.store(false, Ordering::Release); + } +} + +#[pin_data] +pub struct CMutex<T> { + #[pin] + wait_list: ListHead, + spin_lock: SpinLock, + locked: Cell<bool>, + #[pin] + data: UnsafeCell<T>, +} + +impl<T> CMutex<T> { + #[inline] + pub fn new(val: impl PinInit<T>) -> impl PinInit<Self> { + pin_init!(CMutex { + wait_list <- ListHead::new(), + spin_lock: SpinLock::new(), + locked: Cell::new(false), + data <- unsafe { + pin_init_from_closure(|slot: *mut UnsafeCell<T>| { + val.__pinned_init(slot.cast::<T>()) + }) + }, + }) + } + + #[inline] + pub fn lock(&self) -> Pin<CMutexGuard<'_, T>> { + let mut sguard = self.spin_lock.acquire(); + if self.locked.get() { + stack_pin_init!(let wait_entry = WaitEntry::insert_new(&self.wait_list)); + // println!("wait list length: {}", self.wait_list.size()); + while self.locked.get() { + drop(sguard); + park(); + sguard = self.spin_lock.acquire(); + } + // This does have an effect, as the ListHead inside wait_entry implements Drop! + #[expect(clippy::drop_non_drop)] + drop(wait_entry); + } + self.locked.set(true); + unsafe { + Pin::new_unchecked(CMutexGuard { + mtx: self, + _pin: PhantomPinned, + }) + } + } + + #[allow(dead_code)] + pub fn get_data_mut(self: Pin<&mut Self>) -> &mut T { + // SAFETY: we have an exclusive reference and thus nobody has access to data. + unsafe { &mut *self.data.get() } + } +} + +unsafe impl<T: Send> Send for CMutex<T> {} +unsafe impl<T: Send> Sync for CMutex<T> {} + +pub struct CMutexGuard<'a, T> { + mtx: &'a CMutex<T>, + _pin: PhantomPinned, +} + +impl<T> Drop for CMutexGuard<'_, T> { + #[inline] + fn drop(&mut self) { + let sguard = self.mtx.spin_lock.acquire(); + self.mtx.locked.set(false); + if let Some(list_field) = self.mtx.wait_list.next() { + let wait_entry = list_field.as_ptr().cast::<WaitEntry>(); + unsafe { (*wait_entry).thread.unpark() }; + } + drop(sguard); + } +} + +impl<T> Deref for CMutexGuard<'_, T> { + type Target = T; + + #[inline] + fn deref(&self) -> &Self::Target { + unsafe { &*self.mtx.data.get() } + } +} + +impl<T> DerefMut for CMutexGuard<'_, T> { + #[inline] + fn deref_mut(&mut self) -> &mut Self::Target { + unsafe { &mut *self.mtx.data.get() } + } +} + +#[pin_data] +#[repr(C)] +struct WaitEntry { + #[pin] + wait_list: ListHead, + thread: Thread, +} + +impl WaitEntry { + #[inline] + fn insert_new(list: &ListHead) -> impl PinInit<Self> + '_ { + pin_init!(Self { + thread: thread::current(), + wait_list <- ListHead::insert_prev(list), + }) + } +} + +#[cfg(not(any(feature = "std", feature = "alloc")))] +fn main() {} + +#[allow(dead_code)] +#[cfg_attr(test, test)] +#[cfg(any(feature = "std", feature = "alloc"))] +fn main() { + let mtx: Pin<Arc<CMutex<usize>>> = Arc::pin_init(CMutex::new(0)).unwrap(); + let mut handles = vec![]; + let thread_count = 20; + let workload = if cfg!(miri) { 100 } else { 1_000 }; + for i in 0..thread_count { + let mtx = mtx.clone(); + handles.push( + Builder::new() + .name(format!("worker #{i}")) + .spawn(move || { + for _ in 0..workload { + *mtx.lock() += 1; + } + println!("{i} halfway"); + sleep(Duration::from_millis((i as u64) * 10)); + for _ in 0..workload { + *mtx.lock() += 1; + } + println!("{i} finished"); + }) + .expect("should not fail"), + ); + } + for h in handles { + h.join().expect("thread panicked"); + } + println!("{:?}", &*mtx.lock()); + assert_eq!(*mtx.lock(), workload * thread_count * 2); +} diff --git a/rust/pin-init/examples/pthread_mutex.rs b/rust/pin-init/examples/pthread_mutex.rs new file mode 100644 index 000000000000..9164298c44c0 --- /dev/null +++ b/rust/pin-init/examples/pthread_mutex.rs @@ -0,0 +1,178 @@ +// SPDX-License-Identifier: Apache-2.0 OR MIT + +// inspired by https://github.com/nbdd0121/pin-init/blob/trunk/examples/pthread_mutex.rs +#![allow(clippy::undocumented_unsafe_blocks)] +#![cfg_attr(feature = "alloc", feature(allocator_api))] +#[cfg(not(windows))] +mod pthread_mtx { + #[cfg(feature = "alloc")] + use core::alloc::AllocError; + use core::{ + cell::UnsafeCell, + marker::PhantomPinned, + mem::MaybeUninit, + ops::{Deref, DerefMut}, + pin::Pin, + }; + use pin_init::*; + use std::convert::Infallible; + + #[pin_data(PinnedDrop)] + pub struct PThreadMutex<T> { + #[pin] + raw: UnsafeCell<libc::pthread_mutex_t>, + data: UnsafeCell<T>, + #[pin] + pin: PhantomPinned, + } + + unsafe impl<T: Send> Send for PThreadMutex<T> {} + unsafe impl<T: Send> Sync for PThreadMutex<T> {} + + #[pinned_drop] + impl<T> PinnedDrop for PThreadMutex<T> { + fn drop(self: Pin<&mut Self>) { + unsafe { + libc::pthread_mutex_destroy(self.raw.get()); + } + } + } + + #[derive(Debug)] + pub enum Error { + #[expect(dead_code)] + IO(std::io::Error), + Alloc, + } + + impl From<Infallible> for Error { + fn from(e: Infallible) -> Self { + match e {} + } + } + + #[cfg(feature = "alloc")] + impl From<AllocError> for Error { + fn from(_: AllocError) -> Self { + Self::Alloc + } + } + + impl<T> PThreadMutex<T> { + pub fn new(data: T) -> impl PinInit<Self, Error> { + fn init_raw() -> impl PinInit<UnsafeCell<libc::pthread_mutex_t>, Error> { + let init = |slot: *mut UnsafeCell<libc::pthread_mutex_t>| { + // we can cast, because `UnsafeCell` has the same layout as T. + let slot: *mut libc::pthread_mutex_t = slot.cast(); + let mut attr = MaybeUninit::uninit(); + let attr = attr.as_mut_ptr(); + // SAFETY: ptr is valid + let ret = unsafe { libc::pthread_mutexattr_init(attr) }; + if ret != 0 { + return Err(Error::IO(std::io::Error::from_raw_os_error(ret))); + } + // SAFETY: attr is initialized + let ret = unsafe { + libc::pthread_mutexattr_settype(attr, libc::PTHREAD_MUTEX_NORMAL) + }; + if ret != 0 { + // SAFETY: attr is initialized + unsafe { libc::pthread_mutexattr_destroy(attr) }; + return Err(Error::IO(std::io::Error::from_raw_os_error(ret))); + } + // SAFETY: slot is valid + unsafe { slot.write(libc::PTHREAD_MUTEX_INITIALIZER) }; + // SAFETY: attr and slot are valid ptrs and attr is initialized + let ret = unsafe { libc::pthread_mutex_init(slot, attr) }; + // SAFETY: attr was initialized + unsafe { libc::pthread_mutexattr_destroy(attr) }; + if ret != 0 { + return Err(Error::IO(std::io::Error::from_raw_os_error(ret))); + } + Ok(()) + }; + // SAFETY: mutex has been initialized + unsafe { pin_init_from_closure(init) } + } + try_pin_init!(Self { + data: UnsafeCell::new(data), + raw <- init_raw(), + pin: PhantomPinned, + }? Error) + } + + pub fn lock(&self) -> PThreadMutexGuard<'_, T> { + // SAFETY: raw is always initialized + unsafe { libc::pthread_mutex_lock(self.raw.get()) }; + PThreadMutexGuard { mtx: self } + } + } + + pub struct PThreadMutexGuard<'a, T> { + mtx: &'a PThreadMutex<T>, + } + + impl<T> Drop for PThreadMutexGuard<'_, T> { + fn drop(&mut self) { + // SAFETY: raw is always initialized + unsafe { libc::pthread_mutex_unlock(self.mtx.raw.get()) }; + } + } + + impl<T> Deref for PThreadMutexGuard<'_, T> { + type Target = T; + + fn deref(&self) -> &Self::Target { + unsafe { &*self.mtx.data.get() } + } + } + + impl<T> DerefMut for PThreadMutexGuard<'_, T> { + fn deref_mut(&mut self) -> &mut Self::Target { + unsafe { &mut *self.mtx.data.get() } + } + } +} + +#[cfg_attr(test, test)] +fn main() { + #[cfg(all(any(feature = "std", feature = "alloc"), not(windows)))] + { + use core::pin::Pin; + use pin_init::*; + use pthread_mtx::*; + use std::{ + sync::Arc, + thread::{sleep, Builder}, + time::Duration, + }; + let mtx: Pin<Arc<PThreadMutex<usize>>> = Arc::try_pin_init(PThreadMutex::new(0)).unwrap(); + let mut handles = vec![]; + let thread_count = 20; + let workload = 1_000_000; + for i in 0..thread_count { + let mtx = mtx.clone(); + handles.push( + Builder::new() + .name(format!("worker #{i}")) + .spawn(move || { + for _ in 0..workload { + *mtx.lock() += 1; + } + println!("{i} halfway"); + sleep(Duration::from_millis((i as u64) * 10)); + for _ in 0..workload { + *mtx.lock() += 1; + } + println!("{i} finished"); + }) + .expect("should not fail"), + ); + } + for h in handles { + h.join().expect("thread panicked"); + } + println!("{:?}", &*mtx.lock()); + assert_eq!(*mtx.lock(), workload * thread_count * 2); + } +} diff --git a/rust/pin-init/examples/static_init.rs b/rust/pin-init/examples/static_init.rs new file mode 100644 index 000000000000..3487d761aa26 --- /dev/null +++ b/rust/pin-init/examples/static_init.rs @@ -0,0 +1,122 @@ +// SPDX-License-Identifier: Apache-2.0 OR MIT + +#![allow(clippy::undocumented_unsafe_blocks)] +#![cfg_attr(feature = "alloc", feature(allocator_api))] + +use core::{ + cell::{Cell, UnsafeCell}, + mem::MaybeUninit, + ops, + pin::Pin, + time::Duration, +}; +use pin_init::*; +use std::{ + sync::Arc, + thread::{sleep, Builder}, +}; + +#[expect(unused_attributes)] +mod mutex; +use mutex::*; + +pub struct StaticInit<T, I> { + cell: UnsafeCell<MaybeUninit<T>>, + init: Cell<Option<I>>, + lock: SpinLock, + present: Cell<bool>, +} + +unsafe impl<T: Sync, I> Sync for StaticInit<T, I> {} +unsafe impl<T: Send, I> Send for StaticInit<T, I> {} + +impl<T, I: PinInit<T>> StaticInit<T, I> { + pub const fn new(init: I) -> Self { + Self { + cell: UnsafeCell::new(MaybeUninit::uninit()), + init: Cell::new(Some(init)), + lock: SpinLock::new(), + present: Cell::new(false), + } + } +} + +impl<T, I: PinInit<T>> ops::Deref for StaticInit<T, I> { + type Target = T; + fn deref(&self) -> &Self::Target { + if self.present.get() { + unsafe { (*self.cell.get()).assume_init_ref() } + } else { + println!("acquire spinlock on static init"); + let _guard = self.lock.acquire(); + println!("rechecking present..."); + std::thread::sleep(std::time::Duration::from_millis(200)); + if self.present.get() { + return unsafe { (*self.cell.get()).assume_init_ref() }; + } + println!("doing init"); + let ptr = self.cell.get().cast::<T>(); + match self.init.take() { + Some(f) => unsafe { f.__pinned_init(ptr).unwrap() }, + None => unsafe { core::hint::unreachable_unchecked() }, + } + self.present.set(true); + unsafe { (*self.cell.get()).assume_init_ref() } + } + } +} + +pub struct CountInit; + +unsafe impl PinInit<CMutex<usize>> for CountInit { + unsafe fn __pinned_init( + self, + slot: *mut CMutex<usize>, + ) -> Result<(), core::convert::Infallible> { + let init = CMutex::new(0); + std::thread::sleep(std::time::Duration::from_millis(1000)); + unsafe { init.__pinned_init(slot) } + } +} + +pub static COUNT: StaticInit<CMutex<usize>, CountInit> = StaticInit::new(CountInit); + +#[cfg(not(any(feature = "std", feature = "alloc")))] +fn main() {} + +#[cfg(any(feature = "std", feature = "alloc"))] +fn main() { + let mtx: Pin<Arc<CMutex<usize>>> = Arc::pin_init(CMutex::new(0)).unwrap(); + let mut handles = vec![]; + let thread_count = 20; + let workload = 1_000; + for i in 0..thread_count { + let mtx = mtx.clone(); + handles.push( + Builder::new() + .name(format!("worker #{i}")) + .spawn(move || { + for _ in 0..workload { + *COUNT.lock() += 1; + std::thread::sleep(std::time::Duration::from_millis(10)); + *mtx.lock() += 1; + std::thread::sleep(std::time::Duration::from_millis(10)); + *COUNT.lock() += 1; + } + println!("{i} halfway"); + sleep(Duration::from_millis((i as u64) * 10)); + for _ in 0..workload { + std::thread::sleep(std::time::Duration::from_millis(10)); + *mtx.lock() += 1; + } + println!("{i} finished"); + }) + .expect("should not fail"), + ); + } + for h in handles { + h.join().expect("thread panicked"); + } + println!("{:?}, {:?}", &*mtx.lock(), &*COUNT.lock()); + assert_eq!(*mtx.lock(), workload * thread_count * 2); +} |