fix: bug at bindCases
Many thanks to @hargoniX
This commit is contained in:
parent
7535c12bc5
commit
8f2ab82408
2 changed files with 29 additions and 16 deletions
|
|
@ -37,12 +37,11 @@ inductive Element where
|
|||
| unreach (fvarId : FVarId)
|
||||
deriving Inhabited
|
||||
|
||||
/-- State for `BindCasesM` monad -/
|
||||
structure BindCasesM.State where
|
||||
/-- New auxiliary join points. -/
|
||||
jps : Array FunDecl := #[]
|
||||
/-- Mapping from `_alt.<idx>` variables to new join points -/
|
||||
map : FVarIdMap FVarId := {}
|
||||
/--
|
||||
State for `BindCasesM` monad
|
||||
Mapping from `_alt.<idx>` variables to new join points
|
||||
-/
|
||||
abbrev BindCasesM.State := FVarIdMap FunDecl
|
||||
|
||||
/-- Auxiliary monad for implementing `bindCases` -/
|
||||
abbrev BindCasesM := StateRefT BindCasesM.State CompilerM
|
||||
|
|
@ -61,9 +60,8 @@ That is, our goal is to try to promote the pre join points `_alt.<idx>` into a p
|
|||
partial def bindCases (jpDecl : FunDecl) (cases : Cases) : CompilerM Code := do
|
||||
let (alts, s) ← visitAlts cases.alts |>.run {}
|
||||
let resultType ← mkCasesResultType alts
|
||||
let mut result := .cases { cases with alts, resultType }
|
||||
for decl in s.jps do
|
||||
result := .jp decl result
|
||||
let result := .cases { cases with alts, resultType }
|
||||
let result := s.fold (init := result) fun result _ altJp => .jp altJp result
|
||||
return .jp jpDecl result
|
||||
where
|
||||
visitAlts (alts : Array Alt) : BindCasesM (Array Alt) :=
|
||||
|
|
@ -96,9 +94,9 @@ where
|
|||
if let some funDecl ← findFun? f then
|
||||
let args := decl.value.getAppArgs
|
||||
eraseFVar decl.fvarId
|
||||
if let some altJp := (← get).map.find? f then
|
||||
if let some altJp := (← get).find? f then
|
||||
/- We already have an auxiliary join point for `f`, then, we just use it. -/
|
||||
return .jmp altJp args
|
||||
return .jmp altJp.fvarId args
|
||||
else
|
||||
/-
|
||||
We have not created a join point for `f` yet.
|
||||
|
|
@ -124,12 +122,15 @@ where
|
|||
let letDecl ← mkAuxLetDecl (mkAppN decl.value.getAppFn jpArgs)
|
||||
let jpValue := .let letDecl (.jmp jpDecl.fvarId #[.fvar letDecl.fvarId])
|
||||
let altJp ← mkAuxJpDecl jpParams jpValue
|
||||
modify fun { jps, map } => {
|
||||
jps := jps.push altJp
|
||||
map := map.insert f altJp.fvarId
|
||||
}
|
||||
modify fun map => map.insert f altJp
|
||||
return .jmp altJp.fvarId args
|
||||
return .let decl (← go k)
|
||||
let k ← go k
|
||||
if let some altJp := (← get).find? decl.fvarId then
|
||||
-- The new join point depends on this variable. Thus, we must insert it here
|
||||
modify fun s => s.erase decl.fvarId
|
||||
return .let decl (.jp altJp k)
|
||||
else
|
||||
return .let decl k
|
||||
| .fun decl k => return .fun decl (← go k)
|
||||
| .jp decl k =>
|
||||
let value ← go decl.value
|
||||
|
|
|
|||
12
tests/lean/run/bindCasesIssue.lean
Normal file
12
tests/lean/run/bindCasesIssue.lean
Normal file
|
|
@ -0,0 +1,12 @@
|
|||
import Lean
|
||||
|
||||
def bar : ReaderM Unit Unit :=
|
||||
if true then
|
||||
match true with
|
||||
| true => pure ()
|
||||
| false => pure ()
|
||||
else
|
||||
pure ()
|
||||
|
||||
set_option trace.Compiler true
|
||||
#eval Lean.Compiler.compile #[``bar]
|
||||
Loading…
Add table
Reference in a new issue