feat: add cast at exit points if necessary when inlining code
This commit is contained in:
parent
ef2d17120c
commit
11fcdb7bf4
4 changed files with 68 additions and 71 deletions
|
|
@ -76,62 +76,8 @@ partial def internalizeCode (code : Code) : InternalizeM Code := do
|
|||
| .unreach type => return .unreach (← normExpr type)
|
||||
| .cases c =>
|
||||
let resultType ← normExpr c.resultType
|
||||
let ensureResultType := !eqvTypes resultType c.resultType
|
||||
/-
|
||||
Note:
|
||||
If the new result type for the cases is not equivalent, we have to use `ensureResultType` to make sure the result is still type correct.
|
||||
For example, suppose `resultType` is `◾` but the old one was not. Then, we must add a cast to `◾` (aka the any type)
|
||||
to every alternative if their resulting type is not `◾`. This is similar to what we do at `ToLCNF.visitCases`.
|
||||
Here is an example to illustrate this issue.
|
||||
Suppose we have
|
||||
```
|
||||
inductive Id {A : Type u} : A → A → Type u
|
||||
| refl {a : A} : Id a a
|
||||
def transport {A : Type u} (B : A → Type v) {a b : A} (p : Id a b) : B a → B b :=
|
||||
```
|
||||
Its LCNF type is
|
||||
```
|
||||
{A : Type u} (B : A → Type v) {a b : A} (p : Id ◾ ◾) (a.1 : B ◾) : B ◾
|
||||
```
|
||||
and base phase code is
|
||||
```
|
||||
cases p : B ◾
|
||||
| Id.refl =>
|
||||
a.1
|
||||
```
|
||||
Now suppose we define
|
||||
```
|
||||
def transportconst {A B : Type u} : A = B → A → B :=
|
||||
transport id
|
||||
```
|
||||
By setting `B` as `id`, and then inlining `transport, we would have the following code for `transportconst` is
|
||||
```
|
||||
cases p : ◾
|
||||
| Id.refl =>
|
||||
a.1
|
||||
```
|
||||
Which can be checked by `Check.lean` because it assumes `◾` is compatible with anything and `a.1 : A`.
|
||||
However, if we inline `transportconst`, we can hit type error since the continuation for transportconst is
|
||||
expecting a `B` instead of an `A`. We avoid this problem by adding a cast to `◾`. See `ToLCNF.visitCases` for
|
||||
another place where we use this approach.
|
||||
Thus, the resulting code for `transportconst` is
|
||||
```
|
||||
def MWE.transportconst (A : Type u) (B : Type u) (p : Id A B) (a.1 : A) :=
|
||||
cases p
|
||||
| Id.refl =>
|
||||
let _x.2 := @lcCast A ◾ a.1
|
||||
_x.2
|
||||
```
|
||||
|
||||
TODO: consider removing this adjustment code from here, and handling it in the inlining procedure.
|
||||
That is, adding casts at the exit points when inlining.
|
||||
-/
|
||||
let internalizeAltCode (k : Code) : InternalizeM Code := do
|
||||
let k ← internalizeCode k
|
||||
if ensureResultType then
|
||||
k.ensureResultType resultType
|
||||
else
|
||||
return k
|
||||
let internalizeAltCode (k : Code) : InternalizeM Code :=
|
||||
internalizeCode k
|
||||
let discr ← normFVar c.discr
|
||||
let alts ← c.alts.mapM fun
|
||||
| .alt ctorName params k => return .alt ctorName (← params.mapM internalizeParam) (← internalizeAltCode k)
|
||||
|
|
|
|||
|
|
@ -34,7 +34,7 @@ and the free variable containing the result (`FVarId`). The resulting `FVarId` o
|
|||
subset of `Array CodeDecl`. However, this method does try to filter the relevant ones.
|
||||
We rely on the `used` var set available in `SimpM` to filter them. See `attachCodeDecls`.
|
||||
-/
|
||||
partial def inlineProjInst? (e : Expr) : SimpM (Option (Array CodeDecl × FVarId)) := do
|
||||
partial def inlineProjInst? (e : Expr) (expectedType : Expr) : SimpM (Option (Array CodeDecl × FVarId)) := do
|
||||
let .proj _ i s := e | return none
|
||||
let sType ← inferType s
|
||||
unless (← isClass? sType).isSome do return none
|
||||
|
|
@ -42,7 +42,13 @@ partial def inlineProjInst? (e : Expr) : SimpM (Option (Array CodeDecl × FVarId
|
|||
unless (← isClass? eType).isNone do return none
|
||||
let (fvarId?, decls) ← visit s [i] |>.run |>.run #[]
|
||||
if let some fvarId := fvarId? then
|
||||
return some (decls, fvarId)
|
||||
let type ← getType fvarId
|
||||
if type.isErased || eqvTypes expectedType type then
|
||||
return some (decls, fvarId)
|
||||
else
|
||||
let cast ← mkLcCast (.fvar fvarId) expectedType
|
||||
let decl ← LCNF.mkAuxLetDecl cast
|
||||
return some (decls.push (.let decl), decl.fvarId)
|
||||
else
|
||||
eraseCodeDecls decls
|
||||
return none
|
||||
|
|
|
|||
|
|
@ -53,6 +53,7 @@ def specializePartialApp (info : InlineCandidateInfo) : SimpM FunDecl := do
|
|||
paramsNew := paramsNew.push paramNew
|
||||
subst := subst.insert param.fvarId (.fvar paramNew.fvarId)
|
||||
let code ← info.value.internalize subst
|
||||
-- TODO: check resulting type
|
||||
updateFunDeclInfo code
|
||||
mkAuxFunDecl paramsNew code
|
||||
|
||||
|
|
@ -96,6 +97,44 @@ def isReturnOf (c : Code) (fvarId : FVarId) : SimpM Bool := do
|
|||
| .return fvarId' => return (← normFVar fvarId') == fvarId
|
||||
| _ => return false
|
||||
|
||||
/-
|
||||
Note: function inlining and result type.
|
||||
The function betaReduce has support for adding cast operations to arguments when inlining a definition.
|
||||
We may also need cast operations at the exit points of a function.
|
||||
Here is a concrete example.
|
||||
|
||||
Suppose we have
|
||||
```
|
||||
inductive Id {A : Type u} : A → A → Type u
|
||||
| refl {a : A} : Id a a
|
||||
def transport {A : Type u} (B : A → Type v) {a b : A} (p : Id a b) : B a → B b :=
|
||||
```
|
||||
Its LCNF type is
|
||||
```
|
||||
{A : Type u} (B : A → Type v) {a b : A} (p : Id ◾ ◾) (a.1 : B ◾) : B ◾
|
||||
```
|
||||
and base phase code is
|
||||
```
|
||||
cases p : B ◾
|
||||
| Id.refl =>
|
||||
a.1
|
||||
```
|
||||
Now suppose we define
|
||||
```
|
||||
def transportconst {A B : Type u} : A = B → A → B :=
|
||||
transport id
|
||||
```
|
||||
By setting `B` as `id`, and then inlining `transport, we would have the following code for `transportconst` is
|
||||
```
|
||||
cases p : ◾
|
||||
| Id.refl =>
|
||||
a.1
|
||||
```
|
||||
Now, suppose we inline `transportconst` in a place where the continuation is expecting a value of
|
||||
type `B`, but we are providing a value of type `A`. We must insert a cast to ensure the result is type
|
||||
correct.
|
||||
-/
|
||||
|
||||
mutual
|
||||
/--
|
||||
If the value of the given let-declaration is an application that can be inlined,
|
||||
|
|
@ -124,7 +163,10 @@ partial def inlineApp? (letDecl : LetDecl) (k : Code) : SimpM (Option Code) := d
|
|||
markSimplified
|
||||
simp (.fun funDecl k)
|
||||
else
|
||||
let expectedType ← inferType (mkAppN info.f info.args[:info.arity])
|
||||
let code ← betaReduce info.params info.value info.args[:info.arity]
|
||||
/- See note above: function inlining and result type. -/
|
||||
let code ← code.ensureResultType expectedType
|
||||
if k.isReturnOf fvarId && numArgs == info.arity then
|
||||
/- Easy case, the continuation `k` is just returning the result of the application. -/
|
||||
markSimplified
|
||||
|
|
@ -149,7 +191,7 @@ partial def inlineApp? (letDecl : LetDecl) (k : Code) : SimpM (Option Code) := d
|
|||
-- return none
|
||||
else
|
||||
markSimplified
|
||||
let jpParam ← mkAuxParam (← inferType (mkAppN info.f info.args[:info.arity]))
|
||||
let jpParam ← mkAuxParam expectedType
|
||||
let jpValue ← if numArgs > info.arity then
|
||||
let decl ← mkAuxLetDecl (mkAppN (.fvar jpParam.fvarId) info.args[info.arity:])
|
||||
addFVarSubst fvarId decl.fvarId
|
||||
|
|
@ -198,6 +240,20 @@ partial def simpCasesOnCtor? (cases : Cases) : SimpM (Option Code) := do
|
|||
auxDecls := auxDecls.push (CodeDecl.let auxDecl)
|
||||
addFVarSubst param.fvarId auxDecl.fvarId
|
||||
let k ← simp k
|
||||
/-
|
||||
We must ensure the result type is equivalent to `cases.resultType` here, otherwise the result may be type incorrect.
|
||||
For example, the following LCNF code is correct before applying this transformation, but requires a cast after.
|
||||
|
||||
```
|
||||
def f (a : A) (b : B) : B :=
|
||||
let _x.1 := true
|
||||
cases _x.1 : ⊤
|
||||
| true => return a
|
||||
| false => return b
|
||||
```
|
||||
This situation is similar to the one we have when inlining functions.
|
||||
-/
|
||||
let k ← k.ensureResultType cases.resultType
|
||||
eraseParams params
|
||||
attachCodeDecls auxDecls k
|
||||
|
||||
|
|
@ -225,7 +281,7 @@ partial def simp (code : Code) : SimpM Code := withIncRecDepth do
|
|||
else if let some code ← inlineApp? decl k then
|
||||
eraseLetDecl decl
|
||||
return code
|
||||
else if let some (decls, fvarId) ← inlineProjInst? decl.value then
|
||||
else if let some (decls, fvarId) ← inlineProjInst? decl.value decl.type then
|
||||
addFVarSubst decl.fvarId fvarId
|
||||
eraseLetDecl decl
|
||||
let k ← simp k
|
||||
|
|
|
|||
|
|
@ -542,17 +542,6 @@ where
|
|||
unless (← compatibleTypes altType resultType) do
|
||||
resultType := erasedExpr
|
||||
alts := alts.push alt
|
||||
/-
|
||||
We must ensure the result type of each alternative is equivalent to `resultType`, and not just compatible.
|
||||
For example, if the result type for a `cases` is `◾`, we put a cast to `◾` (aka the any type)
|
||||
at every alternative that does not have `◾` type.
|
||||
The cast is useful to ensure the result is type correct when reducing `cases` in the simplifier
|
||||
or applying `bind`. For example, suppose we are using `Code.bind` to connect a `cases` with type `◾`
|
||||
to a continuation that expects type `B`, and one of the alternatives has type `A`. The operation makes
|
||||
sense, but we need a cast since we are connecting a value of type `A` to a continuation that expects `B`.
|
||||
-/
|
||||
alts ← alts.mapM fun alt =>
|
||||
return alt.updateCode (← alt.getCode.ensureResultType resultType)
|
||||
let cases : Cases := { typeName, discr := discr.fvarId!, resultType, alts }
|
||||
let auxDecl ← mkAuxParam resultType
|
||||
pushElement (.cases auxDecl cases)
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue