use core::{
fmt,
hash::{Hash, Hasher},
mem::{ManuallyDrop, MaybeUninit},
ops, ptr,
sync::atomic::{self, AtomicUsize, Ordering},
};
use super::treiber::{NonNullPtr, Stack, UnionNode};
#[macro_export]
macro_rules! arc_pool {
($name:ident: $data_type:ty) => {
pub struct $name;
impl $crate::pool::arc::ArcPool for $name {
type Data = $data_type;
fn singleton() -> &'static $crate::pool::arc::ArcPoolImpl<$data_type> {
static $name: $crate::pool::arc::ArcPoolImpl<$data_type> =
$crate::pool::arc::ArcPoolImpl::new();
&$name
}
}
impl $name {
#[allow(dead_code)]
pub fn alloc(
&self,
value: $data_type,
) -> Result<$crate::pool::arc::Arc<$name>, $data_type> {
<$name as $crate::pool::arc::ArcPool>::alloc(value)
}
#[allow(dead_code)]
pub fn manage(&self, block: &'static mut $crate::pool::arc::ArcBlock<$data_type>) {
<$name as $crate::pool::arc::ArcPool>::manage(block)
}
}
};
}
pub trait ArcPool: Sized {
type Data: 'static;
#[doc(hidden)]
fn singleton() -> &'static ArcPoolImpl<Self::Data>;
fn alloc(value: Self::Data) -> Result<Arc<Self>, Self::Data> {
Ok(Arc {
node_ptr: Self::singleton().alloc(value)?,
})
}
fn manage(block: &'static mut ArcBlock<Self::Data>) {
Self::singleton().manage(block)
}
}
#[doc(hidden)]
pub struct ArcPoolImpl<T> {
stack: Stack<UnionNode<MaybeUninit<ArcInner<T>>>>,
}
impl<T> ArcPoolImpl<T> {
#[doc(hidden)]
pub const fn new() -> Self {
Self {
stack: Stack::new(),
}
}
fn alloc(&self, value: T) -> Result<NonNullPtr<UnionNode<MaybeUninit<ArcInner<T>>>>, T> {
if let Some(node_ptr) = self.stack.try_pop() {
let inner = ArcInner {
data: value,
strong: AtomicUsize::new(1),
};
unsafe { node_ptr.as_ptr().cast::<ArcInner<T>>().write(inner) }
Ok(node_ptr)
} else {
Err(value)
}
}
fn manage(&self, block: &'static mut ArcBlock<T>) {
let node: &'static mut _ = &mut block.node;
unsafe { self.stack.push(NonNullPtr::from_static_mut_ref(node)) }
}
}
unsafe impl<T> Sync for ArcPoolImpl<T> {}
pub struct Arc<P>
where
P: ArcPool,
{
node_ptr: NonNullPtr<UnionNode<MaybeUninit<ArcInner<P::Data>>>>,
}
impl<P> Arc<P>
where
P: ArcPool,
{
fn inner(&self) -> &ArcInner<P::Data> {
unsafe { &*self.node_ptr.as_ptr().cast::<ArcInner<P::Data>>() }
}
fn from_inner(node_ptr: NonNullPtr<UnionNode<MaybeUninit<ArcInner<P::Data>>>>) -> Self {
Self { node_ptr }
}
unsafe fn get_mut_unchecked(this: &mut Self) -> &mut P::Data {
&mut *ptr::addr_of_mut!((*this.node_ptr.as_ptr().cast::<ArcInner<P::Data>>()).data)
}
#[inline(never)]
unsafe fn drop_slow(&mut self) {
ptr::drop_in_place(Self::get_mut_unchecked(self));
P::singleton().stack.push(self.node_ptr);
}
}
impl<P> AsRef<P::Data> for Arc<P>
where
P: ArcPool,
{
fn as_ref(&self) -> &P::Data {
&**self
}
}
const MAX_REFCOUNT: usize = (isize::MAX) as usize;
impl<P> Clone for Arc<P>
where
P: ArcPool,
{
fn clone(&self) -> Self {
let old_size = self.inner().strong.fetch_add(1, Ordering::Relaxed);
if old_size > MAX_REFCOUNT {
panic!();
}
Self::from_inner(self.node_ptr)
}
}
impl<A> fmt::Debug for Arc<A>
where
A: ArcPool,
A::Data: fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
A::Data::fmt(self, f)
}
}
impl<P> ops::Deref for Arc<P>
where
P: ArcPool,
{
type Target = P::Data;
fn deref(&self) -> &Self::Target {
unsafe { &*ptr::addr_of!((*self.node_ptr.as_ptr().cast::<ArcInner<P::Data>>()).data) }
}
}
impl<A> fmt::Display for Arc<A>
where
A: ArcPool,
A::Data: fmt::Display,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
A::Data::fmt(self, f)
}
}
impl<A> Drop for Arc<A>
where
A: ArcPool,
{
fn drop(&mut self) {
if self.inner().strong.fetch_sub(1, Ordering::Release) != 1 {
return;
}
atomic::fence(Ordering::Acquire);
unsafe { self.drop_slow() }
}
}
impl<A> Eq for Arc<A>
where
A: ArcPool,
A::Data: Eq,
{
}
impl<A> Hash for Arc<A>
where
A: ArcPool,
A::Data: Hash,
{
fn hash<H>(&self, state: &mut H)
where
H: Hasher,
{
(**self).hash(state)
}
}
impl<A> Ord for Arc<A>
where
A: ArcPool,
A::Data: Ord,
{
fn cmp(&self, other: &Self) -> core::cmp::Ordering {
A::Data::cmp(self, other)
}
}
impl<A, B> PartialEq<Arc<B>> for Arc<A>
where
A: ArcPool,
B: ArcPool,
A::Data: PartialEq<B::Data>,
{
fn eq(&self, other: &Arc<B>) -> bool {
A::Data::eq(self, &**other)
}
}
impl<A, B> PartialOrd<Arc<B>> for Arc<A>
where
A: ArcPool,
B: ArcPool,
A::Data: PartialOrd<B::Data>,
{
fn partial_cmp(&self, other: &Arc<B>) -> Option<core::cmp::Ordering> {
A::Data::partial_cmp(self, &**other)
}
}
unsafe impl<A> Send for Arc<A>
where
A: ArcPool,
A::Data: Sync + Send,
{
}
unsafe impl<A> Sync for Arc<A>
where
A: ArcPool,
A::Data: Sync + Send,
{
}
impl<A> Unpin for Arc<A> where A: ArcPool {}
struct ArcInner<T> {
data: T,
strong: AtomicUsize,
}
pub struct ArcBlock<T> {
node: UnionNode<MaybeUninit<ArcInner<T>>>,
}
impl<T> ArcBlock<T> {
pub const fn new() -> Self {
Self {
node: UnionNode {
data: ManuallyDrop::new(MaybeUninit::uninit()),
},
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn cannot_alloc_if_empty() {
arc_pool!(P: i32);
assert_eq!(Err(42), P.alloc(42),);
}
#[test]
fn can_alloc_if_manages_one_block() {
arc_pool!(P: i32);
let block = unsafe {
static mut B: ArcBlock<i32> = ArcBlock::new();
&mut B
};
P.manage(block);
assert_eq!(42, *P.alloc(42).unwrap());
}
#[test]
fn alloc_drop_alloc() {
arc_pool!(P: i32);
let block = unsafe {
static mut B: ArcBlock<i32> = ArcBlock::new();
&mut B
};
P.manage(block);
let arc = P.alloc(1).unwrap();
drop(arc);
assert_eq!(2, *P.alloc(2).unwrap());
}
#[test]
fn strong_count_starts_at_one() {
arc_pool!(P: i32);
let block = unsafe {
static mut B: ArcBlock<i32> = ArcBlock::new();
&mut B
};
P.manage(block);
let arc = P.alloc(1).ok().unwrap();
assert_eq!(1, arc.inner().strong.load(Ordering::Relaxed));
}
#[test]
fn clone_increases_strong_count() {
arc_pool!(P: i32);
let block = unsafe {
static mut B: ArcBlock<i32> = ArcBlock::new();
&mut B
};
P.manage(block);
let arc = P.alloc(1).ok().unwrap();
let before = arc.inner().strong.load(Ordering::Relaxed);
let arc2 = arc.clone();
let expected = before + 1;
assert_eq!(expected, arc.inner().strong.load(Ordering::Relaxed));
assert_eq!(expected, arc2.inner().strong.load(Ordering::Relaxed));
}
#[test]
fn drop_decreases_strong_count() {
arc_pool!(P: i32);
let block = unsafe {
static mut B: ArcBlock<i32> = ArcBlock::new();
&mut B
};
P.manage(block);
let arc = P.alloc(1).ok().unwrap();
let arc2 = arc.clone();
let before = arc.inner().strong.load(Ordering::Relaxed);
drop(arc);
let expected = before - 1;
assert_eq!(expected, arc2.inner().strong.load(Ordering::Relaxed));
}
#[test]
fn runs_destructor_exactly_once_when_strong_count_reaches_zero() {
static COUNT: AtomicUsize = AtomicUsize::new(0);
pub struct S;
impl Drop for S {
fn drop(&mut self) {
COUNT.fetch_add(1, Ordering::Relaxed);
}
}
arc_pool!(P: S);
let block = unsafe {
static mut B: ArcBlock<S> = ArcBlock::new();
&mut B
};
P.manage(block);
let arc = P.alloc(S).ok().unwrap();
assert_eq!(0, COUNT.load(Ordering::Relaxed));
drop(arc);
assert_eq!(1, COUNT.load(Ordering::Relaxed));
}
#[test]
fn zst_is_well_aligned() {
#[repr(align(4096))]
pub struct Zst4096;
arc_pool!(P: Zst4096);
let block = unsafe {
static mut B: ArcBlock<Zst4096> = ArcBlock::new();
&mut B
};
P.manage(block);
let arc = P.alloc(Zst4096).ok().unwrap();
let raw = &*arc as *const Zst4096;
assert_eq!(0, raw as usize % 4096);
}
}