From bc88b0307e990e09afa567ab3ba2e6567d567664 Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Sat, 3 Sep 2022 07:51:36 -0700 Subject: [PATCH] feat: `cases` on `cases` for new LCNF simplifier --- src/Lean/Compiler/LCNF/CompilerM.lean | 3 + src/Lean/Compiler/LCNF/Simp.lean | 213 +++++++++++--------------- 2 files changed, 94 insertions(+), 122 deletions(-) diff --git a/src/Lean/Compiler/LCNF/CompilerM.lean b/src/Lean/Compiler/LCNF/CompilerM.lean index df9558f403..7d94c8a606 100644 --- a/src/Lean/Compiler/LCNF/CompilerM.lean +++ b/src/Lean/Compiler/LCNF/CompilerM.lean @@ -92,6 +92,9 @@ class MonadFVarSubst (m : Type → Type) where export MonadFVarSubst (getSubst) +instance (m n) [MonadLift m n] [MonadFVarSubst m] : MonadFVarSubst n where + getSubst := liftM (getSubst : m _) + @[inline] def normFVar [MonadFVarSubst m] [Monad m] (fvarId : FVarId) : m FVarId := return normFVarImp (← getSubst) fvarId diff --git a/src/Lean/Compiler/LCNF/Simp.lean b/src/Lean/Compiler/LCNF/Simp.lean index c38b49172c..478a6d64f5 100644 --- a/src/Lean/Compiler/LCNF/Simp.lean +++ b/src/Lean/Compiler/LCNF/Simp.lean @@ -73,6 +73,11 @@ def FunDeclInfoMap.addMustInline (s : FunDeclInfoMap) (fvarId : FVarId) : FunDec match s with | { map } => { map := map.insert fvarId .mustInline } +def FunDeclInfoMap.restore (s : FunDeclInfoMap) (fvarId : FVarId) (saved? : Option FunDeclInfo) : FunDeclInfoMap := + match s, saved? with + | { map }, none => { map := map.erase fvarId } + | { map }, some saved => { map := map.insert fvarId saved } + partial def findFunDecl? (e : Expr) : CompilerM (Option FunDecl) := do match e with | .fvar fvarId => @@ -149,6 +154,12 @@ def incInline : SimpM Unit := def incInlineLocal : SimpM Unit := modify fun s => { s with inlineLocal := s.inlineLocal + 1 } +def addMustInline (fvarId : FVarId) : SimpM Unit := + modify fun s => { s with funDeclInfoMap := s.funDeclInfoMap.addMustInline fvarId } + +def addFunOcc (fvarId : FVarId) : SimpM Unit := + modify fun s => { s with funDeclInfoMap := s.funDeclInfoMap.add fvarId } + partial def updateFunDeclInfo (code : Code) (mustInline := false) : SimpM Unit := go code where @@ -157,15 +168,30 @@ where | .let decl k => if decl.value.isApp then if let some funDecl ← findFunDecl? decl.value.getAppFn then - modify fun s => { s with funDeclInfoMap := s.funDeclInfoMap.add funDecl.fvarId } + addFunOcc funDecl.fvarId go k | .fun decl k => if mustInline then - modify fun s => { s with funDeclInfoMap := s.funDeclInfoMap.addMustInline decl.fvarId } + addMustInline decl.fvarId go decl.value; go k | .jp decl k => go decl.value; go k | .cases c => c.alts.forM fun alt => go alt.getCode - | .return .. | .jmp .. | .unreach .. => return () + | .jmp fvarId .. => + let funDecl ← getFunDecl fvarId + addFunOcc funDecl.fvarId + | .return .. | .unreach .. => return () + +/-- +Execute `x` with `fvarId` set as `mustInline`. +After execution the original setting is restored. +-/ +def withAddMustInline (fvarId : FVarId) (x : SimpM α) : SimpM α := do + let saved? := (← get).funDeclInfoMap.map.find? fvarId + try + addMustInline fvarId + x + finally + modify fun s => { s with funDeclInfoMap := s.funDeclInfoMap.restore fvarId saved? } def isOnceOrMustInline (fvarId : FVarId) : SimpM Bool := do match (← get).funDeclInfoMap.map.find? fvarId with @@ -298,6 +324,7 @@ Try to inline a join point. partial def inlineJp? (fvarId : FVarId) (args : Array Expr) : SimpM (Option Code) := do let some decl ← LCNF.findFunDecl? fvarId | return none unless (← shouldInlineLocal decl) do return none + markSimplified betaReduce decl.params decl.value args def markUsedFVar (fvarId : FVarId) : SimpM Unit := @@ -409,6 +436,56 @@ where | .return fvarId => visit (.fvar fvarId) projs | _ => failure +/-- +Return `some _jp.k` if the given `cases` is of the form +``` +cases _x.i + (... let _x.j₁ := ctorⱼ₁ ...; _jp.k _x.j₁) + ... + (... let _x.jₙ := ctorⱼₙ ...; _jp.k _x.jₙ) +``` +where `_jp.k` is a join point of the form +``` +let _jp.k y := + cases y ... +``` +The goal is to mark `_jp.k` as must inline in this scenarion. +Example: consider the following declarations +``` +@[inline] def pred? (x : Nat) : Option Nat := + match x with + | 0 => none + | x+1 => some x + +def isZero (x : Nat) := + match pred? x with + | some _ => false + | none => true +``` +After inlining `pred?` in `isZero`, this simplification is applicable, producing + +Remark: this method does not assume `cases` has already been normalized, +but returns a normalized `FVarId` in case of success. +-/ +def isCasesOnCases? (cases : Cases) : OptionT SimpM FVarId := do + let jpFirst ← isCtorJmp? cases.alts[0]!.getCode + let funDecl ← getFunDecl jpFirst + guard <| funDecl.value matches .cases .. + for alt in cases.alts[1:] do + let jp ← isCtorJmp? alt.getCode + guard <| jpFirst == jp + return jpFirst +where + isCtorJmp? (code : Code) : OptionT SimpM FVarId := do + match code with + | .let _ k | .jp _ k | .fun _ k => isCtorJmp? k + | .return .. | .unreach .. | .cases .. => failure + | .jmp jpFVarId args => + let #[arg] := args | failure + let arg ← findExpr (← normExpr arg) + guard <| arg.isConstructorApp (← getEnv) + normFVar jpFVarId + def findCtor (e : Expr) : SimpM Expr := do -- TODO: add support for mapping discriminants to constructors in branches findExpr e @@ -559,12 +636,16 @@ partial def simp (code : Code) : SimpM Code := do if let some k ← simpCasesOnCtor? c then return k else - -- TODO: other cases simplifications - let discr ← normFVar c.discr - let resultType ← normExpr c.resultType - markUsedFVar discr - let alts ← c.alts.mapMonoM fun alt => return alt.updateCode (← simp alt.getCode) - return code.updateCases! resultType discr alts + let simpCasesDefault := do + let discr ← normFVar c.discr + let resultType ← normExpr c.resultType + markUsedFVar discr + let alts ← c.alts.mapMonoM fun alt => return alt.updateCode (← simp alt.getCode) + return code.updateCases! resultType discr alts + if let some jpFVarId ← isCasesOnCases? c then + withAddMustInline jpFVarId simpCasesDefault + else + simpCasesDefault end @@ -600,116 +681,4 @@ builtin_initialize registerTraceClass `Compiler.simp.step.new registerTraceClass `Compiler.simp.projInst -end Lean.Compiler.LCNF - -#exit -- TODO: port rest of file - -namespace Lean.Compiler -namespace Simp - -/-- -Try "cases on cases" simplification. -If `casesFn args` is of the form -``` -casesOn _x.i - (... let _x.j₁ := ctorⱼ₁ ...; _jp.k _x.j₁) - ... - (... let _x.jₙ := ctorⱼₙ ...; _jp.k _x.jₙ) -``` -where `_jp.k` is a join point of the form -``` -let _jp.k := fun y => - casesOn y ... -``` -Then, inline `_jp.k`. The idea is to force the `casesOn` application in the join point to -reduce after the inlining step. -Example: consider the following declarations -``` -@[inline] def pred? (x : Nat) : Option Nat := - match x with - | 0 => none - | x+1 => some x - -def isZero (x : Nat) := - match pred? x with - | some _ => false - | none => true -``` -After inlining `pred?` in `isZero`, we have -``` -let _jp.1 := fun y : Option Nat => - casesOn y true (fun y => false) -casesOn x - (let _x.1 := none; _jp.1 _x.1) - (fun n => let _x.2 := some n; _jp.1 _x.2) -``` -and this simplification is applicable, producing -``` -casesOn x true (fun n => false) -``` --/ -def simpCasesOnCases? (casesInfo : CasesInfo) (casesFn : Expr) (args : Array Expr) : OptionT SimpM Expr := do - let mut jpFirst? := none - for i in casesInfo.altsRange do - let alt := args[i]! - let jp ← isJpCtor? alt - if let some jpFirst := jpFirst? then - guard <| jp == jpFirst - else - let some localDecl ← findDecl? jp | failure - let .lam _ _ jpBody _ := localDecl.value | failure - guard (← isCasesApp? jpBody).isSome - jpFirst? := jp - let some jpFVarId := jpFirst? | failure - let some localDecl ← findDecl? jpFVarId | failure - let .lam _ _ jpBody _ := localDecl.value | failure - let mut args := args - for i in casesInfo.altsRange do - args := args.modify i (inlineJp · jpBody) - return mkAppN casesFn args -where - isJpCtor? (alt : Expr) : OptionT SimpM FVarId := do - match alt with - | .lam _ _ b _ => isJpCtor? b - | .letE _ _ v b _ => match b with - | .letE .. => isJpCtor? b - | .app (.fvar fvarId) (.bvar 0) => - let some localDecl ← findDecl? fvarId | failure - guard localDecl.isJp - guard <| v.isConstructorApp (← getEnv) - return fvarId - | _ => failure - | _ => failure - - inlineJp (alt : Expr) (jpBody : Expr) : Expr := - match alt with - | .lam n d b bi => .lam n d (inlineJp b jpBody) bi - | .letE n t v b nd => .letE n t v (inlineJp b jpBody) nd - | _ => jpBody - -mutual - -partial def visitCases (casesInfo : CasesInfo) (e : Expr) : SimpM Expr := do - let f := e.getAppFn - let mut args := e.getAppArgs - let major := args[casesInfo.discrsRange.stop - 1]! - let major ← findExpr major - if let some (ctorVal, ctorArgs) := major.constructorApp? (← getEnv) then - /- Simplify `casesOn` constructor -/ - let ctorIdx := ctorVal.cidx - let alt := args[casesInfo.altsRange.start + ctorIdx]! - let ctorFields := ctorArgs[ctorVal.numParams:] - let alt := alt.beta ctorFields - assert! !alt.isLambda - markSimplified - visitLet alt - else if let some e ← simpCasesOnCases? casesInfo f args then - visitCases casesInfo e - else - for i in casesInfo.altsRange do - args ← args.modifyM i (visitLambda · (checkEmptyTypes := true)) - return mkAppN f args - -end Simp - -end Lean.Compiler +end Lean.Compiler.LCNF \ No newline at end of file