use std::convert::TryFrom;
use std::sync::Arc;
use tidb_query_datatype::{EvalType, FieldTypeAccessor};
use tipb::{Expr, FieldType};
use crate::interface::*;
use tidb_query_aggr::*;
use tidb_query_common::storage::IntervalRange;
use tidb_query_common::Result;
use tidb_query_datatype::codec::batch::{LazyBatchColumn, LazyBatchColumnVec};
use tidb_query_datatype::codec::data_type::*;
use tidb_query_datatype::expr::{EvalConfig, EvalContext};
use tidb_query_expr::RpnExpression;
pub trait AggregationExecutorImpl<Src: BatchExecutor>: Send {
fn prepare_entities(&mut self, entities: &mut Entities<Src>);
fn process_batch_input(
&mut self,
entities: &mut Entities<Src>,
input_physical_columns: LazyBatchColumnVec,
input_logical_rows: &[usize],
) -> Result<()>;
fn groups_len(&self) -> usize;
fn iterate_available_groups(
&mut self,
entities: &mut Entities<Src>,
src_is_drained: bool,
iteratee: impl FnMut(&mut Entities<Src>, &[Box<dyn AggrFunctionState>]) -> Result<()>,
) -> Result<Vec<LazyBatchColumn>>;
fn is_partial_results_ready(&self) -> bool;
}
pub struct Entities<Src: BatchExecutor> {
pub src: Src,
pub context: EvalContext,
pub schema: Vec<FieldType>,
pub each_aggr_fn: Vec<Box<dyn AggrFunction>>,
pub each_aggr_cardinality: Vec<usize>,
pub each_aggr_exprs: Vec<RpnExpression>,
pub all_result_column_types: Vec<EvalType>,
}
pub struct AggregationExecutor<Src: BatchExecutor, I: AggregationExecutorImpl<Src>> {
imp: I,
is_ended: bool,
entities: Entities<Src>,
}
impl<Src: BatchExecutor, I: AggregationExecutorImpl<Src>> AggregationExecutor<Src, I> {
pub fn new(
mut imp: I,
src: Src,
config: Arc<EvalConfig>,
aggr_defs: Vec<Expr>,
aggr_def_parser: impl AggrDefinitionParser,
) -> Result<Self> {
let aggr_fn_len = aggr_defs.len();
let src_schema = src.schema();
let mut schema = Vec::with_capacity(aggr_fn_len * 2);
let mut each_aggr_fn = Vec::with_capacity(aggr_fn_len);
let mut each_aggr_cardinality = Vec::with_capacity(aggr_fn_len);
let mut each_aggr_exprs = Vec::with_capacity(aggr_fn_len);
let mut ctx = EvalContext::new(config.clone());
for aggr_def in aggr_defs {
let schema_len = schema.len();
let each_aggr_exprs_len = each_aggr_exprs.len();
let aggr_fn = aggr_def_parser.parse(
aggr_def,
&mut ctx,
src_schema,
&mut schema,
&mut each_aggr_exprs,
)?;
assert!(schema.len() > schema_len);
assert_eq!(each_aggr_exprs.len(), each_aggr_exprs_len + 1);
each_aggr_fn.push(aggr_fn);
each_aggr_cardinality.push(schema.len() - schema_len);
}
let all_result_column_types = schema
.iter()
.map(|ft| {
EvalType::try_from(ft.as_accessor().tp()).unwrap()
})
.collect();
let mut entities = Entities {
src,
context: EvalContext::new(config),
schema,
each_aggr_fn,
each_aggr_cardinality,
each_aggr_exprs,
all_result_column_types,
};
imp.prepare_entities(&mut entities);
Ok(Self {
imp,
is_ended: false,
entities,
})
}
#[inline]
fn handle_next_batch(&mut self) -> Result<(Option<LazyBatchColumnVec>, bool)> {
let src_result = self.entities.src.next_batch(crate::runner::BATCH_MAX_SIZE);
self.entities.context.warnings = src_result.warnings;
let src_is_drained = src_result.is_drained?;
if !src_result.logical_rows.is_empty() {
self.imp.process_batch_input(
&mut self.entities,
src_result.physical_columns,
&src_result.logical_rows,
)?;
}
let result = if src_is_drained || self.imp.is_partial_results_ready() {
Some(self.aggregate_partial_results(src_is_drained)?)
} else {
None
};
Ok((result, src_is_drained))
}
fn aggregate_partial_results(&mut self, src_is_drained: bool) -> Result<LazyBatchColumnVec> {
let groups_len = self.imp.groups_len();
let mut all_result_columns: Vec<_> = self
.entities
.all_result_column_types
.iter()
.map(|eval_type| VectorValue::with_capacity(groups_len, *eval_type))
.collect();
let group_by_columns = self.imp.iterate_available_groups(
&mut self.entities,
src_is_drained,
|entities, states| {
assert_eq!(states.len(), entities.each_aggr_cardinality.len());
let mut offset = 0;
for (state, result_cardinality) in
states.iter().zip(&entities.each_aggr_cardinality)
{
assert!(*result_cardinality > 0);
state.push_result(
&mut entities.context,
&mut all_result_columns[offset..offset + *result_cardinality],
)?;
offset += *result_cardinality;
}
Ok(())
},
)?;
let columns: Vec<_> = all_result_columns
.into_iter()
.map(LazyBatchColumn::Decoded)
.chain(group_by_columns)
.collect();
let ret = LazyBatchColumnVec::from(columns);
ret.assert_columns_equal_length();
Ok(ret)
}
}
impl<Src: BatchExecutor, I: AggregationExecutorImpl<Src>> BatchExecutor
for AggregationExecutor<Src, I>
{
type StorageStats = Src::StorageStats;
#[inline]
fn schema(&self) -> &[FieldType] {
self.entities.schema.as_slice()
}
#[inline]
fn next_batch(&mut self, _scan_rows: usize) -> BatchExecuteResult {
assert!(!self.is_ended);
let result = self.handle_next_batch();
match result {
Err(e) => {
self.is_ended = true;
BatchExecuteResult {
physical_columns: LazyBatchColumnVec::empty(),
logical_rows: Vec::new(),
warnings: self.entities.context.take_warnings(),
is_drained: Err(e),
}
}
Ok((data, src_is_drained)) => {
self.is_ended = src_is_drained;
let logical_columns = data.unwrap_or_else(LazyBatchColumnVec::empty);
let logical_rows = (0..logical_columns.rows_len()).collect();
BatchExecuteResult {
physical_columns: logical_columns,
logical_rows,
warnings: self.entities.context.take_warnings(),
is_drained: Ok(src_is_drained),
}
}
}
}
#[inline]
fn collect_exec_stats(&mut self, dest: &mut ExecuteStats) {
self.entities.src.collect_exec_stats(dest);
}
#[inline]
fn collect_storage_stats(&mut self, dest: &mut Self::StorageStats) {
self.entities.src.collect_storage_stats(dest);
}
#[inline]
fn take_scanned_range(&mut self) -> IntervalRange {
self.entities.src.take_scanned_range()
}
#[inline]
fn can_be_cached(&self) -> bool {
self.entities.src.can_be_cached()
}
}
#[cfg(test)]
pub mod tests {
use tidb_query_codegen::AggrFunction;
use tidb_query_datatype::builder::FieldTypeBuilder;
use tidb_query_datatype::{Collation, FieldTypeTp};
use crate::interface::*;
use crate::util::mock_executor::MockExecutor;
use tidb_query_aggr::*;
use tidb_query_common::Result;
use tidb_query_datatype::codec::batch::LazyBatchColumnVec;
use tidb_query_datatype::codec::data_type::*;
use tidb_query_datatype::expr::{EvalContext, EvalWarnings};
#[derive(Debug, AggrFunction)]
#[aggr_function(state = AggrFnUnreachableState)]
pub struct AggrFnUnreachable;
#[derive(Debug)]
pub struct AggrFnUnreachableState;
impl ConcreteAggrFunctionState for AggrFnUnreachableState {
type ParameterType = &'static Real;
unsafe fn update_concrete_unsafe(
&mut self,
_ctx: &mut EvalContext,
_value: Option<Self::ParameterType>,
) -> Result<()> {
unreachable!()
}
fn push_result(&self, _ctx: &mut EvalContext, _target: &mut [VectorValue]) -> Result<()> {
unreachable!()
}
}
pub fn make_src_executor_1() -> MockExecutor {
MockExecutor::new(
vec![
FieldTypeTp::Double.into(),
FieldTypeTp::Double.into(),
FieldTypeTp::VarString.into(),
FieldTypeTp::LongLong.into(),
FieldTypeBuilder::new()
.tp(FieldTypeTp::VarString)
.collation(Collation::Utf8Mb4GeneralCi)
.into(),
],
vec![
BatchExecuteResult {
physical_columns: LazyBatchColumnVec::from(vec![
VectorValue::Real(
vec![None, None, None, Real::new(-5.0).ok(), Real::new(7.0).ok()]
.into(),
),
VectorValue::Real(
vec![
None,
Real::new(4.5).ok(),
Real::new(1.0).ok(),
None,
Real::new(2.0).ok(),
]
.into(),
),
VectorValue::Bytes(
vec![
Some(vec![]),
Some(b"HelloWorld".to_vec()),
Some(b"abc".to_vec()),
None,
None,
]
.into(),
),
VectorValue::Int(vec![None, None, Some(1), Some(10), None].into()),
VectorValue::Bytes(
vec![
Some("áá".as_bytes().to_vec()),
None,
Some(b"aa".to_vec()),
Some("ááá".as_bytes().to_vec()),
Some(b"aaa".to_vec()),
]
.into(),
),
]),
logical_rows: vec![2, 4, 0, 1],
warnings: EvalWarnings::default(),
is_drained: Ok(false),
},
BatchExecuteResult {
physical_columns: LazyBatchColumnVec::from(vec![
VectorValue::Real(vec![None].into()),
VectorValue::Real(vec![Real::new(-10.0).ok()].into()),
VectorValue::Bytes(vec![Some(b"foo".to_vec())].into()),
VectorValue::Int(vec![None].into()),
VectorValue::Bytes(vec![None].into()),
]),
logical_rows: Vec::new(),
warnings: EvalWarnings::default(),
is_drained: Ok(false),
},
BatchExecuteResult {
physical_columns: LazyBatchColumnVec::from(vec![
VectorValue::Real(vec![Real::new(5.5).ok(), Real::new(1.5).ok()].into()),
VectorValue::Real(vec![None, Real::new(4.5).ok()].into()),
VectorValue::Bytes(vec![None, Some(b"aaaaa".to_vec())].into()),
VectorValue::Int(vec![None, Some(5)].into()),
VectorValue::Bytes(
vec![
Some("áá".as_bytes().to_vec()),
Some("ááá".as_bytes().to_vec()),
]
.into(),
),
]),
logical_rows: vec![1],
warnings: EvalWarnings::default(),
is_drained: Ok(true),
},
],
)
}
}