lean4-htt/src/Lean/Compiler/LCNF/CompilerM.lean
Leonardo de Moura e0197b4e09 feat: add bindCases
It is similar to `Code.bind` but has special support for `inlineMatcher`
2022-09-04 19:04:21 -07:00

325 lines
12 KiB
Text
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

/-
Copyright (c) 2022 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
import Lean.CoreM
import Lean.Compiler.LCNF.Basic
import Lean.Compiler.LCNF.LCtx
namespace Lean.Compiler.LCNF
/--
The state managed by the `CompilerM` `Monad`.
-/
structure CompilerM.State where
/--
A `LocalContext` to store local declarations from let binders
and other constructs in as we move through `Expr`s.
-/
lctx : LCtx := {}
/-- Next auxiliary variable suffix -/
nextIdx : Nat := 1
deriving Inhabited
abbrev CompilerM := StateRefT CompilerM.State CoreM
instance : AddMessageContext CompilerM where
addMessageContext msgData := do
let env ← getEnv
let lctx := (← get).lctx.toLocalContext
let opts ← getOptions
return MessageData.withContext { env, lctx, opts, mctx := {} } msgData
def getLocalDecl (fvarId : FVarId) : CompilerM LocalDecl := do
let some decl := (← get).lctx.localDecls.find? fvarId | throwError "unknown free variable {fvarId.name}"
return decl
def findFunDecl? (fvarId : FVarId) : CompilerM (Option FunDecl) :=
return (← get).lctx.funDecls.find? fvarId
def getFunDecl (fvarId : FVarId) : CompilerM FunDecl := do
let some decl ← findFunDecl? fvarId | throwError "unknown local function {fvarId.name}"
return decl
@[inline] def modifyLCtx (f : LCtx → LCtx) : CompilerM Unit := do
modify fun s => { s with lctx := f s.lctx }
def eraseFVar (fvarId : FVarId) (recursive := true) : CompilerM Unit := do
modifyLCtx fun lctx => lctx.erase fvarId recursive
def eraseFVarsAt (code : Code) : CompilerM Unit := do
modifyLCtx fun lctx => lctx.eraseFVarsAt code
def eraseParams (params : Array Param) : CompilerM Unit :=
params.forM (eraseFVar ·.fvarId)
/--
A free variable substitution.
We use these substitutions when inlining definitions and "internalizing" LCNF code into `CompilerM`.
During the internalization process, we ensure all free variables in the LCNF code do not collide with existing ones
at the `CompilerM` local context.
Remark: in LCNF, (computationally relevant) data is in A-normal form, but this is not the case for types and type formers.
So, when inlining we often want to replace a free variable with a type or type former.
-/
abbrev FVarSubst := Std.HashMap FVarId Expr
private partial def normExprImp (s : FVarSubst) (e : Expr) : Expr :=
go e
where
go (e : Expr) : Expr :=
if e.hasFVar then
match e with
| .fvar fvarId => s.find? fvarId |>.getD e
| .lit .. | .const .. | .sort .. | .mvar .. | .bvar .. => e
| .app f a => e.updateApp! (go f) (go a)
| .mdata _ b => e.updateMData! (go b)
| .proj _ _ b => e.updateProj! (go b)
| .forallE _ d b _ => e.updateForallE! (go d) (go b)
| .lam _ d b _ => e.updateLambdaE! (go d) (go b)
| .letE .. => unreachable! -- Valid LCNF does not contain `let`-declarations
else
e
private def normFVarImp (s : FVarSubst) (fvarId : FVarId) : FVarId :=
match s.find? fvarId with
| some (.fvar fvarId') => fvarId'
| some _ => panic! "invalid LCNF substitution of free variable with expression"
| none => fvarId
class MonadFVarSubst (m : Type → Type) where
getSubst : m FVarSubst
export MonadFVarSubst (getSubst)
instance (m n) [MonadLift m n] [MonadFVarSubst m] : MonadFVarSubst n where
getSubst := liftM (getSubst : m _)
@[inline] def normFVar [MonadFVarSubst m] [Monad m] (fvarId : FVarId) : m FVarId :=
return normFVarImp (← getSubst) fvarId
@[inline] def normExpr [MonadFVarSubst m] [Monad m] (e : Expr) : m Expr :=
return normExprImp (← getSubst) e
def normExprs [MonadFVarSubst m] [Monad m] (es : Array Expr) : m (Array Expr) :=
es.mapMonoM normExpr
def mkFreshBinderName (binderName := `_x): CompilerM Name := do
let declName := .num binderName (← get).nextIdx
modify fun s => { s with nextIdx := s.nextIdx + 1 }
return declName
private def refreshBinderName (binderName : Name) : CompilerM Name := do
match binderName with
| .num p _ =>
let r := .num p (← get).nextIdx
modify fun s => { s with nextIdx := s.nextIdx + 1 }
return r
| _ => return binderName
namespace Internalize
abbrev M := StateRefT FVarSubst CompilerM
instance : MonadFVarSubst M where
getSubst := get
private def mkNewFVarId (fvarId : FVarId) : M FVarId := do
let fvarId' ← Lean.mkFreshFVarId
modify fun s => s.insert fvarId (.fvar fvarId')
return fvarId'
private def addParam (p : Param) : M Param := do
let type ← normExpr p.type
let fvarId ← mkNewFVarId p.fvarId
modifyLCtx fun lctx => lctx.addLocalDecl fvarId p.binderName type
return { p with fvarId, type }
mutual
partial def internalizeFunDecl (decl : FunDecl) : M FunDecl := do
let type ← normExpr decl.type
let binderName ← refreshBinderName decl.binderName
let params ← decl.params.mapM addParam
let value ← internalizeCode decl.value
let fvarId ← mkNewFVarId decl.fvarId
let decl := { decl with binderName, fvarId, params, type, value }
modifyLCtx fun lctx => lctx.addFunDecl decl
return decl
partial def internalizeCode (code : Code) : M Code := do
match code with
| .let decl k =>
let binderName ← refreshBinderName decl.binderName
let type ← normExpr decl.type
let value ← normExpr decl.value
let fvarId ← mkNewFVarId decl.fvarId
modifyLCtx fun lctx => lctx.addLetDecl fvarId binderName type value
let k ← internalizeCode k
return .let { decl with binderName, fvarId, type, value } k
| .fun decl k =>
return .fun (← internalizeFunDecl decl) (← internalizeCode k)
| .jp decl k =>
return .jp (← internalizeFunDecl decl) (← internalizeCode k)
| .return fvarId => return .return (← normFVar fvarId)
| .jmp fvarId args => return .jmp (← normFVar fvarId) (← args.mapM normExpr)
| .unreach type => return .unreach (← normExpr type)
| .cases c =>
let resultType ← normExpr c.resultType
let discr ← normFVar c.discr
let alts ← c.alts.mapM fun
| .alt ctorName params k => return .alt ctorName (← params.mapM addParam) (← internalizeCode k)
| .default k => return .default (← internalizeCode k)
return .cases { c with discr, alts, resultType }
end
end Internalize
/--
Refresh free variables ids in `code`, and store their declarations in the local context.
-/
partial def Code.internalize (code : Code) (s : FVarSubst := {}) : CompilerM Code :=
Internalize.internalizeCode code |>.run' s
open Internalize in
def Decl.internalize (decl : Decl) (s : FVarSubst := {}): CompilerM Decl :=
go decl |>.run' s
where
go (decl : Decl) : M Decl := do
let type ← normExpr decl.type
let params ← decl.params.mapM addParam
let value ← internalizeCode decl.value
return { decl with type, params, value }
/-!
Helper functions for creating LCNF local declarations.
-/
def mkParam (binderName : Name) (type : Expr) : CompilerM Param := do
let fvarId ← mkFreshFVarId
modifyLCtx fun lctx => lctx.addLocalDecl fvarId binderName type
return { fvarId, binderName, type }
def mkLetDecl (binderName : Name) (type : Expr) (value : Expr) (pure := true) : CompilerM LetDecl := do
let fvarId ← mkFreshFVarId
modifyLCtx fun lctx => lctx.addLetDecl fvarId binderName type value
return { fvarId, binderName, type, value, pure }
def mkFunDecl (binderName : Name) (type : Expr) (params : Array Param) (value : Code) : CompilerM FunDecl := do
let fvarId ← mkFreshFVarId
let funDecl := { fvarId, binderName, type, params, value }
modifyLCtx fun lctx => lctx.addFunDecl funDecl
return funDecl
private unsafe def updateParamImp (p : Param) (type : Expr) : CompilerM Param := do
if ptrEq type p.type then
return p
else
let p := { p with type }
modifyLCtx fun lctx => lctx.addLocalDecl p.fvarId p.binderName p.type
return p
@[implementedBy updateParamImp] opaque Param.update (p : Param) (type : Expr) : CompilerM Param
private unsafe def updateLetDeclImp (decl : LetDecl) (type : Expr) (value : Expr) : CompilerM LetDecl := do
if ptrEq type decl.type && ptrEq value decl.value then
return decl
else
let decl := { decl with type, value }
modifyLCtx fun lctx => lctx.addLetDecl decl.fvarId decl.binderName decl.type decl.value
return decl
@[implementedBy updateLetDeclImp] opaque LetDecl.update (decl : LetDecl) (type : Expr) (value : Expr) : CompilerM LetDecl
def LetDecl.updateValue (decl : LetDecl) (value : Expr) : CompilerM LetDecl :=
decl.update decl.type value
private unsafe def updateFunDeclImp (decl: FunDecl) (type : Expr) (params : Array Param) (value : Code) : CompilerM FunDecl := do
if ptrEq type decl.type && ptrEq params decl.params && ptrEq value decl.value then
return decl
else
let decl := { decl with type, params, value }
modifyLCtx fun lctx => lctx.addFunDecl decl
return decl
@[implementedBy updateFunDeclImp] opaque FunDeclCore.update (decl: FunDecl) (type : Expr) (params : Array Param) (value : Code) : CompilerM FunDecl
abbrev FunDeclCore.update' (decl : FunDecl) (type : Expr) (value : Code) : CompilerM FunDecl :=
decl.update type decl.params value
abbrev FunDeclCore.updateValue (decl : FunDecl) (value : Code) : CompilerM FunDecl :=
decl.update decl.type decl.params value
@[inline] def normParam [MonadLiftT CompilerM m] [Monad m] [MonadFVarSubst m] (p : Param) : m Param := do
p.update (← normExpr p.type)
def normParams [MonadLiftT CompilerM m] [Monad m] [MonadFVarSubst m] (ps : Array Param) : m (Array Param) :=
ps.mapMonoM normParam
def normLetDecl [MonadLiftT CompilerM m] [Monad m] [MonadFVarSubst m] (decl : LetDecl) : m LetDecl := do
decl.update (← normExpr decl.type) (← normExpr decl.value)
instance : MonadFVarSubst (ReaderT FVarSubst CompilerM) where
getSubst := read
mutual
partial def normFunDeclImp (decl : FunDecl) : ReaderT FVarSubst CompilerM FunDecl := do
let type ← normExpr decl.type
let params ← normParams decl.params
let value ← normCodeImp decl.value
decl.update type params value
partial def normCodeImp (code : Code) : ReaderT FVarSubst CompilerM Code := do
match code with
| .let decl k => return code.updateLet! (← normLetDecl decl) (← normCodeImp k)
| .fun decl k | .jp decl k => return code.updateFun! (← normFunDeclImp decl) (← normCodeImp k)
| .return fvarId => return code.updateReturn! (← normFVar fvarId)
| .jmp fvarId args => return code.updateJmp! (← normFVar fvarId) (← normExprs args)
| .unreach type => return code.updateUnreach! (← normExpr type)
| .cases c =>
let resultType ← normExpr c.resultType
let discr ← normFVar c.discr
let alts ← c.alts.mapMonoM fun alt =>
match alt with
| .alt _ params k => return alt.updateAlt! (← normParams params) (← normCodeImp k)
| .default k => return alt.updateCode (← normCodeImp k)
return code.updateCases! resultType discr alts
end
@[inline] def normFunDecl [MonadLiftT CompilerM m] [Monad m] [MonadFVarSubst m] (decl : FunDecl) : m FunDecl := do
normFunDeclImp decl (← getSubst)
/-- Similar to `internalize`, but does not refresh `FVarId`s. -/
@[inline] def normCode [MonadLiftT CompilerM m] [Monad m] [MonadFVarSubst m] (code : Code) : m Code := do
normCodeImp code (← getSubst)
def replaceExprFVars (e : Expr) (s : FVarSubst) : CompilerM Expr :=
(normExpr e : ReaderT FVarSubst CompilerM Expr).run s
def replaceFVars (code : Code) (s : FVarSubst) : CompilerM Code :=
(normCode code : ReaderT FVarSubst CompilerM Code).run s
def replaceFVar (code : Code) (fvarId fvarId' : FVarId) : CompilerM Code :=
let s : FVarSubst := {}
replaceFVars code (s.insert fvarId (.fvar fvarId'))
def mkFreshJpName : CompilerM Name := do
mkFreshBinderName `_jp
def mkAuxParam (type : Expr) : CompilerM Param := do
mkParam (← mkFreshBinderName `_y) type
/--
Create a fresh local context and internalize the given decls.
-/
def cleanup (decl : Array Decl) : CompilerM (Array Decl) := do
modify fun _ => {}
decl.mapM fun decl => do
modify fun s => { s with nextIdx := 1 }
decl.internalize
def CompilerM.run (x : CompilerM α) (s : State := {}) : CoreM α :=
x |>.run' s
end Lean.Compiler.LCNF