390 lines
15 KiB
Text
390 lines
15 KiB
Text
/-
|
||
Copyright (c) 2022 Microsoft Corporation. All rights reserved.
|
||
Released under Apache 2.0 license as described in the file LICENSE.
|
||
Authors: Leonardo de Moura
|
||
-/
|
||
import Lean.CoreM
|
||
import Lean.Compiler.LCNF.Basic
|
||
import Lean.Compiler.LCNF.LCtx
|
||
import Lean.Compiler.LCNF.ConfigOptions
|
||
|
||
namespace Lean.Compiler.LCNF
|
||
/--
|
||
The pipeline phase a certain `Pass` is supposed to happen in.
|
||
-/
|
||
inductive Phase where
|
||
/-- Here we still carry most of the original type information, most
|
||
of the dependent portion is already (partially) erased though. -/
|
||
| base
|
||
/-- In this phase polymorphism has been eliminated. -/
|
||
| mono
|
||
/-- In this phase impure stuff such as RC or efficient BaseIO transformations happen. -/
|
||
| impure
|
||
deriving Inhabited
|
||
|
||
/--
|
||
The state managed by the `CompilerM` `Monad`.
|
||
-/
|
||
structure CompilerM.State where
|
||
/--
|
||
A `LocalContext` to store local declarations from let binders
|
||
and other constructs in as we move through `Expr`s.
|
||
-/
|
||
lctx : LCtx := {}
|
||
/-- Next auxiliary variable suffix -/
|
||
nextIdx : Nat := 1
|
||
deriving Inhabited
|
||
|
||
structure CompilerM.Context where
|
||
phase : Phase
|
||
config : ConfigOptions
|
||
deriving Inhabited
|
||
|
||
abbrev CompilerM := ReaderT CompilerM.Context $ StateRefT CompilerM.State CoreM
|
||
|
||
@[alwaysInline]
|
||
instance : Monad CompilerM := let i := inferInstanceAs (Monad CompilerM); { pure := i.pure, bind := i.bind }
|
||
|
||
@[inline] def withPhase (phase : Phase) (x : CompilerM α) : CompilerM α :=
|
||
withReader (fun ctx => { ctx with phase }) x
|
||
|
||
def getPhase : CompilerM Phase :=
|
||
return (← read).phase
|
||
|
||
def inBasePhase : CompilerM Bool :=
|
||
return (← getPhase) matches .base
|
||
|
||
instance : AddMessageContext CompilerM where
|
||
addMessageContext msgData := do
|
||
let env ← getEnv
|
||
let lctx := (← get).lctx.toLocalContext
|
||
let opts ← getOptions
|
||
return MessageData.withContext { env, lctx, opts, mctx := {} } msgData
|
||
|
||
def getType (fvarId : FVarId) : CompilerM Expr := do
|
||
let lctx := (← get).lctx
|
||
if let some decl := lctx.letDecls.find? fvarId then
|
||
return decl.type
|
||
else if let some decl := lctx.params.find? fvarId then
|
||
return decl.type
|
||
else if let some decl := lctx.funDecls.find? fvarId then
|
||
return decl.type
|
||
else
|
||
throwError "unknown free variable {fvarId.name}"
|
||
|
||
def getBinderName (fvarId : FVarId) : CompilerM Name := do
|
||
let lctx := (← get).lctx
|
||
if let some decl := lctx.letDecls.find? fvarId then
|
||
return decl.binderName
|
||
else if let some decl := lctx.params.find? fvarId then
|
||
return decl.binderName
|
||
else if let some decl := lctx.funDecls.find? fvarId then
|
||
return decl.binderName
|
||
else
|
||
throwError "unknown free variable {fvarId.name}"
|
||
|
||
def findParam? (fvarId : FVarId) : CompilerM (Option Param) :=
|
||
return (← get).lctx.params.find? fvarId
|
||
|
||
def findLetDecl? (fvarId : FVarId) : CompilerM (Option LetDecl) :=
|
||
return (← get).lctx.letDecls.find? fvarId
|
||
|
||
def findFunDecl? (fvarId : FVarId) : CompilerM (Option FunDecl) :=
|
||
return (← get).lctx.funDecls.find? fvarId
|
||
|
||
def getParam (fvarId : FVarId) : CompilerM Param := do
|
||
let some param ← findParam? fvarId | throwError "unknown parameter {fvarId.name}"
|
||
return param
|
||
|
||
def getLetDecl (fvarId : FVarId) : CompilerM LetDecl := do
|
||
let some decl ← findLetDecl? fvarId | throwError "unknown let-declaration {fvarId.name}"
|
||
return decl
|
||
|
||
def getFunDecl (fvarId : FVarId) : CompilerM FunDecl := do
|
||
let some decl ← findFunDecl? fvarId | throwError "unknown local function {fvarId.name}"
|
||
return decl
|
||
|
||
@[inline] def modifyLCtx (f : LCtx → LCtx) : CompilerM Unit := do
|
||
modify fun s => { s with lctx := f s.lctx }
|
||
|
||
def eraseLetDecl (decl : LetDecl) : CompilerM Unit := do
|
||
modifyLCtx fun lctx => lctx.eraseLetDecl decl
|
||
|
||
def eraseFunDecl (decl : FunDecl) (recursive := true) : CompilerM Unit := do
|
||
modifyLCtx fun lctx => lctx.eraseFunDecl decl recursive
|
||
|
||
def eraseCode (code : Code) : CompilerM Unit := do
|
||
modifyLCtx fun lctx => lctx.eraseCode code
|
||
|
||
def eraseParam (param : Param) : CompilerM Unit :=
|
||
modifyLCtx fun lctx => lctx.eraseParam param
|
||
|
||
def eraseParams (params : Array Param) : CompilerM Unit :=
|
||
modifyLCtx fun lctx => lctx.eraseParams params
|
||
|
||
def eraseCodeDecl (decl : CodeDecl) : CompilerM Unit := do
|
||
match decl with
|
||
| .let decl => eraseLetDecl decl
|
||
| .jp decl | .fun decl => eraseFunDecl decl
|
||
|
||
/--
|
||
Erase all free variables occurring in `decls` from the local context.
|
||
-/
|
||
def eraseCodeDecls (decls : Array CodeDecl) : CompilerM Unit := do
|
||
decls.forM fun decl => eraseCodeDecl decl
|
||
|
||
def eraseDecl (decl : Decl) : CompilerM Unit := do
|
||
eraseParams decl.params
|
||
eraseCode decl.value
|
||
|
||
abbrev Decl.erase (decl : Decl) : CompilerM Unit :=
|
||
eraseDecl decl
|
||
|
||
/--
|
||
A free variable substitution.
|
||
We use these substitutions when inlining definitions and "internalizing" LCNF code into `CompilerM`.
|
||
During the internalization process, we ensure all free variables in the LCNF code do not collide with existing ones
|
||
at the `CompilerM` local context.
|
||
Remark: in LCNF, (computationally relevant) data is in A-normal form, but this is not the case for types and type formers.
|
||
So, when inlining we often want to replace a free variable with a type or type former.
|
||
|
||
The substitution contains entries `fvarId ↦ e` s.t., `e` is a valid LCNF argument. That is,
|
||
it is a free variable, a type (or type former), or `lcErased`.
|
||
|
||
`Check.lean` contains a substitution validator.
|
||
-/
|
||
abbrev FVarSubst := HashMap FVarId Expr
|
||
|
||
/--
|
||
Replace the free variables in `e` using the given substitution.
|
||
|
||
If `translator = true`, then we assume the free variables occurring in the range of the substitution are in another
|
||
local context. For example, `translator = true` during internalization where we are making sure all free variables
|
||
in a given expression are replaced with new ones that do not collide with the ones in the current local context.
|
||
|
||
If `translator = false`, we assume the substitution contains free variable replacements in the same local context,
|
||
and given entries such as `x₁ ↦ x₂`, `x₂ ↦ x₃`, ..., `xₙ₋₁ ↦ xₙ`, and the expression `f x₁ x₂`, we want the resulting
|
||
expression to be `f xₙ xₙ`. We use this setting, for example, in the simplifier.
|
||
-/
|
||
private partial def normExprImp (s : FVarSubst) (e : Expr) (translator : Bool) : Expr :=
|
||
go e
|
||
where
|
||
goApp (e : Expr) : Expr :=
|
||
match e with
|
||
| .app f a => e.updateApp! (goApp f) (go a)
|
||
| _ => go e
|
||
|
||
go (e : Expr) : Expr :=
|
||
if e.hasFVar then
|
||
match e with
|
||
| .fvar fvarId => match s.find? fvarId with
|
||
| some e => if translator then e else go e
|
||
| none => e
|
||
| .lit .. | .const .. | .sort .. | .mvar .. | .bvar .. => e
|
||
| .app f a => e.updateApp! (goApp f) (go a) |>.headBeta
|
||
| .mdata _ b => e.updateMData! (go b)
|
||
| .proj _ _ b => e.updateProj! (go b)
|
||
| .forallE _ d b _ => e.updateForallE! (go d) (go b)
|
||
| .lam _ d b _ => e.updateLambdaE! (go d) (go b)
|
||
| .letE .. => unreachable! -- Valid LCNF does not contain `let`-declarations
|
||
else
|
||
e
|
||
|
||
/--
|
||
Normalize the given free variable.
|
||
See `normExprImp` for documentation on the `translator` parameter.
|
||
This function is meant to be used in contexts where the input free-variable is computationally relevant.
|
||
This function panics if the substitution is mapping `fvarId` to an expression that is not another free variable.
|
||
That is, it is not a type (or type former), nor `lcErased`. Recall that a valid `FVarSubst` contains only
|
||
expressions that are free variables, `lcErased`, or type formers.
|
||
-/
|
||
private partial def normFVarImp (s : FVarSubst) (fvarId : FVarId) (translator : Bool) : FVarId :=
|
||
match s.find? fvarId with
|
||
| some (.fvar fvarId') =>
|
||
if translator then
|
||
fvarId'
|
||
else
|
||
normFVarImp s fvarId' translator
|
||
| some e => panic! s!"invalid LCNF substitution of free variable with expression {e}"
|
||
| none => fvarId
|
||
|
||
/--
|
||
Interface for monads that have a free substitutions.
|
||
-/
|
||
class MonadFVarSubst (m : Type → Type) (translator : outParam Bool) where
|
||
getSubst : m FVarSubst
|
||
|
||
export MonadFVarSubst (getSubst)
|
||
|
||
instance (m n) [MonadLift m n] [MonadFVarSubst m t] : MonadFVarSubst n t where
|
||
getSubst := liftM (getSubst : m _)
|
||
|
||
class MonadFVarSubstState (m : Type → Type) where
|
||
modifySubst : (FVarSubst → FVarSubst) → m Unit
|
||
|
||
export MonadFVarSubstState (modifySubst)
|
||
|
||
instance (m n) [MonadLift m n] [MonadFVarSubstState m] : MonadFVarSubstState n where
|
||
modifySubst f := liftM (modifySubst f : m _)
|
||
|
||
/--
|
||
Add the entry `fvarId ↦ fvarId'` to the free variable substitution.
|
||
-/
|
||
@[inline] def addFVarSubst [MonadFVarSubstState m] (fvarId : FVarId) (fvarId' : FVarId) : m Unit :=
|
||
modifySubst fun s => s.insert fvarId (.fvar fvarId')
|
||
|
||
/--
|
||
Add the substitution `fvarId ↦ e`, `e` must be a valid LCNF argument.
|
||
That is, it must be a free variable, type (or type former), or `lcErased`.
|
||
|
||
See `Check.lean` for the free variable substitution checker.
|
||
-/
|
||
@[inline] def addSubst [MonadFVarSubstState m] (fvarId : FVarId) (e : Expr) : m Unit :=
|
||
modifySubst fun s => s.insert fvarId e
|
||
|
||
@[inline, inheritDoc normFVarImp] def normFVar [MonadFVarSubst m t] [Monad m] (fvarId : FVarId) : m FVarId :=
|
||
return normFVarImp (← getSubst) fvarId t
|
||
|
||
@[inline, inheritDoc normExprImp] def normExpr [MonadFVarSubst m t] [Monad m] (e : Expr) : m Expr :=
|
||
return normExprImp (← getSubst) e t
|
||
|
||
@[inheritDoc normExprImp]
|
||
abbrev normExprCore (s : FVarSubst) (e : Expr) (translator : Bool) : Expr :=
|
||
normExprImp s e translator
|
||
|
||
/--
|
||
Normalize the given expressions using the current substitution.
|
||
-/
|
||
def normExprs [MonadFVarSubst m t] [Monad m] (es : Array Expr) : m (Array Expr) :=
|
||
es.mapMonoM normExpr
|
||
|
||
def mkFreshBinderName (binderName := `_x): CompilerM Name := do
|
||
let declName := .num binderName (← get).nextIdx
|
||
modify fun s => { s with nextIdx := s.nextIdx + 1 }
|
||
return declName
|
||
|
||
/-!
|
||
Helper functions for creating LCNF local declarations.
|
||
-/
|
||
|
||
def mkParam (binderName : Name) (type : Expr) (borrow : Bool) : CompilerM Param := do
|
||
let fvarId ← mkFreshFVarId
|
||
let param := { fvarId, binderName, type, borrow }
|
||
modifyLCtx fun lctx => lctx.addParam param
|
||
return param
|
||
|
||
def mkLetDecl (binderName : Name) (type : Expr) (value : Expr) : CompilerM LetDecl := do
|
||
let fvarId ← mkFreshFVarId
|
||
let decl := { fvarId, binderName, type, value }
|
||
modifyLCtx fun lctx => lctx.addLetDecl decl
|
||
return decl
|
||
|
||
def mkFunDecl (binderName : Name) (type : Expr) (params : Array Param) (value : Code) : CompilerM FunDecl := do
|
||
let fvarId ← mkFreshFVarId
|
||
let funDecl := { fvarId, binderName, type, params, value }
|
||
modifyLCtx fun lctx => lctx.addFunDecl funDecl
|
||
return funDecl
|
||
|
||
private unsafe def updateParamImp (p : Param) (type : Expr) : CompilerM Param := do
|
||
if ptrEq type p.type then
|
||
return p
|
||
else
|
||
let p := { p with type }
|
||
modifyLCtx fun lctx => lctx.addParam p
|
||
return p
|
||
|
||
@[implementedBy updateParamImp] opaque Param.update (p : Param) (type : Expr) : CompilerM Param
|
||
|
||
private unsafe def updateLetDeclImp (decl : LetDecl) (type : Expr) (value : Expr) : CompilerM LetDecl := do
|
||
if ptrEq type decl.type && ptrEq value decl.value then
|
||
return decl
|
||
else
|
||
let decl := { decl with type, value }
|
||
modifyLCtx fun lctx => lctx.addLetDecl decl
|
||
return decl
|
||
|
||
@[implementedBy updateLetDeclImp] opaque LetDecl.update (decl : LetDecl) (type : Expr) (value : Expr) : CompilerM LetDecl
|
||
|
||
def LetDecl.updateValue (decl : LetDecl) (value : Expr) : CompilerM LetDecl :=
|
||
decl.update decl.type value
|
||
|
||
private unsafe def updateFunDeclImp (decl: FunDecl) (type : Expr) (params : Array Param) (value : Code) : CompilerM FunDecl := do
|
||
if ptrEq type decl.type && ptrEq params decl.params && ptrEq value decl.value then
|
||
return decl
|
||
else
|
||
let decl := { decl with type, params, value }
|
||
modifyLCtx fun lctx => lctx.addFunDecl decl
|
||
return decl
|
||
|
||
@[implementedBy updateFunDeclImp] opaque FunDeclCore.update (decl: FunDecl) (type : Expr) (params : Array Param) (value : Code) : CompilerM FunDecl
|
||
|
||
abbrev FunDeclCore.update' (decl : FunDecl) (type : Expr) (value : Code) : CompilerM FunDecl :=
|
||
decl.update type decl.params value
|
||
|
||
abbrev FunDeclCore.updateValue (decl : FunDecl) (value : Code) : CompilerM FunDecl :=
|
||
decl.update decl.type decl.params value
|
||
|
||
@[inline] def normParam [MonadLiftT CompilerM m] [Monad m] [MonadFVarSubst m t] (p : Param) : m Param := do
|
||
p.update (← normExpr p.type)
|
||
|
||
def normParams [MonadLiftT CompilerM m] [Monad m] [MonadFVarSubst m t] (ps : Array Param) : m (Array Param) :=
|
||
ps.mapMonoM normParam
|
||
|
||
def normLetDecl [MonadLiftT CompilerM m] [Monad m] [MonadFVarSubst m t] (decl : LetDecl) : m LetDecl := do
|
||
decl.update (← normExpr decl.type) (← normExpr decl.value)
|
||
|
||
abbrev NormalizerM (_translator : Bool) := ReaderT FVarSubst CompilerM
|
||
|
||
instance : MonadFVarSubst (NormalizerM t) t where
|
||
getSubst := read
|
||
|
||
mutual
|
||
partial def normFunDeclImp (decl : FunDecl) : NormalizerM t FunDecl := do
|
||
let type ← normExpr decl.type
|
||
let params ← normParams decl.params
|
||
let value ← normCodeImp decl.value
|
||
decl.update type params value
|
||
|
||
partial def normCodeImp (code : Code) : NormalizerM t Code := do
|
||
match code with
|
||
| .let decl k => return code.updateLet! (← normLetDecl decl) (← normCodeImp k)
|
||
| .fun decl k | .jp decl k => return code.updateFun! (← normFunDeclImp decl) (← normCodeImp k)
|
||
| .return fvarId => return code.updateReturn! (← normFVar fvarId)
|
||
| .jmp fvarId args => return code.updateJmp! (← normFVar fvarId) (← normExprs args)
|
||
| .unreach type => return code.updateUnreach! (← normExpr type)
|
||
| .cases c =>
|
||
let resultType ← normExpr c.resultType
|
||
let discr ← normFVar c.discr
|
||
let alts ← c.alts.mapMonoM fun alt =>
|
||
match alt with
|
||
| .alt _ params k => return alt.updateAlt! (← normParams params) (← normCodeImp k)
|
||
| .default k => return alt.updateCode (← normCodeImp k)
|
||
return code.updateCases! resultType discr alts
|
||
end
|
||
|
||
@[inline] def normFunDecl [MonadLiftT CompilerM m] [Monad m] [MonadFVarSubst m t] (decl : FunDecl) : m FunDecl := do
|
||
normFunDeclImp (t := t) decl (← getSubst)
|
||
|
||
/-- Similar to `internalize`, but does not refresh `FVarId`s. -/
|
||
@[inline] def normCode [MonadLiftT CompilerM m] [Monad m] [MonadFVarSubst m t] (code : Code) : m Code := do
|
||
normCodeImp (t := t) code (← getSubst)
|
||
|
||
def replaceExprFVars (e : Expr) (s : FVarSubst) (translator : Bool) : CompilerM Expr :=
|
||
(normExpr e : NormalizerM translator Expr).run s
|
||
|
||
def replaceFVars (code : Code) (s : FVarSubst) (translator : Bool) : CompilerM Code :=
|
||
(normCode code : NormalizerM translator Code).run s
|
||
|
||
def mkFreshJpName : CompilerM Name := do
|
||
mkFreshBinderName `_jp
|
||
|
||
def mkAuxParam (type : Expr) (borrow := false) : CompilerM Param := do
|
||
mkParam (← mkFreshBinderName `_y) type borrow
|
||
|
||
def getConfig : CompilerM ConfigOptions :=
|
||
return (← read).config
|
||
|
||
def CompilerM.run (x : CompilerM α) (s : State := {}) (phase : Phase := .base) : CoreM α := do
|
||
x { phase, config := toConfigOptions (← getOptions) } |>.run' s
|
||
|
||
end Lean.Compiler.LCNF
|