fix: address code review for jp checker

This commit is contained in:
Henrik Böving 2022-08-14 11:13:22 +02:00 committed by Leonardo de Moura
parent ff9c9032b4
commit 8e29fa88eb
3 changed files with 47 additions and 59 deletions

View file

@ -41,6 +41,9 @@ def findDecl? (fvarId : FVarId) : CompilerM (Option LocalDecl) := do
let lctx := (← get).lctx
return lctx.find? fvarId
def _root_.Lean.LocalDecl.isJp (decl : LocalDecl) : Bool :=
decl.userName.getPrefix == `_jp
def mkAuxLetDecl (e : Expr) (prefixName := `_x) : CompilerM Expr := do
if e.isFVar then
return e

View file

@ -87,9 +87,8 @@ in order to allow the optimizer to turn them into efficient machine code.
This function ensures that inside the given declaration both of these
conditions are satisfied and throws an exception otherwise.
-/
def Decl.checkJoinPoints (decl : Decl) : CompilerM Unit := do
let tails ← JoinPointChecker.getTails decl.value
discard <| tails.mapM JoinPointChecker.checkTail
def Decl.checkJoinPoints (decl : Decl) : CompilerM Unit :=
JoinPointChecker.checkJoinPoints decl.value
/--
@ -103,7 +102,7 @@ These invariants are:
-/
def Decl.check (decl : Decl) (cfg : Check.Config := {}): CoreM Unit := do
Compiler.check decl.value cfg { lctx := {} }
discard <| checkJoinPoints decl |>.run {}
checkJoinPoints decl |>.run' {}
let valueType ← InferType.inferType decl.value { lctx := {} }
unless compatibleTypes decl.type valueType do
throwError "declaration type mismatch at `{decl.name}`, value has type{indentExpr valueType}\nbut is expected to have type{indentExpr decl.type}"

View file

@ -23,85 +23,71 @@ partial def containsNoJp (e : Expr) : CompilerM Unit := do
match e.consumeMData with
| .proj _ _ struct => containsNoJp struct
| .lam .. =>
let (_, b) ← visitLambda e
containsNoJp b
withNewScope do
let (_, b) ← visitLambda e
containsNoJp b
| .letE .. =>
let body ← visitLet e (fun e => do containsNoJp e; pure e)
containsNoJp body
withNewScope do
let body ← visitLet e (fun e => do containsNoJp e; pure e)
containsNoJp body
| .app fn arg =>
containsNoJp fn
containsNoJp arg
| .fvar fvarId =>
let some decl ← findDecl? fvarId | unreachable!
if decl.userName.getPrefix ==`_jp then
if decl.isJp then
throwError s!"Join point {decl.userName} in forbidden position"
| .sort .. | .forallE .. | .const .. | .lit .. => return ()
| .bvar .. | .mvar .. | .mdata .. => unreachable!
/--
Obtain all the tail `Expr`s of `e`. Already checking whether non
tail values contain a join point and throwing an exception if they do.
Check whether all join points in `e` are in a valid position that is:
1. Fully applied
2. In tail position inside of `e`
-/
partial def getTails (e : Expr) : CompilerM (Array Expr) := do
let (_, body) ← visitLambda e
let (_, tails) ← go body |>.run #[]
return tails
partial def checkJoinPoints (e : Expr) : CompilerM Unit := do
goLambda e
where
goLetValue (value : Expr) : StateRefT (Array Expr) CompilerM Unit := do
goLambda (e : Expr) : CompilerM Unit := do
withNewScope do
let (_, body) ← visitLambda e
go body
goLetValue (value : Expr) : CompilerM Unit := do
let value := value.consumeMData
match value with
| .lam .. =>
let (_, body) ← visitLambda value
go body
| _ =>
containsNoJp value
return ()
| .lam .. => goLambda value
| _ => containsNoJp value
go (e : Expr) : StateRefT (Array Expr) CompilerM Unit := do
go (e : Expr) : CompilerM Unit := do
let e := e.consumeMData
match e with
| .letE .. =>
let body ← visitLet e (fun value => do goLetValue value; pure value)
go body
withNewScope do
let body ← visitLet e (fun value => do goLetValue value; pure value)
go body
| .app (.fvar fvarId) arg =>
let some decl ← findDecl? fvarId | unreachable!
if decl.isJp then
let args := e.getAppNumArgs
let jpArity := jpArity decl
if args != jpArity then
throwError s!"Join point {decl.userName} : {decl.type} is not fully applied in {e}"
-- Make sure no join point is inside the arguments since that would not be in tail position
containsNoJp arg
| .app .. =>
if let some casesInfo ←isCasesApp? e then
let (motive, discrs, arms) ← visitMatch e casesInfo
containsNoJp motive
discard <| discrs.mapM (liftM ∘ containsNoJp)
discard <| arms.mapM go
withNewScope do
let (motive, discrs, arms) ← visitMatch e casesInfo
containsNoJp motive
discrs.forM containsNoJp
arms.forM go
else
let tails ← get
set <| tails.push e
| .fvar .. | .sort .. | .forallE .. | .const .. | .lit .. | .proj .. | .lam .. =>
let tails ← get
set <| tails.push e
containsNoJp e
| .proj .. | .lam .. => containsNoJp e
| .fvar .. | .sort .. | .forallE .. | .const .. | .lit .. => return ()
| .bvar .. | .mvar .. | .mdata .. => unreachable!
/--
Checks that a tail is valid, that is either a fully applied join point
or doesn't contain a join point at all.
-/
def checkTail (e : Expr) : CompilerM Unit := do
match e with
| .app (.fvar fvarId) arg =>
let some decl ← findDecl? fvarId | unreachable!
if decl.userName.getPrefix == `_jp then
let args := e.getAppNumArgs
let jpArity := jpArity decl
if args != jpArity then
throwError s!"Join point {decl.userName} : {decl.type} is not fully applied in {e}"
-- Make sure no join point is inside the arguments since that would not be in tail position
containsNoJp arg
/-
If these are in tail position they may not contain any join point
since that would mean it is not in tail position.
Note: the special case where the app function is one is handled above
-/
| .app .. | .proj .. | .lam .. => containsNoJp e
| .fvar .. | .sort .. | .forallE .. | .const .. | .lit .. => return ()
| .letE .. | .mdata .. | .bvar .. | .mvar .. => unreachable!
end JoinPointChecker
end Lean.Compiler