diff --git a/spec-trait-impl/crates/spec-trait-bin/src/main.rs b/spec-trait-impl/crates/spec-trait-bin/src/main.rs index d4fb9f0..0409e63 100644 --- a/spec-trait-impl/crates/spec-trait-bin/src/main.rs +++ b/spec-trait-impl/crates/spec-trait-bin/src/main.rs @@ -244,10 +244,10 @@ impl Foo for T { } } -#[when(all(not(T = i32), not(T = ZST)))] +#[when(all(not(T = i32), not(T = ZST), not(U = i32)))] impl Foo for T { fn foo(&self, _x: U) { - println!("Foo impl T where T is not i32 or ZST"); + println!("Foo impl T where T is not i32 or ZST and U is not i32"); } } @@ -326,7 +326,7 @@ fn main() { spec! { x.foo(1u8); Vec; [u8]; u8 = MyType } // -> "Foo impl T where T is Vec<_> and U is MyType" spec! { 1i32.foo("str"); i32; [&str] } // -> "Foo impl T where U is &str" spec! { zst.foo("str"); ZST; [&str] } // -> "Foo impl T where U is &str" - spec! { 1u8.foo(1u8); u8; [u8] } // -> "Foo impl T where T is not i32 or ZST" + spec! { 1u8.foo(1u8); u8; [u8] } // -> "Foo impl T where T is not i32 or ZST and U is not i32" println!(); // T - Foo4 diff --git a/spec-trait-impl/crates/spec-trait-utils/src/impls.rs b/spec-trait-impl/crates/spec-trait-utils/src/impls.rs index 2b200a9..034f4c5 100644 --- a/spec-trait-impl/crates/spec-trait-utils/src/impls.rs +++ b/spec-trait-impl/crates/spec-trait-utils/src/impls.rs @@ -8,8 +8,8 @@ use crate::parsing::{ parse_generics, }; use crate::specialize::{ - Specializable, add_generic_lifetime, add_generic_type, apply_type_condition, - get_assignable_conditions, get_used_generics, remove_generic, + Specializable, TypeReplacer, add_generic_lifetime, add_generic_type, apply_type_condition, + get_assignable_conditions, get_used_generics, handle_generics, remove_generic, }; use crate::types::{replace_type, type_contains, type_contains_lifetime}; use proc_macro2::TokenStream; @@ -130,58 +130,65 @@ impl ImplBody { let mut new_impl = self.clone(); let mut specialized = new_impl.clone(); - // set specialized trait name - specialized.trait_name = specialized.get_spec_trait_name(); - // apply condition if let Some(condition) = &self.condition { specialized.apply_condition(condition); } - // set missing generics - let mut trait_generics = str_to_generics(&specialized.trait_generics); - let curr_generics_types = get_generics_types::>(&specialized.trait_generics); - let curr_generics_lifetimes = - get_generics_lifetimes::>(&specialized.trait_generics); - for generic in get_generics_types::>(&specialized.impl_generics) { - if !curr_generics_types.contains(&generic) { - add_generic_type(&mut trait_generics, &generic); + // fix generics + specialized.add_missing_generics(); + specialized.remove_unused_generics(); + + // set specialized trait name + specialized.trait_name = specialized.get_spec_trait_name(); + let mut replacer = TypeReplacer { + generic: self.trait_name.clone(), + type_: str_to_type_name(&specialized.trait_name), + }; + specialized.handle_items_replace(&mut replacer); + + // TODO: fix generics for impl Trait types + + // set specialized trait + new_impl.specialized = Some(Box::new(specialized)); + new_impl + } + + /// add missing generics from impl to trait + fn add_missing_generics(&mut self) { + let mut trait_generics = str_to_generics(&self.trait_generics); + let curr_generics_types = get_generics_types::>(&self.trait_generics); + let curr_generics_lifetimes = get_generics_lifetimes::>(&self.trait_generics); + handle_generics(&self.impl_generics, |generic| { + if generic.starts_with("'") && !curr_generics_lifetimes.contains(generic) { + add_generic_lifetime(&mut trait_generics, generic); } - } - for generic in get_generics_lifetimes::>(&specialized.impl_generics) { - if !curr_generics_lifetimes.contains(&generic) { - add_generic_lifetime(&mut trait_generics, &generic); + if !generic.starts_with("'") && !curr_generics_types.contains(generic) { + add_generic_type(&mut trait_generics, generic); } - } - specialized.trait_generics = to_string(&trait_generics); - - // clean unused generics - let used_generics = - get_used_generics(&specialized, &str_to_generics(&specialized.impl_generics)); - - let mut impl_generics = str_to_generics(&specialized.impl_generics); - let mut trait_generics = str_to_generics(&specialized.trait_generics); - for generic in get_generics_lifetimes::>(&specialized.impl_generics) { - if !used_generics.contains(&generic) { - remove_generic(&mut trait_generics, &generic); - if !type_contains_lifetime(&str_to_type_name(&specialized.type_name), &generic) { - remove_generic(&mut impl_generics, &generic); + }); + self.trait_generics = to_string(&trait_generics); + } + + /// remove unused generics from impl and trait + fn remove_unused_generics(&mut self) { + let used_generics = get_used_generics(self, &str_to_generics(&self.impl_generics)); + let mut impl_generics = str_to_generics(&self.impl_generics); + let mut trait_generics = str_to_generics(&self.trait_generics); + handle_generics(&self.impl_generics, |generic| { + if !used_generics.contains(generic) { + remove_generic(&mut trait_generics, generic); + let ty = str_to_type_name(&self.type_name); + if !generic.starts_with("'") && !type_contains(&ty, generic) { + remove_generic(&mut impl_generics, generic); } - } - } - for generic in get_generics_types::>(&specialized.impl_generics) { - if !used_generics.contains(&generic) { - remove_generic(&mut trait_generics, &generic); - if !type_contains(&str_to_type_name(&specialized.type_name), &generic) { - remove_generic(&mut impl_generics, &generic); + if generic.starts_with("'") && !type_contains_lifetime(&ty, generic) { + remove_generic(&mut impl_generics, generic); } } - } - specialized.impl_generics = to_string(&impl_generics); - specialized.trait_generics = to_string(&trait_generics); - - new_impl.specialized = Some(Box::new(specialized)); - new_impl + }); + self.impl_generics = to_string(&impl_generics); + self.trait_generics = to_string(&trait_generics); } /// apply a condition to the impl body, modifying its generics and items diff --git a/spec-trait-impl/crates/spec-trait-utils/src/lib.rs b/spec-trait-impl/crates/spec-trait-utils/src/lib.rs index 459530f..b10677d 100644 --- a/spec-trait-impl/crates/spec-trait-utils/src/lib.rs +++ b/spec-trait-impl/crates/spec-trait-utils/src/lib.rs @@ -6,4 +6,5 @@ pub mod impls; pub mod parsing; mod specialize; pub mod traits; +mod type_visitor; pub mod types; diff --git a/spec-trait-impl/crates/spec-trait-utils/src/specialize.rs b/spec-trait-impl/crates/spec-trait-utils/src/specialize.rs index 61a5e2c..99f0885 100644 --- a/spec-trait-impl/crates/spec-trait-utils/src/specialize.rs +++ b/spec-trait-impl/crates/spec-trait-utils/src/specialize.rs @@ -1,7 +1,7 @@ use std::collections::HashSet; use crate::conditions::WhenCondition; -use crate::conversions::{str_to_lifetime, str_to_type_name}; +use crate::conversions::{str_to_generics, str_to_lifetime, str_to_type_name}; use crate::types::{ Aliases, replace_infers, replace_type, type_assignable, type_contains, type_contains_lifetime, }; @@ -101,11 +101,13 @@ pub fn apply_type_condition( for generic in new_generics { add_generic_type(generics, &generic); add_generic_type(other_generics, &generic); + // TODO: add generic to impl Trait types } // remove generic type remove_generic(generics, &item_generic); remove_generic(other_generics, impl_generic); + // TODO: remove generic from impl Trait types // replace generic type with type in the items let mut replacer = TypeReplacer { @@ -209,6 +211,16 @@ pub fn get_used_generics(target: &T, generics: &Generics) -> H visitor.used_generics } +pub fn handle_generics(generics_str: &str, mut generics_fn: F) { + let generics = str_to_generics(generics_str); + for g in collect_generics_lifetimes::>(&generics) { + generics_fn(&g); + } + for g in collect_generics_types::>(&generics) { + generics_fn(&g); + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/spec-trait-impl/crates/spec-trait-utils/src/traits.rs b/spec-trait-impl/crates/spec-trait-utils/src/traits.rs index 9978afb..45ee2a4 100644 --- a/spec-trait-impl/crates/spec-trait-utils/src/traits.rs +++ b/spec-trait-impl/crates/spec-trait-utils/src/traits.rs @@ -9,7 +9,7 @@ use crate::parsing::{ }; use crate::specialize::{ Specializable, TypeReplacer, add_generic_lifetime, add_generic_type, apply_type_condition, - get_assignable_conditions, get_used_generics, remove_generic, + get_assignable_conditions, get_used_generics, handle_generics, remove_generic, }; use crate::types::get_unique_generic_name; use proc_macro2::TokenStream; @@ -110,25 +110,11 @@ impl TraitBody { let mut new_trait = self.clone(); let mut specialized = new_trait.clone(); - // set specialized trait name - specialized.name = impl_body.specialized.as_ref().unwrap().trait_name.clone(); - // replace generics with unique generic name specialized.replace_generics_names(); - // set missing generic lifetimes - let mut generics = str_to_generics(&specialized.generics); - let impl_generics = &impl_body.specialized.as_ref().unwrap().trait_generics; - let specialized_impl_generics = str_to_generics(impl_generics); - for generic in get_generics_lifetimes::>(impl_generics) { - if specialized - .get_corresponding_generic(&specialized_impl_generics, &generic) - .is_none() - { - add_generic_lifetime(&mut generics, &generic); - } - } - specialized.generics = to_string(&generics); + // add missing generics from impl + specialized.add_missing_generics(impl_body); // apply condition if let Some(condition) = &impl_body.condition { @@ -136,39 +122,57 @@ impl TraitBody { specialized.apply_condition(&mut impl_generics, condition); } - // set missing generic types - let mut generics = str_to_generics(&specialized.generics); + // fix generics + specialized.add_missing_generics(impl_body); + specialized.remove_unused_generics(); + + // set specialized trait name + specialized.name = impl_body.specialized.as_ref().unwrap().trait_name.clone(); + let mut replacer = TypeReplacer { + generic: self.name.clone(), + type_: str_to_type_name(&specialized.name), + }; + specialized.handle_items_replace(&mut replacer); + + // TODO: fix generics for impl Trait types + + // set specialized trait + new_trait.specialized = Some(Box::new(specialized)); + new_trait + } + + /// add missing generics from impl to trait + fn add_missing_generics(&mut self, impl_body: &ImplBody) { + let mut generics = str_to_generics(&self.generics); let impl_generics = &impl_body.specialized.as_ref().unwrap().trait_generics; let specialized_impl_generics = str_to_generics(impl_generics); - for generic in get_generics_types::>(impl_generics) { - if specialized - .get_corresponding_generic(&specialized_impl_generics, &generic) + handle_generics(impl_generics, |generic| { + if self + .get_corresponding_generic(&specialized_impl_generics, generic) .is_none() { - add_generic_type(&mut generics, &generic); + if generic.starts_with("'") { + add_generic_lifetime(&mut generics, generic); + } else { + add_generic_type(&mut generics, generic); + } } - } - specialized.generics = to_string(&generics); + }); + self.generics = to_string(&generics); + self.generics = to_string(&generics); + } + /// remove unused generics from trait + fn remove_unused_generics(&mut self) { // clean unused generics - let used_generics = - get_used_generics(&specialized, &str_to_generics(&specialized.generics)); - - let mut generics = str_to_generics(&specialized.generics); - for generic in get_generics_lifetimes::>(&specialized.generics) { - if !used_generics.contains(&generic) { - remove_generic(&mut generics, &generic); - } - } - for generic in get_generics_types::>(&specialized.generics) { - if !used_generics.contains(&generic) { - remove_generic(&mut generics, &generic); + let used_generics = get_used_generics(self, &str_to_generics(&self.generics)); + let mut generics = str_to_generics(&self.generics); + handle_generics(&self.generics, |generic| { + if !used_generics.contains(generic) { + remove_generic(&mut generics, generic); } - } - specialized.generics = to_string(&generics); - - new_trait.specialized = Some(Box::new(specialized)); - new_trait + }); + self.generics = to_string(&generics); } /// apply a condition to the trait body, modifying its generics and items diff --git a/spec-trait-impl/crates/spec-trait-utils/src/type_visitor.rs b/spec-trait-impl/crates/spec-trait-utils/src/type_visitor.rs new file mode 100644 index 0000000..5abe67d --- /dev/null +++ b/spec-trait-impl/crates/spec-trait-utils/src/type_visitor.rs @@ -0,0 +1,108 @@ +use syn::{ + GenericArgument, Ident, Lifetime, PathArguments, PathSegment, ReturnType, Type, TypeArray, + TypeBareFn, TypeImplTrait, TypeParamBound, TypeParen, TypePath, TypeReference, TypeSlice, + TypeTuple, +}; + +pub trait VisitTypeInDepth { + /// `T` + fn visit_ident(&mut self, _ident: &mut Ident) {} + + /// `'a` + fn visit_lifetime(&mut self, _lifetime: &mut Lifetime) {} + + /// `T` + fn visit_segment(&mut self, segment: &mut PathSegment) { + // `T` + self.visit_ident(&mut segment.ident); + + // `` + if let PathArguments::AngleBracketed(angle_bracketed) = &mut segment.arguments { + for arg in &mut angle_bracketed.args { + if let GenericArgument::Type(ty) = arg { + self.visit_type(ty); + } + } + } + } + + /// General Type visitor + fn visit_type(&mut self, ty: &mut Type) { + self.before_visit_type(ty); + match ty { + Type::Tuple(type_tuple) => self.visit_type_tuple(type_tuple), + Type::Reference(type_reference) => self.visit_type_reference(type_reference), + Type::Array(type_array) => self.visit_type_array(type_array), + Type::Slice(type_slice) => self.visit_type_slice(type_slice), + Type::Paren(type_paren) => self.visit_type_parentheses(type_paren), + Type::BareFn(type_bare_fn) => self.visit_type_function(type_bare_fn), + Type::ImplTrait(type_impl_trait) => self.visit_type_impl_trait(type_impl_trait), + Type::Path(type_path) => self.visit_type_path(type_path), + _ => self.visit_type_default(ty), + } + } + + /// Common visitor called before specific type visitors + fn before_visit_type(&mut self, _ty: &mut Type) {} + + /// Default visitor for types not explicitly handled + fn visit_type_default(&mut self, _ty: &mut Type) {} + + /// `(T, U)` + fn visit_type_tuple(&mut self, ty: &mut TypeTuple) { + for element in &mut ty.elems { + self.visit_type(element); + } + } + + /// `&T` + fn visit_type_reference(&mut self, ty: &mut TypeReference) { + if let Some(lifetime) = &mut ty.lifetime { + self.visit_lifetime(lifetime); + } + self.visit_type(&mut ty.elem); + } + + /// `[T; N]` + fn visit_type_array(&mut self, ty: &mut TypeArray) { + self.visit_type(&mut ty.elem); + } + + /// `[T]` + fn visit_type_slice(&mut self, ty: &mut TypeSlice) { + self.visit_type(&mut ty.elem); + } + + /// `(T)` + fn visit_type_parentheses(&mut self, ty: &mut TypeParen) { + self.visit_type(&mut ty.elem); + } + + /// `fn(T) -> U` + fn visit_type_function(&mut self, ty: &mut TypeBareFn) { + for input in &mut ty.inputs { + self.visit_type(&mut input.ty); + } + if let ReturnType::Type(_, ty) = &mut ty.output { + self.visit_type(ty); + } + } + + /// `impl T` + fn visit_type_impl_trait(&mut self, ty: &mut TypeImplTrait) { + for bound in &mut ty.bounds { + if let TypeParamBound::Trait(trait_bound) = bound { + for segment in &mut trait_bound.path.segments { + self.visit_segment(segment); + } + } + } + } + + /// `T, T` + fn visit_type_path(&mut self, ty: &mut TypePath) { + for segment in &mut ty.path.segments { + self.visit_segment(segment); + } + } +} diff --git a/spec-trait-impl/crates/spec-trait-utils/src/types.rs b/spec-trait-impl/crates/spec-trait-utils/src/types.rs index f79a37c..44c5f02 100644 --- a/spec-trait-impl/crates/spec-trait-utils/src/types.rs +++ b/spec-trait-impl/crates/spec-trait-utils/src/types.rs @@ -1,12 +1,13 @@ use crate::{ conversions::{str_to_generics, str_to_lifetime, str_to_type_name, to_string}, specialize::collect_generics_lifetimes, + type_visitor::VisitTypeInDepth, }; use proc_macro2::Span; use std::collections::{HashMap, HashSet}; use syn::{ - Expr, GenericArgument, GenericParam, Generics, Ident, PathArguments, ReturnType, Type, - TypeArray, TypeReference, TypeSlice, TypeTuple, + Expr, GenericArgument, GenericParam, Generics, Ident, Lifetime, PathArguments, ReturnType, + Type, TypeParamBound, TypeReference, }; pub type Aliases = HashMap>; @@ -18,88 +19,24 @@ pub fn get_concrete_type(type_or_alias: &str, aliases: &Aliases) -> String { } fn resolve_type(ty: &Type, aliases: &Aliases) -> Type { - match unwrap_paren(ty) { - // (T, U) - Type::Tuple(tuple) => { - let resolved_elems = tuple - .elems + struct TypeResolver<'a> { + aliases: &'a Aliases, + } + impl<'a> VisitTypeInDepth for TypeResolver<'a> { + fn before_visit_type(&mut self, ty: &mut Type) { + *ty = unwrap_paren(ty).clone(); + if let Some((k, _)) = self + .aliases .iter() - .map(|elem| resolve_type(elem, aliases)) - .collect(); - Type::Tuple(TypeTuple { - elems: resolved_elems, - ..tuple.clone() - }) - } - - // &T - Type::Reference(reference) => { - let resolved_elem = resolve_type(&reference.elem, aliases); - Type::Reference(TypeReference { - elem: Box::new(resolved_elem), - ..reference.clone() - }) - } - - // [T; N] - Type::Array(array) => { - let resolved_elem = resolve_type(&array.elem, aliases); - Type::Array(TypeArray { - elem: Box::new(resolved_elem), - ..array.clone() - }) - } - - // [T] - Type::Slice(slice) => { - let resolved_elem = resolve_type(&slice.elem, aliases); - Type::Slice(TypeSlice { - elem: Box::new(resolved_elem), - ..slice.clone() - }) - } - - // fn(T) -> U - Type::BareFn(bare_fn) => { - let mut resolved = bare_fn.clone(); - - for input in &mut resolved.inputs { - input.ty = resolve_type(&input.ty, aliases); - } - - if let ReturnType::Type(arrow, ty) = &resolved.output { - let new_ty = resolve_type(ty, aliases); - resolved.output = ReturnType::Type(*arrow, Box::new(new_ty)); - } - - Type::BareFn(resolved) - } - - // T, T - Type::Path(type_path) if type_path.qself.is_none() => { - let mut resolved_path = type_path.clone(); - - let ident = type_path.path.segments.last().unwrap().ident.to_string(); - if let Some((k, _)) = aliases.iter().find(|(_, v)| v.contains(&ident)) { - return str_to_type_name(k); - } - - for segment in &mut resolved_path.path.segments { - if let PathArguments::AngleBracketed(args) = &mut segment.arguments { - for arg in &mut args.args { - if let GenericArgument::Type(inner_ty) = arg { - *inner_ty = resolve_type(inner_ty, aliases); - } - } - } + .find(|(_, v)| v.contains(&to_string(ty))) + { + *ty = str_to_type_name(k); } - - Type::Path(resolved_path) } - - // Default case: return the type as-is - _ => ty.clone(), } + let mut ty = ty.clone(); + TypeResolver { aliases }.visit_type(&mut ty); + ty } type GenericsMap = HashMap>; @@ -288,6 +225,41 @@ fn can_assign( }) } + // impl T + (Type::ImplTrait(tr1), Type::ImplTrait(tr2)) => { + tr1.bounds.len() == tr2.bounds.len() + && tr1 + .bounds + .iter() + .zip(&tr2.bounds) + .all(|(b1, b2)| match (b1, b2) { + (TypeParamBound::Trait(tb1), TypeParamBound::Trait(tb2)) => { + tb1.path.segments.len() == tb2.path.segments.len() + && tb1.path.segments.iter().zip(&tb2.path.segments).all( + |(seg1, seg2)| match (&seg1.arguments, &seg2.arguments) { + ( + PathArguments::AngleBracketed(args1), + PathArguments::AngleBracketed(args2), + ) => { + args1.args.len() == args2.args.len() + && args1.args.iter().zip(&args2.args).all( + |(arg1, arg2)| match (arg1, arg2) { + ( + GenericArgument::Type(t1), + GenericArgument::Type(t2), + ) => can_assign(t1, t2, generics), + _ => false, + }, + ) + } + _ => seg1.arguments.is_empty() && seg2.arguments.is_empty(), + }, + ) + } + _ => false, + }) + } + _ => false, } } @@ -371,181 +343,71 @@ pub fn type_contains_lifetime(ty: &Type, lifetime: &str) -> bool { /// Replaces all occurrences of `prev` in the given type with `new`. pub fn replace_type(ty: &mut Type, prev: &str, new: &Type) { - if to_string(ty) == to_string(&str_to_type_name(prev)) { - *ty = new.clone(); - return; + struct TypeReplacer { + prev: String, + new: Type, } - - match ty { - // (T, U) - Type::Tuple(t) => { - for elem in &mut t.elems { - replace_type(elem, prev, new); - } - } - - // &T - Type::Reference(r) => replace_type(&mut r.elem, prev, new), - - // [T; N] - Type::Array(a) => replace_type(&mut a.elem, prev, new), - - // [T] - Type::Slice(s) => replace_type(&mut s.elem, prev, new), - - // (T) - Type::Paren(s) => replace_type(&mut s.elem, prev, new), - - // _ - Type::Infer(_) if prev == "_" => { - *ty = new.clone(); - } - - // T, T - Type::Path(type_path) => { - // T - if type_path.qself.is_none() - && type_path.path.segments.len() == 1 - && type_path.path.segments[0].ident == prev - && type_path.path.segments[0].arguments.is_empty() - { - *ty = new.clone(); - return; - } - - // T - for seg in &mut type_path.path.segments { - // T - if seg.ident == prev { - seg.ident = Ident::new(&to_string(&new.clone()), Span::call_site()); - } - - // - if let PathArguments::AngleBracketed(ref mut ab) = seg.arguments { - for arg in ab.args.iter_mut() { - if let GenericArgument::Type(inner_ty) = arg { - replace_type(inner_ty, prev, new); - } - } - } + impl VisitTypeInDepth for TypeReplacer { + fn visit_ident(&mut self, ident: &mut Ident) { + if ident == &self.prev { + *ident = Ident::new(&to_string(&self.new.clone()), Span::call_site()); } } - - // fn(T) -> U - Type::BareFn(bare_fn) => { - for input in &mut bare_fn.inputs { - replace_type(&mut input.ty, prev, new); - } - - if let ReturnType::Type(_, ty) = &mut bare_fn.output { - replace_type(ty, prev, new); + fn before_visit_type(&mut self, ty: &mut Type) { + if to_string(ty) == to_string(&str_to_type_name(self.prev.as_str())) { + *ty = self.new.clone(); } } - - _ => {} } + TypeReplacer { + prev: prev.to_string(), + new: new.clone(), + } + .visit_type(ty); } /// Replaces all occurrences of `prev` lifetime in the given type with `new`. pub fn replace_lifetime(ty: &mut Type, prev: &str, new: &str) { - match ty { - Type::Reference(r) => { - if r.lifetime.as_ref().is_some_and(|l| l.to_string() == prev) { - r.lifetime = Some(str_to_lifetime(new)); - } - replace_lifetime(&mut r.elem, prev, new); - } - Type::Tuple(t) => { - for elem in &mut t.elems { - replace_lifetime(elem, prev, new); - } - } - Type::Array(a) => replace_lifetime(&mut a.elem, prev, new), - Type::Slice(s) => replace_lifetime(&mut s.elem, prev, new), - Type::Paren(p) => replace_lifetime(&mut p.elem, prev, new), - Type::Path(type_path) => { - for seg in &mut type_path.path.segments { - if let PathArguments::AngleBracketed(ref mut ab) = seg.arguments { - for arg in ab.args.iter_mut() { - if let GenericArgument::Type(inner_ty) = arg { - replace_lifetime(inner_ty, prev, new); - } - } - } - } - } - Type::BareFn(bare_fn) => { - for input in &mut bare_fn.inputs { - replace_lifetime(&mut input.ty, prev, new); - } - - if let ReturnType::Type(_, ty) = &mut bare_fn.output { - replace_lifetime(ty, prev, new); + struct LifetimeReplacer { + prev: String, + new: String, + } + impl VisitTypeInDepth for LifetimeReplacer { + fn visit_lifetime(&mut self, lifetime: &mut Lifetime) { + if lifetime.to_string() == self.prev { + *lifetime = Lifetime::new(&self.new, Span::call_site()); } } - _ => {} } + LifetimeReplacer { + prev: prev.to_string(), + new: new.to_string(), + } + .visit_type(ty); } /// removes all lifetimes present in generics pub fn strip_lifetimes(ty: &mut Type, generics: &Generics) { - match ty { - // (T, U) - Type::Tuple(t) => { - for elem in &mut t.elems { - strip_lifetimes(elem, generics); - } - } - - // &T - Type::Reference(r) => { - let generics_lifetimes = collect_generics_lifetimes::>(generics); - - if r.lifetime + struct LifetimeStripper { + generics: HashSet, + } + impl VisitTypeInDepth for LifetimeStripper { + fn visit_type_reference(&mut self, ty: &mut TypeReference) { + if ty + .lifetime .as_ref() - .is_some_and(|l| generics_lifetimes.contains(&l.to_string())) + .is_some_and(|l| self.generics.contains(&l.to_string())) { - r.lifetime = None; + ty.lifetime = None; } - strip_lifetimes(&mut r.elem, generics); + self.visit_type(&mut ty.elem); } - - // [T; N] - Type::Array(a) => strip_lifetimes(&mut a.elem, generics), - - // [T] - Type::Slice(s) => strip_lifetimes(&mut s.elem, generics), - - // (T) - Type::Paren(s) => strip_lifetimes(&mut s.elem, generics), - - // T, T - Type::Path(type_path) => { - for seg in &mut type_path.path.segments { - if let PathArguments::AngleBracketed(ref mut ab) = seg.arguments { - for arg in ab.args.iter_mut() { - if let GenericArgument::Type(inner_ty) = arg { - strip_lifetimes(inner_ty, generics); - } - } - } - } - } - - // fn(T) -> U - Type::BareFn(bare_fn) => { - for input in &mut bare_fn.inputs { - strip_lifetimes(&mut input.ty, generics); - } - - if let ReturnType::Type(_, ty) = &mut bare_fn.output { - strip_lifetimes(ty, generics); - } - } - - _ => {} } + LifetimeStripper { + generics: collect_generics_lifetimes::>(generics), + } + .visit_type(ty); } /// replaces all lifetimes with the most specific one in two types @@ -634,11 +496,35 @@ pub fn assign_lifetimes(t1: &mut Type, t2: &Type, generics: &mut ConstrainedGene } } + // impl T + (Type::ImplTrait(impl1), Type::ImplTrait(impl2)) => { + for (bound1, bound2) in impl1.bounds.iter_mut().zip(&impl2.bounds) { + if let (TypeParamBound::Trait(trait1), TypeParamBound::Trait(trait2)) = + (bound1, bound2) + { + for (seg1, seg2) in trait1.path.segments.iter_mut().zip(&trait2.path.segments) { + if let ( + PathArguments::AngleBracketed(args1), + PathArguments::AngleBracketed(args2), + ) = (&mut seg1.arguments, &seg2.arguments) + { + for (arg1, arg2) in args1.args.iter_mut().zip(&args2.args) { + if let (GenericArgument::Type(t1), GenericArgument::Type(t2)) = + (arg1, arg2) + { + assign_lifetimes(t1, t2, generics); + } + } + } + } + } + } + } + _ => {} } } -// TODO: use replace_type to simplify this function /// Replaces all occurrences of `_` (inferred types) in the given type with fresh generic type parameters. pub fn replace_infers( ty: &mut Type, @@ -646,59 +532,26 @@ pub fn replace_infers( counter: &mut usize, new_generics: &mut Vec, ) { - match ty { - // (T, U, _) - Type::Tuple(t) => { - for elem in &mut t.elems { - replace_infers(elem, generics, counter, new_generics); - } - } - - // &_ - Type::Reference(r) => replace_infers(&mut r.elem, generics, counter, new_generics), - - // [_; N] - Type::Array(a) => replace_infers(&mut a.elem, generics, counter, new_generics), - - // [_] - Type::Slice(s) => replace_infers(&mut s.elem, generics, counter, new_generics), - - // (_) - Type::Paren(p) => replace_infers(&mut p.elem, generics, counter, new_generics), - - // T<_> - Type::Path(type_path) => { - for seg in &mut type_path.path.segments { - if let PathArguments::AngleBracketed(ref mut ab) = seg.arguments { - for arg in ab.args.iter_mut() { - if let GenericArgument::Type(inner_ty) = arg { - replace_infers(inner_ty, generics, counter, new_generics); - } - } - } - } - } - - // fn(T) -> U - Type::BareFn(bare_fn) => { - for input in &mut bare_fn.inputs { - replace_infers(&mut input.ty, generics, counter, new_generics); - } - - if let ReturnType::Type(_, ty) = &mut bare_fn.output { - replace_infers(ty, generics, counter, new_generics); + struct InferReplacer<'a> { + generics: &'a mut HashSet, + counter: &'a mut usize, + new_generics: &'a mut Vec, + } + impl VisitTypeInDepth for InferReplacer<'_> { + fn visit_type_default(&mut self, ty: &mut Type) { + if matches!(ty, Type::Infer(_)) { + let name = get_unique_generic_name(self.generics, self.counter, None); + *ty = str_to_type_name(&name); + self.new_generics.push(name); } } - - // _ - Type::Infer(_) => { - let name = get_unique_generic_name(generics, counter, None); - *ty = str_to_type_name(&name); - new_generics.push(name); - } - - _ => {} } + InferReplacer { + generics, + counter, + new_generics, + } + .visit_type(ty); } pub fn get_unique_generic_name( @@ -788,6 +641,23 @@ mod tests { ); } + #[test] + fn resolve_type_fn() { + let ty = str_to_type_name("fn(MyType) -> MyType"); + let resolved = resolve_type(&ty, &get_aliases()); + assert_eq!(to_string(&resolved).replace(" ", ""), "fn(u8)->u8"); + } + + #[test] + fn resolve_type_impl_trait() { + let ty = str_to_type_name("impl Into"); + let resolved = resolve_type(&ty, &get_aliases()); + assert_eq!( + to_string(&resolved).replace(" ", ""), + "impl Into".replace(" ", "") + ); + } + #[test] fn compare_types_simple() { let mut g = ConstrainedGenerics::default(); @@ -1107,6 +977,20 @@ mod tests { assert!(!can_assign(&t1, &t2, &mut g)); } + #[test] + fn compare_types_impl_trait() { + let mut g = ConstrainedGenerics::default(); + + let t1 = str_to_type_name("impl Into"); + let t2 = str_to_type_name("impl Into"); + assert!(can_assign(&t1, &t2, &mut g)); + + g.types.insert("T".to_string(), None); + let t1 = str_to_type_name("impl Into"); + let t2 = str_to_type_name("impl Into"); + assert!(can_assign(&t1, &t2, &mut g)); + } + #[test] fn contains_type_true() { let types = vec![ @@ -1301,6 +1185,19 @@ mod tests { assert_eq!(to_string(&ty).replace(" ", ""), "String".to_string()); } + #[test] + fn replace_type_impl_trait() { + let new_ty: Type = parse2(quote! { String }).unwrap(); + + let mut ty: Type = parse2(quote! { impl Into }).unwrap(); + replace_type(&mut ty, "T", &new_ty); + + assert_eq!( + to_string(&ty).replace(" ", ""), + "impl Into".to_string().replace(" ", "") + ); + } + #[test] fn replace_infers_simple() { let mut ty: Type = parse2(quote! { _ }).unwrap();