chore: address PR comments
This commit is contained in:
parent
7b3709e28a
commit
e15e6bfaee
3 changed files with 136 additions and 113 deletions
46
src/Lean/Compiler/LCNF/FVarUtil.lean
Normal file
46
src/Lean/Compiler/LCNF/FVarUtil.lean
Normal file
|
|
@ -0,0 +1,46 @@
|
|||
/-
|
||||
Copyright (c) 2022 Henrik Böving. All rights reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Henrik Böving
|
||||
-/
|
||||
import Lean.Expr
|
||||
|
||||
namespace Lean.Compiler.LCNF
|
||||
|
||||
partial def mapFVarM [Monad m] (f : FVarId → m FVarId) (e : Expr) : m Expr := do
|
||||
match e with
|
||||
| .proj typ idx struct => return .proj typ idx (← mapFVarM f struct)
|
||||
| .app fn arg => return .app (← mapFVarM f fn) (← mapFVarM f arg)
|
||||
| .fvar fvarId => return .fvar (← f fvarId)
|
||||
| .lam arg ty body bi =>
|
||||
return .lam arg (← mapFVarM f ty) (← mapFVarM f body) bi
|
||||
| .forallE arg ty body bi =>
|
||||
return .forallE arg (←mapFVarM f ty) (← mapFVarM f body) bi
|
||||
| .letE var ty value body nonDep =>
|
||||
return .letE var (← mapFVarM f ty) (← mapFVarM f value) (← mapFVarM f body) nonDep
|
||||
| .bvar .. | .sort .. => return e
|
||||
| .mdata .. | .const .. | .lit .. => return e
|
||||
| .mvar .. => unreachable!
|
||||
|
||||
partial def forFVarM [Monad m] (f : FVarId → m Unit) (e : Expr) : m Unit := do
|
||||
match e with
|
||||
| .proj _ _ struct => forFVarM f struct
|
||||
| .app fn arg =>
|
||||
forFVarM f fn
|
||||
forFVarM f arg
|
||||
| .fvar fvarId => f fvarId
|
||||
| .lam _ ty body .. =>
|
||||
forFVarM f ty
|
||||
forFVarM f body
|
||||
| .forallE _ ty body .. =>
|
||||
forFVarM f ty
|
||||
forFVarM f body
|
||||
| .letE _ ty value body .. =>
|
||||
forFVarM f ty
|
||||
forFVarM f value
|
||||
forFVarM f body
|
||||
| .bvar .. | .sort .. => return
|
||||
| .mdata .. | .const .. | .lit .. => return
|
||||
| .mvar .. => unreachable!
|
||||
|
||||
end Lean.Compiler.LCNF
|
||||
|
|
@ -6,89 +6,11 @@ Authors: Henrik Böving
|
|||
import Lean.Compiler.LCNF.CompilerM
|
||||
import Lean.Compiler.LCNF.PassManager
|
||||
import Lean.Compiler.LCNF.PullFunDecls
|
||||
import Lean.Compiler.LCNF.FVarUtil
|
||||
import Lean.Compiler.LCNF.ScopeM
|
||||
|
||||
namespace Lean.Compiler.LCNF
|
||||
|
||||
-- TODO: These can be used in a much more general context
|
||||
partial def mapFVarM [Monad m] (f : FVarId → m FVarId) (e : Expr) : m Expr := do
|
||||
match e with
|
||||
| .proj typ idx struct => return .proj typ idx (←mapFVarM f struct)
|
||||
| .app fn arg => return .app (←mapFVarM f fn) (←mapFVarM f arg)
|
||||
| .fvar fvarId => return .fvar (←f fvarId)
|
||||
| .lam arg ty body bi =>
|
||||
return .lam arg (←mapFVarM f ty) (←mapFVarM f body) bi
|
||||
| .forallE arg ty body bi =>
|
||||
return .forallE arg (←mapFVarM f ty) (←mapFVarM f body) bi
|
||||
| .letE var ty value body nonDep =>
|
||||
return .letE var (←mapFVarM f ty) (←mapFVarM f value) (←mapFVarM f body) nonDep
|
||||
| .bvar .. | .sort .. => return e
|
||||
| .mdata .. | .const .. | .lit .. => return e
|
||||
| .mvar .. => unreachable!
|
||||
|
||||
partial def forFVarM [Monad m] (f : FVarId → m Unit) (e : Expr) : m Unit := do
|
||||
match e with
|
||||
| .proj _ _ struct => forFVarM f struct
|
||||
| .app fn arg =>
|
||||
forFVarM f fn
|
||||
forFVarM f arg
|
||||
| .fvar fvarId => f fvarId
|
||||
| .lam _ ty body .. =>
|
||||
forFVarM f ty
|
||||
forFVarM f body
|
||||
| .forallE _ ty body .. =>
|
||||
forFVarM f ty
|
||||
forFVarM f body
|
||||
| .letE _ ty value body .. =>
|
||||
forFVarM f ty
|
||||
forFVarM f value
|
||||
forFVarM f body
|
||||
| .bvar .. | .sort .. => return
|
||||
| .mdata .. | .const .. | .lit .. => return
|
||||
| .mvar .. => unreachable!
|
||||
|
||||
/--
|
||||
A general abstraction for the idea of a scope in the compiler.
|
||||
-/
|
||||
abbrev ScopeM := StateRefT FVarIdSet CompilerM
|
||||
|
||||
namespace ScopeM
|
||||
|
||||
def getScope : ScopeM FVarIdSet := get
|
||||
def setScope (newScope : FVarIdSet) : ScopeM Unit := set newScope
|
||||
def clearScope : ScopeM Unit := setScope {}
|
||||
|
||||
/--
|
||||
Execute `x` but recover the previous scope after doing so.
|
||||
-/
|
||||
def withBackTrackingScope [MonadLiftT ScopeM m] [Monad m] [MonadFinally m] (x : m α) : m α := do
|
||||
let scope ← getScope
|
||||
try x finally setScope scope
|
||||
|
||||
/--
|
||||
Clear the current scope for the monadic action `x`, afterwards continuing
|
||||
with the old one.
|
||||
-/
|
||||
def withNewScope [MonadLiftT ScopeM m] [Monad m] [MonadFinally m] (x : m α) : m α := do
|
||||
withBackTrackingScope do
|
||||
clearScope
|
||||
x
|
||||
|
||||
/--
|
||||
Check whether `fvarId` is in the current scope, that is, was declared within
|
||||
the current `fun` declaration that is being processed.
|
||||
-/
|
||||
def isInScope (fvarId : FVarId) : ScopeM Bool := do
|
||||
let scope ← getScope
|
||||
return scope.contains fvarId
|
||||
|
||||
/--
|
||||
Add a new `FVarId` to the current scope.
|
||||
-/
|
||||
def addToScope (fvarId : FVarId) : ScopeM Unit :=
|
||||
modify fun scope => scope.insert fvarId
|
||||
|
||||
end ScopeM
|
||||
|
||||
namespace JoinPointFinder
|
||||
|
||||
open ScopeM
|
||||
|
|
@ -217,7 +139,7 @@ where
|
|||
eraseCandidate fvarId
|
||||
-- Out of scope join point candidate handling
|
||||
else if let some upperCandidate ← read then
|
||||
if !(←isInScope fvarId) then
|
||||
if !(← isInScope fvarId) then
|
||||
addDependency fvarId upperCandidate
|
||||
else
|
||||
eraseCandidate fvarId
|
||||
|
|
@ -246,7 +168,7 @@ Replace all join point candidate `fun` declarations with `jp` ones
|
|||
and all calls to them with `jmp`s.
|
||||
-/
|
||||
partial def replace (decl : Decl) (state : FindState) : CompilerM Decl := do
|
||||
let mapper := fun acc cname _ => do return acc.insert cname (←mkFreshJpName)
|
||||
let mapper := fun acc cname _ => do return acc.insert cname (← mkFreshJpName)
|
||||
let replaceCtx : ReplaceCtx ← state.candidates.foldM (init := .empty) mapper
|
||||
let newValue ← go decl.value |>.run replaceCtx
|
||||
return { decl with value := newValue }
|
||||
|
|
@ -264,23 +186,23 @@ where
|
|||
return code
|
||||
else
|
||||
return code
|
||||
| _, _, _ => return Code.updateLet! code decl (←go k)
|
||||
| _, _, _ => return Code.updateLet! code decl (← go k)
|
||||
| .fun decl k =>
|
||||
if let some replacement := (← read).find? decl.fvarId then
|
||||
let newDecl := { decl with
|
||||
binderName := replacement,
|
||||
value := (←go decl.value)
|
||||
value := (← go decl.value)
|
||||
}
|
||||
modifyLCtx fun lctx => lctx.addFunDecl newDecl
|
||||
return .jp newDecl (←go k)
|
||||
return .jp newDecl (← go k)
|
||||
else
|
||||
let newDecl ← decl.updateValue (←go decl.value)
|
||||
return Code.updateFun! code newDecl (←go k)
|
||||
let newDecl ← decl.updateValue (← go decl.value)
|
||||
return Code.updateFun! code newDecl (← go k)
|
||||
| .jp decl k =>
|
||||
let newDecl ← decl.updateValue (←go decl.value)
|
||||
return Code.updateFun! code newDecl (←go k)
|
||||
let newDecl ← decl.updateValue (← go decl.value)
|
||||
return Code.updateFun! code newDecl (← go k)
|
||||
| .cases cs =>
|
||||
return Code.updateCases! code cs.resultType cs.discr (←cs.alts.mapM (·.mapCodeM go))
|
||||
return Code.updateCases! code cs.resultType cs.discr (← cs.alts.mapM (·.mapCodeM go))
|
||||
| .jmp .. | .return .. | .unreach .. =>
|
||||
return code
|
||||
|
||||
|
|
@ -330,9 +252,9 @@ Replace a free variable if necessary, that is:
|
|||
otherwise just return `fvar`.
|
||||
-/
|
||||
def replaceFVar (fvar : FVarId) : ExtendM FVarId := do
|
||||
if (←read).candidates.contains fvar then
|
||||
if let some currentJp := (←read).currentJp? then
|
||||
if let some replacement := (←get).fvarMap.find! currentJp |>.find? fvar then
|
||||
if (← read).candidates.contains fvar then
|
||||
if let some currentJp := (← read).currentJp? then
|
||||
if let some replacement := (← get).fvarMap.find! currentJp |>.find? fvar then
|
||||
return replacement.fvarId
|
||||
return fvar
|
||||
|
||||
|
|
@ -342,7 +264,7 @@ if we are currently within a join point. Then execute `x`.
|
|||
-/
|
||||
def withNewCandidate (fvar : FVarId) (x : ExtendM α) : ExtendM α := do
|
||||
addToScope fvar
|
||||
if (←read).currentJp?.isSome then
|
||||
if (← read).currentJp?.isSome then
|
||||
withReader (fun ctx => { ctx with candidates := ctx.candidates.insert fvar }) do
|
||||
x
|
||||
else
|
||||
|
|
@ -352,9 +274,11 @@ def withNewCandidate (fvar : FVarId) (x : ExtendM α) : ExtendM α := do
|
|||
Same as `withNewCandidate` but with multiple `FVarId`s.
|
||||
-/
|
||||
def withNewCandidates (fvars : Array FVarId) (x : ExtendM α) : ExtendM α := do
|
||||
if (←read).currentJp?.isSome then
|
||||
let candidates := (←read).candidates
|
||||
let folder := (fun acc val => do addToScope val; return acc.insert val)
|
||||
if (← read).currentJp?.isSome then
|
||||
let candidates := (← read).candidates
|
||||
let folder (acc : FVarIdSet) (val : FVarId) := do
|
||||
addToScope val
|
||||
return acc.insert val
|
||||
let newCandidates ← fvars.foldlM (init := candidates) folder
|
||||
withReader (fun ctx => { ctx with candidates := newCandidates }) do
|
||||
x
|
||||
|
|
@ -380,10 +304,10 @@ This is necessary if:
|
|||
cannot lift a join point outside of a local function declaration.
|
||||
-/
|
||||
def extendByIfNecessary (fvar : FVarId) : ExtendM Unit := do
|
||||
if let some currentJp := (←read).currentJp? then
|
||||
let mut translator := (←get).fvarMap.find! currentJp
|
||||
let candidates := (←read).candidates
|
||||
if !(←isInScope fvar) && !translator.contains fvar && candidates.contains fvar then
|
||||
if let some currentJp := (← read).currentJp? then
|
||||
let mut translator := (← get).fvarMap.find! currentJp
|
||||
let candidates := (← read).candidates
|
||||
if !(← isInScope fvar) && !translator.contains fvar && candidates.contains fvar then
|
||||
let typ ← getType fvar
|
||||
let newParam ← mkAuxParam typ
|
||||
translator := translator.insert fvar newParam
|
||||
|
|
@ -404,8 +328,8 @@ as well because we need to drag these variables through at the call sites
|
|||
of `j.2` in `j.1`.
|
||||
-/
|
||||
def mergeJpContextIfNecessary (jp : FVarId) : ExtendM Unit := do
|
||||
if (←read).currentJp?.isSome then
|
||||
let additionalArgs := (←get).fvarMap.find! jp |>.toArray
|
||||
if (← read).currentJp?.isSome then
|
||||
let additionalArgs := (← get).fvarMap.find! jp |>.toArray
|
||||
for (fvar, _) in additionalArgs do
|
||||
extendByIfNecessary fvar
|
||||
|
||||
|
|
@ -469,23 +393,23 @@ where
|
|||
go (code : Code) : ExtendM Code := do
|
||||
match code with
|
||||
| .let decl k =>
|
||||
let decl ← decl.updateValue (←goExpr decl.value)
|
||||
let decl ← decl.updateValue (← goExpr decl.value)
|
||||
withNewCandidate decl.fvarId do
|
||||
return Code.updateLet! code decl (←go k)
|
||||
return Code.updateLet! code decl (← go k)
|
||||
| .jp decl k =>
|
||||
let decl ← withNewJpScope decl do
|
||||
let value ← go decl.value
|
||||
let additionalParams := (←get).fvarMap.find! decl.fvarId |>.toArray |>.map Prod.snd
|
||||
let additionalParams := (← get).fvarMap.find! decl.fvarId |>.toArray |>.map Prod.snd
|
||||
let newType := additionalParams.foldr (init := decl.type) (fun val acc => .forallE val.binderName val.type acc .default)
|
||||
decl.update newType (additionalParams ++ decl.params) value
|
||||
mergeJpContextIfNecessary decl.fvarId
|
||||
withNewCandidate decl.fvarId do
|
||||
return Code.updateFun! code decl (←go k)
|
||||
return Code.updateFun! code decl (← go k)
|
||||
| .fun decl k =>
|
||||
let decl ← withNewFunScope decl do
|
||||
decl.updateValue (←go decl.value)
|
||||
decl.updateValue (← go decl.value)
|
||||
withNewCandidate decl.fvarId do
|
||||
return Code.updateFun! code decl (←go k)
|
||||
return Code.updateFun! code decl (← go k)
|
||||
| .cases cs =>
|
||||
extendByIfNecessary cs.discr
|
||||
let discr ← replaceFVar cs.discr
|
||||
|
|
@ -496,9 +420,9 @@ where
|
|||
return Code.updateCases! code cs.resultType discr alts
|
||||
| .jmp fn args =>
|
||||
let mut newArgs ← args.mapM goExpr
|
||||
let additionalArgs := (←get).fvarMap.find! fn |>.toArray |>.map Prod.fst
|
||||
if let some currentJp := (←read).currentJp? then
|
||||
let translator := (←get).fvarMap.find! currentJp
|
||||
let additionalArgs := (← get).fvarMap.find! fn |>.toArray |>.map Prod.fst
|
||||
if let some currentJp := (← read).currentJp? then
|
||||
let translator := (← get).fvarMap.find! currentJp
|
||||
let f := fun arg =>
|
||||
if let some translated := translator.find? arg then
|
||||
.fvar translated.fvarId
|
||||
|
|
@ -510,7 +434,7 @@ where
|
|||
return Code.updateJmp! code fn newArgs
|
||||
| .return var =>
|
||||
extendByIfNecessary var
|
||||
return Code.updateReturn! code (←replaceFVar var)
|
||||
return Code.updateReturn! code (← replaceFVar var)
|
||||
| .unreach .. => return code
|
||||
|
||||
end JoinPointContextExtender
|
||||
|
|
|
|||
53
src/Lean/Compiler/LCNF/ScopeM.lean
Normal file
53
src/Lean/Compiler/LCNF/ScopeM.lean
Normal file
|
|
@ -0,0 +1,53 @@
|
|||
/-
|
||||
Copyright (c) 2022 Henrik Böving. All rights reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Henrik Böving
|
||||
-/
|
||||
import Lean.Compiler.LCNF.CompilerM
|
||||
|
||||
namespace Lean.Compiler.LCNF
|
||||
|
||||
/--
|
||||
A general abstraction for the idea of a scope in the compiler.
|
||||
-/
|
||||
abbrev ScopeM := StateRefT FVarIdSet CompilerM
|
||||
|
||||
namespace ScopeM
|
||||
|
||||
def getScope : ScopeM FVarIdSet := get
|
||||
def setScope (newScope : FVarIdSet) : ScopeM Unit := set newScope
|
||||
def clearScope : ScopeM Unit := setScope {}
|
||||
|
||||
/--
|
||||
Execute `x` but recover the previous scope after doing so.
|
||||
-/
|
||||
def withBackTrackingScope [MonadLiftT ScopeM m] [Monad m] [MonadFinally m] (x : m α) : m α := do
|
||||
let scope ← getScope
|
||||
try x finally setScope scope
|
||||
|
||||
/--
|
||||
Clear the current scope for the monadic action `x`, afterwards continuing
|
||||
with the old one.
|
||||
-/
|
||||
def withNewScope [MonadLiftT ScopeM m] [Monad m] [MonadFinally m] (x : m α) : m α := do
|
||||
withBackTrackingScope do
|
||||
clearScope
|
||||
x
|
||||
|
||||
/--
|
||||
Check whether `fvarId` is in the current scope, that is, was declared within
|
||||
the current `fun` declaration that is being processed.
|
||||
-/
|
||||
def isInScope (fvarId : FVarId) : ScopeM Bool := do
|
||||
let scope ← getScope
|
||||
return scope.contains fvarId
|
||||
|
||||
/--
|
||||
Add a new `FVarId` to the current scope.
|
||||
-/
|
||||
def addToScope (fvarId : FVarId) : ScopeM Unit :=
|
||||
modify fun scope => scope.insert fvarId
|
||||
|
||||
end ScopeM
|
||||
|
||||
end Lean.Compiler.LCNF
|
||||
Loading…
Add table
Reference in a new issue