lean4-htt/src/Lean/Elab/PreDefinition/WF/GuessLex.lean
Joachim Breitner f45c19b428
feat: identify more fixed parameters (#7166)
This PR extends the notion of “fixed parameter” of a recursive function
also to parameters that come after varying function. The main benefit is
that we get nicer induction principles.


Before the definition

```lean
def app (as : List α) (bs : List α) : List α :=
  match as with
  | [] => bs
  | a::as => a :: app as bs
```

produced

```lean
app.induct.{u_1} {α : Type u_1} (motive : List α → List α → Prop) (case1 : ∀ (bs : List α), motive [] bs)
  (case2 : ∀ (bs : List α) (a : α) (as : List α), motive as bs → motive (a :: as) bs) (as bs : List α) : motive as bs
```
and now you get
```lean
app.induct.{u_1} {α : Type u_1} (motive : List α → Prop) (case1 : motive [])
  (case2 : ∀ (a : α) (as : List α), motive as → motive (a :: as)) (as : List α) : motive as
```
because `bs` is fixed throughout the recursion (and can completely be
dropped from the principle).

This is a breaking change when such an induction principle is used
explicitly. Using `fun_induction` makes proof tactics robust against
this change.

The rules for when a parameter is fixed are now:

1. A parameter is fixed if it is reducibly defq to the the corresponding
argument in each recursive call, so we have to look at each such call.
2. With mutual recursion, it is not clear a-priori which arguments of
another function correspond to the parameter. This requires an analysis
with some graph algorithms to determine.
3. A parameter can only be fixed if all parameters occurring in its type
are fixed as well.
This dependency graph on parameters can be different for the different
functions in a recursive group, even leading to cycles.
4. For structural recursion, we kinda want to know the fixed parameters
before investigating which argument to actually recurs on. But once we
have that we may find that we fixed an index of the recursive
parameter’s type, and these cannot be fixed. So we have to un-fix them
5. … and all other fixed parameters that have dependencies on them.

Lean tries to identify the largest set of parameters that satisfies
these criteria.

Note that in a definition like
```lean
def app : List α → List α → List α
  | [], bs => bs
  | a::as, bs => a :: app as bs
```
the `bs` is not considered fixes, as it goes through the matcher
machinery.


Fixes #7027
Fixes #2113
2025-03-04 22:26:20 +00:00

827 lines
34 KiB
Text
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

/-
Copyright (c) 2023 Lean FRO, LLC. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Joachim Breitner
-/
prelude
import Lean.Util.HasConstCache
import Lean.Meta.Match.MatcherApp.Transform
import Lean.Meta.Tactic.Cleanup
import Lean.Meta.Tactic.Refl
import Lean.Meta.Tactic.TryThis
import Lean.Meta.ArgsPacker
import Lean.Elab.Quotation
import Lean.Elab.RecAppSyntax
import Lean.Elab.PreDefinition.Basic
import Lean.Elab.PreDefinition.Mutual
import Lean.Elab.PreDefinition.Structural.Basic
import Lean.Elab.PreDefinition.TerminationMeasure
import Lean.Elab.PreDefinition.FixedParams
import Lean.Elab.PreDefinition.WF.Basic
import Lean.Data.Array
/-!
This module finds lexicographic termination measures for well-founded recursion.
Starting with basic measures (`sizeOf xᵢ` for all parameters `xᵢ`), and complex measures
(e.g. `e₂ - e₁` if `e₁ < e₂` is found in the context of a recursive call) it tries all combinations
until it finds one where all proof obligations go through with the given tactic (`decerasing_by`),
if given, or the default `decreasing_tactic`.
For mutual recursion, a single measure is not just one parameter, but one from each recursive
function. Enumerating these can lead to a combinatoric explosion, so we bound
the number of measures tried.
In addition to measures derived from `sizeOf xᵢ`, it also considers measures
that assign an order to the functions themselves. This way we can support mutual
function definitions where no arguments decrease from one function to another.
The result of this module is a `TerminationWF`, which is then passed on to `wfRecursion`; this
design is crucial so that whatever we infer in this module could also be written manually by the
user. It would be bad if there are function definitions that can only be processed with the
guessed lexicographic order.
The following optimizations are applied to make this feasible:
1. The crucial optimization is to look at each measure of each recursive call
_once_, try to prove `<` and (if that fails `≤`), and then look at that table to
pick a suitable measure.
2. The next-crucial optimization is to fill that table lazily. This way, we run the (likely
expensive) tactics as few times as possible, while still being able to consider a possibly
large number of combinations.
3. Before we even try to prove `<`, we check if the measures are equal (`=`). No well-founded
measure will relate equal terms, likely this check is faster than firing up the tactic engine,
and it adds more signal to the output.
4. Instead of traversing the whole function body over and over, we traverse it once and store
the arguments (in unpacked form) and the local `MetaM` state at each recursive call
(see `collectRecCalls`), which we then re-use for the possibly many proof attempts.
The logic here is based on “Finding Lexicographic Orders for Termination Proofs in Isabelle/HOL”
by Lukas Bulwahn, Alexander Krauss, and Tobias Nipkow, 10.1007/978-3-540-74591-4_5
<https://www21.in.tum.de/~nipkow/pubs/tphols07.pdf>.
We got the idea of considering the measure `e₂ - e₁` if we see `e₁ < e₂` from
“Termination Analysis with Calling Context Graphs” by Panagiotis Manolios &
Daron Vroon, https://doi.org/10.1007/11817963_36.
-/
set_option autoImplicit false
open Lean Meta Elab
namespace Lean.Elab.WF.GuessLex
register_builtin_option showInferredTerminationBy : Bool := {
defValue := false
descr := "In recursive definitions, show the inferred `termination_by` measure."
}
/--
Given a predefinition, return the variable names in the outermost lambdas.
Includes the “fixed prefix”.
The length of the returned array is also used to determine the arity
of the function, so it should match what `packDomain` does.
-/
def originalVarNames (preDef : PreDefinition) : MetaM (Array Name) := do
lambdaTelescope preDef.value fun xs _ => xs.mapM (·.fvarId!.getUserName)
/--
Given the original parameter names from `originalVarNames`, find
good variable names to be used when talking about termination measures:
Use user-given parameter names if present; use x1...xn otherwise.
The names ought to accessible (no macro scopes) and fresh wrt to the current environment,
so that with `showInferredTerminationBy` we can print them to the user reliably.
We do that by appending `'` as needed.
It is possible (but unlikely without malice) that some of the user-given names
shadow each other, and the guessed relation refers to the wrong one. In that
case, the user gets to keep both pieces (and may have to rename variables).
-/
partial
def naryVarNames (xs : Array Name) : MetaM (Array Name) := do
let mut ns : Array Name := #[]
for h : i in [:xs.size] do
let n := xs[i]
let n ← if n.hasMacroScopes then
freshen ns (.mkSimple s!"x{i+1}")
else
pure n
ns := ns.push n
return ns
where
freshen (ns : Array Name) (n : Name): MetaM Name := do
if !(ns.elem n) && (← resolveGlobalName n).isEmpty then
return n
else
freshen ns (n.appendAfter "'")
/-- A termination measure with extra fields for use within GuessLex -/
structure BasicMeasure extends TerminationMeasure where
/--
Like `.fn`, but unconditionally with `sizeOf` at the right type.
We use this one when in `evalRecCall`
-/
natFn : Expr
deriving Inhabited
/-- String description of this measure -/
def BasicMeasure.toString (measure : BasicMeasure) : MetaM String := do
lambdaTelescope measure.fn fun _xs e => do
-- This is a bit slopping if `measure.fn` takes more parameters than the `PreDefinition`
return (← ppExpr e).pretty
/--
Determine if the measure for parameter `x` should be `sizeOf x` or just `x`.
For non-mutual definitions, we omit `sizeOf` when the measure does not depend on
the other varying parameters, and its `WellFoundedRelation` instance goes via `SizeOf`.
For mutual definitions, we omit `sizeOf` only when the measure is (at reducible transparency!) of
type `Nat` (else we'd have to worry about differently-typed measures from different functions to
line up).
-/
def mayOmitSizeOf (is_mutual : Bool) (args : Array Expr) (x : Expr) : MetaM Bool := do
let t ← inferType x
if is_mutual
then
withReducible (isDefEq t (.const `Nat []))
else
try
if t.hasAnyFVar (fun fvar => args.contains (.fvar fvar)) then
pure false
else
let u ← getLevel t
let wfi ← synthInstance (.app (.const ``WellFoundedRelation [u]) t)
let soi ← synthInstance (.app (.const ``SizeOf [u]) t)
isDefEq wfi (mkApp2 (.const ``sizeOfWFRel [u]) t soi)
catch _ =>
pure false
/-- Sets the user names for the given freevars in `xs`. -/
def withUserNames {α} (xs : Array Expr) (ns : Array Name) (k : MetaM α) : MetaM α := do
let mut lctx ← getLCtx
for x in xs, n in ns do lctx := lctx.setUserName x.fvarId! n
withLCtx' lctx k
/-- Create one measure for each (eligible) parameter of the given predefintion. -/
def simpleMeasures (preDefs : Array PreDefinition) (fixedParamPerms : FixedParamPerms)
(userVarNamess : Array (Array Name)) : MetaM (Array (Array BasicMeasure)) := do
let is_mutual : Bool := preDefs.size > 1
preDefs.mapIdxM fun funIdx preDef => do
lambdaTelescope preDef.value fun params _ => do
let xs := fixedParamPerms.perms[funIdx]!.pickVarying params
withUserNames xs userVarNamess[funIdx]! do
let mut ret : Array BasicMeasure := #[]
for x in xs do
-- If the `SizeOf` instance produces a constant (e.g. because it's type is a `Prop` or
-- `Type`), then ignore this parameter
let sizeOf ← whnfD (← mkAppM ``sizeOf #[x])
if sizeOf.isLit then continue
let natFn ← mkLambdaFVars params (← mkAppM ``sizeOf #[x])
-- Determine if we need to exclude `sizeOf` in the measure we show/pass on.
let fn ←
if ← mayOmitSizeOf is_mutual xs x
then mkLambdaFVars params x
else pure natFn
ret := ret.push { ref := .missing, structural := false, fn, natFn }
return ret
/-- Internal monad used by `withRecApps` -/
abbrev M (recFnName : Name) (α β : Type) : Type :=
StateRefT (Array α) (StateRefT (HasConstCache #[recFnName]) MetaM) β
/--
Traverses the given expression `e`, and invokes the continuation `k`
at every saturated call to `recFnName`.
The expression `param` is passed along, and refined when going under a matcher
or `casesOn` application.
-/
partial def withRecApps {α} (recFnName : Name) (fixedPrefixSize : Nat) (param : Expr) (e : Expr)
(k : Expr → Array Expr → MetaM α) : MetaM (Array α) := do
trace[Elab.definition.wf] "withRecApps (param {param}): {indentExpr e}"
let (_, as) ← loop param e |>.run #[] |>.run' {}
return as
where
processRec (param : Expr) (e : Expr) : M recFnName α Unit := do
if e.getAppNumArgs < fixedPrefixSize + 1 then
loop param (← etaExpand e)
else
let a ← k param e.getAppArgs
modifyThe (Array α) (·.push a)
processApp (param : Expr) (e : Expr) : M recFnName α Unit := do
e.withApp fun f args => do
args.forM (loop param)
if f.isConstOf recFnName then
processRec param e
else
loop param f
containsRecFn (e : Expr) : M recFnName α Bool := do
modifyGetThe (HasConstCache #[recFnName]) (·.contains e)
loop (param : Expr) (e : Expr) : M recFnName α Unit := do
if !(← containsRecFn e) then
return
match e with
| Expr.lam n d b c =>
loop param d
withLocalDecl n c d fun x => do
loop param (b.instantiate1 x)
| Expr.forallE n d b c =>
loop param d
withLocalDecl n c d fun x => do
loop param (b.instantiate1 x)
| Expr.letE n type val body _ =>
loop param type
loop param val
withLetDecl n type val fun x => do
loop param (body.instantiate1 x)
| Expr.mdata _d b =>
if let some stx := getRecAppSyntax? e then
withRef stx <| loop param b
else
loop param b
| Expr.proj _n _i e => loop param e
| Expr.const .. => if e.isConstOf recFnName then processRec param e
| Expr.app .. =>
match (← matchMatcherApp? (alsoCasesOn := true) e) with
| some matcherApp =>
if let some altParams ← matcherApp.refineThrough? param then
matcherApp.discrs.forM (loop param)
(Array.zip matcherApp.alts (Array.zip matcherApp.altNumParams altParams)).forM
fun (alt, altNumParam, altParam) =>
lambdaBoundedTelescope altParam altNumParam fun xs altParam => do
unless altNumParam = xs.size do
throwError "unexpected `casesOn` application alternative{indentExpr alt}\nat application{indentExpr e}"
let altBody := alt.beta xs
loop altParam altBody
matcherApp.remaining.forM (loop param)
else
processApp param e
| none => processApp param e
| e => do
ensureNoRecFn #[recFnName] e
/--
A `SavedLocalContext` captures the state and local context of a `MetaM`, to be continued later.
-/
structure SavedLocalContext where
savedLocalContext : LocalContext
savedLocalInstances : LocalInstances
savedState : Meta.SavedState
/-- Capture the `MetaM` state including local context. -/
def SavedLocalContext.create : MetaM SavedLocalContext := do
let savedLocalContext ← getLCtx
let savedLocalInstances ← getLocalInstances
let savedState ← saveState
return { savedLocalContext, savedLocalInstances, savedState }
/-- Run a `MetaM` action in the saved state. -/
def SavedLocalContext.run {α} (slc : SavedLocalContext) (k : MetaM α) :
MetaM α :=
withoutModifyingState $ do
withLCtx slc.savedLocalContext slc.savedLocalInstances do
slc.savedState.restore
k
/-- A `RecCallWithContext` focuses on a single recursive call in a unary predefinition,
and runs the given action in the context of that call. -/
structure RecCallWithContext where
/-- Syntax location of recursive call -/
ref : Syntax
/-- Function index of caller -/
caller : Nat
/-- Parameters of caller (including fixed prefix) -/
params : Array Expr
/-- Function index of callee -/
callee : Nat
/-- Arguments to callee (including fixed prefix) -/
args : Array Expr
ctxt : SavedLocalContext
/-- Store the current recursive call and its context. -/
def RecCallWithContext.create (ref : Syntax) (caller : Nat) (params : Array Expr) (callee : Nat)
(args : Array Expr) : MetaM RecCallWithContext := do
return { ref, caller, params, callee, args, ctxt := (← SavedLocalContext.create) }
/--
The elaborator is prone to duplicate terms, including recursive calls, even if the user
only wrote a single one. This duplication is wasteful if we run the tactics on duplicated
calls, and confusing in the output of GuessLex. So prune the list of recursive calls,
and remove those where another call exists that has the same goal and context that is no more
specific.
-/
def filterSubsumed (rcs : Array RecCallWithContext ) : Array RecCallWithContext := Id.run do
rcs.filterPairsM fun rci rcj => do
if rci.caller == rcj.caller && rci.callee == rcj.callee &&
rci.params == rcj.params && rci.args == rcj.args then
-- same goals; check contexts.
let lci := rci.ctxt.savedLocalContext
let lcj := rcj.ctxt.savedLocalContext
if lci.isSubPrefixOf lcj then
-- rci is better
return (true, false)
else if lcj.isSubPrefixOf lci then
-- rcj is better
return (false, true)
return (true, true)
/--
Traverse a unary `PreDefinition`, and returns a `WithRecCall` closure for each recursive
call site.
-/
def collectRecCalls (unaryPreDef : PreDefinition) (fixedParamPerms : FixedParamPerms)
(argsPacker : ArgsPacker) : MetaM (Array RecCallWithContext) := withoutModifyingState do
addAsAxiom unaryPreDef
lambdaBoundedTelescope unaryPreDef.value (fixedParamPerms.numFixed + 1) fun xs body => do
unless xs.size == fixedParamPerms.numFixed + 1 do
throwError "Unexpected number of lambdas in unary pre-definition"
let ys := xs[:fixedParamPerms.numFixed]
let param := xs[fixedParamPerms.numFixed]!
withRecApps unaryPreDef.declName fixedParamPerms.numFixed param body fun param args => do
unless args.size ≥ fixedParamPerms.numFixed + 1 do
throwError "Insufficient arguments in recursive call"
let arg := args[fixedParamPerms.numFixed]!
trace[Elab.definition.wf] "collectRecCalls: {unaryPreDef.declName} ({param}) → {unaryPreDef.declName} ({arg})"
let some (caller, params) := argsPacker.unpack param
| throwError "Cannot unpack param, unexpected expression:{indentExpr param}"
let some (callee, args) := argsPacker.unpack arg
| throwError "Cannot unpack arg, unexpected expression:{indentExpr arg}"
let callerParams := fixedParamPerms.perms[caller]!.buildArgs ys params
let calleeArgs := fixedParamPerms.perms[callee]!.buildArgs ys args
RecCallWithContext.create (← getRef) caller callerParams callee calleeArgs
/-- Is the expression a `<`-like comparison of `Nat` expressions -/
def isNatCmp (e : Expr) : Option (Expr × Expr) :=
match_expr e with
| LT.lt α _ e₁ e₂ => if α.isConstOf ``Nat then some (e₁, e₂) else none
| LE.le α _ e₁ e₂ => if α.isConstOf ``Nat then some (e₁, e₂) else none
| GT.gt α _ e₁ e₂ => if α.isConstOf ``Nat then some (e₂, e₁) else none
| GE.ge α _ e₁ e₂ => if α.isConstOf ``Nat then some (e₂, e₁) else none
| _ => none
def complexMeasures (preDefs : Array PreDefinition) (fixedParamPerms : FixedParamPerms)
(userVarNamess : Array (Array Name)) (recCalls : Array RecCallWithContext) :
MetaM (Array (Array BasicMeasure)) := do
preDefs.mapIdxM fun funIdx _preDef => do
let mut measures := #[]
for rc in recCalls do
-- Only look at calls from the current function
unless rc.caller = funIdx do continue
-- Only look at calls where the parameters have not been refined
unless rc.params.all (·.isFVar) do continue
let varyingParams := fixedParamPerms.perms[funIdx]!.pickVarying rc.params
let varyingFVars := varyingParams.map (·.fvarId!)
let params := rc.params.map (·.fvarId!)
measures ← rc.ctxt.run do
withUserNames varyingParams userVarNamess[funIdx]! do
trace[Elab.definition.wf] "rc: {rc.caller} ({rc.params}) → {rc.callee} ({rc.args})"
let mut measures := measures
for ldecl in ← getLCtx do
if let some (e₁, e₂) := isNatCmp ldecl.type then
-- We only want to consider these expressions if they depend only on the function's
-- immediate arguments, so check that
if e₁.hasAnyFVar (! params.contains ·) then continue
if e₂.hasAnyFVar (! params.contains ·) then continue
-- If e₁ does not depend on any varying parameters, simply ignore it
let e₁_is_const := ! e₁.hasAnyFVar (varyingFVars.contains ·)
let body := if e₁_is_const then e₂ else mkNatSub e₂ e₁
-- Avoid adding simple measures
unless body.isFVar do
let fn ← mkLambdaFVars rc.params body
-- Avoid duplicates
unless ← measures.anyM (isDefEq ·.fn fn) do
measures := measures.push { ref := .missing, structural := false, fn, natFn := fn }
return measures
return measures
/-- A `GuessLexRel` described how a recursive call affects a measure; whether it
decreases strictly, non-strictly, is equal, or else. -/
inductive GuessLexRel | lt | eq | le | no_idea
deriving Repr, DecidableEq
instance : ToString GuessLexRel where
toString | .lt => "<"
| .eq => "="
| .le => "≤"
| .no_idea => "?"
instance : ToFormat GuessLexRel where
format r := toString r
/-- Given a `GuessLexRel`, produce a binary `Expr` that relates two `Nat` values accordingly. -/
def GuessLexRel.toNatRel : GuessLexRel → Expr
| lt => mkAppN (mkConst ``LT.lt [levelZero]) #[mkConst ``Nat, mkConst ``instLTNat]
| eq => mkAppN (mkConst ``Eq [levelOne]) #[mkConst ``Nat]
| le => mkAppN (mkConst ``LE.le [levelZero]) #[mkConst ``Nat, mkConst ``instLENat]
| no_idea => unreachable!
/--
For a given recursive call, and a choice of parameter and argument index,
try to prove equality, < or ≤.
-/
def evalRecCall (callerName: Name) (decrTactic? : Option DecreasingBy) (callerMeasures calleeMeasures : Array BasicMeasure)
(rcc : RecCallWithContext) (callerMeasureIdx calleeMeasureIdx : Nat) : MetaM GuessLexRel := do
rcc.ctxt.run do
let callerMeasure := callerMeasures[callerMeasureIdx]!
let calleeMeasure := calleeMeasures[calleeMeasureIdx]!
let param := callerMeasure.natFn.beta rcc.params
let arg := calleeMeasure.natFn.beta rcc.args
trace[Elab.definition.wf] "inspectRecCall: {rcc.caller} ({param}) → {rcc.callee} ({arg})"
for rel in [GuessLexRel.eq, .lt, .le] do
let goalExpr := mkAppN rel.toNatRel #[arg, param]
trace[Elab.definition.wf] "Goal for {rel}: {goalExpr}"
check goalExpr
let mvar ← mkFreshExprSyntheticOpaqueMVar goalExpr
let mvarId := mvar.mvarId!
let mvarId ← mvarId.cleanup
try
if rel = .eq then
MVarId.refl mvarId
else do
Lean.Elab.Term.TermElabM.run' do Term.withDeclName callerName do
Term.withoutErrToSorry do
let remainingGoals ← Tactic.run mvarId do Tactic.withoutRecover do
applyCleanWfTactic
let tacticStx : Syntax ←
match decrTactic? with
| none => pure (← `(tactic| decreasing_tactic)).raw
| some decrTactic =>
trace[Elab.definition.wf] "Using tactic {decrTactic.tactic.raw}"
pure decrTactic.tactic.raw
Tactic.evalTactic tacticStx
remainingGoals.forM fun _ => throwError "goal not solved"
trace[Elab.definition.wf] "inspectRecCall: success!"
return rel
catch e =>
trace[Elab.definition.wf] "Did not find {rel} proof. Goal:{goalsToMessageData [mvarId]}\nError:{indentD e.toMessageData}"
continue
return .no_idea
/- A cache for `evalRecCall` -/
structure RecCallCache where mk'' ::
callerName : Name
decrTactic? : Option DecreasingBy
callerMeasures : Array BasicMeasure
calleeMeasures : Array BasicMeasure
rcc : RecCallWithContext
cache : IO.Ref (Array (Array (Option GuessLexRel)))
/-- Create a cache to memoize calls to `evalRecCall descTactic? rcc` -/
def RecCallCache.mk (funNames : Array Name) (decrTactics : Array (Option DecreasingBy)) (measuress : Array (Array BasicMeasure))
(rcc : RecCallWithContext) :
BaseIO RecCallCache := do
let callerName := funNames[rcc.caller]!
let decrTactic? := decrTactics[rcc.caller]!
let callerMeasures := measuress[rcc.caller]!
let calleeMeasures := measuress[rcc.callee]!
let cache ← IO.mkRef <| Array.mkArray callerMeasures.size (Array.mkArray calleeMeasures.size Option.none)
return { callerName, decrTactic?, callerMeasures, calleeMeasures, rcc, cache }
/-- Run `evalRecCall` and cache there result -/
def RecCallCache.eval (rc: RecCallCache) (callerMeasureIdx calleeMeasureIdx : Nat) : MetaM GuessLexRel := do
-- Check the cache first
if let Option.some res := (← rc.cache.get)[callerMeasureIdx]![calleeMeasureIdx]! then
return res
else
let res ← evalRecCall rc.callerName rc.decrTactic? rc.callerMeasures rc.calleeMeasures rc.rcc callerMeasureIdx calleeMeasureIdx
rc.cache.modify (·.modify callerMeasureIdx (·.set! calleeMeasureIdx res))
return res
/-- Print a single cache entry as a string, without forcing it -/
def RecCallCache.prettyEntry (rcc : RecCallCache) (callerMeasureIdx calleeMeasureIdx : Nat) : MetaM String := do
let cachedEntries ← rcc.cache.get
return match cachedEntries[callerMeasureIdx]![calleeMeasureIdx]! with
| .some rel => toString rel
| .none => "_"
/-- The measures that we order lexicographically can be comparing basic measures,
or numbering the functions -/
inductive MutualMeasure where
/-- For every function, the given argument index -/
| args : Array Nat → MutualMeasure
/-- The given function index is assigned 1, the rest 0 -/
| func : Nat → MutualMeasure
/-- Evaluate a recursive call at a given `MutualMeasure` -/
def inspectCall (rc : RecCallCache) : MutualMeasure → MetaM GuessLexRel
| .args tmIdxs => do
let callerMeasureIdx := tmIdxs[rc.rcc.caller]!
let calleeMeasureIdx := tmIdxs[rc.rcc.callee]!
rc.eval callerMeasureIdx calleeMeasureIdx
| .func funIdx => do
if rc.rcc.caller == funIdx && rc.rcc.callee != funIdx then
return .lt
if rc.rcc.caller != funIdx && rc.rcc.callee == funIdx then
return .no_idea
else
return .eq
/--
Generate all combination of measures. Assumes we have numbered the measures of each function,
and their counts is in `numMeasures`.
This puts the uniform combinations ([0,0,0], [1,1,1]) to the front; they are commonly most useful to
try first, when the mutually recursive functions have similar argument structures
-/
partial def generateCombinations? (numMeasures : Array Nat) (threshold : Nat := 32) :
Option (Array (Array Nat)) :=
(do goUniform 0; go 0 #[]) |>.run #[] |>.2
where
-- Enumerate all permissible uniform combinations
goUniform (idx : Nat) : OptionT (StateM (Array (Array Nat))) Unit := do
if numMeasures.all (idx < ·) then
modify (·.push (Array.mkArray numMeasures.size idx))
goUniform (idx + 1)
-- Enumerate all other permissible combinations
go (fidx : Nat) : OptionT (ReaderT (Array Nat) (StateM (Array (Array Nat)))) Unit := do
if h : fidx < numMeasures.size then
let n := numMeasures[fidx]
for idx in [:n] do withReader (·.push idx) (go (fidx + 1))
else
let comb ← read
unless comb.all (· == comb[0]!) do
modify (·.push comb)
if (← get).size > threshold then
failure
/--
Enumerate all measures we want to try.
All measures (resp. combinations thereof) and
possible orderings of functions (if more than one)
-/
def generateMeasures (numMeasures : Array Nat) : MetaM (Array MutualMeasure) := do
let some arg_measures := generateCombinations? numMeasures
| throwError "Too many combinations"
let func_measures :=
if numMeasures.size > 1 then
(List.range numMeasures.size).toArray
else
#[]
return arg_measures.map .args ++ func_measures.map .func
/--
The core logic of guessing the lexicographic order
Given a matrix that for each call and measure indicates whether that measure is
decreasing, equal, less-or-equal or unknown, It finds a sequence of measures
that is lexicographically decreasing.
The matrix is implemented here as an array of monadic query methods only so that
we can fill is lazily. Morally, this is a pure function
-/
partial def solve {m} {α} [Monad m] (measures : Array α)
(calls : Array (α → m GuessLexRel)) : m (Option (Array α)) := do
go measures calls #[]
where
go (measures : Array α) (calls : Array (α → m GuessLexRel)) (acc : Array α) := do
if calls.isEmpty then return .some acc
-- Find the first measure that has at least one < and otherwise only = or <=
for h : measureIdx in [:measures.size] do
let measure := measures[measureIdx]
let mut has_lt := false
let mut all_le := true
let mut todo := #[]
for call in calls do
let entry ← call measure
if entry = .lt then
has_lt := true
else
todo := todo.push call
if entry != .le && entry != .eq then
all_le := false
break
-- No progress here? Try the next measure
if !(has_lt && all_le) then continue
-- We found a suitable measure, remove it from the list (mild optimization)
let measures' := measures.eraseIdx measureIdx
return ← go measures' todo (acc.push measure)
-- None found, we have to give up
return .none
/--
Given a matrix (row-major) of strings, arranges them in tabular form.
First column is left-aligned, others right-aligned.
Single space as column separator.
-/
def formatTable : Array (Array String) → String := fun xss => Id.run do
let mut colWidths := xss[0]!.map (fun _ => 0)
for hi : i in [:xss.size] do
for hj : j in [:xss[i].size] do
if xss[i][j].length > colWidths[j]! then
colWidths := colWidths.set! j xss[i][j].length
let mut str := ""
for hi : i in [:xss.size] do
for hj : j in [:xss[i].size] do
let s := xss[i][j]
if j > 0 then -- right-align
for _ in [:colWidths[j]! - s.length] do
str := str ++ " "
str := str ++ s
if j = 0 then -- left-align
for _ in [:colWidths[j]! - s.length] do
str := str ++ " "
if j + 1 < xss[i].size then
str := str ++ " "
if i + 1 < xss.size then
str := str ++ "\n"
return str
/-- Concise textual representation of the source location of a recursive call -/
def RecCallWithContext.posString (rcc : RecCallWithContext) : MetaM String := do
let fileMap ← getFileMap
let .some pos := rcc.ref.getPos? | return ""
let position := fileMap.toPosition pos
let endPosStr := match rcc.ref.getTailPos? with
| some endPos =>
let endPosition := fileMap.toPosition endPos
if endPosition.line = position.line then
s!"-{endPosition.column}"
else
s!"-{endPosition.line}:{endPosition.column}"
| none => ""
return s!"{position.line}:{position.column}{endPosStr}"
/-- How to present the basic measure in the table header, possibly abbreviated. -/
def measureHeader (measure : BasicMeasure) : StateT (Nat × String) MetaM String := do
let s ← measure.toString
if s.length > 5 then
let (i, footer) ← get
let i := i + 1
let footer := footer ++ s!"#{i}: {s}\n"
set (i, footer)
pure s!"#{i}"
else
pure s
def collectHeaders {α} (a : StateT (Nat × String) MetaM α) : MetaM (α × String) := do
let (x, (_, footer)) ← a.run (0, "")
pure (x,footer)
/-- Explain what we found out about the recursive calls (non-mutual case) -/
def explainNonMutualFailure (measures : Array BasicMeasure) (rcs : Array RecCallCache) : MetaM Format := do
let (header, footer) ← collectHeaders (measures.mapM measureHeader)
let mut table : Array (Array String) := #[#[""] ++ header]
for i in [:rcs.size], rc in rcs do
let mut row := #[s!"{i+1}) {← rc.rcc.posString}"]
for argIdx in [:measures.size] do
row := row.push (← rc.prettyEntry argIdx argIdx)
table := table.push row
let out := formatTable table
if footer.isEmpty then
return out
else
return out ++ "\n\n" ++ footer
/-- Explain what we found out about the recursive calls (mutual case) -/
def explainMutualFailure (declNames : Array Name) (measuress : Array (Array BasicMeasure))
(rcs : Array RecCallCache) : MetaM Format := do
let (headerss, footer) ← collectHeaders (measuress.mapM (·.mapM measureHeader))
let mut r := Format.nil
for rc in rcs do
let caller := rc.rcc.caller
let callee := rc.rcc.callee
r := r ++ f!"Call from {declNames[caller]!} to {declNames[callee]!} " ++
f!"at {← rc.rcc.posString}:\n"
let mut table : Array (Array String) := #[#[""] ++ headerss[caller]!]
if caller = callee then
-- For self-calls, only the diagonal is interesting, so put it into one row
let mut row := #[""]
for argIdx in [:measuress[caller]!.size] do
row := row.push (← rc.prettyEntry argIdx argIdx)
table := table.push row
else
for argIdx in [:measuress[callee]!.size] do
let mut row := #[]
row := row.push headerss[callee]![argIdx]!
for paramIdx in [:measuress[caller]!.size] do
row := row.push (← rc.prettyEntry paramIdx argIdx)
table := table.push row
r := r ++ formatTable table ++ "\n"
unless footer.isEmpty do
r := r ++ "\n\n" ++ footer
return r
def explainFailure (declNames : Array Name) (measuress : Array (Array BasicMeasure))
(rcs : Array RecCallCache) : MetaM Format := do
let mut r : Format := "The basic measures relate at each recursive call as follows:\n" ++
"(<, ≤, =: relation proved, ? all proofs failed, _: no proof attempted)\n"
if declNames.size = 1 then
r := r ++ (← explainNonMutualFailure measuress[0]! rcs)
else
r := r ++ (← explainMutualFailure declNames measuress rcs)
return r
/--
For `#[x₁, .., xₙ]` create `(x₁, .., xₙ)`.
-/
def mkProdElem (xs : Array Expr) : MetaM Expr := do
match xs.size with
| 0 => return default
| 1 => return xs[0]!
| _ =>
let n := xs.size
xs[0:n-1].foldrM (init:=xs[n-1]!) fun x p => mkAppM ``Prod.mk #[x,p]
def toTerminationMeasures (preDefs : Array PreDefinition) (fixedParamPerms : FixedParamPerms)
(userVarNamess : Array (Array Name)) (measuress : Array (Array BasicMeasure))
(solution : Array MutualMeasure) : MetaM TerminationMeasures := do
preDefs.mapIdxM fun funIdx preDef => do
let measures := measuress[funIdx]!
lambdaTelescope preDef.value fun params _ => do
let xs := fixedParamPerms.perms[funIdx]!.pickVarying params
withUserNames xs userVarNamess[funIdx]! do
let args := solution.map fun
| .args tmIdxs => measures[tmIdxs[funIdx]!]!.fn.beta params
| .func funIdx' => mkNatLit <| if funIdx' == funIdx then 1 else 0
let fn ← mkLambdaFVars params (← mkProdElem args)
return { ref := .missing, structural := false, fn}
/--
Shows the inferred termination measure to the user, and implements `termination_by?`
-/
def reportTerminationMeasures (preDefs : Array PreDefinition) (termMeasures : TerminationMeasures) : MetaM Unit := do
for preDef in preDefs, termMeasure in termMeasures do
let stx := do
let arity ← lambdaTelescope preDef.value fun xs _ => pure xs.size
termMeasure.delab arity (extraParams := preDef.termination.extraParams)
if showInferredTerminationBy.get (← getOptions) then
logInfoAt preDef.ref m!"Inferred termination measure:\n{← stx}"
if let some ref := preDef.termination.terminationBy?? then
Tactic.TryThis.addSuggestion ref (← stx)
end GuessLex
open GuessLex
/--
Main entry point of this module:
Try to find a lexicographic ordering of the basic measures for which the recursive definition
terminates. See the module doc string for a high-level overview.
The `preDefs` are used to determine arity and types of parameters; the bodies are ignored.
-/
def guessLex (preDefs : Array PreDefinition) (unaryPreDef : PreDefinition)
(fixedParamPerms : FixedParamPerms) (argsPacker : ArgsPacker) :
MetaM TerminationMeasures := do
let userVarNamess ← argsPacker.varNamess.mapM (naryVarNames ·)
trace[Elab.definition.wf] "varNames is: {userVarNamess}"
-- Collect all recursive calls and extract their context
let recCalls ← collectRecCalls unaryPreDef fixedParamPerms argsPacker
let recCalls := filterSubsumed recCalls
-- For every function, the measures we want to use
-- (One for each non-forbiddend arg)
let basicMeassures₁ ← simpleMeasures preDefs fixedParamPerms userVarNamess
let basicMeassures₂ ← complexMeasures preDefs fixedParamPerms userVarNamess recCalls
let basicMeasures := Array.zipWith (· ++ ·) basicMeassures₁ basicMeassures₂
-- The list of measures, including the measures that order functions.
-- The function ordering measures come last
let mutualMeasures ← generateMeasures (basicMeasures.map (·.size))
-- If there is only one plausible measure, use that
if let #[solution] := mutualMeasures then
let termMeasures ← toTerminationMeasures preDefs fixedParamPerms userVarNamess basicMeasures #[solution]
reportTerminationMeasures preDefs termMeasures
return termMeasures
let rcs ← recCalls.mapM (RecCallCache.mk (preDefs.map (·.declName)) (preDefs.map (·.termination.decreasingBy?)) basicMeasures ·)
let callMatrix := rcs.map (inspectCall ·)
match ← liftMetaM <| solve mutualMeasures callMatrix with
| .some solution => do
let termMeasures ← toTerminationMeasures preDefs fixedParamPerms userVarNamess basicMeasures solution
reportTerminationMeasures preDefs termMeasures
return termMeasures
| .none =>
let explanation ← explainFailure (preDefs.map (·.declName)) basicMeasures rcs
Lean.throwError <| "Could not find a decreasing measure.\n" ++
explanation ++ "\n" ++
"Please use `termination_by` to specify a decreasing measure."