feat: add join point detector

This commit is contained in:
Henrik Böving 2022-08-17 19:15:38 +02:00 committed by Leonardo de Moura
parent ea35f6e091
commit 70ef3875d1
4 changed files with 133 additions and 65 deletions

View file

@ -166,22 +166,23 @@ and returning the body of the final one.
class VisitLet (m : Type → Type) where
/--
Move through consecutive top level let binders in the first argument,
applying the function in the second argument to the values before the
the local declarations for the binders are created. The final return
value is the body of the last let binder in the chain.
applying the function in the second argument to the binder name
and the values before the the local declarations for the binders are
created. The final return value is the body of the last let binder in
the chain.
-/
visitLet : Expr → (Expr → m Expr) → m Expr
visitLet : Expr → (Name → Expr → m Expr) → m Expr
export VisitLet (visitLet)
def visitLetImp (e : Expr) (f : Expr → CompilerM Expr) : CompilerM Expr :=
def visitLetImp (e : Expr) (f : Name → Expr → CompilerM Expr) : CompilerM Expr :=
go e #[]
where
go (e : Expr) (fvars : Array Expr) : CompilerM Expr := do
if let .letE binderName type value body nonDep := e then
let type := type.instantiateRev fvars
let value := value.instantiateRev fvars
let value ← f value
let value ← f binderName value
let fvar ← mkLetDecl binderName type value nonDep
go body (fvars.push fvar)
else
@ -191,7 +192,7 @@ instance : VisitLet CompilerM where
visitLet := visitLetImp
instance [VisitLet m] : VisitLet (ReaderT ρ m) where
visitLet e f ctx := visitLet e (f · ctx)
visitLet e f ctx := visitLet e (f · · ctx)
instance [VisitLet m] : VisitLet (StateRefT' ω σ m) := inferInstanceAs (VisitLet (ReaderT _ _))

View file

@ -133,7 +133,7 @@ This function ensures that inside the given declaration both of these
conditions are satisfied and throws an exception otherwise.
-/
def Decl.checkJoinPoints (decl : Decl) : CompilerM Unit :=
JoinPointChecker.checkJoinPoints decl.value
JoinPoints.JoinPointChecker.checkJoinPoints decl.value
/--

View file

@ -7,34 +7,125 @@ import Lean.Compiler.CompilerM
namespace Lean.Compiler
namespace JoinPointChecker
def jpArity (jp : LocalDecl) : Nat :=
getLambdaArity jp.value
namespace JoinPoints
section Visitors
variable {m : Type → Type} [Monad m] [MonadLiftT CompilerM m] [VisitLet m] [MonadFunctorT CompilerM m]
partial def forEachFVar (e : Expr) (f : FVarId → m Unit) : m Unit := do
let e := e.consumeMData
match e with
| .proj _ _ struct => forEachFVar struct f
| .lam .. =>
withNewScope do
let (_, body) ← visitLambda e
forEachFVar body f
| .letE .. =>
withNewScope do
let body ← visitLet e (fun _ e => do forEachFVar e f; pure e)
forEachFVar body f
| .app fn arg =>
forEachFVar fn f
forEachFVar arg f
| .fvar fvarId => f fvarId
| .sort .. | .forallE .. | .const .. | .lit .. => return ()
| .bvar .. | .mvar .. | .mdata .. => unreachable!
mutual
variable (tailAppFvarVisitor : FVarId → Expr → m Unit) (valueValidator : Expr → m Unit) (letValueVisitor : Name → Expr → m Expr)
private partial def visitLambda (e : Expr) : m Unit := do
withNewScope do
let (_, body) ← Compiler.visitLambda e
visitTails body
private partial def visitTails (e : Expr) : m Unit := do
let e := e.consumeMData
match e with
| .letE .. =>
withNewScope do
let body ← visitLet e letValueVisitor
visitTails body
| .app (.fvar fvarId) arg =>
tailAppFvarVisitor fvarId e
valueValidator arg
| .app .. =>
if let some casesInfo ← (isCasesApp? e : CompilerM (Option CasesInfo)) then
withNewScope do
let (motive, discrs, arms) ← visitMatch e casesInfo
valueValidator motive
discrs.forM valueValidator
arms.forM visitTails
else
valueValidator e
| .proj .. | .lam .. => valueValidator e
| .fvar .. | .sort .. | .forallE .. | .const .. | .lit .. => return ()
| .bvar .. | .mvar .. | .mdata .. => unreachable!
end
end Visitors
namespace JoinPointFinder
abbrev M := StateRefT (Std.HashMap Name Nat) CompilerM
private partial def removeCandidatesContainedIn (e : Expr) : M Unit := do
let remover := fun fvarId => do
let some decl ← findDecl? fvarId | unreachable!
modify (fun jps? => jps?.erase decl.userName)
forEachFVar e remover
/--
Return a set of let declaration names inside of `e` that qualify as a join
point that is:
1. Are always used in tail position
2. Are always fully applied
Since declaration names are unique inside of LCNF the let bindings and
call sites can be uniquely identified by this.
-/
partial def findJoinPoints (e : Expr) : CompilerM (Array Name) := do
let (_, names) ← visitLambda goTailApp removeCandidatesContainedIn goLetValue e |>.run .empty
return names.toArray.map Prod.fst
where
goLetValue (userName : Name) (value : Expr) : M Expr := do
if let .lam .. := value then
withNewScope do
let (vars, body) ← Compiler.visitLambda value
modify (fun jps? => jps?.insert userName vars.size)
visitTails goTailApp removeCandidatesContainedIn goLetValue body
else
visitTails goTailApp removeCandidatesContainedIn goLetValue value
return value
goTailApp (fvarId : FVarId) (e : Expr) : M Unit := do
let some decl ← findDecl? fvarId | unreachable!
if let some jpArity := (←get).find? decl.userName then
let args := e.getAppNumArgs
if args != jpArity then
modify (fun jps? => jps?.erase decl.userName)
end JoinPointFinder
namespace JoinPointChecker
/--
Throws an exception if `e` contains a join point.
-/
partial def containsNoJp (e : Expr) : CompilerM Unit := do
match e.consumeMData with
| .proj _ _ struct => containsNoJp struct
| .lam .. =>
withNewScope do
let (_, b) ← visitLambda e
containsNoJp b
| .letE .. =>
withNewScope do
let body ← visitLet e (fun e => do containsNoJp e; pure e)
containsNoJp body
| .app fn arg =>
containsNoJp fn
containsNoJp arg
| .fvar fvarId =>
def containsNoJp (e : Expr) : CompilerM Unit := do
trace[Compiler.step] s!"Checking whether {e} contains jp"
let checker := fun fvarId => do
let some decl ← findDecl? fvarId | unreachable!
if decl.isJp then
throwError s!"Join point {decl.userName} in forbidden position"
| .sort .. | .forallE .. | .const .. | .lit .. => return ()
| .bvar .. | .mvar .. | .mdata .. => unreachable!
forEachFVar e checker
/--
Check whether all join points in `e` are in a valid position that is:
@ -42,48 +133,24 @@ Check whether all join points in `e` are in a valid position that is:
2. In tail position inside of `e`
-/
partial def checkJoinPoints (e : Expr) : CompilerM Unit := do
goLambda e
visitLambda goTailApp containsNoJp goLetValue e
where
goLambda (e : Expr) : CompilerM Unit := do
withNewScope do
let (_, body) ← visitLambda e
go body
goLetValue (value : Expr) : CompilerM Unit := do
let value := value.consumeMData
goLetValue (_userName : Name) (value : Expr) : CompilerM Expr := do
match value with
| .lam .. => goLambda value
| .lam .. => visitLambda goTailApp containsNoJp goLetValue value
| _ => containsNoJp value
return value
go (e : Expr) : CompilerM Unit := do
let e := e.consumeMData
match e with
| .letE .. =>
withNewScope do
let body ← visitLet e (fun value => do goLetValue value; pure value)
go body
| .app (.fvar fvarId) arg =>
let some decl ← findDecl? fvarId | unreachable!
if decl.isJp then
let args := e.getAppNumArgs
let jpArity := jpArity decl
if args != jpArity then
throwError s!"Join point {decl.userName} : {decl.type} is not fully applied in {e}"
-- Make sure no join point is inside the arguments since that would not be in tail position
containsNoJp arg
| .app .. =>
if let some casesInfo ←isCasesApp? e then
withNewScope do
let (motive, discrs, arms) ← visitMatch e casesInfo
containsNoJp motive
discrs.forM containsNoJp
arms.forM go
else
containsNoJp e
| .proj .. | .lam .. => containsNoJp e
| .fvar .. | .sort .. | .forallE .. | .const .. | .lit .. => return ()
| .bvar .. | .mvar .. | .mdata .. => unreachable!
goTailApp (fvarId : FVarId) (e : Expr) := do
let some decl ← findDecl? fvarId | unreachable!
if decl.isJp then
let args := e.getAppNumArgs
let jpArity := jpArity decl
if args != jpArity then
throwError s!"Join point {decl.userName} : {decl.type} is not fully applied in {e}"
end JoinPointChecker
end JoinPoints
end Lean.Compiler

View file

@ -91,7 +91,7 @@ where
match e with
| .letE .. =>
withNewScope do
let body ← visitLet e fun value => do goValue value; return value
let body ← visitLet e fun _ value => do goValue value; return value
go body
| e =>
if let some casesInfo ← isCasesApp? e then