diff --git a/src/Lean/Meta/KAbstract.lean b/src/Lean/Meta/KAbstract.lean index 5a2a07ec96..ee534ac938 100644 --- a/src/Lean/Meta/KAbstract.lean +++ b/src/Lean/Meta/KAbstract.lean @@ -34,11 +34,42 @@ def kabstract (e : Expr) (p : Expr) (occs : Occurrences := Occurrences.all) : Me let i ← get set (i+1) if occs.contains i then - pure (mkBVar offset) + return mkBVar offset else visitChildren () else visitChildren () visit e 0 |>.run' 1 +/-- + Similar to `kabstract`, but only abstracts occurrences of `p` s.t. `pred parent? p` is true where `parent?` + is the parent expression for `p` if any. +-/ +partial def kabstractWithPred (e : Expr) (p : Expr) (pred : (parent? : Option Expr) → (e : Expr) → MetaM Bool) : MetaM Expr := do + let e ← instantiateMVars e + let pHeadIdx := p.toHeadIndex + let pNumArgs := p.headNumArgs + let rec visit (parent? : Option Expr) (e : Expr) (offset : Nat) : MetaM Expr := do + let visitChildren : Unit → MetaM Expr := fun _ => do + match e with + | Expr.app .. => e.withApp fun f args => return mkAppN (← visit e f offset) (← args.mapM (visit e . offset)) + | Expr.mdata _ b _ => return e.updateMData! (← visit e b offset) + | Expr.proj _ _ b _ => return e.updateProj! (← visit e b offset) + | Expr.letE _ t v b _ => return e.updateLet! (← visit e t offset) (← visit e v offset) (← visit e b (offset+1)) + | Expr.lam _ d b _ => return e.updateLambdaE! (← visit e d offset) (← visit e b (offset+1)) + | Expr.forallE _ d b _ => return e.updateForallE! (← visit e d offset) (← visit e b (offset+1)) + | e => return e + if e.hasLooseBVars then + visitChildren () + else if e.toHeadIndex != pHeadIdx || e.headNumArgs != pNumArgs then + visitChildren () + else if (← isDefEq e p) then + if (← pred parent? e) then + return mkBVar offset + else + visitChildren () + else + visitChildren () + visit none e 0 + end Lean.Meta diff --git a/src/Lean/Meta/Tactic/Generalize.lean b/src/Lean/Meta/Tactic/Generalize.lean index c972bc9db9..1e0ecbef09 100644 --- a/src/Lean/Meta/Tactic/Generalize.lean +++ b/src/Lean/Meta/Tactic/Generalize.lean @@ -16,7 +16,10 @@ structure GeneralizeArg where hName? : Option Name := none deriving Inhabited -partial def generalize (mvarId : MVarId) (args : Array GeneralizeArg) : MetaM (Array FVarId × MVarId) := +partial def generalize + (mvarId : MVarId) (args : Array GeneralizeArg) + (pred : (parent? : Option Expr) → (e : Expr) → MetaM Bool := fun _ _ => return true) + : MetaM (Array FVarId × MVarId) := withMVarContext mvarId do checkNotAssigned mvarId `generalize let tag ← getMVarTag mvarId @@ -28,7 +31,7 @@ partial def generalize (mvarId : MVarId) (args : Array GeneralizeArg) : MetaM (A let eType ← instantiateMVars (← inferType e) let type ← go (i+1) let xName ← if let some xName := arg.xName? then pure xName else mkFreshUserName `x - return Lean.mkForall xName BinderInfo.default eType (← kabstract type e) + return Lean.mkForall xName BinderInfo.default eType (← kabstractWithPred type e pred) else return target let targetNew ← go 0