feat: more LCNF update functions
and bug fixes at CSE
This commit is contained in:
parent
5552d610e8
commit
062d4728a1
4 changed files with 50 additions and 21 deletions
|
|
@ -126,13 +126,19 @@ private unsafe def updateAltImp (alt : Alt) (ps' : Array Param) (k' : Code) : Al
|
|||
|
||||
@[implementedBy updateFunImp] opaque Code.updateFun! (c : Code) (decl' : FunDecl) (k' : Code) : Code
|
||||
|
||||
/-
|
||||
@[inline] private unsafe def updateCases (cases : Cases) (decl' : FunDecl) (k' : Code) : Cases :=
|
||||
@[inline] private unsafe def updateReturnImp (c : Code) (fvarId' : FVarId) : Code :=
|
||||
match c with
|
||||
| .fun decl k => if ptrEq k k' && ptrEq decl decl' then c else .fun decl' k'
|
||||
| .jp decl k => if ptrEq k k' && ptrEq decl decl' then c else .jp decl' k'
|
||||
| .return fvarId => if fvarId == fvarId' then c else .return fvarId'
|
||||
| _ => unreachable!
|
||||
-/
|
||||
|
||||
@[implementedBy updateReturnImp] opaque Code.updateReturn! (c : Code) (fvarId' : FVarId) : Code
|
||||
|
||||
@[inline] private unsafe def updateJmpImp (c : Code) (fvarId' : FVarId) (args' : Array Expr) : Code :=
|
||||
match c with
|
||||
| .jmp fvarId args => if fvarId == fvarId' && ptrEq args args' then c else .jmp fvarId' args'
|
||||
| _ => unreachable!
|
||||
|
||||
@[implementedBy updateJmpImp] opaque Code.updateJmp! (c : Code) (fvarId' : FVarId) (args' : Array Expr) : Code
|
||||
|
||||
def Code.isDecl : Code → Bool
|
||||
| .let .. | .fun .. | .jp .. => true
|
||||
|
|
|
|||
|
|
@ -89,7 +89,9 @@ where
|
|||
return alt.updateAlt! ps (← go k)
|
||||
| .default k => withNewScope do return alt.updateCode (← go k)
|
||||
return code.updateCases! resultType discr alts
|
||||
| .return .. | .jmp .. | .unreach .. => return code
|
||||
| .return fvarId => return code.updateReturn! ((← getSubst).applyToFVar fvarId)
|
||||
| .jmp fvarId args => return code.updateJmp! ((← getSubst).applyToFVar fvarId) (← args.mapMonoM fun arg => return (← getSubst).applyToExpr arg)
|
||||
| .unreach .. => return code
|
||||
|
||||
/--
|
||||
Common sub-expression elimination
|
||||
|
|
|
|||
|
|
@ -6,29 +6,33 @@ Authors: Leonardo de Moura
|
|||
import Lean.Compiler.LCNF.CompilerM
|
||||
|
||||
namespace Lean.Compiler.LCNF
|
||||
namespace ElimDead
|
||||
|
||||
abbrev UsedSet := FVarIdHashSet
|
||||
abbrev UsedLocalDecls := FVarIdHashSet
|
||||
|
||||
/--
|
||||
Collect set of (let) free variables in a LCNF value.
|
||||
This code exploits the LCNF property that local declarations do not occur in types.
|
||||
-/
|
||||
def collectExpr (s : UsedSet) (e : Expr) : UsedSet :=
|
||||
match e with
|
||||
| .proj _ _ e => collectExpr s e
|
||||
| .forallE .. => s
|
||||
| .lam _ _ b _ => collectExpr s b
|
||||
| .letE .. => unreachable! -- Valid LCNF does not contain `let`-declarations
|
||||
| .app f a => collectExpr (collectExpr s a) f
|
||||
| .mdata _ b => collectExpr s b
|
||||
| .fvar fvarId => s.insert fvarId
|
||||
| _ => s
|
||||
def collectLocalDecls (s : UsedLocalDecls) (e : Expr) : UsedLocalDecls :=
|
||||
go s e
|
||||
where
|
||||
go (s : UsedLocalDecls) (e : Expr) : UsedLocalDecls :=
|
||||
match e with
|
||||
| .proj _ _ e => go s e
|
||||
| .forallE .. => s
|
||||
| .lam _ _ b _ => go s b
|
||||
| .letE .. => unreachable! -- Valid LCNF does not contain `let`-declarations
|
||||
| .app f a => go (go s a) f
|
||||
| .mdata _ b => go s b
|
||||
| .fvar fvarId => s.insert fvarId
|
||||
| _ => s
|
||||
|
||||
abbrev M := StateRefT UsedSet CompilerM
|
||||
namespace ElimDead
|
||||
|
||||
abbrev M := StateRefT UsedLocalDecls CompilerM
|
||||
|
||||
private abbrev collectExprM (e : Expr) : M Unit :=
|
||||
modify (collectExpr · e)
|
||||
modify (collectLocalDecls · e)
|
||||
|
||||
private abbrev collectFVarM (fvarId : FVarId) : M Unit :=
|
||||
modify (·.insert fvarId)
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ Authors: Leonardo de Moura
|
|||
import Lean.Util.Recognizers
|
||||
import Lean.Compiler.InlineAttrs
|
||||
import Lean.Compiler.LCNF.CompilerM
|
||||
import Lean.Compiler.LCNF.ElimDead
|
||||
import Lean.Compiler.LCNF.Stage1
|
||||
|
||||
namespace Lean.Compiler.LCNF
|
||||
|
|
@ -102,6 +103,10 @@ structure State where
|
|||
-/
|
||||
subst : FVarSubst := {}
|
||||
/--
|
||||
Track used local declarations to be able to eliminate dead variables.
|
||||
-/
|
||||
used : UsedLocalDecls := {}
|
||||
/--
|
||||
Mapping used to decide whether a local function declaration must be inlined or not.
|
||||
-/
|
||||
funDeclInfoMap : FunDeclInfoMap := {}
|
||||
|
|
@ -184,10 +189,22 @@ def simpAppApp? (e : Expr) : OptionT SimpM Expr := do
|
|||
markSimplified
|
||||
return mkAppN f e.getAppArgs
|
||||
|
||||
def incVisited : SimpM Unit :=
|
||||
modify fun s => { s with visited := s.visited + 1 }
|
||||
|
||||
def markUsedFVar (fvarId : FVarId) : SimpM Unit :=
|
||||
modify fun s => { s with used := s.used.insert fvarId }
|
||||
|
||||
mutual
|
||||
|
||||
partial def simp (code : Code) : SimpM Code :=
|
||||
partial def simp (code : Code) : SimpM Code := do
|
||||
-- TODO
|
||||
incVisited
|
||||
match code with
|
||||
| .return fvarId =>
|
||||
markUsedFVar fvarId
|
||||
return code
|
||||
| .unreach .. => return code
|
||||
| _ => return code
|
||||
|
||||
end
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue