feat: sanity checking at attachJp
This commit is contained in:
parent
f0370749f9
commit
1d936e2d6b
3 changed files with 20 additions and 12 deletions
|
|
@ -68,8 +68,7 @@ where
|
|||
withReader (fun ctx => { ctx with jps := if isJpBinderName n then ctx.jps.insert x.fvarId! else ctx.jps }) do
|
||||
checkBlock b (xs.push x)
|
||||
| _ =>
|
||||
if (← isJump e) then
|
||||
let .fvar fvarId := e.getAppFn | unreachable!
|
||||
if let some fvarId ← isJump? e then
|
||||
unless (← read).jps.contains fvarId do
|
||||
/-
|
||||
We cannot jump to join points defined out of the scope of a local function declaration.
|
||||
|
|
|
|||
|
|
@ -247,17 +247,20 @@ def mkJump (jp : Expr) (e : Expr) : CompilerM Expr := do
|
|||
|
||||
/--
|
||||
Given a let-declaration block `e`, return a new block that jumps to `jp` at its "exit points".
|
||||
|
||||
`e` must contain all join points declarations used in `e`.
|
||||
-/
|
||||
partial def attachJp (e : Expr) (jp : Expr) : CompilerM Expr := do
|
||||
visitLet e #[]
|
||||
withNewScope do
|
||||
mkLetUsingScope (← visitLet e #[] |>.run {})
|
||||
where
|
||||
visitLambda (e : Expr) : CompilerM Expr := do
|
||||
visitLambda (e : Expr) : ReaderT FVarIdSet CompilerM Expr := do
|
||||
withNewScope do
|
||||
let (as, e) ← Compiler.visitLambda e
|
||||
let e ← mkLetUsingScope (← visitLet e #[])
|
||||
mkLambda as e
|
||||
|
||||
visitCases (casesInfo : CasesInfo) (cases : Expr) : CompilerM Expr := do
|
||||
visitCases (casesInfo : CasesInfo) (cases : Expr) : ReaderT FVarIdSet CompilerM Expr := do
|
||||
let mut args := cases.getAppArgs
|
||||
let .forallE _ _ b _ ← inferType jp | unreachable! -- jp's type is guaranteed to be an nondependent arrow
|
||||
args := casesInfo.updateResultingType args b
|
||||
|
|
@ -265,7 +268,7 @@ where
|
|||
args ← args.modifyM i visitLambda
|
||||
return mkAppN cases.getAppFn args
|
||||
|
||||
visitLet (e : Expr) (xs : Array Expr) : CompilerM Expr := do
|
||||
visitLet (e : Expr) (xs : Array Expr) : ReaderT FVarIdSet CompilerM Expr := do
|
||||
match e with
|
||||
| .letE binderName type value body nonDep =>
|
||||
let type := type.instantiateRev xs
|
||||
|
|
@ -273,10 +276,13 @@ where
|
|||
if isJpBinderName binderName then
|
||||
value ← visitLambda value
|
||||
let x ← mkLetDecl binderName type value nonDep
|
||||
visitLet body (xs.push x)
|
||||
withReader (fun jps => if isJpBinderName binderName then jps.insert x.fvarId! else jps) do
|
||||
visitLet body (xs.push x)
|
||||
| _ =>
|
||||
let e := e.instantiateRev xs
|
||||
if (← isJump e) then
|
||||
if let some fvarId ← isJump? e then
|
||||
unless (← read).contains fvarId do
|
||||
throwError "failed to attach join point to let-block, it contains a out of scope join point"
|
||||
return e
|
||||
else if let some casesInfo ← isCasesApp? e then
|
||||
visitCases casesInfo e
|
||||
|
|
|
|||
|
|
@ -167,9 +167,12 @@ def findDecl? [Monad m] [MonadLCtx m] (fvarId : FVarId) : m (Option LocalDecl) :
|
|||
/--
|
||||
Return true if `e` is of the form `_jp.<idx> ..` where `_jp.<idx>` is a join point.
|
||||
-/
|
||||
def isJump [Monad m] [MonadLCtx m] (e : Expr) : m Bool := do
|
||||
let .fvar fvarId := e.getAppFn | return false
|
||||
let some localDecl ← findDecl? fvarId | return false
|
||||
return localDecl.isJp
|
||||
def isJump? [Monad m] [MonadLCtx m] (e : Expr) : m (Option FVarId) := do
|
||||
let .fvar fvarId := e.getAppFn | return none
|
||||
let some localDecl ← findDecl? fvarId | return none
|
||||
if localDecl.isJp then
|
||||
return some fvarId
|
||||
else
|
||||
return none
|
||||
|
||||
end Lean.Compiler
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue