From 08e6f714caf50719bfbdee301dcc2d86e84fea81 Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Wed, 21 Jan 2026 09:02:22 -0800 Subject: [PATCH] chore: normalize `Sym` APIs (#12088) This PR cleanups the Sym APIs for `apply` and `simp`. --- src/Lean/Meta/Sym/Apply.lean | 26 ++++----- src/Lean/Meta/Sym/Simp.lean | 1 + src/Lean/Meta/Sym/Simp/Debug.lean | 22 ++------ src/Lean/Meta/Sym/Simp/Forall.lean | 4 +- src/Lean/Meta/Sym/Simp/Goal.lean | 69 ++++++++++++++++++++++++ src/Lean/Meta/Sym/Simp/Lambda.lean | 4 +- tests/bench/sym/sym_add_sub_cancel.lean | 71 +++++++++++++++++++++++++ tests/lean/run/sym_pattern.lean | 16 +++--- tests/lean/run/sym_simp_1.lean | 4 +- tests/lean/run/sym_simp_2.lean | 4 +- tests/lean/run/sym_simp_3.lean | 4 +- tests/lean/sym/perf_sym_apply.lean | 9 ++-- 12 files changed, 181 insertions(+), 53 deletions(-) create mode 100644 src/Lean/Meta/Sym/Simp/Goal.lean diff --git a/src/Lean/Meta/Sym/Apply.lean b/src/Lean/Meta/Sym/Apply.lean index 58ce7b9e70..ef4dbdc9e7 100644 --- a/src/Lean/Meta/Sym/Apply.lean +++ b/src/Lean/Meta/Sym/Apply.lean @@ -96,6 +96,10 @@ def mkValue (expr : Expr) (pattern : Pattern) (result : MatchUnifyResult) : Expr else mkAppN (expr.instantiateLevelParams pattern.levelParams result.us) result.args +public inductive ApplyResult where + | notApplicable + | goals (mvarId : List MVarId) + /-- Applies a backward rule to a goal, returning new subgoals. @@ -103,27 +107,23 @@ Applies a backward rule to a goal, returning new subgoals. 2. Assigns the goal metavariable to the theorem application 3. Returns new goals for unassigned arguments (per `resultPos`) -Returns `none` if unification fails. +Returns `.notApplicable` if unification fails. -/ -public def BackwardRule.apply? (mvarId : MVarId) (rule : BackwardRule) : SymM (Option (List MVarId)) := mvarId.withContext do +public def BackwardRule.apply (mvarId : MVarId) (rule : BackwardRule) : SymM ApplyResult := mvarId.withContext do let decl ← mvarId.getDecl if let some result ← rule.pattern.unify? decl.type then mvarId.assign (mkValue rule.expr rule.pattern result) - return some <| rule.resultPos.map fun i => + return .goals <| rule.resultPos.map fun i => result.args[i]!.mvarId! else - return none + return .notApplicable /-- -Similar to `BackwardRule.apply?`, but throws an error if unification fails. +Similar to `BackwardRule.apply', but throws an error if unification fails. -/ -public def BackwardRule.apply (mvarId : MVarId) (rule : BackwardRule) : SymM (List MVarId) := mvarId.withContext do - let decl ← mvarId.getDecl - if let some result ← rule.pattern.unify? decl.type then - mvarId.assign (mkValue rule.expr rule.pattern result) - return rule.resultPos.map fun i => - result.args[i]!.mvarId! - else - throwError "rule is not applicable to goal{mvarId}rule:{indentExpr rule.expr}" +public def BackwardRule.apply' (mvarId : MVarId) (rule : BackwardRule) : SymM (List MVarId) := do + let .goals mvarIds ← rule.apply mvarId + | throwError "rule is not applicable to goal{mvarId}rule:{indentExpr rule.expr}" + return mvarIds end Lean.Meta.Sym diff --git a/src/Lean/Meta/Sym/Simp.lean b/src/Lean/Meta/Sym/Simp.lean index 1b3f11b926..a4aa483b38 100644 --- a/src/Lean/Meta/Sym/Simp.lean +++ b/src/Lean/Meta/Sym/Simp.lean @@ -21,3 +21,4 @@ public import Lean.Meta.Sym.Simp.Debug public import Lean.Meta.Sym.Simp.EvalGround public import Lean.Meta.Sym.Simp.Discharger public import Lean.Meta.Sym.Simp.ControlFlow +public import Lean.Meta.Sym.Simp.Goal diff --git a/src/Lean/Meta/Sym/Simp/Debug.lean b/src/Lean/Meta/Sym/Simp/Debug.lean index e9fd29df3d..63102e498f 100644 --- a/src/Lean/Meta/Sym/Simp/Debug.lean +++ b/src/Lean/Meta/Sym/Simp/Debug.lean @@ -9,6 +9,7 @@ public import Lean.Meta.Sym.Simp.SimpM public import Lean.Meta.Sym.Simp.Discharger import Lean.Meta.Sym.Simp.Theorems import Lean.Meta.Sym.Simp.Rewrite +import Lean.Meta.Sym.Simp.Goal import Lean.Meta.Sym.Util import Lean.Meta.Tactic.Util import Lean.Meta.AppBuilder @@ -27,24 +28,9 @@ public def mkSimprocFor (declNames : Array Name) (d : Discharger := dischargeNon public def mkMethods (declNames : Array Name) : MetaM Methods := do return { post := (← mkSimprocFor declNames) } -public def simpWith (k : Expr → SymM Result) (mvarId : MVarId) : MetaM (Option MVarId) := SymM.run do - let mvarId ← preprocessMVar mvarId - let decl ← mvarId.getDecl - let target := decl.type - match (← k target) with - | .rfl _ => throwError "`Sym.simp` made no progress " - | .step target' h _ => - let mvarNew ← mkFreshExprSyntheticOpaqueMVar target' decl.userName - let h ← mkAppM ``Eq.mpr #[h, mvarNew] - mvarId.assign h - if target'.isTrue then - mvarNew.mvarId!.assign (mkConst ``True.intro) - return none - else - return some mvarNew.mvarId! - -public def simpGoal (declNames : Array Name) (mvarId : MVarId) : MetaM (Option MVarId) := SymM.run do mvarId.withContext do +public def simpGoalUsing (declNames : Array Name) (mvarId : MVarId) : MetaM (Option MVarId) := SymM.run do let methods ← mkMethods declNames - simpWith (simp · methods) mvarId + let mvarId ← preprocessMVar mvarId + (← simpGoal mvarId methods).toOption end Lean.Meta.Sym diff --git a/src/Lean/Meta/Sym/Simp/Forall.lean b/src/Lean/Meta/Sym/Simp/Forall.lean index 4032430bd7..4e37f4220c 100644 --- a/src/Lean/Meta/Sym/Simp/Forall.lean +++ b/src/Lean/Meta/Sym/Simp/Forall.lean @@ -81,7 +81,7 @@ public def simpForall (e : Expr) : SimpM Result := do else if (← isProp e) then let n := getForallTelescopeSize e.bindingBody! 1 forallBoundedTelescope e n fun xs b => withoutModifyingCacheIfNotWellBehaved do - main xs (← share b) + main xs (← shareCommon b) else return .rfl where @@ -90,7 +90,7 @@ where | .rfl _ => return .rfl | .step b' h _ => let h ← mkLambdaFVars xs h - let e' ← share (← mkForallFVars xs b') + let e' ← shareCommon (← mkForallFVars xs b') -- **Note**: consider caching the forall-congr theorems let hcongr ← mkForallCongrFor xs return .step e' (mkApp3 hcongr (← mkLambdaFVars xs b) (← mkLambdaFVars xs b') h) diff --git a/src/Lean/Meta/Sym/Simp/Goal.lean b/src/Lean/Meta/Sym/Simp/Goal.lean new file mode 100644 index 0000000000..0db2f442c1 --- /dev/null +++ b/src/Lean/Meta/Sym/Simp/Goal.lean @@ -0,0 +1,69 @@ +/- +Copyright (c) 2026 Amazon.com, Inc. or its affiliates. All Rights Reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Leonardo de Moura +-/ +module +prelude +public import Lean.Meta.Sym.Simp.SimpM +import Lean.Meta.Tactic.Util +import Lean.Meta.AppBuilder +namespace Lean.Meta.Sym +/-! +# Goal simplification + +Applies `Sym.simp` to a goal's target type, producing a simplified goal or closing it if +the result is `True`. +-/ + +/-- Result of simplifying a goal with `Sym.simp`. -/ +public inductive SimpGoalResult where + /-- No simplification was possible. -/ + | noProgress + /-- The goal was closed (simplified to `True`). -/ + | closed + /-- The goal was simplified to a new goal. -/ + | goal (mvarId : MVarId) + +/-- +Converts a `SimpGoalResult` to an optional goal. +Returns `none` if closed, `some mvarId` if simplified, or throws an error if no progress. +-/ +public def SimpGoalResult.toOption : SimpGoalResult → CoreM (Option MVarId) + | .noProgress => throwError "`Sym.simp` made no progress " + | .closed => return none + | .goal mvarId => return some mvarId + +/-- +Simplifies the target of `mvarId` using `Sym.simp`. +Returns `.closed` if the target simplifies to `True`, `.simp mvarId'` if simplified +to a new goal, or `.noProgress` if no simplification occurred. + +This function assumed the input goal is a valid `Sym` goal (e.g., expressions are maximally shared). +-/ +public def simpGoal (mvarId : MVarId) (methods : Simp.Methods := {}) (config : Simp.Config := {}) + : SymM SimpGoalResult := mvarId.withContext do + let decl ← mvarId.getDecl + let target := decl.type + match (← simp target methods config) with + | .rfl _ => return .noProgress + | .step target' h _ => + let mvarNew ← mkFreshExprSyntheticOpaqueMVar target' decl.userName + let h ← mkAppM ``Eq.mpr #[h, mvarNew] + mvarId.assign h + if target'.isTrue then + mvarNew.mvarId!.assign (mkConst ``True.intro) + return .closed + else + return .goal mvarNew.mvarId! + +/-- +Similar to `simpGoal`, but returns `.goal mvarId` if no progress was made. +-/ +public def trySimpGoal (mvarId : MVarId) (methods : Simp.Methods := {}) (config : Simp.Config := {}) + : SymM SimpGoalResult := do + match (← simpGoal mvarId methods config) with + | .noProgress => return .goal mvarId + | r => return r + +end Lean.Meta.Sym diff --git a/src/Lean/Meta/Sym/Simp/Lambda.lean b/src/Lean/Meta/Sym/Simp/Lambda.lean index 98a4d0f223..78050f2997 100644 --- a/src/Lean/Meta/Sym/Simp/Lambda.lean +++ b/src/Lean/Meta/Sym/Simp/Lambda.lean @@ -48,14 +48,14 @@ def mkFunextFor (xs : Array Expr) (β : Expr) : MetaM Expr := do public def simpLambda (e : Expr) : SimpM Result := do lambdaTelescope e fun xs b => withoutModifyingCacheIfNotWellBehaved do - main xs (← share b) + main xs (← shareCommon b) where main (xs : Array Expr) (b : Expr) : SimpM Result := do match (← simp b) with | .rfl _ => return .rfl | .step b' h _ => let h ← mkLambdaFVars xs h - let e' ← share (← mkLambdaFVars xs b') + let e' ← shareCommon (← mkLambdaFVars xs b') let funext ← getFunext xs b return .step e' (mkApp3 funext e e' h) diff --git a/tests/bench/sym/sym_add_sub_cancel.lean b/tests/bench/sym/sym_add_sub_cancel.lean index ed6d326ec9..7455d1b830 100644 --- a/tests/bench/sym/sym_add_sub_cancel.lean +++ b/tests/bench/sym/sym_add_sub_cancel.lean @@ -235,3 +235,74 @@ def runBenchUsingMeta : MetaM Unit := do solveUsingMeta n #eval runBenchUsingMeta +-- goal_80: 1467.414291 ms, kernel: 120.162250 ms + +/-! +`SymM` Solution +-/ + +theorem exists_eq_True (a : α) : (∃ x, x = a) = True := by + simp + +open Sym + +def mkMethods (declNames : Array Name) : MetaM Sym.Simp.Methods := do + let rewrite ← Sym.mkSimprocFor declNames + return { + post := Sym.Simp.evalGround.andThen rewrite + } + +elab "sym_simp" "[" declNames:ident,* "]" : tactic => do + let rewrite ← Sym.mkSimprocFor (← declNames.getElems.mapM fun s => realizeGlobalConstNoOverload s.raw) Sym.Simp.dischargeSimpSelf + let methods : Sym.Simp.Methods := { + pre := Sym.Simp.simpControl + post := Sym.Simp.evalGround.andThen rewrite + } + liftMetaTactic1 fun mvarId => Sym.SymM.run do + let mvarId ← Sym.preprocessMVar mvarId + (← Sym.simpGoal mvarId methods).toOption + +example (l : PartialMap String Word) : ((l.put "b" x).put "a" y).get "b" = x := by + sym_simp [PartialMap.get_put_diff, PartialMap.get_put] + +partial def solve (mvarId : MVarId) : SymM Unit := do + let exec_cpsRule ← mkBackwardRuleFromDecl ``Exec.seq_cps + let inputRule ← mkBackwardRuleFromDecl ``Exec.input + let skipRule ← mkBackwardRuleFromDecl ``Exec.skip + let setRule ← mkBackwardRuleFromDecl ``Exec.set + let rflRule ← mkBackwardRuleFromDecl ``Eq.refl + let unfoldMethods ← mkMethods #[``generated_cmd.eq_1, ``repeated_cmds.eq_1, ``repeated_cmds.eq_2] + let evalMethods ← mkMethods #[``Expr.eval.eq_1, ``Expr.eval.eq_2, ``Expr.eval.eq_3] + let simpMethods ← mkMethods #[``PartialMap.get_put_diff, ``PartialMap.get_put, ``PartialMap.put_put, ``Binop.interp_add, + ``Binop.interp_sub, ``Word.add_sub_cancel, ``Option.some.injEq, ``not_false_eq_true, ``ne_eq] + let finalSimpMethods ← mkMethods #[``List.cons.injEq, ``IOEvent.IN.injEq, ``and_true, ``true_and, ``PartialMap.put_put, ``PartialMap.get_put, + ``Option.some.injEq, ``and_self, ``exists_eq_True] + -- Initialize + let mvarId ← preprocessMVar mvarId + let (_, mvarId) ← Sym.introN mvarId 2 + let .goal mvarId ← Sym.simpGoal mvarId unfoldMethods | failure + let .goals [mvarId] ← exec_cpsRule.apply mvarId | failure + let .goals [mvarId] ← inputRule.apply mvarId | failure + let (_, mvarId) ← Sym.introN mvarId 1 + -- Loop + let rec loop (mvarId : MVarId) : SymM MVarId := do + -- mvarId.withContext do logInfo m!"{← mvarId.getType}" + let .goals [mvarId] ← exec_cpsRule.apply mvarId | return mvarId + let .goals [mvarId', mvarId, _] ← setRule.apply mvarId | failure + let .goal mvarId' ← Sym.simpGoal mvarId' evalMethods | failure + let .goal mvarId' ← Sym.simpGoal mvarId' simpMethods | failure + let .goals [] ← rflRule.apply mvarId' | failure + loop mvarId + + let mvarId ← loop mvarId + let .goals [mvarId] ← skipRule.apply mvarId | failure + let .goal mvarId ← Sym.simpGoal mvarId finalSimpMethods { maxSteps := 100000 } | failure + logInfo mvarId -- **TODO**: get_put theorem is not behaving correctly + mvarId.admit + return + +def solveUsingSym (n : Nat) (check := true) : MetaM Unit := do + driver n check fun mvarId => SymM.run do solve mvarId + +set_option maxRecDepth 100000 +#eval solveUsingSym 4 diff --git a/tests/lean/run/sym_pattern.lean b/tests/lean/run/sym_pattern.lean index 2c8aa403ac..1f9ca3326d 100644 --- a/tests/lean/run/sym_pattern.lean +++ b/tests/lean/run/sym_pattern.lean @@ -13,9 +13,9 @@ def test1 : SymM Unit := do let e ← shareCommon (← getConstInfo ``ex).value! let some r₁ ← pEx.match? e | throwError "failed" logInfo <| mkAppN (mkConst ``Exists.intro r₁.us) r₁.args - let some r₂ ← pAnd.match? (← Sym.inferType r₁.args[3]!) | throwError "failed" + let some r₂ ← pAnd.match? (← Sym.inferType r₁.args[3]!) | failure logInfo <| mkAppN (mkConst ``And.intro r₂.us) r₂.args - let some r₃ ← pEq.unify? (← Sym.inferType r₂.args[3]!) | throwError "failed" + let some r₃ ← pEq.unify? (← Sym.inferType r₂.args[3]!) | failure logInfo <| mkAppN (mkConst ``Eq.refl r₃.us) r₃.args /-- @@ -36,10 +36,10 @@ def test2 : SymM Unit := do let rulePax ← mkBackwardRuleFromDecl ``pax let mvar ← mkFreshExprMVar (← getConstInfo ``ex).value! let mvarId ← preprocessMVar mvar.mvarId! - let [mvarId, _] ← ruleEx.apply mvarId | throwError "Failed" - let [mvarId₁, mvarId₂] ← ruleAnd.apply mvarId | throwError "Failed" - let [] ← rulePax.apply mvarId₁ | throwError "Failed" - let [] ← ruleRefl.apply mvarId₂ | throwError "Failed" + let .goals [mvarId, _] ← ruleEx.apply mvarId | failure + let .goals [mvarId₁, mvarId₂] ← ruleAnd.apply mvarId | failure + let .goals [] ← rulePax.apply mvarId₁ | failure + let .goals [] ← ruleRefl.apply mvarId₂ | failure logInfo mvar /-- @@ -62,7 +62,7 @@ def test3 : SymM Unit := do let mvar ← mkFreshExprMVar target let mvarId ← preprocessMVar mvar.mvarId! let rule ← mkBackwardRuleFromDecl ``pFoo - let [] ← rule.apply mvarId | throwError "failed" + let .goals [] ← rule.apply mvarId | failure logInfo mvar /-- info: pFoo (3 + y) -/ @@ -78,7 +78,7 @@ def test4 : SymM Unit := do let target := mkApp (mkConst ``p) (mkApp2 (mkConst ``foo) x m1) let target ← shareCommon target let p ← mkPatternFromDecl ``pFoo - let some r ← p.match? target | throwError "failed" + let some r ← p.match? target | failure logInfo <| mkAppN (mkConst ``pFoo r.us) r.args /-- info: pFoo (3 + y) -/ diff --git a/tests/lean/run/sym_simp_1.lean b/tests/lean/run/sym_simp_1.lean index 3de843bb07..dc7911c3db 100644 --- a/tests/lean/run/sym_simp_1.lean +++ b/tests/lean/run/sym_simp_1.lean @@ -7,7 +7,7 @@ set_option warn.sorry false elab "sym_simp" "[" declNames:ident,* "]" : tactic => do let declNames ← declNames.getElems.mapM fun s => realizeGlobalConstNoOverload s.raw - liftMetaTactic1 <| Sym.simpGoal declNames + liftMetaTactic1 <| Sym.simpGoalUsing declNames theorem heq_self : (x ≍ x) = True := by simp theorem forall_true {α : Sort u} : (∀ _ : α, True) = True := by simp @@ -115,7 +115,7 @@ example (as : Array (Nat → Nat)) (i : Nat) (_ : i < as.size) (h : as[i] a = b) /-- trace: c a : Nat g : Nat → Nat -h : ite (c > 0) a = g +h : ite (0 < c) a = g ⊢ ite (0 < c) a = g -/ #guard_msgs in diff --git a/tests/lean/run/sym_simp_2.lean b/tests/lean/run/sym_simp_2.lean index 06e368dd82..8059827846 100644 --- a/tests/lean/run/sym_simp_2.lean +++ b/tests/lean/run/sym_simp_2.lean @@ -3,7 +3,9 @@ open Lean Meta Elab Tactic elab "sym_simp" : tactic => do let methods : Sym.Simp.Methods := { post := Sym.Simp.evalGround } - liftMetaTactic1 <| Sym.simpWith (Sym.simp · methods) + liftMetaTactic1 fun mvarId => Sym.SymM.run do + let mvarId ← Sym.preprocessMVar mvarId + (← Sym.simpGoal mvarId methods).toOption -- Basic arithmetic: Nat example : 2 + 3 = 5 := by sym_simp diff --git a/tests/lean/run/sym_simp_3.lean b/tests/lean/run/sym_simp_3.lean index 6be9341436..7062859f45 100644 --- a/tests/lean/run/sym_simp_3.lean +++ b/tests/lean/run/sym_simp_3.lean @@ -7,7 +7,9 @@ elab "sym_simp" "[" declNames:ident,* "]" : tactic => do pre := Sym.Simp.simpControl post := Sym.Simp.evalGround.andThen rewrite } - liftMetaTactic1 <| Sym.simpWith (Sym.simp · methods) + liftMetaTactic1 fun mvarId => Sym.SymM.run do + let mvarId ← Sym.preprocessMVar mvarId + (← Sym.simpGoal mvarId methods).toOption example : (1-1) + x*1 + (2-1)*0 = x := by sym_simp [Nat.add_zero, Nat.zero_add, Nat.mul_one] diff --git a/tests/lean/sym/perf_sym_apply.lean b/tests/lean/sym/perf_sym_apply.lean index 4ef2bb804e..b4af3b9645 100644 --- a/tests/lean/sym/perf_sym_apply.lean +++ b/tests/lean/sym/perf_sym_apply.lean @@ -3,7 +3,7 @@ import Lean.Meta.Sym open Lean Meta Sym def profileM {α : Type} (k : MetaM α) (msg : String := "experiment") : MetaM α := - profileitM Exception msg ({ : Options }.set `profiler true |>.setNat `profiler.threshold 0) k + profileitM Exception msg (Options.empty.set `profiler true |>.set `profiler.threshold 0) k def genTerm (n : Nat) : Expr := Id.run do let mut e := mkConst ``True @@ -33,11 +33,8 @@ def tryIntros? (goals : List MVarId) : SymM (Option (List MVarId)) := do def tryApply? (rule : BackwardRule) (goals : List MVarId) : SymM (Option (List MVarId)) := do let goal :: goals := goals | return none - try - let goals' ← rule.apply goal - return some (goals' ++ goals) - catch _ => - return none + let .goals goals' ← rule.apply goal | return none + return some (goals' ++ goals) def tryApplyAny? (rules : List BackwardRule) (goals : List MVarId) : SymM (Option (List MVarId)) := do match rules with