refactor: move unpackArg etc. to WF.PackDomain/WF.PackMutual (#3077)
extracted from #3040 to keep the diff smaller
This commit is contained in:
parent
4dd59690e0
commit
7acbee8ae4
3 changed files with 72 additions and 48 deletions
|
|
@ -14,6 +14,7 @@ import Lean.Elab.RecAppSyntax
|
|||
import Lean.Elab.PreDefinition.Basic
|
||||
import Lean.Elab.PreDefinition.Structural.Basic
|
||||
import Lean.Elab.PreDefinition.WF.TerminationHint
|
||||
import Lean.Elab.PreDefinition.WF.PackMutual
|
||||
import Lean.Data.Array
|
||||
|
||||
|
||||
|
|
@ -263,39 +264,6 @@ def filterSubsumed (rcs : Array RecCallWithContext ) : Array RecCallWithContext
|
|||
return (false, true)
|
||||
return (true, true)
|
||||
|
||||
/-- Given the packed argument of a (possibly) mutual and (possibly) nary call,
|
||||
return the function index that is called and the arguments individually.
|
||||
|
||||
We expect precisely the expressions produced by `packMutual`, with manifest
|
||||
`PSum.inr`, `PSum.inl` and `PSigma.mk` constructors, and thus take them apart
|
||||
rather than using projectinos. -/
|
||||
def unpackArg {m} [Monad m] [MonadError m] (arities : Array Nat) (e : Expr) :
|
||||
m (Nat × Array Expr) := do
|
||||
-- count PSum injections to find out which function is doing the call
|
||||
let mut funidx := 0
|
||||
let mut e := e
|
||||
while funidx + 1 < arities.size do
|
||||
if e.isAppOfArity ``PSum.inr 3 then
|
||||
e := e.getArg! 2
|
||||
funidx := funidx + 1
|
||||
else if e.isAppOfArity ``PSum.inl 3 then
|
||||
e := e.getArg! 2
|
||||
break
|
||||
else
|
||||
throwError "Unexpected expression while unpacking mutual argument"
|
||||
|
||||
-- now unpack PSigmas
|
||||
let arity := arities[funidx]!
|
||||
let mut args := #[]
|
||||
while args.size + 1 < arity do
|
||||
if e.isAppOfArity ``PSigma.mk 4 then
|
||||
args := args.push (e.getArg! 2)
|
||||
e := e.getArg! 3
|
||||
else
|
||||
throwError "Unexpected expression while unpacking n-ary argument"
|
||||
args := args.push e
|
||||
return (funidx, args)
|
||||
|
||||
/-- Traverse a unary PreDefinition, and returns a `WithRecCall` closure for each recursive
|
||||
call site.
|
||||
-/
|
||||
|
|
|
|||
|
|
@ -40,6 +40,19 @@ where
|
|||
else
|
||||
return args[i]!
|
||||
|
||||
/-- Unpacks a unary packed argument created with `mkUnaryArg`. -/
|
||||
def unpackUnaryArg {m} [Monad m] [MonadError m] (arity : Nat) (e : Expr) : m (Array Expr) := do
|
||||
let mut e := e
|
||||
let mut args := #[]
|
||||
while args.size + 1 < arity do
|
||||
if e.isAppOfArity ``PSigma.mk 4 then
|
||||
args := args.push (e.getArg! 2)
|
||||
e := e.getArg! 3
|
||||
else
|
||||
throwError "Unexpected expression while unpacking n-ary argument"
|
||||
args := args.push e
|
||||
return args
|
||||
|
||||
private partial def mkPSigmaCasesOn (y : Expr) (codomain : Expr) (xs : Array Expr) (value : Expr) : MetaM Expr := do
|
||||
let mvar ← mkFreshExprSyntheticOpaqueMVar codomain
|
||||
let rec go (mvarId : MVarId) (y : FVarId) (ys : Array Expr) : MetaM Unit := do
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ Authors: Leonardo de Moura
|
|||
-/
|
||||
import Lean.Meta.Tactic.Cases
|
||||
import Lean.Elab.PreDefinition.Basic
|
||||
import Lean.Elab.PreDefinition.WF.PackDomain
|
||||
|
||||
namespace Lean.Elab.WF
|
||||
open Meta
|
||||
|
|
@ -110,8 +111,60 @@ def withAppN (n : Nat) (e : Expr) (k : Array Expr → MetaM Expr) : MetaM Expr :
|
|||
mkLambdaFVars xs e'
|
||||
|
||||
/--
|
||||
Auxiliary function for replacing nested `preDefs` recursive calls in `e` with the new function `newFn`.
|
||||
See: `packMutual`
|
||||
If `arg` is the argument to the `fidx`th of the `numFuncs` in the recursive group,
|
||||
then `mkMutualArg` packs that argument in `PSum.inl` and `PSum.inr` constructors
|
||||
to create the mutual-packed argument of type `domain`.
|
||||
-/
|
||||
partial def mkMutualArg (numFuncs : Nat) (domain : Expr) (fidx : Nat) (arg : Expr) : MetaM Expr := do
|
||||
let rec go (i : Nat) (type : Expr) : MetaM Expr := do
|
||||
if i == numFuncs - 1 then
|
||||
return arg
|
||||
else
|
||||
(← whnfD type).withApp fun f args => do
|
||||
assert! args.size == 2
|
||||
if i == fidx then
|
||||
return mkApp3 (mkConst ``PSum.inl f.constLevels!) args[0]! args[1]! arg
|
||||
else
|
||||
let r ← go (i+1) args[1]!
|
||||
return mkApp3 (mkConst ``PSum.inr f.constLevels!) args[0]! args[1]! r
|
||||
go 0 domain
|
||||
|
||||
/--
|
||||
Unpacks a mutually packed argument, returning the argument and function index.
|
||||
Inverse of `mkMutualArg`. Cf. `unpackUnaryArg` and `unpackArg`, which does both
|
||||
-/
|
||||
def unpackMutualArg {m} [Monad m] [MonadError m] (numFuncs : Nat) (e : Expr) : m (Nat × Expr) := do
|
||||
let mut funidx := 0
|
||||
let mut e := e
|
||||
while funidx + 1 < numFuncs do
|
||||
if e.isAppOfArity ``PSum.inr 3 then
|
||||
e := e.getArg! 2
|
||||
funidx := funidx + 1
|
||||
else if e.isAppOfArity ``PSum.inl 3 then
|
||||
e := e.getArg! 2
|
||||
break
|
||||
else
|
||||
throwError "Unexpected expression while unpacking mutual argument"
|
||||
return (funidx, e)
|
||||
|
||||
/--
|
||||
Given the packed argument of a (possibly) mutual and (possibly) nary call,
|
||||
return the function index that is called and the arguments individually.
|
||||
|
||||
We expect precisely the expressions produced by `packMutual`, with manifest
|
||||
`PSum.inr`, `PSum.inl` and `PSigma.mk` constructors, and thus take them apart
|
||||
rather than using projectinos.
|
||||
-/
|
||||
def unpackArg {m} [Monad m] [MonadError m] (arities : Array Nat) (e : Expr) :
|
||||
m (Nat × Array Expr) := do
|
||||
let (funidx, e) ← unpackMutualArg arities.size e
|
||||
let args ← unpackUnaryArg arities[funidx]! e
|
||||
return (funidx, args)
|
||||
|
||||
|
||||
/--
|
||||
Auxiliary function for replacing nested `preDefs` recursive calls in `e` with the new function `newFn`.
|
||||
See: `packMutual`
|
||||
-/
|
||||
private partial def post (fixedPrefix : Nat) (preDefs : Array PreDefinition) (domain : Expr) (newFn : Name) (e : Expr) : MetaM TransformStep := do
|
||||
let f := e.getAppFn
|
||||
|
|
@ -122,19 +175,9 @@ private partial def post (fixedPrefix : Nat) (preDefs : Array PreDefinition) (do
|
|||
if let some fidx := preDefs.findIdx? (·.declName == declName) then
|
||||
let e' ← withAppN (fixedPrefix + 1) e fun args => do
|
||||
let fixedArgs := args[:fixedPrefix]
|
||||
let arg := args[fixedPrefix]!
|
||||
let rec mkNewArg (i : Nat) (type : Expr) : MetaM Expr := do
|
||||
if i == preDefs.size - 1 then
|
||||
return arg
|
||||
else
|
||||
(← whnfD type).withApp fun f args => do
|
||||
assert! args.size == 2
|
||||
if i == fidx then
|
||||
return mkApp3 (mkConst ``PSum.inl f.constLevels!) args[0]! args[1]! arg
|
||||
else
|
||||
let r ← mkNewArg (i+1) args[1]!
|
||||
return mkApp3 (mkConst ``PSum.inr f.constLevels!) args[0]! args[1]! r
|
||||
return mkApp (mkAppN (mkConst newFn us) fixedArgs) (← mkNewArg 0 domain)
|
||||
let arg := args[fixedPrefix]!
|
||||
let packedArg ← mkMutualArg preDefs.size domain fidx arg
|
||||
return mkApp (mkAppN (mkConst newFn us) fixedArgs) packedArg
|
||||
return TransformStep.done e'
|
||||
return TransformStep.done e
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue