feat: GuessLex: avoid writing sizeOf in termination argument when not needed (#3630)

this makes `termination_by?` even slicker.

The heuristics is agressive in the non-mutual case (will omit `sizeOf`
if the argument is non-dependent and the `WellFoundedRelation` relation
is via `sizeOfWFRel`.

In the mutual case we'd also have to check the arguments, as they line
up in the termination argument, have the same types. I did not bother at
this point; in the mutual case we omit `sizeOf` only if the argument
type is `Nat`.

As a drive-by fix, `termination_by?` now also works on functions that
have only one plausible measure.
This commit is contained in:
Joachim Breitner 2024-03-10 23:57:10 +01:00 committed by GitHub
parent 1d3ef577c2
commit 32dcc6eb89
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 269 additions and 74 deletions

View file

@ -302,7 +302,11 @@ def GuessLexRel.toNatRel : GuessLexRel → Expr
| le => mkAppN (mkConst ``LE.le [levelZero]) #[mkConst ``Nat, mkConst ``instLENat]
| no_idea => unreachable!
/-- Given an expression `e`, produce `sizeOf e` with a suitable instance. -/
/--
Given an expression `e`, produce `sizeOf e` with a suitable instance.
NB: We must use the instance of the type of the function parameter!
The concrete argument at hand may have a different (still def-eq) typ.
-/
def mkSizeOf (e : Expr) : MetaM Expr := do
let ty ← inferType e
let lvl ← getLevel ty
@ -315,8 +319,8 @@ def mkSizeOf (e : Expr) : MetaM Expr := do
For a given recursive call, and a choice of parameter and argument index,
try to prove equality, < or ≤.
-/
def evalRecCall (decrTactic? : Option DecreasingBy) (rcc : RecCallWithContext) (paramIdx argIdx : Nat) :
MetaM GuessLexRel := do
def evalRecCall (decrTactic? : Option DecreasingBy) (rcc : RecCallWithContext)
(paramIdx argIdx : Nat) : MetaM GuessLexRel := do
rcc.ctxt.run do
let param := rcc.params[paramIdx]!
let arg := rcc.args[argIdx]!
@ -407,7 +411,7 @@ def inspectCall (rc : RecCallCache) : MutualMeasure → MetaM GuessLexRel
return .eq
/--
Given a predefinition with value `fun (x_₁ ... xₙ) (y_₁ : α₁)... (yₘ : αₘ) => ...`,
Given a predefinition with value `fun (x₁ ... xₙ) (y₁ : α₁)... (yₘ : αₘ) => ...`,
where `n = fixedPrefixSize`, return an array `A` s.t. `i ∈ A` iff `sizeOf yᵢ` reduces to a literal.
This is the case for types such as `Prop`, `Type u`, etc.
These arguments should not be considered when guessing a well-founded relation.
@ -425,6 +429,47 @@ def getForbiddenByTrivialSizeOf (fixedPrefixSize : Nat) (preDef : PreDefinition)
result := result.push i
return result
/--
Given a predefinition with value `fun (x₁ ... xₙ) (y₁ : α₁)... (yₘ : αₘ) => ...`,
where `n = fixedPrefixSize`, return an array `A` s.t. `i ∈ A` iff the
`WellFoundedRelation` of `aᵢ` goes via `SizeOf`, and `aᵢ` does not depend on `y₁`….
These are the parameters for which we omit an explicit call to `sizeOf` in the termination argument.
We only use this in the non-mutual case; in the mutual case we would have to additional check
if the parameters that line up in the actual `TerminationWF` have the same type.
-/
def getSizeOfParams (fixedPrefixSize : Nat) (preDef : PreDefinition) : MetaM (Array Nat) :=
lambdaTelescope preDef.value fun xs _ => do
let xs : Array Expr := xs[fixedPrefixSize:]
let mut result := #[]
for x in xs, i in [:xs.size] do
try
let t ← inferType x
if t.hasAnyFVar (fun fvar => xs.contains (.fvar fvar)) then continue
let u ← getLevel t
let wfi ← synthInstance (.app (.const ``WellFoundedRelation [u]) t)
let soi ← synthInstance (.app (.const ``SizeOf [u]) t)
if ← isDefEq wfi (mkApp2 (.const ``sizeOfWFRel [u]) t soi) then
result := result.push i
catch _ =>
pure ()
return result
/--
Given a predefinition with value `fun (x₁ ... xₙ) (y₁ : α₁)... (yₘ : αₘ) => ...`,
where `n = fixedPrefixSize`, return an array `A` s.t. `i ∈ A` iff `aᵢ` is `Nat`.
These are parameters where we can definitely omit the call to `sizeOf`.
-/
def getNatParams (fixedPrefixSize : Nat) (preDef : PreDefinition) : MetaM (Array Nat) :=
lambdaTelescope preDef.value fun xs _ => do
let xs : Array Expr := xs[fixedPrefixSize:]
let mut result := #[]
for x in xs, i in [:xs.size] do
let t ← inferType x
if ← withReducible (isDefEq t (.const `Nat [])) then
result := result.push i
return result
/--
Generate all combination of arguments, skipping those that are forbidden.
@ -539,23 +584,26 @@ combination of these measures. The parameters are
* `measures`: The measures to be used.
-/
def buildTermWF (originalVarNamess : Array (Array Name)) (varNamess : Array (Array Name))
(measures : Array MutualMeasure) : MetaM TerminationWF := do
(needsNoSizeOf : Array (Array Nat)) (measures : Array MutualMeasure) : MetaM TerminationWF := do
varNamess.mapIdxM fun funIdx varNames => do
let idents := varNames.map mkIdent
let measureStxs ← measures.mapM fun
| .args varIdxs => do
let varIdx := varIdxs[funIdx]!
let v := idents[varIdx]!
-- Print `sizeOf` as such, unless it is shadowed.
-- Shadowing by a `def` in the current namespace is handled by `unresolveNameGlobal`.
-- But it could also be shadowed by an earlier parameter (including the fixed prefix),
-- so look for unqualified (single tick) occurrences in `originalVarNames`
let sizeOfIdent :=
if originalVarNamess[funIdx]!.any (· = `sizeOf) then
mkIdent ``sizeOf -- fully qualified
else
mkIdent (← unresolveNameGlobal ``sizeOf)
`($sizeOfIdent $v)
if needsNoSizeOf[funIdx]!.contains varIdx then
`($v)
else
-- Print `sizeOf` as such, unless it is shadowed.
-- Shadowing by a `def` in the current namespace is handled by `unresolveNameGlobal`.
-- But it could also be shadowed by an earlier parameter (including the fixed prefix),
-- so look for unqualified (single tick) occurrences in `originalVarNames`
let sizeOfIdent :=
if originalVarNamess[funIdx]!.any (· = `sizeOf) then
mkIdent ``sizeOf -- fully qualified
else
mkIdent (← unresolveNameGlobal ``sizeOf)
`($sizeOfIdent $v)
| .func funIdx' => if funIdx' == funIdx then `(1) else `(0)
let body ← mkTupleSyntax measureStxs
return { ref := .missing, vars := idents, body, synthetic := true }
@ -668,11 +716,20 @@ def explainFailure (declNames : Array Name) (varNamess : Array (Array Name))
r := r ++ (← explainMutualFailure declNames varNamess rcs)
return r
end Lean.Elab.WF.GuessLex
/--
Shows the termination measure used to the user, and implements `termination_by?`
-/
def reportWF (preDefs : Array PreDefinition) (wf : TerminationWF) : MetaM Unit := do
let extraParamss := preDefs.map (·.termination.extraParams)
let wf' := trimTermWF extraParamss wf
for preDef in preDefs, term in wf' do
if showInferredTerminationBy.get (← getOptions) then
logInfoAt preDef.ref m!"Inferred termination argument:\n{← term.unexpand}"
if let some ref := preDef.termination.terminationBy?? then
Tactic.TryThis.addSuggestion ref (← term.unexpand)
namespace Lean.Elab.WF
open Lean.Elab.WF.GuessLex
end GuessLex
open GuessLex
/--
Main entry point of this module:
@ -683,14 +740,17 @@ terminates. See the module doc string for a high-level overview.
def guessLex (preDefs : Array PreDefinition) (unaryPreDef : PreDefinition)
(fixedPrefixSize : Nat) :
MetaM TerminationWF := do
let extraParamss := preDefs.map (·.termination.extraParams)
let originalVarNamess ← preDefs.mapM originalVarNames
let varNamess ← originalVarNamess.mapM (naryVarNames fixedPrefixSize ·)
let arities := varNamess.map (·.size)
trace[Elab.definition.wf] "varNames is: {varNamess}"
let forbiddenArgs ← preDefs.mapM fun preDef =>
getForbiddenByTrivialSizeOf fixedPrefixSize preDef
let forbiddenArgs ← preDefs.mapM (getForbiddenByTrivialSizeOf fixedPrefixSize)
let needsNoSizeOf ←
if preDefs.size = 1 then
preDefs.mapM (getSizeOfParams fixedPrefixSize)
else
preDefs.mapM (getNatParams fixedPrefixSize)
-- The list of measures, including the measures that order functions.
-- The function ordering measures come last
@ -698,7 +758,9 @@ def guessLex (preDefs : Array PreDefinition) (unaryPreDef : PreDefinition)
-- If there is only one plausible measure, use that
if let #[solution] := measures then
return ← buildTermWF originalVarNamess varNamess #[solution]
let wf ← buildTermWF originalVarNamess varNamess needsNoSizeOf #[solution]
reportWF preDefs wf
return wf
-- Collect all recursive calls and extract their context
let recCalls ← collectRecCalls unaryPreDef fixedPrefixSize arities
@ -708,15 +770,8 @@ def guessLex (preDefs : Array PreDefinition) (unaryPreDef : PreDefinition)
match ← liftMetaM <| solve measures callMatrix with
| .some solution => do
let wf ← buildTermWF originalVarNamess varNamess solution
let wf' := trimTermWF extraParamss wf
for preDef in preDefs, term in wf' do
if showInferredTerminationBy.get (← getOptions) then
logInfoAt preDef.ref m!"Inferred termination argument:\n{← term.unexpand}"
if let some ref := preDef.termination.terminationBy?? then
Tactic.TryThis.addSuggestion ref (← term.unexpand)
let wf ← buildTermWF originalVarNamess varNamess needsNoSizeOf solution
reportWF preDefs wf
return wf
| .none =>
let explanation ← explainFailure (preDefs.map (·.declName)) varNamess rcs

View file

@ -87,6 +87,7 @@ def confuseLex2 : @PSigma Nat (fun _ => Nat) → Nat
| ⟨0,_⟩ => 0
| ⟨.succ y,.succ n⟩ => confuseLex2 ⟨y,n⟩
-- NB: uses sizeOf to make the termination argument non-dependent
def dependent : (n : Nat) → (m : Fin n) → Nat
| 0, i => Fin.elim0 i
| .succ 0, 0 => 0
@ -94,6 +95,11 @@ def dependent : (n : Nat) → (m : Fin n) → Nat
| .succ (.succ n), ⟨.succ m, h⟩ =>
dependent (.succ (.succ n)) ⟨m, Nat.lt_of_le_of_lt (Nat.le_succ _) h⟩
-- NB: does not use sizeOf, as parameters in the fixed prefix are fine.
def dependentWithFixedPrefix (n : Nat) : (m : Fin n) → (acc : Nat) → Nat
| ⟨0, _⟩, acc => acc
| ⟨i+1, h⟩, acc => dependentWithFixedPrefix n ⟨i, Nat.lt_of_succ_lt h⟩ (acc + i)
-- An example based on a real world problem, condensed by Leo
inductive Expr where
| add (a b : Expr)
@ -110,6 +116,7 @@ def eval_add (a : Expr × Expr) : Nat :=
| (x, y) => eval x + eval y
end
namespace VarNames
/-! Test that varnames are inferred nicely. -/
@ -127,24 +134,97 @@ def shadow2 (some_n : Nat) : Nat → Nat
| .succ n => shadow2 (some_n + 1) n
decreasing_by decreasing_tactic
-- The following test whether `sizeOf` is properly printed, and possibly qualified
-- For this we need a type that needs an explicit “sizeOf”.
structure OddNat where nat : Nat
instance : WellFoundedRelation OddNat := measure (fun ⟨n⟩ => n+1)
-- Just to check that sizeOf is actually used
def oddNat : OddNat → Nat
| ⟨0⟩ => 0
| ⟨.succ n⟩ => oddNat ⟨n⟩
decreasing_by decreasing_tactic
-- Shadowing `sizeOf`, as a varying paramter
def shadowSizeOf1 (sizeOf : Nat) : Nat → Nat
| 0 => 0
| .succ n => shadowSizeOf1 (sizeOf + 1) n
def shadowSizeOf1 (sizeOf : Nat) : OddNat → Nat
| ⟨0⟩ => 0
| ⟨.succ n⟩ => shadowSizeOf1 (sizeOf + 1) ⟨n⟩
decreasing_by decreasing_tactic
-- Shadowing `sizeOf`, as a fixed paramter
def shadowSizeOf2 (sizeOf : Nat) : Nat → Nat → Nat
| 0, m => m
| .succ n, m => shadowSizeOf2 sizeOf n m
def shadowSizeOf2 (sizeOf : Nat) : OddNat → Nat → Nat
| ⟨0⟩, m => m
| ⟨.succ n⟩, m => shadowSizeOf2 sizeOf ⟨n⟩ m
decreasing_by decreasing_tactic
-- Shadowing `sizeOf`, as something in the environment
def sizeOf : Nat := 2
def qualifiedSizeOf (m : Nat) : Nat → Nat
| 0 => 0
| .succ n => qualifiedSizeOf (m + 1) n
def qualifiedSizeOf (m : Nat) : OddNat → Nat
| ⟨0⟩ => 0
| ⟨.succ n⟩ => qualifiedSizeOf (m + 1) ⟨n⟩
decreasing_by decreasing_tactic
end VarNames
namespace MutualNotNat1
-- A type that isn't Nat, checking that the inferred argument uses `sizeOf` so that
-- the types of the termination argument aligns.
structure OddNat2 where nat : Nat
instance : SizeOf OddNat2 := ⟨fun n => n.nat⟩
@[simp] theorem OddNat2.sizeOf_eq (n : OddNat2) : sizeOf n = n.nat := rfl
mutual
def foo : Nat → Nat
| 0 => 0
| n+1 => bar ⟨n⟩
def bar : OddNat2 → Nat
| ⟨0⟩ => 0
| ⟨n+1⟩ => foo n
end
end MutualNotNat1
namespace MutualNotNat2
-- A type that is defeq to Nat, but with a different `sizeOf`, checking that the
-- inferred argument uses `sizeOf` so that the types of the termination argument aligns.
def OddNat3 := Nat
instance : SizeOf OddNat3 := ⟨fun n => 42 - @id Nat n⟩
@[simp] theorem OddNat3.sizeOf_eq (n : OddNat3) : sizeOf n = 42 - @id Nat n := rfl
mutual
def foo : Nat → Nat
| 0 => 0
| n+1 =>
if h : n < 42 then bar (42 - n) else 0
-- termination_by x1 => x1
decreasing_by simp_wf; simp [OddNat3]; omega
def bar (o : OddNat3) : Nat := if h : @id Nat o < 41 then foo (41 - @id Nat o) else 0
-- termination_by sizeOf o
decreasing_by simp_wf; simp [id] at *; omega
end
namespace MutualNotNat2
namespace MutualNotNat3
-- A varant of the above, but where the type of the parameter refined to `Nat`.
-- This tests if `GuessLex` is inferring the `SizeOf` instance based on the type of the
-- concrete parameter/argument (wrong, but status quo), or based on the types in the function
-- signature (correct, todo)
def OddNat3 := Nat
instance : SizeOf OddNat3 := ⟨fun n => 42 - @id Nat n⟩
@[simp] theorem OddNat3.sizeOf_eq (n : OddNat3) : sizeOf n = 42 - @id Nat n := rfl
mutual
def foo : Nat → Nat
| 0 => 0
| n+1 =>
if h : n < 42 then bar (42 - n) else 0
-- termination_by x1 => x1
decreasing_by simp_wf; simp [OddNat3]; omega
def bar : OddNat3 → Nat
| Nat.zero => 0
| n+1 => if h : n < 41 then foo (40 - n) else 0
-- termination_by x1 => sizeOf x1
decreasing_by simp_wf; omega
end
namespace MutualNotNat3

View file

@ -1,41 +1,65 @@
Inferred termination argument:
termination_by (sizeOf n, sizeOf m)
termination_by (n, m)
Inferred termination argument:
termination_by (sizeOf m, sizeOf n)
termination_by (m, n)
Inferred termination argument:
termination_by (sizeOf n, sizeOf m)
termination_by (n, m)
Inferred termination argument:
termination_by x1 x2 => (sizeOf x2, sizeOf x1)
termination_by x1 x2 => (x2, x1)
Inferred termination argument:
termination_by x1 => x1
Inferred termination argument:
termination_by x1 => x1
Inferred termination argument:
termination_by x1 => x1
Inferred termination argument:
termination_by x1 => x1
Inferred termination argument:
termination_by x1 => (x1, 0)
Inferred termination argument:
termination_by (n, 1)
Inferred termination argument:
termination_by (m, n)
Inferred termination argument:
termination_by x1 x2 x3 x4 x5 x6 x7 x8 => (x8, x7, x6, x5, x4, x3, x2, x1)
Inferred termination argument:
termination_by x1 => sizeOf x1
Inferred termination argument:
termination_by x1 => sizeOf x1
termination_by x1 x2 => (x1, sizeOf x2)
Inferred termination argument:
termination_by x1 => sizeOf x1
Inferred termination argument:
termination_by x1 => sizeOf x1
Inferred termination argument:
termination_by x1 => (sizeOf x1, 0)
Inferred termination argument:
termination_by (sizeOf n, 1)
Inferred termination argument:
termination_by (sizeOf m, sizeOf n)
Inferred termination argument:
termination_by x1 x2 x3 x4 x5 x6 x7 x8 =>
(sizeOf x8, sizeOf x7, sizeOf x6, sizeOf x5, sizeOf x4, sizeOf x3, sizeOf x2, sizeOf x1)
Inferred termination argument:
termination_by x1 x2 => (sizeOf x1, sizeOf x2)
termination_by x1 x2 => x1
Inferred termination argument:
termination_by (sizeOf a, 1)
Inferred termination argument:
termination_by (sizeOf a, 0)
Inferred termination argument:
termination_by x2' => sizeOf x2'
termination_by x2' => x2'
Inferred termination argument:
termination_by x2 => sizeOf x2
termination_by x2 => x2
Inferred termination argument:
termination_by x1 => sizeOf x1
Inferred termination argument:
termination_by x2 => SizeOf.sizeOf x2
Inferred termination argument:
termination_by x1 x2 => SizeOf.sizeOf x1
Inferred termination argument:
termination_by x2 => SizeOf.sizeOf x2
Inferred termination argument:
termination_by x1 => x1
Inferred termination argument:
termination_by x1 => sizeOf x1
Inferred termination argument:
termination_by x1 => x1
Inferred termination argument:
termination_by sizeOf o
guessLex.lean:217:0-229:3: error: Could not find a decreasing measure.
The arguments relate at each recursive call as follows:
(<, ≤, =: relation proved, ? all proofs failed, _: no proof attempted)
Call from MutualNotNat2.MutualNotNat2.MutualNotNat3.foo to MutualNotNat2.MutualNotNat2.MutualNotNat3.bar at 221:23-35:
x1
x1 <
Call from MutualNotNat2.MutualNotNat2.MutualNotNat3.bar to MutualNotNat2.MutualNotNat2.MutualNotNat3.foo at 226:30-42:
x1
x1 ?
Please use `termination_by` to specify a decreasing measure.

View file

@ -1,6 +1,6 @@
Inferred termination argument:
termination_by (sizeOf y, 1, 0)
termination_by (y, 1, 0)
Inferred termination argument:
termination_by (sizeOf y, 0, 1)
termination_by (y, 0, 1)
Inferred termination argument:
termination_by (sizeOf y, 0, 0)
termination_by (y, 0, 0)

View file

@ -1,4 +1,4 @@
Inferred termination argument:
termination_by x1 x2 x3 => (sizeOf x1, sizeOf x2, 0)
termination_by x1 x2 x3 => (x1, x2, 0)
Inferred termination argument:
termination_by x1 x2 x3 => (sizeOf x1, sizeOf x2, 1)
termination_by x1 x2 x3 => (x1, x2, 1)

View file

@ -1,7 +1,19 @@
def ackermann (n m : Nat) := match n, m with
| 0, m => m + 1
| .succ n, 0 => ackermann n 1
| .succ n, .succ m => ackermann n (ackermann (n + 1) m)
termination_by?
--^ codeAction
-- Check hat we print this even if there is only one plausible measure
def onlyOneMeasure (n : Nat) := match n with
| 0 => 0
| .succ n => onlyOneMeasure n
termination_by?
--^ codeAction
def anonymousMeasure : Nat → Nat
| 0 => 0
| .succ n => anonymousMeasure n
termination_by?
--^ codeAction

View file

@ -1,4 +1,4 @@
{"title": "Try this: termination_by (sizeOf n, sizeOf m)",
{"title": "Try this: termination_by (n, m)",
"kind": "quickfix",
"isPreferred": true,
"edit":
@ -7,6 +7,30 @@
{"version": 1, "uri": "file:///terminationBySuggestion.lean"},
"edits":
[{"range":
{"start": {"line": 5, "character": 0},
"end": {"line": 5, "character": 15}},
"newText": "termination_by (sizeOf n, sizeOf m)"}]}]}}
{"start": {"line": 4, "character": 0},
"end": {"line": 4, "character": 15}},
"newText": "termination_by (n, m)"}]}]}}
{"title": "Try this: termination_by n",
"kind": "quickfix",
"isPreferred": true,
"edit":
{"documentChanges":
[{"textDocument":
{"version": 1, "uri": "file:///terminationBySuggestion.lean"},
"edits":
[{"range":
{"start": {"line": 11, "character": 0},
"end": {"line": 11, "character": 15}},
"newText": "termination_by n"}]}]}}
{"title": "Try this: termination_by x1 => x1",
"kind": "quickfix",
"isPreferred": true,
"edit":
{"documentChanges":
[{"textDocument":
{"version": 1, "uri": "file:///terminationBySuggestion.lean"},
"edits":
[{"range":
{"start": {"line": 17, "character": 0},
"end": {"line": 17, "character": 15}},
"newText": "termination_by x1 => x1"}]}]}}

View file

@ -3,11 +3,11 @@ Tactic is run (ideally only twice)
Tactic is run (ideally only twice)
Tactic is run (ideally only once, in most general context)
n : Nat
⊢ (invImage (fun a => sizeOf a) instWellFoundedRelation).1 n (Nat.succ n)
⊢ (invImage (fun a => a) instWellFoundedRelation).1 n (Nat.succ n)
Tactic is run (ideally only twice, in most general context)
Tactic is run (ideally only twice, in most general context)
n : Nat
⊢ sizeOf n < sizeOf (Nat.succ n)
n m : Nat
⊢ (invImage (fun a => PSigma.casesOn a fun x1 snd => sizeOf x1) instWellFoundedRelation).1 { fst := n, snd := m + 1 }
⊢ (invImage (fun a => PSigma.casesOn a fun x1 snd => x1) instWellFoundedRelation).1 { fst := n, snd := m + 1 }
{ fst := Nat.succ n, snd := m }