From d3d03df83c2b40649bdc65917ed5f3932b01857b Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Fri, 13 Aug 2021 15:44:04 -0700 Subject: [PATCH] feat: new elaborator for `binop%` cc @gebner. --- src/Lean/Elab/Extra.lean | 138 +++++++++++++++++------ src/Lean/Elab/SyntheticMVars.lean | 10 +- src/Lean/Elab/Term.lean | 42 +++---- tests/lean/binopIssues.lean | 13 +++ tests/lean/binopIssues.lean.expected.out | 4 + 5 files changed, 149 insertions(+), 58 deletions(-) create mode 100644 tests/lean/binopIssues.lean create mode 100644 tests/lean/binopIssues.lean.expected.out diff --git a/src/Lean/Elab/Extra.lean b/src/Lean/Elab/Extra.lean index e18968d9f9..15cda8c4fd 100644 --- a/src/Lean/Elab/Extra.lean +++ b/src/Lean/Elab/Extra.lean @@ -43,41 +43,6 @@ open Meta elabAppArgs f #[] #[Arg.expr lhs, Arg.expr rhs] expectedType? (explicit := false) (ellipsis := false) | none => throwUnknownConstant stx[1].getId --- TODO: move to another file? -private def hasUnknownType (e : Expr) : MetaM Bool := - return (← inferType e).getAppFn.isMVar - -@[builtinTermElab binop] def elabBinOp : TermElab := fun stx expectedType? => do - match stx with - | `(binop% $f $lhs $rhs) => - match expectedType? with - | none => - -- We elaborate as a normal application when expected type is not available - let stxNew ← `($f:ident $lhs $rhs) - withMacroExpansion stx stxNew <| elabTerm stxNew none - | some expectedType => - match (← resolveId? f) with - | some f => - let syntheticMVarsSaved := (← get).syntheticMVars - modify fun s => { s with syntheticMVars := [] } - try - let lhs ← elabTerm lhs none - let rhs ← elabTerm rhs none - if (← hasUnknownType lhs) && (← hasUnknownType rhs) then - -- We want the numerals in terms such as `(1 + 1)` `(2 * 3 + 4)` to be elaborated using the expected type - -- This is particularly important when there is no coercion from `Nat` to the expected type. - elabAppArgs f #[] #[Arg.expr lhs, Arg.expr rhs] expectedType (explicit := false) (ellipsis := false) - else - -- We force TC resolution and default instances to be used. - -- Note that we do not provide the expected type to make sure it can be inferred by the TC procedure. See issue #382 - let r ← elabAppArgs f #[] #[Arg.expr lhs, Arg.expr rhs] (expectedType? := none) (explicit := false) (ellipsis := false) - synthesizeSyntheticMVarsUsingDefault - return r - finally - modify fun s => { s with syntheticMVars := s.syntheticMVars ++ syntheticMVarsSaved } - | none => throwUnknownConstant stx[1].getId - | _ => throwUnsupportedSyntax - @[builtinTermElab forInMacro] def elabForIn : TermElab := fun stx expectedType? => do match stx with | `(forIn% $col $init $body) => @@ -118,4 +83,107 @@ where throwFailure (forInInstance : Expr) : TermElabM Expr := throwError "failed to synthesize instance for 'forIn%' notation{indentExpr forInInstance}" +namespace BinOp +/- Elaborator for `binop%` -/ + +private inductive Tree where + | term (ref : Syntax) (val : Expr) + | op (ref : Syntax) (f : Expr) (lhs rhs : Tree) + +private partial def toTree (s : Syntax) : TermElabM Tree := + withSynthesizeLight do + go (← liftMacroM <| expandMacros s) +where + go (s : Syntax) := do + match s with + | `(binop% $f $lhs $rhs) => + let some f ← resolveId? f | throwUnknownConstant f.getId + return Tree.op s f (← go lhs) (← go rhs) + | `(($e)) => (← go e) + | _ => + return Tree.term s (← elabTerm s none) + +-- Auxiliary function used at `analyze` +private def hasCoe (fromType toType : Expr) : TermElabM Bool := do + let u ← getLevel fromType + let v ← getLevel toType + let coeInstType := mkAppN (Lean.mkConst ``CoeHTCT [u, v]) #[fromType, toType] + match ← trySynthInstance coeInstType (some (maxCoeSize.get (← getOptions))) with + | LOption.some _ => return true + | LOption.none => return false + | LOption.undef => return false -- TODO: should we do something smarter here? + +private structure AnalyzeResult where + max? : Option Expr := none + hasUncomparable : Bool := false -- `true` if there are two types `α` and `β` where we don't have coercions in any direction. + +private def isUnknow (e : Expr) : Bool := + e.getAppFn.isMVar + +private def analyze (t : Tree) (expectedType? : Option Expr) : TermElabM AnalyzeResult := do + let max? ← + match expectedType? with + | none => pure none + | some expectedType => + let expectedType ← instantiateMVars expectedType + if isUnknow expectedType then pure none else pure (some expectedType) + (go t *> get).run' { max? } +where + go (t : Tree) : StateRefT AnalyzeResult TermElabM Unit := do + unless (← get).hasUncomparable do + match t with + | Tree.op _ _ lhs rhs => go lhs; go rhs + | Tree.term _ val => + let type ← instantiateMVars (← inferType val) + unless isUnknow type do + match (← get).max? with + | none => modify fun s => { s with max? := type } + | some max => + unless (← withNewMCtxDepth <| isDefEqGuarded max type) do + if (← hasCoe type max) then + return () + else if (← hasCoe max type) then + modify fun s => { s with max? := type } + else + trace[Elab.binop] "uncomparable types: {max}, {type}" + modify fun s => { s with hasUncomparable := true } + +private def mkOp (f : Expr) (lhs rhs : Expr) : TermElabM Expr := + elabAppArgs f #[] #[Arg.expr lhs, Arg.expr rhs] (expectedType? := none) (explicit := false) (ellipsis := false) + +private def toExpr (t : Tree) : TermElabM Expr := do + match t with + | Tree.term _ e => return e + | Tree.op ref f lhs rhs => withRef ref <| mkOp f (← toExpr lhs) (← toExpr rhs) + +private def applyCoe (t : Tree) (maxType : Expr) : TermElabM Tree := do + go t +where + go (t : Tree) : TermElabM Tree := do + match t with + | Tree.op ref f lhs rhs => return Tree.op ref f (← go lhs) (← go rhs) + | Tree.term ref e => + let type ← inferType e + if (← isDefEqGuarded maxType type) then + return t + else + withRef ref <| return Tree.term ref (← mkCoe maxType type e) + +@[builtinTermElab binop] +def elabBinOp' : TermElab := fun stx expectedType? => do + let tree ← toTree stx + let r ← analyze tree expectedType? + trace[Elab.binop] "hasUncomparable: {r.hasUncomparable}, maxType: {r.max?}" + if r.hasUncomparable || r.max?.isNone then + let result ← toExpr tree + ensureHasType expectedType? result + else + let result ← toExpr (← applyCoe tree r.max?.get!) + ensureHasType expectedType? result + +builtin_initialize + registerTraceClass `Elab.binop + +end BinOp + end Lean.Elab.Term diff --git a/src/Lean/Elab/SyntheticMVars.lean b/src/Lean/Elab/SyntheticMVars.lean index fc96c7c58e..95a5f24e54 100644 --- a/src/Lean/Elab/SyntheticMVars.lean +++ b/src/Lean/Elab/SyntheticMVars.lean @@ -309,13 +309,13 @@ def synthesizeSyntheticMVarsUsingDefault : TermElabM Unit := do synthesizeSyntheticMVars (mayPostpone := true) synthesizeUsingDefaultLoop -private partial def withSynthesizeImp {α} (k : TermElabM α) (mayPostpone : Bool) : TermElabM α := do +private partial def withSynthesizeImp {α} (k : TermElabM α) (mayPostpone : Bool) (synthesizeDefault : Bool) : TermElabM α := do let syntheticMVarsSaved := (← get).syntheticMVars modify fun s => { s with syntheticMVars := [] } try let a ← k synthesizeSyntheticMVars mayPostpone - if mayPostpone then + if mayPostpone && synthesizeDefault then synthesizeUsingDefaultLoop return a finally @@ -326,7 +326,11 @@ private partial def withSynthesizeImp {α} (k : TermElabM α) (mayPostpone : Boo If `mayPostpone == false`, then all of them must be synthesized. Remark: even if `mayPostpone == true`, the method still uses `synthesizeUsingDefault` -/ @[inline] def withSynthesize [MonadFunctorT TermElabM m] [Monad m] (k : m α) (mayPostpone := false) : m α := - monadMap (m := TermElabM) (withSynthesizeImp . mayPostpone) k + monadMap (m := TermElabM) (withSynthesizeImp . mayPostpone (synthesizeDefault := true)) k + +/-- Similar to `withSynthesize`, but sets `mayPostpone` to `true`, and do not use `synthesizeUsingDefault` -/ +@[inline] def withSynthesizeLight [MonadFunctorT TermElabM m] [Monad m] (k : m α) : m α := + monadMap (m := TermElabM) (withSynthesizeImp . (mayPostpone := true) (synthesizeDefault := false)) k /-- Elaborate `stx`, and make sure all pending synthetic metavariables created while elaborating `stx` are solved. -/ def elabTermAndSynthesize (stx : Syntax) (expectedType? : Option Expr) : TermElabM Expr := diff --git a/src/Lean/Elab/Term.lean b/src/Lean/Elab/Term.lean index 9cc0d302f9..95d1dcdd91 100644 --- a/src/Lean/Elab/Term.lean +++ b/src/Lean/Elab/Term.lean @@ -700,6 +700,27 @@ def tryCoeThunk? (expectedType : Expr) (eType : Expr) (e : Expr) : TermElabM (Op | _ => pure none +def mkCoe (expectedType : Expr) (eType : Expr) (e : Expr) (f? : Option Expr := none) (errorMsgHeader? : Option String := none) : TermElabM Expr := do + let u ← getLevel eType + let v ← getLevel expectedType + let coeTInstType := mkAppN (mkConst ``CoeT [u, v]) #[eType, e, expectedType] + let mvar ← mkFreshExprMVar coeTInstType MetavarKind.synthetic + let eNew := mkAppN (mkConst ``coe [u, v]) #[eType, expectedType, e, mvar] + let mvarId := mvar.mvarId! + try + withoutMacroStackAtErr do + if (← synthesizeCoeInstMVarCore mvarId) then + expandCoe eNew + else + -- We create an auxiliary metavariable to represent the result, because we need to execute `expandCoe` + -- after we syntheze `mvar` + let mvarAux ← mkFreshExprMVar expectedType MetavarKind.syntheticOpaque + registerSyntheticMVarWithCurrRef mvarAux.mvarId! (SyntheticMVarKind.coe errorMsgHeader? eNew expectedType eType e f?) + return mvarAux + catch + | Exception.error _ msg => throwTypeMismatchError errorMsgHeader? expectedType eType e f? msg + | _ => throwTypeMismatchError errorMsgHeader? expectedType eType e f? + /-- Try to apply coercion to make sure `e` has type `expectedType`. Relevant definitions: @@ -713,26 +734,7 @@ private def tryCoe (errorMsgHeader? : Option String) (expectedType : Expr) (eTyp return e else match (← tryCoeThunk? expectedType eType e) with | some r => return r - | none => - let u ← getLevel eType - let v ← getLevel expectedType - let coeTInstType := mkAppN (mkConst ``CoeT [u, v]) #[eType, e, expectedType] - let mvar ← mkFreshExprMVar coeTInstType MetavarKind.synthetic - let eNew := mkAppN (mkConst ``coe [u, v]) #[eType, expectedType, e, mvar] - let mvarId := mvar.mvarId! - try - withoutMacroStackAtErr do - if (← synthesizeCoeInstMVarCore mvarId) then - expandCoe eNew - else - -- We create an auxiliary metavariable to represent the result, because we need to execute `expandCoe` - -- after we syntheze `mvar` - let mvarAux ← mkFreshExprMVar expectedType MetavarKind.syntheticOpaque - registerSyntheticMVarWithCurrRef mvarAux.mvarId! (SyntheticMVarKind.coe errorMsgHeader? eNew expectedType eType e f?) - return mvarAux - catch - | Exception.error _ msg => throwTypeMismatchError errorMsgHeader? expectedType eType e f? msg - | _ => throwTypeMismatchError errorMsgHeader? expectedType eType e f? + | none => mkCoe expectedType eType e f? errorMsgHeader? def isTypeApp? (type : Expr) : TermElabM (Option (Expr × Expr)) := do let type ← withReducible $ whnf type diff --git a/tests/lean/binopIssues.lean b/tests/lean/binopIssues.lean new file mode 100644 index 0000000000..200dd92bc3 --- /dev/null +++ b/tests/lean/binopIssues.lean @@ -0,0 +1,13 @@ +axiom Int.add_comm (i j : Int) : i + j = j + i + +example (n : Nat) (i : Int) : n + i = i + n := by + rw [Int.add_comm] + +def f1 (a : Int) (b c : Nat) : Int := + a + (b - c) + +def f2 (a : Int) (b c : Nat) : Int := + (b - c) + a + +#print f1 +#print f2 diff --git a/tests/lean/binopIssues.lean.expected.out b/tests/lean/binopIssues.lean.expected.out new file mode 100644 index 0000000000..f46005b465 --- /dev/null +++ b/tests/lean/binopIssues.lean.expected.out @@ -0,0 +1,4 @@ +def f1 : Int → Nat → Nat → Int := +fun a b c => a + (Int.ofNat b - Int.ofNat c) +def f2 : Int → Nat → Nat → Int := +fun a b c => Int.ofNat b - Int.ofNat c + a