chore: address PR comments

This commit is contained in:
Henrik Böving 2022-10-05 13:36:11 +02:00
parent 7b3709e28a
commit e15e6bfaee
3 changed files with 136 additions and 113 deletions

View file

@ -0,0 +1,46 @@
/-
Copyright (c) 2022 Henrik Böving. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Henrik Böving
-/
import Lean.Expr
namespace Lean.Compiler.LCNF
partial def mapFVarM [Monad m] (f : FVarId → m FVarId) (e : Expr) : m Expr := do
match e with
| .proj typ idx struct => return .proj typ idx (← mapFVarM f struct)
| .app fn arg => return .app (← mapFVarM f fn) (← mapFVarM f arg)
| .fvar fvarId => return .fvar (← f fvarId)
| .lam arg ty body bi =>
return .lam arg (← mapFVarM f ty) (← mapFVarM f body) bi
| .forallE arg ty body bi =>
return .forallE arg (←mapFVarM f ty) (← mapFVarM f body) bi
| .letE var ty value body nonDep =>
return .letE var (← mapFVarM f ty) (← mapFVarM f value) (← mapFVarM f body) nonDep
| .bvar .. | .sort .. => return e
| .mdata .. | .const .. | .lit .. => return e
| .mvar .. => unreachable!
partial def forFVarM [Monad m] (f : FVarId → m Unit) (e : Expr) : m Unit := do
match e with
| .proj _ _ struct => forFVarM f struct
| .app fn arg =>
forFVarM f fn
forFVarM f arg
| .fvar fvarId => f fvarId
| .lam _ ty body .. =>
forFVarM f ty
forFVarM f body
| .forallE _ ty body .. =>
forFVarM f ty
forFVarM f body
| .letE _ ty value body .. =>
forFVarM f ty
forFVarM f value
forFVarM f body
| .bvar .. | .sort .. => return
| .mdata .. | .const .. | .lit .. => return
| .mvar .. => unreachable!
end Lean.Compiler.LCNF

View file

