feat: optional preprocessing tactic at induction and cases

cc @Kha
This commit is contained in:
Leonardo de Moura 2021-03-07 14:48:24 -08:00
parent 5ba171d946
commit 0a881315ba
3 changed files with 38 additions and 25 deletions

View file

@ -168,7 +168,7 @@ private def checkAltNames (alts : Array (Name × MVarId)) (altsSyntax : Array Sy
unless alts.any fun (n, _) => n == altName do
throwErrorAt! altStx "invalid alternative name '{altName}'"
def evalAlts (elimInfo : ElimInfo) (alts : Array (Name × MVarId)) (altsSyntax : Array Syntax)
def evalAlts (elimInfo : ElimInfo) (alts : Array (Name × MVarId)) (optPreTac : Syntax) (altsSyntax : Array Syntax)
(numEqs : Nat := 0) (numGeneralized : Nat := 0) (toClear : Array FVarId := #[]) : TacticM Unit := do
checkAltNames alts altsSyntax
let mut usedWildcard := false
@ -195,13 +195,16 @@ def evalAlts (elimInfo : ElimInfo) (alts : Array (Name × MVarId)) (altsSyntax :
match (← Cases.unifyEqs numEqs altMVarId {}) with
| none => pure () -- alternative is not reachable
| some (altMVarId, _) =>
let (_, altMVarId) ← introNP altMVarId numGeneralized
for fvarId in toClear do
altMVarId ← tryClear altMVarId fvarId
let altMVarIds ← applyPreTac altMVarId
if !hasAlts then
-- User did not provide alternatives using `|`
let (_, altMVarId) ← introNP altMVarId numGeneralized
for fvarId in toClear do
altMVarId ← tryClear altMVarId fvarId
trace[Meta.debug]! "new subgoal {MessageData.ofGoal altMVarId}"
subgoals := subgoals.push altMVarId
subgoals := subgoals ++ altMVarIds.toArray
else if altMVarIds.isEmpty then
pure ()
else
throwError! "alternative '{altName}' has not been provided"
| some altStx =>
@ -214,12 +217,23 @@ def evalAlts (elimInfo : ElimInfo) (alts : Array (Name × MVarId)) (altsSyntax :
let (_, altMVarId) ← introNP altMVarId numGeneralized
for fvarId in toClear do
altMVarId ← tryClear altMVarId fvarId
evalAlt altMVarId altStx subgoals
let altMVarIds ← applyPreTac altMVarId
if altMVarIds.isEmpty then
throwError! "alternative '{altName}' is not needed"
else
altMVarIds.foldlM (init := subgoals) fun subgoal altMVarId =>
evalAlt altMVarId altStx subgoals
if usedWildcard then
altsSyntax := altsSyntax.filter fun alt => getAltName alt != `_
unless altsSyntax.isEmpty do
throwErrorAt altsSyntax[0] "unused alternative"
setGoals subgoals.toList
where
applyPreTac (mvarId : MVarId) : TacticM (List MVarId) :=
if optPreTac.isNone then
return [mvarId]
else
evalTacticAt optPreTac[0] mvarId
end ElimApp
@ -256,6 +270,9 @@ private def getAltsOfInductionAlts (inductionAlts : Syntax) : Array Syntax :=
private def getAltsOfOptInductionAlts (optInductionAlts : Syntax) : Array Syntax :=
if optInductionAlts.isNone then #[] else getAltsOfInductionAlts optInductionAlts[0]
private def getOptPreTacOfOptInductionAlts (optInductionAlts : Syntax) : Syntax :=
if optInductionAlts.isNone then mkNullNode else optInductionAlts[0][1]
/-
We may have at most one `| _ => ...` (wildcard alternative), and it must not set variable names.
The idea is to make sure users do not write unstructured tactics. -/
@ -318,7 +335,9 @@ private def getElimNameInfo (optElimId : Syntax) (targets : Array Expr) (inducti
let targetFVarIds := targets.map (·.fvarId!)
ElimApp.setMotiveArg mvarId elimArgs[elimInfo.motivePos].mvarId! targetFVarIds
let optInductionAlts := stx[4]
ElimApp.evalAlts elimInfo result.alts (getAltsOfOptInductionAlts optInductionAlts) (numGeneralized := n) (toClear := targetFVarIds)
let optPreTac := getOptPreTacOfOptInductionAlts optInductionAlts
let alts := getAltsOfOptInductionAlts optInductionAlts
ElimApp.evalAlts elimInfo result.alts optPreTac alts (numGeneralized := n) (toClear := targetFVarIds)
appendGoals result.others.toList
-- Recall that
@ -359,6 +378,8 @@ builtin_initialize registerTraceClass `Elab.cases
-- parser! nonReservedSymbol "cases " >> sepBy1 (group majorPremise) ", " >> usingRec >> optInductionAlts
let targets ← elabTargets stx[1].getSepArgs
let optInductionAlts := stx[3]
let optPreTac := getOptPreTacOfOptInductionAlts optInductionAlts
let alts := getAltsOfOptInductionAlts optInductionAlts
let targetRef := stx[1]
let (elimName, elimInfo) ← getElimNameInfo stx[2] targets (induction := false)
let (mvarId, _) ← getMainGoal
@ -373,6 +394,6 @@ builtin_initialize registerTraceClass `Elab.cases
withMVarContext mvarId do
ElimApp.setMotiveArg mvarId elimArgs[elimInfo.motivePos].mvarId! targetsNew
assignExprMVar mvarId result.elimApp
ElimApp.evalAlts elimInfo result.alts (getAltsOfOptInductionAlts optInductionAlts) (numEqs := targets.size) (toClear := targetsNew)
ElimApp.evalAlts elimInfo result.alts optPreTac alts (numEqs := targets.size) (toClear := targetsNew)
end Lean.Elab.Tactic

View file

@ -14,10 +14,9 @@ theorem eq_findM [Monad m] [LawfulMonad m] (p : α → m Bool) (xs : List α) :
return none)
=
xs.findM? p := by
induction xs with
| nil => simp [List.findM?]
induction xs with simp [List.findM?]
| cons x xs ih =>
simp [List.findM?]; rw[← ih]; simp
rw[← ih]; simp
apply byCases_Bool_bind <;> simp
theorem eq_findSomeM_findM [Monad m] [LawfulMonad m] (p : α → m Bool) (xss : List (List α)) :
@ -29,14 +28,11 @@ theorem eq_findSomeM_findM [Monad m] [LawfulMonad m] (p : α → m Bool) (xss :
return none)
=
xss.findSomeM? (fun xs => xs.findM? p) := by
induction xss with
| nil => simp [List.findSomeM?]
induction xss with simp [List.findSomeM?]
| cons xs xss ih =>
simp [List.findSomeM?]
rw [← ih, ← eq_findM]
induction xs with
| nil => simp
| cons x xs ih => simp; apply byCases_Bool_bind <;> simp [ih]
induction xs with simp
| cons x xs ih => apply byCases_Bool_bind <;> simp [ih]
theorem eq_findSomeM_findM' [Monad m] [LawfulMonad m] (p : α → m Bool) (xss : List (List α)) :
(do for xs in xss do

View file

@ -15,10 +15,9 @@ def Expr.times : Nat → Expr → Expr
| k, mul e₁ e₂ => mul (times k e₁) e₂
theorem eval_times (k : Nat) (e : Expr) : e.times k |>.eval = k * e.eval := by
induction e with
| const => simp [Expr.times, Expr.eval]
| plus e₁ e₂ ih₁ ih₂ => simp [Expr.times, Expr.eval, ih₁, ih₂, Nat.left_distrib]
| mul _ _ ih₁ ih₂ => simp [Expr.times, Expr.eval, ih₁, Nat.mul_assoc]
induction e with simp [Expr.times, Expr.eval]
| plus e₁ e₂ ih₁ ih₂ => simp [ih₁, ih₂, Nat.left_distrib]
| mul _ _ ih₁ ih₂ => simp [ih₁, Nat.mul_assoc]
def Expr.reassoc : Expr → Expr
| const n => const n
@ -36,13 +35,10 @@ def Expr.reassoc : Expr → Expr
| _ => mul e₁' e₂'
theorem eval_reassoc (e : Expr) : e.reassoc.eval = e.eval := by
induction e with
| const => rfl
induction e with simp [Expr.reassoc]
| plus e₁ e₂ ih₁ ih₂ =>
simp [Expr.reassoc]
generalize h : (Expr.reassoc e₂) = e₂'
cases e₂' <;> rw [h] at ih₂ <;> simp [Expr.eval] at * <;> rw [← ih₂, ih₁]; rw [Nat.add_assoc]
| mul e₁ e₂ ih₁ ih₂ =>
simp [Expr.reassoc]
generalize h : (Expr.reassoc e₂) = e₂'
cases e₂' <;> rw [h] at ih₂ <;> simp [Expr.eval] at * <;> rw [← ih₂, ih₁]; rw [Nat.mul_assoc]