feat: more LCNF update functions

and bug fixes at CSE
This commit is contained in:
Leonardo de Moura 2022-08-28 19:00:49 -07:00
parent 5552d610e8
commit 062d4728a1
4 changed files with 50 additions and 21 deletions

View file

@ -126,13 +126,19 @@ private unsafe def updateAltImp (alt : Alt) (ps' : Array Param) (k' : Code) : Al
@[implementedBy updateFunImp] opaque Code.updateFun! (c : Code) (decl' : FunDecl) (k' : Code) : Code
/-
@[inline] private unsafe def updateCases (cases : Cases) (decl' : FunDecl) (k' : Code) : Cases :=
@[inline] private unsafe def updateReturnImp (c : Code) (fvarId' : FVarId) : Code :=
match c with
| .fun decl k => if ptrEq k k' && ptrEq decl decl' then c else .fun decl' k'
| .jp decl k => if ptrEq k k' && ptrEq decl decl' then c else .jp decl' k'
| .return fvarId => if fvarId == fvarId' then c else .return fvarId'
| _ => unreachable!
-/
@[implementedBy updateReturnImp] opaque Code.updateReturn! (c : Code) (fvarId' : FVarId) : Code
@[inline] private unsafe def updateJmpImp (c : Code) (fvarId' : FVarId) (args' : Array Expr) : Code :=
match c with
| .jmp fvarId args => if fvarId == fvarId' && ptrEq args args' then c else .jmp fvarId' args'
| _ => unreachable!
@[implementedBy updateJmpImp] opaque Code.updateJmp! (c : Code) (fvarId' : FVarId) (args' : Array Expr) : Code
def Code.isDecl : Code → Bool
| .let .. | .fun .. | .jp .. => true

View file

@ -89,7 +89,9 @@ where
return alt.updateAlt! ps (← go k)
| .default k => withNewScope do return alt.updateCode (← go k)
return code.updateCases! resultType discr alts
| .return .. | .jmp .. | .unreach .. => return code
| .return fvarId => return code.updateReturn! ((← getSubst).applyToFVar fvarId)
| .jmp fvarId args => return code.updateJmp! ((← getSubst).applyToFVar fvarId) (← args.mapMonoM fun arg => return (← getSubst).applyToExpr arg)
| .unreach .. => return code
/--
Common sub-expression elimination

View file

@ -6,29 +6,33 @@ Authors: Leonardo de Moura
import Lean.Compiler.LCNF.CompilerM
namespace Lean.Compiler.LCNF
namespace ElimDead
abbrev UsedSet := FVarIdHashSet
abbrev UsedLocalDecls := FVarIdHashSet
/--
Collect set of (let) free variables in a LCNF value.
This code exploits the LCNF property that local declarations do not occur in types.
-/
def collectExpr (s : UsedSet) (e : Expr) : UsedSet :=
match e with
| .proj _ _ e => collectExpr s e
| .forallE .. => s
| .lam _ _ b _ => collectExpr s b
| .letE .. => unreachable! -- Valid LCNF does not contain `let`-declarations
| .app f a => collectExpr (collectExpr s a) f
| .mdata _ b => collectExpr s b
| .fvar fvarId => s.insert fvarId
| _ => s
def collectLocalDecls (s : UsedLocalDecls) (e : Expr) : UsedLocalDecls :=
go s e
where
go (s : UsedLocalDecls) (e : Expr) : UsedLocalDecls :=
match e with
| .proj _ _ e => go s e
| .forallE .. => s
| .lam _ _ b _ => go s b
| .letE .. => unreachable! -- Valid LCNF does not contain `let`-declarations
| .app f a => go (go s a) f
| .mdata _ b => go s b
| .fvar fvarId => s.insert fvarId
| _ => s
abbrev M := StateRefT UsedSet CompilerM
namespace ElimDead
abbrev M := StateRefT UsedLocalDecls CompilerM
private abbrev collectExprM (e : Expr) : M Unit :=
modify (collectExpr · e)
modify (collectLocalDecls · e)
private abbrev collectFVarM (fvarId : FVarId) : M Unit :=
modify (·.insert fvarId)

View file

@ -6,6 +6,7 @@ Authors: Leonardo de Moura
import Lean.Util.Recognizers
import Lean.Compiler.InlineAttrs
import Lean.Compiler.LCNF.CompilerM
import Lean.Compiler.LCNF.ElimDead
import Lean.Compiler.LCNF.Stage1
namespace Lean.Compiler.LCNF
@ -102,6 +103,10 @@ structure State where
-/
subst : FVarSubst := {}
/--
Track used local declarations to be able to eliminate dead variables.
-/
used : UsedLocalDecls := {}
/--
Mapping used to decide whether a local function declaration must be inlined or not.
-/
funDeclInfoMap : FunDeclInfoMap := {}
@ -184,10 +189,22 @@ def simpAppApp? (e : Expr) : OptionT SimpM Expr := do
markSimplified
return mkAppN f e.getAppArgs
def incVisited : SimpM Unit :=
modify fun s => { s with visited := s.visited + 1 }
def markUsedFVar (fvarId : FVarId) : SimpM Unit :=
modify fun s => { s with used := s.used.insert fvarId }
mutual
partial def simp (code : Code) : SimpM Code :=
partial def simp (code : Code) : SimpM Code := do
-- TODO
incVisited
match code with
| .return fvarId =>
markUsedFVar fvarId
return code
| .unreach .. => return code
| _ => return code
end