@ -6,89 +6,11 @@ Authors: Henrik Böving
import Lean.Compiler.LCNF.CompilerM
import Lean.Compiler.LCNF.PassManager
import Lean.Compiler.LCNF.PullFunDecls
import Lean.Compiler.LCNF.FVarUtil
import Lean.Compiler.LCNF.ScopeM
namespace Lean.Compiler.LCNF
-- TODO: These can be used in a much more general context
partial def mapFVarM [Monad m] (f : FVarId → m FVarId) (e : Expr) : m Expr := do
match e with
| .proj typ idx struct => return .proj typ idx (←mapFVarM f struct)
| .app fn arg => return .app (←mapFVarM f fn) (←mapFVarM f arg)
| .fvar fvarId => return .fvar (←f fvarId)
| .lam arg ty body bi =>
return .lam arg (←mapFVarM f ty) (←mapFVarM f body) bi
| .forallE arg ty body bi =>
return .forallE arg (←mapFVarM f ty) (←mapFVarM f body) bi
| .letE var ty value body nonDep =>
return .letE var (←mapFVarM f ty) (←mapFVarM f value) (←mapFVarM f body) nonDep
| .bvar .. | .sort .. => return e
| .mdata .. | .const .. | .lit .. => return e
| .mvar .. => unreachable!
partial def forFVarM [Monad m] (f : FVarId → m Unit) (e : Expr) : m Unit := do
match e with
| .proj _ _ struct => forFVarM f struct
| .app fn arg =>
forFVarM f fn
forFVarM f arg
| .fvar fvarId => f fvarId
| .lam _ ty body .. =>
forFVarM f ty
forFVarM f body
| .forallE _ ty body .. =>
forFVarM f ty
forFVarM f body
| .letE _ ty value body .. =>
forFVarM f ty
forFVarM f value
forFVarM f body
| .bvar .. | .sort .. => return
| .mdata .. | .const .. | .lit .. => return
| .mvar .. => unreachable!
/--
A general abstraction for the idea of a scope in the compiler.
-/
abbrev ScopeM := StateRefT FVarIdSet CompilerM
namespace ScopeM
def getScope : ScopeM FVarIdSet := get
def setScope (newScope : FVarIdSet) : ScopeM Unit := set newScope
def clearScope : ScopeM Unit := setScope {}
/--
Execute `x` but recover the previous scope after doing so.
-/
def withBackTrackingScope [MonadLiftT ScopeM m] [Monad m] [MonadFinally m] (x : m α) : m α := do
let scope ← getScope
try x finally setScope scope
/--
Clear the current scope for the monadic action `x`, afterwards continuing
with the old one.
-/
def withNewScope [MonadLiftT ScopeM m] [Monad m] [MonadFinally m] (x : m α) : m α := do
withBackTrackingScope do
clearScope
x
/--
Check whether `fvarId` is in the current scope, that is, was declared within
the current `fun` declaration that is being processed.
-/
def isInScope (fvarId : FVarId) : ScopeM Bool := do
let scope ← getScope
return scope.contains fvarId
/--
Add a new `FVarId` to the current scope.
-/
def addToScope (fvarId : FVarId) : ScopeM Unit :=
modify fun scope => scope.insert fvarId
end ScopeM
namespace JoinPointFinder
open ScopeM
@ -217,7 +139,7 @@ where
eraseCandidate fvarId
-- Out of scope join point candidate handling
else if let some upperCandidate ← read then
if !(←isInScope fvarId) then
if !(← isInScope fvarId) then
addDependency fvarId upperCandidate
else
eraseCandidate fvarId
@ -246,7 +168,7 @@ Replace all join point candidate `fun` declarations with `jp` ones
and all calls to them with `jmp`s.
-/
partial def replace (decl : Decl) (state : FindState) : CompilerM Decl := do
let mapper := fun acc cname _ => do return acc.insert cname (←mkFreshJpName)
let mapper := fun acc cname _ => do return acc.insert cname (← mkFreshJpName)
let replaceCtx : ReplaceCtx ← state.candidates.foldM (init := .empty) mapper
let newValue ← go decl.value |>.run replaceCtx
return { decl with value := newValue }
@ -264,23 +186,23 @@ where
return code
else
return code
| _, _, _ => return Code.updateLet! code decl (←go k)
| _, _, _ => return Code.updateLet! code decl (← go k)
| .fun decl k =>
if let some replacement := (← read).find? decl.fvarId then
let newDecl := { decl with
binderName := replacement,
value := (←go decl.value)
value := (← go decl.value)
}
modifyLCtx fun lctx => lctx.addFunDecl newDecl
return .jp newDecl (←go k)
return .jp newDecl (← go k)
else
let newDecl ← decl.updateValue (←go decl.value)
return Code.updateFun! code newDecl (←go k)
let newDecl ← decl.updateValue (← go decl.value)
return Code.updateFun! code newDecl (← go k)
| .jp decl k =>
let newDecl ← decl.updateValue (←go decl.value)
return Code.updateFun! code newDecl (←go k)
let newDecl ← decl.updateValue (← go decl.value)
return Code.updateFun! code newDecl (← go k)
| .cases cs =>
return Code.updateCases! code cs.resultType cs.discr (←cs.alts.mapM (·.mapCodeM go))
return Code.updateCases! code cs.resultType cs.discr (← cs.alts.mapM (·.mapCodeM go))
| .jmp .. | .return .. | .unreach .. =>
return code
@ -330,9 +252,9 @@ Replace a free variable if necessary, that is:
otherwise just return `fvar`.
-/
def replaceFVar (fvar : FVarId) : ExtendM FVarId := do
if (←read).candidates.contains fvar then
if let some currentJp := (←read).currentJp? then
if let some replacement := (←get).fvarMap.find! currentJp |>.find? fvar then
if (← read).candidates.contains fvar then
if let some currentJp := (← read).currentJp? then
if let some replacement := (← get).fvarMap.find! currentJp |>.find? fvar then
return replacement.fvarId
return fvar
@ -342,7 +264,7 @@ if we are currently within a join point. Then execute `x`.
-/
def withNewCandidate (fvar : FVarId) (x : ExtendM α) : ExtendM α := do
addToScope fvar
if (←read).currentJp?.isSome then
if (← read).currentJp?.isSome then
withReader (fun ctx => { ctx with candidates := ctx.candidates.insert fvar }) do
x
else
@ -352,9 +274,11 @@ def withNewCandidate (fvar : FVarId) (x : ExtendM α) : ExtendM α := do
Same as `withNewCandidate` but with multiple `FVarId`s.
-/
def withNewCandidates (fvars : Array FVarId) (x : ExtendM α) : ExtendM α := do
if (←read).currentJp?.isSome then
let candidates := (←read).candidates
let folder := (fun acc val => do addToScope val; return acc.insert val)
if (← read).currentJp?.isSome then
let candidates := (← read).candidates
let folder (acc : FVarIdSet) (val : FVarId) := do
addToScope val
return acc.insert val
let newCandidates ← fvars.foldlM (init := candidates) folder
withReader (fun ctx => { ctx with candidates := newCandidates }) do
x
@ -380,10 +304,10 @@ This is necessary if:
cannot lift a join point outside of a local function declaration.
-/
def extendByIfNecessary (fvar : FVarId) : ExtendM Unit := do
if let some currentJp := (←read).currentJp? then
let mut translator := (←get).fvarMap.find! currentJp
let candidates := (←read).candidates
if !(←isInScope fvar) && !translator.contains fvar && candidates.contains fvar then
if let some currentJp := (← read).currentJp? then
let mut translator := (← get).fvarMap.find! currentJp
let candidates := (← read).candidates
if !(← isInScope fvar) && !translator.contains fvar && candidates.contains fvar then
let typ ← getType fvar
let newParam ← mkAuxParam typ
translator := translator.insert fvar newParam
@ -404,8 +328,8 @@ as well because we need to drag these variables through at the call sites
of `j.2` in `j.1`.
-/
def mergeJpContextIfNecessary (jp : FVarId) : ExtendM Unit := do
if (←read).currentJp?.isSome then
let additionalArgs := (←get).fvarMap.find! jp |>.toArray
if (← read).currentJp?.isSome then
let additionalArgs := (← get).fvarMap.find! jp |>.toArray
for (fvar, _) in additionalArgs do
extendByIfNecessary fvar
@ -469,23 +393,23 @@ where
go (code : Code) : ExtendM Code := do
match code with
| .let decl k =>
let decl ← decl.updateValue (←goExpr decl.value)
let decl ← decl.updateValue (← goExpr decl.value)
withNewCandidate decl.fvarId do
return Code.updateLet! code decl (←go k)
return Code.updateLet! code decl (← go k)
| .jp decl k =>
let decl ← withNewJpScope decl do
let value ← go decl.value
let additionalParams := (←get).fvarMap.find! decl.fvarId |>.toArray |>.map Prod.snd
let additionalParams := (← get).fvarMap.find! decl.fvarId |>.toArray |>.map Prod.snd
let newType := additionalParams.foldr (init := decl.type) (fun val acc => .forallE val.binderName val.type acc .default)
decl.update newType (additionalParams ++ decl.params) value
mergeJpContextIfNecessary decl.fvarId
withNewCandidate decl.fvarId do
return Code.updateFun! code decl (←go k)
return Code.updateFun! code decl (← go k)
| .fun decl k =>
let decl ← withNewFunScope decl do
decl.updateValue (←go decl.value)
decl.updateValue (← go decl.value)
withNewCandidate decl.fvarId do
return Code.updateFun! code decl (←go k)
return Code.updateFun! code decl (← go k)
| .cases cs =>
extendByIfNecessary cs.discr
let discr ← replaceFVar cs.discr
@ -496,9 +420,9 @@ where
return Code.updateCases! code cs.resultType discr alts
| .jmp fn args =>
let mut newArgs ← args.mapM goExpr
let additionalArgs := (←get).fvarMap.find! fn |>.toArray |>.map Prod.fst
if let some currentJp := (←read).currentJp? then
let translator := (←get).fvarMap.find! currentJp
let additionalArgs := (← get).fvarMap.find! fn |>.toArray |>.map Prod.fst
if let some currentJp := (← read).currentJp? then
let translator := (← get).fvarMap.find! currentJp
let f := fun arg =>
if let some translated := translator.find? arg then
.fvar translated.fvarId
@ -510,7 +434,7 @@ where
return Code.updateJmp! code fn newArgs
| .return var =>
extendByIfNecessary var
return Code.updateReturn! code (←replaceFVar var)
return Code.updateReturn! code (← replaceFVar var)
| .unreach .. => return code
end JoinPointContextExtender

