feat: improve generalizing at induction

This commit is contained in:
Leonardo de Moura 2021-03-27 14:28:03 -07:00
parent ba3d6103fa
commit 4a0f8bf21a
9 changed files with 217 additions and 22 deletions

View file

@ -3,6 +3,7 @@ Copyright (c) 2020 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE. Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura, Sebastian Ullrich Authors: Leonardo de Moura, Sebastian Ullrich
-/ -/
import Lean.Util.CollectFVars
import Lean.Parser.Term import Lean.Parser.Term
import Lean.Meta.RecursorInfo import Lean.Meta.RecursorInfo
import Lean.Meta.CollectMVars import Lean.Meta.CollectMVars
@ -142,7 +143,7 @@ partial def mkElimApp (elimName : Name) (elimInfo : ElimInfo) (targets : Array E
catch _ => catch _ =>
setMVarKind mvarId MetavarKind.syntheticOpaque setMVarKind mvarId MetavarKind.syntheticOpaque
others := others.push mvarId others := others.push mvarId
pure { elimApp := (← instantiateMVars s.f), alts := s.alts, others := others } return { elimApp := (← instantiateMVars s.f), alts := s.alts, others := others }
/- Given a goal `... targets ... |- C[targets]` associated with `mvarId`, assign /- Given a goal `... targets ... |- C[targets]` associated with `mvarId`, assign
`motiveArg := fun targets => C[targets]` -/ `motiveArg := fun targets => C[targets]` -/
@ -239,6 +240,53 @@ where
end ElimApp end ElimApp
/--
Return a set of `FVarId`s containing `targets` and all variables they depend on.
Remark: this method assumes `targets` are free variables.
-/
private partial def mkForbiddenSet (targets : Array Expr) : MetaM NameSet := do
loop (targets.toList.map Expr.fvarId!) {}
where
visit (fvarId : FVarId) (todo : List FVarId) (s : NameSet) : MetaM (List FVarId × NameSet) := do
let localDecl ← getLocalDecl fvarId
let mut s' := collectFVars {} (← instantiateMVars localDecl.type)
if let some val := localDecl.value? then
s' := collectFVars s' (← instantiateMVars val)
let mut todo := todo
let mut s := s
for fvarId in s'.fvarSet do
unless s.contains fvarId do
todo := fvarId :: todo
s := s.insert fvarId
return (todo, s)
loop (todo : List FVarId) (s : NameSet) : MetaM NameSet := do
match todo with
| [] => return s
| fvarId::todo =>
if s.contains fvarId then
return s
else
let (todo, s) ← visit fvarId todo <| s.insert fvarId
loop todo s
/--
Collect forward dependencies that are not in the forbidden set, and depend on some variable in `targets`.
Remark: this method assumes `targets` are free variables.
-/
private def collectForwardDeps (targets : Array Expr) (forbidden : NameSet) : MetaM NameSet := do
let mut s : NameSet := targets.foldl (init := {}) fun s target => s.insert target.fvarId!
let mut r : NameSet := {}
for localDecl in (← getLCtx) do
unless forbidden.contains localDecl.fvarId do
unless localDecl.isAuxDecl do
if (← getMCtx).findLocalDeclDependsOn localDecl fun fvarId => s.contains fvarId then
r := r.insert localDecl.fvarId
s := s.insert localDecl.fvarId
return r
/- /-
Recall that Recall that
``` ```
@ -251,19 +299,28 @@ private def getGeneralizingFVarIds (stx : Syntax) : TacticM (Array FVarId) :=
let generalizingStx := stx[3] let generalizingStx := stx[3]
if generalizingStx.isNone then if generalizingStx.isNone then
pure #[] pure #[]
else withMainContext do else
trace[Elab.induction] "{generalizingStx}" trace[Elab.induction] "{generalizingStx}"
let vars := generalizingStx[1].getArgs let vars := generalizingStx[1].getArgs
getFVarIds vars getFVarIds vars
-- process `generalizingVars` subterm of induction Syntax `stx`. -- process `generalizingVars` subterm of induction Syntax `stx`.
private def generalizeVars (stx : Syntax) (targets : Array Expr) : TacticM Nat := do private def generalizeVars (mvarId : MVarId) (stx : Syntax) (targets : Array Expr) : TacticM (Nat × MVarId) :=
let fvarIds ← getGeneralizingFVarIds stx withMVarContext mvarId do
liftMetaTacticAux fun mvarId => do let userFVarIds ← getGeneralizingFVarIds stx
let forbidden ← mkForbiddenSet targets
let mut s ← collectForwardDeps targets forbidden
for userFVarId in userFVarIds do
if forbidden.contains userFVarId then
throwError "variable cannot be generalized because target depends on it{indentExpr (mkFVar userFVarId)}"
if s.contains userFVarId then
throwError "unnecessary 'generalizing' argument, variable '{mkFVar userFVarId}' is generalized automatically"
s := s.insert userFVarId
let fvarIds := s.fold (init := #[]) fun s fvarId => s.push fvarId
let lctx ← getLCtx
let fvarIds ← fvarIds.qsort fun x y => (lctx.get! x).index < (lctx.get! y).index
let (fvarIds, mvarId') ← Meta.revert mvarId fvarIds let (fvarIds, mvarId') ← Meta.revert mvarId fvarIds
if targets.any fun target => fvarIds.contains target.fvarId! then return (fvarIds.size, mvarId')
Meta.throwTacticEx `induction mvarId "major premise depends on variable being generalized"
pure (fvarIds.size, [mvarId'])
-- syntax inductionAlts := "with " (tactic)? withPosition( (colGe inductionAlt)+) -- syntax inductionAlts := "with " (tactic)? withPosition( (colGe inductionAlt)+)
private def getAltsOfInductionAlts (inductionAlts : Syntax) : Array Syntax := private def getAltsOfInductionAlts (inductionAlts : Syntax) : Array Syntax :=
@ -324,23 +381,33 @@ private def getElimNameInfo (optElimId : Syntax) (targets : Array Expr) (inducti
let targets ← stx[1].getSepArgs.mapM fun target => do let targets ← stx[1].getSepArgs.mapM fun target => do
let target ← withMainContext <| elabTerm target none let target ← withMainContext <| elabTerm target none
generalizeTerm target generalizeTerm target
let n ← generalizeVars stx targets
let (elimName, elimInfo) ← getElimNameInfo stx[2] targets (induction := true) let (elimName, elimInfo) ← getElimNameInfo stx[2] targets (induction := true)
let mvarId ← getMainGoal let mvarId ← getMainGoal
let tag ← getMVarTag mvarId let tag ← getMVarTag mvarId
withMVarContext mvarId do withMVarContext mvarId do
let result ← withRef stx[1] do -- use target position as reference let result ← withRef stx[1] do -- use target position as reference
ElimApp.mkElimApp elimName elimInfo targets tag ElimApp.mkElimApp elimName elimInfo targets tag
assignExprMVar mvarId result.elimApp
let elimArgs := result.elimApp.getAppArgs let elimArgs := result.elimApp.getAppArgs
let targets ← elimInfo.targetsPos.mapM fun i => instantiateMVars elimArgs[i] let targets ← elimInfo.targetsPos.mapM fun i => instantiateMVars elimArgs[i]
checkTargets targets
let motiveType ← inferType elimArgs[elimInfo.motivePos]
let (n, mvarId) ← generalizeVars mvarId stx targets
let targetFVarIds := targets.map (·.fvarId!) let targetFVarIds := targets.map (·.fvarId!)
ElimApp.setMotiveArg mvarId elimArgs[elimInfo.motivePos].mvarId! targetFVarIds ElimApp.setMotiveArg mvarId elimArgs[elimInfo.motivePos].mvarId! targetFVarIds
let optInductionAlts := stx[4] let optInductionAlts := stx[4]
let optPreTac := getOptPreTacOfOptInductionAlts optInductionAlts let optPreTac := getOptPreTacOfOptInductionAlts optInductionAlts
let alts := getAltsOfOptInductionAlts optInductionAlts let alts := getAltsOfOptInductionAlts optInductionAlts
assignExprMVar mvarId result.elimApp
ElimApp.evalAlts elimInfo result.alts optPreTac alts (numGeneralized := n) (toClear := targetFVarIds) ElimApp.evalAlts elimInfo result.alts optPreTac alts (numGeneralized := n) (toClear := targetFVarIds)
appendGoals result.others.toList appendGoals result.others.toList
where
checkTargets (targets : Array Expr) : MetaM Unit := do
let mut foundFVars : NameSet := {}
for target in targets do
unless target.isFVar do
throwError "index in target's type is not a variable (consider using the `cases` tactic instead){indentExpr target}"
if foundFVars.contains target.fvarId! then
throwError "target (or one of its indices) occurs more than once{indentExpr target}"
-- Recall that -- Recall that
-- majorPremise := leading_parser optional (try (ident >> " : ")) >> termParser -- majorPremise := leading_parser optional (try (ident >> " : ")) >> termParser

View file

@ -0,0 +1,68 @@
inductive Vec (α : Type u) : Nat → Type u
| nil : Vec α 0
| cons : α → Vec α n → Vec α (n+1)
def Vec.map (xs : Vec α n) (f : α → β) : Vec β n :=
match xs with
| nil => nil
| cons a as => cons (f a) (map as f)
def Vec.map' (f : α → β) : Vec α n → Vec β n
| nil => nil
| cons a as => cons (f a) (map' f as)
def Vec.map2 (f : αα → β) : Vec α n → Vec α n → Vec β n
| nil, nil => nil
| cons a as, cons b bs => cons (f a b) (map2 f as bs)
def Vec.head (xs : Vec α (n+1)) : α :=
match xs with
| cons x _ => x
theorem ex1 (xs ys : Vec α (n+1)) (h : xs = ys) : xs.head = ys.head := by
induction xs -- error, use cases
theorem ex2 (xs ys : Vec α (n+1)) (h : xs = ys) : xs.head = ys.head := by
cases xs with
| cons x xs =>
traceState -- `h` has been refined
repeat admit
inductive ExprType where
| bool : ExprType
| nat : ExprType
inductive Expr : ExprType → Type
| natVal : Nat → Expr ExprType.nat
| boolVal : Bool → Expr ExprType.bool
| eq : Expr α → Expr α → Expr ExprType.bool
| add : Expr ExprType.nat → Expr ExprType.nat → Expr ExprType.nat
def constProp : Expr α → Expr α
| Expr.add a b =>
match constProp a, constProp b with
| Expr.natVal v, Expr.natVal w => Expr.natVal (v + w)
| a, b => Expr.add a b
| e => e
abbrev denoteType : ExprType → Type
| ExprType.bool => Bool
| ExprType.nat => Nat
instance : BEq (denoteType α) where
beq a b :=
match α, a, b with
| ExprType.bool, a, b => a == b
| ExprType.nat, a, b => a == b
def eval : Expr α → denoteType α
| Expr.natVal v => v
| Expr.boolVal b => b
| Expr.eq a b => eval a == eval b
| Expr.add a b => eval a + eval b
theorem ex3 (a b : Expr α) (h : a = b) : eval (constProp a) = eval b := by
set_option trace.Meta.debug true in
induction a
traceState -- b's type must have been refined, `h` too
repeat admit

View file

@ -0,0 +1,55 @@
inductionGen.lean:23:2-23:14: error: index in target's type is not a variable (consider using the `cases` tactic instead)
n + 1
case cons
α : Type u_1
n : Nat
ys : Vec α (n + 1)
x : α
xs : Vec α n
h : Vec.cons x xs = ys
⊢ Vec.head (Vec.cons x xs) = Vec.head ys
inductionGen.lean:29:11-29:16: warning: declaration uses 'sorry'
case natVal
α : ExprType
a b✝ : Expr α
: a = b✝
a✝ : Nat
b : Expr ExprType.nat
h : Expr.natVal a✝ = b
⊢ eval (constProp (Expr.natVal a✝)) = eval b
case boolVal
α : ExprType
a b✝ : Expr α
: a = b✝
a✝ : Bool
b : Expr ExprType.bool
h : Expr.boolVal a✝ = b
⊢ eval (constProp (Expr.boolVal a✝)) = eval b
case eq
α : ExprType
a b✝ : Expr α
: a = b✝
α✝ : ExprType
a✝¹ a✝ : Expr α✝
: ∀ (b : Expr α✝), a✝¹ = b → eval (constProp a✝¹) = eval b
: ∀ (b : Expr α✝), a✝ = b → eval (constProp a✝) = eval b
b : Expr ExprType.bool
h : Expr.eq a✝¹ a✝ = b
⊢ eval (constProp (Expr.eq a✝¹ a✝)) = eval b
case add
α : ExprType
a b✝ : Expr α
: a = b✝
a✝¹ a✝ : Expr ExprType.nat
: ∀ (b : Expr ExprType.nat), a✝¹ = b → eval (constProp a✝¹) = eval b
: ∀ (b : Expr ExprType.nat), a✝ = b → eval (constProp a✝) = eval b
b : Expr ExprType.nat
h : Expr.add a✝¹ a✝ = b
⊢ eval (constProp (Expr.add a✝¹ a✝)) = eval b
inductionGen.lean:68:9-68:14: warning: declaration uses 'sorry'
inductionGen.lean:68:9-68:14: warning: declaration uses 'sorry'
inductionGen.lean:68:9-68:14: warning: declaration uses 'sorry'
inductionGen.lean:68:9-68:14: warning: declaration uses 'sorry'

View file

@ -88,7 +88,7 @@ theorem ex9 (xs : List α) (h : xs = [] → False) : Nonempty α := by
| cons x _ => apply Nonempty.intro; assumption | cons x _ => apply Nonempty.intro; assumption
theorem modLt (x : Nat) {y : Nat} (h : y > 0) : x % y < y := by theorem modLt (x : Nat) {y : Nat} (h : y > 0) : x % y < y := by
induction x, y using Nat.mod.inductionOn generalizing h with induction x, y using Nat.mod.inductionOn with
| ind x y h₁ ih => | ind x y h₁ ih =>
rw [Nat.mod_eq_sub_mod h₁.2] rw [Nat.mod_eq_sub_mod h₁.2]
exact ih h exact ih h

View file

@ -29,7 +29,8 @@ theorem eq_findSomeM_findM [Monad m] [LawfulMonad m] (p : α → m Bool) (xss :
| cons xs xss ih => | cons xs xss ih =>
rw [← ih, ← eq_findM] rw [← ih, ← eq_findM]
induction xs with simp induction xs with simp
| cons x xs ih => apply byCases_Bool_bind <;> simp [ih] | cons x xs ih =>
apply byCases_Bool_bind <;> simp [ih]
theorem eq_findSomeM_findM' [Monad m] [LawfulMonad m] (p : α → m Bool) (xss : List (List α)) : theorem eq_findSomeM_findM' [Monad m] [LawfulMonad m] (p : α → m Bool) (xss : List (List α)) :
(do for xs in xss do (do for xs in xss do

View file

@ -28,13 +28,13 @@ by {
theorem tst7 {α : Type} (xs : List α) (h : (a : α) → (as : List α) → xs ≠ a :: as) : xs = [] := theorem tst7 {α : Type} (xs : List α) (h : (a : α) → (as : List α) → xs ≠ a :: as) : xs = [] :=
by { by {
induction xs generalizing h with induction xs with
| nil => exact rfl | nil => exact rfl
| cons z zs ih => exact absurd rfl (h z zs) | cons z zs ih => exact absurd rfl (h z zs)
} }
theorem tst8 {α : Type} (xs : List α) (h : (a : α) → (as : List α) → xs ≠ a :: as) : xs = [] := by { theorem tst8 {α : Type} (xs : List α) (h : (a : α) → (as : List α) → xs ≠ a :: as) : xs = [] := by {
induction xs generalizing h; induction xs;
exact rfl; exact rfl;
exact absurd rfl $ h _ _ exact absurd rfl $ h _ _
} }
@ -75,7 +75,7 @@ theorem tst13 (x : Tree) (h : x = Tree.leaf₁) : x.isLeaf₁ = true := by
| _ => injection h | _ => injection h
theorem tst14 (x : Tree) (h : x = Tree.leaf₁) : x.isLeaf₁ = true := by theorem tst14 (x : Tree) (h : x = Tree.leaf₁) : x.isLeaf₁ = true := by
induction x generalizing h with induction x with
| leaf₁ => rfl | leaf₁ => rfl
| _ => injection h | _ => injection h

View file

@ -4,15 +4,15 @@ inductive Lex (ra : αα → Prop) (rb : β → β → Prop) : α × β →
def lexAccessible1 {ra : αα → Prop} {rb : β → β → Prop} (aca : (a : α) → Acc ra a) (acb : (b : β) → Acc rb b) (a : α) (b : β) : Acc (Lex ra rb) (a, b) := by def lexAccessible1 {ra : αα → Prop} {rb : β → β → Prop} (aca : (a : α) → Acc ra a) (acb : (b : β) → Acc rb b) (a : α) (b : β) : Acc (Lex ra rb) (a, b) := by
induction (aca a) generalizing b with induction aca a generalizing b with
| intro xa aca iha => | intro xa aca iha =>
induction (acb b) with induction acb b with
| intro xb acb ihb => | intro xb acb ihb =>
apply Acc.intro (xa, xb) apply Acc.intro (xa, xb)
intro p lt intro p lt
cases lt with cases lt with
| left b1 b2 h => apply iha _ h -- only explicit fields are provided by default | left b1 b2 h => apply iha _ h _ (aca _ h)
| @right a b1 b2 h => apply ihb b1 h -- `@` allows us to provide names to implicit fields too | @right a b1 b2 h => apply ihb _ h (acb _ h)
def lexAccessible2 {ra : αα → Prop} {rb : β → β → Prop} (aca : (a : α) → Acc ra a) (acb : (b : β) → Acc rb b) (a : α) (b : β) : Acc (Lex ra rb) (a, b) := by def lexAccessible2 {ra : αα → Prop} {rb : β → β → Prop} (aca : (a : α) → Acc ra a) (acb : (b : β) → Acc rb b) (a : α) (b : β) : Acc (Lex ra rb) (a, b) := by
induction (aca a) generalizing b with induction (aca a) generalizing b with
@ -22,5 +22,5 @@ def lexAccessible2 {ra : αα → Prop} {rb : β → β → Prop} (aca : (a
apply Acc.intro (xa, xb) apply Acc.intro (xa, xb)
intro p lt intro p lt
cases lt with cases lt with
| @left a1 b1 a2 b2 h => apply iha a1 h | @left a1 b1 a2 b2 h => apply iha _ h _ (aca _ h)
| right _ h => apply ihb _ h | right _ h => apply ihb _ h (acb _ h)

View file

@ -14,7 +14,7 @@ theorem ex3 (x : Nat) : 0 + x = x := by
| succ y => skip -- Error: unsolved goals | succ y => skip -- Error: unsolved goals
theorem ex4 (x : Nat) {y : Nat} (h : y > 0) : x % y < y := by theorem ex4 (x : Nat) {y : Nat} (h : y > 0) : x % y < y := by
induction x, y using Nat.mod.inductionOn generalizing h with induction x, y using Nat.mod.inductionOn with
| ind x y h₁ ih => skip -- Error: unsolved goals | ind x y h₁ ih => skip -- Error: unsolved goals
| base x y h₁ => skip -- Error: unsolved goals | base x y h₁ => skip -- Error: unsolved goals

View file

@ -19,6 +19,8 @@ y : Nat
⊢ 0 + Nat.succ y = Nat.succ y ⊢ 0 + Nat.succ y = Nat.succ y
unsolvedIndCases.lean:18:18-18:25: error: unsolved goals unsolvedIndCases.lean:18:18-18:25: error: unsolved goals
case ind case ind
y✝ : Nat
h✝ : y✝ > 0
x y : Nat x y : Nat
h₁ : 0 < y ∧ y ≤ x h₁ : 0 < y ∧ y ≤ x
ih : y > 0 → (x - y) % y < y ih : y > 0 → (x - y) % y < y
@ -26,6 +28,8 @@ h : y > 0
⊢ x % y < y ⊢ x % y < y
unsolvedIndCases.lean:19:18-19:25: error: unsolved goals unsolvedIndCases.lean:19:18-19:25: error: unsolved goals
case base case base
y✝ : Nat
h✝ : y✝ > 0
x y : Nat x y : Nat
h₁ : ¬(0 < y ∧ y ≤ x) h₁ : ¬(0 < y ∧ y ≤ x)
h : y > 0 h : y > 0