diff --git a/src/Lean/Compiler/LCNF/InferBorrow.lean b/src/Lean/Compiler/LCNF/InferBorrow.lean index ca6aa9066b..e7eafa5d0d 100644 --- a/src/Lean/Compiler/LCNF/InferBorrow.lean +++ b/src/Lean/Compiler/LCNF/InferBorrow.lean @@ -213,6 +213,8 @@ inductive OwnReason where | jpArgPropagation (jpFVar : FVarId) /-- Tail call preservation at a join point jump. -/ | jpTailCallPreservation (jpFVar : FVarId) + /-- Annotated as an owned parameter (currently only triggerable through `@[export]`)-/ + | ownedAnnotation def OwnReason.toString (reason : OwnReason) : CompilerM String := do PP.run do @@ -229,6 +231,7 @@ def OwnReason.toString (reason : OwnReason) : CompilerM String := do | .tailCallPreservation funcName => return s!"tail call preservation of {funcName}" | .jpArgPropagation jpFVar => return s!"backward propagation from JP {← PP.ppFVar jpFVar}" | .jpTailCallPreservation jpFVar => return s!"JP tail call preservation {← PP.ppFVar jpFVar}" + | .ownedAnnotation => return s!"Annotated as owned" /-- Determine whether an `OwnReason` is necessary for correctness (forced) or just an optimization @@ -245,7 +248,7 @@ def OwnReason.isForced (reason : OwnReason) : Bool := | .constructorResult .. | .functionCallResult .. -- We cannot pass borrowed values to reset or have borrow annotations destroy tail calls for -- correctness reasons. - | .resetReuse .. | .tailCallPreservation .. | .jpTailCallPreservation .. + | .resetReuse .. | .tailCallPreservation .. | .jpTailCallPreservation .. | .ownedAnnotation | .forwardProjectionProp .. | .backwardProjectionProp .. => true /-- @@ -256,10 +259,19 @@ partial def infer (decls : Array (Decl .impure)) : CompilerM ParamMap := do return map.paramMap where go : InferM Unit := do + for (_, params) in (← get).paramMap.map do + for param in params do + if !param.borrow && param.type.isPossibleRef then + -- if the param already disqualifies as borrow now this is because of an annotation + ownFVar param.fvarId .ownedAnnotation + modify fun s => { s with modified := false } + loop + + loop : InferM Unit := do step if (← get).modified then modify fun s => { s with modified := false } - go + loop else return ()