fix: inferProjType at LCNF
This commit is contained in:
parent
a2631ce037
commit
b2d6caca0a
2 changed files with 86 additions and 20 deletions
|
|
@ -91,30 +91,34 @@ mutual
|
|||
partial def inferProjType (structName : Name) (idx : Nat) (s : Expr) : InferTypeM Expr := do
|
||||
let failed {α} : Unit → InferTypeM α := fun _ =>
|
||||
throwError "invalid projection{indentExpr (mkProj structName idx s)}"
|
||||
let structType ← inferType s
|
||||
matchConstStruct structType.getAppFn failed fun structVal structLvls ctorVal =>
|
||||
let n := structVal.numParams
|
||||
let structParams := structType.getAppArgs
|
||||
if n != structParams.size then
|
||||
failed ()
|
||||
else do
|
||||
let mut ctorType ← inferAppType (mkAppN (mkConst ctorVal.name structLvls) structParams)
|
||||
for _ in [:idx] do
|
||||
let structType := (← inferType s).headBeta
|
||||
if structType.isAnyType then
|
||||
/- TODO: after we erase universe variables, we can just extract a better type using just `structName` and `idx`. -/
|
||||
return anyTypeExpr
|
||||
else
|
||||
matchConstStruct structType.getAppFn failed fun structVal structLvls ctorVal =>
|
||||
let n := structVal.numParams
|
||||
let structParams := structType.getAppArgs
|
||||
if n != structParams.size then
|
||||
failed ()
|
||||
else do
|
||||
let mut ctorType ← inferAppType (mkAppN (mkConst ctorVal.name structLvls) structParams)
|
||||
for _ in [:idx] do
|
||||
match ctorType with
|
||||
| .forallE _ _ body _ =>
|
||||
if body.hasLooseBVars then
|
||||
-- This can happen when one of the fields is a type or type former.
|
||||
ctorType := body.instantiate1 anyTypeExpr
|
||||
else
|
||||
ctorType := body
|
||||
| _ =>
|
||||
if ctorType.isAnyType then return anyTypeExpr
|
||||
failed ()
|
||||
match ctorType with
|
||||
| .forallE _ _ body _ =>
|
||||
if body.hasLooseBVars then
|
||||
-- This can happen when one of the fields is a type or type former.
|
||||
ctorType := body.instantiate1 anyTypeExpr
|
||||
else
|
||||
ctorType := body
|
||||
| .forallE _ d _ _ => return d
|
||||
| _ =>
|
||||
if ctorType.isAnyType then return anyTypeExpr
|
||||
failed ()
|
||||
match ctorType with
|
||||
| .forallE _ d _ _ => return d
|
||||
| _ =>
|
||||
if ctorType.isAnyType then return anyTypeExpr
|
||||
failed ()
|
||||
|
||||
partial def getLevel? (type : Expr) : InferTypeM (Option Level) := do
|
||||
match (← inferType type) with
|
||||
|
|
|
|||
62
tests/lean/run/lcnfInferProjTypeIssue.lean
Normal file
62
tests/lean/run/lcnfInferProjTypeIssue.lean
Normal file
|
|
@ -0,0 +1,62 @@
|
|||
import Lean
|
||||
|
||||
structure Vec2 where
|
||||
(x y : Float)
|
||||
|
||||
instance : Add Vec2 :=
|
||||
⟨λ ⟨x₁,x₂⟩ ⟨y₁, y₂⟩ => ⟨x₁+y₁, x₂+y₂⟩⟩
|
||||
|
||||
instance : HMul Float Vec2 Vec2 :=
|
||||
⟨λ s ⟨x₁,x₂⟩ => ⟨s*x₁, s*x₂⟩⟩
|
||||
|
||||
def NFloatArray (n : Nat) := {a : FloatArray // a.size = n}
|
||||
|
||||
instance {n} : Add (NFloatArray n) :=
|
||||
⟨λ x y => Id.run do
|
||||
let mut x := x.1
|
||||
for i in [0:n] do
|
||||
x := x.set ⟨i,sorry⟩ (x[i]'sorry+y.1[i]'sorry)
|
||||
⟨x,sorry⟩⟩
|
||||
|
||||
instance {n} : HMul Float (NFloatArray n) (NFloatArray n) :=
|
||||
⟨λ s x => Id.run do
|
||||
let mut x := x.1
|
||||
for i in [0:n] do
|
||||
x := x.set ⟨i,sorry⟩ (s*x[i]'sorry)
|
||||
⟨x,sorry⟩⟩
|
||||
|
||||
def FloatVector : Nat → Type
|
||||
| 0 => Unit
|
||||
| 1 => Float
|
||||
| 2 => Vec2
|
||||
| (n+3) => NFloatArray (n+3)
|
||||
|
||||
@[inline] def FloatVector.add {n : Nat} (x y : FloatVector n) : FloatVector n :=
|
||||
match n with
|
||||
| 0 => Unit.unit
|
||||
| 1 => by unfold FloatVector at x y; apply x + y
|
||||
| 2 => by unfold FloatVector at x y; apply x + y
|
||||
| (_+3) => by unfold FloatVector at x y; apply x + y
|
||||
|
||||
def FloatVector.smul {n : Nat} (s : Float) (x : FloatVector n) : FloatVector n :=
|
||||
match n with
|
||||
| 0 => Unit.unit
|
||||
| 1 => by unfold FloatVector at x; apply s*x
|
||||
| 2 => by unfold FloatVector at x; apply s*x
|
||||
| (_+3) => by unfold FloatVector at x; apply s*x
|
||||
|
||||
|
||||
instance : Add (FloatVector n) := ⟨λ x y => x.add y⟩
|
||||
instance : HMul Float (FloatVector n) (FloatVector n) := ⟨λ s x => x.smul s⟩
|
||||
|
||||
def foo1 := λ (x y : FloatVector 2) => x + y
|
||||
|
||||
def foo2 := λ {n} (s : Float) (x y : FloatVector n) => s * (x + y)
|
||||
|
||||
#eval Lean.Compiler.compile #[``foo1]
|
||||
#eval Lean.Compiler.compile #[``foo2]
|
||||
|
||||
set_option trace.Compiler.result true
|
||||
set_option pp.funBinderTypes true
|
||||
#eval Lean.Compiler.compile #[``foo1]
|
||||
#eval Lean.Compiler.compile #[``foo2]
|
||||
Loading…
Add table
Reference in a new issue