diff --git a/src/Lean/Compiler/LCNF/Basic.lean b/src/Lean/Compiler/LCNF/Basic.lean index 89618fcebb..2d9936bfbd 100644 --- a/src/Lean/Compiler/LCNF/Basic.lean +++ b/src/Lean/Compiler/LCNF/Basic.lean @@ -126,13 +126,19 @@ private unsafe def updateAltImp (alt : Alt) (ps' : Array Param) (k' : Code) : Al @[implementedBy updateFunImp] opaque Code.updateFun! (c : Code) (decl' : FunDecl) (k' : Code) : Code -/- -@[inline] private unsafe def updateCases (cases : Cases) (decl' : FunDecl) (k' : Code) : Cases := +@[inline] private unsafe def updateReturnImp (c : Code) (fvarId' : FVarId) : Code := match c with - | .fun decl k => if ptrEq k k' && ptrEq decl decl' then c else .fun decl' k' - | .jp decl k => if ptrEq k k' && ptrEq decl decl' then c else .jp decl' k' + | .return fvarId => if fvarId == fvarId' then c else .return fvarId' | _ => unreachable! --/ + +@[implementedBy updateReturnImp] opaque Code.updateReturn! (c : Code) (fvarId' : FVarId) : Code + +@[inline] private unsafe def updateJmpImp (c : Code) (fvarId' : FVarId) (args' : Array Expr) : Code := + match c with + | .jmp fvarId args => if fvarId == fvarId' && ptrEq args args' then c else .jmp fvarId' args' + | _ => unreachable! + +@[implementedBy updateJmpImp] opaque Code.updateJmp! (c : Code) (fvarId' : FVarId) (args' : Array Expr) : Code def Code.isDecl : Code → Bool | .let .. | .fun .. | .jp .. => true diff --git a/src/Lean/Compiler/LCNF/CSE.lean b/src/Lean/Compiler/LCNF/CSE.lean index 8939cec447..b7a4d1193c 100644 --- a/src/Lean/Compiler/LCNF/CSE.lean +++ b/src/Lean/Compiler/LCNF/CSE.lean @@ -89,7 +89,9 @@ where return alt.updateAlt! ps (← go k) | .default k => withNewScope do return alt.updateCode (← go k) return code.updateCases! resultType discr alts - | .return .. | .jmp .. | .unreach .. => return code + | .return fvarId => return code.updateReturn! ((← getSubst).applyToFVar fvarId) + | .jmp fvarId args => return code.updateJmp! ((← getSubst).applyToFVar fvarId) (← args.mapMonoM fun arg => return (← getSubst).applyToExpr arg) + | .unreach .. => return code /-- Common sub-expression elimination diff --git a/src/Lean/Compiler/LCNF/ElimDead.lean b/src/Lean/Compiler/LCNF/ElimDead.lean index cf77ae963d..a7343de68e 100644 --- a/src/Lean/Compiler/LCNF/ElimDead.lean +++ b/src/Lean/Compiler/LCNF/ElimDead.lean @@ -6,29 +6,33 @@ Authors: Leonardo de Moura import Lean.Compiler.LCNF.CompilerM namespace Lean.Compiler.LCNF -namespace ElimDead -abbrev UsedSet := FVarIdHashSet +abbrev UsedLocalDecls := FVarIdHashSet /-- Collect set of (let) free variables in a LCNF value. This code exploits the LCNF property that local declarations do not occur in types. -/ -def collectExpr (s : UsedSet) (e : Expr) : UsedSet := - match e with - | .proj _ _ e => collectExpr s e - | .forallE .. => s - | .lam _ _ b _ => collectExpr s b - | .letE .. => unreachable! -- Valid LCNF does not contain `let`-declarations - | .app f a => collectExpr (collectExpr s a) f - | .mdata _ b => collectExpr s b - | .fvar fvarId => s.insert fvarId - | _ => s +def collectLocalDecls (s : UsedLocalDecls) (e : Expr) : UsedLocalDecls := + go s e +where + go (s : UsedLocalDecls) (e : Expr) : UsedLocalDecls := + match e with + | .proj _ _ e => go s e + | .forallE .. => s + | .lam _ _ b _ => go s b + | .letE .. => unreachable! -- Valid LCNF does not contain `let`-declarations + | .app f a => go (go s a) f + | .mdata _ b => go s b + | .fvar fvarId => s.insert fvarId + | _ => s -abbrev M := StateRefT UsedSet CompilerM +namespace ElimDead + +abbrev M := StateRefT UsedLocalDecls CompilerM private abbrev collectExprM (e : Expr) : M Unit := - modify (collectExpr · e) + modify (collectLocalDecls · e) private abbrev collectFVarM (fvarId : FVarId) : M Unit := modify (·.insert fvarId) diff --git a/src/Lean/Compiler/LCNF/Simp.lean b/src/Lean/Compiler/LCNF/Simp.lean index 5145ad6856..5dd00a9173 100644 --- a/src/Lean/Compiler/LCNF/Simp.lean +++ b/src/Lean/Compiler/LCNF/Simp.lean @@ -6,6 +6,7 @@ Authors: Leonardo de Moura import Lean.Util.Recognizers import Lean.Compiler.InlineAttrs import Lean.Compiler.LCNF.CompilerM +import Lean.Compiler.LCNF.ElimDead import Lean.Compiler.LCNF.Stage1 namespace Lean.Compiler.LCNF @@ -102,6 +103,10 @@ structure State where -/ subst : FVarSubst := {} /-- + Track used local declarations to be able to eliminate dead variables. + -/ + used : UsedLocalDecls := {} + /-- Mapping used to decide whether a local function declaration must be inlined or not. -/ funDeclInfoMap : FunDeclInfoMap := {} @@ -184,10 +189,22 @@ def simpAppApp? (e : Expr) : OptionT SimpM Expr := do markSimplified return mkAppN f e.getAppArgs +def incVisited : SimpM Unit := + modify fun s => { s with visited := s.visited + 1 } + +def markUsedFVar (fvarId : FVarId) : SimpM Unit := + modify fun s => { s with used := s.used.insert fvarId } + mutual -partial def simp (code : Code) : SimpM Code := +partial def simp (code : Code) : SimpM Code := do + -- TODO + incVisited match code with + | .return fvarId => + markUsedFVar fvarId + return code + | .unreach .. => return code | _ => return code end