Skip to content

Commit 2183340

Browse files
Auto merge of #150485 - dianqk:gvn-ssa-borrow, r=<try>
[DRAFT] GVN: Only propagate borrows from SSA-locals
2 parents fcd6309 + c090eb9 commit 2183340

26 files changed

+746
-328
lines changed

compiler/rustc_middle/src/mir/terminator.rs

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -695,28 +695,6 @@ impl<'tcx> TerminatorKind<'tcx> {
695695
_ => None,
696696
}
697697
}
698-
699-
/// Returns true if the terminator can write to memory.
700-
pub fn can_write_to_memory(&self) -> bool {
701-
match self {
702-
TerminatorKind::Goto { .. }
703-
| TerminatorKind::SwitchInt { .. }
704-
| TerminatorKind::UnwindResume
705-
| TerminatorKind::UnwindTerminate(_)
706-
| TerminatorKind::Return
707-
| TerminatorKind::Assert { .. }
708-
| TerminatorKind::CoroutineDrop
709-
| TerminatorKind::FalseEdge { .. }
710-
| TerminatorKind::FalseUnwind { .. }
711-
| TerminatorKind::Unreachable => false,
712-
TerminatorKind::Call { .. }
713-
| TerminatorKind::Drop { .. }
714-
| TerminatorKind::TailCall { .. }
715-
// Yield writes to the resume_arg place.
716-
| TerminatorKind::Yield { .. }
717-
| TerminatorKind::InlineAsm { .. } => true,
718-
}
719-
}
720698
}
721699

722700
#[derive(Copy, Clone, Debug)]

compiler/rustc_mir_transform/src/gvn.rs

