refactor: add GeneralizeVars.lean
Helper methods for performing auto generalization.
This commit is contained in:
parent
2667744092
commit
555b978d67
3 changed files with 82 additions and 57 deletions
|
|
@ -10,6 +10,7 @@ import Lean.Meta.CollectMVars
|
|||
import Lean.Meta.Tactic.ElimInfo
|
||||
import Lean.Meta.Tactic.Induction
|
||||
import Lean.Meta.Tactic.Cases
|
||||
import Lean.Meta.GeneralizeVars
|
||||
import Lean.Elab.App
|
||||
import Lean.Elab.Tactic.ElabTerm
|
||||
import Lean.Elab.Tactic.Generalize
|
||||
|
|
@ -243,56 +244,6 @@ where
|
|||
|
||||
end ElimApp
|
||||
|
||||
/--
|
||||
Return a set of `FVarId`s containing `targets` and all variables they depend on.
|
||||
|
||||
Remark: this method assumes `targets` are free variables.
|
||||
-/
|
||||
private partial def mkForbiddenSet (targets : Array Expr) : MetaM NameSet := do
|
||||
loop (targets.toList.map Expr.fvarId!) {}
|
||||
where
|
||||
visit (fvarId : FVarId) (todo : List FVarId) (s : NameSet) : MetaM (List FVarId × NameSet) := do
|
||||
let localDecl ← getLocalDecl fvarId
|
||||
let mut s' := collectFVars {} (← instantiateMVars localDecl.type)
|
||||
if let some val := localDecl.value? then
|
||||
s' := collectFVars s' (← instantiateMVars val)
|
||||
let mut todo := todo
|
||||
let mut s := s
|
||||
for fvarId in s'.fvarSet do
|
||||
unless s.contains fvarId do
|
||||
todo := fvarId :: todo
|
||||
s := s.insert fvarId
|
||||
return (todo, s)
|
||||
|
||||
loop (todo : List FVarId) (s : NameSet) : MetaM NameSet := do
|
||||
match todo with
|
||||
| [] => return s
|
||||
| fvarId::todo =>
|
||||
if s.contains fvarId then
|
||||
loop todo s
|
||||
else
|
||||
let (todo, s) ← visit fvarId todo <| s.insert fvarId
|
||||
loop todo s
|
||||
|
||||
/--
|
||||
Collect forward dependencies that are not in the forbidden set, and depend on some variable in `targets`.
|
||||
|
||||
Remark: this method assumes `targets` are free variables.
|
||||
|
||||
Remark: we *not* collect instance implicit arguments nor auxiliary declarations for compiling
|
||||
recursive declarations.
|
||||
-/
|
||||
private def collectForwardDeps (targets : Array Expr) (forbidden : NameSet) : MetaM NameSet := do
|
||||
let mut s : NameSet := targets.foldl (init := {}) fun s target => s.insert target.fvarId!
|
||||
let mut r : NameSet := {}
|
||||
for localDecl in (← getLCtx) do
|
||||
unless forbidden.contains localDecl.fvarId do
|
||||
unless localDecl.isAuxDecl || localDecl.binderInfo.isInstImplicit do
|
||||
if (← getMCtx).findLocalDeclDependsOn localDecl fun fvarId => s.contains fvarId then
|
||||
r := r.insert localDecl.fvarId
|
||||
s := s.insert localDecl.fvarId
|
||||
return r
|
||||
|
||||
/-
|
||||
Recall that
|
||||
```
|
||||
|
|
@ -300,7 +251,7 @@ private def collectForwardDeps (targets : Array Expr) (forbidden : NameSet) : Me
|
|||
«induction» := leading_parser nonReservedSymbol "induction " >> majorPremise >> usingRec >> generalizingVars >> optional inductionAlts
|
||||
```
|
||||
`stx` is syntax for `induction`. -/
|
||||
private def getGeneralizingFVarIds (stx : Syntax) : TacticM (Array FVarId) :=
|
||||
private def getUserGeneralizingFVarIds (stx : Syntax) : TacticM (Array FVarId) :=
|
||||
withRef stx do
|
||||
let generalizingStx := stx[3]
|
||||
if generalizingStx.isNone then
|
||||
|
|
@ -313,18 +264,16 @@ private def getGeneralizingFVarIds (stx : Syntax) : TacticM (Array FVarId) :=
|
|||
-- process `generalizingVars` subterm of induction Syntax `stx`.
|
||||
private def generalizeVars (mvarId : MVarId) (stx : Syntax) (targets : Array Expr) : TacticM (Nat × MVarId) :=
|
||||
withMVarContext mvarId do
|
||||
let userFVarIds ← getGeneralizingFVarIds stx
|
||||
let forbidden ← mkForbiddenSet targets
|
||||
let mut s ← collectForwardDeps targets forbidden
|
||||
let userFVarIds ← getUserGeneralizingFVarIds stx
|
||||
let forbidden ← mkGeneralizationForbiddenSet targets
|
||||
let mut s ← getFVarSetToGeneralize targets forbidden
|
||||
for userFVarId in userFVarIds do
|
||||
if forbidden.contains userFVarId then
|
||||
throwError "variable cannot be generalized because target depends on it{indentExpr (mkFVar userFVarId)}"
|
||||
if s.contains userFVarId then
|
||||
throwError "unnecessary 'generalizing' argument, variable '{mkFVar userFVarId}' is generalized automatically"
|
||||
s := s.insert userFVarId
|
||||
let fvarIds := s.fold (init := #[]) fun s fvarId => s.push fvarId
|
||||
let lctx ← getLCtx
|
||||
let fvarIds ← fvarIds.qsort fun x y => (lctx.get! x).index < (lctx.get! y).index
|
||||
let fvarIds ← sortFVars s
|
||||
let (fvarIds, mvarId') ← Meta.revert mvarId fvarIds
|
||||
return (fvarIds.size, mvarId')
|
||||
|
||||
|
|
|
|||
|
|
@ -32,3 +32,4 @@ import Lean.Meta.SizeOf
|
|||
import Lean.Meta.Coe
|
||||
import Lean.Meta.SortLocalDecls
|
||||
import Lean.Meta.CollectFVars
|
||||
import Lean.Meta.GeneralizeVars
|
||||
|
|
|
|||
75
src/Lean/Meta/GeneralizeVars.lean
Normal file
75
src/Lean/Meta/GeneralizeVars.lean
Normal file
|
|
@ -0,0 +1,75 @@
|
|||
/-
|
||||
Copyright (c) 2021 Microsoft Corporation. All rights reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Leonardo de Moura
|
||||
-/
|
||||
import Lean.Meta.Basic
|
||||
import Lean.Util.CollectFVars
|
||||
|
||||
namespace Lean.Meta
|
||||
|
||||
/--
|
||||
Return a set of `FVarId`s containing `targets` and all variables they depend on.
|
||||
|
||||
Remark: this method assumes `targets` are free variables.
|
||||
-/
|
||||
partial def mkGeneralizationForbiddenSet (targets : Array Expr) : MetaM NameSet := do
|
||||
loop (targets.toList.map Expr.fvarId!) {}
|
||||
where
|
||||
visit (fvarId : FVarId) (todo : List FVarId) (s : NameSet) : MetaM (List FVarId × NameSet) := do
|
||||
let localDecl ← getLocalDecl fvarId
|
||||
let mut s' := collectFVars {} (← instantiateMVars localDecl.type)
|
||||
if let some val := localDecl.value? then
|
||||
s' := collectFVars s' (← instantiateMVars val)
|
||||
let mut todo := todo
|
||||
let mut s := s
|
||||
for fvarId in s'.fvarSet do
|
||||
unless s.contains fvarId do
|
||||
todo := fvarId :: todo
|
||||
s := s.insert fvarId
|
||||
return (todo, s)
|
||||
|
||||
loop (todo : List FVarId) (s : NameSet) : MetaM NameSet := do
|
||||
match todo with
|
||||
| [] => return s
|
||||
| fvarId::todo =>
|
||||
if s.contains fvarId then
|
||||
loop todo s
|
||||
else
|
||||
let (todo, s) ← visit fvarId todo <| s.insert fvarId
|
||||
loop todo s
|
||||
|
||||
/--
|
||||
Collect variables to be generalized.
|
||||
It uses the following heuristic
|
||||
- Collect forward dependencies that are not in the forbidden set, and depend on some variable in `targets`.
|
||||
|
||||
- We use `mkForbiddenSet` to compute `forbidden`.
|
||||
|
||||
Remark: this method assumes `targets` are free variables.
|
||||
|
||||
Remark: we *not* collect instance implicit arguments nor auxiliary declarations for compiling
|
||||
recursive declarations.
|
||||
-/
|
||||
def getFVarSetToGeneralize (targets : Array Expr) (forbidden : NameSet) : MetaM NameSet := do
|
||||
let mut s : NameSet := targets.foldl (init := {}) fun s target => s.insert target.fvarId!
|
||||
let mut r : NameSet := {}
|
||||
for localDecl in (← getLCtx) do
|
||||
unless forbidden.contains localDecl.fvarId do
|
||||
unless localDecl.isAuxDecl || localDecl.binderInfo.isInstImplicit do
|
||||
if (← getMCtx).findLocalDeclDependsOn localDecl fun fvarId => s.contains fvarId then
|
||||
r := r.insert localDecl.fvarId
|
||||
s := s.insert localDecl.fvarId
|
||||
return r
|
||||
|
||||
def sortFVars (fvars : NameSet) : MetaM (Array FVarId) := do
|
||||
let fvarIds := fvars.fold (init := #[]) fun s fvarId => s.push fvarId
|
||||
let lctx ← getLCtx
|
||||
return fvarIds.qsort fun x y => (lctx.get! x).index < (lctx.get! y).index
|
||||
|
||||
def getFVarsToGeneralize (targets : Array Expr) : MetaM (Array FVarId) := do
|
||||
let forbidden ← mkGeneralizationForbiddenSet targets
|
||||
let s ← getFVarSetToGeneralize targets forbidden
|
||||
sortFVars s
|
||||
|
||||
end Lean.Meta
|
||||
Loading…
Add table
Reference in a new issue