fix: type error introducing when inlining LCNF functions

This issue has been reported at
https://leanprover.zulipchat.com/#narrow/stream/270676-lean4/topic/Annoying.20LCNF.20errors/near/303142516
This commit is contained in:
Leonardo de Moura 2022-10-09 12:09:01 -07:00
parent f61ec4929f
commit cc09afc5e1
5 changed files with 76 additions and 24 deletions

View file

@ -72,14 +72,14 @@ instance [STWorld ω m] [MonadCodeBind m] : MonadCodeBind (StateRefT' ω σ m) w
codeBind c f sref := c.bind fun fvarId => f fvarId sref
/--
Ensure resulting code has type `◾`.
Ensure resulting code type is equivalent to `type`.
-/
def Code.ensureAnyType (c : Code) : CompilerM Code := do
if (← c.inferType).isErased then
def Code.ensureResultType (c : Code) (type : Expr) : CompilerM Code := do
if eqvTypes (← c.inferType) type then
return c
else
c.bind fun fvarId => do
let cast ← mkLcCast (.fvar fvarId) erasedExpr
let cast ← mkLcCast (.fvar fvarId) type
let decl ← LCNF.mkAuxLetDecl cast
return .let decl (.return decl.fvarId)

View file

@ -76,10 +76,11 @@ partial def internalizeCode (code : Code) : InternalizeM Code := do
| .unreach type => return .unreach (← normExpr type)
| .cases c =>
let resultType ← normExpr c.resultType
let ensureAny := resultType != c.resultType && resultType.isErased
let ensureResultType := !eqvTypes resultType c.resultType
/-
Note:
If the new result type for the cases is `◾`, we must add a cast to `◾` (aka the any type)
If the new result type for the cases is not equivalent, we have to use `ensureResultType` to make sure the result is still type correc.
For result, suppose `resultType` is `◾` but the old one was not. Then, we must add a cast to `◾` (aka the any type)
to every alternative if their resulting type is not `◾`. This is similar to what we do at `ToLCNF.visitCases`.
Here is an example to illustrate this issue.
Suppose we have
@ -124,8 +125,8 @@ partial def internalizeCode (code : Code) : InternalizeM Code := do
-/
let internalizeAltCode (k : Code) : InternalizeM Code := do
let k ← internalizeCode k
if ensureAny then
k.ensureAnyType
if ensureResultType then
k.ensureResultType resultType
else
return k
let discr ← normFVar c.discr

View file

@ -219,12 +219,29 @@ def betaReduce (params : Array Param) (code : Code) (args : Array Expr) (mustInl
...
```
We must introduce a cast around `a` to make sure the resulting expression is type correct.
Note: this issue is not restricted to situations where `param.type` is `◾`. It may also happen if
the `◾` is nested at `param.type`. For example, consider the following variant of the example above.
It is also type correct before inlining `f`, but we must introduce a cast to ensure it is still type
correct after.
```
def foo (g : List A → List A) (a : List B) :=
fun f (x : List ◾) :=
let _x.1 := g x
...
let _x.2 := f a
...
```
-/
if param.type.isErased && !(← inferType arg).isErased then
let castArg ← mkLcCast arg erasedExpr
let castDecl ← mkAuxLetDecl castArg
castDecls := castDecls.push (CodeDecl.let castDecl)
subst := subst.insert param.fvarId (.fvar castDecl.fvarId)
let argType ← inferType arg
if !argType.isErased && !eqvTypes argType (normExprCore subst param.type (translator := true)) then
if !arg.isFVar || isTypeFormerType argType then
subst := subst.insert param.fvarId erasedExpr
else
let castArg ← mkLcCast arg erasedExpr
let castDecl ← mkAuxLetDecl castArg
castDecls := castDecls.push (CodeDecl.let castDecl)
subst := subst.insert param.fvarId (.fvar castDecl.fvarId)
else
subst := subst.insert param.fvarId arg
let code ← code.internalize subst

View file

@ -542,17 +542,17 @@ where
unless (← compatibleTypes altType resultType) do
resultType := erasedExpr
alts := alts.push alt
if resultType.isErased || resultType.isErased then
/-
If the result type for a `cases` is `◾`, we put a cast to `◾` (aka the any type)
at every alternative that does not have `◾` type.
The cast is useful to ensure the result is type correct when reducing `cases` in the simplifier
or applying `bind`. For example, suppose we are using `Code.bind` to connect a `cases` with type `◾`
to a continuation that expects type `B`, and one of the alternatives has type `A`. The operation makes
sense, but we need a cast since we are connecting a value of type `A` to a continuation that expects `B`.
-/
alts ← alts.mapM fun alt =>
return alt.updateCode (← alt.getCode.ensureAnyType)
/-
We must ensure the result type of each alternative is equivalent to `resultType`, and not just compatible.
For example, if the result type for a `cases` is `◾`, we put a cast to `◾` (aka the any type)
at every alternative that does not have `◾` type.
The cast is useful to ensure the result is type correct when reducing `cases` in the simplifier
or applying `bind`. For example, suppose we are using `Code.bind` to connect a `cases` with type `◾`
to a continuation that expects type `B`, and one of the alternatives has type `A`. The operation makes
sense, but we need a cast since we are connecting a value of type `A` to a continuation that expects `B`.
-/
alts ← alts.mapM fun alt =>
return alt.updateCode (← alt.getCode.ensureResultType resultType)
let cases : Cases := { typeName, discr := discr.fvarId!, resultType, alts }
let auxDecl ← mkAuxParam resultType
pushElement (.cases auxDecl cases)

View file

@ -0,0 +1,34 @@
namespace MWE
inductive Id {A : Type u} : A → A → Type u
| refl {a : A} : Id a a
attribute [eliminator] Id.casesOn
infix:50 (priority := high) " = " => Id
def symm {A : Type u} {a b : A} (p : a = b) : b = a :=
by { induction p; exact Id.refl }
def map {A : Type u} {B : Type v} {a b : A} (f : A → B) (p : a = b) : f a = f b :=
by { induction p; apply Id.refl }
def transport {A : Type u} (B : A → Type v) {a b : A} (p : a = b) : B a → B b :=
by { induction p; exact id }
def boolToUniverse : Bool → Type
| true => Unit
| false => Empty
def ffNeqTt : false = true → Empty :=
λ p => transport boolToUniverse (symm p) ()
def isZero : Nat → Bool
| Nat.zero => true
| Nat.succ _ => false
set_option pp.funBinderTypes true
set_option pp.letVarTypes true
set_option trace.Compiler.result true
def succNeqZero (n : Nat) : Nat.succ n = 0 → Empty :=
λ h => ffNeqTt (map isZero h)