lean4-htt/src/Lean/Elab/PreDefinition/Structural/FindRecArg.lean
Joachim Breitner 56733b953e
refactor: TerminationArgument → TerminationMeasure (#6727)
this PR aligns the terminology of the code with the one use in the
reference manual, as developed with and refined by @david-christiansen.
2025-01-23 10:41:38 +00:00

284 lines
13 KiB
Text
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

/-
Copyright (c) 2021 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura, Joachim Breitner
-/
prelude
import Lean.Elab.PreDefinition.TerminationMeasure
import Lean.Elab.PreDefinition.Structural.Basic
import Lean.Elab.PreDefinition.Structural.RecArgInfo
namespace Lean.Elab.Structural
open Meta
def prettyParam (xs : Array Expr) (i : Nat) : MetaM MessageData := do
let x := xs[i]!
let n ← x.fvarId!.getUserName
addMessageContextFull <| if n.hasMacroScopes then m!"#{i+1}" else m!"{x}"
def prettyRecArg (xs : Array Expr) (value : Expr) (recArgInfo : RecArgInfo) : MetaM MessageData := do
lambdaTelescope value fun ys _ => prettyParam (xs ++ ys) recArgInfo.recArgPos
def prettyParameterSet (fnNames : Array Name) (xs : Array Expr) (values : Array Expr)
(recArgInfos : Array RecArgInfo) : MetaM MessageData := do
if fnNames.size = 1 then
return m!"parameter " ++ (← prettyRecArg xs values[0]! recArgInfos[0]!)
else
let mut l := #[]
for fnName in fnNames, value in values, recArgInfo in recArgInfos do
l := l.push m!"{(← prettyRecArg xs value recArgInfo)} of {fnName}"
return m!"parameters " ++ .andList l.toList
private def getIndexMinPos (xs : Array Expr) (indices : Array Expr) : Nat := Id.run do
let mut minPos := xs.size
for index in indices do
match xs.indexOf? index with
| some pos => if pos.val < minPos then minPos := pos.val
| _ => pure ()
return minPos
-- Indices can only depend on other indices
private def hasBadIndexDep? (ys : Array Expr) (indices : Array Expr) : MetaM (Option (Expr × Expr)) := do
for index in indices do
let indexType ← inferType index
for y in ys do
if !indices.contains y && (← dependsOn indexType y.fvarId!) then
return some (index, y)
return none
-- Inductive datatype parameters cannot depend on ys
private def hasBadParamDep? (ys : Array Expr) (indParams : Array Expr) : MetaM (Option (Expr × Expr)) := do
for p in indParams do
for y in ys do
if ← dependsOn p y.fvarId! then
return some (p, y)
return none
/--
Assemble the `RecArgInfo` for the `i`th parameter in the parameter list `xs`. This performs
various sanity checks on the parameter (is it even of inductive type etc).
-/
def getRecArgInfo (fnName : Name) (numFixed : Nat) (xs : Array Expr) (i : Nat) : MetaM RecArgInfo := do
if h : i < xs.size then
if i < numFixed then
throwError "it is unchanged in the recursive calls"
let x := xs[i]
let localDecl ← getFVarLocalDecl x
if localDecl.isLet then
throwError "it is a let-binding"
let xType ← whnfD localDecl.type
matchConstInduct xType.getAppFn (fun _ => throwError "its type is not an inductive") fun indInfo us => do
if indInfo.isReflexive && !(← hasConst (mkBInductionOnName indInfo.name)) && !(← isInductivePredicate indInfo.name) then
throwError "its type {indInfo.name} is a reflexive inductive, but {mkBInductionOnName indInfo.name} does not exist and it is not an inductive predicate"
else
let indArgs : Array Expr := xType.getAppArgs
let indParams : Array Expr := indArgs[0:indInfo.numParams]
let indIndices : Array Expr := indArgs[indInfo.numParams:]
if !indIndices.all Expr.isFVar then
throwError "its type {indInfo.name} is an inductive family and indices are not variables{indentExpr xType}"
else if !indIndices.allDiff then
throwError "its type {indInfo.name} is an inductive family and indices are not pairwise distinct{indentExpr xType}"
else
let indexMinPos := getIndexMinPos xs indIndices
let numFixed := if indexMinPos < numFixed then indexMinPos else numFixed
let ys := xs[numFixed:]
match (← hasBadIndexDep? ys indIndices) with
| some (index, y) =>
throwError "its type {indInfo.name} is an inductive family{indentExpr xType}\nand index{indentExpr index}\ndepends on the non index{indentExpr y}"
| none =>
match (← hasBadParamDep? ys indParams) with
| some (indParam, y) =>
throwError "its type is an inductive datatype{indentExpr xType}\nand the datatype parameter{indentExpr indParam}\ndepends on the function parameter{indentExpr y}\nwhich does not come before the varying parameters and before the indices of the recursion parameter."
| none =>
let indAll := indInfo.all.toArray
let .some indIdx := indAll.indexOf? indInfo.name | panic! "{indInfo.name} not in {indInfo.all}"
let indicesPos := indIndices.map fun index => match xs.indexOf? index with | some i => i.val | none => unreachable!
let indGroupInst := {
IndGroupInfo.ofInductiveVal indInfo with
levels := us
params := indParams }
return { fnName := fnName
numFixed := numFixed
recArgPos := i
indicesPos := indicesPos
indGroupInst := indGroupInst
indIdx := indIdx }
else
throwError "the index #{i+1} exceeds {xs.size}, the number of parameters"
/--
Collects the `RecArgInfos` for one function, and returns a report for why the others were not
considered.
The `xs` are the fixed parameters, `value` the body with the fixed prefix instantiated.
Takes the optional user annotation into account (`termMeasure?`). If this is given and the measure
is unsuitable, throw an error.
-/
def getRecArgInfos (fnName : Name) (xs : Array Expr) (value : Expr)
(termMeasure? : Option TerminationMeasure) : MetaM (Array RecArgInfo × MessageData) := do
lambdaTelescope value fun ys _ => do
if let .some termMeasure := termMeasure? then
-- User explicitly asked to use a certain measure, so throw errors eagerly
let recArgInfo ← withRef termMeasure.ref do
mapError (f := (m!"cannot use specified measure for structural recursion:{indentD ·}")) do
getRecArgInfo fnName xs.size (xs ++ ys) (← termMeasure.structuralArg)
return (#[recArgInfo], m!"")
else
let mut recArgInfos := #[]
let mut report : MessageData := m!""
-- No `termination_by`, so try all, and remember the errors
for idx in [:xs.size + ys.size] do
try
let recArgInfo ← getRecArgInfo fnName xs.size (xs ++ ys) idx
recArgInfos := recArgInfos.push recArgInfo
catch e =>
report := report ++ (m!"Not considering parameter {← prettyParam (xs ++ ys) idx} of {fnName}:" ++
indentD e.toMessageData) ++ "\n"
trace[Elab.definition.structural] "getRecArgInfos report: {report}"
return (recArgInfos, report)
/--
Reorders the `RecArgInfos` of one function to put arguments that are indices of other arguments
last.
See issue #837 for an example where we can show termination using the index of an inductive family, but
we don't get the desired definitional equalities.
-/
def nonIndicesFirst (recArgInfos : Array RecArgInfo) : Array RecArgInfo := Id.run do
let mut indicesPos : Std.HashSet Nat := {}
for recArgInfo in recArgInfos do
for pos in recArgInfo.indicesPos do
indicesPos := indicesPos.insert pos
let (indices,nonIndices) := recArgInfos.partition (indicesPos.contains ·.recArgPos)
return nonIndices ++ indices
private def dedup [Monad m] (eq : αα → m Bool) (xs : Array α) : m (Array α) := do
let mut ret := #[]
for x in xs do
unless (← ret.anyM (eq · x)) do
ret := ret.push x
return ret
/--
Given the `RecArgInfo`s of all the recursive functions, find the inductive groups to consider.
-/
def inductiveGroups (recArgInfos : Array RecArgInfo) : MetaM (Array IndGroupInst) :=
dedup IndGroupInst.isDefEq (recArgInfos.map (·.indGroupInst))
/--
Filters the `recArgInfos` by those that describe an argument that's part of the recursive inductive
group `group`.
Because of nested inductives this function has the ability to change the `recArgInfo`.
Consider
```
inductive Tree where | node : List Tree → Tree
```
then when we look for arguments whose type is part of the group `Tree`, we want to also consider
the argument of type `List Tree`, even though that arguments `RecArgInfo` refers to initially to
`List`.
-/
def argsInGroup (group : IndGroupInst) (xs : Array Expr) (value : Expr)
(recArgInfos : Array RecArgInfo) : MetaM (Array RecArgInfo) := do
let nestedTypeFormers ← group.nestedTypeFormers
recArgInfos.filterMapM fun recArgInfo => do
-- Is this argument from the same mutual group of inductives?
if (← group.isDefEq recArgInfo.indGroupInst) then
return (.some recArgInfo)
-- Can this argument be understood as the auxiliary type former of a nested inductive?
if nestedTypeFormers.isEmpty then return .none
lambdaTelescope value fun ys _ => do
let x := (xs++ys)[recArgInfo.recArgPos]!
for nestedTypeFormer in nestedTypeFormers, indIdx in [group.all.size : group.numMotives] do
let xType ← whnfD (← inferType x)
let (indIndices, _, type) ← forallMetaTelescope nestedTypeFormer
if (← isDefEqGuarded type xType) then
let indIndices ← indIndices.mapM instantiateMVars
if !indIndices.all Expr.isFVar then
-- throwError "indices are not variables{indentExpr xType}"
continue
if !indIndices.allDiff then
-- throwError "indices are not pairwise distinct{indentExpr xType}"
continue
-- TODO: Do we have to worry about the indices ending up in the fixed prefix here?
if let some (_index, _y) ← hasBadIndexDep? ys indIndices then
-- throwError "its type {indInfo.name} is an inductive family{indentExpr xType}\nand index{indentExpr index}\ndepends on the non index{indentExpr y}"
continue
let indicesPos := indIndices.map fun index => match (xs++ys).indexOf? index with | some i => i.val | none => unreachable!
return .some
{ fnName := recArgInfo.fnName
numFixed := recArgInfo.numFixed
recArgPos := recArgInfo.recArgPos
indicesPos := indicesPos
indGroupInst := group
indIdx := indIdx }
return .none
def maxCombinationSize : Nat := 10
def allCombinations (xss : Array (Array α)) : Option (Array (Array α)) :=
if xss.foldl (· * ·.size) 1 > maxCombinationSize then
none
else
let rec go i acc : Array (Array α):=
if h : i < xss.size then
xss[i].flatMap fun x => go (i + 1) (acc.push x)
else
#[acc]
some (go 0 #[])
def tryAllArgs (fnNames : Array Name) (xs : Array Expr) (values : Array Expr)
(termMeasure?s : Array (Option TerminationMeasure)) (k : Array RecArgInfo → M α) : M α := do
let mut report := m!""
-- Gather information on all possible recursive arguments
let mut recArgInfoss := #[]
for fnName in fnNames, value in values, termMeasure? in termMeasure?s do
let (recArgInfos, thisReport) ← getRecArgInfos fnName xs value termMeasure?
report := report ++ thisReport
recArgInfoss := recArgInfoss.push recArgInfos
-- Put non-indices first
recArgInfoss := recArgInfoss.map nonIndicesFirst
trace[Elab.definition.structural] "recArgInfos:{indentD (.joinSep (recArgInfoss.flatten.toList.map (repr ·)) Format.line)}"
-- Inductive groups to consider
let groups ← inductiveGroups recArgInfoss.flatten
trace[Elab.definition.structural] "inductive groups: {groups}"
if groups.isEmpty then
report := report ++ "no parameters suitable for structural recursion"
-- Consider each group
for group in groups do
-- Select those RecArgInfos that are compatible with this inductive group
let mut recArgInfoss' := #[]
for value in values, recArgInfos in recArgInfoss do
recArgInfoss' := recArgInfoss'.push (← argsInGroup group xs value recArgInfos)
if let some idx := recArgInfoss'.findIdx? (·.isEmpty) then
report := report ++ m!"Skipping arguments of type {group}, as {fnNames[idx]!} has no compatible argument.\n"
continue
if let some combs := allCombinations recArgInfoss' then
for comb in combs do
try
-- Check that the group actually has a brecOn (we used to check this in getRecArgInfo,
-- but in the first phase we do not want to rule-out non-recursive types like `Array`, which
-- are ok in a nested group. This logic can maybe simplified)
unless (← hasConst (group.brecOnName false 0)) do
throwError "the type {group} does not have a `.brecOn` recursor"
-- TODO: Here we used to save and restore the state. But should the `try`-`catch`
-- not suffice?
let r ← k comb
trace[Elab.definition.structural] "tryAllArgs report:\n{report}"
return r
catch e =>
let m ← prettyParameterSet fnNames xs values comb
report := report ++ m!"Cannot use {m}:{indentD e.toMessageData}\n"
else
report := report ++ m!"Too many possible combinations of parameters of type {group} (or " ++
m!"please indicate the recursive argument explicitly using `termination_by structural`).\n"
report := m!"failed to infer structural recursion:\n" ++ report
trace[Elab.definition.structural] "tryAllArgs:\n{report}"
throwError report
end Lean.Elab.Structural