From 11fcdb7bf458b7af1dd8f36aef2594894f977d25 Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Sun, 9 Oct 2022 13:00:30 -0700 Subject: [PATCH] feat: add cast at exit points if necessary when inlining code --- src/Lean/Compiler/LCNF/Internalize.lean | 58 +------------------- src/Lean/Compiler/LCNF/Simp/InlineProj.lean | 10 +++- src/Lean/Compiler/LCNF/Simp/Main.lean | 60 ++++++++++++++++++++- src/Lean/Compiler/LCNF/ToLCNF.lean | 11 ---- 4 files changed, 68 insertions(+), 71 deletions(-) diff --git a/src/Lean/Compiler/LCNF/Internalize.lean b/src/Lean/Compiler/LCNF/Internalize.lean index 3ead3c832c..f826ea6ed8 100644 --- a/src/Lean/Compiler/LCNF/Internalize.lean +++ b/src/Lean/Compiler/LCNF/Internalize.lean @@ -76,62 +76,8 @@ partial def internalizeCode (code : Code) : InternalizeM Code := do | .unreach type => return .unreach (← normExpr type) | .cases c => let resultType ← normExpr c.resultType - let ensureResultType := !eqvTypes resultType c.resultType - /- - Note: - If the new result type for the cases is not equivalent, we have to use `ensureResultType` to make sure the result is still type correct. - For example, suppose `resultType` is `◾` but the old one was not. Then, we must add a cast to `◾` (aka the any type) - to every alternative if their resulting type is not `◾`. This is similar to what we do at `ToLCNF.visitCases`. - Here is an example to illustrate this issue. - Suppose we have - ``` - inductive Id {A : Type u} : A → A → Type u - | refl {a : A} : Id a a - def transport {A : Type u} (B : A → Type v) {a b : A} (p : Id a b) : B a → B b := - ``` - Its LCNF type is - ``` - {A : Type u} (B : A → Type v) {a b : A} (p : Id ◾ ◾) (a.1 : B ◾) : B ◾ - ``` - and base phase code is - ``` - cases p : B ◾ - | Id.refl => - a.1 - ``` - Now suppose we define - ``` - def transportconst {A B : Type u} : A = B → A → B := - transport id - ``` - By setting `B` as `id`, and then inlining `transport, we would have the following code for `transportconst` is - ``` - cases p : ◾ - | Id.refl => - a.1 - ``` - Which can be checked by `Check.lean` because it assumes `◾` is compatible with anything and `a.1 : A`. - However, if we inline `transportconst`, we can hit type error since the continuation for transportconst is - expecting a `B` instead of an `A`. We avoid this problem by adding a cast to `◾`. See `ToLCNF.visitCases` for - another place where we use this approach. - Thus, the resulting code for `transportconst` is - ``` - def MWE.transportconst (A : Type u) (B : Type u) (p : Id A B) (a.1 : A) := - cases p - | Id.refl => - let _x.2 := @lcCast A ◾ a.1 - _x.2 - ``` - - TODO: consider removing this adjustment code from here, and handling it in the inlining procedure. - That is, adding casts at the exit points when inlining. - -/ - let internalizeAltCode (k : Code) : InternalizeM Code := do - let k ← internalizeCode k - if ensureResultType then - k.ensureResultType resultType - else - return k + let internalizeAltCode (k : Code) : InternalizeM Code := + internalizeCode k let discr ← normFVar c.discr let alts ← c.alts.mapM fun | .alt ctorName params k => return .alt ctorName (← params.mapM internalizeParam) (← internalizeAltCode k) diff --git a/src/Lean/Compiler/LCNF/Simp/InlineProj.lean b/src/Lean/Compiler/LCNF/Simp/InlineProj.lean index 248a2cb40c..5426b31423 100644 --- a/src/Lean/Compiler/LCNF/Simp/InlineProj.lean +++ b/src/Lean/Compiler/LCNF/Simp/InlineProj.lean @@ -34,7 +34,7 @@ and the free variable containing the result (`FVarId`). The resulting `FVarId` o subset of `Array CodeDecl`. However, this method does try to filter the relevant ones. We rely on the `used` var set available in `SimpM` to filter them. See `attachCodeDecls`. -/ -partial def inlineProjInst? (e : Expr) : SimpM (Option (Array CodeDecl × FVarId)) := do +partial def inlineProjInst? (e : Expr) (expectedType : Expr) : SimpM (Option (Array CodeDecl × FVarId)) := do let .proj _ i s := e | return none let sType ← inferType s unless (← isClass? sType).isSome do return none @@ -42,7 +42,13 @@ partial def inlineProjInst? (e : Expr) : SimpM (Option (Array CodeDecl × FVarId unless (← isClass? eType).isNone do return none let (fvarId?, decls) ← visit s [i] |>.run |>.run #[] if let some fvarId := fvarId? then - return some (decls, fvarId) + let type ← getType fvarId + if type.isErased || eqvTypes expectedType type then + return some (decls, fvarId) + else + let cast ← mkLcCast (.fvar fvarId) expectedType + let decl ← LCNF.mkAuxLetDecl cast + return some (decls.push (.let decl), decl.fvarId) else eraseCodeDecls decls return none diff --git a/src/Lean/Compiler/LCNF/Simp/Main.lean b/src/Lean/Compiler/LCNF/Simp/Main.lean index 7db68a3e3c..9c2c27a57c 100644 --- a/src/Lean/Compiler/LCNF/Simp/Main.lean +++ b/src/Lean/Compiler/LCNF/Simp/Main.lean @@ -53,6 +53,7 @@ def specializePartialApp (info : InlineCandidateInfo) : SimpM FunDecl := do paramsNew := paramsNew.push paramNew subst := subst.insert param.fvarId (.fvar paramNew.fvarId) let code ← info.value.internalize subst + -- TODO: check resulting type updateFunDeclInfo code mkAuxFunDecl paramsNew code @@ -96,6 +97,44 @@ def isReturnOf (c : Code) (fvarId : FVarId) : SimpM Bool := do | .return fvarId' => return (← normFVar fvarId') == fvarId | _ => return false +/- +Note: function inlining and result type. +The function betaReduce has support for adding cast operations to arguments when inlining a definition. +We may also need cast operations at the exit points of a function. +Here is a concrete example. + +Suppose we have +``` +inductive Id {A : Type u} : A → A → Type u + | refl {a : A} : Id a a +def transport {A : Type u} (B : A → Type v) {a b : A} (p : Id a b) : B a → B b := +``` +Its LCNF type is +``` +{A : Type u} (B : A → Type v) {a b : A} (p : Id ◾ ◾) (a.1 : B ◾) : B ◾ +``` +and base phase code is +``` +cases p : B ◾ +| Id.refl => + a.1 +``` +Now suppose we define +``` +def transportconst {A B : Type u} : A = B → A → B := + transport id +``` +By setting `B` as `id`, and then inlining `transport, we would have the following code for `transportconst` is +``` +cases p : ◾ +| Id.refl => + a.1 +``` +Now, suppose we inline `transportconst` in a place where the continuation is expecting a value of +type `B`, but we are providing a value of type `A`. We must insert a cast to ensure the result is type +correct. +-/ + mutual /-- If the value of the given let-declaration is an application that can be inlined, @@ -124,7 +163,10 @@ partial def inlineApp? (letDecl : LetDecl) (k : Code) : SimpM (Option Code) := d markSimplified simp (.fun funDecl k) else + let expectedType ← inferType (mkAppN info.f info.args[:info.arity]) let code ← betaReduce info.params info.value info.args[:info.arity] + /- See note above: function inlining and result type. -/ + let code ← code.ensureResultType expectedType if k.isReturnOf fvarId && numArgs == info.arity then /- Easy case, the continuation `k` is just returning the result of the application. -/ markSimplified @@ -149,7 +191,7 @@ partial def inlineApp? (letDecl : LetDecl) (k : Code) : SimpM (Option Code) := d -- return none else markSimplified - let jpParam ← mkAuxParam (← inferType (mkAppN info.f info.args[:info.arity])) + let jpParam ← mkAuxParam expectedType let jpValue ← if numArgs > info.arity then let decl ← mkAuxLetDecl (mkAppN (.fvar jpParam.fvarId) info.args[info.arity:]) addFVarSubst fvarId decl.fvarId @@ -198,6 +240,20 @@ partial def simpCasesOnCtor? (cases : Cases) : SimpM (Option Code) := do auxDecls := auxDecls.push (CodeDecl.let auxDecl) addFVarSubst param.fvarId auxDecl.fvarId let k ← simp k + /- + We must ensure the result type is equivalent to `cases.resultType` here, otherwise the result may be type incorrect. + For example, the following LCNF code is correct before applying this transformation, but requires a cast after. + + ``` + def f (a : A) (b : B) : B := + let _x.1 := true + cases _x.1 : ⊤ + | true => return a + | false => return b + ``` + This situation is similar to the one we have when inlining functions. + -/ + let k ← k.ensureResultType cases.resultType eraseParams params attachCodeDecls auxDecls k @@ -225,7 +281,7 @@ partial def simp (code : Code) : SimpM Code := withIncRecDepth do else if let some code ← inlineApp? decl k then eraseLetDecl decl return code - else if let some (decls, fvarId) ← inlineProjInst? decl.value then + else if let some (decls, fvarId) ← inlineProjInst? decl.value decl.type then addFVarSubst decl.fvarId fvarId eraseLetDecl decl let k ← simp k diff --git a/src/Lean/Compiler/LCNF/ToLCNF.lean b/src/Lean/Compiler/LCNF/ToLCNF.lean index 35e02501d4..035deeae17 100644 --- a/src/Lean/Compiler/LCNF/ToLCNF.lean +++ b/src/Lean/Compiler/LCNF/ToLCNF.lean @@ -542,17 +542,6 @@ where unless (← compatibleTypes altType resultType) do resultType := erasedExpr alts := alts.push alt - /- - We must ensure the result type of each alternative is equivalent to `resultType`, and not just compatible. - For example, if the result type for a `cases` is `◾`, we put a cast to `◾` (aka the any type) - at every alternative that does not have `◾` type. - The cast is useful to ensure the result is type correct when reducing `cases` in the simplifier - or applying `bind`. For example, suppose we are using `Code.bind` to connect a `cases` with type `◾` - to a continuation that expects type `B`, and one of the alternatives has type `A`. The operation makes - sense, but we need a cast since we are connecting a value of type `A` to a continuation that expects `B`. - -/ - alts ← alts.mapM fun alt => - return alt.updateCode (← alt.getCode.ensureResultType resultType) let cases : Cases := { typeName, discr := discr.fvarId!, resultType, alts } let auxDecl ← mkAuxParam resultType pushElement (.cases auxDecl cases)