diff --git a/src/Lean/Meta/Sym/Pattern.lean b/src/Lean/Meta/Sym/Pattern.lean index 964a4cc419..1960483ff1 100644 --- a/src/Lean/Meta/Sym/Pattern.lean +++ b/src/Lean/Meta/Sym/Pattern.lean @@ -68,6 +68,15 @@ def isUVar? (n : Name) : Option Nat := Id.run do unless p == uvarPrefix do return none return some idx +/-- Helper function for implementing `mkPatternFromDecl` and `mkEqPatternFromDecl` -/ +def preprocessPattern (declName : Name) : MetaM (List Name × Expr) := do + let info ← getConstInfo declName + let levelParams := info.levelParams.mapIdx fun i _ => Name.num uvarPrefix i + let us := levelParams.map mkLevelParam + let type ← instantiateTypeLevelParams info.toConstantVal us + let type ← preprocessType type + return (levelParams, type) + /-- Creates a `Pattern` from the type of a theorem. @@ -82,11 +91,7 @@ If `num?` is `some n`, at most `n` leading quantifiers are stripped. If `num?` is `none`, all leading quantifiers are stripped. -/ public def mkPatternFromDecl (declName : Name) (num? : Option Nat := none) : MetaM Pattern := do - let info ← getConstInfo declName - let levelParams := info.levelParams.mapIdx fun i _ => Name.num uvarPrefix i - let us := levelParams.map mkLevelParam - let type ← instantiateTypeLevelParams info.toConstantVal us - let type ← preprocessType type + let (levelParams, type) ← preprocessPattern declName let hugeNumber := 10000000 let num := num?.getD hugeNumber let rec go (i : Nat) (type : Expr) (varTypes : Array Expr) (isInstance : Array Bool) : MetaM Pattern := do @@ -98,6 +103,29 @@ public def mkPatternFromDecl (declName : Name) (num? : Option Nat := none) : Met return { levelParams, varTypes, isInstance, pattern, fnInfos } go 0 type #[] #[] +/-- +Creates a `Pattern` from an equational theorem, using the left-hand side of the equation. +It also returns the right-hand side of the equation + +Like `mkPatternFromDecl`, this strips all leading universal quantifiers, recording variable +types and instance status. However, instead of using the entire resulting type as the pattern, +it extracts just the LHS of the equation. + +For a theorem `∀ x₁ ... xₙ, lhs = rhs`, returns a pattern matching `lhs` with `n` pattern variables. +Throws an error if the theorem's conclusion is not an equality. +-/ +public def mkEqPatternFromDecl (declName : Name) : MetaM (Pattern × Expr) := do + let (levelParams, type) ← preprocessPattern declName + let rec go (type : Expr) (varTypes : Array Expr) (isInstance : Array Bool) : MetaM (Pattern × Expr) := do + if let .forallE _ d b _ := type then + return (← go b (varTypes.push d) (isInstance.push (isClass? (← getEnv) d).isSome)) + else + let_expr Eq _ lhs rhs := type | throwError "resulting type for `{.ofConstName declName}` is not an equality" + let pattern := lhs + let fnInfos ← mkProofInstInfoMapFor pattern + return ({ levelParams, varTypes, isInstance, pattern, fnInfos }, rhs) + go type #[] #[] + structure UnifyM.Context where pattern : Pattern unify : Bool := true diff --git a/src/Lean/Meta/Sym/Rewrite.lean b/src/Lean/Meta/Sym/Rewrite.lean new file mode 100644 index 0000000000..2a230a7e95 --- /dev/null +++ b/src/Lean/Meta/Sym/Rewrite.lean @@ -0,0 +1,50 @@ +/- +Copyright (c) 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Leonardo de Moura +-/ +module +prelude +public import Lean.Meta.Sym.SimpM +public import Lean.Meta.Sym.SimpFun +import Lean.Meta.Sym.InstantiateS +namespace Lean.Meta.Sym.Simp +open Grind + +public def mkTheoremFromDecl (declName : Name) : MetaM Theorem := do + let (pattern, rhs) ← mkEqPatternFromDecl declName + return { expr := mkConst declName, pattern, rhs } + +/-- +Creates proof term for a rewriting step. +Handles both constant expressions (common case, avoids `instantiateLevelParams`) +and general expressions. +-/ +def mkValue (expr : Expr) (pattern : Pattern) (result : MatchUnifyResult) : Expr := + if let .const declName [] := expr then + mkAppN (mkConst declName result.us) result.args + else + mkAppN (expr.instantiateLevelParams pattern.levelParams result.us) result.args + +/-- +Tries to rewrite `e` using the given theorem. +-/ +-- **TODO**: Define `Step` result? +public def Theorem.rewrite? (thm : Theorem) (e : Expr) : SimpM (Option Result) := do + if let some result ← thm.pattern.match? e then + let proof? := some <| mkValue thm.expr thm.pattern result + let rhs := thm.rhs.instantiateLevelParams thm.pattern.levelParams result.us + let rhs ← shareCommonInc rhs + let expr ← instantiateRevBetaS rhs result.args + return some { expr, proof? } + else + return none + +public def rewrite : SimpFun := fun e => do + -- **TODO**: use indexing + for thm in (← read).thms.thms do + if let some result ← thm.rewrite? e then + return result + return { expr := e } + +end Lean.Meta.Sym.Simp diff --git a/src/Lean/Meta/Sym/Simp.lean b/src/Lean/Meta/Sym/Simp.lean index d76446cf61..bed1f2abb7 100644 --- a/src/Lean/Meta/Sym/Simp.lean +++ b/src/Lean/Meta/Sym/Simp.lean @@ -8,6 +8,9 @@ prelude public import Lean.Meta.Sym.SimpM import Lean.Meta.Tactic.Grind.AlphaShareBuilder import Lean.Meta.Sym.EqTrans +import Lean.Meta.Sym.Rewrite +import Lean.Meta.Sym.SimpResult +import Lean.Meta.Sym.SimpFun import Lean.Meta.Sym.Congr namespace Lean.Meta.Sym.Simp open Grind @@ -39,7 +42,7 @@ def simpMVar (e : Expr) : SimpM Result := do def simpApp (e : Expr) : SimpM Result := do congrArgs e -def simpStep (e : Expr) : SimpM Result := do +def simpStep : SimpFun := fun e => do match e with | .lit _ | .sort _ | .bvar _ => return { expr := e } | .proj .. => @@ -57,10 +60,6 @@ def simpStep (e : Expr) : SimpM Result := do | .mvar .. => simpMVar e | .app .. => simpApp e -def mkEqTrans (e : Expr) (r₁ : Result) (r₂ : Result) : SimpM Result := do - let proof? ← Sym.mkEqTrans e r₁.expr r₁.proof? r₂.expr r₂.proof? - return { r₂ with proof? } - def cacheResult (e : Expr) (r : Result) : SimpM Result := do modify fun s => { s with cache := s.cache.insert { expr := e } r } return r @@ -71,7 +70,7 @@ def simpImpl (e : Expr) : SimpM Result := do throwError "`simp` failed: maximum number of steps exceeded" if let some result := (← getCache).find? { expr := e } then return result - let r ← simpStep e + let r ← (simpStep >> rewrite) e if isSameExpr r.expr e then cacheResult e r else diff --git a/src/Lean/Meta/Sym/SimpFun.lean b/src/Lean/Meta/Sym/SimpFun.lean new file mode 100644 index 0000000000..a663a452c1 --- /dev/null +++ b/src/Lean/Meta/Sym/SimpFun.lean @@ -0,0 +1,29 @@ +/- +Copyright (c) 2026 Amazon.com, Inc. or its affiliates. All Rights Reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Leonardo de Moura +-/ +module +prelude +public import Lean.Meta.Sym.SimpM +import Lean.Meta.Sym.EqTrans +namespace Lean.Meta.Sym.Simp +open Grind +public def mkEqTrans (e : Expr) (r₁ : Result) (r₂ : Result) : SimpM Result := do + let proof? ← Sym.mkEqTrans e r₁.expr r₁.proof? r₂.expr r₂.proof? + return { r₂ with proof? } + +public abbrev SimpFun := Expr → SimpM Result + +public abbrev SimpFun.andThen (f g : SimpFun) : SimpFun := fun e => do + let r₁ ← f e + if isSameExpr e r₁.expr then + g e + else + let r₂ ← g r₁.expr + mkEqTrans e r₁ r₂ + +public instance : AndThen SimpFun where + andThen f g := SimpFun.andThen f (g ()) + +end Lean.Meta.Sym.Simp diff --git a/src/Lean/Meta/Sym/SimpM.lean b/src/Lean/Meta/Sym/SimpM.lean index 5dc53eb8d0..8e8ba3bcc6 100644 --- a/src/Lean/Meta/Sym/SimpM.lean +++ b/src/Lean/Meta/Sym/SimpM.lean @@ -123,9 +123,11 @@ during rewriting. -/ structure Theorem where /-- The theorem expression, typically `Expr.const declName` for a named theorem. -/ - expr : Expr + expr : Expr /-- Precomputed pattern extracted from the theorem's type for efficient matching. -/ - pattern : Pattern + pattern : Pattern + /-- Right-hand side of the equation. -/ + rhs : Expr /-- Collection of simplification theorems available to the simplifier. -/ structure Theorems where @@ -178,4 +180,9 @@ def getConfig : SimpM Config := abbrev getCache : SimpM Cache := return (← get).cache -end Lean.Meta.Sym.Simp +end Simp + +public def simp (e : Expr) (thms : Simp.Theorems := {}) (config : Simp.Config := {}) : SymM Simp.Result := do + Simp.SimpM.run (Simp.simp e) thms config + +end Lean.Meta.Sym diff --git a/src/Lean/Meta/Sym/SimpResult.lean b/src/Lean/Meta/Sym/SimpResult.lean index 3950c69dcc..b5d4550d12 100644 --- a/src/Lean/Meta/Sym/SimpResult.lean +++ b/src/Lean/Meta/Sym/SimpResult.lean @@ -6,7 +6,6 @@ Authors: Leonardo de Moura module prelude public import Lean.Meta.Sym.SimpM -import Lean.Meta.Sym.InferType namespace Lean.Meta.Sym.Simp public def Result.getProof (result : Result) : SymM Expr := do