use std::future::Future;
use std::sync::{Arc, Mutex};
use futures::channel::oneshot;
use futures::future::TryFutureExt;
use kvproto::kvrpcpb::CommandPri;
use prometheus::IntGauge;
use thiserror::Error;
use yatp::pool::Remote;
use yatp::queue::Extras;
use yatp::task::future::TaskCell;
use file_system::{set_io_type, IOType};
use tikv_util::yatp_pool::{self, FuturePool, PoolTicker, YatpPoolBuilder};
use self::metrics::*;
use crate::config::UnifiedReadPoolConfig;
use crate::storage::kv::{destroy_tls_engine, set_tls_engine, Engine, FlowStatsReporter};
pub enum ReadPool {
FuturePools {
read_pool_high: FuturePool,
read_pool_normal: FuturePool,
read_pool_low: FuturePool,
},
Yatp {
pool: yatp::ThreadPool<TaskCell>,
running_tasks: IntGauge,
max_tasks: usize,
pool_size: usize,
},
}
impl ReadPool {
pub fn handle(&self) -> ReadPoolHandle {
match self {
ReadPool::FuturePools {
read_pool_high,
read_pool_normal,
read_pool_low,
} => ReadPoolHandle::FuturePools {
read_pool_high: read_pool_high.clone(),
read_pool_normal: read_pool_normal.clone(),
read_pool_low: read_pool_low.clone(),
},
ReadPool::Yatp {
pool,
running_tasks,
max_tasks,
pool_size,
} => ReadPoolHandle::Yatp {
remote: pool.remote().clone(),
running_tasks: running_tasks.clone(),
max_tasks: *max_tasks,
pool_size: *pool_size,
},
}
}
}
#[derive(Clone)]
pub enum ReadPoolHandle {
FuturePools {
read_pool_high: FuturePool,
read_pool_normal: FuturePool,
read_pool_low: FuturePool,
},
Yatp {
remote: Remote<TaskCell>,
running_tasks: IntGauge,
max_tasks: usize,
pool_size: usize,
},
}
impl ReadPoolHandle {
pub fn spawn<F>(&self, f: F, priority: CommandPri, task_id: u64) -> Result<(), ReadPoolError>
where
F: Future<Output = ()> + Send + 'static,
{
match self {
ReadPoolHandle::FuturePools {
read_pool_high,
read_pool_normal,
read_pool_low,
} => {
let pool = match priority {
CommandPri::High => read_pool_high,
CommandPri::Normal => read_pool_normal,
CommandPri::Low => read_pool_low,
};
pool.spawn(f)?;
}
ReadPoolHandle::Yatp {
remote,
running_tasks,
max_tasks,
..
} => {
let running_tasks = running_tasks.clone();
if running_tasks.get() as usize >= *max_tasks {
return Err(ReadPoolError::UnifiedReadPoolFull);
}
running_tasks.inc();
let fixed_level = match priority {
CommandPri::High => Some(0),
CommandPri::Normal => None,
CommandPri::Low => Some(2),
};
let extras = Extras::new_multilevel(task_id, fixed_level);
let task_cell = TaskCell::new(
async move {
f.await;
running_tasks.dec();
},
extras,
);
remote.spawn(task_cell);
}
}
Ok(())
}
pub fn spawn_handle<F, T>(
&self,
f: F,
priority: CommandPri,
task_id: u64,
) -> impl Future<Output = Result<T, ReadPoolError>>
where
F: Future<Output = T> + Send + 'static,
T: Send + 'static,
{
let (tx, rx) = oneshot::channel::<T>();
let res = self.spawn(
async move {
let res = f.await;
let _ = tx.send(res);
},
priority,
task_id,
);
async move {
res?;
rx.map_err(ReadPoolError::from).await
}
}
pub fn get_normal_pool_size(&self) -> usize {
match self {
ReadPoolHandle::FuturePools {
read_pool_normal, ..
} => read_pool_normal.get_pool_size(),
ReadPoolHandle::Yatp { pool_size, .. } => *pool_size,
}
}
pub fn get_queue_size_per_worker(&self) -> usize {
match self {
ReadPoolHandle::FuturePools {
read_pool_normal, ..
} => {
read_pool_normal.get_running_task_count() as usize
/ read_pool_normal.get_pool_size()
}
ReadPoolHandle::Yatp {
running_tasks,
pool_size,
..
} => running_tasks.get() as usize / *pool_size,
}
}
}
#[derive(Clone)]
pub struct ReporterTicker<R: FlowStatsReporter> {
reporter: R,
}
impl<R: FlowStatsReporter> PoolTicker for ReporterTicker<R> {
fn on_tick(&mut self) {
self.flush_metrics_on_tick();
}
}
impl<R: FlowStatsReporter> ReporterTicker<R> {
fn flush_metrics_on_tick(&mut self) {
crate::storage::metrics::tls_flush(&self.reporter);
crate::coprocessor::metrics::tls_flush(&self.reporter);
}
}
#[cfg(test)]
fn get_unified_read_pool_name() -> String {
use std::sync::atomic::{AtomicU64, Ordering};
static COUNTER: AtomicU64 = AtomicU64::new(0);
format!(
"unified-read-pool-test-{}",
COUNTER.fetch_add(1, Ordering::Relaxed)
)
}
#[cfg(not(test))]
fn get_unified_read_pool_name() -> String {
"unified-read-pool".to_string()
}
pub fn build_yatp_read_pool<E: Engine, R: FlowStatsReporter>(
config: &UnifiedReadPoolConfig,
reporter: R,
engine: E,
) -> ReadPool {
let unified_read_pool_name = get_unified_read_pool_name();
let mut builder = YatpPoolBuilder::new(ReporterTicker { reporter });
let raftkv = Arc::new(Mutex::new(engine));
let pool = builder
.name_prefix(&unified_read_pool_name)
.stack_size(config.stack_size.0 as usize)
.thread_count(config.min_thread_count, config.max_thread_count)
.after_start(move || {
let engine = raftkv.lock().unwrap().clone();
set_tls_engine(engine);
set_io_type(IOType::ForegroundRead);
})
.before_stop(|| unsafe {
destroy_tls_engine::<E>();
})
.build_multi_level_pool();
ReadPool::Yatp {
pool,
running_tasks: UNIFIED_READ_POOL_RUNNING_TASKS
.with_label_values(&[&unified_read_pool_name]),
max_tasks: config
.max_tasks_per_worker
.saturating_mul(config.max_thread_count),
pool_size: config.max_thread_count,
}
}
impl From<Vec<FuturePool>> for ReadPool {
fn from(mut v: Vec<FuturePool>) -> ReadPool {
assert_eq!(v.len(), 3);
let read_pool_high = v.remove(2);
let read_pool_normal = v.remove(1);
let read_pool_low = v.remove(0);
ReadPool::FuturePools {
read_pool_high,
read_pool_normal,
read_pool_low,
}
}
}
#[derive(Debug, Error)]
pub enum ReadPoolError {
#[error("{0}")]
FuturePoolFull(#[from] yatp_pool::Full),
#[error("Unified read pool is full")]
UnifiedReadPoolFull,
#[error("{0}")]
Canceled(#[from] oneshot::Canceled),
}
mod metrics {
use prometheus::*;
lazy_static! {
pub static ref UNIFIED_READ_POOL_RUNNING_TASKS: IntGaugeVec = register_int_gauge_vec!(
"tikv_unified_read_pool_running_tasks",
"The number of running tasks in the unified read pool",
&["name"]
)
.unwrap();
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::storage::TestEngineBuilder;
use futures::channel::oneshot;
use raftstore::store::ReadStats;
use std::thread;
use std::time::Duration;
#[derive(Clone)]
struct DummyReporter;
impl FlowStatsReporter for DummyReporter {
fn report_read_stats(&self, _read_stats: ReadStats) {}
}
#[test]
fn test_yatp_full() {
let config = UnifiedReadPoolConfig {
min_thread_count: 1,
max_thread_count: 2,
max_tasks_per_worker: 1,
..Default::default()
};
let engine = TestEngineBuilder::new().build().unwrap();
let pool = build_yatp_read_pool(&config, DummyReporter, engine);
let gen_task = || {
let (tx, rx) = oneshot::channel::<()>();
let task = async move {
let _ = rx.await;
};
(task, tx)
};
let handle = pool.handle();
let (task1, tx1) = gen_task();
let (task2, _tx2) = gen_task();
let (task3, _tx3) = gen_task();
let (task4, _tx4) = gen_task();
assert!(handle.spawn(task1, CommandPri::Normal, 1).is_ok());
assert!(handle.spawn(task2, CommandPri::Normal, 2).is_ok());
thread::sleep(Duration::from_millis(300));
match handle.spawn(task3, CommandPri::Normal, 3) {
Err(ReadPoolError::UnifiedReadPoolFull) => {}
_ => panic!("should return full error"),
}
tx1.send(()).unwrap();
thread::sleep(Duration::from_millis(300));
assert!(handle.spawn(task4, CommandPri::Normal, 4).is_ok());
}
}