diff --git a/src/Lean/Compiler/Check.lean b/src/Lean/Compiler/Check.lean index 4cce1e4736..85533b78c0 100644 --- a/src/Lean/Compiler/Check.lean +++ b/src/Lean/Compiler/Check.lean @@ -68,8 +68,7 @@ where withReader (fun ctx => { ctx with jps := if isJpBinderName n then ctx.jps.insert x.fvarId! else ctx.jps }) do checkBlock b (xs.push x) | _ => - if (← isJump e) then - let .fvar fvarId := e.getAppFn | unreachable! + if let some fvarId ← isJump? e then unless (← read).jps.contains fvarId do /- We cannot jump to join points defined out of the scope of a local function declaration. diff --git a/src/Lean/Compiler/CompilerM.lean b/src/Lean/Compiler/CompilerM.lean index ce25a85e38..7869d6f264 100644 --- a/src/Lean/Compiler/CompilerM.lean +++ b/src/Lean/Compiler/CompilerM.lean @@ -247,17 +247,20 @@ def mkJump (jp : Expr) (e : Expr) : CompilerM Expr := do /-- Given a let-declaration block `e`, return a new block that jumps to `jp` at its "exit points". + +`e` must contain all join points declarations used in `e`. -/ partial def attachJp (e : Expr) (jp : Expr) : CompilerM Expr := do - visitLet e #[] + withNewScope do + mkLetUsingScope (← visitLet e #[] |>.run {}) where - visitLambda (e : Expr) : CompilerM Expr := do + visitLambda (e : Expr) : ReaderT FVarIdSet CompilerM Expr := do withNewScope do let (as, e) ← Compiler.visitLambda e let e ← mkLetUsingScope (← visitLet e #[]) mkLambda as e - visitCases (casesInfo : CasesInfo) (cases : Expr) : CompilerM Expr := do + visitCases (casesInfo : CasesInfo) (cases : Expr) : ReaderT FVarIdSet CompilerM Expr := do let mut args := cases.getAppArgs let .forallE _ _ b _ ← inferType jp | unreachable! -- jp's type is guaranteed to be an nondependent arrow args := casesInfo.updateResultingType args b @@ -265,7 +268,7 @@ where args ← args.modifyM i visitLambda return mkAppN cases.getAppFn args - visitLet (e : Expr) (xs : Array Expr) : CompilerM Expr := do + visitLet (e : Expr) (xs : Array Expr) : ReaderT FVarIdSet CompilerM Expr := do match e with | .letE binderName type value body nonDep => let type := type.instantiateRev xs @@ -273,10 +276,13 @@ where if isJpBinderName binderName then value ← visitLambda value let x ← mkLetDecl binderName type value nonDep - visitLet body (xs.push x) + withReader (fun jps => if isJpBinderName binderName then jps.insert x.fvarId! else jps) do + visitLet body (xs.push x) | _ => let e := e.instantiateRev xs - if (← isJump e) then + if let some fvarId ← isJump? e then + unless (← read).contains fvarId do + throwError "failed to attach join point to let-block, it contains a out of scope join point" return e else if let some casesInfo ← isCasesApp? e then visitCases casesInfo e diff --git a/src/Lean/Compiler/Util.lean b/src/Lean/Compiler/Util.lean index 7d5900651a..71fa94e107 100644 --- a/src/Lean/Compiler/Util.lean +++ b/src/Lean/Compiler/Util.lean @@ -167,9 +167,12 @@ def findDecl? [Monad m] [MonadLCtx m] (fvarId : FVarId) : m (Option LocalDecl) : /-- Return true if `e` is of the form `_jp. ..` where `_jp.` is a join point. -/ -def isJump [Monad m] [MonadLCtx m] (e : Expr) : m Bool := do - let .fvar fvarId := e.getAppFn | return false - let some localDecl ← findDecl? fvarId | return false - return localDecl.isJp +def isJump? [Monad m] [MonadLCtx m] (e : Expr) : m (Option FVarId) := do + let .fvar fvarId := e.getAppFn | return none + let some localDecl ← findDecl? fvarId | return none + if localDecl.isJp then + return some fvarId + else + return none end Lean.Compiler