diff --git a/src/Init/Conv.lean b/src/Init/Conv.lean index 11f5b47d06..c3b2a19067 100644 --- a/src/Init/Conv.lean +++ b/src/Init/Conv.lean @@ -26,7 +26,7 @@ syntax (name := whnf) "whnf" : conv syntax (name := congr) "congr" : conv syntax (name := arg) "arg " num : conv syntax (name := traceState) "traceState" : conv -syntax (name := funext) "funext" ident* : conv +syntax (name := funext) "funext" (colGt ident)* : conv syntax (name := change) "change " term : conv syntax (name := rewrite) "rewrite " rwRuleSeq : conv syntax (name := erewrite) "erewrite " rwRuleSeq : conv diff --git a/src/Lean/Elab/Tactic/Conv/Basic.lean b/src/Lean/Elab/Tactic/Conv/Basic.lean index ac73b3ed02..160cb0b67b 100644 --- a/src/Lean/Elab/Tactic/Conv/Basic.lean +++ b/src/Lean/Elab/Tactic/Conv/Basic.lean @@ -18,6 +18,12 @@ def mkConvGoalFor (lhs : Expr) : MetaM (Expr × Expr) := do let newGoal ← mkFreshExprSyntheticOpaqueMVar targetNew return (rhs, newGoal) +def markAsConvGoal (mvarId : MVarId) : MetaM MVarId := do + let target ← getMVarType mvarId + if isLHSGoal? target |>.isSome then + return mvarId -- it is already tagged as LHS goal + replaceTargetDefEq mvarId (mkLHSGoal (← getMVarType mvarId)) + def convert (lhs : Expr) (conv : TacticM Unit) : TacticM (Expr × Expr) := do let (rhs, newGoal) ← mkConvGoalFor lhs let savedGoals ← getGoals diff --git a/src/Lean/Elab/Tactic/Conv/Congr.lean b/src/Lean/Elab/Tactic/Conv/Congr.lean index afe7f29810..49cc07abfc 100644 --- a/src/Lean/Elab/Tactic/Conv/Congr.lean +++ b/src/Lean/Elab/Tactic/Conv/Congr.lean @@ -13,7 +13,7 @@ def congr (mvarId : MVarId) : MetaM (List MVarId) := withMVarContext mvarId do let (lhs, rhs) ← getLhsRhsCore mvarId unless lhs.isApp do - throwError "invalid 'congr' conv tactic, application expected{indentD lhs}" + throwError "invalid 'congr' conv tactic, application expected{indentExpr lhs}" lhs.withApp fun f args => do let infos := (← getFunInfoNArgs f args.size).paramInfo let mut r := { expr := f : Simp.Result } @@ -69,5 +69,26 @@ def congr (mvarId : MVarId) : MetaM (List MVarId) := throwError "invalid 'arg' conv tactic, application has only {mvarIds.length} (nondependent) arguments" | _ => throwUnsupportedSyntax +private def funextCore (mvarId : MVarId) (userName? : Option Name) : MetaM MVarId := + withMVarContext mvarId do + let (lhs, rhs) ← getLhsRhsCore mvarId + let lhsType ← whnfD (← inferType lhs) + unless lhsType.isForall do + throwError "invalid 'funext' conv tactic, function expected{indentD m!"{lhs} : {lhsType}"}" + let [mvarIdNew] ← apply mvarId (← mkConstWithFreshMVarLevels ``funext) | throwError "'apply funext' unexpected result" + let userNames := if let some userName := userName? then [userName] else [] + let (_, mvarId) ← introN mvarIdNew 1 userNames + markAsConvGoal mvarId + +private def funext (userName? : Option Name) : TacticM Unit := do + replaceMainGoal [← funextCore (← getMainGoal) userName?] + +@[builtinTactic Lean.Parser.Tactic.Conv.funext] def evalFunext : Tactic := fun stx => do + let ids := stx[1].getArgs + if ids.isEmpty then + funext none + else + for id in ids do + withRef id <| funext id.getId end Lean.Elab.Tactic.Conv diff --git a/tests/lean/conv1.lean b/tests/lean/conv1.lean index b4bd61e13e..9a02c5abac 100644 --- a/tests/lean/conv1.lean +++ b/tests/lean/conv1.lean @@ -52,3 +52,13 @@ example (x y : Nat) : f x (x + y + 0) y = y + x := by traceState rw [Nat.add_comm] rfl + +example : id (fun x y => 0 + x + y) = Nat.add := by + conv => + lhs + arg 1 + funext a b + traceState + rw [Nat.zero_add] + traceState + rfl diff --git a/tests/lean/conv1.lean.expected.out b/tests/lean/conv1.lean.expected.out index d92c074833..31e324f14d 100644 --- a/tests/lean/conv1.lean.expected.out +++ b/tests/lean/conv1.lean.expected.out @@ -8,3 +8,8 @@ x y : Nat ⊢ f x (Nat.add x y) y = y + x x y : Nat ⊢ x + y +case h.h +a b : Nat +⊢ 0 + a + b +a b : Nat +⊢ a + b diff --git a/tests/lean/run/conv1.lean b/tests/lean/run/conv1.lean index 48906252b1..5d8b41b6e1 100644 --- a/tests/lean/run/conv1.lean +++ b/tests/lean/run/conv1.lean @@ -62,3 +62,11 @@ example (h1 : x ≠ 0) (h2 : y = x / x) : y = 1 := by skip tactic => assumption assumption + +example : id (fun x => 0 + x) = id := by + conv => + lhs + arg 1 + funext y + rw [Nat.zero_add] + rfl