use std::{
collections::{HashMap, HashSet},
fmt::Display,
iter::FromIterator as _,
ops::Deref as _,
str::FromStr as _,
};
use crate::utils::add_extra_where_clauses;
use proc_macro2::{Ident, Span, TokenStream};
use quote::{quote, quote_spanned};
use syn::{
parse::{Error, Parser as _, Result},
punctuated::Punctuated,
spanned::Spanned as _,
};
pub fn expand(input: &syn::DeriveInput, trait_name: &str) -> Result<TokenStream> {
let trait_name = trait_name.trim_end_matches("Custom");
let trait_ident = syn::Ident::new(trait_name, Span::call_site());
let trait_path = "e!(::core::fmt::#trait_ident);
let trait_attr = trait_name_to_attribute_name(trait_name);
let type_params = input
.generics
.type_params()
.map(|t| t.ident.clone())
.collect();
let ParseResult {
arms,
bounds,
requires_helper,
} = State {
trait_path,
trait_attr,
input,
type_params,
}
.get_match_arms_and_extra_bounds()?;
let generics = if !bounds.is_empty() {
let bounds: Vec<_> = bounds
.into_iter()
.map(|(ty, trait_names)| {
let bounds: Vec<_> = trait_names
.into_iter()
.map(|bound| quote!(#bound))
.collect();
quote!(#ty: #(#bounds)+*)
})
.collect();
let where_clause = quote_spanned!(input.span()=> where #(#bounds),*);
add_extra_where_clauses(&input.generics, where_clause)
} else {
input.generics.clone()
};
let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
let name = &input.ident;
let helper_struct = if requires_helper {
display_as_helper_struct()
} else {
TokenStream::new()
};
Ok(quote! {
impl #impl_generics #trait_path for #name #ty_generics #where_clause
{
#[allow(unused_variables)]
#[inline]
fn fmt(&self, _derive_more_display_formatter: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
#helper_struct
match self {
#arms
_ => Ok(())
}
}
}
})
}
fn trait_name_to_attribute_name(trait_name: &str) -> &'static str {
match trait_name {
"Display" => "display",
"Binary" => "binary",
"Octal" => "octal",
"LowerHex" => "lower_hex",
"UpperHex" => "upper_hex",
"LowerExp" => "lower_exp",
"UpperExp" => "upper_exp",
"Pointer" => "pointer",
"Debug" => "debug",
_ => unimplemented!(),
}
}
fn attribute_name_to_trait_name(attribute_name: &str) -> &'static str {
match attribute_name {
"display" => "Display",
"binary" => "Binary",
"octal" => "Octal",
"lower_hex" => "LowerHex",
"upper_hex" => "UpperHex",
"lower_exp" => "LowerExp",
"upper_exp" => "UpperExp",
"pointer" => "Pointer",
_ => unreachable!(),
}
}
fn trait_name_to_trait_bound(trait_name: &str) -> syn::TraitBound {
let path_segments_iterator = vec!["core", "fmt", trait_name]
.into_iter()
.map(|segment| syn::PathSegment::from(Ident::new(segment, Span::call_site())));
syn::TraitBound {
lifetimes: None,
modifier: syn::TraitBoundModifier::None,
paren_token: None,
path: syn::Path {
leading_colon: Some(syn::Token![::](Span::call_site())),
segments: syn::punctuated::Punctuated::from_iter(path_segments_iterator),
},
}
}
fn display_as_helper_struct() -> TokenStream {
quote! {
struct _derive_more_DisplayAs<F>(F)
where
F: ::core::ops::Fn(&mut ::core::fmt::Formatter) -> ::core::fmt::Result;
const _derive_more_DisplayAs_impl: () = {
impl<F> ::core::fmt::Display for _derive_more_DisplayAs<F>
where
F: ::core::ops::Fn(&mut ::core::fmt::Formatter) -> ::core::fmt::Result
{
fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
(self.0)(f)
}
}
};
}
}
#[derive(Default)]
struct ParseResult {
arms: TokenStream,
bounds: HashMap<syn::Type, HashSet<syn::TraitBound>>,
requires_helper: bool,
}
struct State<'a, 'b> {
trait_path: &'b TokenStream,
trait_attr: &'static str,
input: &'a syn::DeriveInput,
type_params: HashSet<Ident>,
}
impl<'a, 'b> State<'a, 'b> {
fn get_proper_fmt_syntax(&self) -> impl Display {
format!(
r#"Proper syntax: #[{}(fmt = "My format", "arg1", "arg2")]"#,
self.trait_attr
)
}
fn get_proper_bound_syntax(&self) -> impl Display {
format!(
"Proper syntax: #[{}(bound = \"T, U: Trait1 + Trait2, V: Trait3\")]",
self.trait_attr
)
}
fn get_matcher(&self, fields: &syn::Fields) -> TokenStream {
match fields {
syn::Fields::Unit => TokenStream::new(),
syn::Fields::Unnamed(fields) => {
let fields: TokenStream = (0..fields.unnamed.len())
.map(|n| {
let i = Ident::new(&format!("_{}", n), Span::call_site());
quote!(#i,)
})
.collect();
quote!((#fields))
}
syn::Fields::Named(fields) => {
let fields: TokenStream = fields
.named
.iter()
.map(|f| {
let i = f.ident.as_ref().unwrap();
quote!(#i,)
})
.collect();
quote!({#fields})
}
}
}
fn find_meta(
&self,
attrs: &[syn::Attribute],
meta_key: &str,
) -> Result<Option<syn::Meta>> {
let mut iterator = attrs
.iter()
.filter_map(|attr| attr.parse_meta().ok())
.filter(|meta| {
let meta = match meta {
syn::Meta::List(meta) => meta,
_ => return false,
};
if !meta.path.is_ident(self.trait_attr) || meta.nested.is_empty() {
return false;
}
let meta = match &meta.nested[0] {
syn::NestedMeta::Meta(meta) => meta,
_ => return false,
};
let meta = match meta {
syn::Meta::NameValue(meta) => meta,
_ => return false,
};
meta.path.is_ident(meta_key)
});
let meta = iterator.next();
if iterator.next().is_none() {
Ok(meta)
} else {
Err(Error::new(meta.span(), "Too many attributes specified"))
}
}
fn parse_meta_bounds(
&self,
bounds: &syn::LitStr,
) -> Result<HashMap<syn::Type, HashSet<syn::TraitBound>>> {
let span = bounds.span();
let input = bounds.value();
let tokens = TokenStream::from_str(&input)?;
let parser = Punctuated::<syn::GenericParam, syn::Token![,]>::parse_terminated;
let generic_params = parser
.parse2(tokens)
.map_err(|error| Error::new(span, error.to_string()))?;
if generic_params.is_empty() {
return Err(Error::new(span, "No bounds specified"));
}
let mut bounds = HashMap::new();
for generic_param in generic_params {
let type_param = match generic_param {
syn::GenericParam::Type(type_param) => type_param,
_ => return Err(Error::new(span, "Only trait bounds allowed")),
};
if !self.type_params.contains(&type_param.ident) {
return Err(Error::new(
span,
"Unknown generic type argument specified",
));
} else if !type_param.attrs.is_empty() {
return Err(Error::new(span, "Attributes aren't allowed"));
} else if type_param.eq_token.is_some() || type_param.default.is_some() {
return Err(Error::new(span, "Default type parameters aren't allowed"));
}
let ident = type_param.ident.to_string();
let ty = syn::Type::Path(syn::TypePath {
qself: None,
path: type_param.ident.into(),
});
let bounds = bounds.entry(ty).or_insert_with(HashSet::new);
for bound in type_param.bounds {
let bound = match bound {
syn::TypeParamBound::Trait(bound) => bound,
_ => return Err(Error::new(span, "Only trait bounds allowed")),
};
if bound.lifetimes.is_some() {
return Err(Error::new(
span,
"Higher-rank trait bounds aren't allowed",
));
}
bounds.insert(bound);
}
if bounds.is_empty() {
return Err(Error::new(
span,
format!("No bounds specified for type parameter {}", ident),
));
}
}
Ok(bounds)
}
fn parse_meta_fmt(
&self,
meta: &syn::Meta,
outer_enum: bool,
) -> Result<(TokenStream, bool)> {
let list = match meta {
syn::Meta::List(list) => list,
_ => {
return Err(Error::new(meta.span(), self.get_proper_fmt_syntax()));
}
};
match &list.nested[0] {
syn::NestedMeta::Meta(syn::Meta::NameValue(syn::MetaNameValue {
path,
lit: syn::Lit::Str(fmt),
..
})) => match path {
op if op.segments.first().expect("path shouldn't be empty").ident
== "fmt" =>
{
if outer_enum {
if list.nested.iter().skip(1).count() != 0 {
return Err(Error::new(
list.nested[1].span(),
"`fmt` formatting requires a single `fmt` argument",
));
}
let fmt_string = match &list.nested[0] {
syn::NestedMeta::Meta(syn::Meta::NameValue(
syn::MetaNameValue {
path,
lit: syn::Lit::Str(s),
..
},
)) if path
.segments
.first()
.expect("path shouldn't be empty")
.ident
== "fmt" =>
{
s.value()
}
_ => unreachable!(),
};
let num_placeholders =
Placeholder::parse_fmt_string(&fmt_string).len();
if num_placeholders > 1 {
return Err(Error::new(
list.nested[1].span(),
"fmt string for enum should have at at most 1 placeholder",
));
}
if num_placeholders == 1 {
return Ok((quote_spanned!(fmt.span()=> #fmt), true));
}
}
let args = list
.nested
.iter()
.skip(1)
.try_fold(TokenStream::new(), |args, arg| {
let arg = match arg {
syn::NestedMeta::Lit(syn::Lit::Str(s)) => s,
syn::NestedMeta::Meta(syn::Meta::Path(i)) => {
return Ok(quote_spanned!(list.span()=> #args #i,));
}
_ => {
return Err(Error::new(
arg.span(),
self.get_proper_fmt_syntax(),
))
}
};
let arg: TokenStream =
arg.parse().map_err(|e| Error::new(arg.span(), e))?;
Ok(quote_spanned!(list.span()=> #args #arg,))
})?;
Ok((
quote_spanned!(meta.span()=> write!(_derive_more_display_formatter, #fmt, #args)),
false,
))
}
_ => Err(Error::new(
list.nested[0].span(),
self.get_proper_fmt_syntax(),
)),
},
_ => Err(Error::new(
list.nested[0].span(),
self.get_proper_fmt_syntax(),
)),
}
}
fn infer_fmt(&self, fields: &syn::Fields, name: &Ident) -> Result<TokenStream> {
let fields = match fields {
syn::Fields::Unit => {
return Ok(quote!(
_derive_more_display_formatter.write_str(stringify!(#name))
))
}
syn::Fields::Named(fields) => &fields.named,
syn::Fields::Unnamed(fields) => &fields.unnamed,
};
if fields.is_empty() {
return Ok(quote!(
_derive_more_display_formatter.write_str(stringify!(#name))
));
} else if fields.len() > 1 {
return Err(Error::new(
fields.span(),
"Can not automatically infer format for types with more than 1 field",
));
}
let trait_path = self.trait_path;
if let Some(ident) = &fields.iter().next().as_ref().unwrap().ident {
Ok(quote!(#trait_path::fmt(#ident, _derive_more_display_formatter)))
} else {
Ok(quote!(#trait_path::fmt(_0, _derive_more_display_formatter)))
}
}
fn get_match_arms_and_extra_bounds(&self) -> Result<ParseResult> {
let result: Result<_> = match &self.input.data {
syn::Data::Enum(e) => {
match self
.find_meta(&self.input.attrs, "fmt")
.and_then(|m| m.map(|m| self.parse_meta_fmt(&m, true)).transpose())?
{
Some((fmt, false)) => {
e.variants.iter().try_for_each(|v| {
if let Some(meta) = self.find_meta(&v.attrs, "fmt")? {
Err(Error::new(
meta.span(),
"`fmt` cannot be used on variant when the whole enum has a format string without a placeholder, maybe you want to add a placeholder?",
))
} else {
Ok(())
}
})?;
Ok(ParseResult {
arms: quote_spanned!(self.input.span()=> _ => #fmt,),
bounds: HashMap::new(),
requires_helper: false,
})
}
Some((outer_fmt, true)) => {
let fmt: Result<TokenStream> = e.variants.iter().try_fold(TokenStream::new(), |arms, v| {
let matcher = self.get_matcher(&v.fields);
let fmt = if let Some(meta) = self.find_meta(&v.attrs, "fmt")? {
self.parse_meta_fmt(&meta, false)?.0
} else {
self.infer_fmt(&v.fields, &v.ident)?
};
let name = &self.input.ident;
let v_name = &v.ident;
Ok(quote_spanned!(fmt.span()=> #arms #name::#v_name #matcher => write!(
_derive_more_display_formatter,
#outer_fmt,
_derive_more_DisplayAs(|_derive_more_display_formatter| #fmt)
),))
});
let fmt = fmt?;
Ok(ParseResult {
arms: quote_spanned!(self.input.span()=> #fmt),
bounds: HashMap::new(),
requires_helper: true,
})
}
None => e.variants.iter().try_fold(ParseResult::default(), |result, v| {
let ParseResult{ arms, mut bounds, requires_helper } = result;
let matcher = self.get_matcher(&v.fields);
let name = &self.input.ident;
let v_name = &v.ident;
let fmt: TokenStream;
let these_bounds: HashMap<_, _>;
if let Some(meta) = self.find_meta(&v.attrs, "fmt")? {
fmt = self.parse_meta_fmt(&meta, false)?.0;
these_bounds = self.get_used_type_params_bounds(&v.fields, &meta);
} else {
fmt = self.infer_fmt(&v.fields, v_name)?;
these_bounds = self.infer_type_params_bounds(&v.fields);
};
these_bounds.into_iter().for_each(|(ty, trait_names)| {
bounds.entry(ty).or_default().extend(trait_names)
});
let arms = quote_spanned!(self.input.span()=> #arms #name::#v_name #matcher => #fmt,);
Ok(ParseResult{ arms, bounds, requires_helper })
}),
}
}
syn::Data::Struct(s) => {
let matcher = self.get_matcher(&s.fields);
let name = &self.input.ident;
let fmt: TokenStream;
let bounds: HashMap<_, _>;
if let Some(meta) = self.find_meta(&self.input.attrs, "fmt")? {
fmt = self.parse_meta_fmt(&meta, false)?.0;
bounds = self.get_used_type_params_bounds(&s.fields, &meta);
} else {
fmt = self.infer_fmt(&s.fields, name)?;
bounds = self.infer_type_params_bounds(&s.fields);
}
Ok(ParseResult {
arms: quote_spanned!(self.input.span()=> #name #matcher => #fmt,),
bounds,
requires_helper: false,
})
}
syn::Data::Union(_) => {
let meta =
self.find_meta(&self.input.attrs, "fmt")?.ok_or_else(|| {
Error::new(
self.input.span(),
"Can not automatically infer format for unions",
)
})?;
let fmt = self.parse_meta_fmt(&meta, false)?.0;
Ok(ParseResult {
arms: quote_spanned!(self.input.span()=> _ => #fmt,),
bounds: HashMap::new(),
requires_helper: false,
})
}
};
let mut result = result?;
let meta = match self.find_meta(&self.input.attrs, "bound")? {
Some(meta) => meta,
_ => return Ok(result),
};
let span = meta.span();
let meta = match meta {
syn::Meta::List(meta) => meta.nested,
_ => return Err(Error::new(span, self.get_proper_bound_syntax())),
};
if meta.len() != 1 {
return Err(Error::new(span, self.get_proper_bound_syntax()));
}
let meta = match &meta[0] {
syn::NestedMeta::Meta(syn::Meta::NameValue(meta)) => meta,
_ => return Err(Error::new(span, self.get_proper_bound_syntax())),
};
let extra_bounds = match &meta.lit {
syn::Lit::Str(extra_bounds) => extra_bounds,
_ => return Err(Error::new(span, self.get_proper_bound_syntax())),
};
let extra_bounds = self.parse_meta_bounds(extra_bounds)?;
extra_bounds.into_iter().for_each(|(ty, trait_names)| {
result.bounds.entry(ty).or_default().extend(trait_names)
});
Ok(result)
}
fn get_used_type_params_bounds(
&self,
fields: &syn::Fields,
meta: &syn::Meta,
) -> HashMap<syn::Type, HashSet<syn::TraitBound>> {
if self.type_params.is_empty() {
return HashMap::new();
}
let fields_type_params: HashMap<syn::Path, _> = fields
.iter()
.enumerate()
.filter_map(|(i, field)| {
self.get_type_param(&field.ty).map(|ty| {
(
field
.ident
.clone()
.unwrap_or_else(|| {
Ident::new(&format!("_{}", i), Span::call_site())
})
.into(),
ty,
)
})
})
.collect();
if fields_type_params.is_empty() {
return HashMap::new();
}
let list = match meta {
syn::Meta::List(list) => list,
_ => unreachable!(),
};
let fmt_args: HashMap<_, _> = list
.nested
.iter()
.skip(1)
.enumerate()
.filter_map(|(i, arg)| match arg {
syn::NestedMeta::Lit(syn::Lit::Str(ref s)) => {
syn::parse_str(&s.value()).ok().map(|id| (i, id))
}
syn::NestedMeta::Meta(syn::Meta::Path(ref id)) => Some((i, id.clone())),
_ => unreachable!(),
})
.collect();
if fmt_args.is_empty() {
return HashMap::new();
}
let fmt_string = match &list.nested[0] {
syn::NestedMeta::Meta(syn::Meta::NameValue(syn::MetaNameValue {
path,
lit: syn::Lit::Str(s),
..
})) if path
.segments
.first()
.expect("path shouldn't be empty")
.ident
== "fmt" =>
{
s.value()
}
_ => unreachable!(),
};
Placeholder::parse_fmt_string(&fmt_string).into_iter().fold(
HashMap::new(),
|mut bounds, pl| {
if let Some(arg) = fmt_args.get(&pl.position) {
if fields_type_params.contains_key(arg) {
bounds
.entry(fields_type_params[arg].clone())
.or_insert_with(HashSet::new)
.insert(trait_name_to_trait_bound(pl.trait_name));
}
}
bounds
},
)
}
fn infer_type_params_bounds(
&self,
fields: &syn::Fields,
) -> HashMap<syn::Type, HashSet<syn::TraitBound>> {
if self.type_params.is_empty() {
return HashMap::new();
}
if let syn::Fields::Unit = fields {
return HashMap::new();
}
fields
.iter()
.take(1)
.filter_map(|field| {
self.get_type_param(&field.ty).map(|ty| {
(
ty,
[trait_name_to_trait_bound(attribute_name_to_trait_name(
self.trait_attr,
))]
.iter()
.cloned()
.collect(),
)
})
})
.collect()
}
fn get_type_param(&self, ty: &syn::Type) -> Option<syn::Type> {
if self.has_type_param_in(ty) {
match ty {
syn::Type::Reference(syn::TypeReference { elem: ty, .. }) => {
Some(ty.deref().clone())
}
ty => Some(ty.clone()),
}
} else {
None
}
}
fn has_type_param_in(&self, ty: &syn::Type) -> bool {
match ty {
syn::Type::Path(ty) => {
if let Some(qself) = &ty.qself {
if self.has_type_param_in(&qself.ty) {
return true;
}
}
if let Some(segment) = ty.path.segments.first() {
if self.type_params.contains(&segment.ident) {
return true;
}
}
ty.path.segments.iter().any(|segment| {
if let syn::PathArguments::AngleBracketed(arguments) =
&segment.arguments
{
arguments.args.iter().any(|argument| match argument {
syn::GenericArgument::Type(ty) => {
self.has_type_param_in(ty)
}
syn::GenericArgument::Constraint(constraint) => {
self.type_params.contains(&constraint.ident)
}
_ => false,
})
} else {
false
}
})
}
syn::Type::Reference(ty) => self.has_type_param_in(&ty.elem),
_ => false,
}
}
}
#[derive(Debug, PartialEq)]
struct Placeholder {
position: usize,
trait_name: &'static str,
}
impl Placeholder {
fn parse_fmt_string(s: &str) -> Vec<Placeholder> {
let mut n = 0;
crate::parsing::all_placeholders(s)
.into_iter()
.flatten()
.map(|m| {
let (maybe_arg, maybe_typ) = crate::parsing::format(m).unwrap();
let position = maybe_arg.unwrap_or_else(|| {
n += 1;
n - 1
});
let typ = maybe_typ.unwrap_or_default();
let trait_name = match typ {
"" => "Display",
"?" | "x?" | "X?" => "Debug",
"o" => "Octal",
"x" => "LowerHex",
"X" => "UpperHex",
"p" => "Pointer",
"b" => "Binary",
"e" => "LowerExp",
"E" => "UpperExp",
_ => unreachable!(),
};
Placeholder {
position,
trait_name,
}
})
.collect()
}
}
#[cfg(test)]
mod regex_maybe_placeholder_spec {
#[test]
fn parses_placeholders_and_omits_escaped() {
let fmt_string = "{}, {:?}, {{}}, {{{1:0$}}}";
let placeholders: Vec<_> = crate::parsing::all_placeholders(&fmt_string)
.into_iter()
.flat_map(|x| x)
.collect();
assert_eq!(placeholders, vec!["{}", "{:?}", "{1:0$}"]);
}
}
#[cfg(test)]
mod regex_placeholder_format_spec {
#[test]
fn detects_type() {
for (p, expected) in vec![
("{}", ""),
("{:?}", "?"),
("{:x?}", "x?"),
("{:X?}", "X?"),
("{:o}", "o"),
("{:x}", "x"),
("{:X}", "X"),
("{:p}", "p"),
("{:b}", "b"),
("{:e}", "e"),
("{:E}", "E"),
("{:.*}", ""),
("{8}", ""),
("{:04}", ""),
("{1:0$}", ""),
("{:width$}", ""),
("{9:>8.*}", ""),
("{2:.1$x}", "x"),
] {
let typ = crate::parsing::format(p).unwrap().1.unwrap_or_default();
assert_eq!(typ, expected);
}
}
#[test]
fn detects_arg() {
for (p, expected) in vec![
("{}", ""),
("{0:?}", "0"),
("{12:x?}", "12"),
("{3:X?}", "3"),
("{5:o}", "5"),
("{6:x}", "6"),
("{:X}", ""),
("{8}", "8"),
("{:04}", ""),
("{1:0$}", "1"),
("{:width$}", ""),
("{9:>8.*}", "9"),
("{2:.1$x}", "2"),
] {
let arg = crate::parsing::format(p)
.unwrap()
.0
.map(|s| s.to_string())
.unwrap_or_default();
assert_eq!(arg, String::from(expected));
}
}
}
#[cfg(test)]
mod placeholder_parse_fmt_string_spec {
use super::*;
#[test]
fn indicates_position_and_trait_name_for_each_fmt_placeholder() {
let fmt_string = "{},{:?},{{}},{{{1:0$}}}-{2:.1$x}{0:#?}{:width$}";
assert_eq!(
Placeholder::parse_fmt_string(&fmt_string),
vec![
Placeholder {
position: 0,
trait_name: "Display",
},
Placeholder {
position: 1,
trait_name: "Debug",
},
Placeholder {
position: 1,
trait_name: "Display",
},
Placeholder {
position: 2,
trait_name: "LowerHex",
},
Placeholder {
position: 0,
trait_name: "Debug",
},
Placeholder {
position: 2,
trait_name: "Display",
},
],
)
}
}