feat: new elaborator for binop%
cc @gebner.
This commit is contained in:
parent
24fe2875c6
commit
d3d03df83c
5 changed files with 149 additions and 58 deletions
|
|
@ -43,41 +43,6 @@ open Meta
|
|||
elabAppArgs f #[] #[Arg.expr lhs, Arg.expr rhs] expectedType? (explicit := false) (ellipsis := false)
|
||||
| none => throwUnknownConstant stx[1].getId
|
||||
|
||||
-- TODO: move to another file?
|
||||
private def hasUnknownType (e : Expr) : MetaM Bool :=
|
||||
return (← inferType e).getAppFn.isMVar
|
||||
|
||||
@[builtinTermElab binop] def elabBinOp : TermElab := fun stx expectedType? => do
|
||||
match stx with
|
||||
| `(binop% $f $lhs $rhs) =>
|
||||
match expectedType? with
|
||||
| none =>
|
||||
-- We elaborate as a normal application when expected type is not available
|
||||
let stxNew ← `($f:ident $lhs $rhs)
|
||||
withMacroExpansion stx stxNew <| elabTerm stxNew none
|
||||
| some expectedType =>
|
||||
match (← resolveId? f) with
|
||||
| some f =>
|
||||
let syntheticMVarsSaved := (← get).syntheticMVars
|
||||
modify fun s => { s with syntheticMVars := [] }
|
||||
try
|
||||
let lhs ← elabTerm lhs none
|
||||
let rhs ← elabTerm rhs none
|
||||
if (← hasUnknownType lhs) && (← hasUnknownType rhs) then
|
||||
-- We want the numerals in terms such as `(1 + 1)` `(2 * 3 + 4)` to be elaborated using the expected type
|
||||
-- This is particularly important when there is no coercion from `Nat` to the expected type.
|
||||
elabAppArgs f #[] #[Arg.expr lhs, Arg.expr rhs] expectedType (explicit := false) (ellipsis := false)
|
||||
else
|
||||
-- We force TC resolution and default instances to be used.
|
||||
-- Note that we do not provide the expected type to make sure it can be inferred by the TC procedure. See issue #382
|
||||
let r ← elabAppArgs f #[] #[Arg.expr lhs, Arg.expr rhs] (expectedType? := none) (explicit := false) (ellipsis := false)
|
||||
synthesizeSyntheticMVarsUsingDefault
|
||||
return r
|
||||
finally
|
||||
modify fun s => { s with syntheticMVars := s.syntheticMVars ++ syntheticMVarsSaved }
|
||||
| none => throwUnknownConstant stx[1].getId
|
||||
| _ => throwUnsupportedSyntax
|
||||
|
||||
@[builtinTermElab forInMacro] def elabForIn : TermElab := fun stx expectedType? => do
|
||||
match stx with
|
||||
| `(forIn% $col $init $body) =>
|
||||
|
|
@ -118,4 +83,107 @@ where
|
|||
throwFailure (forInInstance : Expr) : TermElabM Expr :=
|
||||
throwError "failed to synthesize instance for 'forIn%' notation{indentExpr forInInstance}"
|
||||
|
||||
namespace BinOp
|
||||
/- Elaborator for `binop%` -/
|
||||
|
||||
private inductive Tree where
|
||||
| term (ref : Syntax) (val : Expr)
|
||||
| op (ref : Syntax) (f : Expr) (lhs rhs : Tree)
|
||||
|
||||
private partial def toTree (s : Syntax) : TermElabM Tree :=
|
||||
withSynthesizeLight do
|
||||
go (← liftMacroM <| expandMacros s)
|
||||
where
|
||||
go (s : Syntax) := do
|
||||
match s with
|
||||
| `(binop% $f $lhs $rhs) =>
|
||||
let some f ← resolveId? f | throwUnknownConstant f.getId
|
||||
return Tree.op s f (← go lhs) (← go rhs)
|
||||
| `(($e)) => (← go e)
|
||||
| _ =>
|
||||
return Tree.term s (← elabTerm s none)
|
||||
|
||||
-- Auxiliary function used at `analyze`
|
||||
private def hasCoe (fromType toType : Expr) : TermElabM Bool := do
|
||||
let u ← getLevel fromType
|
||||
let v ← getLevel toType
|
||||
let coeInstType := mkAppN (Lean.mkConst ``CoeHTCT [u, v]) #[fromType, toType]
|
||||
match ← trySynthInstance coeInstType (some (maxCoeSize.get (← getOptions))) with
|
||||
| LOption.some _ => return true
|
||||
| LOption.none => return false
|
||||
| LOption.undef => return false -- TODO: should we do something smarter here?
|
||||
|
||||
private structure AnalyzeResult where
|
||||
max? : Option Expr := none
|
||||
hasUncomparable : Bool := false -- `true` if there are two types `α` and `β` where we don't have coercions in any direction.
|
||||
|
||||
private def isUnknow (e : Expr) : Bool :=
|
||||
e.getAppFn.isMVar
|
||||
|
||||
private def analyze (t : Tree) (expectedType? : Option Expr) : TermElabM AnalyzeResult := do
|
||||
let max? ←
|
||||
match expectedType? with
|
||||
| none => pure none
|
||||
| some expectedType =>
|
||||
let expectedType ← instantiateMVars expectedType
|
||||
if isUnknow expectedType then pure none else pure (some expectedType)
|
||||
(go t *> get).run' { max? }
|
||||
where
|
||||
go (t : Tree) : StateRefT AnalyzeResult TermElabM Unit := do
|
||||
unless (← get).hasUncomparable do
|
||||
match t with
|
||||
| Tree.op _ _ lhs rhs => go lhs; go rhs
|
||||
| Tree.term _ val =>
|
||||
let type ← instantiateMVars (← inferType val)
|
||||
unless isUnknow type do
|
||||
match (← get).max? with
|
||||
| none => modify fun s => { s with max? := type }
|
||||
| some max =>
|
||||
unless (← withNewMCtxDepth <| isDefEqGuarded max type) do
|
||||
if (← hasCoe type max) then
|
||||
return ()
|
||||
else if (← hasCoe max type) then
|
||||
modify fun s => { s with max? := type }
|
||||
else
|
||||
trace[Elab.binop] "uncomparable types: {max}, {type}"
|
||||
modify fun s => { s with hasUncomparable := true }
|
||||
|
||||
private def mkOp (f : Expr) (lhs rhs : Expr) : TermElabM Expr :=
|
||||
elabAppArgs f #[] #[Arg.expr lhs, Arg.expr rhs] (expectedType? := none) (explicit := false) (ellipsis := false)
|
||||
|
||||
private def toExpr (t : Tree) : TermElabM Expr := do
|
||||
match t with
|
||||
| Tree.term _ e => return e
|
||||
| Tree.op ref f lhs rhs => withRef ref <| mkOp f (← toExpr lhs) (← toExpr rhs)
|
||||
|
||||
private def applyCoe (t : Tree) (maxType : Expr) : TermElabM Tree := do
|
||||
go t
|
||||
where
|
||||
go (t : Tree) : TermElabM Tree := do
|
||||
match t with
|
||||
| Tree.op ref f lhs rhs => return Tree.op ref f (← go lhs) (← go rhs)
|
||||
| Tree.term ref e =>
|
||||
let type ← inferType e
|
||||
if (← isDefEqGuarded maxType type) then
|
||||
return t
|
||||
else
|
||||
withRef ref <| return Tree.term ref (← mkCoe maxType type e)
|
||||
|
||||
@[builtinTermElab binop]
|
||||
def elabBinOp' : TermElab := fun stx expectedType? => do
|
||||
let tree ← toTree stx
|
||||
let r ← analyze tree expectedType?
|
||||
trace[Elab.binop] "hasUncomparable: {r.hasUncomparable}, maxType: {r.max?}"
|
||||
if r.hasUncomparable || r.max?.isNone then
|
||||
let result ← toExpr tree
|
||||
ensureHasType expectedType? result
|
||||
else
|
||||
let result ← toExpr (← applyCoe tree r.max?.get!)
|
||||
ensureHasType expectedType? result
|
||||
|
||||
builtin_initialize
|
||||
registerTraceClass `Elab.binop
|
||||
|
||||
end BinOp
|
||||
|
||||
end Lean.Elab.Term
|
||||
|
|
|
|||
|
|
@ -309,13 +309,13 @@ def synthesizeSyntheticMVarsUsingDefault : TermElabM Unit := do
|
|||
synthesizeSyntheticMVars (mayPostpone := true)
|
||||
synthesizeUsingDefaultLoop
|
||||
|
||||
private partial def withSynthesizeImp {α} (k : TermElabM α) (mayPostpone : Bool) : TermElabM α := do
|
||||
private partial def withSynthesizeImp {α} (k : TermElabM α) (mayPostpone : Bool) (synthesizeDefault : Bool) : TermElabM α := do
|
||||
let syntheticMVarsSaved := (← get).syntheticMVars
|
||||
modify fun s => { s with syntheticMVars := [] }
|
||||
try
|
||||
let a ← k
|
||||
synthesizeSyntheticMVars mayPostpone
|
||||
if mayPostpone then
|
||||
if mayPostpone && synthesizeDefault then
|
||||
synthesizeUsingDefaultLoop
|
||||
return a
|
||||
finally
|
||||
|
|
@ -326,7 +326,11 @@ private partial def withSynthesizeImp {α} (k : TermElabM α) (mayPostpone : Boo
|
|||
If `mayPostpone == false`, then all of them must be synthesized.
|
||||
Remark: even if `mayPostpone == true`, the method still uses `synthesizeUsingDefault` -/
|
||||
@[inline] def withSynthesize [MonadFunctorT TermElabM m] [Monad m] (k : m α) (mayPostpone := false) : m α :=
|
||||
monadMap (m := TermElabM) (withSynthesizeImp . mayPostpone) k
|
||||
monadMap (m := TermElabM) (withSynthesizeImp . mayPostpone (synthesizeDefault := true)) k
|
||||
|
||||
/-- Similar to `withSynthesize`, but sets `mayPostpone` to `true`, and do not use `synthesizeUsingDefault` -/
|
||||
@[inline] def withSynthesizeLight [MonadFunctorT TermElabM m] [Monad m] (k : m α) : m α :=
|
||||
monadMap (m := TermElabM) (withSynthesizeImp . (mayPostpone := true) (synthesizeDefault := false)) k
|
||||
|
||||
/-- Elaborate `stx`, and make sure all pending synthetic metavariables created while elaborating `stx` are solved. -/
|
||||
def elabTermAndSynthesize (stx : Syntax) (expectedType? : Option Expr) : TermElabM Expr :=
|
||||
|
|
|
|||
|
|
@ -700,6 +700,27 @@ def tryCoeThunk? (expectedType : Expr) (eType : Expr) (e : Expr) : TermElabM (Op
|
|||
| _ =>
|
||||
pure none
|
||||
|
||||
def mkCoe (expectedType : Expr) (eType : Expr) (e : Expr) (f? : Option Expr := none) (errorMsgHeader? : Option String := none) : TermElabM Expr := do
|
||||
let u ← getLevel eType
|
||||
let v ← getLevel expectedType
|
||||
let coeTInstType := mkAppN (mkConst ``CoeT [u, v]) #[eType, e, expectedType]
|
||||
let mvar ← mkFreshExprMVar coeTInstType MetavarKind.synthetic
|
||||
let eNew := mkAppN (mkConst ``coe [u, v]) #[eType, expectedType, e, mvar]
|
||||
let mvarId := mvar.mvarId!
|
||||
try
|
||||
withoutMacroStackAtErr do
|
||||
if (← synthesizeCoeInstMVarCore mvarId) then
|
||||
expandCoe eNew
|
||||
else
|
||||
-- We create an auxiliary metavariable to represent the result, because we need to execute `expandCoe`
|
||||
-- after we syntheze `mvar`
|
||||
let mvarAux ← mkFreshExprMVar expectedType MetavarKind.syntheticOpaque
|
||||
registerSyntheticMVarWithCurrRef mvarAux.mvarId! (SyntheticMVarKind.coe errorMsgHeader? eNew expectedType eType e f?)
|
||||
return mvarAux
|
||||
catch
|
||||
| Exception.error _ msg => throwTypeMismatchError errorMsgHeader? expectedType eType e f? msg
|
||||
| _ => throwTypeMismatchError errorMsgHeader? expectedType eType e f?
|
||||
|
||||
/--
|
||||
Try to apply coercion to make sure `e` has type `expectedType`.
|
||||
Relevant definitions:
|
||||
|
|
@ -713,26 +734,7 @@ private def tryCoe (errorMsgHeader? : Option String) (expectedType : Expr) (eTyp
|
|||
return e
|
||||
else match (← tryCoeThunk? expectedType eType e) with
|
||||
| some r => return r
|
||||
| none =>
|
||||
let u ← getLevel eType
|
||||
let v ← getLevel expectedType
|
||||
let coeTInstType := mkAppN (mkConst ``CoeT [u, v]) #[eType, e, expectedType]
|
||||
let mvar ← mkFreshExprMVar coeTInstType MetavarKind.synthetic
|
||||
let eNew := mkAppN (mkConst ``coe [u, v]) #[eType, expectedType, e, mvar]
|
||||
let mvarId := mvar.mvarId!
|
||||
try
|
||||
withoutMacroStackAtErr do
|
||||
if (← synthesizeCoeInstMVarCore mvarId) then
|
||||
expandCoe eNew
|
||||
else
|
||||
-- We create an auxiliary metavariable to represent the result, because we need to execute `expandCoe`
|
||||
-- after we syntheze `mvar`
|
||||
let mvarAux ← mkFreshExprMVar expectedType MetavarKind.syntheticOpaque
|
||||
registerSyntheticMVarWithCurrRef mvarAux.mvarId! (SyntheticMVarKind.coe errorMsgHeader? eNew expectedType eType e f?)
|
||||
return mvarAux
|
||||
catch
|
||||
| Exception.error _ msg => throwTypeMismatchError errorMsgHeader? expectedType eType e f? msg
|
||||
| _ => throwTypeMismatchError errorMsgHeader? expectedType eType e f?
|
||||
| none => mkCoe expectedType eType e f? errorMsgHeader?
|
||||
|
||||
def isTypeApp? (type : Expr) : TermElabM (Option (Expr × Expr)) := do
|
||||
let type ← withReducible $ whnf type
|
||||
|
|
|
|||
13
tests/lean/binopIssues.lean
Normal file
13
tests/lean/binopIssues.lean
Normal file
|
|
@ -0,0 +1,13 @@
|
|||
axiom Int.add_comm (i j : Int) : i + j = j + i
|
||||
|
||||
example (n : Nat) (i : Int) : n + i = i + n := by
|
||||
rw [Int.add_comm]
|
||||
|
||||
def f1 (a : Int) (b c : Nat) : Int :=
|
||||
a + (b - c)
|
||||
|
||||
def f2 (a : Int) (b c : Nat) : Int :=
|
||||
(b - c) + a
|
||||
|
||||
#print f1
|
||||
#print f2
|
||||
4
tests/lean/binopIssues.lean.expected.out
Normal file
4
tests/lean/binopIssues.lean.expected.out
Normal file
|
|
@ -0,0 +1,4 @@
|
|||
def f1 : Int → Nat → Nat → Int :=
|
||||
fun a b c => a + (Int.ofNat b - Int.ofNat c)
|
||||
def f2 : Int → Nat → Nat → Int :=
|
||||
fun a b c => Int.ofNat b - Int.ofNat c + a
|
||||
Loading…
Add table
Reference in a new issue