From 94c2ec38d563f3ddc2773444532db0db9127546b Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Wed, 28 Sep 2022 15:27:01 -0700 Subject: [PATCH] feat: implement `cast` TODO fixes issue reported at https://leanprover.zulipchat.com/#narrow/stream/270676-lean4/topic/Annoying.20LCNF.20errors/near/301269857 --- src/Lean/Compiler/LCNF/Simp/SimpM.lean | 28 ++++++++++++++++-- tests/lean/run/lcnfCastIssue.lean | 39 ++++++++++++++++++++++++++ 2 files changed, 64 insertions(+), 3 deletions(-) create mode 100644 tests/lean/run/lcnfCastIssue.lean diff --git a/src/Lean/Compiler/LCNF/Simp/SimpM.lean b/src/Lean/Compiler/LCNF/Simp/SimpM.lean index 3493e47e74..fb846e8abd 100644 --- a/src/Lean/Compiler/LCNF/Simp/SimpM.lean +++ b/src/Lean/Compiler/LCNF/Simp/SimpM.lean @@ -228,16 +228,38 @@ def shouldInlineLocal (decl : FunDecl) : SimpM Bool := do isSmall decl /-- -"Beta-reduce" `(fun params => code) args`. +LCNF "Beta-reduce". The equivalent of `(fun params => code) args`. If `mustInline` is true, the local function declarations in the resulting code are marked as `.mustInline`. See comment at `updateFunDeclInfo`. -/ def betaReduce (params : Array Param) (code : Code) (args : Array Expr) (mustInline := false) : SimpM Code := do - -- TODO: add necessary casts to `args` let mut subst := {} + let mut castDecls := #[] for param in params, arg in args do - subst := subst.insert param.fvarId arg + /- + If `param` hast type `⊤` but `arg` does not, we must insert a cast. + Otherwise, the resulting code may be type incorrect. + For example, the following code is type correct before inlining `f` + because `x : ⊤`. + ``` + def foo (g : A → A) (a : B) := + fun f (x : ⊤) := + let _x.1 := g x + ... + let _x.2 := f a + ... + ``` + We must introduce a cast around `a` to make sure the resulting expression is type correct. + -/ + if param.type.isAnyType && !(← inferType arg).isAnyType then + let castArg ← mkLcCast arg anyTypeExpr + let castDecl ← mkAuxLetDecl castArg + castDecls := castDecls.push (CodeDecl.let castDecl) + subst := subst.insert param.fvarId (.fvar castDecl.fvarId) + else + subst := subst.insert param.fvarId arg let code ← code.internalize subst + let code := LCNF.attachCodeDecls castDecls code updateFunDeclInfo code mustInline return code diff --git a/tests/lean/run/lcnfCastIssue.lean b/tests/lean/run/lcnfCastIssue.lean new file mode 100644 index 0000000000..cdef50c3b1 --- /dev/null +++ b/tests/lean/run/lcnfCastIssue.lean @@ -0,0 +1,39 @@ +namespace MWE + +universe u v w + +inductive Id {A : Type u} : A → A → Type u +| refl {a : A} : Id a a + +attribute [eliminator] Id.casesOn + +infix:50 (priority := high) " = " => Id + +inductive Unit : Type u +| star : Unit + +attribute [eliminator] Unit.casesOn + +notation "𝟏" => Unit +notation "★" => Unit.star +notation "ℕ" => Nat + +def vect (A : Type u) : ℕ → Type u +| Nat.zero => 𝟏 +| Nat.succ n => A × vect A n + +def vect.const {A : Type u} (a : A) : ∀ n, vect A n +| Nat.zero => ★ +| Nat.succ n => (a, const a n) + +def vect.map {A : Type u} {B : Type v} (f : A → B) : + ∀ {n : ℕ}, vect A n → vect B n +| Nat.zero => λ _ => ★ +| Nat.succ n => λ v => (f v.1, map f v.2) + +def transport {A : Type u} (B : A → Type v) {a b : A} (p : a = b) : B a → B b := +by { induction p; apply id } + +def vect.subst {A B : Type u} (p : A = B) (f : B → A) {n : ℕ} (v : vect A n) : + vect.map f (transport (vect · n) p v) = vect.map (f ∘ transport id p) v := +by { induction p; apply Id.refl }