chore: normalize Sym APIs (#12088)
This PR cleanups the Sym APIs for `apply` and `simp`.
This commit is contained in:
parent
b8f8dde0b3
commit
08e6f714ca
12 changed files with 181 additions and 53 deletions
|
|
@ -96,6 +96,10 @@ def mkValue (expr : Expr) (pattern : Pattern) (result : MatchUnifyResult) : Expr
|
|||
else
|
||||
mkAppN (expr.instantiateLevelParams pattern.levelParams result.us) result.args
|
||||
|
||||
public inductive ApplyResult where
|
||||
| notApplicable
|
||||
| goals (mvarId : List MVarId)
|
||||
|
||||
/--
|
||||
Applies a backward rule to a goal, returning new subgoals.
|
||||
|
||||
|
|
@ -103,27 +107,23 @@ Applies a backward rule to a goal, returning new subgoals.
|
|||
2. Assigns the goal metavariable to the theorem application
|
||||
3. Returns new goals for unassigned arguments (per `resultPos`)
|
||||
|
||||
Returns `none` if unification fails.
|
||||
Returns `.notApplicable` if unification fails.
|
||||
-/
|
||||
public def BackwardRule.apply? (mvarId : MVarId) (rule : BackwardRule) : SymM (Option (List MVarId)) := mvarId.withContext do
|
||||
public def BackwardRule.apply (mvarId : MVarId) (rule : BackwardRule) : SymM ApplyResult := mvarId.withContext do
|
||||
let decl ← mvarId.getDecl
|
||||
if let some result ← rule.pattern.unify? decl.type then
|
||||
mvarId.assign (mkValue rule.expr rule.pattern result)
|
||||
return some <| rule.resultPos.map fun i =>
|
||||
return .goals <| rule.resultPos.map fun i =>
|
||||
result.args[i]!.mvarId!
|
||||
else
|
||||
return none
|
||||
return .notApplicable
|
||||
|
||||
/--
|
||||
Similar to `BackwardRule.apply?`, but throws an error if unification fails.
|
||||
Similar to `BackwardRule.apply', but throws an error if unification fails.
|
||||
-/
|
||||
public def BackwardRule.apply (mvarId : MVarId) (rule : BackwardRule) : SymM (List MVarId) := mvarId.withContext do
|
||||
let decl ← mvarId.getDecl
|
||||
if let some result ← rule.pattern.unify? decl.type then
|
||||
mvarId.assign (mkValue rule.expr rule.pattern result)
|
||||
return rule.resultPos.map fun i =>
|
||||
result.args[i]!.mvarId!
|
||||
else
|
||||
throwError "rule is not applicable to goal{mvarId}rule:{indentExpr rule.expr}"
|
||||
public def BackwardRule.apply' (mvarId : MVarId) (rule : BackwardRule) : SymM (List MVarId) := do
|
||||
let .goals mvarIds ← rule.apply mvarId
|
||||
| throwError "rule is not applicable to goal{mvarId}rule:{indentExpr rule.expr}"
|
||||
return mvarIds
|
||||
|
||||
end Lean.Meta.Sym
|
||||
|
|
|
|||
|
|
@ -21,3 +21,4 @@ public import Lean.Meta.Sym.Simp.Debug
|
|||
public import Lean.Meta.Sym.Simp.EvalGround
|
||||
public import Lean.Meta.Sym.Simp.Discharger
|
||||
public import Lean.Meta.Sym.Simp.ControlFlow
|
||||
public import Lean.Meta.Sym.Simp.Goal
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ public import Lean.Meta.Sym.Simp.SimpM
|
|||
public import Lean.Meta.Sym.Simp.Discharger
|
||||
import Lean.Meta.Sym.Simp.Theorems
|
||||
import Lean.Meta.Sym.Simp.Rewrite
|
||||
import Lean.Meta.Sym.Simp.Goal
|
||||
import Lean.Meta.Sym.Util
|
||||
import Lean.Meta.Tactic.Util
|
||||
import Lean.Meta.AppBuilder
|
||||
|
|
@ -27,24 +28,9 @@ public def mkSimprocFor (declNames : Array Name) (d : Discharger := dischargeNon
|
|||
public def mkMethods (declNames : Array Name) : MetaM Methods := do
|
||||
return { post := (← mkSimprocFor declNames) }
|
||||
|
||||
public def simpWith (k : Expr → SymM Result) (mvarId : MVarId) : MetaM (Option MVarId) := SymM.run do
|
||||
let mvarId ← preprocessMVar mvarId
|
||||
let decl ← mvarId.getDecl
|
||||
let target := decl.type
|
||||
match (← k target) with
|
||||
| .rfl _ => throwError "`Sym.simp` made no progress "
|
||||
| .step target' h _ =>
|
||||
let mvarNew ← mkFreshExprSyntheticOpaqueMVar target' decl.userName
|
||||
let h ← mkAppM ``Eq.mpr #[h, mvarNew]
|
||||
mvarId.assign h
|
||||
if target'.isTrue then
|
||||
mvarNew.mvarId!.assign (mkConst ``True.intro)
|
||||
return none
|
||||
else
|
||||
return some mvarNew.mvarId!
|
||||
|
||||
public def simpGoal (declNames : Array Name) (mvarId : MVarId) : MetaM (Option MVarId) := SymM.run do mvarId.withContext do
|
||||
public def simpGoalUsing (declNames : Array Name) (mvarId : MVarId) : MetaM (Option MVarId) := SymM.run do
|
||||
let methods ← mkMethods declNames
|
||||
simpWith (simp · methods) mvarId
|
||||
let mvarId ← preprocessMVar mvarId
|
||||
(← simpGoal mvarId methods).toOption
|
||||
|
||||
end Lean.Meta.Sym
|
||||
|
|
|
|||
|
|
@ -81,7 +81,7 @@ public def simpForall (e : Expr) : SimpM Result := do
|
|||
else if (← isProp e) then
|
||||
let n := getForallTelescopeSize e.bindingBody! 1
|
||||
forallBoundedTelescope e n fun xs b => withoutModifyingCacheIfNotWellBehaved do
|
||||
main xs (← share b)
|
||||
main xs (← shareCommon b)
|
||||
else
|
||||
return .rfl
|
||||
where
|
||||
|
|
@ -90,7 +90,7 @@ where
|
|||
| .rfl _ => return .rfl
|
||||
| .step b' h _ =>
|
||||
let h ← mkLambdaFVars xs h
|
||||
let e' ← share (← mkForallFVars xs b')
|
||||
let e' ← shareCommon (← mkForallFVars xs b')
|
||||
-- **Note**: consider caching the forall-congr theorems
|
||||
let hcongr ← mkForallCongrFor xs
|
||||
return .step e' (mkApp3 hcongr (← mkLambdaFVars xs b) (← mkLambdaFVars xs b') h)
|
||||
|
|
|
|||
69
src/Lean/Meta/Sym/Simp/Goal.lean
Normal file
69
src/Lean/Meta/Sym/Simp/Goal.lean
Normal file
|
|
@ -0,0 +1,69 @@
|
|||
/-
|
||||
Copyright (c) 2026 Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Leonardo de Moura
|
||||
-/
|
||||
module
|
||||
prelude
|
||||
public import Lean.Meta.Sym.Simp.SimpM
|
||||
import Lean.Meta.Tactic.Util
|
||||
import Lean.Meta.AppBuilder
|
||||
namespace Lean.Meta.Sym
|
||||
/-!
|
||||
# Goal simplification
|
||||
|
||||
Applies `Sym.simp` to a goal's target type, producing a simplified goal or closing it if
|
||||
the result is `True`.
|
||||
-/
|
||||
|
||||
/-- Result of simplifying a goal with `Sym.simp`. -/
|
||||
public inductive SimpGoalResult where
|
||||
/-- No simplification was possible. -/
|
||||
| noProgress
|
||||
/-- The goal was closed (simplified to `True`). -/
|
||||
| closed
|
||||
/-- The goal was simplified to a new goal. -/
|
||||
| goal (mvarId : MVarId)
|
||||
|
||||
/--
|
||||
Converts a `SimpGoalResult` to an optional goal.
|
||||
Returns `none` if closed, `some mvarId` if simplified, or throws an error if no progress.
|
||||
-/
|
||||
public def SimpGoalResult.toOption : SimpGoalResult → CoreM (Option MVarId)
|
||||
| .noProgress => throwError "`Sym.simp` made no progress "
|
||||
| .closed => return none
|
||||
| .goal mvarId => return some mvarId
|
||||
|
||||
/--
|
||||
Simplifies the target of `mvarId` using `Sym.simp`.
|
||||
Returns `.closed` if the target simplifies to `True`, `.simp mvarId'` if simplified
|
||||
to a new goal, or `.noProgress` if no simplification occurred.
|
||||
|
||||
This function assumed the input goal is a valid `Sym` goal (e.g., expressions are maximally shared).
|
||||
-/
|
||||
public def simpGoal (mvarId : MVarId) (methods : Simp.Methods := {}) (config : Simp.Config := {})
|
||||
: SymM SimpGoalResult := mvarId.withContext do
|
||||
let decl ← mvarId.getDecl
|
||||
let target := decl.type
|
||||
match (← simp target methods config) with
|
||||
| .rfl _ => return .noProgress
|
||||
| .step target' h _ =>
|
||||
let mvarNew ← mkFreshExprSyntheticOpaqueMVar target' decl.userName
|
||||
let h ← mkAppM ``Eq.mpr #[h, mvarNew]
|
||||
mvarId.assign h
|
||||
if target'.isTrue then
|
||||
mvarNew.mvarId!.assign (mkConst ``True.intro)
|
||||
return .closed
|
||||
else
|
||||
return .goal mvarNew.mvarId!
|
||||
|
||||
/--
|
||||
Similar to `simpGoal`, but returns `.goal mvarId` if no progress was made.
|
||||
-/
|
||||
public def trySimpGoal (mvarId : MVarId) (methods : Simp.Methods := {}) (config : Simp.Config := {})
|
||||
: SymM SimpGoalResult := do
|
||||
match (← simpGoal mvarId methods config) with
|
||||
| .noProgress => return .goal mvarId
|
||||
| r => return r
|
||||
|
||||
end Lean.Meta.Sym
|
||||
|
|
@ -48,14 +48,14 @@ def mkFunextFor (xs : Array Expr) (β : Expr) : MetaM Expr := do
|
|||
|
||||
public def simpLambda (e : Expr) : SimpM Result := do
|
||||
lambdaTelescope e fun xs b => withoutModifyingCacheIfNotWellBehaved do
|
||||
main xs (← share b)
|
||||
main xs (← shareCommon b)
|
||||
where
|
||||
main (xs : Array Expr) (b : Expr) : SimpM Result := do
|
||||
match (← simp b) with
|
||||
| .rfl _ => return .rfl
|
||||
| .step b' h _ =>
|
||||
let h ← mkLambdaFVars xs h
|
||||
let e' ← share (← mkLambdaFVars xs b')
|
||||
let e' ← shareCommon (← mkLambdaFVars xs b')
|
||||
let funext ← getFunext xs b
|
||||
return .step e' (mkApp3 funext e e' h)
|
||||
|
||||
|
|
|
|||
|
|
@ -235,3 +235,74 @@ def runBenchUsingMeta : MetaM Unit := do
|
|||
solveUsingMeta n
|
||||
|
||||
#eval runBenchUsingMeta
|
||||
-- goal_80: 1467.414291 ms, kernel: 120.162250 ms
|
||||
|
||||
/-!
|
||||
`SymM` Solution
|
||||
-/
|
||||
|
||||
theorem exists_eq_True (a : α) : (∃ x, x = a) = True := by
|
||||
simp
|
||||
|
||||
open Sym
|
||||
|
||||
def mkMethods (declNames : Array Name) : MetaM Sym.Simp.Methods := do
|
||||
let rewrite ← Sym.mkSimprocFor declNames
|
||||
return {
|
||||
post := Sym.Simp.evalGround.andThen rewrite
|
||||
}
|
||||
|
||||
elab "sym_simp" "[" declNames:ident,* "]" : tactic => do
|
||||
let rewrite ← Sym.mkSimprocFor (← declNames.getElems.mapM fun s => realizeGlobalConstNoOverload s.raw) Sym.Simp.dischargeSimpSelf
|
||||
let methods : Sym.Simp.Methods := {
|
||||
pre := Sym.Simp.simpControl
|
||||
post := Sym.Simp.evalGround.andThen rewrite
|
||||
}
|
||||
liftMetaTactic1 fun mvarId => Sym.SymM.run do
|
||||
let mvarId ← Sym.preprocessMVar mvarId
|
||||
(← Sym.simpGoal mvarId methods).toOption
|
||||
|
||||
example (l : PartialMap String Word) : ((l.put "b" x).put "a" y).get "b" = x := by
|
||||
sym_simp [PartialMap.get_put_diff, PartialMap.get_put]
|
||||
|
||||
partial def solve (mvarId : MVarId) : SymM Unit := do
|
||||
let exec_cpsRule ← mkBackwardRuleFromDecl ``Exec.seq_cps
|
||||
let inputRule ← mkBackwardRuleFromDecl ``Exec.input
|
||||
let skipRule ← mkBackwardRuleFromDecl ``Exec.skip
|
||||
let setRule ← mkBackwardRuleFromDecl ``Exec.set
|
||||
let rflRule ← mkBackwardRuleFromDecl ``Eq.refl
|
||||
let unfoldMethods ← mkMethods #[``generated_cmd.eq_1, ``repeated_cmds.eq_1, ``repeated_cmds.eq_2]
|
||||
let evalMethods ← mkMethods #[``Expr.eval.eq_1, ``Expr.eval.eq_2, ``Expr.eval.eq_3]
|
||||
let simpMethods ← mkMethods #[``PartialMap.get_put_diff, ``PartialMap.get_put, ``PartialMap.put_put, ``Binop.interp_add,
|
||||
``Binop.interp_sub, ``Word.add_sub_cancel, ``Option.some.injEq, ``not_false_eq_true, ``ne_eq]
|
||||
let finalSimpMethods ← mkMethods #[``List.cons.injEq, ``IOEvent.IN.injEq, ``and_true, ``true_and, ``PartialMap.put_put, ``PartialMap.get_put,
|
||||
``Option.some.injEq, ``and_self, ``exists_eq_True]
|
||||
-- Initialize
|
||||
let mvarId ← preprocessMVar mvarId
|
||||
let (_, mvarId) ← Sym.introN mvarId 2
|
||||
let .goal mvarId ← Sym.simpGoal mvarId unfoldMethods | failure
|
||||
let .goals [mvarId] ← exec_cpsRule.apply mvarId | failure
|
||||
let .goals [mvarId] ← inputRule.apply mvarId | failure
|
||||
let (_, mvarId) ← Sym.introN mvarId 1
|
||||
-- Loop
|
||||
let rec loop (mvarId : MVarId) : SymM MVarId := do
|
||||
-- mvarId.withContext do logInfo m!"{← mvarId.getType}"
|
||||
let .goals [mvarId] ← exec_cpsRule.apply mvarId | return mvarId
|
||||
let .goals [mvarId', mvarId, _] ← setRule.apply mvarId | failure
|
||||
let .goal mvarId' ← Sym.simpGoal mvarId' evalMethods | failure
|
||||
let .goal mvarId' ← Sym.simpGoal mvarId' simpMethods | failure
|
||||
let .goals [] ← rflRule.apply mvarId' | failure
|
||||
loop mvarId
|
||||
|
||||
let mvarId ← loop mvarId
|
||||
let .goals [mvarId] ← skipRule.apply mvarId | failure
|
||||
let .goal mvarId ← Sym.simpGoal mvarId finalSimpMethods { maxSteps := 100000 } | failure
|
||||
logInfo mvarId -- **TODO**: get_put theorem is not behaving correctly
|
||||
mvarId.admit
|
||||
return
|
||||
|
||||
def solveUsingSym (n : Nat) (check := true) : MetaM Unit := do
|
||||
driver n check fun mvarId => SymM.run do solve mvarId
|
||||
|
||||
set_option maxRecDepth 100000
|
||||
#eval solveUsingSym 4
|
||||
|
|
|
|||
|
|
@ -13,9 +13,9 @@ def test1 : SymM Unit := do
|
|||
let e ← shareCommon (← getConstInfo ``ex).value!
|
||||
let some r₁ ← pEx.match? e | throwError "failed"
|
||||
logInfo <| mkAppN (mkConst ``Exists.intro r₁.us) r₁.args
|
||||
let some r₂ ← pAnd.match? (← Sym.inferType r₁.args[3]!) | throwError "failed"
|
||||
let some r₂ ← pAnd.match? (← Sym.inferType r₁.args[3]!) | failure
|
||||
logInfo <| mkAppN (mkConst ``And.intro r₂.us) r₂.args
|
||||
let some r₃ ← pEq.unify? (← Sym.inferType r₂.args[3]!) | throwError "failed"
|
||||
let some r₃ ← pEq.unify? (← Sym.inferType r₂.args[3]!) | failure
|
||||
logInfo <| mkAppN (mkConst ``Eq.refl r₃.us) r₃.args
|
||||
|
||||
/--
|
||||
|
|
@ -36,10 +36,10 @@ def test2 : SymM Unit := do
|
|||
let rulePax ← mkBackwardRuleFromDecl ``pax
|
||||
let mvar ← mkFreshExprMVar (← getConstInfo ``ex).value!
|
||||
let mvarId ← preprocessMVar mvar.mvarId!
|
||||
let [mvarId, _] ← ruleEx.apply mvarId | throwError "Failed"
|
||||
let [mvarId₁, mvarId₂] ← ruleAnd.apply mvarId | throwError "Failed"
|
||||
let [] ← rulePax.apply mvarId₁ | throwError "Failed"
|
||||
let [] ← ruleRefl.apply mvarId₂ | throwError "Failed"
|
||||
let .goals [mvarId, _] ← ruleEx.apply mvarId | failure
|
||||
let .goals [mvarId₁, mvarId₂] ← ruleAnd.apply mvarId | failure
|
||||
let .goals [] ← rulePax.apply mvarId₁ | failure
|
||||
let .goals [] ← ruleRefl.apply mvarId₂ | failure
|
||||
logInfo mvar
|
||||
|
||||
/--
|
||||
|
|
@ -62,7 +62,7 @@ def test3 : SymM Unit := do
|
|||
let mvar ← mkFreshExprMVar target
|
||||
let mvarId ← preprocessMVar mvar.mvarId!
|
||||
let rule ← mkBackwardRuleFromDecl ``pFoo
|
||||
let [] ← rule.apply mvarId | throwError "failed"
|
||||
let .goals [] ← rule.apply mvarId | failure
|
||||
logInfo mvar
|
||||
|
||||
/-- info: pFoo (3 + y) -/
|
||||
|
|
@ -78,7 +78,7 @@ def test4 : SymM Unit := do
|
|||
let target := mkApp (mkConst ``p) (mkApp2 (mkConst ``foo) x m1)
|
||||
let target ← shareCommon target
|
||||
let p ← mkPatternFromDecl ``pFoo
|
||||
let some r ← p.match? target | throwError "failed"
|
||||
let some r ← p.match? target | failure
|
||||
logInfo <| mkAppN (mkConst ``pFoo r.us) r.args
|
||||
|
||||
/-- info: pFoo (3 + y) -/
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ set_option warn.sorry false
|
|||
|
||||
elab "sym_simp" "[" declNames:ident,* "]" : tactic => do
|
||||
let declNames ← declNames.getElems.mapM fun s => realizeGlobalConstNoOverload s.raw
|
||||
liftMetaTactic1 <| Sym.simpGoal declNames
|
||||
liftMetaTactic1 <| Sym.simpGoalUsing declNames
|
||||
|
||||
theorem heq_self : (x ≍ x) = True := by simp
|
||||
theorem forall_true {α : Sort u} : (∀ _ : α, True) = True := by simp
|
||||
|
|
@ -115,7 +115,7 @@ example (as : Array (Nat → Nat)) (i : Nat) (_ : i < as.size) (h : as[i] a = b)
|
|||
/--
|
||||
trace: c a : Nat
|
||||
g : Nat → Nat
|
||||
h : ite (c > 0) a = g
|
||||
h : ite (0 < c) a = g
|
||||
⊢ ite (0 < c) a = g
|
||||
-/
|
||||
#guard_msgs in
|
||||
|
|
|
|||
|
|
@ -3,7 +3,9 @@ open Lean Meta Elab Tactic
|
|||
|
||||
elab "sym_simp" : tactic => do
|
||||
let methods : Sym.Simp.Methods := { post := Sym.Simp.evalGround }
|
||||
liftMetaTactic1 <| Sym.simpWith (Sym.simp · methods)
|
||||
liftMetaTactic1 fun mvarId => Sym.SymM.run do
|
||||
let mvarId ← Sym.preprocessMVar mvarId
|
||||
(← Sym.simpGoal mvarId methods).toOption
|
||||
|
||||
-- Basic arithmetic: Nat
|
||||
example : 2 + 3 = 5 := by sym_simp
|
||||
|
|
|
|||
|
|
@ -7,7 +7,9 @@ elab "sym_simp" "[" declNames:ident,* "]" : tactic => do
|
|||
pre := Sym.Simp.simpControl
|
||||
post := Sym.Simp.evalGround.andThen rewrite
|
||||
}
|
||||
liftMetaTactic1 <| Sym.simpWith (Sym.simp · methods)
|
||||
liftMetaTactic1 fun mvarId => Sym.SymM.run do
|
||||
let mvarId ← Sym.preprocessMVar mvarId
|
||||
(← Sym.simpGoal mvarId methods).toOption
|
||||
|
||||
example : (1-1) + x*1 + (2-1)*0 = x := by
|
||||
sym_simp [Nat.add_zero, Nat.zero_add, Nat.mul_one]
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ import Lean.Meta.Sym
|
|||
|
||||
open Lean Meta Sym
|
||||
def profileM {α : Type} (k : MetaM α) (msg : String := "experiment") : MetaM α :=
|
||||
profileitM Exception msg ({ : Options }.set `profiler true |>.setNat `profiler.threshold 0) k
|
||||
profileitM Exception msg (Options.empty.set `profiler true |>.set `profiler.threshold 0) k
|
||||
|
||||
def genTerm (n : Nat) : Expr := Id.run do
|
||||
let mut e := mkConst ``True
|
||||
|
|
@ -33,11 +33,8 @@ def tryIntros? (goals : List MVarId) : SymM (Option (List MVarId)) := do
|
|||
|
||||
def tryApply? (rule : BackwardRule) (goals : List MVarId) : SymM (Option (List MVarId)) := do
|
||||
let goal :: goals := goals | return none
|
||||
try
|
||||
let goals' ← rule.apply goal
|
||||
return some (goals' ++ goals)
|
||||
catch _ =>
|
||||
return none
|
||||
let .goals goals' ← rule.apply goal | return none
|
||||
return some (goals' ++ goals)
|
||||
|
||||
def tryApplyAny? (rules : List BackwardRule) (goals : List MVarId) : SymM (Option (List MVarId)) := do
|
||||
match rules with
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue