diff --git a/src/Init/Conv.lean b/src/Init/Conv.lean index 55c24a1073..9b906d4317 100644 --- a/src/Init/Conv.lean +++ b/src/Init/Conv.lean @@ -51,6 +51,10 @@ scoped syntax (name := withAnnotateState) /-- `skip` does nothing. -/ syntax (name := skip) "skip" : conv +/-- `cbv` performs simplification that closely mimics call-by-value evaluation, +using equations associated with definitions and the matchers. -/ +syntax (name := cbv) "cbv" : conv + /-- Traverses into the left subterm of a binary operator. diff --git a/src/Lean/Elab/Tactic/Conv.lean b/src/Lean/Elab/Tactic/Conv.lean index 382ee8bec7..d2f9d50ad2 100644 --- a/src/Lean/Elab/Tactic/Conv.lean +++ b/src/Lean/Elab/Tactic/Conv.lean @@ -15,3 +15,4 @@ public import Lean.Elab.Tactic.Conv.Simp public import Lean.Elab.Tactic.Conv.Pattern public import Lean.Elab.Tactic.Conv.Delta public import Lean.Elab.Tactic.Conv.Unfold +public import Lean.Elab.Tactic.Conv.Cbv diff --git a/src/Lean/Elab/Tactic/Conv/Cbv.lean b/src/Lean/Elab/Tactic/Conv/Cbv.lean new file mode 100644 index 0000000000..f5524dee5c --- /dev/null +++ b/src/Lean/Elab/Tactic/Conv/Cbv.lean @@ -0,0 +1,29 @@ +/- +Copyright (c) 2026 Lean FRO, LLC. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Wojciech Różowski +-/ + +module + +prelude +public import Lean.Meta.Tactic.Cbv +public import Lean.Elab.Tactic.Conv.Basic + +section + +namespace Lean.Elab.Tactic.Conv +open Lean.Meta.Tactic.Cbv + + +@[builtin_tactic Lean.Parser.Tactic.Conv.cbv] public def evalCbv : Tactic := fun stx => withMainContext do + if cbv.warning.get (← getOptions) then + logWarningAt stx "The `cbv` tactic is experimental and still under development. Avoid using it in production projects" + let lhs ← getLhs + let evalResult ← cbvEntry lhs + match evalResult with + | .rfl .. => return () + | .step e' proof _ => + updateLhs e' proof + +end Lean.Elab.Tactic.Conv diff --git a/src/Lean/Meta/Tactic.lean b/src/Lean/Meta/Tactic.lean index 4f7436cdba..6c1ae1122e 100644 --- a/src/Lean/Meta/Tactic.lean +++ b/src/Lean/Meta/Tactic.lean @@ -45,3 +45,4 @@ public import Lean.Meta.Tactic.Rewrites public import Lean.Meta.Tactic.Grind public import Lean.Meta.Tactic.Ext public import Lean.Meta.Tactic.Try +public import Lean.Meta.Tactic.Cbv diff --git a/src/Lean/Meta/Tactic/Cbv.lean b/src/Lean/Meta/Tactic/Cbv.lean new file mode 100644 index 0000000000..8f4d558411 --- /dev/null +++ b/src/Lean/Meta/Tactic/Cbv.lean @@ -0,0 +1,19 @@ +/- +Copyright (c) 2026 Lean FRO, LLC. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Wojciech Różowski +-/ +module + +prelude +public import Lean.Meta.Tactic.Cbv.Main +public import Lean.Meta.Tactic.Cbv.Util + +public section + +namespace Lean + +builtin_initialize registerTraceClass `Meta.Tactic.cbv +builtin_initialize registerTraceClass `Debug.Meta.Tactic.cbv + +end Lean diff --git a/src/Lean/Meta/Tactic/Cbv/Main.lean b/src/Lean/Meta/Tactic/Cbv/Main.lean new file mode 100644 index 0000000000..5e3f26ad1d --- /dev/null +++ b/src/Lean/Meta/Tactic/Cbv/Main.lean @@ -0,0 +1,157 @@ +/- +Copyright (c) 2026 Lean FRO, LLC. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Wojciech Różowski +-/ + +module + +prelude +public import Lean.Meta.Sym.Simp.SimpM +import Lean.Meta.Tactic.Cbv.Util +import Lean.Meta.Tactic.Cbv.TheoremsLookup +import Lean.Meta.Sym + +namespace Lean.Meta.Tactic.Cbv +open Lean.Meta.Sym.Simp + +public register_builtin_option cbv.warning : Bool := { + defValue := true + descr := "disable `cbv` usage warning" +} + +def skipBinders : Simproc := fun e => do + return .rfl (e.isLambda || e.isForall) + +def tryMatchEquations (appFn : Name) : Simproc := fun e => do + let thms ← getMatchTheorems appFn + thms.rewrite (d := dischargeNone) e + +def reduceRecMatcher : Simproc := fun e => do + if let some e' ← reduceRecMatcher? e then + return .step e' (← Sym.mkEqRefl e') + else + return .rfl + +def tryEquations : Simproc := fun e => do + unless e.isApp do + return .rfl + let some appFn := e.getAppFn.constName? | return .rfl + let thms ← getEqnTheorems appFn + thms.rewrite (d := dischargeNone) e + +def tryUnfold : Simproc := fun e => do + unless e.isApp do + return .rfl + let some appFn := e.getAppFn.constName? | return .rfl + let some thm ← getUnfoldTheorem appFn | return .rfl + Theorem.rewrite thm e + +def tryMatcher : Simproc := fun e => do + unless e.isApp do + return .rfl + let some appFn := e.getAppFn.constName? | return .rfl + let some info ← getMatcherInfo? appFn | return .rfl + let start := info.numParams + 1 + let stop := start + info.numDiscrs + (simpAppArgRange · start stop) + >> tryMatchEquations appFn + <|> reduceRecMatcher + <| e + +def handleConstApp : Simproc := + tryEquations <|> tryUnfold + +def betaReduce : Simproc := fun e => do + -- TODO: Improve term sharing + let new := e.headBeta + let new ← Sym.share new + return .step new (← Sym.mkEqRefl new) + +def handleApp : Simproc := fun e => do + unless e.isApp do return .rfl + let fn := e.getAppFn + match fn with + | .const constName _ => + let info ← getConstInfo constName + (guardSimproc (fun _ => info.hasValue) handleConstApp) <|> reduceRecMatcher <| e + | .lam .. => betaReduce e + | _ => return .rfl + +def foldLit : Simproc := fun e => do + let some n := e.rawNatLit? | return .rfl + -- TODO: check performance of sharing + return .step (← Sym.share <| mkNatLit n) (← Sym.mkEqRefl e) + +def zetaReduce : Simproc := fun e => do + let .letE _ _ value body _ := e | return .rfl + let new := expandLet body #[value] + -- TODO: Improve sharing + let new ← Sym.share new + return .step new (← Sym.mkEqRefl new) + +def handleProj : Simproc := fun e => do + let Expr.proj typeName idx struct := e | return .rfl + -- We recursively simplify the projection + let res ← simp struct + match res with + | .rfl _ => + let some reduced ← reduceProj? <| .proj typeName idx struct | do + return .rfl (done := true) + + -- TODO: Figure if we can share this term incrementally + let reduced ← Sym.share reduced + return .step reduced (← Sym.mkEqRefl reduced) + | .step e' proof _ => + let type ← Sym.inferType e' + let congrArgFun := Lean.mkLambda `x .default type <| .proj typeName idx <| .bvar 0 + + -- TODO: Create an efficient symbolic version of `mkCongrArg` + let newProof ← mkCongrArg congrArgFun proof + return .step (← Lean.Expr.updateProjS! e e') newProof + +def simplifyAppFn : Simproc := fun e => do + unless e.isApp do return .rfl + let fn := e.getAppFn + unless fn.isLambda || fn.isConst do + let res ← simp fn + match res with + | .rfl _ => return res + | .step e' proof _ => + let newType ← Sym.inferType e' + let congrArgFun := Lean.mkLambda `x .default newType (mkAppN (.bvar 0) e.getAppArgs) + let newValue ← mkAppNS e' e.getAppArgs + let newProof ← mkCongrArg congrArgFun proof + return .step newValue newProof + return .rfl + +def handleConst : Simproc := fun e => do + let .const n _ := e | return .rfl + let info ← getConstInfo n + unless info.isDefinition do return .rfl + let eType ← Sym.inferType e + let eType ← whnfD eType + unless eType matches .forallE .. do + return .rfl + -- TODO: Check if we need to look if we applied all the levels correctly + let some thm ← getUnfoldTheorem n | return .rfl + Theorem.rewrite thm e + +def cbvPre : Simproc := + isBuiltinValue <|> isProofTerm <|> skipBinders + >> (tryMatcher >> simpControl) <|> (handleConst <|> simplifyAppFn <|> handleProj) + +def cbvPost : Simproc := + evalGround + >> (handleApp <|> zetaReduce) + >> foldLit + +public def cbvEntry (e : Expr) : MetaM Result := do + trace[Meta.Tactic.cbv] "Called cbv tactic to simplify {e}" + let methods := {pre := cbvPre, post := cbvPost} + let e ← Sym.unfoldReducible e + Sym.SymM.run do + let e ← Sym.shareCommon e + SimpM.run' (simp e) (methods := methods) + +end Lean.Meta.Tactic.Cbv diff --git a/src/Lean/Meta/Tactic/Cbv/TheoremsLookup.lean b/src/Lean/Meta/Tactic/Cbv/TheoremsLookup.lean new file mode 100644 index 0000000000..a4f4aba33e --- /dev/null +++ b/src/Lean/Meta/Tactic/Cbv/TheoremsLookup.lean @@ -0,0 +1,77 @@ +/- +Copyright (c) 2026 Lean FRO, LLC. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Wojciech Różowski +-/ + +module + +prelude +public import Lean.Meta.Sym.Simp.Theorems +import Lean.Meta.Match.MatchEqsExt +import Lean.Meta.Eqns + + +namespace Lean.Meta.Sym.Simp + +def Theorems.insertMany (thms : Theorems) (toInsert : Array Theorem) : Theorems := + Array.foldl Theorems.insert thms toInsert + +end Lean.Meta.Sym.Simp + +namespace Lean.Meta.Tactic.Cbv +open Lean.Meta.Sym.Simp + +public structure CbvTheoremsLookupState where + eqnTheorems : PHashMap Name Theorems := {} + unfoldTheorems : PHashMap Name Theorem := {} + matchTheorems : PHashMap Name Theorems := {} + deriving Inhabited + +builtin_initialize cbvTheoremsLookup : EnvExtension CbvTheoremsLookupState ← + registerEnvExtension (pure {}) (asyncMode := .local) + +public def getEqnTheorems (fnName : Name) : MetaM Theorems := do + let env ← getEnv + let cache := cbvTheoremsLookup.getState env + if let some thms := cache.eqnTheorems.find? fnName then + return thms + else + -- Compute theorems from equation names + let some eqnNames ← getEqnsFor? fnName | return {} + let thms := Theorems.insertMany {} <| ← eqnNames.mapM mkTheoremFromDecl + -- Store in cache + modifyEnv fun env => + cbvTheoremsLookup.modifyState env fun cache => + { cache with eqnTheorems := cache.eqnTheorems.insert fnName thms } + return thms + +public def getUnfoldTheorem (fnName : Name) : MetaM (Option Theorem) := do + let env ← getEnv + let cache := cbvTheoremsLookup.getState env + if let some thm := cache.unfoldTheorems.find? fnName then + return some thm + else + let some unfoldEqn ← getUnfoldEqnFor? fnName (nonRec := true) | return none + let thm ← mkTheoremFromDecl unfoldEqn + + modifyEnv fun env => + cbvTheoremsLookup.modifyState env fun cache => + { cache with unfoldTheorems := cache.unfoldTheorems.insert fnName thm } + return some thm + +public def getMatchTheorems (matcherName : Name) : MetaM Theorems := do + let env ← getEnv + let cache := cbvTheoremsLookup.getState env + if let some thms := cache.matchTheorems.find? matcherName then + return thms + else + let eqns ← Match.getEquationsFor matcherName + let thms := Theorems.insertMany {} <| ← eqns.eqnNames.mapM mkTheoremFromDecl + + modifyEnv fun env => + cbvTheoremsLookup.modifyState env fun cache => + { cache with matchTheorems := cache.matchTheorems.insert matcherName thms } + return thms + +end Lean.Meta.Tactic.Cbv diff --git a/src/Lean/Meta/Tactic/Cbv/Util.lean b/src/Lean/Meta/Tactic/Cbv/Util.lean new file mode 100644 index 0000000000..d732e01614 --- /dev/null +++ b/src/Lean/Meta/Tactic/Cbv/Util.lean @@ -0,0 +1,92 @@ +/- +Copyright (c) 2026 Lean FRO, LLC. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Wojciech Różowski +-/ + +module + +prelude +public import Lean.Meta.Sym.Simp.SimpM +import Lean.Meta.Sym.InferType +import Lean.Meta.Sym.AlphaShareBuilder +import Lean.Meta.Sym.LitValues + +namespace Lean.Meta.Tactic.Cbv + +open Lean.Meta.Sym.Simp + +public def mkAppNS (f : Expr) (args : Array Expr) : Sym.SymM Expr := do + args.foldlM Sym.Internal.mkAppS f + +abbrev isNatValue (e : Expr) : Bool := (Sym.getNatValue? e).isSome +abbrev isStringValue (e : Expr) : Bool := (Sym.getStringValue? e).isSome +abbrev isIntValue (e : Expr) : Bool := (Sym.getIntValue? e).isSome +abbrev isBitVecValue (e : Expr) : Bool := (Sym.getBitVecValue? e).isSome +abbrev isFinValue (e : Expr) : Bool := (Sym.getFinValue? e).isSome +abbrev isCharValue (e : Expr) : Bool := (Sym.getCharValue? e).isSome +abbrev isRatValue (e : Expr) : Bool := (Sym.getRatValue? e).isSome +abbrev isUInt8Value (e : Expr) : Bool := (Sym.getUInt8Value? e).isSome +abbrev isUInt16Value (e : Expr) : Bool := (Sym.getUInt16Value? e).isSome +abbrev isUInt32Value (e : Expr) : Bool := (Sym.getUInt32Value? e).isSome +abbrev isUInt64Value (e : Expr) : Bool := (Sym.getUInt64Value? e).isSome +abbrev isInt8Value (e : Expr) : Bool := (Sym.getInt8Value? e).isSome +abbrev isInt16Value (e : Expr) : Bool := (Sym.getInt16Value? e).isSome +abbrev isInt32Value (e : Expr) : Bool := (Sym.getInt32Value? e).isSome +abbrev isInt64Value (e : Expr) : Bool := (Sym.getInt64Value? e).isSome + +public def isVal (e : Expr) : Bool := + [ + isNatValue, + isStringValue, + isIntValue, + isBitVecValue, + isFinValue, + isCharValue, + isUInt8Value, + isUInt16Value, + isUInt32Value, + isUInt64Value, + isInt8Value, + isInt16Value, + isInt32Value, + isInt64Value + ].any (· e) + +public def isBuiltinValue : Simproc := fun e => return .rfl (isVal e) + +public def guardSimproc (p : Expr → Bool) (s : Simproc) : Simproc := fun e => do + if p e then s e else return .rfl + +/-- TODO: Handle code duplication -/ +def isAlwaysZero : Level → Bool + | .zero .. => true + | .mvar .. => false + | .param .. => false + | .succ .. => false + | .max u v => isAlwaysZero u && isAlwaysZero v + | .imax _ u => isAlwaysZero u + +/- Modified for the `SymM` usage -/ +def isProp (e : Expr) : Sym.SymM Bool := do + match (← isPropQuick e) with + | .true => return true + | .false => return false + | .undef => + let type ← Sym.inferType e + let type ← whnfD type + match type with + | Expr.sort u => return isAlwaysZero (← instantiateLevelMVars u) + | _ => return false + +/- Modified for the `SymM` usage -/ +def isProof (e : Expr) : Sym.SymM Bool := do + match (← isProofQuick e) with + | .true => return true + | .false => return false + | .undef => isProp (← Sym.inferType e) + +public def isProofTerm : Simproc := fun e => do + return .rfl (← isProof e) + +end Lean.Meta.Tactic.Cbv diff --git a/tests/lean/run/cbv1.lean b/tests/lean/run/cbv1.lean new file mode 100644 index 0000000000..6e53445e70 --- /dev/null +++ b/tests/lean/run/cbv1.lean @@ -0,0 +1,173 @@ +import Std +set_option cbv.warning false + +def function (n : Nat) : Nat := match n with + | 0 => 0 + 1 + | Nat.succ n => function n + 1 +termination_by (n,0) + +example : function 150 = 151 := by + conv => + lhs + cbv + +example : ((1,1).1,1).1 = 1 := by + conv => + lhs + cbv + + +def f : Unit -> Nat × Nat := fun _ => (1, 2) + +example : (f ()).2 = 2 := by + conv => + lhs + cbv + +def g : Unit → (Nat → Nat) × (Nat → Nat) := fun _ => (fun x => x + 1, fun x => x + 3) + +example : (g ()).1 6 = 7 := by + conv => + lhs + cbv + +example : "abx" ++ "c" = "a" ++ "bxc" := by + conv => + lhs + cbv + conv => + rhs + cbv + +example : instHAdd.1 2 2 = 4 := by + conv => + lhs + cbv + +example : (fun y : Nat → Nat => (fun x => y x)) Nat.succ 1 = 2 := by + conv => + lhs + cbv + +example : (Std.TreeMap.empty.insert "a" "b" : Std.TreeMap String String).toList = [("a", "b")] := by + conv => + lhs + cbv + +theorem array_test : (List.replicate 200 5 : List Nat).reverse = List.replicate 200 5 := by + conv => + lhs + cbv + conv => + rhs + cbv + +def testFun (l : List Nat) : Nat := Id.run do + let mut i := 0 + for _ in l do + i := i + 1 + return i + +-- Possibly a good benchmark for dealing with let expressions +example : testFun [1,2,3,4,5] = 5 := by + conv => + lhs + cbv + +example : "ab".length + "ab".length = ("ab" ++ "ab").length := by + conv => + lhs + cbv + conv => + rhs + cbv + +example : (((Std.TreeMap.empty : Std.TreeMap Nat Nat).insert 2 4).toList ++ [(5, 6)]).reverse = [(5,6), (2,4)] := by + conv => + lhs + cbv + +def h := () + +example : h = () := by + conv => + lhs + cbv + +def IsSubseq (s₁ : String) (s₂ : String) : Prop := + List.Sublist s₁.toList s₂.toList + +def vowels : List Char := ['a', 'e', 'i', 'o', 'u', 'A', 'E', 'I', 'O', 'U'] + +def removeVowels (s : String) : String := + String.ofList (s.toList.filter (· ∉ vowels)) + +example : removeVowels "abcdef" = "bcdf" := by + conv => + lhs + cbv + rfl + +example : removeVowels "abcdef\nghijklm" = "bcdf\nghjklm" := by + conv => + lhs + cbv + rfl +example : removeVowels "aaaaa" = "" := by + conv => + lhs + cbv + rfl +example : removeVowels "aaBAA" = "B" := by + conv => + lhs + cbv + rfl + +example : removeVowels "zbcd" = "zbcd" := by + conv => + lhs + cbv + rfl + +def Nat.factorial : Nat → Nat + | 0 => 1 + | .succ n => Nat.succ n * factorial n + +notation:10000 n "!" => Nat.factorial n + +def Nat.brazilianFactorial : Nat → Nat + | .zero => 1 + | .succ n => (Nat.succ n)! * brazilianFactorial n + +def special_factorial (n : Nat) : Nat := + special_factorial.go n 1 1 0 +where + go (n fact brazilFact curr : Nat) : Nat := + if _h: curr >= n + then brazilFact + else + let fact' := (curr + 1) * fact + let brazilFact' := fact' * brazilFact + special_factorial.go n fact' brazilFact' (Nat.succ curr) + termination_by n - curr + +example : Nat.brazilianFactorial 4 = 288 := by + conv => + lhs + cbv + +example : special_factorial 4 = 288 := by + conv => + lhs + cbv + +example : Nat.brazilianFactorial 5 = 34560 := by + conv => + lhs + cbv + +example : Nat.brazilianFactorial 7 = 125411328000 := by + conv => + lhs + cbv