feat: more update functions for LCNF

This commit is contained in:
Leonardo de Moura 2022-08-27 08:48:52 -07:00
parent bdf89b4d85
commit 11c8253f6c
5 changed files with 48 additions and 10 deletions

View file

@ -73,6 +73,26 @@ def AltCore.getCode : Alt → Code
| .default k => k
| .alt _ _ k => k
private unsafe def updateAltCodeImp (alt : Alt) (c : Code) : Alt :=
match alt with
| .default k => if ptrEq k c then alt else AltCore.default c
| .alt ctorName ps k => if ptrEq k c then alt else AltCore.alt ctorName ps c
@[implementedBy updateAltCodeImp] opaque AltCore.updateCode (alt : Alt) (c : Code) : Alt
@[inline] private unsafe def updateContImp (c : Code) (k' : Code) : Code :=
match c with
| .let decl k => if ptrEq k k' then c else .let decl k'
| .fun decl k => if ptrEq k k' then c else .fun decl k'
| .jp decl k => if ptrEq k k' then c else .jp decl k'
| _ => unreachable!
@[implementedBy updateContImp] opaque Code.updateCont! (c : Code) (k' : Code) : Code
def Code.isDecl : Code → Bool
| .let .. | .fun .. | .jp .. => true
| _ => false
partial def Code.size (c : Code) : Nat :=
go c 0
where

View file

@ -26,7 +26,7 @@ where
| .jp decl k =>
let value ← go decl.value
let type ← value.inferParamType decl.params
let decl := { decl with value, type }
let decl ← decl.update' type value
withReader (fun s => s.insert decl.fvarId) do
return .jp decl (← go k)
| .cases c =>

View file

@ -45,7 +45,7 @@ where
let type := (← getSubst).applyToExpr decl.type
let params ← (← getSubst).applyToParams decl.params
let value ← withNewScope do go decl.value
return { decl with type, params, value }
decl.update type params value
go (code : Code) : M Code := do
match code with

View file

@ -96,8 +96,8 @@ open Internalize in
/--
Refresh free variables ids in `code`, and store their declarations in the local context.
-/
partial def Code.internalize (code : Code) : CompilerM Code :=
go code |>.run' {}
partial def Code.internalize (code : Code) (s : FVarSubst := {}) : CompilerM Code :=
go code |>.run' s
where
goFunDecl (decl : FunDecl) : M FunDecl := do
let type ← translateExpr decl.type
@ -159,7 +159,7 @@ private unsafe def updateParamImp (p : Param) (type : Expr) : CompilerM Param :=
modifyLCtx fun lctx => lctx.addLocalDecl p.fvarId p.binderName p.type
return p
@[implementedBy updateParamImp] opaque updateParam (p : Param) (type : Expr) : CompilerM Param
@[implementedBy updateParamImp] opaque Param.update (p : Param) (type : Expr) : CompilerM Param
private unsafe def updateLetDeclImp (decl : LetDecl) (type : Expr) (value : Expr) : CompilerM LetDecl := do
if ptrEq type decl.type && ptrEq value decl.value then
@ -169,16 +169,32 @@ private unsafe def updateLetDeclImp (decl : LetDecl) (type : Expr) (value : Expr
modifyLCtx fun lctx => lctx.addLetDecl decl.fvarId decl.binderName decl.type decl.value
return decl
@[implementedBy updateLetDeclImp] opaque updateLetDecl (decl : LetDecl) (type : Expr) (value : Expr) : CompilerM LetDecl
@[implementedBy updateLetDeclImp] opaque LetDecl.update (decl : LetDecl) (type : Expr) (value : Expr) : CompilerM LetDecl
private unsafe def updateFunDeclImp (decl: FunDecl) (type : Expr) (params : Array Param) (value : Code) : CompilerM FunDecl := do
if ptrEq type decl.type && ptrEq params decl.params && ptrEq value decl.value then
return decl
else
let decl := { decl with type, params, value }
modifyLCtx fun lctx => lctx.addFunDecl decl
return decl
@[implementedBy updateFunDeclImp] opaque FunDeclCore.update (decl: FunDecl) (type : Expr) (params : Array Param) (value : Code) : CompilerM FunDecl
abbrev FunDeclCore.update' (decl : FunDecl) (type : Expr) (value : Code) : CompilerM FunDecl :=
decl.update type decl.params value
abbrev FunDeclCore.updateValue (decl : FunDecl) (value : Code) : CompilerM FunDecl :=
decl.update decl.type decl.params value
def FVarSubst.applyToParam (s : FVarSubst) (p : Param) : CompilerM Param :=
updateParam p (s.applyToExpr p.type)
p.update (s.applyToExpr p.type)
def FVarSubst.applyToParams (s : FVarSubst) (ps : Array Param) : CompilerM (Array Param) :=
ps.mapM s.applyToParam
def FVarSubst.applyToLetDecl (s : FVarSubst) (decl : LetDecl) : CompilerM LetDecl :=
updateLetDecl decl (s.applyToExpr decl.type) (s.applyToExpr decl.value)
decl.update (s.applyToExpr decl.type) (s.applyToExpr decl.value)
def mkFreshBinderName (binderName := `_x): CompilerM Name := do
let declName := .num binderName (← get).nextIdx

View file

@ -78,11 +78,13 @@ mutual
| .fun decl k =>
withCheckpoint do
let value ← withParams decl.params <| pullDecls decl.value
withFVar decl.fvarId do return .fun { decl with value } (← pullDecls k)
let decl ← decl.updateValue value
withFVar decl.fvarId do return .fun decl (← pullDecls k)
| .jp decl k =>
withCheckpoint do
let value ← withParams decl.params <| pullDecls decl.value
withFVar decl.fvarId do return .jp { decl with value } (← pullDecls k)
let decl ← decl.updateValue value
withFVar decl.fvarId do return .jp decl (← pullDecls k)
| _ => return code
end