diff --git a/src/Lean/Elab/Tactic/Induction.lean b/src/Lean/Elab/Tactic/Induction.lean index 960485c90d..9448b93e8f 100644 --- a/src/Lean/Elab/Tactic/Induction.lean +++ b/src/Lean/Elab/Tactic/Induction.lean @@ -10,6 +10,7 @@ import Lean.Meta.CollectMVars import Lean.Meta.Tactic.ElimInfo import Lean.Meta.Tactic.Induction import Lean.Meta.Tactic.Cases +import Lean.Meta.GeneralizeVars import Lean.Elab.App import Lean.Elab.Tactic.ElabTerm import Lean.Elab.Tactic.Generalize @@ -243,56 +244,6 @@ where end ElimApp -/-- - Return a set of `FVarId`s containing `targets` and all variables they depend on. - - Remark: this method assumes `targets` are free variables. --/ -private partial def mkForbiddenSet (targets : Array Expr) : MetaM NameSet := do - loop (targets.toList.map Expr.fvarId!) {} -where - visit (fvarId : FVarId) (todo : List FVarId) (s : NameSet) : MetaM (List FVarId × NameSet) := do - let localDecl ← getLocalDecl fvarId - let mut s' := collectFVars {} (← instantiateMVars localDecl.type) - if let some val := localDecl.value? then - s' := collectFVars s' (← instantiateMVars val) - let mut todo := todo - let mut s := s - for fvarId in s'.fvarSet do - unless s.contains fvarId do - todo := fvarId :: todo - s := s.insert fvarId - return (todo, s) - - loop (todo : List FVarId) (s : NameSet) : MetaM NameSet := do - match todo with - | [] => return s - | fvarId::todo => - if s.contains fvarId then - loop todo s - else - let (todo, s) ← visit fvarId todo <| s.insert fvarId - loop todo s - -/-- - Collect forward dependencies that are not in the forbidden set, and depend on some variable in `targets`. - - Remark: this method assumes `targets` are free variables. - - Remark: we *not* collect instance implicit arguments nor auxiliary declarations for compiling - recursive declarations. --/ -private def collectForwardDeps (targets : Array Expr) (forbidden : NameSet) : MetaM NameSet := do - let mut s : NameSet := targets.foldl (init := {}) fun s target => s.insert target.fvarId! - let mut r : NameSet := {} - for localDecl in (← getLCtx) do - unless forbidden.contains localDecl.fvarId do - unless localDecl.isAuxDecl || localDecl.binderInfo.isInstImplicit do - if (← getMCtx).findLocalDeclDependsOn localDecl fun fvarId => s.contains fvarId then - r := r.insert localDecl.fvarId - s := s.insert localDecl.fvarId - return r - /- Recall that ``` @@ -300,7 +251,7 @@ private def collectForwardDeps (targets : Array Expr) (forbidden : NameSet) : Me «induction» := leading_parser nonReservedSymbol "induction " >> majorPremise >> usingRec >> generalizingVars >> optional inductionAlts ``` `stx` is syntax for `induction`. -/ -private def getGeneralizingFVarIds (stx : Syntax) : TacticM (Array FVarId) := +private def getUserGeneralizingFVarIds (stx : Syntax) : TacticM (Array FVarId) := withRef stx do let generalizingStx := stx[3] if generalizingStx.isNone then @@ -313,18 +264,16 @@ private def getGeneralizingFVarIds (stx : Syntax) : TacticM (Array FVarId) := -- process `generalizingVars` subterm of induction Syntax `stx`. private def generalizeVars (mvarId : MVarId) (stx : Syntax) (targets : Array Expr) : TacticM (Nat × MVarId) := withMVarContext mvarId do - let userFVarIds ← getGeneralizingFVarIds stx - let forbidden ← mkForbiddenSet targets - let mut s ← collectForwardDeps targets forbidden + let userFVarIds ← getUserGeneralizingFVarIds stx + let forbidden ← mkGeneralizationForbiddenSet targets + let mut s ← getFVarSetToGeneralize targets forbidden for userFVarId in userFVarIds do if forbidden.contains userFVarId then throwError "variable cannot be generalized because target depends on it{indentExpr (mkFVar userFVarId)}" if s.contains userFVarId then throwError "unnecessary 'generalizing' argument, variable '{mkFVar userFVarId}' is generalized automatically" s := s.insert userFVarId - let fvarIds := s.fold (init := #[]) fun s fvarId => s.push fvarId - let lctx ← getLCtx - let fvarIds ← fvarIds.qsort fun x y => (lctx.get! x).index < (lctx.get! y).index + let fvarIds ← sortFVars s let (fvarIds, mvarId') ← Meta.revert mvarId fvarIds return (fvarIds.size, mvarId') diff --git a/src/Lean/Meta.lean b/src/Lean/Meta.lean index 1045c27416..5d453b3a10 100644 --- a/src/Lean/Meta.lean +++ b/src/Lean/Meta.lean @@ -32,3 +32,4 @@ import Lean.Meta.SizeOf import Lean.Meta.Coe import Lean.Meta.SortLocalDecls import Lean.Meta.CollectFVars +import Lean.Meta.GeneralizeVars diff --git a/src/Lean/Meta/GeneralizeVars.lean b/src/Lean/Meta/GeneralizeVars.lean new file mode 100644 index 0000000000..aa21f93482 --- /dev/null +++ b/src/Lean/Meta/GeneralizeVars.lean @@ -0,0 +1,75 @@ +/- +Copyright (c) 2021 Microsoft Corporation. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Leonardo de Moura +-/ +import Lean.Meta.Basic +import Lean.Util.CollectFVars + +namespace Lean.Meta + +/-- + Return a set of `FVarId`s containing `targets` and all variables they depend on. + + Remark: this method assumes `targets` are free variables. +-/ +partial def mkGeneralizationForbiddenSet (targets : Array Expr) : MetaM NameSet := do + loop (targets.toList.map Expr.fvarId!) {} +where + visit (fvarId : FVarId) (todo : List FVarId) (s : NameSet) : MetaM (List FVarId × NameSet) := do + let localDecl ← getLocalDecl fvarId + let mut s' := collectFVars {} (← instantiateMVars localDecl.type) + if let some val := localDecl.value? then + s' := collectFVars s' (← instantiateMVars val) + let mut todo := todo + let mut s := s + for fvarId in s'.fvarSet do + unless s.contains fvarId do + todo := fvarId :: todo + s := s.insert fvarId + return (todo, s) + + loop (todo : List FVarId) (s : NameSet) : MetaM NameSet := do + match todo with + | [] => return s + | fvarId::todo => + if s.contains fvarId then + loop todo s + else + let (todo, s) ← visit fvarId todo <| s.insert fvarId + loop todo s + +/-- + Collect variables to be generalized. + It uses the following heuristic + - Collect forward dependencies that are not in the forbidden set, and depend on some variable in `targets`. + + - We use `mkForbiddenSet` to compute `forbidden`. + + Remark: this method assumes `targets` are free variables. + + Remark: we *not* collect instance implicit arguments nor auxiliary declarations for compiling + recursive declarations. +-/ +def getFVarSetToGeneralize (targets : Array Expr) (forbidden : NameSet) : MetaM NameSet := do + let mut s : NameSet := targets.foldl (init := {}) fun s target => s.insert target.fvarId! + let mut r : NameSet := {} + for localDecl in (← getLCtx) do + unless forbidden.contains localDecl.fvarId do + unless localDecl.isAuxDecl || localDecl.binderInfo.isInstImplicit do + if (← getMCtx).findLocalDeclDependsOn localDecl fun fvarId => s.contains fvarId then + r := r.insert localDecl.fvarId + s := s.insert localDecl.fvarId + return r + +def sortFVars (fvars : NameSet) : MetaM (Array FVarId) := do + let fvarIds := fvars.fold (init := #[]) fun s fvarId => s.push fvarId + let lctx ← getLCtx + return fvarIds.qsort fun x y => (lctx.get! x).index < (lctx.get! y).index + +def getFVarsToGeneralize (targets : Array Expr) : MetaM (Array FVarId) := do + let forbidden ← mkGeneralizationForbiddenSet targets + let s ← getFVarSetToGeneralize targets forbidden + sortFVars s + +end Lean.Meta \ No newline at end of file