From 9b9998f5c84dcc2a4ebe8bb8b6a146b3ada6720b Mon Sep 17 00:00:00 2001 From: Mario Carneiro Date: Wed, 21 Sep 2022 01:31:31 -0400 Subject: [PATCH] feat: `pattern (occs := ...)` conv --- src/Init/Conv.lean | 35 +++++++- src/Lean/Elab/Tactic/Conv/Basic.lean | 4 +- src/Lean/Elab/Tactic/Conv/Pattern.lean | 116 +++++++++++++++++++------ tests/lean/conv1.lean | 14 +++ tests/lean/conv1.lean.expected.out | 31 +++++++ 5 files changed, 169 insertions(+), 31 deletions(-) diff --git a/src/Init/Conv.lean b/src/Init/Conv.lean index 2ad1f1c1d4..6f0a7c943a 100644 --- a/src/Init/Conv.lean +++ b/src/Init/Conv.lean @@ -23,6 +23,18 @@ syntax convSeqBracketed := "{" sepByIndentSemicolon(conv) "}" -- automatically closing goals syntax convSeq := convSeqBracketed <|> convSeq1Indented +/-- The `*` occurrence list means to apply to all occurrences of the pattern. -/ +syntax occsWildcard := "*" + +/-- +A list `1 2 4` of occurrences means to apply to the first, second, and fourth +occurrence of the pattern. +-/ +syntax occsIndexed := num+ + +/-- An occurrence specification, either `*` or a list of numbers. The default is `[1]`. -/ +syntax occs := atomic(" (" &"occs") " := " (occsWildcard <|> occsIndexed) ")" + /-- `conv => ...` allows the user to perform targeted rewriting on a goal or hypothesis, by focusing on particular subexpressions. @@ -32,9 +44,9 @@ See for more d Basic forms: * `conv => cs` will rewrite the goal with conv tactics `cs`. * `conv at h => cs` will rewrite hypothesis `h`. -* `conv in pat => cs` will rewrite the first subexpression matching `pat`. +* `conv in pat => cs` will rewrite the first subexpression matching `pat` (see `pattern`). -/ -syntax (name := conv) "conv " (" at " ident)? (" in " term)? " => " convSeq : tactic +syntax (name := conv) "conv " (" at " ident)? (" in " (occs)? term)? " => " convSeq : tactic /-- `skip` does nothing. -/ syntax (name := skip) "skip" : conv @@ -92,8 +104,23 @@ to rewrite the target. For recursive definitions, only one layer of unfolding is performed. -/ syntax (name := unfold) "unfold " (colGt ident)+ : conv -/-- `pattern pat` traverses to the first subterm of the target that matches `pat`. -/ -syntax (name := pattern) "pattern " term : conv +/-- +* `pattern pat` traverses to the first subterm of the target that matches `pat`. +* `pattern (occs := *) pat` traverses to the every subterm of the target that matches `pat` + which is not contained in another match of `pat`. It generates one subgoal for each matching + subterm. +* `pattern (occs := 1 2 4) pat` matches occurrences `1, 2, 4` of `pat` and produces three subgoals. + Occurrences are numbered left to right from the outside in. + +Note that skipping an occurrence of `pat` will traverse inside that subexpression, which means +it may find more matches and this can affect the numbering of subsequent pattern matches. +For example, if we are searching for `f _` in `f (f a) = f b`: +* `occs := 1 2` (and `occs := *`) returns `| f (f a)` and `| f b` +* `occs := 2` returns `| f a` +* `occs := 2 3` returns `| f a` and `| f b` +* `occs := 1 3` is an error, because after skipping `f b` there is no third match. +-/ +syntax (name := pattern) "pattern " (occs)? term : conv /-- `rw [thm]` rewrites the target using `thm`. See the `rw` tactic for more information. -/ syntax (name := rewrite) "rewrite" (config)? rwRuleSeq : conv diff --git a/src/Lean/Elab/Tactic/Conv/Basic.lean b/src/Lean/Elab/Tactic/Conv/Basic.lean index a9fdd80b41..30efd06af7 100644 --- a/src/Lean/Elab/Tactic/Conv/Basic.lean +++ b/src/Lean/Elab/Tactic/Conv/Basic.lean @@ -157,8 +157,8 @@ private def convLocalDecl (conv : Syntax) (hUserName : Name) : TacticM Unit := w @[builtinTactic Lean.Parser.Tactic.Conv.conv] def evalConv : Tactic := fun stx => do match stx with - | `(tactic| conv%$tk $[at $loc?]? in $p =>%$arr $code) => - evalTactic (← `(tactic| conv%$tk $[at $loc?]? =>%$arr pattern $p; ($code:convSeq))) + | `(tactic| conv%$tk $[at $loc?]? in $(occs)? $p =>%$arr $code) => + evalTactic (← `(tactic| conv%$tk $[at $loc?]? =>%$arr pattern $(occs)? $p; ($code:convSeq))) | `(tactic| conv%$tk $[at $loc?]? =>%$arr $code) => -- show initial conv goal state between `conv` and `=>` withRef (mkNullNode #[tk, arr]) do diff --git a/src/Lean/Elab/Tactic/Conv/Pattern.lean b/src/Lean/Elab/Tactic/Conv/Pattern.lean index e32d8cfc3f..6278c3a4a1 100644 --- a/src/Lean/Elab/Tactic/Conv/Pattern.lean +++ b/src/Lean/Elab/Tactic/Conv/Pattern.lean @@ -33,42 +33,108 @@ partial def matchPattern? (pattern : AbstractMVarsResult) (e : Expr) : MetaM (Op return none withReducible <| go? e -private def pre (pattern : AbstractMVarsResult) (found? : IO.Ref (Option Expr)) (e : Expr) : SimpM Simp.Step := do - if (← found?.get).isSome then +inductive PatternMatchState where + /-- + The state corresponding to a `(occs := *)` pattern, which acts like `occs := 1 2 ... n` where + `n` is the total number of pattern matches. + * `subgoals` is the list of subgoals for patterns already matched + -/ + | all (subgoals : Array MVarId) + /-- + The state corresponding to a partially consumed `(occs := a₁ a₂ ...)` pattern. + * `subgoals` is the list of subgoals for patterns already matched, + along with their index in the original occs list + * `idx` is the number of matches that have occurred so far + * `remaining` is a list of `(i, orig)` pairs representing matches we have not yet reached. + We maintain the invariant that `idx :: remaining.map (·.1)` is sorted. + The number `i` is the value in the `occs` list and `orig` is its index in the list. + -/ + | occs (subgoals : Array (Nat × MVarId)) (idx : Nat) (remaining : List (Nat × Nat)) + +namespace PatternMatchState + +/-- Is this pattern no longer interested in accepting matches? -/ +def isDone : PatternMatchState → Bool + | .all _ => false + | .occs _ _ remaining => remaining.isEmpty + +/-- Is this pattern interested in accepting the next match? -/ +def isReady : PatternMatchState → Bool + | .all _ => true + | .occs _ idx ((i, _) :: _) => idx == i + | _ => false + +/-- Assuming `isReady` returned false, this advances to the next match. -/ +def skip : PatternMatchState → PatternMatchState + | .occs subgoals idx remaining => .occs subgoals (idx + 1) remaining + | s => s + +/-- +Assuming `isReady` returned true, this adds the generated subgoal to the list +and advances to the next match. +-/ +def accept (mvarId : MVarId) : PatternMatchState → PatternMatchState + | .all subgoals => .all (subgoals.push mvarId) + | .occs subgoals idx ((_, n) :: remaining) => .occs (subgoals.push (n, mvarId)) (idx + 1) remaining + | s => s + +end PatternMatchState + +private def pre (pattern : AbstractMVarsResult) (state : IO.Ref PatternMatchState) (e : Expr) : SimpM Simp.Step := do + if (← state.get).isDone then return Simp.Step.visit { expr := e } else if let some (e, extraArgs) ← matchPattern? pattern e then - let (rhs, newGoal) ← mkConvGoalFor e - found?.set newGoal - let mut proof := newGoal - for extraArg in extraArgs do - proof ← mkCongrFun proof extraArg - return Simp.Step.done { expr := mkAppN rhs extraArgs, proof? := proof } + if (← state.get).isReady then + let (rhs, newGoal) ← mkConvGoalFor e + state.modify (·.accept newGoal.mvarId!) + let mut proof := newGoal + for extraArg in extraArgs do + proof ← mkCongrFun proof extraArg + return Simp.Step.done { expr := mkAppN rhs extraArgs, proof? := proof } + else + state.modify (·.skip) + -- Note that because we return `visit` here and `done` in the other case, + -- it is possible for skipping an earlier match to affect what later matches + -- refer to. For example, matching `f _` in `f (f a) = f b` with occs `[1, 2]` + -- yields `[f (f a), f b]`, but `[2, 3]` yields `[f a, f b]`, and `[1, 3]` is an error. + return Simp.Step.visit { expr := e } else return Simp.Step.visit { expr := e } -private def findPattern? (pattern : AbstractMVarsResult) (e : Expr) : MetaM (Option (MVarId × Simp.Result)) := do - let found? ← IO.mkRef none - let (result, _) ← Simp.main e (← getContext) (methods := { pre := pre pattern found? }) - if let some newGoal ← found?.get then - return some (newGoal.mvarId!, result) - else - return none - @[builtinTactic Lean.Parser.Tactic.Conv.pattern] def evalPattern : Tactic := fun stx => withMainContext do match stx with - | `(conv| pattern $p) => + | `(conv| pattern $[(occs := $occs)]? $p) => let patternA ← withTheReader Term.Context (fun ctx => { ctx with ignoreTCFailures := true }) <| Term.withoutModifyingElabMetaStateWithInfo <| withRef p <| Term.withoutErrToSorry do abstractMVars (← Term.elabTerm p none) let lhs ← getLhs - match (← findPattern? patternA lhs) with - | none => throwError "'pattern' conv tactic failed, pattern was not found{indentExpr patternA.expr}" - | some (mvarId', result) => - updateLhs result.expr (← result.getProof) - (← getMainGoal).refl - replaceMainGoal [mvarId'] + let occs ← match occs with + | none => pure (.occs #[] 0 [(0, 0)]) + | some occs => match occs with + | `(Parser.Tactic.Conv.occsWildcard| *) => pure (.all #[]) + | `(Parser.Tactic.Conv.occsIndexed| $ids*) => do + let ids ← ids.mapIdxM fun i id => + match id.getNat with + | 0 => throwErrorAt id "positive integer expected" + | n+1 => pure (n, i.1) + let ids := ids.qsort (·.1 < ·.1) + unless @Array.allDiff _ ⟨(·.1 == ·.1)⟩ ids do + throwError "occurrence list is not distinct" + pure (.occs #[] 0 ids.toList) + | _ => throwUnsupportedSyntax + let state ← IO.mkRef occs + let (result, _) ← Simp.main lhs (← getContext) (methods := { pre := pre patternA state }) + let subgoals ← match ← state.get with + | .all #[] | .occs _ 0 _ => + throwError "'pattern' conv tactic failed, pattern was not found{indentExpr patternA.expr}" + | .all subgoals => pure subgoals + | .occs subgoals idx remaining => + if let some (i, _) := remaining.getLast? then + throwError "'pattern' conv tactic failed, pattern was found only {idx} times but {i+1} expected" + pure <| (subgoals.qsort (·.1 < ·.1)).map (·.2) + (← getRhs).mvarId!.assign result.expr + (← getMainGoal).assign (← result.getProof) + replaceMainGoal subgoals.toList | _ => throwUnsupportedSyntax - -end Lean.Elab.Tactic.Conv diff --git a/tests/lean/conv1.lean b/tests/lean/conv1.lean index 1709530fad..6d12799fb6 100644 --- a/tests/lean/conv1.lean +++ b/tests/lean/conv1.lean @@ -180,3 +180,17 @@ example : let a := 0; let b := a; b = 0 := by conv => zeta trace_state + +example : ((x + y) + z : Nat) = x + (y + z) := by + conv in _ + _ => trace_state + conv in (occs := *) _ + _ => trace_state + conv in (occs := 1 3) _ + _ => trace_state + conv in (occs := 3 1) _ + _ => trace_state + conv in (occs := 2 3) _ + _ => trace_state + conv in (occs := 2 4) _ + _ => trace_state + apply Nat.add_assoc + +example : ((x + y) + z : Nat) = x + (y + z) := by conv => pattern (occs := 5) _ + _ +example : ((x + y) + z : Nat) = x + (y + z) := by conv => pattern (occs := 2 5) _ + _ +example : ((x + y) + z : Nat) = x + (y + z) := by conv => pattern (occs := 1 5) _ + _ +example : ((x + y) + z : Nat) = x + (y + z) := by conv => pattern (occs := 1 2 5) _ + _ diff --git a/tests/lean/conv1.lean.expected.out b/tests/lean/conv1.lean.expected.out index c40389ceb3..631bb14120 100644 --- a/tests/lean/conv1.lean.expected.out +++ b/tests/lean/conv1.lean.expected.out @@ -85,3 +85,34 @@ conv1.lean:175:10-175:15: error: cannot select argument a✝ : Nat := 0 b✝ : Nat := a✝ | 0 = 0 +x y z : Nat +| x + y + z +x y z : Nat +| x + y + z + +x y z : Nat +| x + (y + z) +x y z : Nat +| x + y + z + +x y z : Nat +| y + z +x y z : Nat +| y + z + +x y z : Nat +| x + y + z +x y z : Nat +| x + y + +x y z : Nat +| x + (y + z) +x y z : Nat +| x + y + +x y z : Nat +| y + z +conv1.lean:193:58-193:83: error: 'pattern' conv tactic failed, pattern was found only 4 times but 5 expected +conv1.lean:194:58-194:85: error: 'pattern' conv tactic failed, pattern was found only 4 times but 5 expected +conv1.lean:195:58-195:85: error: 'pattern' conv tactic failed, pattern was found only 3 times but 5 expected +conv1.lean:196:58-196:87: error: 'pattern' conv tactic failed, pattern was found only 2 times but 5 expected