feat: improve performance of instantiateBetaRevRange (#13758)
This PR improves `Expr.instantiateBetaRevRange` to be more efficient in the common case where lambda functions are not being instantiated, and it increases expression sharing in applications. The motivation is that we would like to use this function more pervasively in elaboration, so that users do not need to write `dsimp only` as frequently in applications that involve higher-order functions, plus `inferType` uses it so there is a UX inconsistency when the elaborator is not using it.
This commit is contained in:
parent
ef2dc0f66a
commit
5dea2142c3
1 changed files with 42 additions and 19 deletions
|
|
@ -19,6 +19,9 @@ Auxiliary function for instantiating the loose bound variables in `e` with `args
|
||||||
This function is similar to `instantiateRevRange`, but it applies beta-reduction when
|
This function is similar to `instantiateRevRange`, but it applies beta-reduction when
|
||||||
we instantiate a bound variable with a lambda expression.
|
we instantiate a bound variable with a lambda expression.
|
||||||
|
|
||||||
|
If `args` contains no lambda expressions, it is equivalent to `instantiateRevRange`, and in fact
|
||||||
|
it will call `instantiateRevRange` for efficiency.
|
||||||
|
|
||||||
Example: Given the term `#0 a`, and `start := 0, stop := 1, args := #[fun x => x]` the result is
|
Example: Given the term `#0 a`, and `start := 0, stop := 1, args := #[fun x => x]` the result is
|
||||||
`a` instead of `(fun x => x) a`.
|
`a` instead of `(fun x => x) a`.
|
||||||
This reduction is useful when we are inferring the type of eliminator-like applications.
|
This reduction is useful when we are inferring the type of eliminator-like applications.
|
||||||
|
|
@ -32,10 +35,38 @@ We use this to implement `inferAppType`.
|
||||||
partial def Expr.instantiateBetaRevRange (e : Expr) (start : Nat) (stop : Nat) (args : Array Expr) : Expr :=
|
partial def Expr.instantiateBetaRevRange (e : Expr) (start : Nat) (stop : Nat) (args : Array Expr) : Expr :=
|
||||||
if e.hasLooseBVars && stop > start then
|
if e.hasLooseBVars && stop > start then
|
||||||
assert! stop ≤ args.size
|
assert! stop ≤ args.size
|
||||||
visit e 0 |>.run
|
if args.any (·.consumeMData.isLambda) start stop then
|
||||||
|
visit e 0 |>.run
|
||||||
|
else
|
||||||
|
-- If there are no lambdas, then `instantiateRevRange` suffices.
|
||||||
|
instantiateRevRange e start stop args
|
||||||
else
|
else
|
||||||
e
|
e
|
||||||
where
|
where
|
||||||
|
/--
|
||||||
|
Visit a bvar `e := .bvar vidx`, assuming `offset < e.looseBVarRange`.
|
||||||
|
-/
|
||||||
|
visitBVar (vidx : Nat) (offset : Nat) : Expr :=
|
||||||
|
-- Recall that `looseBVarRange` for `Expr.bvar` is `vidx+1`.
|
||||||
|
-- So, we must have `offset ≤ vidx`, since `offset < e.looseBVarRange`
|
||||||
|
let n := stop - start
|
||||||
|
if vidx < offset + n then
|
||||||
|
args[stop - (vidx - offset) - 1]!.liftLooseBVars 0 offset
|
||||||
|
else
|
||||||
|
Expr.bvar (vidx - n)
|
||||||
|
visitWithoutBeta (e : Expr) (offset : Nat) : MonadStateCacheT (ExprStructEq × Nat) Expr Id Expr := do
|
||||||
|
if offset >= e.looseBVarRange then
|
||||||
|
-- `e` doesn't have free variables
|
||||||
|
return e
|
||||||
|
else
|
||||||
|
match e with
|
||||||
|
| .app f a =>
|
||||||
|
-- Check the cache only here, since in the other alternative `visit` will check the cache.
|
||||||
|
checkCache ({ val := e : ExprStructEq }, offset) fun _ => visitApp e f a offset
|
||||||
|
| e => visit e offset
|
||||||
|
/-- Visit an application without beta reducing the head -/
|
||||||
|
visitApp (e f a : Expr) (offset : Nat) : MonadStateCacheT (ExprStructEq × Nat) Expr Id Expr :=
|
||||||
|
return e.updateApp! (← visitWithoutBeta f offset) (← visit a offset)
|
||||||
visit (e : Expr) (offset : Nat) : MonadStateCacheT (ExprStructEq × Nat) Expr Id Expr :=
|
visit (e : Expr) (offset : Nat) : MonadStateCacheT (ExprStructEq × Nat) Expr Id Expr :=
|
||||||
if offset >= e.looseBVarRange then
|
if offset >= e.looseBVarRange then
|
||||||
-- `e` doesn't have free variables
|
-- `e` doesn't have free variables
|
||||||
|
|
@ -47,23 +78,17 @@ where
|
||||||
| .letE _ t v b _ => return e.updateLetE! (← visit t offset) (← visit v offset) (← visit b (offset+1))
|
| .letE _ t v b _ => return e.updateLetE! (← visit t offset) (← visit v offset) (← visit b (offset+1))
|
||||||
| .mdata _ b => return e.updateMData! (← visit b offset)
|
| .mdata _ b => return e.updateMData! (← visit b offset)
|
||||||
| .proj _ _ b => return e.updateProj! (← visit b offset)
|
| .proj _ _ b => return e.updateProj! (← visit b offset)
|
||||||
| .app .. =>
|
| .bvar vidx => return visitBVar vidx offset
|
||||||
e.withAppRev fun f revArgs => do
|
| .app f a =>
|
||||||
let fNew ← visit f offset
|
let head := e.getAppFn
|
||||||
let revArgs ← revArgs.mapM (visit · offset)
|
-- try to beta reduce if the head is a bound variable
|
||||||
if f.isBVar then
|
if head.isBVar then
|
||||||
-- try to beta reduce if `f` was a bound variable
|
-- using `visit` instead of `visitBVar` for the `offset >= vidx` check and for caching `liftLooseBVars`
|
||||||
return fNew.betaRev revArgs
|
let head ← visit head offset
|
||||||
|
let revArgs ← e.getAppRevArgs.mapM (visit · offset)
|
||||||
|
return head.betaRev revArgs
|
||||||
else
|
else
|
||||||
return mkAppRev fNew revArgs
|
visitApp e f a offset
|
||||||
| Expr.bvar vidx =>
|
|
||||||
-- Recall that looseBVarRange for `Expr.bvar` is `vidx+1`.
|
|
||||||
-- So, we must have offset ≤ vidx, since we are in the "else" branch of `if offset >= e.looseBVarRange`
|
|
||||||
let n := stop - start
|
|
||||||
if vidx < offset + n then
|
|
||||||
return args[stop - (vidx - offset) - 1]!.liftLooseBVars 0 offset
|
|
||||||
else
|
|
||||||
return mkBVar (vidx - n)
|
|
||||||
-- The following cases are unreachable because they never contain loose bound variables
|
-- The following cases are unreachable because they never contain loose bound variables
|
||||||
| .const .. => unreachable!
|
| .const .. => unreachable!
|
||||||
| .fvar .. => unreachable!
|
| .fvar .. => unreachable!
|
||||||
|
|
@ -79,8 +104,6 @@ def throwFunctionExpected {α} (f : Expr) : MetaM α :=
|
||||||
private def inferAppType (f : Expr) (args : Array Expr) : MetaM Expr := do
|
private def inferAppType (f : Expr) (args : Array Expr) : MetaM Expr := do
|
||||||
let mut fType ← inferType f
|
let mut fType ← inferType f
|
||||||
let mut j := 0
|
let mut j := 0
|
||||||
/- TODO: check whether `instantiateBetaRevRange` is too expensive, and
|
|
||||||
use it only when `args` contains a lambda expression. -/
|
|
||||||
for i in *...args.size do
|
for i in *...args.size do
|
||||||
match fType with
|
match fType with
|
||||||
| Expr.forallE _ _ b _ => fType := b
|
| Expr.forallE _ _ b _ => fType := b
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue