From a3e803df49ff3ff280645e0fb028bc086b0957be Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Fri, 16 Oct 2020 16:57:58 -0700 Subject: [PATCH] chore: move to new frontend --- src/Lean/Compiler/IR/ElimDeadBranches.lean | 148 ++++++++++----------- 1 file changed, 72 insertions(+), 76 deletions(-) diff --git a/src/Lean/Compiler/IR/ElimDeadBranches.lean b/src/Lean/Compiler/IR/ElimDeadBranches.lean index eae63176ba..9360946bcd 100644 --- a/src/Lean/Compiler/IR/ElimDeadBranches.lean +++ b/src/Lean/Compiler/IR/ElimDeadBranches.lean @@ -1,3 +1,4 @@ +#lang lean4 /- Copyright (c) 2019 Microsoft Corporation. All rights reserved. Released under Apache 2.0 license as described in the file LICENSE. @@ -7,9 +8,7 @@ import Lean.Compiler.IR.Format import Lean.Compiler.IR.Basic import Lean.Compiler.IR.CompilerM -namespace Lean -namespace IR -namespace UnreachableBranches +namespace Lean.IR.UnreachableBranches /-- Value used in the abstract interpreter -/ inductive Value @@ -25,11 +24,11 @@ instance : Inhabited Value := ⟨top⟩ protected partial def beq : Value → Value → Bool | bot, bot => true | top, top => true -| ctor i₁ vs₁, ctor i₂ vs₂ => i₁ == i₂ && Array.isEqv vs₁ vs₂ beq +| ctor i₁ vs₁, ctor i₂ vs₂ => i₁ == i₂ && Array.isEqv vs₁ vs₂ Value.beq | choice vs₁, choice vs₂ => - vs₁.all (fun v₁ => vs₂.any $ fun v₂ => beq v₁ v₂) + vs₁.all (fun v₁ => vs₂.any $ fun v₂ => Value.beq v₁ v₂) && - vs₂.all (fun v₂ => vs₁.any $ fun v₁ => beq v₁ v₂) + vs₂.all (fun v₂ => vs₁.any $ fun v₁ => Value.beq v₁ v₂) | _, _ => false instance : HasBeq Value := ⟨Value.beq⟩ @@ -38,7 +37,7 @@ partial def addChoice (merge : Value → Value → Value) : List Value → Value | [], v => [v] | v₁@(ctor i₁ vs₁) :: cs, v₂@(ctor i₂ vs₂) => if i₁ == i₂ then merge v₁ v₂ :: cs - else v₁ :: addChoice cs v₂ + else v₁ :: addChoice merge cs v₂ | _, _ => panic! "invalid addChoice" partial def merge : Value → Value → Value @@ -56,8 +55,8 @@ partial def merge : Value → Value → Value protected partial def format : Value → Format | top => "top" | bot => "bot" -| choice vs => fmt "@" ++ @List.format _ ⟨format⟩ vs -| ctor i vs => fmt "#" ++ if vs.isEmpty then fmt i.name else Format.paren (fmt i.name ++ @formatArray _ ⟨format⟩ vs) +| choice vs => fmt "@" ++ @List.format _ ⟨Value.format⟩ vs +| ctor i vs => fmt "#" ++ if vs.isEmpty then fmt i.name else Format.paren (fmt i.name ++ @formatArray _ ⟨Value.format⟩ vs) instance : HasFormat Value := ⟨Value.format⟩ instance : HasToString Value := ⟨Format.pretty ∘ Value.format⟩ @@ -67,19 +66,19 @@ instance : HasToString Value := ⟨Format.pretty ∘ Value.format⟩ interpreter. -/ partial def truncate (env : Environment) : Value → NameSet → Value | ctor i vs, found => - let I := i.name.getPrefix; + let I := i.name.getPrefix if found.contains I then top else let cont (found' : NameSet) : Value := - ctor i (vs.map $ fun v => truncate v found'); + ctor i (vs.map fun v => truncate env v found') match env.find? I with | some (ConstantInfo.inductInfo d) => if d.isRec then cont (found.insert I) else cont found | _ => cont found | choice vs, found => - let newVs := vs.map $ fun v => truncate v found; + let newVs := vs.map fun v => truncate env v found if newVs.elem top then top else choice newVs | v, _ => v @@ -96,7 +95,7 @@ def mkFunctionSummariesExtension : IO (SimplePersistentEnvExtension (FunId × Va registerSimplePersistentEnvExtension { name := `unreachBranchesFunSummary, addImportedFn := fun as => - let cache : FunctionSummaries := mkStateFromImportedEntries (fun s (p : FunId × Value) => s.insert p.1 p.2) {} as; + let cache : FunctionSummaries := mkStateFromImportedEntries (fun s (p : FunId × Value) => s.insert p.1 p.2) {} as cache.switch, addEntryFn := fun s ⟨e, n⟩ => s.insert e n } @@ -127,9 +126,9 @@ abbrev M := ReaderT InterpContext (StateM InterpState) open Value def findVarValue (x : VarId) : M Value := do -ctx ← read; -s ← get; -let assignment := s.assignments.get! ctx.currFnIdx; +let ctx ← read +let s ← get +let assignment := s.assignments[ctx.currFnIdx] pure $ assignment.findD x bot def findArgValue (arg : Arg) : M Value := @@ -138,13 +137,13 @@ match arg with | _ => pure top def updateVarAssignment (x : VarId) (v : Value) : M Unit := do -v' ← findVarValue x; -ctx ← read; -modify $ fun s => { s with assignments := s.assignments.modify ctx.currFnIdx $ fun a => a.insert x (merge v v') } +let v' ← findVarValue x +let ctx ← read +modify fun s => { s with assignments := s.assignments.modify ctx.currFnIdx fun a => a.insert x (merge v v') } def resetVarAssignment (x : VarId) : M Unit := do -ctx ← read; -modify $ fun s => { s with assignments := s.assignments.modify ctx.currFnIdx $ fun a => a.insert x Value.bot } +let ctx ← read +modify fun s => { s with assignments := s.assignments.modify ctx.currFnIdx fun a => a.insert x Value.bot } def resetParamAssignment (y : Param) : M Unit := resetVarAssignment y.x @@ -156,15 +155,15 @@ partial def projValue : Value → Nat → Value def interpExpr : Expr → M Value | Expr.ctor i ys => ctor i <$> ys.mapM (fun y => findArgValue y) -| Expr.proj i x => do v ← findVarValue x; pure $ projValue v i +| Expr.proj i x => do let v ← findVarValue x; pure $ projValue v i | Expr.fap fid ys => do - ctx ← read; + let ctx ← read match getFunctionSummary? ctx.env fid with | some v => pure v | none => do - s ← get; + let s ← get match ctx.decls.findIdx? (fun decl => decl.name == fid) with - | some idx => pure $ s.funVals.get! idx + | some idx => pure s.funVals[idx] | none => pure top | _ => pure top @@ -175,32 +174,32 @@ partial def containsCtor : Value → CtorInfo → Bool | _, _ => false def updateCurrFnSummary (v : Value) : M Unit := do -ctx ← read; -let currFnIdx := ctx.currFnIdx; -modify $ fun s => { s with funVals := s.funVals.modify currFnIdx (fun v' => widening ctx.env v v') } +let ctx ← read +let currFnIdx := ctx.currFnIdx +modify fun s => { s with funVals := s.funVals.modify currFnIdx (fun v' => widening ctx.env v v') } /-- Return true if the assignment of at least one parameter has been updated. -/ def updateJPParamsAssignment (ys : Array Param) (xs : Array Arg) : M Bool := do -ctx ← read; -let currFnIdx := ctx.currFnIdx; +let ctx ← read +let currFnIdx := ctx.currFnIdx ys.size.foldM (fun i r => do - let y := ys.get! i; - let x := xs.get! i; - yVal ← findVarValue y.x; - xVal ← findArgValue x; - let newVal := merge yVal xVal; + let y := ys[i] + let x := xs[i] + let yVal ← findVarValue y.x + let xVal ← findArgValue x + let newVal := merge yVal xVal if newVal == yVal then pure r else do - modify $ fun s => { s with assignments := s.assignments.modify currFnIdx $ fun a => a.insert y.x newVal }; + modify fun s => { s with assignments := s.assignments.modify currFnIdx fun a => a.insert y.x newVal } pure true) false private partial def resetNestedJPParams : FnBody → M Unit | FnBody.jdecl _ ys b k => do - ctx ← read; - let currFnIdx := ctx.currFnIdx; - ys.forM resetParamAssignment; + let ctx ← read + let currFnIdx := ctx.currFnIdx + ys.forM resetParamAssignment /- Remark we don't need to reset the parameters of joint-points nested in `b` since they will be reset if this JP is used. -/ resetNestedJPParams k @@ -208,50 +207,50 @@ private partial def resetNestedJPParams : FnBody → M Unit alts.forM fun alt => match alt with | Alt.ctor _ b => resetNestedJPParams b | Alt.default b => resetNestedJPParams b -| e => unless (e.isTerminal) $ resetNestedJPParams e.body +| e => do unless e.isTerminal do resetNestedJPParams e.body partial def interpFnBody : FnBody → M Unit | FnBody.vdecl x _ e b => do - v ← interpExpr e; - updateVarAssignment x v; + let v ← interpExpr e + updateVarAssignment x v interpFnBody b | FnBody.jdecl j ys v b => withReader (fun ctx => { ctx with lctx := ctx.lctx.addJP j ys v }) $ interpFnBody b | FnBody.case _ x _ alts => do - v ← findVarValue x; - alts.forM $ fun alt => + let v ← findVarValue x + alts.forM fun alt => do match alt with - | Alt.ctor i b => when (containsCtor v i) $ interpFnBody b + | Alt.ctor i b => if containsCtor v i then interpFnBody b | Alt.default b => interpFnBody b | FnBody.ret x => do - v ← findArgValue x; + let v ← findArgValue x -- dbgTrace ("ret " ++ toString v) $ fun _ => updateCurrFnSummary v | FnBody.jmp j xs => do - ctx ← read; + let ctx ← read let ys := (ctx.lctx.getJPParams j).get!; let b := (ctx.lctx.getJPBody j).get!; - updated ← updateJPParamsAssignment ys xs; + let updated ← updateJPParamsAssignment ys xs; when updated do -- We must reset the value of nested join-point parameters since they depend on `ys` values resetNestedJPParams b; interpFnBody b -| e => unless (e.isTerminal) $ interpFnBody e.body +| e => do unless (e.isTerminal) do interpFnBody e.body def inferStep : M Bool := do -ctx ← read; +let ctx ← read modify $ fun s => { s with assignments := ctx.decls.map $ fun _ => {} }; ctx.decls.size.foldM (fun idx modified => do match ctx.decls.get! idx with | Decl.fdecl fid ys _ b => do - s ← get; + let s ← get; -- dbgTrace (">> " ++ toString fid) $ fun _ => let currVals := s.funVals.get! idx; withReader (fun ctx => { ctx with currFnIdx := idx }) $ do ys.forM $ fun y => updateVarAssignment y.x top; interpFnBody b; - s ← get; + let s ← get; let newVals := s.funVals.get! idx; pure (modified || currVals != newVals) | Decl.extern _ _ _ _ => pure modified) @@ -259,24 +258,24 @@ ctx.decls.size.foldM (fun idx modified => do partial def inferMain : Unit → M Unit | _ => do - modified ← inferStep; + let modified ← inferStep; if modified then inferMain () else pure () partial def elimDeadAux (assignment : Assignment) : FnBody → FnBody -| FnBody.vdecl x t e b => FnBody.vdecl x t e (elimDeadAux b) -| FnBody.jdecl j ys v b => FnBody.jdecl j ys (elimDeadAux v) (elimDeadAux b) +| FnBody.vdecl x t e b => FnBody.vdecl x t e (elimDeadAux assignment b) +| FnBody.jdecl j ys v b => FnBody.jdecl j ys (elimDeadAux assignment v) (elimDeadAux assignment b) | FnBody.case tid x xType alts => let v := assignment.findD x bot; let alts := alts.map $ fun alt => match alt with - | Alt.ctor i b => Alt.ctor i $ if containsCtor v i then elimDeadAux b else FnBody.unreachable - | Alt.default b => Alt.default (elimDeadAux b); + | Alt.ctor i b => Alt.ctor i $ if containsCtor v i then elimDeadAux assignment b else FnBody.unreachable + | Alt.default b => Alt.default (elimDeadAux assignment b); FnBody.case tid x xType alts | e => if e.isTerminal then e else let (instr, b) := e.split; - let b := elimDeadAux b; + let b := elimDeadAux assignment b; instr.setBody b partial def elimDead (assignment : Assignment) : Decl → Decl @@ -288,22 +287,19 @@ end UnreachableBranches open UnreachableBranches def elimDeadBranches (decls : Array Decl) : CompilerM (Array Decl) := do -s ← get; -let env := s.env; -let assignments : Array Assignment := decls.map $ fun _ => {}; -let funVals := Std.mkPArray decls.size Value.bot; -let ctx : InterpContext := { decls := decls, env := env }; -let s : InterpState := { assignments := assignments, funVals := funVals }; -let (_, s) := (inferMain () ctx).run s; -let funVals := s.funVals; -let assignments := s.assignments; -modify $ fun s => - let env := decls.size.fold (fun i env => - -- dbgTrace (">> " ++ toString (decls.get! i).name ++ " " ++ toString (funVals.get! i)) $ fun _ => - addFunctionSummary env (decls.get! i).name (funVals.get! i)) - s.env; - { s with env := env }; -pure $ decls.mapIdx $ fun i decl => elimDead (assignments.get! i) decl +let s ← get +let env := s.env +let assignments : Array Assignment := decls.map $ fun _ => {} +let funVals := Std.mkPArray decls.size Value.bot +let ctx : InterpContext := { decls := decls, env := env } +let s : InterpState := { assignments := assignments, funVals := funVals } +let (_, s) := (inferMain () ctx).run s +let funVals := s.funVals +let assignments := s.assignments +modify fun s => + let env := decls.size.fold (init := s.env) fun i env => + addFunctionSummary env decls[i].name funVals[i] + { s with env := env } +pure $ decls.mapIdx $ fun i decl => elimDead assignments[i] decl -end IR -end Lean +end Lean.IR