use crate::loom::cell::UnsafeCell;
use crate::loom::sync::atomic::{AtomicU16, AtomicU32, AtomicUsize};
use crate::loom::sync::{Arc, Mutex};
use crate::runtime::task;
use std::marker::PhantomData;
use std::mem::MaybeUninit;
use std::ptr::{self, NonNull};
use std::sync::atomic::Ordering::{AcqRel, Acquire, Release};
pub(super) struct Local<T: 'static> {
inner: Arc<Inner<T>>,
}
pub(super) struct Steal<T: 'static>(Arc<Inner<T>>);
pub(super) struct Inject<T: 'static> {
pointers: Mutex<Pointers>,
len: AtomicUsize,
_p: PhantomData<T>,
}
pub(super) struct Inner<T: 'static> {
head: AtomicU32,
tail: AtomicU16,
buffer: Box<[UnsafeCell<MaybeUninit<task::Notified<T>>>]>,
}
struct Pointers {
is_closed: bool,
head: Option<NonNull<task::Header>>,
tail: Option<NonNull<task::Header>>,
}
unsafe impl<T> Send for Inner<T> {}
unsafe impl<T> Sync for Inner<T> {}
unsafe impl<T> Send for Inject<T> {}
unsafe impl<T> Sync for Inject<T> {}
#[cfg(not(loom))]
const LOCAL_QUEUE_CAPACITY: usize = 256;
#[cfg(loom)]
const LOCAL_QUEUE_CAPACITY: usize = 4;
const MASK: usize = LOCAL_QUEUE_CAPACITY - 1;
pub(super) fn local<T: 'static>() -> (Steal<T>, Local<T>) {
let mut buffer = Vec::with_capacity(LOCAL_QUEUE_CAPACITY);
for _ in 0..LOCAL_QUEUE_CAPACITY {
buffer.push(UnsafeCell::new(MaybeUninit::uninit()));
}
let inner = Arc::new(Inner {
head: AtomicU32::new(0),
tail: AtomicU16::new(0),
buffer: buffer.into(),
});
let local = Local {
inner: inner.clone(),
};
let remote = Steal(inner);
(remote, local)
}
impl<T> Local<T> {
pub(super) fn is_stealable(&self) -> bool {
!self.inner.is_empty()
}
pub(super) fn push_back(&mut self, mut task: task::Notified<T>, inject: &Inject<T>) {
let tail = loop {
let head = self.inner.head.load(Acquire);
let (steal, real) = unpack(head);
let tail = unsafe { self.inner.tail.unsync_load() };
if tail.wrapping_sub(steal) < LOCAL_QUEUE_CAPACITY as u16 {
break tail;
} else if steal != real {
inject.push(task);
return;
} else {
match self.push_overflow(task, real, tail, inject) {
Ok(_) => return,
Err(v) => {
task = v;
}
}
}
};
let idx = tail as usize & MASK;
self.inner.buffer[idx].with_mut(|ptr| {
unsafe {
ptr::write((*ptr).as_mut_ptr(), task);
}
});
self.inner.tail.store(tail.wrapping_add(1), Release);
}
#[inline(never)]
fn push_overflow(
&mut self,
task: task::Notified<T>,
head: u16,
tail: u16,
inject: &Inject<T>,
) -> Result<(), task::Notified<T>> {
const BATCH_LEN: usize = LOCAL_QUEUE_CAPACITY / 2 + 1;
let n = (LOCAL_QUEUE_CAPACITY / 2) as u16;
assert_eq!(
tail.wrapping_sub(head) as usize,
LOCAL_QUEUE_CAPACITY,
"queue is not full; tail = {}; head = {}",
tail,
head
);
let prev = pack(head, head);
let actual = self.inner.head.compare_and_swap(
prev,
pack(head.wrapping_add(n), head.wrapping_add(n)),
Release,
);
if actual != prev {
return Err(task);
}
for i in 0..n {
let j = i + 1;
let i_idx = i.wrapping_add(head) as usize & MASK;
let j_idx = j.wrapping_add(head) as usize & MASK;
let next = if j == n {
task.header().into()
} else {
self.inner.buffer[j_idx].with(|ptr| unsafe {
let value = (*ptr).as_ptr();
(*value).header().into()
})
};
self.inner.buffer[i_idx].with_mut(|ptr| unsafe {
let ptr = (*ptr).as_ptr();
(*ptr).header().queue_next.with_mut(|ptr| *ptr = Some(next));
});
}
let head = self.inner.buffer[head as usize & MASK]
.with(|ptr| unsafe { ptr::read((*ptr).as_ptr()) });
inject.push_batch(head, task, BATCH_LEN);
Ok(())
}
pub(super) fn pop(&mut self) -> Option<task::Notified<T>> {
let mut head = self.inner.head.load(Acquire);
let idx = loop {
let (steal, real) = unpack(head);
let tail = unsafe { self.inner.tail.unsync_load() };
if real == tail {
return None;
}
let next_real = real.wrapping_add(1);
let next = if steal == real {
pack(next_real, next_real)
} else {
assert_ne!(steal, next_real);
pack(steal, next_real)
};
let res = self
.inner
.head
.compare_exchange(head, next, AcqRel, Acquire);
match res {
Ok(_) => break real as usize & MASK,
Err(actual) => head = actual,
}
};
Some(self.inner.buffer[idx].with(|ptr| unsafe { ptr::read(ptr).assume_init() }))
}
}
impl<T> Steal<T> {
pub(super) fn is_empty(&self) -> bool {
self.0.is_empty()
}
pub(super) fn steal_into(&self, dst: &mut Local<T>) -> Option<task::Notified<T>> {
let dst_tail = unsafe { dst.inner.tail.unsync_load() };
let (steal, _) = unpack(dst.inner.head.load(Acquire));
if dst_tail.wrapping_sub(steal) > LOCAL_QUEUE_CAPACITY as u16 / 2 {
return None;
}
let mut n = self.steal_into2(dst, dst_tail);
if n == 0 {
return None;
}
n -= 1;
let ret_pos = dst_tail.wrapping_add(n);
let ret_idx = ret_pos as usize & MASK;
let ret = dst.inner.buffer[ret_idx].with(|ptr| unsafe { ptr::read((*ptr).as_ptr()) });
if n == 0 {
return Some(ret);
}
dst.inner.tail.store(dst_tail.wrapping_add(n), Release);
Some(ret)
}
fn steal_into2(&self, dst: &mut Local<T>, dst_tail: u16) -> u16 {
let mut prev_packed = self.0.head.load(Acquire);
let mut next_packed;
let n = loop {
let (src_head_steal, src_head_real) = unpack(prev_packed);
let src_tail = self.0.tail.load(Acquire);
if src_head_steal != src_head_real {
return 0;
}
let n = src_tail.wrapping_sub(src_head_real);
let n = n - n / 2;
if n == 0 {
return 0;
}
let steal_to = src_head_real.wrapping_add(n);
assert_ne!(src_head_steal, steal_to);
next_packed = pack(src_head_steal, steal_to);
let res = self
.0
.head
.compare_exchange(prev_packed, next_packed, AcqRel, Acquire);
match res {
Ok(_) => break n,
Err(actual) => prev_packed = actual,
}
};
assert!(n <= LOCAL_QUEUE_CAPACITY as u16 / 2, "actual = {}", n);
let (first, _) = unpack(next_packed);
for i in 0..n {
let src_pos = first.wrapping_add(i);
let dst_pos = dst_tail.wrapping_add(i);
let src_idx = src_pos as usize & MASK;
let dst_idx = dst_pos as usize & MASK;
let task = self.0.buffer[src_idx].with(|ptr| unsafe { ptr::read((*ptr).as_ptr()) });
dst.inner.buffer[dst_idx]
.with_mut(|ptr| unsafe { ptr::write((*ptr).as_mut_ptr(), task) });
}
let mut prev_packed = next_packed;
loop {
let head = unpack(prev_packed).1;
next_packed = pack(head, head);
let res = self
.0
.head
.compare_exchange(prev_packed, next_packed, AcqRel, Acquire);
match res {
Ok(_) => return n,
Err(actual) => {
let (actual_steal, actual_real) = unpack(actual);
assert_ne!(actual_steal, actual_real);
prev_packed = actual;
}
}
}
}
}
impl<T> Clone for Steal<T> {
fn clone(&self) -> Steal<T> {
Steal(self.0.clone())
}
}
impl<T> Drop for Local<T> {
fn drop(&mut self) {
if !std::thread::panicking() {
assert!(self.pop().is_none(), "queue not empty");
}
}
}
impl<T> Inner<T> {
fn is_empty(&self) -> bool {
let (_, head) = unpack(self.head.load(Acquire));
let tail = self.tail.load(Acquire);
head == tail
}
}
impl<T: 'static> Inject<T> {
pub(super) fn new() -> Inject<T> {
Inject {
pointers: Mutex::new(Pointers {
is_closed: false,
head: None,
tail: None,
}),
len: AtomicUsize::new(0),
_p: PhantomData,
}
}
pub(super) fn is_empty(&self) -> bool {
self.len() == 0
}
pub(super) fn close(&self) -> bool {
let mut p = self.pointers.lock().unwrap();
if p.is_closed {
return false;
}
p.is_closed = true;
true
}
pub(super) fn is_closed(&self) -> bool {
self.pointers.lock().unwrap().is_closed
}
pub(super) fn len(&self) -> usize {
self.len.load(Acquire)
}
pub(super) fn push(&self, task: task::Notified<T>) {
let mut p = self.pointers.lock().unwrap();
if p.is_closed {
drop(p);
drop(task);
return;
}
let len = unsafe { self.len.unsync_load() };
let task = task.into_raw();
debug_assert!(get_next(task).is_none());
if let Some(tail) = p.tail {
set_next(tail, Some(task));
} else {
p.head = Some(task);
}
p.tail = Some(task);
self.len.store(len + 1, Release);
}
pub(super) fn push_batch(
&self,
batch_head: task::Notified<T>,
batch_tail: task::Notified<T>,
num: usize,
) {
let batch_head = batch_head.into_raw();
let batch_tail = batch_tail.into_raw();
debug_assert!(get_next(batch_tail).is_none());
let mut p = self.pointers.lock().unwrap();
if let Some(tail) = p.tail {
set_next(tail, Some(batch_head));
} else {
p.head = Some(batch_head);
}
p.tail = Some(batch_tail);
let len = unsafe { self.len.unsync_load() };
self.len.store(len + num, Release);
}
pub(super) fn pop(&self) -> Option<task::Notified<T>> {
if self.is_empty() {
return None;
}
let mut p = self.pointers.lock().unwrap();
let task = p.head?;
p.head = get_next(task);
if p.head.is_none() {
p.tail = None;
}
set_next(task, None);
self.len
.store(unsafe { self.len.unsync_load() } - 1, Release);
Some(unsafe { task::Notified::from_raw(task) })
}
}
impl<T: 'static> Drop for Inject<T> {
fn drop(&mut self) {
if !std::thread::panicking() {
assert!(self.pop().is_none(), "queue not empty");
}
}
}
fn get_next(header: NonNull<task::Header>) -> Option<NonNull<task::Header>> {
unsafe { header.as_ref().queue_next.with(|ptr| *ptr) }
}
fn set_next(header: NonNull<task::Header>, val: Option<NonNull<task::Header>>) {
unsafe {
header.as_ref().queue_next.with_mut(|ptr| *ptr = val);
}
}
fn unpack(n: u32) -> (u16, u16) {
let real = n & u16::max_value() as u32;
let steal = n >> 16;
(steal as u16, real as u16)
}
fn pack(steal: u16, real: u16) -> u32 {
(real as u32) | ((steal as u32) << 16)
}
#[test]
fn test_local_queue_capacity() {
assert!(LOCAL_QUEUE_CAPACITY - 1 <= u8::max_value() as usize);
}