From 96c6f9dc96218ac06d14965dd81d2cc8a7752e3b Mon Sep 17 00:00:00 2001 From: Joachim Breitner Date: Sun, 16 Feb 2025 11:59:56 +0100 Subject: [PATCH] feat: fun_induction and fun_cases tactics (#7069) This PR adds the `fun_induction` and `fun_cases` tactics, which add convenience around using functional induction and functional cases principles. ``` fun_induction foo x y z ``` elaborates `foo x y z`, then looks up `foo.induct`, and then essentially does ``` induction z using foo.induct y ``` including and in particular figuring out which arguments are parameters, targets or dropped. This only works for non-mutual functions so far. Likewise there is the `fun_cases` tactic using `foo.fun_cases`. --- src/Init/Tactics.lean | 40 ++ .../Elab/PreDefinition/Structural/Basic.lean | 10 + src/Lean/Elab/Tactic/Basic.lean | 8 + src/Lean/Elab/Tactic/Induction.lean | 223 ++++++--- src/Lean/Meta/Tactic/FunInd.lean | 103 +++-- src/Lean/Meta/Tactic/FunIndInfo.lean | 76 ++++ tests/lean/run/funInduction.lean | 430 ++++++++++++++++++ 7 files changed, 788 insertions(+), 102 deletions(-) create mode 100644 src/Lean/Meta/Tactic/FunIndInfo.lean create mode 100644 tests/lean/run/funInduction.lean diff --git a/src/Init/Tactics.lean b/src/Init/Tactics.lean index bccc46a86e..cfaee92934 100644 --- a/src/Init/Tactics.lean +++ b/src/Init/Tactics.lean @@ -899,6 +899,46 @@ You can use `with` to provide the variables names for each constructor. -/ syntax (name := cases) "cases " casesTarget,+ (" using " term)? (inductionAlts)? : tactic +/-- +The `fun_induction` tactic is a convenience wrapper of the `induction` tactic when using a functional +induction principle. + +The tactic invocation +``` +fun_induction f x₁ ... xₙ y₁ ... yₘ +``` +where `f` is a function defined by non-mutual structural or well-founded recursion, is equivalent to +``` +induction y₁, ... yₘ using f.induct x₁ ... xₙ +``` +where the arguments of `f` are used as arguments to `f.induct` or targets of the induction, as +appropriate. + +The forms `fun_induction f x y generalizing z₁ ... zₙ` and +`fun_induction f x y with | case1 => tac₁ | case2 x' ih => tac₂` work like with `induction.` +-/ +syntax (name := funInduction) "fun_induction " term + (" generalizing" (ppSpace colGt term:max)+)? (inductionAlts)? : tactic + +/-- +The `fun_cass` tactic is a convenience wrapper of the `cases` tactic when using a functional +cases principle. + +The tactic invocation +``` +fun_cases f x ... y ...` +``` +is equivalent to +``` +cases y, ... using f.fun_cases x ... +``` +where the arguments of `f` are used as arguments to `f.fun_cases` or targets of the case analysis, as +appropriate. + +The form `fun_cases f x y with | case1 => tac₁ | case2 x' ih => tac₂` works like with `cases`. +-/ +syntax (name := funCases) "fun_cases " term (inductionAlts)? : tactic + /-- `rename_i x_1 ... x_n` renames the last `n` inaccessible names using the given names. -/ syntax (name := renameI) "rename_i" (ppSpace colGt binderIdent)+ : tactic diff --git a/src/Lean/Elab/PreDefinition/Structural/Basic.lean b/src/Lean/Elab/PreDefinition/Structural/Basic.lean index 73b8a605e6..36ad05395f 100644 --- a/src/Lean/Elab/PreDefinition/Structural/Basic.lean +++ b/src/Lean/Elab/PreDefinition/Structural/Basic.lean @@ -66,6 +66,16 @@ The number of indices in the array. def Positions.numIndices (positions : Positions) : Nat := positions.foldl (fun s poss => s + poss.size) 0 +/-- +`positions.inverse[k] = i` means that function `i` has type k +-/ +def Positions.inverse (positions : Positions) : Array Nat := Id.run do + let mut r := mkArray positions.numIndices 0 + for _h : i in [:positions.size] do + for k in positions[i] do + r := r.set! k i + return r + /-- Groups the `xs` by their `f` value, and puts these groups into the order given by `ys`. -/ diff --git a/src/Lean/Elab/Tactic/Basic.lean b/src/Lean/Elab/Tactic/Basic.lean index 2c99643403..eac29e0e8b 100644 --- a/src/Lean/Elab/Tactic/Basic.lean +++ b/src/Lean/Elab/Tactic/Basic.lean @@ -269,6 +269,10 @@ def done : TacticM Unit := do Term.reportUnsolvedGoals gs throwAbortTactic +/-- +Runs `x` with only the first unsolved goal as the goal. +Fails if there are no goal to be solved. +-/ def focus (x : TacticM α) : TacticM α := do let mvarId :: mvarIds ← getUnsolvedGoals | throwNoGoalsToBeSolved setGoals [mvarId] @@ -277,6 +281,10 @@ def focus (x : TacticM α) : TacticM α := do setGoals (mvarIds' ++ mvarIds) pure a +/-- +Runs `tactic` with only the first unsolved goal as the goal, and expects it leave no goals. +Fails if there are no goal to be solved. +-/ def focusAndDone (tactic : TacticM α) : TacticM α := focus do let a ← tactic diff --git a/src/Lean/Elab/Tactic/Induction.lean b/src/Lean/Elab/Tactic/Induction.lean index 28529317f2..b2d690ed72 100644 --- a/src/Lean/Elab/Tactic/Induction.lean +++ b/src/Lean/Elab/Tactic/Induction.lean @@ -10,6 +10,7 @@ import Lean.Parser.Term import Lean.Meta.RecursorInfo import Lean.Meta.CollectMVars import Lean.Meta.Tactic.ElimInfo +import Lean.Meta.Tactic.FunIndInfo import Lean.Meta.Tactic.Induction import Lean.Meta.Tactic.Cases import Lean.Meta.GeneralizeVars @@ -547,31 +548,32 @@ private def expandInductionAlts? (inductionAlts : Syntax) : Option Syntax := Id. else none +private def inductionAltsPos (stx : Syntax) : Nat := + if stx.getKind == ``Lean.Parser.Tactic.induction then + 4 + else if stx.getKind == ``Lean.Parser.Tactic.cases then + 3 + else if stx.getKind == ``Lean.Parser.Tactic.funInduction then + 3 + else if stx.getKind == ``Lean.Parser.Tactic.funCases then + 2 + else + panic! "inductionAltsSyntaxPos: Unexpected syntax kind {stx.getKind}" + /-- Expand ``` syntax "induction " term,+ (" using " ident)? ("generalizing " (colGt term:max)+)? (inductionAlts)? : tactic ``` -if `inductionAlts` has an alternative with multiple LHSs. +if `inductionAlts` has an alternative with multiple LHSs, and likewise for +`cases`, `fun_induction`, `fun_cases`. -/ private def expandInduction? (induction : Syntax) : Option Syntax := do - let optInductionAlts := induction[4] + let inductionAltsPos := inductionAltsPos induction + let optInductionAlts := induction[inductionAltsPos] guard <| !optInductionAlts.isNone let inductionAlts' ← expandInductionAlts? optInductionAlts[0] - return induction.setArg 4 (mkNullNode #[inductionAlts']) - -/-- -Expand -``` -syntax "cases " casesTarget,+ (" using " ident)? (inductionAlts)? : tactic -``` -if `inductionAlts` has an alternative with multiple LHSs. --/ -private def expandCases? (induction : Syntax) : Option Syntax := do - let optInductionAlts := induction[3] - guard <| !optInductionAlts.isNone - let inductionAlts' ← expandInductionAlts? optInductionAlts[0] - return induction.setArg 3 (mkNullNode #[inductionAlts']) + return induction.setArg inductionAltsPos (mkNullNode #[inductionAlts']) /-- We may have at most one `| _ => ...` (wildcard alternative), and it must not set variable names. @@ -683,6 +685,43 @@ private def generalizeTargets (exprs : Array Expr) : TacticM (Array Expr) := do else return exprs +def checkInductionTargets (targets : Array Expr) : MetaM Unit := do + let mut foundFVars : FVarIdSet := {} + for target in targets do + unless target.isFVar do + throwError "index in target's type is not a variable (consider using the `cases` tactic instead){indentExpr target}" + if foundFVars.contains target.fvarId! then + throwError "target (or one of its indices) occurs more than once{indentExpr target}" + foundFVars := foundFVars.insert target.fvarId! + +/-- +The code path shared between `induction` and `fun_induct`; when we already have an `elimInfo` +and the `targets` contains the implicit targets +-/ +private def evalInductionCore (stx : Syntax) (elimInfo : ElimInfo) (targets : Array Expr) : TacticM Unit := do + let mvarId ← getMainGoal + -- save initial info before main goal is reassigned + let initInfo ← mkTacticInfo (← getMCtx) (← getUnsolvedGoals) (← getRef) + let tag ← mvarId.getTag + mvarId.withContext do + checkInductionTargets targets + let targetFVarIds := targets.map (·.fvarId!) + let (n, mvarId) ← generalizeVars mvarId stx targets + mvarId.withContext do + let result ← withRef stx[1] do -- use target position as reference + ElimApp.mkElimApp elimInfo targets tag + trace[Elab.induction] "elimApp: {result.elimApp}" + ElimApp.setMotiveArg mvarId result.motive targetFVarIds + -- drill down into old and new syntax: allow reuse of an rhs only if everything before it is + -- unchanged + -- everything up to the alternatives must be unchanged for reuse + Term.withNarrowedArgTacticReuse (stx := stx) (argIdx := inductionAltsPos stx) fun optInductionAlts => do + withAltsOfOptInductionAlts optInductionAlts fun alts? => do + let optPreTac := getOptPreTacOfOptInductionAlts optInductionAlts + mvarId.assign result.elimApp + ElimApp.evalAlts elimInfo result.alts optPreTac alts? initInfo (numGeneralized := n) (toClear := targetFVarIds) + appendGoals result.others.toList + @[builtin_tactic Lean.Parser.Tactic.induction, builtin_incremental] def evalInduction : Tactic := fun stx => match expandInduction? stx with @@ -691,38 +730,57 @@ def evalInduction : Tactic := fun stx => let targets ← withMainContext <| stx[1].getSepArgs.mapM (elabTerm · none) let targets ← generalizeTargets targets let elimInfo ← withMainContext <| getElimNameInfo stx[2] targets (induction := true) - let mvarId ← getMainGoal - -- save initial info before main goal is reassigned - let initInfo ← mkTacticInfo (← getMCtx) (← getUnsolvedGoals) (← getRef) - let tag ← mvarId.getTag - mvarId.withContext do - let targets ← addImplicitTargets elimInfo targets - checkTargets targets - let targetFVarIds := targets.map (·.fvarId!) - let (n, mvarId) ← generalizeVars mvarId stx targets - mvarId.withContext do - let result ← withRef stx[1] do -- use target position as reference - ElimApp.mkElimApp elimInfo targets tag - trace[Elab.induction] "elimApp: {result.elimApp}" - ElimApp.setMotiveArg mvarId result.motive targetFVarIds - -- drill down into old and new syntax: allow reuse of an rhs only if everything before it is - -- unchanged - -- everything up to the alternatives must be unchanged for reuse - Term.withNarrowedArgTacticReuse (stx := stx) (argIdx := 4) fun optInductionAlts => do - withAltsOfOptInductionAlts optInductionAlts fun alts? => do - let optPreTac := getOptPreTacOfOptInductionAlts optInductionAlts - mvarId.assign result.elimApp - ElimApp.evalAlts elimInfo result.alts optPreTac alts? initInfo (numGeneralized := n) (toClear := targetFVarIds) - appendGoals result.others.toList -where - checkTargets (targets : Array Expr) : MetaM Unit := do - let mut foundFVars : FVarIdSet := {} - for target in targets do - unless target.isFVar do - throwError "index in target's type is not a variable (consider using the `cases` tactic instead){indentExpr target}" - if foundFVars.contains target.fvarId! then - throwError "target (or one of its indices) occurs more than once{indentExpr target}" - foundFVars := foundFVars.insert target.fvarId! + let targets ← withMainContext <| addImplicitTargets elimInfo targets + evalInductionCore stx elimInfo targets + +/-- +Elaborates the `foo args` of `fun_induction` or `fun_cases`, returning the `ElabInfo` and targets. +-/ +private def elabFunTarget (cases : Bool) (stx : Syntax) : TacticM (ElimInfo × Array Expr) := do + withRef stx <| withMainContext do + let funCall ← elabTerm stx none + funCall.withApp fun fn funArgs => do + let .const fnName fnUs := fn | + throwError "expected application headed by a function constant" + let some funIndInfo ← getFunIndInfo? cases fnName | + let theoremKind := if cases then "induction" else "cases" + throwError "no functional {theoremKind} theorem for '{.ofConstName fnName}', or function is mutually recursive " + if funArgs.size != funIndInfo.params.size then + throwError "Expected fully applied application of '{.ofConstName fnName}' with \ + {funIndInfo.params.size} arguments, but found {funArgs.size} arguments" + let mut params := #[] + let mut targets := #[] + let mut us := #[] + for u in fnUs, b in funIndInfo.levelMask do + if b then + us := us.push u + for a in funArgs, kind in funIndInfo.params do + match kind with + | .dropped => pure () + | .param => params := params.push a + | .target => targets := targets.push a + if cases then + trace[Elab.cases] "us: {us}\nparams: {params}\ntargets: {targets}" + else + trace[Elab.induction] "us: {us}\nparams: {params}\ntargets: {targets}" + + let elimExpr := mkAppN (.const funIndInfo.funIndName us.toList) params + let elimInfo ← getElimExprInfo elimExpr + unless targets.size = elimInfo.targetsPos.size do + let tacName := if cases then "fun_cases" else "fun_induction" + throwError "{tacName} got confused trying to use \ + {.ofConstName funIndInfo.funIndName}. Does it take {targets.size} or \ + {elimInfo.targetsPos.size} targets?" + return (elimInfo, targets) + +@[builtin_tactic Lean.Parser.Tactic.funInduction, builtin_incremental] +def evalFunInduction : Tactic := fun stx => + match expandInduction? stx with + | some stxNew => withMacroExpansion stx stxNew <| evalTactic stxNew + | _ => focus do + let (elimInfo, targets) ← elabFunTarget (cases := false) stx[1] + let targets ← generalizeTargets targets + evalInductionCore stx elimInfo targets def elabCasesTargets (targets : Array Syntax) : TacticM (Array Expr × Array (Ident × FVarId)) := withMainContext do @@ -736,7 +794,7 @@ def elabCasesTargets (targets : Array Syntax) : TacticM (Array Expr × Array (Id pure (some target[0][0].getId) let expr ← elabTerm target[1] none args := args.push { expr, hName? : GeneralizeArg } - if (← withMainContext <| args.anyM fun arg => shouldGeneralizeTarget arg.expr <||> pure arg.hName?.isSome) then + if (← args.anyM fun arg => shouldGeneralizeTarget arg.expr <||> pure arg.hName?.isSome) then liftMetaTacticAux fun mvarId => do let argsToGeneralize ← args.filterM fun arg => shouldGeneralizeTarget arg.expr <||> pure arg.hName?.isSome let (fvarIdsNew, mvarId) ← mvarId.generalize argsToGeneralize @@ -755,38 +813,55 @@ def elabCasesTargets (targets : Array Syntax) : TacticM (Array Expr × Array (Id else return (args.map (·.expr), #[]) +/-- +The code path shared between `cases` and `fun_cases`; when we already have an `elimInfo` +and the `targets` contains the implicit targets +-/ +def evalCasesCore (stx : Syntax) (elimInfo : ElimInfo) (targets : Array Expr) + (toTag : Array (Ident × FVarId) := #[]) : TacticM Unit := do + let targetRef := stx[1] + let mvarId ← getMainGoal + -- save initial info before main goal is reassigned + let initInfo ← mkTacticInfo (← getMCtx) (← getUnsolvedGoals) (← getRef) + let tag ← mvarId.getTag + mvarId.withContext do + let result ← withRef targetRef <| ElimApp.mkElimApp elimInfo targets tag + let elimArgs := result.elimApp.getAppArgs + let targets ← elimInfo.targetsPos.mapM fun i => instantiateMVars elimArgs[i]! + let motiveType ← inferType elimArgs[elimInfo.motivePos]! + let mvarId ← generalizeTargetsEq mvarId motiveType targets + let (targetsNew, mvarId) ← mvarId.introN targets.size + mvarId.withContext do + ElimApp.setMotiveArg mvarId elimArgs[elimInfo.motivePos]!.mvarId! targetsNew + mvarId.assign result.elimApp + -- drill down into old and new syntax: allow reuse of an rhs only if everything before it is + -- unchanged + -- everything up to the alternatives must be unchanged for reuse + Term.withNarrowedArgTacticReuse (stx := stx) (argIdx := inductionAltsPos stx) fun optInductionAlts => do + withAltsOfOptInductionAlts optInductionAlts fun alts => do + let optPreTac := getOptPreTacOfOptInductionAlts optInductionAlts + ElimApp.evalAlts elimInfo result.alts optPreTac alts initInfo + (numEqs := targets.size) (toClear := targetsNew) (toTag := toTag) + @[builtin_tactic Lean.Parser.Tactic.cases, builtin_incremental] def evalCases : Tactic := fun stx => - match expandCases? stx with + match expandInduction? stx with | some stxNew => withMacroExpansion stx stxNew <| evalTactic stxNew | _ => focus do -- leading_parser nonReservedSymbol "cases " >> sepBy1 (group majorPremise) ", " >> usingRec >> optInductionAlts let (targets, toTag) ← elabCasesTargets stx[1].getSepArgs - let targetRef := stx[1] let elimInfo ← withMainContext <| getElimNameInfo stx[2] targets (induction := false) - let mvarId ← getMainGoal - -- save initial info before main goal is reassigned - let initInfo ← mkTacticInfo (← getMCtx) (← getUnsolvedGoals) (← getRef) - let tag ← mvarId.getTag - mvarId.withContext do - let targets ← addImplicitTargets elimInfo targets - let result ← withRef targetRef <| ElimApp.mkElimApp elimInfo targets tag - let elimArgs := result.elimApp.getAppArgs - let targets ← elimInfo.targetsPos.mapM fun i => instantiateMVars elimArgs[i]! - let motiveType ← inferType elimArgs[elimInfo.motivePos]! - let mvarId ← generalizeTargetsEq mvarId motiveType targets - let (targetsNew, mvarId) ← mvarId.introN targets.size - mvarId.withContext do - ElimApp.setMotiveArg mvarId elimArgs[elimInfo.motivePos]!.mvarId! targetsNew - mvarId.assign result.elimApp - -- drill down into old and new syntax: allow reuse of an rhs only if everything before it is - -- unchanged - -- everything up to the alternatives must be unchanged for reuse - Term.withNarrowedArgTacticReuse (stx := stx) (argIdx := 3) fun optInductionAlts => do - withAltsOfOptInductionAlts optInductionAlts fun alts => do - let optPreTac := getOptPreTacOfOptInductionAlts optInductionAlts - ElimApp.evalAlts elimInfo result.alts optPreTac alts initInfo - (numEqs := targets.size) (toClear := targetsNew) (toTag := toTag) + let targets ← withMainContext <| addImplicitTargets elimInfo targets + evalCasesCore stx elimInfo targets toTag + +@[builtin_tactic Lean.Parser.Tactic.funCases, builtin_incremental] +def evalFunCases : Tactic := fun stx => + match expandInduction? stx with + | some stxNew => withMacroExpansion stx stxNew <| evalTactic stxNew + | _ => focus do + let (elimInfo, targets) ← elabFunTarget (cases := true) stx[1] + let targets ← generalizeTargets targets + evalCasesCore stx elimInfo targets builtin_initialize registerTraceClass `Elab.cases diff --git a/src/Lean/Meta/Tactic/FunInd.lean b/src/Lean/Meta/Tactic/FunInd.lean index ace83b9e89..6754a1ce7f 100644 --- a/src/Lean/Meta/Tactic/FunInd.lean +++ b/src/Lean/Meta/Tactic/FunInd.lean @@ -18,6 +18,7 @@ import Lean.Elab.PreDefinition.Structural.IndGroupInfo import Lean.Elab.PreDefinition.Structural.FindRecArg import Lean.Elab.Command import Lean.Meta.Tactic.ElimInfo +import Lean.Meta.Tactic.FunIndInfo /-! This module contains code to derive, from the definition of a recursive function (structural or @@ -659,7 +660,7 @@ Given a unary definition `foo` defined via `WellFounded.fixF`, derive a suitable `foo.induct` for it. See module doc for details. -/ def deriveUnaryInduction (name : Name) : MetaM Name := do - let inductName := .append name `induct + let inductName := getFunInductName name if ← hasConst inductName then return inductName let info ← getConstInfoDefn name @@ -677,7 +678,7 @@ def deriveUnaryInduction (name : Name) : MetaM Name := do mkLambdaFVars (params ++ xs) (mkAppN body xs) else pure e - let e' ← lambdaTelescope e fun params funBody => MatcherApp.withUserNames params varNames do + let (e', paramMask) ← lambdaTelescope e fun params funBody => MatcherApp.withUserNames params varNames do match_expr funBody with | fix@WellFounded.fix α _motive rel wf body target => unless params.back! == target do @@ -719,8 +720,9 @@ def deriveUnaryInduction (name : Name) : MetaM Name := do -- induction principle match the type of the function better. -- But this leads to avoidable parameters that make functional induction strictly less -- useful (e.g. when the unsued parameter mentions bound variables in the users' goal) - let e' ← mkLambdaFVars (binderInfoForMVars := .default) (usedOnly := true) fixedParams e' - instantiateMVars e' + let (paramMask, e') ← mkLambdaFVarsMasked fixedParams e' + let e' ← instantiateMVars e' + return (e', paramMask) | _ => if funBody.isAppOf ``WellFounded.fix then throwError "Function {name} defined via WellFounded.fix with unexpected arity {funBody.getAppNumArgs}:{indentExpr funBody}" @@ -734,12 +736,20 @@ def deriveUnaryInduction (name : Name) : MetaM Name := do let eTyp ← inferType e' let eTyp ← elimOptParam eTyp -- logInfo m!"eTyp: {eTyp}" - let params := (collectLevelParams {} eTyp).params + let levelParams := (collectLevelParams {} eTyp).params -- Prune unused level parameters, preserving the original order - let us := info.levelParams.filter (params.contains ·) + let funUs := info.levelParams.toArray + let usMask := funUs.map (levelParams.contains ·) + let us := maskArray usMask funUs |>.toList addDecl <| Declaration.thmDecl { name := inductName, levelParams := us, type := eTyp, value := e' } + + setFunIndInfo { + funIndName := inductName + levelMask := usMask + params := paramMask.map (cond · .param .dropped) ++ #[.target] + } return inductName /-- @@ -751,7 +761,7 @@ def projectMutualInduct (names : Array Name) (mutualInduct : Name) : MetaM Unit let levelParams := ci.levelParams for name in names, idx in [:names.size] do - let inductName := .append name `induct + let inductName := getFunInductName name unless ← hasConst inductName do let value ← forallTelescope ci.type fun xs _body => do let value := .const ci.name (levelParams.map mkLevelParam) @@ -761,6 +771,21 @@ def projectMutualInduct (names : Array Name) (mutualInduct : Name) : MetaM Unit let type ← inferType value addDecl <| Declaration.thmDecl { name := inductName, levelParams, type, value } +/-- +For a (non-mutual!) definition of `name`, uses the `FunIndInfo` associated with the `unaryInduct` and +derives the one for the n-ary function. +-/ +def setNaryFunIndInfo (name : Name) (arity : Nat) (unaryInduct : Name) : MetaM Unit := do + let inductName := getFunInductName name + unless inductName = unaryInduct do + let some unaryFunIndInfo ← getFunIndInfoForInduct? unaryInduct + | throwError "Expected {unaryInduct} to have FunIndInfo" + setFunIndInfo { + unaryFunIndInfo with + funIndName := inductName + params := unaryFunIndInfo.params.filter (· != .target) ++ mkArray arity .target + } + /-- In the type of `value`, reduces * Beta-redexes @@ -823,10 +848,10 @@ unpacks it into a n-ary and (possibly) joint induction principle. -/ def unpackMutualInduction (eqnInfo : WF.EqnInfo) (unaryInductName : Name) : MetaM Name := do let inductName := if eqnInfo.declNames.size > 1 then - .append eqnInfo.declNames[0]! `mutual_induct + getMutualInductName eqnInfo.declNames[0]! else -- If there is no mutual recursion, we generate the `foo.induct` directly. - .append eqnInfo.declNames[0]! `induct + getFunInductName eqnInfo.declNames[0]! if ← hasConst inductName then return inductName let ci ← getConstInfo unaryInductName @@ -867,11 +892,6 @@ def unpackMutualInduction (eqnInfo : WF.EqnInfo) (unaryInductName : Name) : Meta return inductName -/-- Given `foo._unary.induct`, define `foo.mutual_induct` and then `foo.induct`, `bar.induct`, … -/ -def deriveUnpackedInduction (eqnInfo : WF.EqnInfo) (unaryInductName : Name): MetaM Unit := do - let unpackedInductName ← unpackMutualInduction eqnInfo unaryInductName - projectMutualInduct eqnInfo.declNames unpackedInductName - def withLetDecls {α} (name : Name) (ts : Array Expr) (es : Array Expr) (k : Array Expr → MetaM α) : MetaM α := do assert! es.size = ts.size go 0 #[] @@ -891,7 +911,7 @@ See module doc for details. def deriveInductionStructural (names : Array Name) (numFixed : Nat) : MetaM Unit := do let infos ← names.mapM getConstInfoDefn -- First open up the fixed parameters everywhere - let e' ← lambdaBoundedTelescope infos[0]!.value numFixed fun xs _ => do + let (e', paramMask, motiveArities) ← lambdaBoundedTelescope infos[0]!.value numFixed fun xs _ => do -- Now look at the body of an arbitrary of the functions (they are essentially the same -- up to the final projections) let body ← instantiateLambda infos[0]!.value xs @@ -937,12 +957,13 @@ def deriveInductionStructural (names : Array Name) (numFixed : Nat) : MetaM Unit -- We also need to know the number of indices of each type former, including the auxiliary -- type formers that do not have IndInfo. We can read it off the motives types of the recursor. - let numTargetss ← do + let numTypeFormerTargetss ← do let aux := mkAppN (.const recInfo.name (0 :: group.levels)) group.params let motives ← inferArgumentTypesN recInfo.numMotives aux motives.mapM fun motive => forallTelescopeReducing motive fun xs _ => pure xs.size + let recArgInfos ← infos.mapM fun info => do let some eqnInfo := Structural.eqnInfoExt.find? (← getEnv) info.name | throwError "{info.name} missing eqnInfo" let value ← instantiateLambda info.value xs @@ -972,6 +993,7 @@ def deriveInductionStructural (names : Array Name) (numFixed : Nat) : MetaM Unit lambdaTelescope (← instantiateLambda info.value xs) fun ys _ => pure ys.size let motiveNames := Array.ofFn (n := infos.size) fun ⟨i, _⟩ => if infos.size = 1 then .mkSimple "motive" else .mkSimple s!"motive_{i+1}" + withLocalDeclsDND (motiveNames.zip motiveTypes) fun motives => do -- Prepare the `isRecCall` that recognizes recursive calls @@ -1000,7 +1022,7 @@ def deriveInductionStructural (names : Array Name) (numFixed : Nat) : MetaM Unit -- So that we can transform them let (minors', mvars) ← M2.run do let mut minors' := #[] - for brecOnMinor in brecOnMinors, goal in minorTypes, numTargets in numTargetss do + for brecOnMinor in brecOnMinors, goal in minorTypes, numTargets in numTypeFormerTargetss do let minor' ← forallTelescope goal fun xs goal => do unless xs.size ≥ numTargets do throwError ".brecOn argument has too few parameters, expected at least {numTargets}: {xs}" @@ -1053,10 +1075,10 @@ def deriveInductionStructural (names : Array Name) (numFixed : Nat) : MetaM Unit -- induction principle match the type of the function better. -- But this leads to avoidable parameters that make functional induction strictly less -- useful (e.g. when the unsued parameter mentions bound variables in the users' goal) - let e' ← mkLambdaFVars (binderInfoForMVars := .default) (usedOnly := true) xs e' + let (paramMask, e') ← mkLambdaFVarsMasked xs e' let e' ← instantiateMVars e' trace[Meta.FunInd] "complete body of mutual induction principle:{indentExpr e'}" - pure e' + pure (e', paramMask, motiveArities) unless (← isTypeCorrect e') do logError m!"constructed induction principle is not type correct:{indentExpr e'}" @@ -1065,22 +1087,33 @@ def deriveInductionStructural (names : Array Name) (numFixed : Nat) : MetaM Unit let eTyp ← inferType e' let eTyp ← elimOptParam eTyp -- logInfo m!"eTyp: {eTyp}" - let params := (collectLevelParams {} eTyp).params + let levelParams := (collectLevelParams {} eTyp).params -- Prune unused level parameters, preserving the original order - let us := infos[0]!.levelParams.filter (params.contains ·) + let funUs := infos[0]!.levelParams.toArray + let usMask := funUs.map (levelParams.contains ·) + let us := maskArray usMask funUs |>.toList let inductName := if names.size = 1 then - names[0]! ++ `induct + getFunInductName names[0]! else - names[0]! ++ `mutual_induct + getMutualInductName names[0]! addDecl <| Declaration.thmDecl { name := inductName, levelParams := us, type := eTyp, value := e' } + if names.size > 1 then projectMutualInduct names inductName + if names.size = 1 then + setFunIndInfo { + funIndName := inductName + levelMask := usMask + params := paramMask.map (cond · .param .dropped) ++ + mkArray motiveArities[0]! .target + } + /-- For non-recursive (and recursive functions) functions we derive a “functional case splitting theorem”. This is very similar @@ -1110,6 +1143,8 @@ def deriveCases (name : Name) : MetaM Unit := do throwError "'{name}' does not have an unfold theorem nor a value" let motiveType ← lambdaTelescope value fun xs _body => do mkForallFVars xs (.sort 0) + let motiveArity ← lambdaTelescope value fun xs _body => do + pure xs.size let e' ← withLocalDeclD `motive motiveType fun motive => do lambdaTelescope value fun xs body => do let (e',mvars) ← M2.run do @@ -1131,12 +1166,22 @@ def deriveCases (name : Name) : MetaM Unit := do let eTyp ← inferType e' let eTyp ← elimOptParam eTyp -- logInfo m!"eTyp: {eTyp}" - let params := (collectLevelParams {} eTyp).params + let levelParams := (collectLevelParams {} eTyp).params -- Prune unused level parameters, preserving the original order - let us := info.levelParams.filter (params.contains ·) + let funUs := info.levelParams.toArray + let usMask := funUs.map (levelParams.contains ·) + let us := maskArray usMask funUs |>.toList + let casesName := getFunCasesName info.name addDecl <| Declaration.thmDecl - { name := info.name ++ `fun_cases, levelParams := us, type := eTyp, value := e' } + { name := casesName, levelParams := us, type := eTyp, value := e' } + + setFunIndInfo { + funIndName := casesName + levelMask := usMask + params := mkArray motiveArity .target + } + /-- Given a recursively defined function `foo`, derives `foo.induct`. See the module doc for details. @@ -1145,8 +1190,10 @@ def deriveInduction (name : Name) : MetaM Unit := do mapError (f := (m!"Cannot derive functional induction principle (please report this issue)\n{indentD ·}")) do if let some eqnInfo := WF.eqnInfoExt.find? (← getEnv) name then let unaryInductName ← deriveUnaryInduction eqnInfo.declNameNonRec - unless eqnInfo.declNameNonRec = name do - deriveUnpackedInduction eqnInfo unaryInductName + let unpackedInductName ← unpackMutualInduction eqnInfo unaryInductName + projectMutualInduct eqnInfo.declNames unpackedInductName + if eqnInfo.argsPacker.numFuncs = 1 then + setNaryFunIndInfo eqnInfo.declNames[0]! eqnInfo.argsPacker.arities[0]! unaryInductName else if let some eqnInfo := Structural.eqnInfoExt.find? (← getEnv) name then deriveInductionStructural eqnInfo.declNames eqnInfo.numFixed else diff --git a/src/Lean/Meta/Tactic/FunIndInfo.lean b/src/Lean/Meta/Tactic/FunIndInfo.lean new file mode 100644 index 0000000000..68b3f93fda --- /dev/null +++ b/src/Lean/Meta/Tactic/FunIndInfo.lean @@ -0,0 +1,76 @@ +/- +Copyright (c) 2025 Lean FRO, LLC. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Joachim Breitner +-/ + +prelude +import Lean.Meta.Basic +import Lean.ScopedEnvExtension +import Lean.ReservedNameAction + +/-! +This module defines the data structure and environment extension to remember how to map the +function's arguments to the functional induction principle's arguments. +Also used for functional cases. +-/ + +namespace Lean.Meta + +inductive FunIndParamKind where + | dropped + | param + | target +deriving BEq, Repr + +/-- +A `FunIndInfo` indicates how a function's arguments map to the arguments of the functional induction +(resp. cases) theorem. + +The size of `params` also indicates the arity of the function. +-/ +structure FunIndInfo where + funIndName : Name + /-- + `true` means that the corresponding level parameter of the function is also a level param + of the induction principle. + -/ + levelMask : Array Bool + params : Array FunIndParamKind +deriving Inhabited, Repr + +builtin_initialize funIndInfoExt : MapDeclarationExtension FunIndInfo ← mkMapDeclarationExtension + +def getFunInductName (declName : Name) : Name := + declName ++ `induct + +def getFunCasesName (declName : Name) : Name := + declName ++ `fun_cases + +def getMutualInductName (declName : Name) : Name := + declName ++ `mutual_induct + +def getFunInduct? (cases : Bool) (declName : Name) : CoreM (Option Name) := do + let .defnInfo _ ← getConstInfo declName | return none + try + let thmName := if cases then + getFunCasesName declName + else + getFunInductName declName + let result ← realizeGlobalConstNoOverloadCore thmName + return some result + catch _ => + return none + +def setFunIndInfo (funIndInfo : FunIndInfo) : CoreM Unit := do + assert! !(funIndInfoExt.contains (← getEnv) funIndInfo.funIndName) + modifyEnv fun env => funIndInfoExt.insert env funIndInfo.funIndName funIndInfo + +def getFunIndInfoForInduct? (inductName : Name) : CoreM (Option FunIndInfo) := do + return funIndInfoExt.find? (← getEnv) inductName + +def getFunIndInfo? (cases : Bool) (funName : Name) : CoreM (Option FunIndInfo) := do + let some inductName ← getFunInduct? cases funName | return none + getFunIndInfoForInduct? inductName + +end Lean.Meta diff --git a/tests/lean/run/funInduction.lean b/tests/lean/run/funInduction.lean new file mode 100644 index 0000000000..85838a6ba4 --- /dev/null +++ b/tests/lean/run/funInduction.lean @@ -0,0 +1,430 @@ +import Lean + +namespace Ex1 + +variable (P : Nat → Prop) + +def ackermann : (Nat × Nat) → Nat + | (0, m) => m + 1 + | (n+1, 0) => ackermann (n, 1) + | (n+1, m+1) => ackermann (n, ackermann (n + 1, m)) +termination_by p => p + +/-- +error: tactic 'fail' failed +case case1 +P : Nat → Prop +m✝ : Nat +⊢ P (ackermann (0, m✝)) + +case case2 +P : Nat → Prop +n✝ : Nat +ih1✝ : P (ackermann (n✝, 1)) +⊢ P (ackermann (n✝.succ, 0)) + +case case3 +P : Nat → Prop +n✝ m✝ : Nat +ih2✝ : P (ackermann (n✝ + 1, m✝)) +ih1✝ : P (ackermann (n✝, ackermann (n✝ + 1, m✝))) +⊢ P (ackermann (n✝.succ, m✝.succ)) +-/ +#guard_msgs in +example : P (ackermann p) := by + fun_induction ackermann p + fail + +/-- +error: tactic 'fail' failed +case case1 +P : Nat → Prop +m✝ : Nat +⊢ P (ackermann (0, m✝)) + +case case2 +P : Nat → Prop +n✝ : Nat +⊢ P (ackermann (n✝.succ, 0)) + +case case3 +P : Nat → Prop +n✝ m✝ : Nat +⊢ P (ackermann (n✝.succ, m✝.succ)) +-/ +#guard_msgs in +example : P (ackermann p) := by + fun_cases ackermann p + fail + +/-- +error: unsolved goals +case case1 +P : Nat → Prop +n m m✝ : Nat +⊢ P (ackermann (0, m✝)) + +case case2 +P : Nat → Prop +n m n✝ : Nat +ih1✝ : P (ackermann (n✝, 1)) +⊢ P (ackermann (n✝.succ, 0)) + +case case3 +P : Nat → Prop +n m n✝ m✝ : Nat +ih2✝ : P (ackermann (n✝ + 1, m✝)) +ih1✝ : P (ackermann (n✝, ackermann (n✝ + 1, m✝))) +⊢ P (ackermann (n✝.succ, m✝.succ)) +-/ +#guard_msgs in +example : P (ackermann (n, m)) := by + fun_induction ackermann (n, m) + +/-- +error: unsolved goals +case case1 +P : Nat → Prop +n m m✝ : Nat +⊢ P (ackermann (0, m✝)) + +case case2 +P : Nat → Prop +n m n✝ : Nat +⊢ P (ackermann (n✝.succ, 0)) + +case case3 +P : Nat → Prop +n m n✝ m✝ : Nat +⊢ P (ackermann (n✝.succ, m✝.succ)) +-/ +#guard_msgs in +example : P (ackermann (n, m)) := by + fun_cases ackermann (n, m) + +-- Testing Generalization: + +/-- +error: unsolved goals +case case1 +P : Nat → Prop +n m m✝ : Nat +⊢ P (ackermann (n, m)) + +case case2 +P : Nat → Prop +n m n✝ : Nat +⊢ P (ackermann (n, m)) + +case case3 +P : Nat → Prop +n m n✝ m✝ : Nat +⊢ P (ackermann (n, m)) +-/ +#guard_msgs in +example : P (ackermann (n, m)) := by + fun_cases ackermann (n+n, m) + +end Ex1 + +namespace Ex2 + +variable (P : Nat → Prop) + +def ackermann : Nat → Nat → Nat + | 0, m => m + 1 + | n+1, 0 => ackermann n 1 + | n+1, m+1 => ackermann n (ackermann (n + 1) m) +termination_by n m => (n, m) + +/-- +error: unsolved goals +case case1 +P : Nat → Prop +m✝ : Nat +⊢ P (ackermann 0 m✝) + +case case2 +P : Nat → Prop +n✝ : Nat +ih1✝ : P (ackermann n✝ 1) +⊢ P (ackermann n✝.succ 0) + +case case3 +P : Nat → Prop +n✝ m✝ : Nat +ih2✝ : P (ackermann (n✝ + 1) m✝) +ih1✝ : P (ackermann n✝ (ackermann (n✝ + 1) m✝)) +⊢ P (ackermann n✝.succ m✝.succ) +-/ +#guard_msgs in +example : P (ackermann n m) := by + fun_induction ackermann n m + +/-- +error: Expected fully applied application of 'ackermann' with 2 arguments, but found 1 arguments +-/ +#guard_msgs in +example : P (ackermann n m) := by + fun_induction ackermann n + +/-- +error: Expected fully applied application of 'ackermann' with 2 arguments, but found 0 arguments +-/ +#guard_msgs in +example : P (ackermann n m) := by + fun_induction ackermann + +end Ex2 + +namespace Ex3 + +variable (P : List α → Prop) + +def ackermann {α} (inc : List α) : List α → List α → List α + | [], ms => ms ++ inc + | _::ns, [] => ackermann inc ns inc + | n::ns, _::ms => ackermann inc ns (ackermann inc (n::ns) ms) +termination_by ns ms => (ns, ms) + +/-- +error: unsolved goals +case case1 +α : Type u_1 +P : List α → Prop +inc ms✝ : List α +⊢ P (ackermann inc [] ms✝) + +case case2 +α : Type u_1 +P : List α → Prop +inc : List α +head✝ : α +ns✝ : List α +ih1✝ : P (ackermann inc ns✝ inc) +⊢ P (ackermann inc (head✝ :: ns✝) []) + +case case3 +α : Type u_1 +P : List α → Prop +inc : List α +n✝ : α +ns✝ : List α +head✝ : α +ms✝ : List α +ih2✝ : P (ackermann inc (n✝ :: ns✝) ms✝) +ih1✝ : P (ackermann inc ns✝ (ackermann inc (n✝ :: ns✝) ms✝)) +⊢ P (ackermann inc (n✝ :: ns✝) (head✝ :: ms✝)) +-/ +#guard_msgs in +example : P (ackermann inc n m) := by + fun_induction ackermann inc n m + +/-- +error: Expected fully applied application of 'ackermann' with 4 arguments, but found 3 arguments +-/ +#guard_msgs in +example : P (ackermann inc n m) := by + fun_induction ackermann inc n + +/-- +error: Expected fully applied application of 'ackermann' with 4 arguments, but found 2 arguments +-/ +#guard_msgs in +example : P (ackermann inc n m) := by + fun_induction ackermann inc + +end Ex3 + +namespace Structural + +variable (P : Nat → Prop) + +def fib : Nat → Nat + | 0 => 0 + | 1 => 1 + | n+2 => fib n + fib (n+1) +termination_by structural x => x + +/-- +error: tactic 'fail' failed +case case1 +P : Nat → Prop +⊢ P (fib 0) + +case case2 +P : Nat → Prop +⊢ P (fib 1) + +case case3 +P : Nat → Prop +n✝ : Nat +ih2✝ : P (fib n✝) +ih1✝ : P (fib (n✝ + 1)) +⊢ P (fib n✝.succ.succ) +-/ +#guard_msgs in +example : P (fib n) := by + fun_induction fib n + fail + +example : n ≤ fib (n + 2) := by + fun_induction fib n + case case1 => simp [fib] + case case2 => simp [fib] + case case3 n ih1 ih2 => simp_all [fib]; omega + +example : n ≤ fib (n + 2) := by + fun_induction fib n with + | case1 | case2 => simp [fib] + | case3 => simp_all [fib]; omega + + +end Structural + +namespace StructuralWithOmittedParam + +variable (P : Nat → Prop) + +variable (inc : Nat) +def fib : Nat → Nat + | 0 => 0 + | 1 => inc + | n+2 => fib n + fib (n+1) +termination_by structural x => x + +/-- +info: StructuralWithOmittedParam.fib.induct (motive : Nat → Prop) (case1 : motive 0) (case2 : motive 1) + (case3 : ∀ (n : Nat), motive n → motive (n + 1) → motive n.succ.succ) (a✝ : Nat) : motive a✝ +-/ +#guard_msgs in +#check fib.induct -- NB: No inc showing up + +/-- +error: tactic 'fail' failed +case case1 +P : Nat → Prop +inc : Nat +⊢ P (fib 2 0) + +case case2 +P : Nat → Prop +inc : Nat +⊢ P (fib 2 1) + +case case3 +P : Nat → Prop +inc n✝ : Nat +ih2✝ : P (fib 2 n✝) +ih1✝ : P (fib 2 (n✝ + 1)) +⊢ P (fib 2 n✝.succ.succ) +-/ +#guard_msgs in +example : P (fib 2 n) := by + fun_induction fib 3 n + fail + +/-- +error: tactic 'fail' failed +case case1 +P : Nat → Prop +inc : Nat +⊢ P (fib 2 0) + +case case2 +P : Nat → Prop +inc : Nat +⊢ P (fib 2 1) + +case case3 +P : Nat → Prop +inc n✝ : Nat +ih2✝ : P (fib 2 n✝) +ih1✝ : P (fib 2 (n✝ + 1)) +⊢ P (fib 2 n✝.succ.succ) +-/ +#guard_msgs in +example : P (fib 2 n) := by + fun_induction fib _ n + fail + +end StructuralWithOmittedParam + +namespace StructuralIndices + +-- Testing recursion on an indexed data type +inductive Finn : Nat → Type where + | fzero : {n : Nat} → Finn n + | fsucc : {n : Nat} → Finn n → Finn (n+1) + +def Finn.min (x : Bool) {n : Nat} (m : Nat) : Finn n → (f : Finn n) → Finn n + | fzero, _ => fzero + | _, fzero => fzero + | fsucc i, fsucc j => fsucc (Finn.min (not x) (m + 1) i j) +termination_by structural i => i + +def Finn.min' (x : Bool) {n : Nat} (m : Nat) : Finn n → (f : Finn n) → Finn n + | fzero, _ => fzero + | _, fzero => fzero + | fsucc i, fsucc j => fsucc (Finn.min' (not x) (m + 1) i j) +termination_by structural _ j => j + +def Finn.min'' (x : Bool) {n : Nat} (m : Nat) : Finn n → (f : Finn n) → Finn n + | fzero, _ => fzero + | _, fzero => fzero + | fsucc i, fsucc j => fsucc (Finn.min'' (not x) (m + 1) i j) +termination_by structural n + +def Finn.le : Finn n → Finn n → Bool + | fzero, _ => true + | _, fzero => false + | fsucc i, fsucc j => Finn.le i j + +theorem Finn.min_le_right₀ : (Finn.min x m i j).le j := by + induction x, m, i, j using @Finn.min.induct <;> simp_all [Finn.min, Finn.le] + +theorem Finn.min_le_right : (Finn.min x m i j).le j := by + fun_induction Finn.min x m i j <;> simp_all [Finn.min, Finn.le] + +theorem Finn.min_le_right' : (Finn.min' x m i j).le j := by + fun_induction Finn.min' x m i j <;> simp_all [Finn.min', Finn.le] + +theorem Finn.min_le_right'' : (Finn.min'' x m i j).le j := by + fun_induction Finn.min'' x m i j <;> simp_all [Finn.min'', Finn.le] + +end StructuralIndices + +namespace Nonrec + +def foo := 1 + +/-- error: no functional cases theorem for 'foo', or function is mutually recursive -/ +#guard_msgs in +example : True := by + fun_induction foo + + +end Nonrec + +namespace Mutual + +inductive Tree (α : Type u) : Type u where + | node : α → (Bool → List (Tree α)) → Tree α + +-- Recursion over nested inductive + +mutual +def Tree.size : Tree α → Nat + | .node _ tsf => 1 + size_aux (tsf true) + size_aux (tsf false) +termination_by structural t => t +def Tree.size_aux : List (Tree α) → Nat + | [] => 0 + | t :: ts => size t + size_aux ts +end + +/-- error: no functional cases theorem for 'Tree.size', or function is mutually recursive -/ +#guard_msgs in +example (t : Tree α) : True := by + fun_induction Tree.size t + +end Mutual