feat: pattern (occs := ...) conv

This commit is contained in:
Mario Carneiro 2022-09-21 01:31:31 -04:00 committed by Leonardo de Moura
parent dadfe84c15
commit 9b9998f5c8
5 changed files with 169 additions and 31 deletions

View file

@ -23,6 +23,18 @@ syntax convSeqBracketed := "{" sepByIndentSemicolon(conv) "}"
-- automatically closing goals
syntax convSeq := convSeqBracketed <|> convSeq1Indented
/-- The `*` occurrence list means to apply to all occurrences of the pattern. -/
syntax occsWildcard := "*"
/--
A list `1 2 4` of occurrences means to apply to the first, second, and fourth
occurrence of the pattern.
-/
syntax occsIndexed := num+
/-- An occurrence specification, either `*` or a list of numbers. The default is `[1]`. -/
syntax occs := atomic(" (" &"occs") " := " (occsWildcard <|> occsIndexed) ")"
/--
`conv => ...` allows the user to perform targeted rewriting on a goal or hypothesis,
by focusing on particular subexpressions.
@ -32,9 +44,9 @@ See <https://leanprover.github.io/theorem_proving_in_lean4/conv.html> for more d
Basic forms:
* `conv => cs` will rewrite the goal with conv tactics `cs`.
* `conv at h => cs` will rewrite hypothesis `h`.
* `conv in pat => cs` will rewrite the first subexpression matching `pat`.
* `conv in pat => cs` will rewrite the first subexpression matching `pat` (see `pattern`).
-/
syntax (name := conv) "conv " (" at " ident)? (" in " term)? " => " convSeq : tactic
syntax (name := conv) "conv " (" at " ident)? (" in " (occs)? term)? " => " convSeq : tactic
/-- `skip` does nothing. -/
syntax (name := skip) "skip" : conv
@ -92,8 +104,23 @@ to rewrite the target. For recursive definitions,
only one layer of unfolding is performed. -/
syntax (name := unfold) "unfold " (colGt ident)+ : conv
/-- `pattern pat` traverses to the first subterm of the target that matches `pat`. -/
syntax (name := pattern) "pattern " term : conv
/--
* `pattern pat` traverses to the first subterm of the target that matches `pat`.
* `pattern (occs := *) pat` traverses to the every subterm of the target that matches `pat`
which is not contained in another match of `pat`. It generates one subgoal for each matching
subterm.
* `pattern (occs := 1 2 4) pat` matches occurrences `1, 2, 4` of `pat` and produces three subgoals.
Occurrences are numbered left to right from the outside in.
Note that skipping an occurrence of `pat` will traverse inside that subexpression, which means
it may find more matches and this can affect the numbering of subsequent pattern matches.
For example, if we are searching for `f _` in `f (f a) = f b`:
* `occs := 1 2` (and `occs := *`) returns `| f (f a)` and `| f b`
* `occs := 2` returns `| f a`
* `occs := 2 3` returns `| f a` and `| f b`
* `occs := 1 3` is an error, because after skipping `f b` there is no third match.
-/
syntax (name := pattern) "pattern " (occs)? term : conv
/-- `rw [thm]` rewrites the target using `thm`. See the `rw` tactic for more information. -/
syntax (name := rewrite) "rewrite" (config)? rwRuleSeq : conv

View file

