use std::borrow::Cow;
use std::collections::hash_map::Entry;
use std::collections::HashMap;
use std::ffi::{CStr, CString};
use std::future::Future;
use std::sync::Arc;
use std::time::Duration;
use std::{cmp, i32, ptr};
use crate::{
grpc_sys::{self, gpr_timespec, grpc_arg_pointer_vtable, grpc_channel, grpc_channel_args},
Deadline,
};
use libc::{self, c_char, c_int};
use crate::call::{Call, Method};
use crate::cq::CompletionQueue;
use crate::env::Environment;
use crate::error::Result;
use crate::task::CallTag;
use crate::task::Kicker;
use crate::CallOption;
use crate::ResourceQuota;
pub use crate::grpc_sys::{
grpc_compression_algorithm as CompressionAlgorithms,
grpc_compression_level as CompressionLevel, grpc_connectivity_state as ConnectivityState,
};
fn format_user_agent_string(agent: &str) -> CString {
let version = env!("CARGO_PKG_VERSION");
let trimed_agent = agent.trim();
let val = if trimed_agent.is_empty() {
format!("grpc-rust/{}", version)
} else {
format!("{} grpc-rust/{}", trimed_agent, version)
};
CString::new(val).unwrap()
}
fn dur_to_ms(dur: Duration) -> i32 {
let millis = dur.as_secs() * 1000 + dur.subsec_nanos() as u64 / 1_000_000;
cmp::min(i32::MAX as u64, millis) as i32
}
enum Options {
Integer(i32),
String(CString),
Pointer(ResourceQuota, *const grpc_arg_pointer_vtable),
}
#[derive(Clone, Copy)]
pub enum OptTarget {
Latency,
Blend,
Throughput,
}
#[derive(Clone, Copy)]
pub enum LbPolicy {
PickFirst,
RoundRobin,
}
pub struct ChannelBuilder {
env: Arc<Environment>,
options: HashMap<Cow<'static, [u8]>, Options>,
}
impl ChannelBuilder {
pub fn new(env: Arc<Environment>) -> ChannelBuilder {
ChannelBuilder {
env,
options: HashMap::new(),
}
}
pub fn default_authority<S: Into<Vec<u8>>>(mut self, authority: S) -> ChannelBuilder {
let authority = CString::new(authority).unwrap();
self.options.insert(
Cow::Borrowed(grpcio_sys::GRPC_ARG_DEFAULT_AUTHORITY),
Options::String(authority),
);
self
}
pub fn set_resource_quota(mut self, quota: ResourceQuota) -> ChannelBuilder {
unsafe {
self.options.insert(
Cow::Borrowed(grpcio_sys::GRPC_ARG_RESOURCE_QUOTA),
Options::Pointer(quota, grpc_sys::grpc_resource_quota_arg_vtable()),
);
}
self
}
pub fn max_concurrent_stream(mut self, num: i32) -> ChannelBuilder {
self.options.insert(
Cow::Borrowed(grpcio_sys::GRPC_ARG_MAX_CONCURRENT_STREAMS),
Options::Integer(num),
);
self
}
pub fn max_receive_message_len(mut self, len: i32) -> ChannelBuilder {
self.options.insert(
Cow::Borrowed(grpcio_sys::GRPC_ARG_MAX_RECEIVE_MESSAGE_LENGTH),
Options::Integer(len),
);
self
}
pub fn max_send_message_len(mut self, len: i32) -> ChannelBuilder {
self.options.insert(
Cow::Borrowed(grpcio_sys::GRPC_ARG_MAX_SEND_MESSAGE_LENGTH),
Options::Integer(len),
);
self
}
pub fn max_reconnect_backoff(mut self, backoff: Duration) -> ChannelBuilder {
self.options.insert(
Cow::Borrowed(grpcio_sys::GRPC_ARG_MAX_RECONNECT_BACKOFF_MS),
Options::Integer(dur_to_ms(backoff)),
);
self
}
pub fn initial_reconnect_backoff(mut self, backoff: Duration) -> ChannelBuilder {
self.options.insert(
Cow::Borrowed(grpcio_sys::GRPC_ARG_INITIAL_RECONNECT_BACKOFF_MS),
Options::Integer(dur_to_ms(backoff)),
);
self
}
pub fn https_initial_seq_number(mut self, number: i32) -> ChannelBuilder {
self.options.insert(
Cow::Borrowed(grpcio_sys::GRPC_ARG_HTTP2_INITIAL_SEQUENCE_NUMBER),
Options::Integer(number),
);
self
}
pub fn stream_initial_window_size(mut self, window_size: i32) -> ChannelBuilder {
self.options.insert(
Cow::Borrowed(grpcio_sys::GRPC_ARG_HTTP2_STREAM_LOOKAHEAD_BYTES),
Options::Integer(window_size),
);
self
}
pub fn primary_user_agent(mut self, agent: &str) -> ChannelBuilder {
let agent_string = format_user_agent_string(agent);
self.options.insert(
Cow::Borrowed(grpcio_sys::GRPC_ARG_PRIMARY_USER_AGENT_STRING),
Options::String(agent_string),
);
self
}
pub fn reuse_port(mut self, reuse: bool) -> ChannelBuilder {
let opt = if reuse { 1 } else { 0 };
self.options.insert(
Cow::Borrowed(grpcio_sys::GRPC_ARG_ALLOW_REUSEPORT),
Options::Integer(opt),
);
self
}
pub fn tcp_read_chunk_size(mut self, bytes: i32) -> ChannelBuilder {
self.options.insert(
Cow::Borrowed(grpcio_sys::GRPC_ARG_TCP_READ_CHUNK_SIZE),
Options::Integer(bytes),
);
self
}
pub fn tcp_min_read_chunk_size(mut self, bytes: i32) -> ChannelBuilder {
self.options.insert(
Cow::Borrowed(grpcio_sys::GRPC_ARG_TCP_MIN_READ_CHUNK_SIZE),
Options::Integer(bytes),
);
self
}
pub fn tcp_max_read_chunk_size(mut self, bytes: i32) -> ChannelBuilder {
self.options.insert(
Cow::Borrowed(grpcio_sys::GRPC_ARG_TCP_MAX_READ_CHUNK_SIZE),
Options::Integer(bytes),
);
self
}
pub fn http2_write_buffer_size(mut self, size: i32) -> ChannelBuilder {
self.options.insert(
Cow::Borrowed(grpcio_sys::GRPC_ARG_HTTP2_WRITE_BUFFER_SIZE),
Options::Integer(size),
);
self
}
pub fn http2_max_frame_size(mut self, size: i32) -> ChannelBuilder {
self.options.insert(
Cow::Borrowed(grpcio_sys::GRPC_ARG_HTTP2_MAX_FRAME_SIZE),
Options::Integer(size),
);
self
}
pub fn http2_bdp_probe(mut self, enable: bool) -> ChannelBuilder {
self.options.insert(
Cow::Borrowed(grpcio_sys::GRPC_ARG_HTTP2_BDP_PROBE),
Options::Integer(enable as i32),
);
self
}
pub fn http2_min_sent_ping_interval_without_data(
mut self,
interval: Duration,
) -> ChannelBuilder {
self.options.insert(
Cow::Borrowed(grpcio_sys::GRPC_ARG_HTTP2_MIN_SENT_PING_INTERVAL_WITHOUT_DATA_MS),
Options::Integer(dur_to_ms(interval)),
);
self
}
pub fn http2_min_recv_ping_interval_without_data(
mut self,
interval: Duration,
) -> ChannelBuilder {
self.options.insert(
Cow::Borrowed(grpcio_sys::GRPC_ARG_HTTP2_MIN_RECV_PING_INTERVAL_WITHOUT_DATA_MS),
Options::Integer(dur_to_ms(interval)),
);
self
}
pub fn http2_max_pings_without_data(mut self, num: i32) -> ChannelBuilder {
self.options.insert(
Cow::Borrowed(grpcio_sys::GRPC_ARG_HTTP2_MAX_PINGS_WITHOUT_DATA),
Options::Integer(num),
);
self
}
pub fn http2_max_ping_strikes(mut self, num: i32) -> ChannelBuilder {
self.options.insert(
Cow::Borrowed(grpcio_sys::GRPC_ARG_HTTP2_MAX_PING_STRIKES),
Options::Integer(num),
);
self
}
pub fn default_compression_algorithm(mut self, algo: CompressionAlgorithms) -> ChannelBuilder {
self.options.insert(
Cow::Borrowed(grpcio_sys::GRPC_COMPRESSION_CHANNEL_DEFAULT_ALGORITHM),
Options::Integer(algo as i32),
);
self
}
pub fn default_compression_level(mut self, level: CompressionLevel) -> ChannelBuilder {
self.options.insert(
Cow::Borrowed(grpcio_sys::GRPC_COMPRESSION_CHANNEL_DEFAULT_LEVEL),
Options::Integer(level as i32),
);
self
}
pub fn keepalive_time(mut self, timeout: Duration) -> ChannelBuilder {
self.options.insert(
Cow::Borrowed(grpcio_sys::GRPC_ARG_KEEPALIVE_TIME_MS),
Options::Integer(dur_to_ms(timeout)),
);
self
}
pub fn keepalive_timeout(mut self, timeout: Duration) -> ChannelBuilder {
self.options.insert(
Cow::Borrowed(grpcio_sys::GRPC_ARG_KEEPALIVE_TIMEOUT_MS),
Options::Integer(dur_to_ms(timeout)),
);
self
}
pub fn keepalive_permit_without_calls(mut self, allow: bool) -> ChannelBuilder {
self.options.insert(
Cow::Borrowed(grpcio_sys::GRPC_ARG_KEEPALIVE_PERMIT_WITHOUT_CALLS),
Options::Integer(allow as i32),
);
self
}
pub fn optimize_for(mut self, target: OptTarget) -> ChannelBuilder {
let val = match target {
OptTarget::Latency => CString::new("latency"),
OptTarget::Blend => CString::new("blend"),
OptTarget::Throughput => CString::new("throughput"),
};
self.options.insert(
Cow::Borrowed(grpcio_sys::GRPC_ARG_OPTIMIZATION_TARGET),
Options::String(val.unwrap()),
);
self
}
pub fn load_balancing_policy(mut self, lb_policy: LbPolicy) -> ChannelBuilder {
let val = match lb_policy {
LbPolicy::PickFirst => CString::new("pick_first"),
LbPolicy::RoundRobin => CString::new("round_robin"),
};
self.options.insert(
Cow::Borrowed(grpcio_sys::GRPC_ARG_LB_POLICY_NAME),
Options::String(val.unwrap()),
);
self
}
#[doc(hidden)]
pub fn raw_cfg_int(mut self, key: CString, val: i32) -> ChannelBuilder {
self.options
.insert(Cow::Owned(key.into_bytes_with_nul()), Options::Integer(val));
self
}
#[doc(hidden)]
pub fn raw_cfg_string(mut self, key: CString, val: CString) -> ChannelBuilder {
self.options
.insert(Cow::Owned(key.into_bytes_with_nul()), Options::String(val));
self
}
#[allow(clippy::useless_conversion)]
pub fn build_args(&self) -> ChannelArgs {
let args = unsafe { grpc_sys::grpcwrap_channel_args_create(self.options.len()) };
for (i, (k, v)) in self.options.iter().enumerate() {
let key = k.as_ptr() as *const c_char;
match *v {
Options::Integer(val) => unsafe {
assert!(
val <= i32::from(libc::INT_MAX) && val >= i32::from(libc::INT_MIN),
"{} is out of range for {:?}",
val,
CStr::from_bytes_with_nul(k).unwrap()
);
grpc_sys::grpcwrap_channel_args_set_integer(args, i, key, val as c_int)
},
Options::String(ref val) => unsafe {
grpc_sys::grpcwrap_channel_args_set_string(args, i, key, val.as_ptr())
},
Options::Pointer(ref quota, vtable) => unsafe {
grpc_sys::grpcwrap_channel_args_set_pointer_vtable(
args,
i,
key,
quota.get_ptr() as _,
vtable,
)
},
}
}
ChannelArgs { args }
}
fn prepare_connect_args(&mut self) -> ChannelArgs {
if let Entry::Vacant(e) = self.options.entry(Cow::Borrowed(
grpcio_sys::GRPC_ARG_PRIMARY_USER_AGENT_STRING,
)) {
e.insert(Options::String(format_user_agent_string("")));
}
self.build_args()
}
pub fn connect(mut self, addr: &str) -> Channel {
let args = self.prepare_connect_args();
let addr = CString::new(addr).unwrap();
let addr_ptr = addr.as_ptr();
let channel =
unsafe { grpc_sys::grpc_insecure_channel_create(addr_ptr, args.args, ptr::null_mut()) };
unsafe { Channel::new(self.env.pick_cq(), self.env, channel) }
}
#[cfg(unix)]
pub unsafe fn connect_from_fd(mut self, target: &str, fd: ::std::os::raw::c_int) -> Channel {
let args = self.prepare_connect_args();
let target = CString::new(target).unwrap();
let target_ptr = target.as_ptr();
let channel = grpc_sys::grpc_insecure_channel_create_from_fd(target_ptr, fd, args.args);
Channel::new(self.env.pick_cq(), self.env, channel)
}
}
#[cfg(feature = "secure")]
mod secure_channel {
use std::borrow::Cow;
use std::ffi::CString;
use std::ptr;
use crate::grpc_sys;
use crate::ChannelCredentials;
use super::{Channel, ChannelBuilder, Options};
const OPT_SSL_TARGET_NAME_OVERRIDE: &[u8] = b"grpc.ssl_target_name_override\0";
impl ChannelBuilder {
#[doc(hidden)]
pub fn override_ssl_target<S: Into<Vec<u8>>>(mut self, target: S) -> ChannelBuilder {
let target = CString::new(target).unwrap();
self.options.insert(
Cow::Borrowed(OPT_SSL_TARGET_NAME_OVERRIDE),
Options::String(target),
);
self
}
pub fn secure_connect(mut self, addr: &str, mut creds: ChannelCredentials) -> Channel {
let args = self.prepare_connect_args();
let addr = CString::new(addr).unwrap();
let addr_ptr = addr.as_ptr();
let channel = unsafe {
grpc_sys::grpc_secure_channel_create(
creds.as_mut_ptr(),
addr_ptr,
args.args,
ptr::null_mut(),
)
};
unsafe { Channel::new(self.env.pick_cq(), self.env, channel) }
}
}
}
pub struct ChannelArgs {
args: *mut grpc_channel_args,
}
impl ChannelArgs {
pub fn as_ptr(&self) -> *const grpc_channel_args {
self.args
}
}
impl Drop for ChannelArgs {
fn drop(&mut self) {
unsafe { grpc_sys::grpcwrap_channel_args_destroy(self.args) }
}
}
struct ChannelInner {
_env: Arc<Environment>,
channel: *mut grpc_channel,
}
impl ChannelInner {
fn check_connectivity_state(&self, try_to_connect: bool) -> ConnectivityState {
let should_try = if try_to_connect { 1 } else { 0 };
unsafe { grpc_sys::grpc_channel_check_connectivity_state(self.channel, should_try) }
}
}
impl Drop for ChannelInner {
fn drop(&mut self) {
unsafe {
grpc_sys::grpc_channel_destroy(self.channel);
}
}
}
#[derive(Clone)]
pub struct Channel {
inner: Arc<ChannelInner>,
cq: CompletionQueue,
}
unsafe impl Send for Channel {}
unsafe impl Sync for Channel {}
impl Channel {
pub unsafe fn new(
cq: CompletionQueue,
env: Arc<Environment>,
channel: *mut grpc_channel,
) -> Channel {
Channel {
inner: Arc::new(ChannelInner { _env: env, channel }),
cq,
}
}
pub fn check_connectivity_state(&self, try_to_connect: bool) -> ConnectivityState {
self.inner.check_connectivity_state(try_to_connect)
}
pub fn wait_for_state_change(
&self,
last_observed: ConnectivityState,
deadline: impl Into<Deadline>,
) -> impl Future<Output = bool> {
let (cq_f, prom) = CallTag::action_pair();
let prom_box = Box::new(prom);
let tag = Box::into_raw(prom_box);
let should_wait = if let Ok(cq_ref) = self.cq.borrow() {
unsafe {
grpcio_sys::grpc_channel_watch_connectivity_state(
self.inner.channel,
last_observed,
deadline.into().spec(),
cq_ref.as_ptr(),
tag as *mut _,
)
}
true
} else {
false
};
async move { should_wait && cq_f.await.unwrap() }
}
pub async fn wait_for_connected(&self, deadline: impl Into<Deadline>) -> bool {
let mut state = self.check_connectivity_state(true);
if ConnectivityState::GRPC_CHANNEL_READY == state {
return true;
}
let deadline = deadline.into();
loop {
if self.wait_for_state_change(state, deadline).await {
state = self.check_connectivity_state(true);
match state {
ConnectivityState::GRPC_CHANNEL_READY => return true,
ConnectivityState::GRPC_CHANNEL_SHUTDOWN => return false,
_ => (),
}
continue;
}
return false;
}
}
pub(crate) fn create_kicker(&self) -> Result<Kicker> {
let cq_ref = self.cq.borrow()?;
let raw_call = unsafe {
let ch = self.inner.channel;
let cq = cq_ref.as_ptr();
let timeout = gpr_timespec::inf_future();
grpc_sys::grpcwrap_channel_create_call(
ch,
ptr::null_mut(),
0,
cq,
ptr::null(),
0,
ptr::null(),
0,
timeout,
)
};
let call = unsafe { Call::from_raw(raw_call, self.cq.clone()) };
Ok(Kicker::from_call(call))
}
pub(crate) fn create_call<Req, Resp>(
&self,
method: &Method<Req, Resp>,
opt: &CallOption,
) -> Result<Call> {
let cq_ref = self.cq.borrow()?;
let raw_call = unsafe {
let ch = self.inner.channel;
let cq = cq_ref.as_ptr();
let method_ptr = method.name.as_ptr();
let method_len = method.name.len();
let timeout = opt
.get_timeout()
.map_or_else(gpr_timespec::inf_future, gpr_timespec::from);
grpc_sys::grpcwrap_channel_create_call(
ch,
ptr::null_mut(),
0,
cq,
method_ptr as *const _,
method_len,
ptr::null(),
0,
timeout,
)
};
unsafe { Ok(Call::from_raw(raw_call, self.cq.clone())) }
}
pub(crate) fn cq(&self) -> &CompletionQueue {
&self.cq
}
}