From fe96911368a3cc5b829e02af6613842a96f9c175 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Henrik=20B=C3=B6ving?= Date: Wed, 17 Dec 2025 12:05:24 +0100 Subject: [PATCH] feat: proper recursive specialization (#11479) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR enables the specializer to also recursively specialize in some non trivial higher order situations. The main motivation for this change is the upcoming changes to do notation by sgraf. In there he uses combinators such as ```lean @[specialize, expose] def List.newForIn {α β γ} (l : List α) (b : β) (kcons : α → (β → γ) → β → γ) (knil : β → γ) : γ := match l with | [] => knil b | a :: l => kcons a (l.newForIn · kcons knil) b ``` in programs such as ```lean def testing := let x := 42; List.newForIn (β := Nat) (γ := Id Nat) [1,2,3] x (fun i kcontinue s => let x := s; List.newForIn [i:10].toList x (fun j kcontinue s => let x := s; let x := x + i + j; kcontinue x) kcontinue) pure ``` inspecting this IR right before we get to the specializer in the current compiler we get: ``` [Compiler.eagerLambdaLifting] size: 22 def testing : Nat := fun _f.1 _y.2 : Nat := return _y.2; let x := 42; let _x.3 := 1; fun _f.4 i kcontinue s : Nat := fun _f.5 j kcontinue s : Nat := let _x.6 := Nat.add s i; let x := Nat.add _x.6 j; let _x.7 := kcontinue x; return _x.7; let _x.8 := 10; let _x.9 := Nat.sub _x.8 i; let _x.10 := Nat.add _x.9 _x.3; let _x.11 := 1; let _x.12 := Nat.sub _x.10 _x.11; let _x.13 := Nat.mul _x.3 _x.12; let _x.14 := Nat.add i _x.13; let _x.15 := @List.nil _; let _x.16 := List.range'TR.go _x.3 _x.12 _x.14 _x.15; let _x.17 := @List.newForIn _ _ _ _x.16 s _f.5 kcontinue; return _x.17; let _x.18 := 2; let _x.19 := 3; let _x.20 := @List.nil _; let _x.21 := @List.cons _ _x.19 _x.20; let _x.22 := @List.cons _ _x.18 _x.21; let _x.23 := @List.cons _ _x.3 _x.22; let _x.24 := @List.newForIn _ _ _ _x.23 x _f.4 _f.1; return _x.24 ``` Here the `kcontinue` higher order functions pose a special challenge because they delay the discovery of new specialization opportunities. Inspecting the IR after the current specializer (and a cleanup simp step) we get functions that look as follows: ``` [simp] size: 7 def List.newForIn._at_.testing.spec_0 i kcontinue l b : Nat := cases l : Nat | List.nil => let _x.1 := kcontinue b; return _x.1 | List.cons head.2 tail.3 => let _x.4 := Nat.add b i; let x := Nat.add _x.4 head.2; let _x.5 := List.newForIn._at_.testing.spec_0 i kcontinue tail.3 x; return _x.5 [simp] size: 14 def List.newForIn._at_.List.newForIn._at_.testing.spec_1.spec_1 _x.1 l b : Nat := cases l : Nat | List.nil => return b | List.cons head.2 tail.3 => fun _f.4 x.5 : Nat := let _x.6 := List.newForIn._at_.List.newForIn._at_.testing.spec_1.spec_1 _x.1 tail.3 x.5; return _x.6; let _x.7 := 10; let _x.8 := Nat.sub _x.7 head.2; let _x.9 := Nat.add _x.8 _x.1; let _x.10 := 1; let _x.11 := Nat.sub _x.9 _x.10; let _x.12 := Nat.mul _x.1 _x.11; let _x.13 := Nat.add head.2 _x.12; let _x.14 := @List.nil _; let _x.15 := List.range'TR.go _x.1 _x.11 _x.13 _x.14; let _x.16 := List.newForIn._at_.testing.spec_0 head.2 _f.4 _x.15 b; return _x.16 ``` Observe that the specializer decided to abstract over `kcontinue` instead of specializing further recursively. Thus this tight loop is now going through an indirect call. This PR now changes the specializer somewhat fundamentally to handle situations like this. The most notable change is going to a fixpoint loop of: 1. Specialize all current declarations in the worklist 2. If a declaration - succeeded in specializing run the simplifier on it and put it back onto the worklist - if it didn't don't put it back onto the worklist anymore 3. Put all newly generated specialisations on the worklist 4. Recompute fixed parameters for the current SCC 5. Repeat until the worklist is empty Furthermore, declarations that were already specialized: - only consider `fixedHO` parameters for specialization, in order to avoid termination issues with repeated specialization and abstraction of type class parameters under binders - recursively specialized declarations only allow specialization if at least one of their fixedHO arguments is not a parameter itself. The reason for allowing this in first generation specialization is that we refrain from specializing inside the body of a declaration marked as `@[specialize]`. Thus we need to specialize them even if their arguments don't actually contain anything of interest in order to ensure that type classes etc. are correctly cleaned up within their bodies. There is one last trade-off to consider. When specializing code generated by the new do elaborator we sometimes generate intermediate specializations that are not actually part of any call graph after we are done specializing. We could in principle detect these functions and delete them but having them in cache is potentially helpful for further specializations later. Once the new do elaborator lands we plan to test this trade-off. Closes #10924 --- src/Lean/Compiler/LCNF/ConfigOptions.lean | 11 + src/Lean/Compiler/LCNF/Passes.lean | 1 + src/Lean/Compiler/LCNF/SpecInfo.lean | 104 ++++++-- src/Lean/Compiler/LCNF/Specialize.lean | 229 +++++++++++++--- stage0/src/stdlib_flags.h | 2 + tests/lean/run/do_for_loop_compiler_test.lean | 239 +++++++++++++++++ .../do_for_loop_levenstein_compiler_test.lean | 250 ++++++++++++++++++ tests/lean/run/more_jps.lean | 26 +- .../run/specFixedHOParamModuloErased.lean | 5 +- tests/lean/run/spec_limit.lean | 11 + 10 files changed, 793 insertions(+), 85 deletions(-) create mode 100644 tests/lean/run/do_for_loop_compiler_test.lean create mode 100644 tests/lean/run/do_for_loop_levenstein_compiler_test.lean create mode 100644 tests/lean/run/spec_limit.lean diff --git a/src/Lean/Compiler/LCNF/ConfigOptions.lean b/src/Lean/Compiler/LCNF/ConfigOptions.lean index 4bcc931551..7dcf157fb4 100644 --- a/src/Lean/Compiler/LCNF/ConfigOptions.lean +++ b/src/Lean/Compiler/LCNF/ConfigOptions.lean @@ -39,6 +39,11 @@ structure ConfigOptions where Cache closed terms and evaluate them at initialization time. -/ extractClosed : Bool := true + /-- + Maximum number of times a definition tagged with `@[specialize]` can be recursively specialized + before generating an error during compilation. + -/ + maxRecSpecialize : Nat := 64 deriving Inhabited register_builtin_option compiler.small : Nat := { @@ -66,12 +71,18 @@ register_builtin_option compiler.extract_closed : Bool := { descr := "(compiler) enable/disable closed term caching" } +register_builtin_option compiler.maxRecSpecialize : Nat := { + defValue := 64 + descr := "(compiler) maximum number of times a definition tagged with `@[specialize]` can be recursively specialized before generating an error during compilation." +} + def toConfigOptions (opts : Options) : ConfigOptions := { smallThreshold := compiler.small.get opts maxRecInline := compiler.maxRecInline.get opts maxRecInlineIfReduce := compiler.maxRecInlineIfReduce.get opts checkTypes := compiler.checkTypes.get opts extractClosed := compiler.extract_closed.get opts + maxRecSpecialize := compiler.maxRecSpecialize.get opts } end Lean.Compiler.LCNF diff --git a/src/Lean/Compiler/LCNF/Passes.lean b/src/Lean/Compiler/LCNF/Passes.lean index 249dca5e42..f72ef04318 100644 --- a/src/Lean/Compiler/LCNF/Passes.lean +++ b/src/Lean/Compiler/LCNF/Passes.lean @@ -18,6 +18,7 @@ public import Lean.Compiler.LCNF.ElimDeadBranches public import Lean.Compiler.LCNF.StructProjCases public import Lean.Compiler.LCNF.ExtractClosed public import Lean.Compiler.LCNF.Visibility +public import Lean.Compiler.LCNF.Simp public section diff --git a/src/Lean/Compiler/LCNF/SpecInfo.lean b/src/Lean/Compiler/LCNF/SpecInfo.lean index 54ed217392..aa1a6c2e5e 100644 --- a/src/Lean/Compiler/LCNF/SpecInfo.lean +++ b/src/Lean/Compiler/LCNF/SpecInfo.lean @@ -45,6 +45,15 @@ inductive SpecParamInfo where | other deriving Inhabited, Repr +namespace SpecParamInfo + +@[inline] +def causesSpecialization : SpecParamInfo → Bool + | .fixedInst | .fixedHO | .user => true + | .fixedNeutral | .other => false + +end SpecParamInfo + instance : ToMessageData SpecParamInfo where toMessageData | .fixedInst => "I" @@ -53,20 +62,35 @@ instance : ToMessageData SpecParamInfo where | .user => "U" | .other => "O" -structure SpecState where - specInfo : PHashMap Name (Array SpecParamInfo) := {} +structure SpecEntry where + /-- + The name of the declaration. + -/ + declName : Name + /-- + Information about which parameters of the declaration qualify for specialization. + -/ + paramsInfo : Array SpecParamInfo + /-- + True if `declName` was already specialized before. This is relevant because we specialize + declarations that have already been specialized less aggressively than declarations that have not. + -/ + alreadySpecialized : Bool deriving Inhabited -structure SpecEntry where - declName : Name - paramsInfo : Array SpecParamInfo +instance : ToMessageData SpecEntry where + toMessageData := fun { declName, paramsInfo, alreadySpecialized } => + m!"{declName}, alreadySpecialized? {alreadySpecialized}, info: {paramsInfo}" + +structure SpecState where + specInfo : PHashMap Name SpecEntry := {} deriving Inhabited namespace SpecState def addEntry (s : SpecState) (e : SpecEntry) : SpecState := match s with - | { specInfo } => { specInfo := specInfo.insert e.declName e.paramsInfo } + | { specInfo } => { specInfo := specInfo.insert e.declName e } end SpecState @@ -77,7 +101,7 @@ private abbrev sortEntries (entries : Array SpecEntry) : Array SpecEntry := entries.qsort declLt private abbrev findAtSorted? (entries : Array SpecEntry) (declName : Name) : Option SpecEntry := - entries.binSearch { declName, paramsInfo := #[] } declLt + entries.binSearch { declName, paramsInfo := #[], alreadySpecialized := false } declLt /-- Extension for storing `SpecParamInfo` for declarations being compiled. @@ -136,20 +160,23 @@ See comment at `.fixedNeutral`. private def hasFwdDeps (decl : Decl) (paramsInfo : Array SpecParamInfo) (j : Nat) : Bool := Id.run do let param := decl.params[j]! for h : k in (j+1)...decl.params.size do - if paramsInfo[k]! matches .user | .fixedHO | .fixedInst then + if paramsInfo[k]!.causesSpecialization then let param' := decl.params[k] if param'.type.containsFVar param.fvarId then return true return false /-- -Save parameter information for `decls`. - -Remark: this function, similarly to `mkFixedArgMap`, -assumes that if a function `f` was declared in a mutual block, then `decls` -contains all (computationally relevant) functions in the mutual block. +Compute specialization information for `decls`. We assume that `decls` contains a full SCC of +computationally relevant declarations. Furthermore this function takes: +- `autoSpecialize` which determines whether we apply "automated" specialization to a decl, that is + whether we automatically specialize for all fixedHO parameters. It receives both the name and + the array of arguments mentioned in `@[specialize]` if any. +- `alreadySpecialized` which is a mask that says whether a decl is already a specialized declaration + itself. -/ -def saveSpecParamInfo (decls : Array Decl) : CompilerM Unit := do +def computeSpecEntries (decls : Array Decl) (autoSpecialize : Name → Option (Array Nat) → Bool) + (alreadySpecialized : Array Bool) : CompilerM (Array SpecEntry) := do let mut declsInfo := #[] for decl in decls do if hasNospecializeAttribute (← getEnv) decl.name then @@ -178,20 +205,20 @@ def saveSpecParamInfo (decls : Array Decl) : CompilerM Unit := do specify which arguments must be specialized besides instances. In this case, we try to specialize any "fixed higher-order argument" -/ - else if specArgs? == some #[] && param.type matches .forallE .. then + else if autoSpecialize decl.name specArgs? && param.type matches .forallE .. then pure .fixedHO else pure .other paramsInfo := paramsInfo.push info pure () declsInfo := declsInfo.push paramsInfo - if declsInfo.any fun paramsInfo => paramsInfo.any (· matches .user | .fixedInst | .fixedHO) then + if declsInfo.any fun paramsInfo => paramsInfo.any SpecParamInfo.causesSpecialization then let m := mkFixedParamsMap decls + let mut entries := Array.emptyWithCapacity decls.size for hi : i in *...decls.size do let decl := decls[i] let mut paramsInfo := declsInfo[i]! let some mask := m.find? decl.name | unreachable! - trace[Compiler.specialize.info] "{decl.name} {mask}" paramsInfo := Array.zipWith (as := paramsInfo) (bs := mask) fun info fixed => if fixed || info matches .user then info @@ -201,24 +228,43 @@ def saveSpecParamInfo (decls : Array Decl) : CompilerM Unit := do let mut info := paramsInfo[j]! if info matches .fixedNeutral && !hasFwdDeps decl paramsInfo j then paramsInfo := paramsInfo.set! j .other - if paramsInfo.any fun info => info matches .fixedInst | .fixedHO | .user then - trace[Compiler.specialize.info] "{decl.name} {paramsInfo}" - modifyEnv fun env => specExtension.addEntry env { declName := decl.name, paramsInfo } + entries := entries.push { + declName := decl.name, + paramsInfo, + alreadySpecialized := alreadySpecialized[i]! + } + return entries + else + return decls.mapIdx fun i decl => { + declName := decl.name, + paramsInfo := Array.replicate decl.params.size .other + alreadySpecialized := alreadySpecialized[i]! + } -def getSpecParamInfoCore? (env : Environment) (declName : Name) : Option (Array SpecParamInfo) := +/-- +Compute and save specialization information for `decls`. Assumes that `decls` is an SCC of user +defined declarations. +-/ +def saveSpecEntries (decls : Array Decl) : CompilerM Unit := do + let entries ← computeSpecEntries + decls + (fun _ specArgs? => specArgs? == some #[]) + (Array.replicate decls.size false) + for entry in entries do + if entry.paramsInfo.any SpecParamInfo.causesSpecialization then + trace[Compiler.specialize.info] "{entry.declName} {entry.paramsInfo}" + modifyEnv fun env => specExtension.addEntry env entry + +def getSpecEntryCore? (env : Environment) (declName : Name) : Option SpecEntry := match env.getModuleIdxFor? declName with - | some modIdx => - if let some entry := findAtSorted? (specExtension.getModuleEntries env modIdx) declName then - some entry.paramsInfo - else - none + | some modIdx => findAtSorted? (specExtension.getModuleEntries env modIdx) declName | none => (specExtension.getState env).specInfo.find? declName -def getSpecParamInfo? [Monad m] [MonadEnv m] (declName : Name) : m (Option (Array SpecParamInfo)) := - return getSpecParamInfoCore? (← getEnv) declName +def getSpecEntry? [Monad m] [MonadEnv m] (declName : Name) : m (Option SpecEntry) := + return getSpecEntryCore? (← getEnv) declName def isSpecCandidate [Monad m] [MonadEnv m] (declName : Name) : m Bool := do - return getSpecParamInfoCore? (← getEnv) declName |>.isSome + return getSpecEntryCore? (← getEnv) declName |>.isSome builtin_initialize registerTraceClass `Compiler.specialize.info diff --git a/src/Lean/Compiler/LCNF/Specialize.lean b/src/Lean/Compiler/LCNF/Specialize.lean index 76d9739014..502040c55a 100644 --- a/src/Lean/Compiler/LCNF/Specialize.lean +++ b/src/Lean/Compiler/LCNF/Specialize.lean @@ -6,28 +6,25 @@ Authors: Leonardo de Moura module prelude -public import Lean.Compiler.LCNF.Simp public import Lean.Compiler.LCNF.SpecInfo -public import Lean.Compiler.LCNF.ToExpr -public import Lean.Compiler.LCNF.Level public import Lean.Compiler.LCNF.MonadScope -public import Lean.Compiler.LCNF.Closure public import Lean.Compiler.LCNF.FVarUtil -import all Lean.Compiler.LCNF.ToExpr - -public section +import Lean.Compiler.LCNF.Simp +import Lean.Compiler.LCNF.ToExpr +import Lean.Compiler.LCNF.Level +import Lean.Compiler.LCNF.Closure namespace Lean.Compiler.LCNF namespace Specialize -abbrev Cache := SMap Expr Name +public abbrev Cache := SMap Expr Name -structure CacheEntry where +public structure CacheEntry where key : Expr declName : Name deriving Inhabited -def addEntry (cache : Cache) (e : CacheEntry) : Cache := +public def addEntry (cache : Cache) (e : CacheEntry) : Cache := cache.insert e.key e.declName builtin_initialize specCacheExt : SimplePersistentEnvExtension CacheEntry Cache ← @@ -44,10 +41,10 @@ builtin_initialize specCacheExt : SimplePersistentEnvExtension CacheEntry Cache (!·.contains ·.key) addEntry } -def cacheSpec (key : Expr) (declName : Name) : CoreM Unit := +public def cacheSpec (key : Expr) (declName : Name) : CoreM Unit := modifyEnv fun env => specCacheExt.addEntry env { key, declName } -def findSpecCache? (key : Expr) : CoreM (Option Name) := +public def findSpecCache? (key : Expr) : CoreM (Option Name) := return specCacheExt.getState (← getEnv) |>.find? key structure Context where @@ -67,7 +64,29 @@ structure Context where declName : Name structure State where - decls : Array Decl := #[] + /-- + The set of `Decl` that we are done processing. + -/ + processedDecls : Array Decl := #[] + /-- + The set of `Decl` that we will attempt recursive specialization on in the next iteration. + -/ + workingDecls : Array Decl := #[] + /-- + Specialization information about specialized declarations generated in this SCC so far. + -/ + localSpecParamInfo : Std.HashMap Name (Array SpecParamInfo) := {} + /-- + If we specialize a declaration but leave some specializable parameters unspecialized, we store + them as a mask here. This mask is used to determine which parameters we specialize for + recursively. + -/ + parentMasks : Std.HashMap Name (Array Bool) := {} + /-- + Whether we made a change to a declaration in this specialization run so far. This is periodically + reset in the fixpoint loop and the signal for the loop to continue running. + -/ + changed : Bool := false abbrev SpecializeM := ReaderT Context $ StateRefT State CompilerM @@ -199,13 +218,47 @@ end Collector /-- Return `true` if it is worth using arguments `args` for specialization given the parameter specialization information. -/ -def shouldSpecialize (paramsInfo : Array SpecParamInfo) (args : Array Arg) : SpecializeM Bool := do - for paramInfo in paramsInfo, arg in args do +def shouldSpecialize (specEntry : SpecEntry) (args : Array Arg) : SpecializeM Bool := do + let hoCheck := + if specEntry.alreadySpecialized then + fun arg => do + /- + If we have `f p` where `p` is a param it makes no sense to specialize as we will just + close over `p` again and will have made no progress. + + The reason for doing this only for declarations which have have already been specialised + themselves is, that we *must* always specialize declarations that are marked with + `@[specialize]`. This is because the specializer will not specialize their bodies because it + waits for the bodies to be specialized at the call site. This is for example important in + the following situation: + ``` + @[specialize] + def test (f : ... -> ...) := + ... + HashMap.get? inst1 inst2 xs ys + ``` + Here the call to `HashMap.get?` will not be specialized unless `test` is specialized. Thus, + even when `f` is just going to be re-abstracted, it makes sense to specialize a call to `test` + that closes over parameters, in order to optimize the `HashMap` invocation. + + We thought about lifting this restriction and instead always specializing `@[specialize]` + decls twice, once at their definition site and once at their call site. However, almost all + `@[specialize]` function declarations will indeed get specialized for a non-trivial function + instead of just an argument. Hence keeping the first version around is likely a waste of + space because it will often end up going unused. + -/ + match arg with + | .erased | .type .. => return false + | .fvar fvar => return (← findParam? fvar).isNone + else + fun _ => pure true + for paramInfo in specEntry.paramsInfo, arg in args do match paramInfo with | .other => pure () | .fixedNeutral => pure () -- If we want to monomorphize types such as `Array`, we need to change here - | .fixedInst | .user => if (← isGround arg) then return true - | .fixedHO => return true -- TODO: check whether this is too aggressive + | .fixedInst | .user => if ← isGround arg then return true + | .fixedHO => if ← hoCheck arg then return true + return false /-- @@ -257,7 +310,10 @@ Specialize `decl` using - `levelParamsNew`: the universe level parameters for the new declaration. -/ def mkSpecDecl (decl : Decl) (us : List Level) (argMask : Array (Option Arg)) (params : Array Param) (decls : Array CodeDecl) (levelParamsNew : List Name) : SpecializeM Decl := do - let nameNew := decl.name.appendCore `_at_ |>.appendCore (← read).declName |>.appendCore `spec |>.appendIndexAfter (← get).decls.size + let nameNew := decl.name.appendCore `_at_ + |>.appendCore (← read).declName + |>.appendCore `spec + |>.appendIndexAfter ((← get).processedDecls.size + (← get).workingDecls.size) /- Recall that we have just retrieved `decl` using `getDecl?`, and it may have free variable identifiers that overlap with the free-variables in `params` and `decls` (i.e., the "closure"). @@ -314,6 +370,24 @@ def paramsToGroundVars (params : Array Param) : CompilerM FVarIdSet := else return r +def getSpecEntry? (declName : Name) : SpecializeM (Option SpecEntry) := do + if let some paramsInfo := (← get).localSpecParamInfo[declName]? then + return some { declName, paramsInfo, alreadySpecialized := true } + else + LCNF.getSpecEntry? declName + +@[inline] +def markChanged : SpecializeM Unit := + modify fun s => { s with changed := true } + +@[inline] +def resetChanged : SpecializeM Unit := + modify fun s => { s with changed := false } + +@[inline] +def hasChanged : SpecializeM Bool := + return (← get).changed + mutual /-- Try to specialize the function application in the given let-declaration. @@ -323,11 +397,12 @@ mutual let .const declName us args := e | return none if args.isEmpty then return none if (← Meta.isInstance declName) then return none - let some paramsInfo ← getSpecParamInfo? declName | return none - unless (← shouldSpecialize paramsInfo args) do return none + let some specEntry ← getSpecEntry? declName | return none + unless (← shouldSpecialize specEntry args) do return none let some decl ← getDecl? declName | return none let .code _ := decl.value | return none - trace[Compiler.specialize.candidate] "{e.toExpr}, {paramsInfo}" + trace[Compiler.specialize.candidate] "{e.toExpr}, {specEntry}" + let paramsInfo := specEntry.paramsInfo let (argMask, params, decls) ← Collector.collect paramsInfo args let keyBody := .const declName us (argMask.filterMap id) let (key, levelParamsNew) ← mkKey params decls keyBody @@ -341,18 +416,31 @@ mutual return some (.const declName usNew argsNew) else let specDecl ← mkSpecDecl decl us argMask params decls levelParamsNew - trace[Compiler.specialize.step] "new: {specDecl.name}" + let parentMask ← argsNew.mapM + fun + | .type .. | .erased => return false + | .fvar fvar => do + if let some param ← findParam? fvar then + /- + For now we only allow recursive specialization on non class parameters, reason: + We can encounter situations where we repeatedly re-abstract over type classes + recursively and would end up in a loop because of that. + -/ + return (param.type matches .forallE ..) && !(← isArrowClass? param.type).isSome + else + return false cacheSpec key specDecl.name specDecl.saveBase let specDecl ← specDecl.etaExpand specDecl.saveBase let specDecl ← specDecl.simp {} let specDecl ← specDecl.simp { etaPoly := true, inlinePartial := true, implementedBy := true } - let ground ← paramsToGroundVars specDecl.params - let value ← withReader (fun _ => { declName := specDecl.name, ground }) do - withParams specDecl.params <| specDecl.value.mapCodeM visitCode - let specDecl := { specDecl with value } - modify fun s => { s with decls := s.decls.push specDecl } + trace[Compiler.specialize.step] "new: {specDecl.name}: {← ppDecl specDecl}" + modify fun s => { + s with + workingDecls := s.workingDecls.push specDecl, + parentMasks := s.parentMasks.insert specDecl.name parentMask + } return some (.const specDecl.name usNew argsNew) partial def visitFunDecl (funDecl : FunDecl) : SpecializeM FunDecl := do @@ -364,6 +452,7 @@ mutual | .let decl k => let mut decl := decl if let some value ← specializeApp? decl.value then + markChanged decl ← decl.updateValue value let k ← withLetDecl decl <| visitCode k return code.updateLet! decl k @@ -385,26 +474,88 @@ mutual end -def main (decl : Decl) : SpecializeM Decl := do +/-- +Run specialization on the body of `decl`. +-/ +def specializeDecl (decl : Decl) : SpecializeM (Decl × Bool) := do + trace[Compiler.specialize.step] m!"Working {decl.name}" if (← decl.isTemplateLike) then - return decl + return (decl, false) else + resetChanged let value ← withParams decl.params <| decl.value.mapCodeM visitCode - return { decl with value } + let changed ← hasChanged + let mut updated := { decl with value } + if changed then + updated ← updated.simp {} + trace[Compiler.specialize.step] m!"Result {decl.name}: {← ppDecl updated}" + return (updated, changed) + +/-- +Recompute specialization information for the current SCC. +-/ +def updateLocalSpecParamInfo : SpecializeM Unit := do + let decls := (← get).processedDecls ++ (← get).workingDecls + let masks := (← get).parentMasks + let infos ← computeSpecEntries + decls + (fun declName specArgs? => specArgs? == some #[] || (masks[declName]?.getD #[] |>.any (· == true))) + (decls.map (masks.contains ·.name)) + + for entry in infos do + if let some mask := (← get).parentMasks[entry.declName]? then + let maskInfo info := + mask.zipWith info (f := fun b i => if !b && i.causesSpecialization then .other else i) + let entry := { entry with paramsInfo := maskInfo entry.paramsInfo } + modify fun s => { + s with + localSpecParamInfo := s.localSpecParamInfo.insert entry.declName entry.paramsInfo + } + + trace[Compiler.specialize.step] m!"Info for next round: {(← get).localSpecParamInfo.toList}" + +partial def loop (round : Nat := 0) : SpecializeM Unit := do + let targets ← modifyGet (fun s => (s.workingDecls, { s with workingDecls := #[] })) + let limit := (← getConfig).maxRecSpecialize + if targets.isEmpty then + trace[Compiler.specialize.step] m!"Termination after {round} rounds" + for (declName, paramsInfo) in (← get).localSpecParamInfo do + if paramsInfo.any SpecParamInfo.causesSpecialization then + trace[Compiler.specialize.info] "{declName} {paramsInfo}" + modifyEnv fun env => specExtension.addEntry env { + declName, + paramsInfo, + alreadySpecialized := true + } + return () + else if round > limit then + throwError m!"Exceeded recursive specialization limit ({limit}), consider increasing it with `set_option compiler.maxRecSpecialize {limit}`" + + trace[Compiler.specialize.step] m!"Round: {round}" + for decl in targets do + let ground ← Specialize.paramsToGroundVars decl.params + let (newDecl, changed) ← withReader (fun ctx => { ctx with ground, declName := decl.name }) do + specializeDecl decl + if changed then + modify fun s => { s with workingDecls := s.workingDecls.push newDecl } + else + modify fun s => { s with processedDecls := s.processedDecls.push newDecl } + + updateLocalSpecParamInfo + + loop (round + 1) + +def main (decls : Array Decl) : CompilerM (Array Decl) := do + saveSpecEntries decls + let (_, s) ← loop |>.run { declName := .anonymous } |>.run { workingDecls := decls } + return s.processedDecls end Specialize -partial def Decl.specialize (decl : Decl) : CompilerM (Array Decl) := do - let ground ← Specialize.paramsToGroundVars decl.params - let (decl, s) ← Specialize.main decl |>.run { declName := decl.name, ground } |>.run {} - return s.decls.push decl - -def specialize : Pass where +public def specialize : Pass where phase := .base name := `specialize - run := fun decls => do - saveSpecParamInfo decls - decls.foldlM (init := #[]) fun decls decl => return decls ++ (← decl.specialize) + run := Specialize.main builtin_initialize registerTraceClass `Compiler.specialize (inherited := true) diff --git a/stage0/src/stdlib_flags.h b/stage0/src/stdlib_flags.h index 79a0e58edd..93a5422fe7 100644 --- a/stage0/src/stdlib_flags.h +++ b/stage0/src/stdlib_flags.h @@ -1,5 +1,7 @@ #include "util/options.h" +// please update thy + namespace lean { options get_default_options() { options opts; diff --git a/tests/lean/run/do_for_loop_compiler_test.lean b/tests/lean/run/do_for_loop_compiler_test.lean new file mode 100644 index 0000000000..1cf37ec576 --- /dev/null +++ b/tests/lean/run/do_for_loop_compiler_test.lean @@ -0,0 +1,239 @@ +import Std.Do.Triple.SpecLemmas + +@[specialize, expose] +def List.newForIn (l : List α) (b : β) (kcons : α → (β → γ) → β → γ) (knil : β → γ) : γ := + match l with + | [] => knil b + | a :: l => kcons a (l.newForIn · kcons knil) b + +/-- +trace: [Compiler.saveMono] size: 7 + def List.newForIn._at_.List.newForIn._at_.testing.spec_0._at_.List.newForIn._at_.testing.spec_1.spec_2.spec_2 i _x.1 tail.2 l b : Nat := + cases l : Nat + | List.nil => + let _x.3 := List.newForIn._at_.testing.spec_1 _x.1 tail.2 b; + return _x.3 + | List.cons head.4 tail.5 => + let _x.6 := Nat.add b i; + let x := Nat.add _x.6 head.4; + let _x.7 := List.newForIn._at_.List.newForIn._at_.testing.spec_0._at_.List.newForIn._at_.testing.spec_1.spec_2.spec_2 i _x.1 tail.2 tail.5 x; + return _x.7 +[Compiler.saveMono] size: 7 + def List.newForIn._at_.testing.spec_0 i kcontinue l b : Nat := + cases l : Nat + | List.nil => + let _x.1 := kcontinue b; + return _x.1 + | List.cons head.2 tail.3 => + let _x.4 := Nat.add b i; + let x := Nat.add _x.4 head.2; + let _x.5 := List.newForIn._at_.testing.spec_0 i kcontinue tail.3 x; + return _x.5 +[Compiler.saveMono] size: 7 + def List.newForIn._at_.testing.spec_0._at_.List.newForIn._at_.testing.spec_1.spec_2 _x.1 tail.2 i l b : Nat := + cases l : Nat + | List.nil => + let _x.3 := List.newForIn._at_.testing.spec_1 _x.1 tail.2 b; + return _x.3 + | List.cons head.4 tail.5 => + let _x.6 := Nat.add b i; + let x := Nat.add _x.6 head.4; + let _x.7 := List.newForIn._at_.List.newForIn._at_.testing.spec_0._at_.List.newForIn._at_.testing.spec_1.spec_2.spec_2 i _x.1 tail.2 tail.5 x; + return _x.7 +[Compiler.saveMono] size: 9 + def testing : Nat := + let x := 42; + let _x.1 := 1; + let _x.2 := 2; + let _x.3 := 3; + let _x.4 := [] ◾; + let _x.5 := List.cons ◾ _x.3 _x.4; + let _x.6 := List.cons ◾ _x.2 _x.5; + let _x.7 := List.cons ◾ _x.1 _x.6; + let _x.8 := List.newForIn._at_.testing.spec_1 _x.1 _x.7 x; + return _x.8 +[Compiler.saveMono] size: 12 + def List.newForIn._at_.testing.spec_1 _x.1 l b : Nat := + cases l : Nat + | List.nil => + return b + | List.cons head.2 tail.3 => + let _x.4 := 10; + let _x.5 := Nat.sub _x.4 head.2; + let _x.6 := Nat.add _x.5 _x.1; + let _x.7 := 1; + let _x.8 := Nat.sub _x.6 _x.7; + let _x.9 := Nat.add head.2 _x.8; + let _x.10 := [] ◾; + let _x.11 := List.range'TR.go _x.1 _x.8 _x.9 _x.10; + let _x.12 := List.newForIn._at_.testing.spec_0._at_.List.newForIn._at_.testing.spec_1.spec_2 _x.1 tail.3 head.2 _x.11 b; + return _x.12 +-/ +#guard_msgs in +set_option trace.Compiler.saveMono true in +def testing := + let x := 42; + List.newForIn (β := Nat) (γ := Id Nat) + [1,2,3] + x + (fun i kcontinue s => + let x := s; + List.newForIn + [i:10].toList x + (fun j kcontinue s => + let x := s; + let x := x + i + j; + kcontinue x) + kcontinue) + pure + + +/-- +trace: [Compiler.saveMono] size: 7 + def List.newForIn._at_.testing.spec_0._at_.List.newForIn._at_.testing2.spec_0.spec_1 _x.1 tail.2 i l b : Nat := + cases l : Nat + | List.nil => + let _x.3 := List.newForIn._at_.testing2.spec_0 _x.1 tail.2 b; + return _x.3 + | List.cons head.4 tail.5 => + let _x.6 := Nat.add b i; + let x := Nat.add _x.6 head.4; + let _x.7 := List.newForIn._at_.testing.spec_0._at_.List.newForIn._at_.testing2.spec_0.spec_1 _x.1 tail.2 i tail.5 x; + return _x.7 +[Compiler.saveMono] size: 9 + def testing2 : Nat := + let x := 42; + let _x.1 := 1; + let _x.2 := 2; + let _x.3 := 3; + let _x.4 := [] ◾; + let _x.5 := List.cons ◾ _x.3 _x.4; + let _x.6 := List.cons ◾ _x.2 _x.5; + let _x.7 := List.cons ◾ _x.1 _x.6; + let _x.8 := List.newForIn._at_.testing2.spec_0 _x.1 _x.7 x; + return _x.8 +[Compiler.saveMono] size: 14 + def List.newForIn._at_.testing2.spec_0 _x.1 l b : Nat := + cases l : Nat + | List.nil => + return b + | List.cons head.2 tail.3 => + let _x.4 := 37; + let x := Nat.add b _x.4; + let _x.5 := 10; + let _x.6 := Nat.sub _x.5 head.2; + let _x.7 := Nat.add _x.6 _x.1; + let _x.8 := 1; + let _x.9 := Nat.sub _x.7 _x.8; + let _x.10 := Nat.add head.2 _x.9; + let _x.11 := [] ◾; + let _x.12 := List.range'TR.go _x.1 _x.9 _x.10 _x.11; + let _x.13 := List.newForIn._at_.testing.spec_0._at_.List.newForIn._at_.testing2.spec_0.spec_1 _x.1 tail.3 head.2 _x.12 x; + return _x.13 +-/ +#guard_msgs in +set_option trace.Compiler.saveMono true in +def testing2 := + let x := 42; + List.newForIn (β := Nat) (γ := Id Nat) + [1,2,3] + x + (fun i kcontinue s => + -- difference to testing1 here + let x := s + 37; + List.newForIn + [i:10].toList x + (fun j kcontinue s => + let x := s; + let x := x + i + j; + kcontinue x) + kcontinue) + pure + +/-- +trace: [Compiler.saveMono] size: 9 + def List.newForIn._at_.List.newForIn._at_.testing3.spec_0._at_.List.newForIn._at_.testing3.spec_1.spec_2.spec_2 s i _x.1 tail.2 l b : Nat := + cases l : Nat + | List.nil => + let _x.3 := List.newForIn._at_.testing3.spec_1 _x.1 tail.2 b; + return _x.3 + | List.cons head.4 tail.5 => + let _x.6 := Nat.add b b; + let x := Nat.add _x.6 s; + let _x.7 := Nat.add x i; + let x := Nat.add _x.7 head.4; + let _x.8 := List.newForIn._at_.List.newForIn._at_.testing3.spec_0._at_.List.newForIn._at_.testing3.spec_1.spec_2.spec_2 s i _x.1 tail.2 tail.5 x; + return _x.8 +[Compiler.saveMono] size: 9 + def List.newForIn._at_.testing3.spec_0 s i kcontinue l b : Nat := + cases l : Nat + | List.nil => + let _x.1 := kcontinue b; + return _x.1 + | List.cons head.2 tail.3 => + let _x.4 := Nat.add b b; + let x := Nat.add _x.4 s; + let _x.5 := Nat.add x i; + let x := Nat.add _x.5 head.2; + let _x.6 := List.newForIn._at_.testing3.spec_0 s i kcontinue tail.3 x; + return _x.6 +[Compiler.saveMono] size: 9 + def List.newForIn._at_.testing3.spec_0._at_.List.newForIn._at_.testing3.spec_1.spec_2 _x.1 tail.2 s i l b : Nat := + cases l : Nat + | List.nil => + let _x.3 := List.newForIn._at_.testing3.spec_1 _x.1 tail.2 b; + return _x.3 + | List.cons head.4 tail.5 => + let _x.6 := Nat.add b b; + let x := Nat.add _x.6 s; + let _x.7 := Nat.add x i; + let x := Nat.add _x.7 head.4; + let _x.8 := List.newForIn._at_.List.newForIn._at_.testing3.spec_0._at_.List.newForIn._at_.testing3.spec_1.spec_2.spec_2 s i _x.1 tail.2 tail.5 x; + return _x.8 +[Compiler.saveMono] size: 9 + def testing3 : Nat := + let x := 42; + let _x.1 := 1; + let _x.2 := 2; + let _x.3 := 3; + let _x.4 := [] ◾; + let _x.5 := List.cons ◾ _x.3 _x.4; + let _x.6 := List.cons ◾ _x.2 _x.5; + let _x.7 := List.cons ◾ _x.1 _x.6; + let _x.8 := List.newForIn._at_.testing3.spec_1 _x.1 _x.7 x; + return _x.8 +[Compiler.saveMono] size: 12 + def List.newForIn._at_.testing3.spec_1 _x.1 l b : Nat := + cases l : Nat + | List.nil => + return b + | List.cons head.2 tail.3 => + let _x.4 := 10; + let _x.5 := Nat.sub _x.4 head.2; + let _x.6 := Nat.add _x.5 _x.1; + let _x.7 := 1; + let _x.8 := Nat.sub _x.6 _x.7; + let _x.9 := Nat.add head.2 _x.8; + let _x.10 := [] ◾; + let _x.11 := List.range'TR.go _x.1 _x.8 _x.9 _x.10; + let _x.12 := List.newForIn._at_.testing3.spec_0._at_.List.newForIn._at_.testing3.spec_1.spec_2 _x.1 tail.3 b head.2 _x.11 b; + return _x.12 +-/ +#guard_msgs in +set_option trace.Compiler.saveMono true in +def testing3 := + let x := 42; + List.newForIn (β := Nat) (γ := Id Nat) + [1,2,3] + x + (fun i kcontinue s => + let x := s; + List.newForIn + [i:10].toList x + (fun j kcontinue s => + -- difference to testing1 here + let x := s + s + x; + let x := x + i + j; + kcontinue x) + kcontinue) + pure diff --git a/tests/lean/run/do_for_loop_levenstein_compiler_test.lean b/tests/lean/run/do_for_loop_levenstein_compiler_test.lean new file mode 100644 index 0000000000..c1995fa8f0 --- /dev/null +++ b/tests/lean/run/do_for_loop_levenstein_compiler_test.lean @@ -0,0 +1,250 @@ +@[inline] unsafe def Array.forInNew'Unsafe {α : Type u} {σ β : Type v} {m : Type v → Type w} + (as : Array α) (s : σ) (kcons : (a : α) → (h : a ∈ as) → (σ → m β) → σ → m β) (knil : σ → m β) : m β := + let sz := as.usize + let rec @[specialize] loop (i : USize) (s : σ) : m β := + if i < sz then + let a := as.uget i lcProof + kcons a lcProof (loop (i+1)) s + else + knil s + loop 0 s + +@[inline] protected def Std.Range.forInNew' {m : Type u → Type v} {σ β} (range : Range) (init : σ) + (kcons : (i : Nat) → i ∈ range → (σ → m β) → σ → m β) (knil : σ → m β) : m β := + have := range.step_pos + let rec @[specialize] loop (i : Nat) + (hs : (i - range.start) % range.step = 0) (hl : range.start ≤ i := by omega) : σ → m β := + if h : i < range.stop then + kcons i ⟨hl, by omega, hs⟩ (loop (i + range.step) (by rwa [Nat.add_comm, Nat.add_sub_assoc hl, Nat.add_mod_left])) + else + knil + loop range.start (by simp) (by simp) init + +/-- +trace: [Compiler.saveMono] size: 1 + def Std.Range.forInNew'.loop._at_.Std.Range.forInNew'.loop._at_.deletions.spec_1._at_.Array.forInNew'Unsafe.loop._at_.deletions.spec_2.spec_4.spec_4 s' _x.1 _x.2 as sz _x.3 range this i hs hl a.4 : Array + String := + let _x.5 := Std.Range.forInNew'.loop._at_.Std.Range.forInNew'.loop._at_.deletions.spec_1._at_.Array.forInNew'Unsafe.loop._at_.deletions.spec_2.spec_4.spec_4._redArg s' _x.1 _x.2 as sz _x.3 range i a.4; + return _x.5 +[Compiler.saveMono] size: 1 + def Std.Range.forInNew'.loop._at_.deletions.spec_1 s' _x.1 _x.2 kcontinue range this i hs hl a.3 : Array String := + let _x.4 := Std.Range.forInNew'.loop._at_.deletions.spec_1._redArg s' _x.1 _x.2 kcontinue range i a.3; + return _x.4 +[Compiler.saveMono] size: 1 + def Std.Range.forInNew'.loop._at_.deletions.spec_1._at_.Array.forInNew'Unsafe.loop._at_.deletions.spec_2.spec_4 as sz _x.1 s' _x.2 _x.3 range this i hs hl a.4 : Array + String := + let _x.5 := Std.Range.forInNew'.loop._at_.deletions.spec_1._at_.Array.forInNew'Unsafe.loop._at_.deletions.spec_2.spec_4._redArg as sz _x.1 s' _x.2 _x.3 range i a.4; + return _x.5 +[Compiler.saveMono] size: 12 + def Array.contains._at_.deletions.spec_0 as a : Bool := + let _x.1 := 0; + let _x.2 := Array.size ◾ as; + let _x.3 := Nat.decLt _x.1 _x.2; + cases _x.3 : Bool + | Bool.false => + return _x.3 + | Bool.true => + cases _x.3 : Bool + | Bool.false => + return _x.3 + | Bool.true => + let _x.4 := 0; + let _x.5 := USize.ofNat _x.2; + let _x.6 := Array.anyMUnsafe.any._at_.Array.contains._at_.deletions.spec_0.spec_0.2 a as _x.4 _x.5; + return _x.6 +[Compiler.saveMono] size: 13 + def _private.Init.Data.Array.Basic.0.Array.anyMUnsafe.any._at_.Array.contains._at_.deletions.spec_0.spec_0 a as i stop : Bool := + let _x.1 := USize.decEq i stop; + cases _x.1 : Bool + | Bool.false => + let _x.2 := Array.uget ◾ as i ◾; + let _x.3 := String.decEq a _x.2; + cases _x.3 : Bool + | Bool.false => + let _x.4 := 1; + let _x.5 := USize.add i _x.4; + let _x.6 := Array.anyMUnsafe.any._at_.Array.contains._at_.deletions.spec_0.spec_0.2 a as _x.5 stop; + return _x.6 + | Bool.true => + return _x.3 + | Bool.true => + let _x.7 := false; + return _x.7 +[Compiler.saveMono] size: 15 + def deletions n s : Array String := + let zero := 0; + let isZero := Nat.decEq n zero; + cases isZero : Array String + | Bool.true => + let _x.1 := 1; + let _x.2 := Array.mkEmpty ◾ _x.1; + let _x.3 := Array.push ◾ _x.2 s; + return _x.3 + | Bool.false => + let one := 1; + let n.4 := Nat.sub n one; + let out := Array.mkEmpty ◾ zero; + let _x.5 := deletions n.4 s; + let sz := Array.usize ◾ _x.5; + let _x.6 := 0; + let _x.7 := Array.forInNew'Unsafe.loop._at_.deletions.spec_2 _x.5 sz _x.6 out; + return _x.7 +[Compiler.saveMono] size: 19 + def Array.forInNew'Unsafe.loop._at_.deletions.spec_2 as sz i s : Array String := + let _x.1 := USize.decLt i sz; + cases _x.1 : Array String + | Bool.false => + let _x.2 := Array.reverse._redArg s; + return _x.2 + | Bool.true => + let a := Array.uget ◾ as i ◾; + let _x.3 := String.utf8ByteSize a; + let _x.4 := 0; + let _x.5 := Nat.decEq _x.3 _x.4; + cases _x.5 : Array String + | Bool.false => + let _x.6 := 1; + let _x.7 := USize.add i _x.6; + let _x.8 := String.length a; + let _x.9 := 1; + let _x.10 := Std.Range.mk _x.4 _x.8 _x.9 ◾; + let _x.11 := Std.Range.forInNew'.loop._at_.deletions.spec_1._at_.Array.forInNew'Unsafe.loop._at_.deletions.spec_2.spec_4._redArg as sz _x.7 a _x.9 _x.5 _x.10 _x.4 s; + return _x.11 + | Bool.true => + let _x.12 := Array.reverse._redArg s; + return _x.12 +[Compiler.saveMono] size: 29 + def Std.Range.forInNew'.loop._at_.Std.Range.forInNew'.loop._at_.deletions.spec_1._at_.Array.forInNew'Unsafe.loop._at_.deletions.spec_2.spec_4.spec_4._redArg s' _x.1 _x.2 as sz _x.3 range i a.4 : Array + String := + cases range : Array String + | Std.Range.mk start stop step step_pos => + let _x.5 := Nat.decLt i stop; + cases _x.5 : Array String + | Bool.false => + let _x.6 := Array.forInNew'Unsafe.loop._at_.deletions.spec_2 as sz _x.3 a.4; + return _x.6 + | Bool.true => + let _x.7 := Nat.add i step; + let _x.8 := 0; + let _x.9 := String.utf8ByteSize s'; + let _x.10 := String.Slice.mk s' _x.8 _x.9 ◾; + let _x.11 := @String.Slice.Pos.nextn _x.10 _x.8 i; + let _x.12 := @String.extract s' _x.8 _x.11; + let _x.13 := Nat.add i _x.1; + let _x.14 := @String.Slice.Pos.nextn _x.10 _x.8 _x.13; + let _x.15 := @String.extract s' _x.14 _x.9; + let d := String.append _x.12 _x.15; + jp _jp.16 : Array String := + let out := Array.push ◾ a.4 d; + let _x.17 := Std.Range.forInNew'.loop._at_.Std.Range.forInNew'.loop._at_.deletions.spec_1._at_.Array.forInNew'Unsafe.loop._at_.deletions.spec_2.spec_4.spec_4._redArg s' _x.1 _x.2 as sz _x.3 range _x.7 out; + return _x.17; + let _x.18 := Array.contains._at_.deletions.spec_0 a.4 d; + cases _x.18 : Array String + | Bool.false => + goto _jp.16 + | Bool.true => + cases _x.2 : Array String + | Bool.false => + let _x.19 := Std.Range.forInNew'.loop._at_.Std.Range.forInNew'.loop._at_.deletions.spec_1._at_.Array.forInNew'Unsafe.loop._at_.deletions.spec_2.spec_4.spec_4._redArg s' _x.1 _x.2 as sz _x.3 range _x.7 a.4; + return _x.19 + | Bool.true => + goto _jp.16 +[Compiler.saveMono] size: 29 + def Std.Range.forInNew'.loop._at_.deletions.spec_1._at_.Array.forInNew'Unsafe.loop._at_.deletions.spec_2.spec_4._redArg as sz _x.1 s' _x.2 _x.3 range i a.4 : Array + String := + cases range : Array String + | Std.Range.mk start stop step step_pos => + let _x.5 := Nat.decLt i stop; + cases _x.5 : Array String + | Bool.false => + let _x.6 := Array.forInNew'Unsafe.loop._at_.deletions.spec_2 as sz _x.1 a.4; + return _x.6 + | Bool.true => + let _x.7 := Nat.add i step; + let _x.8 := 0; + let _x.9 := String.utf8ByteSize s'; + let _x.10 := String.Slice.mk s' _x.8 _x.9 ◾; + let _x.11 := @String.Slice.Pos.nextn _x.10 _x.8 i; + let _x.12 := @String.extract s' _x.8 _x.11; + let _x.13 := Nat.add i _x.2; + let _x.14 := @String.Slice.Pos.nextn _x.10 _x.8 _x.13; + let _x.15 := @String.extract s' _x.14 _x.9; + let d := String.append _x.12 _x.15; + jp _jp.16 : Array String := + let out := Array.push ◾ a.4 d; + let _x.17 := Std.Range.forInNew'.loop._at_.Std.Range.forInNew'.loop._at_.deletions.spec_1._at_.Array.forInNew'Unsafe.loop._at_.deletions.spec_2.spec_4.spec_4._redArg s' _x.2 _x.3 as sz _x.1 range _x.7 out; + return _x.17; + let _x.18 := Array.contains._at_.deletions.spec_0 a.4 d; + cases _x.18 : Array String + | Bool.false => + goto _jp.16 + | Bool.true => + cases _x.3 : Array String + | Bool.false => + let _x.19 := Std.Range.forInNew'.loop._at_.Std.Range.forInNew'.loop._at_.deletions.spec_1._at_.Array.forInNew'Unsafe.loop._at_.deletions.spec_2.spec_4.spec_4._redArg s' _x.2 _x.3 as sz _x.1 range _x.7 a.4; + return _x.19 + | Bool.true => + goto _jp.16 +[Compiler.saveMono] size: 29 + def Std.Range.forInNew'.loop._at_.deletions.spec_1._redArg s' _x.1 _x.2 kcontinue range i a.3 : Array String := + cases range : Array String + | Std.Range.mk start stop step step_pos => + let _x.4 := Nat.decLt i stop; + cases _x.4 : Array String + | Bool.false => + let _x.5 := kcontinue a.3; + return _x.5 + | Bool.true => + let _x.6 := Nat.add i step; + let _x.7 := 0; + let _x.8 := String.utf8ByteSize s'; + let _x.9 := String.Slice.mk s' _x.7 _x.8 ◾; + let _x.10 := @String.Slice.Pos.nextn _x.9 _x.7 i; + let _x.11 := @String.extract s' _x.7 _x.10; + let _x.12 := Nat.add i _x.1; + let _x.13 := @String.Slice.Pos.nextn _x.9 _x.7 _x.12; + let _x.14 := @String.extract s' _x.13 _x.8; + let d := String.append _x.11 _x.14; + jp _jp.15 : Array String := + let out := Array.push ◾ a.3 d; + let _x.16 := Std.Range.forInNew'.loop._at_.deletions.spec_1._redArg s' _x.1 _x.2 kcontinue range _x.6 out; + return _x.16; + let _x.17 := Array.contains._at_.deletions.spec_0 a.3 d; + cases _x.17 : Array String + | Bool.false => + goto _jp.15 + | Bool.true => + cases _x.2 : Array String + | Bool.false => + let _x.18 := Std.Range.forInNew'.loop._at_.deletions.spec_1._redArg s' _x.1 _x.2 kcontinue range _x.6 a.3; + return _x.18 + | Bool.true => + goto _jp.15 +-/ +#guard_msgs in +set_option trace.Compiler.saveMono true in +unsafe def deletions (n : Nat) (s : String) : Array String := + match n with + | 0 => #[s] + | n' + 1 => Id.run do + let out := #[]; + have kbreak := fun (s : Array String) => + let out := s; + pure out.reverse; + (deletions n' s).forInNew'Unsafe out + (fun s' _ kcontinue s => + let out := s; + if s'.isEmpty = true then kbreak out + else + [:s'.length].forInNew' out + (fun i _ kcontinue s => + let out := s; + let d := (s'.take i).copy ++ s'.drop (i + 1); + if (!out.contains d) = true then + let out := out.push d; + kcontinue out + else kcontinue out) + fun s => + let out := s; + kcontinue out) + kbreak diff --git a/tests/lean/run/more_jps.lean b/tests/lean/run/more_jps.lean index 39705c5855..6b6985c0bf 100644 --- a/tests/lean/run/more_jps.lean +++ b/tests/lean/run/more_jps.lean @@ -16,7 +16,19 @@ def List.forBreak_ {α : Type u} {m : Type w → Type x} [Monad m] (xs : List α s /-- -trace: [Compiler.saveBase] size: 25 +trace: [Compiler.saveBase] size: 9 + def _example : Nat := + let x := 42; + let _x.1 := 1; + let _x.2 := 2; + let _x.3 := 3; + let _x.4 := @List.nil _; + let _x.5 := @List.cons _ _x.3 _x.4; + let _x.6 := @List.cons _ _x.2 _x.5; + let _x.7 := @List.cons _ _x.1 _x.6; + let _x.8 := List.foldrNonTR._at_._example.spec_0 _x.7 x; + return _x.8 +[Compiler.saveBase] size: 25 def List.foldrNonTR._at_._example.spec_0 x.1 _y.2 : Nat := jp _jp.3 x : Nat := let _x.4 := 13; @@ -47,18 +59,6 @@ trace: [Compiler.saveBase] size: 25 goto _jp.3 _y.2 | Decidable.isTrue x.16 => return _y.2 -[Compiler.saveBase] size: 9 - def _example : Nat := - let x := 42; - let _x.1 := 1; - let _x.2 := 2; - let _x.3 := 3; - let _x.4 := @List.nil _; - let _x.5 := @List.cons _ _x.3 _x.4; - let _x.6 := @List.cons _ _x.2 _x.5; - let _x.7 := @List.cons _ _x.1 _x.6; - let _x.8 := List.foldrNonTR._at_._example.spec_0 _x.7 x; - return _x.8 -/ #guard_msgs in set_option trace.Compiler.saveBase true in diff --git a/tests/lean/run/specFixedHOParamModuloErased.lean b/tests/lean/run/specFixedHOParamModuloErased.lean index dc902da5c0..2f74b98f2a 100644 --- a/tests/lean/run/specFixedHOParamModuloErased.lean +++ b/tests/lean/run/specFixedHOParamModuloErased.lean @@ -1,7 +1,4 @@ -/-- -trace: [Compiler.specialize.info] pmap [true, true, false, true] -[Compiler.specialize.info] pmap [N, N, O, H] --/ +/-- trace: [Compiler.specialize.info] pmap [N, N, O, H] -/ #guard_msgs in set_option trace.Compiler.specialize.info true in @[specialize] diff --git a/tests/lean/run/spec_limit.lean b/tests/lean/run/spec_limit.lean new file mode 100644 index 0000000000..20990bb123 --- /dev/null +++ b/tests/lean/run/spec_limit.lean @@ -0,0 +1,11 @@ +/-! This test asserts that the compiler respects compiler.maxRecSpecialize -/ + +@[specialize, noinline] +def aux2 (f : Nat → Nat) := f 12 + +/-- +error: Exceeded recursive specialization limit (0), consider increasing it with `set_option compiler.maxRecSpecialize 0` +-/ +#guard_msgs in +set_option compiler.maxRecSpecialize 0 in +def test := aux2 Nat.succ