feat: improve generalizing at induction
This commit is contained in:
parent
ba3d6103fa
commit
4a0f8bf21a
9 changed files with 217 additions and 22 deletions
|
|
@ -3,6 +3,7 @@ Copyright (c) 2020 Microsoft Corporation. All rights reserved.
|
|||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Leonardo de Moura, Sebastian Ullrich
|
||||
-/
|
||||
import Lean.Util.CollectFVars
|
||||
import Lean.Parser.Term
|
||||
import Lean.Meta.RecursorInfo
|
||||
import Lean.Meta.CollectMVars
|
||||
|
|
@ -142,7 +143,7 @@ partial def mkElimApp (elimName : Name) (elimInfo : ElimInfo) (targets : Array E
|
|||
catch _ =>
|
||||
setMVarKind mvarId MetavarKind.syntheticOpaque
|
||||
others := others.push mvarId
|
||||
pure { elimApp := (← instantiateMVars s.f), alts := s.alts, others := others }
|
||||
return { elimApp := (← instantiateMVars s.f), alts := s.alts, others := others }
|
||||
|
||||
/- Given a goal `... targets ... |- C[targets]` associated with `mvarId`, assign
|
||||
`motiveArg := fun targets => C[targets]` -/
|
||||
|
|
@ -239,6 +240,53 @@ where
|
|||
|
||||
end ElimApp
|
||||
|
||||
/--
|
||||
Return a set of `FVarId`s containing `targets` and all variables they depend on.
|
||||
|
||||
Remark: this method assumes `targets` are free variables.
|
||||
-/
|
||||
private partial def mkForbiddenSet (targets : Array Expr) : MetaM NameSet := do
|
||||
loop (targets.toList.map Expr.fvarId!) {}
|
||||
where
|
||||
visit (fvarId : FVarId) (todo : List FVarId) (s : NameSet) : MetaM (List FVarId × NameSet) := do
|
||||
let localDecl ← getLocalDecl fvarId
|
||||
let mut s' := collectFVars {} (← instantiateMVars localDecl.type)
|
||||
if let some val := localDecl.value? then
|
||||
s' := collectFVars s' (← instantiateMVars val)
|
||||
let mut todo := todo
|
||||
let mut s := s
|
||||
for fvarId in s'.fvarSet do
|
||||
unless s.contains fvarId do
|
||||
todo := fvarId :: todo
|
||||
s := s.insert fvarId
|
||||
return (todo, s)
|
||||
|
||||
loop (todo : List FVarId) (s : NameSet) : MetaM NameSet := do
|
||||
match todo with
|
||||
| [] => return s
|
||||
| fvarId::todo =>
|
||||
if s.contains fvarId then
|
||||
return s
|
||||
else
|
||||
let (todo, s) ← visit fvarId todo <| s.insert fvarId
|
||||
loop todo s
|
||||
|
||||
/--
|
||||
Collect forward dependencies that are not in the forbidden set, and depend on some variable in `targets`.
|
||||
|
||||
Remark: this method assumes `targets` are free variables.
|
||||
-/
|
||||
private def collectForwardDeps (targets : Array Expr) (forbidden : NameSet) : MetaM NameSet := do
|
||||
let mut s : NameSet := targets.foldl (init := {}) fun s target => s.insert target.fvarId!
|
||||
let mut r : NameSet := {}
|
||||
for localDecl in (← getLCtx) do
|
||||
unless forbidden.contains localDecl.fvarId do
|
||||
unless localDecl.isAuxDecl do
|
||||
if (← getMCtx).findLocalDeclDependsOn localDecl fun fvarId => s.contains fvarId then
|
||||
r := r.insert localDecl.fvarId
|
||||
s := s.insert localDecl.fvarId
|
||||
return r
|
||||
|
||||
/-
|
||||
Recall that
|
||||
```
|
||||
|
|
@ -251,19 +299,28 @@ private def getGeneralizingFVarIds (stx : Syntax) : TacticM (Array FVarId) :=
|
|||
let generalizingStx := stx[3]
|
||||
if generalizingStx.isNone then
|
||||
pure #[]
|
||||
else withMainContext do
|
||||
else
|
||||
trace[Elab.induction] "{generalizingStx}"
|
||||
let vars := generalizingStx[1].getArgs
|
||||
getFVarIds vars
|
||||
|
||||
-- process `generalizingVars` subterm of induction Syntax `stx`.
|
||||
private def generalizeVars (stx : Syntax) (targets : Array Expr) : TacticM Nat := do
|
||||
let fvarIds ← getGeneralizingFVarIds stx
|
||||
liftMetaTacticAux fun mvarId => do
|
||||
private def generalizeVars (mvarId : MVarId) (stx : Syntax) (targets : Array Expr) : TacticM (Nat × MVarId) :=
|
||||
withMVarContext mvarId do
|
||||
let userFVarIds ← getGeneralizingFVarIds stx
|
||||
let forbidden ← mkForbiddenSet targets
|
||||
let mut s ← collectForwardDeps targets forbidden
|
||||
for userFVarId in userFVarIds do
|
||||
if forbidden.contains userFVarId then
|
||||
throwError "variable cannot be generalized because target depends on it{indentExpr (mkFVar userFVarId)}"
|
||||
if s.contains userFVarId then
|
||||
throwError "unnecessary 'generalizing' argument, variable '{mkFVar userFVarId}' is generalized automatically"
|
||||
s := s.insert userFVarId
|
||||
let fvarIds := s.fold (init := #[]) fun s fvarId => s.push fvarId
|
||||
let lctx ← getLCtx
|
||||
let fvarIds ← fvarIds.qsort fun x y => (lctx.get! x).index < (lctx.get! y).index
|
||||
let (fvarIds, mvarId') ← Meta.revert mvarId fvarIds
|
||||
if targets.any fun target => fvarIds.contains target.fvarId! then
|
||||
Meta.throwTacticEx `induction mvarId "major premise depends on variable being generalized"
|
||||
pure (fvarIds.size, [mvarId'])
|
||||
return (fvarIds.size, mvarId')
|
||||
|
||||
-- syntax inductionAlts := "with " (tactic)? withPosition( (colGe inductionAlt)+)
|
||||
private def getAltsOfInductionAlts (inductionAlts : Syntax) : Array Syntax :=
|
||||
|
|
@ -324,23 +381,33 @@ private def getElimNameInfo (optElimId : Syntax) (targets : Array Expr) (inducti
|
|||
let targets ← stx[1].getSepArgs.mapM fun target => do
|
||||
let target ← withMainContext <| elabTerm target none
|
||||
generalizeTerm target
|
||||
let n ← generalizeVars stx targets
|
||||
let (elimName, elimInfo) ← getElimNameInfo stx[2] targets (induction := true)
|
||||
let mvarId ← getMainGoal
|
||||
let tag ← getMVarTag mvarId
|
||||
withMVarContext mvarId do
|
||||
let result ← withRef stx[1] do -- use target position as reference
|
||||
ElimApp.mkElimApp elimName elimInfo targets tag
|
||||
assignExprMVar mvarId result.elimApp
|
||||
let elimArgs := result.elimApp.getAppArgs
|
||||
let targets ← elimInfo.targetsPos.mapM fun i => instantiateMVars elimArgs[i]
|
||||
checkTargets targets
|
||||
let motiveType ← inferType elimArgs[elimInfo.motivePos]
|
||||
let (n, mvarId) ← generalizeVars mvarId stx targets
|
||||
let targetFVarIds := targets.map (·.fvarId!)
|
||||
ElimApp.setMotiveArg mvarId elimArgs[elimInfo.motivePos].mvarId! targetFVarIds
|
||||
let optInductionAlts := stx[4]
|
||||
let optPreTac := getOptPreTacOfOptInductionAlts optInductionAlts
|
||||
let alts := getAltsOfOptInductionAlts optInductionAlts
|
||||
assignExprMVar mvarId result.elimApp
|
||||
ElimApp.evalAlts elimInfo result.alts optPreTac alts (numGeneralized := n) (toClear := targetFVarIds)
|
||||
appendGoals result.others.toList
|
||||
where
|
||||
checkTargets (targets : Array Expr) : MetaM Unit := do
|
||||
let mut foundFVars : NameSet := {}
|
||||
for target in targets do
|
||||
unless target.isFVar do
|
||||
throwError "index in target's type is not a variable (consider using the `cases` tactic instead){indentExpr target}"
|
||||
if foundFVars.contains target.fvarId! then
|
||||
throwError "target (or one of its indices) occurs more than once{indentExpr target}"
|
||||
|
||||
-- Recall that
|
||||
-- majorPremise := leading_parser optional (try (ident >> " : ")) >> termParser
|
||||
|
|
|
|||
68
tests/lean/inductionGen.lean
Normal file
68
tests/lean/inductionGen.lean
Normal file
|
|
@ -0,0 +1,68 @@
|
|||
inductive Vec (α : Type u) : Nat → Type u
|
||||
| nil : Vec α 0
|
||||
| cons : α → Vec α n → Vec α (n+1)
|
||||
|
||||
def Vec.map (xs : Vec α n) (f : α → β) : Vec β n :=
|
||||
match xs with
|
||||
| nil => nil
|
||||
| cons a as => cons (f a) (map as f)
|
||||
|
||||
def Vec.map' (f : α → β) : Vec α n → Vec β n
|
||||
| nil => nil
|
||||
| cons a as => cons (f a) (map' f as)
|
||||
|
||||
def Vec.map2 (f : α → α → β) : Vec α n → Vec α n → Vec β n
|
||||
| nil, nil => nil
|
||||
| cons a as, cons b bs => cons (f a b) (map2 f as bs)
|
||||
|
||||
def Vec.head (xs : Vec α (n+1)) : α :=
|
||||
match xs with
|
||||
| cons x _ => x
|
||||
|
||||
theorem ex1 (xs ys : Vec α (n+1)) (h : xs = ys) : xs.head = ys.head := by
|
||||
induction xs -- error, use cases
|
||||
|
||||
theorem ex2 (xs ys : Vec α (n+1)) (h : xs = ys) : xs.head = ys.head := by
|
||||
cases xs with
|
||||
| cons x xs =>
|
||||
traceState -- `h` has been refined
|
||||
repeat admit
|
||||
|
||||
inductive ExprType where
|
||||
| bool : ExprType
|
||||
| nat : ExprType
|
||||
|
||||
inductive Expr : ExprType → Type
|
||||
| natVal : Nat → Expr ExprType.nat
|
||||
| boolVal : Bool → Expr ExprType.bool
|
||||
| eq : Expr α → Expr α → Expr ExprType.bool
|
||||
| add : Expr ExprType.nat → Expr ExprType.nat → Expr ExprType.nat
|
||||
|
||||
def constProp : Expr α → Expr α
|
||||
| Expr.add a b =>
|
||||
match constProp a, constProp b with
|
||||
| Expr.natVal v, Expr.natVal w => Expr.natVal (v + w)
|
||||
| a, b => Expr.add a b
|
||||
| e => e
|
||||
|
||||
abbrev denoteType : ExprType → Type
|
||||
| ExprType.bool => Bool
|
||||
| ExprType.nat => Nat
|
||||
|
||||
instance : BEq (denoteType α) where
|
||||
beq a b :=
|
||||
match α, a, b with
|
||||
| ExprType.bool, a, b => a == b
|
||||
| ExprType.nat, a, b => a == b
|
||||
|
||||
def eval : Expr α → denoteType α
|
||||
| Expr.natVal v => v
|
||||
| Expr.boolVal b => b
|
||||
| Expr.eq a b => eval a == eval b
|
||||
| Expr.add a b => eval a + eval b
|
||||
|
||||
theorem ex3 (a b : Expr α) (h : a = b) : eval (constProp a) = eval b := by
|
||||
set_option trace.Meta.debug true in
|
||||
induction a
|
||||
traceState -- b's type must have been refined, `h` too
|
||||
repeat admit
|
||||
55
tests/lean/inductionGen.lean.expected.out
Normal file
55
tests/lean/inductionGen.lean.expected.out
Normal file
|
|
@ -0,0 +1,55 @@
|
|||
inductionGen.lean:23:2-23:14: error: index in target's type is not a variable (consider using the `cases` tactic instead)
|
||||
n + 1
|
||||
case cons
|
||||
α : Type u_1
|
||||
n : Nat
|
||||
ys : Vec α (n + 1)
|
||||
x : α
|
||||
xs : Vec α n
|
||||
h : Vec.cons x xs = ys
|
||||
⊢ Vec.head (Vec.cons x xs) = Vec.head ys
|
||||
inductionGen.lean:29:11-29:16: warning: declaration uses 'sorry'
|
||||
case natVal
|
||||
α : ExprType
|
||||
a b✝ : Expr α
|
||||
: a = b✝
|
||||
a✝ : Nat
|
||||
b : Expr ExprType.nat
|
||||
h : Expr.natVal a✝ = b
|
||||
⊢ eval (constProp (Expr.natVal a✝)) = eval b
|
||||
|
||||
case boolVal
|
||||
α : ExprType
|
||||
a b✝ : Expr α
|
||||
: a = b✝
|
||||
a✝ : Bool
|
||||
b : Expr ExprType.bool
|
||||
h : Expr.boolVal a✝ = b
|
||||
⊢ eval (constProp (Expr.boolVal a✝)) = eval b
|
||||
|
||||
case eq
|
||||
α : ExprType
|
||||
a b✝ : Expr α
|
||||
: a = b✝
|
||||
α✝ : ExprType
|
||||
a✝¹ a✝ : Expr α✝
|
||||
: ∀ (b : Expr α✝), a✝¹ = b → eval (constProp a✝¹) = eval b
|
||||
: ∀ (b : Expr α✝), a✝ = b → eval (constProp a✝) = eval b
|
||||
b : Expr ExprType.bool
|
||||
h : Expr.eq a✝¹ a✝ = b
|
||||
⊢ eval (constProp (Expr.eq a✝¹ a✝)) = eval b
|
||||
|
||||
case add
|
||||
α : ExprType
|
||||
a b✝ : Expr α
|
||||
: a = b✝
|
||||
a✝¹ a✝ : Expr ExprType.nat
|
||||
: ∀ (b : Expr ExprType.nat), a✝¹ = b → eval (constProp a✝¹) = eval b
|
||||
: ∀ (b : Expr ExprType.nat), a✝ = b → eval (constProp a✝) = eval b
|
||||
b : Expr ExprType.nat
|
||||
h : Expr.add a✝¹ a✝ = b
|
||||
⊢ eval (constProp (Expr.add a✝¹ a✝)) = eval b
|
||||
inductionGen.lean:68:9-68:14: warning: declaration uses 'sorry'
|
||||
inductionGen.lean:68:9-68:14: warning: declaration uses 'sorry'
|
||||
inductionGen.lean:68:9-68:14: warning: declaration uses 'sorry'
|
||||
inductionGen.lean:68:9-68:14: warning: declaration uses 'sorry'
|
||||
|
|
@ -88,7 +88,7 @@ theorem ex9 (xs : List α) (h : xs = [] → False) : Nonempty α := by
|
|||
| cons x _ => apply Nonempty.intro; assumption
|
||||
|
||||
theorem modLt (x : Nat) {y : Nat} (h : y > 0) : x % y < y := by
|
||||
induction x, y using Nat.mod.inductionOn generalizing h with
|
||||
induction x, y using Nat.mod.inductionOn with
|
||||
| ind x y h₁ ih =>
|
||||
rw [Nat.mod_eq_sub_mod h₁.2]
|
||||
exact ih h
|
||||
|
|
|
|||
|
|
@ -29,7 +29,8 @@ theorem eq_findSomeM_findM [Monad m] [LawfulMonad m] (p : α → m Bool) (xss :
|
|||
| cons xs xss ih =>
|
||||
rw [← ih, ← eq_findM]
|
||||
induction xs with simp
|
||||
| cons x xs ih => apply byCases_Bool_bind <;> simp [ih]
|
||||
| cons x xs ih =>
|
||||
apply byCases_Bool_bind <;> simp [ih]
|
||||
|
||||
theorem eq_findSomeM_findM' [Monad m] [LawfulMonad m] (p : α → m Bool) (xss : List (List α)) :
|
||||
(do for xs in xss do
|
||||
|
|
|
|||
|
|
@ -28,13 +28,13 @@ by {
|
|||
|
||||
theorem tst7 {α : Type} (xs : List α) (h : (a : α) → (as : List α) → xs ≠ a :: as) : xs = [] :=
|
||||
by {
|
||||
induction xs generalizing h with
|
||||
induction xs with
|
||||
| nil => exact rfl
|
||||
| cons z zs ih => exact absurd rfl (h z zs)
|
||||
}
|
||||
|
||||
theorem tst8 {α : Type} (xs : List α) (h : (a : α) → (as : List α) → xs ≠ a :: as) : xs = [] := by {
|
||||
induction xs generalizing h;
|
||||
induction xs;
|
||||
exact rfl;
|
||||
exact absurd rfl $ h _ _
|
||||
}
|
||||
|
|
@ -75,7 +75,7 @@ theorem tst13 (x : Tree) (h : x = Tree.leaf₁) : x.isLeaf₁ = true := by
|
|||
| _ => injection h
|
||||
|
||||
theorem tst14 (x : Tree) (h : x = Tree.leaf₁) : x.isLeaf₁ = true := by
|
||||
induction x generalizing h with
|
||||
induction x with
|
||||
| leaf₁ => rfl
|
||||
| _ => injection h
|
||||
|
||||
|
|
|
|||
|
|
@ -4,15 +4,15 @@ inductive Lex (ra : α → α → Prop) (rb : β → β → Prop) : α × β →
|
|||
|
||||
|
||||
def lexAccessible1 {ra : α → α → Prop} {rb : β → β → Prop} (aca : (a : α) → Acc ra a) (acb : (b : β) → Acc rb b) (a : α) (b : β) : Acc (Lex ra rb) (a, b) := by
|
||||
induction (aca a) generalizing b with
|
||||
induction aca a generalizing b with
|
||||
| intro xa aca iha =>
|
||||
induction (acb b) with
|
||||
induction acb b with
|
||||
| intro xb acb ihb =>
|
||||
apply Acc.intro (xa, xb)
|
||||
intro p lt
|
||||
cases lt with
|
||||
| left b1 b2 h => apply iha _ h -- only explicit fields are provided by default
|
||||
| @right a b1 b2 h => apply ihb b1 h -- `@` allows us to provide names to implicit fields too
|
||||
| left b1 b2 h => apply iha _ h _ (aca _ h)
|
||||
| @right a b1 b2 h => apply ihb _ h (acb _ h)
|
||||
|
||||
def lexAccessible2 {ra : α → α → Prop} {rb : β → β → Prop} (aca : (a : α) → Acc ra a) (acb : (b : β) → Acc rb b) (a : α) (b : β) : Acc (Lex ra rb) (a, b) := by
|
||||
induction (aca a) generalizing b with
|
||||
|
|
@ -22,5 +22,5 @@ def lexAccessible2 {ra : α → α → Prop} {rb : β → β → Prop} (aca : (a
|
|||
apply Acc.intro (xa, xb)
|
||||
intro p lt
|
||||
cases lt with
|
||||
| @left a1 b1 a2 b2 h => apply iha a1 h
|
||||
| right _ h => apply ihb _ h
|
||||
| @left a1 b1 a2 b2 h => apply iha _ h _ (aca _ h)
|
||||
| right _ h => apply ihb _ h (acb _ h)
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ theorem ex3 (x : Nat) : 0 + x = x := by
|
|||
| succ y => skip -- Error: unsolved goals
|
||||
|
||||
theorem ex4 (x : Nat) {y : Nat} (h : y > 0) : x % y < y := by
|
||||
induction x, y using Nat.mod.inductionOn generalizing h with
|
||||
induction x, y using Nat.mod.inductionOn with
|
||||
| ind x y h₁ ih => skip -- Error: unsolved goals
|
||||
| base x y h₁ => skip -- Error: unsolved goals
|
||||
|
||||
|
|
|
|||
|
|
@ -19,6 +19,8 @@ y : Nat
|
|||
⊢ 0 + Nat.succ y = Nat.succ y
|
||||
unsolvedIndCases.lean:18:18-18:25: error: unsolved goals
|
||||
case ind
|
||||
y✝ : Nat
|
||||
h✝ : y✝ > 0
|
||||
x y : Nat
|
||||
h₁ : 0 < y ∧ y ≤ x
|
||||
ih : y > 0 → (x - y) % y < y
|
||||
|
|
@ -26,6 +28,8 @@ h : y > 0
|
|||
⊢ x % y < y
|
||||
unsolvedIndCases.lean:19:18-19:25: error: unsolved goals
|
||||
case base
|
||||
y✝ : Nat
|
||||
h✝ : y✝ > 0
|
||||
x y : Nat
|
||||
h₁ : ¬(0 < y ∧ y ≤ x)
|
||||
h : y > 0
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue