1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
// Copyright 2019 TiKV Project Authors. Licensed under Apache-2.0.

//! People implementing RPN functions with fixed argument type and count don't necessarily
//! need to understand how `Evaluator` and `RpnDef` work. There's a procedural macro called
//! `rpn_fn` defined in `tidb_query_codegen` to help you create RPN functions. For example:
//!
//! ```ignore
//! use tidb_query_codegen::rpn_fn;
//!
//! #[rpn_fn(nullable)]
//! fn foo(lhs: &Option<Int>, rhs: &Option<Int>) -> Result<Option<Int>> {
//!     // Your RPN function logic
//! }
//! ```
//!
//! You can still call the `foo` function directly; the macro preserves the original function
//! It creates a `foo_fn_meta()` function (simply add `_fn_meta` to the original
//! function name) which generates an `RpnFnMeta` struct.
//!
//! For more information on the procedural macro, see the documentation in
//! `components/tidb_query_codegen/src/rpn_function`.

use static_assertions::assert_eq_size;
use std::any::Any;
use std::convert::TryFrom;
use std::marker::PhantomData;

use tidb_query_datatype::{EvalType, FieldTypeAccessor};
use tipb::{Expr, FieldType};

use super::expr_eval::LogicalRows;
use super::RpnStackNode;
use tidb_query_common::Result;
use tidb_query_datatype::codec::data_type::*;
use tidb_query_datatype::expr::EvalContext;

/// Metadata of an RPN function.
#[derive(Clone, Copy)]
pub struct RpnFnMeta {
    /// The display name of the RPN function. Mainly used in tests.
    pub name: &'static str,

    /// Validator against input expression tree.
    pub validator_ptr: fn(expr: &Expr) -> Result<()>,

    /// The metadata constructor of the RPN function.
    pub metadata_expr_ptr: fn(expr: &mut Expr) -> Result<Box<dyn Any + Send>>,

    #[allow(clippy::type_complexity)]
    /// The RPN function.
    pub fn_ptr: fn(
        // Common arguments
        ctx: &mut EvalContext,
        output_rows: usize,
        args: &[RpnStackNode<'_>],
        // Uncommon arguments are grouped together
        extra: &mut RpnFnCallExtra<'_>,
        metadata: &(dyn Any + Send),
    ) -> Result<VectorValue>,
}

impl std::fmt::Debug for RpnFnMeta {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        write!(f, "{}", self.name)
    }
}

/// Extra information about an RPN function call.
pub struct RpnFnCallExtra<'a> {
    /// The field type of the return value.
    pub ret_field_type: &'a FieldType,
}

/// A single argument of an RPN function.
pub trait RpnFnArg: std::fmt::Debug {
    type Type;

    /// Gets the value in the given row.
    fn get(&self, row: usize) -> Self::Type;

    /// Gets the bit vector of the arg.
    /// Returns `None` if scalar value, and bool indicates whether
    /// all is null or isn't null, otherwise a BitVec.
    /// Returns `Some` if vector value, and bool indicates whether
    /// stored bitmap vector has the same layout as elements,
    /// aka. logical_rows is identical or not. If logical_rows is
    /// identical, the second tuple element yields true.
    fn get_bit_vec(&self) -> (Option<&BitVec>, bool);
}

/// Represents an RPN function argument of a `ScalarValue`.
#[derive(Clone, Copy, Debug)]
pub struct ScalarArg<'a, T: EvaluableRef<'a>>(Option<T>, PhantomData<&'a T>);

impl<'a, T: EvaluableRef<'a>> ScalarArg<'a, T> {
    pub fn new(data: Option<T>) -> Self {
        Self(data, PhantomData)
    }
}

impl<'a, T: EvaluableRef<'a>> RpnFnArg for ScalarArg<'a, T> {
    type Type = Option<T>;

    /// Gets the value in the given row. All rows of a `ScalarArg` share the same value.
    #[inline]
    fn get(&self, _row: usize) -> Option<T> {
        self.0.clone()
    }

    // All items of scalar arg is either not null or null
    #[inline]
    fn get_bit_vec(&self) -> (Option<&BitVec>, bool) {
        (None, self.0.is_some())
    }
}

/// Represents an RPN function argument of a `VectorValue`.
#[derive(Clone, Copy, Debug)]
pub struct VectorArg<'a, T: 'a + EvaluableRef<'a>, C: 'a + ChunkRef<'a, T>> {
    physical_col: C,
    logical_rows: LogicalRows<'a>,
    _phantom: PhantomData<T>,
}

impl<'a, T: EvaluableRef<'a>, C: 'a + ChunkRef<'a, T>> RpnFnArg for VectorArg<'a, T, C> {
    type Type = Option<T>;

    #[inline]
    fn get(&self, row: usize) -> Option<T> {
        let logical_index = self.logical_rows.get_idx(row);
        self.physical_col.get_option_ref(logical_index)
    }

    #[inline]
    fn get_bit_vec(&self) -> (Option<&BitVec>, bool) {
        (
            Some(self.physical_col.get_bit_vec()),
            self.logical_rows.is_ident(),
        )
    }
}

/// Partial or complete argument definition of an RPN function.
///
/// `ArgDef` is constructed at the beginning of evaluating an RPN function. The types of
/// `RpnFnArg`s are determined at this stage. So there won't be dynamic dispatch or enum matches
/// when the function is applied to each row of the input.
pub trait ArgDef: std::fmt::Debug {}

/// RPN function argument definitions in the form of a linked list.
///
/// For example, if an RPN function foo(Int, Real, Decimal) is applied to input of a scalar of
/// integer, a vector of reals and a vector of decimals, the constructed `ArgDef` will be
/// `Arg<ScalarArg<Int>, Arg<VectorValue<Real>, Arg<VectorValue<Decimal>, Null>>>`. `Null`
/// indicates the end of the argument list.
#[derive(Debug)]
pub struct Arg<A: RpnFnArg, Rem: ArgDef> {
    arg: A,
    rem: Rem,
}

impl<A: RpnFnArg, Rem: ArgDef> ArgDef for Arg<A, Rem> {}

impl<A: RpnFnArg, Rem: ArgDef> Arg<A, Rem> {
    /// Gets the value of the head argument in the given row and returns the remaining argument
    /// list.
    #[inline]
    pub fn extract(&self, row: usize) -> (A::Type, &Rem) {
        (self.arg.get(row), &self.rem)
    }

    /// Gets the bit vector of each arg
    #[inline]
    pub fn get_bit_vec(&self) -> ((Option<&BitVec>, bool), &Rem) {
        (self.arg.get_bit_vec(), &self.rem)
    }
}

/// Represents the end of the argument list.
#[derive(Debug)]
pub struct Null;

impl ArgDef for Null {}

/// A generic evaluator of an RPN function.
///
/// For every RPN function, the evaluator should be created first. Then, call its `eval` method
/// with the input to get the result vector.
///
/// There are two kinds of evaluators in general:
/// - `ArgConstructor`: It's a provided `Evaluator`. It is used in the `rpn_fn` attribute macro
///   to generate the `ArgDef`. The `def` parameter of its eval method is the already constructed
///   `ArgDef`. If it is the outmost evaluator, `def` should be `Null`.
/// - Custom evaluators which do the actual execution of the RPN function. The `def` parameter of
///   its eval method is the constructed `ArgDef`. Implementors can then extract values from the
///   arguments, execute the RPN function and fill the result vector.
pub trait Evaluator<'a> {
    fn eval(
        self,
        def: impl ArgDef,
        ctx: &mut EvalContext,
        output_rows: usize,
        args: &'a [RpnStackNode<'a>],
        extra: &mut RpnFnCallExtra<'_>,
        metadata: &(dyn Any + Send),
    ) -> Result<VectorValue>;
}

pub struct ArgConstructor<'a, A: EvaluableRef<'a>, E: Evaluator<'a>> {
    arg_index: usize,
    inner: E,
    _phantom: PhantomData<&'a A>,
}

impl<'a, A: EvaluableRef<'a>, E: Evaluator<'a>> ArgConstructor<'a, A, E> {
    #[inline]
    pub fn new(arg_index: usize, inner: E) -> Self {
        ArgConstructor {
            arg_index,
            inner,
            _phantom: PhantomData,
        }
    }
}

