feat: add cast at exit points if necessary when inlining code

This commit is contained in:
Leonardo de Moura 2022-10-09 13:00:30 -07:00
parent ef2d17120c
commit 11fcdb7bf4
4 changed files with 68 additions and 71 deletions

View file

@ -76,62 +76,8 @@ partial def internalizeCode (code : Code) : InternalizeM Code := do
| .unreach type => return .unreach (← normExpr type)
| .cases c =>
let resultType ← normExpr c.resultType
let ensureResultType := !eqvTypes resultType c.resultType
/-
Note:
If the new result type for the cases is not equivalent, we have to use `ensureResultType` to make sure the result is still type correct.
For example, suppose `resultType` is `◾` but the old one was not. Then, we must add a cast to `◾` (aka the any type)
to every alternative if their resulting type is not `◾`. This is similar to what we do at `ToLCNF.visitCases`.
Here is an example to illustrate this issue.
Suppose we have
```
inductive Id {A : Type u} : A → A → Type u
| refl {a : A} : Id a a
def transport {A : Type u} (B : A → Type v) {a b : A} (p : Id a b) : B a → B b :=
```
Its LCNF type is
```
{A : Type u} (B : A → Type v) {a b : A} (p : Id ◾ ◾) (a.1 : B ◾) : B ◾
```
and base phase code is
```
cases p : B ◾
| Id.refl =>
a.1
```
Now suppose we define
```
def transportconst {A B : Type u} : A = B → A → B :=
transport id
```
By setting `B` as `id`, and then inlining `transport, we would have the following code for `transportconst` is
```
cases p : ◾
| Id.refl =>
a.1
```
Which can be checked by `Check.lean` because it assumes `◾` is compatible with anything and `a.1 : A`.
However, if we inline `transportconst`, we can hit type error since the continuation for transportconst is
expecting a `B` instead of an `A`. We avoid this problem by adding a cast to `◾`. See `ToLCNF.visitCases` for
another place where we use this approach.
Thus, the resulting code for `transportconst` is
```
def MWE.transportconst (A : Type u) (B : Type u) (p : Id A B) (a.1 : A) :=
cases p
| Id.refl =>
let _x.2 := @lcCast A ◾ a.1
_x.2
```
TODO: consider removing this adjustment code from here, and handling it in the inlining procedure.
That is, adding casts at the exit points when inlining.
-/
let internalizeAltCode (k : Code) : InternalizeM Code := do
let k ← internalizeCode k
if ensureResultType then
k.ensureResultType resultType
else
return k
let internalizeAltCode (k : Code) : InternalizeM Code :=
internalizeCode k
let discr ← normFVar c.discr
let alts ← c.alts.mapM fun
| .alt ctorName params k => return .alt ctorName (← params.mapM internalizeParam) (← internalizeAltCode k)

View file

@ -34,7 +34,7 @@ and the free variable containing the result (`FVarId`). The resulting `FVarId` o
subset of `Array CodeDecl`. However, this method does try to filter the relevant ones.
We rely on the `used` var set available in `SimpM` to filter them. See `attachCodeDecls`.
-/
partial def inlineProjInst? (e : Expr) : SimpM (Option (Array CodeDecl × FVarId)) := do
partial def inlineProjInst? (e : Expr) (expectedType : Expr) : SimpM (Option (Array CodeDecl × FVarId)) := do
let .proj _ i s := e | return none
let sType ← inferType s
unless (← isClass? sType).isSome do return none
@ -42,7 +42,13 @@ partial def inlineProjInst? (e : Expr) : SimpM (Option (Array CodeDecl × FVarId
unless (← isClass? eType).isNone do return none
let (fvarId?, decls) ← visit s [i] |>.run |>.run #[]
if let some fvarId := fvarId? then
return some (decls, fvarId)
let type ← getType fvarId
if type.isErased || eqvTypes expectedType type then
return some (decls, fvarId)
else
let cast ← mkLcCast (.fvar fvarId) expectedType
let decl ← LCNF.mkAuxLetDecl cast
return some (decls.push (.let decl), decl.fvarId)
else
eraseCodeDecls decls
return none

View file

@ -53,6 +53,7 @@ def specializePartialApp (info : InlineCandidateInfo) : SimpM FunDecl := do
paramsNew := paramsNew.push paramNew
subst := subst.insert param.fvarId (.fvar paramNew.fvarId)
let code ← info.value.internalize subst
-- TODO: check resulting type
updateFunDeclInfo code
mkAuxFunDecl paramsNew code
@ -96,6 +97,44 @@ def isReturnOf (c : Code) (fvarId : FVarId) : SimpM Bool := do
| .return fvarId' => return (← normFVar fvarId') == fvarId
| _ => return false
/-
Note: function inlining and result type.
The function betaReduce has support for adding cast operations to arguments when inlining a definition.
We may also need cast operations at the exit points of a function.
Here is a concrete example.
Suppose we have
```
inductive Id {A : Type u} : A → A → Type u
| refl {a : A} : Id a a
def transport {A : Type u} (B : A → Type v) {a b : A} (p : Id a b) : B a → B b :=
```
Its LCNF type is
```
{A : Type u} (B : A → Type v) {a b : A} (p : Id ◾ ◾) (a.1 : B ◾) : B ◾
```
and base phase code is
```
cases p : B ◾
| Id.refl =>
a.1
```
Now suppose we define
```
def transportconst {A B : Type u} : A = B → A → B :=
transport id
```
By setting `B` as `id`, and then inlining `transport, we would have the following code for `transportconst` is
```
cases p : ◾
| Id.refl =>
a.1
```
Now, suppose we inline `transportconst` in a place where the continuation is expecting a value of
type `B`, but we are providing a value of type `A`. We must insert a cast to ensure the result is type
correct.
-/
mutual
/--
If the value of the given let-declaration is an application that can be inlined,
@ -124,7 +163,10 @@ partial def inlineApp? (letDecl : LetDecl) (k : Code) : SimpM (Option Code) := d
markSimplified
simp (.fun funDecl k)
else
let expectedType ← inferType (mkAppN info.f info.args[:info.arity])
let code ← betaReduce info.params info.value info.args[:info.arity]
/- See note above: function inlining and result type. -/
let code ← code.ensureResultType expectedType
if k.isReturnOf fvarId && numArgs == info.arity then
/- Easy case, the continuation `k` is just returning the result of the application. -/
markSimplified
@ -149,7 +191,7 @@ partial def inlineApp? (letDecl : LetDecl) (k : Code) : SimpM (Option Code) := d
-- return none
else
markSimplified
let jpParam ← mkAuxParam (← inferType (mkAppN info.f info.args[:info.arity]))
let jpParam ← mkAuxParam expectedType
let jpValue ← if numArgs > info.arity then
let decl ← mkAuxLetDecl (mkAppN (.fvar jpParam.fvarId) info.args[info.arity:])
addFVarSubst fvarId decl.fvarId
@ -198,6 +240,20 @@ partial def simpCasesOnCtor? (cases : Cases) : SimpM (Option Code) := do
auxDecls := auxDecls.push (CodeDecl.let auxDecl)
addFVarSubst param.fvarId auxDecl.fvarId
let k ← simp k
/-
We must ensure the result type is equivalent to `cases.resultType` here, otherwise the result may be type incorrect.
For example, the following LCNF code is correct before applying this transformation, but requires a cast after.
```
def f (a : A) (b : B) : B :=
let _x.1 := true
cases _x.1 :
| true => return a
| false => return b
```
This situation is similar to the one we have when inlining functions.
-/
let k ← k.ensureResultType cases.resultType
eraseParams params
attachCodeDecls auxDecls k
@ -225,7 +281,7 @@ partial def simp (code : Code) : SimpM Code := withIncRecDepth do
else if let some code ← inlineApp? decl k then
eraseLetDecl decl
return code
else if let some (decls, fvarId) ← inlineProjInst? decl.value then
else if let some (decls, fvarId) ← inlineProjInst? decl.value decl.type then
addFVarSubst decl.fvarId fvarId
eraseLetDecl decl
let k ← simp k

View file

@ -542,17 +542,6 @@ where
unless (← compatibleTypes altType resultType) do
resultType := erasedExpr
alts := alts.push alt
/-
We must ensure the result type of each alternative is equivalent to `resultType`, and not just compatible.
For example, if the result type for a `cases` is `◾`, we put a cast to `◾` (aka the any type)
at every alternative that does not have `◾` type.
The cast is useful to ensure the result is type correct when reducing `cases` in the simplifier
or applying `bind`. For example, suppose we are using `Code.bind` to connect a `cases` with type `◾`
to a continuation that expects type `B`, and one of the alternatives has type `A`. The operation makes
sense, but we need a cast since we are connecting a value of type `A` to a continuation that expects `B`.
-/
alts ← alts.mapM fun alt =>
return alt.updateCode (← alt.getCode.ensureResultType resultType)
let cases : Cases := { typeName, discr := discr.fvarId!, resultType, alts }
let auxDecl ← mkAuxParam resultType
pushElement (.cases auxDecl cases)