feat: simpForall and simpArrow in Sym.simp (#11950)
This PR implements `simpForall` and `simpArrow` in `Sym.simp`.
This commit is contained in:
parent
7d5a96941e
commit
d92cdae8e9
5 changed files with 314 additions and 10 deletions
|
|
@ -38,6 +38,12 @@ theorem eq_false_of_decide {p : Prop} {_ : Decidable p} (h : decide p = false) :
|
|||
theorem implies_congr {p₁ p₂ : Sort u} {q₁ q₂ : Sort v} (h₁ : p₁ = p₂) (h₂ : q₁ = q₂) : (p₁ → q₁) = (p₂ → q₂) :=
|
||||
h₁ ▸ h₂ ▸ rfl
|
||||
|
||||
theorem implies_congr_left {p₁ p₂ : Sort u} {q : Sort v} (h : p₁ = p₂) : (p₁ → q) = (p₂ → q) :=
|
||||
h ▸ rfl
|
||||
|
||||
theorem implies_congr_right {p : Sort u} {q₁ q₂ : Sort v} (h : q₁ = q₂) : (p → q₁) = (p → q₂) :=
|
||||
h ▸ rfl
|
||||
|
||||
theorem iff_congr {p₁ p₂ q₁ q₂ : Prop} (h₁ : p₁ ↔ p₂) (h₂ : q₁ ↔ q₂) : (p₁ ↔ q₁) ↔ (p₂ ↔ q₂) :=
|
||||
Iff.of_eq (propext h₁ ▸ propext h₂ ▸ rfl)
|
||||
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ import Lean.Meta.InferType
|
|||
import Lean.Meta.Closure
|
||||
import Lean.Meta.AppBuilder
|
||||
namespace Lean.Meta.Sym.Simp
|
||||
|
||||
/--
|
||||
Given `xs` containing free variables
|
||||
`(x₁ : α₁) (x₂ : α₂[x₁]) ... (xₙ : αₙ[x₁, ..., x_{n-1}])`
|
||||
|
|
@ -29,11 +30,8 @@ public def mkFunextFor (xs : Array Expr) (β : Expr) : MetaM Expr := do
|
|||
let w ← getLevel type
|
||||
withLocalDeclD `f type fun f =>
|
||||
withLocalDeclD `g type fun g => do
|
||||
let lhs := mkAppN f xs
|
||||
let rhs := mkAppN g xs
|
||||
let eq := mkApp3 (mkConst ``Eq [v]) β lhs rhs
|
||||
let p ← mkForallFVars xs eq
|
||||
withLocalDeclD `h p fun h => do
|
||||
let eq := mkApp3 (mkConst ``Eq [v]) β (mkAppN f xs) (mkAppN g xs)
|
||||
withLocalDeclD `h (← mkForallFVars xs eq) fun h => do
|
||||
let eqv ← mkLambdaFVars #[f, g] (← mkForallFVars xs eq)
|
||||
let quotEqv := mkApp2 (mkConst ``Quot [w]) type eqv
|
||||
withLocalDeclD `f' quotEqv fun f' => do
|
||||
|
|
@ -50,4 +48,46 @@ public def mkFunextFor (xs : Array Expr) (β : Expr) : MetaM Expr := do
|
|||
let result ← mkLambdaFVars #[f, g, h] result
|
||||
return result
|
||||
|
||||
/--
|
||||
Given `xs` containing free variables
|
||||
`(x₁ : α₁) (x₂ : α₂[x₁]) ... (xₙ : αₙ[x₁, ..., x_{n-1}])`,
|
||||
creates the custom forall congruence theorem
|
||||
```
|
||||
∀ (p q : (x₁ : α₁) → (x₂ : α₂[x₁]) → ... → (xₙ : αₙ[x₁, ..., x_{n-1}]) → Prop)
|
||||
(h : ∀ x₁ ... xₙ, p x₁ ... xₙ = q x₁ ... xₙ),
|
||||
(∀ x₁ ... xₙ, p x₁ ... xₙ) = (∀ x₁ ... xₙ, q x₁ ... xₙ)
|
||||
```
|
||||
The theorem has three arguments `p`, `q`, and `h`.
|
||||
This auxiliary theorem is used by the simplifier when visiting forall expressions.
|
||||
The proof uses the approach used in `mkFunextFor` followed by an `Eq.ndrec`.
|
||||
-/
|
||||
public def mkForallCongrFor (xs : Array Expr) : MetaM Expr := do
|
||||
let prop := mkSort 0
|
||||
let type ← mkForallFVars xs prop
|
||||
let w ← getLevel type
|
||||
withLocalDeclD `p type fun p =>
|
||||
withLocalDeclD `q type fun q => do
|
||||
let eq := mkApp3 (mkConst ``Eq [1]) prop (mkAppN p xs) (mkAppN q xs)
|
||||
withLocalDeclD `h (← mkForallFVars xs eq) fun h => do
|
||||
let eqv ← mkLambdaFVars #[p, q] (← mkForallFVars xs eq)
|
||||
let quotEqv := mkApp2 (mkConst ``Quot [w]) type eqv
|
||||
withLocalDeclD `p' quotEqv fun p' => do
|
||||
let lift := mkApp6 (mkConst ``Quot.lift [w, 1]) type eqv prop
|
||||
(mkLambda `p .default type (mkAppN (.bvar 0) xs))
|
||||
(mkLambda `p .default type (mkLambda `q .default type (mkLambda `h .default (mkApp2 eqv (.bvar 1) (.bvar 0)) (mkAppN (.bvar 0) xs))))
|
||||
p'
|
||||
let extfunAppVal ← mkLambdaFVars (#[p'] ++ xs) lift
|
||||
let extfunApp := extfunAppVal
|
||||
let quotSound := mkApp5 (mkConst ``Quot.sound [w]) type eqv p q h
|
||||
let Quot_mk_p := mkApp3 (mkConst ``Quot.mk [w]) type eqv p
|
||||
let Quot_mk_q := mkApp3 (mkConst ``Quot.mk [w]) type eqv q
|
||||
let p_eq_q := mkApp6 (mkConst ``congrArg [w, w]) quotEqv type Quot_mk_p Quot_mk_q extfunApp quotSound
|
||||
let lhs ← mkForallFVars xs (mkAppN p xs)
|
||||
let rhs ← mkForallFVars xs (mkAppN q xs)
|
||||
let motive ← mkLambdaFVars #[q] (mkApp3 (mkConst ``Eq [1]) prop lhs rhs)
|
||||
let rfl := mkApp2 (mkConst ``Eq.refl [1]) prop lhs
|
||||
let result := mkApp6 (mkConst ``Eq.ndrec [0, w]) type p motive rfl q p_eq_q
|
||||
let result ← mkLambdaFVars #[p, q, h] result
|
||||
return result
|
||||
|
||||
end Lean.Meta.Sym.Simp
|
||||
|
|
|
|||
|
|
@ -25,13 +25,11 @@ instance : MonadSimp SimpM where
|
|||
| .step e' h _ => return .step e' h
|
||||
|
||||
def simpLambda (e : Expr) : SimpM Result := do
|
||||
-- **TODO**: Add free variable reuse
|
||||
lambdaTelescope e fun xs b => do
|
||||
match (← simp b) with
|
||||
| .rfl _ => return .rfl
|
||||
| .step b' h _ =>
|
||||
let h ← mkLambdaFVars xs h
|
||||
-- **TODO**: Add `mkLambdaFVarsS`?
|
||||
let e' ← shareCommonInc (← mkLambdaFVars xs b')
|
||||
let funext ← getFunext xs b
|
||||
return .step e' (mkApp3 funext e e' h)
|
||||
|
|
@ -46,9 +44,50 @@ where
|
|||
modify fun s => { s with funext := s.funext.insert { expr := key } h }
|
||||
return h
|
||||
|
||||
def simpForall (_ : Expr) : SimpM Result := do
|
||||
-- **TODO**
|
||||
return .rfl
|
||||
def simpArrow (e : Expr) : SimpM Result := do
|
||||
let p := e.bindingDomain!
|
||||
let q := e.bindingBody!
|
||||
match (← simp p), (← simp q) with
|
||||
| .rfl _, .rfl _ =>
|
||||
return .rfl
|
||||
| .step p' h _, .rfl _ =>
|
||||
let u ← getLevel p
|
||||
let v ← getLevel q
|
||||
let e' ← e.updateForallS! p' q
|
||||
return .step e' <| mkApp4 (mkConst ``implies_congr_left [u, v]) p p' q h
|
||||
| .rfl _, .step q' h _ =>
|
||||
let u ← getLevel p
|
||||
let v ← getLevel q
|
||||
let e' ← e.updateForallS! p q'
|
||||
return .step e' <| mkApp4 (mkConst ``implies_congr_right [u, v]) p q q' h
|
||||
| .step p' h₁ _, .step q' h₂ _ =>
|
||||
let u ← getLevel p
|
||||
let v ← getLevel q
|
||||
let e' ← e.updateForallS! p' q'
|
||||
return .step e' <| mkApp6 (mkConst ``implies_congr [u, v]) p p' q q' h₁ h₂
|
||||
|
||||
def simpForall (e : Expr) : SimpM Result := do
|
||||
if e.isArrow then
|
||||
simpArrow e
|
||||
else if (← isProp e) then
|
||||
let n := getForallTelescopeSize e.bindingBody! 1
|
||||
forallBoundedTelescope e n fun xs b => do
|
||||
match (← simp b) with
|
||||
| .rfl _ => return .rfl
|
||||
| .step b' h _ =>
|
||||
let h ← mkLambdaFVars xs h
|
||||
let e' ← shareCommonInc (← mkForallFVars xs b')
|
||||
-- **Note**: consider caching the forall-congr theorems
|
||||
let hcongr ← mkForallCongrFor xs
|
||||
return .step e' (mkApp3 hcongr (← mkLambdaFVars xs b) (← mkLambdaFVars xs b') h)
|
||||
else
|
||||
return .rfl
|
||||
where
|
||||
-- **Note**: Optimize if this is quadratic in practice
|
||||
getForallTelescopeSize (e : Expr) (n : Nat) : Nat :=
|
||||
match e with
|
||||
| .forallE _ _ b _ => if b.hasLooseBVar 0 then getForallTelescopeSize b (n+1) else n
|
||||
| _ => n
|
||||
|
||||
def simpLet (e : Expr) : SimpM Result := do
|
||||
if !e.letNondep! then
|
||||
|
|
|
|||
102
tests/bench/sym/meta_simp_4.lean
Normal file
102
tests/bench/sym/meta_simp_4.lean
Normal file
|
|
@ -0,0 +1,102 @@
|
|||
import Lean
|
||||
open Lean Meta
|
||||
opaque f : Nat → Nat
|
||||
|
||||
namespace SimpBench
|
||||
|
||||
/-!
|
||||
## `MetaM` Simplifier benchmarks
|
||||
-/
|
||||
def getProofSize (r : Simp.Result) : MetaM Nat := do
|
||||
(← r.getProof).numObjs
|
||||
|
||||
def checkWithKernel (r : Simp.Result) : MetaM Float := do
|
||||
let p := ShareCommon.shareCommon' (← r.getProof)
|
||||
let startTime ← IO.monoNanosNow
|
||||
Meta.checkWithKernel p
|
||||
let endTime ← IO.monoNanosNow
|
||||
return (endTime - startTime).toFloat / 1000000.0
|
||||
|
||||
def mkSimpContext (config : Simp.Config := {}) : MetaM Simp.Context := do
|
||||
let s : SimpTheorems := {}
|
||||
let s ← s.addConst ``Nat.zero_add
|
||||
let s ← s.addConst ``Nat.add_zero
|
||||
let config := { config with implicitDefEqProofs := false }
|
||||
Simp.mkContext config #[s] {}
|
||||
|
||||
def simp (e : Expr) : MetaM (Simp.Result × Float) := do
|
||||
-- let e ← Grind.shareCommon e
|
||||
let startTime ← IO.monoNanosNow
|
||||
let (r, _) ← Meta.simp e (← mkSimpContext)
|
||||
let endTime ← IO.monoNanosNow
|
||||
-- logInfo e
|
||||
-- logInfo r.expr
|
||||
-- check (← r.getProof)
|
||||
let timeMs := (endTime - startTime).toFloat / 1000000.0
|
||||
return (r, timeMs)
|
||||
|
||||
def ppExample (e : Expr) (info := false) : MetaM Unit := do
|
||||
forallTelescope e fun _ e => do
|
||||
IO.println "Example:"
|
||||
IO.println (← ppExpr e)
|
||||
IO.println "====>"
|
||||
let (r, _) ← simp e
|
||||
IO.println (← ppExpr r.expr)
|
||||
let h ← r.getProof
|
||||
IO.println "Proof:"
|
||||
if info then
|
||||
logInfo h
|
||||
else
|
||||
IO.println (← ppExpr h)
|
||||
|
||||
def mkForallPrefix (n : Nat) (k : Array Expr → MetaM Expr) : MetaM Expr := do
|
||||
let rec go (n : Nat) (xs : Array Expr) : MetaM Expr := do
|
||||
match n with
|
||||
| 0 => mkForallFVars xs (← k xs)
|
||||
| n+1 =>
|
||||
withLocalDeclD `x (mkConst ``Nat) fun x =>
|
||||
go n (xs.push x)
|
||||
go n #[]
|
||||
|
||||
def mkForallBench (n : Nat) : MetaM Expr :=
|
||||
mkForallPrefix n fun xs => do
|
||||
let rec go (n : Nat) (e : Expr) : MetaM Expr := do
|
||||
match n with
|
||||
| 0 => return e
|
||||
| n+1 =>
|
||||
let p₁ := mkNatEq xs[n]! (mkNatAdd (mkNatLit 0) xs[n]!)
|
||||
let p₂ := mkNatEq (mkNatAdd (mkNatLit 0) xs[n]!) xs[n]!
|
||||
let q := mkNatEq (mkNatLit 1) xs[n]!
|
||||
if n % 2 == 0 then
|
||||
go n (← mkArrow p₁ (← mkArrow q e))
|
||||
else
|
||||
go n (← mkArrow (← mkArrow p₁ p₂) e)
|
||||
go n (mkConst ``True)
|
||||
|
||||
def benchForall (n : Nat) (check := false) : MetaM Unit := do
|
||||
let e ← mkForallBench n
|
||||
let (r, timeMs) ← simp e
|
||||
let proofSize ← getProofSize r
|
||||
if check then
|
||||
let kMs ← checkWithKernel r
|
||||
IO.println s!"forall_{n}: {timeMs}ms, kernel: {kMs}ms, proof_size={proofSize}"
|
||||
else
|
||||
IO.println s!"forall_{n}: {timeMs}ms, proof_size={proofSize}"
|
||||
|
||||
set_option maxRecDepth 100000
|
||||
|
||||
/-! ## Run all benchmarks -/
|
||||
def runAllBenchmarks : MetaM Unit := do
|
||||
IO.println "=== Simplifier Forall Telescope Stress Tests ==="
|
||||
IO.println ""
|
||||
|
||||
IO.println ""
|
||||
IO.println "--- Benchmark 1: Forall Telescope block ---"
|
||||
ppExample (← mkForallBench 10)
|
||||
|
||||
for n in [10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110, 120] do
|
||||
benchForall n (n < 500)
|
||||
|
||||
#eval runAllBenchmarks
|
||||
|
||||
end SimpBench
|
||||
117
tests/bench/sym/simp_4.lean
Normal file
117
tests/bench/sym/simp_4.lean
Normal file
|
|
@ -0,0 +1,117 @@
|
|||
import Lean
|
||||
open Lean Meta
|
||||
opaque f : Nat → Nat
|
||||
|
||||
namespace SimpBench
|
||||
/-!
|
||||
## `SymM` Simplifier benchmarks
|
||||
-/
|
||||
|
||||
def getProofSize (r : Sym.Simp.Result) : MetaM Nat := do
|
||||
match r with
|
||||
| .rfl _ => return 0
|
||||
| .step _ p _ => (ShareCommon.shareCommon' p).numObjs
|
||||
|
||||
def checkWithKernel (r : Sym.Simp.Result) : MetaM Float := do
|
||||
match r with
|
||||
| .rfl _ => return 0.0
|
||||
| .step _ p _ =>
|
||||
let p := ShareCommon.shareCommon' p
|
||||
let startTime ← IO.monoNanosNow
|
||||
Meta.checkWithKernel p
|
||||
let endTime ← IO.monoNanosNow
|
||||
return (endTime - startTime).toFloat / 1000000.0
|
||||
|
||||
def mkSimpMethods : MetaM Sym.Simp.Methods := do
|
||||
let thms : Sym.Simp.Theorems := {}
|
||||
let thms := thms.insert (← Sym.Simp.mkTheoremFromDecl ``Nat.zero_add)
|
||||
let thms := thms.insert (← Sym.Simp.mkTheoremFromDecl ``Nat.add_zero)
|
||||
return { post := thms.rewrite }
|
||||
|
||||
def simp (e : Expr) : MetaM (Sym.Simp.Result × Float) := Sym.SymM.run do
|
||||
let e ← Grind.shareCommon e
|
||||
let methods ← mkSimpMethods
|
||||
let startTime ← IO.monoNanosNow
|
||||
let r ← Sym.simp e methods { maxSteps := 100000000 }
|
||||
let endTime ← IO.monoNanosNow
|
||||
let timeMs := (endTime - startTime).toFloat / 1000000.0
|
||||
-- logInfo e
|
||||
-- match r with
|
||||
-- | .rfl _ => logInfo "rfl"
|
||||
-- | .step e' h _ =>
|
||||
-- logInfo e'; logInfo h
|
||||
return (r, timeMs)
|
||||
|
||||
def ppExample (e : Expr) (info := false) : MetaM Unit := do
|
||||
IO.println "Example:"
|
||||
IO.println (← ppExpr e)
|
||||
IO.println "====>"
|
||||
match (← simp e).1 with
|
||||
| .rfl _ => IO.println "<no change>"
|
||||
| .step e' h _ =>
|
||||
IO.println (← ppExpr e')
|
||||
IO.println "Proof:"
|
||||
if info then
|
||||
logInfo h
|
||||
else
|
||||
IO.println (← ppExpr h)
|
||||
IO.println ""
|
||||
|
||||
def benchSimp (name : String) (e : Expr) (check := false) : MetaM Unit := do
|
||||
let (r, timeMs) ← simp e
|
||||
let proofSize ← getProofSize r
|
||||
if check then
|
||||
let kMs ← checkWithKernel r
|
||||
IO.println s!"{name}: {timeMs}ms, kernel: {kMs}ms, proof_size={proofSize}"
|
||||
else
|
||||
IO.println s!"{name}: {timeMs}ms, proof_size={proofSize}"
|
||||
|
||||
def mkForallPrefix (n : Nat) (k : Array Expr → MetaM Expr) : MetaM Expr := do
|
||||
let rec go (n : Nat) (xs : Array Expr) : MetaM Expr := do
|
||||
match n with
|
||||
| 0 => mkForallFVars xs (← k xs)
|
||||
| n+1 =>
|
||||
withLocalDeclD `x (mkConst ``Nat) fun x =>
|
||||
go n (xs.push x)
|
||||
go n #[]
|
||||
|
||||
def mkForallBench (n : Nat) : MetaM Expr :=
|
||||
mkForallPrefix n fun xs => do
|
||||
let rec go (n : Nat) (e : Expr) : MetaM Expr := do
|
||||
match n with
|
||||
| 0 => return e
|
||||
| n+1 =>
|
||||
let p₁ := mkNatEq xs[n]! (mkNatAdd (mkNatLit 0) xs[n]!)
|
||||
let p₂ := mkNatEq (mkNatAdd (mkNatLit 0) xs[n]!) xs[n]!
|
||||
let q := mkNatEq (mkNatLit 1) xs[n]!
|
||||
if n % 2 == 0 then
|
||||
go n (← mkArrow p₁ (← mkArrow q e))
|
||||
else
|
||||
go n (← mkArrow (← mkArrow p₁ p₂) e)
|
||||
go n (mkConst ``True)
|
||||
|
||||
def benchForall (n : Nat) (check := false) : MetaM Unit := do
|
||||
let e ← mkForallBench n
|
||||
benchSimp s!"forall_{n}" e check
|
||||
|
||||
set_option maxRecDepth 100000
|
||||
|
||||
/-! ## Run all benchmarks -/
|
||||
def runAllBenchmarks : MetaM Unit := do
|
||||
IO.println "=== Simplifier Forall Telescope Stress Tests ==="
|
||||
IO.println ""
|
||||
|
||||
benchForall 600 false
|
||||
|
||||
if true then return ()
|
||||
|
||||
IO.println ""
|
||||
IO.println "--- Benchmark 1: Forall Telescope block ---"
|
||||
ppExample (← mkForallBench 10)
|
||||
|
||||
for n in [10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110, 120, 130, 140, 150, 160, 170, 180, 190, 200, 300, 400] do
|
||||
benchForall n (n < 500)
|
||||
|
||||
#eval runAllBenchmarks
|
||||
|
||||
end SimpBench
|
||||
Loading…
Add table
Reference in a new issue