From 70ef3875d10110864b8e89fa2fdb2fe627529cf0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Henrik=20B=C3=B6ving?= Date: Wed, 17 Aug 2022 19:15:38 +0200 Subject: [PATCH] feat: add join point detector --- src/Lean/Compiler/CompilerM.lean | 15 +-- src/Lean/Compiler/Decl.lean | 2 +- src/Lean/Compiler/JoinPoints.lean | 179 ++++++++++++++++++++---------- src/Lean/Compiler/Simp.lean | 2 +- 4 files changed, 133 insertions(+), 65 deletions(-) diff --git a/src/Lean/Compiler/CompilerM.lean b/src/Lean/Compiler/CompilerM.lean index 0e316ff0da..ff98843a3d 100644 --- a/src/Lean/Compiler/CompilerM.lean +++ b/src/Lean/Compiler/CompilerM.lean @@ -166,22 +166,23 @@ and returning the body of the final one. class VisitLet (m : Type → Type) where /-- Move through consecutive top level let binders in the first argument, - applying the function in the second argument to the values before the - the local declarations for the binders are created. The final return - value is the body of the last let binder in the chain. + applying the function in the second argument to the binder name + and the values before the the local declarations for the binders are + created. The final return value is the body of the last let binder in + the chain. -/ - visitLet : Expr → (Expr → m Expr) → m Expr + visitLet : Expr → (Name → Expr → m Expr) → m Expr export VisitLet (visitLet) -def visitLetImp (e : Expr) (f : Expr → CompilerM Expr) : CompilerM Expr := +def visitLetImp (e : Expr) (f : Name → Expr → CompilerM Expr) : CompilerM Expr := go e #[] where go (e : Expr) (fvars : Array Expr) : CompilerM Expr := do if let .letE binderName type value body nonDep := e then let type := type.instantiateRev fvars let value := value.instantiateRev fvars - let value ← f value + let value ← f binderName value let fvar ← mkLetDecl binderName type value nonDep go body (fvars.push fvar) else @@ -191,7 +192,7 @@ instance : VisitLet CompilerM where visitLet := visitLetImp instance [VisitLet m] : VisitLet (ReaderT ρ m) where - visitLet e f ctx := visitLet e (f · ctx) + visitLet e f ctx := visitLet e (f · · ctx) instance [VisitLet m] : VisitLet (StateRefT' ω σ m) := inferInstanceAs (VisitLet (ReaderT _ _)) diff --git a/src/Lean/Compiler/Decl.lean b/src/Lean/Compiler/Decl.lean index 9f746660cc..a1e592878a 100644 --- a/src/Lean/Compiler/Decl.lean +++ b/src/Lean/Compiler/Decl.lean @@ -133,7 +133,7 @@ This function ensures that inside the given declaration both of these conditions are satisfied and throws an exception otherwise. -/ def Decl.checkJoinPoints (decl : Decl) : CompilerM Unit := - JoinPointChecker.checkJoinPoints decl.value + JoinPoints.JoinPointChecker.checkJoinPoints decl.value /-- diff --git a/src/Lean/Compiler/JoinPoints.lean b/src/Lean/Compiler/JoinPoints.lean index d4bb102cc1..3baaad9b32 100644 --- a/src/Lean/Compiler/JoinPoints.lean +++ b/src/Lean/Compiler/JoinPoints.lean @@ -7,34 +7,125 @@ import Lean.Compiler.CompilerM namespace Lean.Compiler -namespace JoinPointChecker - def jpArity (jp : LocalDecl) : Nat := getLambdaArity jp.value + +namespace JoinPoints + +section Visitors + +variable {m : Type → Type} [Monad m] [MonadLiftT CompilerM m] [VisitLet m] [MonadFunctorT CompilerM m] + +partial def forEachFVar (e : Expr) (f : FVarId → m Unit) : m Unit := do + let e := e.consumeMData + match e with + | .proj _ _ struct => forEachFVar struct f + | .lam .. => + withNewScope do + let (_, body) ← visitLambda e + forEachFVar body f + | .letE .. => + withNewScope do + let body ← visitLet e (fun _ e => do forEachFVar e f; pure e) + forEachFVar body f + | .app fn arg => + forEachFVar fn f + forEachFVar arg f + | .fvar fvarId => f fvarId + | .sort .. | .forallE .. | .const .. | .lit .. => return () + | .bvar .. | .mvar .. | .mdata .. => unreachable! + +mutual + +variable (tailAppFvarVisitor : FVarId → Expr → m Unit) (valueValidator : Expr → m Unit) (letValueVisitor : Name → Expr → m Expr) + +private partial def visitLambda (e : Expr) : m Unit := do + withNewScope do + let (_, body) ← Compiler.visitLambda e + visitTails body + +private partial def visitTails (e : Expr) : m Unit := do + let e := e.consumeMData + match e with + | .letE .. => + withNewScope do + let body ← visitLet e letValueVisitor + visitTails body + | .app (.fvar fvarId) arg => + tailAppFvarVisitor fvarId e + valueValidator arg + | .app .. => + if let some casesInfo ← (isCasesApp? e : CompilerM (Option CasesInfo)) then + withNewScope do + let (motive, discrs, arms) ← visitMatch e casesInfo + valueValidator motive + discrs.forM valueValidator + arms.forM visitTails + else + valueValidator e + | .proj .. | .lam .. => valueValidator e + | .fvar .. | .sort .. | .forallE .. | .const .. | .lit .. => return () + | .bvar .. | .mvar .. | .mdata .. => unreachable! + +end + +end Visitors + +namespace JoinPointFinder + +abbrev M := StateRefT (Std.HashMap Name Nat) CompilerM + +private partial def removeCandidatesContainedIn (e : Expr) : M Unit := do + let remover := fun fvarId => do + let some decl ← findDecl? fvarId | unreachable! + modify (fun jps? => jps?.erase decl.userName) + forEachFVar e remover + +/-- +Return a set of let declaration names inside of `e` that qualify as a join +point that is: +1. Are always used in tail position +2. Are always fully applied + +Since declaration names are unique inside of LCNF the let bindings and +call sites can be uniquely identified by this. +-/ +partial def findJoinPoints (e : Expr) : CompilerM (Array Name) := do + let (_, names) ← visitLambda goTailApp removeCandidatesContainedIn goLetValue e |>.run .empty + return names.toArray.map Prod.fst +where + goLetValue (userName : Name) (value : Expr) : M Expr := do + if let .lam .. := value then + withNewScope do + let (vars, body) ← Compiler.visitLambda value + modify (fun jps? => jps?.insert userName vars.size) + visitTails goTailApp removeCandidatesContainedIn goLetValue body + else + visitTails goTailApp removeCandidatesContainedIn goLetValue value + return value + + goTailApp (fvarId : FVarId) (e : Expr) : M Unit := do + let some decl ← findDecl? fvarId | unreachable! + if let some jpArity := (←get).find? decl.userName then + let args := e.getAppNumArgs + if args != jpArity then + modify (fun jps? => jps?.erase decl.userName) + +end JoinPointFinder + +namespace JoinPointChecker + /-- Throws an exception if `e` contains a join point. -/ -partial def containsNoJp (e : Expr) : CompilerM Unit := do - match e.consumeMData with - | .proj _ _ struct => containsNoJp struct - | .lam .. => - withNewScope do - let (_, b) ← visitLambda e - containsNoJp b - | .letE .. => - withNewScope do - let body ← visitLet e (fun e => do containsNoJp e; pure e) - containsNoJp body - | .app fn arg => - containsNoJp fn - containsNoJp arg - | .fvar fvarId => +def containsNoJp (e : Expr) : CompilerM Unit := do + trace[Compiler.step] s!"Checking whether {e} contains jp" + let checker := fun fvarId => do let some decl ← findDecl? fvarId | unreachable! if decl.isJp then throwError s!"Join point {decl.userName} in forbidden position" - | .sort .. | .forallE .. | .const .. | .lit .. => return () - | .bvar .. | .mvar .. | .mdata .. => unreachable! + forEachFVar e checker /-- Check whether all join points in `e` are in a valid position that is: @@ -42,48 +133,24 @@ Check whether all join points in `e` are in a valid position that is: 2. In tail position inside of `e` -/ partial def checkJoinPoints (e : Expr) : CompilerM Unit := do - goLambda e + visitLambda goTailApp containsNoJp goLetValue e + where - goLambda (e : Expr) : CompilerM Unit := do - withNewScope do - let (_, body) ← visitLambda e - go body - - goLetValue (value : Expr) : CompilerM Unit := do - let value := value.consumeMData + goLetValue (_userName : Name) (value : Expr) : CompilerM Expr := do match value with - | .lam .. => goLambda value + | .lam .. => visitLambda goTailApp containsNoJp goLetValue value | _ => containsNoJp value + return value - go (e : Expr) : CompilerM Unit := do - let e := e.consumeMData - match e with - | .letE .. => - withNewScope do - let body ← visitLet e (fun value => do goLetValue value; pure value) - go body - | .app (.fvar fvarId) arg => - let some decl ← findDecl? fvarId | unreachable! - if decl.isJp then - let args := e.getAppNumArgs - let jpArity := jpArity decl - if args != jpArity then - throwError s!"Join point {decl.userName} : {decl.type} is not fully applied in {e}" - -- Make sure no join point is inside the arguments since that would not be in tail position - containsNoJp arg - | .app .. => - if let some casesInfo ←isCasesApp? e then - withNewScope do - let (motive, discrs, arms) ← visitMatch e casesInfo - containsNoJp motive - discrs.forM containsNoJp - arms.forM go - else - containsNoJp e - | .proj .. | .lam .. => containsNoJp e - | .fvar .. | .sort .. | .forallE .. | .const .. | .lit .. => return () - | .bvar .. | .mvar .. | .mdata .. => unreachable! + goTailApp (fvarId : FVarId) (e : Expr) := do + let some decl ← findDecl? fvarId | unreachable! + if decl.isJp then + let args := e.getAppNumArgs + let jpArity := jpArity decl + if args != jpArity then + throwError s!"Join point {decl.userName} : {decl.type} is not fully applied in {e}" end JoinPointChecker +end JoinPoints end Lean.Compiler diff --git a/src/Lean/Compiler/Simp.lean b/src/Lean/Compiler/Simp.lean index bb60bc9080..9869c1c8a4 100644 --- a/src/Lean/Compiler/Simp.lean +++ b/src/Lean/Compiler/Simp.lean @@ -91,7 +91,7 @@ where match e with | .letE .. => withNewScope do - let body ← visitLet e fun value => do goValue value; return value + let body ← visitLet e fun _ value => do goValue value; return value go body | e => if let some casesInfo ← isCasesApp? e then