use crate::metrics::*;
use crossbeam::channel::{unbounded, Receiver, Sender, TryRecvError};
use crossbeam::queue::SegQueue as Queue;
use futures::select;
use futures::task::AtomicWaker;
use futures::{future::poll_fn, FutureExt, Sink, SinkExt, StreamExt};
use std::future::Future;
use std::pin::Pin;
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::sync::{Arc, Mutex, RwLock};
use std::task::{Context, Poll, Waker};
use std::time::Duration;
use tikv_util::{debug, info, warn};
use tokio::sync::mpsc::{
channel as async_channel, Receiver as AsyncReceiver, Sender as AsyncSender,
};
#[derive(Debug, PartialEq, Eq)]
pub enum RateLimiterError {
DisconnectedError,
CongestedError(
usize,
),
DownstreamClosed(u64 ),
}
#[derive(Debug, PartialEq, Eq)]
pub enum DrainerError<RpcError> {
RateLimitExceededError,
RpcSinkError(RpcError),
}
pub struct RateLimiter<E> {
sink: Sender<E>,
close_tx: AsyncSender<()>,
state: Arc<State>,
per_downstream_state: Option<Arc<PerDownstreamState>>,
}
pub struct Drainer<E> {
receiver: Receiver<E>,
close_rx: Option<AsyncReceiver<()>>,
state: Arc<State>,
}
struct State {
block_scan_threshold: usize,
close_sink_threshold: usize,
is_sink_closed: AtomicBool,
ref_count: AtomicUsize,
wait_queue: RwLock<Queue<Arc<Mutex<Option<Waker>>>>>,
recv_task: AtomicWaker,
force_flush_flag: AtomicBool,
#[cfg(test)]
blocked_sender_count: AtomicUsize,
}
struct PerDownstreamState {
region_id: u64,
is_downstream_closed: AtomicBool,
}
impl State {
#[inline]
fn yield_drainer(&self, cx: &mut Context<'_>) {
self.recv_task.register(cx.waker());
}
#[inline]
fn unyield_drainer(&self) {
let _ = self.recv_task.take();
}
#[inline]
fn wake_up_drainer(&self) {
self.recv_task.wake();
}
fn wake_up_one_sender(&self) {
let queue = self.wait_queue.write().unwrap();
while let Some(waker) = queue.pop() {
let mut waker = waker.lock().unwrap();
if let Some(waker) = waker.take() {
waker.wake();
return;
}
}
}
fn wake_up_all_senders(&self) {
let queue = self.wait_queue.write().unwrap();
while let Some(waker) = queue.pop() {
let mut waker = waker.lock().unwrap();
if let Some(waker) = waker.take() {
waker.wake();
}
}
}
}
pub fn new_pair<E>(
block_scan_threshold: usize,
close_sink_threshold: usize,
) -> (RateLimiter<E>, Drainer<E>) {
let (sender, receiver) = unbounded::<E>();
let state = Arc::new(State {
is_sink_closed: AtomicBool::new(false),
block_scan_threshold,
close_sink_threshold,
ref_count: AtomicUsize::new(0),
wait_queue: RwLock::new(Queue::new()),
recv_task: AtomicWaker::new(),
force_flush_flag: AtomicBool::new(false),
#[cfg(test)]
blocked_sender_count: AtomicUsize::new(0),
});
let (close_tx, close_rx) = async_channel::<()>(1);
let rate_limiter = RateLimiter::new(sender, state.clone(), close_tx);
let drainer = Drainer::new(receiver, state, close_rx);
(rate_limiter, drainer)
}
impl<E> RateLimiter<E> {
fn new(sink: Sender<E>, state: Arc<State>, close_tx: AsyncSender<()>) -> RateLimiter<E> {
state.ref_count.fetch_add(1, Ordering::SeqCst);
RateLimiter {
sink,
close_tx,
state,
per_downstream_state: None,
}
}
pub fn send_realtime_event(&self, event: E) -> Result<(), RateLimiterError> {
if self.state.is_sink_closed.load(Ordering::SeqCst) {
return Err(RateLimiterError::DisconnectedError);
}
let queue_size = self.sink.len();
debug!("cdc send_realtime_event"; "queue_size" => queue_size);
CDC_SINK_QUEUE_SIZE_HISTOGRAM.observe(queue_size as f64);
if queue_size >= self.state.close_sink_threshold {
warn!("cdc send_realtime_event queue length reached threshold"; "queue_size" => queue_size);
self.state.is_sink_closed.store(true, Ordering::SeqCst);
let _ = self.close_tx.clone().try_send(());
return Err(RateLimiterError::CongestedError(queue_size));
}
self.sink.try_send(event).map_err(|e| {
warn!("cdc send_realtime_event error"; "err" => ?e);
self.state.is_sink_closed.store(true, Ordering::SeqCst);
RateLimiterError::DisconnectedError
})?;
self.state.wake_up_drainer();
Ok(())
}
pub async fn send_scan_event(&self, event: E) -> Result<(), RateLimiterError> {
let sink_clone = self.sink.clone();
let state_clone = self.state.clone();
let threshold = self.state.block_scan_threshold;
let timer = CDC_SCAN_BLOCK_DURATION_HISTOGRAM.start_coarse_timer();
BlockSender::block_sender(self.state.as_ref(), move || {
sink_clone.len() >= threshold && !state_clone.is_sink_closed.load(Ordering::SeqCst)
})
.await;
if self.state.is_sink_closed.load(Ordering::SeqCst) {
return Err(RateLimiterError::DisconnectedError);
}
timer.observe_duration();
if let Some(per_downstream_state) = self.per_downstream_state.as_ref() {
if per_downstream_state
.is_downstream_closed
.load(Ordering::SeqCst)
{
return Err(RateLimiterError::DownstreamClosed(
per_downstream_state.region_id,
));
}
}
match self.sink.try_send(event) {
Ok(_) => {
self.state.wake_up_drainer();
Ok(())
}
Err(_err) => {
return Err(RateLimiterError::DisconnectedError);
}
}
}
pub fn start_flush(&self) {
self.state.force_flush_flag.store(true, Ordering::SeqCst);
self.state.wake_up_drainer();
}
pub fn with_region_id(mut self, region_id: u64) -> RateLimiter<E> {
self.per_downstream_state = Some(Arc::new(PerDownstreamState {
is_downstream_closed: AtomicBool::new(false),
region_id,
}));
self
}
pub fn close_with_error(&self, event: E) -> Result<(), RateLimiterError> {
if self.per_downstream_state.is_some() {
self.per_downstream_state
.as_ref()
.unwrap()
.is_downstream_closed
.store(true, Ordering::SeqCst);
}
self.send_realtime_event(event)
}
#[cfg(test)]
fn inject_instant_drainer_exit(&self) {
let _ = self.close_tx.clone().try_send(());
self.state.wake_up_all_senders();
self.state.recv_task.wake();
}
}
impl<E> Clone for RateLimiter<E> {
fn clone(&self) -> Self {
self.state.ref_count.fetch_add(1, Ordering::SeqCst);
RateLimiter {
sink: self.sink.clone(),
close_tx: self.close_tx.clone(),
state: self.state.clone(),
per_downstream_state: self.per_downstream_state.clone(),
}
}
}
impl<E> Drop for RateLimiter<E> {
fn drop(&mut self) {
let prev = self.state.ref_count.fetch_sub(1, Ordering::SeqCst);
if prev == 1 {
self.state.wake_up_drainer();
}
}
}
struct BlockSender<'a, Cond>
where
Cond: Fn() -> bool + Unpin + 'a,
{
state: &'a State,
cond: Cond,
waker: Option<Arc<Mutex<Option<Waker>>>>,
}
impl<'a, Cond> BlockSender<'a, Cond>
where
Cond: Fn() -> bool + Unpin + 'a,
{
#[inline]
fn block_sender(state: &'a State, cond: Cond) -> Self {
Self {
state,
cond,
waker: None,
}
}
#[inline]
fn unblock(&mut self) {
if let Some(waker_arc) = self.waker.take() {
waker_arc.lock().unwrap().take();
#[cfg(test)]
{
let prev_count = self
.state
.blocked_sender_count
.fetch_sub(1, Ordering::SeqCst);
debug_assert!(prev_count > 0, "prev_count = {}", prev_count);
}
}
}
}
impl<'a, Cond> Future for BlockSender<'a, Cond>
where
Cond: Fn() -> bool + Unpin + 'a,
{
type Output = ();
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
if !(self.cond)() || self.state.is_sink_closed.load(Ordering::SeqCst) {
let mut_self = self.get_mut();
mut_self.unblock();
Poll::Ready(())
} else {
let queue = self.state.wait_queue.read().unwrap();
if !(self.cond)() || self.state.is_sink_closed.load(Ordering::SeqCst) {
self.get_mut().unblock();
Poll::Ready(())
} else {
if let Some(ref waker_arc) = self.waker {
let mut waker_slot = waker_arc.lock().unwrap();
if waker_slot.is_none() {
*waker_slot = Some(cx.waker().clone());
queue.push(waker_arc.clone());
}
return Poll::Pending;
}
let waker_arc = Arc::new(Mutex::new(Some(cx.waker().clone())));
queue.push(waker_arc.clone());
let mut mut_self = self.get_mut();
mut_self.waker = Some(waker_arc);
#[cfg(test)]
{
mut_self
.state
.blocked_sender_count
.fetch_add(1, Ordering::SeqCst);
}
Poll::Pending
}
}
}
}
impl<E> Drainer<E> {
fn new(receiver: Receiver<E>, state: Arc<State>, close_rx: AsyncReceiver<()>) -> Drainer<E> {
Drainer {
receiver,
close_rx: Some(close_rx),
state,
}
}
pub async fn drain<F: Copy, Error, S: Sink<(E, F), Error = Error> + Unpin>(
mut self,
mut rpc_sink: S,
flag: F,
) -> Result<(), DrainerError<Error>> {
let mut close_rx = self.close_rx.take().unwrap().fuse();
let mut unflushed_size: usize = 0;
let mut last_flushed_time = std::time::Instant::now();
loop {
let mut sink_ready = poll_fn(|cx| rpc_sink.poll_ready_unpin(cx)).fuse();
select! {
ready = sink_ready => {
ready.map_err(|err| {
self.state.wake_up_all_senders();
DrainerError::RpcSinkError(err)
})?;
},
item = close_rx.next() => {
if item.is_some() {
self.state.wake_up_all_senders();
return Err(DrainerError::RateLimitExceededError);
}
return Ok(())
},
}
let mut drain_one = DrainOne::wrap(&self.receiver, self.state.as_ref()).fuse();
select! {
next_event = drain_one => {
match next_event {
DrainOneResult::Value(v) => {
self.state.wake_up_one_sender();
rpc_sink.start_send_unpin((v, flag))
.map_err(|err| {
self.state.wake_up_all_senders();
DrainerError::RpcSinkError(err)
})?;
unflushed_size += 1;
if unflushed_size >= 128
|| std::time::Instant::now().duration_since(last_flushed_time) > Duration::from_millis(200) {
rpc_sink.flush().await.map_err(|err| {
self.state.wake_up_all_senders();
DrainerError::RpcSinkError(err)
})?;
unflushed_size = 0;
last_flushed_time = std::time::Instant::now();
}
},
DrainOneResult::FlushRequest => {
rpc_sink.flush().await.map_err(|err| {
self.state.wake_up_all_senders();
DrainerError::RpcSinkError(err)
})?;
unflushed_size = 0;
last_flushed_time = std::time::Instant::now();
},
DrainOneResult::Disconnected => {
info!("cdc rate_limiter closing");
rpc_sink.flush().await.map_err(|err| {
DrainerError::RpcSinkError(err)
})?;
return Ok(())
},
}
},
item = close_rx.next() => {
if item.is_some() {
self.state.wake_up_all_senders();
return Err(DrainerError::RateLimitExceededError);
}
return Ok(())
},
}
}
}
}
impl<E> Drop for Drainer<E> {
fn drop(&mut self) {
self.state.is_sink_closed.store(true, Ordering::SeqCst);
self.state.wake_up_all_senders();
}
}
struct DrainOne<'a, E> {
receiver: &'a Receiver<E>,
state: &'a State,
}
enum DrainOneResult<E> {
Value(E),
FlushRequest,
Disconnected,
}
impl<'a, E> DrainOne<'a, E> {
#[inline]
fn wrap(receiver: &'a Receiver<E>, state: &'a State) -> Self {
Self { receiver, state }
}
}
impl<'a, E> Future for DrainOne<'a, E> {
type Output = DrainOneResult<E>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
loop {
if self.state.force_flush_flag.swap(false, Ordering::SeqCst) {
return Poll::Ready(DrainOneResult::FlushRequest);
}
if self.receiver.is_empty() && self.state.ref_count.load(Ordering::SeqCst) == 0 {
return Poll::Ready(DrainOneResult::Disconnected);
}
return match self.receiver.try_recv() {
Ok(v) => Poll::Ready(DrainOneResult::Value(v)),
Err(TryRecvError::Empty) => {
self.state.yield_drainer(cx);
if self.state.ref_count.load(Ordering::SeqCst) == 0 {
self.state.unyield_drainer();
return Poll::Ready(DrainOneResult::Disconnected);
}
if !self.receiver.is_empty() {
self.state.unyield_drainer();
continue;
}
Poll::Pending
}
Err(TryRecvError::Disconnected) => Poll::Ready(DrainOneResult::Disconnected),
};
}
}
}
#[cfg(test)]
pub mod testing_util {
use super::*;
use crate::service::{CdcEvent, EventBatcherSink};
use futures::channel::mpsc::Receiver;
use futures::stream::StreamExt;
use kvproto::cdcpb::ChangeDataEvent;
use std::cell::RefCell;
use tokio::time::timeout;
use tokio::{runtime::Runtime, time::Elapsed};
pub struct TestingHarness {
rx: Option<Receiver<(ChangeDataEvent, grpcio::WriteFlags)>>,
rate_limiter: RateLimiter<CdcEvent>,
runtime: RefCell<Runtime>,
}
impl TestingHarness {
pub fn new() -> Self {
let mut builder = tokio::runtime::Builder::new();
let runtime = builder
.threaded_scheduler()
.core_threads(4)
.enable_all()
.build()
.unwrap();
let (tx, rx) = futures::channel::mpsc::channel(16);
let (rate_limiter, drainer) = new_pair::<CdcEvent>(1024, 1024);
let batched_sink = EventBatcherSink::new(tx);
runtime.spawn(async move {
let _ = drainer
.drain(batched_sink, grpcio::WriteFlags::default())
.await;
});
let rate_limiter_clone = rate_limiter.clone();
runtime.spawn(async move {
let mut interval = tokio::time::interval(std::time::Duration::from_micros(200));
loop {
interval.tick().await;
rate_limiter_clone.start_flush();
}
});
Self {
rx: Some(rx),
rate_limiter,
runtime: RefCell::new(runtime),
}
}
pub fn get_rx(&mut self) -> Receiver<(ChangeDataEvent, grpcio::WriteFlags)> {
self.rx.take().unwrap()
}
pub fn get_rate_limiter(&self) -> RateLimiter<CdcEvent> {
self.rate_limiter.clone()
}
pub fn block_on<O, F: Future<Output = O> + Send>(&self, f: F) -> O {
self.runtime.borrow_mut().block_on(f)
}
pub fn recv_timeout(
&mut self,
duration: std::time::Duration,
) -> Result<ChangeDataEvent, Elapsed> {
let mut rx = self.rx.take().unwrap();
let fut = self
.runtime
.borrow_mut()
.enter(|| timeout(duration, rx.next()));
let res = self.runtime.borrow_mut().block_on(fut);
self.rx = Some(rx);
res.map(|res| {
let (ev, _) = res.unwrap();
ev
})
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Mutex;
type MockCdcEvent = u64;
type MockWriteFlag = ();
#[derive(Debug, PartialEq, Eq)]
enum MockRpcError {
SinkClosed,
InjectedRpcError,
}
#[derive(Clone)]
struct MockRpcSink {
value: Arc<Mutex<Option<MockCdcEvent>>>,
send_waker: Arc<AtomicWaker>,
recv_waker: Arc<AtomicWaker>,
injected_send_error: Arc<Mutex<Option<MockRpcError>>>,
sink_closed: Arc<AtomicBool>,
}
struct MockRpcSinkBlockRecv<'a, Cond>
where
Cond: Fn() -> bool + 'a,
{
sink: &'a MockRpcSink,
cond: Cond,
}
impl<'a, Cond> MockRpcSinkBlockRecv<'a, Cond>
where
Cond: Fn() -> bool + 'a,
{
fn new(sink: &'a MockRpcSink, cond: Cond) -> Self {
Self { sink, cond }
}
}
impl<'a, Cond> Future for MockRpcSinkBlockRecv<'a, Cond>
where
Cond: Fn() -> bool + 'a,
{
type Output = ();
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
if !(self.cond)() {
Poll::Ready(())
} else {
self.sink.recv_waker.register(cx.waker());
if !(self.cond)() {
Poll::Ready(())
} else {
Poll::Pending
}
}
}
}
impl MockRpcSink {
fn new() -> MockRpcSink {
MockRpcSink {
value: Arc::new(Mutex::new(None)),
send_waker: Arc::new(AtomicWaker::new()),
recv_waker: Arc::new(AtomicWaker::new()),
injected_send_error: Arc::new(Mutex::new(None)),
sink_closed: Arc::new(AtomicBool::new(false)),
}
}
async fn recv(&self) -> Option<MockCdcEvent> {
let value_clone = self.value.clone();
MockRpcSinkBlockRecv::new(self, move || {
value_clone.lock().unwrap().is_none() && !self.sink_closed.load(Ordering::SeqCst)
})
.await;
let ret = self.value.lock().unwrap().take();
self.send_waker.wake();
ret
}
fn inject_send_error(&self, err: MockRpcError) {
*self.injected_send_error.lock().unwrap() = Some(err);
self.send_waker.wake();
}
}
impl Sink<(MockCdcEvent, MockWriteFlag)> for MockRpcSink {
type Error = MockRpcError;
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
if self.sink_closed.load(Ordering::SeqCst) {
return Poll::Ready(Err(MockRpcError::SinkClosed));
}
if let Some(err) = self.injected_send_error.lock().unwrap().take() {
return Poll::Ready(Err(err));
}
let value_guard = self.value.lock().unwrap();
if value_guard.is_none() {
Poll::Ready(Ok(()))
} else {
self.send_waker.register(cx.waker());
Poll::Pending
}
}
fn start_send(self: Pin<&mut Self>, item: (u64, MockWriteFlag)) -> Result<(), Self::Error> {
if self.sink_closed.load(Ordering::SeqCst) {
return Err(MockRpcError::SinkClosed);
}
if let Some(err) = self.injected_send_error.lock().unwrap().take() {
return Err(err);
}
let (value, _) = item;
*self.value.lock().unwrap() = Some(value);
self.recv_waker.wake();
Ok(())
}
fn poll_flush(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
) -> Poll<Result<(), Self::Error>> {
if let Some(err) = self.injected_send_error.lock().unwrap().take() {
return Poll::Ready(Err(err));
}
Poll::Ready(Ok(()))
}
fn poll_close(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
) -> Poll<Result<(), Self::Error>> {
if let Some(err) = self.injected_send_error.lock().unwrap().take() {
return Poll::Ready(Err(err));
}
self.sink_closed.store(true, Ordering::SeqCst);
self.recv_waker.wake();
Poll::Ready(Ok(()))
}
}
#[tokio::test]
async fn test_basic_realtime() -> Result<(), RateLimiterError> {
let (rate_limiter, drainer) = new_pair::<MockCdcEvent>(1024, 1024);
let mut mock_sink = MockRpcSink::new();
let drain_handle = tokio::spawn(drainer.drain(mock_sink.clone(), ()));
for i in 0..10u64 {
rate_limiter.send_realtime_event(i)?;
tokio::task::yield_now().await;
}
for i in 0..10u64 {
assert_eq!(mock_sink.recv().await.unwrap(), i);
}
mock_sink.close().await.unwrap();
assert_eq!(
drain_handle.await.unwrap(),
Err(DrainerError::RpcSinkError(MockRpcError::SinkClosed))
);
Ok(())
}
#[tokio::test]
async fn test_basic_scan() -> Result<(), RateLimiterError> {
let (rate_limiter, drainer) = new_pair::<MockCdcEvent>(1024, 1024);
let mut mock_sink = MockRpcSink::new();
let drain_handle = tokio::spawn(drainer.drain(mock_sink.clone(), ()));
tokio::task::yield_now().await;
for i in 0..10u64 {
rate_limiter.send_scan_event(i).await?;
tokio::task::yield_now().await;
}
for i in 0..10u64 {
assert_eq!(mock_sink.recv().await.unwrap(), i);
}
mock_sink.close().await.unwrap();
assert_eq!(
drain_handle.await.unwrap(),
Err(DrainerError::RpcSinkError(MockRpcError::SinkClosed))
);
Ok(())
}
#[tokio::test]
async fn test_realtime_disconnected() -> Result<(), RateLimiterError> {
let (rate_limiter, drainer) = new_pair::<MockCdcEvent>(1024, 1024);
let mut mock_sink = MockRpcSink::new();
let drain_handle = tokio::spawn(drainer.drain(mock_sink.clone(), ()));
tokio::task::yield_now().await;
rate_limiter.send_realtime_event(1)?;
rate_limiter.send_realtime_event(2)?;
rate_limiter.send_realtime_event(3)?;
rate_limiter.inject_instant_drainer_exit();
assert_eq!(
drain_handle.await.unwrap(),
Err(DrainerError::<MockRpcError>::RateLimitExceededError)
);
assert_eq!(
rate_limiter.send_realtime_event(4),
Err(RateLimiterError::DisconnectedError)
);
assert_eq!(
rate_limiter.send_realtime_event(5),
Err(RateLimiterError::DisconnectedError)
);
mock_sink.close().await.unwrap();
Ok(())
}
#[tokio::test]
async fn test_scan_disconnected() -> Result<(), RateLimiterError> {
let (rate_limiter, drainer) = new_pair::<MockCdcEvent>(1024, 1024);
let mut mock_sink = MockRpcSink::new();
let drain_handle = tokio::spawn(drainer.drain(mock_sink.clone(), ()));
tokio::task::yield_now().await;
rate_limiter.send_scan_event(1).await?;
rate_limiter.send_scan_event(2).await?;
rate_limiter.send_scan_event(3).await?;
rate_limiter.inject_instant_drainer_exit();
assert_eq!(
drain_handle.await.unwrap(),
Err(DrainerError::<MockRpcError>::RateLimitExceededError)
);
assert_eq!(
rate_limiter.send_scan_event(4).await,
Err(RateLimiterError::DisconnectedError)
);
assert_eq!(
rate_limiter.send_scan_event(5).await,
Err(RateLimiterError::DisconnectedError)
);
mock_sink.close().await.unwrap();
Ok(())
}
#[tokio::test]
async fn test_realtime_congested() -> Result<(), RateLimiterError> {
let (rate_limiter, drainer) = new_pair::<MockCdcEvent>(1024, 5);
let mut mock_sink = MockRpcSink::new();
let drain_handle = tokio::spawn(drainer.drain(mock_sink.clone(), ()));
tokio::task::yield_now().await;
rate_limiter.send_realtime_event(1)?;
rate_limiter.send_realtime_event(2)?;
rate_limiter.send_realtime_event(3)?;
rate_limiter.send_realtime_event(4)?;
rate_limiter.send_realtime_event(5)?;
match rate_limiter.send_realtime_event(6) {
Ok(_) => panic!("expected error"),
Err(RateLimiterError::CongestedError(len)) => assert_eq!(len, 5),
_ => panic!("expected CongestedError"),
}
match rate_limiter.send_realtime_event(6) {
Ok(_) => panic!("expected error"),
Err(err) => assert_eq!(err, RateLimiterError::DisconnectedError),
}
rate_limiter.start_flush();
assert_eq!(
drain_handle.await.unwrap(),
Err(DrainerError::RateLimitExceededError)
);
mock_sink.close().await.unwrap();
Ok(())
}
#[tokio::test]
async fn test_scan_block_normal() -> Result<(), RateLimiterError> {
let (rate_limiter, drainer) = new_pair::<MockCdcEvent>(5, 1024);
let mut mock_sink = MockRpcSink::new();
rate_limiter.send_realtime_event(1)?;
rate_limiter.send_realtime_event(2)?;
rate_limiter.send_scan_event(3).await?;
rate_limiter.send_scan_event(4).await?;
rate_limiter.send_scan_event(5).await?;
assert_eq!(
rate_limiter
.state
.blocked_sender_count
.load(Ordering::SeqCst),
0
);
let rate_limiter_clone = rate_limiter.clone();
let handle = tokio::spawn(async move {
rate_limiter_clone.send_scan_event(6).await.unwrap();
});
tokio::time::delay_for(std::time::Duration::from_millis(200)).await;
assert_eq!(
rate_limiter
.state
.blocked_sender_count
.load(Ordering::SeqCst),
1
);
let drain_handle = tokio::spawn(drainer.drain(mock_sink.clone(), ()));
assert_eq!(mock_sink.recv().await.unwrap(), 1);
assert_eq!(mock_sink.recv().await.unwrap(), 2);
assert_eq!(mock_sink.recv().await.unwrap(), 3);
assert_eq!(mock_sink.recv().await.unwrap(), 4);
assert_eq!(mock_sink.recv().await.unwrap(), 5);
assert_eq!(mock_sink.recv().await.unwrap(), 6);
mock_sink.close().await.unwrap();
assert_eq!(
drain_handle.await.unwrap(),
Err(DrainerError::RpcSinkError(MockRpcError::SinkClosed))
);
handle.await.unwrap();
Ok(())
}
#[tokio::test]
async fn test_scan_block_disconnected() -> Result<(), RateLimiterError> {
let (rate_limiter, drainer) = new_pair::<MockCdcEvent>(5, 1024);
let mut mock_sink = MockRpcSink::new();
rate_limiter.send_realtime_event(1)?;
let drain_handle = tokio::spawn(drainer.drain(mock_sink.clone(), ()));
tokio::time::delay_for(std::time::Duration::from_millis(200)).await;
assert_eq!(rate_limiter.sink.len(), 0);
rate_limiter.send_realtime_event(2)?;
assert_eq!(rate_limiter.sink.len(), 1);
rate_limiter.send_scan_event(3).await?;
assert_eq!(rate_limiter.sink.len(), 2);
rate_limiter.send_scan_event(4).await?;
assert_eq!(rate_limiter.sink.len(), 3);
rate_limiter.send_scan_event(5).await?;
assert_eq!(rate_limiter.sink.len(), 4);
rate_limiter.send_scan_event(6).await?;
tokio::time::delay_for(std::time::Duration::from_millis(200)).await;
assert_eq!(
rate_limiter
.state
.blocked_sender_count
.load(Ordering::SeqCst),
0
);
let rate_limiter_clone = rate_limiter.clone();
let handle = tokio::spawn(async move {
assert_eq!(rate_limiter_clone.sink.len(), 5);
match rate_limiter_clone.send_scan_event(7).await {
Ok(_) => panic!("expected error"),
Err(err) => assert_eq!(err, RateLimiterError::DisconnectedError),
}
});
tokio::time::delay_for(std::time::Duration::from_millis(200)).await;
assert_eq!(
rate_limiter
.state
.blocked_sender_count
.load(Ordering::SeqCst),
1
);
rate_limiter.inject_instant_drainer_exit();
assert_eq!(
drain_handle.await.unwrap(),
Err(DrainerError::RateLimitExceededError)
);
mock_sink.close().await.unwrap();
handle.await.unwrap();
Ok(())
}
#[tokio::test]
async fn test_rpc_sink_error() -> Result<(), RateLimiterError> {
let (rate_limiter, drainer) = new_pair::<MockCdcEvent>(1024, 1024);
let mock_sink = MockRpcSink::new();
let drain_handle = tokio::spawn(drainer.drain(mock_sink.clone(), ()));
for i in 0..10u64 {
rate_limiter.send_realtime_event(i)?;
tokio::task::yield_now().await;
}
mock_sink.inject_send_error(MockRpcError::InjectedRpcError);
rate_limiter.start_flush();
let res = drain_handle.await.unwrap();
assert_eq!(
res,
Err(DrainerError::RpcSinkError(MockRpcError::InjectedRpcError))
);
Ok(())
}
#[tokio::test]
async fn test_close_downstream() -> Result<(), RateLimiterError> {
let (rate_limiter, drainer) = new_pair::<MockCdcEvent>(2, 1024);
let mock_sink = MockRpcSink::new();
let drain_handle = tokio::spawn(drainer.drain(mock_sink.clone(), ()));
let rate_limiter = rate_limiter.with_region_id(1);
let rate_limiter_clone = rate_limiter.clone();
let send_task = tokio::spawn(async move {
rate_limiter_clone.send_scan_event(1).await.unwrap();
tokio::task::yield_now().await;
assert_eq!(rate_limiter_clone.sink.len(), 0);
rate_limiter_clone.send_scan_event(2).await.unwrap();
tokio::task::yield_now().await;
assert_eq!(rate_limiter_clone.sink.len(), 1);
rate_limiter_clone.send_scan_event(3).await.unwrap();
tokio::task::yield_now().await;
assert_eq!(rate_limiter_clone.sink.len(), 2);
let res = rate_limiter_clone.send_scan_event(4).await;
assert_eq!(res, Err(RateLimiterError::DownstreamClosed(1)));
});
tokio::time::delay_for(std::time::Duration::from_millis(200)).await;
rate_limiter.close_with_error(9)?;
assert_eq!(mock_sink.recv().await.unwrap(), 1);
assert_eq!(mock_sink.recv().await.unwrap(), 2);
assert_eq!(mock_sink.recv().await.unwrap(), 3);
drop(rate_limiter);
send_task.await.unwrap();
drain_handle.await.unwrap().unwrap();
Ok(())
}
#[tokio::test]
async fn test_single_thread_many_events() -> Result<(), RateLimiterError> {
let (rate_limiter, drainer) = new_pair::<MockCdcEvent>(1024, 1024);
let mut mock_sink = MockRpcSink::new();
let drain_handle = tokio::spawn(drainer.drain(mock_sink.clone(), ()));
tokio::task::yield_now().await;
let verifier_handler = tokio::spawn(async move {
for i in 0..100000u64 {
assert_eq!(mock_sink.recv().await.unwrap(), i);
}
let close_res = mock_sink.close().await;
assert_eq!(close_res, Ok(()));
});
for i in 0..100000u64 {
rate_limiter.send_scan_event(i).await?;
}
verifier_handler.await.unwrap();
let res = drain_handle.await.unwrap();
assert_eq!(
res,
Err(DrainerError::RpcSinkError(MockRpcError::SinkClosed))
);
Ok(())
}
#[test]
fn test_multi_thread_many_events() {
let mut builder = tokio::runtime::Builder::new();
let mut runtime = builder
.threaded_scheduler()
.core_threads(16)
.enable_all()
.build()
.unwrap();
runtime.block_on(async {
do_test_multi_thread_many_events().await.unwrap();
});
}
async fn do_test_multi_thread_many_events() -> Result<(), RateLimiterError> {
let (rate_limiter, drainer) = new_pair::<MockCdcEvent>(1024, usize::MAX);
let (sink, mut rx) = futures::channel::mpsc::unbounded::<(MockCdcEvent, MockWriteFlag)>();
let drain_handle = tokio::spawn(async move {
match drainer.drain(sink, ()).await {
Ok(_) => {}
Err(err) => panic!("drainer exited with error {:?}", err),
}
});
let verify_handle = tokio::spawn(async move {
let mut count = 0u64;
loop {
if rx.next().await.is_some() {
count += 1;
}
if count == 10 * 10000 {
return;
}
}
});
tokio::task::yield_now().await;
let mut handles = vec![];
for _i in 0..10 {
let rate_limiter = rate_limiter.clone();
let handle = tokio::spawn(async move {
for j in 0..10000u64 {
rate_limiter.send_scan_event(j).await.unwrap();
}
});
handles.push(handle);
}
rate_limiter.start_flush();
verify_handle.await.unwrap();
for handle in handles.into_iter() {
handle.await.unwrap();
}
drop(rate_limiter);
drain_handle.await.unwrap();
Ok(())
}
}