diff --git a/src/Lean/Elab/Tactic/Induction.lean b/src/Lean/Elab/Tactic/Induction.lean index d577ba5937..b1fb4326af 100644 --- a/src/Lean/Elab/Tactic/Induction.lean +++ b/src/Lean/Elab/Tactic/Induction.lean @@ -168,7 +168,7 @@ private def checkAltNames (alts : Array (Name × MVarId)) (altsSyntax : Array Sy unless alts.any fun (n, _) => n == altName do throwErrorAt! altStx "invalid alternative name '{altName}'" -def evalAlts (elimInfo : ElimInfo) (alts : Array (Name × MVarId)) (altsSyntax : Array Syntax) +def evalAlts (elimInfo : ElimInfo) (alts : Array (Name × MVarId)) (optPreTac : Syntax) (altsSyntax : Array Syntax) (numEqs : Nat := 0) (numGeneralized : Nat := 0) (toClear : Array FVarId := #[]) : TacticM Unit := do checkAltNames alts altsSyntax let mut usedWildcard := false @@ -195,13 +195,16 @@ def evalAlts (elimInfo : ElimInfo) (alts : Array (Name × MVarId)) (altsSyntax : match (← Cases.unifyEqs numEqs altMVarId {}) with | none => pure () -- alternative is not reachable | some (altMVarId, _) => + let (_, altMVarId) ← introNP altMVarId numGeneralized + for fvarId in toClear do + altMVarId ← tryClear altMVarId fvarId + let altMVarIds ← applyPreTac altMVarId if !hasAlts then -- User did not provide alternatives using `|` - let (_, altMVarId) ← introNP altMVarId numGeneralized - for fvarId in toClear do - altMVarId ← tryClear altMVarId fvarId trace[Meta.debug]! "new subgoal {MessageData.ofGoal altMVarId}" - subgoals := subgoals.push altMVarId + subgoals := subgoals ++ altMVarIds.toArray + else if altMVarIds.isEmpty then + pure () else throwError! "alternative '{altName}' has not been provided" | some altStx => @@ -214,12 +217,23 @@ def evalAlts (elimInfo : ElimInfo) (alts : Array (Name × MVarId)) (altsSyntax : let (_, altMVarId) ← introNP altMVarId numGeneralized for fvarId in toClear do altMVarId ← tryClear altMVarId fvarId - evalAlt altMVarId altStx subgoals + let altMVarIds ← applyPreTac altMVarId + if altMVarIds.isEmpty then + throwError! "alternative '{altName}' is not needed" + else + altMVarIds.foldlM (init := subgoals) fun subgoal altMVarId => + evalAlt altMVarId altStx subgoals if usedWildcard then altsSyntax := altsSyntax.filter fun alt => getAltName alt != `_ unless altsSyntax.isEmpty do throwErrorAt altsSyntax[0] "unused alternative" setGoals subgoals.toList +where + applyPreTac (mvarId : MVarId) : TacticM (List MVarId) := + if optPreTac.isNone then + return [mvarId] + else + evalTacticAt optPreTac[0] mvarId end ElimApp @@ -256,6 +270,9 @@ private def getAltsOfInductionAlts (inductionAlts : Syntax) : Array Syntax := private def getAltsOfOptInductionAlts (optInductionAlts : Syntax) : Array Syntax := if optInductionAlts.isNone then #[] else getAltsOfInductionAlts optInductionAlts[0] +private def getOptPreTacOfOptInductionAlts (optInductionAlts : Syntax) : Syntax := + if optInductionAlts.isNone then mkNullNode else optInductionAlts[0][1] + /- We may have at most one `| _ => ...` (wildcard alternative), and it must not set variable names. The idea is to make sure users do not write unstructured tactics. -/ @@ -318,7 +335,9 @@ private def getElimNameInfo (optElimId : Syntax) (targets : Array Expr) (inducti let targetFVarIds := targets.map (·.fvarId!) ElimApp.setMotiveArg mvarId elimArgs[elimInfo.motivePos].mvarId! targetFVarIds let optInductionAlts := stx[4] - ElimApp.evalAlts elimInfo result.alts (getAltsOfOptInductionAlts optInductionAlts) (numGeneralized := n) (toClear := targetFVarIds) + let optPreTac := getOptPreTacOfOptInductionAlts optInductionAlts + let alts := getAltsOfOptInductionAlts optInductionAlts + ElimApp.evalAlts elimInfo result.alts optPreTac alts (numGeneralized := n) (toClear := targetFVarIds) appendGoals result.others.toList -- Recall that @@ -359,6 +378,8 @@ builtin_initialize registerTraceClass `Elab.cases -- parser! nonReservedSymbol "cases " >> sepBy1 (group majorPremise) ", " >> usingRec >> optInductionAlts let targets ← elabTargets stx[1].getSepArgs let optInductionAlts := stx[3] + let optPreTac := getOptPreTacOfOptInductionAlts optInductionAlts + let alts := getAltsOfOptInductionAlts optInductionAlts let targetRef := stx[1] let (elimName, elimInfo) ← getElimNameInfo stx[2] targets (induction := false) let (mvarId, _) ← getMainGoal @@ -373,6 +394,6 @@ builtin_initialize registerTraceClass `Elab.cases withMVarContext mvarId do ElimApp.setMotiveArg mvarId elimArgs[elimInfo.motivePos].mvarId! targetsNew assignExprMVar mvarId result.elimApp - ElimApp.evalAlts elimInfo result.alts (getAltsOfOptInductionAlts optInductionAlts) (numEqs := targets.size) (toClear := targetsNew) + ElimApp.evalAlts elimInfo result.alts optPreTac alts (numEqs := targets.size) (toClear := targetsNew) end Lean.Elab.Tactic diff --git a/tests/lean/run/do_eqv.lean b/tests/lean/run/do_eqv.lean index 6a3a61622e..139f75aa66 100644 --- a/tests/lean/run/do_eqv.lean +++ b/tests/lean/run/do_eqv.lean @@ -14,10 +14,9 @@ theorem eq_findM [Monad m] [LawfulMonad m] (p : α → m Bool) (xs : List α) : return none) = xs.findM? p := by - induction xs with - | nil => simp [List.findM?] + induction xs with simp [List.findM?] | cons x xs ih => - simp [List.findM?]; rw[← ih]; simp + rw[← ih]; simp apply byCases_Bool_bind <;> simp theorem eq_findSomeM_findM [Monad m] [LawfulMonad m] (p : α → m Bool) (xss : List (List α)) : @@ -29,14 +28,11 @@ theorem eq_findSomeM_findM [Monad m] [LawfulMonad m] (p : α → m Bool) (xss : return none) = xss.findSomeM? (fun xs => xs.findM? p) := by - induction xss with - | nil => simp [List.findSomeM?] + induction xss with simp [List.findSomeM?] | cons xs xss ih => - simp [List.findSomeM?] rw [← ih, ← eq_findM] - induction xs with - | nil => simp - | cons x xs ih => simp; apply byCases_Bool_bind <;> simp [ih] + induction xs with simp + | cons x xs ih => apply byCases_Bool_bind <;> simp [ih] theorem eq_findSomeM_findM' [Monad m] [LawfulMonad m] (p : α → m Bool) (xss : List (List α)) : (do for xs in xss do diff --git a/tests/lean/run/exp.lean b/tests/lean/run/exp.lean index 3b139c83f4..856ad0d38f 100644 --- a/tests/lean/run/exp.lean +++ b/tests/lean/run/exp.lean @@ -15,10 +15,9 @@ def Expr.times : Nat → Expr → Expr | k, mul e₁ e₂ => mul (times k e₁) e₂ theorem eval_times (k : Nat) (e : Expr) : e.times k |>.eval = k * e.eval := by - induction e with - | const => simp [Expr.times, Expr.eval] - | plus e₁ e₂ ih₁ ih₂ => simp [Expr.times, Expr.eval, ih₁, ih₂, Nat.left_distrib] - | mul _ _ ih₁ ih₂ => simp [Expr.times, Expr.eval, ih₁, Nat.mul_assoc] + induction e with simp [Expr.times, Expr.eval] + | plus e₁ e₂ ih₁ ih₂ => simp [ih₁, ih₂, Nat.left_distrib] + | mul _ _ ih₁ ih₂ => simp [ih₁, Nat.mul_assoc] def Expr.reassoc : Expr → Expr | const n => const n @@ -36,13 +35,10 @@ def Expr.reassoc : Expr → Expr | _ => mul e₁' e₂' theorem eval_reassoc (e : Expr) : e.reassoc.eval = e.eval := by - induction e with - | const => rfl + induction e with simp [Expr.reassoc] | plus e₁ e₂ ih₁ ih₂ => - simp [Expr.reassoc] generalize h : (Expr.reassoc e₂) = e₂' cases e₂' <;> rw [h] at ih₂ <;> simp [Expr.eval] at * <;> rw [← ih₂, ih₁]; rw [Nat.add_assoc] | mul e₁ e₂ ih₁ ih₂ => - simp [Expr.reassoc] generalize h : (Expr.reassoc e₂) = e₂' cases e₂' <;> rw [h] at ih₂ <;> simp [Expr.eval] at * <;> rw [← ih₂, ih₁]; rw [Nat.mul_assoc]