View file

@ -0,0 +1,53 @@
/-
Copyright (c) 2022 Henrik Böving. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Henrik Böving
-/
import Lean.Compiler.LCNF.CompilerM
namespace Lean.Compiler.LCNF
/--
A general abstraction for the idea of a scope in the compiler.
-/
abbrev ScopeM := StateRefT FVarIdSet CompilerM
namespace ScopeM
def getScope : ScopeM FVarIdSet := get
def setScope (newScope : FVarIdSet) : ScopeM Unit := set newScope
def clearScope : ScopeM Unit := setScope {}
/--
Execute `x` but recover the previous scope after doing so.
-/
def withBackTrackingScope [MonadLiftT ScopeM m] [Monad m] [MonadFinally m] (x : m α) : m α := do
let scope ← getScope
try x finally setScope scope
/--
Clear the current scope for the monadic action `x`, afterwards continuing
with the old one.
-/
def withNewScope [MonadLiftT ScopeM m] [Monad m] [MonadFinally m] (x : m α) : m α := do
withBackTrackingScope do
clearScope
x
/--
Check whether `fvarId` is in the current scope, that is, was declared within
the current `fun` declaration that is being processed.
-/
def isInScope (fvarId : FVarId) : ScopeM Bool := do
let scope ← getScope
return scope.contains fvarId
/--
Add a new `FVarId` to the current scope.
-/
def addToScope (fvarId : FVarId) : ScopeM Unit :=
modify fun scope => scope.insert fvarId
end ScopeM
end Lean.Compiler.LCNF