lean4-htt/src/Lean/Meta/Sym/Pattern.lean
Leonardo de Moura 48bb954e4e
feat: structural isDefEq for Sym (#11819)
This PR adds some basic infrastructure for a structural (and cheaper)
`isDefEq` predicate for pattern matching and unification in `Sym`.
2025-12-28 22:37:21 +00:00

402 lines
14 KiB
Text
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

/-
Copyright (c) 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
module
prelude
public import Lean.Meta.Sym.SymM
import Lean.Util.FoldConsts
import Lean.Meta.Sym.InstantiateS
import Lean.Meta.Sym.IsClass
import Lean.Meta.Sym.ProofInstInfo
import Lean.Meta.Tactic.Grind.AlphaShareBuilder
namespace Lean.Meta.Sym
open Grind
/-!
This module implements efficient pattern matching and unification module for the symbolic simulation
framework (`Sym`). The design prioritizes performance by using a two-phase approach:
# Phase 1 (Syntactic Matching)
- Patterns use de Bruijn indices for expression variables and renamed level params (`_uvar.0`, `_uvar.1`, ...) for universe variables
- Matching is purely structural after reducible definitions are unfolded during preprocessing
- Universe levels treat `max` and `imax` as uninterpreted functions (no AC reasoning)
- Binders and term metavariables are deferred to Phase 2
# Phase 2 (Pending Constraints) [WIP]
- Handles binders (Miller patterns) and metavariable unification
- Converts remaining de Bruijn variables to metavariables
- Falls back to `isDefEq` when necessary
# Key design decisions:
- Preprocessing unfolds reducible definitions and performs beta/zeta reduction
- Kernel projections are expected to be folded as projection applications before matching
- Assignment conflicts are deferred to pending rather than invoking `isDefEq` inline
- `instantiateRevS` ensures maximal sharing of result expressions
-/
/--
Collects `ProofInstInfo` for all function symbols occurring in `pattern`.
Only includes functions that have at least one proof or instance argument.
-/
def mkProofInstInfoMapFor (pattern : Expr) : MetaM (AssocList Name ProofInstInfo) := do
let cs := pattern.getUsedConstants
let mut fnInfos := {}
for declName in cs do
if let some info ← mkProofInstInfo? declName then
fnInfos := fnInfos.insertNew declName info
return fnInfos
public structure Pattern where
levelParams : List Name
varTypes : Array Expr
pattern : Expr
fnInfos : AssocList Name ProofInstInfo
deriving Inhabited
def uvarPrefix : Name := `_uvar
def isUVar? (n : Name) : Option Nat := Id.run do
let .num p idx := n | return none
unless p == uvarPrefix do return none
return some idx
public def mkPatternFromTheorem (declName : Name) : MetaM Pattern := do
let info ← getConstInfo declName
let levelParams := info.levelParams.mapIdx fun i _ => Name.num uvarPrefix i
let us := levelParams.map mkLevelParam
let type ← instantiateTypeLevelParams info.toConstantVal us
let type ← preprocessType type
-- **TODO**: save position of instance arguments
let rec go (type : Expr) (varTypes : Array Expr) : MetaM Pattern := do
match type with
| .forallE _ d b _ => go b (varTypes.push d)
| _ =>
let pattern := type
let fnInfos ← mkProofInstInfoMapFor pattern
return { levelParams, varTypes, pattern, fnInfos }
go type #[]
structure UnifyM.Context where
pattern : Pattern
unify : Bool := true
structure UnifyM.State where
eAssignment : Array (Option Expr) := #[]
uAssignment : Array (Option Level) := #[]
ePending : Array (Expr × Expr) := #[]
uPending : Array (Level × Level) := #[]
iPending : Array (Expr × Expr) := #[]
us : List Level := []
args : Array Expr := #[]
abbrev UnifyM := ReaderT UnifyM.Context StateRefT UnifyM.State SymM
def pushPending (p : Expr) (e : Expr) : UnifyM Unit :=
modify fun s => { s with ePending := s.ePending.push (p, e) }
def pushLevelPending (u : Level) (v : Level) : UnifyM Unit :=
modify fun s => { s with uPending := s.uPending.push (u, v) }
def pushInstPending (p : Expr) (e : Expr) : UnifyM Unit :=
modify fun s => { s with iPending := s.iPending.push (p, e) }
def assignExprIfUnassigned (bidx : Nat) (e : Expr) : UnifyM Unit := do
let s ← get
let i := s.eAssignment.size - bidx - 1
if s.eAssignment[i]!.isNone then
modify fun s => { s with eAssignment := s.eAssignment.set! i (some e) }
def assignExpr (bidx : Nat) (e : Expr) : UnifyM Bool := do
let s ← get
let i := s.eAssignment.size - bidx - 1
if let some e' := s.eAssignment[i]! then
if isSameExpr e e' then return true
else
pushPending e' e
return true
else
modify fun s => { s with eAssignment := s.eAssignment.set! i (some e) }
return true
def assignLevel (uidx : Nat) (u : Level) : UnifyM Bool := do
if let some u' := (← get).uAssignment[uidx]! then
isLevelDefEq u u'
else
modify fun s => { s with uAssignment := s.uAssignment.set! uidx (some u) }
return true
def checkMVar (p : Expr) (e : Expr) : UnifyM Bool := do
if (← read).unify && e.getAppFn.isMVar then
pushPending p e
return true
else
return false
def processLevel (u : Level) (v : Level) : UnifyM Bool := do
match u, v with
| .zero, .zero => return true
| .succ u, .succ v => processLevel u v
| .zero, .succ _ => return false
| .succ _, .zero => return false
| .zero, .max v₁ v₂ => processLevel .zero v₁ <&&> processLevel .zero v₂
| .max u₁ u₂, .zero => processLevel u₁ .zero <&&> processLevel u₂ .zero
| .zero, .imax _ v => processLevel .zero v
| .imax _ u, .zero => processLevel u .zero
| .max u₁ u₂, .max v₁ v₂ => processLevel u₁ v₁ <&&> processLevel u₂ v₂
| .imax u₁ u₂, .imax v₁ v₂ => processLevel u₁ v₁ <&&> processLevel u₂ v₂
| .param uName, _ =>
if let some uidx := isUVar? uName then
assignLevel uidx v
else if u == v then
return true
else if v.isMVar && (← read).unify then
pushLevelPending u v
return true
else
return false
| .mvar _, _ | _, .mvar _ =>
if (← read).unify then
pushLevelPending u v
return true
else
return false
| _, _ => return false
def processLevels (us : List Level) (vs : List Level) : UnifyM Bool := do
match us, vs with
| [], [] => return true
| [], _::_ => return false
| _::_, [] => return false
| u::us, v::vs => processLevel u v <&&> processLevels us vs
partial def process (p : Expr) (e : Expr) : UnifyM Bool := do
match p with
| .bvar bidx => assignExpr bidx e
| .mdata _ p => process p e
| .const declName us =>
processConst declName us e <||> checkMVar p e
| .sort u =>
processSort u e <||> checkMVar p e
| .app .. =>
processApp p e <||> checkMVar p e
| .forallE .. | .lam .. =>
pushPending p e
return true
| .proj .. =>
reportIssue! "unexpected kernel projection term during unification/matching{indentExpr e}\npre-process and fold them as projection applications"
return false
| .mvar _ | .fvar _ | .lit _ =>
pure (p == e) <||> checkMVar p e
| .letE .. => unreachable!
where
processApp (p : Expr) (e : Expr) : UnifyM Bool := do
let f := p.getAppFn
let .const declName _ := f | processAppDefault p e
let some info := (← read).pattern.fnInfos.find? declName | process.processAppDefault p e
let numArgs := p.getAppNumArgs
processAppWithInfo p e (numArgs - 1) info
processAppWithInfo (p : Expr) (e : Expr) (i : Nat) (info : ProofInstInfo) : UnifyM Bool := do
let .app fp ap := p | process p e
let .app fe ae := e | return false
unless (← processAppWithInfo fp fe (i - 1) info) do return false
if h : i < info.argsInfo.size then
let argInfo := info.argsInfo[i]
if argInfo.isInstance then
if let .bvar bidx := ap then
assignExprIfUnassigned bidx ae
else
pushInstPending ap ae
return true
else if argInfo.isProof then
if let .bvar bidx := ap then
assignExprIfUnassigned bidx ae
return true
else
process ap ae
else
process ap ae
processAppDefault (p : Expr) (e : Expr) : UnifyM Bool := do
let .app fp ap := p | process p e
let .app fe ae := e | return false
unless (← processAppDefault fp fe) do return false
process ap ae
processConst (declName : Name) (us : List Level) (e : Expr) : UnifyM Bool := do
let .const declName' us' := e | return false
unless declName == declName' do return false
processLevels us us'
processSort (u : Level) (e : Expr) : UnifyM Bool := do
let .sort v := e | return false
processLevel u v
def isLevelDefEqS (u : Level) (v : Level) : MetaM Bool := do
match u, v with
| .param u, .param v => return u == v
| .zero, .zero => return true
| .succ u, .succ v => isLevelDefEqS u v
| .zero, .succ _ => return false
| .succ _, .zero => return false
| .zero, .max v₁ v₂ => isLevelDefEqS .zero v₁ <&&> isLevelDefEqS .zero v₂
| .max u₁ u₂, .zero => isLevelDefEqS u₁ .zero <&&> isLevelDefEqS u₂ .zero
| .zero, .imax _ v => isLevelDefEqS .zero v
| .imax _ u, .zero => isLevelDefEqS u .zero
| .max u₁ u₂, .max v₁ v₂ => isLevelDefEqS u₁ v₁ <&&> isLevelDefEqS u₂ v₂
| .imax u₁ u₂, .imax v₁ v₂ => isLevelDefEqS u₁ v₁ <&&> isLevelDefEqS u₂ v₂
| .mvar mvarId, v => assignLevelMVar mvarId v; return true
| u, .mvar mvarId => assignLevelMVar mvarId u; return true
| _, _ => return false
structure DefEqM.Context where
unify : Bool := true
zetaDelta : Bool := true
/--
If `unit
-/
mvarsNew : Array MVarId := #[]
abbrev DefEqM := ReaderT DefEqM.Context SymM
/--
Structural definitional equality. It is much cheaper than `isDefEq`.
-/
@[extern "lean_sym_def_eq"] -- Forward definition
opaque isDefEqS : Expr → Expr → DefEqM Bool
/--
Structural definitional equality for `forall` and `lambda` binders.
-/
def isDefEqBindingS (a b : Expr) : DefEqM Bool := do
let lctx ← getLCtx
let localInsts ← getLocalInstances
go lctx localInsts #[] a b #[]
where
checkDomains (fvars : Array Expr) (ds₂ : Array Expr) : DefEqM Bool := do
for fvar in fvars, d in ds₂ do
let fvarType ← fvar.fvarId!.getType
unless (← isDefEqS fvarType d) do return false
return true
go (lctx : LocalContext) (localInsts : LocalInstances) (fvars : Array Expr) (e₁ e₂ : Expr) (ds₂ : Array Expr) : DefEqM Bool := do
match e₁, e₂ with
| .forallE n d₁ b₁ _, .forallE _ d₂ b₂ _
| .lam n d₁ b₁ _, .lam _ d₂ b₂ _ =>
let d₁ ← instantiateRevS d₁ fvars
let d₂ ← instantiateRevS d₂ fvars
let fvarId ← mkFreshFVarId
let fvar ← mkFVarS fvarId
let lctx := lctx.mkLocalDecl fvarId n d₁
let localInsts := if let some className := isClass? (← getEnv) d₁ then
localInsts.push { className, fvar }
else
localInsts
go lctx localInsts (fvars.push fvar) b₁ b₂ (ds₂.push d₂)
| _, _ => withLCtx lctx localInsts do
unless (← checkDomains fvars ds₂) do return false
isDefEqS (← instantiateRevS e₁ fvars) (← instantiateRevS e₂ fvars)
/--
`isDefEqS` implementation.
-/
@[export lean_sym_def_eq]
def isDefEqSImpl (t : Expr) (s : Expr) : DefEqM Bool := do
if isSameExpr t s then return true
match t, s with
| .lit l₁, .lit l₂ => return l₁ == l₂
| .sort u, .sort v => isLevelDefEqS u v
| .lam .., .lam .. => isDefEqBindingS t s
| .forallE .., .forallE .. => isDefEqBindingS t s
| _, _ =>
-- **TODO**
return false
def noPending : UnifyM Bool := do
let s ← get
return s.ePending.isEmpty && s.uPending.isEmpty && s.iPending.isEmpty
def mkPreResult : UnifyM Unit := do
let us ← (← get).uAssignment.toList.mapM fun
| some val => pure val
| none => mkFreshLevelMVar
let pattern := (← read).pattern
let varTypes := pattern.varTypes
let eAssignment := (← get).eAssignment
let mut args := #[]
for h : i in *...eAssignment.size do
if let .some val := eAssignment[i] then
args := args.push val
else
let type := varTypes[i]!
let type := type.instantiateLevelParams pattern.levelParams us
let type ← shareCommon type
let type ← instantiateRevBetaS type args
let mvar ← mkFreshExprSyntheticOpaqueMVar type
let mvar ← shareCommon mvar
args := args.push mvar
modify fun s => { s with args, us }
def processPending : UnifyM Bool := do
if (← noPending) then
return true
throwError "NIY: pending constraints"
abbrev run (pattern : Pattern) (unify : Bool) (k : UnifyM α) : SymM α := do
let eAssignment := pattern.varTypes.map fun _ => none
let uAssignment := pattern.levelParams.toArray.map fun _ => none
k { unify, pattern } |>.run' { eAssignment, uAssignment }
public structure MatchUnifyResult where
us : List Level
args : Array Expr
def mkResult : UnifyM MatchUnifyResult := do
let s ← get
return { s with }
def main (p : Pattern) (e : Expr) (unify : Bool) : SymM (Option (MatchUnifyResult)) :=
run p unify do
unless (← process p.pattern e) do return none
mkPreResult
-- **TODO** synthesize instance arguments
unless (← processPending) do return none
return some (← mkResult)
/--
Attempts to match expression `e` against pattern `p` using purely syntactic matching.
Returns `some result` if matching succeeds, where `result` contains:
- `us`: Level assignments for the pattern's universe variables
- `args`: Expression assignments for the pattern's bound variables
Matching fails if:
- The term contains metavariables (use `unify?` instead)
- Structural mismatch after reducible unfolding
Instance arguments are deferred for later synthesis. Proof arguments are
skipped via proof irrelevance.
-/
public def Pattern.match? (p : Pattern) (e : Expr) : SymM (Option (MatchUnifyResult)) :=
main p e (unify := false)
/--
Attempts to unify expression `e` against pattern `p`, allowing metavariables in `e`.
Returns `some result` if unification succeeds, where `result` contains:
- `us`: Level assignments for the pattern's universe variables
- `args`: Expression assignments for the pattern's bound variables
Unlike `match?`, this handles terms containing metavariables by deferring
constraints to Phase 2 unification. Use this when matching against goal
expressions that may contain unsolved metavariables.
Instance arguments are deferred for later synthesis. Proof arguments are
skipped via proof irrelevance.
-/
public def Pattern.unify? (p : Pattern) (e : Expr) : SymM (Option (MatchUnifyResult)) :=
main p e (unify := true)
end Lean.Meta.Sym