feat: add join point detector
This commit is contained in:
parent
ea35f6e091
commit
70ef3875d1
4 changed files with 133 additions and 65 deletions
|
|
@ -166,22 +166,23 @@ and returning the body of the final one.
|
|||
class VisitLet (m : Type → Type) where
|
||||
/--
|
||||
Move through consecutive top level let binders in the first argument,
|
||||
applying the function in the second argument to the values before the
|
||||
the local declarations for the binders are created. The final return
|
||||
value is the body of the last let binder in the chain.
|
||||
applying the function in the second argument to the binder name
|
||||
and the values before the the local declarations for the binders are
|
||||
created. The final return value is the body of the last let binder in
|
||||
the chain.
|
||||
-/
|
||||
visitLet : Expr → (Expr → m Expr) → m Expr
|
||||
visitLet : Expr → (Name → Expr → m Expr) → m Expr
|
||||
|
||||
export VisitLet (visitLet)
|
||||
|
||||
def visitLetImp (e : Expr) (f : Expr → CompilerM Expr) : CompilerM Expr :=
|
||||
def visitLetImp (e : Expr) (f : Name → Expr → CompilerM Expr) : CompilerM Expr :=
|
||||
go e #[]
|
||||
where
|
||||
go (e : Expr) (fvars : Array Expr) : CompilerM Expr := do
|
||||
if let .letE binderName type value body nonDep := e then
|
||||
let type := type.instantiateRev fvars
|
||||
let value := value.instantiateRev fvars
|
||||
let value ← f value
|
||||
let value ← f binderName value
|
||||
let fvar ← mkLetDecl binderName type value nonDep
|
||||
go body (fvars.push fvar)
|
||||
else
|
||||
|
|
@ -191,7 +192,7 @@ instance : VisitLet CompilerM where
|
|||
visitLet := visitLetImp
|
||||
|
||||
instance [VisitLet m] : VisitLet (ReaderT ρ m) where
|
||||
visitLet e f ctx := visitLet e (f · ctx)
|
||||
visitLet e f ctx := visitLet e (f · · ctx)
|
||||
|
||||
instance [VisitLet m] : VisitLet (StateRefT' ω σ m) := inferInstanceAs (VisitLet (ReaderT _ _))
|
||||
|
||||
|
|
|
|||
|
|
@ -133,7 +133,7 @@ This function ensures that inside the given declaration both of these
|
|||
conditions are satisfied and throws an exception otherwise.
|
||||
-/
|
||||
def Decl.checkJoinPoints (decl : Decl) : CompilerM Unit :=
|
||||
JoinPointChecker.checkJoinPoints decl.value
|
||||
JoinPoints.JoinPointChecker.checkJoinPoints decl.value
|
||||
|
||||
|
||||
/--
|
||||
|
|
|
|||
|
|
@ -7,34 +7,125 @@ import Lean.Compiler.CompilerM
|
|||
|
||||
namespace Lean.Compiler
|
||||
|
||||
namespace JoinPointChecker
|
||||
|
||||
def jpArity (jp : LocalDecl) : Nat :=
|
||||
getLambdaArity jp.value
|
||||
|
||||
|
||||
namespace JoinPoints
|
||||
|
||||
section Visitors
|
||||
|
||||
variable {m : Type → Type} [Monad m] [MonadLiftT CompilerM m] [VisitLet m] [MonadFunctorT CompilerM m]
|
||||
|
||||
partial def forEachFVar (e : Expr) (f : FVarId → m Unit) : m Unit := do
|
||||
let e := e.consumeMData
|
||||
match e with
|
||||
| .proj _ _ struct => forEachFVar struct f
|
||||
| .lam .. =>
|
||||
withNewScope do
|
||||
let (_, body) ← visitLambda e
|
||||
forEachFVar body f
|
||||
| .letE .. =>
|
||||
withNewScope do
|
||||
let body ← visitLet e (fun _ e => do forEachFVar e f; pure e)
|
||||
forEachFVar body f
|
||||
| .app fn arg =>
|
||||
forEachFVar fn f
|
||||
forEachFVar arg f
|
||||
| .fvar fvarId => f fvarId
|
||||
| .sort .. | .forallE .. | .const .. | .lit .. => return ()
|
||||
| .bvar .. | .mvar .. | .mdata .. => unreachable!
|
||||
|
||||
mutual
|
||||
|
||||
variable (tailAppFvarVisitor : FVarId → Expr → m Unit) (valueValidator : Expr → m Unit) (letValueVisitor : Name → Expr → m Expr)
|
||||
|
||||
private partial def visitLambda (e : Expr) : m Unit := do
|
||||
withNewScope do
|
||||
let (_, body) ← Compiler.visitLambda e
|
||||
visitTails body
|
||||
|
||||
private partial def visitTails (e : Expr) : m Unit := do
|
||||
let e := e.consumeMData
|
||||
match e with
|
||||
| .letE .. =>
|
||||
withNewScope do
|
||||
let body ← visitLet e letValueVisitor
|
||||
visitTails body
|
||||
| .app (.fvar fvarId) arg =>
|
||||
tailAppFvarVisitor fvarId e
|
||||
valueValidator arg
|
||||
| .app .. =>
|
||||
if let some casesInfo ← (isCasesApp? e : CompilerM (Option CasesInfo)) then
|
||||
withNewScope do
|
||||
let (motive, discrs, arms) ← visitMatch e casesInfo
|
||||
valueValidator motive
|
||||
discrs.forM valueValidator
|
||||
arms.forM visitTails
|
||||
else
|
||||
valueValidator e
|
||||
| .proj .. | .lam .. => valueValidator e
|
||||
| .fvar .. | .sort .. | .forallE .. | .const .. | .lit .. => return ()
|
||||
| .bvar .. | .mvar .. | .mdata .. => unreachable!
|
||||
|
||||
end
|
||||
|
||||
end Visitors
|
||||
|
||||
namespace JoinPointFinder
|
||||
|
||||
abbrev M := StateRefT (Std.HashMap Name Nat) CompilerM
|
||||
|
||||
private partial def removeCandidatesContainedIn (e : Expr) : M Unit := do
|
||||
let remover := fun fvarId => do
|
||||
let some decl ← findDecl? fvarId | unreachable!
|
||||
modify (fun jps? => jps?.erase decl.userName)
|
||||
forEachFVar e remover
|
||||
|
||||
/--
|
||||
Return a set of let declaration names inside of `e` that qualify as a join
|
||||
point that is:
|
||||
1. Are always used in tail position
|
||||
2. Are always fully applied
|
||||
|
||||
Since declaration names are unique inside of LCNF the let bindings and
|
||||
call sites can be uniquely identified by this.
|
||||
-/
|
||||
partial def findJoinPoints (e : Expr) : CompilerM (Array Name) := do
|
||||
let (_, names) ← visitLambda goTailApp removeCandidatesContainedIn goLetValue e |>.run .empty
|
||||
return names.toArray.map Prod.fst
|
||||
where
|
||||
goLetValue (userName : Name) (value : Expr) : M Expr := do
|
||||
if let .lam .. := value then
|
||||
withNewScope do
|
||||
let (vars, body) ← Compiler.visitLambda value
|
||||
modify (fun jps? => jps?.insert userName vars.size)
|
||||
visitTails goTailApp removeCandidatesContainedIn goLetValue body
|
||||
else
|
||||
visitTails goTailApp removeCandidatesContainedIn goLetValue value
|
||||
return value
|
||||
|
||||
goTailApp (fvarId : FVarId) (e : Expr) : M Unit := do
|
||||
let some decl ← findDecl? fvarId | unreachable!
|
||||
if let some jpArity := (←get).find? decl.userName then
|
||||
let args := e.getAppNumArgs
|
||||
if args != jpArity then
|
||||
modify (fun jps? => jps?.erase decl.userName)
|
||||
|
||||
end JoinPointFinder
|
||||
|
||||
namespace JoinPointChecker
|
||||
|
||||
/--
|
||||
Throws an exception if `e` contains a join point.
|
||||
-/
|
||||
partial def containsNoJp (e : Expr) : CompilerM Unit := do
|
||||
match e.consumeMData with
|
||||
| .proj _ _ struct => containsNoJp struct
|
||||
| .lam .. =>
|
||||
withNewScope do
|
||||
let (_, b) ← visitLambda e
|
||||
containsNoJp b
|
||||
| .letE .. =>
|
||||
withNewScope do
|
||||
let body ← visitLet e (fun e => do containsNoJp e; pure e)
|
||||
containsNoJp body
|
||||
| .app fn arg =>
|
||||
containsNoJp fn
|
||||
containsNoJp arg
|
||||
| .fvar fvarId =>
|
||||
def containsNoJp (e : Expr) : CompilerM Unit := do
|
||||
trace[Compiler.step] s!"Checking whether {e} contains jp"
|
||||
let checker := fun fvarId => do
|
||||
let some decl ← findDecl? fvarId | unreachable!
|
||||
if decl.isJp then
|
||||
throwError s!"Join point {decl.userName} in forbidden position"
|
||||
| .sort .. | .forallE .. | .const .. | .lit .. => return ()
|
||||
| .bvar .. | .mvar .. | .mdata .. => unreachable!
|
||||
forEachFVar e checker
|
||||
|
||||
/--
|
||||
Check whether all join points in `e` are in a valid position that is:
|
||||
|
|
@ -42,48 +133,24 @@ Check whether all join points in `e` are in a valid position that is:
|
|||
2. In tail position inside of `e`
|
||||
-/
|
||||
partial def checkJoinPoints (e : Expr) : CompilerM Unit := do
|
||||
goLambda e
|
||||
visitLambda goTailApp containsNoJp goLetValue e
|
||||
|
||||
where
|
||||
goLambda (e : Expr) : CompilerM Unit := do
|
||||
withNewScope do
|
||||
let (_, body) ← visitLambda e
|
||||
go body
|
||||
|
||||
goLetValue (value : Expr) : CompilerM Unit := do
|
||||
let value := value.consumeMData
|
||||
goLetValue (_userName : Name) (value : Expr) : CompilerM Expr := do
|
||||
match value with
|
||||
| .lam .. => goLambda value
|
||||
| .lam .. => visitLambda goTailApp containsNoJp goLetValue value
|
||||
| _ => containsNoJp value
|
||||
return value
|
||||
|
||||
go (e : Expr) : CompilerM Unit := do
|
||||
let e := e.consumeMData
|
||||
match e with
|
||||
| .letE .. =>
|
||||
withNewScope do
|
||||
let body ← visitLet e (fun value => do goLetValue value; pure value)
|
||||
go body
|
||||
| .app (.fvar fvarId) arg =>
|
||||
let some decl ← findDecl? fvarId | unreachable!
|
||||
if decl.isJp then
|
||||
let args := e.getAppNumArgs
|
||||
let jpArity := jpArity decl
|
||||
if args != jpArity then
|
||||
throwError s!"Join point {decl.userName} : {decl.type} is not fully applied in {e}"
|
||||
-- Make sure no join point is inside the arguments since that would not be in tail position
|
||||
containsNoJp arg
|
||||
| .app .. =>
|
||||
if let some casesInfo ←isCasesApp? e then
|
||||
withNewScope do
|
||||
let (motive, discrs, arms) ← visitMatch e casesInfo
|
||||
containsNoJp motive
|
||||
discrs.forM containsNoJp
|
||||
arms.forM go
|
||||
else
|
||||
containsNoJp e
|
||||
| .proj .. | .lam .. => containsNoJp e
|
||||
| .fvar .. | .sort .. | .forallE .. | .const .. | .lit .. => return ()
|
||||
| .bvar .. | .mvar .. | .mdata .. => unreachable!
|
||||
goTailApp (fvarId : FVarId) (e : Expr) := do
|
||||
let some decl ← findDecl? fvarId | unreachable!
|
||||
if decl.isJp then
|
||||
let args := e.getAppNumArgs
|
||||
let jpArity := jpArity decl
|
||||
if args != jpArity then
|
||||
throwError s!"Join point {decl.userName} : {decl.type} is not fully applied in {e}"
|
||||
|
||||
end JoinPointChecker
|
||||
end JoinPoints
|
||||
|
||||
end Lean.Compiler
|
||||
|
|
|
|||
|
|
@ -91,7 +91,7 @@ where
|
|||
match e with
|
||||
| .letE .. =>
|
||||
withNewScope do
|
||||
let body ← visitLet e fun value => do goValue value; return value
|
||||
let body ← visitLet e fun _ value => do goValue value; return value
|
||||
go body
|
||||
| e =>
|
||||
if let some casesInfo ← isCasesApp? e then
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue