diff --git a/src/Lean/Elab/PreDefinition/WF/PackMutual.lean b/src/Lean/Elab/PreDefinition/WF/PackMutual.lean index 6c005a2bae..71bd812916 100644 --- a/src/Lean/Elab/PreDefinition/WF/PackMutual.lean +++ b/src/Lean/Elab/PreDefinition/WF/PackMutual.lean @@ -90,36 +90,52 @@ private partial def packValues (x : Expr) (codomain : Expr) (preDefValues : Arra go mvar.mvarId! x.fvarId! 0 instantiateMVars mvar +/-- + Pass the first `n` arguments of `e` to the continuation, and apply the result to the + remaining arguments. If `e` does not have enough arguments, it is eta-expanded as needed. + + Unlike `Meta.etaExpand` does not use `withDefault`. +-/ +def withAppN (n : Nat) (e : Expr) (k : Array Expr → MetaM Expr) : MetaM Expr := do + let args := e.getAppArgs + if n ≤ args.size then + let e' ← k args[:n] + return mkAppN e' args[n:] + else + let missing := n - args.size + forallBoundedTelescope (← inferType e) missing fun xs _ => do + if xs.size < missing then + throwError "Failed to eta-expand partial application" + let e' ← k (args ++ xs) + mkLambdaFVars xs e' + /-- Auxiliary function for replacing nested `preDefs` recursive calls in `e` with the new function `newFn`. See: `packMutual` -/ private partial def post (fixedPrefix : Nat) (preDefs : Array PreDefinition) (domain : Expr) (newFn : Name) (e : Expr) : MetaM TransformStep := do - if e.getAppNumArgs < fixedPrefix + 1 then - return TransformStep.done e let f := e.getAppFn if !f.isConst then return TransformStep.done e let declName := f.constName! let us := f.constLevels! if let some fidx := preDefs.findIdx? (·.declName == declName) then - let args := e.getAppArgs - let fixedArgs := args[:fixedPrefix] - let arg := args[fixedPrefix]! - let remaining := args[fixedPrefix+1:] - let rec mkNewArg (i : Nat) (type : Expr) : MetaM Expr := do - if i == preDefs.size - 1 then - return arg - else - (← whnfD type).withApp fun f args => do - assert! args.size == 2 - if i == fidx then - return mkApp3 (mkConst ``PSum.inl f.constLevels!) args[0]! args[1]! arg - else - let r ← mkNewArg (i+1) args[1]! - return mkApp3 (mkConst ``PSum.inr f.constLevels!) args[0]! args[1]! r - return TransformStep.done <| - mkAppN (mkApp (mkAppN (mkConst newFn us) fixedArgs) (← mkNewArg 0 domain)) remaining + let e' ← withAppN (fixedPrefix + 1) e fun args => do + let fixedArgs := args[:fixedPrefix] + let arg := args[fixedPrefix]! + let rec mkNewArg (i : Nat) (type : Expr) : MetaM Expr := do + if i == preDefs.size - 1 then + return arg + else + (← whnfD type).withApp fun f args => do + assert! args.size == 2 + if i == fidx then + return mkApp3 (mkConst ``PSum.inl f.constLevels!) args[0]! args[1]! arg + else + let r ← mkNewArg (i+1) args[1]! + return mkApp3 (mkConst ``PSum.inr f.constLevels!) args[0]! args[1]! r + return mkApp (mkAppN (mkConst newFn us) fixedArgs) (← mkNewArg 0 domain) + return TransformStep.done e' return TransformStep.done e partial def withFixedPrefix (fixedPrefix : Nat) (preDefs : Array PreDefinition) (k : Array Expr → Array Expr → Array Expr → MetaM α) : MetaM α := @@ -185,7 +201,7 @@ def packMutual (fixedPrefix : Nat) (preDefsOriginal : Array PreDefinition) (preD let newFn := preDefs[0]!.declName ++ `_mutual let preDefNew := { preDefs[0]! with declName := newFn, type, value } addAsAxiom preDefNew - let value ← transform value (post := post fixedPrefix preDefs domain newFn) + let value ← transform value (skipConstInApp := true) (post := post fixedPrefix preDefs domain newFn) let value ← mkLambdaFVars (ys.push x) value return { preDefNew with value } diff --git a/src/Lean/Meta/Transform.lean b/src/Lean/Meta/Transform.lean index bcba08287a..1b22604300 100644 --- a/src/Lean/Meta/Transform.lean +++ b/src/Lean/Meta/Transform.lean @@ -73,12 +73,18 @@ namespace Meta /-- Similar to `Core.transform`, but terms provided to `pre` and `post` do not contain loose bound variables. - So, it is safe to use any `MetaM` method at `pre` and `post`. -/ + So, it is safe to use any `MetaM` method at `pre` and `post`. + + If `skipConstInApp := true`, then for an expression `mkAppN (.const f) args`, the subexpression + `.const f` is not visited again. Put differently: every `.const f` is visited once, with its + arguments if present, on its own otherwise. + -/ partial def transform {m} [Monad m] [MonadLiftT MetaM m] [MonadControlT MetaM m] [MonadTrace m] [MonadRef m] [MonadOptions m] [AddMessageContext m] (input : Expr) (pre : Expr → m TransformStep := fun _ => return .continue) (post : Expr → m TransformStep := fun e => return .done e) (usedLetOnly := false) + (skipConstInApp := false) : m Expr := do let _ : STWorld IO.RealWorld m := ⟨⟩ let _ : MonadLiftT (ST IO.RealWorld) m := { monadLift := fun x => liftM (m := MetaM) (liftM (m := ST IO.RealWorld) x) } @@ -109,7 +115,10 @@ partial def transform {m} [Monad m] [MonadLiftT MetaM m] [MonadControlT MetaM m] | e => visitPost (← mkLetFVars (usedLetOnly := usedLetOnly) fvars (← visit (e.instantiateRev fvars))) let visitApp (e : Expr) : MonadCacheT ExprStructEq Expr m Expr := e.withApp fun f args => do - visitPost (mkAppN (← visit f) (← args.mapM visit)) + if skipConstInApp && f.isConst then + visitPost (mkAppN f (← args.mapM visit)) + else + visitPost (mkAppN (← visit f) (← args.mapM visit)) match (← pre e) with | .done e => pure e | .visit e => visit e diff --git a/tests/lean/run/issue2628.lean b/tests/lean/run/issue2628.lean new file mode 100644 index 0000000000..d9c7e9f90d --- /dev/null +++ b/tests/lean/run/issue2628.lean @@ -0,0 +1,108 @@ +/-! +Test that PackMutual isn't confused when a recursive call has more arguments than is packed +into the unary argument, which can happen if the retturn type is a function type. +-/ + +namespace Ex1 +mutual +def foo : Nat → Nat + | .zero => 0 + | .succ n => (id bar) n +def bar : Nat → Nat + | .zero => 0 + | .succ n => foo n +end +termination_by foo n => n; bar n => n +decreasing_by sorry + +end Ex1 + +-- Same for n-ary functions + +opaque id' : ∀ {α}, α → α := id + +namespace Ex2 + +mutual +def foo : Nat → Nat → Nat + | .zero, _m => 0 + | .succ n, .zero => (id' (bar n)) .zero + | .succ n, m => (id' bar) n m +def bar : Nat → Nat → Nat + | .zero, _m => 0 + | .succ n, m => foo n m +end +termination_by foo n m => (n,m); bar n m => (n,m) +decreasing_by sorry + +end Ex2 + +-- With extra arguments + +namespace Ex3 +mutual +def foo : Nat → Nat → Nat + | .zero => fun _ => 0 + | .succ n => fun m => (id bar) n m +def bar : Nat → Nat → Nat + | .zero => fun _ => 0 + | .succ n => fun m => foo n m +end +termination_by foo n => n; bar n => n +decreasing_by sorry + +end Ex3 + +-- n-ary and with extra arguments + +namespace Ex4 + +mutual +def foo : Nat → Nat → Nat → Nat + | .zero, _m => fun _ => 0 + | .succ n, .zero => fun k => (id' (bar n)) .zero k + | .succ n, m => fun k => (id' bar) n m k +def bar : Nat → Nat → Nat → Nat + | .zero, _m => fun _ => 0 + | .succ n, m => fun k => foo n m k +end +termination_by foo n m => (n,m); bar n m => (n,m) +decreasing_by sorry + +end Ex4 + +-- Check that eta-expansion works even if the function does not +-- have a function type +namespace Ex5 +def FunType := Nat → Nat + +mutual +def foo : FunType + | .zero => 0 + | .succ n => (id bar) n +def bar : Nat → Nat + | .zero => 0 + | .succ n => foo n +end +termination_by foo n => n; bar n => n +decreasing_by sorry + +end Ex5 + + +namespace Ex6 +def Fun3Type := Nat → Nat → Nat + +mutual +def foo : Nat → Nat → Nat → Nat + | .zero, _m => fun _ => 0 + | .succ n, .zero => fun k => (id' (bar n)) .zero k + | .succ n, m => fun k => (id' bar) n m k +def bar : Nat → Fun3Type + | .zero, _m => fun _ => 0 + | .succ n, m => fun k => foo n m k +end +termination_by foo n m => (n,m); bar n m => (n,m) +decreasing_by sorry + +end Ex6 diff --git a/tests/lean/run/issue2883.lean b/tests/lean/run/issue2883.lean index 5bb2d3000f..b7541f94ec 100644 --- a/tests/lean/run/issue2883.lean +++ b/tests/lean/run/issue2883.lean @@ -1,6 +1,6 @@ /-! Test that PackMutual isn't confused when a recursive call has more arguments than is packed -into the unary argument, which can happen if the retturn type is a function type. +into the unary argument, which can happen if the return type is a function type. -/ mutual