feat: use DiscrM to implement simpJpCases?
This commit is contained in:
parent
ddbf4c01eb
commit
da4812659c
3 changed files with 107 additions and 11 deletions
|
|
@ -7,6 +7,7 @@ import Lean.Compiler.LCNF.DependsOn
|
|||
import Lean.Compiler.LCNF.InferType
|
||||
import Lean.Compiler.LCNF.Internalize
|
||||
import Lean.Compiler.LCNF.Simp.Basic
|
||||
import Lean.Compiler.LCNF.Simp.DiscrM
|
||||
|
||||
namespace Lean.Compiler.LCNF
|
||||
namespace Simp
|
||||
|
|
@ -61,10 +62,10 @@ there is a jump `.jmp jpFVarId #[..., x, ...]` in `code` and `x` is a constructo
|
|||
`paramIdx` is the index of the parameter
|
||||
-/
|
||||
partial def collectJpCasesInfo (code : Code) : CompilerM JpCasesInfoMap := do
|
||||
let (_, s) ← go code |>.run {}
|
||||
let (_, s) ← go code |>.run {} |>.run {}
|
||||
return s
|
||||
where
|
||||
go (code : Code) : StateRefT JpCasesInfoMap CompilerM Unit := do
|
||||
go (code : Code) : StateRefT JpCasesInfoMap DiscrM Unit := do
|
||||
match code with
|
||||
| .let _ k => go k
|
||||
| .fun decl k => go decl.value; go k
|
||||
|
|
@ -72,11 +73,14 @@ where
|
|||
if let some paramIdx ← isJpCases? decl then
|
||||
modify fun s => s.insert decl.fvarId { paramIdx }
|
||||
go decl.value; go k
|
||||
| .cases c => c.alts.forM fun alt => go alt.getCode
|
||||
| .cases c => c.alts.forM fun alt =>
|
||||
match alt with
|
||||
| .default k => go k
|
||||
| .alt ctorName ps k => withDiscrCtor c.discr ctorName ps <| go k
|
||||
| .return .. | .unreach .. => return ()
|
||||
| .jmp fvarId args =>
|
||||
if let some info := (← get).find? fvarId then
|
||||
let arg ← findExpr args[info.paramIdx]!
|
||||
let arg ← findCtor args[info.paramIdx]!
|
||||
let some (cval, _) := arg.constructorApp? (← getEnv) | return ()
|
||||
modify fun map => map.insert fvarId <| { info with ctorNames := info.ctorNames.insert cval.name }
|
||||
|
||||
|
|
@ -195,9 +199,9 @@ partial def simpJpCases? (code : Code) : CompilerM (Option Code) := do
|
|||
for (fvarId, info) in map.toList do
|
||||
msg := msg ++ indentD m!"{mkFVar fvarId} ↦ {info.ctorNames.toList}"
|
||||
return msg
|
||||
visit code map |>.run' {}
|
||||
visit code map |>.run' {} |>.run {}
|
||||
where
|
||||
visit (code : Code) : ReaderT JpCasesInfoMap (StateRefT Ctor2JpCasesAlt CompilerM) Code := do
|
||||
visit (code : Code) : ReaderT JpCasesInfoMap (StateRefT Ctor2JpCasesAlt DiscrM) Code := do
|
||||
match code with
|
||||
| .let decl k =>
|
||||
return code.updateLet! decl (← visit k)
|
||||
|
|
@ -213,14 +217,19 @@ where
|
|||
let decl ← decl.updateValue value
|
||||
return code.updateFun! decl (← visit k)
|
||||
| .cases c =>
|
||||
let alts ← c.alts.mapMonoM fun alt => return alt.updateCode (← visit alt.getCode)
|
||||
let alts ← c.alts.mapMonoM fun alt =>
|
||||
match alt with
|
||||
| .alt ctorName ps k =>
|
||||
withDiscrCtor c.discr ctorName ps do
|
||||
return alt.updateCode (← visit k)
|
||||
| .default k => return alt.updateCode (← visit k)
|
||||
return code.updateAlts! alts
|
||||
| .return _ | .unreach _ => return code
|
||||
| .jmp fvarId args =>
|
||||
let some code ← visitJmp? fvarId args | return code
|
||||
return code
|
||||
|
||||
visitJp? (decl : FunDecl) (k : Code) : ReaderT JpCasesInfoMap (StateRefT Ctor2JpCasesAlt CompilerM) (Option Code) := do
|
||||
visitJp? (decl : FunDecl) (k : Code) : ReaderT JpCasesInfoMap (StateRefT Ctor2JpCasesAlt DiscrM) (Option Code) := do
|
||||
let some info := (← read).find? decl.fvarId | return none
|
||||
if info.ctorNames.isEmpty then return none
|
||||
-- This join point satisfies `isJpCases?` and there are jumps with constructors in `info` to it.
|
||||
|
|
@ -245,7 +254,7 @@ where
|
|||
else
|
||||
altsNew := altsNew.push (alt.updateCode k)
|
||||
| .alt ctorName fields k =>
|
||||
let k ← visit k
|
||||
let k ← withDiscrCtor cases.discr ctorName fields <| visit k
|
||||
if info.ctorNames.contains ctorName then
|
||||
let jpAlt ← mkJpAlt decls decl.params info.paramIdx fields k (default := false)
|
||||
jpAltDecls := jpAltDecls.push (.jp jpAlt.decl)
|
||||
|
|
@ -261,10 +270,10 @@ where
|
|||
let code := .jp decl (← visit k)
|
||||
return LCNF.attachCodeDecls jpAltDecls code
|
||||
|
||||
visitJmp? (fvarId : FVarId) (args : Array Expr) : ReaderT JpCasesInfoMap (StateRefT Ctor2JpCasesAlt CompilerM) (Option Code) := do
|
||||
visitJmp? (fvarId : FVarId) (args : Array Expr) : ReaderT JpCasesInfoMap (StateRefT Ctor2JpCasesAlt DiscrM) (Option Code) := do
|
||||
let some ctorJpAltMap := (← get).find? fvarId | return none
|
||||
let some info := (← read).find? fvarId | return none
|
||||
let arg ← findExpr args[info.paramIdx]!
|
||||
let arg ← findCtor args[info.paramIdx]!
|
||||
let some (ctorVal, ctorArgs) := arg.constructorApp? (← getEnv) (useRaw := true) | return none
|
||||
let some jpAlt := ctorJpAltMap.find? ctorVal.name | return none
|
||||
let fields := ctorArgs[ctorVal.numParams:]
|
||||
|
|
|
|||
35
tests/lean/jpCasesDiscrM.lean
Normal file
35
tests/lean/jpCasesDiscrM.lean
Normal file
|
|
@ -0,0 +1,35 @@
|
|||
set_option trace.Compiler.saveBase true in
|
||||
def f1 (c : Bool) (a b : Nat) :=
|
||||
let k d y z :=
|
||||
match d with
|
||||
| true => y + z + z*y
|
||||
| false => z + y
|
||||
match c with
|
||||
| true => k true a b
|
||||
| false => k false b a
|
||||
|
||||
set_option trace.Compiler.saveBase true in
|
||||
def f2 (c : Bool) (a b : Nat) :=
|
||||
let k d y z :=
|
||||
match d with
|
||||
| true => y + z + z*y
|
||||
| false => z + y
|
||||
match c with
|
||||
| true => k c a b
|
||||
| false => k c b a
|
||||
|
||||
inductive C where
|
||||
| c1 | c2 | c3 | c4
|
||||
|
||||
set_option trace.Compiler.saveBase true in
|
||||
def f3 (c c' : C) (a b : Nat) :=
|
||||
let k y z (d : C) :=
|
||||
match d with
|
||||
| C.c1 => y + z + z*y
|
||||
| C.c3 => y*y+a
|
||||
| _ => z + y + y
|
||||
match c with
|
||||
| .c1 => k a b c
|
||||
| .c2 => k b b c
|
||||
| .c3 => k b a c'
|
||||
| .c4 => k a a c'
|
||||
52
tests/lean/jpCasesDiscrM.lean.expected.out
Normal file
52
tests/lean/jpCasesDiscrM.lean.expected.out
Normal file
|
|
@ -0,0 +1,52 @@
|
|||
[Compiler.saveBase] size: 7
|
||||
def f1 c a b :=
|
||||
cases c
|
||||
| Bool.false =>
|
||||
let _x.1 := Nat.add a b
|
||||
_x.1
|
||||
| Bool.true =>
|
||||
let _x.2 := Nat.add a b
|
||||
let _x.3 := Nat.mul b a
|
||||
let _x.4 := Nat.add _x.2 _x.3
|
||||
_x.4
|
||||
[Compiler.saveBase] size: 7
|
||||
def f2 c a b :=
|
||||
cases c
|
||||
| Bool.false =>
|
||||
let _x.1 := Nat.add a b
|
||||
_x.1
|
||||
| Bool.true =>
|
||||
let _x.2 := Nat.add a b
|
||||
let _x.3 := Nat.mul b a
|
||||
let _x.4 := Nat.add _x.2 _x.3
|
||||
_x.4
|
||||
[Compiler.saveBase] size: 22
|
||||
def f3 c c' a b :=
|
||||
jp _jp.1 y z :=
|
||||
let _x.2 := Nat.add y z
|
||||
let _x.3 := Nat.mul z y
|
||||
let _x.4 := Nat.add _x.2 _x.3
|
||||
_x.4
|
||||
jp _jp.5 y z :=
|
||||
let _x.6 := Nat.add z y
|
||||
let _x.7 := Nat.add _x.6 y
|
||||
_x.7
|
||||
jp _jp.8 y z d :=
|
||||
cases d
|
||||
| C.c1 =>
|
||||
goto _jp.1 y z
|
||||
| C.c3 =>
|
||||
let _x.9 := Nat.mul y y
|
||||
let _x.10 := Nat.add _x.9 a
|
||||
_x.10
|
||||
| _ =>
|
||||
goto _jp.5 y z
|
||||
cases c
|
||||
| C.c1 =>
|
||||
goto _jp.1 a b
|
||||
| C.c2 =>
|
||||
goto _jp.5 b b
|
||||
| C.c3 =>
|
||||
goto _jp.8 b a c'
|
||||
| C.c4 =>
|
||||
goto _jp.8 a a c'
|
||||
Loading…
Add table
Reference in a new issue