use std::cell::UnsafeCell;
use std::collections::HashMap;
use std::fmt::{self, Debug, Formatter};
use std::net::{IpAddr, SocketAddr};
use std::pin::Pin;
use std::ptr;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use crate::grpc_sys::{self, grpc_call_error, grpc_server};
use futures::future::Future;
use futures::ready;
use futures::task::{Context, Poll};
use crate::call::server::*;
use crate::call::{MessageReader, Method, MethodType};
use crate::channel::ChannelArgs;
use crate::cq::CompletionQueue;
use crate::env::Environment;
use crate::error::{Error, Result};
use crate::task::{CallTag, CqFuture};
use crate::RpcContext;
use crate::RpcStatus;
const DEFAULT_REQUEST_SLOTS_PER_CQ: usize = 1024;
#[derive(Clone)]
pub struct Handler<F> {
method_type: MethodType,
cb: F,
}
impl<F> Handler<F> {
pub fn new(method_type: MethodType, cb: F) -> Handler<F> {
Handler { method_type, cb }
}
}
pub trait CloneableHandler: Send {
fn handle(&mut self, ctx: RpcContext<'_>, reqs: Option<MessageReader>);
fn box_clone(&self) -> Box<dyn CloneableHandler>;
fn method_type(&self) -> MethodType;
}
impl<F: 'static> CloneableHandler for Handler<F>
where
F: FnMut(RpcContext<'_>, Option<MessageReader>) + Send + Clone,
{
#[inline]
fn handle(&mut self, ctx: RpcContext<'_>, reqs: Option<MessageReader>) {
(self.cb)(ctx, reqs)
}
#[inline]
fn box_clone(&self) -> Box<dyn CloneableHandler> {
Box::new(self.clone())
}
#[inline]
fn method_type(&self) -> MethodType {
self.method_type
}
}
fn join_host_port(host: &str, port: u16) -> String {
if host.starts_with("unix:") {
format!("{}\0", host)
} else if let Ok(ip) = host.parse::<IpAddr>() {
format!("{}\0", SocketAddr::new(ip, port))
} else {
format!("{}:{}\0", host, port)
}
}
#[cfg(feature = "secure")]
mod imp {
use super::join_host_port;
use crate::grpc_sys::{self, grpc_server};
use crate::security::ServerCredentialsFetcher;
use crate::ServerCredentials;
pub struct Binder {
pub host: String,
pub port: u16,
cred: Option<ServerCredentials>,
_fetcher: Option<Box<Box<dyn ServerCredentialsFetcher + Send + Sync>>>,
}
impl Binder {
pub fn new(host: String, port: u16) -> Binder {
let cred = None;
Binder {
host,
port,
cred,
_fetcher: None,
}
}
pub fn with_cred(
host: String,
port: u16,
cred: ServerCredentials,
_fetcher: Option<Box<Box<dyn ServerCredentialsFetcher + Send + Sync>>>,
) -> Binder {
let cred = Some(cred);
Binder {
host,
port,
cred,
_fetcher,
}
}
pub unsafe fn bind(&mut self, server: *mut grpc_server) -> u16 {
let addr = join_host_port(&self.host, self.port);
let port = match self.cred.take() {
None => grpc_sys::grpc_server_add_insecure_http2_port(server, addr.as_ptr() as _),
Some(mut cert) => grpc_sys::grpc_server_add_secure_http2_port(
server,
addr.as_ptr() as _,
cert.as_mut_ptr(),
),
};
port as u16
}
}
}
#[cfg(not(feature = "secure"))]
mod imp {
use super::join_host_port;
use crate::grpc_sys::{self, grpc_server};
pub struct Binder {
pub host: String,
pub port: u16,
}
impl Binder {
pub fn new(host: String, port: u16) -> Binder {
Binder { host, port }
}
pub unsafe fn bind(&mut self, server: *mut grpc_server) -> u16 {
let addr = join_host_port(&self.host, self.port);
grpc_sys::grpc_server_add_insecure_http2_port(server, addr.as_ptr() as _) as u16
}
}
}
use self::imp::Binder;
impl Debug for Binder {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
write!(f, "Binder {{ host: {}, port: {} }}", self.host, self.port)
}
}
pub struct ServiceBuilder {
handlers: HashMap<&'static [u8], BoxHandler>,
}
impl ServiceBuilder {
pub fn new() -> ServiceBuilder {
ServiceBuilder {
handlers: HashMap::new(),
}
}
pub fn add_unary_handler<Req, Resp, F>(
mut self,
method: &Method<Req, Resp>,
mut handler: F,
) -> ServiceBuilder
where
Req: 'static,
Resp: 'static,
F: FnMut(RpcContext<'_>, Req, UnarySink<Resp>) + Send + Clone + 'static,
{
let (ser, de) = (method.resp_ser(), method.req_de());
let h = move |ctx: RpcContext<'_>, payload: Option<MessageReader>| {
execute_unary(ctx, ser, de, payload.unwrap(), &mut handler)
};
let ch = Box::new(Handler::new(MethodType::Unary, h));
self.handlers.insert(method.name.as_bytes(), ch);
self
}
pub fn add_client_streaming_handler<Req, Resp, F>(
mut self,
method: &Method<Req, Resp>,
mut handler: F,
) -> ServiceBuilder
where
Req: 'static,
Resp: 'static,
F: FnMut(RpcContext<'_>, RequestStream<Req>, ClientStreamingSink<Resp>)
+ Send
+ Clone
+ 'static,
{
let (ser, de) = (method.resp_ser(), method.req_de());
let h = move |ctx: RpcContext<'_>, _: Option<MessageReader>| {
execute_client_streaming(ctx, ser, de, &mut handler)
};
let ch = Box::new(Handler::new(MethodType::ClientStreaming, h));
self.handlers.insert(method.name.as_bytes(), ch);
self
}
pub fn add_server_streaming_handler<Req, Resp, F>(
mut self,
method: &Method<Req, Resp>,
mut handler: F,
) -> ServiceBuilder
where
Req: 'static,
Resp: 'static,
F: FnMut(RpcContext<'_>, Req, ServerStreamingSink<Resp>) + Send + Clone + 'static,
{
let (ser, de) = (method.resp_ser(), method.req_de());
let h = move |ctx: RpcContext<'_>, payload: Option<MessageReader>| {
execute_server_streaming(ctx, ser, de, payload.unwrap(), &mut handler)
};
let ch = Box::new(Handler::new(MethodType::ServerStreaming, h));
self.handlers.insert(method.name.as_bytes(), ch);
self
}
pub fn add_duplex_streaming_handler<Req, Resp, F>(
mut self,
method: &Method<Req, Resp>,
mut handler: F,
) -> ServiceBuilder
where
Req: 'static,
Resp: 'static,
F: FnMut(RpcContext<'_>, RequestStream<Req>, DuplexSink<Resp>) + Send + Clone + 'static,
{
let (ser, de) = (method.resp_ser(), method.req_de());
let h = move |ctx: RpcContext<'_>, _: Option<MessageReader>| {
execute_duplex_streaming(ctx, ser, de, &mut handler)
};
let ch = Box::new(Handler::new(MethodType::Duplex, h));
self.handlers.insert(method.name.as_bytes(), ch);
self
}
pub fn build(self) -> Service {
Service {
handlers: self.handlers,
}
}
}
pub enum CheckResult {
Continue,
Abort(RpcStatus),
}
pub trait ServerChecker: Send {
fn check(&mut self, ctx: &RpcContext) -> CheckResult;
fn box_clone(&self) -> Box<dyn ServerChecker>;
}
impl Clone for Box<dyn ServerChecker> {
fn clone(&self) -> Self {
self.box_clone()
}
}
pub struct Service {
handlers: HashMap<&'static [u8], BoxHandler>,
}
pub struct ServerBuilder {
env: Arc<Environment>,
binders: Vec<Binder>,
args: Option<ChannelArgs>,
slots_per_cq: usize,
handlers: HashMap<&'static [u8], BoxHandler>,
checkers: Vec<Box<dyn ServerChecker>>,
}
impl ServerBuilder {
pub fn new(env: Arc<Environment>) -> ServerBuilder {
ServerBuilder {
env,
binders: Vec::new(),
args: None,
slots_per_cq: DEFAULT_REQUEST_SLOTS_PER_CQ,
handlers: HashMap::new(),
checkers: Vec::new(),
}
}
pub fn bind<S: Into<String>>(mut self, host: S, port: u16) -> ServerBuilder {
self.binders.push(Binder::new(host.into(), port));
self
}
pub fn channel_args(mut self, args: ChannelArgs) -> ServerBuilder {
self.args = Some(args);
self
}
pub fn requests_slot_per_cq(mut self, slots: usize) -> ServerBuilder {
self.slots_per_cq = slots;
self
}
pub fn register_service(mut self, service: Service) -> ServerBuilder {
self.handlers.extend(service.handlers);
self
}
pub fn add_checker<C: ServerChecker + 'static>(mut self, checker: C) -> ServerBuilder {
self.checkers.push(Box::new(checker));
self
}
pub fn build(mut self) -> Result<Server> {
let args = self
.args
.as_ref()
.map_or_else(ptr::null, ChannelArgs::as_ptr);
unsafe {
let server = grpc_sys::grpc_server_create(args, ptr::null_mut());
for binder in self.binders.iter_mut() {
let bind_port = binder.bind(server);
if bind_port == 0 {
grpc_sys::grpc_server_destroy(server);
return Err(Error::BindFail(binder.host.clone(), binder.port));
}
binder.port = bind_port;
}
for cq in self.env.completion_queues() {
let cq_ref = cq.borrow()?;
grpc_sys::grpc_server_register_completion_queue(
server,
cq_ref.as_ptr(),
ptr::null_mut(),
);
}
Ok(Server {
env: self.env,
core: Arc::new(ServerCore {
server,
shutdown: AtomicBool::new(false),
binders: self.binders,
slots_per_cq: self.slots_per_cq,
}),
handlers: self.handlers,
checkers: self.checkers,
})
}
}
}
#[cfg(feature = "secure")]
mod secure_server {
use super::{Binder, ServerBuilder};
use crate::grpc_sys;
use crate::security::{
server_cert_fetcher_wrapper, CertificateRequestType, ServerCredentials,
ServerCredentialsFetcher,
};
impl ServerBuilder {
pub fn bind_with_cred<S: Into<String>>(
mut self,
host: S,
port: u16,
c: ServerCredentials,
) -> ServerBuilder {
self.binders
.push(Binder::with_cred(host.into(), port, c, None));
self
}
pub fn bind_with_fetcher<S: Into<String>>(
mut self,
host: S,
port: u16,
fetcher: Box<dyn ServerCredentialsFetcher + Send + Sync>,
cer_request_type: CertificateRequestType,
) -> ServerBuilder {
let fetcher_wrap = Box::new(fetcher);
let fetcher_wrap_ptr = Box::into_raw(fetcher_wrap);
let (sc, fb) = unsafe {
let opt = grpc_sys::grpc_ssl_server_credentials_create_options_using_config_fetcher(
cer_request_type.to_native(),
Some(server_cert_fetcher_wrapper),
fetcher_wrap_ptr as _,
);
(
ServerCredentials::frow_raw(
grpcio_sys::grpc_ssl_server_credentials_create_with_options(opt),
),
Box::from_raw(fetcher_wrap_ptr),
)
};
self.binders
.push(Binder::with_cred(host.into(), port, sc, Some(fb)));
self
}
}
}
struct ServerCore {
server: *mut grpc_server,
binders: Vec<Binder>,
slots_per_cq: usize,
shutdown: AtomicBool,
}
impl Drop for ServerCore {
fn drop(&mut self) {
unsafe { grpc_sys::grpc_server_destroy(self.server) }
}
}
unsafe impl Send for ServerCore {}
unsafe impl Sync for ServerCore {}
pub type BoxHandler = Box<dyn CloneableHandler>;
#[derive(Clone)]
pub struct RequestCallContext {
server: Arc<ServerCore>,
registry: Arc<UnsafeCell<HashMap<&'static [u8], BoxHandler>>>,
checkers: Vec<Box<dyn ServerChecker>>,
}
impl RequestCallContext {
#[inline]
pub unsafe fn get_handler(&mut self, path: &[u8]) -> Option<&mut BoxHandler> {
let registry = &mut *self.registry.get();
registry.get_mut(path)
}
pub(crate) fn get_checker(&self) -> Vec<Box<dyn ServerChecker>> {
self.checkers.clone()
}
}
unsafe impl Send for RequestCallContext {}
pub fn request_call(ctx: RequestCallContext, cq: &CompletionQueue) {
if ctx.server.shutdown.load(Ordering::Relaxed) {
return;
}
let cq_ref = match cq.borrow() {
Err(_) => return,
Ok(c) => c,
};
let server_ptr = ctx.server.server;
let prom = CallTag::request(ctx);
let request_ptr = prom.request_ctx().unwrap().as_ptr();
let prom_box = Box::new(prom);
let tag = Box::into_raw(prom_box);
let code = unsafe {
grpc_sys::grpcwrap_server_request_call(
server_ptr,
cq_ref.as_ptr(),
request_ptr,
tag as *mut _,
)
};
if code != grpc_call_error::GRPC_CALL_OK {
Box::from(tag);
panic!("failed to request call: {:?}", code);
}
}
pub struct ShutdownFuture {
cq_f: CqFuture<bool>,
}
impl Future for ShutdownFuture {
type Output = Result<()>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
match ready!(Pin::new(&mut self.cq_f).poll(cx)) {
Ok(true) => Poll::Ready(Ok(())),
Ok(false) => Poll::Ready(Err(Error::ShutdownFailed)),
Err(e) => unreachable!("action future should never resolve to error: {}", e),
}
}
}
pub struct Server {
env: Arc<Environment>,
core: Arc<ServerCore>,
handlers: HashMap<&'static [u8], BoxHandler>,
checkers: Vec<Box<dyn ServerChecker>>,
}
impl Server {
pub fn shutdown(&mut self) -> ShutdownFuture {
let (cq_f, prom) = CallTag::action_pair();
let prom_box = Box::new(prom);
let tag = Box::into_raw(prom_box);
unsafe {
let cq_ref = self.env.completion_queues()[0].borrow().unwrap();
grpc_sys::grpc_server_shutdown_and_notify(
self.core.server,
cq_ref.as_ptr(),
tag as *mut _,
)
}
self.core.shutdown.store(true, Ordering::SeqCst);
ShutdownFuture { cq_f }
}
pub fn cancel_all_calls(&mut self) {
unsafe { grpc_sys::grpc_server_cancel_all_calls(self.core.server) }
}
pub fn start(&mut self) {
unsafe {
grpc_sys::grpc_server_start(self.core.server);
for cq in self.env.completion_queues() {
let registry = self
.handlers
.iter()
.map(|(k, v)| (k.to_owned(), v.box_clone()))
.collect();
let rc = RequestCallContext {
server: self.core.clone(),
registry: Arc::new(UnsafeCell::new(registry)),
checkers: self.checkers.clone(),
};
for _ in 0..self.core.slots_per_cq {
request_call(rc.clone(), cq);
}
}
}
}
pub fn bind_addrs(&self) -> impl ExactSizeIterator<Item = (&String, u16)> {
self.core.binders.iter().map(|b| (&b.host, b.port))
}
#[cfg(unix)]
pub unsafe fn add_insecure_channel_from_fd(&self, fd: ::std::os::raw::c_int) {
grpc_sys::grpc_server_add_insecure_channel_from_fd(self.core.server, ptr::null_mut(), fd)
}
}
impl Drop for Server {
fn drop(&mut self) {
let f = if !self.core.shutdown.load(Ordering::SeqCst) {
Some(self.shutdown())
} else {
None
};
self.cancel_all_calls();
let _ = f.map(futures::executor::block_on);
}
}
impl Debug for Server {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
write!(f, "Server {:?}", self.core.binders)
}
}
#[cfg(test)]
mod tests {
use super::join_host_port;
#[test]
fn test_join_host_port() {
let tbl = vec![
("localhost", 0u16, "localhost:0\0"),
("127.0.0.1", 100u16, "127.0.0.1:100\0"),
("::1", 0u16, "[::1]:0\0"),
(
"fe80::7376:45d5:fb08:61e3",
10028u16,
"[fe80::7376:45d5:fb08:61e3]:10028\0",
),
];
for (h, p, e) in &tbl {
assert_eq!(join_host_port(h, *p), e.to_owned());
}
}
}