lean4-htt/src/Lean/Meta/Tactic/FunIndCollect.lean
Joachim Breitner e575736cae
feat: fun_induction to unfold function application in the goal (#8104)
This PR makes `fun_induction` and `fun_cases` (try to) unfold the
function application of interest in the goal. The old behavior can be
enabled with `set_option tactic.fun_induction.unfolding false`. For
`fun_cases` this does not work yet when the function’s result type
depends on one of the arguments, see issue #8296.
2025-05-13 09:37:39 +00:00

115 lines
3.7 KiB
Text
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

/-
Copyright (c) 2025 Lean FRO, LLC. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Joachim Breitner
-/
prelude
import Lean.Meta.Tactic.Util
import Lean.Meta.Tactic.FunIndInfo
/-!
Support for collecting function calls that could be used for `fun_induction` or `fun_cases`.
Used by `fun_induction foo`, and the `Calls` structure is also used by `try?`.
-/
namespace Lean.Meta.FunInd
structure Call where
/- The full function call -/
expr : Expr
/- Used to avoid adding calls that differ only in dropped arguments -/
relevantArgs : Expr
deriving Hashable, BEq
structure SeenCalls where
/-- the full calls -/
calls : Array Expr
/-- only function name and relevant arguments -/
seen : Std.HashSet (Name × Array Expr)
instance : EmptyCollection SeenCalls where
emptyCollection := ⟨#[], {}⟩
def SeenCalls.isEmpty (sc : SeenCalls) : Bool :=
sc.calls.isEmpty
def SeenCalls.push (e : Expr) (funIndInfo : FunIndInfo) (args : Array Expr) (calls : SeenCalls) :
MetaM SeenCalls := do
if funIndInfo.params.size != args.size then return calls
let mut keys := #[]
for arg in args, kind in funIndInfo.params do
if kind matches .target then
if !arg.isFVar then return calls
unless kind matches .dropped do
keys := keys.push arg
let key := (funIndInfo.funName, keys)
if calls.seen.contains key then return calls
return { calls := calls.calls.push e, seen := calls.seen.insert key }
/--
Which functions have exactly one candidate application. Used by `try?` to determine whether
we can use `fun_induction foo` or need `fun_induction foo x y z`.
-/
def SeenCalls.uniques (calls : SeenCalls) : NameSet := Id.run do
let mut seen : NameSet := {}
let mut seenTwice : NameSet := {}
for (n, _) in calls.seen do
unless seenTwice.contains n do
if seen.contains n then
seenTwice := seenTwice.insert n
else
seen := seen.insert n
return seen.filter (! seenTwice.contains ·)
namespace Collector
abbrev M := ReaderT FunIndInfo <| StateRefT SeenCalls MetaM
def saveFunInd (e : Expr) (funIndInfo : FunIndInfo) (args : Array Expr) : M Unit := do
set (← (← get).push e funIndInfo args)
def visitApp (e : Expr) (funIndInfo : FunIndInfo) (args : Array Expr) : M Unit := do
saveFunInd e funIndInfo args
unsafe abbrev Cache := PtrSet Expr
unsafe def visit (e : Expr) : StateRefT Cache M Unit := do
unless (← get).contains e do
modify fun s => s.insert e
match e with
| .const _ _ => pure ()
| .forallE _ d b _ => visit d; visit b
| .lam _ d b _ => visit d; visit b
| .mdata _ b => visit b
| .letE _ t v b _ => visit t; visit v; visit b
| .app .. => e.withApp fun f args => do
if let .const declName _ := f then
if declName = (← read).funName then
unless e.hasLooseBVars do -- TODO: We can allow them in `.dropped` arguments
visitApp e (← read) args
else
visit f
args.forM visit
| .proj _ _ b => visit b
| _ => return ()
unsafe def main (needle : FunIndInfo) (mvarId : MVarId) : MetaM (Array Expr) := mvarId.withContext do
let (_, s) ← go |>.run mkPtrSet |>.run needle |>.run {}
return s.calls
where
go : StateRefT Cache M Unit := do
for localDecl in (← getLCtx) do
unless localDecl.isAuxDecl do
if let some val := localDecl.value? then
visit val
else
visit localDecl.type
visit (← mvarId.getType)
end Collector
def collect (needle : FunIndInfo) (mvarId : MVarId) : MetaM (Array Expr) := do
unsafe Collector.main needle mvarId
end Lean.Meta.FunInd