diff --git a/library/init/lean/compiler/ir/unreachbranches.lean b/library/init/lean/compiler/ir/unreachbranches.lean index 9a725b5308..bf52c896f3 100644 --- a/library/init/lean/compiler/ir/unreachbranches.lean +++ b/library/init/lean/compiler/ir/unreachbranches.lean @@ -5,6 +5,7 @@ Authors: Leonardo de Moura -/ prelude import init.control.reader +import init.data.option import init.lean.compiler.ir.format import init.lean.compiler.ir.basic @@ -85,26 +86,26 @@ functionSummariesExt.addEntry env (fid, v) def getFunctionSummary (env : Environment) (fid : FunId) : Option Value := (functionSummariesExt.getState env).find fid -def Assignment := PHashMap VarId Value +abbrev Assignment := HashMap VarId Value structure InterpContext := -(currFn : Name) -(env : Environment) -(lctx : LocalContext) - -abbrev FunMap := HashMap FunId Value +(currFnIdx : Nat := 0) +(decls : Array Decl) +(env : Environment) +(lctx : LocalContext := {}) structure InterpState := -(assignment : Assignment) -(funMap : FunMap) +(assignments : Array Assignment) +(funVals : PArray Value) -- we take snapshots during fixpoint computations abbrev M := ReaderT InterpContext (State InterpState) open Value def findVarValue (x : VarId) : M Value := -do s ← get; - match s.assignment.find x with +do ctx ← read; + s ← get; + match (s.assignments.get! ctx.currFnIdx).find x with | some v => pure v | none => pure top @@ -115,7 +116,8 @@ match arg with def updateVarAssignment (x : VarId) (v : Value) : M Unit := do v' ← findVarValue x; - modify $ fun s => { assignment := s.assignment.insert x (merge v v'), .. s } + ctx ← read; + modify $ fun s => { assignments := s.assignments.modify ctx.currFnIdx $ fun a => a.insert x (merge v v'), .. s } partial def projValue : Value → Nat → Value | ctor _ vs, i => vs.getD i bot @@ -129,7 +131,11 @@ def interpExpr : Expr → M Value ctx ← read; match getFunctionSummary ctx.env fid with | some v => pure v - | none => pure top + | none => do + s ← get; + match ctx.decls.findIdx? (fun decl => decl.name == fid) with + | some idx => pure $ s.funVals.get! idx + | none => pure top | _ => pure top partial def containsCtor : Value → CtorInfo → Bool @@ -140,10 +146,26 @@ partial def containsCtor : Value → CtorInfo → Bool def updateCurrFnSummary (v : Value) : M Unit := do ctx ← read; - s ← get; - let currFn := ctx.currFn; - let v' := (s.funMap.find currFn).getOrElse bot; - modify $ fun s => { funMap := s.funMap.insert currFn (merge v v'), .. s } + let currFnIdx := ctx.currFnIdx; + s ← get; + modify $ fun s => { funVals := s.funVals.modify currFnIdx (fun v' => merge v v'), .. s } + + +/-- Return true if the assignment of at least one parameter has been updated. -/ +def updateJPParamsAssignment (ys : Array Param) (xs : Array Arg) : M Bool := +do ctx ← read; + let currFnIdx := ctx.currFnIdx; + ys.size.mfold (fun i r => do + let y := ys.get! i; + let x := xs.get! i; + yVal ← findVarValue y.x; + xVal ← findArgValue x; + let newVal := merge yVal xVal; + if newVal == yVal then pure r + else do + modify $ fun s => { assignments := s.assignments.modify currFnIdx $ fun a => a.insert y.x newVal, .. s }; + pure true) + false partial def interpFnBody : FnBody → M Unit | FnBody.vdecl x _ e b => do @@ -164,10 +186,42 @@ partial def interpFnBody : FnBody → M Unit updateCurrFnSummary v | FnBody.jmp j xs => do ctx ← read; - -- TODO - pure () + let ys := (ctx.lctx.getJPParams j).get!; + updated ← updateJPParamsAssignment ys xs; + when updated $ + interpFnBody $ (ctx.lctx.getJPBody j).get! | e => unless (e.isTerminal) $ interpFnBody e.body +def inferStep : M Bool := +do ctx ← read; + ctx.decls.size.mfold (fun idx modified => do + match ctx.decls.get! idx with + | Decl.fdecl _ _ _ b => do + s ← get; + let currVals := s.funVals.get! idx; + adaptReader (fun (ctx : InterpContext) => { currFnIdx := idx, .. ctx }) $ + interpFnBody b; + s ← get; + let newVals := s.funVals.get! idx; + -- TODO: apply widening + pure (modified || currVals != newVals) + | Decl.extern _ _ _ _ => pure modified) + false + +partial def inferMain : Unit → M Unit +| _ => do + modified ← inferStep; + if modified then inferMain () else pure () + +def infer (env : Environment) (decls : Array Decl) : Environment := +let assignments : Array Assignment := decls.map $ fun _ => {}; +let funVals := mkPArray decls.size bot; +let ctx : InterpContext := { decls := decls, env := env }; +let s : InterpState := { assignments := assignments, funVals := funVals }; +let (_, s) := (inferMain () ctx).run s; +let funVals := s.funVals; +decls.size.fold (fun i env => addFunctionSummary env (decls.get! i).name (funVals.get! i)) env + end UnreachableBranches end IR end Lean