diff --git a/src/Lean/Compiler/LCNF/Bind.lean b/src/Lean/Compiler/LCNF/Bind.lean index 374c6bffdf..4eefe379a2 100644 --- a/src/Lean/Compiler/LCNF/Bind.lean +++ b/src/Lean/Compiler/LCNF/Bind.lean @@ -72,14 +72,14 @@ instance [STWorld ω m] [MonadCodeBind m] : MonadCodeBind (StateRefT' ω σ m) w codeBind c f sref := c.bind fun fvarId => f fvarId sref /-- -Ensure resulting code has type `◾`. +Ensure resulting code type is equivalent to `type`. -/ -def Code.ensureAnyType (c : Code) : CompilerM Code := do - if (← c.inferType).isErased then +def Code.ensureResultType (c : Code) (type : Expr) : CompilerM Code := do + if eqvTypes (← c.inferType) type then return c else c.bind fun fvarId => do - let cast ← mkLcCast (.fvar fvarId) erasedExpr + let cast ← mkLcCast (.fvar fvarId) type let decl ← LCNF.mkAuxLetDecl cast return .let decl (.return decl.fvarId) diff --git a/src/Lean/Compiler/LCNF/Internalize.lean b/src/Lean/Compiler/LCNF/Internalize.lean index e2bc7934c8..fb01a3da14 100644 --- a/src/Lean/Compiler/LCNF/Internalize.lean +++ b/src/Lean/Compiler/LCNF/Internalize.lean @@ -76,10 +76,11 @@ partial def internalizeCode (code : Code) : InternalizeM Code := do | .unreach type => return .unreach (← normExpr type) | .cases c => let resultType ← normExpr c.resultType - let ensureAny := resultType != c.resultType && resultType.isErased + let ensureResultType := !eqvTypes resultType c.resultType /- Note: - If the new result type for the cases is `◾`, we must add a cast to `◾` (aka the any type) + If the new result type for the cases is not equivalent, we have to use `ensureResultType` to make sure the result is still type correc. + For result, suppose `resultType` is `◾` but the old one was not. Then, we must add a cast to `◾` (aka the any type) to every alternative if their resulting type is not `◾`. This is similar to what we do at `ToLCNF.visitCases`. Here is an example to illustrate this issue. Suppose we have @@ -124,8 +125,8 @@ partial def internalizeCode (code : Code) : InternalizeM Code := do -/ let internalizeAltCode (k : Code) : InternalizeM Code := do let k ← internalizeCode k - if ensureAny then - k.ensureAnyType + if ensureResultType then + k.ensureResultType resultType else return k let discr ← normFVar c.discr diff --git a/src/Lean/Compiler/LCNF/Simp/SimpM.lean b/src/Lean/Compiler/LCNF/Simp/SimpM.lean index fea8ecd6d3..31d7aff306 100644 --- a/src/Lean/Compiler/LCNF/Simp/SimpM.lean +++ b/src/Lean/Compiler/LCNF/Simp/SimpM.lean @@ -219,12 +219,29 @@ def betaReduce (params : Array Param) (code : Code) (args : Array Expr) (mustInl ... ``` We must introduce a cast around `a` to make sure the resulting expression is type correct. + + Note: this issue is not restricted to situations where `param.type` is `◾`. It may also happen if + the `◾` is nested at `param.type`. For example, consider the following variant of the example above. + It is also type correct before inlining `f`, but we must introduce a cast to ensure it is still type + correct after. + ``` + def foo (g : List A → List A) (a : List B) := + fun f (x : List ◾) := + let _x.1 := g x + ... + let _x.2 := f a + ... + ``` -/ - if param.type.isErased && !(← inferType arg).isErased then - let castArg ← mkLcCast arg erasedExpr - let castDecl ← mkAuxLetDecl castArg - castDecls := castDecls.push (CodeDecl.let castDecl) - subst := subst.insert param.fvarId (.fvar castDecl.fvarId) + let argType ← inferType arg + if !argType.isErased && !eqvTypes argType (normExprCore subst param.type (translator := true)) then + if !arg.isFVar || isTypeFormerType argType then + subst := subst.insert param.fvarId erasedExpr + else + let castArg ← mkLcCast arg erasedExpr + let castDecl ← mkAuxLetDecl castArg + castDecls := castDecls.push (CodeDecl.let castDecl) + subst := subst.insert param.fvarId (.fvar castDecl.fvarId) else subst := subst.insert param.fvarId arg let code ← code.internalize subst diff --git a/src/Lean/Compiler/LCNF/ToLCNF.lean b/src/Lean/Compiler/LCNF/ToLCNF.lean index 84fe803763..35e02501d4 100644 --- a/src/Lean/Compiler/LCNF/ToLCNF.lean +++ b/src/Lean/Compiler/LCNF/ToLCNF.lean @@ -542,17 +542,17 @@ where unless (← compatibleTypes altType resultType) do resultType := erasedExpr alts := alts.push alt - if resultType.isErased || resultType.isErased then - /- - If the result type for a `cases` is `◾`, we put a cast to `◾` (aka the any type) - at every alternative that does not have `◾` type. - The cast is useful to ensure the result is type correct when reducing `cases` in the simplifier - or applying `bind`. For example, suppose we are using `Code.bind` to connect a `cases` with type `◾` - to a continuation that expects type `B`, and one of the alternatives has type `A`. The operation makes - sense, but we need a cast since we are connecting a value of type `A` to a continuation that expects `B`. - -/ - alts ← alts.mapM fun alt => - return alt.updateCode (← alt.getCode.ensureAnyType) + /- + We must ensure the result type of each alternative is equivalent to `resultType`, and not just compatible. + For example, if the result type for a `cases` is `◾`, we put a cast to `◾` (aka the any type) + at every alternative that does not have `◾` type. + The cast is useful to ensure the result is type correct when reducing `cases` in the simplifier + or applying `bind`. For example, suppose we are using `Code.bind` to connect a `cases` with type `◾` + to a continuation that expects type `B`, and one of the alternatives has type `A`. The operation makes + sense, but we need a cast since we are connecting a value of type `A` to a continuation that expects `B`. + -/ + alts ← alts.mapM fun alt => + return alt.updateCode (← alt.getCode.ensureResultType resultType) let cases : Cases := { typeName, discr := discr.fvarId!, resultType, alts } let auxDecl ← mkAuxParam resultType pushElement (.cases auxDecl cases) diff --git a/tests/lean/run/lcnfInliningIssue.lean b/tests/lean/run/lcnfInliningIssue.lean new file mode 100644 index 0000000000..bcceecefd0 --- /dev/null +++ b/tests/lean/run/lcnfInliningIssue.lean @@ -0,0 +1,34 @@ +namespace MWE + +inductive Id {A : Type u} : A → A → Type u +| refl {a : A} : Id a a + +attribute [eliminator] Id.casesOn + +infix:50 (priority := high) " = " => Id + +def symm {A : Type u} {a b : A} (p : a = b) : b = a := +by { induction p; exact Id.refl } + +def map {A : Type u} {B : Type v} {a b : A} (f : A → B) (p : a = b) : f a = f b := +by { induction p; apply Id.refl } + +def transport {A : Type u} (B : A → Type v) {a b : A} (p : a = b) : B a → B b := +by { induction p; exact id } + +def boolToUniverse : Bool → Type +| true => Unit +| false => Empty + +def ffNeqTt : false = true → Empty := +λ p => transport boolToUniverse (symm p) () + +def isZero : Nat → Bool +| Nat.zero => true +| Nat.succ _ => false + +set_option pp.funBinderTypes true +set_option pp.letVarTypes true +set_option trace.Compiler.result true +def succNeqZero (n : Nat) : Nat.succ n = 0 → Empty := +λ h => ffNeqTt (map isZero h)