fix: address code review for jp checker
This commit is contained in:
parent
ff9c9032b4
commit
8e29fa88eb
3 changed files with 47 additions and 59 deletions
|
|
@ -41,6 +41,9 @@ def findDecl? (fvarId : FVarId) : CompilerM (Option LocalDecl) := do
|
|||
let lctx := (← get).lctx
|
||||
return lctx.find? fvarId
|
||||
|
||||
def _root_.Lean.LocalDecl.isJp (decl : LocalDecl) : Bool :=
|
||||
decl.userName.getPrefix == `_jp
|
||||
|
||||
def mkAuxLetDecl (e : Expr) (prefixName := `_x) : CompilerM Expr := do
|
||||
if e.isFVar then
|
||||
return e
|
||||
|
|
|
|||
|
|
@ -87,9 +87,8 @@ in order to allow the optimizer to turn them into efficient machine code.
|
|||
This function ensures that inside the given declaration both of these
|
||||
conditions are satisfied and throws an exception otherwise.
|
||||
-/
|
||||
def Decl.checkJoinPoints (decl : Decl) : CompilerM Unit := do
|
||||
let tails ← JoinPointChecker.getTails decl.value
|
||||
discard <| tails.mapM JoinPointChecker.checkTail
|
||||
def Decl.checkJoinPoints (decl : Decl) : CompilerM Unit :=
|
||||
JoinPointChecker.checkJoinPoints decl.value
|
||||
|
||||
|
||||
/--
|
||||
|
|
@ -103,7 +102,7 @@ These invariants are:
|
|||
-/
|
||||
def Decl.check (decl : Decl) (cfg : Check.Config := {}): CoreM Unit := do
|
||||
Compiler.check decl.value cfg { lctx := {} }
|
||||
discard <| checkJoinPoints decl |>.run {}
|
||||
checkJoinPoints decl |>.run' {}
|
||||
let valueType ← InferType.inferType decl.value { lctx := {} }
|
||||
unless compatibleTypes decl.type valueType do
|
||||
throwError "declaration type mismatch at `{decl.name}`, value has type{indentExpr valueType}\nbut is expected to have type{indentExpr decl.type}"
|
||||
|
|
|
|||
|
|
@ -23,85 +23,71 @@ partial def containsNoJp (e : Expr) : CompilerM Unit := do
|
|||
match e.consumeMData with
|
||||
| .proj _ _ struct => containsNoJp struct
|
||||
| .lam .. =>
|
||||
let (_, b) ← visitLambda e
|
||||
containsNoJp b
|
||||
withNewScope do
|
||||
let (_, b) ← visitLambda e
|
||||
containsNoJp b
|
||||
| .letE .. =>
|
||||
let body ← visitLet e (fun e => do containsNoJp e; pure e)
|
||||
containsNoJp body
|
||||
withNewScope do
|
||||
let body ← visitLet e (fun e => do containsNoJp e; pure e)
|
||||
containsNoJp body
|
||||
| .app fn arg =>
|
||||
containsNoJp fn
|
||||
containsNoJp arg
|
||||
| .fvar fvarId =>
|
||||
let some decl ← findDecl? fvarId | unreachable!
|
||||
if decl.userName.getPrefix ==`_jp then
|
||||
if decl.isJp then
|
||||
throwError s!"Join point {decl.userName} in forbidden position"
|
||||
| .sort .. | .forallE .. | .const .. | .lit .. => return ()
|
||||
| .bvar .. | .mvar .. | .mdata .. => unreachable!
|
||||
|
||||
/--
|
||||
Obtain all the tail `Expr`s of `e`. Already checking whether non
|
||||
tail values contain a join point and throwing an exception if they do.
|
||||
Check whether all join points in `e` are in a valid position that is:
|
||||
1. Fully applied
|
||||
2. In tail position inside of `e`
|
||||
-/
|
||||
partial def getTails (e : Expr) : CompilerM (Array Expr) := do
|
||||
let (_, body) ← visitLambda e
|
||||
let (_, tails) ← go body |>.run #[]
|
||||
return tails
|
||||
partial def checkJoinPoints (e : Expr) : CompilerM Unit := do
|
||||
goLambda e
|
||||
where
|
||||
goLetValue (value : Expr) : StateRefT (Array Expr) CompilerM Unit := do
|
||||
goLambda (e : Expr) : CompilerM Unit := do
|
||||
withNewScope do
|
||||
let (_, body) ← visitLambda e
|
||||
go body
|
||||
|
||||
goLetValue (value : Expr) : CompilerM Unit := do
|
||||
let value := value.consumeMData
|
||||
match value with
|
||||
| .lam .. =>
|
||||
let (_, body) ← visitLambda value
|
||||
go body
|
||||
| _ =>
|
||||
containsNoJp value
|
||||
return ()
|
||||
| .lam .. => goLambda value
|
||||
| _ => containsNoJp value
|
||||
|
||||
go (e : Expr) : StateRefT (Array Expr) CompilerM Unit := do
|
||||
go (e : Expr) : CompilerM Unit := do
|
||||
let e := e.consumeMData
|
||||
match e with
|
||||
| .letE .. =>
|
||||
let body ← visitLet e (fun value => do goLetValue value; pure value)
|
||||
go body
|
||||
withNewScope do
|
||||
let body ← visitLet e (fun value => do goLetValue value; pure value)
|
||||
go body
|
||||
| .app (.fvar fvarId) arg =>
|
||||
let some decl ← findDecl? fvarId | unreachable!
|
||||
if decl.isJp then
|
||||
let args := e.getAppNumArgs
|
||||
let jpArity := jpArity decl
|
||||
if args != jpArity then
|
||||
throwError s!"Join point {decl.userName} : {decl.type} is not fully applied in {e}"
|
||||
-- Make sure no join point is inside the arguments since that would not be in tail position
|
||||
containsNoJp arg
|
||||
| .app .. =>
|
||||
if let some casesInfo ←isCasesApp? e then
|
||||
let (motive, discrs, arms) ← visitMatch e casesInfo
|
||||
containsNoJp motive
|
||||
discard <| discrs.mapM (liftM ∘ containsNoJp)
|
||||
discard <| arms.mapM go
|
||||
withNewScope do
|
||||
let (motive, discrs, arms) ← visitMatch e casesInfo
|
||||
containsNoJp motive
|
||||
discrs.forM containsNoJp
|
||||
arms.forM go
|
||||
else
|
||||
let tails ← get
|
||||
set <| tails.push e
|
||||
| .fvar .. | .sort .. | .forallE .. | .const .. | .lit .. | .proj .. | .lam .. =>
|
||||
let tails ← get
|
||||
set <| tails.push e
|
||||
containsNoJp e
|
||||
| .proj .. | .lam .. => containsNoJp e
|
||||
| .fvar .. | .sort .. | .forallE .. | .const .. | .lit .. => return ()
|
||||
| .bvar .. | .mvar .. | .mdata .. => unreachable!
|
||||
|
||||
/--
|
||||
Checks that a tail is valid, that is either a fully applied join point
|
||||
or doesn't contain a join point at all.
|
||||
-/
|
||||
def checkTail (e : Expr) : CompilerM Unit := do
|
||||
match e with
|
||||
| .app (.fvar fvarId) arg =>
|
||||
let some decl ← findDecl? fvarId | unreachable!
|
||||
if decl.userName.getPrefix == `_jp then
|
||||
let args := e.getAppNumArgs
|
||||
let jpArity := jpArity decl
|
||||
if args != jpArity then
|
||||
throwError s!"Join point {decl.userName} : {decl.type} is not fully applied in {e}"
|
||||
-- Make sure no join point is inside the arguments since that would not be in tail position
|
||||
containsNoJp arg
|
||||
/-
|
||||
If these are in tail position they may not contain any join point
|
||||
since that would mean it is not in tail position.
|
||||
Note: the special case where the app function is one is handled above
|
||||
-/
|
||||
| .app .. | .proj .. | .lam .. => containsNoJp e
|
||||
| .fvar .. | .sort .. | .forallE .. | .const .. | .lit .. => return ()
|
||||
| .letE .. | .mdata .. | .bvar .. | .mvar .. => unreachable!
|
||||
|
||||
|
||||
end JoinPointChecker
|
||||
|
||||
end Lean.Compiler
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue