feat: new elaborator for binop%

cc @gebner.
This commit is contained in:
Leonardo de Moura 2021-08-13 15:44:04 -07:00
parent 24fe2875c6
commit d3d03df83c
5 changed files with 149 additions and 58 deletions

View file

@ -43,41 +43,6 @@ open Meta
elabAppArgs f #[] #[Arg.expr lhs, Arg.expr rhs] expectedType? (explicit := false) (ellipsis := false)
| none => throwUnknownConstant stx[1].getId
-- TODO: move to another file?
private def hasUnknownType (e : Expr) : MetaM Bool :=
return (← inferType e).getAppFn.isMVar
@[builtinTermElab binop] def elabBinOp : TermElab := fun stx expectedType? => do
match stx with
| `(binop% $f $lhs $rhs) =>
match expectedType? with
| none =>
-- We elaborate as a normal application when expected type is not available
let stxNew ← `($f:ident $lhs $rhs)
withMacroExpansion stx stxNew <| elabTerm stxNew none
| some expectedType =>
match (← resolveId? f) with
| some f =>
let syntheticMVarsSaved := (← get).syntheticMVars
modify fun s => { s with syntheticMVars := [] }
try
let lhs ← elabTerm lhs none
let rhs ← elabTerm rhs none
if (← hasUnknownType lhs) && (← hasUnknownType rhs) then
-- We want the numerals in terms such as `(1 + 1)` `(2 * 3 + 4)` to be elaborated using the expected type
-- This is particularly important when there is no coercion from `Nat` to the expected type.
elabAppArgs f #[] #[Arg.expr lhs, Arg.expr rhs] expectedType (explicit := false) (ellipsis := false)
else
-- We force TC resolution and default instances to be used.
-- Note that we do not provide the expected type to make sure it can be inferred by the TC procedure. See issue #382
let r ← elabAppArgs f #[] #[Arg.expr lhs, Arg.expr rhs] (expectedType? := none) (explicit := false) (ellipsis := false)
synthesizeSyntheticMVarsUsingDefault
return r
finally
modify fun s => { s with syntheticMVars := s.syntheticMVars ++ syntheticMVarsSaved }
| none => throwUnknownConstant stx[1].getId
| _ => throwUnsupportedSyntax
@[builtinTermElab forInMacro] def elabForIn : TermElab := fun stx expectedType? => do
match stx with
| `(forIn% $col $init $body) =>
@ -118,4 +83,107 @@ where
throwFailure (forInInstance : Expr) : TermElabM Expr :=
throwError "failed to synthesize instance for 'forIn%' notation{indentExpr forInInstance}"
namespace BinOp
/- Elaborator for `binop%` -/
private inductive Tree where
| term (ref : Syntax) (val : Expr)
| op (ref : Syntax) (f : Expr) (lhs rhs : Tree)
private partial def toTree (s : Syntax) : TermElabM Tree :=
withSynthesizeLight do
go (← liftMacroM <| expandMacros s)
where
go (s : Syntax) := do
match s with
| `(binop% $f $lhs $rhs) =>
let some f ← resolveId? f | throwUnknownConstant f.getId
return Tree.op s f (← go lhs) (← go rhs)
| `(($e)) => (← go e)
| _ =>
return Tree.term s (← elabTerm s none)
-- Auxiliary function used at `analyze`
private def hasCoe (fromType toType : Expr) : TermElabM Bool := do
let u ← getLevel fromType
let v ← getLevel toType
let coeInstType := mkAppN (Lean.mkConst ``CoeHTCT [u, v]) #[fromType, toType]
match ← trySynthInstance coeInstType (some (maxCoeSize.get (← getOptions))) with
| LOption.some _ => return true
| LOption.none => return false
| LOption.undef => return false -- TODO: should we do something smarter here?
private structure AnalyzeResult where
max? : Option Expr := none
hasUncomparable : Bool := false -- `true` if there are two types `α` and `β` where we don't have coercions in any direction.
private def isUnknow (e : Expr) : Bool :=
e.getAppFn.isMVar
private def analyze (t : Tree) (expectedType? : Option Expr) : TermElabM AnalyzeResult := do
let max? ←
match expectedType? with
| none => pure none
| some expectedType =>
let expectedType ← instantiateMVars expectedType
if isUnknow expectedType then pure none else pure (some expectedType)
(go t *> get).run' { max? }
where
go (t : Tree) : StateRefT AnalyzeResult TermElabM Unit := do
unless (← get).hasUncomparable do
match t with
| Tree.op _ _ lhs rhs => go lhs; go rhs
| Tree.term _ val =>
let type ← instantiateMVars (← inferType val)
unless isUnknow type do
match (← get).max? with
| none => modify fun s => { s with max? := type }
| some max =>
unless (← withNewMCtxDepth <| isDefEqGuarded max type) do
if (← hasCoe type max) then
return ()
else if (← hasCoe max type) then
modify fun s => { s with max? := type }
else
trace[Elab.binop] "uncomparable types: {max}, {type}"
modify fun s => { s with hasUncomparable := true }
private def mkOp (f : Expr) (lhs rhs : Expr) : TermElabM Expr :=
elabAppArgs f #[] #[Arg.expr lhs, Arg.expr rhs] (expectedType? := none) (explicit := false) (ellipsis := false)
private def toExpr (t : Tree) : TermElabM Expr := do
match t with
| Tree.term _ e => return e
| Tree.op ref f lhs rhs => withRef ref <| mkOp f (← toExpr lhs) (← toExpr rhs)
private def applyCoe (t : Tree) (maxType : Expr) : TermElabM Tree := do
go t
where
go (t : Tree) : TermElabM Tree := do
match t with
| Tree.op ref f lhs rhs => return Tree.op ref f (← go lhs) (← go rhs)
| Tree.term ref e =>
let type ← inferType e
if (← isDefEqGuarded maxType type) then
return t
else
withRef ref <| return Tree.term ref (← mkCoe maxType type e)
@[builtinTermElab binop]
def elabBinOp' : TermElab := fun stx expectedType? => do
let tree ← toTree stx
let r ← analyze tree expectedType?
trace[Elab.binop] "hasUncomparable: {r.hasUncomparable}, maxType: {r.max?}"
if r.hasUncomparable || r.max?.isNone then
let result ← toExpr tree
ensureHasType expectedType? result
else
let result ← toExpr (← applyCoe tree r.max?.get!)
ensureHasType expectedType? result
builtin_initialize
registerTraceClass `Elab.binop
end BinOp
end Lean.Elab.Term

View file

@ -309,13 +309,13 @@ def synthesizeSyntheticMVarsUsingDefault : TermElabM Unit := do
synthesizeSyntheticMVars (mayPostpone := true)
synthesizeUsingDefaultLoop
private partial def withSynthesizeImp {α} (k : TermElabM α) (mayPostpone : Bool) : TermElabM α := do
private partial def withSynthesizeImp {α} (k : TermElabM α) (mayPostpone : Bool) (synthesizeDefault : Bool) : TermElabM α := do
let syntheticMVarsSaved := (← get).syntheticMVars
modify fun s => { s with syntheticMVars := [] }
try
let a ← k
synthesizeSyntheticMVars mayPostpone
if mayPostpone then
if mayPostpone && synthesizeDefault then
synthesizeUsingDefaultLoop
return a
finally
@ -326,7 +326,11 @@ private partial def withSynthesizeImp {α} (k : TermElabM α) (mayPostpone : Boo
If `mayPostpone == false`, then all of them must be synthesized.
Remark: even if `mayPostpone == true`, the method still uses `synthesizeUsingDefault` -/
@[inline] def withSynthesize [MonadFunctorT TermElabM m] [Monad m] (k : m α) (mayPostpone := false) : m α :=
monadMap (m := TermElabM) (withSynthesizeImp . mayPostpone) k
monadMap (m := TermElabM) (withSynthesizeImp . mayPostpone (synthesizeDefault := true)) k
/-- Similar to `withSynthesize`, but sets `mayPostpone` to `true`, and do not use `synthesizeUsingDefault` -/
@[inline] def withSynthesizeLight [MonadFunctorT TermElabM m] [Monad m] (k : m α) : m α :=
monadMap (m := TermElabM) (withSynthesizeImp . (mayPostpone := true) (synthesizeDefault := false)) k
/-- Elaborate `stx`, and make sure all pending synthetic metavariables created while elaborating `stx` are solved. -/
def elabTermAndSynthesize (stx : Syntax) (expectedType? : Option Expr) : TermElabM Expr :=

View file

@ -700,6 +700,27 @@ def tryCoeThunk? (expectedType : Expr) (eType : Expr) (e : Expr) : TermElabM (Op
| _ =>
pure none
def mkCoe (expectedType : Expr) (eType : Expr) (e : Expr) (f? : Option Expr := none) (errorMsgHeader? : Option String := none) : TermElabM Expr := do
let u ← getLevel eType
let v ← getLevel expectedType
let coeTInstType := mkAppN (mkConst ``CoeT [u, v]) #[eType, e, expectedType]
let mvar ← mkFreshExprMVar coeTInstType MetavarKind.synthetic
let eNew := mkAppN (mkConst ``coe [u, v]) #[eType, expectedType, e, mvar]
let mvarId := mvar.mvarId!
try
withoutMacroStackAtErr do
if (← synthesizeCoeInstMVarCore mvarId) then
expandCoe eNew
else
-- We create an auxiliary metavariable to represent the result, because we need to execute `expandCoe`
-- after we syntheze `mvar`
let mvarAux ← mkFreshExprMVar expectedType MetavarKind.syntheticOpaque
registerSyntheticMVarWithCurrRef mvarAux.mvarId! (SyntheticMVarKind.coe errorMsgHeader? eNew expectedType eType e f?)
return mvarAux
catch
| Exception.error _ msg => throwTypeMismatchError errorMsgHeader? expectedType eType e f? msg
| _ => throwTypeMismatchError errorMsgHeader? expectedType eType e f?
/--
Try to apply coercion to make sure `e` has type `expectedType`.
Relevant definitions:
@ -713,26 +734,7 @@ private def tryCoe (errorMsgHeader? : Option String) (expectedType : Expr) (eTyp
return e
else match (← tryCoeThunk? expectedType eType e) with
| some r => return r
| none =>
let u ← getLevel eType
let v ← getLevel expectedType
let coeTInstType := mkAppN (mkConst ``CoeT [u, v]) #[eType, e, expectedType]
let mvar ← mkFreshExprMVar coeTInstType MetavarKind.synthetic
let eNew := mkAppN (mkConst ``coe [u, v]) #[eType, expectedType, e, mvar]
let mvarId := mvar.mvarId!
try
withoutMacroStackAtErr do
if (← synthesizeCoeInstMVarCore mvarId) then
expandCoe eNew
else
-- We create an auxiliary metavariable to represent the result, because we need to execute `expandCoe`
-- after we syntheze `mvar`
let mvarAux ← mkFreshExprMVar expectedType MetavarKind.syntheticOpaque
registerSyntheticMVarWithCurrRef mvarAux.mvarId! (SyntheticMVarKind.coe errorMsgHeader? eNew expectedType eType e f?)
return mvarAux
catch
| Exception.error _ msg => throwTypeMismatchError errorMsgHeader? expectedType eType e f? msg
| _ => throwTypeMismatchError errorMsgHeader? expectedType eType e f?
| none => mkCoe expectedType eType e f? errorMsgHeader?
def isTypeApp? (type : Expr) : TermElabM (Option (Expr × Expr)) := do
let type ← withReducible $ whnf type

View file

@ -0,0 +1,13 @@
axiom Int.add_comm (i j : Int) : i + j = j + i
example (n : Nat) (i : Int) : n + i = i + n := by
rw [Int.add_comm]
def f1 (a : Int) (b c : Nat) : Int :=
a + (b - c)
def f2 (a : Int) (b c : Nat) : Int :=
(b - c) + a
#print f1
#print f2

View file

@ -0,0 +1,4 @@
def f1 : Int → Nat → Nat → Int :=
fun a b c => a + (Int.ofNat b - Int.ofNat c)
def f2 : Int → Nat → Nat → Int :=
fun a b c => Int.ofNat b - Int.ofNat c + a