fix: use PSum instead of Sum when using well-founded recursion

See new test for example that did not work with `Sum` because type
alpha was `Sort u`.
This commit is contained in:
Leonardo de Moura 2022-02-17 16:11:46 -08:00
parent 4a0ae8326c
commit 9ee529e5ce
16 changed files with 67 additions and 48 deletions

View file

@ -68,3 +68,5 @@ example (a : α) (x : Fam α) : α :=
| Fam.any => a
| Fam.nat n => n
```
* We now use `PSum` (instead of `Sum`) when compiling mutually recursive definitions using well-founded recursion.

View file

@ -77,9 +77,9 @@ private partial def replaceRecApps (recFnName : Name) (decrTactic? : Option Synt
| e => ensureNoRecFn recFnName e
loop F e
/-- Refine `F` over `Sum.casesOn` -/
/-- Refine `F` over `PSum.casesOn` -/
private partial def processSumCasesOn (x F val : Expr) (k : (x : Expr) → (F : Expr) → (val : Expr) → TermElabM Expr) : TermElabM Expr := do
if x.isFVar && val.isAppOfArity ``Sum.casesOn 6 && val.getArg! 3 == x && (val.getArg! 4).isLambda && (val.getArg! 5).isLambda then
if x.isFVar && val.isAppOfArity ``PSum.casesOn 6 && val.getArg! 3 == x && (val.getArg! 4).isLambda && (val.getArg! 5).isLambda then
let args := val.getAppArgs
let α := args[0]
let β := args[1]
@ -94,9 +94,9 @@ private partial def processSumCasesOn (x F val : Expr) (k : (x : Expr) → (F :
let FTypeNew := FDecl.type.replaceFVar x (← mkAppOptM ctorName #[α, β, xNew])
withLocalDeclD FDecl.userName FTypeNew fun FNew => do
mkLambdaFVars #[xNew, FNew] (← processSumCasesOn xNew FNew valNew k)
let minorLeft ← mkMinorNew ``Sum.inl args[4]
let minorRight ← mkMinorNew ``Sum.inr args[5]
let result := mkAppN (mkConst ``Sum.casesOn [u, (← getDecLevel α), (← getDecLevel β)]) #[α, β, motiveNew, x, minorLeft, minorRight, F]
let minorLeft ← mkMinorNew ``PSum.inl args[4]
let minorRight ← mkMinorNew ``PSum.inr args[5]
let result := mkAppN (mkConst ``PSum.casesOn [u, (← getLevel α), (← getLevel β)]) #[α, β, motiveNew, x, minorLeft, minorRight, F]
return result
else
k x F val
@ -138,7 +138,8 @@ def mkFix (preDef : PreDefinition) (wfRel : Expr) (decrTactic? : Option Syntax)
let x := xs[0]
let F := xs[1]
let val := preDef.value.betaRev #[x]
let val ← processSumCasesOn x F val fun x F val => processPSigmaCasesOn x F val (replaceRecApps preDef.declName decrTactic?)
let val ← processSumCasesOn x F val fun x F val => do
processPSigmaCasesOn x F val (replaceRecApps preDef.declName decrTactic?)
return { preDef with value := mkApp wfFix (← mkLambdaFVars #[x, F] val) }
end Lean.Elab.WF

View file

@ -38,10 +38,10 @@ private partial def addNonRecPreDefs (preDefs : Array PreDefinition) (preDefNonR
(← whnfD type).withApp fun f args => do
assert! args.size == 2
if i == fidx then
return mkApp3 (mkConst ``Sum.inl f.constLevels!) args[0] args[1] (← mkProd args[0])
return mkApp3 (mkConst ``PSum.inl f.constLevels!) args[0] args[1] (← mkProd args[0])
else
let r ← mkSum (i+1) args[1]
return mkApp3 (mkConst ``Sum.inr f.constLevels!) args[0] args[1] r
return mkApp3 (mkConst ``PSum.inr f.constLevels!) args[0] args[1] r
let arg ← mkSum 0 domain
mkLambdaFVars xs (mkApp (mkConst preDefNonRec.declName us) arg)
trace[Elab.definition.wf] "{preDef.declName} := {value}"

View file

@ -11,11 +11,11 @@ open Meta
private def getDomains (preDefs : Array PreDefinition) : Array Expr :=
preDefs.map fun preDef => preDef.type.bindingDomain!
/-- Combine different function domains `ds` using `Sum`s -/
/-- Combine different function domains `ds` using `PSum`s -/
private def mkNewDomain (ds : Array Expr) : MetaM Expr := do
let mut r := ds.back
for d in ds.pop.reverse do
r ← mkAppM ``Sum #[d, r]
r ← mkAppM ``PSum #[d, r]
return r
private def getCodomainLevel (preDef : PreDefinition) : MetaM Level :=
@ -41,9 +41,9 @@ private partial def mkNewCoDomain (x : Expr) (preDefs : Array PreDefinition) : M
let rec go (x : Expr) (i : Nat) : MetaM Expr := do
if i < preDefs.size - 1 then
let xType ← whnfD (← inferType x)
assert! xType.isAppOfArity ``Sum 2
assert! xType.isAppOfArity ``PSum 2
let xTypeArgs := xType.getAppArgs
let casesOn := mkConst (mkCasesOnName ``Sum) (mkLevelSucc u :: xType.getAppFn.constLevels!)
let casesOn := mkConst (mkCasesOnName ``PSum) (mkLevelSucc u :: xType.getAppFn.constLevels!)
let casesOn := mkAppN casesOn xTypeArgs -- parameters
let casesOn := mkApp casesOn (← mkLambdaFVars #[x] (mkSort u)) -- motive
let casesOn := mkApp casesOn x -- major
@ -58,7 +58,7 @@ private partial def mkNewCoDomain (x : Expr) (preDefs : Array PreDefinition) : M
/--
Combine/pack the values of the different definitions in a single value
`x` is `Sum`, and we use `Sum.casesOn` to select the appropriate `preDefs.value`.
`x` is `PSum`, and we use `PSum.casesOn` to select the appropriate `preDefs.value`.
See: `packMutual`.
Remark: this method does not replace the nested recursive `preDefs` applications.
This step is performed by `transform` with the following `post` method.
@ -100,15 +100,15 @@ private partial def post (preDefs : Array PreDefinition) (domain : Expr) (newFn
(← whnfD type).withApp fun f args => do
assert! args.size == 2
if i == fidx then
return mkApp3 (mkConst ``Sum.inl f.constLevels!) args[0] args[1] arg
return mkApp3 (mkConst ``PSum.inl f.constLevels!) args[0] args[1] arg
else
let r ← mkNewArg (i+1) args[1]
return mkApp3 (mkConst ``Sum.inr f.constLevels!) args[0] args[1] r
return mkApp3 (mkConst ``PSum.inr f.constLevels!) args[0] args[1] r
return TransformStep.done <| mkApp (mkConst newFn us) (← mkNewArg 0 domain)
return TransformStep.done e
/--
If `preDefs.size > 1`, combine different functions in a single one using `Sum`.
If `preDefs.size > 1`, combine different functions in a single one using `PSum`.
This method assumes all `preDefs` have arity 1, and have already been processed using `packDomain`.
Here is a small example. Suppose the input is
```
@ -128,22 +128,22 @@ private partial def post (preDefs : Array PreDefinition) (domain : Expr) (newFn
this method produces the following pre definition
```
f._mutual x :=
Sum.casesOn x
PSum.casesOn x
(fun val =>
match val.2.1, val.2.2.1, val.2.2.2 with
| 0, a, b => a
| Nat.succ n, a, b => (f._mutual (Sum.inr (Sum.inl ⟨val.1, n, a, b⟩))).fst
| Nat.succ n, a, b => (f._mutual (PSum.inr (PSum.inl ⟨val.1, n, a, b⟩))).fst
fun val =>
Sum.casesOn val
PSum.casesOn val
(fun val =>
match val.2.1, val.2.2.1, val.2.2.2 with
| 0, a, b => (a, b)
| Nat.succ n, a, b => (f._mutual (Sum.inr (Sum.inr ⟨val.1, n, a, b⟩)), a)
| Nat.succ n, a, b => (f._mutual (PSum.inr (PSum.inr ⟨val.1, n, a, b⟩)), a)
fun val =>
match val.2.1, val.2.2.1, val.2.2.2 with
| 0, a, b => b
| Nat.succ n, a, b =>
f._mutual (Sum.inl ⟨val.1, n, a, b⟩)
f._mutual (PSum.inl ⟨val.1, n, a, b⟩)
```
-/
def packMutual (preDefs : Array PreDefinition) : MetaM PreDefinition := do

View file

@ -14,9 +14,9 @@ end
termination_by'
invImage
(fun
| Sum.inl ⟨n, true⟩ => (n, 2)
| Sum.inl ⟨n, false⟩ => (n, 1)
| Sum.inr n => (n, 0))
| PSum.inl ⟨n, true⟩ => (n, 2)
| PSum.inl ⟨n, false⟩ => (n, 1)
| PSum.inr n => (n, 0))
$ Prod.lex sizeOfWFRel sizeOfWFRel
decreasing_by
simp [invImage, InvImage, Prod.lex, sizeOfWFRel, measure, Nat.lt_wfRel, WellFoundedRelation.rel]

View file

@ -22,5 +22,5 @@ mutual
| 0 => false
| n+1 => isEven n
end
termination_by' measure (fun n => match n with | Sum.inl n => n | Sum.inr n => n)
termination_by' measure (fun n => match n with | PSum.inl n => n | PSum.inr n => n)
end Ex3

View file

@ -7,8 +7,8 @@ mutual
| n+1 => isEven n
end
termination_by' measure fun
| Sum.inl n => n
| Sum.inr n => n
| PSum.inl n => n
| PSum.inr n => n
decreasing_by
simp [measure, invImage, InvImage, Nat.lt_wfRel]
apply Nat.lt_succ_self

View file

@ -7,8 +7,8 @@ mutual
| n+1 => isEven n
end
termination_by' measure fun
| Sum.inl n => n
| Sum.inr n => n
| PSum.inl n => n
| PSum.inr n => n
decreasing_by apply Nat.lt_succ_self
theorem isEven_double (x : Nat) : isEven (2 * x) = true := by

View file

@ -15,9 +15,9 @@ end
termination_by'
invImage
(fun
| Sum.inl ⟨_, n, _, _⟩ => (n, 2)
| Sum.inr <| Sum.inl ⟨_, _, n, _⟩ => (n, 1)
| Sum.inr <| Sum.inr ⟨_, _, _, n⟩ => (n, 0))
| PSum.inl ⟨_, n, _, _⟩ => (n, 2)
| PSum.inr <| PSum.inl ⟨_, _, n, _⟩ => (n, 1)
| PSum.inr <| PSum.inr ⟨_, _, _, n⟩ => (n, 0))
(Prod.lex sizeOfWFRel sizeOfWFRel)
#print f

