Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 27 additions & 1 deletion src/analyze/annot_fn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -482,7 +482,33 @@ impl<'tcx> AnnotFnTranslator<'tcx> {
.next()
.is_some()
{
let pred = refine::user_defined_pred(self.tcx, def_id);
let param_env = self
.tcx
.param_env(self.local_def_id)
.with_reveal_all_normalized(self.tcx);
let generic_args = self.typeck.node_args(func_expr.hir_id);
tracing::debug!(
lhs = ?def_id,
lhs_generic_args = ?generic_args,
outer = ?self.local_def_id,
outer_generic_args = ?self.generic_args,
"resolving predicate call in formula"
);
let generic_args = mir_ty::EarlyBinder::bind(generic_args)
.instantiate(self.tcx, self.generic_args);
let instance = mir_ty::Instance::resolve(
self.tcx,
param_env,
def_id,
generic_args,
)
.unwrap();
let pred_def_id = if let Some(instance) = instance {
instance.def_id()
} else {
def_id
};
let pred = refine::user_defined_pred(self.tcx, pred_def_id);
let arg_terms = args.iter().map(|e| self.to_term(e)).collect();
let atom = chc::Atom::new(pred.into(), arg_terms);
return FormulaOrTerm::Formula(chc::Formula::Atom(atom));
Expand Down
95 changes: 57 additions & 38 deletions src/analyze/basic_block.rs
Original file line number Diff line number Diff line change
Expand Up @@ -420,12 +420,7 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
_ty,
) => {
let func_ty = match operand.const_fn_def() {
Some((def_id, args)) => self
.ctx
.def_ty_with_args(def_id, args)
.expect("unknown def")
.ty
.clone(),
Some((def_id, args)) => self.fn_def_ty(def_id, args),
_ => unimplemented!(),
};
PlaceType::with_ty_and_term(func_ty.vacuous(), chc::Term::null())
Expand Down Expand Up @@ -573,44 +568,68 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
});
}

fn resolve_fn_def(
&self,
def_id: DefId,
args: mir_ty::GenericArgsRef<'tcx>,
) -> (DefId, mir_ty::GenericArgsRef<'tcx>) {
if self.ctx.is_fn_trait_method(def_id) {
// When calling a closure via `Fn`/`FnMut`/`FnOnce` trait,
// we simply replace the def_id with the closure's function def_id.
// This skips shims, and makes self arguments mismatch. visitor::RustCallVisitor
// adjusts the arguments accordingly.
let mir_ty::TyKind::Closure(closure_def_id, _) = args.type_at(0).kind() else {
panic!("expected closure arg for fn trait");
};
tracing::debug!(?closure_def_id, "closure instance");
(*closure_def_id, args)
} else {
let param_env = self
.tcx
.param_env(self.local_def_id)
.with_reveal_all_normalized(self.tcx);
let instance = mir_ty::Instance::resolve(self.tcx, param_env, def_id, args).unwrap();
if let Some(instance) = instance {
(instance.def_id(), instance.args)
} else {
(def_id, args)
}
}
}

fn fn_def_ty(
&mut self,
def_id: DefId,
args: mir_ty::GenericArgsRef<'tcx>,
) -> rty::Type<rty::Closed> {
if let Some(def_ty) = self.ctx.def_ty_with_args(def_id, args) {
return def_ty.ty;
}

let (resolved_def_id, resolved_args) = self.resolve_fn_def(def_id, args);
if resolved_def_id == def_id {
panic!(
"unknown def (and not resolved): {:?}, args: {:?}",
def_id, args
);
}
tracing::info!(?def_id, ?resolved_def_id, ?resolved_args, "resolved");
let Some(def_ty) = self.ctx.def_ty_with_args(resolved_def_id, resolved_args) else {
panic!(
"unknown def (resolved): {:?}, args: {:?}",
resolved_def_id, resolved_args
);
};
def_ty.ty
}

fn type_call<I>(&mut self, func: Operand<'tcx>, args: I, expected_ret: &rty::RefinedType<Var>)
where
I: IntoIterator<Item = Operand<'tcx>>,
{
// TODO: handle const_fn_def on Env side
let func_ty = if let Some((def_id, args)) = func.const_fn_def() {
let (resolved_def_id, resolved_args) = if self.ctx.is_fn_trait_method(def_id) {
// When calling a closure via `Fn`/`FnMut`/`FnOnce` trait,
// we simply replace the def_id with the closure's function def_id.
// This skips shims, and makes self arguments mismatch. visitor::RustCallVisitor
// adjusts the arguments accordingly.
let mir_ty::TyKind::Closure(closure_def_id, _) = args.type_at(0).kind() else {
panic!("expected closure arg for fn trait");
};
tracing::debug!(?closure_def_id, "closure instance");
(*closure_def_id, args)
} else {
let param_env = self
.tcx
.param_env(self.local_def_id)
.with_reveal_all_normalized(self.tcx);
let instance =
mir_ty::Instance::resolve(self.tcx, param_env, def_id, args).unwrap();
if let Some(instance) = instance {
(instance.def_id(), instance.args)
} else {
(def_id, args)
}
};
if def_id != resolved_def_id {
tracing::info!(?def_id, ?resolved_def_id, ?resolved_args, "resolved");
}

self.ctx
.def_ty_with_args(resolved_def_id, resolved_args)
.expect("unknown def")
.ty
.vacuous()
self.fn_def_ty(def_id, args).vacuous()
} else {
self.operand_type(func.clone()).ty
};
Expand Down
2 changes: 1 addition & 1 deletion src/analyze/crate_.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
}

if analyzer.is_annotated_as_predicate() {
analyzer.analyze_predicate_definition(local_def_id);
analyzer.analyze_predicate_definition();
self.skip_analysis.insert(local_def_id);
return;
}
Expand Down
80 changes: 60 additions & 20 deletions src/analyze/local_def.rs
Original file line number Diff line number Diff line change
Expand Up @@ -115,22 +115,33 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
Some(self_ty)
}

