diff --git a/src/Lean/Compiler/LCNF/ToLCNF.lean b/src/Lean/Compiler/LCNF/ToLCNF.lean index 096c207e31..e329139247 100644 --- a/src/Lean/Compiler/LCNF/ToLCNF.lean +++ b/src/Lean/Compiler/LCNF/ToLCNF.lean @@ -529,30 +529,37 @@ where let typeName := casesInfo.declName.getPrefix let discr ← visitAppArg args[casesInfo.discrPos]! let .inductInfo indVal ← getConstInfo typeName | unreachable! - for i in casesInfo.altsRange, numParams in casesInfo.altNumParams, ctorName in indVal.ctors do - let (altType, alt) ← visitAlt ctorName numParams args[i]! - unless (← compatibleTypes altType resultType) do - resultType := anyTypeExpr - alts := alts.push alt - if resultType.isAnyType || resultType.isErased then + if !discr.isFVar then /- - If the result type for a `cases` is `⊤` or `◾`, we put a cast to `⊤` - at every alternative that does not have `⊤` type. - The cast is useful to ensure the result is type correct when reducing `cases` in the simplifier - or applying `bind`. For example, suppose we are using `Code.bind` to connect a `cases` with type `⊤` - to a continuation that expects type `B`, and one of the alternatives has type `A`. The operation makes - sense, but we need a cast since we are connecting a value of type `A` to a continuation that expects `B`. + This can happen for inductive predicates that can eliminate into type (e.g., `And`, `Iff`). + TODO: add support for them. Right now, we have hard-coded support for the ones defined at `Init`. -/ - alts ← alts.mapM fun alt => - return alt.updateCode (← alt.getCode.ensureAnyType) - let cases : Cases := { typeName, discr := discr.fvarId!, resultType, alts } - let auxDecl ← mkAuxParam resultType - pushElement (.cases auxDecl cases) - let result := .fvar auxDecl.fvarId - if args.size == casesInfo.arity then - return result + throwError "unsupported `{casesInfo.declName}` application during code generation" else - mkOverApplication result args casesInfo.arity + for i in casesInfo.altsRange, numParams in casesInfo.altNumParams, ctorName in indVal.ctors do + let (altType, alt) ← visitAlt ctorName numParams args[i]! + unless (← compatibleTypes altType resultType) do + resultType := anyTypeExpr + alts := alts.push alt + if resultType.isAnyType || resultType.isErased then + /- + If the result type for a `cases` is `⊤` or `◾`, we put a cast to `⊤` + at every alternative that does not have `⊤` type. + The cast is useful to ensure the result is type correct when reducing `cases` in the simplifier + or applying `bind`. For example, suppose we are using `Code.bind` to connect a `cases` with type `⊤` + to a continuation that expects type `B`, and one of the alternatives has type `A`. The operation makes + sense, but we need a cast since we are connecting a value of type `A` to a continuation that expects `B`. + -/ + alts ← alts.mapM fun alt => + return alt.updateCode (← alt.getCode.ensureAnyType) + let cases : Cases := { typeName, discr := discr.fvarId!, resultType, alts } + let auxDecl ← mkAuxParam resultType + pushElement (.cases auxDecl cases) + let result := .fvar auxDecl.fvarId + if args.size == casesInfo.arity then + return result + else + mkOverApplication result args casesInfo.arity visitCtor (arity : Nat) (e : Expr) : M Expr := etaIfUnderApplied e arity do @@ -594,13 +601,13 @@ where let type ← toLCNFType (← liftMetaM do Meta.inferType e) mkUnreachable type - visitAndRec (e : Expr) : M Expr := + visitAndIffRecCore (e : Expr) (minorPos : Nat) : M Expr := let arity := 5 etaIfUnderApplied e arity do let args := e.getAppArgs let ha := mkLcProof args[0]! -- We should not use `lcErased` here since we use it to create a pre-LCNF Expr. let hb := mkLcProof args[1]! - let minor := if e.isAppOf ``And.rec then args[3]! else args[4]! + let minor := args[minorPos]! let minor := minor.beta #[ha, hb] visit (mkAppN minor args[arity:]) @@ -661,8 +668,10 @@ where visitCtor 3 e else if declName == ``Eq.casesOn || declName == ``Eq.rec || declName == ``Eq.ndrec then visitEqRec e - else if declName == ``And.rec || declName == ``And.casesOn then - visitAndRec e + else if declName == ``And.rec || declName == ``Iff.rec then + visitAndIffRecCore e (minorPos := 3) + else if declName == ``And.casesOn || declName == ``Iff.casesOn then + visitAndIffRecCore e (minorPos := 4) else if declName == ``False.rec || declName == ``Empty.rec || declName == ``False.casesOn || declName == ``Empty.casesOn then visitFalseRec e else if let some casesInfo ← getCasesInfo? declName then diff --git a/tests/lean/run/1684.lean b/tests/lean/run/1684.lean new file mode 100644 index 0000000000..7cf51fb171 --- /dev/null +++ b/tests/lean/run/1684.lean @@ -0,0 +1,7 @@ +set_option trace.Compiler.result true + +def Iff.elim1.{u} {a b : Prop} {motive : Sort u} (t : a ↔ b) (h : (mp : a → b) → (mpr : b → a) → motive) : motive := + match t with | ⟨hab, hba⟩ => h hab hba + +def Iff.elim2.{u} {a b : Prop} {motive : Sort u} (t : a ↔ b) (h : (mp : a → b) → (mpr : b → a) → motive) : motive := + Iff.casesOn (motive:= fun _ : a ↔ b => motive) t h