View file

@ -8,8 +8,8 @@ mutual
| n+1 => isEven n
end
termination_by' measure fun
| Sum.inl n => n
| Sum.inr n => n
| PSum.inl n => n
| PSum.inr n => n
#print isEven
#print isOdd

View file

@ -15,9 +15,9 @@ end
termination_by'
invImage
(fun
| Sum.inl ⟨_, n, _, _⟩ => (n, 2)
| Sum.inr <| Sum.inl ⟨_, _, n, _⟩ => (n, 1)
| Sum.inr <| Sum.inr ⟨_, _, _, n⟩ => (n, 0))
| PSum.inl ⟨_, n, _, _⟩ => (n, 2)
| PSum.inr <| PSum.inl ⟨_, _, n, _⟩ => (n, 1)
| PSum.inr <| PSum.inr ⟨_, _, _, n⟩ => (n, 0))
(Prod.lex sizeOfWFRel sizeOfWFRel)
decreasing_by
simp [invImage, InvImage, Prod.lex, sizeOfWFRel, measure, Nat.lt_wfRel, WellFoundedRelation.rel]
@ -51,9 +51,9 @@ end
termination_by'
invImage
(fun
| Sum.inl ⟨_, n, _, _⟩ => (n, 2)
| Sum.inr <| Sum.inl ⟨_, _, n, _⟩ => (n, 1)
| Sum.inr <| Sum.inr ⟨_, _, _, n⟩ => (n, 0))
| PSum.inl ⟨_, n, _, _⟩ => (n, 2)
| PSum.inr <| PSum.inl ⟨_, _, n, _⟩ => (n, 1)
| PSum.inr <| PSum.inr ⟨_, _, _, n⟩ => (n, 0))
(Prod.lex sizeOfWFRel sizeOfWFRel)
#print f._unary._mutual

