feat: implement cast TODO
fixes issue reported at https://leanprover.zulipchat.com/#narrow/stream/270676-lean4/topic/Annoying.20LCNF.20errors/near/301269857
This commit is contained in:
parent
970331de05
commit
94c2ec38d5
2 changed files with 64 additions and 3 deletions
|
|
@ -228,16 +228,38 @@ def shouldInlineLocal (decl : FunDecl) : SimpM Bool := do
|
|||
isSmall decl
|
||||
|
||||
/--
|
||||
"Beta-reduce" `(fun params => code) args`.
|
||||
LCNF "Beta-reduce". The equivalent of `(fun params => code) args`.
|
||||
If `mustInline` is true, the local function declarations in the resulting code are marked as `.mustInline`.
|
||||
See comment at `updateFunDeclInfo`.
|
||||
-/
|
||||
def betaReduce (params : Array Param) (code : Code) (args : Array Expr) (mustInline := false) : SimpM Code := do
|
||||
-- TODO: add necessary casts to `args`
|
||||
let mut subst := {}
|
||||
let mut castDecls := #[]
|
||||
for param in params, arg in args do
|
||||
subst := subst.insert param.fvarId arg
|
||||
/-
|
||||
If `param` hast type `⊤` but `arg` does not, we must insert a cast.
|
||||
Otherwise, the resulting code may be type incorrect.
|
||||
For example, the following code is type correct before inlining `f`
|
||||
because `x : ⊤`.
|
||||
```
|
||||
def foo (g : A → A) (a : B) :=
|
||||
fun f (x : ⊤) :=
|
||||
let _x.1 := g x
|
||||
...
|
||||
let _x.2 := f a
|
||||
...
|
||||
```
|
||||
We must introduce a cast around `a` to make sure the resulting expression is type correct.
|
||||
-/
|
||||
if param.type.isAnyType && !(← inferType arg).isAnyType then
|
||||
let castArg ← mkLcCast arg anyTypeExpr
|
||||
let castDecl ← mkAuxLetDecl castArg
|
||||
castDecls := castDecls.push (CodeDecl.let castDecl)
|
||||
subst := subst.insert param.fvarId (.fvar castDecl.fvarId)
|
||||
else
|
||||
subst := subst.insert param.fvarId arg
|
||||
let code ← code.internalize subst
|
||||
let code := LCNF.attachCodeDecls castDecls code
|
||||
updateFunDeclInfo code mustInline
|
||||
return code
|
||||
|
||||
|
|
|
|||
39
tests/lean/run/lcnfCastIssue.lean
Normal file
39
tests/lean/run/lcnfCastIssue.lean
Normal file
|
|
@ -0,0 +1,39 @@
|
|||
namespace MWE
|
||||
|
||||
universe u v w
|
||||
|
||||
inductive Id {A : Type u} : A → A → Type u
|
||||
| refl {a : A} : Id a a
|
||||
|
||||
attribute [eliminator] Id.casesOn
|
||||
|
||||
infix:50 (priority := high) " = " => Id
|
||||
|
||||
inductive Unit : Type u
|
||||
| star : Unit
|
||||
|
||||
attribute [eliminator] Unit.casesOn
|
||||
|
||||
notation "𝟏" => Unit
|
||||
notation "★" => Unit.star
|
||||
notation "ℕ" => Nat
|
||||
|
||||
def vect (A : Type u) : ℕ → Type u
|
||||
| Nat.zero => 𝟏
|
||||
| Nat.succ n => A × vect A n
|
||||
|
||||
def vect.const {A : Type u} (a : A) : ∀ n, vect A n
|
||||
| Nat.zero => ★
|
||||
| Nat.succ n => (a, const a n)
|
||||
|
||||
def vect.map {A : Type u} {B : Type v} (f : A → B) :
|
||||
∀ {n : ℕ}, vect A n → vect B n
|
||||
| Nat.zero => λ _ => ★
|
||||
| Nat.succ n => λ v => (f v.1, map f v.2)
|
||||
|
||||
def transport {A : Type u} (B : A → Type v) {a b : A} (p : a = b) : B a → B b :=
|
||||
by { induction p; apply id }
|
||||
|
||||
def vect.subst {A B : Type u} (p : A = B) (f : B → A) {n : ℕ} (v : vect A n) :
|
||||
vect.map f (transport (vect · n) p v) = vect.map (f ∘ transport id p) v :=
|
||||
by { induction p; apply Id.refl }
|
||||
Loading…
Add table
Reference in a new issue