feat: helper LCNF functions
This commit is contained in:
parent
cd49e564cf
commit
3e2f8c61ec
3 changed files with 72 additions and 3 deletions
|
|
@ -7,12 +7,24 @@ import Lean.Expr
|
|||
|
||||
namespace Lean.Compiler.LCNF
|
||||
|
||||
/-!
|
||||
# Lean Compiler Normal Form (LCNF)
|
||||
|
||||
It is based on the [A-normal form](https://en.wikipedia.org/wiki/A-normal_form),
|
||||
and the approach described in the paper
|
||||
[Compiling without continuations](https://www.microsoft.com/en-us/research/wp-content/uploads/2016/11/compiling-without-continuations.pdf).
|
||||
|
||||
-/
|
||||
|
||||
structure Param where
|
||||
fvarId : FVarId
|
||||
binderName : Name
|
||||
type : Expr
|
||||
deriving Inhabited
|
||||
|
||||
def Param.toExpr (p : Param) : Expr :=
|
||||
.fvar p.fvarId
|
||||
|
||||
inductive AltCore (Code : Type) where
|
||||
| alt (ctorName : Name) (params : Array Param) (code : Code)
|
||||
| default (code : Code)
|
||||
|
|
|
|||
|
|
@ -18,6 +18,8 @@ structure CompilerM.State where
|
|||
and other constructs in as we move through `Expr`s.
|
||||
-/
|
||||
lctx : LCtx := {}
|
||||
/-- Next auxiliary variable suffix -/
|
||||
nextIdx : Nat := 1
|
||||
deriving Inhabited
|
||||
|
||||
abbrev CompilerM := StateRefT CompilerM.State CoreM
|
||||
|
|
@ -33,6 +35,9 @@ def getLocalDecl (fvarId : FVarId) : CompilerM LocalDecl := do
|
|||
let some decl := (← get).lctx.localDecls.find? fvarId | throwError "unknown free variable {fvarId.name}"
|
||||
return decl
|
||||
|
||||
@[inline] def modifyLCtx (f : LCtx → LCtx) : CompilerM Unit := do
|
||||
modify fun s => { s with lctx := f s.lctx }
|
||||
|
||||
namespace Internalize
|
||||
|
||||
structure State where
|
||||
|
|
@ -66,9 +71,6 @@ where
|
|||
@[inline] private def translate (e : Expr) : M Expr :=
|
||||
return translateCore (← get) e
|
||||
|
||||
@[inline] def modifyLCtx (f : LCtx → LCtx) : CompilerM Unit := do
|
||||
modify fun s => { s with lctx := f s.lctx }
|
||||
|
||||
private def mkNewFVarId (fvarId : FVarId) : M FVarId := do
|
||||
let fvarId' ← Lean.mkFreshFVarId
|
||||
modify fun s => { s with fvarIdMap := s.fvarIdMap.insert fvarId fvarId' }
|
||||
|
|
@ -121,4 +123,35 @@ where
|
|||
| .default k => return .default (← go k)
|
||||
return .cases { c with discr, alts }
|
||||
|
||||
/-!
|
||||
Helper functions for creating LCNF local declarations.
|
||||
-/
|
||||
|
||||
def mkParam (binderName : Name) (type : Expr) : CompilerM Param := do
|
||||
let fvarId ← mkFreshFVarId
|
||||
modifyLCtx fun lctx => lctx.addLocalDecl fvarId binderName type
|
||||
return { fvarId, binderName, type }
|
||||
|
||||
def mkLetDecl (binderName : Name) (type : Expr) (value : Expr) : CompilerM LetDecl := do
|
||||
let fvarId ← mkFreshFVarId
|
||||
modifyLCtx fun lctx => lctx.addLetDecl fvarId binderName type value
|
||||
return { fvarId, binderName, type, value }
|
||||
|
||||
def mkFunDecl (binderName : Name) (type : Expr) (params : Array Param) (value : Code) : CompilerM FunDecl := do
|
||||
let fvarId ← mkFreshFVarId
|
||||
let funDecl := { fvarId, binderName, type, params, value }
|
||||
modifyLCtx fun lctx => lctx.addFunDecl funDecl
|
||||
return funDecl
|
||||
|
||||
def mkFreshBinderName (binderName := `_x): CompilerM Name := do
|
||||
let declName := .num binderName (← get).nextIdx
|
||||
modify fun s => { s with nextIdx := s.nextIdx + 1 }
|
||||
return declName
|
||||
|
||||
def mkFreshJpName : CompilerM Name := do
|
||||
mkFreshBinderName `_jp
|
||||
|
||||
def mkAuxParam (type : Expr) : CompilerM Param := do
|
||||
mkParam (← mkFreshBinderName `_y) type
|
||||
|
||||
end Lean.Compiler.LCNF
|
||||
|
|
|
|||
|
|
@ -184,4 +184,28 @@ def AltCore.inferType (alt : Alt) : CompilerM Expr := do
|
|||
| .default k => k.inferType
|
||||
| .alt _ params k => k.inferParamType params
|
||||
|
||||
def mkAuxLetDecl (e : Expr) (prefixName := `_x) : CompilerM Expr := do
|
||||
if e.isFVar then
|
||||
return e
|
||||
else
|
||||
let letDecl ← mkLetDecl (← mkFreshBinderName prefixName) (← inferType e) e
|
||||
return .fvar letDecl.fvarId
|
||||
|
||||
def mkForallParams (params : Array Param) (type : Expr) : CompilerM Expr :=
|
||||
InferType.mkForallParams params type |>.run {}
|
||||
|
||||
def mkAuxFunDecl (params : Array Param) (code : Code) (prefixName := `_f) : CompilerM FunDecl := do
|
||||
let type ← mkForallParams params (← code.inferType)
|
||||
let binderName ← mkFreshBinderName prefixName
|
||||
mkFunDecl binderName type params code
|
||||
|
||||
def mkAuxJpDecl (params : Array Param) (code : Code) (prefixName := `_jp) : CompilerM FunDecl := do
|
||||
mkAuxFunDecl params code prefixName
|
||||
|
||||
def mkAuxJpDecl' (fvarId : FVarId) (code : Code) (prefixName := `_jp) : CompilerM FunDecl := do
|
||||
let y ← mkFreshBinderName `_y
|
||||
let yType ← inferType (.fvar fvarId)
|
||||
let params := #[{ fvarId, binderName := y, type := yType }]
|
||||
mkAuxFunDecl params code prefixName
|
||||
|
||||
end Lean.Compiler.LCNF
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue