diff --git a/src/Lean/Compiler/LCNF/Main.lean b/src/Lean/Compiler/LCNF/Main.lean index 45e217882e..2753b39933 100644 --- a/src/Lean/Compiler/LCNF/Main.lean +++ b/src/Lean/Compiler/LCNF/Main.lean @@ -50,7 +50,7 @@ def checkpoint (stepName : Name) (decls : Array Decl) : CompilerM Unit := do withOptions (fun opts => opts.setBool `pp.motives.pi false) do let clsName := `Compiler ++ stepName if (← Lean.isTracingEnabledFor clsName) then - Lean.addTrace clsName m!"size: {decl.size}\n{← ppDecl decl}" + Lean.addTrace clsName m!"size: {decl.size}\n{← ppDecl' decl}" if compiler.check.get (← getOptions) then decl.check if compiler.check.get (← getOptions) then diff --git a/src/Lean/Compiler/LCNF/Simp/JpCases.lean b/src/Lean/Compiler/LCNF/Simp/JpCases.lean index ee4c62a2d2..f24f671934 100644 --- a/src/Lean/Compiler/LCNF/Simp/JpCases.lean +++ b/src/Lean/Compiler/LCNF/Simp/JpCases.lean @@ -12,7 +12,7 @@ namespace Lean.Compiler.LCNF namespace Simp /-- -Given the function declaration `decl`, return `true` if it is of the form +Given the function declaration `decl`, return `some idx` if it is of the form ``` f y := ... /- This part is not bigger than smallThreshold. -/ @@ -20,55 +20,68 @@ f y := | ... => ... ... ``` +`idx` is the index of the parameter used in the `cases` statement. -/ -def isJpCases (decl : FunDecl) : CompilerM Bool := do - if decl.params.size != 1 then - return false +def isJpCases? (decl : FunDecl) : CompilerM (Option Nat) := do + if decl.params.size == 0 then + return none else - let param := decl.params[0]! let small := (← getConfig).smallThreshold - let rec go (code : Code) (prefixSize : Nat) : Bool := - prefixSize <= small && + let rec go (code : Code) (prefixSize : Nat) : Option Nat := + if prefixSize > small then none else match code with | .let _ k => go k (prefixSize + 1) /- TODO: we should have uniform heuristics for estimating the size. -/ - | .cases c => c.discr == param.fvarId - | _ => false + | .cases c => decl.params.findIdx? fun param => c.discr == param.fvarId + | _ => none return go decl.value 0 -abbrev JpCasesInfo := FVarIdMap NameSet +/-- +Information for join points that satisfy `isJpCases?` +-/ +structure JpCasesInfo where + /-- Parameter index returned by `isJpCases?`. This parameter is the one the join point is performing the case-split. -/ + paramIdx : Nat + /-- + Set of constructor names s.t. `ctorName` is in the set if there is a jump to the join point where the parameter + `paramIdx` is a constructor application. + -/ + ctorNames : NameSet := {} + deriving Inhabited + +abbrev JpCasesInfoMap := FVarIdMap JpCasesInfo /-- Return `true` if the collected information suggests opportunities for the `JpCases` optimization. -/ -def JpCasesInfo.isCandidate (info : JpCasesInfo) : Bool := - info.any fun _ s => !s.isEmpty +def JpCasesInfoMap.isCandidate (info : JpCasesInfoMap) : Bool := + info.any fun _ s => !s.ctorNames.isEmpty /-- -Return a map containing entries `jpFVarId ↦ ctorNames` where `jpFVarId` is the id of join point +Return a map containing entries `jpFVarId ↦ { paramIdx, ctorNames }` where `jpFVarId` is the id of join point in code that satisfies `isJpCases`, and `ctorNames` is a set of constructor names such that -there is a jump `.jmp jpFVarId #[x]` in `code` and `x` is a constructor application. +there is a jump `.jmp jpFVarId #[..., x, ...]` in `code` and `x` is a constructor application. +`paramIdx` is the index of the parameter -/ -partial def collectJpCasesInfo (code : Code) : CompilerM JpCasesInfo := do +partial def collectJpCasesInfo (code : Code) : CompilerM JpCasesInfoMap := do let (_, s) ← go code |>.run {} return s where - go (code : Code) : StateRefT JpCasesInfo CompilerM Unit := do + go (code : Code) : StateRefT JpCasesInfoMap CompilerM Unit := do match code with | .let _ k => go k | .fun decl k => go decl.value; go k | .jp decl k => - if (← isJpCases decl) then - modify fun s => s.insert decl.fvarId {} + if let some paramIdx ← isJpCases? decl then + modify fun s => s.insert decl.fvarId { paramIdx } go decl.value; go k | .cases c => c.alts.forM fun alt => go alt.getCode | .return .. | .unreach .. => return () | .jmp fvarId args => - if args.size == 1 then - if let some ctorNames := (← get).find? fvarId then - let arg ← findExpr args[0]! + if let some info := (← get).find? fvarId then + let arg ← findExpr args[info.paramIdx]! let some (cval, _) := arg.constructorApp? (← getEnv) | return () - modify fun s => s.insert fvarId <| ctorNames.insert cval.name + modify fun map => map.insert fvarId <| { info with ctorNames := info.ctorNames.insert cval.name } /-- -Extract the let-declarations and `cases` for a join point body that satisfies `isJpCases`. +Extract the let-declarations and `cases` for a join point body that satisfies `isJpCases?`. -/ private def extractJpCases (code : Code) : Array CodeDecl × Cases := go code #[] @@ -87,21 +100,49 @@ structure JpCasesAlt where abbrev Ctor2JpCasesAlt := FVarIdMap (NameMap JpCasesAlt) open Internalize in -private def mkJpAlt (decls : Array CodeDecl) (discr : Param) (fields : Array Param) (k : Code) (default : Bool) : CompilerM JpCasesAlt := do +/-- +Construct an auxiliary join point for a particular alternative in a join-point that satifies `isJpCases?`. +- `decls` is the prefix (before the `cases`). See `isJpCases?`. +- `params` are the parameters of the main join point that satisfies `isJpCases?`. +- `targetParamIdx` is the index of the parameter that we are expanding to `fields` +- `fields` are the fields/parameter of the alternative. +- `k` is the body of the alternative. +- `default` is true if it is a default alternative. +-/ +private def mkJpAlt (decls : Array CodeDecl) (params : Array Param) (targetParamIdx : Nat) (fields : Array Param) (k : Code) (default : Bool) : CompilerM JpCasesAlt := do go |>.run' {} where go : InternalizeM JpCasesAlt := do let s : FVarIdSet := {} let mut paramsNew := #[] - let dependsOnDiscr := k.dependsOn (s.insert discr.fvarId) - if dependsOnDiscr then - paramsNew := paramsNew.push (← internalizeParam discr) - paramsNew := paramsNew ++ (← fields.mapM internalizeParam) + let dependsOnDiscr := k.dependsOn (s.insert params[targetParamIdx]!.fvarId) + for i in [:params.size] do + let param := params[i]! + if targetParamIdx == i then + if dependsOnDiscr then + paramsNew := paramsNew.push (← internalizeParam param) + paramsNew := paramsNew ++ (← fields.mapM internalizeParam) + else + paramsNew := paramsNew.push (← internalizeParam param) let decls ← decls.mapM internalizeCodeDecl let k ← internalizeCode k let value := LCNF.attachCodeDecls decls k return { decl := (← mkAuxJpDecl paramsNew value), default, dependsOnDiscr } +/-- Create the arguments for a jump to an auxiliary join point created using `mkJpAlt`. -/ +private def mkJmpNewArgs (args : Array Expr) (targetParamIdx : Nat) (fields : Array Expr) (dependsOnTarget : Bool) : Array Expr := + if dependsOnTarget then + args[:targetParamIdx+1] ++ fields ++ args[targetParamIdx+1:] + else + args[:targetParamIdx] ++ fields ++ args[targetParamIdx+1:] + +/-- +Create the arguments for a jump to an auxiliary join point created using `mkJpAlt`. +This function is used to create jumps from the join point satisfying `isJpCases?` to the new auxiliary join points created using `mkJpAlt`. +-/ +private def mkJmpArgsAtJp (params : Array Param) (targetParamIdx : Nat) (fields : Array Param) (dependsOnTarget : Bool) : Array Expr := Id.run do + mkJmpNewArgs (params.map (mkFVar ·.fvarId)) targetParamIdx (fields.map (mkFVar ·.fvarId)) dependsOnTarget + /-- Try to optimize `jpCases` join points. We say a join point is a `jpCases` when it satifies the predicate `isJpCases`. @@ -147,16 +188,16 @@ Note that if all jumps to the join point are with constructors, then the join point is eliminated as dead code. -/ partial def simpJpCases? (code : Code) : CompilerM (Option Code) := do - let info ← collectJpCasesInfo code - unless info.isCandidate do return none + let map ← collectJpCasesInfo code + unless map.isCandidate do return none traceM `Compiler.simp.jpCases do let mut msg : MessageData := "candidates" - for (fvarId, ctorName) in info.toList do - msg := msg ++ indentD m!"{mkFVar fvarId} ↦ {ctorName.toList}" + for (fvarId, info) in map.toList do + msg := msg ++ indentD m!"{mkFVar fvarId} ↦ {info.ctorNames.toList}" return msg - visit code info |>.run' {} + visit code map |>.run' {} where - visit (code : Code) : ReaderT JpCasesInfo (StateRefT Ctor2JpCasesAlt CompilerM) Code := do + visit (code : Code) : ReaderT JpCasesInfoMap (StateRefT Ctor2JpCasesAlt CompilerM) Code := do match code with | .let decl k => return code.updateLet! decl (← visit k) @@ -179,11 +220,10 @@ where let some code ← visitJmp? fvarId args | return code return code - visitJp? (decl : FunDecl) (k : Code) : ReaderT JpCasesInfo (StateRefT Ctor2JpCasesAlt CompilerM) (Option Code) := do - let some s := (← read).find? decl.fvarId | return none - if s.isEmpty then return none - -- This join point satisfies `isJp` and there jumps with constructors in `s` to it. - let p := decl.params[0]! + visitJp? (decl : FunDecl) (k : Code) : ReaderT JpCasesInfoMap (StateRefT Ctor2JpCasesAlt CompilerM) (Option Code) := do + let some info := (← read).find? decl.fvarId | return none + if info.ctorNames.isEmpty then return none + -- This join point satisfies `isJpCases?` and there are jumps with constructors in `info` to it. let (decls, cases) := extractJpCases decl.value let mut jpAltMap := {} let mut jpAltDecls := #[] @@ -193,26 +233,24 @@ where | .default k => let k ← visit k let explicitCtorNames := cases.getCtorNames - if s.any fun ctorNameInJump => !explicitCtorNames.contains ctorNameInJump then - let jpAlt ← mkJpAlt decls p #[] k (default := true) + if info.ctorNames.any fun ctorNameInJump => !explicitCtorNames.contains ctorNameInJump then + let jpAlt ← mkJpAlt decls decl.params info.paramIdx #[] k (default := true) jpAltDecls := jpAltDecls.push (.jp jpAlt.decl) eraseCode k - for ctorNameInJmp in s do + for ctorNameInJmp in info.ctorNames do unless explicitCtorNames.contains ctorNameInJmp do jpAltMap := jpAltMap.insert ctorNameInJmp jpAlt - let args := if jpAlt.dependsOnDiscr then #[.fvar p.fvarId] else #[] + let args := mkJmpArgsAtJp decl.params info.paramIdx #[] jpAlt.dependsOnDiscr altsNew := altsNew.push (alt.updateCode (.jmp jpAlt.decl.fvarId args)) else altsNew := altsNew.push (alt.updateCode k) | .alt ctorName fields k => let k ← visit k - if s.contains ctorName then - let jpAlt ← mkJpAlt decls p fields k (default := false) + if info.ctorNames.contains ctorName then + let jpAlt ← mkJpAlt decls decl.params info.paramIdx fields k (default := false) jpAltDecls := jpAltDecls.push (.jp jpAlt.decl) jpAltMap := jpAltMap.insert ctorName jpAlt - let mut args := fields.map (mkFVar ·.fvarId) - if jpAlt.dependsOnDiscr then - args := #[mkFVar p.fvarId] ++ args + let args := mkJmpArgsAtJp decl.params info.paramIdx fields jpAlt.dependsOnDiscr eraseCode k altsNew := altsNew.push (alt.updateCode (.jmp jpAlt.decl.fvarId args)) else @@ -223,23 +261,18 @@ where let code := .jp decl (← visit k) return LCNF.attachCodeDecls jpAltDecls code - visitJmp? (fvarId : FVarId) (args : Array Expr) : ReaderT JpCasesInfo (StateRefT Ctor2JpCasesAlt CompilerM) (Option Code) := do + visitJmp? (fvarId : FVarId) (args : Array Expr) : ReaderT JpCasesInfoMap (StateRefT Ctor2JpCasesAlt CompilerM) (Option Code) := do let some ctorJpAltMap := (← get).find? fvarId | return none - assert! args.size == 1 - let arg ← findExpr args[0]! + let some info := (← read).find? fvarId | return none + let arg ← findExpr args[info.paramIdx]! let some (ctorVal, ctorArgs) := arg.constructorApp? (← getEnv) (useRaw := true) | return none let some jpAlt := ctorJpAltMap.find? ctorVal.name | return none - if jpAlt.default then - if jpAlt.dependsOnDiscr then - return some <| .jmp jpAlt.decl.fvarId args - else - return some <| .jmp jpAlt.decl.fvarId #[] - else - let fields := ctorArgs[ctorVal.numParams:] - -- Recall that if `arg` is a `Nat` literal, then `ctorArgs` is a literal too. - -- We use a for-loop because we may have other special cases in the future. - let mut auxDecls := #[] - let mut fieldsNew := #[] + let fields := ctorArgs[ctorVal.numParams:] + -- Recall that if `arg` is a `Nat` literal, then `ctorArgs` is a literal too. + -- We use a for-loop because we may have other special cases in the future. + let mut auxDecls := #[] + let mut fieldsNew := #[] + unless jpAlt.default do for field in fields do if field.isFVar then fieldsNew := fieldsNew.push field @@ -247,11 +280,8 @@ where let letDecl ← mkAuxLetDecl field auxDecls := auxDecls.push (CodeDecl.let letDecl) fieldsNew := fieldsNew.push (.fvar letDecl.fvarId) - let code ← if jpAlt.dependsOnDiscr then - pure <| .jmp jpAlt.decl.fvarId (args ++ fieldsNew) - else - pure <| .jmp jpAlt.decl.fvarId fieldsNew - return some <| LCNF.attachCodeDecls auxDecls code + let argsNew := mkJmpNewArgs args info.paramIdx fieldsNew jpAlt.dependsOnDiscr + return some <| LCNF.attachCodeDecls auxDecls (.jmp jpAlt.decl.fvarId argsNew) end Simp diff --git a/tests/lean/jpCasesNary.lean b/tests/lean/jpCasesNary.lean new file mode 100644 index 0000000000..7406761dd2 --- /dev/null +++ b/tests/lean/jpCasesNary.lean @@ -0,0 +1,47 @@ +set_option trace.Compiler.saveBase true in +def f1 (c : Bool) (a b : Nat) := + let k d y z := + match d with + | true => y + z + z*y + | false => z + y + match c with + | true => k false a b + | false => k true b a + +set_option trace.Compiler.saveBase true in +def f2 (c : Bool) (a b : Nat) := + let k y d z := + match d with + | true => y + z + z*y + | false => z + y + match c with + | true => k a false b + | false => k b true a + +inductive C where + | c1 | c2 | c3 | c4 + +set_option trace.Compiler.saveBase true in +def f3 (c c' : C) (a b : Nat) := + let k y (d : C) z := + match d with + | C.c1 => y + z + z*y + | _ => z + y + y + match c with + | .c1 => k a .c2 b + | .c2 => k b .c1 a + | .c3 => k b c' a + | .c4 => k a c' a + +set_option trace.Compiler.saveBase true in +def f4 (c c' : C) (a b : Nat) := + let k y z (d : C) := + match d with + | C.c1 => y + z + z*y + | C.c3 => y*y+a + | _ => z + y + y + match c with + | .c1 => k a b .c2 + | .c2 => k b b .c1 + | .c3 => k b a c' + | .c4 => k a a c' diff --git a/tests/lean/jpCasesNary.lean.expected.out b/tests/lean/jpCasesNary.lean.expected.out new file mode 100644 index 0000000000..f8f074d824 --- /dev/null +++ b/tests/lean/jpCasesNary.lean.expected.out @@ -0,0 +1,78 @@ +[Compiler.saveBase] size: 7 + def f1 c a b := + cases c + | Bool.false => + let _x.1 := Nat.add b a + let _x.2 := Nat.mul a b + let _x.3 := Nat.add _x.1 _x.2 + _x.3 + | Bool.true => + let _x.4 := Nat.add b a + _x.4 +[Compiler.saveBase] size: 7 + def f2 c a b := + cases c + | Bool.false => + let _x.1 := Nat.add b a + let _x.2 := Nat.mul a b + let _x.3 := Nat.add _x.1 _x.2 + _x.3 + | Bool.true => + let _x.4 := Nat.add b a + _x.4 +[Compiler.saveBase] size: 19 + def f3 c c' a b := + jp _jp.1 y z := + let _x.2 := Nat.add y z + let _x.3 := Nat.mul z y + let _x.4 := Nat.add _x.2 _x.3 + _x.4 + jp _jp.5 y z := + let _x.6 := Nat.add z y + let _x.7 := Nat.add _x.6 y + _x.7 + jp _jp.8 y d z := + cases d + | C.c1 => + goto _jp.1 y z + | _ => + goto _jp.5 y z + cases c + | C.c1 => + goto _jp.5 a b + | C.c2 => + goto _jp.1 b a + | C.c3 => + goto _jp.8 b c' a + | C.c4 => + goto _jp.8 a c' a +[Compiler.saveBase] size: 22 + def f4 c c' a b := + jp _jp.1 y z := + let _x.2 := Nat.add y z + let _x.3 := Nat.mul z y + let _x.4 := Nat.add _x.2 _x.3 + _x.4 + jp _jp.5 y z := + let _x.6 := Nat.add z y + let _x.7 := Nat.add _x.6 y + _x.7 + jp _jp.8 y z d := + cases d + | C.c1 => + goto _jp.1 y z + | C.c3 => + let _x.9 := Nat.mul y y + let _x.10 := Nat.add _x.9 a + _x.10 + | _ => + goto _jp.5 y z + cases c + | C.c1 => + goto _jp.5 a b + | C.c2 => + goto _jp.1 b b + | C.c3 => + goto _jp.8 b a c' + | C.c4 => + goto _jp.8 a a c'