fix: PackMutual: Eta-Expand as needed (#2902)
The `packMutual` code ought to reliably replace all recursive calls to the functions in `preDefs`, even when they are under- or over-applied. Therefore eta-expand if need rsp. keep extra arguments around. Needs a tweak to `Meta.transform` to avoid mistaking the `f` in `f x1 x2` as a zero-arity application. Includes a test case. This fixes #2628 and #2883.
This commit is contained in:
parent
dede354e77
commit
260eaebf4e
4 changed files with 156 additions and 23 deletions
|
|
@ -90,36 +90,52 @@ private partial def packValues (x : Expr) (codomain : Expr) (preDefValues : Arra
|
|||
go mvar.mvarId! x.fvarId! 0
|
||||
instantiateMVars mvar
|
||||
|
||||
/--
|
||||
Pass the first `n` arguments of `e` to the continuation, and apply the result to the
|
||||
remaining arguments. If `e` does not have enough arguments, it is eta-expanded as needed.
|
||||
|
||||
Unlike `Meta.etaExpand` does not use `withDefault`.
|
||||
-/
|
||||
def withAppN (n : Nat) (e : Expr) (k : Array Expr → MetaM Expr) : MetaM Expr := do
|
||||
let args := e.getAppArgs
|
||||
if n ≤ args.size then
|
||||
let e' ← k args[:n]
|
||||
return mkAppN e' args[n:]
|
||||
else
|
||||
let missing := n - args.size
|
||||
forallBoundedTelescope (← inferType e) missing fun xs _ => do
|
||||
if xs.size < missing then
|
||||
throwError "Failed to eta-expand partial application"
|
||||
let e' ← k (args ++ xs)
|
||||
mkLambdaFVars xs e'
|
||||
|
||||
/--
|
||||
Auxiliary function for replacing nested `preDefs` recursive calls in `e` with the new function `newFn`.
|
||||
See: `packMutual`
|
||||
-/
|
||||
private partial def post (fixedPrefix : Nat) (preDefs : Array PreDefinition) (domain : Expr) (newFn : Name) (e : Expr) : MetaM TransformStep := do
|
||||
if e.getAppNumArgs < fixedPrefix + 1 then
|
||||
return TransformStep.done e
|
||||
let f := e.getAppFn
|
||||
if !f.isConst then
|
||||
return TransformStep.done e
|
||||
let declName := f.constName!
|
||||
let us := f.constLevels!
|
||||
if let some fidx := preDefs.findIdx? (·.declName == declName) then
|
||||
let args := e.getAppArgs
|
||||
let fixedArgs := args[:fixedPrefix]
|
||||
let arg := args[fixedPrefix]!
|
||||
let remaining := args[fixedPrefix+1:]
|
||||
let rec mkNewArg (i : Nat) (type : Expr) : MetaM Expr := do
|
||||
if i == preDefs.size - 1 then
|
||||
return arg
|
||||
else
|
||||
(← whnfD type).withApp fun f args => do
|
||||
assert! args.size == 2
|
||||
if i == fidx then
|
||||
return mkApp3 (mkConst ``PSum.inl f.constLevels!) args[0]! args[1]! arg
|
||||
else
|
||||
let r ← mkNewArg (i+1) args[1]!
|
||||
return mkApp3 (mkConst ``PSum.inr f.constLevels!) args[0]! args[1]! r
|
||||
return TransformStep.done <|
|
||||
mkAppN (mkApp (mkAppN (mkConst newFn us) fixedArgs) (← mkNewArg 0 domain)) remaining
|
||||
let e' ← withAppN (fixedPrefix + 1) e fun args => do
|
||||
let fixedArgs := args[:fixedPrefix]
|
||||
let arg := args[fixedPrefix]!
|
||||
let rec mkNewArg (i : Nat) (type : Expr) : MetaM Expr := do
|
||||
if i == preDefs.size - 1 then
|
||||
return arg
|
||||
else
|
||||
(← whnfD type).withApp fun f args => do
|
||||
assert! args.size == 2
|
||||
if i == fidx then
|
||||
return mkApp3 (mkConst ``PSum.inl f.constLevels!) args[0]! args[1]! arg
|
||||
else
|
||||
let r ← mkNewArg (i+1) args[1]!
|
||||
return mkApp3 (mkConst ``PSum.inr f.constLevels!) args[0]! args[1]! r
|
||||
return mkApp (mkAppN (mkConst newFn us) fixedArgs) (← mkNewArg 0 domain)
|
||||
return TransformStep.done e'
|
||||
return TransformStep.done e
|
||||
|
||||
partial def withFixedPrefix (fixedPrefix : Nat) (preDefs : Array PreDefinition) (k : Array Expr → Array Expr → Array Expr → MetaM α) : MetaM α :=
|
||||
|
|
@ -185,7 +201,7 @@ def packMutual (fixedPrefix : Nat) (preDefsOriginal : Array PreDefinition) (preD
|
|||
let newFn := preDefs[0]!.declName ++ `_mutual
|
||||
let preDefNew := { preDefs[0]! with declName := newFn, type, value }
|
||||
addAsAxiom preDefNew
|
||||
let value ← transform value (post := post fixedPrefix preDefs domain newFn)
|
||||
let value ← transform value (skipConstInApp := true) (post := post fixedPrefix preDefs domain newFn)
|
||||
let value ← mkLambdaFVars (ys.push x) value
|
||||
return { preDefNew with value }
|
||||
|
||||
|
|
|
|||
|
|
@ -73,12 +73,18 @@ namespace Meta
|
|||
|
||||
/--
|
||||
Similar to `Core.transform`, but terms provided to `pre` and `post` do not contain loose bound variables.
|
||||
So, it is safe to use any `MetaM` method at `pre` and `post`. -/
|
||||
So, it is safe to use any `MetaM` method at `pre` and `post`.
|
||||
|
||||
If `skipConstInApp := true`, then for an expression `mkAppN (.const f) args`, the subexpression
|
||||
`.const f` is not visited again. Put differently: every `.const f` is visited once, with its
|
||||
arguments if present, on its own otherwise.
|
||||
-/
|
||||
partial def transform {m} [Monad m] [MonadLiftT MetaM m] [MonadControlT MetaM m] [MonadTrace m] [MonadRef m] [MonadOptions m] [AddMessageContext m]
|
||||
(input : Expr)
|
||||
(pre : Expr → m TransformStep := fun _ => return .continue)
|
||||
(post : Expr → m TransformStep := fun e => return .done e)
|
||||
(usedLetOnly := false)
|
||||
(skipConstInApp := false)
|
||||
: m Expr := do
|
||||
let _ : STWorld IO.RealWorld m := ⟨⟩
|
||||
let _ : MonadLiftT (ST IO.RealWorld) m := { monadLift := fun x => liftM (m := MetaM) (liftM (m := ST IO.RealWorld) x) }
|
||||
|
|
@ -109,7 +115,10 @@ partial def transform {m} [Monad m] [MonadLiftT MetaM m] [MonadControlT MetaM m]
|
|||
| e => visitPost (← mkLetFVars (usedLetOnly := usedLetOnly) fvars (← visit (e.instantiateRev fvars)))
|
||||
let visitApp (e : Expr) : MonadCacheT ExprStructEq Expr m Expr :=
|
||||
e.withApp fun f args => do
|
||||
visitPost (mkAppN (← visit f) (← args.mapM visit))
|
||||
if skipConstInApp && f.isConst then
|
||||
visitPost (mkAppN f (← args.mapM visit))
|
||||
else
|
||||
visitPost (mkAppN (← visit f) (← args.mapM visit))
|
||||
match (← pre e) with
|
||||
| .done e => pure e
|
||||
| .visit e => visit e
|
||||
|
|
|
|||
108
tests/lean/run/issue2628.lean
Normal file
108
tests/lean/run/issue2628.lean
Normal file
|
|
@ -0,0 +1,108 @@
|
|||
/-!
|
||||
Test that PackMutual isn't confused when a recursive call has more arguments than is packed
|
||||
into the unary argument, which can happen if the retturn type is a function type.
|
||||
-/
|
||||
|
||||
namespace Ex1
|
||||
mutual
|
||||
def foo : Nat → Nat
|
||||
| .zero => 0
|
||||
| .succ n => (id bar) n
|
||||
def bar : Nat → Nat
|
||||
| .zero => 0
|
||||
| .succ n => foo n
|
||||
end
|
||||
termination_by foo n => n; bar n => n
|
||||
decreasing_by sorry
|
||||
|
||||
end Ex1
|
||||
|
||||
-- Same for n-ary functions
|
||||
|
||||
opaque id' : ∀ {α}, α → α := id
|
||||
|
||||
namespace Ex2
|
||||
|
||||
mutual
|
||||
def foo : Nat → Nat → Nat
|
||||
| .zero, _m => 0
|
||||
| .succ n, .zero => (id' (bar n)) .zero
|
||||
| .succ n, m => (id' bar) n m
|
||||
def bar : Nat → Nat → Nat
|
||||
| .zero, _m => 0
|
||||
| .succ n, m => foo n m
|
||||
end
|
||||
termination_by foo n m => (n,m); bar n m => (n,m)
|
||||
decreasing_by sorry
|
||||
|
||||
end Ex2
|
||||
|
||||
-- With extra arguments
|
||||
|
||||
namespace Ex3
|
||||
mutual
|
||||
def foo : Nat → Nat → Nat
|
||||
| .zero => fun _ => 0
|
||||
| .succ n => fun m => (id bar) n m
|
||||
def bar : Nat → Nat → Nat
|
||||
| .zero => fun _ => 0
|
||||
| .succ n => fun m => foo n m
|
||||
end
|
||||
termination_by foo n => n; bar n => n
|
||||
decreasing_by sorry
|
||||
|
||||
end Ex3
|
||||
|
||||
-- n-ary and with extra arguments
|
||||
|
||||
namespace Ex4
|
||||
|
||||
mutual
|
||||
def foo : Nat → Nat → Nat → Nat
|
||||
| .zero, _m => fun _ => 0
|
||||
| .succ n, .zero => fun k => (id' (bar n)) .zero k
|
||||
| .succ n, m => fun k => (id' bar) n m k
|
||||
def bar : Nat → Nat → Nat → Nat
|
||||
| .zero, _m => fun _ => 0
|
||||
| .succ n, m => fun k => foo n m k
|
||||
end
|
||||
termination_by foo n m => (n,m); bar n m => (n,m)
|
||||
decreasing_by sorry
|
||||
|
||||
end Ex4
|
||||
|
||||
-- Check that eta-expansion works even if the function does not
|
||||
-- have a function type
|
||||
namespace Ex5
|
||||
def FunType := Nat → Nat
|
||||
|
||||
mutual
|
||||
def foo : FunType
|
||||
| .zero => 0
|
||||
| .succ n => (id bar) n
|
||||
def bar : Nat → Nat
|
||||
| .zero => 0
|
||||
| .succ n => foo n
|
||||
end
|
||||
termination_by foo n => n; bar n => n
|
||||
decreasing_by sorry
|
||||
|
||||
end Ex5
|
||||
|
||||
|
||||
namespace Ex6
|
||||
def Fun3Type := Nat → Nat → Nat
|
||||
|
||||
mutual
|
||||
def foo : Nat → Nat → Nat → Nat
|
||||
| .zero, _m => fun _ => 0
|
||||
| .succ n, .zero => fun k => (id' (bar n)) .zero k
|
||||
| .succ n, m => fun k => (id' bar) n m k
|
||||
def bar : Nat → Fun3Type
|
||||
| .zero, _m => fun _ => 0
|
||||
| .succ n, m => fun k => foo n m k
|
||||
end
|
||||
termination_by foo n m => (n,m); bar n m => (n,m)
|
||||
decreasing_by sorry
|
||||
|
||||
end Ex6
|
||||
|
|
@ -1,6 +1,6 @@
|
|||
/-!
|
||||
Test that PackMutual isn't confused when a recursive call has more arguments than is packed
|
||||
into the unary argument, which can happen if the retturn type is a function type.
|
||||
into the unary argument, which can happen if the return type is a function type.
|
||||
-/
|
||||
|
||||
mutual
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue