feat: respect user provided borrow annotations (#12830)
This PR enables support for respecting user provided borrow annotations. This allows user to mark arguments of their definitions or local functions with `(x : @&Ty)` and have the borrow inference try its best to preserve this annotation, thus potentially reducing RC pressure. Note that in some cases this might not be possible. For example, the compiler prioritizes preserving tail calls over preserving borrow annotations. A precise reasoning of why the compiler chose to make its inference decisions can be obtained with `trace.Compiler.inferBorrow`. The implementation consists of two parts: 1. A propagator in ToLCNF. This is required because the elaborator does not place the borrow annotations in the function binders themselves but just in type annotations of let binders/global declarations while LCNF expects them in the lambda variable binders themselves. Thus ToLCNF now implements a (very weak but strong enough for this purpose) propagator of the borrow annotations of a type annotation into the variable binders of the term affected by the annotations 2. A weakening of the InferBorrow heuristic. It now has a set of "forced" and "non-forced" reasons to mark a variable as owned instead of borrowed. If a variable is user annotated as borrowed, it will only be marked as owned if the reason is a forced one, e.g. preservation of tail calls.
This commit is contained in:
parent
0b9ad3fb8d
commit
511be304d7
5 changed files with 406 additions and 48 deletions
|
|
@ -93,7 +93,7 @@ where
|
|||
match type with
|
||||
| .forallE _ d b _ =>
|
||||
let d := d.instantiateRev xs
|
||||
let p ← mkAuxParam d
|
||||
let p ← mkAuxParam d (isMarkedBorrowed d)
|
||||
go b (xs.push (.fvar p.fvarId)) (ps.push p)
|
||||
| _ =>
|
||||
let type := type.instantiateRev xs
|
||||
|
|
|
|||
|
|
@ -41,6 +41,11 @@ For performance we:
|
|||
particular when `f` is partially applied we ensure that all arguments are owned.
|
||||
- When passing a parameter into a constructor we ensure it is passed as owned so we do not have
|
||||
to `inc` before calling the constructor.
|
||||
|
||||
Furthermore, the performance related heuristics will be ignored if there is a user-defined
|
||||
borrow annotation. This allows the user to override the ABI of their function in exchange for
|
||||
potentially worse code in the function that is being analyzed. Doing so can benefit other functions
|
||||
that call the current one and thus reduce overall RC pressure.
|
||||
-/
|
||||
|
||||
namespace Lean.Compiler.LCNF
|
||||
|
|
@ -56,11 +61,25 @@ inductive Key where
|
|||
|
||||
end ParamMap
|
||||
|
||||
abbrev ParamMap := Std.HashMap ParamMap.Key (Array (Param .impure))
|
||||
structure ParamMap where
|
||||
map : Std.HashMap ParamMap.Key (Array (Param .impure)) := {}
|
||||
/--
|
||||
The set of fvars that were already annotated as borrowed before arriving at this pass. We try to
|
||||
preserve the annotations here if possible.
|
||||
-/
|
||||
annoatedBorrows : Std.HashSet FVarId := {}
|
||||
|
||||
/-- Mark parameters that take a reference as borrow -/
|
||||
def initBorrow (ps : Array (Param .impure)) : Array (Param .impure) :=
|
||||
ps.map fun p => { p with borrow := p.type.isPossibleRef }
|
||||
namespace ParamMap
|
||||
|
||||
@[inline]
|
||||
def insert (pm : ParamMap) (k : Key) (ps : Array (Param .impure)) : ParamMap :=
|
||||
{ pm with map := pm.map.insert k ps }
|
||||
|
||||
@[inline]
|
||||
def erase (pm : ParamMap) (k : Key) : ParamMap :=
|
||||
{ pm with map := pm.map.erase k }
|
||||
|
||||
end ParamMap
|
||||
|
||||
abbrev InitM := StateRefT ParamMap CompilerM
|
||||
|
||||
|
|
@ -73,7 +92,12 @@ where
|
|||
match decl.value with
|
||||
| .code code =>
|
||||
let exported := isExport (← getEnv) decl.name
|
||||
modify fun m => m.insert (.decl decl.name) (initParamsIfNotExported exported decl.params)
|
||||
modify fun m =>
|
||||
{ m with
|
||||
map := m.map.insert (.decl decl.name) (initParamsIfNotExported exported decl.params),
|
||||
annoatedBorrows := decl.params.foldl (init := m.annoatedBorrows) fun acc p =>
|
||||
if p.borrow then acc.insert p.fvarId else acc
|
||||
}
|
||||
goCode decl.name code
|
||||
| .extern .. => return ()
|
||||
|
||||
|
|
@ -89,7 +113,12 @@ where
|
|||
goCode (declName : Name) (code : Code .impure) : InitM Unit := do
|
||||
match code with
|
||||
| .jp decl k =>
|
||||
modify fun m => m.insert (.jp declName decl.fvarId) (initParams decl.params)
|
||||
modify fun m =>
|
||||
{ m with
|
||||
map := m.map.insert (.jp declName decl.fvarId) (initParams decl.params),
|
||||
annoatedBorrows := decl.params.foldl (init := m.annoatedBorrows) fun acc p =>
|
||||
if p.borrow then acc.insert p.fvarId else acc
|
||||
}
|
||||
goCode declName decl.value
|
||||
goCode declName k
|
||||
| .cases cs => cs.alts.forM (·.forCodeM (goCode declName))
|
||||
|
|
@ -105,7 +134,7 @@ partial def apply (decls : Array (Decl .impure)) (map : ParamMap) : CompilerM (A
|
|||
match decl.value with
|
||||
| .code code =>
|
||||
let code ← go decl.name code
|
||||
let newParams ← updateParams decl.params map[ParamMap.Key.decl decl.name]!
|
||||
let newParams ← updateParams decl.params map.map[ParamMap.Key.decl decl.name]!
|
||||
return { decl with value := .code code, params := newParams }
|
||||
| _ => return decl
|
||||
where
|
||||
|
|
@ -119,7 +148,7 @@ where
|
|||
go (declName : Name) (code : Code .impure) : CompilerM (Code .impure) := do
|
||||
match code with
|
||||
| .jp decl k =>
|
||||
let ps ← updateParams decl.params map[ParamMap.Key.jp declName decl.fvarId]!
|
||||
let ps ← updateParams decl.params map.map[ParamMap.Key.jp declName decl.fvarId]!
|
||||
let decl ← decl.update decl.type ps (← go declName decl.value)
|
||||
return code.updateFun! decl (← go declName k)
|
||||
| .cases cs => return code.updateAlts! <| ← cs.alts.mapMonoM (·.mapCodeM (go declName))
|
||||
|
|
@ -166,8 +195,10 @@ inductive OwnReason where
|
|||
| constructorResult (resultFVar : FVarId)
|
||||
/-- Parameter packed into a constructor. -/
|
||||
| constructorArg (resultFVar : FVarId)
|
||||
/-- Bidirectional ownership propagation through `oproj`. -/
|
||||
| projectionPropagation (resultFVar : FVarId)
|
||||
/-- Forward ownership propagation through `oproj`. -/
|
||||
| forwardProjectionProp (resultFVar : FVarId)
|
||||
/-- Backward ownership propagation through `oproj`. -/
|
||||
| backwardProjectionProp (resultFVar : FVarId)
|
||||
/-- Result of a function application. -/
|
||||
| functionCallResult (resultFVar : FVarId)
|
||||
/-- Argument to a function whose corresponding parameter is owned. -/
|
||||
|
|
@ -189,7 +220,8 @@ def OwnReason.toString (reason : OwnReason) : CompilerM String := do
|
|||
| .resetReuse resultFVar => return s!"used in reset reuse {← PP.ppFVar resultFVar}"
|
||||
| .constructorResult resultFVar => return s!"result of ctor call {← PP.ppFVar resultFVar}"
|
||||
| .constructorArg resultFVar => return s!"argument to constructor call {← PP.ppFVar resultFVar}"
|
||||
| .projectionPropagation resultFVar => return s!"projection propagation {← PP.ppFVar resultFVar}"
|
||||
| .forwardProjectionProp resultFVar => return s!"fwd projection propagation {← PP.ppFVar resultFVar}"
|
||||
| .backwardProjectionProp resultFVar => return s!"bkwd projection propagation {← PP.ppFVar resultFVar}"
|
||||
| .functionCallResult resultFVar => return s!"result of function call {← PP.ppFVar resultFVar}"
|
||||
| .functionCallArg resultFVar => return s!"owned function argument {← PP.ppFVar resultFVar}"
|
||||
| .fvarCall resultFVar => return s!"argument to closure call {← PP.ppFVar resultFVar}"
|
||||
|
|
@ -198,6 +230,29 @@ def OwnReason.toString (reason : OwnReason) : CompilerM String := do
|
|||
| .jpArgPropagation jpFVar => return s!"backward propagation from JP {← PP.ppFVar jpFVar}"
|
||||
| .jpTailCallPreservation jpFVar => return s!"JP tail call preservation {← PP.ppFVar jpFVar}"
|
||||
|
||||
/--
|
||||
Determine whether an `OwnReason` is necessary for correctness (forced) or just an optimization
|
||||
(not-forced). If we attempt to own a variable that has been previously annotated as borrow for a
|
||||
non-forced reason we ignore it.
|
||||
-/
|
||||
def OwnReason.isForced (reason : OwnReason) : Bool :=
|
||||
match reason with
|
||||
-- All of these reasons propagate through ABI decisions and can thus safely be ignored as they
|
||||
-- will be accounted for by the reference counting pass.
|
||||
| .constructorArg .. | .functionCallArg .. | .fvarCall .. | .partialApplication ..
|
||||
| .jpArgPropagation ..
|
||||
-- If a projection of a value is used in an owned fashion that does not necessarily mean we have
|
||||
-- to make the value itself owned. Note that this will however prevent potential reset-reuse
|
||||
-- opportunities.
|
||||
| .backwardProjectionProp .. => false
|
||||
-- Results of functions and constructors are naturally owned.
|
||||
| .constructorResult .. | .functionCallResult ..
|
||||
-- We cannot pass borrowed values to reset or have borrow annotations destroy tail calls for
|
||||
-- correctness reasons.
|
||||
| .resetReuse .. | .tailCallPreservation .. | .jpTailCallPreservation ..
|
||||
-- If a value is owned and we project from it its projectee is always owned as well.
|
||||
| .forwardProjectionProp .. => true
|
||||
|
||||
/--
|
||||
Infer the borrowing annotations in a SCC through dataflow analysis.
|
||||
-/
|
||||
|
|
@ -218,8 +273,11 @@ where
|
|||
|
||||
ownFVar (fvarId : FVarId) (reason : OwnReason) : InferM Unit := do
|
||||
unless (← get).owned.contains fvarId do
|
||||
trace[Compiler.inferBorrow] "own {← PP.run <| PP.ppFVar fvarId}: {← reason.toString}"
|
||||
modify fun s => { s with owned := s.owned.insert fvarId, modified := true }
|
||||
if !reason.isForced && (← get).paramMap.annoatedBorrows.contains fvarId then
|
||||
trace[Compiler.inferBorrow] "user annotation blocked owning {← PP.run <| PP.ppFVar fvarId}: {← reason.toString}"
|
||||
else
|
||||
trace[Compiler.inferBorrow] "own {← PP.run <| PP.ppFVar fvarId}: {← reason.toString}"
|
||||
modify fun s => { s with owned := s.owned.insert fvarId, modified := true }
|
||||
|
||||
ownArg (reason : OwnReason) (a : Arg .impure) : InferM Unit := do
|
||||
a.forFVarM (ownFVar · reason)
|
||||
|
|
@ -240,7 +298,7 @@ where
|
|||
|
||||
/-- Updates `map[k]` using the current set of `owned` variables. -/
|
||||
updateParamMap (k : ParamMap.Key) : InferM Unit := do
|
||||
if let some ps := (← get).paramMap[k]? then
|
||||
if let some ps := (← get).paramMap.map[k]? then
|
||||
-- This is to ensure linearity over ps in the following code, if you know how to make this
|
||||
-- linear in a nice fashion please make a PR
|
||||
modify fun s => { s with paramMap := s.paramMap.erase k }
|
||||
|
|
@ -255,7 +313,7 @@ where
|
|||
modify fun s => { s with paramMap := s.paramMap.insert k ps }
|
||||
|
||||
getParamInfo (k : ParamMap.Key) : InferM (Array (Param .impure)) := do
|
||||
match (← get).paramMap[k]? with
|
||||
match (← get).paramMap.map[k]? with
|
||||
| some ps => return ps
|
||||
| none =>
|
||||
let .decl fn := k | unreachable!
|
||||
|
|
@ -306,8 +364,8 @@ where
|
|||
| .reuse x _ _ args => ownFVar z (.resetReuse z); ownFVar x (.resetReuse z); ownArgsIfParam z args
|
||||
| .ctor _ args => ownFVar z (.constructorResult z); ownArgsIfParam z args
|
||||
| .oproj _ x _ =>
|
||||
if ← isOwned x then ownFVar z (.projectionPropagation z)
|
||||
if ← isOwned z then ownFVar x (.projectionPropagation z)
|
||||
if ← isOwned x then ownFVar z (.forwardProjectionProp z)
|
||||
if ← isOwned z then ownFVar x (.backwardProjectionProp z)
|
||||
| .fap f args =>
|
||||
let ps ← getParamInfo (.decl f)
|
||||
ownFVar z (.functionCallResult z)
|
||||
|
|
|
|||
|
|
@ -159,7 +159,7 @@ def toDecl (declName : Name) : CompilerM (Decl .pure) := do
|
|||
/- Recall that `inlineMatchers` may have exposed `ite`s and `dite`s which are tagged as `[macro_inline]`. -/
|
||||
let value ← macroInline value
|
||||
return (type, value)
|
||||
let code ← toLCNF value
|
||||
let code ← toLCNF value type
|
||||
let mut decl ← if let .fun decl (.return _) := code then
|
||||
eraseFunDecl decl (recursive := false)
|
||||
pure { name := declName, params := decl.params, type, value := .code decl.value, levelParams := info.levelParams, safe, inlineAttr? : Decl .pure }
|
||||
|
|
|
|||
|
|
@ -206,12 +206,18 @@ structure Context where
|
|||
eventually.
|
||||
-/
|
||||
ignoreNoncomputable : Bool := false
|
||||
/--
|
||||
The expected type of the expression that is currently being handled if available. This type is
|
||||
only used to propagate potential borrow annotations as they are not propagated everywhere by the
|
||||
elaborator.
|
||||
-/
|
||||
expectedType : Option Expr
|
||||
|
||||
structure State where
|
||||
/-- Local context containing the original Lean types (not LCNF ones). -/
|
||||
lctx : LocalContext := {}
|
||||
/-- Cache from Lean regular expression to LCNF argument. -/
|
||||
cache : PHashMap Expr (Arg .pure) := {}
|
||||
cache : PHashMap (Expr × Option Expr) (Arg .pure) := {}
|
||||
/--
|
||||
Determines whether caching has been disabled due to finding a use of
|
||||
a constant marked with `never_extract`.
|
||||
|
|
@ -265,8 +271,18 @@ def toCode (result : Arg .pure) : M (Code .pure) := do
|
|||
let fvarId ← mkAuxLetDecl .erased
|
||||
seqToCode (← get).seq (.return fvarId)
|
||||
|
||||
def run (x : M α) : CompilerM α :=
|
||||
x.run {} |>.run' {}
|
||||
def run (expectedType : Expr) (x : M α) : CompilerM α :=
|
||||
x.run { expectedType } |>.run' {}
|
||||
|
||||
@[inline]
|
||||
def withExpectedType (e : Option Expr) (x : M α) : M α :=
|
||||
withReader (fun ctx => { ctx with expectedType := e }) do
|
||||
x
|
||||
|
||||
@[inline]
|
||||
def withoutExpectedType (x : M α) : M α :=
|
||||
withExpectedType none do
|
||||
x
|
||||
|
||||
/--
|
||||
Return true iff `type` is `Sort _` or `As → Sort _`.
|
||||
|
|
@ -340,9 +356,9 @@ def cleanupBinderName (binderName : Name) : CompilerM Name :=
|
|||
return binderName
|
||||
|
||||
/-- Create a new local declaration using a Lean regular type. -/
|
||||
def mkParam (binderName : Name) (type : Expr) : M (Param .pure) := do
|
||||
def mkParam (binderName : Name) (type : Expr) (borrow : Bool := isMarkedBorrowed type) :
|
||||
M (Param .pure) := do
|
||||
let binderName ← cleanupBinderName binderName
|
||||
let borrow := isMarkedBorrowed type
|
||||
let type' ← toLCNFType type
|
||||
let param ← LCNF.mkParam binderName type' borrow
|
||||
modify fun s => { s with lctx := s.lctx.mkLocalDecl param.fvarId binderName type .default }
|
||||
|
|
@ -361,16 +377,22 @@ def mkLetDecl (binderName : Name) (type : Expr) (value : Expr) (type' : Expr) (a
|
|||
}
|
||||
return letDecl
|
||||
|
||||
def visitLambda (e : Expr) : M (Array (Param .pure) × Expr) :=
|
||||
go e #[] #[]
|
||||
def visitLambda (e : Expr) : M (Array (Param .pure) × Expr × Option Expr) := do
|
||||
go e #[] #[] (← read).expectedType
|
||||
where
|
||||
go (e : Expr) (xs : Array Expr) (ps : Array (Param .pure)) := do
|
||||
go (e : Expr) (xs : Array Expr) (ps : Array (Param .pure)) (eType? : Option Expr) := do
|
||||
if let .lam binderName type body _ := e then
|
||||
let type := type.instantiateRev xs
|
||||
let p ← mkParam binderName type
|
||||
go body (xs.push p.toExpr) (ps.push p)
|
||||
if let some (.forallE _ type' eType _) := eType? then
|
||||
let borrow := isMarkedBorrowed type || isMarkedBorrowed type'
|
||||
let p ← mkParam binderName type borrow
|
||||
-- no need to instantiate eType, we only ever check it for `isMarkedBorrowed`
|
||||
go body (xs.push p.toExpr) (ps.push p) (some eType)
|
||||
else
|
||||
let p ← mkParam binderName type
|
||||
go body (xs.push p.toExpr) (ps.push p) none
|
||||
else
|
||||
return (ps, e.instantiateRev xs)
|
||||
return (ps, e.instantiateRev xs, eType?.map (·.instantiateRev xs))
|
||||
|
||||
def visitBoundedLambda (e : Expr) (n : Nat) : M (Array (Param .pure) × Expr) :=
|
||||
go e n #[] #[]
|
||||
|
|
@ -446,11 +468,12 @@ Put the given expression in `LCNF`.
|
|||
- Eta-expand applications of declarations that satisfy `shouldEtaExpand`.
|
||||
- Put computationally relevant expressions in A-normal form.
|
||||
-/
|
||||
partial def toLCNF (e : Expr) : CompilerM (Code .pure) := do
|
||||
run do toCode (← visit e)
|
||||
partial def toLCNF (e : Expr) (eType : Expr) : CompilerM (Code .pure) := do
|
||||
run eType do toCode (← visit e)
|
||||
where
|
||||
visitCore (e : Expr) : M (Arg .pure) := withIncRecDepth do
|
||||
if let some arg := (← get).cache.find? e then
|
||||
let eType? := (← read).expectedType
|
||||
if let some arg := (← get).cache.find? (e, eType?) then
|
||||
return arg
|
||||
let r : Arg .pure ← match e with
|
||||
| .app .. => visitApp e
|
||||
|
|
@ -462,7 +485,7 @@ where
|
|||
| .lit lit => visitLit lit
|
||||
| .fvar fvarId => if (← get).toAny.contains fvarId then pure .erased else pure (.fvar fvarId)
|
||||
| .forallE .. | .mvar .. | .bvar .. | .sort .. => unreachable!
|
||||
modify fun s => if s.shouldCache then { s with cache := s.cache.insert e r } else s
|
||||
modify fun s => if s.shouldCache then { s with cache := s.cache.insert (e, eType?) r } else s
|
||||
return r
|
||||
|
||||
visit (e : Expr) : M (Arg .pure) := withIncRecDepth do
|
||||
|
|
@ -505,7 +528,7 @@ where
|
|||
visitAppDefaultConst (f : Expr) (args : Array Expr) : M (Arg .pure) := do
|
||||
let env ← getEnv
|
||||
let .const declName us ← CSimp.replaceConstant env f | unreachable!
|
||||
let args ← args.mapM visitAppArg
|
||||
let args ← args.mapM (withoutExpectedType do visitAppArg ·)
|
||||
if hasNeverExtractAttribute env declName then
|
||||
modify fun s => {s with shouldCache := false }
|
||||
letValueToArg <| .const declName us args
|
||||
|
|
@ -549,10 +572,12 @@ where
|
|||
let altType ← c.inferType
|
||||
return (altType, .default c)
|
||||
| .ctor ctorName numParams =>
|
||||
let mut (ps, e) ← visitBoundedLambda e numParams
|
||||
let mut (ps, e) ← withoutExpectedType do
|
||||
visitBoundedLambda e numParams
|
||||
if ps.size < numParams then
|
||||
e ← etaExpandN e (numParams - ps.size)
|
||||
let (ps', e') ← ToLCNF.visitLambda e
|
||||
let (ps', e', _) ← withoutExpectedType do
|
||||
ToLCNF.visitLambda e
|
||||
ps := ps ++ ps'
|
||||
e := e'
|
||||
/-
|
||||
|
|
@ -609,11 +634,17 @@ where
|
|||
fieldArgs := fieldArgs.push fieldArg
|
||||
return fieldArgs
|
||||
let f := args[casesInfo.altsRange.lower]!
|
||||
let result ← visit (mkAppN f fieldArgs)
|
||||
mkOverApplication result args casesInfo.arity
|
||||
let arity := casesInfo.arity
|
||||
if args.size == arity then
|
||||
visit (mkAppN f fieldArgs)
|
||||
else
|
||||
withoutExpectedType do
|
||||
let result ← visit (mkAppN f fieldArgs)
|
||||
mkOverApplication result args casesInfo.arity
|
||||
else
|
||||
let mut alts := #[]
|
||||
let discr ← visitAppArg args[casesInfo.discrPos]!
|
||||
let discr ← withoutExpectedType do
|
||||
visitAppArg args[casesInfo.discrPos]!
|
||||
let discrFVarId ← match discr with
|
||||
| .fvar discrFVarId => pure discrFVarId
|
||||
| .erased | .type .. => mkAuxLetDecl .erased
|
||||
|
|
@ -625,9 +656,11 @@ where
|
|||
let auxDecl ← mkAuxParam resultType
|
||||
pushElement (.cases auxDecl cases)
|
||||
let result := .fvar auxDecl.fvarId
|
||||
mkOverApplication result args casesInfo.arity
|
||||
withoutExpectedType do
|
||||
mkOverApplication result args casesInfo.arity
|
||||
|
||||
visitCtor (arity : Nat) (e : Expr) : M (Arg .pure) :=
|
||||
withoutExpectedType do
|
||||
etaIfUnderApplied e arity do
|
||||
let f := e.getAppFn
|
||||
let args := e.getAppArgs
|
||||
|
|
@ -638,7 +671,7 @@ where
|
|||
-- We can rely on `toMono` erasing ctor params eventually; we do not do so here so that type
|
||||
-- inference on the value is preserved.
|
||||
withReader (fun ctx =>
|
||||
{ ignoreNoncomputable := ctx.ignoreNoncomputable || ctorInfo?.any (idx < ·.numParams) }) do
|
||||
{ ctx with ignoreNoncomputable := ctx.ignoreNoncomputable || ctorInfo?.any (idx < ·.numParams) }) do
|
||||
visitAppArg arg
|
||||
if hasNeverExtractAttribute env declName then
|
||||
modify fun s => {s with shouldCache := false }
|
||||
|
|
@ -646,6 +679,7 @@ where
|
|||
|
||||
visitQuotLift (e : Expr) : M (Arg .pure) := do
|
||||
let arity := 6
|
||||
withoutExpectedType do
|
||||
etaIfUnderApplied e arity do
|
||||
let mut args := e.getAppArgs
|
||||
let α ← visitAppArg args[0]!
|
||||
|
|
@ -661,6 +695,7 @@ where
|
|||
|
||||
visitEqRec (e : Expr) : M (Arg .pure) :=
|
||||
let arity := 6
|
||||
withoutExpectedType do
|
||||
etaIfUnderApplied e arity do
|
||||
let args := e.getAppArgs
|
||||
let minor := if e.isAppOf ``Eq.rec || e.isAppOf ``Eq.ndrec then args[3]! else args[5]!
|
||||
|
|
@ -669,6 +704,7 @@ where
|
|||
|
||||
visitHEqRec (e : Expr) : M (Arg .pure) :=
|
||||
let arity := 7
|
||||
withoutExpectedType do
|
||||
etaIfUnderApplied e arity do
|
||||
let args := e.getAppArgs
|
||||
let minor := if e.isAppOf ``HEq.rec || e.isAppOf ``HEq.ndrec then args[3]! else args[6]!
|
||||
|
|
@ -677,18 +713,21 @@ where
|
|||
|
||||
visitFalseRec (e : Expr) : M (Arg .pure) :=
|
||||
let arity := 2
|
||||
withoutExpectedType do
|
||||
etaIfUnderApplied e arity do
|
||||
let type ← toLCNFType (← liftMetaM do Meta.inferType e)
|
||||
mkUnreachable type
|
||||
|
||||
visitLcUnreachable (e : Expr) : M (Arg .pure) :=
|
||||
let arity := 1
|
||||
withoutExpectedType do
|
||||
etaIfUnderApplied e arity do
|
||||
let type ← toLCNFType (← liftMetaM do Meta.inferType e)
|
||||
mkUnreachable type
|
||||
|
||||
visitAndIffRecCore (e : Expr) (minorPos : Nat) : M (Arg .pure) :=
|
||||
let arity := 5
|
||||
withoutExpectedType do
|
||||
etaIfUnderApplied e arity do
|
||||
let args := e.getAppArgs
|
||||
let ha := mkLcProof args[0]! -- We should not use `lcErased` here since we use it to create a pre-LCNF Expr.
|
||||
|
|
@ -701,6 +740,7 @@ where
|
|||
let .const declName _ := e.getAppFn | unreachable!
|
||||
let info := getNoConfusionInfo (← getEnv) declName
|
||||
let typeName := declName.getPrefix
|
||||
withoutExpectedType do
|
||||
etaIfUnderApplied e info.arity do
|
||||
let args := e.getAppArgs
|
||||
let visitMajor (numNonPropFields : Nat) := do
|
||||
|
|
@ -786,10 +826,10 @@ where
|
|||
e.withApp visitAppDefaultConst
|
||||
else
|
||||
e.withApp fun f args => do
|
||||
match (← visit f) with
|
||||
match (← withoutExpectedType do visit f) with
|
||||
| .erased | .type .. => return .erased
|
||||
| .fvar fvarId =>
|
||||
let args ← args.mapM visitAppArg
|
||||
let args ← args.mapM (withoutExpectedType do visitAppArg ·)
|
||||
letValueToArg <| .fvar fvarId args
|
||||
|
||||
visitLambda (e : Expr) : M (Arg .pure) := do
|
||||
|
|
@ -821,8 +861,9 @@ where
|
|||
visit b
|
||||
else
|
||||
let funDecl ← withNewScope do
|
||||
let (ps, e) ← ToLCNF.visitLambda e
|
||||
let e ← visit e
|
||||
let (ps, e, eType?) ← ToLCNF.visitLambda e
|
||||
let e ← withExpectedType eType? do
|
||||
visit e
|
||||
let c ← toCode e
|
||||
mkAuxFunDecl ps c
|
||||
pushElement (.fun funDecl)
|
||||
|
|
@ -837,7 +878,7 @@ where
|
|||
let projExpr ← liftMetaM <| Meta.mkProjection e structInfo.fieldNames[i]!
|
||||
visitApp projExpr
|
||||
else
|
||||
match (← visit e) with
|
||||
match (← withoutExpectedType do visit e) with
|
||||
| .erased | .type .. => return .erased
|
||||
| .fvar fvarId => letValueToArg <| .proj s i fvarId
|
||||
|
||||
|
|
@ -850,7 +891,9 @@ where
|
|||
visitLet body (xs.push value)
|
||||
else
|
||||
let type' ← toLCNFType type
|
||||
let letDecl ← mkLetDecl binderName type value type' (← visit value)
|
||||
let value' ← withExpectedType type' do
|
||||
visit value
|
||||
let letDecl ← mkLetDecl binderName type value type' value'
|
||||
visitLet body (xs.push (.fvar letDecl.fvarId))
|
||||
| _ =>
|
||||
let e := e.instantiateRev xs
|
||||
|
|
|
|||
257
tests/elab/lcnf_borrow_expected_type.lean
Normal file
257
tests/elab/lcnf_borrow_expected_type.lean
Normal file
|
|
@ -0,0 +1,257 @@
|
|||
module
|
||||
|
||||
public section
|
||||
|
||||
/-!
|
||||
Tests that borrow annotations from declaration/let-binding types survive LCNF conversion.
|
||||
The `@&` annotations live in the forall type, not in the lambda binders, and are based on the
|
||||
(rather brittle) mdata so LCNF must infer them to a degree.
|
||||
-/
|
||||
|
||||
/--
|
||||
trace: [Compiler.saveBase] size: 1
|
||||
def borrowTop @&xs : Nat :=
|
||||
let _x.1 := @List.lengthTR _ xs;
|
||||
return _x.1
|
||||
-/
|
||||
#guard_msgs in
|
||||
set_option trace.Compiler.saveBase true in
|
||||
def borrowTop (xs : @& List Nat) : Nat := xs.length
|
||||
|
||||
/--
|
||||
trace: [Compiler.saveBase] size: 3
|
||||
def borrowMixed n @&xs m : Nat :=
|
||||
let _x.1 := @List.lengthTR _ xs;
|
||||
let _x.2 := Nat.add n _x.1;
|
||||
let _x.3 := Nat.add _x.2 m;
|
||||
return _x.3
|
||||
-/
|
||||
#guard_msgs in
|
||||
set_option trace.Compiler.saveBase true in
|
||||
def borrowMixed (n : Nat) (xs : @& List Nat) (m : Nat) : Nat :=
|
||||
n + xs.length + m
|
||||
|
||||
/--
|
||||
trace: [Compiler.saveBase] size: 5
|
||||
def borrowLet n xs ys : Nat :=
|
||||
fun f @&ys : Nat :=
|
||||
let _x.1 := @List.lengthTR _ ys;
|
||||
let _x.2 := Nat.add _x.1 n;
|
||||
return _x.2;
|
||||
let _x.3 := f xs;
|
||||
let _x.4 := f ys;
|
||||
let _x.5 := Nat.add _x.3 _x.4;
|
||||
return _x.5
|
||||
[Compiler.lambdaLifting] size: 2
|
||||
def borrowLet._lam_0 n @&ys : Nat :=
|
||||
let _x.1 := List.lengthTR._redArg ys;
|
||||
let _x.2 := Nat.add _x.1 n;
|
||||
return _x.2
|
||||
[Compiler.lambdaLifting] size: 4
|
||||
def borrowLet n xs ys : Nat :=
|
||||
let f := borrowLet._lam_0 n;
|
||||
let _x.1 := f xs;
|
||||
let _x.2 := f ys;
|
||||
let _x.3 := Nat.add _x.1 _x.2;
|
||||
return _x.3
|
||||
-/
|
||||
#guard_msgs in
|
||||
set_option trace.Compiler.saveBase true in
|
||||
set_option trace.Compiler.lambdaLifting true in
|
||||
def borrowLet (n : Nat) (xs ys : List Nat) : Nat :=
|
||||
let f : (@& List Nat) → Nat := fun ys => ys.length + n
|
||||
f xs + f ys
|
||||
|
||||
/--
|
||||
trace: [Compiler.saveBase] size: 2
|
||||
def applyTwice f @&a.1 : Nat :=
|
||||
let _x.2 := f a.1;
|
||||
let _x.3 := f _x.2;
|
||||
return _x.3
|
||||
-/
|
||||
#guard_msgs in
|
||||
set_option trace.Compiler.saveBase true in
|
||||
def applyTwice (f : Nat → Nat) : (@& Nat) → Nat :=
|
||||
let g := f ∘ f
|
||||
g
|
||||
|
||||
structure Ctx where
|
||||
values : List Nat
|
||||
|
||||
abbrev MyReaderM (α : Type) := (@& Ctx) → α
|
||||
|
||||
@[inline]
|
||||
def MyReaderM.bind (f : MyReaderM α) (g : α → MyReaderM β) : MyReaderM β :=
|
||||
fun ctx => g (f ctx) ctx
|
||||
|
||||
instance : Monad MyReaderM where
|
||||
pure a := fun _ => a
|
||||
bind := MyReaderM.bind
|
||||
|
||||
@[inline] def MyReaderM.read : MyReaderM Ctx := fun ctx => ctx
|
||||
|
||||
/--
|
||||
trace: [Compiler.saveBase] size: 2
|
||||
def withMyReader α f x @&ctx : α :=
|
||||
let _x.1 := f ctx;
|
||||
let _x.2 := x _x.1;
|
||||
return _x.2
|
||||
-/
|
||||
#guard_msgs in
|
||||
set_option trace.Compiler.saveBase true in
|
||||
@[noinline]
|
||||
def withMyReader (f : Ctx → Ctx) (x : MyReaderM α) : MyReaderM α :=
|
||||
fun ctx => x (f ctx)
|
||||
|
||||
/--
|
||||
trace: [Compiler.saveBase] size: 6
|
||||
def getLength other @&a.1 : Nat :=
|
||||
fun _f.2 ctx : Ctx :=
|
||||
let _x.3 := ctx # 0;
|
||||
let _x.4 := @List.appendTR _ _x.3 other;
|
||||
let _x.5 := Ctx.mk _x.4;
|
||||
return _x.5;
|
||||
fun _f.6 _y.7 : Nat :=
|
||||
let _x.8 := _y.7 # 0;
|
||||
let _x.9 := @List.lengthTR _ _x.8;
|
||||
return _x.9;
|
||||
let _x.10 := @withMyReader _ _f.2 _f.6 a.1;
|
||||
return _x.10
|
||||
-/
|
||||
#guard_msgs in
|
||||
set_option trace.Compiler.saveBase true in
|
||||
def getLength (other : List Nat) : MyReaderM Nat := do
|
||||
withMyReader (fun ctx => { ctx with values := ctx.values ++ other }) do
|
||||
let ctx ← MyReaderM.read
|
||||
return ctx.values.length
|
||||
|
||||
structure Pair where
|
||||
fst : List Nat
|
||||
snd : List Nat
|
||||
|
||||
structure Quad where
|
||||
left : Pair
|
||||
right : Pair
|
||||
|
||||
/-- Packs `xs` into a constructor → parameter inferred as owned. -/
|
||||
@[noinline] def wrap (xs : List Nat) : List (List Nat) := [xs]
|
||||
|
||||
/-- Only traverses → parameter stays borrowed. -/
|
||||
@[noinline] def measuree (xs : List Nat) : Nat := xs.length
|
||||
|
||||
/--
|
||||
trace: [Compiler.explicitRc] size: 22
|
||||
def cascadeDemo @&t : tobj :=
|
||||
let left := oproj[0] t;
|
||||
let right := oproj[1] t;
|
||||
let fst := oproj[0] left;
|
||||
let snd := oproj[1] left;
|
||||
let fst := oproj[0] right;
|
||||
let snd := oproj[1] right;
|
||||
inc fst;
|
||||
let _x.1 := wrap fst;
|
||||
let res := List.lengthTR._redArg _x.1;
|
||||
dec _x.1;
|
||||
let _x.2 := measuree snd;
|
||||
let _x.3 := Nat.add res _x.2;
|
||||
dec _x.2;
|
||||
dec res;
|
||||
let _x.4 := measuree fst;
|
||||
let _x.5 := Nat.add _x.3 _x.4;
|
||||
dec _x.4;
|
||||
dec _x.3;
|
||||
let _x.6 := measuree snd;
|
||||
let _x.7 := Nat.add _x.5 _x.6;
|
||||
dec _x.6;
|
||||
dec _x.5;
|
||||
return _x.7
|
||||
[Compiler.explicitRc] size: 2
|
||||
def cascadeDemo._boxed t : tobj :=
|
||||
let res := cascadeDemo t;
|
||||
dec t;
|
||||
return res
|
||||
-/
|
||||
#guard_msgs in
|
||||
set_option trace.Compiler.explicitRc true in
|
||||
def cascadeDemo (t : @&Quad) : Nat :=
|
||||
let l := t.left
|
||||
let r := t.right
|
||||
let res := (wrap l.fst).length
|
||||
res + measuree l.snd + measuree r.fst + measuree r.snd
|
||||
|
||||
/--
|
||||
trace: [Compiler.explicitRc] size: 33
|
||||
def cascadeDemo' t : tobj :=
|
||||
let left := oproj[0] t;
|
||||
inc left;
|
||||
let right := oproj[1] t;
|
||||
inc right;
|
||||
dec t;
|
||||
let fst := oproj[0] left;
|
||||
inc fst;
|
||||
let snd := oproj[1] left;
|
||||
inc snd;
|
||||
dec left;
|
||||
let fst := oproj[0] right;
|
||||
inc fst;
|
||||
let snd := oproj[1] right;
|
||||
inc snd;
|
||||
dec right;
|
||||
let _x.1 := wrap fst;
|
||||
let res := List.lengthTR._redArg _x.1;
|
||||
dec _x.1;
|
||||
let _x.2 := measuree snd;
|
||||
dec snd;
|
||||
let _x.3 := Nat.add res _x.2;
|
||||
dec _x.2;
|
||||
dec res;
|
||||
let _x.4 := measuree fst;
|
||||
dec fst;
|
||||
let _x.5 := Nat.add _x.3 _x.4;
|
||||
dec _x.4;
|
||||
dec _x.3;
|
||||
let _x.6 := measuree snd;
|
||||
dec snd;
|
||||
let _x.7 := Nat.add _x.5 _x.6;
|
||||
dec _x.6;
|
||||
dec _x.5;
|
||||
return _x.7
|
||||
-/
|
||||
#guard_msgs in
|
||||
set_option trace.Compiler.explicitRc true in
|
||||
def cascadeDemo' (t : Quad) : Nat :=
|
||||
let l := t.left
|
||||
let r := t.right
|
||||
let res := (wrap l.fst).length
|
||||
res + measuree l.snd + measuree r.fst + measuree r.snd
|
||||
|
||||
@[noinline]
|
||||
public def mkNewProd (x : Prod Nat Nat) (a : Nat) := { x with fst := a }
|
||||
|
||||
/--
|
||||
trace: [Compiler.explicitRc] size: 15
|
||||
def preserveTailCall x a : tobj :=
|
||||
let zero := 0;
|
||||
let isZero := Nat.decEq a zero;
|
||||
cases isZero : tobj
|
||||
| Bool.true =>
|
||||
dec a;
|
||||
let fst := oproj[0] x;
|
||||
inc fst;
|
||||
dec x;
|
||||
return fst
|
||||
| Bool.false =>
|
||||
let one := 1;
|
||||
let n.1 := Nat.sub a one;
|
||||
dec a;
|
||||
inc n.1;
|
||||
let _x.2 := mkNewProd x n.1;
|
||||
let _x.3 := preserveTailCall _x.2 n.1;
|
||||
return _x.3
|
||||
-/
|
||||
#guard_msgs in
|
||||
set_option trace.Compiler.explicitRc true in
|
||||
def preserveTailCall (x : @&Prod Nat Nat) (a : Nat) : Nat :=
|
||||
match a with
|
||||
| 0 => x.fst
|
||||
| a + 1 => preserveTailCall (mkNewProd x a) a
|
||||
Loading…
Add table
Reference in a new issue