fix: use PSum instead of Sum when using well-founded recursion
See new test for example that did not work with `Sum` because type alpha was `Sort u`.
This commit is contained in:
parent
4a0ae8326c
commit
9ee529e5ce
16 changed files with 67 additions and 48 deletions
|
|
@ -68,3 +68,5 @@ example (a : α) (x : Fam α) : α :=
|
|||
| Fam.any => a
|
||||
| Fam.nat n => n
|
||||
```
|
||||
|
||||
* We now use `PSum` (instead of `Sum`) when compiling mutually recursive definitions using well-founded recursion.
|
||||
|
|
|
|||
|
|
@ -77,9 +77,9 @@ private partial def replaceRecApps (recFnName : Name) (decrTactic? : Option Synt
|
|||
| e => ensureNoRecFn recFnName e
|
||||
loop F e
|
||||
|
||||
/-- Refine `F` over `Sum.casesOn` -/
|
||||
/-- Refine `F` over `PSum.casesOn` -/
|
||||
private partial def processSumCasesOn (x F val : Expr) (k : (x : Expr) → (F : Expr) → (val : Expr) → TermElabM Expr) : TermElabM Expr := do
|
||||
if x.isFVar && val.isAppOfArity ``Sum.casesOn 6 && val.getArg! 3 == x && (val.getArg! 4).isLambda && (val.getArg! 5).isLambda then
|
||||
if x.isFVar && val.isAppOfArity ``PSum.casesOn 6 && val.getArg! 3 == x && (val.getArg! 4).isLambda && (val.getArg! 5).isLambda then
|
||||
let args := val.getAppArgs
|
||||
let α := args[0]
|
||||
let β := args[1]
|
||||
|
|
@ -94,9 +94,9 @@ private partial def processSumCasesOn (x F val : Expr) (k : (x : Expr) → (F :
|
|||
let FTypeNew := FDecl.type.replaceFVar x (← mkAppOptM ctorName #[α, β, xNew])
|
||||
withLocalDeclD FDecl.userName FTypeNew fun FNew => do
|
||||
mkLambdaFVars #[xNew, FNew] (← processSumCasesOn xNew FNew valNew k)
|
||||
let minorLeft ← mkMinorNew ``Sum.inl args[4]
|
||||
let minorRight ← mkMinorNew ``Sum.inr args[5]
|
||||
let result := mkAppN (mkConst ``Sum.casesOn [u, (← getDecLevel α), (← getDecLevel β)]) #[α, β, motiveNew, x, minorLeft, minorRight, F]
|
||||
let minorLeft ← mkMinorNew ``PSum.inl args[4]
|
||||
let minorRight ← mkMinorNew ``PSum.inr args[5]
|
||||
let result := mkAppN (mkConst ``PSum.casesOn [u, (← getLevel α), (← getLevel β)]) #[α, β, motiveNew, x, minorLeft, minorRight, F]
|
||||
return result
|
||||
else
|
||||
k x F val
|
||||
|
|
@ -138,7 +138,8 @@ def mkFix (preDef : PreDefinition) (wfRel : Expr) (decrTactic? : Option Syntax)
|
|||
let x := xs[0]
|
||||
let F := xs[1]
|
||||
let val := preDef.value.betaRev #[x]
|
||||
let val ← processSumCasesOn x F val fun x F val => processPSigmaCasesOn x F val (replaceRecApps preDef.declName decrTactic?)
|
||||
let val ← processSumCasesOn x F val fun x F val => do
|
||||
processPSigmaCasesOn x F val (replaceRecApps preDef.declName decrTactic?)
|
||||
return { preDef with value := mkApp wfFix (← mkLambdaFVars #[x, F] val) }
|
||||
|
||||
end Lean.Elab.WF
|
||||
|
|
|
|||
|
|
@ -38,10 +38,10 @@ private partial def addNonRecPreDefs (preDefs : Array PreDefinition) (preDefNonR
|
|||
(← whnfD type).withApp fun f args => do
|
||||
assert! args.size == 2
|
||||
if i == fidx then
|
||||
return mkApp3 (mkConst ``Sum.inl f.constLevels!) args[0] args[1] (← mkProd args[0])
|
||||
return mkApp3 (mkConst ``PSum.inl f.constLevels!) args[0] args[1] (← mkProd args[0])
|
||||
else
|
||||
let r ← mkSum (i+1) args[1]
|
||||
return mkApp3 (mkConst ``Sum.inr f.constLevels!) args[0] args[1] r
|
||||
return mkApp3 (mkConst ``PSum.inr f.constLevels!) args[0] args[1] r
|
||||
let arg ← mkSum 0 domain
|
||||
mkLambdaFVars xs (mkApp (mkConst preDefNonRec.declName us) arg)
|
||||
trace[Elab.definition.wf] "{preDef.declName} := {value}"
|
||||
|
|
|
|||
|
|
@ -11,11 +11,11 @@ open Meta
|
|||
private def getDomains (preDefs : Array PreDefinition) : Array Expr :=
|
||||
preDefs.map fun preDef => preDef.type.bindingDomain!
|
||||
|
||||
/-- Combine different function domains `ds` using `Sum`s -/
|
||||
/-- Combine different function domains `ds` using `PSum`s -/
|
||||
private def mkNewDomain (ds : Array Expr) : MetaM Expr := do
|
||||
let mut r := ds.back
|
||||
for d in ds.pop.reverse do
|
||||
r ← mkAppM ``Sum #[d, r]
|
||||
r ← mkAppM ``PSum #[d, r]
|
||||
return r
|
||||
|
||||
private def getCodomainLevel (preDef : PreDefinition) : MetaM Level :=
|
||||
|
|
@ -41,9 +41,9 @@ private partial def mkNewCoDomain (x : Expr) (preDefs : Array PreDefinition) : M
|
|||
let rec go (x : Expr) (i : Nat) : MetaM Expr := do
|
||||
if i < preDefs.size - 1 then
|
||||
let xType ← whnfD (← inferType x)
|
||||
assert! xType.isAppOfArity ``Sum 2
|
||||
assert! xType.isAppOfArity ``PSum 2
|
||||
let xTypeArgs := xType.getAppArgs
|
||||
let casesOn := mkConst (mkCasesOnName ``Sum) (mkLevelSucc u :: xType.getAppFn.constLevels!)
|
||||
let casesOn := mkConst (mkCasesOnName ``PSum) (mkLevelSucc u :: xType.getAppFn.constLevels!)
|
||||
let casesOn := mkAppN casesOn xTypeArgs -- parameters
|
||||
let casesOn := mkApp casesOn (← mkLambdaFVars #[x] (mkSort u)) -- motive
|
||||
let casesOn := mkApp casesOn x -- major
|
||||
|
|
@ -58,7 +58,7 @@ private partial def mkNewCoDomain (x : Expr) (preDefs : Array PreDefinition) : M
|
|||
|
||||
/--
|
||||
Combine/pack the values of the different definitions in a single value
|
||||
`x` is `Sum`, and we use `Sum.casesOn` to select the appropriate `preDefs.value`.
|
||||
`x` is `PSum`, and we use `PSum.casesOn` to select the appropriate `preDefs.value`.
|
||||
See: `packMutual`.
|
||||
Remark: this method does not replace the nested recursive `preDefs` applications.
|
||||
This step is performed by `transform` with the following `post` method.
|
||||
|
|
@ -100,15 +100,15 @@ private partial def post (preDefs : Array PreDefinition) (domain : Expr) (newFn
|
|||
(← whnfD type).withApp fun f args => do
|
||||
assert! args.size == 2
|
||||
if i == fidx then
|
||||
return mkApp3 (mkConst ``Sum.inl f.constLevels!) args[0] args[1] arg
|
||||
return mkApp3 (mkConst ``PSum.inl f.constLevels!) args[0] args[1] arg
|
||||
else
|
||||
let r ← mkNewArg (i+1) args[1]
|
||||
return mkApp3 (mkConst ``Sum.inr f.constLevels!) args[0] args[1] r
|
||||
return mkApp3 (mkConst ``PSum.inr f.constLevels!) args[0] args[1] r
|
||||
return TransformStep.done <| mkApp (mkConst newFn us) (← mkNewArg 0 domain)
|
||||
return TransformStep.done e
|
||||
|
||||
/--
|
||||
If `preDefs.size > 1`, combine different functions in a single one using `Sum`.
|
||||
If `preDefs.size > 1`, combine different functions in a single one using `PSum`.
|
||||
This method assumes all `preDefs` have arity 1, and have already been processed using `packDomain`.
|
||||
Here is a small example. Suppose the input is
|
||||
```
|
||||
|
|
@ -128,22 +128,22 @@ private partial def post (preDefs : Array PreDefinition) (domain : Expr) (newFn
|
|||
this method produces the following pre definition
|
||||
```
|
||||
f._mutual x :=
|
||||
Sum.casesOn x
|
||||
PSum.casesOn x
|
||||
(fun val =>
|
||||
match val.2.1, val.2.2.1, val.2.2.2 with
|
||||
| 0, a, b => a
|
||||
| Nat.succ n, a, b => (f._mutual (Sum.inr (Sum.inl ⟨val.1, n, a, b⟩))).fst
|
||||
| Nat.succ n, a, b => (f._mutual (PSum.inr (PSum.inl ⟨val.1, n, a, b⟩))).fst
|
||||
fun val =>
|
||||
Sum.casesOn val
|
||||
PSum.casesOn val
|
||||
(fun val =>
|
||||
match val.2.1, val.2.2.1, val.2.2.2 with
|
||||
| 0, a, b => (a, b)
|
||||
| Nat.succ n, a, b => (f._mutual (Sum.inr (Sum.inr ⟨val.1, n, a, b⟩)), a)
|
||||
| Nat.succ n, a, b => (f._mutual (PSum.inr (PSum.inr ⟨val.1, n, a, b⟩)), a)
|
||||
fun val =>
|
||||
match val.2.1, val.2.2.1, val.2.2.2 with
|
||||
| 0, a, b => b
|
||||
| Nat.succ n, a, b =>
|
||||
f._mutual (Sum.inl ⟨val.1, n, a, b⟩)
|
||||
f._mutual (PSum.inl ⟨val.1, n, a, b⟩)
|
||||
```
|
||||
-/
|
||||
def packMutual (preDefs : Array PreDefinition) : MetaM PreDefinition := do
|
||||
|
|
|
|||
|
|
@ -14,9 +14,9 @@ end
|
|||
termination_by'
|
||||
invImage
|
||||
(fun
|
||||
| Sum.inl ⟨n, true⟩ => (n, 2)
|
||||
| Sum.inl ⟨n, false⟩ => (n, 1)
|
||||
| Sum.inr n => (n, 0))
|
||||
| PSum.inl ⟨n, true⟩ => (n, 2)
|
||||
| PSum.inl ⟨n, false⟩ => (n, 1)
|
||||
| PSum.inr n => (n, 0))
|
||||
$ Prod.lex sizeOfWFRel sizeOfWFRel
|
||||
decreasing_by
|
||||
simp [invImage, InvImage, Prod.lex, sizeOfWFRel, measure, Nat.lt_wfRel, WellFoundedRelation.rel]
|
||||
|
|
|
|||
|
|
@ -22,5 +22,5 @@ mutual
|
|||
| 0 => false
|
||||
| n+1 => isEven n
|
||||
end
|
||||
termination_by' measure (fun n => match n with | Sum.inl n => n | Sum.inr n => n)
|
||||
termination_by' measure (fun n => match n with | PSum.inl n => n | PSum.inr n => n)
|
||||
end Ex3
|
||||
|
|
|
|||
|
|
@ -7,8 +7,8 @@ mutual
|
|||
| n+1 => isEven n
|
||||
end
|
||||
termination_by' measure fun
|
||||
| Sum.inl n => n
|
||||
| Sum.inr n => n
|
||||
| PSum.inl n => n
|
||||
| PSum.inr n => n
|
||||
decreasing_by
|
||||
simp [measure, invImage, InvImage, Nat.lt_wfRel]
|
||||
apply Nat.lt_succ_self
|
||||
|
|
|
|||
|
|
@ -7,8 +7,8 @@ mutual
|
|||
| n+1 => isEven n
|
||||
end
|
||||
termination_by' measure fun
|
||||
| Sum.inl n => n
|
||||
| Sum.inr n => n
|
||||
| PSum.inl n => n
|
||||
| PSum.inr n => n
|
||||
decreasing_by apply Nat.lt_succ_self
|
||||
|
||||
theorem isEven_double (x : Nat) : isEven (2 * x) = true := by
|
||||
|
|
|
|||
|
|
@ -15,9 +15,9 @@ end
|
|||
termination_by'
|
||||
invImage
|
||||
(fun
|
||||
| Sum.inl ⟨_, n, _, _⟩ => (n, 2)
|
||||
| Sum.inr <| Sum.inl ⟨_, _, n, _⟩ => (n, 1)
|
||||
| Sum.inr <| Sum.inr ⟨_, _, _, n⟩ => (n, 0))
|
||||
| PSum.inl ⟨_, n, _, _⟩ => (n, 2)
|
||||
| PSum.inr <| PSum.inl ⟨_, _, n, _⟩ => (n, 1)
|
||||
| PSum.inr <| PSum.inr ⟨_, _, _, n⟩ => (n, 0))
|
||||
(Prod.lex sizeOfWFRel sizeOfWFRel)
|
||||
|
||||
#print f
|
||||
|
|
|
|||
|
|
@ -8,8 +8,8 @@ mutual
|
|||
| n+1 => isEven n
|
||||
end
|
||||
termination_by' measure fun
|
||||
| Sum.inl n => n
|
||||
| Sum.inr n => n
|
||||
| PSum.inl n => n
|
||||
| PSum.inr n => n
|
||||
|
||||
#print isEven
|
||||
#print isOdd
|
||||
|
|
|
|||
|
|
@ -15,9 +15,9 @@ end
|
|||
termination_by'
|
||||
invImage
|
||||
(fun
|
||||
| Sum.inl ⟨_, n, _, _⟩ => (n, 2)
|
||||
| Sum.inr <| Sum.inl ⟨_, _, n, _⟩ => (n, 1)
|
||||
| Sum.inr <| Sum.inr ⟨_, _, _, n⟩ => (n, 0))
|
||||
| PSum.inl ⟨_, n, _, _⟩ => (n, 2)
|
||||
| PSum.inr <| PSum.inl ⟨_, _, n, _⟩ => (n, 1)
|
||||
| PSum.inr <| PSum.inr ⟨_, _, _, n⟩ => (n, 0))
|
||||
(Prod.lex sizeOfWFRel sizeOfWFRel)
|
||||
decreasing_by
|
||||
simp [invImage, InvImage, Prod.lex, sizeOfWFRel, measure, Nat.lt_wfRel, WellFoundedRelation.rel]
|
||||
|
|
@ -51,9 +51,9 @@ end
|
|||
termination_by'
|
||||
invImage
|
||||
(fun
|
||||
| Sum.inl ⟨_, n, _, _⟩ => (n, 2)
|
||||
| Sum.inr <| Sum.inl ⟨_, _, n, _⟩ => (n, 1)
|
||||
| Sum.inr <| Sum.inr ⟨_, _, _, n⟩ => (n, 0))
|
||||
| PSum.inl ⟨_, n, _, _⟩ => (n, 2)
|
||||
| PSum.inr <| PSum.inl ⟨_, _, n, _⟩ => (n, 1)
|
||||
| PSum.inr <| PSum.inr ⟨_, _, _, n⟩ => (n, 0))
|
||||
(Prod.lex sizeOfWFRel sizeOfWFRel)
|
||||
|
||||
#print f._unary._mutual
|
||||
|
|
|
|||
16
tests/lean/run/psumAtWF.lean
Normal file
16
tests/lean/run/psumAtWF.lean
Normal file
|
|
@ -0,0 +1,16 @@
|
|||
mutual
|
||||
|
||||
def fn (f : α → α) (a : α) (n : Nat) : α :=
|
||||
match n with
|
||||
| 0 => a
|
||||
| n+1 => gn f (f (f a)) (f a) n
|
||||
|
||||
def gn (f : α → α) (a b : α) (n : Nat) : α :=
|
||||
match n with
|
||||
| 0 => b
|
||||
| n+1 => fn f (f b) n
|
||||
|
||||
end
|
||||
termination_by
|
||||
fn n => n
|
||||
gn n => n
|
||||
|
|
@ -14,8 +14,8 @@ mutual
|
|||
| n+1 => isEven n
|
||||
end
|
||||
termination_by' measure fun
|
||||
| Sum.inl n => n
|
||||
| Sum.inr n => n
|
||||
| PSum.inl n => n
|
||||
| PSum.inr n => n
|
||||
decreasing_by
|
||||
simp [measure, invImage, InvImage, Nat.lt_wfRel]
|
||||
apply Nat.lt_succ_self
|
||||
|
|
|
|||
|
|
@ -19,8 +19,8 @@ end
|
|||
termination_by'
|
||||
invImage
|
||||
(fun
|
||||
| Sum.inl ⟨_, n⟩ => (n, 0)
|
||||
| Sum.inr ⟨_, n⟩ => (n, 1))
|
||||
| PSum.inl ⟨_, n⟩ => (n, 0)
|
||||
| PSum.inr ⟨_, n⟩ => (n, 1))
|
||||
(Prod.lex sizeOfWFRel sizeOfWFRel)
|
||||
decreasing_by
|
||||
simp [invImage, InvImage, Prod.lex, sizeOfWFRel, measure, Nat.lt_wfRel, WellFoundedRelation.rel]
|
||||
|
|
|
|||
|
|
@ -21,9 +21,9 @@ end
|
|||
termination_by'
|
||||
invImage
|
||||
(fun
|
||||
| Sum.inl ⟨_, n, _, _⟩ => (n, 2)
|
||||
| Sum.inr <| Sum.inl ⟨_, _, n, _⟩ => (n, 1)
|
||||
| Sum.inr <| Sum.inr ⟨_, _, _, n⟩ => (n, 0))
|
||||
| PSum.inl ⟨_, n, _, _⟩ => (n, 2)
|
||||
| PSum.inr <| PSum.inl ⟨_, _, n, _⟩ => (n, 1)
|
||||
| PSum.inr <| PSum.inr ⟨_, _, _, n⟩ => (n, 0))
|
||||
(Prod.lex sizeOfWFRel sizeOfWFRel)
|
||||
decreasing_by
|
||||
simp [invImage, InvImage, Prod.lex, sizeOfWFRel, measure, Nat.lt_wfRel, WellFoundedRelation.rel]
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ mutual
|
|||
| 0 => false
|
||||
| n+1 => isEven n
|
||||
end
|
||||
termination_by' measure fun | Sum.inl n => n | Sum.inr n => n
|
||||
termination_by' measure fun | PSum.inl n => n | PSum.inr n => n
|
||||
decreasing_by apply Nat.lt_succ_self
|
||||
|
||||
theorem isEven_double (x : Nat) : isEven (2 * x) = true := by
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue