From 8e29fa88eb7546f37359275ccb87e271f39443ff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Henrik=20B=C3=B6ving?= Date: Sun, 14 Aug 2022 11:13:22 +0200 Subject: [PATCH] fix: address code review for jp checker --- src/Lean/Compiler/CompilerM.lean | 3 + src/Lean/Compiler/Decl.lean | 7 +-- src/Lean/Compiler/JoinPoints.lean | 96 +++++++++++++------------------ 3 files changed, 47 insertions(+), 59 deletions(-) diff --git a/src/Lean/Compiler/CompilerM.lean b/src/Lean/Compiler/CompilerM.lean index 8e7fe02e92..4a2f22be92 100644 --- a/src/Lean/Compiler/CompilerM.lean +++ b/src/Lean/Compiler/CompilerM.lean @@ -41,6 +41,9 @@ def findDecl? (fvarId : FVarId) : CompilerM (Option LocalDecl) := do let lctx := (← get).lctx return lctx.find? fvarId +def _root_.Lean.LocalDecl.isJp (decl : LocalDecl) : Bool := + decl.userName.getPrefix == `_jp + def mkAuxLetDecl (e : Expr) (prefixName := `_x) : CompilerM Expr := do if e.isFVar then return e diff --git a/src/Lean/Compiler/Decl.lean b/src/Lean/Compiler/Decl.lean index 870d4eacb8..0bac380214 100644 --- a/src/Lean/Compiler/Decl.lean +++ b/src/Lean/Compiler/Decl.lean @@ -87,9 +87,8 @@ in order to allow the optimizer to turn them into efficient machine code. This function ensures that inside the given declaration both of these conditions are satisfied and throws an exception otherwise. -/ -def Decl.checkJoinPoints (decl : Decl) : CompilerM Unit := do - let tails ← JoinPointChecker.getTails decl.value - discard <| tails.mapM JoinPointChecker.checkTail +def Decl.checkJoinPoints (decl : Decl) : CompilerM Unit := + JoinPointChecker.checkJoinPoints decl.value /-- @@ -103,7 +102,7 @@ These invariants are: -/ def Decl.check (decl : Decl) (cfg : Check.Config := {}): CoreM Unit := do Compiler.check decl.value cfg { lctx := {} } - discard <| checkJoinPoints decl |>.run {} + checkJoinPoints decl |>.run' {} let valueType ← InferType.inferType decl.value { lctx := {} } unless compatibleTypes decl.type valueType do throwError "declaration type mismatch at `{decl.name}`, value has type{indentExpr valueType}\nbut is expected to have type{indentExpr decl.type}" diff --git a/src/Lean/Compiler/JoinPoints.lean b/src/Lean/Compiler/JoinPoints.lean index 93acea36cb..7437ee9604 100644 --- a/src/Lean/Compiler/JoinPoints.lean +++ b/src/Lean/Compiler/JoinPoints.lean @@ -23,85 +23,71 @@ partial def containsNoJp (e : Expr) : CompilerM Unit := do match e.consumeMData with | .proj _ _ struct => containsNoJp struct | .lam .. => - let (_, b) ← visitLambda e - containsNoJp b + withNewScope do + let (_, b) ← visitLambda e + containsNoJp b | .letE .. => - let body ← visitLet e (fun e => do containsNoJp e; pure e) - containsNoJp body + withNewScope do + let body ← visitLet e (fun e => do containsNoJp e; pure e) + containsNoJp body | .app fn arg => containsNoJp fn containsNoJp arg | .fvar fvarId => let some decl ← findDecl? fvarId | unreachable! - if decl.userName.getPrefix ==`_jp then + if decl.isJp then throwError s!"Join point {decl.userName} in forbidden position" | .sort .. | .forallE .. | .const .. | .lit .. => return () | .bvar .. | .mvar .. | .mdata .. => unreachable! /-- -Obtain all the tail `Expr`s of `e`. Already checking whether non -tail values contain a join point and throwing an exception if they do. +Check whether all join points in `e` are in a valid position that is: +1. Fully applied +2. In tail position inside of `e` -/ -partial def getTails (e : Expr) : CompilerM (Array Expr) := do - let (_, body) ← visitLambda e - let (_, tails) ← go body |>.run #[] - return tails +partial def checkJoinPoints (e : Expr) : CompilerM Unit := do + goLambda e where - goLetValue (value : Expr) : StateRefT (Array Expr) CompilerM Unit := do + goLambda (e : Expr) : CompilerM Unit := do + withNewScope do + let (_, body) ← visitLambda e + go body + + goLetValue (value : Expr) : CompilerM Unit := do let value := value.consumeMData match value with - | .lam .. => - let (_, body) ← visitLambda value - go body - | _ => - containsNoJp value - return () + | .lam .. => goLambda value + | _ => containsNoJp value - go (e : Expr) : StateRefT (Array Expr) CompilerM Unit := do + go (e : Expr) : CompilerM Unit := do let e := e.consumeMData match e with | .letE .. => - let body ← visitLet e (fun value => do goLetValue value; pure value) - go body + withNewScope do + let body ← visitLet e (fun value => do goLetValue value; pure value) + go body + | .app (.fvar fvarId) arg => + let some decl ← findDecl? fvarId | unreachable! + if decl.isJp then + let args := e.getAppNumArgs + let jpArity := jpArity decl + if args != jpArity then + throwError s!"Join point {decl.userName} : {decl.type} is not fully applied in {e}" + -- Make sure no join point is inside the arguments since that would not be in tail position + containsNoJp arg | .app .. => if let some casesInfo ←isCasesApp? e then - let (motive, discrs, arms) ← visitMatch e casesInfo - containsNoJp motive - discard <| discrs.mapM (liftM ∘ containsNoJp) - discard <| arms.mapM go + withNewScope do + let (motive, discrs, arms) ← visitMatch e casesInfo + containsNoJp motive + discrs.forM containsNoJp + arms.forM go else - let tails ← get - set <| tails.push e - | .fvar .. | .sort .. | .forallE .. | .const .. | .lit .. | .proj .. | .lam .. => - let tails ← get - set <| tails.push e + containsNoJp e + | .proj .. | .lam .. => containsNoJp e + | .fvar .. | .sort .. | .forallE .. | .const .. | .lit .. => return () | .bvar .. | .mvar .. | .mdata .. => unreachable! -/-- -Checks that a tail is valid, that is either a fully applied join point -or doesn't contain a join point at all. --/ -def checkTail (e : Expr) : CompilerM Unit := do - match e with - | .app (.fvar fvarId) arg => - let some decl ← findDecl? fvarId | unreachable! - if decl.userName.getPrefix == `_jp then - let args := e.getAppNumArgs - let jpArity := jpArity decl - if args != jpArity then - throwError s!"Join point {decl.userName} : {decl.type} is not fully applied in {e}" - -- Make sure no join point is inside the arguments since that would not be in tail position - containsNoJp arg - /- - If these are in tail position they may not contain any join point - since that would mean it is not in tail position. - Note: the special case where the app function is one is handled above - -/ - | .app .. | .proj .. | .lam .. => containsNoJp e - | .fvar .. | .sort .. | .forallE .. | .const .. | .lit .. => return () - | .letE .. | .mdata .. | .bvar .. | .mvar .. => unreachable! - - end JoinPointChecker end Lean.Compiler