feat: implement cast TODO

fixes issue reported at
https://leanprover.zulipchat.com/#narrow/stream/270676-lean4/topic/Annoying.20LCNF.20errors/near/301269857
This commit is contained in:
Leonardo de Moura 2022-09-28 15:27:01 -07:00
parent 970331de05
commit 94c2ec38d5
2 changed files with 64 additions and 3 deletions

View file

@ -228,16 +228,38 @@ def shouldInlineLocal (decl : FunDecl) : SimpM Bool := do
isSmall decl
/--
"Beta-reduce" `(fun params => code) args`.
LCNF "Beta-reduce". The equivalent of `(fun params => code) args`.
If `mustInline` is true, the local function declarations in the resulting code are marked as `.mustInline`.
See comment at `updateFunDeclInfo`.
-/
def betaReduce (params : Array Param) (code : Code) (args : Array Expr) (mustInline := false) : SimpM Code := do
-- TODO: add necessary casts to `args`
let mut subst := {}
let mut castDecls := #[]
for param in params, arg in args do
subst := subst.insert param.fvarId arg
/-
If `param` hast type `` but `arg` does not, we must insert a cast.
Otherwise, the resulting code may be type incorrect.
For example, the following code is type correct before inlining `f`
because `x : `.
```
def foo (g : A → A) (a : B) :=
fun f (x : ) :=
let _x.1 := g x
...
let _x.2 := f a
...
```
We must introduce a cast around `a` to make sure the resulting expression is type correct.
-/
if param.type.isAnyType && !(← inferType arg).isAnyType then
let castArg ← mkLcCast arg anyTypeExpr
let castDecl ← mkAuxLetDecl castArg
castDecls := castDecls.push (CodeDecl.let castDecl)
subst := subst.insert param.fvarId (.fvar castDecl.fvarId)
else
subst := subst.insert param.fvarId arg
let code ← code.internalize subst
let code := LCNF.attachCodeDecls castDecls code
updateFunDeclInfo code mustInline
return code

View file

@ -0,0 +1,39 @@
namespace MWE
universe u v w
inductive Id {A : Type u} : A → A → Type u
| refl {a : A} : Id a a
attribute [eliminator] Id.casesOn
infix:50 (priority := high) " = " => Id
inductive Unit : Type u
| star : Unit
attribute [eliminator] Unit.casesOn
notation "𝟏" => Unit
notation "★" => Unit.star
notation "" => Nat
def vect (A : Type u) : → Type u
| Nat.zero => 𝟏
| Nat.succ n => A × vect A n
def vect.const {A : Type u} (a : A) : ∀ n, vect A n
| Nat.zero => ★
| Nat.succ n => (a, const a n)
def vect.map {A : Type u} {B : Type v} (f : A → B) :
∀ {n : }, vect A n → vect B n
| Nat.zero => λ _ => ★
| Nat.succ n => λ v => (f v.1, map f v.2)
def transport {A : Type u} (B : A → Type v) {a b : A} (p : a = b) : B a → B b :=
by { induction p; apply id }
def vect.subst {A B : Type u} (p : A = B) (f : B → A) {n : } (v : vect A n) :
vect.map f (transport (vect · n) p v) = vect.map (f ∘ transport id p) v :=
by { induction p; apply Id.refl }