use std::future::Future;
use std::sync::Arc;
use fail::fail_point;
use futures::channel::oneshot::{self, Canceled};
use prometheus::{Histogram, IntCounter, IntGauge};
use yatp::task::future;
pub type ThreadPool = yatp::ThreadPool<future::TaskCell>;
use super::metrics;
use crate::time::Instant;
#[derive(Clone)]
struct Env {
metrics_running_task_count: IntGauge,
metrics_handled_task_count: IntCounter,
metrics_pool_schedule_duration: Histogram,
}
#[derive(Clone)]
pub struct FuturePool {
pool: Arc<ThreadPool>,
env: Env,
pool_size: usize,
max_tasks: usize,
}
impl std::fmt::Debug for FuturePool {
fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(fmt, "FuturePool")
}
}
impl crate::AssertSend for FuturePool {}
impl crate::AssertSync for FuturePool {}
impl FuturePool {
pub fn from_pool(pool: ThreadPool, name: &str, pool_size: usize, max_tasks: usize) -> Self {
let env = Env {
metrics_running_task_count: metrics::FUTUREPOOL_RUNNING_TASK_VEC
.with_label_values(&[name]),
metrics_handled_task_count: metrics::FUTUREPOOL_HANDLED_TASK_VEC
.with_label_values(&[name]),
metrics_pool_schedule_duration: metrics::FUTUREPOOL_SCHEDULE_DURATION_VEC
.with_label_values(&[name]),
};
FuturePool {
pool: Arc::new(pool),
env,
pool_size,
max_tasks,
}
}
#[inline]
pub fn get_pool_size(&self) -> usize {
self.pool_size
}
#[inline]
pub fn get_running_task_count(&self) -> usize {
self.env.metrics_running_task_count.get() as usize
}
fn gate_spawn(&self) -> Result<(), Full> {
fail_point!("future_pool_spawn_full", |_| Err(Full {
current_tasks: 100,
max_tasks: 100,
}));
if self.max_tasks == std::usize::MAX {
return Ok(());
}
let current_tasks = self.get_running_task_count();
if current_tasks >= self.max_tasks {
Err(Full {
current_tasks,
max_tasks: self.max_tasks,
})
} else {
Ok(())
}
}
pub fn spawn<F>(&self, future: F) -> Result<(), Full>
where
F: Future + Send + 'static,
{
let timer = Instant::now_coarse();
let h_schedule = self.env.metrics_pool_schedule_duration.clone();
let metrics_handled_task_count = self.env.metrics_handled_task_count.clone();
let metrics_running_task_count = self.env.metrics_running_task_count.clone();
self.gate_spawn()?;
metrics_running_task_count.inc();
self.pool.spawn(async move {
h_schedule.observe(timer.elapsed_secs());
let _ = future.await;
metrics_handled_task_count.inc();
metrics_running_task_count.dec();
});
Ok(())
}
pub fn spawn_handle<F>(
&self,
future: F,
) -> Result<impl Future<Output = Result<F::Output, Canceled>>, Full>
where
F: Future + Send + 'static,
F::Output: Send,
{
let timer = Instant::now_coarse();
let h_schedule = self.env.metrics_pool_schedule_duration.clone();
let metrics_handled_task_count = self.env.metrics_handled_task_count.clone();
let metrics_running_task_count = self.env.metrics_running_task_count.clone();
self.gate_spawn()?;
let (tx, rx) = oneshot::channel();
metrics_running_task_count.inc();
self.pool.spawn(async move {
h_schedule.observe(timer.elapsed_secs());
let res = future.await;
metrics_handled_task_count.inc();
metrics_running_task_count.dec();
let _ = tx.send(res);
});
Ok(rx)
}
}
#[derive(Clone, Copy, PartialEq, Eq, Debug)]
pub struct Full {
pub current_tasks: usize,
pub max_tasks: usize,
}
impl std::fmt::Display for Full {
fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(fmt, "future pool is full")
}
}
impl std::error::Error for Full {
fn description(&self) -> &str {
"future pool is full"
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::mpsc;
use std::sync::{
atomic::{AtomicUsize, Ordering},
Mutex,
};
use std::thread;
use std::time::Duration;
use super::super::{DefaultTicker, PoolTicker, YatpPoolBuilder as Builder, TICK_INTERVAL};
use futures::executor::block_on;
fn spawn_future_and_wait(pool: &FuturePool, duration: Duration) {
block_on(
pool.spawn_handle(async move {
thread::sleep(duration);
})
.unwrap(),
)
.unwrap();
}
fn spawn_future_without_wait(pool: &FuturePool, duration: Duration) {
pool.spawn(async move {
thread::sleep(duration);
})
.unwrap();
}
#[derive(Clone)]
pub struct SequenceTicker {
tick: Arc<dyn Fn() + Send + Sync>,
}
impl SequenceTicker {
pub fn new<F>(tick: F) -> SequenceTicker
where
F: Fn() + Send + Sync + 'static,
{
SequenceTicker {
tick: Arc::new(tick),
}
}
}
impl PoolTicker for SequenceTicker {
fn on_tick(&mut self) {
(self.tick)();
}
}
#[test]
fn test_tick() {
let tick_sequence = Arc::new(AtomicUsize::new(0));
let (tx, rx) = mpsc::sync_channel(1000);
let rx = Arc::new(Mutex::new(rx));
let ticker = SequenceTicker::new(move || {
let seq = tick_sequence.fetch_add(1, Ordering::SeqCst);
tx.send(seq).unwrap();
});
let pool = Builder::new(ticker).thread_count(1, 1).build_future_pool();
let try_recv_tick = || {
let rx = rx.clone();
block_on(
pool.spawn_handle(async move { rx.lock().unwrap().try_recv() })
.unwrap(),
)
.unwrap()
};
assert!(try_recv_tick().is_err());
spawn_future_and_wait(&pool, TICK_INTERVAL / 20);
assert!(try_recv_tick().is_err());
spawn_future_and_wait(&pool, TICK_INTERVAL / 20);
spawn_future_and_wait(&pool, TICK_INTERVAL / 20);
spawn_future_and_wait(&pool, TICK_INTERVAL / 20);
spawn_future_and_wait(&pool, TICK_INTERVAL / 20);
assert!(try_recv_tick().is_err());
thread::sleep(TICK_INTERVAL * 2);
assert!(try_recv_tick().is_err());
spawn_future_and_wait(&pool, TICK_INTERVAL / 20);
assert_eq!(try_recv_tick().unwrap(), 0);
assert!(try_recv_tick().is_err());
thread::sleep(TICK_INTERVAL * 2);
assert!(try_recv_tick().is_err());
spawn_future_and_wait(&pool, TICK_INTERVAL / 20);
assert_eq!(try_recv_tick().unwrap(), 1);
assert!(try_recv_tick().is_err());
spawn_future_and_wait(&pool, TICK_INTERVAL * 2);
assert_eq!(try_recv_tick().unwrap(), 2);
assert!(try_recv_tick().is_err());
}
#[test]
fn test_tick_multi_thread() {
let tick_sequence = Arc::new(AtomicUsize::new(0));
let (tx, rx) = mpsc::sync_channel(1000);
let ticker = SequenceTicker::new(move || {
let seq = tick_sequence.fetch_add(1, Ordering::SeqCst);
tx.send(seq).unwrap();
});
let pool = Builder::new(ticker).thread_count(2, 2).build_future_pool();
assert!(rx.try_recv().is_err());
spawn_future_without_wait(&pool, TICK_INTERVAL / 2);
spawn_future_without_wait(&pool, TICK_INTERVAL / 2);
assert!(rx.try_recv().is_err());
thread::sleep(TICK_INTERVAL * 2);
assert!(rx.try_recv().is_err());
spawn_future_without_wait(&pool, TICK_INTERVAL);
spawn_future_without_wait(&pool, TICK_INTERVAL / 2);
thread::sleep(TICK_INTERVAL * 2);
assert_eq!(rx.try_recv().unwrap(), 0);
assert_eq!(rx.try_recv().unwrap(), 1);
assert!(rx.try_recv().is_err());
}
#[test]
fn test_handle_result() {
let pool = Builder::new(DefaultTicker {})
.thread_count(1, 1)
.build_future_pool();
let handle = pool.spawn_handle(async { 42 });
assert_eq!(block_on(handle.unwrap()).unwrap(), 42);
}
#[test]
fn test_running_task_count() {
let pool = Builder::new(DefaultTicker {})
.name_prefix("future_pool_for_running_task_test")
.thread_count(2, 2)
.build_future_pool();
assert_eq!(pool.get_running_task_count(), 0);
spawn_future_without_wait(&pool, Duration::from_millis(500));
assert_eq!(pool.get_running_task_count(), 1);
spawn_future_without_wait(&pool, Duration::from_millis(1000));
assert_eq!(pool.get_running_task_count(), 2);
spawn_future_without_wait(&pool, Duration::from_millis(1500));
assert_eq!(pool.get_running_task_count(), 3);
thread::sleep(Duration::from_millis(700));
assert_eq!(pool.get_running_task_count(), 2);
spawn_future_without_wait(&pool, Duration::from_millis(1500));
assert_eq!(pool.get_running_task_count(), 3);
thread::sleep(Duration::from_millis(2700));
assert_eq!(pool.get_running_task_count(), 0);
}
fn spawn_long_time_future(
pool: &FuturePool,
id: u64,
future_duration_ms: u64,
) -> Result<impl Future<Output = Result<u64, Canceled>>, Full> {
pool.spawn_handle(async move {
thread::sleep(Duration::from_millis(future_duration_ms));
id
})
}
fn wait_on_new_thread<F>(sender: mpsc::Sender<F::Output>, future: F)
where
F: Future + Send + 'static,
F::Output: Send + 'static,
{
thread::spawn(move || {
let r = block_on(future);
sender.send(r).unwrap();
});
}
#[test]
fn test_full() {
let (tx, rx) = mpsc::channel();
let read_pool = Builder::new(DefaultTicker {})
.name_prefix("future_pool_test_full")
.thread_count(2, 2)
.max_tasks(4)
.build_future_pool();
wait_on_new_thread(
tx.clone(),
spawn_long_time_future(&read_pool, 0, 5).unwrap(),
);
assert_eq!(rx.recv().unwrap(), Ok(0));
wait_on_new_thread(
tx.clone(),
spawn_long_time_future(&read_pool, 1, 100).unwrap(),
);
wait_on_new_thread(
tx.clone(),
spawn_long_time_future(&read_pool, 2, 200).unwrap(),
);
wait_on_new_thread(
tx.clone(),
spawn_long_time_future(&read_pool, 3, 300).unwrap(),
);
wait_on_new_thread(
tx.clone(),
spawn_long_time_future(&read_pool, 4, 400).unwrap(),
);
assert!(rx.recv_timeout(Duration::from_millis(50)).is_err());
assert!(spawn_long_time_future(&read_pool, 5, 100).is_err());
assert!(spawn_long_time_future(&read_pool, 6, 100).is_err());
assert_eq!(rx.recv().unwrap(), Ok(1));
wait_on_new_thread(tx, spawn_long_time_future(&read_pool, 7, 5).unwrap());
assert!(spawn_long_time_future(&read_pool, 8, 100).is_err());
assert!(rx.recv().is_ok());
assert!(rx.recv().is_ok());
assert!(rx.recv().is_ok());
assert!(rx.recv().is_ok());
assert!(rx.recv_timeout(Duration::from_millis(500)).is_err());
}
}