From b2d6caca0a2ce02ec210f6f47c9ff1e8d6b31c1e Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Mon, 12 Sep 2022 18:27:14 -0700 Subject: [PATCH] fix: `inferProjType` at LCNF --- src/Lean/Compiler/LCNF/InferType.lean | 44 ++++++++------- tests/lean/run/lcnfInferProjTypeIssue.lean | 62 ++++++++++++++++++++++ 2 files changed, 86 insertions(+), 20 deletions(-) create mode 100644 tests/lean/run/lcnfInferProjTypeIssue.lean diff --git a/src/Lean/Compiler/LCNF/InferType.lean b/src/Lean/Compiler/LCNF/InferType.lean index 2d018d4147..8005a8a51b 100644 --- a/src/Lean/Compiler/LCNF/InferType.lean +++ b/src/Lean/Compiler/LCNF/InferType.lean @@ -91,30 +91,34 @@ mutual partial def inferProjType (structName : Name) (idx : Nat) (s : Expr) : InferTypeM Expr := do let failed {α} : Unit → InferTypeM α := fun _ => throwError "invalid projection{indentExpr (mkProj structName idx s)}" - let structType ← inferType s - matchConstStruct structType.getAppFn failed fun structVal structLvls ctorVal => - let n := structVal.numParams - let structParams := structType.getAppArgs - if n != structParams.size then - failed () - else do - let mut ctorType ← inferAppType (mkAppN (mkConst ctorVal.name structLvls) structParams) - for _ in [:idx] do + let structType := (← inferType s).headBeta + if structType.isAnyType then + /- TODO: after we erase universe variables, we can just extract a better type using just `structName` and `idx`. -/ + return anyTypeExpr + else + matchConstStruct structType.getAppFn failed fun structVal structLvls ctorVal => + let n := structVal.numParams + let structParams := structType.getAppArgs + if n != structParams.size then + failed () + else do + let mut ctorType ← inferAppType (mkAppN (mkConst ctorVal.name structLvls) structParams) + for _ in [:idx] do + match ctorType with + | .forallE _ _ body _ => + if body.hasLooseBVars then + -- This can happen when one of the fields is a type or type former. + ctorType := body.instantiate1 anyTypeExpr + else + ctorType := body + | _ => + if ctorType.isAnyType then return anyTypeExpr + failed () match ctorType with - | .forallE _ _ body _ => - if body.hasLooseBVars then - -- This can happen when one of the fields is a type or type former. - ctorType := body.instantiate1 anyTypeExpr - else - ctorType := body + | .forallE _ d _ _ => return d | _ => if ctorType.isAnyType then return anyTypeExpr failed () - match ctorType with - | .forallE _ d _ _ => return d - | _ => - if ctorType.isAnyType then return anyTypeExpr - failed () partial def getLevel? (type : Expr) : InferTypeM (Option Level) := do match (← inferType type) with diff --git a/tests/lean/run/lcnfInferProjTypeIssue.lean b/tests/lean/run/lcnfInferProjTypeIssue.lean new file mode 100644 index 0000000000..4f070a5248 --- /dev/null +++ b/tests/lean/run/lcnfInferProjTypeIssue.lean @@ -0,0 +1,62 @@ +import Lean + +structure Vec2 where + (x y : Float) + +instance : Add Vec2 := + ⟨λ ⟨x₁,x₂⟩ ⟨y₁, y₂⟩ => ⟨x₁+y₁, x₂+y₂⟩⟩ + +instance : HMul Float Vec2 Vec2 := + ⟨λ s ⟨x₁,x₂⟩ => ⟨s*x₁, s*x₂⟩⟩ + +def NFloatArray (n : Nat) := {a : FloatArray // a.size = n} + +instance {n} : Add (NFloatArray n) := + ⟨λ x y => Id.run do + let mut x := x.1 + for i in [0:n] do + x := x.set ⟨i,sorry⟩ (x[i]'sorry+y.1[i]'sorry) + ⟨x,sorry⟩⟩ + +instance {n} : HMul Float (NFloatArray n) (NFloatArray n) := + ⟨λ s x => Id.run do + let mut x := x.1 + for i in [0:n] do + x := x.set ⟨i,sorry⟩ (s*x[i]'sorry) + ⟨x,sorry⟩⟩ + +def FloatVector : Nat → Type + | 0 => Unit + | 1 => Float + | 2 => Vec2 + | (n+3) => NFloatArray (n+3) + +@[inline] def FloatVector.add {n : Nat} (x y : FloatVector n) : FloatVector n := + match n with + | 0 => Unit.unit + | 1 => by unfold FloatVector at x y; apply x + y + | 2 => by unfold FloatVector at x y; apply x + y + | (_+3) => by unfold FloatVector at x y; apply x + y + +def FloatVector.smul {n : Nat} (s : Float) (x : FloatVector n) : FloatVector n := + match n with + | 0 => Unit.unit + | 1 => by unfold FloatVector at x; apply s*x + | 2 => by unfold FloatVector at x; apply s*x + | (_+3) => by unfold FloatVector at x; apply s*x + + +instance : Add (FloatVector n) := ⟨λ x y => x.add y⟩ +instance : HMul Float (FloatVector n) (FloatVector n) := ⟨λ s x => x.smul s⟩ + +def foo1 := λ (x y : FloatVector 2) => x + y + +def foo2 := λ {n} (s : Float) (x y : FloatVector n) => s * (x + y) + +#eval Lean.Compiler.compile #[``foo1] +#eval Lean.Compiler.compile #[``foo2] + +set_option trace.Compiler.result true +set_option pp.funBinderTypes true +#eval Lean.Compiler.compile #[``foo1] +#eval Lean.Compiler.compile #[``foo2]