refactor: add GeneralizeVars.lean

Helper methods for performing auto generalization.
This commit is contained in:
Leonardo de Moura 2021-04-15 21:37:48 -07:00
parent 2667744092
commit 555b978d67
3 changed files with 82 additions and 57 deletions

View file

@ -10,6 +10,7 @@ import Lean.Meta.CollectMVars
import Lean.Meta.Tactic.ElimInfo
import Lean.Meta.Tactic.Induction
import Lean.Meta.Tactic.Cases
import Lean.Meta.GeneralizeVars
import Lean.Elab.App
import Lean.Elab.Tactic.ElabTerm
import Lean.Elab.Tactic.Generalize
@ -243,56 +244,6 @@ where
end ElimApp
/--
Return a set of `FVarId`s containing `targets` and all variables they depend on.
Remark: this method assumes `targets` are free variables.
-/
private partial def mkForbiddenSet (targets : Array Expr) : MetaM NameSet := do
loop (targets.toList.map Expr.fvarId!) {}
where
visit (fvarId : FVarId) (todo : List FVarId) (s : NameSet) : MetaM (List FVarId × NameSet) := do
let localDecl ← getLocalDecl fvarId
let mut s' := collectFVars {} (← instantiateMVars localDecl.type)
if let some val := localDecl.value? then
s' := collectFVars s' (← instantiateMVars val)
let mut todo := todo
let mut s := s
for fvarId in s'.fvarSet do
unless s.contains fvarId do
todo := fvarId :: todo
s := s.insert fvarId
return (todo, s)
loop (todo : List FVarId) (s : NameSet) : MetaM NameSet := do
match todo with
| [] => return s
| fvarId::todo =>
if s.contains fvarId then
loop todo s
else
let (todo, s) ← visit fvarId todo <| s.insert fvarId
loop todo s
/--
Collect forward dependencies that are not in the forbidden set, and depend on some variable in `targets`.
Remark: this method assumes `targets` are free variables.
Remark: we *not* collect instance implicit arguments nor auxiliary declarations for compiling
recursive declarations.
-/
private def collectForwardDeps (targets : Array Expr) (forbidden : NameSet) : MetaM NameSet := do
let mut s : NameSet := targets.foldl (init := {}) fun s target => s.insert target.fvarId!
let mut r : NameSet := {}
for localDecl in (← getLCtx) do
unless forbidden.contains localDecl.fvarId do
unless localDecl.isAuxDecl || localDecl.binderInfo.isInstImplicit do
if (← getMCtx).findLocalDeclDependsOn localDecl fun fvarId => s.contains fvarId then
r := r.insert localDecl.fvarId
s := s.insert localDecl.fvarId
return r
/-
Recall that
```
@ -300,7 +251,7 @@ private def collectForwardDeps (targets : Array Expr) (forbidden : NameSet) : Me
«induction» := leading_parser nonReservedSymbol "induction " >> majorPremise >> usingRec >> generalizingVars >> optional inductionAlts
```
`stx` is syntax for `induction`. -/
private def getGeneralizingFVarIds (stx : Syntax) : TacticM (Array FVarId) :=
private def getUserGeneralizingFVarIds (stx : Syntax) : TacticM (Array FVarId) :=
withRef stx do
let generalizingStx := stx[3]
if generalizingStx.isNone then
@ -313,18 +264,16 @@ private def getGeneralizingFVarIds (stx : Syntax) : TacticM (Array FVarId) :=
-- process `generalizingVars` subterm of induction Syntax `stx`.
private def generalizeVars (mvarId : MVarId) (stx : Syntax) (targets : Array Expr) : TacticM (Nat × MVarId) :=
withMVarContext mvarId do
let userFVarIds ← getGeneralizingFVarIds stx
let forbidden ← mkForbiddenSet targets
let mut s ← collectForwardDeps targets forbidden
let userFVarIds ← getUserGeneralizingFVarIds stx
let forbidden ← mkGeneralizationForbiddenSet targets
let mut s ← getFVarSetToGeneralize targets forbidden
for userFVarId in userFVarIds do
if forbidden.contains userFVarId then
throwError "variable cannot be generalized because target depends on it{indentExpr (mkFVar userFVarId)}"
if s.contains userFVarId then
throwError "unnecessary 'generalizing' argument, variable '{mkFVar userFVarId}' is generalized automatically"
s := s.insert userFVarId
let fvarIds := s.fold (init := #[]) fun s fvarId => s.push fvarId
let lctx ← getLCtx
let fvarIds ← fvarIds.qsort fun x y => (lctx.get! x).index < (lctx.get! y).index
let fvarIds ← sortFVars s
let (fvarIds, mvarId') ← Meta.revert mvarId fvarIds
return (fvarIds.size, mvarId')

View file

@ -32,3 +32,4 @@ import Lean.Meta.SizeOf
import Lean.Meta.Coe
import Lean.Meta.SortLocalDecls
import Lean.Meta.CollectFVars
import Lean.Meta.GeneralizeVars

View file

@ -0,0 +1,75 @@
/-
Copyright (c) 2021 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
import Lean.Meta.Basic
import Lean.Util.CollectFVars
namespace Lean.Meta
/--
Return a set of `FVarId`s containing `targets` and all variables they depend on.
Remark: this method assumes `targets` are free variables.
-/
partial def mkGeneralizationForbiddenSet (targets : Array Expr) : MetaM NameSet := do
loop (targets.toList.map Expr.fvarId!) {}
where
visit (fvarId : FVarId) (todo : List FVarId) (s : NameSet) : MetaM (List FVarId × NameSet) := do
let localDecl ← getLocalDecl fvarId
let mut s' := collectFVars {} (← instantiateMVars localDecl.type)
if let some val := localDecl.value? then
s' := collectFVars s' (← instantiateMVars val)
let mut todo := todo
let mut s := s
for fvarId in s'.fvarSet do
unless s.contains fvarId do
todo := fvarId :: todo
s := s.insert fvarId
return (todo, s)
loop (todo : List FVarId) (s : NameSet) : MetaM NameSet := do
match todo with
| [] => return s
| fvarId::todo =>
if s.contains fvarId then
loop todo s
else
let (todo, s) ← visit fvarId todo <| s.insert fvarId
loop todo s
/--
Collect variables to be generalized.
It uses the following heuristic
- Collect forward dependencies that are not in the forbidden set, and depend on some variable in `targets`.
- We use `mkForbiddenSet` to compute `forbidden`.
Remark: this method assumes `targets` are free variables.
Remark: we *not* collect instance implicit arguments nor auxiliary declarations for compiling
recursive declarations.
-/
def getFVarSetToGeneralize (targets : Array Expr) (forbidden : NameSet) : MetaM NameSet := do
let mut s : NameSet := targets.foldl (init := {}) fun s target => s.insert target.fvarId!
let mut r : NameSet := {}
for localDecl in (← getLCtx) do
unless forbidden.contains localDecl.fvarId do
unless localDecl.isAuxDecl || localDecl.binderInfo.isInstImplicit do
if (← getMCtx).findLocalDeclDependsOn localDecl fun fvarId => s.contains fvarId then
r := r.insert localDecl.fvarId
s := s.insert localDecl.fvarId
return r
def sortFVars (fvars : NameSet) : MetaM (Array FVarId) := do
let fvarIds := fvars.fold (init := #[]) fun s fvarId => s.push fvarId
let lctx ← getLCtx
return fvarIds.qsort fun x y => (lctx.get! x).index < (lctx.get! y).index
def getFVarsToGeneralize (targets : Array Expr) : MetaM (Array FVarId) := do
let forbidden ← mkGeneralizationForbiddenSet targets
let s ← getFVarSetToGeneralize targets forbidden
sortFVars s
end Lean.Meta