diff --git a/src/Lean/Meta/Tactic/Grind/AlphaShareCommon.lean b/src/Lean/Meta/Tactic/Grind/AlphaShareCommon.lean new file mode 100644 index 0000000000..2402ff7693 --- /dev/null +++ b/src/Lean/Meta/Tactic/Grind/AlphaShareCommon.lean @@ -0,0 +1,112 @@ +/- +Copyright (c) 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Leonardo de Moura +-/ +prelude +import Lean.Meta.Tactic.Grind.ENodeKey + +namespace Lean.Meta.Grind + +private def hashChild (e : Expr) : UInt64 := + match e with + | .bvar .. | .mvar .. | .const .. | .fvar .. | .sort .. | .lit .. => + hash e + | .app .. | .letE .. | .forallE .. | .lam .. | .mdata .. | .proj .. => + (unsafe ptrAddrUnsafe e).toUInt64 + +private def alphaHash (e : Expr) : UInt64 := + match e with + | .bvar .. | .mvar .. | .const .. | .fvar .. | .sort .. | .lit .. => + hash e + | .app f a => mixHash (hashChild f) (hashChild a) + | .letE _ _ v b _ => mixHash (hashChild v) (hashChild b) + | .forallE _ d b _ | .lam _ d b _ => mixHash (hashChild d) (hashChild b) + | .mdata _ b => mixHash 13 (hashChild b) + | .proj n i b => mixHash (mixHash (hash n) (hash i)) (hashChild b) + +private def alphaEq (e₁ e₂ : Expr) : Bool := Id.run do + match e₁ with + | .bvar .. | .mvar .. | .const .. | .fvar .. | .sort .. | .lit .. => + e₁ == e₂ + | .app f₁ a₁ => + let .app f₂ a₂ := e₂ | false + isSameExpr f₁ f₂ && isSameExpr a₁ a₂ + | .letE _ _ v₁ b₁ _ => + let .letE _ _ v₂ b₂ _ := e₂ | false + isSameExpr v₁ v₂ && isSameExpr b₁ b₂ + | .forallE _ d₁ b₁ _ => + let .forallE _ d₂ b₂ _ := e₂ | false + isSameExpr d₁ d₂ && isSameExpr b₁ b₂ + | .lam _ d₁ b₁ _ => + let .lam _ d₂ b₂ _ := e₂ | false + isSameExpr d₁ d₂ && isSameExpr b₁ b₂ + | .mdata d₁ b₁ => + let .mdata d₂ b₂ := e₂ | false + return isSameExpr b₁ b₂ && d₁ == d₂ + | .proj n₁ i₁ b₁ => + let .proj n₂ i₂ b₂ := e₂ | false + n₁ == n₂ && i₁ == i₂ && isSameExpr b₁ b₂ + +structure AlphaKey where + expr : Expr + +instance : Hashable AlphaKey where + hash k := alphaHash k.expr + +instance : BEq AlphaKey where + beq k₁ k₂ := alphaEq k₁.expr k₂.expr + +structure AlphaShareCommon.State where + map : PHashMap ENodeKey Expr := {} + set : PHashSet AlphaKey := {} + +abbrev AlphaShareCommonM := StateM AlphaShareCommon.State + +private def save (e : Expr) (r : Expr) : AlphaShareCommonM Expr := do + if let some r := (← get).set.find? { expr := r } then + let r := r.expr + modify fun { set, map } => { + set + map := map.insert { expr := e } r + } + return r + else + modify fun { set, map } => { + set := set.insert { expr := r } + map := map.insert { expr := e } r |>.insert { expr := r } r + } + return r + +private abbrev visit (e : Expr) (k : AlphaShareCommonM Expr) : AlphaShareCommonM Expr := do + if let some r := (← get).map.find? { expr := e } then + return r + else + save e (← k) + +/-- Similar to `shareCommon`, but handles alpha-equivalence. -/ +def shareCommonAlpha (e : Expr) : AlphaShareCommonM Expr := + go e +where + go (e : Expr) : AlphaShareCommonM Expr := do + match e with + | .bvar .. | .mvar .. | .const .. | .fvar .. | .sort .. | .lit .. => + if let some r := (← get).set.find? { expr := e } then + return r.expr + else + modify fun { set, map } => { set := set.insert { expr := e }, map } + return e + | .app f a => + visit e (return mkApp (← go f) (← go a)) + | .letE n t v b nd => + visit e (return mkLet n t (← go v) (← go b) nd) + | .forallE n d b bi => + visit e (return mkForall n bi (← go d) (← go b)) + | .lam n d b bi => + visit e (return mkLambda n bi (← go d) (← go b)) + | .mdata d b => + visit e (return mkMData d (← go b)) + | .proj n i b => + visit e (return mkProj n i (← go b)) + +end Lean.Meta.Grind diff --git a/src/Lean/Meta/Tactic/Grind/Main.lean b/src/Lean/Meta/Tactic/Grind/Main.lean index 5e98a99722..4a2a8ee5d2 100644 --- a/src/Lean/Meta/Tactic/Grind/Main.lean +++ b/src/Lean/Meta/Tactic/Grind/Main.lean @@ -55,12 +55,11 @@ def mkMethods (fallback : Fallback) : CoreM Methods := do } def GrindM.run (x : GrindM α) (params : Params) (fallback : Fallback) : MetaM α := do - let scState := ShareCommon.State.mk _ - let (falseExpr, scState) := ShareCommon.State.shareCommon scState (mkConst ``False) - let (trueExpr, scState) := ShareCommon.State.shareCommon scState (mkConst ``True) - let (bfalseExpr, scState) := ShareCommon.State.shareCommon scState (mkConst ``Bool.false) - let (btrueExpr, scState) := ShareCommon.State.shareCommon scState (mkConst ``Bool.true) - let (natZExpr, scState) := ShareCommon.State.shareCommon scState (mkNatLit 0) + let (falseExpr, scState) := shareCommonAlpha (mkConst ``False) {} + let (trueExpr, scState) := shareCommonAlpha (mkConst ``True) scState + let (bfalseExpr, scState) := shareCommonAlpha (mkConst ``Bool.false) scState + let (btrueExpr, scState) := shareCommonAlpha (mkConst ``Bool.true) scState + let (natZExpr, scState) := shareCommonAlpha (mkNatLit 0) scState let simprocs := params.normProcs let simp := params.norm let config := params.config diff --git a/src/Lean/Meta/Tactic/Grind/Types.lean b/src/Lean/Meta/Tactic/Grind/Types.lean index 92f76e2850..34389a4773 100644 --- a/src/Lean/Meta/Tactic/Grind/Types.lean +++ b/src/Lean/Meta/Tactic/Grind/Types.lean @@ -7,7 +7,6 @@ prelude import Init.Grind.Tactics import Init.Data.Queue import Std.Data.TreeSet -import Lean.Util.ShareCommon import Lean.HeadIndex import Lean.Meta.Basic import Lean.Meta.CongrTheorems @@ -16,6 +15,7 @@ import Lean.Meta.Tactic.Simp.Types import Lean.Meta.Tactic.Util import Lean.Meta.Tactic.Ext import Lean.Meta.Tactic.Grind.ENodeKey +import Lean.Meta.Tactic.Grind.AlphaShareCommon import Lean.Meta.Tactic.Grind.Attr import Lean.Meta.Tactic.Grind.ExtAttr import Lean.Meta.Tactic.Grind.Cases @@ -117,7 +117,7 @@ private def emptySC : ShareCommon.State.{0} ShareCommon.objectFactory := ShareCo /-- State for the `GrindM` monad. -/ structure State where /-- `ShareCommon` (aka `Hashconsing`) state. -/ - scState : ShareCommon.State.{0} ShareCommon.objectFactory := emptySC + scState : AlphaShareCommon.State := {} /-- Congruence theorems generated so far. Recall that for constant symbols we rely on the reserved name feature (i.e., `mkHCongrWithArityForConst?`). @@ -232,8 +232,8 @@ Applies hash-consing to `e`. Recall that all expressions in a `grind` goal have been hash-consed. We perform this step before we internalize expressions. -/ def shareCommon (e : Expr) : GrindM Expr := do - let scState ← modifyGet fun s => (s.scState, { s with scState := emptySC }) - let (e, scState) := ShareCommon.State.shareCommon scState e + let scState ← modifyGet fun s => (s.scState, { s with scState := {} }) + let (e, scState) := shareCommonAlpha e scState modify fun s => { s with scState } return e diff --git a/tests/lean/run/grind_heartbeats.lean b/tests/lean/run/grind_heartbeats.lean index c70dadfc4e..9e158415c1 100644 --- a/tests/lean/run/grind_heartbeats.lean +++ b/tests/lean/run/grind_heartbeats.lean @@ -12,7 +12,7 @@ macro_rules | `(gen! $n:num) => `(op (f $n) (gen! $(Lean.quote (n.getNat - 1)))) /-- -trace: [grind.issues] (deterministic) timeout at `simp`, maximum number of heartbeats (5000) has been reached +trace: [grind.issues] (deterministic) timeout at `isDefEq`, maximum number of heartbeats (5000) has been reached Use `set_option maxHeartbeats ` to set the limit. ⏎ Additional diagnostic information may be available using the `set_option diagnostics true` command. diff --git a/tests/lean/run/grind_t1.lean b/tests/lean/run/grind_t1.lean index 3b9bf3219e..6ceb0877aa 100644 --- a/tests/lean/run/grind_t1.lean +++ b/tests/lean/run/grind_t1.lean @@ -457,3 +457,13 @@ example (h : ∀ i, (¬i > 0) ∨ ∀ h : i ≠ 10, p i h) : p 5 (by decide) := -- Similar to previous test. example (h : ∀ i, (∀ h : i ≠ 10, p i h) ∨ (¬i > 0)) : p 5 (by decide) := by grind + +-- `grind` performs hash-consing modulo alpha-equivalence +/-- +trace: [grind.assert] (f fun x => x) = a +[grind.assert] ¬a = f fun x => x +-/ +#guard_msgs (trace) in +example (f : (Nat → Nat) → Nat) : f (fun x => x) = a → a = f (fun y => y) := by + set_option trace.grind.assert true in + grind