pub fn analyze_predicate_definition(&self, local_def_id: LocalDefId) {
// predicate's name
// TODO: simply use refine::user_defined_pred for all functions
// after we dropped old annotation parser impl
let impl_type = self.impl_type();
let pred_item_name = self.tcx.item_name(local_def_id.to_def_id()).to_string();
let pred = match impl_type {
Some(t) => chc::UserDefinedPred::new(t.to_string() + "_" + &pred_item_name),
None => refine::user_defined_pred(self.tcx, local_def_id.to_def_id()),
};
pub fn analyze_predicate_definition(&self) {
self.define_as_predicate(refine::user_defined_pred(
self.tcx,
self.local_def_id.to_def_id(),
));

// For thrust::{requires,ensures} annotations which does not know DefId of the predicate
// during parsing, we also define a predicate with a name based on the self type name
//
// TODO: remove this after we dropped old annotation parser impl
// (then move this to crate_::Analyzer)
let pred_item_name = self
.tcx
.item_name(self.local_def_id.to_def_id())
.to_string();
if let Some(self_ty) = self.impl_type() {
let name = chc::UserDefinedPred::new(self_ty.to_string() + "_" + &pred_item_name);
self.define_as_predicate(name);
}
}

fn define_as_predicate(&self, pred: chc::UserDefinedPred) {
// function's body
use rustc_hir::{Block, Expr, ExprKind};

let hir_map = self.tcx.hir();
let body_id = hir_map.maybe_body_owned_by(local_def_id).unwrap();
let body_id = hir_map.maybe_body_owned_by(self.local_def_id).unwrap();
let hir_body = hir_map.body(body_id);

let predicate_body = match hir_body.value {
Expand All @@ -147,11 +158,11 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
// names and sorts of arguments
let arg_names = self
.tcx
.fn_arg_names(local_def_id.to_def_id())
.fn_arg_names(self.local_def_id.to_def_id())
.iter()
.map(|ident| ident.to_string());

let sig = self.ctx.fn_sig(local_def_id.to_def_id());
let sig = self.ctx.fn_sig(self.local_def_id.to_def_id());
let arg_sorts = sig
.inputs()
.iter()
Expand Down Expand Up @@ -276,17 +287,41 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
|| (all_params_annotated && has_ret)
}

pub fn trait_item_id(&self) -> Option<LocalDefId> {
pub fn local_trait_item_id(&self) -> Option<LocalDefId> {
let impl_item_assoc = self
.tcx
.opt_associated_item(self.local_def_id.to_def_id())?;
let trait_item_id = impl_item_assoc
.trait_item_def_id
.and_then(|id| id.as_local())?;

if trait_item_id == self.local_def_id {
return None;
}

Some(trait_item_id)
}

pub fn trait_item_ty(&mut self) -> Option<rty::RefinedType> {
let impl_did = self.tcx.parent(self.local_def_id.to_def_id());

if self.tcx.def_kind(impl_did) != (rustc_hir::def::DefKind::Impl { of_trait: true }) {
return None;
}

let trait_ref = self.tcx.impl_trait_ref(impl_did)?.instantiate_identity();
let trait_item_did = self
.tcx
.associated_item(self.local_def_id.to_def_id())
.trait_item_def_id
.unwrap();
self.ctx.def_ty_with_args(trait_item_did, trait_ref.args)
}

// Note that we do not expect predicate variables to be generated here
// when type params are still present in the type. Callers should ensure either
// - type params are fully instantiated, or
// - the function is fully annotated
pub fn expected_ty(&mut self) -> rty::RefinedType {
let sig = self
.ctx
Expand Down Expand Up @@ -324,7 +359,7 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
self.generic_args,
);

if let Some(trait_item_id) = self.trait_item_id() {
if let Some(trait_item_id) = self.local_trait_item_id() {
tracing::info!("trait item found: {:?}", trait_item_id);
let trait_require_annot = self.ctx.extract_require_annot(
trait_item_id,
Expand Down Expand Up @@ -364,6 +399,9 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
assert!(require_annot.is_none() || param_annots.is_empty());
assert!(ensure_annot.is_none() || ret_annot.is_none());

let trait_item_ty = self.trait_item_ty();
let is_fully_annotated = self.is_fully_annotated();

let mut builder = self.type_builder.for_function_template(&mut self.ctx, sig);
if let Some(AnnotFormula::Formula(require)) = require_annot {
let formula = require.map_var(|idx| {
Expand All @@ -387,11 +425,13 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
builder.ret_rty(ret_rty);
}

// Note that we do not expect predicate variables to be generated here
// when type params are still present in the type. Callers should ensure either
// - type params are fully instantiated, or
// - the function is fully annotated
rty::RefinedType::unrefined(builder.build().into())
if is_fully_annotated {
rty::RefinedType::unrefined(builder.build().into())
} else if let Some(trait_item_ty) = trait_item_ty {
trait_item_ty
} else {
rty::RefinedType::unrefined(builder.build().into())
}
}

/// Extract the target DefId from `#[thrust::extern_spec_fn]` function.
Expand Down
26 changes: 13 additions & 13 deletions std.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ mod thrust_models {
}

#[thrust::def::mut_model]
pub struct Mut<T>(PhantomData<T>);
pub struct Mut<T: ?Sized>(PhantomData<T>);

impl<T> Mut<T> {
#[allow(dead_code)]
Expand Down Expand Up @@ -100,7 +100,7 @@ mod thrust_models {
}

#[thrust::def::box_model]
pub struct Box<T>(PhantomData<T>);
pub struct Box<T: ?Sized>(PhantomData<T>);

impl<T> Box<T> {
#[allow(dead_code)]
Expand Down Expand Up @@ -128,7 +128,7 @@ mod thrust_models {
}

#[thrust::def::array_model]
pub struct Array<I, T>(PhantomData<I>, PhantomData<T>);
pub struct Array<I: ?Sized, T: ?Sized>(PhantomData<I>, PhantomData<T>);

impl<I, T, U> PartialEq<U> for Array<I, T> where U: super::Model<Ty = Self> {
#[thrust::ignored]
Expand Down Expand Up @@ -156,9 +156,9 @@ mod thrust_models {
}

#[thrust::def::closure_model]
pub struct Closure<T>(PhantomData<T>);
pub struct Closure<T: ?Sized>(PhantomData<T>);

pub struct Vec<T>(pub Array<Int, T>, pub Int);
pub struct Vec<T: ?Sized>(pub Array<Int, T>, pub Int);

impl<T, U> PartialEq<U> for Vec<T> where U: super::Model<Ty = Self> {
#[thrust::ignored]
Expand Down Expand Up @@ -200,7 +200,7 @@ mod thrust_models {
type Ty = bool;
}

impl<T> Model for model::Closure<T> {
impl<T: ?Sized> Model for model::Closure<T> {
type Ty = model::Closure<T>;
}

Expand All @@ -224,35 +224,35 @@ mod thrust_models {
impl_tuple_model!(T0, T1, T2, T3, T4, T5, T6, T7, T8);
impl_tuple_model!(T0, T1, T2, T3, T4, T5, T6, T7, T8, T9);

impl<'a, T> Model for &'a mut T where T: Model {
impl<'a, T: ?Sized> Model for &'a mut T where T: Model {
type Ty = model::Mut<<T as Model>::Ty>;
}

impl<T> Model for model::Mut<T> {
impl<T: ?Sized> Model for model::Mut<T> {
type Ty = model::Mut<T>;
}

impl<'a, T> Model for &'a T where T: Model {
impl<'a, T: ?Sized> Model for &'a T where T: Model {
type Ty = &'a <T as Model>::Ty;
}

impl<T> Model for Box<T> where T: Model {
impl<T: ?Sized> Model for Box<T> where T: Model {
type Ty = model::Box<<T as Model>::Ty>;
}

impl<T> Model for model::Box<T> {
impl<T: ?Sized> Model for model::Box<T> {
type Ty = model::Box<T>;
}

impl<I, T> Model for model::Array<I, T> {
impl<I: ?Sized, T: ?Sized> Model for model::Array<I, T> {
type Ty = model::Array<I, T>;
}

impl<T> Model for Vec<T> where T: Model {
type Ty = model::Vec<<T as Model>::Ty>;
}

impl<T> Model for model::Vec<T> {
impl<T: ?Sized> Model for model::Vec<T> {
type Ty = model::Vec<T>;
}

Expand Down
Loading