impl<'a, A: EvaluableRef<'a>, E: Evaluator<'a>> Evaluator<'a> for ArgConstructor<'a, A, E> {
    fn eval(
        self,
        def: impl ArgDef,
        ctx: &mut EvalContext,
        output_rows: usize,
        args: &'a [RpnStackNode<'a>],
        extra: &mut RpnFnCallExtra<'_>,
        metadata: &(dyn Any + Send),
    ) -> Result<VectorValue> {
        match &args[self.arg_index] {
            RpnStackNode::Scalar { value, .. } => {
                let v = A::borrow_scalar_value_ref(value.as_scalar_value_ref());
                let new_def = Arg {
                    arg: ScalarArg::new(v),
                    rem: def,
                };
                self.inner
                    .eval(new_def, ctx, output_rows, args, extra, metadata)
            }
            RpnStackNode::Vector { value, .. } => {
                let logical_rows = value.logical_rows_struct();

                let v = A::borrow_vector_value(value.as_ref());

                let new_def = Arg {
                    arg: VectorArg {
                        physical_col: v,
                        logical_rows,
                        _phantom: PhantomData,
                    },
                    rem: def,
                };
                self.inner
                    .eval(new_def, ctx, output_rows, args, extra, metadata)
            }
        }
    }
}

/// Validates whether the return type of an expression node meets expectation.
pub fn validate_expr_return_type(expr: &Expr, et: EvalType) -> Result<()> {
    let received_et = box_try!(EvalType::try_from(expr.get_field_type().as_accessor().tp()));
    if et == received_et {
        Ok(())
    } else {
        match (et, received_et) {
            (EvalType::Int, EvalType::Enum) | (EvalType::Bytes, EvalType::Enum) => Ok(()),
            _ => Err(other_err!("Expect `{}`, received `{}`", et, received_et)),
        }
    }
}

/// Validates whether the number of arguments of an expression node meets expectation.
pub fn validate_expr_arguments_eq(expr: &Expr, args: usize) -> Result<()> {
    let received_args = expr.get_children().len();
    if received_args == args {
        Ok(())
    } else {
        Err(other_err!(
            "Expect {} arguments, received {}",
            args,
            received_args
        ))
    }
}

/// Validates whether the number of arguments of an expression node >= expectation.
pub fn validate_expr_arguments_gte(expr: &Expr, args: usize) -> Result<()> {
    let received_args = expr.get_children().len();
    if received_args >= args {
        Ok(())
    } else {
        Err(other_err!(
            "Expect at least {} arguments, received {}",
            args,
            received_args
        ))
    }
}

/// Validates whether the number of arguments of an expression node <= expectation.
pub fn validate_expr_arguments_lte(expr: &Expr, args: usize) -> Result<()> {
    let received_args = expr.get_children().len();
    if received_args <= args {
        Ok(())
    } else {
        Err(other_err!(
            "Expect at most {} arguments, received {}",
            args,
            received_args
        ))
    }
}

// `VARG_PARAM_BUF` is a thread-local cache for evaluating vargs
// `rpn_fn`. In this way, we can reduce overhead of allocating new Vec.
// According to https://doc.rust-lang.org/std/mem/fn.size_of.html ,
// &T and Option<&T> has the same size.
assert_eq_size!(usize, Option<&Int>);
assert_eq_size!(usize, Option<&Real>);
assert_eq_size!(usize, Option<&Decimal>);
assert_eq_size!(usize, Option<&Bytes>);
assert_eq_size!(usize, Option<&DateTime>);
assert_eq_size!(usize, Option<&Duration>);
assert_eq_size!(usize, Option<&Json>);

thread_local! {
    pub static VARG_PARAM_BUF: std::cell::RefCell<Vec<usize>> =
        std::cell::RefCell::new(Vec::with_capacity(20));

    pub static VARG_PARAM_BUF_BYTES_REF: std::cell::RefCell<Vec<Option<BytesRef<'static>>>> =
        std::cell::RefCell::new(Vec::with_capacity(20));

    pub static VARG_PARAM_BUF_JSON_REF: std::cell::RefCell<Vec<Option<JsonRef<'static>>>> =
        std::cell::RefCell::new(Vec::with_capacity(20));

    pub static RAW_VARG_PARAM_BUF: std::cell::RefCell<Vec<ScalarValueRef<'static>>> =
        std::cell::RefCell::new(Vec::with_capacity(20));
}

pub fn extract_metadata_from_val<T: protobuf::Message + Default>(val: &[u8]) -> Result<T> {
    if val.is_empty() {
        Ok(T::default())
    } else {
        protobuf::parse_from_bytes::<T>(val)
            .map_err(|e| other_err!("Decode metadata failed: {}", e))
    }
}