feat: more update functions for LCNF
This commit is contained in:
parent
bdf89b4d85
commit
11c8253f6c
5 changed files with 48 additions and 10 deletions
|
|
@ -73,6 +73,26 @@ def AltCore.getCode : Alt → Code
|
|||
| .default k => k
|
||||
| .alt _ _ k => k
|
||||
|
||||
private unsafe def updateAltCodeImp (alt : Alt) (c : Code) : Alt :=
|
||||
match alt with
|
||||
| .default k => if ptrEq k c then alt else AltCore.default c
|
||||
| .alt ctorName ps k => if ptrEq k c then alt else AltCore.alt ctorName ps c
|
||||
|
||||
@[implementedBy updateAltCodeImp] opaque AltCore.updateCode (alt : Alt) (c : Code) : Alt
|
||||
|
||||
@[inline] private unsafe def updateContImp (c : Code) (k' : Code) : Code :=
|
||||
match c with
|
||||
| .let decl k => if ptrEq k k' then c else .let decl k'
|
||||
| .fun decl k => if ptrEq k k' then c else .fun decl k'
|
||||
| .jp decl k => if ptrEq k k' then c else .jp decl k'
|
||||
| _ => unreachable!
|
||||
|
||||
@[implementedBy updateContImp] opaque Code.updateCont! (c : Code) (k' : Code) : Code
|
||||
|
||||
def Code.isDecl : Code → Bool
|
||||
| .let .. | .fun .. | .jp .. => true
|
||||
| _ => false
|
||||
|
||||
partial def Code.size (c : Code) : Nat :=
|
||||
go c 0
|
||||
where
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ where
|
|||
| .jp decl k =>
|
||||
let value ← go decl.value
|
||||
let type ← value.inferParamType decl.params
|
||||
let decl := { decl with value, type }
|
||||
let decl ← decl.update' type value
|
||||
withReader (fun s => s.insert decl.fvarId) do
|
||||
return .jp decl (← go k)
|
||||
| .cases c =>
|
||||
|
|
|
|||
|
|
@ -45,7 +45,7 @@ where
|
|||
let type := (← getSubst).applyToExpr decl.type
|
||||
let params ← (← getSubst).applyToParams decl.params
|
||||
let value ← withNewScope do go decl.value
|
||||
return { decl with type, params, value }
|
||||
decl.update type params value
|
||||
|
||||
go (code : Code) : M Code := do
|
||||
match code with
|
||||
|
|
|
|||
|
|
@ -96,8 +96,8 @@ open Internalize in
|
|||
/--
|
||||
Refresh free variables ids in `code`, and store their declarations in the local context.
|
||||
-/
|
||||
partial def Code.internalize (code : Code) : CompilerM Code :=
|
||||
go code |>.run' {}
|
||||
partial def Code.internalize (code : Code) (s : FVarSubst := {}) : CompilerM Code :=
|
||||
go code |>.run' s
|
||||
where
|
||||
goFunDecl (decl : FunDecl) : M FunDecl := do
|
||||
let type ← translateExpr decl.type
|
||||
|
|
@ -159,7 +159,7 @@ private unsafe def updateParamImp (p : Param) (type : Expr) : CompilerM Param :=
|
|||
modifyLCtx fun lctx => lctx.addLocalDecl p.fvarId p.binderName p.type
|
||||
return p
|
||||
|
||||
@[implementedBy updateParamImp] opaque updateParam (p : Param) (type : Expr) : CompilerM Param
|
||||
@[implementedBy updateParamImp] opaque Param.update (p : Param) (type : Expr) : CompilerM Param
|
||||
|
||||
private unsafe def updateLetDeclImp (decl : LetDecl) (type : Expr) (value : Expr) : CompilerM LetDecl := do
|
||||
if ptrEq type decl.type && ptrEq value decl.value then
|
||||
|
|
@ -169,16 +169,32 @@ private unsafe def updateLetDeclImp (decl : LetDecl) (type : Expr) (value : Expr
|
|||
modifyLCtx fun lctx => lctx.addLetDecl decl.fvarId decl.binderName decl.type decl.value
|
||||
return decl
|
||||
|
||||
@[implementedBy updateLetDeclImp] opaque updateLetDecl (decl : LetDecl) (type : Expr) (value : Expr) : CompilerM LetDecl
|
||||
@[implementedBy updateLetDeclImp] opaque LetDecl.update (decl : LetDecl) (type : Expr) (value : Expr) : CompilerM LetDecl
|
||||
|
||||
private unsafe def updateFunDeclImp (decl: FunDecl) (type : Expr) (params : Array Param) (value : Code) : CompilerM FunDecl := do
|
||||
if ptrEq type decl.type && ptrEq params decl.params && ptrEq value decl.value then
|
||||
return decl
|
||||
else
|
||||
let decl := { decl with type, params, value }
|
||||
modifyLCtx fun lctx => lctx.addFunDecl decl
|
||||
return decl
|
||||
|
||||
@[implementedBy updateFunDeclImp] opaque FunDeclCore.update (decl: FunDecl) (type : Expr) (params : Array Param) (value : Code) : CompilerM FunDecl
|
||||
|
||||
abbrev FunDeclCore.update' (decl : FunDecl) (type : Expr) (value : Code) : CompilerM FunDecl :=
|
||||
decl.update type decl.params value
|
||||
|
||||
abbrev FunDeclCore.updateValue (decl : FunDecl) (value : Code) : CompilerM FunDecl :=
|
||||
decl.update decl.type decl.params value
|
||||
|
||||
def FVarSubst.applyToParam (s : FVarSubst) (p : Param) : CompilerM Param :=
|
||||
updateParam p (s.applyToExpr p.type)
|
||||
p.update (s.applyToExpr p.type)
|
||||
|
||||
def FVarSubst.applyToParams (s : FVarSubst) (ps : Array Param) : CompilerM (Array Param) :=
|
||||
ps.mapM s.applyToParam
|
||||
|
||||
def FVarSubst.applyToLetDecl (s : FVarSubst) (decl : LetDecl) : CompilerM LetDecl :=
|
||||
updateLetDecl decl (s.applyToExpr decl.type) (s.applyToExpr decl.value)
|
||||
decl.update (s.applyToExpr decl.type) (s.applyToExpr decl.value)
|
||||
|
||||
def mkFreshBinderName (binderName := `_x): CompilerM Name := do
|
||||
let declName := .num binderName (← get).nextIdx
|
||||
|
|
|
|||
|
|
@ -78,11 +78,13 @@ mutual
|
|||
| .fun decl k =>
|
||||
withCheckpoint do
|
||||
let value ← withParams decl.params <| pullDecls decl.value
|
||||
withFVar decl.fvarId do return .fun { decl with value } (← pullDecls k)
|
||||
let decl ← decl.updateValue value
|
||||
withFVar decl.fvarId do return .fun decl (← pullDecls k)
|
||||
| .jp decl k =>
|
||||
withCheckpoint do
|
||||
let value ← withParams decl.params <| pullDecls decl.value
|
||||
withFVar decl.fvarId do return .jp { decl with value } (← pullDecls k)
|
||||
let decl ← decl.updateValue value
|
||||
withFVar decl.fvarId do return .jp decl (← pullDecls k)
|
||||
| _ => return code
|
||||
|
||||
end
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue