feat: pattern (occs := ...) conv
This commit is contained in:
parent
dadfe84c15
commit
9b9998f5c8
5 changed files with 169 additions and 31 deletions
|
|
@ -23,6 +23,18 @@ syntax convSeqBracketed := "{" sepByIndentSemicolon(conv) "}"
|
|||
-- automatically closing goals
|
||||
syntax convSeq := convSeqBracketed <|> convSeq1Indented
|
||||
|
||||
/-- The `*` occurrence list means to apply to all occurrences of the pattern. -/
|
||||
syntax occsWildcard := "*"
|
||||
|
||||
/--
|
||||
A list `1 2 4` of occurrences means to apply to the first, second, and fourth
|
||||
occurrence of the pattern.
|
||||
-/
|
||||
syntax occsIndexed := num+
|
||||
|
||||
/-- An occurrence specification, either `*` or a list of numbers. The default is `[1]`. -/
|
||||
syntax occs := atomic(" (" &"occs") " := " (occsWildcard <|> occsIndexed) ")"
|
||||
|
||||
/--
|
||||
`conv => ...` allows the user to perform targeted rewriting on a goal or hypothesis,
|
||||
by focusing on particular subexpressions.
|
||||
|
|
@ -32,9 +44,9 @@ See <https://leanprover.github.io/theorem_proving_in_lean4/conv.html> for more d
|
|||
Basic forms:
|
||||
* `conv => cs` will rewrite the goal with conv tactics `cs`.
|
||||
* `conv at h => cs` will rewrite hypothesis `h`.
|
||||
* `conv in pat => cs` will rewrite the first subexpression matching `pat`.
|
||||
* `conv in pat => cs` will rewrite the first subexpression matching `pat` (see `pattern`).
|
||||
-/
|
||||
syntax (name := conv) "conv " (" at " ident)? (" in " term)? " => " convSeq : tactic
|
||||
syntax (name := conv) "conv " (" at " ident)? (" in " (occs)? term)? " => " convSeq : tactic
|
||||
|
||||
/-- `skip` does nothing. -/
|
||||
syntax (name := skip) "skip" : conv
|
||||
|
|
@ -92,8 +104,23 @@ to rewrite the target. For recursive definitions,
|
|||
only one layer of unfolding is performed. -/
|
||||
syntax (name := unfold) "unfold " (colGt ident)+ : conv
|
||||
|
||||
/-- `pattern pat` traverses to the first subterm of the target that matches `pat`. -/
|
||||
syntax (name := pattern) "pattern " term : conv
|
||||
/--
|
||||
* `pattern pat` traverses to the first subterm of the target that matches `pat`.
|
||||
* `pattern (occs := *) pat` traverses to the every subterm of the target that matches `pat`
|
||||
which is not contained in another match of `pat`. It generates one subgoal for each matching
|
||||
subterm.
|
||||
* `pattern (occs := 1 2 4) pat` matches occurrences `1, 2, 4` of `pat` and produces three subgoals.
|
||||
Occurrences are numbered left to right from the outside in.
|
||||
|
||||
Note that skipping an occurrence of `pat` will traverse inside that subexpression, which means
|
||||
it may find more matches and this can affect the numbering of subsequent pattern matches.
|
||||
For example, if we are searching for `f _` in `f (f a) = f b`:
|
||||
* `occs := 1 2` (and `occs := *`) returns `| f (f a)` and `| f b`
|
||||
* `occs := 2` returns `| f a`
|
||||
* `occs := 2 3` returns `| f a` and `| f b`
|
||||
* `occs := 1 3` is an error, because after skipping `f b` there is no third match.
|
||||
-/
|
||||
syntax (name := pattern) "pattern " (occs)? term : conv
|
||||
|
||||
/-- `rw [thm]` rewrites the target using `thm`. See the `rw` tactic for more information. -/
|
||||
syntax (name := rewrite) "rewrite" (config)? rwRuleSeq : conv
|
||||
|
|
|
|||
|
|
@ -157,8 +157,8 @@ private def convLocalDecl (conv : Syntax) (hUserName : Name) : TacticM Unit := w
|
|||
|
||||
@[builtinTactic Lean.Parser.Tactic.Conv.conv] def evalConv : Tactic := fun stx => do
|
||||
match stx with
|
||||
| `(tactic| conv%$tk $[at $loc?]? in $p =>%$arr $code) =>
|
||||
evalTactic (← `(tactic| conv%$tk $[at $loc?]? =>%$arr pattern $p; ($code:convSeq)))
|
||||
| `(tactic| conv%$tk $[at $loc?]? in $(occs)? $p =>%$arr $code) =>
|
||||
evalTactic (← `(tactic| conv%$tk $[at $loc?]? =>%$arr pattern $(occs)? $p; ($code:convSeq)))
|
||||
| `(tactic| conv%$tk $[at $loc?]? =>%$arr $code) =>
|
||||
-- show initial conv goal state between `conv` and `=>`
|
||||
withRef (mkNullNode #[tk, arr]) do
|
||||
|
|
|
|||
|
|
@ -33,42 +33,108 @@ partial def matchPattern? (pattern : AbstractMVarsResult) (e : Expr) : MetaM (Op
|
|||
return none
|
||||
withReducible <| go? e
|
||||
|
||||
private def pre (pattern : AbstractMVarsResult) (found? : IO.Ref (Option Expr)) (e : Expr) : SimpM Simp.Step := do
|
||||
if (← found?.get).isSome then
|
||||
inductive PatternMatchState where
|
||||
/--
|
||||
The state corresponding to a `(occs := *)` pattern, which acts like `occs := 1 2 ... n` where
|
||||
`n` is the total number of pattern matches.
|
||||
* `subgoals` is the list of subgoals for patterns already matched
|
||||
-/
|
||||
| all (subgoals : Array MVarId)
|
||||
/--
|
||||
The state corresponding to a partially consumed `(occs := a₁ a₂ ...)` pattern.
|
||||
* `subgoals` is the list of subgoals for patterns already matched,
|
||||
along with their index in the original occs list
|
||||
* `idx` is the number of matches that have occurred so far
|
||||
* `remaining` is a list of `(i, orig)` pairs representing matches we have not yet reached.
|
||||
We maintain the invariant that `idx :: remaining.map (·.1)` is sorted.
|
||||
The number `i` is the value in the `occs` list and `orig` is its index in the list.
|
||||
-/
|
||||
| occs (subgoals : Array (Nat × MVarId)) (idx : Nat) (remaining : List (Nat × Nat))
|
||||
|
||||
namespace PatternMatchState
|
||||
|
||||
/-- Is this pattern no longer interested in accepting matches? -/
|
||||
def isDone : PatternMatchState → Bool
|
||||
| .all _ => false
|
||||
| .occs _ _ remaining => remaining.isEmpty
|
||||
|
||||
/-- Is this pattern interested in accepting the next match? -/
|
||||
def isReady : PatternMatchState → Bool
|
||||
| .all _ => true
|
||||
| .occs _ idx ((i, _) :: _) => idx == i
|
||||
| _ => false
|
||||
|
||||
/-- Assuming `isReady` returned false, this advances to the next match. -/
|
||||
def skip : PatternMatchState → PatternMatchState
|
||||
| .occs subgoals idx remaining => .occs subgoals (idx + 1) remaining
|
||||
| s => s
|
||||
|
||||
/--
|
||||
Assuming `isReady` returned true, this adds the generated subgoal to the list
|
||||
and advances to the next match.
|
||||
-/
|
||||
def accept (mvarId : MVarId) : PatternMatchState → PatternMatchState
|
||||
| .all subgoals => .all (subgoals.push mvarId)
|
||||
| .occs subgoals idx ((_, n) :: remaining) => .occs (subgoals.push (n, mvarId)) (idx + 1) remaining
|
||||
| s => s
|
||||
|
||||
end PatternMatchState
|
||||
|
||||
private def pre (pattern : AbstractMVarsResult) (state : IO.Ref PatternMatchState) (e : Expr) : SimpM Simp.Step := do
|
||||
if (← state.get).isDone then
|
||||
return Simp.Step.visit { expr := e }
|
||||
else if let some (e, extraArgs) ← matchPattern? pattern e then
|
||||
let (rhs, newGoal) ← mkConvGoalFor e
|
||||
found?.set newGoal
|
||||
let mut proof := newGoal
|
||||
for extraArg in extraArgs do
|
||||
proof ← mkCongrFun proof extraArg
|
||||
return Simp.Step.done { expr := mkAppN rhs extraArgs, proof? := proof }
|
||||
if (← state.get).isReady then
|
||||
let (rhs, newGoal) ← mkConvGoalFor e
|
||||
state.modify (·.accept newGoal.mvarId!)
|
||||
let mut proof := newGoal
|
||||
for extraArg in extraArgs do
|
||||
proof ← mkCongrFun proof extraArg
|
||||
return Simp.Step.done { expr := mkAppN rhs extraArgs, proof? := proof }
|
||||
else
|
||||
state.modify (·.skip)
|
||||
-- Note that because we return `visit` here and `done` in the other case,
|
||||
-- it is possible for skipping an earlier match to affect what later matches
|
||||
-- refer to. For example, matching `f _` in `f (f a) = f b` with occs `[1, 2]`
|
||||
-- yields `[f (f a), f b]`, but `[2, 3]` yields `[f a, f b]`, and `[1, 3]` is an error.
|
||||
return Simp.Step.visit { expr := e }
|
||||
else
|
||||
return Simp.Step.visit { expr := e }
|
||||
|
||||
private def findPattern? (pattern : AbstractMVarsResult) (e : Expr) : MetaM (Option (MVarId × Simp.Result)) := do
|
||||
let found? ← IO.mkRef none
|
||||
let (result, _) ← Simp.main e (← getContext) (methods := { pre := pre pattern found? })
|
||||
if let some newGoal ← found?.get then
|
||||
return some (newGoal.mvarId!, result)
|
||||
else
|
||||
return none
|
||||
|
||||
@[builtinTactic Lean.Parser.Tactic.Conv.pattern] def evalPattern : Tactic := fun stx => withMainContext do
|
||||
match stx with
|
||||
| `(conv| pattern $p) =>
|
||||
| `(conv| pattern $[(occs := $occs)]? $p) =>
|
||||
let patternA ←
|
||||
withTheReader Term.Context (fun ctx => { ctx with ignoreTCFailures := true }) <|
|
||||
Term.withoutModifyingElabMetaStateWithInfo <| withRef p <|
|
||||
Term.withoutErrToSorry do
|
||||
abstractMVars (← Term.elabTerm p none)
|
||||
let lhs ← getLhs
|
||||
match (← findPattern? patternA lhs) with
|
||||
| none => throwError "'pattern' conv tactic failed, pattern was not found{indentExpr patternA.expr}"
|
||||
| some (mvarId', result) =>
|
||||
updateLhs result.expr (← result.getProof)
|
||||
(← getMainGoal).refl
|
||||
replaceMainGoal [mvarId']
|
||||
let occs ← match occs with
|
||||
| none => pure (.occs #[] 0 [(0, 0)])
|
||||
| some occs => match occs with
|
||||
| `(Parser.Tactic.Conv.occsWildcard| *) => pure (.all #[])
|
||||
| `(Parser.Tactic.Conv.occsIndexed| $ids*) => do
|
||||
let ids ← ids.mapIdxM fun i id =>
|
||||
match id.getNat with
|
||||
| 0 => throwErrorAt id "positive integer expected"
|
||||
| n+1 => pure (n, i.1)
|
||||
let ids := ids.qsort (·.1 < ·.1)
|
||||
unless @Array.allDiff _ ⟨(·.1 == ·.1)⟩ ids do
|
||||
throwError "occurrence list is not distinct"
|
||||
pure (.occs #[] 0 ids.toList)
|
||||
| _ => throwUnsupportedSyntax
|
||||
let state ← IO.mkRef occs
|
||||
let (result, _) ← Simp.main lhs (← getContext) (methods := { pre := pre patternA state })
|
||||
let subgoals ← match ← state.get with
|
||||
| .all #[] | .occs _ 0 _ =>
|
||||
throwError "'pattern' conv tactic failed, pattern was not found{indentExpr patternA.expr}"
|
||||
| .all subgoals => pure subgoals
|
||||
| .occs subgoals idx remaining =>
|
||||
if let some (i, _) := remaining.getLast? then
|
||||
throwError "'pattern' conv tactic failed, pattern was found only {idx} times but {i+1} expected"
|
||||
pure <| (subgoals.qsort (·.1 < ·.1)).map (·.2)
|
||||
(← getRhs).mvarId!.assign result.expr
|
||||
(← getMainGoal).assign (← result.getProof)
|
||||
replaceMainGoal subgoals.toList
|
||||
| _ => throwUnsupportedSyntax
|
||||
|
||||
end Lean.Elab.Tactic.Conv
|
||||
|
|
|
|||
|
|
@ -180,3 +180,17 @@ example : let a := 0; let b := a; b = 0 := by
|
|||
conv =>
|
||||
zeta
|
||||
trace_state
|
||||
|
||||
example : ((x + y) + z : Nat) = x + (y + z) := by
|
||||
conv in _ + _ => trace_state
|
||||
conv in (occs := *) _ + _ => trace_state
|
||||
conv in (occs := 1 3) _ + _ => trace_state
|
||||
conv in (occs := 3 1) _ + _ => trace_state
|
||||
conv in (occs := 2 3) _ + _ => trace_state
|
||||
conv in (occs := 2 4) _ + _ => trace_state
|
||||
apply Nat.add_assoc
|
||||
|
||||
example : ((x + y) + z : Nat) = x + (y + z) := by conv => pattern (occs := 5) _ + _
|
||||
example : ((x + y) + z : Nat) = x + (y + z) := by conv => pattern (occs := 2 5) _ + _
|
||||
example : ((x + y) + z : Nat) = x + (y + z) := by conv => pattern (occs := 1 5) _ + _
|
||||
example : ((x + y) + z : Nat) = x + (y + z) := by conv => pattern (occs := 1 2 5) _ + _
|
||||
|
|
|
|||
|
|
@ -85,3 +85,34 @@ conv1.lean:175:10-175:15: error: cannot select argument
|
|||
a✝ : Nat := 0
|
||||
b✝ : Nat := a✝
|
||||
| 0 = 0
|
||||
x y z : Nat
|
||||
| x + y + z
|
||||
x y z : Nat
|
||||
| x + y + z
|
||||
|
||||
x y z : Nat
|
||||
| x + (y + z)
|
||||
x y z : Nat
|
||||
| x + y + z
|
||||
|
||||
x y z : Nat
|
||||
| y + z
|
||||
x y z : Nat
|
||||
| y + z
|
||||
|
||||
x y z : Nat
|
||||
| x + y + z
|
||||
x y z : Nat
|
||||
| x + y
|
||||
|
||||
x y z : Nat
|
||||
| x + (y + z)
|
||||
x y z : Nat
|
||||
| x + y
|
||||
|
||||
x y z : Nat
|
||||
| y + z
|
||||
conv1.lean:193:58-193:83: error: 'pattern' conv tactic failed, pattern was found only 4 times but 5 expected
|
||||
conv1.lean:194:58-194:85: error: 'pattern' conv tactic failed, pattern was found only 4 times but 5 expected
|
||||
conv1.lean:195:58-195:85: error: 'pattern' conv tactic failed, pattern was found only 3 times but 5 expected
|
||||
conv1.lean:196:58-196:87: error: 'pattern' conv tactic failed, pattern was found only 2 times but 5 expected
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue