feat: add bindCases
It is similar to `Code.bind` but has special support for `inlineMatcher`
This commit is contained in:
parent
d0600b3750
commit
e0197b4e09
4 changed files with 123 additions and 5 deletions
|
|
@ -35,9 +35,7 @@ where
|
|||
| .default k => return .default (← go k)
|
||||
if alts.isEmpty then
|
||||
throwError "`Code.bind` failed, empty `cases` found"
|
||||
let mut resultType ← alts[0]!.inferType
|
||||
for alt in alts[1:] do
|
||||
resultType := joinTypes resultType (← alt.inferType)
|
||||
let resultType ← mkCasesResultType alts
|
||||
return .cases { c with alts, resultType }
|
||||
| .return fvarId => f fvarId
|
||||
| .jmp fvarId .. =>
|
||||
|
|
|
|||
|
|
@ -294,6 +294,9 @@ end
|
|||
@[inline] def normCode [MonadLiftT CompilerM m] [Monad m] [MonadFVarSubst m] (code : Code) : m Code := do
|
||||
normCodeImp code (← getSubst)
|
||||
|
||||
def replaceExprFVars (e : Expr) (s : FVarSubst) : CompilerM Expr :=
|
||||
(normExpr e : ReaderT FVarSubst CompilerM Expr).run s
|
||||
|
||||
def replaceFVars (code : Code) (s : FVarSubst) : CompilerM Code :=
|
||||
(normCode code : ReaderT FVarSubst CompilerM Code).run s
|
||||
|
||||
|
|
|
|||
|
|
@ -219,4 +219,12 @@ where
|
|||
return type
|
||||
termination_by go i => params.size - i
|
||||
|
||||
def mkCasesResultType (alts : Array Alt) : CompilerM Expr := do
|
||||
if alts.isEmpty then
|
||||
throwError "`Code.bind` failed, empty `cases` found"
|
||||
let mut resultType ← alts[0]!.inferType
|
||||
for alt in alts[1:] do
|
||||
resultType := joinTypes resultType (← alt.inferType)
|
||||
return resultType
|
||||
|
||||
end Lean.Compiler.LCNF
|
||||
|
|
|
|||
|
|
@ -36,6 +36,116 @@ inductive Element where
|
|||
| unreach (fvarId : FVarId)
|
||||
deriving Inhabited
|
||||
|
||||
/-- State for `BindCasesM` monad -/
|
||||
structure BindCasesM.State where
|
||||
/-- New auxiliary join points. -/
|
||||
jps : Array FunDecl := #[]
|
||||
/-- Mapping from `_alt.<idx>` variables to new join points -/
|
||||
map : FVarIdMap FVarId := {}
|
||||
|
||||
/-- Auxiliary monad for implementing `bindCases` -/
|
||||
abbrev BindCasesM := StateRefT BindCasesM.State CompilerM
|
||||
|
||||
/--
|
||||
This method returns code that at each exit point of `cases`, it jumps to `jpDecl`.
|
||||
It is similar to `Code.bind`, but we add special support for `inlineMatcher`.
|
||||
The `inlineMatcher` function inlines the auxiliary `_match_<idx>` declarations.
|
||||
To make sure there is no code duplication, `inlineMatcher` creates auxiliary declarations `_alt.<idx>`.
|
||||
We can say the `_alt.<idx>` declarations are pre join points. For each auxiliary declaration used at
|
||||
an exit point of `cases`, this method creates an new auxiliary join point that invokes `_alt.<idx>`,
|
||||
and then jumps to `jpDecl`. The goal is to make sure the auxiliary join point is the only occurrence
|
||||
of `_alt.<idx>`, then `simp` will inline it.
|
||||
That is, our goal is to try to promote the pre join points `_alt.<idx>` into a proper join point.
|
||||
-/
|
||||
partial def bindCases (jpDecl : FunDecl) (cases : Cases) : CompilerM Code := do
|
||||
let (alts, s) ← visitAlts cases.alts |>.run {}
|
||||
let resultType ← mkCasesResultType alts
|
||||
let mut result := .cases { cases with alts, resultType }
|
||||
for decl in s.jps do
|
||||
result := .jp decl result
|
||||
return .jp jpDecl result
|
||||
where
|
||||
visitAlts (alts : Array Alt) : BindCasesM (Array Alt) :=
|
||||
alts.mapM fun alt => return alt.updateCode (← go alt.getCode)
|
||||
|
||||
findFun? (f : FVarId) : CompilerM (Option FunDecl) := do
|
||||
if let some funDecl ← findFunDecl? f then
|
||||
return funDecl
|
||||
else if let .ldecl (value := .fvar f') .. ← getLocalDecl f then
|
||||
findFun? f'
|
||||
else
|
||||
return none
|
||||
|
||||
go (code : Code) : BindCasesM Code := do
|
||||
match code with
|
||||
| .let decl k =>
|
||||
if let .return fvarId := k then
|
||||
/-
|
||||
Check whether the current let-declaration is of the form
|
||||
```
|
||||
let _x := _alt.<idx> args
|
||||
return _x
|
||||
```
|
||||
where `_alt.<idx>` is an auxiliary declaration created by `inlineMatcher`
|
||||
-/
|
||||
if decl.fvarId == fvarId && decl.value.isApp && decl.value.getAppFn.isFVar then
|
||||
let f := decl.value.getAppFn.fvarId!
|
||||
let localDecl ← getLocalDecl f
|
||||
if localDecl.userName.getPrefix == `_alt then
|
||||
if let some funDecl ← findFun? f then
|
||||
let args := decl.value.getAppArgs
|
||||
eraseFVar decl.fvarId
|
||||
if let some altJp := (← get).map.find? f then
|
||||
/- We already have an auxiliary join point for `f`, then, we just use it. -/
|
||||
return .jmp altJp args
|
||||
else
|
||||
/-
|
||||
We have not created a join point for `f` yet.
|
||||
The join point has the form
|
||||
```
|
||||
jp altJp jpParams :=
|
||||
let _x := f jpParams
|
||||
jmp jpDecl _x
|
||||
```
|
||||
Then, we replace the current `let`-declaration with `jmp altJp args`
|
||||
-/
|
||||
let mut jpParams := #[]
|
||||
let mut subst := {}
|
||||
let mut jpArgs := #[]
|
||||
/- Remark: `funDecl.params.size` may be greater than `args.size`. -/
|
||||
for param in funDecl.params[:args.size] do
|
||||
let type ← replaceExprFVars param.type subst
|
||||
let paramNew ← mkAuxParam type
|
||||
jpParams := jpParams.push paramNew
|
||||
let arg := .fvar paramNew.fvarId
|
||||
subst := subst.insert param.fvarId arg
|
||||
jpArgs := jpArgs.push arg
|
||||
let letDecl ← mkAuxLetDecl (mkAppN decl.value.getAppFn jpArgs)
|
||||
let jpValue := .let letDecl (.jmp jpDecl.fvarId #[.fvar letDecl.fvarId])
|
||||
let altJp ← mkAuxJpDecl jpParams jpValue
|
||||
modify fun { jps, map } => {
|
||||
jps := jps.push altJp
|
||||
map := map.insert f altJp.fvarId
|
||||
}
|
||||
return .jmp altJp.fvarId args
|
||||
return .let decl (← go k)
|
||||
| .fun decl k => return .fun decl (← go k)
|
||||
| .jp decl k =>
|
||||
let value ← go decl.value
|
||||
let type ← value.inferParamType decl.params
|
||||
let decl ← decl.update' type value
|
||||
return .jp decl (← go k)
|
||||
| .cases c =>
|
||||
let alts ← c.alts.mapM fun
|
||||
| .alt ctorName params k => return .alt ctorName params (← go k)
|
||||
| .default k => return .default (← go k)
|
||||
if alts.isEmpty then
|
||||
throwError "`Code.bind` failed, empty `cases` found"
|
||||
let resultType ← mkCasesResultType alts
|
||||
return .cases { c with alts, resultType }
|
||||
| .return fvarId => return .jmp jpDecl.fvarId #[.fvar fvarId]
|
||||
| .jmp .. | .unreach .. => return code
|
||||
|
||||
def seqToCode (seq : Array Element) (e : Expr) : CompilerM Code := do
|
||||
if let .fvar fvarId := e then
|
||||
go seq seq.size (.return fvarId)
|
||||
|
|
@ -69,8 +179,7 @@ where
|
|||
else
|
||||
/- Create a join point for `c` and jump to it from `cases` -/
|
||||
let jpDecl ← mkAuxJpDecl' fvarId c
|
||||
let cases ← (Code.cases cases).bind fun fvarId => return .jmp jpDecl.fvarId #[.fvar fvarId]
|
||||
go seq (i - 1) (.jp jpDecl cases)
|
||||
go seq (i - 1) (← bindCases jpDecl cases)
|
||||
else
|
||||
return c
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue