feat: new constraints in grind_pattern (#11391)

This PR implements new kinds of constraints for the `grind_pattern`
command. These constraints allow users to control theorem instantiation
in `grind`.
It requires a manual `update-stage0` because the change affects the
`.olean` format, and the PR fails without it.
This commit is contained in:
Leonardo de Moura 2025-11-26 21:13:14 -08:00 committed by GitHub
parent 490d714486
commit a4f9a793d9
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 12464 additions and 4628 deletions

View file

@ -30,6 +30,7 @@ declare_config_elab elabCutsatConfig Grind.CutsatConfig
declare_config_elab elabGrobnerConfig Grind.GrobnerConfig
open Command Term in
open Lean.Parser.Command.GrindCnstr in
@[builtin_command_elab Lean.Parser.Command.grindPattern]
def elabGrindPattern : CommandElab := fun stx => do
match stx with
@ -38,41 +39,92 @@ def elabGrindPattern : CommandElab := fun stx => do
| `(local grind_pattern $thmName:ident => $terms,* $[$cnstrs?:grindPatternCnstrs]?) => go thmName terms cnstrs? .local
| _ => throwUnsupportedSyntax
where
findLHS (xs : Array Expr) (lhs : Syntax) : TermElabM (LocalDecl × Nat) := do
let lhsId := lhs.getId
let mut i := 0
for x in xs do
let xDecl ← x.fvarId!.getDecl
if xDecl.userName == lhsId then
return (xDecl, xs.size - i - 1)
i := i + 1
throwErrorAt lhs "invalid constraint, `{lhsId}` is not local variable of the theorem"
elabCnstrRHS (xs : Array Expr) (rhs : Syntax) (expectedType : Expr) : TermElabM Grind.CnstrRHS := do
/-
**Note**: We need better sanity checking here.
We must check whether the type of `rhs` is type correct with respect to
an arbitrary instantiation of `xs`. That is, we should use meta-variables
in the check. It is incorrect to use `xDecl.type`. For example, suppose the
type of `xDecl` is `α → β` where `α` and `β` are variables in `xs` occurring before
`xDecl`, and `rhsExpr` is `some : ?m → ?m`. The types `α → β =?= ?m → ?m` are
not definitionally equal, but `?α → ?β =?= ?m → ?m` are.
-/
let rhsExpr ← Term.elabTerm rhs expectedType
Term.synthesizeSyntheticMVars (postpone := .no) (ignoreStuckTC := true)
let rhsExpr ← instantiateMVars rhsExpr
if rhsExpr.hasSyntheticSorry then
throwErrorAt rhs "invalid constraint, rhs contains a synthetic `sorry`"
let rhsExpr := rhsExpr.eta
let { paramNames := levelNames, mvars, expr := rhs } ← abstractMVars rhsExpr
let numMVars := mvars.size
let rhs := rhs.abstract xs
return { levelNames, numMVars, expr := rhs }
elabProp (xs : Array Expr) (term : Syntax) : TermElabM Expr := do
let e ← Term.elabTermAndSynthesize term (Expr.sort 0)
let e ← instantiateMVars e
if e.hasSyntheticSorry then
throwErrorAt term "invalid proposition, it contains a synthetic `sorry`"
if e.hasMVar then
throwErrorAt term "invalid proposition, it contains metavariables{indentExpr e}"
return e.abstract xs
elabNotDefEq (xs : Array Expr) (lhs rhs : Syntax) : TermElabM Grind.EMatchTheoremConstraint := do
let (localDecl, lhsBVarIdx) ← findLHS xs lhs
let rhs ← elabCnstrRHS xs rhs localDecl.type
return .notDefEq lhsBVarIdx rhs
elabDefEq (xs : Array Expr) (lhs rhs : Syntax) : TermElabM Grind.EMatchTheoremConstraint := do
let (localDecl, lhsBVarIdx) ← findLHS xs lhs
let rhs ← elabCnstrRHS xs rhs localDecl.type
return .defEq lhsBVarIdx rhs
elabCnstrs (xs : Array Expr) (cnstrs? : Option (TSyntax ``Parser.Command.grindPatternCnstrs))
: TermElabM (List (Grind.EMatchTheoremConstraint)) := do
let some cnstrs := cnstrs? | return []
let cnstrs := cnstrs.raw[1].getArgs
cnstrs.toList.mapM fun cnstr => do
-- **Note**: Hack because syntax matching is not working. Fix after another update stage0
let lhs := cnstr[0]
let rhs := cnstr[2]
let lhsId := lhs.getId
let mut i := 0
for x in xs do
let xDecl ← x.fvarId!.getDecl
if xDecl.userName == lhsId then
let bvarIdx := xs.size - i - 1
/-
**Note**: We need better sanity checking here.
We must check whether the type of `rhs` is type correct with respect to
an arbitrary instantiation of `xs`. That is, we should use meta-variables
in the check. It is incorrect to use `xDecl.type`. For example, suppose the
type of `xDecl` is `α → β` where `α` and `β` are variables in `xs` occurring before
`xDecl`, and `rhsExpr` is `some : ?m → ?m`. The types `α → β =?= ?m → ?m` are
not definitionally equal, but `?α → ?β =?= ?m → ?m` are.
-/
let rhsExpr ← Term.elabTerm rhs xDecl.type
Term.synthesizeSyntheticMVars (postpone := .no) (ignoreStuckTC := true)
let rhsExpr ← instantiateMVars rhsExpr
if rhsExpr.hasSyntheticSorry then
throwErrorAt rhs "invalid constraint, rhs contains a synthetic `sorry`"
let rhsExpr := rhsExpr.eta
let { paramNames := levelNames, mvars, expr := rhs } ← abstractMVars rhsExpr
let numMVars := mvars.size
let rhs := rhs.abstract xs
return { bvarIdx, levelNames, numMVars, rhs }
i := i + 1
throwErrorAt lhs "invalid constraint, `{lhsId}` is not local variable of the theorem"
let kind := cnstr.getKind
if kind == ``notDefEq then
elabNotDefEq xs cnstr[0] cnstr[2]
else if kind == ``defEq then
elabDefEq xs cnstr[0] cnstr[2]
else if kind == ``genLt then
let (_, lhs) ← findLHS xs cnstr[1]
return .genLt lhs cnstr[3].toNat
else if kind == ``sizeLt then
let (_, lhs) ← findLHS xs cnstr[1]
return .sizeLt lhs cnstr[3].toNat
else if kind == ``depthLt then
let (_, lhs) ← findLHS xs cnstr[1]
return .depthLt lhs cnstr[3].toNat
else if kind == ``maxInsts then
return .maxInsts cnstr[1].toNat
else if kind == ``isValue then
let (_, lhs) ← findLHS xs cnstr[1]
return .isValue lhs false
else if kind == ``isStrictValue then
let (_, lhs) ← findLHS xs cnstr[1]
return .isValue lhs true
else if kind == ``isGround then
let (_, lhs) ← findLHS xs cnstr[1]
return .isGround lhs
else if kind == ``Parser.Command.GrindCnstr.check then
return .check (← elabProp xs cnstr[1])
else if kind == ``Parser.Command.GrindCnstr.guard then
return .guard (← elabProp xs cnstr[1])
else
throwErrorAt cnstr "unexpected constraint"
go (thmName : TSyntax `ident) (terms : Syntax.TSepArray `term ",")
(cnstrs? : Option (TSyntax ``Parser.Command.grindPatternCnstrs))

View file

@ -637,6 +637,28 @@ private abbrev withFreshNGen (x : M α) : M α := do
finally
setNGen ngen
/--
Checks constraints of the form `lhs =/= rhs`.
-/
private def checkNotDefEq (levelParams : List Name) (us : List Level) (args : Array Expr) (lhs : Nat) (rhs : CnstrRHS) : GoalM Bool := do
unless lhs < args.size do
throwError "`grind` internal error, invalid variable in `grind_pattern` constraint"
let lhs := args[args.size - lhs - 1]!
/- **Note**: We first instantiate the theorem variables and universes occurring in `rhs`. -/
let rhsExpr := rhs.expr.instantiateRev args
let rhsExpr := rhsExpr.instantiateLevelParams levelParams us
withNewMCtxDepth do
/-
**Note**: Recall that we have abstracted metavariables occurring in `rhs` after we elaborated it.
So, we must "recreate" them.
-/
let us ← rhs.levelNames.mapM fun _ => mkFreshLevelMVar
let rhsExpr := rhsExpr.instantiateLevelParamsArray rhs.levelNames us
let (_, _, rhsExpr) ← lambdaMetaTelescope rhsExpr (some rhs.numMVars)
/- **Note**: We used the guarded version to ensure type errors will not interrupt `grind`. -/
let defEq ← isDefEqGuarded lhs rhsExpr
return !defEq
/--
Checks whether `vars` satisfies the `grind_pattern` constraints attached at `thm`.
Example:
@ -650,29 +672,15 @@ In the example above, a `map_map` instance should be added to the logical contex
Remark: `proof` is used to extract the universe parameters in the proof.
-/
private def checkConstraints (thm : EMatchTheorem) (proof : Expr) (args : Array Expr) : MetaM Bool := do
private def checkConstraints (thm : EMatchTheorem) (proof : Expr) (args : Array Expr) : GoalM Bool := do
if thm.cnstrs.isEmpty then return true
/- **Note**: Only top-level theorems have constraints. -/
let .const declName us := proof | return true
let info ← getConstInfo declName
thm.cnstrs.allM fun cnstr => do
unless cnstr.bvarIdx < args.size do
throwError "`grind` internal error, invalid variable in `grind_pattern` constraint"
let lhs := args[args.size - cnstr.bvarIdx - 1]!
/- **Note**: We first instantiate the theorem variables and universes occurring in `rhs`. -/
let rhs := cnstr.rhs.instantiateRev args
let rhs := rhs.instantiateLevelParams info.levelParams us
withNewMCtxDepth do
/-
**Note**: Recall that we have abstracted metavariables occurring in `rhs` after we elaborated it.
So, we must "recreate" them.
-/
let us ← cnstr.levelNames.mapM fun _ => mkFreshLevelMVar
let rhs := rhs.instantiateLevelParamsArray cnstr.levelNames us
let (_, _, rhs) ← lambdaMetaTelescope rhs (some cnstr.numMVars)
/- **Note**: We used the guarded version to ensure type errors will not interrupt `grind`. -/
let defEq ← isDefEqGuarded lhs rhs
return !defEq
match cnstr with
| .notDefEq lhs rhs => checkNotDefEq info.levelParams us args lhs rhs
| _ => throwError "NIY"
/--
After processing a (multi-)pattern, use the choice assignment to instantiate the proof.

View file

@ -344,20 +344,68 @@ private def EMatchTheoremKind.explainFailure : EMatchTheoremKind → String
| .default _ => "failed to find patterns"
| .user => unreachable!
/--
Grind patterns may have constraints of the form `lhs =/= rhs` associated with them.
The `lhs` is one of the bound variables, and the `rhs` an abstract term that must not be definitionally
equal to a term `t` assigned to `lhs`.
-/
structure EMatchTheoremConstraint where
/-- `lhs` -/
bvarIdx : Nat
structure CnstrRHS where
/-- Abstracted universe level param names in the `rhs` -/
levelNames : Array Name
/-- Number of abstracted metavariable in the `rhs` -/
numMVars : Nat
/-- The actual `rhs`. -/
rhs : Expr
expr : Expr
deriving Inhabited, BEq, Repr
/--
Grind patterns may have constraints associated with them.
-/
inductive EMatchTheoremConstraint where
| /--
A constraint of the form `lhs =/= rhs`.
The `lhs` is one of the bound variables, and the `rhs` an abstract term that must not be definitionally
equal to a term `t` assigned to `lhs`. -/
notDefEq (lhs : Nat) (rhs : CnstrRHS)
| /--
A constraint of the form `lhs =?= rhs`.
The `lhs` is one of the bound variables, and the `rhs` an abstract term that must be definitionally
equal to a term `t` assigned to `lhs`. -/
defEq (lhs : Nat) (rhs : CnstrRHS)
| /--
A constraint of the form `size lhs < n`. The `lhs` is one of the bound variables.
The size is computed ignoring implicit terms, but sharing is not taken into account.
-/
sizeLt (lhs : Nat) (n : Nat)
| /--
A constraint of the form `depth lhs < n`. The `lhs` is one of the bound variables.
The depth is computed in constant time using the `approxDepth` field attached to expressions.
-/
depthLt (lhs : Nat) (n : Nat)
| /--
Instantiates the theorem only if its generation is less than `n`
-/
genLt (lhs : Nat) (n : Nat)
| /--
Constraints of the form `is_ground x`. Instantiates the theorem only if
`x` is ground term.
-/
isGround (bvarIdx : Nat)
| /--
Constraints of the form `is_value x` and `is_strict_value x`.
A value is defined as
- A constructor fully applied to value arguments.
- A literal: numerals, strings, etc.
- A lambda. In the strict case, lambdas are not considered.
-/
isValue (bvarIdx : Nat) (strict : Bool)
| /--
Instantiates the theorem only if less than `n` instances have been generated for this theorem.
-/
maxInsts (n : Nat)
| /--
It instructs `grind` to postpone the instantiation of the theorem until `e` is known to be `true`.
-/
guard (e : Expr)
| /--
Similar to `guard`, but checks whether `e` is implied by asserting `¬e`.
-/
check (e : Expr)
deriving Inhabited, Repr, BEq
/-- A theorem for heuristic instantiation based on E-matching. -/

View file

@ -9,7 +9,26 @@ public import Lean.Parser.Command
public section
namespace Lean.Parser.Command
def grindPatternCnstr : Parser := leading_parser ident >> " =/= " >> checkColGe "irrelevant" >> termParser >> optional ";"
namespace GrindCnstr
def isValue := leading_parser nonReservedSymbol "is_value " >> ident >> optional ";"
def isStrictValue := leading_parser nonReservedSymbol "is_strict_value " >> ident >> optional ";"
def isGround := leading_parser nonReservedSymbol "is_ground " >> ident >> optional ";"
def sizeLt := leading_parser nonReservedSymbol "size " >> ident >> " < " >> numLit >> optional ";"
def depthLt := leading_parser nonReservedSymbol "depth " >> ident >> " < " >> numLit >> optional ";"
def genLt := leading_parser nonReservedSymbol "gen " >> ident >> " < " >> numLit >> optional ";"
def maxInsts := leading_parser nonReservedSymbol "max_insts " >> numLit >> optional ";"
def guard := leading_parser nonReservedSymbol "guard " >> checkColGe "irrelevant" >> termParser >> optional ";"
def check := leading_parser nonReservedSymbol "check " >> checkColGe "irrelevant" >> termParser >> optional ";"
def notDefEq := leading_parser atomic (ident >> " =/= ") >> checkColGe "irrelevant" >> termParser >> optional ";"
def defEq := leading_parser atomic (ident >> " =?= ") >> checkColGe "irrelevant" >> termParser >> optional ";"
end GrindCnstr
open GrindCnstr in
def grindPatternCnstr : Parser :=
isValue <|> isStrictValue <|> isGround <|> sizeLt <|> depthLt <|> genLt <|> maxInsts
<|> guard <|> check <|> notDefEq <|> defEq
def grindPatternCnstrs : Parser := leading_parser "where " >> many1Indent (ppLine >> grindPatternCnstr)

File diff suppressed because it is too large Load diff

View file

@ -3201,21 +3201,21 @@ goto block_34;
block_67:
{
lean_object* x_63; lean_object* x_64; lean_object* x_65; lean_object* x_66;
x_63 = l_Lean_PersistentArray_push___redArg(x_56, x_48);
x_64 = l_Lean_PersistentArray_push___redArg(x_63, x_52);
x_63 = l_Lean_PersistentArray_push___redArg(x_57, x_48);
x_64 = l_Lean_PersistentArray_push___redArg(x_63, x_51);
x_65 = lean_alloc_ctor(0, 12, 0);
lean_ctor_set(x_65, 0, x_58);
lean_ctor_set(x_65, 1, x_57);
lean_ctor_set(x_65, 0, x_60);
lean_ctor_set(x_65, 1, x_55);
lean_ctor_set(x_65, 2, x_61);
lean_ctor_set(x_65, 3, x_53);
lean_ctor_set(x_65, 4, x_55);
lean_ctor_set(x_65, 3, x_58);
lean_ctor_set(x_65, 4, x_59);
lean_ctor_set(x_65, 5, x_64);
lean_ctor_set(x_65, 6, x_59);
lean_ctor_set(x_65, 7, x_60);
lean_ctor_set(x_65, 8, x_51);
lean_ctor_set(x_65, 9, x_49);
lean_ctor_set(x_65, 10, x_50);
lean_ctor_set(x_65, 11, x_54);
lean_ctor_set(x_65, 6, x_53);
lean_ctor_set(x_65, 7, x_52);
lean_ctor_set(x_65, 8, x_56);
lean_ctor_set(x_65, 9, x_50);
lean_ctor_set(x_65, 10, x_54);
lean_ctor_set(x_65, 11, x_49);
x_66 = lean_alloc_ctor(0, 1, 0);
lean_ctor_set(x_66, 0, x_65);
return x_66;
@ -3750,18 +3750,18 @@ x_197 = lean_ctor_get(x_196, 0);
lean_inc(x_197);
lean_dec_ref(x_196);
x_48 = x_194;
x_49 = x_190;
x_50 = x_191;
x_51 = x_189;
x_52 = x_197;
x_53 = x_184;
x_54 = x_192;
x_55 = x_185;
x_56 = x_186;
x_57 = x_182;
x_58 = x_181;
x_59 = x_187;
x_60 = x_188;
x_49 = x_192;
x_50 = x_190;
x_51 = x_197;
x_52 = x_188;
x_53 = x_187;
x_54 = x_191;
x_55 = x_182;
x_56 = x_189;
x_57 = x_186;
x_58 = x_184;
x_59 = x_185;
x_60 = x_181;
x_61 = x_183;
x_62 = lean_box(0);
goto block_67;
@ -3787,18 +3787,18 @@ lean_dec(x_174);
lean_dec_ref(x_173);
lean_dec(x_3);
x_48 = x_194;
x_49 = x_190;
x_50 = x_191;
x_51 = x_189;
x_52 = x_198;
x_53 = x_184;
x_54 = x_192;
x_55 = x_185;
x_56 = x_186;
x_57 = x_182;
x_58 = x_181;
x_59 = x_187;
x_60 = x_188;
x_49 = x_192;
x_50 = x_190;
x_51 = x_198;
x_52 = x_188;
x_53 = x_187;
x_54 = x_191;
x_55 = x_182;
x_56 = x_189;
x_57 = x_186;
x_58 = x_184;
x_59 = x_185;
x_60 = x_181;
x_61 = x_183;
x_62 = lean_box(0);
goto block_67;
@ -3821,18 +3821,18 @@ lean_dec(x_174);
lean_dec_ref(x_173);
lean_dec(x_3);
x_48 = x_194;
x_49 = x_190;
x_50 = x_191;
x_51 = x_189;
x_52 = x_198;
x_53 = x_184;
x_54 = x_192;
x_55 = x_185;
x_56 = x_186;
x_57 = x_182;
x_58 = x_181;
x_59 = x_187;
x_60 = x_188;
x_49 = x_192;
x_50 = x_190;
x_51 = x_198;
x_52 = x_188;
x_53 = x_187;
x_54 = x_191;
x_55 = x_182;
x_56 = x_189;
x_57 = x_186;
x_58 = x_184;
x_59 = x_185;
x_60 = x_181;
x_61 = x_183;
x_62 = lean_box(0);
goto block_67;
@ -3847,18 +3847,18 @@ lean_dec(x_174);
lean_dec_ref(x_173);
lean_dec_ref(x_207);
x_48 = x_194;
x_49 = x_190;
x_50 = x_191;
x_51 = x_189;
x_52 = x_198;
x_53 = x_184;
x_54 = x_192;
x_55 = x_185;
x_56 = x_186;
x_57 = x_182;
x_58 = x_181;
x_59 = x_187;
x_60 = x_188;
x_49 = x_192;
x_50 = x_190;
x_51 = x_198;
x_52 = x_188;
x_53 = x_187;
x_54 = x_191;
x_55 = x_182;
x_56 = x_189;
x_57 = x_186;
x_58 = x_184;
x_59 = x_185;
x_60 = x_181;
x_61 = x_183;
x_62 = lean_box(0);
goto block_67;
@ -4041,18 +4041,18 @@ x_236 = lean_ctor_get(x_235, 0);
lean_inc(x_236);
lean_dec_ref(x_235);
x_48 = x_233;
x_49 = x_228;
x_50 = x_229;
x_51 = x_227;
x_52 = x_236;
x_53 = x_222;
x_54 = x_230;
x_55 = x_223;
x_56 = x_224;
x_57 = x_220;
x_58 = x_219;
x_59 = x_225;
x_60 = x_226;
x_49 = x_230;
x_50 = x_228;
x_51 = x_236;
x_52 = x_226;
x_53 = x_225;
x_54 = x_229;
x_55 = x_220;
x_56 = x_227;
x_57 = x_224;
x_58 = x_222;
x_59 = x_223;
x_60 = x_219;
x_61 = x_221;
x_62 = lean_box(0);
goto block_67;
@ -4078,18 +4078,18 @@ lean_dec(x_174);
lean_dec_ref(x_173);
lean_dec(x_3);
x_48 = x_233;
x_49 = x_228;
x_50 = x_229;
x_51 = x_227;
x_52 = x_237;
x_53 = x_222;
x_54 = x_230;
x_55 = x_223;
x_56 = x_224;
x_57 = x_220;
x_58 = x_219;
x_59 = x_225;
x_60 = x_226;
x_49 = x_230;
x_50 = x_228;
x_51 = x_237;
x_52 = x_226;
x_53 = x_225;
x_54 = x_229;
x_55 = x_220;
x_56 = x_227;
x_57 = x_224;
x_58 = x_222;
x_59 = x_223;
x_60 = x_219;
x_61 = x_221;
x_62 = lean_box(0);
goto block_67;
@ -4112,18 +4112,18 @@ lean_dec(x_174);
lean_dec_ref(x_173);
lean_dec(x_3);
x_48 = x_233;
x_49 = x_228;
x_50 = x_229;
x_51 = x_227;
x_52 = x_237;
x_53 = x_222;
x_54 = x_230;
x_55 = x_223;
x_56 = x_224;
x_57 = x_220;
x_58 = x_219;
x_59 = x_225;
x_60 = x_226;
x_49 = x_230;
x_50 = x_228;
x_51 = x_237;
x_52 = x_226;
x_53 = x_225;
x_54 = x_229;
x_55 = x_220;
x_56 = x_227;
x_57 = x_224;
x_58 = x_222;
x_59 = x_223;
x_60 = x_219;
x_61 = x_221;
x_62 = lean_box(0);
goto block_67;
@ -4138,18 +4138,18 @@ lean_dec(x_174);
lean_dec_ref(x_173);
lean_dec_ref(x_246);
x_48 = x_233;
x_49 = x_228;
x_50 = x_229;
x_51 = x_227;
x_52 = x_237;
x_53 = x_222;
x_54 = x_230;
x_55 = x_223;
x_56 = x_224;
x_57 = x_220;
x_58 = x_219;
x_59 = x_225;
x_60 = x_226;
x_49 = x_230;
x_50 = x_228;
x_51 = x_237;
x_52 = x_226;
x_53 = x_225;
x_54 = x_229;
x_55 = x_220;
x_56 = x_227;
x_57 = x_224;
x_58 = x_222;
x_59 = x_223;
x_60 = x_219;
x_61 = x_221;
x_62 = lean_box(0);
goto block_67;

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff