325 lines
12 KiB
Text
325 lines
12 KiB
Text
/-
|
||
Copyright (c) 2022 Microsoft Corporation. All rights reserved.
|
||
Released under Apache 2.0 license as described in the file LICENSE.
|
||
Authors: Leonardo de Moura
|
||
-/
|
||
import Lean.CoreM
|
||
import Lean.Compiler.LCNF.Basic
|
||
import Lean.Compiler.LCNF.LCtx
|
||
|
||
namespace Lean.Compiler.LCNF
|
||
|
||
/--
|
||
The state managed by the `CompilerM` `Monad`.
|
||
-/
|
||
structure CompilerM.State where
|
||
/--
|
||
A `LocalContext` to store local declarations from let binders
|
||
and other constructs in as we move through `Expr`s.
|
||
-/
|
||
lctx : LCtx := {}
|
||
/-- Next auxiliary variable suffix -/
|
||
nextIdx : Nat := 1
|
||
deriving Inhabited
|
||
|
||
abbrev CompilerM := StateRefT CompilerM.State CoreM
|
||
|
||
instance : AddMessageContext CompilerM where
|
||
addMessageContext msgData := do
|
||
let env ← getEnv
|
||
let lctx := (← get).lctx.toLocalContext
|
||
let opts ← getOptions
|
||
return MessageData.withContext { env, lctx, opts, mctx := {} } msgData
|
||
|
||
def getLocalDecl (fvarId : FVarId) : CompilerM LocalDecl := do
|
||
let some decl := (← get).lctx.localDecls.find? fvarId | throwError "unknown free variable {fvarId.name}"
|
||
return decl
|
||
|
||
def findFunDecl? (fvarId : FVarId) : CompilerM (Option FunDecl) :=
|
||
return (← get).lctx.funDecls.find? fvarId
|
||
|
||
def getFunDecl (fvarId : FVarId) : CompilerM FunDecl := do
|
||
let some decl ← findFunDecl? fvarId | throwError "unknown local function {fvarId.name}"
|
||
return decl
|
||
|
||
@[inline] def modifyLCtx (f : LCtx → LCtx) : CompilerM Unit := do
|
||
modify fun s => { s with lctx := f s.lctx }
|
||
|
||
def eraseFVar (fvarId : FVarId) (recursive := true) : CompilerM Unit := do
|
||
modifyLCtx fun lctx => lctx.erase fvarId recursive
|
||
|
||
def eraseFVarsAt (code : Code) : CompilerM Unit := do
|
||
modifyLCtx fun lctx => lctx.eraseFVarsAt code
|
||
|
||
def eraseParams (params : Array Param) : CompilerM Unit :=
|
||
params.forM (eraseFVar ·.fvarId)
|
||
|
||
/--
|
||
A free variable substitution.
|
||
We use these substitutions when inlining definitions and "internalizing" LCNF code into `CompilerM`.
|
||
During the internalization process, we ensure all free variables in the LCNF code do not collide with existing ones
|
||
at the `CompilerM` local context.
|
||
Remark: in LCNF, (computationally relevant) data is in A-normal form, but this is not the case for types and type formers.
|
||
So, when inlining we often want to replace a free variable with a type or type former.
|
||
-/
|
||
abbrev FVarSubst := Std.HashMap FVarId Expr
|
||
|
||
private partial def normExprImp (s : FVarSubst) (e : Expr) : Expr :=
|
||
go e
|
||
where
|
||
go (e : Expr) : Expr :=
|
||
if e.hasFVar then
|
||
match e with
|
||
| .fvar fvarId => s.find? fvarId |>.getD e
|
||
| .lit .. | .const .. | .sort .. | .mvar .. | .bvar .. => e
|
||
| .app f a => e.updateApp! (go f) (go a)
|
||
| .mdata _ b => e.updateMData! (go b)
|
||
| .proj _ _ b => e.updateProj! (go b)
|
||
| .forallE _ d b _ => e.updateForallE! (go d) (go b)
|
||
| .lam _ d b _ => e.updateLambdaE! (go d) (go b)
|
||
| .letE .. => unreachable! -- Valid LCNF does not contain `let`-declarations
|
||
else
|
||
e
|
||
|
||
private def normFVarImp (s : FVarSubst) (fvarId : FVarId) : FVarId :=
|
||
match s.find? fvarId with
|
||
| some (.fvar fvarId') => fvarId'
|
||
| some _ => panic! "invalid LCNF substitution of free variable with expression"
|
||
| none => fvarId
|
||
|
||
class MonadFVarSubst (m : Type → Type) where
|
||
getSubst : m FVarSubst
|
||
|
||
export MonadFVarSubst (getSubst)
|
||
|
||
instance (m n) [MonadLift m n] [MonadFVarSubst m] : MonadFVarSubst n where
|
||
getSubst := liftM (getSubst : m _)
|
||
|
||
@[inline] def normFVar [MonadFVarSubst m] [Monad m] (fvarId : FVarId) : m FVarId :=
|
||
return normFVarImp (← getSubst) fvarId
|
||
|
||
@[inline] def normExpr [MonadFVarSubst m] [Monad m] (e : Expr) : m Expr :=
|
||
return normExprImp (← getSubst) e
|
||
|
||
def normExprs [MonadFVarSubst m] [Monad m] (es : Array Expr) : m (Array Expr) :=
|
||
es.mapMonoM normExpr
|
||
|
||
def mkFreshBinderName (binderName := `_x): CompilerM Name := do
|
||
let declName := .num binderName (← get).nextIdx
|
||
modify fun s => { s with nextIdx := s.nextIdx + 1 }
|
||
return declName
|
||
|
||
private def refreshBinderName (binderName : Name) : CompilerM Name := do
|
||
match binderName with
|
||
| .num p _ =>
|
||
let r := .num p (← get).nextIdx
|
||
modify fun s => { s with nextIdx := s.nextIdx + 1 }
|
||
return r
|
||
| _ => return binderName
|
||
|
||
namespace Internalize
|
||
|
||
abbrev M := StateRefT FVarSubst CompilerM
|
||
|
||
instance : MonadFVarSubst M where
|
||
getSubst := get
|
||
|
||
private def mkNewFVarId (fvarId : FVarId) : M FVarId := do
|
||
let fvarId' ← Lean.mkFreshFVarId
|
||
modify fun s => s.insert fvarId (.fvar fvarId')
|
||
return fvarId'
|
||
|
||
private def addParam (p : Param) : M Param := do
|
||
let type ← normExpr p.type
|
||
let fvarId ← mkNewFVarId p.fvarId
|
||
modifyLCtx fun lctx => lctx.addLocalDecl fvarId p.binderName type
|
||
return { p with fvarId, type }
|
||
|
||
mutual
|
||
|
||
partial def internalizeFunDecl (decl : FunDecl) : M FunDecl := do
|
||
let type ← normExpr decl.type
|
||
let binderName ← refreshBinderName decl.binderName
|
||
let params ← decl.params.mapM addParam
|
||
let value ← internalizeCode decl.value
|
||
let fvarId ← mkNewFVarId decl.fvarId
|
||
let decl := { decl with binderName, fvarId, params, type, value }
|
||
modifyLCtx fun lctx => lctx.addFunDecl decl
|
||
return decl
|
||
|
||
partial def internalizeCode (code : Code) : M Code := do
|
||
match code with
|
||
| .let decl k =>
|
||
let binderName ← refreshBinderName decl.binderName
|
||
let type ← normExpr decl.type
|
||
let value ← normExpr decl.value
|
||
let fvarId ← mkNewFVarId decl.fvarId
|
||
modifyLCtx fun lctx => lctx.addLetDecl fvarId binderName type value
|
||
let k ← internalizeCode k
|
||
return .let { decl with binderName, fvarId, type, value } k
|
||
| .fun decl k =>
|
||
return .fun (← internalizeFunDecl decl) (← internalizeCode k)
|
||
| .jp decl k =>
|
||
return .jp (← internalizeFunDecl decl) (← internalizeCode k)
|
||
| .return fvarId => return .return (← normFVar fvarId)
|
||
| .jmp fvarId args => return .jmp (← normFVar fvarId) (← args.mapM normExpr)
|
||
| .unreach type => return .unreach (← normExpr type)
|
||
| .cases c =>
|
||
let resultType ← normExpr c.resultType
|
||
let discr ← normFVar c.discr
|
||
let alts ← c.alts.mapM fun
|
||
| .alt ctorName params k => return .alt ctorName (← params.mapM addParam) (← internalizeCode k)
|
||
| .default k => return .default (← internalizeCode k)
|
||
return .cases { c with discr, alts, resultType }
|
||
|
||
end
|
||
|
||
end Internalize
|
||
|
||
/--
|
||
Refresh free variables ids in `code`, and store their declarations in the local context.
|
||
-/
|
||
partial def Code.internalize (code : Code) (s : FVarSubst := {}) : CompilerM Code :=
|
||
Internalize.internalizeCode code |>.run' s
|
||
|
||
open Internalize in
|
||
def Decl.internalize (decl : Decl) (s : FVarSubst := {}): CompilerM Decl :=
|
||
go decl |>.run' s
|
||
where
|
||
go (decl : Decl) : M Decl := do
|
||
let type ← normExpr decl.type
|
||
let params ← decl.params.mapM addParam
|
||
let value ← internalizeCode decl.value
|
||
return { decl with type, params, value }
|
||
|
||
/-!
|
||
Helper functions for creating LCNF local declarations.
|
||
-/
|
||
|
||
def mkParam (binderName : Name) (type : Expr) : CompilerM Param := do
|
||
let fvarId ← mkFreshFVarId
|
||
modifyLCtx fun lctx => lctx.addLocalDecl fvarId binderName type
|
||
return { fvarId, binderName, type }
|
||
|
||
def mkLetDecl (binderName : Name) (type : Expr) (value : Expr) (pure := true) : CompilerM LetDecl := do
|
||
let fvarId ← mkFreshFVarId
|
||
modifyLCtx fun lctx => lctx.addLetDecl fvarId binderName type value
|
||
return { fvarId, binderName, type, value, pure }
|
||
|
||
def mkFunDecl (binderName : Name) (type : Expr) (params : Array Param) (value : Code) : CompilerM FunDecl := do
|
||
let fvarId ← mkFreshFVarId
|
||
let funDecl := { fvarId, binderName, type, params, value }
|
||
modifyLCtx fun lctx => lctx.addFunDecl funDecl
|
||
return funDecl
|
||
|
||
private unsafe def updateParamImp (p : Param) (type : Expr) : CompilerM Param := do
|
||
if ptrEq type p.type then
|
||
return p
|
||
else
|
||
let p := { p with type }
|
||
modifyLCtx fun lctx => lctx.addLocalDecl p.fvarId p.binderName p.type
|
||
return p
|
||
|
||
@[implementedBy updateParamImp] opaque Param.update (p : Param) (type : Expr) : CompilerM Param
|
||
|
||
private unsafe def updateLetDeclImp (decl : LetDecl) (type : Expr) (value : Expr) : CompilerM LetDecl := do
|
||
if ptrEq type decl.type && ptrEq value decl.value then
|
||
return decl
|
||
else
|
||
let decl := { decl with type, value }
|
||
modifyLCtx fun lctx => lctx.addLetDecl decl.fvarId decl.binderName decl.type decl.value
|
||
return decl
|
||
|
||
@[implementedBy updateLetDeclImp] opaque LetDecl.update (decl : LetDecl) (type : Expr) (value : Expr) : CompilerM LetDecl
|
||
|
||
def LetDecl.updateValue (decl : LetDecl) (value : Expr) : CompilerM LetDecl :=
|
||
decl.update decl.type value
|
||
|
||
private unsafe def updateFunDeclImp (decl: FunDecl) (type : Expr) (params : Array Param) (value : Code) : CompilerM FunDecl := do
|
||
if ptrEq type decl.type && ptrEq params decl.params && ptrEq value decl.value then
|
||
return decl
|
||
else
|
||
let decl := { decl with type, params, value }
|
||
modifyLCtx fun lctx => lctx.addFunDecl decl
|
||
return decl
|
||
|
||
@[implementedBy updateFunDeclImp] opaque FunDeclCore.update (decl: FunDecl) (type : Expr) (params : Array Param) (value : Code) : CompilerM FunDecl
|
||
|
||
abbrev FunDeclCore.update' (decl : FunDecl) (type : Expr) (value : Code) : CompilerM FunDecl :=
|
||
decl.update type decl.params value
|
||
|
||
abbrev FunDeclCore.updateValue (decl : FunDecl) (value : Code) : CompilerM FunDecl :=
|
||
decl.update decl.type decl.params value
|
||
|
||
@[inline] def normParam [MonadLiftT CompilerM m] [Monad m] [MonadFVarSubst m] (p : Param) : m Param := do
|
||
p.update (← normExpr p.type)
|
||
|
||
def normParams [MonadLiftT CompilerM m] [Monad m] [MonadFVarSubst m] (ps : Array Param) : m (Array Param) :=
|
||
ps.mapMonoM normParam
|
||
|
||
def normLetDecl [MonadLiftT CompilerM m] [Monad m] [MonadFVarSubst m] (decl : LetDecl) : m LetDecl := do
|
||
decl.update (← normExpr decl.type) (← normExpr decl.value)
|
||
|
||
instance : MonadFVarSubst (ReaderT FVarSubst CompilerM) where
|
||
getSubst := read
|
||
|
||
mutual
|
||
partial def normFunDeclImp (decl : FunDecl) : ReaderT FVarSubst CompilerM FunDecl := do
|
||
let type ← normExpr decl.type
|
||
let params ← normParams decl.params
|
||
let value ← normCodeImp decl.value
|
||
decl.update type params value
|
||
|
||
partial def normCodeImp (code : Code) : ReaderT FVarSubst CompilerM Code := do
|
||
match code with
|
||
| .let decl k => return code.updateLet! (← normLetDecl decl) (← normCodeImp k)
|
||
| .fun decl k | .jp decl k => return code.updateFun! (← normFunDeclImp decl) (← normCodeImp k)
|
||
| .return fvarId => return code.updateReturn! (← normFVar fvarId)
|
||
| .jmp fvarId args => return code.updateJmp! (← normFVar fvarId) (← normExprs args)
|
||
| .unreach type => return code.updateUnreach! (← normExpr type)
|
||
| .cases c =>
|
||
let resultType ← normExpr c.resultType
|
||
let discr ← normFVar c.discr
|
||
let alts ← c.alts.mapMonoM fun alt =>
|
||
match alt with
|
||
| .alt _ params k => return alt.updateAlt! (← normParams params) (← normCodeImp k)
|
||
| .default k => return alt.updateCode (← normCodeImp k)
|
||
return code.updateCases! resultType discr alts
|
||
end
|
||
|
||
@[inline] def normFunDecl [MonadLiftT CompilerM m] [Monad m] [MonadFVarSubst m] (decl : FunDecl) : m FunDecl := do
|
||
normFunDeclImp decl (← getSubst)
|
||
|
||
/-- Similar to `internalize`, but does not refresh `FVarId`s. -/
|
||
@[inline] def normCode [MonadLiftT CompilerM m] [Monad m] [MonadFVarSubst m] (code : Code) : m Code := do
|
||
normCodeImp code (← getSubst)
|
||
|
||
def replaceExprFVars (e : Expr) (s : FVarSubst) : CompilerM Expr :=
|
||
(normExpr e : ReaderT FVarSubst CompilerM Expr).run s
|
||
|
||
def replaceFVars (code : Code) (s : FVarSubst) : CompilerM Code :=
|
||
(normCode code : ReaderT FVarSubst CompilerM Code).run s
|
||
|
||
def replaceFVar (code : Code) (fvarId fvarId' : FVarId) : CompilerM Code :=
|
||
let s : FVarSubst := {}
|
||
replaceFVars code (s.insert fvarId (.fvar fvarId'))
|
||
|
||
def mkFreshJpName : CompilerM Name := do
|
||
mkFreshBinderName `_jp
|
||
|
||
def mkAuxParam (type : Expr) : CompilerM Param := do
|
||
mkParam (← mkFreshBinderName `_y) type
|
||
|
||
/--
|
||
Create a fresh local context and internalize the given decls.
|
||
-/
|
||
def cleanup (decl : Array Decl) : CompilerM (Array Decl) := do
|
||
modify fun _ => {}
|
||
decl.mapM fun decl => do
|
||
modify fun s => { s with nextIdx := 1 }
|
||
decl.internalize
|
||
|
||
def CompilerM.run (x : CompilerM α) (s : State := {}) : CoreM α :=
|
||
x |>.run' s
|
||
|
||
end Lean.Compiler.LCNF
|