@ -157,8 +157,8 @@ private def convLocalDecl (conv : Syntax) (hUserName : Name) : TacticM Unit := w
@[builtinTactic Lean.Parser.Tactic.Conv.conv] def evalConv : Tactic := fun stx => do
match stx with
| `(tactic| conv%$tk $[at $loc?]? in $p =>%$arr $code) =>
evalTactic (← `(tactic| conv%$tk $[at $loc?]? =>%$arr pattern $p; ($code:convSeq)))
| `(tactic| conv%$tk $[at $loc?]? in $(occs)? $p =>%$arr $code) =>
evalTactic (← `(tactic| conv%$tk $[at $loc?]? =>%$arr pattern $(occs)? $p; ($code:convSeq)))
| `(tactic| conv%$tk $[at $loc?]? =>%$arr $code) =>
-- show initial conv goal state between `conv` and `=>`
withRef (mkNullNode #[tk, arr]) do

View file

@ -33,42 +33,108 @@ partial def matchPattern? (pattern : AbstractMVarsResult) (e : Expr) : MetaM (Op
return none
withReducible <| go? e
private def pre (pattern : AbstractMVarsResult) (found? : IO.Ref (Option Expr)) (e : Expr) : SimpM Simp.Step := do
if (← found?.get).isSome then
inductive PatternMatchState where
/--
The state corresponding to a `(occs := *)` pattern, which acts like `occs := 1 2 ... n` where
`n` is the total number of pattern matches.
* `subgoals` is the list of subgoals for patterns already matched
-/
| all (subgoals : Array MVarId)
/--
The state corresponding to a partially consumed `(occs := a₁ a₂ ...)` pattern.
* `subgoals` is the list of subgoals for patterns already matched,
along with their index in the original occs list
* `idx` is the number of matches that have occurred so far
* `remaining` is a list of `(i, orig)` pairs representing matches we have not yet reached.
We maintain the invariant that `idx :: remaining.map (·.1)` is sorted.
The number `i` is the value in the `occs` list and `orig` is its index in the list.
-/
| occs (subgoals : Array (Nat × MVarId)) (idx : Nat) (remaining : List (Nat × Nat))
namespace PatternMatchState
/-- Is this pattern no longer interested in accepting matches? -/
def isDone : PatternMatchState → Bool
| .all _ => false
| .occs _ _ remaining => remaining.isEmpty
/-- Is this pattern interested in accepting the next match? -/
def isReady : PatternMatchState → Bool
| .all _ => true
| .occs _ idx ((i, _) :: _) => idx == i
| _ => false
/-- Assuming `isReady` returned false, this advances to the next match. -/
def skip : PatternMatchState → PatternMatchState
| .occs subgoals idx remaining => .occs subgoals (idx + 1) remaining
| s => s
/--
Assuming `isReady` returned true, this adds the generated subgoal to the list
and advances to the next match.
-/
def accept (mvarId : MVarId) : PatternMatchState → PatternMatchState
| .all subgoals => .all (subgoals.push mvarId)
| .occs subgoals idx ((_, n) :: remaining) => .occs (subgoals.push (n, mvarId)) (idx + 1) remaining
| s => s
end PatternMatchState
private def pre (pattern : AbstractMVarsResult) (state : IO.Ref PatternMatchState) (e : Expr) : SimpM Simp.Step := do
if (← state.get).isDone then
return Simp.Step.visit { expr := e }
else if let some (e, extraArgs) ← matchPattern? pattern e then
let (rhs, newGoal) ← mkConvGoalFor e
found?.set newGoal
let mut proof := newGoal
for extraArg in extraArgs do
proof ← mkCongrFun proof extraArg
return Simp.Step.done { expr := mkAppN rhs extraArgs, proof? := proof }
if (← state.get).isReady then
let (rhs, newGoal) ← mkConvGoalFor e
state.modify (·.accept newGoal.mvarId!)
let mut proof := newGoal
for extraArg in extraArgs do
proof ← mkCongrFun proof extraArg
return Simp.Step.done { expr := mkAppN rhs extraArgs, proof? := proof }
else
state.modify (·.skip)
-- Note that because we return `visit` here and `done` in the other case,
-- it is possible for skipping an earlier match to affect what later matches
-- refer to. For example, matching `f _` in `f (f a) = f b` with occs `[1, 2]`
-- yields `[f (f a), f b]`, but `[2, 3]` yields `[f a, f b]`, and `[1, 3]` is an error.
return Simp.Step.visit { expr := e }
else
return Simp.Step.visit { expr := e }
private def findPattern? (pattern : AbstractMVarsResult) (e : Expr) : MetaM (Option (MVarId × Simp.Result)) := do
let found? ← IO.mkRef none
let (result, _) ← Simp.main e (← getContext) (methods := { pre := pre pattern found? })
if let some newGoal ← found?.get then
return some (newGoal.mvarId!, result)
else
return none
@[builtinTactic Lean.Parser.Tactic.Conv.pattern] def evalPattern : Tactic := fun stx => withMainContext do
match stx with
| `(conv| pattern $p) =>
| `(conv| pattern $[(occs := $occs)]? $p) =>
let patternA ←
withTheReader Term.Context (fun ctx => { ctx with ignoreTCFailures := true }) <|
Term.withoutModifyingElabMetaStateWithInfo <| withRef p <|
Term.withoutErrToSorry do
abstractMVars (← Term.elabTerm p none)
let lhs ← getLhs
match (← findPattern? patternA lhs) with
| none => throwError "'pattern' conv tactic failed, pattern was not found{indentExpr patternA.expr}"
| some (mvarId', result) =>
updateLhs result.expr (← result.getProof)
(← getMainGoal).refl
replaceMainGoal [mvarId']
let occs ← match occs with
| none => pure (.occs #[] 0 [(0, 0)])
| some occs => match occs with
| `(Parser.Tactic.Conv.occsWildcard| *) => pure (.all #[])
| `(Parser.Tactic.Conv.occsIndexed| $ids*) => do
let ids ← ids.mapIdxM fun i id =>
match id.getNat with
| 0 => throwErrorAt id "positive integer expected"
| n+1 => pure (n, i.1)
let ids := ids.qsort (·.1 < ·.1)
unless @Array.allDiff _ ⟨(·.1 == ·.1)⟩ ids do
throwError "occurrence list is not distinct"
pure (.occs #[] 0 ids.toList)
| _ => throwUnsupportedSyntax
let state ← IO.mkRef occs
let (result, _) ← Simp.main lhs (← getContext) (methods := { pre := pre patternA state })
let subgoals ← match ← state.get with
| .all #[] | .occs _ 0 _ =>
throwError "'pattern' conv tactic failed, pattern was not found{indentExpr patternA.expr}"
| .all subgoals => pure subgoals
| .occs subgoals idx remaining =>
if let some (i, _) := remaining.getLast? then
throwError "'pattern' conv tactic failed, pattern was found only {idx} times but {i+1} expected"
pure <| (subgoals.qsort (·.1 < ·.1)).map (·.2)
(← getRhs).mvarId!.assign result.expr
(← getMainGoal).assign (← result.getProof)
replaceMainGoal subgoals.toList
| _ => throwUnsupportedSyntax
end Lean.Elab.Tactic.Conv

View file

@ -180,3 +180,17 @@ example : let a := 0; let b := a; b = 0 := by
conv =>
zeta
trace_state
example : ((x + y) + z : Nat) = x + (y + z) := by
conv in _ + _ => trace_state
conv in (occs := *) _ + _ => trace_state
conv in (occs := 1 3) _ + _ => trace_state
conv in (occs := 3 1) _ + _ => trace_state
conv in (occs := 2 3) _ + _ => trace_state
conv in (occs := 2 4) _ + _ => trace_state
apply Nat.add_assoc
example : ((x + y) + z : Nat) = x + (y + z) := by conv => pattern (occs := 5) _ + _
example : ((x + y) + z : Nat) = x + (y + z) := by conv => pattern (occs := 2 5) _ + _
example : ((x + y) + z : Nat) = x + (y + z) := by conv => pattern (occs := 1 5) _ + _
example : ((x + y) + z : Nat) = x + (y + z) := by conv => pattern (occs := 1 2 5) _ + _

View file

@ -85,3 +85,34 @@ conv1.lean:175:10-175:15: error: cannot select argument
a✝ : Nat := 0
b✝ : Nat := a✝
| 0 = 0
x y z : Nat
| x + y + z
x y z : Nat
| x + y + z
x y z : Nat
| x + (y + z)
x y z : Nat
| x + y + z
x y z : Nat
| y + z
x y z : Nat
| y + z
x y z : Nat
| x + y + z
x y z : Nat
| x + y
x y z : Nat
| x + (y + z)
x y z : Nat
| x + y
x y z : Nat
| y + z
conv1.lean:193:58-193:83: error: 'pattern' conv tactic failed, pattern was found only 4 times but 5 expected
conv1.lean:194:58-194:85: error: 'pattern' conv tactic failed, pattern was found only 4 times but 5 expected
conv1.lean:195:58-195:85: error: 'pattern' conv tactic failed, pattern was found only 3 times but 5 expected
conv1.lean:196:58-196:87: error: 'pattern' conv tactic failed, pattern was found only 2 times but 5 expected