use crate::elision::{have_elision, AtomicElisionExt};
use crate::raw_mutex::{TOKEN_HANDOFF, TOKEN_NORMAL};
use crate::util;
use core::{
cell::Cell,
sync::atomic::{AtomicUsize, Ordering},
};
use instant::Instant;
use lock_api::{RawRwLock as RawRwLock_, RawRwLockUpgrade};
use parking_lot_core::{
self, deadlock, FilterOp, ParkResult, ParkToken, SpinWait, UnparkResult, UnparkToken,
};
use std::time::Duration;
const PARKED_BIT: usize = 0b0001;
const WRITER_PARKED_BIT: usize = 0b0010;
const UPGRADABLE_BIT: usize = 0b0100;
const WRITER_BIT: usize = 0b1000;
const READERS_MASK: usize = !0b1111;
const ONE_READER: usize = 0b10000;
const TOKEN_SHARED: ParkToken = ParkToken(ONE_READER);
const TOKEN_EXCLUSIVE: ParkToken = ParkToken(WRITER_BIT);
const TOKEN_UPGRADABLE: ParkToken = ParkToken(ONE_READER | UPGRADABLE_BIT);
pub struct RawRwLock {
state: AtomicUsize,
}
unsafe impl lock_api::RawRwLock for RawRwLock {
const INIT: RawRwLock = RawRwLock {
state: AtomicUsize::new(0),
};
type GuardMarker = crate::GuardMarker;
#[inline]
fn lock_exclusive(&self) {
if self
.state
.compare_exchange_weak(0, WRITER_BIT, Ordering::Acquire, Ordering::Relaxed)
.is_err()
{
let result = self.lock_exclusive_slow(None);
debug_assert!(result);
}
self.deadlock_acquire();
}
#[inline]
fn try_lock_exclusive(&self) -> bool {
if self
.state
.compare_exchange(0, WRITER_BIT, Ordering::Acquire, Ordering::Relaxed)
.is_ok()
{
self.deadlock_acquire();
true
} else {
false
}
}
#[inline]
unsafe fn unlock_exclusive(&self) {
self.deadlock_release();
if self
.state
.compare_exchange(WRITER_BIT, 0, Ordering::Release, Ordering::Relaxed)
.is_ok()
{
return;
}
self.unlock_exclusive_slow(false);
}
#[inline]
fn lock_shared(&self) {
if !self.try_lock_shared_fast(false) {
let result = self.lock_shared_slow(false, None);
debug_assert!(result);
}
self.deadlock_acquire();
}
#[inline]
fn try_lock_shared(&self) -> bool {
let result = if self.try_lock_shared_fast(false) {
true
} else {
self.try_lock_shared_slow(false)
};
if result {
self.deadlock_acquire();
}
result
}
#[inline]
unsafe fn unlock_shared(&self) {
self.deadlock_release();
let state = if have_elision() {
self.state.elision_fetch_sub_release(ONE_READER)
} else {
self.state.fetch_sub(ONE_READER, Ordering::Release)
};
if state & (READERS_MASK | WRITER_PARKED_BIT) == (ONE_READER | WRITER_PARKED_BIT) {
self.unlock_shared_slow();
}
}
#[inline]
fn is_locked(&self) -> bool {
let state = self.state.load(Ordering::Relaxed);
state & (WRITER_BIT | READERS_MASK) != 0
}
}
unsafe impl lock_api::RawRwLockFair for RawRwLock {
#[inline]
unsafe fn unlock_shared_fair(&self) {
self.unlock_shared();
}
#[inline]
unsafe fn unlock_exclusive_fair(&self) {
self.deadlock_release();
if self
.state
.compare_exchange(WRITER_BIT, 0, Ordering::Release, Ordering::Relaxed)
.is_ok()
{
return;
}
self.unlock_exclusive_slow(true);
}
#[inline]
unsafe fn bump_shared(&self) {
if self.state.load(Ordering::Relaxed) & (READERS_MASK | WRITER_BIT)
== ONE_READER | WRITER_BIT
{
self.bump_shared_slow();
}
}
#[inline]
unsafe fn bump_exclusive(&self) {
if self.state.load(Ordering::Relaxed) & PARKED_BIT != 0 {
self.bump_exclusive_slow();
}
}
}
unsafe impl lock_api::RawRwLockDowngrade for RawRwLock {
#[inline]
unsafe fn downgrade(&self) {
let state = self
.state
.fetch_add(ONE_READER - WRITER_BIT, Ordering::Release);
if state & PARKED_BIT != 0 {
self.downgrade_slow();
}
}
}
unsafe impl lock_api::RawRwLockTimed for RawRwLock {
type Duration = Duration;
type Instant = Instant;
#[inline]
fn try_lock_shared_for(&self, timeout: Self::Duration) -> bool {
let result = if self.try_lock_shared_fast(false) {
true
} else {
self.lock_shared_slow(false, util::to_deadline(timeout))
};
if result {
self.deadlock_acquire();
}
result
}
#[inline]
fn try_lock_shared_until(&self, timeout: Self::Instant) -> bool {
let result = if self.try_lock_shared_fast(false) {
true
} else {
self.lock_shared_slow(false, Some(timeout))
};
if result {
self.deadlock_acquire();
}
result
}
#[inline]
fn try_lock_exclusive_for(&self, timeout: Duration) -> bool {
let result = if self
.state
.compare_exchange_weak(0, WRITER_BIT, Ordering::Acquire, Ordering::Relaxed)
.is_ok()
{
true
} else {
self.lock_exclusive_slow(util::to_deadline(timeout))
};
if result {
self.deadlock_acquire();
}
result
}
#[inline]
fn try_lock_exclusive_until(&self, timeout: Instant) -> bool {
let result = if self
.state
.compare_exchange_weak(0, WRITER_BIT, Ordering::Acquire, Ordering::Relaxed)
.is_ok()
{
true
} else {
self.lock_exclusive_slow(Some(timeout))
};
if result {
self.deadlock_acquire();
}
result
}
}
unsafe impl lock_api::RawRwLockRecursive for RawRwLock {
#[inline]
fn lock_shared_recursive(&self) {
if !self.try_lock_shared_fast(true) {
let result = self.lock_shared_slow(true, None);
debug_assert!(result);
}
self.deadlock_acquire();
}
#[inline]
fn try_lock_shared_recursive(&self) -> bool {
let result = if self.try_lock_shared_fast(true) {
true
} else {
self.try_lock_shared_slow(true)
};
if result {
self.deadlock_acquire();
}
result
}
}
unsafe impl lock_api::RawRwLockRecursiveTimed for RawRwLock {
#[inline]
fn try_lock_shared_recursive_for(&self, timeout: Self::Duration) -> bool {
let result = if self.try_lock_shared_fast(true) {
true
} else {
self.lock_shared_slow(true, util::to_deadline(timeout))
};
if result {
self.deadlock_acquire();
}
result
}
#[inline]
fn try_lock_shared_recursive_until(&self, timeout: Self::Instant) -> bool {
let result = if self.try_lock_shared_fast(true) {
true
} else {
self.lock_shared_slow(true, Some(timeout))
};
if result {
self.deadlock_acquire();
}
result
}
}
unsafe impl lock_api::RawRwLockUpgrade for RawRwLock {
#[inline]
fn lock_upgradable(&self) {
if !self.try_lock_upgradable_fast() {
let result = self.lock_upgradable_slow(None);
debug_assert!(result);
}
self.deadlock_acquire();
}
#[inline]
fn try_lock_upgradable(&self) -> bool {
let result = if self.try_lock_upgradable_fast() {
true
} else {
self.try_lock_upgradable_slow()
};
if result {
self.deadlock_acquire();
}
result
}
#[inline]
unsafe fn unlock_upgradable(&self) {
self.deadlock_release();
let state = self.state.load(Ordering::Relaxed);
if state & PARKED_BIT == 0 {
if self
.state
.compare_exchange_weak(
state,
state - (ONE_READER | UPGRADABLE_BIT),
Ordering::Release,
Ordering::Relaxed,
)
.is_ok()
{
return;
}
}
self.unlock_upgradable_slow(false);
}
#[inline]
unsafe fn upgrade(&self) {
let state = self.state.fetch_sub(
(ONE_READER | UPGRADABLE_BIT) - WRITER_BIT,
Ordering::Relaxed,
);
if state & READERS_MASK != ONE_READER {
let result = self.upgrade_slow(None);
debug_assert!(result);
}
}
#[inline]
unsafe fn try_upgrade(&self) -> bool {
if self
.state
.compare_exchange_weak(
ONE_READER | UPGRADABLE_BIT,
WRITER_BIT,
Ordering::Relaxed,
Ordering::Relaxed,
)
.is_ok()
{
true
} else {
self.try_upgrade_slow()
}
}
}
unsafe impl lock_api::RawRwLockUpgradeFair for RawRwLock {
#[inline]
unsafe fn unlock_upgradable_fair(&self) {
self.deadlock_release();
let state = self.state.load(Ordering::Relaxed);
if state & PARKED_BIT == 0 {
if self
.state
.compare_exchange_weak(
state,
state - (ONE_READER | UPGRADABLE_BIT),
Ordering::Release,
Ordering::Relaxed,
)
.is_ok()
{
return;
}
}
self.unlock_upgradable_slow(false);
}
#[inline]
unsafe fn bump_upgradable(&self) {
if self.state.load(Ordering::Relaxed) == ONE_READER | UPGRADABLE_BIT | PARKED_BIT {
self.bump_upgradable_slow();
}
}
}
unsafe impl lock_api::RawRwLockUpgradeDowngrade for RawRwLock {
#[inline]
unsafe fn downgrade_upgradable(&self) {
let state = self.state.fetch_sub(UPGRADABLE_BIT, Ordering::Relaxed);
if state & PARKED_BIT != 0 {
self.downgrade_slow();
}
}
#[inline]
unsafe fn downgrade_to_upgradable(&self) {
let state = self.state.fetch_add(
(ONE_READER | UPGRADABLE_BIT) - WRITER_BIT,
Ordering::Release,
);
if state & PARKED_BIT != 0 {
self.downgrade_to_upgradable_slow();
}
}
}
unsafe impl lock_api::RawRwLockUpgradeTimed for RawRwLock {
#[inline]
fn try_lock_upgradable_until(&self, timeout: Instant) -> bool {
let result = if self.try_lock_upgradable_fast() {
true
} else {
self.lock_upgradable_slow(Some(timeout))
};
if result {
self.deadlock_acquire();
}
result
}
#[inline]
fn try_lock_upgradable_for(&self, timeout: Duration) -> bool {
let result = if self.try_lock_upgradable_fast() {
true
} else {
self.lock_upgradable_slow(util::to_deadline(timeout))
};
if result {
self.deadlock_acquire();
}
result
}
#[inline]
unsafe fn try_upgrade_until(&self, timeout: Instant) -> bool {
let state = self.state.fetch_sub(
(ONE_READER | UPGRADABLE_BIT) - WRITER_BIT,
Ordering::Relaxed,
);
if state & READERS_MASK == ONE_READER {
true
} else {
self.upgrade_slow(Some(timeout))
}
}
#[inline]
unsafe fn try_upgrade_for(&self, timeout: Duration) -> bool {
let state = self.state.fetch_sub(
(ONE_READER | UPGRADABLE_BIT) - WRITER_BIT,
Ordering::Relaxed,
);
if state & READERS_MASK == ONE_READER {
true
} else {
self.upgrade_slow(util::to_deadline(timeout))
}
}
}
impl RawRwLock {
#[inline(always)]
fn try_lock_shared_fast(&self, recursive: bool) -> bool {
let state = self.state.load(Ordering::Relaxed);
if state & WRITER_BIT != 0 {
if !recursive || state & READERS_MASK == 0 {
return false;
}
}
if have_elision() && state == 0 {
self.state
.elision_compare_exchange_acquire(0, ONE_READER)
.is_ok()
} else if let Some(new_state) = state.checked_add(ONE_READER) {
self.state
.compare_exchange_weak(state, new_state, Ordering::Acquire, Ordering::Relaxed)
.is_ok()
} else {
false
}
}
#[cold]
fn try_lock_shared_slow(&self, recursive: bool) -> bool {
let mut state = self.state.load(Ordering::Relaxed);
loop {
if state & WRITER_BIT != 0 {
if !recursive || state & READERS_MASK == 0 {
return false;
}
}
if have_elision() && state == 0 {
match self.state.elision_compare_exchange_acquire(0, ONE_READER) {
Ok(_) => return true,
Err(x) => state = x,
}
} else {
match self.state.compare_exchange_weak(
state,
state
.checked_add(ONE_READER)
.expect("RwLock reader count overflow"),
Ordering::Acquire,
Ordering::Relaxed,
) {
Ok(_) => return true,
Err(x) => state = x,
}
}
}
}
#[inline(always)]
fn try_lock_upgradable_fast(&self) -> bool {
let state = self.state.load(Ordering::Relaxed);
if state & (WRITER_BIT | UPGRADABLE_BIT) != 0 {
return false;
}
if let Some(new_state) = state.checked_add(ONE_READER | UPGRADABLE_BIT) {
self.state
.compare_exchange_weak(state, new_state, Ordering::Acquire, Ordering::Relaxed)
.is_ok()
} else {
false
}
}
#[cold]
fn try_lock_upgradable_slow(&self) -> bool {
let mut state = self.state.load(Ordering::Relaxed);
loop {
if state & (WRITER_BIT | UPGRADABLE_BIT) != 0 {
return false;
}
match self.state.compare_exchange_weak(
state,
state
.checked_add(ONE_READER | UPGRADABLE_BIT)
.expect("RwLock reader count overflow"),
Ordering::Acquire,
Ordering::Relaxed,
) {
Ok(_) => return true,
Err(x) => state = x,
}
}
}
#[cold]
fn lock_exclusive_slow(&self, timeout: Option<Instant>) -> bool {
let try_lock = |state: &mut usize| {
loop {
if *state & (WRITER_BIT | UPGRADABLE_BIT) != 0 {
return false;
}
match self.state.compare_exchange_weak(
*state,
*state | WRITER_BIT,
Ordering::Acquire,
Ordering::Relaxed,
) {
Ok(_) => return true,
Err(x) => *state = x,
}
}
};
let timed_out = !self.lock_common(
timeout,
TOKEN_EXCLUSIVE,
try_lock,
WRITER_BIT | UPGRADABLE_BIT,
);
if timed_out {
return false;
}
self.wait_for_readers(timeout, 0)
}
#[cold]
fn unlock_exclusive_slow(&self, force_fair: bool) {
let callback = |mut new_state, result: UnparkResult| {
if result.unparked_threads != 0 && (force_fair || result.be_fair) {
if result.have_more_threads {
new_state |= PARKED_BIT;
}
self.state.store(new_state, Ordering::Release);
TOKEN_HANDOFF
} else {
if result.have_more_threads {
self.state.store(PARKED_BIT, Ordering::Release);
} else {
self.state.store(0, Ordering::Release);
}
TOKEN_NORMAL
}
};
unsafe {
self.wake_parked_threads(0, callback);
}
}
#[cold]
fn lock_shared_slow(&self, recursive: bool, timeout: Option<Instant>) -> bool {
let try_lock = |state: &mut usize| {
let mut spinwait_shared = SpinWait::new();
loop {
if have_elision() && *state == 0 {
match self.state.elision_compare_exchange_acquire(0, ONE_READER) {
Ok(_) => return true,
Err(x) => *state = x,
}
}
if *state & WRITER_BIT != 0 {
if !recursive || *state & READERS_MASK == 0 {
return false;
}
}
if self
.state
.compare_exchange_weak(
*state,
state
.checked_add(ONE_READER)
.expect("RwLock reader count overflow"),
Ordering::Acquire,
Ordering::Relaxed,
)
.is_ok()
{
return true;
}
spinwait_shared.spin_no_yield();
*state = self.state.load(Ordering::Relaxed);
}
};
self.lock_common(timeout, TOKEN_SHARED, try_lock, WRITER_BIT)
}
#[cold]
fn unlock_shared_slow(&self) {
let addr = self as *const _ as usize + 1;
let callback = |_result: UnparkResult| {
self.state.fetch_and(!WRITER_PARKED_BIT, Ordering::Relaxed);
TOKEN_NORMAL
};
unsafe {
parking_lot_core::unpark_one(addr, callback);
}
}
#[cold]
fn lock_upgradable_slow(&self, timeout: Option<Instant>) -> bool {
let try_lock = |state: &mut usize| {
let mut spinwait_shared = SpinWait::new();
loop {
if *state & (WRITER_BIT | UPGRADABLE_BIT) != 0 {
return false;
}
if self
.state
.compare_exchange_weak(
*state,
state
.checked_add(ONE_READER | UPGRADABLE_BIT)
.expect("RwLock reader count overflow"),
Ordering::Acquire,
Ordering::Relaxed,
)
.is_ok()
{
return true;
}
spinwait_shared.spin_no_yield();
*state = self.state.load(Ordering::Relaxed);
}
};
self.lock_common(
timeout,
TOKEN_UPGRADABLE,
try_lock,
WRITER_BIT | UPGRADABLE_BIT,
)
}
#[cold]
fn unlock_upgradable_slow(&self, force_fair: bool) {
let mut state = self.state.load(Ordering::Relaxed);
while state & PARKED_BIT == 0 {
match self.state.compare_exchange_weak(
state,
state - (ONE_READER | UPGRADABLE_BIT),
Ordering::Release,
Ordering::Relaxed,
) {
Ok(_) => return,
Err(x) => state = x,
}
}
let callback = |new_state, result: UnparkResult| {
let mut state = self.state.load(Ordering::Relaxed);
if force_fair || result.be_fair {
while let Some(mut new_state) =
(state - (ONE_READER | UPGRADABLE_BIT)).checked_add(new_state)
{
if result.have_more_threads {
new_state |= PARKED_BIT;
} else {
new_state &= !PARKED_BIT;
}
match self.state.compare_exchange_weak(
state,
new_state,
Ordering::Relaxed,
Ordering::Relaxed,
) {
Ok(_) => return TOKEN_HANDOFF,
Err(x) => state = x,
}
}
}
loop {
let mut new_state = state - (ONE_READER | UPGRADABLE_BIT);
if result.have_more_threads {
new_state |= PARKED_BIT;
} else {
new_state &= !PARKED_BIT;
}
match self.state.compare_exchange_weak(
state,
new_state,
Ordering::Relaxed,
Ordering::Relaxed,
) {
Ok(_) => return TOKEN_NORMAL,
Err(x) => state = x,
}
}
};
unsafe {
self.wake_parked_threads(0, callback);
}
}
#[cold]
fn try_upgrade_slow(&self) -> bool {
let mut state = self.state.load(Ordering::Relaxed);
loop {
if state & READERS_MASK != ONE_READER {
return false;
}
match self.state.compare_exchange_weak(
state,
state - (ONE_READER | UPGRADABLE_BIT) + WRITER_BIT,
Ordering::Relaxed,
Ordering::Relaxed,
) {
Ok(_) => return true,
Err(x) => state = x,
}
}
}
#[cold]
fn upgrade_slow(&self, timeout: Option<Instant>) -> bool {
self.wait_for_readers(timeout, ONE_READER | UPGRADABLE_BIT)
}
#[cold]
fn downgrade_slow(&self) {
let callback = |_, result: UnparkResult| {
if !result.have_more_threads {
self.state.fetch_and(!PARKED_BIT, Ordering::Relaxed);
}
TOKEN_NORMAL
};
unsafe {
self.wake_parked_threads(ONE_READER, callback);
}
}
#[cold]
fn downgrade_to_upgradable_slow(&self) {
let callback = |_, result: UnparkResult| {
if !result.have_more_threads {
self.state.fetch_and(!PARKED_BIT, Ordering::Relaxed);
}
TOKEN_NORMAL
};
unsafe {
self.wake_parked_threads(ONE_READER | UPGRADABLE_BIT, callback);
}
}
#[cold]
unsafe fn bump_shared_slow(&self) {
self.unlock_shared();
self.lock_shared();
}
#[cold]
fn bump_exclusive_slow(&self) {
self.deadlock_release();
self.unlock_exclusive_slow(true);
self.lock_exclusive();
}
#[cold]
fn bump_upgradable_slow(&self) {
self.deadlock_release();
self.unlock_upgradable_slow(true);
self.lock_upgradable();
}
#[inline]
unsafe fn wake_parked_threads(
&self,
new_state: usize,
callback: impl FnOnce(usize, UnparkResult) -> UnparkToken,
) {
let new_state = Cell::new(new_state);
let addr = self as *const _ as usize;
let filter = |ParkToken(token)| {
let s = new_state.get();
if s & WRITER_BIT != 0 {
return FilterOp::Stop;
}
if token & (UPGRADABLE_BIT | WRITER_BIT) != 0 && s & UPGRADABLE_BIT != 0 {
FilterOp::Skip
} else {
new_state.set(s + token);
FilterOp::Unpark
}
};
let callback = |result| callback(new_state.get(), result);
parking_lot_core::unpark_filter(addr, filter, callback);
}
#[inline]
fn wait_for_readers(&self, timeout: Option<Instant>, prev_value: usize) -> bool {
let mut spinwait = SpinWait::new();
let mut state = self.state.load(Ordering::Relaxed);
while state & READERS_MASK != 0 {
if spinwait.spin() {
state = self.state.load(Ordering::Relaxed);
continue;
}
if state & WRITER_PARKED_BIT == 0 {
if let Err(x) = self.state.compare_exchange_weak(
state,
state | WRITER_PARKED_BIT,
Ordering::Relaxed,
Ordering::Relaxed,
) {
state = x;
continue;
}
}
let addr = self as *const _ as usize + 1;
let validate = || {
let state = self.state.load(Ordering::Relaxed);
state & READERS_MASK != 0 && state & WRITER_PARKED_BIT != 0
};
let before_sleep = || {};
let timed_out = |_, _| {};
let park_result = unsafe {
parking_lot_core::park(
addr,
validate,
before_sleep,
timed_out,
TOKEN_EXCLUSIVE,
timeout,
)
};
match park_result {
ParkResult::Unparked(_) | ParkResult::Invalid => {
state = self.state.load(Ordering::Relaxed);
continue;
}
ParkResult::TimedOut => {
let state = self.state.fetch_add(
prev_value.wrapping_sub(WRITER_BIT | WRITER_PARKED_BIT),
Ordering::Relaxed,
);
if state & PARKED_BIT != 0 {
let callback = |_, result: UnparkResult| {
if !result.have_more_threads {
self.state.fetch_and(!PARKED_BIT, Ordering::Relaxed);
}
TOKEN_NORMAL
};
unsafe {
self.wake_parked_threads(ONE_READER | UPGRADABLE_BIT, callback);
}
}
return false;
}
}
}
true
}
#[inline]
fn lock_common(
&self,
timeout: Option<Instant>,
token: ParkToken,
mut try_lock: impl FnMut(&mut usize) -> bool,
validate_flags: usize,
) -> bool {
let mut spinwait = SpinWait::new();
let mut state = self.state.load(Ordering::Relaxed);
loop {
if try_lock(&mut state) {
return true;
}
if state & (PARKED_BIT | WRITER_PARKED_BIT) == 0 && spinwait.spin() {
state = self.state.load(Ordering::Relaxed);
continue;
}
if state & PARKED_BIT == 0 {
if let Err(x) = self.state.compare_exchange_weak(
state,
state | PARKED_BIT,
Ordering::Relaxed,
Ordering::Relaxed,
) {
state = x;
continue;
}
}
let addr = self as *const _ as usize;
let validate = || {
let state = self.state.load(Ordering::Relaxed);
state & PARKED_BIT != 0 && (state & validate_flags != 0)
};
let before_sleep = || {};
let timed_out = |_, was_last_thread| {
if was_last_thread {
self.state.fetch_and(!PARKED_BIT, Ordering::Relaxed);
}
};
let park_result = unsafe {
parking_lot_core::park(addr, validate, before_sleep, timed_out, token, timeout)
};
match park_result {
ParkResult::Unparked(TOKEN_HANDOFF) => return true,
ParkResult::Unparked(_) => (),
ParkResult::Invalid => (),
ParkResult::TimedOut => return false,
}
spinwait.reset();
state = self.state.load(Ordering::Relaxed);
}
}
#[inline]
fn deadlock_acquire(&self) {
unsafe { deadlock::acquire_resource(self as *const _ as usize) };
unsafe { deadlock::acquire_resource(self as *const _ as usize + 1) };
}
#[inline]
fn deadlock_release(&self) {
unsafe { deadlock::release_resource(self as *const _ as usize) };
unsafe { deadlock::release_resource(self as *const _ as usize + 1) };
}
}