feat: add Sym.Simp.Theorem.rewrite? (#11868)
This PR implements `Sym.Simp.Theorem.rewrite?` for rewriting terms using equational theorems in `Sym`.
This commit is contained in:
parent
97c23abf8e
commit
b82f969e5b
6 changed files with 127 additions and 15 deletions
|
|
@ -68,6 +68,15 @@ def isUVar? (n : Name) : Option Nat := Id.run do
|
|||
unless p == uvarPrefix do return none
|
||||
return some idx
|
||||
|
||||
/-- Helper function for implementing `mkPatternFromDecl` and `mkEqPatternFromDecl` -/
|
||||
def preprocessPattern (declName : Name) : MetaM (List Name × Expr) := do
|
||||
let info ← getConstInfo declName
|
||||
let levelParams := info.levelParams.mapIdx fun i _ => Name.num uvarPrefix i
|
||||
let us := levelParams.map mkLevelParam
|
||||
let type ← instantiateTypeLevelParams info.toConstantVal us
|
||||
let type ← preprocessType type
|
||||
return (levelParams, type)
|
||||
|
||||
/--
|
||||
Creates a `Pattern` from the type of a theorem.
|
||||
|
||||
|
|
@ -82,11 +91,7 @@ If `num?` is `some n`, at most `n` leading quantifiers are stripped.
|
|||
If `num?` is `none`, all leading quantifiers are stripped.
|
||||
-/
|
||||
public def mkPatternFromDecl (declName : Name) (num? : Option Nat := none) : MetaM Pattern := do
|
||||
let info ← getConstInfo declName
|
||||
let levelParams := info.levelParams.mapIdx fun i _ => Name.num uvarPrefix i
|
||||
let us := levelParams.map mkLevelParam
|
||||
let type ← instantiateTypeLevelParams info.toConstantVal us
|
||||
let type ← preprocessType type
|
||||
let (levelParams, type) ← preprocessPattern declName
|
||||
let hugeNumber := 10000000
|
||||
let num := num?.getD hugeNumber
|
||||
let rec go (i : Nat) (type : Expr) (varTypes : Array Expr) (isInstance : Array Bool) : MetaM Pattern := do
|
||||
|
|
@ -98,6 +103,29 @@ public def mkPatternFromDecl (declName : Name) (num? : Option Nat := none) : Met
|
|||
return { levelParams, varTypes, isInstance, pattern, fnInfos }
|
||||
go 0 type #[] #[]
|
||||
|
||||
/--
|
||||
Creates a `Pattern` from an equational theorem, using the left-hand side of the equation.
|
||||
It also returns the right-hand side of the equation
|
||||
|
||||
Like `mkPatternFromDecl`, this strips all leading universal quantifiers, recording variable
|
||||
types and instance status. However, instead of using the entire resulting type as the pattern,
|
||||
it extracts just the LHS of the equation.
|
||||
|
||||
For a theorem `∀ x₁ ... xₙ, lhs = rhs`, returns a pattern matching `lhs` with `n` pattern variables.
|
||||
Throws an error if the theorem's conclusion is not an equality.
|
||||
-/
|
||||
public def mkEqPatternFromDecl (declName : Name) : MetaM (Pattern × Expr) := do
|
||||
let (levelParams, type) ← preprocessPattern declName
|
||||
let rec go (type : Expr) (varTypes : Array Expr) (isInstance : Array Bool) : MetaM (Pattern × Expr) := do
|
||||
if let .forallE _ d b _ := type then
|
||||
return (← go b (varTypes.push d) (isInstance.push (isClass? (← getEnv) d).isSome))
|
||||
else
|
||||
let_expr Eq _ lhs rhs := type | throwError "resulting type for `{.ofConstName declName}` is not an equality"
|
||||
let pattern := lhs
|
||||
let fnInfos ← mkProofInstInfoMapFor pattern
|
||||
return ({ levelParams, varTypes, isInstance, pattern, fnInfos }, rhs)
|
||||
go type #[] #[]
|
||||
|
||||
structure UnifyM.Context where
|
||||
pattern : Pattern
|
||||
unify : Bool := true
|
||||
|
|
|
|||
50
src/Lean/Meta/Sym/Rewrite.lean
Normal file
50
src/Lean/Meta/Sym/Rewrite.lean
Normal file
|
|
@ -0,0 +1,50 @@
|
|||
/-
|
||||
Copyright (c) 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Leonardo de Moura
|
||||
-/
|
||||
module
|
||||
prelude
|
||||
public import Lean.Meta.Sym.SimpM
|
||||
public import Lean.Meta.Sym.SimpFun
|
||||
import Lean.Meta.Sym.InstantiateS
|
||||
namespace Lean.Meta.Sym.Simp
|
||||
open Grind
|
||||
|
||||
public def mkTheoremFromDecl (declName : Name) : MetaM Theorem := do
|
||||
let (pattern, rhs) ← mkEqPatternFromDecl declName
|
||||
return { expr := mkConst declName, pattern, rhs }
|
||||
|
||||
/--
|
||||
Creates proof term for a rewriting step.
|
||||
Handles both constant expressions (common case, avoids `instantiateLevelParams`)
|
||||
and general expressions.
|
||||
-/
|
||||
def mkValue (expr : Expr) (pattern : Pattern) (result : MatchUnifyResult) : Expr :=
|
||||
if let .const declName [] := expr then
|
||||
mkAppN (mkConst declName result.us) result.args
|
||||
else
|
||||
mkAppN (expr.instantiateLevelParams pattern.levelParams result.us) result.args
|
||||
|
||||
/--
|
||||
Tries to rewrite `e` using the given theorem.
|
||||
-/
|
||||
-- **TODO**: Define `Step` result?
|
||||
public def Theorem.rewrite? (thm : Theorem) (e : Expr) : SimpM (Option Result) := do
|
||||
if let some result ← thm.pattern.match? e then
|
||||
let proof? := some <| mkValue thm.expr thm.pattern result
|
||||
let rhs := thm.rhs.instantiateLevelParams thm.pattern.levelParams result.us
|
||||
let rhs ← shareCommonInc rhs
|
||||
let expr ← instantiateRevBetaS rhs result.args
|
||||
return some { expr, proof? }
|
||||
else
|
||||
return none
|
||||
|
||||
public def rewrite : SimpFun := fun e => do
|
||||
-- **TODO**: use indexing
|
||||
for thm in (← read).thms.thms do
|
||||
if let some result ← thm.rewrite? e then
|
||||
return result
|
||||
return { expr := e }
|
||||
|
||||
end Lean.Meta.Sym.Simp
|
||||
|
|
@ -8,6 +8,9 @@ prelude
|
|||
public import Lean.Meta.Sym.SimpM
|
||||
import Lean.Meta.Tactic.Grind.AlphaShareBuilder
|
||||
import Lean.Meta.Sym.EqTrans
|
||||
import Lean.Meta.Sym.Rewrite
|
||||
import Lean.Meta.Sym.SimpResult
|
||||
import Lean.Meta.Sym.SimpFun
|
||||
import Lean.Meta.Sym.Congr
|
||||
namespace Lean.Meta.Sym.Simp
|
||||
open Grind
|
||||
|
|
@ -39,7 +42,7 @@ def simpMVar (e : Expr) : SimpM Result := do
|
|||
def simpApp (e : Expr) : SimpM Result := do
|
||||
congrArgs e
|
||||
|
||||
def simpStep (e : Expr) : SimpM Result := do
|
||||
def simpStep : SimpFun := fun e => do
|
||||
match e with
|
||||
| .lit _ | .sort _ | .bvar _ => return { expr := e }
|
||||
| .proj .. =>
|
||||
|
|
@ -57,10 +60,6 @@ def simpStep (e : Expr) : SimpM Result := do
|
|||
| .mvar .. => simpMVar e
|
||||
| .app .. => simpApp e
|
||||
|
||||
def mkEqTrans (e : Expr) (r₁ : Result) (r₂ : Result) : SimpM Result := do
|
||||
let proof? ← Sym.mkEqTrans e r₁.expr r₁.proof? r₂.expr r₂.proof?
|
||||
return { r₂ with proof? }
|
||||
|
||||
def cacheResult (e : Expr) (r : Result) : SimpM Result := do
|
||||
modify fun s => { s with cache := s.cache.insert { expr := e } r }
|
||||
return r
|
||||
|
|
@ -71,7 +70,7 @@ def simpImpl (e : Expr) : SimpM Result := do
|
|||
throwError "`simp` failed: maximum number of steps exceeded"
|
||||
if let some result := (← getCache).find? { expr := e } then
|
||||
return result
|
||||
let r ← simpStep e
|
||||
let r ← (simpStep >> rewrite) e
|
||||
if isSameExpr r.expr e then
|
||||
cacheResult e r
|
||||
else
|
||||
|
|
|
|||
29
src/Lean/Meta/Sym/SimpFun.lean
Normal file
29
src/Lean/Meta/Sym/SimpFun.lean
Normal file
|
|
@ -0,0 +1,29 @@
|
|||
/-
|
||||
Copyright (c) 2026 Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Leonardo de Moura
|
||||
-/
|
||||
module
|
||||
prelude
|
||||
public import Lean.Meta.Sym.SimpM
|
||||
import Lean.Meta.Sym.EqTrans
|
||||
namespace Lean.Meta.Sym.Simp
|
||||
open Grind
|
||||
public def mkEqTrans (e : Expr) (r₁ : Result) (r₂ : Result) : SimpM Result := do
|
||||
let proof? ← Sym.mkEqTrans e r₁.expr r₁.proof? r₂.expr r₂.proof?
|
||||
return { r₂ with proof? }
|
||||
|
||||
public abbrev SimpFun := Expr → SimpM Result
|
||||
|
||||
public abbrev SimpFun.andThen (f g : SimpFun) : SimpFun := fun e => do
|
||||
let r₁ ← f e
|
||||
if isSameExpr e r₁.expr then
|
||||
g e
|
||||
else
|
||||
let r₂ ← g r₁.expr
|
||||
mkEqTrans e r₁ r₂
|
||||
|
||||
public instance : AndThen SimpFun where
|
||||
andThen f g := SimpFun.andThen f (g ())
|
||||
|
||||
end Lean.Meta.Sym.Simp
|
||||
|
|
@ -123,9 +123,11 @@ during rewriting.
|
|||
-/
|
||||
structure Theorem where
|
||||
/-- The theorem expression, typically `Expr.const declName` for a named theorem. -/
|
||||
expr : Expr
|
||||
expr : Expr
|
||||
/-- Precomputed pattern extracted from the theorem's type for efficient matching. -/
|
||||
pattern : Pattern
|
||||
pattern : Pattern
|
||||
/-- Right-hand side of the equation. -/
|
||||
rhs : Expr
|
||||
|
||||
/-- Collection of simplification theorems available to the simplifier. -/
|
||||
structure Theorems where
|
||||
|
|
@ -178,4 +180,9 @@ def getConfig : SimpM Config :=
|
|||
abbrev getCache : SimpM Cache :=
|
||||
return (← get).cache
|
||||
|
||||
end Lean.Meta.Sym.Simp
|
||||
end Simp
|
||||
|
||||
public def simp (e : Expr) (thms : Simp.Theorems := {}) (config : Simp.Config := {}) : SymM Simp.Result := do
|
||||
Simp.SimpM.run (Simp.simp e) thms config
|
||||
|
||||
end Lean.Meta.Sym
|
||||
|
|
|
|||
|
|
@ -6,7 +6,6 @@ Authors: Leonardo de Moura
|
|||
module
|
||||
prelude
|
||||
public import Lean.Meta.Sym.SimpM
|
||||
import Lean.Meta.Sym.InferType
|
||||
namespace Lean.Meta.Sym.Simp
|
||||
|
||||
public def Result.getProof (result : Result) : SymM Expr := do
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue