diff --git a/src/Lean/Elab/Tactic/Try.lean b/src/Lean/Elab/Tactic/Try.lean index 85a2a9ae56..8ebd70b5d7 100644 --- a/src/Lean/Elab/Tactic/Try.lean +++ b/src/Lean/Elab/Tactic/Try.lean @@ -40,7 +40,10 @@ private def isAccessible (fvarId : FVarId) : MetaM Bool := do | return false return localDecl'.fvarId == localDecl.fvarId -/-- Returns `true` if all free variables occurring in `e` are accessible. -/ +/-- +Returns `true` if all free variables occurring in `e` are accessible. Over-approximation, since +the free variable may be implicit. + -/ private def isExprAccessible (e : Expr) : MetaM Bool := do let (_, s) ← e.collectFVars |>.run {} s.fvarIds.allM isAccessible @@ -598,22 +601,28 @@ private def mkSimpleTacStx : CoreM (TSyntax `tactic) := /-! Function induction generators -/ open Try.Collector in -private def mkFunIndStx (c : FunIndCandidate) (cont : TSyntax `tactic) : MetaM (TSyntax `tactic) := do - if (← c.majors.allM isAccessible) then - go - else withExposedNames do - `(tactic| (expose_names; $(← go):tactic)) -where - go : MetaM (TSyntax `tactic) := do - let mut terms := #[] - for major in c.majors do - let localDecl ← major.getDecl - terms := terms.push (← `(Parser.Tactic.elimTarget| $(mkIdent localDecl.userName):term)) - let indFn ← toIdent c.funIndDeclName - `(tactic| induction $terms,* using $indFn <;> $cont) +private def mkFunIndStx (uniques : NameSet) (expr : Expr) (cont : TSyntax `tactic) : + MetaM (TSyntax `tactic) := do + let fn := expr.getAppFn.constName! + if uniques.contains fn then + -- If it is unambigous, use `fun_induction foo` without arguments + `(tactic| fun_induction $(← toIdent fn):term <;> $cont) + else + let isAccessible ← isExprAccessible expr + withExposedNames do + let stx ← PrettyPrinter.delab expr + let tac₁ ← `(tactic| fun_induction $stx <;> $cont) + -- if expr has no inaccessible names, use as is + if isAccessible then + pure tac₁ + else + -- if it has inaccessible names, still try without, in case they are all implicit + let tac₂ ← `(tactic| (expose_names; $tac₁)) + mkFirstStx #[tac₁, tac₂] private def mkAllFunIndStx (info : Try.Info) (cont : TSyntax `tactic) : MetaM (TSyntax `tactic) := do - let tacs ← info.funIndCandidates.elems.mapM (mkFunIndStx · cont) + let uniques := info.funIndCandidates.uniques + let tacs ← info.funIndCandidates.calls.mapM (mkFunIndStx uniques · cont) mkFirstStx tacs /-! Main code -/ diff --git a/src/Lean/Meta/Tactic/FunIndCollect.lean b/src/Lean/Meta/Tactic/FunIndCollect.lean index 2d99f0a711..74e6488a06 100644 --- a/src/Lean/Meta/Tactic/FunIndCollect.lean +++ b/src/Lean/Meta/Tactic/FunIndCollect.lean @@ -25,12 +25,15 @@ structure Call where structure SeenCalls where /-- the full calls -/ calls : Array Expr - /-- only relevant arguments -/ - seen : Std.HashSet (Array Expr) + /-- only function name and relevant arguments -/ + seen : Std.HashSet (Name × Array Expr) instance : EmptyCollection SeenCalls where emptyCollection := ⟨#[], {}⟩ +def SeenCalls.isEmpty (sc : SeenCalls) : Bool := + sc.calls.isEmpty + def SeenCalls.push (e : Expr) (declName : Name) (args : Array Expr) (calls : SeenCalls) : MetaM SeenCalls := do let some funIndInfo ← getFunIndInfo? (cases := false) declName | return calls @@ -41,8 +44,24 @@ def SeenCalls.push (e : Expr) (declName : Name) (args : Array Expr) (calls : See if !arg.isFVar then return calls unless kind matches .dropped do keys := keys.push arg - if calls.seen.contains keys then return calls - return { calls := calls.calls.push e, seen := calls.seen.insert keys } + let key := (declName, keys) + if calls.seen.contains key then return calls + return { calls := calls.calls.push e, seen := calls.seen.insert key } + +/-- +Which functions have exactly one candidate application. Used by `try?` to determine whether +we can use `fun_induction foo` or need `fun_induction foo x y z`. +-/ +def SeenCalls.uniques (calls : SeenCalls) : NameSet := Id.run do + let mut seen : NameSet := {} + let mut seenTwice : NameSet := {} + for (n, _) in calls.seen do + unless seenTwice.contains n do + if seen.contains n then + seenTwice := seenTwice.insert n + else + seen := seen.insert n + return seen.filter (! seenTwice.contains ·) namespace Collector diff --git a/src/Lean/Meta/Tactic/Try/Collect.lean b/src/Lean/Meta/Tactic/Try/Collect.lean index 20e9eca006..11106b6765 100644 --- a/src/Lean/Meta/Tactic/Try/Collect.lean +++ b/src/Lean/Meta/Tactic/Try/Collect.lean @@ -9,6 +9,8 @@ import Lean.Meta.Tactic.LibrarySearch import Lean.Meta.Tactic.Util import Lean.Meta.Tactic.Grind.Cases import Lean.Meta.Tactic.Grind.EMatchTheorem +import Lean.Meta.Tactic.FunIndInfo +import Lean.Meta.Tactic.FunIndCollect namespace Lean.Meta.Try.Collector @@ -16,11 +18,6 @@ structure InductionCandidate where fvarId : FVarId val : InductiveVal -structure FunIndCandidate where - funIndDeclName : Name - majors : Array FVarId - deriving Hashable, BEq - /-- `Set` with insertion order preserved. -/ structure OrdSet (α : Type) [Hashable α] [BEq α] where elems : Array α := #[] @@ -44,8 +41,8 @@ structure Result where unfoldCandidates : OrdSet Name := {} /-- Equation function candiates. -/ eqnCandidates : OrdSet Name := {} - /-- Function induction candidates. -/ - funIndCandidates : OrdSet FunIndCandidate := {} + /-- Function induction candidates -/ + funIndCandidates : FunInd.SeenCalls := {} /-- Induction candidates. -/ indCandidates : Array InductionCandidate := #[] /-- Relevant declarations by `libSearch` -/ @@ -66,17 +63,6 @@ def saveConst (declName : Name) : M Unit := do def inCurrentModule (declName : Name) : CoreM Bool := do return ((← getEnv).getModuleIdxFor? declName).isNone -def getFunInductName (declName : Name) : Name := - declName ++ `induct - -def getFunInduct? (declName : Name) : MetaM (Option Name) := do - let .defnInfo _ ← getConstInfo declName | return none - try - let result ← realizeGlobalConstNoOverloadCore (getFunInductName declName) - return some result - catch _ => - return none - def isEligible (declName : Name) : M Bool := do if declName.hasMacroScopes then return false @@ -112,49 +98,11 @@ def visitConst (declName : Name) : M Unit := do saveConst declName saveUnfoldCandidate declName --- Horrible temporary hack: compute the mask assuming parameters appear before a variable named `motive` --- It assumes major premises appear after variables with name `case?` --- It assumes if something is not a parameter, then it is major :( --- TODO: save the mask while generating the induction principle. -def getFunIndMask? (declName : Name) (indDeclName : Name) : MetaM (Option (Array Bool)) := do - let info ← getConstInfo declName - let indInfo ← getConstInfo indDeclName - let (numParams, numMajor) ← forallTelescope indInfo.type fun xs _ => do - let mut foundCase := false - let mut foundMotive := false - let mut numParams : Nat := 0 - let mut numMajor : Nat := 0 - for x in xs do - let localDecl ← x.fvarId!.getDecl - let n := localDecl.userName - if n == `motive then - foundMotive := true - else if !foundMotive then - numParams := numParams + 1 - else if n.isStr && "case".isPrefixOf n.getString! then - foundCase := true - else if foundCase then - numMajor := numMajor + 1 - return (numParams, numMajor) - if numMajor == 0 then return none - forallTelescope info.type fun xs _ => do - if xs.size != numParams + numMajor then - return none - return some (mkArray numParams false ++ mkArray numMajor true) - -def saveFunInd (_e : Expr) (declName : Name) (args : Array Expr) : M Unit := do +def saveFunInd (e : Expr) (declName : Name) (args : Array Expr) : M Unit := do if (← isEligible declName) then - let some funIndDeclName ← getFunInduct? declName - | saveUnfoldCandidate declName; return () - let some mask ← getFunIndMask? declName funIndDeclName | return () - if mask.size != args.size then return () - let mut majors := #[] - for arg in args, isMajor in mask do - if isMajor then - if !arg.isFVar then return () - majors := majors.push arg.fvarId! - trace[try.collect.funInd] "{funIndDeclName}, {majors.map mkFVar}" - modify fun s => { s with funIndCandidates := s.funIndCandidates.insert { majors, funIndDeclName }} + let sc := (← get).funIndCandidates + let sc' ← sc.push e declName args + modify fun s => { s with funIndCandidates := sc' } open LibrarySearch in def saveLibSearchCandidates (e : Expr) : M Unit := do @@ -170,6 +118,7 @@ def saveLibSearchCandidates (e : Expr) : M Unit := do def visitApp (e : Expr) (declName : Name) (args : Array Expr) : M Unit := do saveEqnCandidate declName saveFunInd e declName args + saveUnfoldCandidate declName saveLibSearchCandidates e def checkInductive (localDecl : LocalDecl) : M Unit := do diff --git a/tests/lean/run/grind_constProp.lean b/tests/lean/run/grind_constProp.lean index 43732f8008..f27b425d7c 100644 --- a/tests/lean/run/grind_constProp.lean +++ b/tests/lean/run/grind_constProp.lean @@ -205,7 +205,7 @@ def evalExpr (e : Expr) : EvalM Val := do @[grind] theorem UnaryOp.simplify_eval (op : UnaryOp) : (op.simplify a).eval σ = (Expr.una op a).eval σ := by grind [UnaryOp.simplify.eq_def] -/-- info: Try this: (induction e using Expr.simplify.induct) <;> grind -/ +/-- info: Try this: (fun_induction Expr.simplify) <;> grind -/ #guard_msgs (info) in example (e : Expr) : e.simplify.eval σ = e.eval σ := by try? (max := 1) @@ -304,13 +304,14 @@ theorem State.cons_le_of_eq (h₁ : σ' ≼ σ) (h₂ : σ.find? x = some v) : ( @[grind] theorem State.join_le_left_of (h : σ₁ ≼ σ₂) (σ₃ : State) : σ₁.join σ₃ ≼ σ₂ := by grind -/-- info: Try this: (induction σ₁, σ₂ using State.join.induct) <;> grind -/ +/-- info: Try this: (fun_induction join) <;> grind -/ #guard_msgs (info) in +open State in example (σ₁ σ₂ : State) : σ₁.join σ₂ ≼ σ₂ := by try? (max := 1) @[grind] theorem State.join_le_right (σ₁ σ₂ : State) : σ₁.join σ₂ ≼ σ₂ := by - induction σ₁, σ₂ using State.join.induct <;> grind + fun_induction join <;> grind @[grind] theorem State.join_le_right_of (h : σ₁ ≼ σ₂) (σ₃ : State) : σ₃.join σ₁ ≼ σ₂ := by grind diff --git a/tests/lean/run/grind_try_trace.lean b/tests/lean/run/grind_try_trace.lean index 1f3d73bce9..e06967a586 100644 --- a/tests/lean/run/grind_try_trace.lean +++ b/tests/lean/run/grind_try_trace.lean @@ -80,24 +80,22 @@ example : app [a, b] [c] = [a, b, c] := by /-- info: Try these: -• (induction as, bs using app.induct) <;> grind [= app] -• (induction as, bs using app.induct) <;> grind only [app] +• (fun_induction app as bs) <;> grind [= app] +• (fun_induction app as bs) <;> grind only [app] -/ #guard_msgs (info) in example : app (app as bs) cs = app as (app bs cs) := by try? -/-- -info: Try this: (induction as, bs using app.induct) <;> grind [= app] --/ +/-- info: Try this: (fun_induction app as bs) <;> grind [= app] -/ #guard_msgs (info) in example : app (app as bs) cs = app as (app bs cs) := by try? (max := 1) /-- info: Try these: -• · expose_names; induction as, bs_1 using app.induct <;> grind [= app] -• · expose_names; induction as, bs_1 using app.induct <;> grind only [app] +• · expose_names; fun_induction app as bs_1 <;> grind [= app] +• · expose_names; fun_induction app as bs_1 <;> grind only [app] -/ #guard_msgs (info) in example : app (app as bs) cs = app as (app bs cs) := by @@ -106,8 +104,8 @@ example : app (app as bs) cs = app as (app bs cs) := by /-- info: Try these: -• · expose_names; induction as, bs using app.induct <;> grind [= app] -• · expose_names; induction as, bs using app.induct <;> grind only [app] +• · expose_names; fun_induction app as bs <;> grind [= app] +• · expose_names; fun_induction app as bs <;> grind only [app] -/ #guard_msgs (info) in example : app (app as bs) cs = app as (app bs cs) := by @@ -124,8 +122,8 @@ attribute [simp] concat /-- info: Try these: -• (induction as, a using concat.induct) <;> simp_all -• (induction as, a using concat.induct) <;> simp [*] +• (fun_induction concat) <;> simp_all +• (fun_induction concat) <;> simp [*] -/ #guard_msgs (info) in example (as : List α) (a : α) : concat as a = as ++ [a] := by @@ -133,9 +131,9 @@ example (as : List α) (a : α) : concat as a = as ++ [a] := by /-- info: Try these: -• (induction as, a using concat.induct) <;> simp_all +• (fun_induction concat) <;> simp_all • · - induction as, a using concat.induct + fun_induction concat · simp · simp [*] -/ @@ -143,15 +141,28 @@ info: Try these: example (as : List α) (a : α) : concat as a = as ++ [a] := by try? -only -merge +def map (f : α → β) : List α → List β + | [] => [] + | x::xs => f x :: map f xs + +/-- +info: Try these: +• (fun_induction map) <;> grind [= map] +• (fun_induction map) <;> grind only [map] +-/ +#guard_msgs (info) in +theorem map_map (f : α → β) (g : β → γ) xs : + map g (map f xs) = map (fun x => g (f x)) xs := by + try? -- NB: Multiple calls to `xs.map`, but they differ only in ignore arguments + def foo : Nat → Nat | 0 => 1 | x+1 => foo x - 1 - /-- info: Try this: · - induction x using foo.induct + fun_induction foo · grind [= foo] · sorry -/ @@ -177,11 +188,11 @@ attribute [grind] List.length_reverse bla /-- info: Try these: -• (induction xs, ys using bla.induct) <;> grind -• (induction xs, ys using bla.induct) <;> simp_all -• (induction xs, ys using bla.induct) <;> simp [*] -• (induction xs, ys using bla.induct) <;> simp only [bla, List.length_reverse, *] -• (induction xs, ys using bla.induct) <;> grind only [List.length_reverse, bla] +• (fun_induction bla) <;> grind +• (fun_induction bla) <;> simp_all +• (fun_induction bla) <;> simp [*] +• (fun_induction bla) <;> simp only [bla, List.length_reverse, *] +• (fun_induction bla) <;> grind only [List.length_reverse, bla] -/ #guard_msgs (info) in example : (bla xs ys).length = ys.length := by