diff --git a/src/Lean/Compiler/LCNF/Bind.lean b/src/Lean/Compiler/LCNF/Bind.lean index 874c8de9c3..c249f0507b 100644 --- a/src/Lean/Compiler/LCNF/Bind.lean +++ b/src/Lean/Compiler/LCNF/Bind.lean @@ -93,7 +93,7 @@ where match type with | .forallE _ d b _ => let d := d.instantiateRev xs - let p ← mkAuxParam d + let p ← mkAuxParam d (isMarkedBorrowed d) go b (xs.push (.fvar p.fvarId)) (ps.push p) | _ => let type := type.instantiateRev xs diff --git a/src/Lean/Compiler/LCNF/InferBorrow.lean b/src/Lean/Compiler/LCNF/InferBorrow.lean index e062ea18ef..85a68efef3 100644 --- a/src/Lean/Compiler/LCNF/InferBorrow.lean +++ b/src/Lean/Compiler/LCNF/InferBorrow.lean @@ -41,6 +41,11 @@ For performance we: particular when `f` is partially applied we ensure that all arguments are owned. - When passing a parameter into a constructor we ensure it is passed as owned so we do not have to `inc` before calling the constructor. + +Furthermore, the performance related heuristics will be ignored if there is a user-defined +borrow annotation. This allows the user to override the ABI of their function in exchange for +potentially worse code in the function that is being analyzed. Doing so can benefit other functions +that call the current one and thus reduce overall RC pressure. -/ namespace Lean.Compiler.LCNF @@ -56,11 +61,25 @@ inductive Key where end ParamMap -abbrev ParamMap := Std.HashMap ParamMap.Key (Array (Param .impure)) +structure ParamMap where + map : Std.HashMap ParamMap.Key (Array (Param .impure)) := {} + /-- + The set of fvars that were already annotated as borrowed before arriving at this pass. We try to + preserve the annotations here if possible. + -/ + annoatedBorrows : Std.HashSet FVarId := {} -/-- Mark parameters that take a reference as borrow -/ -def initBorrow (ps : Array (Param .impure)) : Array (Param .impure) := - ps.map fun p => { p with borrow := p.type.isPossibleRef } +namespace ParamMap + +@[inline] +def insert (pm : ParamMap) (k : Key) (ps : Array (Param .impure)) : ParamMap := + { pm with map := pm.map.insert k ps } + +@[inline] +def erase (pm : ParamMap) (k : Key) : ParamMap := + { pm with map := pm.map.erase k } + +end ParamMap abbrev InitM := StateRefT ParamMap CompilerM @@ -73,7 +92,12 @@ where match decl.value with | .code code => let exported := isExport (← getEnv) decl.name - modify fun m => m.insert (.decl decl.name) (initParamsIfNotExported exported decl.params) + modify fun m => + { m with + map := m.map.insert (.decl decl.name) (initParamsIfNotExported exported decl.params), + annoatedBorrows := decl.params.foldl (init := m.annoatedBorrows) fun acc p => + if p.borrow then acc.insert p.fvarId else acc + } goCode decl.name code | .extern .. => return () @@ -89,7 +113,12 @@ where goCode (declName : Name) (code : Code .impure) : InitM Unit := do match code with | .jp decl k => - modify fun m => m.insert (.jp declName decl.fvarId) (initParams decl.params) + modify fun m => + { m with + map := m.map.insert (.jp declName decl.fvarId) (initParams decl.params), + annoatedBorrows := decl.params.foldl (init := m.annoatedBorrows) fun acc p => + if p.borrow then acc.insert p.fvarId else acc + } goCode declName decl.value goCode declName k | .cases cs => cs.alts.forM (·.forCodeM (goCode declName)) @@ -105,7 +134,7 @@ partial def apply (decls : Array (Decl .impure)) (map : ParamMap) : CompilerM (A match decl.value with | .code code => let code ← go decl.name code - let newParams ← updateParams decl.params map[ParamMap.Key.decl decl.name]! + let newParams ← updateParams decl.params map.map[ParamMap.Key.decl decl.name]! return { decl with value := .code code, params := newParams } | _ => return decl where @@ -119,7 +148,7 @@ where go (declName : Name) (code : Code .impure) : CompilerM (Code .impure) := do match code with | .jp decl k => - let ps ← updateParams decl.params map[ParamMap.Key.jp declName decl.fvarId]! + let ps ← updateParams decl.params map.map[ParamMap.Key.jp declName decl.fvarId]! let decl ← decl.update decl.type ps (← go declName decl.value) return code.updateFun! decl (← go declName k) | .cases cs => return code.updateAlts! <| ← cs.alts.mapMonoM (·.mapCodeM (go declName)) @@ -166,8 +195,10 @@ inductive OwnReason where | constructorResult (resultFVar : FVarId) /-- Parameter packed into a constructor. -/ | constructorArg (resultFVar : FVarId) - /-- Bidirectional ownership propagation through `oproj`. -/ - | projectionPropagation (resultFVar : FVarId) + /-- Forward ownership propagation through `oproj`. -/ + | forwardProjectionProp (resultFVar : FVarId) + /-- Backward ownership propagation through `oproj`. -/ + | backwardProjectionProp (resultFVar : FVarId) /-- Result of a function application. -/ | functionCallResult (resultFVar : FVarId) /-- Argument to a function whose corresponding parameter is owned. -/ @@ -189,7 +220,8 @@ def OwnReason.toString (reason : OwnReason) : CompilerM String := do | .resetReuse resultFVar => return s!"used in reset reuse {← PP.ppFVar resultFVar}" | .constructorResult resultFVar => return s!"result of ctor call {← PP.ppFVar resultFVar}" | .constructorArg resultFVar => return s!"argument to constructor call {← PP.ppFVar resultFVar}" - | .projectionPropagation resultFVar => return s!"projection propagation {← PP.ppFVar resultFVar}" + | .forwardProjectionProp resultFVar => return s!"fwd projection propagation {← PP.ppFVar resultFVar}" + | .backwardProjectionProp resultFVar => return s!"bkwd projection propagation {← PP.ppFVar resultFVar}" | .functionCallResult resultFVar => return s!"result of function call {← PP.ppFVar resultFVar}" | .functionCallArg resultFVar => return s!"owned function argument {← PP.ppFVar resultFVar}" | .fvarCall resultFVar => return s!"argument to closure call {← PP.ppFVar resultFVar}" @@ -198,6 +230,29 @@ def OwnReason.toString (reason : OwnReason) : CompilerM String := do | .jpArgPropagation jpFVar => return s!"backward propagation from JP {← PP.ppFVar jpFVar}" | .jpTailCallPreservation jpFVar => return s!"JP tail call preservation {← PP.ppFVar jpFVar}" +/-- +Determine whether an `OwnReason` is necessary for correctness (forced) or just an optimization +(not-forced). If we attempt to own a variable that has been previously annotated as borrow for a +non-forced reason we ignore it. +-/ +def OwnReason.isForced (reason : OwnReason) : Bool := + match reason with + -- All of these reasons propagate through ABI decisions and can thus safely be ignored as they + -- will be accounted for by the reference counting pass. + | .constructorArg .. | .functionCallArg .. | .fvarCall .. | .partialApplication .. + | .jpArgPropagation .. + -- If a projection of a value is used in an owned fashion that does not necessarily mean we have + -- to make the value itself owned. Note that this will however prevent potential reset-reuse + -- opportunities. + | .backwardProjectionProp .. => false + -- Results of functions and constructors are naturally owned. + | .constructorResult .. | .functionCallResult .. + -- We cannot pass borrowed values to reset or have borrow annotations destroy tail calls for + -- correctness reasons. + | .resetReuse .. | .tailCallPreservation .. | .jpTailCallPreservation .. + -- If a value is owned and we project from it its projectee is always owned as well. + | .forwardProjectionProp .. => true + /-- Infer the borrowing annotations in a SCC through dataflow analysis. -/ @@ -218,8 +273,11 @@ where ownFVar (fvarId : FVarId) (reason : OwnReason) : InferM Unit := do unless (← get).owned.contains fvarId do - trace[Compiler.inferBorrow] "own {← PP.run <| PP.ppFVar fvarId}: {← reason.toString}" - modify fun s => { s with owned := s.owned.insert fvarId, modified := true } + if !reason.isForced && (← get).paramMap.annoatedBorrows.contains fvarId then + trace[Compiler.inferBorrow] "user annotation blocked owning {← PP.run <| PP.ppFVar fvarId}: {← reason.toString}" + else + trace[Compiler.inferBorrow] "own {← PP.run <| PP.ppFVar fvarId}: {← reason.toString}" + modify fun s => { s with owned := s.owned.insert fvarId, modified := true } ownArg (reason : OwnReason) (a : Arg .impure) : InferM Unit := do a.forFVarM (ownFVar · reason) @@ -240,7 +298,7 @@ where /-- Updates `map[k]` using the current set of `owned` variables. -/ updateParamMap (k : ParamMap.Key) : InferM Unit := do - if let some ps := (← get).paramMap[k]? then + if let some ps := (← get).paramMap.map[k]? then -- This is to ensure linearity over ps in the following code, if you know how to make this -- linear in a nice fashion please make a PR modify fun s => { s with paramMap := s.paramMap.erase k } @@ -255,7 +313,7 @@ where modify fun s => { s with paramMap := s.paramMap.insert k ps } getParamInfo (k : ParamMap.Key) : InferM (Array (Param .impure)) := do - match (← get).paramMap[k]? with + match (← get).paramMap.map[k]? with | some ps => return ps | none => let .decl fn := k | unreachable! @@ -306,8 +364,8 @@ where | .reuse x _ _ args => ownFVar z (.resetReuse z); ownFVar x (.resetReuse z); ownArgsIfParam z args | .ctor _ args => ownFVar z (.constructorResult z); ownArgsIfParam z args | .oproj _ x _ => - if ← isOwned x then ownFVar z (.projectionPropagation z) - if ← isOwned z then ownFVar x (.projectionPropagation z) + if ← isOwned x then ownFVar z (.forwardProjectionProp z) + if ← isOwned z then ownFVar x (.backwardProjectionProp z) | .fap f args => let ps ← getParamInfo (.decl f) ownFVar z (.functionCallResult z) diff --git a/src/Lean/Compiler/LCNF/ToDecl.lean b/src/Lean/Compiler/LCNF/ToDecl.lean index a1aa84e88b..02b57145b5 100644 --- a/src/Lean/Compiler/LCNF/ToDecl.lean +++ b/src/Lean/Compiler/LCNF/ToDecl.lean @@ -159,7 +159,7 @@ def toDecl (declName : Name) : CompilerM (Decl .pure) := do /- Recall that `inlineMatchers` may have exposed `ite`s and `dite`s which are tagged as `[macro_inline]`. -/ let value ← macroInline value return (type, value) - let code ← toLCNF value + let code ← toLCNF value type let mut decl ← if let .fun decl (.return _) := code then eraseFunDecl decl (recursive := false) pure { name := declName, params := decl.params, type, value := .code decl.value, levelParams := info.levelParams, safe, inlineAttr? : Decl .pure } diff --git a/src/Lean/Compiler/LCNF/ToLCNF.lean b/src/Lean/Compiler/LCNF/ToLCNF.lean index 31ddaf3f30..fb7d10e253 100644 --- a/src/Lean/Compiler/LCNF/ToLCNF.lean +++ b/src/Lean/Compiler/LCNF/ToLCNF.lean @@ -206,12 +206,18 @@ structure Context where eventually. -/ ignoreNoncomputable : Bool := false + /-- + The expected type of the expression that is currently being handled if available. This type is + only used to propagate potential borrow annotations as they are not propagated everywhere by the + elaborator. + -/ + expectedType : Option Expr structure State where /-- Local context containing the original Lean types (not LCNF ones). -/ lctx : LocalContext := {} /-- Cache from Lean regular expression to LCNF argument. -/ - cache : PHashMap Expr (Arg .pure) := {} + cache : PHashMap (Expr × Option Expr) (Arg .pure) := {} /-- Determines whether caching has been disabled due to finding a use of a constant marked with `never_extract`. @@ -265,8 +271,18 @@ def toCode (result : Arg .pure) : M (Code .pure) := do let fvarId ← mkAuxLetDecl .erased seqToCode (← get).seq (.return fvarId) -def run (x : M α) : CompilerM α := - x.run {} |>.run' {} +def run (expectedType : Expr) (x : M α) : CompilerM α := + x.run { expectedType } |>.run' {} + +@[inline] +def withExpectedType (e : Option Expr) (x : M α) : M α := + withReader (fun ctx => { ctx with expectedType := e }) do + x + +@[inline] +def withoutExpectedType (x : M α) : M α := + withExpectedType none do + x /-- Return true iff `type` is `Sort _` or `As → Sort _`. @@ -340,9 +356,9 @@ def cleanupBinderName (binderName : Name) : CompilerM Name := return binderName /-- Create a new local declaration using a Lean regular type. -/ -def mkParam (binderName : Name) (type : Expr) : M (Param .pure) := do +def mkParam (binderName : Name) (type : Expr) (borrow : Bool := isMarkedBorrowed type) : + M (Param .pure) := do let binderName ← cleanupBinderName binderName - let borrow := isMarkedBorrowed type let type' ← toLCNFType type let param ← LCNF.mkParam binderName type' borrow modify fun s => { s with lctx := s.lctx.mkLocalDecl param.fvarId binderName type .default } @@ -361,16 +377,22 @@ def mkLetDecl (binderName : Name) (type : Expr) (value : Expr) (type' : Expr) (a } return letDecl -def visitLambda (e : Expr) : M (Array (Param .pure) × Expr) := - go e #[] #[] +def visitLambda (e : Expr) : M (Array (Param .pure) × Expr × Option Expr) := do + go e #[] #[] (← read).expectedType where - go (e : Expr) (xs : Array Expr) (ps : Array (Param .pure)) := do + go (e : Expr) (xs : Array Expr) (ps : Array (Param .pure)) (eType? : Option Expr) := do if let .lam binderName type body _ := e then let type := type.instantiateRev xs - let p ← mkParam binderName type - go body (xs.push p.toExpr) (ps.push p) + if let some (.forallE _ type' eType _) := eType? then + let borrow := isMarkedBorrowed type || isMarkedBorrowed type' + let p ← mkParam binderName type borrow + -- no need to instantiate eType, we only ever check it for `isMarkedBorrowed` + go body (xs.push p.toExpr) (ps.push p) (some eType) + else + let p ← mkParam binderName type + go body (xs.push p.toExpr) (ps.push p) none else - return (ps, e.instantiateRev xs) + return (ps, e.instantiateRev xs, eType?.map (·.instantiateRev xs)) def visitBoundedLambda (e : Expr) (n : Nat) : M (Array (Param .pure) × Expr) := go e n #[] #[] @@ -446,11 +468,12 @@ Put the given expression in `LCNF`. - Eta-expand applications of declarations that satisfy `shouldEtaExpand`. - Put computationally relevant expressions in A-normal form. -/ -partial def toLCNF (e : Expr) : CompilerM (Code .pure) := do - run do toCode (← visit e) +partial def toLCNF (e : Expr) (eType : Expr) : CompilerM (Code .pure) := do + run eType do toCode (← visit e) where visitCore (e : Expr) : M (Arg .pure) := withIncRecDepth do - if let some arg := (← get).cache.find? e then + let eType? := (← read).expectedType + if let some arg := (← get).cache.find? (e, eType?) then return arg let r : Arg .pure ← match e with | .app .. => visitApp e @@ -462,7 +485,7 @@ where | .lit lit => visitLit lit | .fvar fvarId => if (← get).toAny.contains fvarId then pure .erased else pure (.fvar fvarId) | .forallE .. | .mvar .. | .bvar .. | .sort .. => unreachable! - modify fun s => if s.shouldCache then { s with cache := s.cache.insert e r } else s + modify fun s => if s.shouldCache then { s with cache := s.cache.insert (e, eType?) r } else s return r visit (e : Expr) : M (Arg .pure) := withIncRecDepth do @@ -505,7 +528,7 @@ where visitAppDefaultConst (f : Expr) (args : Array Expr) : M (Arg .pure) := do let env ← getEnv let .const declName us ← CSimp.replaceConstant env f | unreachable! - let args ← args.mapM visitAppArg + let args ← args.mapM (withoutExpectedType do visitAppArg ·) if hasNeverExtractAttribute env declName then modify fun s => {s with shouldCache := false } letValueToArg <| .const declName us args @@ -549,10 +572,12 @@ where let altType ← c.inferType return (altType, .default c) | .ctor ctorName numParams => - let mut (ps, e) ← visitBoundedLambda e numParams + let mut (ps, e) ← withoutExpectedType do + visitBoundedLambda e numParams if ps.size < numParams then e ← etaExpandN e (numParams - ps.size) - let (ps', e') ← ToLCNF.visitLambda e + let (ps', e', _) ← withoutExpectedType do + ToLCNF.visitLambda e ps := ps ++ ps' e := e' /- @@ -609,11 +634,17 @@ where fieldArgs := fieldArgs.push fieldArg return fieldArgs let f := args[casesInfo.altsRange.lower]! - let result ← visit (mkAppN f fieldArgs) - mkOverApplication result args casesInfo.arity + let arity := casesInfo.arity + if args.size == arity then + visit (mkAppN f fieldArgs) + else + withoutExpectedType do + let result ← visit (mkAppN f fieldArgs) + mkOverApplication result args casesInfo.arity else let mut alts := #[] - let discr ← visitAppArg args[casesInfo.discrPos]! + let discr ← withoutExpectedType do + visitAppArg args[casesInfo.discrPos]! let discrFVarId ← match discr with | .fvar discrFVarId => pure discrFVarId | .erased | .type .. => mkAuxLetDecl .erased @@ -625,9 +656,11 @@ where let auxDecl ← mkAuxParam resultType pushElement (.cases auxDecl cases) let result := .fvar auxDecl.fvarId - mkOverApplication result args casesInfo.arity + withoutExpectedType do + mkOverApplication result args casesInfo.arity visitCtor (arity : Nat) (e : Expr) : M (Arg .pure) := + withoutExpectedType do etaIfUnderApplied e arity do let f := e.getAppFn let args := e.getAppArgs @@ -638,7 +671,7 @@ where -- We can rely on `toMono` erasing ctor params eventually; we do not do so here so that type -- inference on the value is preserved. withReader (fun ctx => - { ignoreNoncomputable := ctx.ignoreNoncomputable || ctorInfo?.any (idx < ·.numParams) }) do + { ctx with ignoreNoncomputable := ctx.ignoreNoncomputable || ctorInfo?.any (idx < ·.numParams) }) do visitAppArg arg if hasNeverExtractAttribute env declName then modify fun s => {s with shouldCache := false } @@ -646,6 +679,7 @@ where visitQuotLift (e : Expr) : M (Arg .pure) := do let arity := 6 + withoutExpectedType do etaIfUnderApplied e arity do let mut args := e.getAppArgs let α ← visitAppArg args[0]! @@ -661,6 +695,7 @@ where visitEqRec (e : Expr) : M (Arg .pure) := let arity := 6 + withoutExpectedType do etaIfUnderApplied e arity do let args := e.getAppArgs let minor := if e.isAppOf ``Eq.rec || e.isAppOf ``Eq.ndrec then args[3]! else args[5]! @@ -669,6 +704,7 @@ where visitHEqRec (e : Expr) : M (Arg .pure) := let arity := 7 + withoutExpectedType do etaIfUnderApplied e arity do let args := e.getAppArgs let minor := if e.isAppOf ``HEq.rec || e.isAppOf ``HEq.ndrec then args[3]! else args[6]! @@ -677,18 +713,21 @@ where visitFalseRec (e : Expr) : M (Arg .pure) := let arity := 2 + withoutExpectedType do etaIfUnderApplied e arity do let type ← toLCNFType (← liftMetaM do Meta.inferType e) mkUnreachable type visitLcUnreachable (e : Expr) : M (Arg .pure) := let arity := 1 + withoutExpectedType do etaIfUnderApplied e arity do let type ← toLCNFType (← liftMetaM do Meta.inferType e) mkUnreachable type visitAndIffRecCore (e : Expr) (minorPos : Nat) : M (Arg .pure) := let arity := 5 + withoutExpectedType do etaIfUnderApplied e arity do let args := e.getAppArgs let ha := mkLcProof args[0]! -- We should not use `lcErased` here since we use it to create a pre-LCNF Expr. @@ -701,6 +740,7 @@ where let .const declName _ := e.getAppFn | unreachable! let info := getNoConfusionInfo (← getEnv) declName let typeName := declName.getPrefix + withoutExpectedType do etaIfUnderApplied e info.arity do let args := e.getAppArgs let visitMajor (numNonPropFields : Nat) := do @@ -786,10 +826,10 @@ where e.withApp visitAppDefaultConst else e.withApp fun f args => do - match (← visit f) with + match (← withoutExpectedType do visit f) with | .erased | .type .. => return .erased | .fvar fvarId => - let args ← args.mapM visitAppArg + let args ← args.mapM (withoutExpectedType do visitAppArg ·) letValueToArg <| .fvar fvarId args visitLambda (e : Expr) : M (Arg .pure) := do @@ -821,8 +861,9 @@ where visit b else let funDecl ← withNewScope do - let (ps, e) ← ToLCNF.visitLambda e - let e ← visit e + let (ps, e, eType?) ← ToLCNF.visitLambda e + let e ← withExpectedType eType? do + visit e let c ← toCode e mkAuxFunDecl ps c pushElement (.fun funDecl) @@ -837,7 +878,7 @@ where let projExpr ← liftMetaM <| Meta.mkProjection e structInfo.fieldNames[i]! visitApp projExpr else - match (← visit e) with + match (← withoutExpectedType do visit e) with | .erased | .type .. => return .erased | .fvar fvarId => letValueToArg <| .proj s i fvarId @@ -850,7 +891,9 @@ where visitLet body (xs.push value) else let type' ← toLCNFType type - let letDecl ← mkLetDecl binderName type value type' (← visit value) + let value' ← withExpectedType type' do + visit value + let letDecl ← mkLetDecl binderName type value type' value' visitLet body (xs.push (.fvar letDecl.fvarId)) | _ => let e := e.instantiateRev xs diff --git a/tests/elab/lcnf_borrow_expected_type.lean b/tests/elab/lcnf_borrow_expected_type.lean new file mode 100644 index 0000000000..fef9a3d711 --- /dev/null +++ b/tests/elab/lcnf_borrow_expected_type.lean @@ -0,0 +1,257 @@ +module + +public section + +/-! +Tests that borrow annotations from declaration/let-binding types survive LCNF conversion. +The `@&` annotations live in the forall type, not in the lambda binders, and are based on the +(rather brittle) mdata so LCNF must infer them to a degree. +-/ + +/-- +trace: [Compiler.saveBase] size: 1 + def borrowTop @&xs : Nat := + let _x.1 := @List.lengthTR _ xs; + return _x.1 +-/ +#guard_msgs in +set_option trace.Compiler.saveBase true in +def borrowTop (xs : @& List Nat) : Nat := xs.length + +/-- +trace: [Compiler.saveBase] size: 3 + def borrowMixed n @&xs m : Nat := + let _x.1 := @List.lengthTR _ xs; + let _x.2 := Nat.add n _x.1; + let _x.3 := Nat.add _x.2 m; + return _x.3 +-/ +#guard_msgs in +set_option trace.Compiler.saveBase true in +def borrowMixed (n : Nat) (xs : @& List Nat) (m : Nat) : Nat := + n + xs.length + m + +/-- +trace: [Compiler.saveBase] size: 5 + def borrowLet n xs ys : Nat := + fun f @&ys : Nat := + let _x.1 := @List.lengthTR _ ys; + let _x.2 := Nat.add _x.1 n; + return _x.2; + let _x.3 := f xs; + let _x.4 := f ys; + let _x.5 := Nat.add _x.3 _x.4; + return _x.5 +[Compiler.lambdaLifting] size: 2 + def borrowLet._lam_0 n @&ys : Nat := + let _x.1 := List.lengthTR._redArg ys; + let _x.2 := Nat.add _x.1 n; + return _x.2 +[Compiler.lambdaLifting] size: 4 + def borrowLet n xs ys : Nat := + let f := borrowLet._lam_0 n; + let _x.1 := f xs; + let _x.2 := f ys; + let _x.3 := Nat.add _x.1 _x.2; + return _x.3 +-/ +#guard_msgs in +set_option trace.Compiler.saveBase true in +set_option trace.Compiler.lambdaLifting true in +def borrowLet (n : Nat) (xs ys : List Nat) : Nat := + let f : (@& List Nat) → Nat := fun ys => ys.length + n + f xs + f ys + +/-- +trace: [Compiler.saveBase] size: 2 + def applyTwice f @&a.1 : Nat := + let _x.2 := f a.1; + let _x.3 := f _x.2; + return _x.3 +-/ +#guard_msgs in +set_option trace.Compiler.saveBase true in +def applyTwice (f : Nat → Nat) : (@& Nat) → Nat := + let g := f ∘ f + g + +structure Ctx where + values : List Nat + +abbrev MyReaderM (α : Type) := (@& Ctx) → α + +@[inline] +def MyReaderM.bind (f : MyReaderM α) (g : α → MyReaderM β) : MyReaderM β := + fun ctx => g (f ctx) ctx + +instance : Monad MyReaderM where + pure a := fun _ => a + bind := MyReaderM.bind + +@[inline] def MyReaderM.read : MyReaderM Ctx := fun ctx => ctx + +/-- +trace: [Compiler.saveBase] size: 2 + def withMyReader α f x @&ctx : α := + let _x.1 := f ctx; + let _x.2 := x _x.1; + return _x.2 +-/ +#guard_msgs in +set_option trace.Compiler.saveBase true in +@[noinline] +def withMyReader (f : Ctx → Ctx) (x : MyReaderM α) : MyReaderM α := + fun ctx => x (f ctx) + +/-- +trace: [Compiler.saveBase] size: 6 + def getLength other @&a.1 : Nat := + fun _f.2 ctx : Ctx := + let _x.3 := ctx # 0; + let _x.4 := @List.appendTR _ _x.3 other; + let _x.5 := Ctx.mk _x.4; + return _x.5; + fun _f.6 _y.7 : Nat := + let _x.8 := _y.7 # 0; + let _x.9 := @List.lengthTR _ _x.8; + return _x.9; + let _x.10 := @withMyReader _ _f.2 _f.6 a.1; + return _x.10 +-/ +#guard_msgs in +set_option trace.Compiler.saveBase true in +def getLength (other : List Nat) : MyReaderM Nat := do + withMyReader (fun ctx => { ctx with values := ctx.values ++ other }) do + let ctx ← MyReaderM.read + return ctx.values.length + +structure Pair where + fst : List Nat + snd : List Nat + +structure Quad where + left : Pair + right : Pair + +/-- Packs `xs` into a constructor → parameter inferred as owned. -/ +@[noinline] def wrap (xs : List Nat) : List (List Nat) := [xs] + +/-- Only traverses → parameter stays borrowed. -/ +@[noinline] def measuree (xs : List Nat) : Nat := xs.length + +/-- +trace: [Compiler.explicitRc] size: 22 + def cascadeDemo @&t : tobj := + let left := oproj[0] t; + let right := oproj[1] t; + let fst := oproj[0] left; + let snd := oproj[1] left; + let fst := oproj[0] right; + let snd := oproj[1] right; + inc fst; + let _x.1 := wrap fst; + let res := List.lengthTR._redArg _x.1; + dec _x.1; + let _x.2 := measuree snd; + let _x.3 := Nat.add res _x.2; + dec _x.2; + dec res; + let _x.4 := measuree fst; + let _x.5 := Nat.add _x.3 _x.4; + dec _x.4; + dec _x.3; + let _x.6 := measuree snd; + let _x.7 := Nat.add _x.5 _x.6; + dec _x.6; + dec _x.5; + return _x.7 +[Compiler.explicitRc] size: 2 + def cascadeDemo._boxed t : tobj := + let res := cascadeDemo t; + dec t; + return res +-/ +#guard_msgs in +set_option trace.Compiler.explicitRc true in +def cascadeDemo (t : @&Quad) : Nat := + let l := t.left + let r := t.right + let res := (wrap l.fst).length + res + measuree l.snd + measuree r.fst + measuree r.snd + +/-- +trace: [Compiler.explicitRc] size: 33 + def cascadeDemo' t : tobj := + let left := oproj[0] t; + inc left; + let right := oproj[1] t; + inc right; + dec t; + let fst := oproj[0] left; + inc fst; + let snd := oproj[1] left; + inc snd; + dec left; + let fst := oproj[0] right; + inc fst; + let snd := oproj[1] right; + inc snd; + dec right; + let _x.1 := wrap fst; + let res := List.lengthTR._redArg _x.1; + dec _x.1; + let _x.2 := measuree snd; + dec snd; + let _x.3 := Nat.add res _x.2; + dec _x.2; + dec res; + let _x.4 := measuree fst; + dec fst; + let _x.5 := Nat.add _x.3 _x.4; + dec _x.4; + dec _x.3; + let _x.6 := measuree snd; + dec snd; + let _x.7 := Nat.add _x.5 _x.6; + dec _x.6; + dec _x.5; + return _x.7 +-/ +#guard_msgs in +set_option trace.Compiler.explicitRc true in +def cascadeDemo' (t : Quad) : Nat := + let l := t.left + let r := t.right + let res := (wrap l.fst).length + res + measuree l.snd + measuree r.fst + measuree r.snd + +@[noinline] +public def mkNewProd (x : Prod Nat Nat) (a : Nat) := { x with fst := a } + +/-- +trace: [Compiler.explicitRc] size: 15 + def preserveTailCall x a : tobj := + let zero := 0; + let isZero := Nat.decEq a zero; + cases isZero : tobj + | Bool.true => + dec a; + let fst := oproj[0] x; + inc fst; + dec x; + return fst + | Bool.false => + let one := 1; + let n.1 := Nat.sub a one; + dec a; + inc n.1; + let _x.2 := mkNewProd x n.1; + let _x.3 := preserveTailCall _x.2 n.1; + return _x.3 +-/ +#guard_msgs in +set_option trace.Compiler.explicitRc true in +def preserveTailCall (x : @&Prod Nat Nat) (a : Nat) : Nat := + match a with + | 0 => x.fst + | a + 1 => preserveTailCall (mkNewProd x a) a