View file

@ -0,0 +1,16 @@
mutual
def fn (f : αα) (a : α) (n : Nat) : α :=
match n with
| 0 => a
| n+1 => gn f (f (f a)) (f a) n
def gn (f : αα) (a b : α) (n : Nat) : α :=
match n with
| 0 => b
| n+1 => fn f (f b) n
end
termination_by
fn n => n
gn n => n

View file

@ -14,8 +14,8 @@ mutual
| n+1 => isEven n
end
termination_by' measure fun
| Sum.inl n => n
| Sum.inr n => n
| PSum.inl n => n
| PSum.inr n => n
decreasing_by
simp [measure, invImage, InvImage, Nat.lt_wfRel]
apply Nat.lt_succ_self

View file

@ -19,8 +19,8 @@ end
termination_by'
invImage
(fun
| Sum.inl ⟨_, n⟩ => (n, 0)
| Sum.inr ⟨_, n⟩ => (n, 1))
| PSum.inl ⟨_, n⟩ => (n, 0)
| PSum.inr ⟨_, n⟩ => (n, 1))
(Prod.lex sizeOfWFRel sizeOfWFRel)
decreasing_by
simp [invImage, InvImage, Prod.lex, sizeOfWFRel, measure, Nat.lt_wfRel, WellFoundedRelation.rel]

View file

@ -21,9 +21,9 @@ end
termination_by'
invImage
(fun
| Sum.inl ⟨_, n, _, _⟩ => (n, 2)
| Sum.inr <| Sum.inl ⟨_, _, n, _⟩ => (n, 1)
| Sum.inr <| Sum.inr ⟨_, _, _, n⟩ => (n, 0))
| PSum.inl ⟨_, n, _, _⟩ => (n, 2)
| PSum.inr <| PSum.inl ⟨_, _, n, _⟩ => (n, 1)
| PSum.inr <| PSum.inr ⟨_, _, _, n⟩ => (n, 0))
(Prod.lex sizeOfWFRel sizeOfWFRel)
decreasing_by
simp [invImage, InvImage, Prod.lex, sizeOfWFRel, measure, Nat.lt_wfRel, WellFoundedRelation.rel]

View file

@ -6,7 +6,7 @@ mutual
| 0 => false
| n+1 => isEven n
end
termination_by' measure fun | Sum.inl n => n | Sum.inr n => n
termination_by' measure fun | PSum.inl n => n | PSum.inr n => n
decreasing_by apply Nat.lt_succ_self
theorem isEven_double (x : Nat) : isEven (2 * x) = true := by