feat: check pattern coverage in the grind_pattern command (#6474)
This PR adds pattern validation to the `grind_pattern` command. The new `checkCoverage` function will also be used to implement the attributes `@[grind_eq]`, `@[grind_fwd]`, and `@[grind_bwd]`.
This commit is contained in:
parent
3c326d771c
commit
24a8561ec4
3 changed files with 218 additions and 4 deletions
|
|
@ -9,7 +9,6 @@ import Lean.Meta.Tactic.Grind
|
|||
import Lean.Elab.Command
|
||||
import Lean.Elab.Tactic.Basic
|
||||
|
||||
|
||||
namespace Lean.Elab.Tactic
|
||||
open Meta
|
||||
|
||||
|
|
@ -20,6 +19,7 @@ def elabGrindPattern : CommandElab := fun stx => do
|
|||
| `(grind_pattern $thmName:ident => $terms,*) => do
|
||||
liftTermElabM do
|
||||
let declName ← resolveGlobalConstNoOverload thmName
|
||||
discard <| addTermInfo thmName (← mkConstWithLevelParams declName)
|
||||
let info ← getConstInfo declName
|
||||
forallTelescope info.type fun xs _ => do
|
||||
let patterns ← terms.getElems.mapM fun term => do
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ Authors: Leonardo de Moura
|
|||
prelude
|
||||
import Lean.HeadIndex
|
||||
import Lean.Util.FoldConsts
|
||||
import Lean.Util.CollectFVars
|
||||
import Lean.Meta.Basic
|
||||
import Lean.Meta.InferType
|
||||
|
||||
|
|
@ -153,19 +154,144 @@ private partial def go (pattern : Expr) (root := false) : M Expr := do
|
|||
args := args.set! i arg
|
||||
return mkAppN f args
|
||||
|
||||
def main (patterns : List Expr) : MetaM (List Expr × List HeadIndex) := do
|
||||
def main (patterns : List Expr) : MetaM (List Expr × List HeadIndex × Std.HashSet Nat) := do
|
||||
let (patterns, s) ← patterns.mapM go |>.run {}
|
||||
return (patterns, s.symbols.toList)
|
||||
return (patterns, s.symbols.toList, s.bvarsFound)
|
||||
|
||||
end NormalizePattern
|
||||
|
||||
/--
|
||||
Returns `true` if free variables in `type` are not in `thmVars` or are in `fvarsFound`.
|
||||
We use this function to check whether `type` is fully instantiated.
|
||||
-/
|
||||
private def checkTypeFVars (thmVars : FVarIdSet) (fvarsFound : FVarIdSet) (type : Expr) : Bool :=
|
||||
let typeFVars := (collectFVars {} type).fvarIds
|
||||
typeFVars.all fun fvarId => !thmVars.contains fvarId || fvarsFound.contains fvarId
|
||||
|
||||
/--
|
||||
Given an type class instance type `instType`, returns true if free variables in input parameters
|
||||
1- are not in `thmVars`, or
|
||||
2- are in `fvarsFound`.
|
||||
Remark: `fvarsFound` is a subset of `thmVars`
|
||||
-/
|
||||
private def canBeSynthesized (thmVars : FVarIdSet) (fvarsFound : FVarIdSet) (instType : Expr) : MetaM Bool := do
|
||||
forallTelescopeReducing instType fun xs type => type.withApp fun classFn classArgs => do
|
||||
for x in xs do
|
||||
unless checkTypeFVars thmVars fvarsFound (← inferType x) do return false
|
||||
forallBoundedTelescope (← inferType classFn) type.getAppNumArgs fun params _ => do
|
||||
for param in params, classArg in classArgs do
|
||||
let paramType ← inferType param
|
||||
if !paramType.isAppOf ``semiOutParam && !paramType.isAppOf ``outParam then
|
||||
unless checkTypeFVars thmVars fvarsFound classArg do
|
||||
return false
|
||||
return true
|
||||
|
||||
/--
|
||||
Auxiliary type for the `checkCoverage` function.
|
||||
-/
|
||||
inductive CheckCoverageResult where
|
||||
| /-- `checkCoverage` succeeded -/
|
||||
ok
|
||||
| /--
|
||||
`checkCoverage` failed because some of the theorem parameters are missing,
|
||||
`pos` contains their positions
|
||||
-/
|
||||
missing (pos : List Nat)
|
||||
|
||||
/--
|
||||
After we process a set of patterns, we obtain the set of de Bruijn indices in these patterns.
|
||||
We say they are pattern variables. This function checks whether the set of pattern variables is sufficient for
|
||||
instantiating the theorem with proof `thmProof`. The theorem has `numParams` parameters.
|
||||
The missing parameters:
|
||||
1- we may be able to infer them using type inference or type class synthesis, or
|
||||
2- they are propositions, and may become hypotheses of the instantiated theorem.
|
||||
|
||||
For type class instance parameters, we must check whether the free variables in class input parameters are available.
|
||||
-/
|
||||
private def checkCoverage (thmProof : Expr) (numParams : Nat) (bvarsFound : Std.HashSet Nat) : MetaM CheckCoverageResult := do
|
||||
if bvarsFound.size == numParams then return .ok
|
||||
forallBoundedTelescope (← inferType thmProof) numParams fun xs _ => do
|
||||
assert! numParams == xs.size
|
||||
let patternVars := bvarsFound.toList.map fun bidx => xs[numParams - bidx - 1]!.fvarId!
|
||||
-- `xs` as a `FVarIdSet`.
|
||||
let thmVars : FVarIdSet := RBTree.ofList <| xs.toList.map (·.fvarId!)
|
||||
-- Collect free variables occurring in `e`, and insert the ones that are in `thmVars` into `fvarsFound`
|
||||
let update (fvarsFound : FVarIdSet) (e : Expr) : FVarIdSet :=
|
||||
(collectFVars {} e).fvarIds.foldl (init := fvarsFound) fun s fvarId =>
|
||||
if thmVars.contains fvarId then s.insert fvarId else s
|
||||
-- Theorem variables found so far. We initialize with the variables occurring in patterns
|
||||
-- Remark: fvarsFound is a subset of thmVars
|
||||
let mut fvarsFound : FVarIdSet := RBTree.ofList patternVars
|
||||
for patternVar in patternVars do
|
||||
let type ← patternVar.getType
|
||||
fvarsFound := update fvarsFound type
|
||||
if fvarsFound.size == numParams then return .ok
|
||||
-- Now, we keep traversing remaining variables and collecting
|
||||
-- `processed` contains the variables we have already processed.
|
||||
let mut processed : FVarIdSet := RBTree.ofList patternVars
|
||||
let mut modified := false
|
||||
repeat
|
||||
modified := false
|
||||
for x in xs do
|
||||
let fvarId := x.fvarId!
|
||||
unless processed.contains fvarId do
|
||||
let xType ← inferType x
|
||||
if fvarsFound.contains fvarId then
|
||||
-- Collect free vars in `x`s type and mark as processed
|
||||
fvarsFound := update fvarsFound xType
|
||||
processed := processed.insert fvarId
|
||||
modified := true
|
||||
else if (← isProp xType) then
|
||||
-- If `x` is a proposition, and all theorem variables in `x`s type have already been found
|
||||
-- add it to `fvarsFound` and mark it as processed.
|
||||
if checkTypeFVars thmVars fvarsFound xType then
|
||||
fvarsFound := fvarsFound.insert fvarId
|
||||
processed := processed.insert fvarId
|
||||
modified := true
|
||||
else if (← fvarId.getDecl).binderInfo matches .instImplicit then
|
||||
-- If `x` is instance implicit, check whether
|
||||
-- we have found all free variables needed to synthesize instance
|
||||
if (← canBeSynthesized thmVars fvarsFound xType) then
|
||||
fvarsFound := fvarsFound.insert fvarId
|
||||
fvarsFound := update fvarsFound xType
|
||||
processed := processed.insert fvarId
|
||||
modified := true
|
||||
if fvarsFound.size == numParams then
|
||||
return .ok
|
||||
if !modified then
|
||||
break
|
||||
let mut pos := #[]
|
||||
for h : i in [:xs.size] do
|
||||
let fvarId := xs[i].fvarId!
|
||||
unless fvarsFound.contains fvarId do
|
||||
pos := pos.push i
|
||||
return .missing pos.toList
|
||||
|
||||
/--
|
||||
Given a theorem with proof `proof` and `numParams` parameters, returns a message
|
||||
containing the parameters at positions `paramPos`.
|
||||
-/
|
||||
private def ppParamsAt (proof : Expr) (numParms : Nat) (paramPos : List Nat) : MetaM MessageData := do
|
||||
forallBoundedTelescope (← inferType proof) numParms fun xs _ => do
|
||||
let mut msg := m!""
|
||||
let mut first := true
|
||||
for h : i in [:xs.size] do
|
||||
if paramPos.contains i then
|
||||
let x := xs[i]
|
||||
if first then first := false else msg := msg ++ "\n"
|
||||
msg := msg ++ m!"{x} : {← inferType x}"
|
||||
addMessageContextFull msg
|
||||
|
||||
def addTheoremPattern (declName : Name) (numParams : Nat) (patterns : List Expr) : MetaM Unit := do
|
||||
let .thmInfo info ← getConstInfo declName
|
||||
| throwError "`{declName}` is not a theorem, you cannot assign patterns to non-theorems for the `grind` tactic"
|
||||
let us := info.levelParams.map mkLevelParam
|
||||
let proof := mkConst declName us
|
||||
let (patterns, symbols) ← NormalizePattern.main patterns
|
||||
let (patterns, symbols, bvarFound) ← NormalizePattern.main patterns
|
||||
trace[grind.pattern] "{declName}: {patterns.map ppPattern}"
|
||||
if let .missing pos ← checkCoverage proof numParams bvarFound then
|
||||
let pats : MessageData := m!"{patterns.map ppPattern}"
|
||||
throwError "invalid pattern(s) for `{declName}`{indentD pats}\nthe following theorem parameters cannot be instantiated:{indentD (← ppParamsAt proof numParams pos)}"
|
||||
theoremPatternsExt.add {
|
||||
proof, patterns, numParams, symbols
|
||||
origin := .decl declName
|
||||
|
|
|
|||
|
|
@ -26,3 +26,91 @@ error: `foo` is not a theorem, you cannot assign patterns to non-theorems for th
|
|||
-/
|
||||
#guard_msgs in
|
||||
grind_pattern foo => x + x
|
||||
|
||||
/--
|
||||
error: invalid pattern(s) for `Array.getElem_push_lt`
|
||||
[@Array.push #4 #3 #2]
|
||||
the following theorem parameters cannot be instantiated:
|
||||
i : Nat
|
||||
h : i < a.size
|
||||
---
|
||||
info: [grind.pattern] Array.getElem_push_lt: [@Array.push #4 #3 #2]
|
||||
-/
|
||||
#guard_msgs in
|
||||
grind_pattern Array.getElem_push_lt => (a.push x)
|
||||
|
||||
class Foo (α : Type) (β : outParam Type) where
|
||||
a : Unit
|
||||
|
||||
class Boo (α : Type) (β : Type) where
|
||||
b : β
|
||||
|
||||
def f [Foo α β] [Boo α β] (a : α) : (α × β) :=
|
||||
(a, Boo.b α)
|
||||
|
||||
instance [Foo α β] : Foo (List α) (Array β) where
|
||||
a := ()
|
||||
|
||||
instance [Boo α β] : Boo (List α) (Array β) where
|
||||
b := #[Boo.b α]
|
||||
|
||||
theorem fEq [Foo α β] [Boo α β] (a : List α) : (f a).1 = a := rfl
|
||||
|
||||
/-- info: [grind.pattern] fEq: [@f ? ? ? ? #0] -/
|
||||
#guard_msgs in
|
||||
grind_pattern fEq => f a
|
||||
|
||||
theorem fEq2 [Foo α β] [Boo α β] (a : List α) (_h : a.length > 5) : (f a).1 = a := rfl
|
||||
|
||||
/-- info: [grind.pattern] fEq2: [@f ? ? ? ? #1] -/
|
||||
#guard_msgs in
|
||||
grind_pattern fEq2 => f a
|
||||
|
||||
def g [Boo α β] (a : α) : (α × β) :=
|
||||
(a, Boo.b α)
|
||||
|
||||
theorem gEq [Boo α β] (a : List α) : (g (β := Array β) a).1 = a := rfl
|
||||
|
||||
/--
|
||||
error: invalid pattern(s) for `gEq`
|
||||
[@g ? ? ? #0]
|
||||
the following theorem parameters cannot be instantiated:
|
||||
β : Type
|
||||
inst✝ : Boo α β
|
||||
---
|
||||
info: [grind.pattern] gEq: [@g ? ? ? #0]
|
||||
-/
|
||||
#guard_msgs in
|
||||
grind_pattern gEq => g a
|
||||
|
||||
def plus (a : Nat) (b : Nat) := a + b
|
||||
|
||||
theorem hThm1 (h : b > 10) : plus a b + plus a c > 10 := by
|
||||
unfold plus; omega
|
||||
|
||||
/--
|
||||
error: invalid pattern(s) for `hThm1`
|
||||
[plus #2 #3]
|
||||
the following theorem parameters cannot be instantiated:
|
||||
c : Nat
|
||||
---
|
||||
info: [grind.pattern] hThm1: [plus #2 #3]
|
||||
-/
|
||||
#guard_msgs in
|
||||
grind_pattern hThm1 => plus a b
|
||||
|
||||
/--
|
||||
error: invalid pattern(s) for `hThm1`
|
||||
[plus #2 #1]
|
||||
the following theorem parameters cannot be instantiated:
|
||||
b : Nat
|
||||
h : b > 10
|
||||
---
|
||||
info: [grind.pattern] hThm1: [plus #2 #1]
|
||||
-/
|
||||
#guard_msgs in
|
||||
grind_pattern hThm1 => plus a c
|
||||
|
||||
/-- info: [grind.pattern] hThm1: [plus #2 #1, plus #2 #3] -/
|
||||
#guard_msgs in
|
||||
grind_pattern hThm1 => plus a c, plus a b
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue