feat: efficient pattern matching and unification for the symbolic simulation framework (#11825)
This PR completes the new pattern matching and unification procedures for the symbolic simulation framework using a two-phase approach. **Phase 1 (Syntactic Matching):** - Patterns use de Bruijn indices for expression variables and renamed level params for universe variables - Purely structural matching after reducible definitions are unfolded - Universe levels treat `max`/`imax` as uninterpreted functions - Proof arguments skipped via proof irrelevance - Instance and binder constraints deferred to Phase 2 **Phase 2 (Pending Constraints):** - Level constraints: structural equality with mvar assignment - Instance constraints: `isDefEqI` (full `isDefEq` for TC synthesis) - Expression constraints: `isDefEqS` with Miller pattern support - Unassigned instance pattern variables synthesized via `trySynthInstance` **`isDefEqS` (Structural DefEq):** - Miller pattern detection and assignment (`?m x y z := rhs` → `?m := fun x y z => rhs`) - Scope checking via `maxFVar` to prevent out-of-scope assignments - Optional zeta-delta reduction for let-declarations - Proof irrelevance and instance delegation to `isDefEqI` **Key optimizations:** - `abstractFVars` skips metavariables and uses `maxFVar` for early cutoff - Per-pattern `ProofInstInfo` cache for fast argument classification - Maximal sharing.
This commit is contained in:
parent
5042c8cc37
commit
2bca310bea
3 changed files with 123 additions and 40 deletions
|
|
@ -27,10 +27,10 @@ framework (`Sym`). The design prioritizes performance by using a two-phase appro
|
|||
- Universe levels treat `max` and `imax` as uninterpreted functions (no AC reasoning)
|
||||
- Binders and term metavariables are deferred to Phase 2
|
||||
|
||||
# Phase 2 (Pending Constraints) [WIP]
|
||||
# Phase 2 (Pending Constraints)
|
||||
- Handles binders (Miller patterns) and metavariable unification
|
||||
- Converts remaining de Bruijn variables to metavariables
|
||||
- Falls back to structural `isDefEq` (aka `isDefEqS`) when necessary.
|
||||
- Falls back to structural `isDefEqS` when necessary.
|
||||
- It still uses the standard `isDefEq` for instances.
|
||||
|
||||
# Key design decisions:
|
||||
|
|
@ -56,6 +56,7 @@ def mkProofInstInfoMapFor (pattern : Expr) : MetaM (AssocList Name ProofInstInfo
|
|||
public structure Pattern where
|
||||
levelParams : List Name
|
||||
varTypes : Array Expr
|
||||
isInstance : Array Bool
|
||||
pattern : Expr
|
||||
fnInfos : AssocList Name ProofInstInfo
|
||||
deriving Inhabited
|
||||
|
|
@ -73,19 +74,20 @@ public def mkPatternFromTheorem (declName : Name) : MetaM Pattern := do
|
|||
let us := levelParams.map mkLevelParam
|
||||
let type ← instantiateTypeLevelParams info.toConstantVal us
|
||||
let type ← preprocessType type
|
||||
-- **TODO**: save position of instance arguments
|
||||
let rec go (type : Expr) (varTypes : Array Expr) : MetaM Pattern := do
|
||||
let rec go (type : Expr) (varTypes : Array Expr) (isInstance : Array Bool) : MetaM Pattern := do
|
||||
match type with
|
||||
| .forallE _ d b _ => go b (varTypes.push d)
|
||||
| .forallE _ d b _ =>
|
||||
go b (varTypes.push d) (isInstance.push (isClass? (← getEnv) d).isSome)
|
||||
| _ =>
|
||||
let pattern := type
|
||||
let fnInfos ← mkProofInstInfoMapFor pattern
|
||||
return { levelParams, varTypes, pattern, fnInfos }
|
||||
go type #[]
|
||||
return { levelParams, varTypes, isInstance, pattern, fnInfos }
|
||||
go type #[] #[]
|
||||
|
||||
structure UnifyM.Context where
|
||||
pattern : Pattern
|
||||
unify : Bool := true
|
||||
pattern : Pattern
|
||||
unify : Bool := true
|
||||
zetaDelta : Bool := true
|
||||
|
||||
structure UnifyM.State where
|
||||
eAssignment : Array (Option Expr) := #[]
|
||||
|
|
@ -586,25 +588,34 @@ def isDefEqMainImpl (t : Expr) (s : Expr) : DefEqM Bool := do
|
|||
else
|
||||
isDefEqApp tFn t s rfl
|
||||
|
||||
abbrev DefEqM.run (unify := true) (zetaDelta := true) (mvarsNew : Array MVarId := #[]) (x : DefEqM α) : SymM α := do
|
||||
let lctx ← getLCtx
|
||||
let lctxInitialNextIndex := lctx.decls.size
|
||||
x { zetaDelta, lctxInitialNextIndex, unify, mvarsNew }
|
||||
|
||||
/--
|
||||
A lightweight structural definitional equality for the symbolic simulation framework.
|
||||
Unlike the full `isDefEq`, it avoids expensive operations while still supporting Miller pattern unification.
|
||||
-/
|
||||
public def isDefEqS (t : Expr) (s : Expr) (unify := true) (zetaDelta := true) (mvarsNew : Array MVarId := #[]) : SymM Bool := do
|
||||
let lctx ← getLCtx
|
||||
let lctxInitialNextIndex := lctx.decls.size
|
||||
isDefEqMain t s { zetaDelta, lctxInitialNextIndex, unify, mvarsNew }
|
||||
DefEqM.run (unify := unify) (zetaDelta := zetaDelta) (mvarsNew := mvarsNew) do
|
||||
isDefEqMain t s
|
||||
|
||||
def noPending : UnifyM Bool := do
|
||||
let s ← get
|
||||
return s.ePending.isEmpty && s.uPending.isEmpty && s.iPending.isEmpty
|
||||
|
||||
def instantiateLevelParamsS (e : Expr) (paramNames : List Name) (us : List Level) : SymM Expr :=
|
||||
-- We do not assume `e` is maximally shared
|
||||
shareCommon (e.instantiateLevelParams paramNames us)
|
||||
|
||||
def mkPreResult : UnifyM Unit := do
|
||||
let us ← (← get).uAssignment.toList.mapM fun
|
||||
| some val => pure val
|
||||
| none => mkFreshLevelMVar
|
||||
let pattern := (← read).pattern
|
||||
let varTypes := pattern.varTypes
|
||||
let isInstance := pattern.isInstance
|
||||
let eAssignment := (← get).eAssignment
|
||||
let mut args := #[]
|
||||
for h : i in *...eAssignment.size do
|
||||
|
|
@ -612,23 +623,70 @@ def mkPreResult : UnifyM Unit := do
|
|||
args := args.push val
|
||||
else
|
||||
let type := varTypes[i]!
|
||||
let type := type.instantiateLevelParams pattern.levelParams us
|
||||
let type ← shareCommon type
|
||||
let type ← instantiateLevelParamsS type pattern.levelParams us
|
||||
let type ← instantiateRevBetaS type args
|
||||
let mvar ← mkFreshExprSyntheticOpaqueMVar type
|
||||
if isInstance[i]! then
|
||||
if let .some val ← trySynthInstance type then
|
||||
args := args.push (← shareCommon val)
|
||||
continue
|
||||
let mvar ← mkFreshExprMVar type
|
||||
let mvar ← shareCommon mvar
|
||||
args := args.push mvar
|
||||
modify fun s => { s with args, us }
|
||||
|
||||
def processPendingLevel : UnifyM Bool := do
|
||||
let uPending := (← get).uPending
|
||||
if uPending.isEmpty then return true
|
||||
let pattern := (← read).pattern
|
||||
let us := (← get).us
|
||||
for (u, v) in uPending do
|
||||
let u := u.instantiateParams pattern.levelParams us
|
||||
unless (← isLevelDefEqS u v) do
|
||||
return false
|
||||
return true
|
||||
|
||||
def processPendingInst : UnifyM Bool := do
|
||||
let iPending := (← get).iPending
|
||||
if iPending.isEmpty then return true
|
||||
let pattern := (← read).pattern
|
||||
let us := (← get).us
|
||||
let args := (← get).args
|
||||
for (t, s) in iPending do
|
||||
let t ← instantiateLevelParamsS t pattern.levelParams us
|
||||
let t ← instantiateRevBetaS t args
|
||||
unless (← isDefEqI t s) do
|
||||
return false
|
||||
return true
|
||||
|
||||
def processPendingExpr : UnifyM Bool := do
|
||||
let ePending := (← get).ePending
|
||||
if ePending.isEmpty then return true
|
||||
let pattern := (← read).pattern
|
||||
let us := (← get).us
|
||||
let args := (← get).args
|
||||
let unify := (← read).unify
|
||||
let zetaDelta := (← read).zetaDelta
|
||||
let mvarsNew := if unify then #[] else args.filterMap fun
|
||||
| .mvar mvarId => some mvarId
|
||||
| _ => none
|
||||
DefEqM.run unify zetaDelta mvarsNew do
|
||||
for (t, s) in ePending do
|
||||
let t ← instantiateLevelParamsS t pattern.levelParams us
|
||||
let t ← instantiateRevBetaS t args
|
||||
unless (← isDefEqMain t s) do
|
||||
return false
|
||||
return true
|
||||
|
||||
def processPending : UnifyM Bool := do
|
||||
if (← noPending) then
|
||||
return true
|
||||
throwError "NIY: pending constraints"
|
||||
else
|
||||
processPendingLevel <&&> processPendingInst <&&> processPendingExpr
|
||||
|
||||
abbrev run (pattern : Pattern) (unify : Bool) (k : UnifyM α) : SymM α := do
|
||||
abbrev UnifyM.run (pattern : Pattern) (unify : Bool) (zetaDelta : Bool) (k : UnifyM α) : SymM α := do
|
||||
let eAssignment := pattern.varTypes.map fun _ => none
|
||||
let uAssignment := pattern.levelParams.toArray.map fun _ => none
|
||||
k { unify, pattern } |>.run' { eAssignment, uAssignment }
|
||||
k { unify, zetaDelta, pattern } |>.run' { eAssignment, uAssignment }
|
||||
|
||||
public structure MatchUnifyResult where
|
||||
us : List Level
|
||||
|
|
@ -638,11 +696,10 @@ def mkResult : UnifyM MatchUnifyResult := do
|
|||
let s ← get
|
||||
return { s with }
|
||||
|
||||
def main (p : Pattern) (e : Expr) (unify : Bool) : SymM (Option (MatchUnifyResult)) :=
|
||||
run p unify do
|
||||
def main (p : Pattern) (e : Expr) (unify : Bool) (zetaDelta : Bool) : SymM (Option (MatchUnifyResult)) :=
|
||||
UnifyM.run p unify zetaDelta do
|
||||
unless (← process p.pattern e) do return none
|
||||
mkPreResult
|
||||
-- **TODO** synthesize instance arguments
|
||||
unless (← processPending) do return none
|
||||
return some (← mkResult)
|
||||
|
||||
|
|
@ -660,8 +717,8 @@ Matching fails if:
|
|||
Instance arguments are deferred for later synthesis. Proof arguments are
|
||||
skipped via proof irrelevance.
|
||||
-/
|
||||
public def Pattern.match? (p : Pattern) (e : Expr) : SymM (Option (MatchUnifyResult)) :=
|
||||
main p e (unify := false)
|
||||
public def Pattern.match? (p : Pattern) (e : Expr) (zetaDelta := true) : SymM (Option (MatchUnifyResult)) :=
|
||||
main p e (unify := false) (zetaDelta := zetaDelta)
|
||||
|
||||
/--
|
||||
Attempts to unify expression `e` against pattern `p`, allowing metavariables in `e`.
|
||||
|
|
@ -677,7 +734,7 @@ expressions that may contain unsolved metavariables.
|
|||
Instance arguments are deferred for later synthesis. Proof arguments are
|
||||
skipped via proof irrelevance.
|
||||
-/
|
||||
public def Pattern.unify? (p : Pattern) (e : Expr) : SymM (Option (MatchUnifyResult)) :=
|
||||
main p e (unify := true)
|
||||
public def Pattern.unify? (p : Pattern) (e : Expr) (zetaDelta := true) : SymM (Option (MatchUnifyResult)) :=
|
||||
main p e (unify := true) (zetaDelta := zetaDelta)
|
||||
|
||||
end Lean.Meta.Sym
|
||||
|
|
|
|||
|
|
@ -1,28 +1,29 @@
|
|||
import Lean.Meta.Sym
|
||||
open Lean Meta Sym
|
||||
|
||||
open Lean Meta Sym Grind
|
||||
set_option grind.debug true
|
||||
opaque p : Nat → Prop
|
||||
opaque q : Nat → Nat → Prop
|
||||
|
||||
def ex := ∃ x : Nat, p x ∧ q x .zero
|
||||
def ex := ∃ x : Nat, p x ∧ x = .zero
|
||||
|
||||
def test : SymM Unit := do
|
||||
let p ← mkPatternFromTheorem ``Exists.intro
|
||||
let e := (← getConstInfo ``ex).value!
|
||||
let some r ← p.match? e | throwError "failed"
|
||||
let app := mkAppN (mkConst ``Exists.intro r.us) r.args
|
||||
logInfo app
|
||||
for arg in r.args do
|
||||
if arg.isMVar then
|
||||
logInfo m!"{arg} : {← inferType arg}"
|
||||
return ()
|
||||
let pEx ← mkPatternFromTheorem ``Exists.intro
|
||||
let pAnd ← mkPatternFromTheorem ``And.intro
|
||||
let pEq ← mkPatternFromTheorem ``Eq.refl
|
||||
let e ← shareCommon (← getConstInfo ``ex).value!
|
||||
let some r₁ ← pEx.match? e | throwError "failed"
|
||||
logInfo <| mkAppN (mkConst ``Exists.intro r₁.us) r₁.args
|
||||
let some r₂ ← pAnd.match? (← inferType r₁.args[3]!) | throwError "failed"
|
||||
logInfo <| mkAppN (mkConst ``And.intro r₂.us) r₂.args
|
||||
let some r₃ ← pEq.unify? (← inferType r₂.args[3]!) | throwError "failed"
|
||||
logInfo <| mkAppN (mkConst ``Eq.refl r₃.us) r₃.args
|
||||
|
||||
/--
|
||||
info: @Exists.intro Nat (fun x => And (p x) (q x Nat.zero)) ?m.1 ?m.2
|
||||
info: @Exists.intro Nat (fun x => And (p x) (@Eq Nat x Nat.zero)) ?m.1 ?m.2
|
||||
---
|
||||
info: ?m.1 : Nat
|
||||
info: @And.intro (p ?m.1) (@Eq Nat ?m.1 Nat.zero) ?m.3 ?m.4
|
||||
---
|
||||
info: ?m.2 : And (p ?m.1) (q ?m.1 Nat.zero)
|
||||
info: @Eq.refl Nat Nat.zero
|
||||
-/
|
||||
#guard_msgs in
|
||||
set_option pp.explicit true in
|
||||
|
|
|
|||
25
tests/lean/run/sym_pattern_2.lean
Normal file
25
tests/lean/run/sym_pattern_2.lean
Normal file
|
|
@ -0,0 +1,25 @@
|
|||
import Lean.Meta.Sym
|
||||
open Lean Meta Sym Grind
|
||||
set_option grind.debug true
|
||||
opaque p [Ring α] : α → α → Prop
|
||||
axiom pax [CommRing α] [NoNatZeroDivisors α] (x y : α) : p x y → p (y + 1) x
|
||||
opaque a : Int
|
||||
opaque b : Int
|
||||
def ex := p (a + 1) b
|
||||
|
||||
def test : SymM Unit := do
|
||||
let pEx ← mkPatternFromTheorem ``pax
|
||||
let e ← shareCommon (← getConstInfo ``ex).value!
|
||||
let some r₁ ← pEx.match? e | throwError "failed"
|
||||
let h := mkAppN (mkConst ``pax r₁.us) r₁.args
|
||||
check h
|
||||
logInfo h
|
||||
logInfo r₁.args
|
||||
|
||||
/--
|
||||
info: pax b a ?m.1
|
||||
---
|
||||
info: #[Int, instCommRingInt, instNoNatZeroDivisorsInt, b, a, ?m.1]
|
||||
-/
|
||||
#guard_msgs in
|
||||
#eval SymM.run' test
|
||||
Loading…
Add table
Reference in a new issue