diff --git a/RELEASES.md b/RELEASES.md index 29d6e754f2..2dcaa55759 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -68,3 +68,5 @@ example (a : α) (x : Fam α) : α := | Fam.any => a | Fam.nat n => n ``` + +* We now use `PSum` (instead of `Sum`) when compiling mutually recursive definitions using well-founded recursion. diff --git a/src/Lean/Elab/PreDefinition/WF/Fix.lean b/src/Lean/Elab/PreDefinition/WF/Fix.lean index 5a482d4aa2..820d3367e2 100644 --- a/src/Lean/Elab/PreDefinition/WF/Fix.lean +++ b/src/Lean/Elab/PreDefinition/WF/Fix.lean @@ -77,9 +77,9 @@ private partial def replaceRecApps (recFnName : Name) (decrTactic? : Option Synt | e => ensureNoRecFn recFnName e loop F e -/-- Refine `F` over `Sum.casesOn` -/ +/-- Refine `F` over `PSum.casesOn` -/ private partial def processSumCasesOn (x F val : Expr) (k : (x : Expr) → (F : Expr) → (val : Expr) → TermElabM Expr) : TermElabM Expr := do - if x.isFVar && val.isAppOfArity ``Sum.casesOn 6 && val.getArg! 3 == x && (val.getArg! 4).isLambda && (val.getArg! 5).isLambda then + if x.isFVar && val.isAppOfArity ``PSum.casesOn 6 && val.getArg! 3 == x && (val.getArg! 4).isLambda && (val.getArg! 5).isLambda then let args := val.getAppArgs let α := args[0] let β := args[1] @@ -94,9 +94,9 @@ private partial def processSumCasesOn (x F val : Expr) (k : (x : Expr) → (F : let FTypeNew := FDecl.type.replaceFVar x (← mkAppOptM ctorName #[α, β, xNew]) withLocalDeclD FDecl.userName FTypeNew fun FNew => do mkLambdaFVars #[xNew, FNew] (← processSumCasesOn xNew FNew valNew k) - let minorLeft ← mkMinorNew ``Sum.inl args[4] - let minorRight ← mkMinorNew ``Sum.inr args[5] - let result := mkAppN (mkConst ``Sum.casesOn [u, (← getDecLevel α), (← getDecLevel β)]) #[α, β, motiveNew, x, minorLeft, minorRight, F] + let minorLeft ← mkMinorNew ``PSum.inl args[4] + let minorRight ← mkMinorNew ``PSum.inr args[5] + let result := mkAppN (mkConst ``PSum.casesOn [u, (← getLevel α), (← getLevel β)]) #[α, β, motiveNew, x, minorLeft, minorRight, F] return result else k x F val @@ -138,7 +138,8 @@ def mkFix (preDef : PreDefinition) (wfRel : Expr) (decrTactic? : Option Syntax) let x := xs[0] let F := xs[1] let val := preDef.value.betaRev #[x] - let val ← processSumCasesOn x F val fun x F val => processPSigmaCasesOn x F val (replaceRecApps preDef.declName decrTactic?) + let val ← processSumCasesOn x F val fun x F val => do + processPSigmaCasesOn x F val (replaceRecApps preDef.declName decrTactic?) return { preDef with value := mkApp wfFix (← mkLambdaFVars #[x, F] val) } end Lean.Elab.WF diff --git a/src/Lean/Elab/PreDefinition/WF/Main.lean b/src/Lean/Elab/PreDefinition/WF/Main.lean index dff8fd9fff..73e67895c9 100644 --- a/src/Lean/Elab/PreDefinition/WF/Main.lean +++ b/src/Lean/Elab/PreDefinition/WF/Main.lean @@ -38,10 +38,10 @@ private partial def addNonRecPreDefs (preDefs : Array PreDefinition) (preDefNonR (← whnfD type).withApp fun f args => do assert! args.size == 2 if i == fidx then - return mkApp3 (mkConst ``Sum.inl f.constLevels!) args[0] args[1] (← mkProd args[0]) + return mkApp3 (mkConst ``PSum.inl f.constLevels!) args[0] args[1] (← mkProd args[0]) else let r ← mkSum (i+1) args[1] - return mkApp3 (mkConst ``Sum.inr f.constLevels!) args[0] args[1] r + return mkApp3 (mkConst ``PSum.inr f.constLevels!) args[0] args[1] r let arg ← mkSum 0 domain mkLambdaFVars xs (mkApp (mkConst preDefNonRec.declName us) arg) trace[Elab.definition.wf] "{preDef.declName} := {value}" diff --git a/src/Lean/Elab/PreDefinition/WF/PackMutual.lean b/src/Lean/Elab/PreDefinition/WF/PackMutual.lean index d2cc4577fe..1e1a4e1372 100644 --- a/src/Lean/Elab/PreDefinition/WF/PackMutual.lean +++ b/src/Lean/Elab/PreDefinition/WF/PackMutual.lean @@ -11,11 +11,11 @@ open Meta private def getDomains (preDefs : Array PreDefinition) : Array Expr := preDefs.map fun preDef => preDef.type.bindingDomain! -/-- Combine different function domains `ds` using `Sum`s -/ +/-- Combine different function domains `ds` using `PSum`s -/ private def mkNewDomain (ds : Array Expr) : MetaM Expr := do let mut r := ds.back for d in ds.pop.reverse do - r ← mkAppM ``Sum #[d, r] + r ← mkAppM ``PSum #[d, r] return r private def getCodomainLevel (preDef : PreDefinition) : MetaM Level := @@ -41,9 +41,9 @@ private partial def mkNewCoDomain (x : Expr) (preDefs : Array PreDefinition) : M let rec go (x : Expr) (i : Nat) : MetaM Expr := do if i < preDefs.size - 1 then let xType ← whnfD (← inferType x) - assert! xType.isAppOfArity ``Sum 2 + assert! xType.isAppOfArity ``PSum 2 let xTypeArgs := xType.getAppArgs - let casesOn := mkConst (mkCasesOnName ``Sum) (mkLevelSucc u :: xType.getAppFn.constLevels!) + let casesOn := mkConst (mkCasesOnName ``PSum) (mkLevelSucc u :: xType.getAppFn.constLevels!) let casesOn := mkAppN casesOn xTypeArgs -- parameters let casesOn := mkApp casesOn (← mkLambdaFVars #[x] (mkSort u)) -- motive let casesOn := mkApp casesOn x -- major @@ -58,7 +58,7 @@ private partial def mkNewCoDomain (x : Expr) (preDefs : Array PreDefinition) : M /-- Combine/pack the values of the different definitions in a single value - `x` is `Sum`, and we use `Sum.casesOn` to select the appropriate `preDefs.value`. + `x` is `PSum`, and we use `PSum.casesOn` to select the appropriate `preDefs.value`. See: `packMutual`. Remark: this method does not replace the nested recursive `preDefs` applications. This step is performed by `transform` with the following `post` method. @@ -100,15 +100,15 @@ private partial def post (preDefs : Array PreDefinition) (domain : Expr) (newFn (← whnfD type).withApp fun f args => do assert! args.size == 2 if i == fidx then - return mkApp3 (mkConst ``Sum.inl f.constLevels!) args[0] args[1] arg + return mkApp3 (mkConst ``PSum.inl f.constLevels!) args[0] args[1] arg else let r ← mkNewArg (i+1) args[1] - return mkApp3 (mkConst ``Sum.inr f.constLevels!) args[0] args[1] r + return mkApp3 (mkConst ``PSum.inr f.constLevels!) args[0] args[1] r return TransformStep.done <| mkApp (mkConst newFn us) (← mkNewArg 0 domain) return TransformStep.done e /-- - If `preDefs.size > 1`, combine different functions in a single one using `Sum`. + If `preDefs.size > 1`, combine different functions in a single one using `PSum`. This method assumes all `preDefs` have arity 1, and have already been processed using `packDomain`. Here is a small example. Suppose the input is ``` @@ -128,22 +128,22 @@ private partial def post (preDefs : Array PreDefinition) (domain : Expr) (newFn this method produces the following pre definition ``` f._mutual x := - Sum.casesOn x + PSum.casesOn x (fun val => match val.2.1, val.2.2.1, val.2.2.2 with | 0, a, b => a - | Nat.succ n, a, b => (f._mutual (Sum.inr (Sum.inl ⟨val.1, n, a, b⟩))).fst + | Nat.succ n, a, b => (f._mutual (PSum.inr (PSum.inl ⟨val.1, n, a, b⟩))).fst fun val => - Sum.casesOn val + PSum.casesOn val (fun val => match val.2.1, val.2.2.1, val.2.2.2 with | 0, a, b => (a, b) - | Nat.succ n, a, b => (f._mutual (Sum.inr (Sum.inr ⟨val.1, n, a, b⟩)), a) + | Nat.succ n, a, b => (f._mutual (PSum.inr (PSum.inr ⟨val.1, n, a, b⟩)), a) fun val => match val.2.1, val.2.2.1, val.2.2.2 with | 0, a, b => b | Nat.succ n, a, b => - f._mutual (Sum.inl ⟨val.1, n, a, b⟩) + f._mutual (PSum.inl ⟨val.1, n, a, b⟩) ``` -/ def packMutual (preDefs : Array PreDefinition) : MetaM PreDefinition := do diff --git a/tests/lean/mutwf1.lean b/tests/lean/mutwf1.lean index 3453d150e8..aa7886e1cb 100644 --- a/tests/lean/mutwf1.lean +++ b/tests/lean/mutwf1.lean @@ -14,9 +14,9 @@ end termination_by' invImage (fun - | Sum.inl ⟨n, true⟩ => (n, 2) - | Sum.inl ⟨n, false⟩ => (n, 1) - | Sum.inr n => (n, 0)) + | PSum.inl ⟨n, true⟩ => (n, 2) + | PSum.inl ⟨n, false⟩ => (n, 1) + | PSum.inr n => (n, 0)) $ Prod.lex sizeOfWFRel sizeOfWFRel decreasing_by simp [invImage, InvImage, Prod.lex, sizeOfWFRel, measure, Nat.lt_wfRel, WellFoundedRelation.rel] diff --git a/tests/lean/run/955.lean b/tests/lean/run/955.lean index befc5ed3e3..07d5711a9c 100644 --- a/tests/lean/run/955.lean +++ b/tests/lean/run/955.lean @@ -22,5 +22,5 @@ mutual | 0 => false | n+1 => isEven n end -termination_by' measure (fun n => match n with | Sum.inl n => n | Sum.inr n => n) +termination_by' measure (fun n => match n with | PSum.inl n => n | PSum.inr n => n) end Ex3 diff --git a/tests/lean/run/eqnsAtSimp.lean b/tests/lean/run/eqnsAtSimp.lean index 19df2bd306..ba691658cd 100644 --- a/tests/lean/run/eqnsAtSimp.lean +++ b/tests/lean/run/eqnsAtSimp.lean @@ -7,8 +7,8 @@ mutual | n+1 => isEven n end termination_by' measure fun - | Sum.inl n => n - | Sum.inr n => n + | PSum.inl n => n + | PSum.inr n => n decreasing_by simp [measure, invImage, InvImage, Nat.lt_wfRel] apply Nat.lt_succ_self diff --git a/tests/lean/run/eqnsAtSimp2.lean b/tests/lean/run/eqnsAtSimp2.lean index c7667be5bc..862b4de62b 100644 --- a/tests/lean/run/eqnsAtSimp2.lean +++ b/tests/lean/run/eqnsAtSimp2.lean @@ -7,8 +7,8 @@ mutual | n+1 => isEven n end termination_by' measure fun - | Sum.inl n => n - | Sum.inr n => n + | PSum.inl n => n + | PSum.inr n => n decreasing_by apply Nat.lt_succ_self theorem isEven_double (x : Nat) : isEven (2 * x) = true := by diff --git a/tests/lean/run/mutwf1.lean b/tests/lean/run/mutwf1.lean index 3defebb4d6..6250597763 100644 --- a/tests/lean/run/mutwf1.lean +++ b/tests/lean/run/mutwf1.lean @@ -15,9 +15,9 @@ end termination_by' invImage (fun - | Sum.inl ⟨_, n, _, _⟩ => (n, 2) - | Sum.inr <| Sum.inl ⟨_, _, n, _⟩ => (n, 1) - | Sum.inr <| Sum.inr ⟨_, _, _, n⟩ => (n, 0)) + | PSum.inl ⟨_, n, _, _⟩ => (n, 2) + | PSum.inr <| PSum.inl ⟨_, _, n, _⟩ => (n, 1) + | PSum.inr <| PSum.inr ⟨_, _, _, n⟩ => (n, 0)) (Prod.lex sizeOfWFRel sizeOfWFRel) #print f diff --git a/tests/lean/run/mutwf2.lean b/tests/lean/run/mutwf2.lean index 2488b6adfc..fa1d464595 100644 --- a/tests/lean/run/mutwf2.lean +++ b/tests/lean/run/mutwf2.lean @@ -8,8 +8,8 @@ mutual | n+1 => isEven n end termination_by' measure fun - | Sum.inl n => n - | Sum.inr n => n + | PSum.inl n => n + | PSum.inr n => n #print isEven #print isOdd diff --git a/tests/lean/run/mutwf3.lean b/tests/lean/run/mutwf3.lean index 2dccfdcda3..d245dd03a8 100644 --- a/tests/lean/run/mutwf3.lean +++ b/tests/lean/run/mutwf3.lean @@ -15,9 +15,9 @@ end termination_by' invImage (fun - | Sum.inl ⟨_, n, _, _⟩ => (n, 2) - | Sum.inr <| Sum.inl ⟨_, _, n, _⟩ => (n, 1) - | Sum.inr <| Sum.inr ⟨_, _, _, n⟩ => (n, 0)) + | PSum.inl ⟨_, n, _, _⟩ => (n, 2) + | PSum.inr <| PSum.inl ⟨_, _, n, _⟩ => (n, 1) + | PSum.inr <| PSum.inr ⟨_, _, _, n⟩ => (n, 0)) (Prod.lex sizeOfWFRel sizeOfWFRel) decreasing_by simp [invImage, InvImage, Prod.lex, sizeOfWFRel, measure, Nat.lt_wfRel, WellFoundedRelation.rel] @@ -51,9 +51,9 @@ end termination_by' invImage (fun - | Sum.inl ⟨_, n, _, _⟩ => (n, 2) - | Sum.inr <| Sum.inl ⟨_, _, n, _⟩ => (n, 1) - | Sum.inr <| Sum.inr ⟨_, _, _, n⟩ => (n, 0)) + | PSum.inl ⟨_, n, _, _⟩ => (n, 2) + | PSum.inr <| PSum.inl ⟨_, _, n, _⟩ => (n, 1) + | PSum.inr <| PSum.inr ⟨_, _, _, n⟩ => (n, 0)) (Prod.lex sizeOfWFRel sizeOfWFRel) #print f._unary._mutual diff --git a/tests/lean/run/psumAtWF.lean b/tests/lean/run/psumAtWF.lean new file mode 100644 index 0000000000..398b1f03e3 --- /dev/null +++ b/tests/lean/run/psumAtWF.lean @@ -0,0 +1,16 @@ +mutual + +def fn (f : α → α) (a : α) (n : Nat) : α := + match n with + | 0 => a + | n+1 => gn f (f (f a)) (f a) n + +def gn (f : α → α) (a b : α) (n : Nat) : α := + match n with + | 0 => b + | n+1 => fn f (f b) n + +end +termination_by + fn n => n + gn n => n diff --git a/tests/lean/run/wfEqns1.lean b/tests/lean/run/wfEqns1.lean index 90d7e70c0f..59c7de89b5 100644 --- a/tests/lean/run/wfEqns1.lean +++ b/tests/lean/run/wfEqns1.lean @@ -14,8 +14,8 @@ mutual | n+1 => isEven n end termination_by' measure fun - | Sum.inl n => n - | Sum.inr n => n + | PSum.inl n => n + | PSum.inr n => n decreasing_by simp [measure, invImage, InvImage, Nat.lt_wfRel] apply Nat.lt_succ_self diff --git a/tests/lean/run/wfEqns2.lean b/tests/lean/run/wfEqns2.lean index af26806193..4c0b3ec1ee 100644 --- a/tests/lean/run/wfEqns2.lean +++ b/tests/lean/run/wfEqns2.lean @@ -19,8 +19,8 @@ end termination_by' invImage (fun - | Sum.inl ⟨_, n⟩ => (n, 0) - | Sum.inr ⟨_, n⟩ => (n, 1)) + | PSum.inl ⟨_, n⟩ => (n, 0) + | PSum.inr ⟨_, n⟩ => (n, 1)) (Prod.lex sizeOfWFRel sizeOfWFRel) decreasing_by simp [invImage, InvImage, Prod.lex, sizeOfWFRel, measure, Nat.lt_wfRel, WellFoundedRelation.rel] diff --git a/tests/lean/run/wfEqns4.lean b/tests/lean/run/wfEqns4.lean index b85921ec71..873147b05d 100644 --- a/tests/lean/run/wfEqns4.lean +++ b/tests/lean/run/wfEqns4.lean @@ -21,9 +21,9 @@ end termination_by' invImage (fun - | Sum.inl ⟨_, n, _, _⟩ => (n, 2) - | Sum.inr <| Sum.inl ⟨_, _, n, _⟩ => (n, 1) - | Sum.inr <| Sum.inr ⟨_, _, _, n⟩ => (n, 0)) + | PSum.inl ⟨_, n, _, _⟩ => (n, 2) + | PSum.inr <| PSum.inl ⟨_, _, n, _⟩ => (n, 1) + | PSum.inr <| PSum.inr ⟨_, _, _, n⟩ => (n, 0)) (Prod.lex sizeOfWFRel sizeOfWFRel) decreasing_by simp [invImage, InvImage, Prod.lex, sizeOfWFRel, measure, Nat.lt_wfRel, WellFoundedRelation.rel] diff --git a/tests/lean/unfold1.lean b/tests/lean/unfold1.lean index 1b7b2d35ca..8c5d4aa28a 100644 --- a/tests/lean/unfold1.lean +++ b/tests/lean/unfold1.lean @@ -6,7 +6,7 @@ mutual | 0 => false | n+1 => isEven n end -termination_by' measure fun | Sum.inl n => n | Sum.inr n => n +termination_by' measure fun | PSum.inl n => n | PSum.inr n => n decreasing_by apply Nat.lt_succ_self theorem isEven_double (x : Nat) : isEven (2 * x) = true := by