feat: structural isDefEq for Sym (#11819)
This PR adds some basic infrastructure for a structural (and cheaper) `isDefEq` predicate for pattern matching and unification in `Sym`.
This commit is contained in:
parent
96160e553a
commit
48bb954e4e
5 changed files with 191 additions and 57 deletions
|
|
@ -14,6 +14,8 @@ public import Lean.Meta.Sym.LooseBVarsS
|
|||
public import Lean.Meta.Sym.InstantiateS
|
||||
public import Lean.Meta.Sym.IsClass
|
||||
public import Lean.Meta.Sym.Intro
|
||||
public import Lean.Meta.Sym.InstantiateMVarsS
|
||||
public import Lean.Meta.Sym.ProofInstInfo
|
||||
public import Lean.Meta.Sym.Pattern
|
||||
|
||||
/-!
|
||||
|
|
|
|||
21
src/Lean/Meta/Sym/InstantiateMVarsS.lean
Normal file
21
src/Lean/Meta/Sym/InstantiateMVarsS.lean
Normal file
|
|
@ -0,0 +1,21 @@
|
|||
/-
|
||||
Copyright (c) 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Leonardo de Moura
|
||||
-/
|
||||
module
|
||||
prelude
|
||||
public import Lean.Meta.Sym.SymM
|
||||
namespace Lean.Meta.Sym
|
||||
|
||||
/--
|
||||
Instantiates metavariables occurring in `e`, and returns a maximally shared term.
|
||||
-/
|
||||
def instantiateMVarsS (e : Expr) : SymM Expr := do
|
||||
if e.hasMVar then
|
||||
-- **Note**: If this is a bottleneck, write a new function that combines both steps.
|
||||
Grind.shareCommon (← instantiateMVars e)
|
||||
else
|
||||
return e
|
||||
|
||||
end Lean.Meta.Sym
|
||||
|
|
@ -9,6 +9,8 @@ public import Lean.Meta.Sym.SymM
|
|||
import Lean.Util.FoldConsts
|
||||
import Lean.Meta.Sym.InstantiateS
|
||||
import Lean.Meta.Sym.IsClass
|
||||
import Lean.Meta.Sym.ProofInstInfo
|
||||
import Lean.Meta.Tactic.Grind.AlphaShareBuilder
|
||||
namespace Lean.Meta.Sym
|
||||
open Grind
|
||||
|
||||
|
|
@ -34,67 +36,16 @@ framework (`Sym`). The design prioritizes performance by using a two-phase appro
|
|||
- `instantiateRevS` ensures maximal sharing of result expressions
|
||||
-/
|
||||
|
||||
def preprocessType (type : Expr) : MetaM Expr := do
|
||||
let type ← unfoldReducible type
|
||||
let type ← Core.betaReduce type
|
||||
zetaReduce type
|
||||
|
||||
/--
|
||||
Information about a single argument position in a function's type signature.
|
||||
|
||||
This is used during pattern matching to identify arguments that can be skipped
|
||||
or handled specially (e.g., instance arguments can be synthesized, proof arguments
|
||||
can be inferred).
|
||||
-/
|
||||
public structure PatternArgInfo where
|
||||
/-- `true` if this argument is a proof (its type is a `Prop`). -/
|
||||
isProof : Bool
|
||||
/-- `true` if this argument is a type class instance. -/
|
||||
isInstance : Bool
|
||||
|
||||
/--
|
||||
Information about a function symbol occurring in a pattern.
|
||||
|
||||
Stores which argument positions are proofs or instances, enabling optimizations
|
||||
during pattern matching such as skipping proof arguments or deferring instance synthesis.
|
||||
-/
|
||||
public structure FunPatternInfo where
|
||||
/-- Information about each argument position. -/
|
||||
argsInfo : Array PatternArgInfo
|
||||
|
||||
/--
|
||||
Analyzes the type signature of `declName` and returns information about which arguments
|
||||
are proofs or instances. Returns `none` if no arguments are proofs or instances.
|
||||
-/
|
||||
def mkFunPatternInfo? (declName : Name) : MetaM (Option FunPatternInfo) := do
|
||||
let info ← getConstInfo declName
|
||||
let type ← preprocessType info.type
|
||||
forallTelescopeReducing type fun xs _ => do
|
||||
let env ← getEnv
|
||||
let mut argsInfo := #[]
|
||||
let mut found := false
|
||||
for x in xs do
|
||||
let type ← inferType x
|
||||
let isInstance := isClass? env type |>.isSome
|
||||
let isProof ← isProp type
|
||||
if isInstance || isProof then
|
||||
found := true
|
||||
argsInfo := argsInfo.push { isInstance, isProof }
|
||||
if found then
|
||||
return some { argsInfo }
|
||||
else
|
||||
return none
|
||||
|
||||
/--
|
||||
Collects `FunPatternInfo` for all function symbols occurring in `pattern`.
|
||||
Collects `ProofInstInfo` for all function symbols occurring in `pattern`.
|
||||
|
||||
Only includes functions that have at least one proof or instance argument.
|
||||
-/
|
||||
def mkFunInfosFor (pattern : Expr) : MetaM (AssocList Name FunPatternInfo) := do
|
||||
def mkProofInstInfoMapFor (pattern : Expr) : MetaM (AssocList Name ProofInstInfo) := do
|
||||
let cs := pattern.getUsedConstants
|
||||
let mut fnInfos := {}
|
||||
for declName in cs do
|
||||
if let some info ← mkFunPatternInfo? declName then
|
||||
if let some info ← mkProofInstInfo? declName then
|
||||
fnInfos := fnInfos.insertNew declName info
|
||||
return fnInfos
|
||||
|
||||
|
|
@ -102,7 +53,7 @@ public structure Pattern where
|
|||
levelParams : List Name
|
||||
varTypes : Array Expr
|
||||
pattern : Expr
|
||||
fnInfos : AssocList Name FunPatternInfo
|
||||
fnInfos : AssocList Name ProofInstInfo
|
||||
deriving Inhabited
|
||||
|
||||
def uvarPrefix : Name := `_uvar
|
||||
|
|
@ -124,7 +75,7 @@ public def mkPatternFromTheorem (declName : Name) : MetaM Pattern := do
|
|||
| .forallE _ d b _ => go b (varTypes.push d)
|
||||
| _ =>
|
||||
let pattern := type
|
||||
let fnInfos ← mkFunInfosFor pattern
|
||||
let fnInfos ← mkProofInstInfoMapFor pattern
|
||||
return { levelParams, varTypes, pattern, fnInfos }
|
||||
go type #[]
|
||||
|
||||
|
|
@ -248,7 +199,7 @@ where
|
|||
let numArgs := p.getAppNumArgs
|
||||
processAppWithInfo p e (numArgs - 1) info
|
||||
|
||||
processAppWithInfo (p : Expr) (e : Expr) (i : Nat) (info : FunPatternInfo) : UnifyM Bool := do
|
||||
processAppWithInfo (p : Expr) (e : Expr) (i : Nat) (info : ProofInstInfo) : UnifyM Bool := do
|
||||
let .app fp ap := p | process p e
|
||||
let .app fe ae := e | return false
|
||||
unless (← processAppWithInfo fp fe (i - 1) info) do return false
|
||||
|
|
@ -283,6 +234,86 @@ where
|
|||
let .sort v := e | return false
|
||||
processLevel u v
|
||||
|
||||
def isLevelDefEqS (u : Level) (v : Level) : MetaM Bool := do
|
||||
match u, v with
|
||||
| .param u, .param v => return u == v
|
||||
| .zero, .zero => return true
|
||||
| .succ u, .succ v => isLevelDefEqS u v
|
||||
| .zero, .succ _ => return false
|
||||
| .succ _, .zero => return false
|
||||
| .zero, .max v₁ v₂ => isLevelDefEqS .zero v₁ <&&> isLevelDefEqS .zero v₂
|
||||
| .max u₁ u₂, .zero => isLevelDefEqS u₁ .zero <&&> isLevelDefEqS u₂ .zero
|
||||
| .zero, .imax _ v => isLevelDefEqS .zero v
|
||||
| .imax _ u, .zero => isLevelDefEqS u .zero
|
||||
| .max u₁ u₂, .max v₁ v₂ => isLevelDefEqS u₁ v₁ <&&> isLevelDefEqS u₂ v₂
|
||||
| .imax u₁ u₂, .imax v₁ v₂ => isLevelDefEqS u₁ v₁ <&&> isLevelDefEqS u₂ v₂
|
||||
| .mvar mvarId, v => assignLevelMVar mvarId v; return true
|
||||
| u, .mvar mvarId => assignLevelMVar mvarId u; return true
|
||||
| _, _ => return false
|
||||
|
||||
structure DefEqM.Context where
|
||||
unify : Bool := true
|
||||
zetaDelta : Bool := true
|
||||
/--
|
||||
If `unit
|
||||
-/
|
||||
mvarsNew : Array MVarId := #[]
|
||||
|
||||
abbrev DefEqM := ReaderT DefEqM.Context SymM
|
||||
|
||||
/--
|
||||
Structural definitional equality. It is much cheaper than `isDefEq`.
|
||||
-/
|
||||
@[extern "lean_sym_def_eq"] -- Forward definition
|
||||
opaque isDefEqS : Expr → Expr → DefEqM Bool
|
||||
|
||||
/--
|
||||
Structural definitional equality for `forall` and `lambda` binders.
|
||||
-/
|
||||
def isDefEqBindingS (a b : Expr) : DefEqM Bool := do
|
||||
let lctx ← getLCtx
|
||||
let localInsts ← getLocalInstances
|
||||
go lctx localInsts #[] a b #[]
|
||||
where
|
||||
checkDomains (fvars : Array Expr) (ds₂ : Array Expr) : DefEqM Bool := do
|
||||
for fvar in fvars, d in ds₂ do
|
||||
let fvarType ← fvar.fvarId!.getType
|
||||
unless (← isDefEqS fvarType d) do return false
|
||||
return true
|
||||
|
||||
go (lctx : LocalContext) (localInsts : LocalInstances) (fvars : Array Expr) (e₁ e₂ : Expr) (ds₂ : Array Expr) : DefEqM Bool := do
|
||||
match e₁, e₂ with
|
||||
| .forallE n d₁ b₁ _, .forallE _ d₂ b₂ _
|
||||
| .lam n d₁ b₁ _, .lam _ d₂ b₂ _ =>
|
||||
let d₁ ← instantiateRevS d₁ fvars
|
||||
let d₂ ← instantiateRevS d₂ fvars
|
||||
let fvarId ← mkFreshFVarId
|
||||
let fvar ← mkFVarS fvarId
|
||||
let lctx := lctx.mkLocalDecl fvarId n d₁
|
||||
let localInsts := if let some className := isClass? (← getEnv) d₁ then
|
||||
localInsts.push { className, fvar }
|
||||
else
|
||||
localInsts
|
||||
go lctx localInsts (fvars.push fvar) b₁ b₂ (ds₂.push d₂)
|
||||
| _, _ => withLCtx lctx localInsts do
|
||||
unless (← checkDomains fvars ds₂) do return false
|
||||
isDefEqS (← instantiateRevS e₁ fvars) (← instantiateRevS e₂ fvars)
|
||||
|
||||
/--
|
||||
`isDefEqS` implementation.
|
||||
-/
|
||||
@[export lean_sym_def_eq]
|
||||
def isDefEqSImpl (t : Expr) (s : Expr) : DefEqM Bool := do
|
||||
if isSameExpr t s then return true
|
||||
match t, s with
|
||||
| .lit l₁, .lit l₂ => return l₁ == l₂
|
||||
| .sort u, .sort v => isLevelDefEqS u v
|
||||
| .lam .., .lam .. => isDefEqBindingS t s
|
||||
| .forallE .., .forallE .. => isDefEqBindingS t s
|
||||
| _, _ =>
|
||||
-- **TODO**
|
||||
return false
|
||||
|
||||
def noPending : UnifyM Bool := do
|
||||
let s ← get
|
||||
return s.ePending.isEmpty && s.uPending.isEmpty && s.iPending.isEmpty
|
||||
|
|
|
|||
55
src/Lean/Meta/Sym/ProofInstInfo.lean
Normal file
55
src/Lean/Meta/Sym/ProofInstInfo.lean
Normal file
|
|
@ -0,0 +1,55 @@
|
|||
/-
|
||||
Copyright (c) 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Leonardo de Moura
|
||||
-/
|
||||
module
|
||||
prelude
|
||||
public import Lean.Meta.Sym.SymM
|
||||
import Lean.Meta.Sym.IsClass
|
||||
namespace Lean.Meta.Sym
|
||||
|
||||
/--
|
||||
Preprocesses types that used for pattern matching and unification.
|
||||
-/
|
||||
public def preprocessType (type : Expr) : MetaM Expr := do
|
||||
let type ← Grind.unfoldReducible type
|
||||
let type ← Core.betaReduce type
|
||||
zetaReduce type
|
||||
|
||||
/--
|
||||
Analyzes the type signature of `declName` and returns information about which arguments
|
||||
are proofs or instances. Returns `none` if no arguments are proofs or instances.
|
||||
-/
|
||||
public def mkProofInstInfo? (declName : Name) : MetaM (Option ProofInstInfo) := do
|
||||
let info ← getConstInfo declName
|
||||
let type ← preprocessType info.type
|
||||
forallTelescopeReducing type fun xs _ => do
|
||||
let env ← getEnv
|
||||
let mut argsInfo := #[]
|
||||
let mut found := false
|
||||
for x in xs do
|
||||
let type ← inferType x
|
||||
let isInstance := isClass? env type |>.isSome
|
||||
let isProof ← isProp type
|
||||
if isInstance || isProof then
|
||||
found := true
|
||||
argsInfo := argsInfo.push { isInstance, isProof }
|
||||
if found then
|
||||
return some { argsInfo }
|
||||
else
|
||||
return none
|
||||
|
||||
/--
|
||||
Returns information about the type signature of `declName`. It contains information about which arguments
|
||||
are proofs or instances. Returns `none` if no arguments are proofs or instances.
|
||||
-/
|
||||
public def getProofInstInfo? (declName : Name) : SymM (Option ProofInstInfo) := do
|
||||
if let some r := (← get).proofInstInfo.find? declName then
|
||||
return r
|
||||
else
|
||||
let r ← mkProofInstInfo? declName
|
||||
modify fun s => { s with proofInstInfo := s.proofInstInfo.insert declName r }
|
||||
return r
|
||||
|
||||
end Lean.Meta.Sym
|
||||
|
|
@ -11,8 +11,33 @@ public section
|
|||
namespace Lean.Meta.Sym
|
||||
export Grind (ExprPtr Goal)
|
||||
|
||||
/--
|
||||
Information about a single argument position in a function's type signature.
|
||||
|
||||
This is used during pattern matching and structural definitional equality tests
|
||||
to identify arguments that can be skipped or handled specially
|
||||
(e.g., instance arguments can be synthesized, proof arguments can be inferred).
|
||||
-/
|
||||
public structure ProofInstArgInfo where
|
||||
/-- `true` if this argument is a proof (its type is a `Prop`). -/
|
||||
isProof : Bool
|
||||
/-- `true` if this argument is a type class instance. -/
|
||||
isInstance : Bool
|
||||
deriving Inhabited
|
||||
|
||||
/--
|
||||
Information about a function symbol. It stores which argument positions are proofs or instances,
|
||||
enabling optimizations during pattern matching and structural definitional equality tests
|
||||
such as skipping proof arguments or deferring instance synthesis.
|
||||
-/
|
||||
public structure ProofInstInfo where
|
||||
/-- Information about each argument position. -/
|
||||
argsInfo : Array ProofInstArgInfo
|
||||
deriving Inhabited
|
||||
|
||||
structure State where
|
||||
maxFVar : PHashMap ExprPtr (Option FVarId) := {}
|
||||
proofInstInfo : PHashMap Name (Option ProofInstInfo) := {}
|
||||
|
||||
abbrev SymM := ReaderT Grind.Params StateRefT State Grind.GrindM
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue