feat: use DiscrM to implement simpJpCases?

This commit is contained in:
Leonardo de Moura 2022-10-03 19:13:31 -07:00
parent ddbf4c01eb
commit da4812659c
3 changed files with 107 additions and 11 deletions

View file

@ -7,6 +7,7 @@ import Lean.Compiler.LCNF.DependsOn
import Lean.Compiler.LCNF.InferType
import Lean.Compiler.LCNF.Internalize
import Lean.Compiler.LCNF.Simp.Basic
import Lean.Compiler.LCNF.Simp.DiscrM
namespace Lean.Compiler.LCNF
namespace Simp
@ -61,10 +62,10 @@ there is a jump `.jmp jpFVarId #[..., x, ...]` in `code` and `x` is a constructo
`paramIdx` is the index of the parameter
-/
partial def collectJpCasesInfo (code : Code) : CompilerM JpCasesInfoMap := do
let (_, s) ← go code |>.run {}
let (_, s) ← go code |>.run {} |>.run {}
return s
where
go (code : Code) : StateRefT JpCasesInfoMap CompilerM Unit := do
go (code : Code) : StateRefT JpCasesInfoMap DiscrM Unit := do
match code with
| .let _ k => go k
| .fun decl k => go decl.value; go k
@ -72,11 +73,14 @@ where
if let some paramIdx ← isJpCases? decl then
modify fun s => s.insert decl.fvarId { paramIdx }
go decl.value; go k
| .cases c => c.alts.forM fun alt => go alt.getCode
| .cases c => c.alts.forM fun alt =>
match alt with
| .default k => go k
| .alt ctorName ps k => withDiscrCtor c.discr ctorName ps <| go k
| .return .. | .unreach .. => return ()
| .jmp fvarId args =>
if let some info := (← get).find? fvarId then
let arg ← findExpr args[info.paramIdx]!
let arg ← findCtor args[info.paramIdx]!
let some (cval, _) := arg.constructorApp? (← getEnv) | return ()
modify fun map => map.insert fvarId <| { info with ctorNames := info.ctorNames.insert cval.name }
@ -195,9 +199,9 @@ partial def simpJpCases? (code : Code) : CompilerM (Option Code) := do
for (fvarId, info) in map.toList do
msg := msg ++ indentD m!"{mkFVar fvarId} ↦ {info.ctorNames.toList}"
return msg
visit code map |>.run' {}
visit code map |>.run' {} |>.run {}
where
visit (code : Code) : ReaderT JpCasesInfoMap (StateRefT Ctor2JpCasesAlt CompilerM) Code := do
visit (code : Code) : ReaderT JpCasesInfoMap (StateRefT Ctor2JpCasesAlt DiscrM) Code := do
match code with
| .let decl k =>
return code.updateLet! decl (← visit k)
@ -213,14 +217,19 @@ where
let decl ← decl.updateValue value
return code.updateFun! decl (← visit k)
| .cases c =>
let alts ← c.alts.mapMonoM fun alt => return alt.updateCode (← visit alt.getCode)
let alts ← c.alts.mapMonoM fun alt =>
match alt with
| .alt ctorName ps k =>
withDiscrCtor c.discr ctorName ps do
return alt.updateCode (← visit k)
| .default k => return alt.updateCode (← visit k)
return code.updateAlts! alts
| .return _ | .unreach _ => return code
| .jmp fvarId args =>
let some code ← visitJmp? fvarId args | return code
return code
visitJp? (decl : FunDecl) (k : Code) : ReaderT JpCasesInfoMap (StateRefT Ctor2JpCasesAlt CompilerM) (Option Code) := do
visitJp? (decl : FunDecl) (k : Code) : ReaderT JpCasesInfoMap (StateRefT Ctor2JpCasesAlt DiscrM) (Option Code) := do
let some info := (← read).find? decl.fvarId | return none
if info.ctorNames.isEmpty then return none
-- This join point satisfies `isJpCases?` and there are jumps with constructors in `info` to it.
@ -245,7 +254,7 @@ where
else
altsNew := altsNew.push (alt.updateCode k)
| .alt ctorName fields k =>
let k ← visit k
let k ← withDiscrCtor cases.discr ctorName fields <| visit k
if info.ctorNames.contains ctorName then
let jpAlt ← mkJpAlt decls decl.params info.paramIdx fields k (default := false)
jpAltDecls := jpAltDecls.push (.jp jpAlt.decl)
@ -261,10 +270,10 @@ where
let code := .jp decl (← visit k)
return LCNF.attachCodeDecls jpAltDecls code
visitJmp? (fvarId : FVarId) (args : Array Expr) : ReaderT JpCasesInfoMap (StateRefT Ctor2JpCasesAlt CompilerM) (Option Code) := do
visitJmp? (fvarId : FVarId) (args : Array Expr) : ReaderT JpCasesInfoMap (StateRefT Ctor2JpCasesAlt DiscrM) (Option Code) := do
let some ctorJpAltMap := (← get).find? fvarId | return none
let some info := (← read).find? fvarId | return none
let arg ← findExpr args[info.paramIdx]!
let arg ← findCtor args[info.paramIdx]!
let some (ctorVal, ctorArgs) := arg.constructorApp? (← getEnv) (useRaw := true) | return none
let some jpAlt := ctorJpAltMap.find? ctorVal.name | return none
let fields := ctorArgs[ctorVal.numParams:]

View file

@ -0,0 +1,35 @@
set_option trace.Compiler.saveBase true in
def f1 (c : Bool) (a b : Nat) :=
let k d y z :=
match d with
| true => y + z + z*y
| false => z + y
match c with
| true => k true a b
| false => k false b a
set_option trace.Compiler.saveBase true in
def f2 (c : Bool) (a b : Nat) :=
let k d y z :=
match d with
| true => y + z + z*y
| false => z + y
match c with
| true => k c a b
| false => k c b a
inductive C where
| c1 | c2 | c3 | c4
set_option trace.Compiler.saveBase true in
def f3 (c c' : C) (a b : Nat) :=
let k y z (d : C) :=
match d with
| C.c1 => y + z + z*y
| C.c3 => y*y+a
| _ => z + y + y
match c with
| .c1 => k a b c
| .c2 => k b b c
| .c3 => k b a c'
| .c4 => k a a c'

View file

@ -0,0 +1,52 @@
[Compiler.saveBase] size: 7
def f1 c a b :=
cases c
| Bool.false =>
let _x.1 := Nat.add a b
_x.1
| Bool.true =>
let _x.2 := Nat.add a b
let _x.3 := Nat.mul b a
let _x.4 := Nat.add _x.2 _x.3
_x.4
[Compiler.saveBase] size: 7
def f2 c a b :=
cases c
| Bool.false =>
let _x.1 := Nat.add a b
_x.1
| Bool.true =>
let _x.2 := Nat.add a b
let _x.3 := Nat.mul b a
let _x.4 := Nat.add _x.2 _x.3
_x.4
[Compiler.saveBase] size: 22
def f3 c c' a b :=
jp _jp.1 y z :=
let _x.2 := Nat.add y z
let _x.3 := Nat.mul z y
let _x.4 := Nat.add _x.2 _x.3
_x.4
jp _jp.5 y z :=
let _x.6 := Nat.add z y
let _x.7 := Nat.add _x.6 y
_x.7
jp _jp.8 y z d :=
cases d
| C.c1 =>
goto _jp.1 y z
| C.c3 =>
let _x.9 := Nat.mul y y
let _x.10 := Nat.add _x.9 a
_x.10
| _ =>
goto _jp.5 y z
cases c
| C.c1 =>
goto _jp.1 a b
| C.c2 =>
goto _jp.5 b b
| C.c3 =>
goto _jp.8 b a c'
| C.c4 =>
goto _jp.8 a a c'