feat: add bindCases

It is similar to `Code.bind` but has special support for `inlineMatcher`
This commit is contained in:
Leonardo de Moura 2022-09-04 19:04:21 -07:00
parent d0600b3750
commit e0197b4e09
4 changed files with 123 additions and 5 deletions

View file

@ -35,9 +35,7 @@ where
| .default k => return .default (← go k)
if alts.isEmpty then
throwError "`Code.bind` failed, empty `cases` found"
let mut resultType ← alts[0]!.inferType
for alt in alts[1:] do
resultType := joinTypes resultType (← alt.inferType)
let resultType ← mkCasesResultType alts
return .cases { c with alts, resultType }
| .return fvarId => f fvarId
| .jmp fvarId .. =>

View file

@ -294,6 +294,9 @@ end
@[inline] def normCode [MonadLiftT CompilerM m] [Monad m] [MonadFVarSubst m] (code : Code) : m Code := do
normCodeImp code (← getSubst)
def replaceExprFVars (e : Expr) (s : FVarSubst) : CompilerM Expr :=
(normExpr e : ReaderT FVarSubst CompilerM Expr).run s
def replaceFVars (code : Code) (s : FVarSubst) : CompilerM Code :=
(normCode code : ReaderT FVarSubst CompilerM Code).run s

View file

@ -219,4 +219,12 @@ where
return type
termination_by go i => params.size - i
def mkCasesResultType (alts : Array Alt) : CompilerM Expr := do
if alts.isEmpty then
throwError "`Code.bind` failed, empty `cases` found"
let mut resultType ← alts[0]!.inferType
for alt in alts[1:] do
resultType := joinTypes resultType (← alt.inferType)
return resultType
end Lean.Compiler.LCNF

View file

@ -36,6 +36,116 @@ inductive Element where
| unreach (fvarId : FVarId)
deriving Inhabited
/-- State for `BindCasesM` monad -/
structure BindCasesM.State where
/-- New auxiliary join points. -/
jps : Array FunDecl := #[]
/-- Mapping from `_alt.<idx>` variables to new join points -/
map : FVarIdMap FVarId := {}
/-- Auxiliary monad for implementing `bindCases` -/
abbrev BindCasesM := StateRefT BindCasesM.State CompilerM
/--
This method returns code that at each exit point of `cases`, it jumps to `jpDecl`.
It is similar to `Code.bind`, but we add special support for `inlineMatcher`.
The `inlineMatcher` function inlines the auxiliary `_match_<idx>` declarations.
To make sure there is no code duplication, `inlineMatcher` creates auxiliary declarations `_alt.<idx>`.
We can say the `_alt.<idx>` declarations are pre join points. For each auxiliary declaration used at
an exit point of `cases`, this method creates an new auxiliary join point that invokes `_alt.<idx>`,
and then jumps to `jpDecl`. The goal is to make sure the auxiliary join point is the only occurrence
of `_alt.<idx>`, then `simp` will inline it.
That is, our goal is to try to promote the pre join points `_alt.<idx>` into a proper join point.
-/
partial def bindCases (jpDecl : FunDecl) (cases : Cases) : CompilerM Code := do
let (alts, s) ← visitAlts cases.alts |>.run {}
let resultType ← mkCasesResultType alts
let mut result := .cases { cases with alts, resultType }
for decl in s.jps do
result := .jp decl result
return .jp jpDecl result
where
visitAlts (alts : Array Alt) : BindCasesM (Array Alt) :=
alts.mapM fun alt => return alt.updateCode (← go alt.getCode)
findFun? (f : FVarId) : CompilerM (Option FunDecl) := do
if let some funDecl ← findFunDecl? f then
return funDecl
else if let .ldecl (value := .fvar f') .. ← getLocalDecl f then
findFun? f'
else
return none
go (code : Code) : BindCasesM Code := do
match code with
| .let decl k =>
if let .return fvarId := k then
/-
Check whether the current let-declaration is of the form
```
let _x := _alt.<idx> args
return _x
```
where `_alt.<idx>` is an auxiliary declaration created by `inlineMatcher`
-/
if decl.fvarId == fvarId && decl.value.isApp && decl.value.getAppFn.isFVar then
let f := decl.value.getAppFn.fvarId!
let localDecl ← getLocalDecl f
if localDecl.userName.getPrefix == `_alt then
if let some funDecl ← findFun? f then
let args := decl.value.getAppArgs
eraseFVar decl.fvarId
if let some altJp := (← get).map.find? f then
/- We already have an auxiliary join point for `f`, then, we just use it. -/
return .jmp altJp args
else
/-
We have not created a join point for `f` yet.
The join point has the form
```
jp altJp jpParams :=
let _x := f jpParams
jmp jpDecl _x
```
Then, we replace the current `let`-declaration with `jmp altJp args`
-/
let mut jpParams := #[]
let mut subst := {}
let mut jpArgs := #[]
/- Remark: `funDecl.params.size` may be greater than `args.size`. -/
for param in funDecl.params[:args.size] do
let type ← replaceExprFVars param.type subst
let paramNew ← mkAuxParam type
jpParams := jpParams.push paramNew
let arg := .fvar paramNew.fvarId
subst := subst.insert param.fvarId arg
jpArgs := jpArgs.push arg
let letDecl ← mkAuxLetDecl (mkAppN decl.value.getAppFn jpArgs)
let jpValue := .let letDecl (.jmp jpDecl.fvarId #[.fvar letDecl.fvarId])
let altJp ← mkAuxJpDecl jpParams jpValue
modify fun { jps, map } => {
jps := jps.push altJp
map := map.insert f altJp.fvarId
}
return .jmp altJp.fvarId args
return .let decl (← go k)
| .fun decl k => return .fun decl (← go k)
| .jp decl k =>
let value ← go decl.value
let type ← value.inferParamType decl.params
let decl ← decl.update' type value
return .jp decl (← go k)
| .cases c =>
let alts ← c.alts.mapM fun
| .alt ctorName params k => return .alt ctorName params (← go k)
| .default k => return .default (← go k)
if alts.isEmpty then
throwError "`Code.bind` failed, empty `cases` found"
let resultType ← mkCasesResultType alts
return .cases { c with alts, resultType }
| .return fvarId => return .jmp jpDecl.fvarId #[.fvar fvarId]
| .jmp .. | .unreach .. => return code
def seqToCode (seq : Array Element) (e : Expr) : CompilerM Code := do
if let .fvar fvarId := e then
go seq seq.size (.return fvarId)
@ -69,8 +179,7 @@ where
else
/- Create a join point for `c` and jump to it from `cases` -/
let jpDecl ← mkAuxJpDecl' fvarId c
let cases ← (Code.cases cases).bind fun fvarId => return .jmp jpDecl.fvarId #[.fvar fvarId]
go seq (i - 1) (.jp jpDecl cases)
go seq (i - 1) (← bindCases jpDecl cases)
else
return c