lean4-htt/src/Lean/Compiler/LCNF/AlphaEqv.lean
Henrik Böving e96d969d59
feat: support for del, isShared, oset and setTag (#12687)
This PR implements the LCNF instructions required for the expand reset
reuse pass.
2026-02-25 10:43:15 +00:00

195 lines
7.3 KiB
Text
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

/-
Copyright (c) 2022 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
module
prelude
public import Lean.Compiler.LCNF.Basic
import Init.Omega
public section
namespace Lean.Compiler.LCNF
/-!
Alpha equivalence for LCNF Code
-/
namespace AlphaEqv
abbrev EqvM := ReaderM (FVarIdMap FVarId)
def eqvFVar (fvarId₁ fvarId₂ : FVarId) : EqvM Bool := do
let fvarId₂ := (← read).get? fvarId₂ |>.getD fvarId₂
return fvarId₁ == fvarId₂
def eqvType (e₁ e₂ : Expr) : EqvM Bool := do
match e₁, e₂ with
| .app f₁ a₁, .app f₂ a₂ => eqvType a₁ a₂ <&&> eqvType f₁ f₂
| .fvar fvarId₁, .fvar fvarId₂ => eqvFVar fvarId₁ fvarId₂
| .forallE _ d₁ b₁ _, .forallE _ d₂ b₂ _ => eqvType d₁ d₂ <&&> eqvType b₁ b₂
| _, _ => return e₁ == e₂
def eqvTypes (es₁ es₂ : Array Expr) : EqvM Bool := do
if es₁.size = es₂.size then
for e₁ in es₁, e₂ in es₂ do
unless (← eqvType e₁ e₂) do
return false
return true
else
return false
def eqvArg (a₁ a₂ : Arg pu) : EqvM Bool := do
match a₁, a₂ with
| .type e₁ _, .type e₂ _ => eqvType e₁ e₂
| .fvar x₁, .fvar x₂ => eqvFVar x₁ x₂
| .erased, .erased => return true
| _, _ => return false
def eqvArgs (as₁ as₂ : Array (Arg pu)) : EqvM Bool := do
if as₁.size = as₂.size then
for a₁ in as₁, a₂ in as₂ do
unless (← eqvArg a₁ a₂) do
return false
return true
else
return false
def eqvLetValue (e₁ e₂ : LetValue pu) : EqvM Bool := do
match e₁, e₂ with
| .lit v₁, .lit v₂ => return v₁ == v₂
| .erased, .erased => return true
| .proj s₁ i₁ x₁ _, .proj s₂ i₂ x₂ _ => pure (s₁ == s₂ && i₁ == i₂) <&&> eqvFVar x₁ x₂
| .const n₁ us₁ as₁ _, .const n₂ us₂ as₂ _ => pure (n₁ == n₂ && us₁ == us₂) <&&> eqvArgs as₁ as₂
| .fvar f₁ as₁, .fvar f₂ as₂ => eqvFVar f₁ f₂ <&&> eqvArgs as₁ as₂
| .ctor i₁ as₁ _, .ctor i₂ as₂ _ => pure (i₁ == i₂) <&&> eqvArgs as₁ as₂
| .oproj i₁ v₁ _, .oproj i₂ v₂ _ => pure (i₁ == i₂) <&&> eqvFVar v₁ v₂
| .uproj i₁ v₁ _, .uproj i₂ v₂ _ => pure (i₁ == i₂) <&&> eqvFVar v₁ v₂
| .sproj i₁ o₁ v₁ _, .sproj i₂ o₂ v₂ _ => pure (i₁ == i₂ && o₁ == o₂) <&&> eqvFVar v₁ v₂
| .fap f₁ as₁ _, .fap f₂ as₂ _ => pure (f₁ == f₂) <&&> eqvArgs as₁ as₂
| .pap f₁ as₁ _, .pap f₂ as₂ _ => pure (f₁ == f₂) <&&> eqvArgs as₁ as₂
| .reset n₁ v₁ _, .reset n₂ v₂ _ => pure (n₁ == n₂) <&&> eqvFVar v₁ v₂
| .reuse v₁ i₁ u₁ as₁ _, .reuse v₂ i₂ u₂ as₂ _ =>
pure (i₁ == i₂ && u₁ == u₂) <&&> eqvFVar v₁ v₂ <&&> eqvArgs as₁ as₂
| .box ty₁ v₁ _, .box ty₂ v₂ _ => eqvType ty₁ ty₂ <&&> eqvFVar v₁ v₂
| .unbox v₁ _, .unbox v₂ _ => eqvFVar v₁ v₂
| .isShared v₁ _, .isShared v₂ _ => eqvFVar v₁ v₂
| _, _ => return false
@[inline] def withFVar (fvarId₁ fvarId₂ : FVarId) (x : EqvM α) : EqvM α :=
withReader (·.insert fvarId₂ fvarId₁) x
@[inline] def withParams (params₁ params₂ : Array (Param pu)) (x : EqvM Bool) : EqvM Bool := do
if h : params₂.size = params₁.size then
let rec @[specialize] go (i : Nat) : EqvM Bool := do
if h : i < params₁.size then
let p₁ := params₁[i]
have : i < params₂.size := by simp_all +arith
let p₂ := params₂[i]
unless (← eqvType p₁.type p₂.type) do return false
withFVar p₁.fvarId p₂.fvarId do
go (i+1)
else
x
termination_by params₁.size - i
go 0
else
return false
def sortAlts (alts : Array (Alt pu)) : Array (Alt pu) :=
alts.qsort fun
| .alt .., .default .. => true
| .ctorAlt .., .default .. => true
| .alt ctorName₁ .., .alt ctorName₂ .. => Name.lt ctorName₁ ctorName₂
| .ctorAlt i₁ .., .ctorAlt i₂ .. => Name.lt i₁.name i₂.name
| _, _ => false
mutual
partial def eqvAlts (alts₁ alts₂ : Array (Alt pu)) : EqvM Bool := do
if alts₁.size = alts₂.size then
let alts₁ := sortAlts alts₁
let alts₂ := sortAlts alts₂
for alt₁ in alts₁, alt₂ in alts₂ do
match alt₁, alt₂ with
| .alt ctorName₁ ps₁ k₁ _, .alt ctorName₂ ps₂ k₂ _ =>
unless ctorName₁ == ctorName₂ do return false
unless (← withParams ps₁ ps₂ (eqv k₁ k₂)) do return false
| .ctorAlt i₁ k₁ _, .ctorAlt i₂ k₂ _ =>
unless i₁ == i₂ do return false
unless ← eqv k₁ k₂ do return false
| .default k₁, .default k₂ => unless (← eqv k₁ k₂) do return false
| _, _ => return false
return true
else
return false
partial def eqv (code₁ code₂ : Code pu) : EqvM Bool := do
match code₁, code₂ with
| .let decl₁ k₁, .let decl₂ k₂ =>
eqvType decl₁.type decl₂.type <&&>
eqvLetValue decl₁.value decl₂.value <&&>
withFVar decl₁.fvarId decl₂.fvarId (eqv k₁ k₂)
| .fun decl₁ k₁ _, .fun decl₂ k₂ _
| .jp decl₁ k₁, .jp decl₂ k₂ =>
eqvType decl₁.type decl₂.type <&&>
withParams decl₁.params decl₂.params (eqv decl₁.value decl₂.value) <&&>
withFVar decl₁.fvarId decl₂.fvarId (eqv k₁ k₂)
| .return fvarId₁, .return fvarId₂ => eqvFVar fvarId₁ fvarId₂
| .unreach type₁, .unreach type₂ => eqvType type₁ type₂
| .jmp fvarId₁ args₁, .jmp fvarId₂ args₂ => eqvFVar fvarId₁ fvarId₂ <&&> eqvArgs args₁ args₂
| .cases c₁, .cases c₂ =>
eqvFVar c₁.discr c₂.discr <&&>
eqvType c₁.resultType c₂.resultType <&&>
eqvAlts c₁.alts c₂.alts
| .oset fvarId₁ i₁ y₁ k₁ _, .oset fvarId₂ i₂ y₂ k₂ _ =>
pure (i₁ == i₂) <&&>
eqvFVar fvarId₁ fvarId₂ <&&>
eqvArg y₁ y₂ <&&>
eqv k₁ k₂
| .sset fvarId₁ i₁ offset₁ y₁ ty₁ k₁ _, .sset fvarId₂ i₂ offset₂ y₂ ty₂ k₂ _ =>
pure (i₁ == i₂) <&&>
pure (offset₁ == offset₂) <&&>
eqvFVar fvarId₁ fvarId₂ <&&>
eqvFVar y₁ y₂ <&&>
eqvType ty₁ ty₂ <&&>
eqv k₁ k₂
| .uset fvarId₁ i₁ y₁ k₁ _, .uset fvarId₂ i₂ y₂ k₂ _ =>
pure (i₁ == i₂) <&&>
eqvFVar fvarId₁ fvarId₂ <&&>
eqvFVar y₁ y₂ <&&>
eqv k₁ k₂
| .setTag fvarId₁ c₁ k₁ _, .setTag fvarId₂ c₂ k₂ _ =>
pure (c₁ == c₂) <&&>
eqvFVar fvarId₁ fvarId₂ <&&>
eqv k₁ k₂
| .inc fvarId₁ n₁ c₁ p₁ k₁ _, .inc fvarId₂ n₂ c₂ p₂ k₂ _ =>
pure (n₁ == n₂) <&&>
pure (c₁ == c₂) <&&>
pure (p₁ == p₂) <&&>
eqvFVar fvarId₁ fvarId₂ <&&>
eqv k₁ k₂
| .dec fvarId₁ n₁ c₁ p₁ k₁ _, .dec fvarId₂ n₂ c₂ p₂ k₂ _ =>
pure (n₁ == n₂) <&&>
pure (c₁ == c₂) <&&>
pure (p₁ == p₂) <&&>
eqvFVar fvarId₁ fvarId₂ <&&>
eqv k₁ k₂
| .del fvarId₁ k₁ _, .del fvarId₂ k₂ _ =>
eqvFVar fvarId₁ fvarId₂ <&&>
eqv k₁ k₂
| _, _ => return false
end
end AlphaEqv
/--
Return `true` if `c₁` and `c₂` are alpha equivalent.
-/
def Code.alphaEqv (c₁ c₂ : Code pu) : Bool :=
AlphaEqv.eqv c₁ c₂ |>.run {}
end Lean.Compiler.LCNF