fix: bug at bindCases

Many thanks to @hargoniX
This commit is contained in:
Leonardo de Moura 2022-09-13 15:36:46 -07:00
parent 7535c12bc5
commit 8f2ab82408
2 changed files with 29 additions and 16 deletions

View file

@ -37,12 +37,11 @@ inductive Element where
| unreach (fvarId : FVarId)
deriving Inhabited
/-- State for `BindCasesM` monad -/
structure BindCasesM.State where
/-- New auxiliary join points. -/
jps : Array FunDecl := #[]
/-- Mapping from `_alt.<idx>` variables to new join points -/
map : FVarIdMap FVarId := {}
/--
State for `BindCasesM` monad
Mapping from `_alt.<idx>` variables to new join points
-/
abbrev BindCasesM.State := FVarIdMap FunDecl
/-- Auxiliary monad for implementing `bindCases` -/
abbrev BindCasesM := StateRefT BindCasesM.State CompilerM
@ -61,9 +60,8 @@ That is, our goal is to try to promote the pre join points `_alt.<idx>` into a p
partial def bindCases (jpDecl : FunDecl) (cases : Cases) : CompilerM Code := do
let (alts, s) ← visitAlts cases.alts |>.run {}
let resultType ← mkCasesResultType alts
let mut result := .cases { cases with alts, resultType }
for decl in s.jps do
result := .jp decl result
let result := .cases { cases with alts, resultType }
let result := s.fold (init := result) fun result _ altJp => .jp altJp result
return .jp jpDecl result
where
visitAlts (alts : Array Alt) : BindCasesM (Array Alt) :=
@ -96,9 +94,9 @@ where
if let some funDecl ← findFun? f then
let args := decl.value.getAppArgs
eraseFVar decl.fvarId
if let some altJp := (← get).map.find? f then
if let some altJp := (← get).find? f then
/- We already have an auxiliary join point for `f`, then, we just use it. -/
return .jmp altJp args
return .jmp altJp.fvarId args
else
/-
We have not created a join point for `f` yet.
@ -124,12 +122,15 @@ where
let letDecl ← mkAuxLetDecl (mkAppN decl.value.getAppFn jpArgs)
let jpValue := .let letDecl (.jmp jpDecl.fvarId #[.fvar letDecl.fvarId])
let altJp ← mkAuxJpDecl jpParams jpValue
modify fun { jps, map } => {
jps := jps.push altJp
map := map.insert f altJp.fvarId
}
modify fun map => map.insert f altJp
return .jmp altJp.fvarId args
return .let decl (← go k)
let k ← go k
if let some altJp := (← get).find? decl.fvarId then
-- The new join point depends on this variable. Thus, we must insert it here
modify fun s => s.erase decl.fvarId
return .let decl (.jp altJp k)
else
return .let decl k
| .fun decl k => return .fun decl (← go k)
| .jp decl k =>
let value ← go decl.value

View file

@ -0,0 +1,12 @@
import Lean
def bar : ReaderM Unit Unit :=
if true then
match true with
| true => pure ()
| false => pure ()
else
pure ()
set_option trace.Compiler true
#eval Lean.Compiler.compile #[``bar]