chore: normalize Sym APIs (#12088)

This PR cleanups the Sym APIs for `apply` and `simp`.
This commit is contained in:
Leonardo de Moura 2026-01-21 09:02:22 -08:00 committed by GitHub
parent b8f8dde0b3
commit 08e6f714ca
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 181 additions and 53 deletions

View file

@ -96,6 +96,10 @@ def mkValue (expr : Expr) (pattern : Pattern) (result : MatchUnifyResult) : Expr
else
mkAppN (expr.instantiateLevelParams pattern.levelParams result.us) result.args
public inductive ApplyResult where
| notApplicable
| goals (mvarId : List MVarId)
/--
Applies a backward rule to a goal, returning new subgoals.
@ -103,27 +107,23 @@ Applies a backward rule to a goal, returning new subgoals.
2. Assigns the goal metavariable to the theorem application
3. Returns new goals for unassigned arguments (per `resultPos`)
Returns `none` if unification fails.
Returns `.notApplicable` if unification fails.
-/
public def BackwardRule.apply? (mvarId : MVarId) (rule : BackwardRule) : SymM (Option (List MVarId)) := mvarId.withContext do
public def BackwardRule.apply (mvarId : MVarId) (rule : BackwardRule) : SymM ApplyResult := mvarId.withContext do
let decl ← mvarId.getDecl
if let some result ← rule.pattern.unify? decl.type then
mvarId.assign (mkValue rule.expr rule.pattern result)
return some <| rule.resultPos.map fun i =>
return .goals <| rule.resultPos.map fun i =>
result.args[i]!.mvarId!
else
return none
return .notApplicable
/--
Similar to `BackwardRule.apply?`, but throws an error if unification fails.
Similar to `BackwardRule.apply', but throws an error if unification fails.
-/
public def BackwardRule.apply (mvarId : MVarId) (rule : BackwardRule) : SymM (List MVarId) := mvarId.withContext do
let decl ← mvarId.getDecl
if let some result ← rule.pattern.unify? decl.type then
mvarId.assign (mkValue rule.expr rule.pattern result)
return rule.resultPos.map fun i =>
result.args[i]!.mvarId!
else
throwError "rule is not applicable to goal{mvarId}rule:{indentExpr rule.expr}"
public def BackwardRule.apply' (mvarId : MVarId) (rule : BackwardRule) : SymM (List MVarId) := do
let .goals mvarIds ← rule.apply mvarId
| throwError "rule is not applicable to goal{mvarId}rule:{indentExpr rule.expr}"
return mvarIds
end Lean.Meta.Sym

View file

@ -21,3 +21,4 @@ public import Lean.Meta.Sym.Simp.Debug
public import Lean.Meta.Sym.Simp.EvalGround
public import Lean.Meta.Sym.Simp.Discharger
public import Lean.Meta.Sym.Simp.ControlFlow
public import Lean.Meta.Sym.Simp.Goal

View file

@ -9,6 +9,7 @@ public import Lean.Meta.Sym.Simp.SimpM
public import Lean.Meta.Sym.Simp.Discharger
import Lean.Meta.Sym.Simp.Theorems
import Lean.Meta.Sym.Simp.Rewrite
import Lean.Meta.Sym.Simp.Goal
import Lean.Meta.Sym.Util
import Lean.Meta.Tactic.Util
import Lean.Meta.AppBuilder
@ -27,24 +28,9 @@ public def mkSimprocFor (declNames : Array Name) (d : Discharger := dischargeNon
public def mkMethods (declNames : Array Name) : MetaM Methods := do
return { post := (← mkSimprocFor declNames) }
public def simpWith (k : Expr → SymM Result) (mvarId : MVarId) : MetaM (Option MVarId) := SymM.run do
let mvarId ← preprocessMVar mvarId
let decl ← mvarId.getDecl
let target := decl.type
match (← k target) with
| .rfl _ => throwError "`Sym.simp` made no progress "
| .step target' h _ =>
let mvarNew ← mkFreshExprSyntheticOpaqueMVar target' decl.userName
let h ← mkAppM ``Eq.mpr #[h, mvarNew]
mvarId.assign h
if target'.isTrue then
mvarNew.mvarId!.assign (mkConst ``True.intro)
return none
else
return some mvarNew.mvarId!
public def simpGoal (declNames : Array Name) (mvarId : MVarId) : MetaM (Option MVarId) := SymM.run do mvarId.withContext do
public def simpGoalUsing (declNames : Array Name) (mvarId : MVarId) : MetaM (Option MVarId) := SymM.run do
let methods ← mkMethods declNames
simpWith (simp · methods) mvarId
let mvarId ← preprocessMVar mvarId
(← simpGoal mvarId methods).toOption
end Lean.Meta.Sym

View file

@ -81,7 +81,7 @@ public def simpForall (e : Expr) : SimpM Result := do
else if (← isProp e) then
let n := getForallTelescopeSize e.bindingBody! 1
forallBoundedTelescope e n fun xs b => withoutModifyingCacheIfNotWellBehaved do
main xs (← share b)
main xs (← shareCommon b)
else
return .rfl
where
@ -90,7 +90,7 @@ where
| .rfl _ => return .rfl
| .step b' h _ =>
let h ← mkLambdaFVars xs h
let e' ← share (← mkForallFVars xs b')
let e' ← shareCommon (← mkForallFVars xs b')
-- **Note**: consider caching the forall-congr theorems
let hcongr ← mkForallCongrFor xs
return .step e' (mkApp3 hcongr (← mkLambdaFVars xs b) (← mkLambdaFVars xs b') h)

View file

@ -0,0 +1,69 @@
/-
Copyright (c) 2026 Amazon.com, Inc. or its affiliates. All Rights Reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
module
prelude
public import Lean.Meta.Sym.Simp.SimpM
import Lean.Meta.Tactic.Util
import Lean.Meta.AppBuilder
namespace Lean.Meta.Sym
/-!
# Goal simplification
Applies `Sym.simp` to a goal's target type, producing a simplified goal or closing it if
the result is `True`.
-/
/-- Result of simplifying a goal with `Sym.simp`. -/
public inductive SimpGoalResult where
/-- No simplification was possible. -/
| noProgress
/-- The goal was closed (simplified to `True`). -/
| closed
/-- The goal was simplified to a new goal. -/
| goal (mvarId : MVarId)
/--
Converts a `SimpGoalResult` to an optional goal.
Returns `none` if closed, `some mvarId` if simplified, or throws an error if no progress.
-/
public def SimpGoalResult.toOption : SimpGoalResult → CoreM (Option MVarId)
| .noProgress => throwError "`Sym.simp` made no progress "
| .closed => return none
| .goal mvarId => return some mvarId
/--
Simplifies the target of `mvarId` using `Sym.simp`.
Returns `.closed` if the target simplifies to `True`, `.simp mvarId'` if simplified
to a new goal, or `.noProgress` if no simplification occurred.
This function assumed the input goal is a valid `Sym` goal (e.g., expressions are maximally shared).
-/
public def simpGoal (mvarId : MVarId) (methods : Simp.Methods := {}) (config : Simp.Config := {})
: SymM SimpGoalResult := mvarId.withContext do
let decl ← mvarId.getDecl
let target := decl.type
match (← simp target methods config) with
| .rfl _ => return .noProgress
| .step target' h _ =>
let mvarNew ← mkFreshExprSyntheticOpaqueMVar target' decl.userName
let h ← mkAppM ``Eq.mpr #[h, mvarNew]
mvarId.assign h
if target'.isTrue then
mvarNew.mvarId!.assign (mkConst ``True.intro)
return .closed
else
return .goal mvarNew.mvarId!
/--
Similar to `simpGoal`, but returns `.goal mvarId` if no progress was made.
-/
public def trySimpGoal (mvarId : MVarId) (methods : Simp.Methods := {}) (config : Simp.Config := {})
: SymM SimpGoalResult := do
match (← simpGoal mvarId methods config) with
| .noProgress => return .goal mvarId
| r => return r
end Lean.Meta.Sym

View file

@ -48,14 +48,14 @@ def mkFunextFor (xs : Array Expr) (β : Expr) : MetaM Expr := do
public def simpLambda (e : Expr) : SimpM Result := do
lambdaTelescope e fun xs b => withoutModifyingCacheIfNotWellBehaved do
main xs (← share b)
main xs (← shareCommon b)
where
main (xs : Array Expr) (b : Expr) : SimpM Result := do
match (← simp b) with
| .rfl _ => return .rfl
| .step b' h _ =>
let h ← mkLambdaFVars xs h
let e' ← share (← mkLambdaFVars xs b')
let e' ← shareCommon (← mkLambdaFVars xs b')
let funext ← getFunext xs b
return .step e' (mkApp3 funext e e' h)

View file

@ -235,3 +235,74 @@ def runBenchUsingMeta : MetaM Unit := do
solveUsingMeta n
#eval runBenchUsingMeta
-- goal_80: 1467.414291 ms, kernel: 120.162250 ms
/-!
`SymM` Solution
-/
theorem exists_eq_True (a : α) : (∃ x, x = a) = True := by
simp
open Sym
def mkMethods (declNames : Array Name) : MetaM Sym.Simp.Methods := do
let rewrite ← Sym.mkSimprocFor declNames
return {
post := Sym.Simp.evalGround.andThen rewrite
}
elab "sym_simp" "[" declNames:ident,* "]" : tactic => do
let rewrite ← Sym.mkSimprocFor (← declNames.getElems.mapM fun s => realizeGlobalConstNoOverload s.raw) Sym.Simp.dischargeSimpSelf
let methods : Sym.Simp.Methods := {
pre := Sym.Simp.simpControl
post := Sym.Simp.evalGround.andThen rewrite
}
liftMetaTactic1 fun mvarId => Sym.SymM.run do
let mvarId ← Sym.preprocessMVar mvarId
(← Sym.simpGoal mvarId methods).toOption
example (l : PartialMap String Word) : ((l.put "b" x).put "a" y).get "b" = x := by
sym_simp [PartialMap.get_put_diff, PartialMap.get_put]
partial def solve (mvarId : MVarId) : SymM Unit := do
let exec_cpsRule ← mkBackwardRuleFromDecl ``Exec.seq_cps
let inputRule ← mkBackwardRuleFromDecl ``Exec.input
let skipRule ← mkBackwardRuleFromDecl ``Exec.skip
let setRule ← mkBackwardRuleFromDecl ``Exec.set
let rflRule ← mkBackwardRuleFromDecl ``Eq.refl
let unfoldMethods ← mkMethods #[``generated_cmd.eq_1, ``repeated_cmds.eq_1, ``repeated_cmds.eq_2]
let evalMethods ← mkMethods #[``Expr.eval.eq_1, ``Expr.eval.eq_2, ``Expr.eval.eq_3]
let simpMethods ← mkMethods #[``PartialMap.get_put_diff, ``PartialMap.get_put, ``PartialMap.put_put, ``Binop.interp_add,
``Binop.interp_sub, ``Word.add_sub_cancel, ``Option.some.injEq, ``not_false_eq_true, ``ne_eq]
let finalSimpMethods ← mkMethods #[``List.cons.injEq, ``IOEvent.IN.injEq, ``and_true, ``true_and, ``PartialMap.put_put, ``PartialMap.get_put,
``Option.some.injEq, ``and_self, ``exists_eq_True]
-- Initialize
let mvarId ← preprocessMVar mvarId
let (_, mvarId) ← Sym.introN mvarId 2
let .goal mvarId ← Sym.simpGoal mvarId unfoldMethods | failure
let .goals [mvarId] ← exec_cpsRule.apply mvarId | failure
let .goals [mvarId] ← inputRule.apply mvarId | failure
let (_, mvarId) ← Sym.introN mvarId 1
-- Loop
let rec loop (mvarId : MVarId) : SymM MVarId := do
-- mvarId.withContext do logInfo m!"{← mvarId.getType}"
let .goals [mvarId] ← exec_cpsRule.apply mvarId | return mvarId
let .goals [mvarId', mvarId, _] ← setRule.apply mvarId | failure
let .goal mvarId' ← Sym.simpGoal mvarId' evalMethods | failure
let .goal mvarId' ← Sym.simpGoal mvarId' simpMethods | failure
let .goals [] ← rflRule.apply mvarId' | failure
loop mvarId
let mvarId ← loop mvarId
let .goals [mvarId] ← skipRule.apply mvarId | failure
let .goal mvarId ← Sym.simpGoal mvarId finalSimpMethods { maxSteps := 100000 } | failure
logInfo mvarId -- **TODO**: get_put theorem is not behaving correctly
mvarId.admit
return
def solveUsingSym (n : Nat) (check := true) : MetaM Unit := do
driver n check fun mvarId => SymM.run do solve mvarId
set_option maxRecDepth 100000
#eval solveUsingSym 4

View file

@ -13,9 +13,9 @@ def test1 : SymM Unit := do
let e ← shareCommon (← getConstInfo ``ex).value!
let some r₁ ← pEx.match? e | throwError "failed"
logInfo <| mkAppN (mkConst ``Exists.intro r₁.us) r₁.args
let some r₂ ← pAnd.match? (← Sym.inferType r₁.args[3]!) | throwError "failed"
let some r₂ ← pAnd.match? (← Sym.inferType r₁.args[3]!) | failure
logInfo <| mkAppN (mkConst ``And.intro r₂.us) r₂.args
let some r₃ ← pEq.unify? (← Sym.inferType r₂.args[3]!) | throwError "failed"
let some r₃ ← pEq.unify? (← Sym.inferType r₂.args[3]!) | failure
logInfo <| mkAppN (mkConst ``Eq.refl r₃.us) r₃.args
/--
@ -36,10 +36,10 @@ def test2 : SymM Unit := do
let rulePax ← mkBackwardRuleFromDecl ``pax
let mvar ← mkFreshExprMVar (← getConstInfo ``ex).value!
let mvarId ← preprocessMVar mvar.mvarId!
let [mvarId, _] ← ruleEx.apply mvarId | throwError "Failed"
let [mvarId₁, mvarId₂] ← ruleAnd.apply mvarId | throwError "Failed"
let [] ← rulePax.apply mvarId₁ | throwError "Failed"
let [] ← ruleRefl.apply mvarId₂ | throwError "Failed"
let .goals [mvarId, _] ← ruleEx.apply mvarId | failure
let .goals [mvarId₁, mvarId₂] ← ruleAnd.apply mvarId | failure
let .goals [] ← rulePax.apply mvarId₁ | failure
let .goals [] ← ruleRefl.apply mvarId₂ | failure
logInfo mvar
/--
@ -62,7 +62,7 @@ def test3 : SymM Unit := do
let mvar ← mkFreshExprMVar target
let mvarId ← preprocessMVar mvar.mvarId!
let rule ← mkBackwardRuleFromDecl ``pFoo
let [] ← rule.apply mvarId | throwError "failed"
let .goals [] ← rule.apply mvarId | failure
logInfo mvar
/-- info: pFoo (3 + y) -/
@ -78,7 +78,7 @@ def test4 : SymM Unit := do
let target := mkApp (mkConst ``p) (mkApp2 (mkConst ``foo) x m1)
let target ← shareCommon target
let p ← mkPatternFromDecl ``pFoo
let some r ← p.match? target | throwError "failed"
let some r ← p.match? target | failure
logInfo <| mkAppN (mkConst ``pFoo r.us) r.args
/-- info: pFoo (3 + y) -/

View file

@ -7,7 +7,7 @@ set_option warn.sorry false
elab "sym_simp" "[" declNames:ident,* "]" : tactic => do
let declNames ← declNames.getElems.mapM fun s => realizeGlobalConstNoOverload s.raw
liftMetaTactic1 <| Sym.simpGoal declNames
liftMetaTactic1 <| Sym.simpGoalUsing declNames
theorem heq_self : (x ≍ x) = True := by simp
theorem forall_true {α : Sort u} : (∀ _ : α, True) = True := by simp
@ -115,7 +115,7 @@ example (as : Array (Nat → Nat)) (i : Nat) (_ : i < as.size) (h : as[i] a = b)
/--
trace: c a : Nat
g : Nat → Nat
h : ite (c > 0) a = g
h : ite (0 < c) a = g
⊢ ite (0 < c) a = g
-/
#guard_msgs in

View file

@ -3,7 +3,9 @@ open Lean Meta Elab Tactic
elab "sym_simp" : tactic => do
let methods : Sym.Simp.Methods := { post := Sym.Simp.evalGround }
liftMetaTactic1 <| Sym.simpWith (Sym.simp · methods)
liftMetaTactic1 fun mvarId => Sym.SymM.run do
let mvarId ← Sym.preprocessMVar mvarId
(← Sym.simpGoal mvarId methods).toOption
-- Basic arithmetic: Nat
example : 2 + 3 = 5 := by sym_simp

View file

@ -7,7 +7,9 @@ elab "sym_simp" "[" declNames:ident,* "]" : tactic => do
pre := Sym.Simp.simpControl
post := Sym.Simp.evalGround.andThen rewrite
}
liftMetaTactic1 <| Sym.simpWith (Sym.simp · methods)
liftMetaTactic1 fun mvarId => Sym.SymM.run do
let mvarId ← Sym.preprocessMVar mvarId
(← Sym.simpGoal mvarId methods).toOption
example : (1-1) + x*1 + (2-1)*0 = x := by
sym_simp [Nat.add_zero, Nat.zero_add, Nat.mul_one]

View file

@ -3,7 +3,7 @@ import Lean.Meta.Sym
open Lean Meta Sym
def profileM {α : Type} (k : MetaM α) (msg : String := "experiment") : MetaM α :=
profileitM Exception msg ({ : Options }.set `profiler true |>.setNat `profiler.threshold 0) k
profileitM Exception msg (Options.empty.set `profiler true |>.set `profiler.threshold 0) k
def genTerm (n : Nat) : Expr := Id.run do
let mut e := mkConst ``True
@ -33,11 +33,8 @@ def tryIntros? (goals : List MVarId) : SymM (Option (List MVarId)) := do
def tryApply? (rule : BackwardRule) (goals : List MVarId) : SymM (Option (List MVarId)) := do
let goal :: goals := goals | return none
try
let goals' ← rule.apply goal
return some (goals' ++ goals)
catch _ =>
return none
let .goals goals' ← rule.apply goal | return none
return some (goals' ++ goals)
def tryApplyAny? (rules : List BackwardRule) (goals : List MVarId) : SymM (Option (List MVarId)) := do
match rules with