feat: eta-reduction support in SymM (#12168)
This PR adds support for eta-reduction in `SymM`.
This commit is contained in:
parent
e8870da205
commit
62d2688579
5 changed files with 129 additions and 3 deletions
|
|
@ -23,6 +23,7 @@ public import Lean.Meta.Sym.Apply
|
|||
public import Lean.Meta.Sym.InferType
|
||||
public import Lean.Meta.Sym.Simp
|
||||
public import Lean.Meta.Sym.Util
|
||||
public import Lean.Meta.Sym.Eta
|
||||
public import Lean.Meta.Sym.Grind
|
||||
|
||||
/-!
|
||||
|
|
|
|||
53
src/Lean/Meta/Sym/Eta.lean
Normal file
53
src/Lean/Meta/Sym/Eta.lean
Normal file
|
|
@ -0,0 +1,53 @@
|
|||
/-
|
||||
Copyright (c) 2026 Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Leonardo de Moura
|
||||
-/
|
||||
module
|
||||
prelude
|
||||
public import Lean.Meta.Sym.ExprPtr
|
||||
public import Lean.Meta.Basic
|
||||
import Lean.Meta.Transform
|
||||
namespace Lean.Meta.Sym
|
||||
/--
|
||||
Checks if `body` is eta-expanded with `n` applications: `f (.bvar (n-1)) ... (.bvar 0)`.
|
||||
Returns `f` if so and `f` has no loose bvars; otherwise returns `default`.
|
||||
- `n`: number of remaining applications to check
|
||||
- `i`: expected bvar index (starts at 0, increments with each application)
|
||||
- `default`: returned when not eta-reducible (enables pointer equality check)
|
||||
-/
|
||||
def etaReduceAux (body : Expr) (n : Nat) (i : Nat) (default : Expr) : Expr := Id.run do
|
||||
match n with
|
||||
| 0 => if body.hasLooseBVars then default else body
|
||||
| n+1 =>
|
||||
let .app f (.bvar j) := body | default
|
||||
if j == i then etaReduceAux f n (i+1) default else default
|
||||
|
||||
/--
|
||||
If `e` is of the form `(fun x₁ ... xₙ => f x₁ ... xₙ)` and `f` does not contain `x₁`, ..., `xₙ`,
|
||||
then returns `f`. Otherwise, returns `e`.
|
||||
|
||||
Returns the original expression when not reducible to enable pointer equality checks.
|
||||
-/
|
||||
public def etaReduce (e : Expr) : Expr :=
|
||||
go e 0
|
||||
where
|
||||
go (body : Expr) (n : Nat) : Expr :=
|
||||
match body with
|
||||
| .lam _ _ b _ => go b (n+1)
|
||||
| _ => if n == 0 then e else etaReduceAux body n 0 e
|
||||
|
||||
/-- Returns `true` if `e` can be eta-reduced. Uses pointer equality for efficiency. -/
|
||||
public def isEtaReducible (e : Expr) : Bool :=
|
||||
!isSameExpr e (etaReduce e)
|
||||
|
||||
/-- Applies `etaReduce` to all subexpressions. Returns `e` unchanged if no subexpression is eta-reducible. -/
|
||||
public def etaReduceAll (e : Expr) : MetaM Expr := do
|
||||
unless Option.isSome <| e.find? isEtaReducible do return e
|
||||
let pre (e : Expr) : MetaM TransformStep := do
|
||||
let e' := etaReduce e
|
||||
if isSameExpr e e' then return .continue
|
||||
else return .visit e'
|
||||
Meta.transform e (pre := pre)
|
||||
|
||||
end Lean.Meta.Sym
|
||||
|
|
@ -18,6 +18,7 @@ import Lean.Meta.Sym.ProofInstInfo
|
|||
import Lean.Meta.Sym.AlphaShareBuilder
|
||||
import Lean.Meta.Sym.LitValues
|
||||
import Lean.Meta.Sym.Offset
|
||||
import Lean.Meta.Sym.Eta
|
||||
namespace Lean.Meta.Sym
|
||||
open Internal
|
||||
|
||||
|
|
@ -323,7 +324,11 @@ def isAssignedMVar (e : Expr) : MetaM Bool :=
|
|||
| _ => return false
|
||||
|
||||
partial def process (p : Expr) (e : Expr) : UnifyM Bool := do
|
||||
match p with
|
||||
let e' := etaReduce e
|
||||
if !isSameExpr e e' then
|
||||
-- **Note**: We eagerly eta reduce patterns
|
||||
process p e'
|
||||
else match p with
|
||||
| .bvar bidx => assignExpr bidx e
|
||||
| .mdata _ p => process p e
|
||||
| .const declName us =>
|
||||
|
|
@ -723,7 +728,12 @@ def isDefEqApp (tFn : Expr) (t : Expr) (s : Expr) (_ : tFn = t.getAppFn) : DefEq
|
|||
@[export lean_sym_def_eq]
|
||||
def isDefEqMainImpl (t : Expr) (s : Expr) : DefEqM Bool := do
|
||||
if isSameExpr t s then return true
|
||||
match t, s with
|
||||
-- **Note**: `etaReduce` is supposed to be fast, and does not allocate memory
|
||||
let t' := etaReduce t
|
||||
let s' := etaReduce s
|
||||
if !isSameExpr t t' || !isSameExpr s s' then
|
||||
isDefEqMain t' s'
|
||||
else match t, s with
|
||||
| .lit l₁, .lit l₂ => return l₁ == l₂
|
||||
| .sort u, .sort v => isLevelDefEqS u v
|
||||
| .lam .., .lam .. => isDefEqBindingS t s
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ public import Lean.Meta.Sym.SymM
|
|||
import Lean.Meta.Sym.IsClass
|
||||
import Lean.Meta.Sym.Util
|
||||
import Lean.Meta.Transform
|
||||
import Lean.Meta.Sym.Eta
|
||||
namespace Lean.Meta.Sym
|
||||
|
||||
/--
|
||||
|
|
@ -17,7 +18,8 @@ Preprocesses types that used for pattern matching and unification.
|
|||
public def preprocessType (type : Expr) : MetaM Expr := do
|
||||
let type ← Sym.unfoldReducible type
|
||||
let type ← Core.betaReduce type
|
||||
zetaReduce type
|
||||
let type ← zetaReduce type
|
||||
etaReduceAll type
|
||||
|
||||
/--
|
||||
Analyzes whether the given free variables (aka arguments) are proofs or instances.
|
||||
|
|
|
|||
60
tests/lean/run/sym_pattern_3.lean
Normal file
60
tests/lean/run/sym_pattern_3.lean
Normal file
|
|
@ -0,0 +1,60 @@
|
|||
import Std.Data.HashMap
|
||||
import Lean.Meta.Sym
|
||||
import Lean.Meta.DiscrTree.Basic
|
||||
open Lean Meta Sym Grind
|
||||
set_option sym.debug true
|
||||
|
||||
abbrev S := Nat
|
||||
abbrev M α := StateM S α
|
||||
|
||||
def Exec (s : S) (k : M α) (post : α → S → Prop) : Prop :=
|
||||
post (k s).1 (k s).2
|
||||
|
||||
theorem Exec.bind (k₁ : M α) (k₂ : α → M β) (post : β → S → Prop) :
|
||||
Exec s k₁ (fun a s₁ => Exec s₁ (k₂ a) post)
|
||||
→ Exec s (k₁ >>= k₂) post := by
|
||||
simp [Exec, Bind.bind, StateT.bind]
|
||||
cases k₁ s; simp
|
||||
|
||||
def goal := ∀ a b, Exec b (set a >>= fun _ => get) fun v _ => v = a
|
||||
set_option pp.explicit true
|
||||
|
||||
/-!
|
||||
Recall that `SymM` patterns are eagerly eta-reduced.
|
||||
Goals are not, but the pattern matcher/unifier performs eta whenever it is needed.
|
||||
-/
|
||||
|
||||
/--
|
||||
info: Pattern:
|
||||
@Exec #5 #4 (@bind (StateT Nat Id) (@Monad.toBind (StateT Nat Id) (@StateT.instMonad Nat Id Id.instMonad)) #6 #5 #3 #2)
|
||||
#1
|
||||
---
|
||||
info: a b : Nat
|
||||
⊢ @Exec Nat b
|
||||
(@bind (fun α => StateT Nat Id α) (@Monad.toBind (fun α => StateT Nat Id α) (@StateT.instMonad Nat Id Id.instMonad))
|
||||
PUnit Nat (@set Nat (fun α => StateT Nat Id α) (@instMonadStateOfStateTOfMonad Nat Id Id.instMonad) a) fun x =>
|
||||
@get Nat (fun α => StateT Nat Id α)
|
||||
(@instMonadStateOfMonadStateOf Nat (fun α => StateT Nat Id α)
|
||||
(@instMonadStateOfStateTOfMonad Nat Id Id.instMonad)))
|
||||
fun v x => @Eq Nat v a
|
||||
---
|
||||
info: a b : Nat
|
||||
⊢ @Exec PUnit b (@set Nat (fun α => StateT Nat Id α) (@instMonadStateOfStateTOfMonad Nat Id Id.instMonad) a)
|
||||
fun a_1 s₁ =>
|
||||
@Exec Nat s₁
|
||||
(@get Nat (fun α => StateT Nat Id α)
|
||||
(@instMonadStateOfMonadStateOf Nat (fun α => StateT Nat Id α)
|
||||
(@instMonadStateOfStateTOfMonad Nat Id Id.instMonad)))
|
||||
fun v x => @Eq Nat v a
|
||||
-/
|
||||
#guard_msgs in
|
||||
run_meta SymM.run do
|
||||
let bindRule ← mkBackwardRuleFromDecl ``Exec.bind
|
||||
let a ← unfoldDefinition (mkConst ``goal)
|
||||
logInfo m!"Pattern:\n{bindRule.pattern.pattern}"
|
||||
forallTelescope a fun _ body => do
|
||||
let mvar ← mkFreshExprMVar body
|
||||
let mvarId ← preprocessMVar mvar.mvarId!
|
||||
logInfo mvarId
|
||||
let .goals [mvarId] ← bindRule.apply mvarId | failure
|
||||
logInfo mvarId
|
||||
Loading…
Add table
Reference in a new issue