diff --git a/src/Lean/Compiler/LCNF/ToLCNF.lean b/src/Lean/Compiler/LCNF/ToLCNF.lean index 5f8d49c9ec..4c42d7ebe2 100644 --- a/src/Lean/Compiler/LCNF/ToLCNF.lean +++ b/src/Lean/Compiler/LCNF/ToLCNF.lean @@ -37,12 +37,11 @@ inductive Element where | unreach (fvarId : FVarId) deriving Inhabited -/-- State for `BindCasesM` monad -/ -structure BindCasesM.State where - /-- New auxiliary join points. -/ - jps : Array FunDecl := #[] - /-- Mapping from `_alt.` variables to new join points -/ - map : FVarIdMap FVarId := {} +/-- +State for `BindCasesM` monad +Mapping from `_alt.` variables to new join points +-/ +abbrev BindCasesM.State := FVarIdMap FunDecl /-- Auxiliary monad for implementing `bindCases` -/ abbrev BindCasesM := StateRefT BindCasesM.State CompilerM @@ -61,9 +60,8 @@ That is, our goal is to try to promote the pre join points `_alt.` into a p partial def bindCases (jpDecl : FunDecl) (cases : Cases) : CompilerM Code := do let (alts, s) ← visitAlts cases.alts |>.run {} let resultType ← mkCasesResultType alts - let mut result := .cases { cases with alts, resultType } - for decl in s.jps do - result := .jp decl result + let result := .cases { cases with alts, resultType } + let result := s.fold (init := result) fun result _ altJp => .jp altJp result return .jp jpDecl result where visitAlts (alts : Array Alt) : BindCasesM (Array Alt) := @@ -96,9 +94,9 @@ where if let some funDecl ← findFun? f then let args := decl.value.getAppArgs eraseFVar decl.fvarId - if let some altJp := (← get).map.find? f then + if let some altJp := (← get).find? f then /- We already have an auxiliary join point for `f`, then, we just use it. -/ - return .jmp altJp args + return .jmp altJp.fvarId args else /- We have not created a join point for `f` yet. @@ -124,12 +122,15 @@ where let letDecl ← mkAuxLetDecl (mkAppN decl.value.getAppFn jpArgs) let jpValue := .let letDecl (.jmp jpDecl.fvarId #[.fvar letDecl.fvarId]) let altJp ← mkAuxJpDecl jpParams jpValue - modify fun { jps, map } => { - jps := jps.push altJp - map := map.insert f altJp.fvarId - } + modify fun map => map.insert f altJp return .jmp altJp.fvarId args - return .let decl (← go k) + let k ← go k + if let some altJp := (← get).find? decl.fvarId then + -- The new join point depends on this variable. Thus, we must insert it here + modify fun s => s.erase decl.fvarId + return .let decl (.jp altJp k) + else + return .let decl k | .fun decl k => return .fun decl (← go k) | .jp decl k => let value ← go decl.value diff --git a/tests/lean/run/bindCasesIssue.lean b/tests/lean/run/bindCasesIssue.lean new file mode 100644 index 0000000000..00a92b7287 --- /dev/null +++ b/tests/lean/run/bindCasesIssue.lean @@ -0,0 +1,12 @@ +import Lean + +def bar : ReaderM Unit Unit := + if true then + match true with + | true => pure () + | false => pure () + else + pure () + +set_option trace.Compiler true +#eval Lean.Compiler.compile #[``bar]