chore: move to new frontend
This commit is contained in:
parent
2d98776632
commit
a3e803df49
1 changed files with 72 additions and 76 deletions
|
|
@ -1,3 +1,4 @@
|
|||
#lang lean4
|
||||
/-
|
||||
Copyright (c) 2019 Microsoft Corporation. All rights reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
|
|
@ -7,9 +8,7 @@ import Lean.Compiler.IR.Format
|
|||
import Lean.Compiler.IR.Basic
|
||||
import Lean.Compiler.IR.CompilerM
|
||||
|
||||
namespace Lean
|
||||
namespace IR
|
||||
namespace UnreachableBranches
|
||||
namespace Lean.IR.UnreachableBranches
|
||||
|
||||
/-- Value used in the abstract interpreter -/
|
||||
inductive Value
|
||||
|
|
@ -25,11 +24,11 @@ instance : Inhabited Value := ⟨top⟩
|
|||
protected partial def beq : Value → Value → Bool
|
||||
| bot, bot => true
|
||||
| top, top => true
|
||||
| ctor i₁ vs₁, ctor i₂ vs₂ => i₁ == i₂ && Array.isEqv vs₁ vs₂ beq
|
||||
| ctor i₁ vs₁, ctor i₂ vs₂ => i₁ == i₂ && Array.isEqv vs₁ vs₂ Value.beq
|
||||
| choice vs₁, choice vs₂ =>
|
||||
vs₁.all (fun v₁ => vs₂.any $ fun v₂ => beq v₁ v₂)
|
||||
vs₁.all (fun v₁ => vs₂.any $ fun v₂ => Value.beq v₁ v₂)
|
||||
&&
|
||||
vs₂.all (fun v₂ => vs₁.any $ fun v₁ => beq v₁ v₂)
|
||||
vs₂.all (fun v₂ => vs₁.any $ fun v₁ => Value.beq v₁ v₂)
|
||||
| _, _ => false
|
||||
|
||||
instance : HasBeq Value := ⟨Value.beq⟩
|
||||
|
|
@ -38,7 +37,7 @@ partial def addChoice (merge : Value → Value → Value) : List Value → Value
|
|||
| [], v => [v]
|
||||
| v₁@(ctor i₁ vs₁) :: cs, v₂@(ctor i₂ vs₂) =>
|
||||
if i₁ == i₂ then merge v₁ v₂ :: cs
|
||||
else v₁ :: addChoice cs v₂
|
||||
else v₁ :: addChoice merge cs v₂
|
||||
| _, _ => panic! "invalid addChoice"
|
||||
|
||||
partial def merge : Value → Value → Value
|
||||
|
|
@ -56,8 +55,8 @@ partial def merge : Value → Value → Value
|
|||
protected partial def format : Value → Format
|
||||
| top => "top"
|
||||
| bot => "bot"
|
||||
| choice vs => fmt "@" ++ @List.format _ ⟨format⟩ vs
|
||||
| ctor i vs => fmt "#" ++ if vs.isEmpty then fmt i.name else Format.paren (fmt i.name ++ @formatArray _ ⟨format⟩ vs)
|
||||
| choice vs => fmt "@" ++ @List.format _ ⟨Value.format⟩ vs
|
||||
| ctor i vs => fmt "#" ++ if vs.isEmpty then fmt i.name else Format.paren (fmt i.name ++ @formatArray _ ⟨Value.format⟩ vs)
|
||||
|
||||
instance : HasFormat Value := ⟨Value.format⟩
|
||||
instance : HasToString Value := ⟨Format.pretty ∘ Value.format⟩
|
||||
|
|
@ -67,19 +66,19 @@ instance : HasToString Value := ⟨Format.pretty ∘ Value.format⟩
|
|||
interpreter. -/
|
||||
partial def truncate (env : Environment) : Value → NameSet → Value
|
||||
| ctor i vs, found =>
|
||||
let I := i.name.getPrefix;
|
||||
let I := i.name.getPrefix
|
||||
if found.contains I then
|
||||
top
|
||||
else
|
||||
let cont (found' : NameSet) : Value :=
|
||||
ctor i (vs.map $ fun v => truncate v found');
|
||||
ctor i (vs.map fun v => truncate env v found')
|
||||
match env.find? I with
|
||||
| some (ConstantInfo.inductInfo d) =>
|
||||
if d.isRec then cont (found.insert I)
|
||||
else cont found
|
||||
| _ => cont found
|
||||
| choice vs, found =>
|
||||
let newVs := vs.map $ fun v => truncate v found;
|
||||
let newVs := vs.map fun v => truncate env v found
|
||||
if newVs.elem top then top
|
||||
else choice newVs
|
||||
| v, _ => v
|
||||
|
|
@ -96,7 +95,7 @@ def mkFunctionSummariesExtension : IO (SimplePersistentEnvExtension (FunId × Va
|
|||
registerSimplePersistentEnvExtension {
|
||||
name := `unreachBranchesFunSummary,
|
||||
addImportedFn := fun as =>
|
||||
let cache : FunctionSummaries := mkStateFromImportedEntries (fun s (p : FunId × Value) => s.insert p.1 p.2) {} as;
|
||||
let cache : FunctionSummaries := mkStateFromImportedEntries (fun s (p : FunId × Value) => s.insert p.1 p.2) {} as
|
||||
cache.switch,
|
||||
addEntryFn := fun s ⟨e, n⟩ => s.insert e n
|
||||
}
|
||||
|
|
@ -127,9 +126,9 @@ abbrev M := ReaderT InterpContext (StateM InterpState)
|
|||
open Value
|
||||
|
||||
def findVarValue (x : VarId) : M Value := do
|
||||
ctx ← read;
|
||||
s ← get;
|
||||
let assignment := s.assignments.get! ctx.currFnIdx;
|
||||
let ctx ← read
|
||||
let s ← get
|
||||
let assignment := s.assignments[ctx.currFnIdx]
|
||||
pure $ assignment.findD x bot
|
||||
|
||||
def findArgValue (arg : Arg) : M Value :=
|
||||
|
|
@ -138,13 +137,13 @@ match arg with
|
|||
| _ => pure top
|
||||
|
||||
def updateVarAssignment (x : VarId) (v : Value) : M Unit := do
|
||||
v' ← findVarValue x;
|
||||
ctx ← read;
|
||||
modify $ fun s => { s with assignments := s.assignments.modify ctx.currFnIdx $ fun a => a.insert x (merge v v') }
|
||||
let v' ← findVarValue x
|
||||
let ctx ← read
|
||||
modify fun s => { s with assignments := s.assignments.modify ctx.currFnIdx fun a => a.insert x (merge v v') }
|
||||
|
||||
def resetVarAssignment (x : VarId) : M Unit := do
|
||||
ctx ← read;
|
||||
modify $ fun s => { s with assignments := s.assignments.modify ctx.currFnIdx $ fun a => a.insert x Value.bot }
|
||||
let ctx ← read
|
||||
modify fun s => { s with assignments := s.assignments.modify ctx.currFnIdx fun a => a.insert x Value.bot }
|
||||
|
||||
def resetParamAssignment (y : Param) : M Unit :=
|
||||
resetVarAssignment y.x
|
||||
|
|
@ -156,15 +155,15 @@ partial def projValue : Value → Nat → Value
|
|||
|
||||
def interpExpr : Expr → M Value
|
||||
| Expr.ctor i ys => ctor i <$> ys.mapM (fun y => findArgValue y)
|
||||
| Expr.proj i x => do v ← findVarValue x; pure $ projValue v i
|
||||
| Expr.proj i x => do let v ← findVarValue x; pure $ projValue v i
|
||||
| Expr.fap fid ys => do
|
||||
ctx ← read;
|
||||
let ctx ← read
|
||||
match getFunctionSummary? ctx.env fid with
|
||||
| some v => pure v
|
||||
| none => do
|
||||
s ← get;
|
||||
let s ← get
|
||||
match ctx.decls.findIdx? (fun decl => decl.name == fid) with
|
||||
| some idx => pure $ s.funVals.get! idx
|
||||
| some idx => pure s.funVals[idx]
|
||||
| none => pure top
|
||||
| _ => pure top
|
||||
|
||||
|
|
@ -175,32 +174,32 @@ partial def containsCtor : Value → CtorInfo → Bool
|
|||
| _, _ => false
|
||||
|
||||
def updateCurrFnSummary (v : Value) : M Unit := do
|
||||
ctx ← read;
|
||||
let currFnIdx := ctx.currFnIdx;
|
||||
modify $ fun s => { s with funVals := s.funVals.modify currFnIdx (fun v' => widening ctx.env v v') }
|
||||
let ctx ← read
|
||||
let currFnIdx := ctx.currFnIdx
|
||||
modify fun s => { s with funVals := s.funVals.modify currFnIdx (fun v' => widening ctx.env v v') }
|
||||
|
||||
/-- Return true if the assignment of at least one parameter has been updated. -/
|
||||
def updateJPParamsAssignment (ys : Array Param) (xs : Array Arg) : M Bool := do
|
||||
ctx ← read;
|
||||
let currFnIdx := ctx.currFnIdx;
|
||||
let ctx ← read
|
||||
let currFnIdx := ctx.currFnIdx
|
||||
ys.size.foldM
|
||||
(fun i r => do
|
||||
let y := ys.get! i;
|
||||
let x := xs.get! i;
|
||||
yVal ← findVarValue y.x;
|
||||
xVal ← findArgValue x;
|
||||
let newVal := merge yVal xVal;
|
||||
let y := ys[i]
|
||||
let x := xs[i]
|
||||
let yVal ← findVarValue y.x
|
||||
let xVal ← findArgValue x
|
||||
let newVal := merge yVal xVal
|
||||
if newVal == yVal then pure r
|
||||
else do
|
||||
modify $ fun s => { s with assignments := s.assignments.modify currFnIdx $ fun a => a.insert y.x newVal };
|
||||
modify fun s => { s with assignments := s.assignments.modify currFnIdx fun a => a.insert y.x newVal }
|
||||
pure true)
|
||||
false
|
||||
|
||||
private partial def resetNestedJPParams : FnBody → M Unit
|
||||
| FnBody.jdecl _ ys b k => do
|
||||
ctx ← read;
|
||||
let currFnIdx := ctx.currFnIdx;
|
||||
ys.forM resetParamAssignment;
|
||||
let ctx ← read
|
||||
let currFnIdx := ctx.currFnIdx
|
||||
ys.forM resetParamAssignment
|
||||
/- Remark we don't need to reset the parameters of joint-points
|
||||
nested in `b` since they will be reset if this JP is used. -/
|
||||
resetNestedJPParams k
|
||||
|
|
@ -208,50 +207,50 @@ private partial def resetNestedJPParams : FnBody → M Unit
|
|||
alts.forM fun alt => match alt with
|
||||
| Alt.ctor _ b => resetNestedJPParams b
|
||||
| Alt.default b => resetNestedJPParams b
|
||||
| e => unless (e.isTerminal) $ resetNestedJPParams e.body
|
||||
| e => do unless e.isTerminal do resetNestedJPParams e.body
|
||||
|
||||
partial def interpFnBody : FnBody → M Unit
|
||||
| FnBody.vdecl x _ e b => do
|
||||
v ← interpExpr e;
|
||||
updateVarAssignment x v;
|
||||
let v ← interpExpr e
|
||||
updateVarAssignment x v
|
||||
interpFnBody b
|
||||
| FnBody.jdecl j ys v b =>
|
||||
withReader (fun ctx => { ctx with lctx := ctx.lctx.addJP j ys v }) $
|
||||
interpFnBody b
|
||||
| FnBody.case _ x _ alts => do
|
||||
v ← findVarValue x;
|
||||
alts.forM $ fun alt =>
|
||||
let v ← findVarValue x
|
||||
alts.forM fun alt => do
|
||||
match alt with
|
||||
| Alt.ctor i b => when (containsCtor v i) $ interpFnBody b
|
||||
| Alt.ctor i b => if containsCtor v i then interpFnBody b
|
||||
| Alt.default b => interpFnBody b
|
||||
| FnBody.ret x => do
|
||||
v ← findArgValue x;
|
||||
let v ← findArgValue x
|
||||
-- dbgTrace ("ret " ++ toString v) $ fun _ =>
|
||||
updateCurrFnSummary v
|
||||
| FnBody.jmp j xs => do
|
||||
ctx ← read;
|
||||
let ctx ← read
|
||||
let ys := (ctx.lctx.getJPParams j).get!;
|
||||
let b := (ctx.lctx.getJPBody j).get!;
|
||||
updated ← updateJPParamsAssignment ys xs;
|
||||
let updated ← updateJPParamsAssignment ys xs;
|
||||
when updated do
|
||||
-- We must reset the value of nested join-point parameters since they depend on `ys` values
|
||||
resetNestedJPParams b;
|
||||
interpFnBody b
|
||||
| e => unless (e.isTerminal) $ interpFnBody e.body
|
||||
| e => do unless (e.isTerminal) do interpFnBody e.body
|
||||
|
||||
def inferStep : M Bool := do
|
||||
ctx ← read;
|
||||
let ctx ← read
|
||||
modify $ fun s => { s with assignments := ctx.decls.map $ fun _ => {} };
|
||||
ctx.decls.size.foldM (fun idx modified => do
|
||||
match ctx.decls.get! idx with
|
||||
| Decl.fdecl fid ys _ b => do
|
||||
s ← get;
|
||||
let s ← get;
|
||||
-- dbgTrace (">> " ++ toString fid) $ fun _ =>
|
||||
let currVals := s.funVals.get! idx;
|
||||
withReader (fun ctx => { ctx with currFnIdx := idx }) $ do
|
||||
ys.forM $ fun y => updateVarAssignment y.x top;
|
||||
interpFnBody b;
|
||||
s ← get;
|
||||
let s ← get;
|
||||
let newVals := s.funVals.get! idx;
|
||||
pure (modified || currVals != newVals)
|
||||
| Decl.extern _ _ _ _ => pure modified)
|
||||
|
|
@ -259,24 +258,24 @@ ctx.decls.size.foldM (fun idx modified => do
|
|||
|
||||
partial def inferMain : Unit → M Unit
|
||||
| _ => do
|
||||
modified ← inferStep;
|
||||
let modified ← inferStep;
|
||||
if modified then inferMain () else pure ()
|
||||
|
||||
partial def elimDeadAux (assignment : Assignment) : FnBody → FnBody
|
||||
| FnBody.vdecl x t e b => FnBody.vdecl x t e (elimDeadAux b)
|
||||
| FnBody.jdecl j ys v b => FnBody.jdecl j ys (elimDeadAux v) (elimDeadAux b)
|
||||
| FnBody.vdecl x t e b => FnBody.vdecl x t e (elimDeadAux assignment b)
|
||||
| FnBody.jdecl j ys v b => FnBody.jdecl j ys (elimDeadAux assignment v) (elimDeadAux assignment b)
|
||||
| FnBody.case tid x xType alts =>
|
||||
let v := assignment.findD x bot;
|
||||
let alts := alts.map $ fun alt =>
|
||||
match alt with
|
||||
| Alt.ctor i b => Alt.ctor i $ if containsCtor v i then elimDeadAux b else FnBody.unreachable
|
||||
| Alt.default b => Alt.default (elimDeadAux b);
|
||||
| Alt.ctor i b => Alt.ctor i $ if containsCtor v i then elimDeadAux assignment b else FnBody.unreachable
|
||||
| Alt.default b => Alt.default (elimDeadAux assignment b);
|
||||
FnBody.case tid x xType alts
|
||||
| e =>
|
||||
if e.isTerminal then e
|
||||
else
|
||||
let (instr, b) := e.split;
|
||||
let b := elimDeadAux b;
|
||||
let b := elimDeadAux assignment b;
|
||||
instr.setBody b
|
||||
|
||||
partial def elimDead (assignment : Assignment) : Decl → Decl
|
||||
|
|
@ -288,22 +287,19 @@ end UnreachableBranches
|
|||
open UnreachableBranches
|
||||
|
||||
def elimDeadBranches (decls : Array Decl) : CompilerM (Array Decl) := do
|
||||
s ← get;
|
||||
let env := s.env;
|
||||
let assignments : Array Assignment := decls.map $ fun _ => {};
|
||||
let funVals := Std.mkPArray decls.size Value.bot;
|
||||
let ctx : InterpContext := { decls := decls, env := env };
|
||||
let s : InterpState := { assignments := assignments, funVals := funVals };
|
||||
let (_, s) := (inferMain () ctx).run s;
|
||||
let funVals := s.funVals;
|
||||
let assignments := s.assignments;
|
||||
modify $ fun s =>
|
||||
let env := decls.size.fold (fun i env =>
|
||||
-- dbgTrace (">> " ++ toString (decls.get! i).name ++ " " ++ toString (funVals.get! i)) $ fun _ =>
|
||||
addFunctionSummary env (decls.get! i).name (funVals.get! i))
|
||||
s.env;
|
||||
{ s with env := env };
|
||||
pure $ decls.mapIdx $ fun i decl => elimDead (assignments.get! i) decl
|
||||
let s ← get
|
||||
let env := s.env
|
||||
let assignments : Array Assignment := decls.map $ fun _ => {}
|
||||
let funVals := Std.mkPArray decls.size Value.bot
|
||||
let ctx : InterpContext := { decls := decls, env := env }
|
||||
let s : InterpState := { assignments := assignments, funVals := funVals }
|
||||
let (_, s) := (inferMain () ctx).run s
|
||||
let funVals := s.funVals
|
||||
let assignments := s.assignments
|
||||
modify fun s =>
|
||||
let env := decls.size.fold (init := s.env) fun i env =>
|
||||
addFunctionSummary env decls[i].name funVals[i]
|
||||
{ s with env := env }
|
||||
pure $ decls.mapIdx $ fun i decl => elimDead assignments[i] decl
|
||||
|
||||
end IR
|
||||
end Lean
|
||||
end Lean.IR
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue