From da4812659ce45658e8bad6cf591620a5c0933ea2 Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Mon, 3 Oct 2022 19:13:31 -0700 Subject: [PATCH] feat: use `DiscrM` to implement `simpJpCases?` --- src/Lean/Compiler/LCNF/Simp/JpCases.lean | 31 ++++++++----- tests/lean/jpCasesDiscrM.lean | 35 +++++++++++++++ tests/lean/jpCasesDiscrM.lean.expected.out | 52 ++++++++++++++++++++++ 3 files changed, 107 insertions(+), 11 deletions(-) create mode 100644 tests/lean/jpCasesDiscrM.lean create mode 100644 tests/lean/jpCasesDiscrM.lean.expected.out diff --git a/src/Lean/Compiler/LCNF/Simp/JpCases.lean b/src/Lean/Compiler/LCNF/Simp/JpCases.lean index f24f671934..8cebb745ee 100644 --- a/src/Lean/Compiler/LCNF/Simp/JpCases.lean +++ b/src/Lean/Compiler/LCNF/Simp/JpCases.lean @@ -7,6 +7,7 @@ import Lean.Compiler.LCNF.DependsOn import Lean.Compiler.LCNF.InferType import Lean.Compiler.LCNF.Internalize import Lean.Compiler.LCNF.Simp.Basic +import Lean.Compiler.LCNF.Simp.DiscrM namespace Lean.Compiler.LCNF namespace Simp @@ -61,10 +62,10 @@ there is a jump `.jmp jpFVarId #[..., x, ...]` in `code` and `x` is a constructo `paramIdx` is the index of the parameter -/ partial def collectJpCasesInfo (code : Code) : CompilerM JpCasesInfoMap := do - let (_, s) ← go code |>.run {} + let (_, s) ← go code |>.run {} |>.run {} return s where - go (code : Code) : StateRefT JpCasesInfoMap CompilerM Unit := do + go (code : Code) : StateRefT JpCasesInfoMap DiscrM Unit := do match code with | .let _ k => go k | .fun decl k => go decl.value; go k @@ -72,11 +73,14 @@ where if let some paramIdx ← isJpCases? decl then modify fun s => s.insert decl.fvarId { paramIdx } go decl.value; go k - | .cases c => c.alts.forM fun alt => go alt.getCode + | .cases c => c.alts.forM fun alt => + match alt with + | .default k => go k + | .alt ctorName ps k => withDiscrCtor c.discr ctorName ps <| go k | .return .. | .unreach .. => return () | .jmp fvarId args => if let some info := (← get).find? fvarId then - let arg ← findExpr args[info.paramIdx]! + let arg ← findCtor args[info.paramIdx]! let some (cval, _) := arg.constructorApp? (← getEnv) | return () modify fun map => map.insert fvarId <| { info with ctorNames := info.ctorNames.insert cval.name } @@ -195,9 +199,9 @@ partial def simpJpCases? (code : Code) : CompilerM (Option Code) := do for (fvarId, info) in map.toList do msg := msg ++ indentD m!"{mkFVar fvarId} ↦ {info.ctorNames.toList}" return msg - visit code map |>.run' {} + visit code map |>.run' {} |>.run {} where - visit (code : Code) : ReaderT JpCasesInfoMap (StateRefT Ctor2JpCasesAlt CompilerM) Code := do + visit (code : Code) : ReaderT JpCasesInfoMap (StateRefT Ctor2JpCasesAlt DiscrM) Code := do match code with | .let decl k => return code.updateLet! decl (← visit k) @@ -213,14 +217,19 @@ where let decl ← decl.updateValue value return code.updateFun! decl (← visit k) | .cases c => - let alts ← c.alts.mapMonoM fun alt => return alt.updateCode (← visit alt.getCode) + let alts ← c.alts.mapMonoM fun alt => + match alt with + | .alt ctorName ps k => + withDiscrCtor c.discr ctorName ps do + return alt.updateCode (← visit k) + | .default k => return alt.updateCode (← visit k) return code.updateAlts! alts | .return _ | .unreach _ => return code | .jmp fvarId args => let some code ← visitJmp? fvarId args | return code return code - visitJp? (decl : FunDecl) (k : Code) : ReaderT JpCasesInfoMap (StateRefT Ctor2JpCasesAlt CompilerM) (Option Code) := do + visitJp? (decl : FunDecl) (k : Code) : ReaderT JpCasesInfoMap (StateRefT Ctor2JpCasesAlt DiscrM) (Option Code) := do let some info := (← read).find? decl.fvarId | return none if info.ctorNames.isEmpty then return none -- This join point satisfies `isJpCases?` and there are jumps with constructors in `info` to it. @@ -245,7 +254,7 @@ where else altsNew := altsNew.push (alt.updateCode k) | .alt ctorName fields k => - let k ← visit k + let k ← withDiscrCtor cases.discr ctorName fields <| visit k if info.ctorNames.contains ctorName then let jpAlt ← mkJpAlt decls decl.params info.paramIdx fields k (default := false) jpAltDecls := jpAltDecls.push (.jp jpAlt.decl) @@ -261,10 +270,10 @@ where let code := .jp decl (← visit k) return LCNF.attachCodeDecls jpAltDecls code - visitJmp? (fvarId : FVarId) (args : Array Expr) : ReaderT JpCasesInfoMap (StateRefT Ctor2JpCasesAlt CompilerM) (Option Code) := do + visitJmp? (fvarId : FVarId) (args : Array Expr) : ReaderT JpCasesInfoMap (StateRefT Ctor2JpCasesAlt DiscrM) (Option Code) := do let some ctorJpAltMap := (← get).find? fvarId | return none let some info := (← read).find? fvarId | return none - let arg ← findExpr args[info.paramIdx]! + let arg ← findCtor args[info.paramIdx]! let some (ctorVal, ctorArgs) := arg.constructorApp? (← getEnv) (useRaw := true) | return none let some jpAlt := ctorJpAltMap.find? ctorVal.name | return none let fields := ctorArgs[ctorVal.numParams:] diff --git a/tests/lean/jpCasesDiscrM.lean b/tests/lean/jpCasesDiscrM.lean new file mode 100644 index 0000000000..6bdced3a7b --- /dev/null +++ b/tests/lean/jpCasesDiscrM.lean @@ -0,0 +1,35 @@ +set_option trace.Compiler.saveBase true in +def f1 (c : Bool) (a b : Nat) := + let k d y z := + match d with + | true => y + z + z*y + | false => z + y + match c with + | true => k true a b + | false => k false b a + +set_option trace.Compiler.saveBase true in +def f2 (c : Bool) (a b : Nat) := + let k d y z := + match d with + | true => y + z + z*y + | false => z + y + match c with + | true => k c a b + | false => k c b a + +inductive C where + | c1 | c2 | c3 | c4 + +set_option trace.Compiler.saveBase true in +def f3 (c c' : C) (a b : Nat) := + let k y z (d : C) := + match d with + | C.c1 => y + z + z*y + | C.c3 => y*y+a + | _ => z + y + y + match c with + | .c1 => k a b c + | .c2 => k b b c + | .c3 => k b a c' + | .c4 => k a a c' diff --git a/tests/lean/jpCasesDiscrM.lean.expected.out b/tests/lean/jpCasesDiscrM.lean.expected.out new file mode 100644 index 0000000000..12852f2317 --- /dev/null +++ b/tests/lean/jpCasesDiscrM.lean.expected.out @@ -0,0 +1,52 @@ +[Compiler.saveBase] size: 7 + def f1 c a b := + cases c + | Bool.false => + let _x.1 := Nat.add a b + _x.1 + | Bool.true => + let _x.2 := Nat.add a b + let _x.3 := Nat.mul b a + let _x.4 := Nat.add _x.2 _x.3 + _x.4 +[Compiler.saveBase] size: 7 + def f2 c a b := + cases c + | Bool.false => + let _x.1 := Nat.add a b + _x.1 + | Bool.true => + let _x.2 := Nat.add a b + let _x.3 := Nat.mul b a + let _x.4 := Nat.add _x.2 _x.3 + _x.4 +[Compiler.saveBase] size: 22 + def f3 c c' a b := + jp _jp.1 y z := + let _x.2 := Nat.add y z + let _x.3 := Nat.mul z y + let _x.4 := Nat.add _x.2 _x.3 + _x.4 + jp _jp.5 y z := + let _x.6 := Nat.add z y + let _x.7 := Nat.add _x.6 y + _x.7 + jp _jp.8 y z d := + cases d + | C.c1 => + goto _jp.1 y z + | C.c3 => + let _x.9 := Nat.mul y y + let _x.10 := Nat.add _x.9 a + _x.10 + | _ => + goto _jp.5 y z + cases c + | C.c1 => + goto _jp.1 a b + | C.c2 => + goto _jp.5 b b + | C.c3 => + goto _jp.8 b a c' + | C.c4 => + goto _jp.8 a a c'