From 4a0f8bf21ab3e6a55369dc658e4736a46ab377fa Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Sat, 27 Mar 2021 14:28:03 -0700 Subject: [PATCH] feat: improve `generalizing` at `induction` --- src/Lean/Elab/Tactic/Induction.lean | 87 ++++++++++++++++--- tests/lean/inductionGen.lean | 68 +++++++++++++++ tests/lean/inductionGen.lean.expected.out | 55 ++++++++++++ tests/lean/run/casesUsing.lean | 2 +- tests/lean/run/do_eqv.lean | 3 +- tests/lean/run/induction1.lean | 6 +- tests/lean/run/inductionAltExplicit.lean | 12 +-- tests/lean/unsolvedIndCases.lean | 2 +- tests/lean/unsolvedIndCases.lean.expected.out | 4 + 9 files changed, 217 insertions(+), 22 deletions(-) create mode 100644 tests/lean/inductionGen.lean create mode 100644 tests/lean/inductionGen.lean.expected.out diff --git a/src/Lean/Elab/Tactic/Induction.lean b/src/Lean/Elab/Tactic/Induction.lean index a137891b68..f7f50dade5 100644 --- a/src/Lean/Elab/Tactic/Induction.lean +++ b/src/Lean/Elab/Tactic/Induction.lean @@ -3,6 +3,7 @@ Copyright (c) 2020 Microsoft Corporation. All rights reserved. Released under Apache 2.0 license as described in the file LICENSE. Authors: Leonardo de Moura, Sebastian Ullrich -/ +import Lean.Util.CollectFVars import Lean.Parser.Term import Lean.Meta.RecursorInfo import Lean.Meta.CollectMVars @@ -142,7 +143,7 @@ partial def mkElimApp (elimName : Name) (elimInfo : ElimInfo) (targets : Array E catch _ => setMVarKind mvarId MetavarKind.syntheticOpaque others := others.push mvarId - pure { elimApp := (← instantiateMVars s.f), alts := s.alts, others := others } + return { elimApp := (← instantiateMVars s.f), alts := s.alts, others := others } /- Given a goal `... targets ... |- C[targets]` associated with `mvarId`, assign `motiveArg := fun targets => C[targets]` -/ @@ -239,6 +240,53 @@ where end ElimApp +/-- + Return a set of `FVarId`s containing `targets` and all variables they depend on. + + Remark: this method assumes `targets` are free variables. +-/ +private partial def mkForbiddenSet (targets : Array Expr) : MetaM NameSet := do + loop (targets.toList.map Expr.fvarId!) {} +where + visit (fvarId : FVarId) (todo : List FVarId) (s : NameSet) : MetaM (List FVarId × NameSet) := do + let localDecl ← getLocalDecl fvarId + let mut s' := collectFVars {} (← instantiateMVars localDecl.type) + if let some val := localDecl.value? then + s' := collectFVars s' (← instantiateMVars val) + let mut todo := todo + let mut s := s + for fvarId in s'.fvarSet do + unless s.contains fvarId do + todo := fvarId :: todo + s := s.insert fvarId + return (todo, s) + + loop (todo : List FVarId) (s : NameSet) : MetaM NameSet := do + match todo with + | [] => return s + | fvarId::todo => + if s.contains fvarId then + return s + else + let (todo, s) ← visit fvarId todo <| s.insert fvarId + loop todo s + +/-- + Collect forward dependencies that are not in the forbidden set, and depend on some variable in `targets`. + + Remark: this method assumes `targets` are free variables. +-/ +private def collectForwardDeps (targets : Array Expr) (forbidden : NameSet) : MetaM NameSet := do + let mut s : NameSet := targets.foldl (init := {}) fun s target => s.insert target.fvarId! + let mut r : NameSet := {} + for localDecl in (← getLCtx) do + unless forbidden.contains localDecl.fvarId do + unless localDecl.isAuxDecl do + if (← getMCtx).findLocalDeclDependsOn localDecl fun fvarId => s.contains fvarId then + r := r.insert localDecl.fvarId + s := s.insert localDecl.fvarId + return r + /- Recall that ``` @@ -251,19 +299,28 @@ private def getGeneralizingFVarIds (stx : Syntax) : TacticM (Array FVarId) := let generalizingStx := stx[3] if generalizingStx.isNone then pure #[] - else withMainContext do + else trace[Elab.induction] "{generalizingStx}" let vars := generalizingStx[1].getArgs getFVarIds vars -- process `generalizingVars` subterm of induction Syntax `stx`. -private def generalizeVars (stx : Syntax) (targets : Array Expr) : TacticM Nat := do - let fvarIds ← getGeneralizingFVarIds stx - liftMetaTacticAux fun mvarId => do +private def generalizeVars (mvarId : MVarId) (stx : Syntax) (targets : Array Expr) : TacticM (Nat × MVarId) := + withMVarContext mvarId do + let userFVarIds ← getGeneralizingFVarIds stx + let forbidden ← mkForbiddenSet targets + let mut s ← collectForwardDeps targets forbidden + for userFVarId in userFVarIds do + if forbidden.contains userFVarId then + throwError "variable cannot be generalized because target depends on it{indentExpr (mkFVar userFVarId)}" + if s.contains userFVarId then + throwError "unnecessary 'generalizing' argument, variable '{mkFVar userFVarId}' is generalized automatically" + s := s.insert userFVarId + let fvarIds := s.fold (init := #[]) fun s fvarId => s.push fvarId + let lctx ← getLCtx + let fvarIds ← fvarIds.qsort fun x y => (lctx.get! x).index < (lctx.get! y).index let (fvarIds, mvarId') ← Meta.revert mvarId fvarIds - if targets.any fun target => fvarIds.contains target.fvarId! then - Meta.throwTacticEx `induction mvarId "major premise depends on variable being generalized" - pure (fvarIds.size, [mvarId']) + return (fvarIds.size, mvarId') -- syntax inductionAlts := "with " (tactic)? withPosition( (colGe inductionAlt)+) private def getAltsOfInductionAlts (inductionAlts : Syntax) : Array Syntax := @@ -324,23 +381,33 @@ private def getElimNameInfo (optElimId : Syntax) (targets : Array Expr) (inducti let targets ← stx[1].getSepArgs.mapM fun target => do let target ← withMainContext <| elabTerm target none generalizeTerm target - let n ← generalizeVars stx targets let (elimName, elimInfo) ← getElimNameInfo stx[2] targets (induction := true) let mvarId ← getMainGoal let tag ← getMVarTag mvarId withMVarContext mvarId do let result ← withRef stx[1] do -- use target position as reference ElimApp.mkElimApp elimName elimInfo targets tag - assignExprMVar mvarId result.elimApp let elimArgs := result.elimApp.getAppArgs let targets ← elimInfo.targetsPos.mapM fun i => instantiateMVars elimArgs[i] + checkTargets targets + let motiveType ← inferType elimArgs[elimInfo.motivePos] + let (n, mvarId) ← generalizeVars mvarId stx targets let targetFVarIds := targets.map (·.fvarId!) ElimApp.setMotiveArg mvarId elimArgs[elimInfo.motivePos].mvarId! targetFVarIds let optInductionAlts := stx[4] let optPreTac := getOptPreTacOfOptInductionAlts optInductionAlts let alts := getAltsOfOptInductionAlts optInductionAlts + assignExprMVar mvarId result.elimApp ElimApp.evalAlts elimInfo result.alts optPreTac alts (numGeneralized := n) (toClear := targetFVarIds) appendGoals result.others.toList +where + checkTargets (targets : Array Expr) : MetaM Unit := do + let mut foundFVars : NameSet := {} + for target in targets do + unless target.isFVar do + throwError "index in target's type is not a variable (consider using the `cases` tactic instead){indentExpr target}" + if foundFVars.contains target.fvarId! then + throwError "target (or one of its indices) occurs more than once{indentExpr target}" -- Recall that -- majorPremise := leading_parser optional (try (ident >> " : ")) >> termParser diff --git a/tests/lean/inductionGen.lean b/tests/lean/inductionGen.lean new file mode 100644 index 0000000000..37a511243e --- /dev/null +++ b/tests/lean/inductionGen.lean @@ -0,0 +1,68 @@ +inductive Vec (α : Type u) : Nat → Type u + | nil : Vec α 0 + | cons : α → Vec α n → Vec α (n+1) + +def Vec.map (xs : Vec α n) (f : α → β) : Vec β n := + match xs with + | nil => nil + | cons a as => cons (f a) (map as f) + +def Vec.map' (f : α → β) : Vec α n → Vec β n + | nil => nil + | cons a as => cons (f a) (map' f as) + +def Vec.map2 (f : α → α → β) : Vec α n → Vec α n → Vec β n + | nil, nil => nil + | cons a as, cons b bs => cons (f a b) (map2 f as bs) + +def Vec.head (xs : Vec α (n+1)) : α := + match xs with + | cons x _ => x + +theorem ex1 (xs ys : Vec α (n+1)) (h : xs = ys) : xs.head = ys.head := by + induction xs -- error, use cases + +theorem ex2 (xs ys : Vec α (n+1)) (h : xs = ys) : xs.head = ys.head := by + cases xs with + | cons x xs => + traceState -- `h` has been refined + repeat admit + +inductive ExprType where + | bool : ExprType + | nat : ExprType + +inductive Expr : ExprType → Type + | natVal : Nat → Expr ExprType.nat + | boolVal : Bool → Expr ExprType.bool + | eq : Expr α → Expr α → Expr ExprType.bool + | add : Expr ExprType.nat → Expr ExprType.nat → Expr ExprType.nat + +def constProp : Expr α → Expr α + | Expr.add a b => + match constProp a, constProp b with + | Expr.natVal v, Expr.natVal w => Expr.natVal (v + w) + | a, b => Expr.add a b + | e => e + +abbrev denoteType : ExprType → Type + | ExprType.bool => Bool + | ExprType.nat => Nat + +instance : BEq (denoteType α) where + beq a b := + match α, a, b with + | ExprType.bool, a, b => a == b + | ExprType.nat, a, b => a == b + +def eval : Expr α → denoteType α + | Expr.natVal v => v + | Expr.boolVal b => b + | Expr.eq a b => eval a == eval b + | Expr.add a b => eval a + eval b + +theorem ex3 (a b : Expr α) (h : a = b) : eval (constProp a) = eval b := by + set_option trace.Meta.debug true in + induction a + traceState -- b's type must have been refined, `h` too + repeat admit diff --git a/tests/lean/inductionGen.lean.expected.out b/tests/lean/inductionGen.lean.expected.out new file mode 100644 index 0000000000..faa67e1337 --- /dev/null +++ b/tests/lean/inductionGen.lean.expected.out @@ -0,0 +1,55 @@ +inductionGen.lean:23:2-23:14: error: index in target's type is not a variable (consider using the `cases` tactic instead) + n + 1 +case cons +α : Type u_1 +n : Nat +ys : Vec α (n + 1) +x : α +xs : Vec α n +h : Vec.cons x xs = ys +⊢ Vec.head (Vec.cons x xs) = Vec.head ys +inductionGen.lean:29:11-29:16: warning: declaration uses 'sorry' +case natVal +α : ExprType +a b✝ : Expr α + : a = b✝ +a✝ : Nat +b : Expr ExprType.nat +h : Expr.natVal a✝ = b +⊢ eval (constProp (Expr.natVal a✝)) = eval b + +case boolVal +α : ExprType +a b✝ : Expr α + : a = b✝ +a✝ : Bool +b : Expr ExprType.bool +h : Expr.boolVal a✝ = b +⊢ eval (constProp (Expr.boolVal a✝)) = eval b + +case eq +α : ExprType +a b✝ : Expr α + : a = b✝ +α✝ : ExprType +a✝¹ a✝ : Expr α✝ + : ∀ (b : Expr α✝), a✝¹ = b → eval (constProp a✝¹) = eval b + : ∀ (b : Expr α✝), a✝ = b → eval (constProp a✝) = eval b +b : Expr ExprType.bool +h : Expr.eq a✝¹ a✝ = b +⊢ eval (constProp (Expr.eq a✝¹ a✝)) = eval b + +case add +α : ExprType +a b✝ : Expr α + : a = b✝ +a✝¹ a✝ : Expr ExprType.nat + : ∀ (b : Expr ExprType.nat), a✝¹ = b → eval (constProp a✝¹) = eval b + : ∀ (b : Expr ExprType.nat), a✝ = b → eval (constProp a✝) = eval b +b : Expr ExprType.nat +h : Expr.add a✝¹ a✝ = b +⊢ eval (constProp (Expr.add a✝¹ a✝)) = eval b +inductionGen.lean:68:9-68:14: warning: declaration uses 'sorry' +inductionGen.lean:68:9-68:14: warning: declaration uses 'sorry' +inductionGen.lean:68:9-68:14: warning: declaration uses 'sorry' +inductionGen.lean:68:9-68:14: warning: declaration uses 'sorry' diff --git a/tests/lean/run/casesUsing.lean b/tests/lean/run/casesUsing.lean index 5d3837725d..3e9232fbf4 100644 --- a/tests/lean/run/casesUsing.lean +++ b/tests/lean/run/casesUsing.lean @@ -88,7 +88,7 @@ theorem ex9 (xs : List α) (h : xs = [] → False) : Nonempty α := by | cons x _ => apply Nonempty.intro; assumption theorem modLt (x : Nat) {y : Nat} (h : y > 0) : x % y < y := by - induction x, y using Nat.mod.inductionOn generalizing h with + induction x, y using Nat.mod.inductionOn with | ind x y h₁ ih => rw [Nat.mod_eq_sub_mod h₁.2] exact ih h diff --git a/tests/lean/run/do_eqv.lean b/tests/lean/run/do_eqv.lean index 7773000d75..8d294de41b 100644 --- a/tests/lean/run/do_eqv.lean +++ b/tests/lean/run/do_eqv.lean @@ -29,7 +29,8 @@ theorem eq_findSomeM_findM [Monad m] [LawfulMonad m] (p : α → m Bool) (xss : | cons xs xss ih => rw [← ih, ← eq_findM] induction xs with simp - | cons x xs ih => apply byCases_Bool_bind <;> simp [ih] + | cons x xs ih => + apply byCases_Bool_bind <;> simp [ih] theorem eq_findSomeM_findM' [Monad m] [LawfulMonad m] (p : α → m Bool) (xss : List (List α)) : (do for xs in xss do diff --git a/tests/lean/run/induction1.lean b/tests/lean/run/induction1.lean index 2bbeb38baa..e38f2a5aa6 100644 --- a/tests/lean/run/induction1.lean +++ b/tests/lean/run/induction1.lean @@ -28,13 +28,13 @@ by { theorem tst7 {α : Type} (xs : List α) (h : (a : α) → (as : List α) → xs ≠ a :: as) : xs = [] := by { - induction xs generalizing h with + induction xs with | nil => exact rfl | cons z zs ih => exact absurd rfl (h z zs) } theorem tst8 {α : Type} (xs : List α) (h : (a : α) → (as : List α) → xs ≠ a :: as) : xs = [] := by { - induction xs generalizing h; + induction xs; exact rfl; exact absurd rfl $ h _ _ } @@ -75,7 +75,7 @@ theorem tst13 (x : Tree) (h : x = Tree.leaf₁) : x.isLeaf₁ = true := by | _ => injection h theorem tst14 (x : Tree) (h : x = Tree.leaf₁) : x.isLeaf₁ = true := by - induction x generalizing h with + induction x with | leaf₁ => rfl | _ => injection h diff --git a/tests/lean/run/inductionAltExplicit.lean b/tests/lean/run/inductionAltExplicit.lean index 027cc55209..36faa871ee 100644 --- a/tests/lean/run/inductionAltExplicit.lean +++ b/tests/lean/run/inductionAltExplicit.lean @@ -4,15 +4,15 @@ inductive Lex (ra : α → α → Prop) (rb : β → β → Prop) : α × β → def lexAccessible1 {ra : α → α → Prop} {rb : β → β → Prop} (aca : (a : α) → Acc ra a) (acb : (b : β) → Acc rb b) (a : α) (b : β) : Acc (Lex ra rb) (a, b) := by - induction (aca a) generalizing b with + induction aca a generalizing b with | intro xa aca iha => - induction (acb b) with + induction acb b with | intro xb acb ihb => apply Acc.intro (xa, xb) intro p lt cases lt with - | left b1 b2 h => apply iha _ h -- only explicit fields are provided by default - | @right a b1 b2 h => apply ihb b1 h -- `@` allows us to provide names to implicit fields too + | left b1 b2 h => apply iha _ h _ (aca _ h) + | @right a b1 b2 h => apply ihb _ h (acb _ h) def lexAccessible2 {ra : α → α → Prop} {rb : β → β → Prop} (aca : (a : α) → Acc ra a) (acb : (b : β) → Acc rb b) (a : α) (b : β) : Acc (Lex ra rb) (a, b) := by induction (aca a) generalizing b with @@ -22,5 +22,5 @@ def lexAccessible2 {ra : α → α → Prop} {rb : β → β → Prop} (aca : (a apply Acc.intro (xa, xb) intro p lt cases lt with - | @left a1 b1 a2 b2 h => apply iha a1 h - | right _ h => apply ihb _ h + | @left a1 b1 a2 b2 h => apply iha _ h _ (aca _ h) + | right _ h => apply ihb _ h (acb _ h) diff --git a/tests/lean/unsolvedIndCases.lean b/tests/lean/unsolvedIndCases.lean index c0e55ce21e..b2d7f6d239 100644 --- a/tests/lean/unsolvedIndCases.lean +++ b/tests/lean/unsolvedIndCases.lean @@ -14,7 +14,7 @@ theorem ex3 (x : Nat) : 0 + x = x := by | succ y => skip -- Error: unsolved goals theorem ex4 (x : Nat) {y : Nat} (h : y > 0) : x % y < y := by - induction x, y using Nat.mod.inductionOn generalizing h with + induction x, y using Nat.mod.inductionOn with | ind x y h₁ ih => skip -- Error: unsolved goals | base x y h₁ => skip -- Error: unsolved goals diff --git a/tests/lean/unsolvedIndCases.lean.expected.out b/tests/lean/unsolvedIndCases.lean.expected.out index bd3736bee1..252bd451a4 100644 --- a/tests/lean/unsolvedIndCases.lean.expected.out +++ b/tests/lean/unsolvedIndCases.lean.expected.out @@ -19,6 +19,8 @@ y : Nat ⊢ 0 + Nat.succ y = Nat.succ y unsolvedIndCases.lean:18:18-18:25: error: unsolved goals case ind +y✝ : Nat +h✝ : y✝ > 0 x y : Nat h₁ : 0 < y ∧ y ≤ x ih : y > 0 → (x - y) % y < y @@ -26,6 +28,8 @@ h : y > 0 ⊢ x % y < y unsolvedIndCases.lean:19:18-19:25: error: unsolved goals case base +y✝ : Nat +h✝ : y✝ > 0 x y : Nat h₁ : ¬(0 < y ∧ y ≤ x) h : y > 0