From 34945dfc1cf5299aa34009038d1c36abcbaae167 Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Thu, 22 Oct 2020 10:14:50 -0700 Subject: [PATCH] =?UTF-8?q?feat:=20elaborate=20`=E2=96=B8`=20notation?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/Lean/Elab/BuiltinNotation.lean | 46 +++++++++++++++++++++++++++--- tests/lean/run/subst.lean | 39 +++++++++++++++++++++++++ 2 files changed, 81 insertions(+), 4 deletions(-) create mode 100644 tests/lean/run/subst.lean diff --git a/src/Lean/Elab/BuiltinNotation.lean b/src/Lean/Elab/BuiltinNotation.lean index d6ea331674..9c633a570a 100644 --- a/src/Lean/Elab/BuiltinNotation.lean +++ b/src/Lean/Elab/BuiltinNotation.lean @@ -6,6 +6,7 @@ Authors: Leonardo de Moura -/ import Init.Data.ToString import Lean.Compiler.BorrowedAnnotation +import Lean.Meta.KAbstract import Lean.Elab.Term import Lean.Elab.Quotation import Lean.Elab.SyntheticMVars @@ -320,9 +321,46 @@ private def elabCDot (stx : Syntax) (expectedType? : Option Expr) : TermElabM Ex withMacroExpansion stx pairs (elabTerm pairs expectedType?) | _ => throwError "unexpected parentheses notation" -/- -TODO -@[builtinTermElab] def elabsubst : TermElab := expandInfixOp infixR " ▸ " 75 --/ +@[builtinTermElab subst] def elabSubst : TermElab := fun stx expectedType? => do + tryPostponeIfNoneOrMVar expectedType? + let some expectedType ← pure expectedType? | + throwError! "invalid `▸` notation, expected type must be known" + let expectedType ← instantiateMVars expectedType + if expectedType.hasExprMVar then + throwError! "invalid `▸` notation, expected type contains metavariables{indentExpr expectedType}" + match_syntax stx with + | `($heq ▸ $h) => do + let heq ← elabTerm heq none + let heqType ← inferType heq + match (← Meta.matchEq? heqType) with + | none => throwError! "invalid `▸` notation, argument{indentExpr heq}\nhas type{indentExpr heqType}\nequality expected" + | some (α, lhs, rhs) => + let mkMotive (typeWithLooseBVar : Expr) := + withLocalDeclD (← mkFreshUserName `x) α fun x => do + mkLambdaFVars #[x] $ typeWithLooseBVar.instantiate1 x + let expectedAbst ← kabstract expectedType rhs + unless expectedAbst.hasLooseBVars do + expectedAbst ← kabstract expectedType lhs + unless expectedAbst.hasLooseBVars do + throwError! "invalid `▸` notation, expected type{indentExpr expectedType}\ndoes contain equation left-hand-side nor right-hand-side{indentExpr heqType}" + heq ← mkEqSymm heq + (lhs, rhs) := (rhs, lhs) + let hExpectedType := expectedAbst.instantiate1 lhs + let h ← withRef h do + let h ← elabTerm h hExpectedType + try + ensureHasType hExpectedType h + catch ex => + -- if `rhs` occurs in `hType`, we try to apply `heq` to `h` too + let hType ← inferType h + let hTypeAbst ← kabstract hType rhs + unless hTypeAbst.hasLooseBVars do + throw ex + let hTypeNew := hTypeAbst.instantiate1 lhs + unless (← isDefEq hExpectedType hTypeNew) do + throw ex + mkEqNDRec (← mkMotive hTypeAbst) h (← mkEqSymm heq) + mkEqNDRec (← mkMotive expectedAbst) h heq + | _ => throwUnsupportedSyntax end Lean.Elab.Term diff --git a/tests/lean/run/subst.lean b/tests/lean/run/subst.lean new file mode 100644 index 0000000000..503571322c --- /dev/null +++ b/tests/lean/run/subst.lean @@ -0,0 +1,39 @@ +#lang lean4 + +universes u + +def f1 (n m : Nat) (x : Fin n) (h : n = m) : Fin m := +h ▸ x + +def f2 (n m : Nat) (x : Fin n) (h : m = n) : Fin m := +h ▸ x + +theorem ex1 {α : Sort u} {a b c : α} (h₁ : a = b) (h₂ : b = c) : a = c := +h₂ ▸ h₁ + +theorem ex2 {α : Sort u} {a b : α} (h : a = b) : b = a := +h ▸ rfl + +theorem ex3 {α : Sort u} {a b c : α} (r : α → α → Prop) (h₁ : r a b) (h₂ : b = c) : r a c := +h₂ ▸ h₁ + +theorem ex4 {α : Sort u} {a b c : α} (r : α → α → Prop) (h₁ : a = b) (h₂ : r b c) : r a c := +h₁ ▸ h₂ + +theorem ex5 {p : Prop} (h : p = True) : p := +h ▸ trivial + +theorem ex6 {p : Prop} (h : p = False) : ¬p := +fun hp => h ▸ hp + +theorem ex7 {α} {a b c d : α} (h₁ : a = c) (h₂ : b = d) (h₃ : c ≠ d) : a ≠ b := +h₁ ▸ h₂ ▸ h₃ + +theorem ex8 (n m k : Nat) (h : Nat.succ n + m = Nat.succ n + k) : Nat.succ (n + m) = Nat.succ (n + k) := +Nat.succAdd .. ▸ Nat.succAdd .. ▸ h + +theorem ex9 (a b : Nat) (h₁ : a = a + b) (h₂ : a = b) : a = b + a := +h₂ ▸ h₁ + +theorem ex10 (a b : Nat) (h : a = b) : b = a := +h ▸ rfl