Lines changed: 93 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -129,24 +129,18 @@ impl<'tcx> crate::MirPass<'tcx> for GVN {
129129
let ssa = SsaLocals::new(tcx, body, typing_env);
130130
// Clone dominators because we need them while mutating the body.
131131
let dominators = body.basic_blocks.dominators().clone();
132-
let maybe_loop_headers = loops::maybe_loop_headers(body);
133132

134133
let arena = DroplessArena::default();
135134
let mut state =
136135
VnState::new(tcx, body, typing_env, &ssa, dominators, &body.local_decls, &arena);
137136

138137
for local in body.args_iter().filter(|&local| ssa.is_ssa(local)) {
139-
let opaque = state.new_opaque(body.local_decls[local].ty);
138+
let opaque = state.new_argument(body.local_decls[local].ty);
140139
state.assign(local, opaque);
141140
}
142141

143142
let reverse_postorder = body.basic_blocks.reverse_postorder().to_vec();
144143
for bb in reverse_postorder {
145-
// N.B. With loops, reverse postorder cannot produce a valid topological order.
146-
// A statement or terminator from inside the loop, that is not processed yet, may have performed an indirect write.
147-
if maybe_loop_headers.contains(bb) {
148-
state.invalidate_derefs();
149-
}
150144
let data = &mut body.basic_blocks.as_mut_preserves_cfg()[bb];
151145
state.visit_basic_block_data(bb, data);
152146
}
@@ -206,6 +200,7 @@ enum Value<'a, 'tcx> {
206200
/// Used to represent values we know nothing about.
207201
/// The `usize` is a counter incremented by `new_opaque`.
208202
Opaque(VnOpaque),
203+
Argument(VnOpaque),
209204
/// Evaluated or unevaluated constant value.
210205
Constant {
211206
value: Const<'tcx>,
@@ -243,7 +238,10 @@ enum Value<'a, 'tcx> {
243238

244239
// Extractions.
245240
/// This is the *value* obtained by projecting another value.
246-
Projection(VnIndex, ProjectionElem<VnIndex, ()>),
241+
Projection {
242+
base: VnIndex,
243+
elem: ProjectionElem<VnIndex, ()>,
244+
},
247245
/// Discriminant of the given value.
248246
Discriminant(VnIndex),
249247

@@ -290,7 +288,7 @@ impl<'a, 'tcx> ValueSet<'a, 'tcx> {
290288
let value = value(VnOpaque);
291289

292290
debug_assert!(match value {
293-
Value::Opaque(_) | Value::Address { .. } => true,
291+
Value::Opaque(_) | Value::Argument(_) | Value::Address { .. } => true,
294292
Value::Constant { disambiguator, .. } => disambiguator.is_some(),
295293
_ => false,
296294
});
@@ -350,12 +348,6 @@ impl<'a, 'tcx> ValueSet<'a, 'tcx> {
350348
fn ty(&self, index: VnIndex) -> Ty<'tcx> {
351349
self.types[index]
352350
}
353-
354-
/// Replace the value associated with `index` with an opaque value.
355-
#[inline]
356-
fn forget(&mut self, index: VnIndex) {
357-
self.values[index] = Value::Opaque(VnOpaque);
358-
}
359351
}
360352

361353
struct VnState<'body, 'a, 'tcx> {
@@ -374,8 +366,6 @@ struct VnState<'body, 'a, 'tcx> {
374366
/// - `Some(None)` are values for which computation has failed;
375367
/// - `Some(Some(op))` are successful computations.
376368
evaluated: IndexVec<VnIndex, Option<Option<&'a OpTy<'tcx>>>>,
377-
/// Cache the deref values.
378-
derefs: Vec<VnIndex>,
379369
ssa: &'body SsaLocals,
380370
dominators: Dominators<BasicBlock>,
381371
reused_locals: DenseBitSet<Local>,
@@ -408,7 +398,6 @@ impl<'body, 'a, 'tcx> VnState<'body, 'a, 'tcx> {
408398
rev_locals: IndexVec::with_capacity(num_values),
409399
values: ValueSet::new(num_values),
410400
evaluated: IndexVec::with_capacity(num_values),
411-
derefs: Vec::new(),
412401
ssa,
413402
dominators,
414403
reused_locals: DenseBitSet::new_empty(local_decls.len()),
@@ -455,6 +444,13 @@ impl<'body, 'a, 'tcx> VnState<'body, 'a, 'tcx> {
455444
index
456445
}
457446

447+
#[instrument(level = "trace", skip(self), ret)]
448+
fn new_argument(&mut self, ty: Ty<'tcx>) -> VnIndex {
449+
let index = self.insert_unique(ty, Value::Argument);
450+
self.evaluated[index] = Some(None);
451+
index
452+
}
453+
458454
/// Create a new `Value::Address` distinct from all the others.
459455
#[instrument(level = "trace", skip(self), ret)]
460456
fn new_pointer(&mut self, place: Place<'tcx>, kind: AddressKind) -> Option<VnIndex> {
@@ -541,18 +537,6 @@ impl<'body, 'a, 'tcx> VnState<'body, 'a, 'tcx> {
541537
self.insert(ty, Value::Aggregate(VariantIdx::ZERO, self.arena.alloc_slice(values)))
542538
}
543539

544-
fn insert_deref(&mut self, ty: Ty<'tcx>, value: VnIndex) -> VnIndex {
545-
let value = self.insert(ty, Value::Projection(value, ProjectionElem::Deref));
546-
self.derefs.push(value);
547-
value
548-
}
549-
550-
fn invalidate_derefs(&mut self) {
551-
for deref in std::mem::take(&mut self.derefs) {
552-
self.values.forget(deref);
553-
}
554-
}
555-
556540
#[instrument(level = "trace", skip(self), ret)]
557541
fn eval_to_const_inner(&mut self, value: VnIndex) -> Option<OpTy<'tcx>> {
558542
use Value::*;
@@ -566,7 +550,7 @@ impl<'body, 'a, 'tcx> VnState<'body, 'a, 'tcx> {
566550
let op = match self.get(value) {
567551
_ if ty.is_zst() => ImmTy::uninit(ty).into(),
568552

569-
Opaque(_) => return None,
553+
Opaque(_) | Argument(_) => return None,
570554
// Keep runtime check constants as symbolic.
571555
RuntimeChecks(..) => return None,
572556

@@ -648,7 +632,7 @@ impl<'body, 'a, 'tcx> VnState<'body, 'a, 'tcx> {
648632
ImmTy::from_immediate(ptr_imm, ty).into()
649633
}
650634

651-
Projection(base, elem) => {
635+
Projection { base, elem, .. } => {
652636
let base = self.eval_to_const(base)?;
653637
// `Index` by constants should have been replaced by `ConstantIndex` by
654638
// `simplify_place_projection`.
@@ -818,7 +802,13 @@ impl<'body, 'a, 'tcx> VnState<'body, 'a, 'tcx> {
818802

819803
// An immutable borrow `_x` always points to the same value for the
820804
// lifetime of the borrow, so we can merge all instances of `*_x`.
821-
return Some((projection_ty, self.insert_deref(projection_ty.ty, value)));
805+
return Some((
806+
projection_ty,
807+
self.insert(
808+
projection_ty.ty,
809+
Value::Projection { base: value, elem: ProjectionElem::Deref },
810+
),
811+
));
822812
} else {
823813
return None;
824814
}
@@ -827,8 +817,8 @@ impl<'body, 'a, 'tcx> VnState<'body, 'a, 'tcx> {
827817
ProjectionElem::Field(f, _) => match self.get(value) {
828818
Value::Aggregate(_, fields) => return Some((projection_ty, fields[f.as_usize()])),
829819
Value::Union(active, field) if active == f => return Some((projection_ty, field)),
830-
Value::Projection(outer_value, ProjectionElem::Downcast(_, read_variant))
831-
if let Value::Aggregate(written_variant, fields) = self.get(outer_value)
820+
Value::Projection { base, elem: ProjectionElem::Downcast(_, read_variant) }
821+
if let Value::Aggregate(written_variant, fields) = self.get(base)
832822
// This pass is not aware of control-flow, so we do not know whether the
833823
// replacement we are doing is actually reachable. We could be in any arm of
834824
// ```
@@ -881,7 +871,7 @@ impl<'body, 'a, 'tcx> VnState<'body, 'a, 'tcx> {
881871
ProjectionElem::UnwrapUnsafeBinder(_) => ProjectionElem::UnwrapUnsafeBinder(()),
882872
};
883873

884-
let value = self.insert(projection_ty.ty, Value::Projection(value, proj));
874+
let value = self.insert(projection_ty.ty, Value::Projection { base: value, elem: proj });
885875
Some((projection_ty, value))
886876
}
887877

@@ -1114,11 +1104,14 @@ impl<'body, 'a, 'tcx> VnState<'body, 'a, 'tcx> {
11141104
fields: &[VnIndex],
11151105
) -> Option<VnIndex> {
11161106
let Some(&first_field) = fields.first() else { return None };
1117-
let Value::Projection(copy_from_value, _) = self.get(first_field) else { return None };
1107+
let Value::Projection { base: copy_from_value, .. } = self.get(first_field) else {
1108+
return None;
1109+
};
11181110

11191111
// All fields must correspond one-to-one and come from the same aggregate value.
11201112
if fields.iter().enumerate().any(|(index, &v)| {
1121-
if let Value::Projection(pointer, ProjectionElem::Field(from_index, _)) = self.get(v)
1113+
if let Value::Projection { base: pointer, elem: ProjectionElem::Field(from_index, _) } =
1114+
self.get(v)
11221115
&& copy_from_value == pointer
11231116
&& from_index.index() == index
11241117
{
@@ -1130,7 +1123,7 @@ impl<'body, 'a, 'tcx> VnState<'body, 'a, 'tcx> {
11301123
}
11311124

11321125
let mut copy_from_local_value = copy_from_value;
1133-
if let Value::Projection(pointer, proj) = self.get(copy_from_value)
1126+
if let Value::Projection { base: pointer, elem: proj } = self.get(copy_from_value)
11341127
&& let ProjectionElem::Downcast(_, read_variant) = proj
11351128
{
11361129
if variant_index == read_variant {
@@ -1142,6 +1135,45 @@ impl<'body, 'a, 'tcx> VnState<'body, 'a, 'tcx> {
11421135
}
11431136
}
11441137

1138+
// We can introduce a new dereference if the source value cannot be changed in the body.
1139+
let mut copy_root = copy_from_local_value;
1140+
loop {
1141+
match self.get(copy_root) {
1142+
Value::Projection { base, .. } => {
1143+
copy_root = base;
1144+
}
1145+
Value::Address {
1146+
base,
1147+
projection,
1148+
kind: AddressKind::Ref(BorrowKind::Shared),
1149+
..
1150+
} if projection.iter().all(ProjectionElem::is_stable_offset) => match base {
1151+
AddressBase::Local(local) => {
1152+
if !self.ssa.is_ssa(local) {
1153+
return None;
1154+
}
1155+
break;
1156+
}
1157+
AddressBase::Deref(index) => {
1158+
copy_root = index;
1159+
}
1160+
},
1161+
Value::Argument(_) if !self.ty(copy_root).is_mutable_ptr() => {
1162+
break;
1163+
}
1164+
Value::Opaque(_) => {
1165+
let ty = self.ty(copy_root);
1166+
if ty.is_fn() || !ty.is_any_ptr() {
1167+
break;
1168+
}
1169+
return None;
1170+
}
1171+
_ => {
1172+
return None;
1173+
}
1174+
}
1175+
}
1176+
11451177
// Both must be variants of the same type.
11461178
if self.ty(copy_from_local_value) == ty { Some(copy_from_local_value) } else { None }
11471179
}
@@ -1843,7 +1875,7 @@ impl<'tcx> VnState<'_, '_, 'tcx> {
18431875
// If we are here, we failed to find a local, and we already have a `Deref`.
18441876
// Trying to add projections will only result in an ill-formed place.
18451877
return None;
1846-
} else if let Value::Projection(pointer, proj) = self.get(index)
1878+
} else if let Value::Projection { base: pointer, elem: proj } = self.get(index)
18471879
&& (allow_complex_projection || proj.is_stable_offset())
18481880
&& let Some(proj) = self.try_as_place_elem(self.ty(index), proj, loc)
18491881
{
@@ -1873,10 +1905,6 @@ impl<'tcx> MutVisitor<'tcx> for VnState<'_, '_, 'tcx> {
18731905

18741906
fn visit_place(&mut self, place: &mut Place<'tcx>, context: PlaceContext, location: Location) {
18751907
self.simplify_place_projection(place, location);
1876-
if context.is_mutating_use() && place.is_indirect() {
1877-
// Non-local mutation maybe invalidate deref.
1878-
self.invalidate_derefs();
1879-
}
18801908
self.super_place(place, context, location);
18811909
}
18821910

@@ -1893,7 +1921,7 @@ impl<'tcx> MutVisitor<'tcx> for VnState<'_, '_, 'tcx> {
18931921
) {
18941922
self.simplify_place_projection(lhs, location);
18951923

1896-
let value = self.simplify_rvalue(lhs, rvalue, location);
1924+
let mut value = self.simplify_rvalue(lhs, rvalue, location);
18971925
if let Some(value) = value {
18981926
if let Some(const_) = self.try_as_constant(value) {
18991927
*rvalue = Rvalue::Use(Operand::Constant(Box::new(const_)));
@@ -1906,14 +1934,30 @@ impl<'tcx> MutVisitor<'tcx> for VnState<'_, '_, 'tcx> {
19061934
}
19071935
}
19081936

1909-
if lhs.is_indirect() {
1910-
// Non-local mutation maybe invalidate deref.
1911-
self.invalidate_derefs();
1937+
let rvalue_ty = rvalue.ty(self.local_decls, self.tcx);
1938+
// DO NOT reason the pointer value if it may point to a non-SSA local.
1939+
// For instance, we cannot unify two pointers that dereference same local, because they may
1940+
// have different lifetimes.
1941+
if rvalue_ty.is_ref()
1942+
&& let Some(index) = value
1943+
{
1944+
match self.get(index) {
1945+
Value::Opaque(_) | Value::Projection { .. } => {
1946+
value = None;
1947+
}
1948+
Value::Constant { .. } | Value::Address { .. } | Value::Argument(_) => {}
1949+
Value::RawPtr { .. } | Value::BinaryOp(..) | Value::Cast { .. } => {}
1950+
Value::Aggregate(..)
1951+
| Value::Union(..)
1952+
| Value::Repeat(..)
1953+
| Value::Discriminant(..)
1954+
| Value::RuntimeChecks(..)
1955+
| Value::UnaryOp(..) => unreachable!(),
1956+
}
19121957
}
19131958

19141959
if let Some(local) = lhs.as_local()
19151960
&& self.ssa.is_ssa(local)
1916-
&& let rvalue_ty = rvalue.ty(self.local_decls, self.tcx)
19171961
// FIXME(#112651) `rvalue` may have a subtype to `local`. We can only mark
19181962
// `local` as reusable if we have an exact type match.
19191963
&& self.local_decls[local].ty == rvalue_ty
@@ -1933,10 +1977,6 @@ impl<'tcx> MutVisitor<'tcx> for VnState<'_, '_, 'tcx> {
19331977
self.assign(local, opaque);
19341978
}
19351979
}
1936-
// Terminators that can write to memory may invalidate (nested) derefs.
1937-
if terminator.kind.can_write_to_memory() {
1938-
self.invalidate_derefs();
1939-
}
19401980
self.super_terminator(terminator, location);
19411981
}
19421982
}
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
- // MIR for `dereference_reborrow` before GVN
2+
+ // MIR for `dereference_reborrow` after GVN
3+
4+
fn dereference_reborrow(_1: &mut u8) -> () {
5+
debug mut_a => _1;
6+
let mut _0: ();
7+
let _2: &u8;
8+
scope 1 {
9+
debug a => _2;
10+
let _3: u8;
11+
scope 2 {
12+
debug b => _3;
13+
let _4: u8;
14+
scope 3 {
15+
debug c => _4;
16+
}
17+
}
18+
}
19+
20+
bb0: {
21+
StorageLive(_2);
22+
_2 = &(*_1);
23+
- StorageLive(_3);
24+
+ nop;
25+
_3 = copy (*_2);
26+
StorageLive(_4);
27+
- _4 = copy (*_2);
28+
+ _4 = copy _3;
29+
_0 = const ();
30+
StorageDead(_4);
31+
- StorageDead(_3);
32+
+ nop;
33+
StorageDead(_2);
34+
return;
35+
}
36+
}
37+

0 commit comments

Comments
 (0)