feat: try? to use fun_induction (#7082)
This PR makes `try?` use `fun_induction` instead of `induction … using foo.induct`. It uses the argument-free short-hand `fun_induction foo` if that is unambiguous. Avoids `expose_names` if not necessary by simply trying without first.
This commit is contained in:
parent
2d4c0017b8
commit
2fed93462d
5 changed files with 91 additions and 102 deletions
|
|
@ -40,7 +40,10 @@ private def isAccessible (fvarId : FVarId) : MetaM Bool := do
|
|||
| return false
|
||||
return localDecl'.fvarId == localDecl.fvarId
|
||||
|
||||
/-- Returns `true` if all free variables occurring in `e` are accessible. -/
|
||||
/--
|
||||
Returns `true` if all free variables occurring in `e` are accessible. Over-approximation, since
|
||||
the free variable may be implicit.
|
||||
-/
|
||||
private def isExprAccessible (e : Expr) : MetaM Bool := do
|
||||
let (_, s) ← e.collectFVars |>.run {}
|
||||
s.fvarIds.allM isAccessible
|
||||
|
|
@ -598,22 +601,28 @@ private def mkSimpleTacStx : CoreM (TSyntax `tactic) :=
|
|||
/-! Function induction generators -/
|
||||
|
||||
open Try.Collector in
|
||||
private def mkFunIndStx (c : FunIndCandidate) (cont : TSyntax `tactic) : MetaM (TSyntax `tactic) := do
|
||||
if (← c.majors.allM isAccessible) then
|
||||
go
|
||||
else withExposedNames do
|
||||
`(tactic| (expose_names; $(← go):tactic))
|
||||
where
|
||||
go : MetaM (TSyntax `tactic) := do
|
||||
let mut terms := #[]
|
||||
for major in c.majors do
|
||||
let localDecl ← major.getDecl
|
||||
terms := terms.push (← `(Parser.Tactic.elimTarget| $(mkIdent localDecl.userName):term))
|
||||
let indFn ← toIdent c.funIndDeclName
|
||||
`(tactic| induction $terms,* using $indFn <;> $cont)
|
||||
private def mkFunIndStx (uniques : NameSet) (expr : Expr) (cont : TSyntax `tactic) :
|
||||
MetaM (TSyntax `tactic) := do
|
||||
let fn := expr.getAppFn.constName!
|
||||
if uniques.contains fn then
|
||||
-- If it is unambigous, use `fun_induction foo` without arguments
|
||||
`(tactic| fun_induction $(← toIdent fn):term <;> $cont)
|
||||
else
|
||||
let isAccessible ← isExprAccessible expr
|
||||
withExposedNames do
|
||||
let stx ← PrettyPrinter.delab expr
|
||||
let tac₁ ← `(tactic| fun_induction $stx <;> $cont)
|
||||
-- if expr has no inaccessible names, use as is
|
||||
if isAccessible then
|
||||
pure tac₁
|
||||
else
|
||||
-- if it has inaccessible names, still try without, in case they are all implicit
|
||||
let tac₂ ← `(tactic| (expose_names; $tac₁))
|
||||
mkFirstStx #[tac₁, tac₂]
|
||||
|
||||
private def mkAllFunIndStx (info : Try.Info) (cont : TSyntax `tactic) : MetaM (TSyntax `tactic) := do
|
||||
let tacs ← info.funIndCandidates.elems.mapM (mkFunIndStx · cont)
|
||||
let uniques := info.funIndCandidates.uniques
|
||||
let tacs ← info.funIndCandidates.calls.mapM (mkFunIndStx uniques · cont)
|
||||
mkFirstStx tacs
|
||||
|
||||
/-! Main code -/
|
||||
|
|
|
|||
|
|
@ -25,12 +25,15 @@ structure Call where
|
|||
structure SeenCalls where
|
||||
/-- the full calls -/
|
||||
calls : Array Expr
|
||||
/-- only relevant arguments -/
|
||||
seen : Std.HashSet (Array Expr)
|
||||
/-- only function name and relevant arguments -/
|
||||
seen : Std.HashSet (Name × Array Expr)
|
||||
|
||||
instance : EmptyCollection SeenCalls where
|
||||
emptyCollection := ⟨#[], {}⟩
|
||||
|
||||
def SeenCalls.isEmpty (sc : SeenCalls) : Bool :=
|
||||
sc.calls.isEmpty
|
||||
|
||||
def SeenCalls.push (e : Expr) (declName : Name) (args : Array Expr) (calls : SeenCalls) :
|
||||
MetaM SeenCalls := do
|
||||
let some funIndInfo ← getFunIndInfo? (cases := false) declName | return calls
|
||||
|
|
@ -41,8 +44,24 @@ def SeenCalls.push (e : Expr) (declName : Name) (args : Array Expr) (calls : See
|
|||
if !arg.isFVar then return calls
|
||||
unless kind matches .dropped do
|
||||
keys := keys.push arg
|
||||
if calls.seen.contains keys then return calls
|
||||
return { calls := calls.calls.push e, seen := calls.seen.insert keys }
|
||||
let key := (declName, keys)
|
||||
if calls.seen.contains key then return calls
|
||||
return { calls := calls.calls.push e, seen := calls.seen.insert key }
|
||||
|
||||
/--
|
||||
Which functions have exactly one candidate application. Used by `try?` to determine whether
|
||||
we can use `fun_induction foo` or need `fun_induction foo x y z`.
|
||||
-/
|
||||
def SeenCalls.uniques (calls : SeenCalls) : NameSet := Id.run do
|
||||
let mut seen : NameSet := {}
|
||||
let mut seenTwice : NameSet := {}
|
||||
for (n, _) in calls.seen do
|
||||
unless seenTwice.contains n do
|
||||
if seen.contains n then
|
||||
seenTwice := seenTwice.insert n
|
||||
else
|
||||
seen := seen.insert n
|
||||
return seen.filter (! seenTwice.contains ·)
|
||||
|
||||
namespace Collector
|
||||
|
||||
|
|
|
|||
|
|
@ -9,6 +9,8 @@ import Lean.Meta.Tactic.LibrarySearch
|
|||
import Lean.Meta.Tactic.Util
|
||||
import Lean.Meta.Tactic.Grind.Cases
|
||||
import Lean.Meta.Tactic.Grind.EMatchTheorem
|
||||
import Lean.Meta.Tactic.FunIndInfo
|
||||
import Lean.Meta.Tactic.FunIndCollect
|
||||
|
||||
namespace Lean.Meta.Try.Collector
|
||||
|
||||
|
|
@ -16,11 +18,6 @@ structure InductionCandidate where
|
|||
fvarId : FVarId
|
||||
val : InductiveVal
|
||||
|
||||
structure FunIndCandidate where
|
||||
funIndDeclName : Name
|
||||
majors : Array FVarId
|
||||
deriving Hashable, BEq
|
||||
|
||||
/-- `Set` with insertion order preserved. -/
|
||||
structure OrdSet (α : Type) [Hashable α] [BEq α] where
|
||||
elems : Array α := #[]
|
||||
|
|
@ -44,8 +41,8 @@ structure Result where
|
|||
unfoldCandidates : OrdSet Name := {}
|
||||
/-- Equation function candiates. -/
|
||||
eqnCandidates : OrdSet Name := {}
|
||||
/-- Function induction candidates. -/
|
||||
funIndCandidates : OrdSet FunIndCandidate := {}
|
||||
/-- Function induction candidates -/
|
||||
funIndCandidates : FunInd.SeenCalls := {}
|
||||
/-- Induction candidates. -/
|
||||
indCandidates : Array InductionCandidate := #[]
|
||||
/-- Relevant declarations by `libSearch` -/
|
||||
|
|
@ -66,17 +63,6 @@ def saveConst (declName : Name) : M Unit := do
|
|||
def inCurrentModule (declName : Name) : CoreM Bool := do
|
||||
return ((← getEnv).getModuleIdxFor? declName).isNone
|
||||
|
||||
def getFunInductName (declName : Name) : Name :=
|
||||
declName ++ `induct
|
||||
|
||||
def getFunInduct? (declName : Name) : MetaM (Option Name) := do
|
||||
let .defnInfo _ ← getConstInfo declName | return none
|
||||
try
|
||||
let result ← realizeGlobalConstNoOverloadCore (getFunInductName declName)
|
||||
return some result
|
||||
catch _ =>
|
||||
return none
|
||||
|
||||
def isEligible (declName : Name) : M Bool := do
|
||||
if declName.hasMacroScopes then
|
||||
return false
|
||||
|
|
@ -112,49 +98,11 @@ def visitConst (declName : Name) : M Unit := do
|
|||
saveConst declName
|
||||
saveUnfoldCandidate declName
|
||||
|
||||
-- Horrible temporary hack: compute the mask assuming parameters appear before a variable named `motive`
|
||||
-- It assumes major premises appear after variables with name `case?`
|
||||
-- It assumes if something is not a parameter, then it is major :(
|
||||
-- TODO: save the mask while generating the induction principle.
|
||||
def getFunIndMask? (declName : Name) (indDeclName : Name) : MetaM (Option (Array Bool)) := do
|
||||
let info ← getConstInfo declName
|
||||
let indInfo ← getConstInfo indDeclName
|
||||
let (numParams, numMajor) ← forallTelescope indInfo.type fun xs _ => do
|
||||
let mut foundCase := false
|
||||
let mut foundMotive := false
|
||||
let mut numParams : Nat := 0
|
||||
let mut numMajor : Nat := 0
|
||||
for x in xs do
|
||||
let localDecl ← x.fvarId!.getDecl
|
||||
let n := localDecl.userName
|
||||
if n == `motive then
|
||||
foundMotive := true
|
||||
else if !foundMotive then
|
||||
numParams := numParams + 1
|
||||
else if n.isStr && "case".isPrefixOf n.getString! then
|
||||
foundCase := true
|
||||
else if foundCase then
|
||||
numMajor := numMajor + 1
|
||||
return (numParams, numMajor)
|
||||
if numMajor == 0 then return none
|
||||
forallTelescope info.type fun xs _ => do
|
||||
if xs.size != numParams + numMajor then
|
||||
return none
|
||||
return some (mkArray numParams false ++ mkArray numMajor true)
|
||||
|
||||
def saveFunInd (_e : Expr) (declName : Name) (args : Array Expr) : M Unit := do
|
||||
def saveFunInd (e : Expr) (declName : Name) (args : Array Expr) : M Unit := do
|
||||
if (← isEligible declName) then
|
||||
let some funIndDeclName ← getFunInduct? declName
|
||||
| saveUnfoldCandidate declName; return ()
|
||||
let some mask ← getFunIndMask? declName funIndDeclName | return ()
|
||||
if mask.size != args.size then return ()
|
||||
let mut majors := #[]
|
||||
for arg in args, isMajor in mask do
|
||||
if isMajor then
|
||||
if !arg.isFVar then return ()
|
||||
majors := majors.push arg.fvarId!
|
||||
trace[try.collect.funInd] "{funIndDeclName}, {majors.map mkFVar}"
|
||||
modify fun s => { s with funIndCandidates := s.funIndCandidates.insert { majors, funIndDeclName }}
|
||||
let sc := (← get).funIndCandidates
|
||||
let sc' ← sc.push e declName args
|
||||
modify fun s => { s with funIndCandidates := sc' }
|
||||
|
||||
open LibrarySearch in
|
||||
def saveLibSearchCandidates (e : Expr) : M Unit := do
|
||||
|
|
@ -170,6 +118,7 @@ def saveLibSearchCandidates (e : Expr) : M Unit := do
|
|||
def visitApp (e : Expr) (declName : Name) (args : Array Expr) : M Unit := do
|
||||
saveEqnCandidate declName
|
||||
saveFunInd e declName args
|
||||
saveUnfoldCandidate declName
|
||||
saveLibSearchCandidates e
|
||||
|
||||
def checkInductive (localDecl : LocalDecl) : M Unit := do
|
||||
|
|
|
|||
|
|
@ -205,7 +205,7 @@ def evalExpr (e : Expr) : EvalM Val := do
|
|||
@[grind] theorem UnaryOp.simplify_eval (op : UnaryOp) : (op.simplify a).eval σ = (Expr.una op a).eval σ := by
|
||||
grind [UnaryOp.simplify.eq_def]
|
||||
|
||||
/-- info: Try this: (induction e using Expr.simplify.induct) <;> grind -/
|
||||
/-- info: Try this: (fun_induction Expr.simplify) <;> grind -/
|
||||
#guard_msgs (info) in
|
||||
example (e : Expr) : e.simplify.eval σ = e.eval σ := by
|
||||
try? (max := 1)
|
||||
|
|
@ -304,13 +304,14 @@ theorem State.cons_le_of_eq (h₁ : σ' ≼ σ) (h₂ : σ.find? x = some v) : (
|
|||
@[grind] theorem State.join_le_left_of (h : σ₁ ≼ σ₂) (σ₃ : State) : σ₁.join σ₃ ≼ σ₂ := by
|
||||
grind
|
||||
|
||||
/-- info: Try this: (induction σ₁, σ₂ using State.join.induct) <;> grind -/
|
||||
/-- info: Try this: (fun_induction join) <;> grind -/
|
||||
#guard_msgs (info) in
|
||||
open State in
|
||||
example (σ₁ σ₂ : State) : σ₁.join σ₂ ≼ σ₂ := by
|
||||
try? (max := 1)
|
||||
|
||||
@[grind] theorem State.join_le_right (σ₁ σ₂ : State) : σ₁.join σ₂ ≼ σ₂ := by
|
||||
induction σ₁, σ₂ using State.join.induct <;> grind
|
||||
fun_induction join <;> grind
|
||||
|
||||
@[grind] theorem State.join_le_right_of (h : σ₁ ≼ σ₂) (σ₃ : State) : σ₃.join σ₁ ≼ σ₂ := by
|
||||
grind
|
||||
|
|
|
|||
|
|
@ -80,24 +80,22 @@ example : app [a, b] [c] = [a, b, c] := by
|
|||
|
||||
/--
|
||||
info: Try these:
|
||||
• (induction as, bs using app.induct) <;> grind [= app]
|
||||
• (induction as, bs using app.induct) <;> grind only [app]
|
||||
• (fun_induction app as bs) <;> grind [= app]
|
||||
• (fun_induction app as bs) <;> grind only [app]
|
||||
-/
|
||||
#guard_msgs (info) in
|
||||
example : app (app as bs) cs = app as (app bs cs) := by
|
||||
try?
|
||||
|
||||
/--
|
||||
info: Try this: (induction as, bs using app.induct) <;> grind [= app]
|
||||
-/
|
||||
/-- info: Try this: (fun_induction app as bs) <;> grind [= app] -/
|
||||
#guard_msgs (info) in
|
||||
example : app (app as bs) cs = app as (app bs cs) := by
|
||||
try? (max := 1)
|
||||
|
||||
/--
|
||||
info: Try these:
|
||||
• · expose_names; induction as, bs_1 using app.induct <;> grind [= app]
|
||||
• · expose_names; induction as, bs_1 using app.induct <;> grind only [app]
|
||||
• · expose_names; fun_induction app as bs_1 <;> grind [= app]
|
||||
• · expose_names; fun_induction app as bs_1 <;> grind only [app]
|
||||
-/
|
||||
#guard_msgs (info) in
|
||||
example : app (app as bs) cs = app as (app bs cs) := by
|
||||
|
|
@ -106,8 +104,8 @@ example : app (app as bs) cs = app as (app bs cs) := by
|
|||
|
||||
/--
|
||||
info: Try these:
|
||||
• · expose_names; induction as, bs using app.induct <;> grind [= app]
|
||||
• · expose_names; induction as, bs using app.induct <;> grind only [app]
|
||||
• · expose_names; fun_induction app as bs <;> grind [= app]
|
||||
• · expose_names; fun_induction app as bs <;> grind only [app]
|
||||
-/
|
||||
#guard_msgs (info) in
|
||||
example : app (app as bs) cs = app as (app bs cs) := by
|
||||
|
|
@ -124,8 +122,8 @@ attribute [simp] concat
|
|||
|
||||
/--
|
||||
info: Try these:
|
||||
• (induction as, a using concat.induct) <;> simp_all
|
||||
• (induction as, a using concat.induct) <;> simp [*]
|
||||
• (fun_induction concat) <;> simp_all
|
||||
• (fun_induction concat) <;> simp [*]
|
||||
-/
|
||||
#guard_msgs (info) in
|
||||
example (as : List α) (a : α) : concat as a = as ++ [a] := by
|
||||
|
|
@ -133,9 +131,9 @@ example (as : List α) (a : α) : concat as a = as ++ [a] := by
|
|||
|
||||
/--
|
||||
info: Try these:
|
||||
• (induction as, a using concat.induct) <;> simp_all
|
||||
• (fun_induction concat) <;> simp_all
|
||||
• ·
|
||||
induction as, a using concat.induct
|
||||
fun_induction concat
|
||||
· simp
|
||||
· simp [*]
|
||||
-/
|
||||
|
|
@ -143,15 +141,28 @@ info: Try these:
|
|||
example (as : List α) (a : α) : concat as a = as ++ [a] := by
|
||||
try? -only -merge
|
||||
|
||||
def map (f : α → β) : List α → List β
|
||||
| [] => []
|
||||
| x::xs => f x :: map f xs
|
||||
|
||||
/--
|
||||
info: Try these:
|
||||
• (fun_induction map) <;> grind [= map]
|
||||
• (fun_induction map) <;> grind only [map]
|
||||
-/
|
||||
#guard_msgs (info) in
|
||||
theorem map_map (f : α → β) (g : β → γ) xs :
|
||||
map g (map f xs) = map (fun x => g (f x)) xs := by
|
||||
try? -- NB: Multiple calls to `xs.map`, but they differ only in ignore arguments
|
||||
|
||||
|
||||
def foo : Nat → Nat
|
||||
| 0 => 1
|
||||
| x+1 => foo x - 1
|
||||
|
||||
|
||||
/--
|
||||
info: Try this: ·
|
||||
induction x using foo.induct
|
||||
fun_induction foo
|
||||
· grind [= foo]
|
||||
· sorry
|
||||
-/
|
||||
|
|
@ -177,11 +188,11 @@ attribute [grind] List.length_reverse bla
|
|||
|
||||
/--
|
||||
info: Try these:
|
||||
• (induction xs, ys using bla.induct) <;> grind
|
||||
• (induction xs, ys using bla.induct) <;> simp_all
|
||||
• (induction xs, ys using bla.induct) <;> simp [*]
|
||||
• (induction xs, ys using bla.induct) <;> simp only [bla, List.length_reverse, *]
|
||||
• (induction xs, ys using bla.induct) <;> grind only [List.length_reverse, bla]
|
||||
• (fun_induction bla) <;> grind
|
||||
• (fun_induction bla) <;> simp_all
|
||||
• (fun_induction bla) <;> simp [*]
|
||||
• (fun_induction bla) <;> simp only [bla, List.length_reverse, *]
|
||||
• (fun_induction bla) <;> grind only [List.length_reverse, bla]
|
||||
-/
|
||||
#guard_msgs (info) in
|
||||
example : (bla xs ys).length = ys.length := by
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue