fix: type error introducing when inlining LCNF functions
This issue has been reported at https://leanprover.zulipchat.com/#narrow/stream/270676-lean4/topic/Annoying.20LCNF.20errors/near/303142516
This commit is contained in:
parent
f61ec4929f
commit
cc09afc5e1
5 changed files with 76 additions and 24 deletions
|
|
@ -72,14 +72,14 @@ instance [STWorld ω m] [MonadCodeBind m] : MonadCodeBind (StateRefT' ω σ m) w
|
|||
codeBind c f sref := c.bind fun fvarId => f fvarId sref
|
||||
|
||||
/--
|
||||
Ensure resulting code has type `◾`.
|
||||
Ensure resulting code type is equivalent to `type`.
|
||||
-/
|
||||
def Code.ensureAnyType (c : Code) : CompilerM Code := do
|
||||
if (← c.inferType).isErased then
|
||||
def Code.ensureResultType (c : Code) (type : Expr) : CompilerM Code := do
|
||||
if eqvTypes (← c.inferType) type then
|
||||
return c
|
||||
else
|
||||
c.bind fun fvarId => do
|
||||
let cast ← mkLcCast (.fvar fvarId) erasedExpr
|
||||
let cast ← mkLcCast (.fvar fvarId) type
|
||||
let decl ← LCNF.mkAuxLetDecl cast
|
||||
return .let decl (.return decl.fvarId)
|
||||
|
||||
|
|
|
|||
|
|
@ -76,10 +76,11 @@ partial def internalizeCode (code : Code) : InternalizeM Code := do
|
|||
| .unreach type => return .unreach (← normExpr type)
|
||||
| .cases c =>
|
||||
let resultType ← normExpr c.resultType
|
||||
let ensureAny := resultType != c.resultType && resultType.isErased
|
||||
let ensureResultType := !eqvTypes resultType c.resultType
|
||||
/-
|
||||
Note:
|
||||
If the new result type for the cases is `◾`, we must add a cast to `◾` (aka the any type)
|
||||
If the new result type for the cases is not equivalent, we have to use `ensureResultType` to make sure the result is still type correc.
|
||||
For result, suppose `resultType` is `◾` but the old one was not. Then, we must add a cast to `◾` (aka the any type)
|
||||
to every alternative if their resulting type is not `◾`. This is similar to what we do at `ToLCNF.visitCases`.
|
||||
Here is an example to illustrate this issue.
|
||||
Suppose we have
|
||||
|
|
@ -124,8 +125,8 @@ partial def internalizeCode (code : Code) : InternalizeM Code := do
|
|||
-/
|
||||
let internalizeAltCode (k : Code) : InternalizeM Code := do
|
||||
let k ← internalizeCode k
|
||||
if ensureAny then
|
||||
k.ensureAnyType
|
||||
if ensureResultType then
|
||||
k.ensureResultType resultType
|
||||
else
|
||||
return k
|
||||
let discr ← normFVar c.discr
|
||||
|
|
|
|||
|
|
@ -219,12 +219,29 @@ def betaReduce (params : Array Param) (code : Code) (args : Array Expr) (mustInl
|
|||
...
|
||||
```
|
||||
We must introduce a cast around `a` to make sure the resulting expression is type correct.
|
||||
|
||||
Note: this issue is not restricted to situations where `param.type` is `◾`. It may also happen if
|
||||
the `◾` is nested at `param.type`. For example, consider the following variant of the example above.
|
||||
It is also type correct before inlining `f`, but we must introduce a cast to ensure it is still type
|
||||
correct after.
|
||||
```
|
||||
def foo (g : List A → List A) (a : List B) :=
|
||||
fun f (x : List ◾) :=
|
||||
let _x.1 := g x
|
||||
...
|
||||
let _x.2 := f a
|
||||
...
|
||||
```
|
||||
-/
|
||||
if param.type.isErased && !(← inferType arg).isErased then
|
||||
let castArg ← mkLcCast arg erasedExpr
|
||||
let castDecl ← mkAuxLetDecl castArg
|
||||
castDecls := castDecls.push (CodeDecl.let castDecl)
|
||||
subst := subst.insert param.fvarId (.fvar castDecl.fvarId)
|
||||
let argType ← inferType arg
|
||||
if !argType.isErased && !eqvTypes argType (normExprCore subst param.type (translator := true)) then
|
||||
if !arg.isFVar || isTypeFormerType argType then
|
||||
subst := subst.insert param.fvarId erasedExpr
|
||||
else
|
||||
let castArg ← mkLcCast arg erasedExpr
|
||||
let castDecl ← mkAuxLetDecl castArg
|
||||
castDecls := castDecls.push (CodeDecl.let castDecl)
|
||||
subst := subst.insert param.fvarId (.fvar castDecl.fvarId)
|
||||
else
|
||||
subst := subst.insert param.fvarId arg
|
||||
let code ← code.internalize subst
|
||||
|
|
|
|||
|
|
@ -542,17 +542,17 @@ where
|
|||
unless (← compatibleTypes altType resultType) do
|
||||
resultType := erasedExpr
|
||||
alts := alts.push alt
|
||||
if resultType.isErased || resultType.isErased then
|
||||
/-
|
||||
If the result type for a `cases` is `◾`, we put a cast to `◾` (aka the any type)
|
||||
at every alternative that does not have `◾` type.
|
||||
The cast is useful to ensure the result is type correct when reducing `cases` in the simplifier
|
||||
or applying `bind`. For example, suppose we are using `Code.bind` to connect a `cases` with type `◾`
|
||||
to a continuation that expects type `B`, and one of the alternatives has type `A`. The operation makes
|
||||
sense, but we need a cast since we are connecting a value of type `A` to a continuation that expects `B`.
|
||||
-/
|
||||
alts ← alts.mapM fun alt =>
|
||||
return alt.updateCode (← alt.getCode.ensureAnyType)
|
||||
/-
|
||||
We must ensure the result type of each alternative is equivalent to `resultType`, and not just compatible.
|
||||
For example, if the result type for a `cases` is `◾`, we put a cast to `◾` (aka the any type)
|
||||
at every alternative that does not have `◾` type.
|
||||
The cast is useful to ensure the result is type correct when reducing `cases` in the simplifier
|
||||
or applying `bind`. For example, suppose we are using `Code.bind` to connect a `cases` with type `◾`
|
||||
to a continuation that expects type `B`, and one of the alternatives has type `A`. The operation makes
|
||||
sense, but we need a cast since we are connecting a value of type `A` to a continuation that expects `B`.
|
||||
-/
|
||||
alts ← alts.mapM fun alt =>
|
||||
return alt.updateCode (← alt.getCode.ensureResultType resultType)
|
||||
let cases : Cases := { typeName, discr := discr.fvarId!, resultType, alts }
|
||||
let auxDecl ← mkAuxParam resultType
|
||||
pushElement (.cases auxDecl cases)
|
||||
|
|
|
|||
34
tests/lean/run/lcnfInliningIssue.lean
Normal file
34
tests/lean/run/lcnfInliningIssue.lean
Normal file
|
|
@ -0,0 +1,34 @@
|
|||
namespace MWE
|
||||
|
||||
inductive Id {A : Type u} : A → A → Type u
|
||||
| refl {a : A} : Id a a
|
||||
|
||||
attribute [eliminator] Id.casesOn
|
||||
|
||||
infix:50 (priority := high) " = " => Id
|
||||
|
||||
def symm {A : Type u} {a b : A} (p : a = b) : b = a :=
|
||||
by { induction p; exact Id.refl }
|
||||
|
||||
def map {A : Type u} {B : Type v} {a b : A} (f : A → B) (p : a = b) : f a = f b :=
|
||||
by { induction p; apply Id.refl }
|
||||
|
||||
def transport {A : Type u} (B : A → Type v) {a b : A} (p : a = b) : B a → B b :=
|
||||
by { induction p; exact id }
|
||||
|
||||
def boolToUniverse : Bool → Type
|
||||
| true => Unit
|
||||
| false => Empty
|
||||
|
||||
def ffNeqTt : false = true → Empty :=
|
||||
λ p => transport boolToUniverse (symm p) ()
|
||||
|
||||
def isZero : Nat → Bool
|
||||
| Nat.zero => true
|
||||
| Nat.succ _ => false
|
||||
|
||||
set_option pp.funBinderTypes true
|
||||
set_option pp.letVarTypes true
|
||||
set_option trace.Compiler.result true
|
||||
def succNeqZero (n : Nat) : Nat.succ n = 0 → Empty :=
|
||||
λ h => ffNeqTt (map isZero h)
|
||||
Loading…
Add table
Reference in a new issue