feat: hash consing with alpha equivalence in grind (#8479)
This PR implements hash-consing for `grind` that takes alpha equivalence into account.
This commit is contained in:
parent
383f68f806
commit
03e905d994
5 changed files with 132 additions and 11 deletions
112
src/Lean/Meta/Tactic/Grind/AlphaShareCommon.lean
Normal file
112
src/Lean/Meta/Tactic/Grind/AlphaShareCommon.lean
Normal file
|
|
@ -0,0 +1,112 @@
|
|||
/-
|
||||
Copyright (c) 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Leonardo de Moura
|
||||
-/
|
||||
prelude
|
||||
import Lean.Meta.Tactic.Grind.ENodeKey
|
||||
|
||||
namespace Lean.Meta.Grind
|
||||
|
||||
private def hashChild (e : Expr) : UInt64 :=
|
||||
match e with
|
||||
| .bvar .. | .mvar .. | .const .. | .fvar .. | .sort .. | .lit .. =>
|
||||
hash e
|
||||
| .app .. | .letE .. | .forallE .. | .lam .. | .mdata .. | .proj .. =>
|
||||
(unsafe ptrAddrUnsafe e).toUInt64
|
||||
|
||||
private def alphaHash (e : Expr) : UInt64 :=
|
||||
match e with
|
||||
| .bvar .. | .mvar .. | .const .. | .fvar .. | .sort .. | .lit .. =>
|
||||
hash e
|
||||
| .app f a => mixHash (hashChild f) (hashChild a)
|
||||
| .letE _ _ v b _ => mixHash (hashChild v) (hashChild b)
|
||||
| .forallE _ d b _ | .lam _ d b _ => mixHash (hashChild d) (hashChild b)
|
||||
| .mdata _ b => mixHash 13 (hashChild b)
|
||||
| .proj n i b => mixHash (mixHash (hash n) (hash i)) (hashChild b)
|
||||
|
||||
private def alphaEq (e₁ e₂ : Expr) : Bool := Id.run do
|
||||
match e₁ with
|
||||
| .bvar .. | .mvar .. | .const .. | .fvar .. | .sort .. | .lit .. =>
|
||||
e₁ == e₂
|
||||
| .app f₁ a₁ =>
|
||||
let .app f₂ a₂ := e₂ | false
|
||||
isSameExpr f₁ f₂ && isSameExpr a₁ a₂
|
||||
| .letE _ _ v₁ b₁ _ =>
|
||||
let .letE _ _ v₂ b₂ _ := e₂ | false
|
||||
isSameExpr v₁ v₂ && isSameExpr b₁ b₂
|
||||
| .forallE _ d₁ b₁ _ =>
|
||||
let .forallE _ d₂ b₂ _ := e₂ | false
|
||||
isSameExpr d₁ d₂ && isSameExpr b₁ b₂
|
||||
| .lam _ d₁ b₁ _ =>
|
||||
let .lam _ d₂ b₂ _ := e₂ | false
|
||||
isSameExpr d₁ d₂ && isSameExpr b₁ b₂
|
||||
| .mdata d₁ b₁ =>
|
||||
let .mdata d₂ b₂ := e₂ | false
|
||||
return isSameExpr b₁ b₂ && d₁ == d₂
|
||||
| .proj n₁ i₁ b₁ =>
|
||||
let .proj n₂ i₂ b₂ := e₂ | false
|
||||
n₁ == n₂ && i₁ == i₂ && isSameExpr b₁ b₂
|
||||
|
||||
structure AlphaKey where
|
||||
expr : Expr
|
||||
|
||||
instance : Hashable AlphaKey where
|
||||
hash k := alphaHash k.expr
|
||||
|
||||
instance : BEq AlphaKey where
|
||||
beq k₁ k₂ := alphaEq k₁.expr k₂.expr
|
||||
|
||||
structure AlphaShareCommon.State where
|
||||
map : PHashMap ENodeKey Expr := {}
|
||||
set : PHashSet AlphaKey := {}
|
||||
|
||||
abbrev AlphaShareCommonM := StateM AlphaShareCommon.State
|
||||
|
||||
private def save (e : Expr) (r : Expr) : AlphaShareCommonM Expr := do
|
||||
if let some r := (← get).set.find? { expr := r } then
|
||||
let r := r.expr
|
||||
modify fun { set, map } => {
|
||||
set
|
||||
map := map.insert { expr := e } r
|
||||
}
|
||||
return r
|
||||
else
|
||||
modify fun { set, map } => {
|
||||
set := set.insert { expr := r }
|
||||
map := map.insert { expr := e } r |>.insert { expr := r } r
|
||||
}
|
||||
return r
|
||||
|
||||
private abbrev visit (e : Expr) (k : AlphaShareCommonM Expr) : AlphaShareCommonM Expr := do
|
||||
if let some r := (← get).map.find? { expr := e } then
|
||||
return r
|
||||
else
|
||||
save e (← k)
|
||||
|
||||
/-- Similar to `shareCommon`, but handles alpha-equivalence. -/
|
||||
def shareCommonAlpha (e : Expr) : AlphaShareCommonM Expr :=
|
||||
go e
|
||||
where
|
||||
go (e : Expr) : AlphaShareCommonM Expr := do
|
||||
match e with
|
||||
| .bvar .. | .mvar .. | .const .. | .fvar .. | .sort .. | .lit .. =>
|
||||
if let some r := (← get).set.find? { expr := e } then
|
||||
return r.expr
|
||||
else
|
||||
modify fun { set, map } => { set := set.insert { expr := e }, map }
|
||||
return e
|
||||
| .app f a =>
|
||||
visit e (return mkApp (← go f) (← go a))
|
||||
| .letE n t v b nd =>
|
||||
visit e (return mkLet n t (← go v) (← go b) nd)
|
||||
| .forallE n d b bi =>
|
||||
visit e (return mkForall n bi (← go d) (← go b))
|
||||
| .lam n d b bi =>
|
||||
visit e (return mkLambda n bi (← go d) (← go b))
|
||||
| .mdata d b =>
|
||||
visit e (return mkMData d (← go b))
|
||||
| .proj n i b =>
|
||||
visit e (return mkProj n i (← go b))
|
||||
|
||||
end Lean.Meta.Grind
|
||||
|
|
@ -55,12 +55,11 @@ def mkMethods (fallback : Fallback) : CoreM Methods := do
|
|||
}
|
||||
|
||||
def GrindM.run (x : GrindM α) (params : Params) (fallback : Fallback) : MetaM α := do
|
||||
let scState := ShareCommon.State.mk _
|
||||
let (falseExpr, scState) := ShareCommon.State.shareCommon scState (mkConst ``False)
|
||||
let (trueExpr, scState) := ShareCommon.State.shareCommon scState (mkConst ``True)
|
||||
let (bfalseExpr, scState) := ShareCommon.State.shareCommon scState (mkConst ``Bool.false)
|
||||
let (btrueExpr, scState) := ShareCommon.State.shareCommon scState (mkConst ``Bool.true)
|
||||
let (natZExpr, scState) := ShareCommon.State.shareCommon scState (mkNatLit 0)
|
||||
let (falseExpr, scState) := shareCommonAlpha (mkConst ``False) {}
|
||||
let (trueExpr, scState) := shareCommonAlpha (mkConst ``True) scState
|
||||
let (bfalseExpr, scState) := shareCommonAlpha (mkConst ``Bool.false) scState
|
||||
let (btrueExpr, scState) := shareCommonAlpha (mkConst ``Bool.true) scState
|
||||
let (natZExpr, scState) := shareCommonAlpha (mkNatLit 0) scState
|
||||
let simprocs := params.normProcs
|
||||
let simp := params.norm
|
||||
let config := params.config
|
||||
|
|
|
|||
|
|
@ -7,7 +7,6 @@ prelude
|
|||
import Init.Grind.Tactics
|
||||
import Init.Data.Queue
|
||||
import Std.Data.TreeSet
|
||||
import Lean.Util.ShareCommon
|
||||
import Lean.HeadIndex
|
||||
import Lean.Meta.Basic
|
||||
import Lean.Meta.CongrTheorems
|
||||
|
|
@ -16,6 +15,7 @@ import Lean.Meta.Tactic.Simp.Types
|
|||
import Lean.Meta.Tactic.Util
|
||||
import Lean.Meta.Tactic.Ext
|
||||
import Lean.Meta.Tactic.Grind.ENodeKey
|
||||
import Lean.Meta.Tactic.Grind.AlphaShareCommon
|
||||
import Lean.Meta.Tactic.Grind.Attr
|
||||
import Lean.Meta.Tactic.Grind.ExtAttr
|
||||
import Lean.Meta.Tactic.Grind.Cases
|
||||
|
|
@ -117,7 +117,7 @@ private def emptySC : ShareCommon.State.{0} ShareCommon.objectFactory := ShareCo
|
|||
/-- State for the `GrindM` monad. -/
|
||||
structure State where
|
||||
/-- `ShareCommon` (aka `Hashconsing`) state. -/
|
||||
scState : ShareCommon.State.{0} ShareCommon.objectFactory := emptySC
|
||||
scState : AlphaShareCommon.State := {}
|
||||
/--
|
||||
Congruence theorems generated so far. Recall that for constant symbols
|
||||
we rely on the reserved name feature (i.e., `mkHCongrWithArityForConst?`).
|
||||
|
|
@ -232,8 +232,8 @@ Applies hash-consing to `e`. Recall that all expressions in a `grind` goal have
|
|||
been hash-consed. We perform this step before we internalize expressions.
|
||||
-/
|
||||
def shareCommon (e : Expr) : GrindM Expr := do
|
||||
let scState ← modifyGet fun s => (s.scState, { s with scState := emptySC })
|
||||
let (e, scState) := ShareCommon.State.shareCommon scState e
|
||||
let scState ← modifyGet fun s => (s.scState, { s with scState := {} })
|
||||
let (e, scState) := shareCommonAlpha e scState
|
||||
modify fun s => { s with scState }
|
||||
return e
|
||||
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ macro_rules
|
|||
| `(gen! $n:num) => `(op (f $n) (gen! $(Lean.quote (n.getNat - 1))))
|
||||
|
||||
/--
|
||||
trace: [grind.issues] (deterministic) timeout at `simp`, maximum number of heartbeats (5000) has been reached
|
||||
trace: [grind.issues] (deterministic) timeout at `isDefEq`, maximum number of heartbeats (5000) has been reached
|
||||
Use `set_option maxHeartbeats <num>` to set the limit.
|
||||
⏎
|
||||
Additional diagnostic information may be available using the `set_option diagnostics true` command.
|
||||
|
|
|
|||
|
|
@ -457,3 +457,13 @@ example (h : ∀ i, (¬i > 0) ∨ ∀ h : i ≠ 10, p i h) : p 5 (by decide) :=
|
|||
-- Similar to previous test.
|
||||
example (h : ∀ i, (∀ h : i ≠ 10, p i h) ∨ (¬i > 0)) : p 5 (by decide) := by
|
||||
grind
|
||||
|
||||
-- `grind` performs hash-consing modulo alpha-equivalence
|
||||
/--
|
||||
trace: [grind.assert] (f fun x => x) = a
|
||||
[grind.assert] ¬a = f fun x => x
|
||||
-/
|
||||
#guard_msgs (trace) in
|
||||
example (f : (Nat → Nat) → Nat) : f (fun x => x) = a → a = f (fun y => y) := by
|
||||
set_option trace.grind.assert true in
|
||||
grind
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue