diff --git a/src/Lean/Compiler/IR/ElimDeadBranches.lean b/src/Lean/Compiler/IR/ElimDeadBranches.lean index 3376ab4fce..7465b705c8 100644 --- a/src/Lean/Compiler/IR/ElimDeadBranches.lean +++ b/src/Lean/Compiler/IR/ElimDeadBranches.lean @@ -142,6 +142,13 @@ v' ← findVarValue x; ctx ← read; modify $ fun s => { s with assignments := s.assignments.modify ctx.currFnIdx $ fun a => a.insert x (merge v v') } +def resetVarAssignment (x : VarId) : M Unit := do +ctx ← read; +modify $ fun s => { s with assignments := s.assignments.modify ctx.currFnIdx $ fun a => a.insert x Value.bot } + +def resetParamAssignment (y : Param) : M Unit := +resetVarAssignment y.x + partial def projValue : Value → Nat → Value | ctor _ vs, i => vs.getD i bot | choice vs, i => vs.foldl (fun r v => merge r (projValue v i)) bot @@ -189,6 +196,19 @@ ys.size.foldM pure true) false +private partial def resetNestedJPParams : FnBody → M Unit +| FnBody.jdecl _ ys _ _ => do + ctx ← read; + let currFnIdx := ctx.currFnIdx; + ys.forM resetParamAssignment + /- Remark we don't need to reset the parameters of joint-points + nested on this one since they will be reset if this JP is used. -/ +| FnBody.case _ _ _ alts => + alts.forM fun alt => match alt with + | Alt.ctor _ b => resetNestedJPParams b + | Alt.default b => resetNestedJPParams b +| e => unless (e.isTerminal) $ resetNestedJPParams e.body + partial def interpFnBody : FnBody → M Unit | FnBody.vdecl x _ e b => do v ← interpExpr e; @@ -210,9 +230,12 @@ partial def interpFnBody : FnBody → M Unit | FnBody.jmp j xs => do ctx ← read; let ys := (ctx.lctx.getJPParams j).get!; + let b := (ctx.lctx.getJPBody j).get!; updated ← updateJPParamsAssignment ys xs; - when updated $ - interpFnBody $ (ctx.lctx.getJPBody j).get! + when updated do + -- We must reset the value of nested join-point parameters since they depend on `ys` values + resetNestedJPParams b; + interpFnBody b | e => unless (e.isTerminal) $ interpFnBody e.body def inferStep : M Bool := do