diff --git a/src/Lean/Compiler/LCNF/ToLCNF.lean b/src/Lean/Compiler/LCNF/ToLCNF.lean index b523a029dd..d86c10da15 100644 --- a/src/Lean/Compiler/LCNF/ToLCNF.lean +++ b/src/Lean/Compiler/LCNF/ToLCNF.lean @@ -348,15 +348,15 @@ def mkParam (binderName : Name) (type : Expr) : M (Param .pure) := do modify fun s => { s with lctx := s.lctx.mkLocalDecl param.fvarId binderName type .default } return param -def mkLetDecl (binderName : Name) (type : Expr) (value : Expr) (type' : Expr) (arg : Arg .pure) : - M (LetDecl .pure) := do +def mkLetDecl (binderName : Name) (type : Expr) (value : Expr) (type' : Expr) (arg : Arg .pure) + (nondep : Bool) : M (LetDecl .pure) := do let binderName ← cleanupBinderName binderName let value' ← match arg with | .fvar fvarId => pure <| .fvar fvarId #[] | .erased | .type .. => pure .erased let letDecl ← LCNF.mkLetDecl binderName type' value' modify fun s => { s with - lctx := s.lctx.mkLetDecl letDecl.fvarId binderName type value false + lctx := s.lctx.mkLetDecl letDecl.fvarId binderName type value nondep seq := s.seq.push <| .let letDecl } return letDecl @@ -385,6 +385,38 @@ where else return (ps, e.instantiateRev xs) +/-- +Given `e` and `args` where `mkAppN e (args.map (·.toExpr))` is not necessarily well-typed +(because of dependent typing), returns `e.beta args'` where `args'` are new local declarations each +assigned to a value in `args` with adjusted type (such that the resulting expression is well-typed). +-/ +def mkTypeCorrectApp (e : Expr) (args : Array (Arg .pure)) : M Expr := do + if args.isEmpty then + return e + let type ← liftMetaM <| do + let type ← Meta.inferType e + if type.getNumHeadForalls < args.size then + -- expose foralls + Meta.forallBoundedTelescope type args.size Meta.mkForallFVars + else + return type + go type 0 #[] +where + go (type : Expr) (i : Nat) (xs : Array Expr) : M Expr := do + if h : i < args.size then + match type with + | .forallE nm t b bi => + let t := t.instantiateRev xs + let arg := args[i] + if ← liftMetaM <| Meta.isProp t then + go b (i + 1) (xs.push (mkLcProof t)) + else + let decl ← mkLetDecl nm t arg.toExpr (← arg.inferType) arg (nondep := true) + go b (i + 1) (xs.push (.fvar decl.fvarId)) + | _ => liftMetaM <| Meta.throwFunctionExpected (mkAppN e xs) + else + return e.beta xs + def mustEtaExpand (env : Environment) (e : Expr) : Bool := if let .const declName _ := e.getAppFn then match env.find? declName with @@ -526,7 +558,7 @@ where k args[arity...*] ``` -/ - mkOverApplication (app : (Arg .pure)) (args : Array Expr) (arity : Nat) : M (Arg .pure) := do + mkOverApplication (app : Arg .pure) (args : Array Expr) (arity : Nat) : M (Arg .pure) := do if args.size == arity then return app else @@ -541,11 +573,14 @@ where /-- Visit a `matcher`/`casesOn` alternative. -/ - visitAlt (casesAltInfo : CasesAltInfo) (e : Expr) : M (Expr × (Alt .pure)) := do + visitAlt (casesAltInfo : CasesAltInfo) (e : Expr) (overArgs : Array (Arg .pure)) : + M (Expr × (Alt .pure)) := do withNewScope do match casesAltInfo with | .default numHyps => - let c ← toCode (← visit (mkAppN e (Array.replicate numHyps erasedExpr))) + let e := mkAppN e (Array.replicate numHyps erasedExpr) + let e ← mkTypeCorrectApp e overArgs + let c ← toCode (← visit e) let altType ← c.inferType return (altType, .default c) | .ctor ctorName numParams => @@ -555,6 +590,7 @@ where let (ps', e') ← ToLCNF.visitLambda e ps := ps ++ ps' e := e' + e ← mkTypeCorrectApp e overArgs /- Insert the free variable ids of fields that are type formers into `toAny`. Recall that we do not want to have "data" occurring in types. @@ -579,7 +615,8 @@ where visitCases (casesInfo : CasesInfo) (e : Expr) : M (Arg .pure) := etaIfUnderApplied e casesInfo.arity do let args := e.getAppArgs - let mut resultType ← toLCNFType (← liftMetaM do Meta.inferType (mkAppN e.getAppFn args[*...casesInfo.arity])) + let overArgs ← (args.drop casesInfo.arity).mapM visitAppArg + let mut resultType ← toLCNFType (← liftMetaM do Meta.inferType (mkAppN e.getAppFn args)) let typeName := casesInfo.indName let .inductInfo indVal ← getConstInfo typeName | unreachable! if casesInfo.numAlts == 0 then @@ -609,8 +646,7 @@ where fieldArgs := fieldArgs.push fieldArg return fieldArgs let f := args[casesInfo.altsRange.lower]! - let result ← visit (mkAppN f fieldArgs) - mkOverApplication result args casesInfo.arity + visit (mkAppN (mkAppN f fieldArgs) (overArgs.map (·.toExpr))) else let mut alts := #[] let discr ← visitAppArg args[casesInfo.discrPos]! @@ -618,14 +654,13 @@ where | .fvar discrFVarId => pure discrFVarId | .erased | .type .. => mkAuxLetDecl .erased for i in casesInfo.altsRange, numParams in casesInfo.altNumParams do - let (altType, alt) ← visitAlt numParams args[i]! + let (altType, alt) ← visitAlt numParams args[i]! overArgs resultType := joinTypes altType resultType alts := alts.push alt let cases := ⟨typeName, resultType, discrFVarId, alts⟩ let auxDecl ← mkAuxParam resultType pushElement (.cases auxDecl cases) - let result := .fvar auxDecl.fvarId - mkOverApplication result args casesInfo.arity + return .fvar auxDecl.fvarId visitCtor (arity : Nat) (e : Expr) : M (Arg .pure) := etaIfUnderApplied e arity do @@ -843,14 +878,14 @@ where visitLet (e : Expr) (xs : Array Expr) : M (Arg .pure) := do match e with - | .letE binderName type value body _ => + | .letE binderName type value body nondep => let type := type.instantiateRev xs let value := value.instantiateRev xs if (← (liftMetaM <| Meta.isProp type) <||> isTypeFormerType type) then visitLet body (xs.push value) else let type' ← toLCNFType type - let letDecl ← mkLetDecl binderName type value type' (← visit value) + let letDecl ← mkLetDecl binderName type value type' (← visit value) nondep visitLet body (xs.push (.fvar letDecl.fvarId)) | _ => let e := e.instantiateRev xs diff --git a/tests/lean/run/12284.lean b/tests/lean/run/12284.lean new file mode 100644 index 0000000000..9e6392222d --- /dev/null +++ b/tests/lean/run/12284.lean @@ -0,0 +1,12 @@ +set_option trace.Compiler.init true +/-- +trace: [Compiler.init] size: 3 + def test x y : Bool := + cases x : Bool + | PUnit.unit => + let a := y; + return a +-/ +#guard_msgs in +def test (x : Unit) (y : Bool) : Bool := + x.casesOn (fun a => a) y