1use core::{
71    fmt,
72    hash::{Hash, Hasher},
73    mem::{ManuallyDrop, MaybeUninit},
74    ops, ptr,
75    sync::atomic::{self, AtomicUsize, Ordering},
76};
77
78use super::treiber::{NonNullPtr, Stack, UnionNode};
79
80#[macro_export]
84macro_rules! arc_pool {
85    ($name:ident: $data_type:ty) => {
86        pub struct $name;
87
88        impl $crate::pool::arc::ArcPool for $name {
89            type Data = $data_type;
90
91            fn singleton() -> &'static $crate::pool::arc::ArcPoolImpl<$data_type> {
92                static $name: $crate::pool::arc::ArcPoolImpl<$data_type> =
93                    $crate::pool::arc::ArcPoolImpl::new();
94
95                &$name
96            }
97        }
98
99        impl $name {
100            #[allow(dead_code)]
102            pub fn alloc(
103                &self,
104                value: $data_type,
105            ) -> Result<$crate::pool::arc::Arc<$name>, $data_type> {
106                <$name as $crate::pool::arc::ArcPool>::alloc(value)
107            }
108
109            #[allow(dead_code)]
111            pub fn manage(&self, block: &'static mut $crate::pool::arc::ArcBlock<$data_type>) {
112                <$name as $crate::pool::arc::ArcPool>::manage(block)
113            }
114        }
115    };
116}
117
118pub trait ArcPool: Sized {
120    type Data: 'static;
122
123    #[doc(hidden)]
125    fn singleton() -> &'static ArcPoolImpl<Self::Data>;
126
127    fn alloc(value: Self::Data) -> Result<Arc<Self>, Self::Data> {
135        Ok(Arc {
136            node_ptr: Self::singleton().alloc(value)?,
137        })
138    }
139
140    fn manage(block: &'static mut ArcBlock<Self::Data>) {
142        Self::singleton().manage(block)
143    }
144}
145
146#[doc(hidden)]
149pub struct ArcPoolImpl<T> {
150    stack: Stack<UnionNode<MaybeUninit<ArcInner<T>>>>,
151}
152
153impl<T> ArcPoolImpl<T> {
154    #[doc(hidden)]
156    pub const fn new() -> Self {
157        Self {
158            stack: Stack::new(),
159        }
160    }
161
162    fn alloc(&self, value: T) -> Result<NonNullPtr<UnionNode<MaybeUninit<ArcInner<T>>>>, T> {
163        if let Some(node_ptr) = self.stack.try_pop() {
164            let inner = ArcInner {
165                data: value,
166                strong: AtomicUsize::new(1),
167            };
168            unsafe { node_ptr.as_ptr().cast::<ArcInner<T>>().write(inner) }
169
170            Ok(node_ptr)
171        } else {
172            Err(value)
173        }
174    }
175
176    fn manage(&self, block: &'static mut ArcBlock<T>) {
177        let node: &'static mut _ = &mut block.node;
178
179        unsafe { self.stack.push(NonNullPtr::from_static_mut_ref(node)) }
180    }
181}
182
183unsafe impl<T> Sync for ArcPoolImpl<T> {}
184
185pub struct Arc<P>
187where
188    P: ArcPool,
189{
190    node_ptr: NonNullPtr<UnionNode<MaybeUninit<ArcInner<P::Data>>>>,
191}
192
193impl<P> Arc<P>
194where
195    P: ArcPool,
196{
197    fn inner(&self) -> &ArcInner<P::Data> {
198        unsafe { &*self.node_ptr.as_ptr().cast::<ArcInner<P::Data>>() }
199    }
200
201    fn from_inner(node_ptr: NonNullPtr<UnionNode<MaybeUninit<ArcInner<P::Data>>>>) -> Self {
202        Self { node_ptr }
203    }
204
205    unsafe fn get_mut_unchecked(this: &mut Self) -> &mut P::Data {
206        &mut *ptr::addr_of_mut!((*this.node_ptr.as_ptr().cast::<ArcInner<P::Data>>()).data)
207    }
208
209    #[inline(never)]
210    unsafe fn drop_slow(&mut self) {
211        ptr::drop_in_place(Self::get_mut_unchecked(self));
213
214        P::singleton().stack.push(self.node_ptr);
216    }
217}
218
219impl<P> AsRef<P::Data> for Arc<P>
220where
221    P: ArcPool,
222{
223    fn as_ref(&self) -> &P::Data {
224        &**self
225    }
226}
227
228const MAX_REFCOUNT: usize = (isize::MAX) as usize;
229
230impl<P> Clone for Arc<P>
231where
232    P: ArcPool,
233{
234    fn clone(&self) -> Self {
235        let old_size = self.inner().strong.fetch_add(1, Ordering::Relaxed);
236
237        if old_size > MAX_REFCOUNT {
238            panic!();
240        }
241
242        Self::from_inner(self.node_ptr)
243    }
244}
245
246impl<A> fmt::Debug for Arc<A>
247where
248    A: ArcPool,
249    A::Data: fmt::Debug,
250{
251    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
252        A::Data::fmt(self, f)
253    }
254}
255
256impl<P> ops::Deref for Arc<P>
257where
258    P: ArcPool,
259{
260    type Target = P::Data;
261
262    fn deref(&self) -> &Self::Target {
263        unsafe { &*ptr::addr_of!((*self.node_ptr.as_ptr().cast::<ArcInner<P::Data>>()).data) }
264    }
265}
266
267impl<A> fmt::Display for Arc<A>
268where
269    A: ArcPool,
270    A::Data: fmt::Display,
271{
272    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
273        A::Data::fmt(self, f)
274    }
275}
276
277impl<A> Drop for Arc<A>
278where
279    A: ArcPool,
280{
281    fn drop(&mut self) {
282        if self.inner().strong.fetch_sub(1, Ordering::Release) != 1 {
283            return;
284        }
285
286        atomic::fence(Ordering::Acquire);
287
288        unsafe { self.drop_slow() }
289    }
290}
291
292impl<A> Eq for Arc<A>
293where
294    A: ArcPool,
295    A::Data: Eq,
296{
297}
298
299impl<A> Hash for Arc<A>
300where
301    A: ArcPool,
302    A::Data: Hash,
303{
304    fn hash<H>(&self, state: &mut H)
305    where
306        H: Hasher,
307    {
308        (**self).hash(state)
309    }
310}
311
312impl<A> Ord for Arc<A>
313where
314    A: ArcPool,
315    A::Data: Ord,
316{
317    fn cmp(&self, other: &Self) -> core::cmp::Ordering {
318        A::Data::cmp(self, other)
319    }
320}
321
322impl<A, B> PartialEq<Arc<B>> for Arc<A>
323where
324    A: ArcPool,
325    B: ArcPool,
326    A::Data: PartialEq<B::Data>,
327{
328    fn eq(&self, other: &Arc<B>) -> bool {
329        A::Data::eq(self, &**other)
330    }
331}
332
333impl<A, B> PartialOrd<Arc<B>> for Arc<A>
334where
335    A: ArcPool,
336    B: ArcPool,
337    A::Data: PartialOrd<B::Data>,
338{
339    fn partial_cmp(&self, other: &Arc<B>) -> Option<core::cmp::Ordering> {
340        A::Data::partial_cmp(self, &**other)
341    }
342}
343
344unsafe impl<A> Send for Arc<A>
345where
346    A: ArcPool,
347    A::Data: Sync + Send,
348{
349}
350
351unsafe impl<A> Sync for Arc<A>
352where
353    A: ArcPool,
354    A::Data: Sync + Send,
355{
356}
357
358impl<A> Unpin for Arc<A> where A: ArcPool {}
359
360struct ArcInner<T> {
361    data: T,
362    strong: AtomicUsize,
363}
364
365pub struct ArcBlock<T> {
367    node: UnionNode<MaybeUninit<ArcInner<T>>>,
368}
369
370impl<T> ArcBlock<T> {
371    pub const fn new() -> Self {
373        Self {
374            node: UnionNode {
375                data: ManuallyDrop::new(MaybeUninit::uninit()),
376            },
377        }
378    }
379}
380
381#[cfg(test)]
382mod tests {
383    use super::*;
384
385    #[test]
386    fn cannot_alloc_if_empty() {
387        arc_pool!(P: i32);
388
389        assert_eq!(Err(42), P.alloc(42),);
390    }
391
392    #[test]
393    fn can_alloc_if_manages_one_block() {
394        arc_pool!(P: i32);
395
396        let block = unsafe {
397            static mut B: ArcBlock<i32> = ArcBlock::new();
398            &mut B
399        };
400        P.manage(block);
401
402        assert_eq!(42, *P.alloc(42).unwrap());
403    }
404
405    #[test]
406    fn alloc_drop_alloc() {
407        arc_pool!(P: i32);
408
409        let block = unsafe {
410            static mut B: ArcBlock<i32> = ArcBlock::new();
411            &mut B
412        };
413        P.manage(block);
414
415        let arc = P.alloc(1).unwrap();
416
417        drop(arc);
418
419        assert_eq!(2, *P.alloc(2).unwrap());
420    }
421
422    #[test]
423    fn strong_count_starts_at_one() {
424        arc_pool!(P: i32);
425
426        let block = unsafe {
427            static mut B: ArcBlock<i32> = ArcBlock::new();
428            &mut B
429        };
430        P.manage(block);
431
432        let arc = P.alloc(1).ok().unwrap();
433
434        assert_eq!(1, arc.inner().strong.load(Ordering::Relaxed));
435    }
436
437    #[test]
438    fn clone_increases_strong_count() {
439        arc_pool!(P: i32);
440
441        let block = unsafe {
442            static mut B: ArcBlock<i32> = ArcBlock::new();
443            &mut B
444        };
445        P.manage(block);
446
447        let arc = P.alloc(1).ok().unwrap();
448
449        let before = arc.inner().strong.load(Ordering::Relaxed);
450
451        let arc2 = arc.clone();
452
453        let expected = before + 1;
454        assert_eq!(expected, arc.inner().strong.load(Ordering::Relaxed));
455        assert_eq!(expected, arc2.inner().strong.load(Ordering::Relaxed));
456    }
457
458    #[test]
459    fn drop_decreases_strong_count() {
460        arc_pool!(P: i32);
461
462        let block = unsafe {
463            static mut B: ArcBlock<i32> = ArcBlock::new();
464            &mut B
465        };
466        P.manage(block);
467
468        let arc = P.alloc(1).ok().unwrap();
469        let arc2 = arc.clone();
470
471        let before = arc.inner().strong.load(Ordering::Relaxed);
472
473        drop(arc);
474
475        let expected = before - 1;
476        assert_eq!(expected, arc2.inner().strong.load(Ordering::Relaxed));
477    }
478
479    #[test]
480    fn runs_destructor_exactly_once_when_strong_count_reaches_zero() {
481        static COUNT: AtomicUsize = AtomicUsize::new(0);
482
483        pub struct S;
484
485        impl Drop for S {
486            fn drop(&mut self) {
487                COUNT.fetch_add(1, Ordering::Relaxed);
488            }
489        }
490
491        arc_pool!(P: S);
492
493        let block = unsafe {
494            static mut B: ArcBlock<S> = ArcBlock::new();
495            &mut B
496        };
497        P.manage(block);
498
499        let arc = P.alloc(S).ok().unwrap();
500
501        assert_eq!(0, COUNT.load(Ordering::Relaxed));
502
503        drop(arc);
504
505        assert_eq!(1, COUNT.load(Ordering::Relaxed));
506    }
507
508    #[test]
509    fn zst_is_well_aligned() {
510        #[repr(align(4096))]
511        pub struct Zst4096;
512
513        arc_pool!(P: Zst4096);
514
515        let block = unsafe {
516            static mut B: ArcBlock<Zst4096> = ArcBlock::new();
517            &mut B
518        };
519        P.manage(block);
520
521        let arc = P.alloc(Zst4096).ok().unwrap();
522
523        let raw = &*arc as *const Zst4096;
524        assert_eq!(0, raw as usize % 4096);
525    }
526}