From 54dd588fc2a16009f837b88b7d514f08ddfa17ba Mon Sep 17 00:00:00 2001 From: Joachim Breitner Date: Wed, 22 Nov 2023 12:31:36 +0100 Subject: [PATCH] =?UTF-8?q?fix:=20Use=20whnf=20for=20mutual=20recursion=20?= =?UTF-8?q?with=20types=20hiding=20`=E2=86=92`=20(#2926)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit the code stumbled over recursive functions whose type doesn’t have enough manifest foralls, like: ``` def FunType := Nat → Nat mutual def foo : FunType | .zero => 0 | .succ n => bar n def bar : FunType | .zero => 0 | .succ n => foo n end termination_by foo n => n; bar n => n ``` This can be fixed by using `whnf` in appropriate places, to expose the `.forall` constructor. Fixes #2925, comes with test case. --- src/Lean/Elab/PreDefinition/WF/PackDomain.lean | 3 ++- src/Lean/Elab/PreDefinition/WF/PackMutual.lean | 8 ++++---- tests/lean/run/issue2925.lean | 15 +++++++++++++++ 3 files changed, 21 insertions(+), 5 deletions(-) create mode 100644 tests/lean/run/issue2925.lean diff --git a/src/Lean/Elab/PreDefinition/WF/PackDomain.lean b/src/Lean/Elab/PreDefinition/WF/PackDomain.lean index bd319a788a..f2083ec7fb 100644 --- a/src/Lean/Elab/PreDefinition/WF/PackDomain.lean +++ b/src/Lean/Elab/PreDefinition/WF/PackDomain.lean @@ -124,7 +124,8 @@ where let args := e.getAppArgs let fNew := mkConst preDefsNew[funIdx]!.declName f.constLevels! let fNew := mkAppN fNew args[:fixedPrefix] - let Expr.forallE _ d .. ← inferType fNew | unreachable! + let Expr.forallE _ d .. ← whnf (← inferType fNew) | unreachable! + -- NB: Use whnf in case the type is not a manifest forall, but a definition around it let argNew ← mkUnaryArg d args[fixedPrefix:] return mkApp fNew argNew let rec diff --git a/src/Lean/Elab/PreDefinition/WF/PackMutual.lean b/src/Lean/Elab/PreDefinition/WF/PackMutual.lean index aa0c83a0e9..6c005a2bae 100644 --- a/src/Lean/Elab/PreDefinition/WF/PackMutual.lean +++ b/src/Lean/Elab/PreDefinition/WF/PackMutual.lean @@ -51,13 +51,13 @@ private partial def mkNewCoDomain (preDefsOriginal : Array PreDefinition) (preDe let casesOn := mkAppN casesOn xTypeArgs -- parameters let casesOn := mkApp casesOn (← mkLambdaFVars #[x] (mkSort u)) -- motive let casesOn := mkApp casesOn x -- major - let minor1 ← withLocalDeclD (← mkFreshUserName `_x) xTypeArgs[0]! fun x => - mkLambdaFVars #[x] (preDefTypes[i]!.bindingBody!.instantiate1 x) + let minor1 ← withLocalDeclD (← mkFreshUserName `_x) xTypeArgs[0]! fun x => do + mkLambdaFVars #[x] ((← whnf preDefTypes[i]!).bindingBody!.instantiate1 x) let minor2 ← withLocalDeclD (← mkFreshUserName `_x) xTypeArgs[1]! fun x => do mkLambdaFVars #[x] (← go x (i+1)) return mkApp2 casesOn minor1 minor2 else - return preDefTypes[i]!.bindingBody!.instantiate1 x + return (← whnf preDefTypes[i]!).bindingBody!.instantiate1 x go x 0 /-- @@ -176,7 +176,7 @@ where def packMutual (fixedPrefix : Nat) (preDefsOriginal : Array PreDefinition) (preDefs : Array PreDefinition) : MetaM PreDefinition := do if preDefs.size == 1 then return preDefs[0]! withFixedPrefix fixedPrefix preDefs fun ys types vals => do - let domains := types.map fun type => type.bindingDomain! + let domains ← types.mapM fun type => do pure (← whnf type).bindingDomain! let domain ← mkNewDomain domains withLocalDeclD (← mkFreshUserName `_x) domain fun x => do let codomain ← mkNewCoDomain preDefsOriginal types x diff --git a/tests/lean/run/issue2925.lean b/tests/lean/run/issue2925.lean new file mode 100644 index 0000000000..851ad59967 --- /dev/null +++ b/tests/lean/run/issue2925.lean @@ -0,0 +1,15 @@ +def FunType := Nat → Nat +def Fun2Type := Nat → Nat → Nat + +mutual +def foo : FunType + | .zero => 0 + | .succ n => bar n +def bar : Nat → Nat + | .zero => 0 + | .succ n => baz n 0 +def baz : Fun2Type + | .zero, m => 0 + | .succ n, m => foo n +end +termination_by foo n => n; bar n => n; baz n _ => n