feat: sanity checking at attachJp

This commit is contained in:
Leonardo de Moura 2022-08-17 09:47:59 -07:00
parent f0370749f9
commit 1d936e2d6b
3 changed files with 20 additions and 12 deletions

View file

@ -68,8 +68,7 @@ where
withReader (fun ctx => { ctx with jps := if isJpBinderName n then ctx.jps.insert x.fvarId! else ctx.jps }) do
checkBlock b (xs.push x)
| _ =>
if (← isJump e) then
let .fvar fvarId := e.getAppFn | unreachable!
if let some fvarId ← isJump? e then
unless (← read).jps.contains fvarId do
/-
We cannot jump to join points defined out of the scope of a local function declaration.

View file

@ -247,17 +247,20 @@ def mkJump (jp : Expr) (e : Expr) : CompilerM Expr := do
/--
Given a let-declaration block `e`, return a new block that jumps to `jp` at its "exit points".
`e` must contain all join points declarations used in `e`.
-/
partial def attachJp (e : Expr) (jp : Expr) : CompilerM Expr := do
visitLet e #[]
withNewScope do
mkLetUsingScope (← visitLet e #[] |>.run {})
where
visitLambda (e : Expr) : CompilerM Expr := do
visitLambda (e : Expr) : ReaderT FVarIdSet CompilerM Expr := do
withNewScope do
let (as, e) ← Compiler.visitLambda e
let e ← mkLetUsingScope (← visitLet e #[])
mkLambda as e
visitCases (casesInfo : CasesInfo) (cases : Expr) : CompilerM Expr := do
visitCases (casesInfo : CasesInfo) (cases : Expr) : ReaderT FVarIdSet CompilerM Expr := do
let mut args := cases.getAppArgs
let .forallE _ _ b _ ← inferType jp | unreachable! -- jp's type is guaranteed to be an nondependent arrow
args := casesInfo.updateResultingType args b
@ -265,7 +268,7 @@ where
args ← args.modifyM i visitLambda
return mkAppN cases.getAppFn args
visitLet (e : Expr) (xs : Array Expr) : CompilerM Expr := do
visitLet (e : Expr) (xs : Array Expr) : ReaderT FVarIdSet CompilerM Expr := do
match e with
| .letE binderName type value body nonDep =>
let type := type.instantiateRev xs
@ -273,10 +276,13 @@ where
if isJpBinderName binderName then
value ← visitLambda value
let x ← mkLetDecl binderName type value nonDep
visitLet body (xs.push x)
withReader (fun jps => if isJpBinderName binderName then jps.insert x.fvarId! else jps) do
visitLet body (xs.push x)
| _ =>
let e := e.instantiateRev xs
if (← isJump e) then
if let some fvarId ← isJump? e then
unless (← read).contains fvarId do
throwError "failed to attach join point to let-block, it contains a out of scope join point"
return e
else if let some casesInfo ← isCasesApp? e then
visitCases casesInfo e

View file

@ -167,9 +167,12 @@ def findDecl? [Monad m] [MonadLCtx m] (fvarId : FVarId) : m (Option LocalDecl) :
/--
Return true if `e` is of the form `_jp.<idx> ..` where `_jp.<idx>` is a join point.
-/
def isJump [Monad m] [MonadLCtx m] (e : Expr) : m Bool := do
let .fvar fvarId := e.getAppFn | return false
let some localDecl ← findDecl? fvarId | return false
return localDecl.isJp
def isJump? [Monad m] [MonadLCtx m] (e : Expr) : m (Option FVarId) := do
let .fvar fvarId := e.getAppFn | return none
let some localDecl ← findDecl? fvarId | return none
if localDecl.isJp then
return some fvarId
else
return none
end Lean.Compiler