feat: cases on cases for new LCNF simplifier
This commit is contained in:
parent
e8335240d8
commit
bc88b0307e
2 changed files with 94 additions and 122 deletions
|
|
@ -92,6 +92,9 @@ class MonadFVarSubst (m : Type → Type) where
|
|||
|
||||
export MonadFVarSubst (getSubst)
|
||||
|
||||
instance (m n) [MonadLift m n] [MonadFVarSubst m] : MonadFVarSubst n where
|
||||
getSubst := liftM (getSubst : m _)
|
||||
|
||||
@[inline] def normFVar [MonadFVarSubst m] [Monad m] (fvarId : FVarId) : m FVarId :=
|
||||
return normFVarImp (← getSubst) fvarId
|
||||
|
||||
|
|
|
|||
|
|
@ -73,6 +73,11 @@ def FunDeclInfoMap.addMustInline (s : FunDeclInfoMap) (fvarId : FVarId) : FunDec
|
|||
match s with
|
||||
| { map } => { map := map.insert fvarId .mustInline }
|
||||
|
||||
def FunDeclInfoMap.restore (s : FunDeclInfoMap) (fvarId : FVarId) (saved? : Option FunDeclInfo) : FunDeclInfoMap :=
|
||||
match s, saved? with
|
||||
| { map }, none => { map := map.erase fvarId }
|
||||
| { map }, some saved => { map := map.insert fvarId saved }
|
||||
|
||||
partial def findFunDecl? (e : Expr) : CompilerM (Option FunDecl) := do
|
||||
match e with
|
||||
| .fvar fvarId =>
|
||||
|
|
@ -149,6 +154,12 @@ def incInline : SimpM Unit :=
|
|||
def incInlineLocal : SimpM Unit :=
|
||||
modify fun s => { s with inlineLocal := s.inlineLocal + 1 }
|
||||
|
||||
def addMustInline (fvarId : FVarId) : SimpM Unit :=
|
||||
modify fun s => { s with funDeclInfoMap := s.funDeclInfoMap.addMustInline fvarId }
|
||||
|
||||
def addFunOcc (fvarId : FVarId) : SimpM Unit :=
|
||||
modify fun s => { s with funDeclInfoMap := s.funDeclInfoMap.add fvarId }
|
||||
|
||||
partial def updateFunDeclInfo (code : Code) (mustInline := false) : SimpM Unit :=
|
||||
go code
|
||||
where
|
||||
|
|
@ -157,15 +168,30 @@ where
|
|||
| .let decl k =>
|
||||
if decl.value.isApp then
|
||||
if let some funDecl ← findFunDecl? decl.value.getAppFn then
|
||||
modify fun s => { s with funDeclInfoMap := s.funDeclInfoMap.add funDecl.fvarId }
|
||||
addFunOcc funDecl.fvarId
|
||||
go k
|
||||
| .fun decl k =>
|
||||
if mustInline then
|
||||
modify fun s => { s with funDeclInfoMap := s.funDeclInfoMap.addMustInline decl.fvarId }
|
||||
addMustInline decl.fvarId
|
||||
go decl.value; go k
|
||||
| .jp decl k => go decl.value; go k
|
||||
| .cases c => c.alts.forM fun alt => go alt.getCode
|
||||
| .return .. | .jmp .. | .unreach .. => return ()
|
||||
| .jmp fvarId .. =>
|
||||
let funDecl ← getFunDecl fvarId
|
||||
addFunOcc funDecl.fvarId
|
||||
| .return .. | .unreach .. => return ()
|
||||
|
||||
/--
|
||||
Execute `x` with `fvarId` set as `mustInline`.
|
||||
After execution the original setting is restored.
|
||||
-/
|
||||
def withAddMustInline (fvarId : FVarId) (x : SimpM α) : SimpM α := do
|
||||
let saved? := (← get).funDeclInfoMap.map.find? fvarId
|
||||
try
|
||||
addMustInline fvarId
|
||||
x
|
||||
finally
|
||||
modify fun s => { s with funDeclInfoMap := s.funDeclInfoMap.restore fvarId saved? }
|
||||
|
||||
def isOnceOrMustInline (fvarId : FVarId) : SimpM Bool := do
|
||||
match (← get).funDeclInfoMap.map.find? fvarId with
|
||||
|
|
@ -298,6 +324,7 @@ Try to inline a join point.
|
|||
partial def inlineJp? (fvarId : FVarId) (args : Array Expr) : SimpM (Option Code) := do
|
||||
let some decl ← LCNF.findFunDecl? fvarId | return none
|
||||
unless (← shouldInlineLocal decl) do return none
|
||||
markSimplified
|
||||
betaReduce decl.params decl.value args
|
||||
|
||||
def markUsedFVar (fvarId : FVarId) : SimpM Unit :=
|
||||
|
|
@ -409,6 +436,56 @@ where
|
|||
| .return fvarId => visit (.fvar fvarId) projs
|
||||
| _ => failure
|
||||
|
||||
/--
|
||||
Return `some _jp.k` if the given `cases` is of the form
|
||||
```
|
||||
cases _x.i
|
||||
(... let _x.j₁ := ctorⱼ₁ ...; _jp.k _x.j₁)
|
||||
...
|
||||
(... let _x.jₙ := ctorⱼₙ ...; _jp.k _x.jₙ)
|
||||
```
|
||||
where `_jp.k` is a join point of the form
|
||||
```
|
||||
let _jp.k y :=
|
||||
cases y ...
|
||||
```
|
||||
The goal is to mark `_jp.k` as must inline in this scenarion.
|
||||
Example: consider the following declarations
|
||||
```
|
||||
@[inline] def pred? (x : Nat) : Option Nat :=
|
||||
match x with
|
||||
| 0 => none
|
||||
| x+1 => some x
|
||||
|
||||
def isZero (x : Nat) :=
|
||||
match pred? x with
|
||||
| some _ => false
|
||||
| none => true
|
||||
```
|
||||
After inlining `pred?` in `isZero`, this simplification is applicable, producing
|
||||
|
||||
Remark: this method does not assume `cases` has already been normalized,
|
||||
but returns a normalized `FVarId` in case of success.
|
||||
-/
|
||||
def isCasesOnCases? (cases : Cases) : OptionT SimpM FVarId := do
|
||||
let jpFirst ← isCtorJmp? cases.alts[0]!.getCode
|
||||
let funDecl ← getFunDecl jpFirst
|
||||
guard <| funDecl.value matches .cases ..
|
||||
for alt in cases.alts[1:] do
|
||||
let jp ← isCtorJmp? alt.getCode
|
||||
guard <| jpFirst == jp
|
||||
return jpFirst
|
||||
where
|
||||
isCtorJmp? (code : Code) : OptionT SimpM FVarId := do
|
||||
match code with
|
||||
| .let _ k | .jp _ k | .fun _ k => isCtorJmp? k
|
||||
| .return .. | .unreach .. | .cases .. => failure
|
||||
| .jmp jpFVarId args =>
|
||||
let #[arg] := args | failure
|
||||
let arg ← findExpr (← normExpr arg)
|
||||
guard <| arg.isConstructorApp (← getEnv)
|
||||
normFVar jpFVarId
|
||||
|
||||
def findCtor (e : Expr) : SimpM Expr := do
|
||||
-- TODO: add support for mapping discriminants to constructors in branches
|
||||
findExpr e
|
||||
|
|
@ -559,12 +636,16 @@ partial def simp (code : Code) : SimpM Code := do
|
|||
if let some k ← simpCasesOnCtor? c then
|
||||
return k
|
||||
else
|
||||
-- TODO: other cases simplifications
|
||||
let discr ← normFVar c.discr
|
||||
let resultType ← normExpr c.resultType
|
||||
markUsedFVar discr
|
||||
let alts ← c.alts.mapMonoM fun alt => return alt.updateCode (← simp alt.getCode)
|
||||
return code.updateCases! resultType discr alts
|
||||
let simpCasesDefault := do
|
||||
let discr ← normFVar c.discr
|
||||
let resultType ← normExpr c.resultType
|
||||
markUsedFVar discr
|
||||
let alts ← c.alts.mapMonoM fun alt => return alt.updateCode (← simp alt.getCode)
|
||||
return code.updateCases! resultType discr alts
|
||||
if let some jpFVarId ← isCasesOnCases? c then
|
||||
withAddMustInline jpFVarId simpCasesDefault
|
||||
else
|
||||
simpCasesDefault
|
||||
|
||||
end
|
||||
|
||||
|
|
@ -600,116 +681,4 @@ builtin_initialize
|
|||
registerTraceClass `Compiler.simp.step.new
|
||||
registerTraceClass `Compiler.simp.projInst
|
||||
|
||||
end Lean.Compiler.LCNF
|
||||
|
||||
#exit -- TODO: port rest of file
|
||||
|
||||
namespace Lean.Compiler
|
||||
namespace Simp
|
||||
|
||||
/--
|
||||
Try "cases on cases" simplification.
|
||||
If `casesFn args` is of the form
|
||||
```
|
||||
casesOn _x.i
|
||||
(... let _x.j₁ := ctorⱼ₁ ...; _jp.k _x.j₁)
|
||||
...
|
||||
(... let _x.jₙ := ctorⱼₙ ...; _jp.k _x.jₙ)
|
||||
```
|
||||
where `_jp.k` is a join point of the form
|
||||
```
|
||||
let _jp.k := fun y =>
|
||||
casesOn y ...
|
||||
```
|
||||
Then, inline `_jp.k`. The idea is to force the `casesOn` application in the join point to
|
||||
reduce after the inlining step.
|
||||
Example: consider the following declarations
|
||||
```
|
||||
@[inline] def pred? (x : Nat) : Option Nat :=
|
||||
match x with
|
||||
| 0 => none
|
||||
| x+1 => some x
|
||||
|
||||
def isZero (x : Nat) :=
|
||||
match pred? x with
|
||||
| some _ => false
|
||||
| none => true
|
||||
```
|
||||
After inlining `pred?` in `isZero`, we have
|
||||
```
|
||||
let _jp.1 := fun y : Option Nat =>
|
||||
casesOn y true (fun y => false)
|
||||
casesOn x
|
||||
(let _x.1 := none; _jp.1 _x.1)
|
||||
(fun n => let _x.2 := some n; _jp.1 _x.2)
|
||||
```
|
||||
and this simplification is applicable, producing
|
||||
```
|
||||
casesOn x true (fun n => false)
|
||||
```
|
||||
-/
|
||||
def simpCasesOnCases? (casesInfo : CasesInfo) (casesFn : Expr) (args : Array Expr) : OptionT SimpM Expr := do
|
||||
let mut jpFirst? := none
|
||||
for i in casesInfo.altsRange do
|
||||
let alt := args[i]!
|
||||
let jp ← isJpCtor? alt
|
||||
if let some jpFirst := jpFirst? then
|
||||
guard <| jp == jpFirst
|
||||
else
|
||||
let some localDecl ← findDecl? jp | failure
|
||||
let .lam _ _ jpBody _ := localDecl.value | failure
|
||||
guard (← isCasesApp? jpBody).isSome
|
||||
jpFirst? := jp
|
||||
let some jpFVarId := jpFirst? | failure
|
||||
let some localDecl ← findDecl? jpFVarId | failure
|
||||
let .lam _ _ jpBody _ := localDecl.value | failure
|
||||
let mut args := args
|
||||
for i in casesInfo.altsRange do
|
||||
args := args.modify i (inlineJp · jpBody)
|
||||
return mkAppN casesFn args
|
||||
where
|
||||
isJpCtor? (alt : Expr) : OptionT SimpM FVarId := do
|
||||
match alt with
|
||||
| .lam _ _ b _ => isJpCtor? b
|
||||
| .letE _ _ v b _ => match b with
|
||||
| .letE .. => isJpCtor? b
|
||||
| .app (.fvar fvarId) (.bvar 0) =>
|
||||
let some localDecl ← findDecl? fvarId | failure
|
||||
guard localDecl.isJp
|
||||
guard <| v.isConstructorApp (← getEnv)
|
||||
return fvarId
|
||||
| _ => failure
|
||||
| _ => failure
|
||||
|
||||
inlineJp (alt : Expr) (jpBody : Expr) : Expr :=
|
||||
match alt with
|
||||
| .lam n d b bi => .lam n d (inlineJp b jpBody) bi
|
||||
| .letE n t v b nd => .letE n t v (inlineJp b jpBody) nd
|
||||
| _ => jpBody
|
||||
|
||||
mutual
|
||||
|
||||
partial def visitCases (casesInfo : CasesInfo) (e : Expr) : SimpM Expr := do
|
||||
let f := e.getAppFn
|
||||
let mut args := e.getAppArgs
|
||||
let major := args[casesInfo.discrsRange.stop - 1]!
|
||||
let major ← findExpr major
|
||||
if let some (ctorVal, ctorArgs) := major.constructorApp? (← getEnv) then
|
||||
/- Simplify `casesOn` constructor -/
|
||||
let ctorIdx := ctorVal.cidx
|
||||
let alt := args[casesInfo.altsRange.start + ctorIdx]!
|
||||
let ctorFields := ctorArgs[ctorVal.numParams:]
|
||||
let alt := alt.beta ctorFields
|
||||
assert! !alt.isLambda
|
||||
markSimplified
|
||||
visitLet alt
|
||||
else if let some e ← simpCasesOnCases? casesInfo f args then
|
||||
visitCases casesInfo e
|
||||
else
|
||||
for i in casesInfo.altsRange do
|
||||
args ← args.modifyM i (visitLambda · (checkEmptyTypes := true))
|
||||
return mkAppN f args
|
||||
|
||||
end Simp
|
||||
|
||||
end Lean.Compiler
|
||||
end Lean.Compiler.LCNF
|
||||
Loading…
Add table
Reference in a new issue