refactor: lambdaBoundedTelescope (#4642)

we have a `forallBoundedTelescope`, and for a long while I was
wondering why we also don't have `lambdaBoundedTelescope`, and every now
and then felt the need for it. So let's just add it.
This commit is contained in:
Joachim Breitner 2024-07-03 17:57:12 +02:00 committed by GitHub
parent 3fb7f632a5
commit 0594bc4e5a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 41 additions and 37 deletions

View file

@ -153,9 +153,9 @@ private partial def replaceRecApps (recFnName : Name) (recArgInfo : RecArgInfo)
trace[Elab.definition.structural] "below before matcherApp.addArg: {below} : {← inferType below}"
if let some matcherApp ← matcherApp.addArg? below then
let altsNew ← (Array.zip matcherApp.alts matcherApp.altNumParams).mapM fun (alt, numParams) =>
lambdaTelescope alt fun xs altBody => do
lambdaBoundedTelescope alt numParams fun xs altBody => do
trace[Elab.definition.structural] "altNumParams: {numParams}, xs: {xs}"
unless xs.size >= numParams do
unless xs.size = numParams do
throwError "unexpected matcher application alternative{indentExpr alt}\nat application{indentExpr e}"
let belowForAlt := xs[numParams - 1]!
mkLambdaFVars xs (← loop belowForAlt altBody)

View file

@ -47,17 +47,13 @@ where
else
let mut altsNew := #[]
for alt in matcherApp.alts, numParams in matcherApp.altNumParams do
let altNew ← lambdaTelescope alt fun xs altBody => do
unless xs.size >= numParams do
let altNew ← lambdaBoundedTelescope alt numParams fun xs altBody => do
unless xs.size = numParams do
throwError "unexpected matcher application alternative{indentExpr alt}\nat application{indentExpr e}"
let altBody ← visit altBody
let containsSUnfoldMatch := Option.isSome <| altBody.find? fun e => smartUnfoldingMatch? e |>.isSome
if !containsSUnfoldMatch then
let altBody ← mkLambdaFVars xs[numParams:xs.size] altBody
let altBody := markSmartUnfoldingMatchAlt altBody
mkLambdaFVars xs[0:numParams] altBody
else
mkLambdaFVars xs altBody
let altBody := if !containsSUnfoldMatch then markSmartUnfoldingMatchAlt altBody else altBody
mkLambdaFVars xs altBody
altsNew := altsNew.push altNew
return markSmartUnfoldingMatch { matcherApp with alts := altsNew }.toExpr
| _ => processApp e

View file

@ -104,8 +104,7 @@ This needs extra information:
* `extraParams` indicates how many of the functions arguments are bound “after the colon”.
-/
def TerminationArgument.delab (arity : Nat) (extraParams : Nat) (termArg : TerminationArgument) : MetaM (TSyntax ``terminationBy) := do
lambdaTelescope termArg.fn fun ys e => do
let e ← mkLambdaFVars ys[arity - extraParams:] e -- undo overshooting by lambdaTelescope
lambdaBoundedTelescope termArg.fn (arity - extraParams) fun _ys e => do
pure (← delabCore e (delab := go extraParams #[])).1
where
go : Nat → TSyntaxArray `ident → DelabM (TSyntax ``terminationBy)

View file

@ -81,8 +81,8 @@ where
| some matcherApp =>
if let some matcherApp ← matcherApp.addArg? F then
let altsNew ← (Array.zip matcherApp.alts matcherApp.altNumParams).mapM fun (alt, numParams) =>
lambdaTelescope alt fun xs altBody => do
unless xs.size >= numParams do
lambdaBoundedTelescope alt numParams fun xs altBody => do
unless xs.size = numParams do
throwError "unexpected matcher application alternative{indentExpr alt}\nat application{indentExpr e}"
let FAlt := xs[numParams - 1]!
mkLambdaFVars xs (← loop FAlt altBody)
@ -103,12 +103,11 @@ private partial def processSumCasesOn (x F val : Expr) (k : (x : Expr) → (F :
let type ← mkArrow (FDecl.type.replaceFVar x xs[0]!) type
return (← mkLambdaFVars xs type, ← getLevel type)
let mkMinorNew (ctorName : Name) (minor : Expr) : TermElabM Expr :=
lambdaTelescope minor fun xs body => do
lambdaBoundedTelescope minor 1 fun xs body => do
let xNew := xs[0]!
let valNew ← mkLambdaFVars xs[1:] body
let FTypeNew := FDecl.type.replaceFVar x (← mkAppOptM ctorName #[α, β, xNew])
withLocalDeclD FDecl.userName FTypeNew fun FNew => do
mkLambdaFVars #[xNew, FNew] (← processSumCasesOn xNew FNew valNew k)
mkLambdaFVars #[xNew, FNew] (← processSumCasesOn xNew FNew body k)
let minorLeft ← mkMinorNew ``PSum.inl args[4]!
let minorRight ← mkMinorNew ``PSum.inr args[5]!
let result := mkAppN (mkConst ``PSum.casesOn [u, (← getLevel α), (← getLevel β)]) #[α, β, motiveNew, x, minorLeft, minorRight, F]

View file

@ -256,8 +256,7 @@ where
matcherApp.discrs.forM (loop param)
(Array.zip matcherApp.alts (Array.zip matcherApp.altNumParams altParams)).forM
fun (alt, altNumParam, altParam) =>
lambdaTelescope altParam fun xs altParam => do
-- TODO: Use boundedLambdaTelescope
lambdaBoundedTelescope altParam altNumParam fun xs altParam => do
unless altNumParam = xs.size do
throwError "unexpected `casesOn` application alternative{indentExpr alt}\nat application{indentExpr e}"
let altBody := alt.beta xs
@ -342,9 +341,8 @@ call site.
def collectRecCalls (unaryPreDef : PreDefinition) (fixedPrefixSize : Nat)
(argsPacker : ArgsPacker) : MetaM (Array RecCallWithContext) := withoutModifyingState do
addAsAxiom unaryPreDef
lambdaTelescope unaryPreDef.value fun xs body => do
lambdaBoundedTelescope unaryPreDef.value (fixedPrefixSize + 1) fun xs body => do
unless xs.size == fixedPrefixSize + 1 do
-- Maybe cleaner to have lambdaBoundedTelescope?
throwError "Unexpected number of lambdas in unary pre-definition"
let ys := xs[:fixedPrefixSize]
let param := xs[fixedPrefixSize]!

View file

@ -1201,30 +1201,31 @@ private def forallBoundedTelescopeImp (type : Expr) (maxFVars? : Option Nat) (k
def forallBoundedTelescope (type : Expr) (maxFVars? : Option Nat) (k : Array Expr → Expr → n α) (cleanupAnnotations := false) : n α :=
map2MetaM (fun k => forallBoundedTelescopeImp type maxFVars? k cleanupAnnotations) k
private partial def lambdaTelescopeImp (e : Expr) (consumeLet : Bool) (k : Array Expr → Expr → MetaM α) (cleanupAnnotations := false) : MetaM α := do
process consumeLet (← getLCtx) #[] 0 e
private partial def lambdaTelescopeImp (e : Expr) (consumeLet : Bool) (maxFVars? : Option Nat)
(k : Array Expr → Expr → MetaM α) (cleanupAnnotations := false) : MetaM α := do
process consumeLet (← getLCtx) #[] e
where
process (consumeLet : Bool) (lctx : LocalContext) (fvars : Array Expr) (j : Nat) (e : Expr) : MetaM α := do
match consumeLet, e with
| _, .lam n d b bi =>
let d := d.instantiateRevRange j fvars.size fvars
process (consumeLet : Bool) (lctx : LocalContext) (fvars : Array Expr) (e : Expr) : MetaM α := do
match fvarsSizeLtMaxFVars fvars maxFVars?, consumeLet, e with
| true, _, .lam n d b bi =>
let d := d.instantiateRevRange 0 fvars.size fvars
let d := if cleanupAnnotations then d.cleanupAnnotations else d
let fvarId ← mkFreshFVarId
let lctx := lctx.mkLocalDecl fvarId n d bi
let fvar := mkFVar fvarId
process consumeLet lctx (fvars.push fvar) j b
| true, .letE n t v b _ => do
let t := t.instantiateRevRange j fvars.size fvars
process consumeLet lctx (fvars.push fvar) b
| true, true, .letE n t v b _ => do
let t := t.instantiateRevRange 0 fvars.size fvars
let t := if cleanupAnnotations then t.cleanupAnnotations else t
let v := v.instantiateRevRange j fvars.size fvars
let v := v.instantiateRevRange 0 fvars.size fvars
let fvarId ← mkFreshFVarId
let lctx := lctx.mkLetDecl fvarId n t v
let fvar := mkFVar fvarId
process true lctx (fvars.push fvar) j b
| _, e =>
let e := e.instantiateRevRange j fvars.size fvars
process true lctx (fvars.push fvar) b
| _, _, e =>
let e := e.instantiateRevRange 0 fvars.size fvars
withReader (fun ctx => { ctx with lctx := lctx }) do
withNewLocalInstancesImp fvars j do
withNewLocalInstancesImp fvars 0 do
k fvars e
/--
@ -1233,7 +1234,7 @@ Similar to `lambdaTelescope` but for lambda and let expressions.
If `cleanupAnnotations` is `true`, we apply `Expr.cleanupAnnotations` to each type in the telescope.
-/
def lambdaLetTelescope (e : Expr) (k : Array Expr → Expr → n α) (cleanupAnnotations := false) : n α :=
map2MetaM (fun k => lambdaTelescopeImp e true k (cleanupAnnotations := cleanupAnnotations)) k
map2MetaM (fun k => lambdaTelescopeImp e true .none k (cleanupAnnotations := cleanupAnnotations)) k
/--
Given `e` of the form `fun ..xs => A`, execute `k xs A`.
@ -1243,7 +1244,18 @@ def lambdaLetTelescope (e : Expr) (k : Array Expr → Expr → n α) (cleanupAnn
If `cleanupAnnotations` is `true`, we apply `Expr.cleanupAnnotations` to each type in the telescope.
-/
def lambdaTelescope (e : Expr) (k : Array Expr → Expr → n α) (cleanupAnnotations := false) : n α :=
map2MetaM (fun k => lambdaTelescopeImp e false k (cleanupAnnotations := cleanupAnnotations)) k
map2MetaM (fun k => lambdaTelescopeImp e false none k (cleanupAnnotations := cleanupAnnotations)) k
/--
Given `e` of the form `fun ..xs ..ys => A`, execute `k xs (fun ..ys => A)` where
`xs.size ≤ maxFVars`.
This combinator will declare local declarations, create free variables for them,
execute `k` with updated local context, and make sure the cache is restored after executing `k`.
If `cleanupAnnotations` is `true`, we apply `Expr.cleanupAnnotations` to each type in the telescope.
-/
def lambdaBoundedTelescope (e : Expr) (maxFVars : Nat) (k : Array Expr → Expr → n α) (cleanupAnnotations := false) : n α :=
map2MetaM (fun k => lambdaTelescopeImp e false (.some maxFVars) k (cleanupAnnotations := cleanupAnnotations)) k
/-- Return the parameter names for the given global declaration. -/
def getParamNames (declName : Name) : MetaM (Array Name) := do