feat: cases on cases for new LCNF simplifier

This commit is contained in:
Leonardo de Moura 2022-09-03 07:51:36 -07:00
parent e8335240d8
commit bc88b0307e
2 changed files with 94 additions and 122 deletions

View file

@ -92,6 +92,9 @@ class MonadFVarSubst (m : Type → Type) where
export MonadFVarSubst (getSubst)
instance (m n) [MonadLift m n] [MonadFVarSubst m] : MonadFVarSubst n where
getSubst := liftM (getSubst : m _)
@[inline] def normFVar [MonadFVarSubst m] [Monad m] (fvarId : FVarId) : m FVarId :=
return normFVarImp (← getSubst) fvarId

View file

@ -73,6 +73,11 @@ def FunDeclInfoMap.addMustInline (s : FunDeclInfoMap) (fvarId : FVarId) : FunDec
match s with
| { map } => { map := map.insert fvarId .mustInline }
def FunDeclInfoMap.restore (s : FunDeclInfoMap) (fvarId : FVarId) (saved? : Option FunDeclInfo) : FunDeclInfoMap :=
match s, saved? with
| { map }, none => { map := map.erase fvarId }
| { map }, some saved => { map := map.insert fvarId saved }
partial def findFunDecl? (e : Expr) : CompilerM (Option FunDecl) := do
match e with
| .fvar fvarId =>
@ -149,6 +154,12 @@ def incInline : SimpM Unit :=
def incInlineLocal : SimpM Unit :=
modify fun s => { s with inlineLocal := s.inlineLocal + 1 }
def addMustInline (fvarId : FVarId) : SimpM Unit :=
modify fun s => { s with funDeclInfoMap := s.funDeclInfoMap.addMustInline fvarId }
def addFunOcc (fvarId : FVarId) : SimpM Unit :=
modify fun s => { s with funDeclInfoMap := s.funDeclInfoMap.add fvarId }
partial def updateFunDeclInfo (code : Code) (mustInline := false) : SimpM Unit :=
go code
where
@ -157,15 +168,30 @@ where
| .let decl k =>
if decl.value.isApp then
if let some funDecl ← findFunDecl? decl.value.getAppFn then
modify fun s => { s with funDeclInfoMap := s.funDeclInfoMap.add funDecl.fvarId }
addFunOcc funDecl.fvarId
go k
| .fun decl k =>
if mustInline then
modify fun s => { s with funDeclInfoMap := s.funDeclInfoMap.addMustInline decl.fvarId }
addMustInline decl.fvarId
go decl.value; go k
| .jp decl k => go decl.value; go k
| .cases c => c.alts.forM fun alt => go alt.getCode
| .return .. | .jmp .. | .unreach .. => return ()
| .jmp fvarId .. =>
let funDecl ← getFunDecl fvarId
addFunOcc funDecl.fvarId
| .return .. | .unreach .. => return ()
/--
Execute `x` with `fvarId` set as `mustInline`.
After execution the original setting is restored.
-/
def withAddMustInline (fvarId : FVarId) (x : SimpM α) : SimpM α := do
let saved? := (← get).funDeclInfoMap.map.find? fvarId
try
addMustInline fvarId
x
finally
modify fun s => { s with funDeclInfoMap := s.funDeclInfoMap.restore fvarId saved? }
def isOnceOrMustInline (fvarId : FVarId) : SimpM Bool := do
match (← get).funDeclInfoMap.map.find? fvarId with
@ -298,6 +324,7 @@ Try to inline a join point.
partial def inlineJp? (fvarId : FVarId) (args : Array Expr) : SimpM (Option Code) := do
let some decl ← LCNF.findFunDecl? fvarId | return none
unless (← shouldInlineLocal decl) do return none
markSimplified
betaReduce decl.params decl.value args
def markUsedFVar (fvarId : FVarId) : SimpM Unit :=
@ -409,6 +436,56 @@ where
| .return fvarId => visit (.fvar fvarId) projs
| _ => failure
/--
Return `some _jp.k` if the given `cases` is of the form
```
cases _x.i
(... let _x.j₁ := ctorⱼ₁ ...; _jp.k _x.j₁)
...
(... let _x.jₙ := ctorⱼₙ ...; _jp.k _x.jₙ)
```
where `_jp.k` is a join point of the form
```
let _jp.k y :=
cases y ...
```
The goal is to mark `_jp.k` as must inline in this scenarion.
Example: consider the following declarations
```
@[inline] def pred? (x : Nat) : Option Nat :=
match x with
| 0 => none
| x+1 => some x
def isZero (x : Nat) :=
match pred? x with
| some _ => false
| none => true
```
After inlining `pred?` in `isZero`, this simplification is applicable, producing
Remark: this method does not assume `cases` has already been normalized,
but returns a normalized `FVarId` in case of success.
-/
def isCasesOnCases? (cases : Cases) : OptionT SimpM FVarId := do
let jpFirst ← isCtorJmp? cases.alts[0]!.getCode
let funDecl ← getFunDecl jpFirst
guard <| funDecl.value matches .cases ..
for alt in cases.alts[1:] do
let jp ← isCtorJmp? alt.getCode
guard <| jpFirst == jp
return jpFirst
where
isCtorJmp? (code : Code) : OptionT SimpM FVarId := do
match code with
| .let _ k | .jp _ k | .fun _ k => isCtorJmp? k
| .return .. | .unreach .. | .cases .. => failure
| .jmp jpFVarId args =>
let #[arg] := args | failure
let arg ← findExpr (← normExpr arg)
guard <| arg.isConstructorApp (← getEnv)
normFVar jpFVarId
def findCtor (e : Expr) : SimpM Expr := do
-- TODO: add support for mapping discriminants to constructors in branches
findExpr e
@ -559,12 +636,16 @@ partial def simp (code : Code) : SimpM Code := do
if let some k ← simpCasesOnCtor? c then
return k
else
-- TODO: other cases simplifications
let discr ← normFVar c.discr
let resultType ← normExpr c.resultType
markUsedFVar discr
let alts ← c.alts.mapMonoM fun alt => return alt.updateCode (← simp alt.getCode)
return code.updateCases! resultType discr alts
let simpCasesDefault := do
let discr ← normFVar c.discr
let resultType ← normExpr c.resultType
markUsedFVar discr
let alts ← c.alts.mapMonoM fun alt => return alt.updateCode (← simp alt.getCode)
return code.updateCases! resultType discr alts
if let some jpFVarId ← isCasesOnCases? c then
withAddMustInline jpFVarId simpCasesDefault
else
simpCasesDefault
end
@ -600,116 +681,4 @@ builtin_initialize
registerTraceClass `Compiler.simp.step.new
registerTraceClass `Compiler.simp.projInst
end Lean.Compiler.LCNF
#exit -- TODO: port rest of file
namespace Lean.Compiler
namespace Simp
/--
Try "cases on cases" simplification.
If `casesFn args` is of the form
```
casesOn _x.i
(... let _x.j₁ := ctorⱼ₁ ...; _jp.k _x.j₁)
...
(... let _x.jₙ := ctorⱼₙ ...; _jp.k _x.jₙ)
```
where `_jp.k` is a join point of the form
```
let _jp.k := fun y =>
casesOn y ...
```
Then, inline `_jp.k`. The idea is to force the `casesOn` application in the join point to
reduce after the inlining step.
Example: consider the following declarations
```
@[inline] def pred? (x : Nat) : Option Nat :=
match x with
| 0 => none
| x+1 => some x
def isZero (x : Nat) :=
match pred? x with
| some _ => false
| none => true
```
After inlining `pred?` in `isZero`, we have
```
let _jp.1 := fun y : Option Nat =>
casesOn y true (fun y => false)
casesOn x
(let _x.1 := none; _jp.1 _x.1)
(fun n => let _x.2 := some n; _jp.1 _x.2)
```
and this simplification is applicable, producing
```
casesOn x true (fun n => false)
```
-/
def simpCasesOnCases? (casesInfo : CasesInfo) (casesFn : Expr) (args : Array Expr) : OptionT SimpM Expr := do
let mut jpFirst? := none
for i in casesInfo.altsRange do
let alt := args[i]!
let jp ← isJpCtor? alt
if let some jpFirst := jpFirst? then
guard <| jp == jpFirst
else
let some localDecl ← findDecl? jp | failure
let .lam _ _ jpBody _ := localDecl.value | failure
guard (← isCasesApp? jpBody).isSome
jpFirst? := jp
let some jpFVarId := jpFirst? | failure
let some localDecl ← findDecl? jpFVarId | failure
let .lam _ _ jpBody _ := localDecl.value | failure
let mut args := args
for i in casesInfo.altsRange do
args := args.modify i (inlineJp · jpBody)
return mkAppN casesFn args
where
isJpCtor? (alt : Expr) : OptionT SimpM FVarId := do
match alt with
| .lam _ _ b _ => isJpCtor? b
| .letE _ _ v b _ => match b with
| .letE .. => isJpCtor? b
| .app (.fvar fvarId) (.bvar 0) =>
let some localDecl ← findDecl? fvarId | failure
guard localDecl.isJp
guard <| v.isConstructorApp (← getEnv)
return fvarId
| _ => failure
| _ => failure
inlineJp (alt : Expr) (jpBody : Expr) : Expr :=
match alt with
| .lam n d b bi => .lam n d (inlineJp b jpBody) bi
| .letE n t v b nd => .letE n t v (inlineJp b jpBody) nd
| _ => jpBody
mutual
partial def visitCases (casesInfo : CasesInfo) (e : Expr) : SimpM Expr := do
let f := e.getAppFn
let mut args := e.getAppArgs
let major := args[casesInfo.discrsRange.stop - 1]!
let major ← findExpr major
if let some (ctorVal, ctorArgs) := major.constructorApp? (← getEnv) then
/- Simplify `casesOn` constructor -/
let ctorIdx := ctorVal.cidx
let alt := args[casesInfo.altsRange.start + ctorIdx]!
let ctorFields := ctorArgs[ctorVal.numParams:]
let alt := alt.beta ctorFields
assert! !alt.isLambda
markSimplified
visitLet alt
else if let some e ← simpCasesOnCases? casesInfo f args then
visitCases casesInfo e
else
for i in casesInfo.altsRange do
args ← args.modifyM i (visitLambda · (checkEmptyTypes := true))
return mkAppN f args
end Simp
end Lean.Compiler
end Lean.Compiler.LCNF