From 31d59e337bc2baf021f8025b9d18b37648359585 Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Sun, 2 Oct 2022 08:09:13 -0700 Subject: [PATCH] fix: LCNF any type issue This fixes an issue reported at https://leanprover.zulipchat.com/#narrow/stream/270676-lean4/topic/Annoying.20LCNF.20errors/near/301935406 --- src/Lean/Compiler/LCNF/Bind.lean | 12 ++++++++++++ src/Lean/Compiler/LCNF/ToLCNF.lean | 11 +++++++++++ tests/lean/run/casesAnyTypeIssue.lean | 24 ++++++++++++++++++++++++ 3 files changed, 47 insertions(+) create mode 100644 tests/lean/run/casesAnyTypeIssue.lean diff --git a/src/Lean/Compiler/LCNF/Bind.lean b/src/Lean/Compiler/LCNF/Bind.lean index 06a4b7ecf5..734f5b6fc5 100644 --- a/src/Lean/Compiler/LCNF/Bind.lean +++ b/src/Lean/Compiler/LCNF/Bind.lean @@ -71,6 +71,18 @@ instance [MonadCodeBind m] : MonadCodeBind (ReaderT ρ m) where instance [STWorld ω m] [MonadCodeBind m] : MonadCodeBind (StateRefT' ω σ m) where codeBind c f sref := c.bind fun fvarId => f fvarId sref +/-- +Ensure resulting code has type `⊤`. +-/ +def Code.ensureAnyType (c : Code) : CompilerM Code := do + if (← c.inferType).isAnyType then + return c + else + c.bind fun fvarId => do + let cast ← mkLcCast (.fvar fvarId) anyTypeExpr + let decl ← LCNF.mkAuxLetDecl cast + return .let decl (.return decl.fvarId) + /-- Create new parameters for the given arrow type. Example: if `type` is `Nat → Bool → Int`, the result is diff --git a/src/Lean/Compiler/LCNF/ToLCNF.lean b/src/Lean/Compiler/LCNF/ToLCNF.lean index 5c321fdf6c..d73029812c 100644 --- a/src/Lean/Compiler/LCNF/ToLCNF.lean +++ b/src/Lean/Compiler/LCNF/ToLCNF.lean @@ -534,6 +534,17 @@ where unless (← compatibleTypes altType resultType) do resultType := anyTypeExpr alts := alts.push alt + if resultType.isAnyType then + /- + If the result type for a `cases` is `⊤`, we put a cast to `⊤` + at every alternative that does not have `⊤` type. + The cast is useful to ensure the result is type correct when reducing `cases` in the simplifier + or applying `bind`. For example, suppose we are using `Code.bind` to connect a `cases` with type `⊤` + to a continuation that expects type `B`, and one of the alternatives has type `A`. The operation makes + sense, but we need a cast since we are connecting a value of type `A` to a continuation that expects `B`. + -/ + alts ← alts.mapM fun alt => + return alt.updateCode (← alt.getCode.ensureAnyType) let cases : Cases := { typeName, discr := discr.fvarId!, resultType, alts } let auxDecl ← mkAuxParam resultType pushElement (.cases auxDecl cases) diff --git a/tests/lean/run/casesAnyTypeIssue.lean b/tests/lean/run/casesAnyTypeIssue.lean new file mode 100644 index 0000000000..c0cf4d5247 --- /dev/null +++ b/tests/lean/run/casesAnyTypeIssue.lean @@ -0,0 +1,24 @@ +namespace MWE + +inductive Id {A : Type u} : A → A → Type u +| refl {a : A} : Id a a + +attribute [eliminator] Id.casesOn + +infix:50 (priority := high) " = " => Id + +def symm {A : Type u} {a b : A} (p : a = b) : b = a := +by { induction p; exact Id.refl } + +def transportconst {A B : Type u} : A = B → A → B := +by { intros p x; induction p; exact x } + +def transportconstInv {A B : Type u} (e : A = B) : B → A := +transportconst (symm e) + +def transportconstOverInv {A B : Type u} (p : A = B) : + ∀ x, transportconst (symm p) x = transportconstInv p x := +by { intro x; apply Id.refl } + +def transportconstInv' {A B : Type u} : A = B → B → A := +transportconst ∘ symm