fix: use maxType when building expression in expression tree elaborator (#4215)
The expression tree elaborator computes a "maxType" that every leaf term can be coerced to, but the elaborator was not ensuring that the entire expression tree would have maxType as its type. This led to unexpected errors in examples such as ```lean example (a : Nat) (b : Int) : a = id (a * b^2) := sorry ``` where it would say it could not synthesize an `HMul Int Int Nat` instance (the `Nat` would propagate from the `a` on the LHS of the equality). The issue in this case is that `HPow` uses default instances, so while the expression tree elaborator decides that `a * b^2` should be referring to an `Int`, the actual elaborated type is temporarily a metavariable. Then, when the binrel elaborator is looking at both sides of the equality, it decides that `Nat` will work and coercions don't need to be inserted. The fix is to unify the type of the resulting elaborated expression with the computed maxType. One wrinkle is that `hasUncomparable` being false is a valid test only if there are no leaf terms with unknown types (if they become known, it could change `hasUncomparable` to true), so this unification is only performed if the leaf terms all have known types. Fixes issue described by Floris van Doorn on [Zulip](https://leanprover.zulipchat.com/#narrow/stream/287929-mathlib4/topic/elaboration.20issue.20involving.20powers.20and.20sums/near/439243587).
This commit is contained in:
parent
02b6fb3f41
commit
b639d102d1
4 changed files with 58 additions and 22 deletions
|
|
@ -241,7 +241,10 @@ private def hasCoe (fromType toType : Expr) : TermElabM Bool := do
|
|||
|
||||
private structure AnalyzeResult where
|
||||
max? : Option Expr := none
|
||||
hasUncomparable : Bool := false -- `true` if there are two types `α` and `β` where we don't have coercions in any direction.
|
||||
/-- `true` if there are two types `α` and `β` where we don't have coercions in any direction. -/
|
||||
hasUncomparable : Bool := false
|
||||
/-- `true` if there are any leaf terms with an unknown type (according to `isUnknown`). -/
|
||||
hasUnknown : Bool := false
|
||||
|
||||
private def isUnknown : Expr → Bool
|
||||
| .mvar .. => true
|
||||
|
|
@ -255,7 +258,7 @@ private def analyze (t : Tree) (expectedType? : Option Expr) : TermElabM Analyze
|
|||
match expectedType? with
|
||||
| none => pure none
|
||||
| some expectedType =>
|
||||
let expectedType ← instantiateMVars expectedType
|
||||
let expectedType := (← instantiateMVars expectedType).cleanupAnnotations
|
||||
if isUnknown expectedType then pure none else pure (some expectedType)
|
||||
(go t *> get).run' { max? }
|
||||
where
|
||||
|
|
@ -268,8 +271,10 @@ where
|
|||
| .binop _ _ _ lhs rhs => go lhs; go rhs
|
||||
| .unop _ _ arg => go arg
|
||||
| .term _ _ val =>
|
||||
let type ← instantiateMVars (← inferType val)
|
||||
unless isUnknown type do
|
||||
let type := (← instantiateMVars (← inferType val)).cleanupAnnotations
|
||||
if isUnknown type then
|
||||
modify fun s => { s with hasUnknown := true }
|
||||
else
|
||||
match (← get).max? with
|
||||
| none => modify fun s => { s with max? := type }
|
||||
| some max =>
|
||||
|
|
@ -430,7 +435,7 @@ mutual
|
|||
| .unop ref f arg =>
|
||||
return .unop ref f (← go arg none false false)
|
||||
| .term ref trees e =>
|
||||
let type ← instantiateMVars (← inferType e)
|
||||
let type := (← instantiateMVars (← inferType e)).cleanupAnnotations
|
||||
trace[Elab.binop] "visiting {e} : {type} =?= {maxType}"
|
||||
if isUnknown type then
|
||||
if let some f := f? then
|
||||
|
|
@ -448,12 +453,17 @@ mutual
|
|||
|
||||
private partial def toExpr (tree : Tree) (expectedType? : Option Expr) : TermElabM Expr := do
|
||||
let r ← analyze tree expectedType?
|
||||
trace[Elab.binop] "hasUncomparable: {r.hasUncomparable}, maxType: {r.max?}"
|
||||
trace[Elab.binop] "hasUncomparable: {r.hasUncomparable}, hasUnknown: {r.hasUnknown}, maxType: {r.max?}"
|
||||
if r.hasUncomparable || r.max?.isNone then
|
||||
let result ← toExprCore tree
|
||||
ensureHasType expectedType? result
|
||||
else
|
||||
let result ← toExprCore (← applyCoe tree r.max?.get! (isPred := false))
|
||||
unless r.hasUnknown do
|
||||
-- Record the resulting maxType calculation.
|
||||
-- We can do this when all the types are known, since in this case `hasUncomparable` is valid.
|
||||
-- If they're not known, recording maxType like this can lead to heterogeneous operations failing to elaborate.
|
||||
discard <| isDefEqGuarded (← inferType result) r.max?.get!
|
||||
trace[Elab.binop] "result: {result}"
|
||||
ensureHasType expectedType? result
|
||||
|
||||
|
|
@ -519,7 +529,7 @@ def elabBinRelCore (noProp : Bool) (stx : Syntax) (expectedType? : Option Expr)
|
|||
let rhs ← withRef rhsStx <| toTree rhsStx
|
||||
let tree := .binop stx .regular f lhs rhs
|
||||
let r ← analyze tree none
|
||||
trace[Elab.binrel] "hasUncomparable: {r.hasUncomparable}, maxType: {r.max?}"
|
||||
trace[Elab.binrel] "hasUncomparable: {r.hasUncomparable}, hasUnknown: {r.hasUnknown}, maxType: {r.max?}"
|
||||
if r.hasUncomparable || r.max?.isNone then
|
||||
-- Use default elaboration strategy + `toBoolIfNecessary`
|
||||
let lhs ← toExprCore lhs
|
||||
|
|
|
|||
|
|
@ -1,11 +0,0 @@
|
|||
example (n : Nat) (i : Int) : n + i = i + n := by
|
||||
rw [Int.add_comm]
|
||||
|
||||
def f1 (a : Int) (b c : Nat) : Int :=
|
||||
a + (b - c)
|
||||
|
||||
def f2 (a : Int) (b c : Nat) : Int :=
|
||||
(b - c) + a
|
||||
|
||||
#print f1
|
||||
#print f2
|
||||
|
|
@ -1,4 +0,0 @@
|
|||
def f1 : Int → Nat → Nat → Int :=
|
||||
fun a b c => a + (↑b - ↑c)
|
||||
def f2 : Int → Nat → Nat → Int :=
|
||||
fun a b c => ↑b - ↑c + a
|
||||
41
tests/lean/run/binop.lean
Normal file
41
tests/lean/run/binop.lean
Normal file
|
|
@ -0,0 +1,41 @@
|
|||
/-!
|
||||
# Tests for the expression tree elaborator (`binop%`, etc.)
|
||||
-/
|
||||
|
||||
/-!
|
||||
Some basic Int/Nat examples
|
||||
-/
|
||||
|
||||
example (n : Nat) (i : Int) : n + i = i + n := by
|
||||
rw [Int.add_comm]
|
||||
|
||||
def f1 (a : Int) (b c : Nat) : Int :=
|
||||
a + (b - c)
|
||||
|
||||
def f2 (a : Int) (b c : Nat) : Int :=
|
||||
(b - c) + a
|
||||
|
||||
/--
|
||||
info: def f1 : Int → Nat → Nat → Int :=
|
||||
fun a b c => a + (↑b - ↑c)
|
||||
-/
|
||||
#guard_msgs in
|
||||
#print f1
|
||||
|
||||
/--
|
||||
info: def f2 : Int → Nat → Nat → Int :=
|
||||
fun a b c => ↑b - ↑c + a
|
||||
-/
|
||||
#guard_msgs in
|
||||
#print f2
|
||||
|
||||
|
||||
/-!
|
||||
Interaction with default instances for pow. This used to fail with not being able
|
||||
to synthesize an `HMul Int Int Nat` instance because the type of
|
||||
the result of `*` wasn't being set to `Int`.
|
||||
-/
|
||||
|
||||
/-- info: ∀ (a : Nat) (b : Int), ↑a = id (↑a * b ^ 2) : Prop -/
|
||||
#guard_msgs in
|
||||
#check ∀ (a : Nat) (b : Int), a = id (a * b^2)
|
||||
Loading…
Add table
Reference in a new issue