diff --git a/src/Lean/Compiler/IR.lean b/src/Lean/Compiler/IR.lean index 7c4a0e21d7..0c2fb41768 100644 --- a/src/Lean/Compiler/IR.lean +++ b/src/Lean/Compiler/IR.lean @@ -11,7 +11,6 @@ public import Lean.Compiler.IR.Basic public import Lean.Compiler.IR.Format public import Lean.Compiler.IR.CompilerM public import Lean.Compiler.IR.PushProj -public import Lean.Compiler.IR.ElimDeadVars public import Lean.Compiler.IR.SimpCase public import Lean.Compiler.IR.NormIds public import Lean.Compiler.IR.Checker @@ -41,8 +40,6 @@ def compile (decls : Array Decl) : CompilerM (Array Decl) := do logDecls `init decls checkDecls decls let mut decls := decls - decls := decls.map Decl.elimDead - logDecls `elim_dead decls decls := decls.map Decl.simpCase logDecls `simp_case decls decls := decls.map Decl.normalizeIds diff --git a/src/Lean/Compiler/LCNF/ElimDead.lean b/src/Lean/Compiler/LCNF/ElimDead.lean index 393293de4c..ff44cdba84 100644 --- a/src/Lean/Compiler/LCNF/ElimDead.lean +++ b/src/Lean/Compiler/LCNF/ElimDead.lean @@ -6,67 +6,82 @@ Authors: Leonardo de Moura module prelude -public import Lean.Compiler.LCNF.CompilerM +public import Lean.Compiler.LCNF.PassManager -public section +/-! +This module implements a pass that does a syntactic use-def check for all let/fun/jp bindings and +removes them if they are unused. Note that in impure mode not all unused let bindings can be removed safely +so we opt for a safe subset. +-/ namespace Lean.Compiler.LCNF -abbrev UsedLocalDecls := FVarIdHashSet +public abbrev UsedLocalDecls := FVarIdHashSet /-- Collect set of (let) free variables in a LCNF value. This code exploits the LCNF property that local declarations do not occur in types. -/ - -def collectLocalDeclsArg (s : UsedLocalDecls) (arg : Arg .pure) : UsedLocalDecls := +def collectLocalDeclsArg (s : UsedLocalDecls) (arg : Arg pu) : UsedLocalDecls := match arg with | .fvar fvarId => s.insert fvarId -- Locally declared variables do not occur in types. - | .type _ | .erased => s + | .type _ _ | .erased => s -def collectLocalDeclsArgs (s : UsedLocalDecls) (args : Array (Arg .pure)) : UsedLocalDecls := +def collectLocalDeclsArgs (s : UsedLocalDecls) (args : Array (Arg pu)) : UsedLocalDecls := args.foldl (init := s) collectLocalDeclsArg -def collectLocalDeclsLetValue (s : UsedLocalDecls) (e : LetValue .pure) : UsedLocalDecls := +def collectLocalDeclsLetValue (s : UsedLocalDecls) (e : LetValue pu) : UsedLocalDecls := match e with | .erased | .lit .. => s - | .proj _ _ fvarId => s.insert fvarId - | .const _ _ args => collectLocalDeclsArgs s args - | .fvar fvarId args => collectLocalDeclsArgs (s.insert fvarId) args - -namespace ElimDead + | .proj _ _ fvarId _ | .reset _ fvarId _ | .sproj _ _ fvarId _ | .uproj _ fvarId _ + | .oproj _ fvarId _ => s.insert fvarId + | .const _ _ args _ => collectLocalDeclsArgs s args + | .fvar fvarId args | .reuse fvarId _ _ args _ => collectLocalDeclsArgs (s.insert fvarId) args + | .fap _ args _ | .pap _ args _ | .ctor _ args _ => collectLocalDeclsArgs s args abbrev M := StateRefT UsedLocalDecls CompilerM -private abbrev collectArgM (arg : Arg .pure) : M Unit := +abbrev collectArgM (arg : Arg pu) : M Unit := modify (collectLocalDeclsArg · arg) -private abbrev collectLetValueM (e : LetValue .pure) : M Unit := +abbrev collectLetValueM (e : LetValue pu) : M Unit := modify (collectLocalDeclsLetValue · e) -private abbrev collectFVarM (fvarId : FVarId) : M Unit := +abbrev collectFVarM (fvarId : FVarId) : M Unit := modify (·.insert fvarId) +def LetValue.safeToElim (val : LetValue pu) : Bool := + match pu with + | .pure => true + | .impure => + match val with + | .ctor .. | .reset .. | .reuse .. | .oproj .. | .uproj .. | .sproj .. | .lit .. | .pap .. + -- TODO | .box .. | .unbox .. | .isShared .. + | .erased .. => true + -- 0-ary full applications are considered constants + | .fap _ args => args.isEmpty + | .fvar .. => false + mutual -partial def visitFunDecl (funDecl : FunDecl .pure) : M (FunDecl .pure) := do - let value ← elimDead funDecl.value +partial def visitFunDecl (funDecl : FunDecl pu) : M (FunDecl pu) := do + let value ← funDecl.value.elimDead funDecl.updateValue value -partial def elimDead (code : Code .pure) : M (Code .pure) := do +partial def Code.elimDead (code : Code pu) : M (Code pu) := do match code with | .let decl k => - let k ← elimDead k - if (← get).contains decl.fvarId then + let k ← k.elimDead + if (← get).contains decl.fvarId || !decl.value.safeToElim then /- Remark: we don't need to collect `decl.type` because LCNF local declarations do not occur in types. -/ collectLetValueM decl.value return code.updateCont! k else eraseLetDecl decl return k - | .fun decl k | .jp decl k => - let k ← elimDead k + | .fun decl k _ | .jp decl k => + let k ← k.elimDead if (← get).contains decl.fvarId then let decl ← visitFunDecl decl return code.updateFun! decl k @@ -74,22 +89,26 @@ partial def elimDead (code : Code .pure) : M (Code .pure) := do eraseFunDecl decl return k | .cases c => - let alts ← c.alts.mapMonoM fun alt => return alt.updateCode (← elimDead alt.getCode) + let alts ← c.alts.mapMonoM fun alt => return alt.updateCode (← alt.getCode.elimDead) collectFVarM c.discr return code.updateAlts! alts | .return fvarId => collectFVarM fvarId; return code | .jmp fvarId args => collectFVarM fvarId; args.forM collectArgM; return code | .unreach .. => return code + | .uset var _ y k _ | .sset var _ _ y _ k _ => + collectFVarM var + collectFVarM y + return code.updateCont! (← k.elimDead) end -end ElimDead +def Decl.elimDead (decl : Decl pu) : CompilerM (Decl pu) := do + return { decl with value := (← decl.value.mapCodeM fun code => code.elimDead.run' {}) } --- TODO: Generalize this to arbitrary phases, keep in mind that in impure elim dead is not as easy though -def Code.elimDead (code : Code .pure) : CompilerM (Code .pure) := - ElimDead.elimDead code |>.run' {} +public def elimDeadVars (phase : Phase) (occurrence : Nat) : Pass := + Pass.mkPerDeclaration `elimDeadVars phase Decl.elimDead occurrence -def Decl.elimDead (decl : Decl .pure) : CompilerM (Decl .pure) := do - return { decl with value := (← decl.value.mapCodeM Code.elimDead) } +builtin_initialize + registerTraceClass `Compiler.elimDeadVars (inherited := true) end Lean.Compiler.LCNF diff --git a/src/Lean/Compiler/LCNF/Passes.lean b/src/Lean/Compiler/LCNF/Passes.lean index 652d78a914..20dfe94ae5 100644 --- a/src/Lean/Compiler/LCNF/Passes.lean +++ b/src/Lean/Compiler/LCNF/Passes.lean @@ -145,6 +145,7 @@ def builtinPassManager : PassManager := { saveImpure, -- End of impure phase pushProj (occurrence := 0), insertResetReuse, + elimDeadVars (phase := .impure) (occurrence := 0), inferVisibility (phase := .impure), ] } diff --git a/tests/lean/run/elim_dead_vars.lean b/tests/lean/run/elim_dead_vars.lean new file mode 100644 index 0000000000..674481854d --- /dev/null +++ b/tests/lean/run/elim_dead_vars.lean @@ -0,0 +1,44 @@ +prelude +import Init.Data.Option + +/-! This test asserts that the `elimDeadVars` pass is able to eliminate dead projections in the + impure phase correctly. -/ + +/-- +trace: [Compiler.saveMono] size: 5 + def isNone x : Bool := + cases x : Bool + | Option.none => + let _x.1 := true; + return _x.1 + | Option.some val.2 => + let _x.3 := false; + return _x.3 +[Compiler.pushProj] size: 6 + def isNone x : UInt8 := + cases x : UInt8 + | Option.none => + let _x.1 := 1; + return _x.1 + | Option.some => + let val.2 := proj[0] x; + let _x.3 := 0; + return _x.3 +[Compiler.elimDeadVars] size: 5 + def isNone x : UInt8 := + cases x : UInt8 + | Option.none => + let _x.1 := 1; + return _x.1 + | Option.some => + let _x.2 := 0; + return _x.2 +-/ +#guard_msgs in +set_option trace.Compiler.saveMono true in +set_option trace.Compiler.pushProj true in +set_option trace.Compiler.elimDeadVars true in +def isNone (x : Option Nat) : Bool := + match x with + | some _ => false + | none => true diff --git a/tests/lean/updateExprIssue.lean b/tests/lean/updateExprIssue.lean index ea9cced25c..012f9a0c9b 100644 --- a/tests/lean/updateExprIssue.lean +++ b/tests/lean/updateExprIssue.lean @@ -11,7 +11,7 @@ unsafe def tst1 : MetaM Unit := do #eval tst1 -set_option trace.compiler.ir.init true +set_option trace.Compiler.saveMono true def sefFn (e : Expr) (f : Expr) : Expr := match e with | .app _ a => e.updateApp! f a diff --git a/tests/lean/updateExprIssue.lean.expected.out b/tests/lean/updateExprIssue.lean.expected.out index 6bc808fe89..5a80f1bd0c 100644 --- a/tests/lean/updateExprIssue.lean.expected.out +++ b/tests/lean/updateExprIssue.lean.expected.out @@ -1,26 +1,23 @@ -[Compiler.IR] [init] - def sefFn (x_1 : obj) (x_2 : obj) : obj := - case x_1 : obj of - Lean.Expr.app._impl → - let x_3 : u64 := sproj[2, 0] x_1; - let x_4 : obj := proj[0] x_1; - let x_5 : obj := proj[1] x_1; - block_6 (x_7 : u8) := - case x_7 : u8 of - Bool.false → - let x_8 : obj := Lean.Expr.app._override x_2 x_5; - ret x_8 - Bool.true → - ret x_1; - let x_9 : usize := ptrAddrUnsafe ◾ x_4; - let x_10 : usize := ptrAddrUnsafe ◾ x_2; - let x_11 : u8 := USize.decEq x_9 x_10; - case x_11 : u8 of - Bool.false → - jmp block_6 x_11 - Bool.true → - let x_12 : usize := ptrAddrUnsafe ◾ x_5; - let x_13 : u8 := USize.decEq x_12 x_12; - jmp block_6 x_13 - default → - ret x_1 +[Compiler.saveMono] size: 17 + def sefFn e f : Expr := + cases e : Expr + | Lean.Expr.app._impl data fn.1 arg.2 => + jp _jp.3 _y.4 : Expr := + cases _y.4 : Expr + | Bool.false => + let _x.5 := Expr.app._override f arg.2; + return _x.5 + | Bool.true => + return e; + let _x.6 := ptrAddrUnsafe ◾ fn.1; + let _x.7 := ptrAddrUnsafe ◾ f; + let _x.8 := USize.decEq _x.6 _x.7; + cases _x.8 : Expr + | Bool.false => + goto _jp.3 _x.8 + | Bool.true => + let _x.9 := ptrAddrUnsafe ◾ arg.2; + let _x.10 := USize.decEq _x.9 _x.9; + goto _jp.3 _x.10 + | _ => + return e