fix: inferProjType at LCNF

This commit is contained in:
Leonardo de Moura 2022-09-12 18:27:14 -07:00
parent a2631ce037
commit b2d6caca0a
2 changed files with 86 additions and 20 deletions

View file

@ -91,30 +91,34 @@ mutual
partial def inferProjType (structName : Name) (idx : Nat) (s : Expr) : InferTypeM Expr := do
let failed {α} : Unit → InferTypeM α := fun _ =>
throwError "invalid projection{indentExpr (mkProj structName idx s)}"
let structType ← inferType s
matchConstStruct structType.getAppFn failed fun structVal structLvls ctorVal =>
let n := structVal.numParams
let structParams := structType.getAppArgs
if n != structParams.size then
failed ()
else do
let mut ctorType ← inferAppType (mkAppN (mkConst ctorVal.name structLvls) structParams)
for _ in [:idx] do
let structType := (← inferType s).headBeta
if structType.isAnyType then
/- TODO: after we erase universe variables, we can just extract a better type using just `structName` and `idx`. -/
return anyTypeExpr
else
matchConstStruct structType.getAppFn failed fun structVal structLvls ctorVal =>
let n := structVal.numParams
let structParams := structType.getAppArgs
if n != structParams.size then
failed ()
else do
let mut ctorType ← inferAppType (mkAppN (mkConst ctorVal.name structLvls) structParams)
for _ in [:idx] do
match ctorType with
| .forallE _ _ body _ =>
if body.hasLooseBVars then
-- This can happen when one of the fields is a type or type former.
ctorType := body.instantiate1 anyTypeExpr
else
ctorType := body
| _ =>
if ctorType.isAnyType then return anyTypeExpr
failed ()
match ctorType with
| .forallE _ _ body _ =>
if body.hasLooseBVars then
-- This can happen when one of the fields is a type or type former.
ctorType := body.instantiate1 anyTypeExpr
else
ctorType := body
| .forallE _ d _ _ => return d
| _ =>
if ctorType.isAnyType then return anyTypeExpr
failed ()
match ctorType with
| .forallE _ d _ _ => return d
| _ =>
if ctorType.isAnyType then return anyTypeExpr
failed ()
partial def getLevel? (type : Expr) : InferTypeM (Option Level) := do
match (← inferType type) with

View file

@ -0,0 +1,62 @@
import Lean
structure Vec2 where
(x y : Float)
instance : Add Vec2 :=
⟨λ ⟨x₁,x₂⟩ ⟨y₁, y₂⟩ => ⟨x₁+y₁, x₂+y₂⟩⟩
instance : HMul Float Vec2 Vec2 :=
⟨λ s ⟨x₁,x₂⟩ => ⟨s*x₁, s*x₂⟩⟩
def NFloatArray (n : Nat) := {a : FloatArray // a.size = n}
instance {n} : Add (NFloatArray n) :=
⟨λ x y => Id.run do
let mut x := x.1
for i in [0:n] do
x := x.set ⟨i,sorry⟩ (x[i]'sorry+y.1[i]'sorry)
⟨x,sorry⟩⟩
instance {n} : HMul Float (NFloatArray n) (NFloatArray n) :=
⟨λ s x => Id.run do
let mut x := x.1
for i in [0:n] do
x := x.set ⟨i,sorry⟩ (s*x[i]'sorry)
⟨x,sorry⟩⟩
def FloatVector : Nat → Type
| 0 => Unit
| 1 => Float
| 2 => Vec2
| (n+3) => NFloatArray (n+3)
@[inline] def FloatVector.add {n : Nat} (x y : FloatVector n) : FloatVector n :=
match n with
| 0 => Unit.unit
| 1 => by unfold FloatVector at x y; apply x + y
| 2 => by unfold FloatVector at x y; apply x + y
| (_+3) => by unfold FloatVector at x y; apply x + y
def FloatVector.smul {n : Nat} (s : Float) (x : FloatVector n) : FloatVector n :=
match n with
| 0 => Unit.unit
| 1 => by unfold FloatVector at x; apply s*x
| 2 => by unfold FloatVector at x; apply s*x
| (_+3) => by unfold FloatVector at x; apply s*x
instance : Add (FloatVector n) := ⟨λ x y => x.add y⟩
instance : HMul Float (FloatVector n) (FloatVector n) := ⟨λ s x => x.smul s⟩
def foo1 := λ (x y : FloatVector 2) => x + y
def foo2 := λ {n} (s : Float) (x y : FloatVector n) => s * (x + y)
#eval Lean.Compiler.compile #[``foo1]
#eval Lean.Compiler.compile #[``foo2]
set_option trace.Compiler.result true
set_option pp.funBinderTypes true
#eval Lean.Compiler.compile #[``foo1]
#eval Lean.Compiler.compile #[``foo2]