diff --git a/spec-trait-impl/crates/spec-trait-bin/src/examples/mod.rs b/spec-trait-impl/crates/spec-trait-bin/src/examples/mod.rs index bda4c07..eced03f 100644 --- a/spec-trait-impl/crates/spec-trait-bin/src/examples/mod.rs +++ b/spec-trait-impl/crates/spec-trait-bin/src/examples/mod.rs @@ -1,19 +1,19 @@ mod all; mod base; +mod higher; mod lifetimes; -mod nested; mod order; +mod repeating; mod sself; mod traits; mod types; -mod higher; pub fn run() { base::run(); types::run(); traits::run(); lifetimes::run(); - nested::run(); + repeating::run(); sself::run(); higher::run(); order::run(); diff --git a/spec-trait-impl/crates/spec-trait-bin/src/examples/nested.rs b/spec-trait-impl/crates/spec-trait-bin/src/examples/nested.rs deleted file mode 100644 index e6521bd..0000000 --- a/spec-trait-impl/crates/spec-trait-bin/src/examples/nested.rs +++ /dev/null @@ -1,37 +0,0 @@ -use spec_trait_macro::{spec, when}; - -struct BaseType; - -trait NestedTrait { - fn nested_method(&self, _x: T, _y: U); -} - -impl NestedTrait for BaseType { - fn nested_method(&self, _x: T, _y: U) { - println!("Default"); - } -} - -#[when(all(U = Vec, T = u8))] -impl NestedTrait for BaseType { - fn nested_method(&self, _x: T, _y: U) { - println!("T is u8 and U is Vec"); - } -} - -#[when(U = Vec)] -impl NestedTrait for BaseType { - fn nested_method(&self, _x: T, _y: U) { - println!("U is Vec"); - } -} - -pub fn run() { - println!("\n- Nested Examples:"); - - let x = BaseType; - - spec! { x.nested_method(1u8, vec![2u8]); BaseType; [u8, Vec] } // "T is u8 and U is Vec" - spec! { x.nested_method(1i32, vec![2i32]); BaseType; [i32, Vec] } // "U is Vec" - spec! { x.nested_method(1i32, 2i32); BaseType; [i32, i32] } // -> "Default" -} diff --git a/spec-trait-impl/crates/spec-trait-bin/src/examples/repeating.rs b/spec-trait-impl/crates/spec-trait-bin/src/examples/repeating.rs new file mode 100644 index 0000000..a89d3b6 --- /dev/null +++ b/spec-trait-impl/crates/spec-trait-bin/src/examples/repeating.rs @@ -0,0 +1,103 @@ +use spec_trait_macro::{spec, when}; + +struct BaseType; + +trait RepeatingGenericsTrait { + fn repeating_generics_method(&self, _x: T, _y: U); +} + +impl RepeatingGenericsTrait for BaseType { + fn repeating_generics_method(&self, _x: T, _y: U) { + println!("Default"); + } +} + +#[when(all(U = Vec, T = u8))] +impl RepeatingGenericsTrait for BaseType { + fn repeating_generics_method(&self, _x: T, _y: U) { + println!("T is u8 and U is Vec"); + } +} + +#[when(U = Vec)] +impl RepeatingGenericsTrait for BaseType { + fn repeating_generics_method(&self, _x: T, _y: U) { + println!("U is Vec"); + } +} + +#[when(T = Vec)] +impl RepeatingGenericsTrait for BaseType { + fn repeating_generics_method(&self, _x: T, _y: U) { + println!("T is Vec"); + } +} + +#[when(T: Copy)] +impl RepeatingGenericsTrait for BaseType { + fn repeating_generics_method(&self, _x: T, _y: T) { + println!("T implements Copy for both parameters"); + } +} + +#[when(all(T = String, T = i32))] +impl RepeatingGenericsTrait for BaseType { + fn repeating_generics_method(&self, _x: T, _y: T) { + panic!("T cannot be both String and i32"); + } +} + +#[when(all(T = &str, T: 'a, U = &'a i32))] +impl<'a, T, U> RepeatingGenericsTrait for BaseType { + fn repeating_generics_method(&self, _x: T, _y: U) { + println!("T and U have the same lifetime"); + } +} + +#[when(all(T = &str, T: 'a, U = &'b i32))] +impl<'a, 'b, T, U> RepeatingGenericsTrait for BaseType { + fn repeating_generics_method(&self, _x: T, _y: U) { + println!("T and U have different lifetimes"); + } +} + +#[when(all(T = &str, T: 'static, U = for<'a> fn(T, &'a i32)))] +impl RepeatingGenericsTrait for BaseType { + fn repeating_generics_method(&self, x: T, y: U) { + println!("T is &'static str and U is for<'a> fn(T, &'a i32)"); + } +} + +#[when(all(T = &str, T: 'b, U = for<'a> fn(T, &'a i32)))] +impl<'b, T, U> RepeatingGenericsTrait for BaseType { + fn repeating_generics_method(&self, x: T, y: U) { + println!("T is &'b str and U is for<'a> fn(T, &'a i32)"); + } +} + +pub fn run() { + println!("\n- Repeating Generics Examples:"); + + let x = BaseType; + + spec! { x.repeating_generics_method(1u8, vec![2u8]); BaseType; [u8, Vec] } // "T is u8 and U is Vec" + spec! { x.repeating_generics_method(1i32, vec![2i32]); BaseType; [i32, Vec] } // "U is Vec" + spec! { x.repeating_generics_method(vec![1i32], 2i32); BaseType; [Vec, i32] } // "T is Vec" + spec! { x.repeating_generics_method(1i32, vec![2u8]); BaseType; [i32, Vec] } // "Default" + spec! { x.repeating_generics_method(vec![1i32], 2u8); BaseType; [Vec, u8] } // "Default" + + spec! { x.repeating_generics_method(1i32, 2i32); BaseType; [i32, i32]; i32: Copy } // "T implements Copy for both parameters" + spec! { x.repeating_generics_method(1i32, 2u8); BaseType; [i32, u8]; i32: Copy; u8: Copy } // "Default" + spec! { x.repeating_generics_method(1i32, 2u8); BaseType; [i32, u8]; i32: Copy } // "Default" + + spec! { x.repeating_generics_method("test", &2i32); BaseType; [&'static str, &'static i32] } // "T and U have the same lifetimes" + spec! { x.repeating_generics_method("test", &2i32); BaseType; [&'a str, &'a i32] } // "T and U have the same lifetimes" + spec! { x.repeating_generics_method("test", &2i32); BaseType; [&'static str, &'a i32] } // "T and U have different lifetimes" + spec! { x.repeating_generics_method("test", &2i32); BaseType; [&'a str, &'b i32] } // "T and U have different lifetimes" + + let lambda = |_a: &str, _b: &i32| {}; + spec! { x.repeating_generics_method("test", lambda); BaseType; [&'static str, for<'a> fn(&'static str, &'a i32)] } // "T is &'static str and U is for<'a> fn(T, &'a i32)" + spec! { x.repeating_generics_method("test", lambda); BaseType; [&str, for<'a> fn(&str, &'a i32)] } // "T is &'b str and U is for<'a> fn(T, &'a i32)" + spec! { x.repeating_generics_method("test", lambda); BaseType; [&'static str, for<'a> fn(&str, &'a i32)] } // "Default" + spec! { x.repeating_generics_method("test", lambda); BaseType; [&'b str, for<'a> fn(&'c str, &'a i32)] } // "Default" +} diff --git a/spec-trait-impl/crates/spec-trait-bin/src/examples/sself.rs b/spec-trait-impl/crates/spec-trait-bin/src/examples/sself.rs index 89404d0..d42b378 100644 --- a/spec-trait-impl/crates/spec-trait-bin/src/examples/sself.rs +++ b/spec-trait-impl/crates/spec-trait-bin/src/examples/sself.rs @@ -60,7 +60,7 @@ impl SelfTrait for (U, V) { } pub fn run() { - println!("\n- Self Examples:"); + println!("\n- Self Type Examples:"); let x = BaseType1; spec! { x.self_method(42u8); BaseType1; [u8] } // -> "BaseType1: T is u8" diff --git a/spec-trait-impl/crates/spec-trait-bin/src/examples/types.rs b/spec-trait-impl/crates/spec-trait-bin/src/examples/types.rs index b314cb5..9dd4f03 100644 --- a/spec-trait-impl/crates/spec-trait-bin/src/examples/types.rs +++ b/spec-trait-impl/crates/spec-trait-bin/src/examples/types.rs @@ -3,6 +3,9 @@ use spec_trait_macro::{spec, when}; type MyAlias = u8; type MyVecAlias = Vec; +trait Tr {} +impl Tr for u8 {} + struct BaseType; trait TypesTrait { @@ -64,6 +67,13 @@ impl TypesTrait for BaseType { } } +#[when(T = &dyn Tr)] +impl TypesTrait for BaseType { + fn types_method(&self, _x: T) { + println!("T is &dyn Tr"); + } +} + pub fn run() { println!("\n- Types Examples:"); @@ -77,5 +87,6 @@ pub fn run() { spec! { x.types_method((1, 2)); BaseType; [(i32, i32)] } // -> "T is (i32, _)" spec! { x.types_method(&[1i32]); BaseType; [&[i32]] } // -> "T is &[i32]" spec! { x.types_method(|x: &u8| x); BaseType; [fn(&u8) -> &u8] } // -> "T is a function pointer from &u8 to &u8" + spec! { x.types_method(&1u8); BaseType; [&dyn Tr] } // -> "T is & dyn Tr" spec! { x.types_method(1i8); BaseType; [i8] } // -> "Default" } diff --git a/spec-trait-impl/crates/spec-trait-macro/src/constraints.rs b/spec-trait-impl/crates/spec-trait-macro/src/constraints.rs index 3f807f7..ba7e903 100644 --- a/spec-trait-impl/crates/spec-trait-macro/src/constraints.rs +++ b/spec-trait-impl/crates/spec-trait-macro/src/constraints.rs @@ -1,6 +1,9 @@ +use crate::SpecBody; use proc_macro2::TokenStream; use spec_trait_utils::conversions::{str_to_generics, str_to_type_name, to_string}; -use spec_trait_utils::parsing::get_generics_types; +use spec_trait_utils::impls::ImplBody; +use spec_trait_utils::parsing::{count_generics, get_generics_types}; +use spec_trait_utils::traits::TraitBody; use spec_trait_utils::types::{Aliases, replace_type, strip_lifetimes, type_assignable}; use std::cmp::Ordering; use std::collections::HashMap; @@ -38,6 +41,7 @@ impl Ord for Constraint { .then(self.traits.len().cmp(&other.traits.len())) .then(self.not_types.len().cmp(&other.not_types.len())) .then(self.not_traits.len().cmp(&other.not_traits.len())) + .then(count_generics(&other.generics).cmp(&count_generics(&self.generics))) } } @@ -174,6 +178,38 @@ impl FromIterator<(String, Constraint)> for Constraints { } } +impl Constraints { + /// fills the constraints for each generic parameter from trait and specialized trait, + /// based on the constraints from impl and type + pub fn fill_trait_constraints(&self, spec: &SpecBody) -> Constraints { + let mut constraints = self.clone(); + constraints.from_trait = self.get_trait_constraints(&spec.trait_, &spec.impl_); + constraints.from_specialized_trait = self.get_trait_constraints( + spec.trait_.specialized.as_ref().unwrap(), + spec.impl_.specialized.as_ref().unwrap(), + ); + constraints.from_type = Constraint { + generics: spec.impl_.impl_generics.clone(), + type_: Some(spec.impl_.type_name.clone()), + ..Default::default() + }; + constraints + } + + /// gets the constraints for each generic parameter from trait based on the constraints from impl + fn get_trait_constraints(&self, trait_: &TraitBody, impl_: &ImplBody) -> ConstraintMap { + let mut new_constraints = ConstraintMap::new(); + for trait_generic in get_generics_types::>(&trait_.generics) { + let corresponding_generic = impl_ + .get_corresponding_generic(&str_to_generics(&trait_.generics), &trait_generic) + .unwrap(); + let constraint = self.from_impl.get(&corresponding_generic); + new_constraints.insert(trait_generic, constraint.cloned().unwrap_or_default()); + } + new_constraints + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/spec-trait-impl/crates/spec-trait-macro/src/spec.rs b/spec-trait-impl/crates/spec-trait-macro/src/spec.rs index 437f0ca..a210a52 100644 --- a/spec-trait-impl/crates/spec-trait-macro/src/spec.rs +++ b/spec-trait-impl/crates/spec-trait-macro/src/spec.rs @@ -1,17 +1,15 @@ use crate::annotations::AnnotationBody; -use crate::constraints::{Constraint, ConstraintMap, Constraints}; +use crate::constraints::Constraints; use crate::vars::VarBody; use proc_macro2::TokenStream; use quote::quote; use spec_trait_utils::conditions::WhenCondition; -use spec_trait_utils::conversions::{ - str_to_expr, str_to_generics, str_to_trait_name, str_to_type_name, to_string, -}; +use spec_trait_utils::conversions::{str_to_expr, str_to_trait_name, str_to_type_name, to_string}; use spec_trait_utils::impls::ImplBody; use spec_trait_utils::parsing::{get_generics_lifetimes, get_generics_types}; use spec_trait_utils::traits::TraitBody; use spec_trait_utils::types::{ - assign_lifetimes, get_concrete_type, trait_assignable, type_assignable, + GenericsMap, assign_lifetimes, get_concrete_type, trait_assignable, type_assignable, type_assignable_generic_constraints, }; use std::cmp::Ordering; @@ -101,12 +99,12 @@ fn get_constraints(default: SpecBody) -> Option { None => Some(default), // from when macro Some(cond) => { - let var = VarBody::from(&default); // TODO: handle conflicting vars (early return None) + let var = VarBody::try_from(&default).ok()?; let (satisfied, constraints) = satisfies_condition(cond, &var, &default.constraints); if satisfied { let mut with_constraints = default.clone(); - with_constraints.constraints = fill_trait_constraints(&constraints, &default); + with_constraints.constraints = constraints.fill_trait_constraints(&default); Some(with_constraints) } else { None @@ -161,9 +159,9 @@ fn satisfies_condition( .iter() .any(|t| type_assignable(&declared_type, t, &var.generics, &var.aliases)) || // generic parameter should implement a trait that the type does not implement - declared_type_var.is_none_or(|v| - constraint.traits.iter().any(|t| !v.traits.contains(t)) - ); + declared_type_var.is_none_or(|v| constraint.traits.iter().any(|t| !v.traits.contains(t))) || + // the condition implies a different type for some other generic than what we already inferred + has_cross_generic_conflict(generic, &declared_type, var); constraint.generics = var.generics.clone(); if violates_constraints { @@ -266,39 +264,61 @@ fn satisfies_condition( } } -/// fills the constraints for each generic parameter from trait and specialized trait, -/// based on the constraints from impl and type -fn fill_trait_constraints(constraints: &Constraints, spec: &SpecBody) -> Constraints { - let mut constraints = constraints.clone(); - constraints.from_trait = get_trait_constraints(&spec.trait_, &spec.impl_, &constraints); - constraints.from_specialized_trait = get_trait_constraints( - spec.trait_.specialized.as_ref().unwrap(), - spec.impl_.specialized.as_ref().unwrap(), - &constraints, - ); - constraints.from_type = Constraint { - generics: spec.impl_.impl_generics.clone(), - type_: Some(spec.impl_.type_name.clone()), - ..Default::default() +/** + Checks if the condition implies a different type for some other generic than what we already inferred + # Example + For condition `U = Vec` with concrete types `T = i32` and `U = Vec`. + The condition implies `T = u8` which conflicts with `T = i32` +*/ +fn has_cross_generic_conflict(generic: &str, declared_type: &str, var: &VarBody) -> bool { + let Some(generic_var) = var.vars.iter().find(|v| v.impl_generic == generic) else { + return false; }; - constraints + + let Some(generics_map) = type_assignable_generic_constraints( + &generic_var.concrete_type, + declared_type, + &var.generics, + &var.aliases, + ) else { + return false; + }; + + has_inner_generic_conflict(generic, &generics_map.types, var) + || has_inner_generic_conflict(generic, &generics_map.lifetimes, var) } -/// gets the constraints for each generic parameter from trait based on the constraints from impl -fn get_trait_constraints( - trait_: &TraitBody, - impl_: &ImplBody, - constraints: &Constraints, -) -> ConstraintMap { - let mut new_constraints = ConstraintMap::new(); - for trait_generic in get_generics_types::>(&trait_.generics) { - let corresponding_generic = impl_ - .get_corresponding_generic(&str_to_generics(&trait_.generics), &trait_generic) - .unwrap(); - let constraint = constraints.from_impl.get(&corresponding_generic); - new_constraints.insert(trait_generic, constraint.cloned().unwrap_or_default()); - } - new_constraints +fn has_inner_generic_conflict(generic: &str, generic_map: &GenericsMap, var: &VarBody) -> bool { + generic_map + .iter() + .filter_map(|(inner_generic, constraint)| { + constraint + .as_ref() + .map(|concrete_type| (inner_generic, concrete_type)) + }) + .any(|(inner_generic, concrete_type)| { + if inner_generic == generic { + return false; + } + + // If we already know a concrete type for this inner generic from vars, + // it must be compatible with the one implied by the condition. + if let Some(existing) = var.vars.iter().find(|v| v.impl_generic == *inner_generic) { + !type_assignable( + &existing.concrete_type, + concrete_type, + &var.generics, + &var.aliases, + ) || !type_assignable( + concrete_type, + &existing.concrete_type, + &var.generics, + &var.aliases, + ) + } else { + false + } + }) } impl From<&SpecBody> for TokenStream { @@ -593,6 +613,33 @@ mod tests { assert!(c.not_types.contains(&"u32".to_string())); } + #[test] + fn cross_generic_conflict() { + let condition = WhenCondition::Type("U".into(), "Vec".into()); + let var = VarBody { + aliases: Aliases::default(), + generics: "".to_string(), + vars: vec![ + VarInfo { + impl_generic: "T".into(), + trait_generic: Some("T".into()), + concrete_type: "i32".into(), + traits: vec![], + }, + VarInfo { + impl_generic: "U".into(), + trait_generic: Some("U".into()), + concrete_type: "Vec".into(), + traits: vec![], + }, + ], + }; + + let (satisfies, _) = satisfies_condition(&condition, &var, &Constraints::default()); + + assert!(!satisfies); + } + #[test] fn default_impl() { let impls = vec![get_impl_body(None)]; @@ -812,7 +859,6 @@ mod tests { assert!(!result.is_ok()); } - #[test] fn test_fill_trait_constraints() { let impl_ = get_impl_body(None); @@ -838,7 +884,7 @@ mod tests { constraints: Constraints::default(), }; - let filled = fill_trait_constraints(&constraints, &spec); + let filled = constraints.fill_trait_constraints(&spec); let c = filled.from_trait.get("A").unwrap(); let sc = filled.from_specialized_trait.get("__G_0__").unwrap(); @@ -869,7 +915,7 @@ mod tests { constraints: Constraints::default(), }; - let filled = fill_trait_constraints(&constraints, &spec); + let filled = constraints.fill_trait_constraints(&spec); let c = filled.from_trait.get("A").unwrap(); assert_eq!(c, &Constraint::default()); diff --git a/spec-trait-impl/crates/spec-trait-macro/src/vars.rs b/spec-trait-impl/crates/spec-trait-macro/src/vars.rs index 8860c34..351d23b 100644 --- a/spec-trait-impl/crates/spec-trait-macro/src/vars.rs +++ b/spec-trait-impl/crates/spec-trait-macro/src/vars.rs @@ -1,16 +1,16 @@ -use std::collections::HashSet; - use crate::SpecBody; use crate::annotations::{Annotation, AnnotationBody}; use spec_trait_utils::conversions::{ str_to_generics, str_to_lifetime, str_to_type_name, to_string, }; use spec_trait_utils::impls::ImplBody; -use spec_trait_utils::parsing::get_generics_types; +use spec_trait_utils::parsing::{get_generics_lifetimes, get_generics_types}; use spec_trait_utils::traits::TraitBody; use spec_trait_utils::types::{ - Aliases, get_concrete_type, type_assignable, type_assignable_generic_constraints, type_contains, + Aliases, GenericsMap, get_concrete_type, type_assignable, type_assignable_generic_constraints, + type_contains, }; +use std::collections::HashSet; use syn::{FnArg, TraitItemFn, Type}; #[derive(Debug, Clone, PartialEq, Eq, Hash)] @@ -35,16 +35,18 @@ pub struct VarBody { pub vars: Vec, } -impl From<&SpecBody> for VarBody { - fn from(spec: &SpecBody) -> Self { +impl TryFrom<&SpecBody> for VarBody { + type Error = String; + + fn try_from(spec: &SpecBody) -> Result { let aliases = get_type_aliases(&spec.annotations.annotations); let generics = spec.impl_.impl_generics.clone(); - let vars = get_vars(&spec.annotations, &spec.impl_, &spec.trait_, &aliases); - VarBody { + let vars = get_vars(&spec.annotations, &spec.impl_, &spec.trait_, &aliases)?; + Ok(VarBody { aliases, generics, vars, - } + }) } } @@ -70,69 +72,93 @@ fn get_vars( impl_: &ImplBody, trait_: &TraitBody, aliases: &Aliases, -) -> Vec { - get_generics_types::>(&impl_.impl_generics) - .iter() - .flat_map(|g| { - let from_type = get_generic_constraints_from_type(g, impl_, ann, aliases); - let from_type_specialized = get_generic_constraints_from_type( - g, - impl_.specialized.as_ref().unwrap(), - ann, - aliases, - ); - - match trait_.get_corresponding_generic(&str_to_generics(&impl_.trait_generics), g) { - // get type from trait (generic already existed before specialization) - Some(trait_generic) => { +) -> Result, String> { + let generics = get_generics_types::>(&impl_.impl_generics); + let lifetimes = get_generics_lifetimes::>(&impl_.impl_generics); + + let mut res: HashSet = HashSet::new(); + + for g in generics.iter().chain(lifetimes.iter()) { + let from_type = get_generic_constraints_from_type(g, impl_, ann, aliases); + let from_type_specialized = + get_generic_constraints_from_type(g, impl_.specialized.as_ref().unwrap(), ann, aliases); + + let corresponding_generics = + trait_.get_corresponding_generics(&str_to_generics(&impl_.trait_generics), g); + + if corresponding_generics.len() > 1 { + let mut resulting_vars = Vec::new(); + + for trait_generic in &corresponding_generics { + let mut from_trait = + get_generic_constraints_from_trait(trait_generic, trait_, impl_, ann, aliases)?; + resulting_vars.append(&mut from_trait); + } + + resulting_vars.extend(from_type.iter().cloned()); + + let conflicting_vars = resulting_vars + .iter() + .map(|v| &v.concrete_type) + .collect::>(); + + if conflicting_vars.len() > 1 { + return Err(format!( + "Multiple conflicting constraints found for generic {}: {:?}", + g, conflicting_vars + )); + } + } + + match corresponding_generics.first() { + // get type from trait (generic already existed before specialization) + Some(trait_generic) => { + let from_trait = + get_generic_constraints_from_trait(trait_generic, trait_, impl_, ann, aliases)?; + + res.extend(from_trait.into_iter().chain(from_type)); + } + + // get from specialized trait (generic was added during specialization) + None => { + let trait_generic = trait_ + .specialized + .as_ref() + .unwrap() + .get_corresponding_generic( + &str_to_generics(&impl_.specialized.as_ref().unwrap().trait_generics), + g, + ); + + if let Some(trait_generic) = trait_generic { let from_trait = get_generic_constraints_from_trait( &trait_generic, - trait_, - impl_, + trait_.specialized.as_ref().unwrap(), + impl_.specialized.as_ref().unwrap(), ann, aliases, - ); - - from_trait.into_iter().chain(from_type).collect::>() - } - - // get from specialized trait (generic was added during specialization) - None => { - let trait_generic = trait_ - .specialized - .as_ref() - .unwrap() - .get_corresponding_generic( - &str_to_generics(&impl_.specialized.as_ref().unwrap().trait_generics), - g, - ); - - if let Some(trait_generic) = trait_generic { - let from_trait = get_generic_constraints_from_trait( - &trait_generic, - trait_.specialized.as_ref().unwrap(), - impl_.specialized.as_ref().unwrap(), - ann, - aliases, - ); + )?; + res.extend( from_trait .into_iter() .chain(from_type_specialized) - .collect::>() - } else { - // get from type only + .collect::>(), + ); + } else { + // get from type only + res.extend( from_type .into_iter() .chain(from_type_specialized) - .collect::>() - } + .collect::>(), + ); } } - }) - .collect::>() - .into_iter() - .collect() + } + } + + Ok(res.into_iter().collect()) } /** @@ -152,17 +178,18 @@ fn get_param_types(trait_fn: &TraitItemFn) -> Vec { .collect() } +/// Get generic constraints from the trait function parameters fn get_generic_constraints_from_trait( trait_generic: &str, trait_: &TraitBody, impl_: &ImplBody, ann: &AnnotationBody, aliases: &Aliases, -) -> Vec { +) -> Result, String> { let trait_fn = trait_.find_fn(&ann.fn_, ann.args.len()).unwrap(); let param_types = get_param_types(&trait_fn); - // find all params that use the generic + // find all params that use the generic (type or lifetime) let params_with_trait_generic = param_types .iter() .enumerate() @@ -171,40 +198,104 @@ fn get_generic_constraints_from_trait( // generic passed but not used if params_with_trait_generic.is_empty() { - return vec![]; + return Ok(vec![]); } - let (pos, trait_type_definition) = params_with_trait_generic.first().unwrap(); - let concrete_type = &ann.args_types[*pos]; + // unify constraints for all generics (types + lifetimes) appearing in those params + let mut unified_generics = GenericsMap::new(); - let mut res = HashSet::new(); + for (pos, trait_type_definition) in params_with_trait_generic { + let concrete_type = &ann.args_types[pos]; - let constrained_generics = type_assignable_generic_constraints( - concrete_type, - trait_type_definition, - &trait_.generics, - aliases, - ); + let constrained_generics = type_assignable_generic_constraints( + concrete_type, + trait_type_definition, + &trait_.generics, + aliases, + ) + .ok_or_else(|| { + format!( + "Type {} is not assignable to parameter {} of function {} (expected {})", + concrete_type, pos, ann.fn_, trait_type_definition + ) + })?; + + // unify type generics + for (generic, constraint) in constrained_generics.types { + if let Some(existing) = unified_generics.get(&generic) { + match (existing, &constraint) { + (Some(e), Some(c)) if e != c => { + return Err(format!( + "Multiple conflicting constraints found for generic {}: {:?} vs {:?}", + generic, e, c + )); + } + (None, Some(_)) => { + unified_generics.insert(generic, constraint); + } + _ => {} + } + } else { + unified_generics.insert(generic, constraint); + } + } - if let Some(generics_map) = constrained_generics { - for (generic, constraint) in generics_map.types { - if let Some(constraint) = constraint { - let impl_generic = impl_ - .get_corresponding_generic(&str_to_generics(&trait_.generics), &generic) - .unwrap(); - res.insert((constraint, impl_generic, generic)); + // unify lifetime generics + for (generic, constraint) in constrained_generics.lifetimes { + if let Some(existing) = unified_generics.get(&generic) { + match (existing, &constraint) { + (Some(e), Some(c)) if e != c => { + return Err(format!( + "Multiple conflicting constraints found for lifetime generic {}: {:?} vs {:?}", + generic, e, c + )); + } + (None, Some(_)) => { + unified_generics.insert(generic, constraint); + } + _ => {} + } + } else { + unified_generics.insert(generic, constraint); } } } - res.into_iter() - .map(|(constraint, impl_generic, trait_generic)| VarInfo { - impl_generic, - trait_generic: Some(trait_generic), - concrete_type: get_concrete_type_with_lifetime(&constraint, &ann.annotations, aliases), - traits: get_type_traits(&constraint, &ann.annotations, aliases), - }) - .collect::>() + // we now build VarInfo for both type and lifetime generics + let mut res: HashSet = HashSet::new(); + + let lifetime_generics = get_generics_lifetimes::>(&trait_.generics); + let trait_generics = str_to_generics(&trait_.generics); + + for (generic, constraint) in unified_generics { + if let Some(constraint) = constraint + && let Some(impl_generic) = impl_.get_corresponding_generic(&trait_generics, &generic) + { + let is_lifetime = lifetime_generics.contains(&generic); + + let concrete_type = if is_lifetime { + // for lifetime generics, keep the lifetime itself as "concrete type" + constraint.clone() + } else { + get_concrete_type_with_lifetime(&constraint, &ann.annotations, aliases) + }; + + let traits = if is_lifetime { + Vec::new() + } else { + get_type_traits(&constraint, &ann.annotations, aliases) + }; + + res.insert(VarInfo { + impl_generic, + trait_generic: Some(generic), + concrete_type, + traits, + }); + } + } + + Ok(res.into_iter().collect()) } fn get_generic_constraints_from_type( @@ -344,7 +435,7 @@ mod tests { let impl_body = ImplBody::try_from(( syn ::parse_str::( - "impl MyTrait for V { fn foo(&self, x: T, y: u32, z: Vec) {} }" + "impl<'a, W, T, U: Debug, V> MyTrait<'a, T, U> for V { fn foo(&self, x: T, y: &'a u32, z: Vec) {} }" ) .unwrap(), None, @@ -352,7 +443,7 @@ mod tests { let trait_body = TraitBody::try_from( syn::parse_str::( - "trait MyTrait { fn foo(&self, x: A, y: u32, z: Vec); }", + "trait MyTrait<'b, A, B> { fn foo(&self, x: A, y: &'b u32, z: Vec); }", ) .unwrap(), ) @@ -363,10 +454,14 @@ mod tests { fn_: "foo".to_string(), args_types: vec![ "i32".to_string(), - "u32".to_string(), + "&'static u32".to_string(), "Vec<&'static i32>".to_string(), ], - args: vec!["1i32".to_string(), "2u32".to_string(), "vec![]".to_string()], + args: vec![ + "1i32".to_string(), + "&2u32".to_string(), + "vec![]".to_string(), + ], var: "x".to_string(), var_type: "MyType".to_string(), annotations: vec![Annotation::Trait("i32".into(), vec!["Debug".into()])], @@ -374,12 +469,13 @@ mod tests { let aliases = Aliases::new(); - let result = get_vars(&ann, &impl_body, &trait_body, &aliases); + let result = get_vars(&ann, &impl_body, &trait_body, &aliases).unwrap(); - assert_eq!(result.len(), 3); + assert_eq!(result.len(), 4); let t = result.iter().find(|v| v.impl_generic == "T").unwrap(); let u = result.iter().find(|v| v.impl_generic == "U").unwrap(); let v = result.iter().find(|v| v.impl_generic == "V").unwrap(); + let a = result.iter().find(|v| v.impl_generic == "'a").unwrap(); assert_eq!( t, &(VarInfo { @@ -407,6 +503,15 @@ mod tests { traits: vec![], }) ); + assert_eq!( + a, + &(VarInfo { + impl_generic: "'a".to_string(), + trait_generic: Some("'b".to_string()), + concrete_type: "'static".to_string(), + traits: vec![], + }) + ); } #[test] @@ -420,7 +525,7 @@ mod tests { Some( WhenCondition::All( vec![ - WhenCondition::Type("W".into(), "Vec".into()), + WhenCondition::Type("W".into(), "&Vec".into()), WhenCondition::Trait("V".into(), vec!["Debug".into()]) ] ) @@ -461,13 +566,13 @@ mod tests { let aliases = Aliases::new(); - let result = get_vars(&ann, &impl_body, &trait_body, &aliases); + let result = get_vars(&ann, &impl_body, &trait_body, &aliases).unwrap(); println!("{:#?}", result); - assert_eq!(result.len(), 5); + assert_eq!(result.len(), 6); let t = result.iter().find(|v| v.impl_generic == "T").unwrap(); let u = result.iter().find(|v| v.impl_generic == "U").unwrap(); - let v = result.iter().find(|v| v.impl_generic == "V"); + let v = result.iter().find(|v| v.impl_generic == "V").unwrap(); let w = result.iter().find(|v| v.impl_generic == "W").unwrap(); let x = result.iter().find(|v| v.impl_generic == "X").unwrap(); let y = result.iter().find(|v| v.impl_generic == "Y").unwrap(); @@ -489,7 +594,15 @@ mod tests { traits: vec![], }) ); - assert!(v.is_none()); + assert_eq!( + v, + &(VarInfo { + impl_generic: "V".to_string(), + trait_generic: Some("V".to_string()), + concrete_type: "i32".to_string(), + traits: vec![], + }) + ); assert_eq!( w, &(VarInfo { diff --git a/spec-trait-impl/crates/spec-trait-utils/src/impls.rs b/spec-trait-impl/crates/spec-trait-utils/src/impls.rs index 18e8f6f..7278592 100644 --- a/spec-trait-impl/crates/spec-trait-utils/src/impls.rs +++ b/spec-trait-impl/crates/spec-trait-utils/src/impls.rs @@ -8,10 +8,10 @@ use crate::parsing::{ parse_generics, }; use crate::specialize::{ - Specializable, TypeReplacer, add_generic_lifetime, add_generic_type, apply_type_condition, + Specializable, add_generic_lifetime, add_generic_type, apply_type_condition, get_assignable_conditions, get_used_generics, handle_generics, remove_generic, }; -use crate::types::{replace_type, type_contains, type_contains_lifetime}; +use crate::types::{TypeReplacer, replace_type, type_contains, type_contains_lifetime}; use proc_macro2::TokenStream; use quote::quote; use serde::{Deserialize, Serialize}; @@ -142,8 +142,8 @@ impl ImplBody { // set specialized trait name specialized.trait_name = specialized.get_spec_trait_name(); let mut replacer = TypeReplacer { - generic: self.trait_name.clone(), - type_: str_to_type_name(&specialized.trait_name), + prev: self.trait_name.clone(), + new: str_to_type_name(&specialized.trait_name), }; specialized.handle_items_replace(&mut replacer); diff --git a/spec-trait-impl/crates/spec-trait-utils/src/parsing.rs b/spec-trait-impl/crates/spec-trait-utils/src/parsing.rs index e0ef675..e3a93f5 100644 --- a/spec-trait-impl/crates/spec-trait-utils/src/parsing.rs +++ b/spec-trait-impl/crates/spec-trait-utils/src/parsing.rs @@ -201,6 +201,11 @@ pub fn get_generics_lifetimes>(generics_str: &str) -> T collect_generics_lifetimes(&generics) } +pub fn count_generics(generics_str: &str) -> usize { + let generics = str_to_generics(generics_str); + generics.params.len() +} + pub fn get_relevant_generics_names(generics: &Generics, generic: &str) -> Vec { let get_lifetimes = generic.starts_with('\''); let get_types = !get_lifetimes; diff --git a/spec-trait-impl/crates/spec-trait-utils/src/specialize.rs b/spec-trait-impl/crates/spec-trait-utils/src/specialize.rs index ee21b73..0483f2c 100644 --- a/spec-trait-impl/crates/spec-trait-utils/src/specialize.rs +++ b/spec-trait-impl/crates/spec-trait-utils/src/specialize.rs @@ -1,14 +1,14 @@ -use std::collections::HashSet; +use std::collections::{HashMap, HashSet}; use crate::conditions::WhenCondition; -use crate::conversions::{str_to_generics, str_to_lifetime, str_to_type_name}; +use crate::conversions::{str_to_generics, str_to_lifetime, str_to_type_name, to_string}; use crate::types::{ - Aliases, replace_infers, replace_type, type_assignable, type_contains, type_contains_lifetime, + Aliases, TypeReplacer, replace_infers, type_assignable, type_contains, type_contains_lifetime, }; use proc_macro2::Span; use syn::punctuated::Punctuated; use syn::visit::Visit; -use syn::visit_mut::{self, VisitMut}; +use syn::visit_mut::VisitMut; use syn::{GenericParam, Generics, Ident, LifetimeParam, Type, TypeParam}; // TODO: infer lifetimes as well @@ -29,12 +29,13 @@ pub fn get_assignable_conditions( conditions: &[WhenCondition], generics: &str, ) -> Vec { + let conditions = merge_ref_and_lifetime(conditions.to_vec()); conditions .iter() .filter_map(|c| match c { WhenCondition::Trait(_, _) => Some(c.clone()), WhenCondition::Type(g, t) => { - let types = get_generic_types_from_conditions(g, conditions); + let types = get_generic_types_from_conditions(g, &conditions); let most_specific = types.last() == Some(t); let diff_types = types.iter().any(|other_t| { !type_assignable(t, other_t, generics, &Aliases::default()) @@ -53,6 +54,124 @@ pub fn get_assignable_conditions( .collect() } +/// Splits a reference type `& 'a Type` into its lifetime and inner type components. +/// Returns `None` if the type is not a reference. +fn split_ref_type(ty: &str) -> Option<(Option, String)> { + let ty = str_to_type_name(ty); + if let syn::Type::Reference(type_ref) = ty { + let lifetime = type_ref.lifetime.as_ref().map(|lt| lt.to_string()); + let inner_type = to_string(&type_ref.elem); + Some((lifetime, inner_type)) + } else { + None + } +} + +/// Merge pairs of type conditions of the form +/// `T = &Type` and `T = & 'a _` into a single `T = & 'a Type`. +fn merge_ref_and_lifetime(conditions: Vec) -> Vec { + let mut generic_to_lifetimes: HashMap> = HashMap::new(); + let mut generic_to_types: HashMap> = HashMap::new(); + + // classify conditions + for (idx, cond) in conditions.iter().enumerate() { + if let WhenCondition::Type(g, t) = cond + && let Some((lt, inner)) = split_ref_type(t) + { + if lt.is_some() && inner == "_" { + generic_to_lifetimes.entry(g.clone()).or_default().push(idx); + } else if lt.is_none() && inner != "_" { + generic_to_types.entry(g.clone()).or_default().push(idx); + } + } + } + + let mut removed = HashSet::new(); + let mut new_conditions = Vec::new(); + + // for each generic, merge only when there is exactly one lifetime and one type + for (generic, lifetime_indices) in generic_to_lifetimes { + let type_indices = match generic_to_types.get(&generic) { + Some(v) if !v.is_empty() => v, + _ => continue, + }; + + // collect distinct lifetimes + let mut lifetimes = HashSet::new(); + for &idx in &lifetime_indices { + if let WhenCondition::Type(_, t) = &conditions[idx] + && let Some((Some(lt), _)) = split_ref_type(t) + { + lifetimes.insert(lt); + } + } + + // collect distinct inner types + let mut types = HashSet::new(); + for &idx in type_indices { + if let WhenCondition::Type(_, t) = &conditions[idx] + && let Some((_, inner)) = split_ref_type(t) + { + types.insert(inner); + } + } + + // if more than one distinct lifetime or type -> conflicting, no merge + if lifetimes.len() != 1 || types.len() != 1 { + continue; + } + + // we know we can merge one pair + let lifetime_idx = lifetime_indices[0]; + let type_idx = type_indices[0]; + + if let Some(merged) = + build_merged_ref_condition(&conditions, lifetime_idx, type_idx, &generic) + { + removed.insert(lifetime_idx); + removed.insert(type_idx); + new_conditions.push(merged); + } + } + + // keep all non-removed originals, then append merged ones + let mut result = vec![]; + for (idx, cond) in conditions.into_iter().enumerate() { + if !removed.contains(&idx) { + result.push(cond); + } + } + result.extend(new_conditions); + result +} + +/// Build `T = &'a Type` from indices of `T = &'a _` and `T = &Type`. +fn build_merged_ref_condition( + conditions: &[WhenCondition], + lifetime_idx: usize, + type_idx: usize, + generic: &str, +) -> Option { + let lifetime = match &conditions[lifetime_idx] { + WhenCondition::Type(_, t) => t, + _ => return None, + }; + let type_ = match &conditions[type_idx] { + WhenCondition::Type(_, t) => t, + _ => return None, + }; + + let (lt_opt, _) = split_ref_type(lifetime)?; + let (_, ty) = split_ref_type(type_)?; + + let lt = lt_opt?; + + Some(WhenCondition::Type( + generic.to_string(), + format!("& {} {}", lt, ty), + )) +} + /// Returns the list of types assigned to the given generic in the conditions, /// ordered from least specific to most specific. fn get_generic_types_from_conditions(generic: &str, conditions: &[WhenCondition]) -> Vec { @@ -68,18 +187,6 @@ fn get_generic_types_from_conditions(generic: &str, conditions: &[WhenCondition] types } -pub struct TypeReplacer { - pub generic: String, - pub type_: Type, -} - -impl VisitMut for TypeReplacer { - fn visit_type_mut(&mut self, node: &mut Type) { - replace_type(node, &self.generic, &self.type_); - visit_mut::visit_type_mut(self, node); - } -} - pub fn apply_type_condition( target: &mut T, generics: &mut Generics, @@ -118,8 +225,8 @@ pub fn apply_type_condition( // replace generic type with type in the items let mut replacer = TypeReplacer { - generic: item_generic.clone(), - type_: new_type.clone(), + prev: item_generic.clone(), + new: new_type.clone(), }; target.handle_items_replace(&mut replacer); @@ -231,7 +338,10 @@ pub fn handle_generics(generics_str: &str, mut generics_fn: F) { #[cfg(test)] mod tests { use super::*; - use crate::conversions::{str_to_generics, to_string}; + use crate::{ + conversions::{str_to_generics, to_string}, + types::TypeReplacer, + }; use syn::{Generics, Type}; #[test] @@ -252,8 +362,8 @@ mod tests { #[test] fn type_replacer() { let mut replacer = TypeReplacer { - generic: "T".into(), - type_: str_to_type_name("u32"), + prev: "T".into(), + new: str_to_type_name("u32"), }; let mut type_ = str_to_type_name("Vec"); @@ -373,4 +483,72 @@ mod tests { let types_v = get_generic_types_from_conditions("V", &conditions); assert!(types_v.is_empty()); } + + #[test] + fn split_ref_type_with_lifetime() { + let (lt, inner) = split_ref_type("&'a i32").expect("expected reference type"); + assert_eq!(lt, Some("'a".to_string())); + assert_eq!(inner.replace(" ", ""), "i32"); + } + + #[test] + fn split_ref_type_non_reference() { + assert!(split_ref_type("i32").is_none()); + } + + #[test] + fn build_merged_ref_condition_basic() { + let conditions = vec![ + WhenCondition::Type("T".into(), "&'a _".into()), + WhenCondition::Type("T".into(), "&i32".into()), + ]; + + let merged = + build_merged_ref_condition(&conditions, 0, 1, "T").expect("expected merged condition"); + + assert_eq!(merged, WhenCondition::Type("T".into(), "& 'a i32".into())); + } + + #[test] + fn build_merged_ref_condition_invalid_indices() { + let conditions = vec![WhenCondition::Trait("T".into(), vec!["Copy".into()])]; + + let merged = build_merged_ref_condition(&conditions, 0, 0, "T"); + assert!(merged.is_none()); + } + + #[test] + fn merge_ref_and_lifetime_single_pair() { + let conditions = vec![ + WhenCondition::Type("T".into(), "&'a _".into()), + WhenCondition::Type("T".into(), "&str".into()), + ]; + + let res = merge_ref_and_lifetime(conditions); + + assert_eq!(res.len(), 1); + assert_eq!(res[0], WhenCondition::Type("T".into(), "& 'a str".into())); + } + + #[test] + fn merge_ref_and_lifetime_multiple_pairs() { + let conditions = vec![ + WhenCondition::Type("T".into(), "&'a _".into()), + WhenCondition::Type("T".into(), "&'b _".into()), + WhenCondition::Type("T".into(), "&Foo".into()), + WhenCondition::Type("T".into(), "&Bar".into()), + ]; + + let res = merge_ref_and_lifetime(conditions); + + assert_eq!(res.len(), 4); + + let expected = vec![ + WhenCondition::Type("T".into(), "&'a _".into()), + WhenCondition::Type("T".into(), "&'b _".into()), + WhenCondition::Type("T".into(), "&Foo".into()), + WhenCondition::Type("T".into(), "&Bar".into()), + ]; + assert_eq!(res, expected); + } } diff --git a/spec-trait-impl/crates/spec-trait-utils/src/traits.rs b/spec-trait-impl/crates/spec-trait-utils/src/traits.rs index b40eb30..0a57ebe 100644 --- a/spec-trait-impl/crates/spec-trait-utils/src/traits.rs +++ b/spec-trait-impl/crates/spec-trait-utils/src/traits.rs @@ -8,10 +8,10 @@ use crate::parsing::{ get_generics_lifetimes, get_generics_types, get_relevant_generics_names, parse_generics, }; use crate::specialize::{ - Specializable, TypeReplacer, add_generic_lifetime, add_generic_type, apply_type_condition, + Specializable, add_generic_lifetime, add_generic_type, apply_type_condition, get_assignable_conditions, get_used_generics, handle_generics, remove_generic, }; -use crate::types::{get_unique_generic_name, replace_type}; +use crate::types::{LifetimeReplacer, TypeReplacer, get_unique_generic_name, replace_type}; use proc_macro2::TokenStream; use quote::quote; use serde::{Deserialize, Serialize}; @@ -129,8 +129,8 @@ impl TraitBody { // set specialized trait name specialized.name = impl_body.specialized.as_ref().unwrap().trait_name.clone(); let mut replacer = TypeReplacer { - generic: self.name.clone(), - type_: str_to_type_name(&specialized.name), + prev: self.name.clone(), + new: str_to_type_name(&specialized.name), }; specialized.handle_items_replace(&mut replacer); @@ -236,8 +236,8 @@ impl TraitBody { let type_ = str_to_type_name(&new_generic_name); let mut replacer = TypeReplacer { - generic: generic.to_owned(), - type_, + prev: generic.to_owned(), + new: type_, }; self.handle_items_replace(&mut replacer); } @@ -249,10 +249,9 @@ impl TraitBody { add_generic_lifetime(&mut trait_generics, &new_generic_name); remove_generic(&mut trait_generics, &generic); - let type_ = str_to_type_name(&new_generic_name); - let mut replacer = TypeReplacer { - generic: generic.to_owned(), - type_, + let mut replacer = LifetimeReplacer { + prev: generic.to_owned(), + new: new_generic_name, }; self.handle_items_replace(&mut replacer); } @@ -260,28 +259,57 @@ impl TraitBody { self.generics = to_string(&trait_generics); } + /** + get the generics in the trait corresponding to the impl_generic in the impl. + In most cases there is a single corresponding generic; however, if the same + impl generic is used multiple times in the trait instantiation (e.g. + `trait Trait` with `impl Trait for ...`), this will return + all corresponding trait generics (e.g. `vec!["T", "U"]`). + + # Example: + for trait `TraitName` and impl `impl TraitName for MyType` + - impl_generic = T -> trait_generics = ["A", "C"] + - impl_generic = U -> trait_generics = ["B"] + - impl_generic = V -> [] + */ + pub fn get_corresponding_generics( + &self, + impl_generics: &Generics, + impl_generic: &str, + ) -> Vec { + let trait_generics = str_to_generics(&self.generics); + + let impl_names = get_relevant_generics_names(impl_generics, impl_generic); + let trait_names = get_relevant_generics_names(&trait_generics, impl_generic); + + let mut result = Vec::new(); + for (idx, name) in impl_names.iter().enumerate() { + if name == impl_generic + && let Some(trait_name) = trait_names.get(idx).cloned() + { + result.push(trait_name); + } + } + + result + } + /** get the generic in the trait corresponding to the impl_generic in the impl # Example: for trait `TraitName` and impl `impl TraitName for MyType` - impl_generic = T -> trait_generic = A - impl_generic = U -> trait_generic = B - - impl_generic = C -> None + - impl_generic = V -> None */ pub fn get_corresponding_generic( &self, impl_generics: &Generics, impl_generic: &str, ) -> Option { - let trait_generics = str_to_generics(&self.generics); - - let impl_generic_param = get_relevant_generics_names(impl_generics, impl_generic) - .iter() - .position(|param| param == impl_generic)?; - - get_relevant_generics_names(&trait_generics, impl_generic) - .get(impl_generic_param) - .cloned() + self.get_corresponding_generics(impl_generics, impl_generic) + .into_iter() + .next() } } diff --git a/spec-trait-impl/crates/spec-trait-utils/src/type_visitor.rs b/spec-trait-impl/crates/spec-trait-utils/src/type_visitor.rs index 5abe67d..3f1ee27 100644 --- a/spec-trait-impl/crates/spec-trait-utils/src/type_visitor.rs +++ b/spec-trait-impl/crates/spec-trait-utils/src/type_visitor.rs @@ -26,6 +26,15 @@ pub trait VisitTypeInDepth { } } + // `T: Trait` + fn visit_bound(&mut self, bound: &mut TypeParamBound) { + if let TypeParamBound::Trait(trait_bound) = bound { + for segment in &mut trait_bound.path.segments { + self.visit_segment(segment); + } + } + } + /// General Type visitor fn visit_type(&mut self, ty: &mut Type) { self.before_visit_type(ty); @@ -38,6 +47,7 @@ pub trait VisitTypeInDepth { Type::BareFn(type_bare_fn) => self.visit_type_function(type_bare_fn), Type::ImplTrait(type_impl_trait) => self.visit_type_impl_trait(type_impl_trait), Type::Path(type_path) => self.visit_type_path(type_path), + Type::TraitObject(type_trait_object) => self.visit_type_trait_object(type_trait_object), _ => self.visit_type_default(ty), } } @@ -91,11 +101,7 @@ pub trait VisitTypeInDepth { /// `impl T` fn visit_type_impl_trait(&mut self, ty: &mut TypeImplTrait) { for bound in &mut ty.bounds { - if let TypeParamBound::Trait(trait_bound) = bound { - for segment in &mut trait_bound.path.segments { - self.visit_segment(segment); - } - } + self.visit_bound(bound); } } @@ -105,4 +111,11 @@ pub trait VisitTypeInDepth { self.visit_segment(segment); } } + + /// `dyn T` + fn visit_type_trait_object(&mut self, ty: &mut syn::TypeTraitObject) { + for bound in &mut ty.bounds { + self.visit_bound(bound); + } + } } diff --git a/spec-trait-impl/crates/spec-trait-utils/src/types.rs b/spec-trait-impl/crates/spec-trait-utils/src/types.rs index 1276b9f..a90517f 100644 --- a/spec-trait-impl/crates/spec-trait-utils/src/types.rs +++ b/spec-trait-impl/crates/spec-trait-utils/src/types.rs @@ -5,6 +5,7 @@ use crate::{ }; use proc_macro2::Span; use std::collections::{HashMap, HashSet}; +use syn::visit_mut::{self, VisitMut}; use syn::{ Expr, GenericArgument, GenericParam, Generics, Ident, Lifetime, PathArguments, ReturnType, Type, TypeParamBound, TypeReference, @@ -39,7 +40,7 @@ fn resolve_type(ty: &Type, aliases: &Aliases) -> Type { ty } -type GenericsMap = HashMap>; +pub type GenericsMap = HashMap>; #[derive(Debug, Default)] pub struct ConstrainedGenerics { @@ -71,15 +72,24 @@ impl From for ConstrainedGenerics { } } +fn get_type_from_generic(generic: &str, aliases: &Aliases) -> Type { + let ty = if generic.starts_with('\'') { + format!("&{} _", generic) + } else { + generic.to_string() + }; + + str_to_type_name(&get_concrete_type(&ty, aliases)) +} + pub fn type_assignable_generic_constraints( concrete_type: &str, declared_or_concrete_type: &str, generics: &str, aliases: &Aliases, ) -> Option { - let concrete_type = str_to_type_name(&get_concrete_type(concrete_type, aliases)); - let declared_or_concrete_type = - str_to_type_name(&get_concrete_type(declared_or_concrete_type, aliases)); + let concrete_type = get_type_from_generic(concrete_type, aliases); + let declared_or_concrete_type = get_type_from_generic(declared_or_concrete_type, aliases); let generics = str_to_generics(generics); let mut generics = ConstrainedGenerics::from(generics); @@ -286,6 +296,41 @@ fn can_assign( }) } + // `dyn T` + (Type::TraitObject(to1), Type::TraitObject(to2)) => { + to1.bounds.len() == to2.bounds.len() + && to1 + .bounds + .iter() + .zip(&to2.bounds) + .all(|(b1, b2)| match (b1, b2) { + (TypeParamBound::Trait(tb1), TypeParamBound::Trait(tb2)) => { + tb1.path.segments.len() == tb2.path.segments.len() + && tb1.path.segments.iter().zip(&tb2.path.segments).all( + |(seg1, seg2)| match (&seg1.arguments, &seg2.arguments) { + ( + PathArguments::AngleBracketed(args1), + PathArguments::AngleBracketed(args2), + ) => { + args1.args.len() == args2.args.len() + && args1.args.iter().zip(&args2.args).all( + |(arg1, arg2)| match (arg1, arg2) { + ( + GenericArgument::Type(t1), + GenericArgument::Type(t2), + ) => can_assign(t1, t2, generics), + _ => false, + }, + ) + } + _ => seg1.arguments.is_empty() && seg2.arguments.is_empty(), + }, + ) + } + _ => false, + }) + } + _ => false, } } @@ -351,6 +396,10 @@ fn check_and_assign_lifetime_generic( } pub fn type_contains(ty: &Type, generic: &str) -> bool { + if generic.starts_with('\'') { + return type_contains_lifetime(ty, generic); + } + let mut type_ = ty.clone(); let replacement = str_to_type_name("__G__"); @@ -368,24 +417,32 @@ pub fn type_contains_lifetime(ty: &Type, lifetime: &str) -> bool { to_string(&type_) != to_string(ty) } -/// Replaces all occurrences of `prev` in the given type with `new`. -pub fn replace_type(ty: &mut Type, prev: &str, new: &Type) { - struct TypeReplacer { - prev: String, - new: Type, - } - impl VisitTypeInDepth for TypeReplacer { - fn visit_ident(&mut self, ident: &mut Ident) { - if ident == &self.prev { - *ident = Ident::new(&to_string(&self.new.clone()), Span::call_site()); - } +pub struct TypeReplacer { + pub prev: String, + pub new: Type, +} + +impl VisitMut for TypeReplacer { + fn visit_type_mut(&mut self, node: &mut Type) { + self.visit_type(node); + visit_mut::visit_type_mut(self, node); + } +} +impl VisitTypeInDepth for TypeReplacer { + fn visit_ident(&mut self, ident: &mut Ident) { + if ident == &self.prev { + *ident = Ident::new(&to_string(&self.new.clone()), Span::call_site()); } - fn before_visit_type(&mut self, ty: &mut Type) { - if to_string(ty) == to_string(&str_to_type_name(self.prev.as_str())) { - *ty = self.new.clone(); - } + } + fn before_visit_type(&mut self, ty: &mut Type) { + if to_string(ty) == to_string(&str_to_type_name(self.prev.as_str())) { + *ty = self.new.clone(); } } +} + +/// Replaces all occurrences of `prev` in the given type with `new`. +pub fn replace_type(ty: &mut Type, prev: &str, new: &Type) { TypeReplacer { prev: prev.to_string(), new: new.clone(), @@ -393,19 +450,27 @@ pub fn replace_type(ty: &mut Type, prev: &str, new: &Type) { .visit_type(ty); } -/// Replaces all occurrences of `prev` lifetime in the given type with `new`. -pub fn replace_lifetime(ty: &mut Type, prev: &str, new: &str) { - struct LifetimeReplacer { - prev: String, - new: String, - } - impl VisitTypeInDepth for LifetimeReplacer { - fn visit_lifetime(&mut self, lifetime: &mut Lifetime) { - if lifetime.to_string() == self.prev { - *lifetime = Lifetime::new(&self.new, Span::call_site()); - } +pub struct LifetimeReplacer { + pub prev: String, + pub new: String, +} + +impl VisitMut for LifetimeReplacer { + fn visit_type_mut(&mut self, node: &mut Type) { + self.visit_type(node); + visit_mut::visit_type_mut(self, node); + } +} +impl VisitTypeInDepth for LifetimeReplacer { + fn visit_lifetime(&mut self, lifetime: &mut Lifetime) { + if lifetime.to_string() == self.prev { + *lifetime = Lifetime::new(&self.new, Span::call_site()); } } +} + +/// Replaces all occurrences of `prev` lifetime in the given type with `new`. +pub fn replace_lifetime(ty: &mut Type, prev: &str, new: &str) { LifetimeReplacer { prev: prev.to_string(), new: new.to_string(), @@ -548,6 +613,31 @@ pub fn assign_lifetimes(t1: &mut Type, t2: &Type, generics: &mut ConstrainedGene } } + // `dyn T` + (Type::TraitObject(trait1), Type::TraitObject(trait2)) => { + for (bound1, bound2) in trait1.bounds.iter_mut().zip(&trait2.bounds) { + if let (TypeParamBound::Trait(trait1), TypeParamBound::Trait(trait2)) = + (bound1, bound2) + { + for (seg1, seg2) in trait1.path.segments.iter_mut().zip(&trait2.path.segments) { + if let ( + PathArguments::AngleBracketed(args1), + PathArguments::AngleBracketed(args2), + ) = (&mut seg1.arguments, &seg2.arguments) + { + for (arg1, arg2) in args1.args.iter_mut().zip(&args2.args) { + if let (GenericArgument::Type(t1), GenericArgument::Type(t2)) = + (arg1, arg2) + { + assign_lifetimes(t1, t2, generics); + } + } + } + } + } + } + } + _ => {} } } @@ -1060,6 +1150,20 @@ mod tests { assert!(can_assign(&t1, &t2, &mut g)); } + #[test] + fn compare_types_dyn_trait() { + let mut g = ConstrainedGenerics::default(); + + let t1 = str_to_type_name("dyn Into"); + let t2 = str_to_type_name("dyn Into"); + assert!(can_assign(&t1, &t2, &mut g)); + + g.types.insert("T".to_string(), None); + let t1 = str_to_type_name("dyn Into"); + let t2 = str_to_type_name("dyn Into"); + assert!(can_assign(&t1, &t2, &mut g)); + } + #[test] fn contains_type_true() { let types = vec![ @@ -1072,6 +1176,8 @@ mod tests { "Other", "T", "fn(T) -> Other", + "impl Trait", + "dyn Trait", ]; for ty in types { let type_ = str_to_type_name(ty); @@ -1091,6 +1197,8 @@ mod tests { "Other", "T", "fn(T) -> Other", + "impl Trait", + "dyn Trait", ]; for ty in types { let type_ = str_to_type_name(ty); @@ -1267,6 +1375,19 @@ mod tests { ); } + #[test] + fn replace_type_dyn_trait() { + let new_ty: Type = parse2(quote! { String }).unwrap(); + + let mut ty: Type = parse2(quote! { dyn Into }).unwrap(); + replace_type(&mut ty, "T", &new_ty); + + assert_eq!( + to_string(&ty).replace(" ", ""), + "dyn Into".to_string().replace(" ", "") + ); + } + #[test] fn replace_infers_simple() { let mut ty: Type = parse2(quote! { _ }).unwrap(); @@ -1427,6 +1548,22 @@ mod tests { ); } + #[test] + fn replace_infers_dyn_trait() { + let mut ty: Type = parse2(quote! { dyn Trait<_> }).unwrap(); + let mut generics = HashSet::new(); + let mut counter = 0; + let mut new_generics = vec![]; + + replace_infers(&mut ty, &mut generics, &mut counter, &mut new_generics); + + assert_eq!( + to_string(&ty).replace(" ", ""), + "dyn Trait<__G_0__>".to_string().replace(" ", "") + ); + assert_eq!(new_generics, vec!["__G_0__".to_string()]); + } + #[test] fn strip_lifetimes_simple() { let mut ty: Type = parse2(quote! { &'a u8 }).unwrap();