diff --git a/src/Init/SimpLemmas.lean b/src/Init/SimpLemmas.lean index 5d510b2528..d948aa32bc 100644 --- a/src/Init/SimpLemmas.lean +++ b/src/Init/SimpLemmas.lean @@ -38,6 +38,12 @@ theorem eq_false_of_decide {p : Prop} {_ : Decidable p} (h : decide p = false) : theorem implies_congr {p₁ p₂ : Sort u} {q₁ q₂ : Sort v} (h₁ : p₁ = p₂) (h₂ : q₁ = q₂) : (p₁ → q₁) = (p₂ → q₂) := h₁ ▸ h₂ ▸ rfl +theorem implies_congr_left {p₁ p₂ : Sort u} {q : Sort v} (h : p₁ = p₂) : (p₁ → q) = (p₂ → q) := + h ▸ rfl + +theorem implies_congr_right {p : Sort u} {q₁ q₂ : Sort v} (h : q₁ = q₂) : (p → q₁) = (p → q₂) := + h ▸ rfl + theorem iff_congr {p₁ p₂ q₁ q₂ : Prop} (h₁ : p₁ ↔ p₂) (h₂ : q₁ ↔ q₂) : (p₁ ↔ q₁) ↔ (p₂ ↔ q₂) := Iff.of_eq (propext h₁ ▸ propext h₂ ▸ rfl) diff --git a/src/Lean/Meta/Sym/Simp/Funext.lean b/src/Lean/Meta/Sym/Simp/Funext.lean index ecd810d426..aba235c4a6 100644 --- a/src/Lean/Meta/Sym/Simp/Funext.lean +++ b/src/Lean/Meta/Sym/Simp/Funext.lean @@ -10,6 +10,7 @@ import Lean.Meta.InferType import Lean.Meta.Closure import Lean.Meta.AppBuilder namespace Lean.Meta.Sym.Simp + /-- Given `xs` containing free variables `(x₁ : α₁) (x₂ : α₂[x₁]) ... (xₙ : αₙ[x₁, ..., x_{n-1}])` @@ -29,11 +30,8 @@ public def mkFunextFor (xs : Array Expr) (β : Expr) : MetaM Expr := do let w ← getLevel type withLocalDeclD `f type fun f => withLocalDeclD `g type fun g => do - let lhs := mkAppN f xs - let rhs := mkAppN g xs - let eq := mkApp3 (mkConst ``Eq [v]) β lhs rhs - let p ← mkForallFVars xs eq - withLocalDeclD `h p fun h => do + let eq := mkApp3 (mkConst ``Eq [v]) β (mkAppN f xs) (mkAppN g xs) + withLocalDeclD `h (← mkForallFVars xs eq) fun h => do let eqv ← mkLambdaFVars #[f, g] (← mkForallFVars xs eq) let quotEqv := mkApp2 (mkConst ``Quot [w]) type eqv withLocalDeclD `f' quotEqv fun f' => do @@ -50,4 +48,46 @@ public def mkFunextFor (xs : Array Expr) (β : Expr) : MetaM Expr := do let result ← mkLambdaFVars #[f, g, h] result return result +/-- +Given `xs` containing free variables +`(x₁ : α₁) (x₂ : α₂[x₁]) ... (xₙ : αₙ[x₁, ..., x_{n-1}])`, +creates the custom forall congruence theorem +``` +∀ (p q : (x₁ : α₁) → (x₂ : α₂[x₁]) → ... → (xₙ : αₙ[x₁, ..., x_{n-1}]) → Prop) + (h : ∀ x₁ ... xₙ, p x₁ ... xₙ = q x₁ ... xₙ), + (∀ x₁ ... xₙ, p x₁ ... xₙ) = (∀ x₁ ... xₙ, q x₁ ... xₙ) +``` +The theorem has three arguments `p`, `q`, and `h`. +This auxiliary theorem is used by the simplifier when visiting forall expressions. +The proof uses the approach used in `mkFunextFor` followed by an `Eq.ndrec`. +-/ +public def mkForallCongrFor (xs : Array Expr) : MetaM Expr := do + let prop := mkSort 0 + let type ← mkForallFVars xs prop + let w ← getLevel type + withLocalDeclD `p type fun p => + withLocalDeclD `q type fun q => do + let eq := mkApp3 (mkConst ``Eq [1]) prop (mkAppN p xs) (mkAppN q xs) + withLocalDeclD `h (← mkForallFVars xs eq) fun h => do + let eqv ← mkLambdaFVars #[p, q] (← mkForallFVars xs eq) + let quotEqv := mkApp2 (mkConst ``Quot [w]) type eqv + withLocalDeclD `p' quotEqv fun p' => do + let lift := mkApp6 (mkConst ``Quot.lift [w, 1]) type eqv prop + (mkLambda `p .default type (mkAppN (.bvar 0) xs)) + (mkLambda `p .default type (mkLambda `q .default type (mkLambda `h .default (mkApp2 eqv (.bvar 1) (.bvar 0)) (mkAppN (.bvar 0) xs)))) + p' + let extfunAppVal ← mkLambdaFVars (#[p'] ++ xs) lift + let extfunApp := extfunAppVal + let quotSound := mkApp5 (mkConst ``Quot.sound [w]) type eqv p q h + let Quot_mk_p := mkApp3 (mkConst ``Quot.mk [w]) type eqv p + let Quot_mk_q := mkApp3 (mkConst ``Quot.mk [w]) type eqv q + let p_eq_q := mkApp6 (mkConst ``congrArg [w, w]) quotEqv type Quot_mk_p Quot_mk_q extfunApp quotSound + let lhs ← mkForallFVars xs (mkAppN p xs) + let rhs ← mkForallFVars xs (mkAppN q xs) + let motive ← mkLambdaFVars #[q] (mkApp3 (mkConst ``Eq [1]) prop lhs rhs) + let rfl := mkApp2 (mkConst ``Eq.refl [1]) prop lhs + let result := mkApp6 (mkConst ``Eq.ndrec [0, w]) type p motive rfl q p_eq_q + let result ← mkLambdaFVars #[p, q, h] result + return result + end Lean.Meta.Sym.Simp diff --git a/src/Lean/Meta/Sym/Simp/Main.lean b/src/Lean/Meta/Sym/Simp/Main.lean index d5ae2629bd..0dc0f447aa 100644 --- a/src/Lean/Meta/Sym/Simp/Main.lean +++ b/src/Lean/Meta/Sym/Simp/Main.lean @@ -25,13 +25,11 @@ instance : MonadSimp SimpM where | .step e' h _ => return .step e' h def simpLambda (e : Expr) : SimpM Result := do - -- **TODO**: Add free variable reuse lambdaTelescope e fun xs b => do match (← simp b) with | .rfl _ => return .rfl | .step b' h _ => let h ← mkLambdaFVars xs h - -- **TODO**: Add `mkLambdaFVarsS`? let e' ← shareCommonInc (← mkLambdaFVars xs b') let funext ← getFunext xs b return .step e' (mkApp3 funext e e' h) @@ -46,9 +44,50 @@ where modify fun s => { s with funext := s.funext.insert { expr := key } h } return h -def simpForall (_ : Expr) : SimpM Result := do - -- **TODO** - return .rfl +def simpArrow (e : Expr) : SimpM Result := do + let p := e.bindingDomain! + let q := e.bindingBody! + match (← simp p), (← simp q) with + | .rfl _, .rfl _ => + return .rfl + | .step p' h _, .rfl _ => + let u ← getLevel p + let v ← getLevel q + let e' ← e.updateForallS! p' q + return .step e' <| mkApp4 (mkConst ``implies_congr_left [u, v]) p p' q h + | .rfl _, .step q' h _ => + let u ← getLevel p + let v ← getLevel q + let e' ← e.updateForallS! p q' + return .step e' <| mkApp4 (mkConst ``implies_congr_right [u, v]) p q q' h + | .step p' h₁ _, .step q' h₂ _ => + let u ← getLevel p + let v ← getLevel q + let e' ← e.updateForallS! p' q' + return .step e' <| mkApp6 (mkConst ``implies_congr [u, v]) p p' q q' h₁ h₂ + +def simpForall (e : Expr) : SimpM Result := do + if e.isArrow then + simpArrow e + else if (← isProp e) then + let n := getForallTelescopeSize e.bindingBody! 1 + forallBoundedTelescope e n fun xs b => do + match (← simp b) with + | .rfl _ => return .rfl + | .step b' h _ => + let h ← mkLambdaFVars xs h + let e' ← shareCommonInc (← mkForallFVars xs b') + -- **Note**: consider caching the forall-congr theorems + let hcongr ← mkForallCongrFor xs + return .step e' (mkApp3 hcongr (← mkLambdaFVars xs b) (← mkLambdaFVars xs b') h) + else + return .rfl +where + -- **Note**: Optimize if this is quadratic in practice + getForallTelescopeSize (e : Expr) (n : Nat) : Nat := + match e with + | .forallE _ _ b _ => if b.hasLooseBVar 0 then getForallTelescopeSize b (n+1) else n + | _ => n def simpLet (e : Expr) : SimpM Result := do if !e.letNondep! then diff --git a/tests/bench/sym/meta_simp_4.lean b/tests/bench/sym/meta_simp_4.lean new file mode 100644 index 0000000000..5afeafde73 --- /dev/null +++ b/tests/bench/sym/meta_simp_4.lean @@ -0,0 +1,102 @@ +import Lean +open Lean Meta +opaque f : Nat → Nat + +namespace SimpBench + +/-! +## `MetaM` Simplifier benchmarks +-/ +def getProofSize (r : Simp.Result) : MetaM Nat := do + (← r.getProof).numObjs + +def checkWithKernel (r : Simp.Result) : MetaM Float := do + let p := ShareCommon.shareCommon' (← r.getProof) + let startTime ← IO.monoNanosNow + Meta.checkWithKernel p + let endTime ← IO.monoNanosNow + return (endTime - startTime).toFloat / 1000000.0 + +def mkSimpContext (config : Simp.Config := {}) : MetaM Simp.Context := do + let s : SimpTheorems := {} + let s ← s.addConst ``Nat.zero_add + let s ← s.addConst ``Nat.add_zero + let config := { config with implicitDefEqProofs := false } + Simp.mkContext config #[s] {} + +def simp (e : Expr) : MetaM (Simp.Result × Float) := do + -- let e ← Grind.shareCommon e + let startTime ← IO.monoNanosNow + let (r, _) ← Meta.simp e (← mkSimpContext) + let endTime ← IO.monoNanosNow + -- logInfo e + -- logInfo r.expr + -- check (← r.getProof) + let timeMs := (endTime - startTime).toFloat / 1000000.0 + return (r, timeMs) + +def ppExample (e : Expr) (info := false) : MetaM Unit := do + forallTelescope e fun _ e => do + IO.println "Example:" + IO.println (← ppExpr e) + IO.println "====>" + let (r, _) ← simp e + IO.println (← ppExpr r.expr) + let h ← r.getProof + IO.println "Proof:" + if info then + logInfo h + else + IO.println (← ppExpr h) + +def mkForallPrefix (n : Nat) (k : Array Expr → MetaM Expr) : MetaM Expr := do + let rec go (n : Nat) (xs : Array Expr) : MetaM Expr := do + match n with + | 0 => mkForallFVars xs (← k xs) + | n+1 => + withLocalDeclD `x (mkConst ``Nat) fun x => + go n (xs.push x) + go n #[] + +def mkForallBench (n : Nat) : MetaM Expr := + mkForallPrefix n fun xs => do + let rec go (n : Nat) (e : Expr) : MetaM Expr := do + match n with + | 0 => return e + | n+1 => + let p₁ := mkNatEq xs[n]! (mkNatAdd (mkNatLit 0) xs[n]!) + let p₂ := mkNatEq (mkNatAdd (mkNatLit 0) xs[n]!) xs[n]! + let q := mkNatEq (mkNatLit 1) xs[n]! + if n % 2 == 0 then + go n (← mkArrow p₁ (← mkArrow q e)) + else + go n (← mkArrow (← mkArrow p₁ p₂) e) + go n (mkConst ``True) + +def benchForall (n : Nat) (check := false) : MetaM Unit := do + let e ← mkForallBench n + let (r, timeMs) ← simp e + let proofSize ← getProofSize r + if check then + let kMs ← checkWithKernel r + IO.println s!"forall_{n}: {timeMs}ms, kernel: {kMs}ms, proof_size={proofSize}" + else + IO.println s!"forall_{n}: {timeMs}ms, proof_size={proofSize}" + +set_option maxRecDepth 100000 + +/-! ## Run all benchmarks -/ +def runAllBenchmarks : MetaM Unit := do + IO.println "=== Simplifier Forall Telescope Stress Tests ===" + IO.println "" + + IO.println "" + IO.println "--- Benchmark 1: Forall Telescope block ---" + ppExample (← mkForallBench 10) + + for n in [10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110, 120] do + benchForall n (n < 500) + +#eval runAllBenchmarks + +end SimpBench diff --git a/tests/bench/sym/simp_4.lean b/tests/bench/sym/simp_4.lean new file mode 100644 index 0000000000..ae20e29f6c --- /dev/null +++ b/tests/bench/sym/simp_4.lean @@ -0,0 +1,117 @@ +import Lean +open Lean Meta +opaque f : Nat → Nat + +namespace SimpBench +/-! +## `SymM` Simplifier benchmarks +-/ + +def getProofSize (r : Sym.Simp.Result) : MetaM Nat := do + match r with + | .rfl _ => return 0 + | .step _ p _ => (ShareCommon.shareCommon' p).numObjs + +def checkWithKernel (r : Sym.Simp.Result) : MetaM Float := do + match r with + | .rfl _ => return 0.0 + | .step _ p _ => + let p := ShareCommon.shareCommon' p + let startTime ← IO.monoNanosNow + Meta.checkWithKernel p + let endTime ← IO.monoNanosNow + return (endTime - startTime).toFloat / 1000000.0 + +def mkSimpMethods : MetaM Sym.Simp.Methods := do + let thms : Sym.Simp.Theorems := {} + let thms := thms.insert (← Sym.Simp.mkTheoremFromDecl ``Nat.zero_add) + let thms := thms.insert (← Sym.Simp.mkTheoremFromDecl ``Nat.add_zero) + return { post := thms.rewrite } + +def simp (e : Expr) : MetaM (Sym.Simp.Result × Float) := Sym.SymM.run do + let e ← Grind.shareCommon e + let methods ← mkSimpMethods + let startTime ← IO.monoNanosNow + let r ← Sym.simp e methods { maxSteps := 100000000 } + let endTime ← IO.monoNanosNow + let timeMs := (endTime - startTime).toFloat / 1000000.0 + -- logInfo e + -- match r with + -- | .rfl _ => logInfo "rfl" + -- | .step e' h _ => + -- logInfo e'; logInfo h + return (r, timeMs) + +def ppExample (e : Expr) (info := false) : MetaM Unit := do + IO.println "Example:" + IO.println (← ppExpr e) + IO.println "====>" + match (← simp e).1 with + | .rfl _ => IO.println "" + | .step e' h _ => + IO.println (← ppExpr e') + IO.println "Proof:" + if info then + logInfo h + else + IO.println (← ppExpr h) + IO.println "" + +def benchSimp (name : String) (e : Expr) (check := false) : MetaM Unit := do + let (r, timeMs) ← simp e + let proofSize ← getProofSize r + if check then + let kMs ← checkWithKernel r + IO.println s!"{name}: {timeMs}ms, kernel: {kMs}ms, proof_size={proofSize}" + else + IO.println s!"{name}: {timeMs}ms, proof_size={proofSize}" + +def mkForallPrefix (n : Nat) (k : Array Expr → MetaM Expr) : MetaM Expr := do + let rec go (n : Nat) (xs : Array Expr) : MetaM Expr := do + match n with + | 0 => mkForallFVars xs (← k xs) + | n+1 => + withLocalDeclD `x (mkConst ``Nat) fun x => + go n (xs.push x) + go n #[] + +def mkForallBench (n : Nat) : MetaM Expr := + mkForallPrefix n fun xs => do + let rec go (n : Nat) (e : Expr) : MetaM Expr := do + match n with + | 0 => return e + | n+1 => + let p₁ := mkNatEq xs[n]! (mkNatAdd (mkNatLit 0) xs[n]!) + let p₂ := mkNatEq (mkNatAdd (mkNatLit 0) xs[n]!) xs[n]! + let q := mkNatEq (mkNatLit 1) xs[n]! + if n % 2 == 0 then + go n (← mkArrow p₁ (← mkArrow q e)) + else + go n (← mkArrow (← mkArrow p₁ p₂) e) + go n (mkConst ``True) + +def benchForall (n : Nat) (check := false) : MetaM Unit := do + let e ← mkForallBench n + benchSimp s!"forall_{n}" e check + +set_option maxRecDepth 100000 + +/-! ## Run all benchmarks -/ +def runAllBenchmarks : MetaM Unit := do + IO.println "=== Simplifier Forall Telescope Stress Tests ===" + IO.println "" + + benchForall 600 false + + if true then return () + + IO.println "" + IO.println "--- Benchmark 1: Forall Telescope block ---" + ppExample (← mkForallBench 10) + + for n in [10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110, 120, 130, 140, 150, 160, 170, 180, 190, 200, 300, 400] do + benchForall n (n < 500) + +#eval runAllBenchmarks + +end SimpBench