refactor: lambdaBoundedTelescope (#4642)
we have a `forallBoundedTelescope`, and for a long while I was wondering why we also don't have `lambdaBoundedTelescope`, and every now and then felt the need for it. So let's just add it.
This commit is contained in:
parent
3fb7f632a5
commit
0594bc4e5a
6 changed files with 41 additions and 37 deletions
|
|
@ -153,9 +153,9 @@ private partial def replaceRecApps (recFnName : Name) (recArgInfo : RecArgInfo)
|
|||
trace[Elab.definition.structural] "below before matcherApp.addArg: {below} : {← inferType below}"
|
||||
if let some matcherApp ← matcherApp.addArg? below then
|
||||
let altsNew ← (Array.zip matcherApp.alts matcherApp.altNumParams).mapM fun (alt, numParams) =>
|
||||
lambdaTelescope alt fun xs altBody => do
|
||||
lambdaBoundedTelescope alt numParams fun xs altBody => do
|
||||
trace[Elab.definition.structural] "altNumParams: {numParams}, xs: {xs}"
|
||||
unless xs.size >= numParams do
|
||||
unless xs.size = numParams do
|
||||
throwError "unexpected matcher application alternative{indentExpr alt}\nat application{indentExpr e}"
|
||||
let belowForAlt := xs[numParams - 1]!
|
||||
mkLambdaFVars xs (← loop belowForAlt altBody)
|
||||
|
|
|
|||
|
|
@ -47,17 +47,13 @@ where
|
|||
else
|
||||
let mut altsNew := #[]
|
||||
for alt in matcherApp.alts, numParams in matcherApp.altNumParams do
|
||||
let altNew ← lambdaTelescope alt fun xs altBody => do
|
||||
unless xs.size >= numParams do
|
||||
let altNew ← lambdaBoundedTelescope alt numParams fun xs altBody => do
|
||||
unless xs.size = numParams do
|
||||
throwError "unexpected matcher application alternative{indentExpr alt}\nat application{indentExpr e}"
|
||||
let altBody ← visit altBody
|
||||
let containsSUnfoldMatch := Option.isSome <| altBody.find? fun e => smartUnfoldingMatch? e |>.isSome
|
||||
if !containsSUnfoldMatch then
|
||||
let altBody ← mkLambdaFVars xs[numParams:xs.size] altBody
|
||||
let altBody := markSmartUnfoldingMatchAlt altBody
|
||||
mkLambdaFVars xs[0:numParams] altBody
|
||||
else
|
||||
mkLambdaFVars xs altBody
|
||||
let altBody := if !containsSUnfoldMatch then markSmartUnfoldingMatchAlt altBody else altBody
|
||||
mkLambdaFVars xs altBody
|
||||
altsNew := altsNew.push altNew
|
||||
return markSmartUnfoldingMatch { matcherApp with alts := altsNew }.toExpr
|
||||
| _ => processApp e
|
||||
|
|
|
|||
|
|
@ -104,8 +104,7 @@ This needs extra information:
|
|||
* `extraParams` indicates how many of the functions arguments are bound “after the colon”.
|
||||
-/
|
||||
def TerminationArgument.delab (arity : Nat) (extraParams : Nat) (termArg : TerminationArgument) : MetaM (TSyntax ``terminationBy) := do
|
||||
lambdaTelescope termArg.fn fun ys e => do
|
||||
let e ← mkLambdaFVars ys[arity - extraParams:] e -- undo overshooting by lambdaTelescope
|
||||
lambdaBoundedTelescope termArg.fn (arity - extraParams) fun _ys e => do
|
||||
pure (← delabCore e (delab := go extraParams #[])).1
|
||||
where
|
||||
go : Nat → TSyntaxArray `ident → DelabM (TSyntax ``terminationBy)
|
||||
|
|
|
|||
|
|
@ -81,8 +81,8 @@ where
|
|||
| some matcherApp =>
|
||||
if let some matcherApp ← matcherApp.addArg? F then
|
||||
let altsNew ← (Array.zip matcherApp.alts matcherApp.altNumParams).mapM fun (alt, numParams) =>
|
||||
lambdaTelescope alt fun xs altBody => do
|
||||
unless xs.size >= numParams do
|
||||
lambdaBoundedTelescope alt numParams fun xs altBody => do
|
||||
unless xs.size = numParams do
|
||||
throwError "unexpected matcher application alternative{indentExpr alt}\nat application{indentExpr e}"
|
||||
let FAlt := xs[numParams - 1]!
|
||||
mkLambdaFVars xs (← loop FAlt altBody)
|
||||
|
|
@ -103,12 +103,11 @@ private partial def processSumCasesOn (x F val : Expr) (k : (x : Expr) → (F :
|
|||
let type ← mkArrow (FDecl.type.replaceFVar x xs[0]!) type
|
||||
return (← mkLambdaFVars xs type, ← getLevel type)
|
||||
let mkMinorNew (ctorName : Name) (minor : Expr) : TermElabM Expr :=
|
||||
lambdaTelescope minor fun xs body => do
|
||||
lambdaBoundedTelescope minor 1 fun xs body => do
|
||||
let xNew := xs[0]!
|
||||
let valNew ← mkLambdaFVars xs[1:] body
|
||||
let FTypeNew := FDecl.type.replaceFVar x (← mkAppOptM ctorName #[α, β, xNew])
|
||||
withLocalDeclD FDecl.userName FTypeNew fun FNew => do
|
||||
mkLambdaFVars #[xNew, FNew] (← processSumCasesOn xNew FNew valNew k)
|
||||
mkLambdaFVars #[xNew, FNew] (← processSumCasesOn xNew FNew body k)
|
||||
let minorLeft ← mkMinorNew ``PSum.inl args[4]!
|
||||
let minorRight ← mkMinorNew ``PSum.inr args[5]!
|
||||
let result := mkAppN (mkConst ``PSum.casesOn [u, (← getLevel α), (← getLevel β)]) #[α, β, motiveNew, x, minorLeft, minorRight, F]
|
||||
|
|
|
|||
|
|
@ -256,8 +256,7 @@ where
|
|||
matcherApp.discrs.forM (loop param)
|
||||
(Array.zip matcherApp.alts (Array.zip matcherApp.altNumParams altParams)).forM
|
||||
fun (alt, altNumParam, altParam) =>
|
||||
lambdaTelescope altParam fun xs altParam => do
|
||||
-- TODO: Use boundedLambdaTelescope
|
||||
lambdaBoundedTelescope altParam altNumParam fun xs altParam => do
|
||||
unless altNumParam = xs.size do
|
||||
throwError "unexpected `casesOn` application alternative{indentExpr alt}\nat application{indentExpr e}"
|
||||
let altBody := alt.beta xs
|
||||
|
|
@ -342,9 +341,8 @@ call site.
|
|||
def collectRecCalls (unaryPreDef : PreDefinition) (fixedPrefixSize : Nat)
|
||||
(argsPacker : ArgsPacker) : MetaM (Array RecCallWithContext) := withoutModifyingState do
|
||||
addAsAxiom unaryPreDef
|
||||
lambdaTelescope unaryPreDef.value fun xs body => do
|
||||
lambdaBoundedTelescope unaryPreDef.value (fixedPrefixSize + 1) fun xs body => do
|
||||
unless xs.size == fixedPrefixSize + 1 do
|
||||
-- Maybe cleaner to have lambdaBoundedTelescope?
|
||||
throwError "Unexpected number of lambdas in unary pre-definition"
|
||||
let ys := xs[:fixedPrefixSize]
|
||||
let param := xs[fixedPrefixSize]!
|
||||
|
|
|
|||
|
|
@ -1201,30 +1201,31 @@ private def forallBoundedTelescopeImp (type : Expr) (maxFVars? : Option Nat) (k
|
|||
def forallBoundedTelescope (type : Expr) (maxFVars? : Option Nat) (k : Array Expr → Expr → n α) (cleanupAnnotations := false) : n α :=
|
||||
map2MetaM (fun k => forallBoundedTelescopeImp type maxFVars? k cleanupAnnotations) k
|
||||
|
||||
private partial def lambdaTelescopeImp (e : Expr) (consumeLet : Bool) (k : Array Expr → Expr → MetaM α) (cleanupAnnotations := false) : MetaM α := do
|
||||
process consumeLet (← getLCtx) #[] 0 e
|
||||
private partial def lambdaTelescopeImp (e : Expr) (consumeLet : Bool) (maxFVars? : Option Nat)
|
||||
(k : Array Expr → Expr → MetaM α) (cleanupAnnotations := false) : MetaM α := do
|
||||
process consumeLet (← getLCtx) #[] e
|
||||
where
|
||||
process (consumeLet : Bool) (lctx : LocalContext) (fvars : Array Expr) (j : Nat) (e : Expr) : MetaM α := do
|
||||
match consumeLet, e with
|
||||
| _, .lam n d b bi =>
|
||||
let d := d.instantiateRevRange j fvars.size fvars
|
||||
process (consumeLet : Bool) (lctx : LocalContext) (fvars : Array Expr) (e : Expr) : MetaM α := do
|
||||
match fvarsSizeLtMaxFVars fvars maxFVars?, consumeLet, e with
|
||||
| true, _, .lam n d b bi =>
|
||||
let d := d.instantiateRevRange 0 fvars.size fvars
|
||||
let d := if cleanupAnnotations then d.cleanupAnnotations else d
|
||||
let fvarId ← mkFreshFVarId
|
||||
let lctx := lctx.mkLocalDecl fvarId n d bi
|
||||
let fvar := mkFVar fvarId
|
||||
process consumeLet lctx (fvars.push fvar) j b
|
||||
| true, .letE n t v b _ => do
|
||||
let t := t.instantiateRevRange j fvars.size fvars
|
||||
process consumeLet lctx (fvars.push fvar) b
|
||||
| true, true, .letE n t v b _ => do
|
||||
let t := t.instantiateRevRange 0 fvars.size fvars
|
||||
let t := if cleanupAnnotations then t.cleanupAnnotations else t
|
||||
let v := v.instantiateRevRange j fvars.size fvars
|
||||
let v := v.instantiateRevRange 0 fvars.size fvars
|
||||
let fvarId ← mkFreshFVarId
|
||||
let lctx := lctx.mkLetDecl fvarId n t v
|
||||
let fvar := mkFVar fvarId
|
||||
process true lctx (fvars.push fvar) j b
|
||||
| _, e =>
|
||||
let e := e.instantiateRevRange j fvars.size fvars
|
||||
process true lctx (fvars.push fvar) b
|
||||
| _, _, e =>
|
||||
let e := e.instantiateRevRange 0 fvars.size fvars
|
||||
withReader (fun ctx => { ctx with lctx := lctx }) do
|
||||
withNewLocalInstancesImp fvars j do
|
||||
withNewLocalInstancesImp fvars 0 do
|
||||
k fvars e
|
||||
|
||||
/--
|
||||
|
|
@ -1233,7 +1234,7 @@ Similar to `lambdaTelescope` but for lambda and let expressions.
|
|||
If `cleanupAnnotations` is `true`, we apply `Expr.cleanupAnnotations` to each type in the telescope.
|
||||
-/
|
||||
def lambdaLetTelescope (e : Expr) (k : Array Expr → Expr → n α) (cleanupAnnotations := false) : n α :=
|
||||
map2MetaM (fun k => lambdaTelescopeImp e true k (cleanupAnnotations := cleanupAnnotations)) k
|
||||
map2MetaM (fun k => lambdaTelescopeImp e true .none k (cleanupAnnotations := cleanupAnnotations)) k
|
||||
|
||||
/--
|
||||
Given `e` of the form `fun ..xs => A`, execute `k xs A`.
|
||||
|
|
@ -1243,7 +1244,18 @@ def lambdaLetTelescope (e : Expr) (k : Array Expr → Expr → n α) (cleanupAnn
|
|||
If `cleanupAnnotations` is `true`, we apply `Expr.cleanupAnnotations` to each type in the telescope.
|
||||
-/
|
||||
def lambdaTelescope (e : Expr) (k : Array Expr → Expr → n α) (cleanupAnnotations := false) : n α :=
|
||||
map2MetaM (fun k => lambdaTelescopeImp e false k (cleanupAnnotations := cleanupAnnotations)) k
|
||||
map2MetaM (fun k => lambdaTelescopeImp e false none k (cleanupAnnotations := cleanupAnnotations)) k
|
||||
|
||||
/--
|
||||
Given `e` of the form `fun ..xs ..ys => A`, execute `k xs (fun ..ys => A)` where
|
||||
`xs.size ≤ maxFVars`.
|
||||
This combinator will declare local declarations, create free variables for them,
|
||||
execute `k` with updated local context, and make sure the cache is restored after executing `k`.
|
||||
|
||||
If `cleanupAnnotations` is `true`, we apply `Expr.cleanupAnnotations` to each type in the telescope.
|
||||
-/
|
||||
def lambdaBoundedTelescope (e : Expr) (maxFVars : Nat) (k : Array Expr → Expr → n α) (cleanupAnnotations := false) : n α :=
|
||||
map2MetaM (fun k => lambdaTelescopeImp e false (.some maxFVars) k (cleanupAnnotations := cleanupAnnotations)) k
|
||||
|
||||
/-- Return the parameter names for the given global declaration. -/
|
||||
def getParamNames (declName : Name) : MetaM (Array Name) := do
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue