From e56351da7ae15673a5bc2b83e7fb70e1fbd988ca Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Sun, 11 Jan 2026 18:25:26 -0800 Subject: [PATCH] fix: pattern unification/matching in `Sym` (#11976) This PR adds missing type checking for pattern variables during pattern matching/unification to prevent incorrect matches. Previously, the pattern matcher could incorrectly match expressions even when pattern variable types were incompatible with the matched subterm types. For example, a pattern like `x` where `x : BitVec 0` could match any term, ignoring the specific type constraint on `x`. This PR introduces a two-phase type checking approach: 1. **Static analysis** (`mkCheckTypeMask`): Identifies which pattern variables require type checking based on their syntactic position. Variables that appear only as arguments to function applications skip checking (the application structure already constrains their types), while variables in function position, binder contexts, or standalone positions must be checked. 2. **Runtime validation**: During matching, when a pattern variable is assigned, its type is checked against the matched subterm's type if flagged by the mask. Checking uses `withReducible` to balance soundness and performance. The PR also adds helper functions for debugging (`Sym.mkMethods`, `Sym.simpWith`, `Sym.simpGoal`) and fixes a minor issue where `Theorem.rewrite` could return `.step` with identical expressions instead of `.rfl`.Body: --- src/Lean/Meta/Sym/Pattern.lean | 191 +++++++++++++++++++++++----- src/Lean/Meta/Sym/Simp.lean | 1 + src/Lean/Meta/Sym/Simp/Debug.lean | 46 +++++++ src/Lean/Meta/Sym/Simp/Rewrite.lean | 5 +- src/Lean/Meta/Sym/Simp/SimpM.lean | 2 +- tests/lean/run/sym_simp_1.lean | 33 +++++ 6 files changed, 242 insertions(+), 36 deletions(-) create mode 100644 src/Lean/Meta/Sym/Simp/Debug.lean create mode 100644 tests/lean/run/sym_simp_1.lean diff --git a/src/Lean/Meta/Sym/Pattern.lean b/src/Lean/Meta/Sym/Pattern.lean index 7ef625ca7a..ee47849805 100644 --- a/src/Lean/Meta/Sym/Pattern.lean +++ b/src/Lean/Meta/Sym/Pattern.lean @@ -42,6 +42,10 @@ framework (`Sym`). The design prioritizes performance by using a two-phase appro - `instantiateRevS` ensures maximal sharing of result expressions -/ +/-- Helper function for checking whether types `α` and `β` are definitionally equal during unification/matching. -/ +def isDefEqTypes (α β : Expr) : MetaM Bool := do + withReducible <| isDefEq α β + /-- Collects `ProofInstInfo` for all function symbols occurring in `pattern`. @@ -56,11 +60,18 @@ def mkProofInstInfoMapFor (pattern : Expr) : MetaM (AssocList Name ProofInstInfo return fnInfos public structure Pattern where - levelParams : List Name - varTypes : Array Expr - isInstance : Array Bool - pattern : Expr - fnInfos : AssocList Name ProofInstInfo + levelParams : List Name + varTypes : Array Expr + isInstance : Array Bool + pattern : Expr + fnInfos : AssocList Name ProofInstInfo + /-- + If `checkTypeMask? = some mask`, then we must check the type of pattern variable `i` + if `mask[i]` is true. + Moreover `mask.size == varTypes.size`. + See `mkCheckTypeMask` + -/ + checkTypeMask? : Option (Array Bool) deriving Inhabited def uvarPrefix : Name := `_uvar @@ -79,6 +90,65 @@ def preprocessPattern (declName : Name) : MetaM (List Name × Expr) := do let type ← preprocessType type return (levelParams, type) +/-- +Creates a mask indicating which pattern variables require type checking during matching. + +When matching a pattern against a target expression, we must ensure that pattern variable +assignments are type-correct. However, checking types for every variable is expensive. +This function identifies which variables actually need type checking. + +**Key insight**: A pattern variable appearing as an argument to a function application +does not need its type checked separately, because the type information is already +encoded in the application structure, and we assume the input is type correct. + +**Variables that need type checking**: +- Variables in function position: `f x` where `f` is a pattern variable +- Variables in binder domains or bodies: `∀ x : α, β` or `fun x : α => b` +- Variables appearing alone (not as part of any application) + +**Variables that skip type checking**: +- Variables appearing only as arguments to applications: in `f x`, the variable `x` + does not need checking because the type of `f` constrains the type of `x` + +**Examples**: +- `bv0_eq (x : BitVec 0) : x = 0`: pattern is just `x`, must check type to ensure `BitVec 0` +- `forall_true : (∀ _ : α, True) = True`: `α` appears in binder domain, must check +- `Nat.add_zero (x : Nat) : x + 0 = x`: `x` is argument to `HAdd.hAdd`, no check needed + +**Note**: This analysis is conservative. It may mark some variables for checking even when +the type information is redundant (already determined by other constraints). This is +harmless—just extra work, not incorrect behavior. + +Returns an array of booleans parallel to the pattern's `varTypes`, where `true` indicates +the variable's type must be checked against the matched subterm's type. +-/ +def mkCheckTypeMask (pattern : Expr) (numPatternVars : Nat) : Array Bool := + let mask := Array.replicate numPatternVars false + go pattern 0 false mask +where + go (e : Expr) (offset : Nat) (isArg : Bool) : Array Bool → Array Bool := + match e with + | .app f a => go f offset isArg ∘ go a offset true + | .letE .. => unreachable! -- We zeta-reduce at `preprocessType` + | .const .. | .fvar _ | .sort _ | .mvar _ | .lit _ => id + | .mdata _ b => go b offset isArg + | .proj .. => id -- Should not occur in patterns + | .forallE _ d b _ + | .lam _ d b _ => go d offset false ∘ go b (offset+1) false + | .bvar idx => fun mask => + if idx >= offset && !isArg then + let idx := idx - offset + mask.set! (mask.size - idx - 1) true + else + mask + +def mkPatternCore (levelParams : List Name) (varTypes : Array Expr) (isInstance : Array Bool) + (pattern : Expr) : MetaM Pattern := do + let fnInfos ← mkProofInstInfoMapFor pattern + let checkTypeMask := mkCheckTypeMask pattern varTypes.size + let checkTypeMask? := if checkTypeMask.all (· == false) then none else some checkTypeMask + return { levelParams, varTypes, isInstance, pattern, fnInfos, checkTypeMask? } + /-- Creates a `Pattern` from the type of a theorem. @@ -100,9 +170,7 @@ public def mkPatternFromDecl (declName : Name) (num? : Option Nat := none) : Met if i < num then if let .forallE _ d b _ := type then return (← go (i+1) b (varTypes.push d) (isInstance.push (isClass? (← getEnv) d).isSome)) - let pattern := type - let fnInfos ← mkProofInstInfoMapFor pattern - return { levelParams, varTypes, isInstance, pattern, fnInfos } + mkPatternCore levelParams varTypes isInstance type go 0 type #[] #[] /-- @@ -123,9 +191,8 @@ public def mkEqPatternFromDecl (declName : Name) : MetaM (Pattern × Expr) := do return (← go b (varTypes.push d) (isInstance.push (isClass? (← getEnv) d).isSome)) else let_expr Eq _ lhs rhs := type | throwError "resulting type for `{.ofConstName declName}` is not an equality" - let pattern := lhs - let fnInfos ← mkProofInstInfoMapFor pattern - return ({ levelParams, varTypes, isInstance, pattern, fnInfos }, rhs) + let pattern ← mkPatternCore levelParams varTypes isInstance lhs + return (pattern, rhs) go type #[] #[] structure UnifyM.Context where @@ -139,6 +206,11 @@ structure UnifyM.State where ePending : Array (Expr × Expr) := #[] uPending : Array (Level × Level) := #[] iPending : Array (Expr × Expr) := #[] + /-- + Contains the index of the pattern variables that we must check whether its type + matches the type of the value assigned to it. + -/ + tPending : Array Nat := #[] us : List Level := [] args : Array Expr := #[] @@ -153,6 +225,14 @@ def pushLevelPending (u : Level) (v : Level) : UnifyM Unit := def pushInstPending (p : Expr) (e : Expr) : UnifyM Unit := modify fun s => { s with iPending := s.iPending.push (p, e) } +/-- +Mark pattern variable `i` for type checking. That is, at the end of phase 1 +we must check whether the type of this pattern variable is compatible with the type of +the value assigned to it. +-/ +def pushCheckTypePending (i : Nat) : UnifyM Unit := + modify fun s => { s with tPending := s.tPending.push i } + def assignExprIfUnassigned (bidx : Nat) (e : Expr) : UnifyM Unit := do let s ← get let i := s.eAssignment.size - bidx - 1 @@ -169,6 +249,8 @@ def assignExpr (bidx : Nat) (e : Expr) : UnifyM Bool := do return true else modify fun s => { s with eAssignment := s.eAssignment.set! i (some e) } + if (← read).pattern.checkTypeMask?.isSome then + pushCheckTypePending i return true def assignLevel (uidx : Nat) (u : Level) : UnifyM Bool := do @@ -369,6 +451,11 @@ structure DefEqM.Context where If `unify` is `false`, it contains which variables can be assigned. -/ mvarsNew : Array MVarId := #[] + /-- + If a metavariable is in this collection, when we perform the assignment `?m := v`, + we must check whether their types are compatible. + -/ + mvarsToCheckType : Array MVarId := #[] abbrev DefEqM := ReaderT DefEqM.Context SymM @@ -481,6 +568,12 @@ def mayAssign (t s : Expr) : SymM Bool := do let tMaxFVarDecl ← tMaxFVarId.getDecl return tMaxFVarDecl.index ≥ sMaxFVarDecl.index +@[inline] def whenUndefDo (x : DefEqM LBool) (k : DefEqM Bool) : DefEqM Bool := do + match (← x) with + | .true => return true + | .false => return false + | .undef => k + /-- Attempts to solve a unification constraint `t =?= s` where `t` has the form `?m a₁ ... aₙ` and satisfies the Miller pattern condition (all `aᵢ` are distinct, newly-introduced free variables). @@ -495,17 +588,20 @@ The `tFn` parameter must equal `t.getAppFn` (enforced by the proof argument). Remark: `t` may be of the form `?m`. -/ -def tryAssignMillerPattern (tFn : Expr) (t : Expr) (s : Expr) (_ : tFn = t.getAppFn) : DefEqM Bool := do - let .mvar mvarId := tFn | return false - if !(← isAssignableMVar mvarId) then return false - if !(← isMillerPatternArgs t) then return false +def tryAssignMillerPattern (tFn : Expr) (t : Expr) (s : Expr) (_ : tFn = t.getAppFn) : DefEqM LBool := do + let .mvar mvarId := tFn | return .undef + if !(← isAssignableMVar mvarId) then return .undef + if !(← isMillerPatternArgs t) then return .undef let s ← if t.isApp then mkLambdaFVarsS t.getAppArgs s else pure s - if !(← mayAssign tFn s) then return false + if !(← mayAssign tFn s) then return .undef + if (← read).mvarsToCheckType.contains mvarId then + unless (← Sym.isDefEqTypes (← mvarId.getDecl).type (← inferType s)) do + return .false mvarId.assign s - return true + return .true /-- Structural definitional equality for applications without `ProofInstInfo`. @@ -531,6 +627,11 @@ where if (← mvarId.isAssigned) then return false if !(← isAssignableMVar mvarId) then return false if !(← mayAssign t s) then return false + /- + **Note**: we don't need to check the type of `mvarId` here even if the variable is marked for + checking. This is the case because `tryAssignUnassigned` is invoked only from a context where `t` and `s` are the arguments + of function applications. + -/ mvarId.assign s return true @@ -619,11 +720,10 @@ def isDefEqMainImpl (t : Expr) (s : Expr) : DefEqM Bool := do isDefEqMain (← instantiateMVarsS t) s else if (← isAssignedMVar sFn) then isDefEqMain t (← instantiateMVarsS s) - else if (← tryAssignMillerPattern tFn t s rfl) then - return true - else if (← tryAssignMillerPattern sFn s t rfl) then - return true - else if let .fvar fvarId₁ := t then + else + whenUndefDo (tryAssignMillerPattern tFn t s rfl) do + whenUndefDo (tryAssignMillerPattern sFn s t rfl) do + if let .fvar fvarId₁ := t then unless (← read).zetaDelta do return false let some val₁ ← fvarId₁.getValue? | return false isDefEqMain val₁ s @@ -634,17 +734,19 @@ def isDefEqMainImpl (t : Expr) (s : Expr) : DefEqM Bool := do else isDefEqApp tFn t s rfl -abbrev DefEqM.run (unify := true) (zetaDelta := true) (mvarsNew : Array MVarId := #[]) (x : DefEqM α) : SymM α := do +abbrev DefEqM.run (unify := true) (zetaDelta := true) (mvarsNew : Array MVarId := #[]) + (mvarsToCheckType : Array MVarId := #[]) (x : DefEqM α) : SymM α := do let lctx ← getLCtx let lctxInitialNextIndex := lctx.decls.size - x { zetaDelta, lctxInitialNextIndex, unify, mvarsNew } + x { zetaDelta, lctxInitialNextIndex, unify, mvarsNew, mvarsToCheckType } /-- A lightweight structural definitional equality for the symbolic simulation framework. Unlike the full `isDefEq`, it avoids expensive operations while still supporting Miller pattern unification. -/ -public def isDefEqS (t : Expr) (s : Expr) (unify := true) (zetaDelta := true) (mvarsNew : Array MVarId := #[]) : SymM Bool := do - DefEqM.run (unify := unify) (zetaDelta := zetaDelta) (mvarsNew := mvarsNew) do +public def isDefEqS (t : Expr) (s : Expr) (unify := true) (zetaDelta := true) + (mvarsNew : Array MVarId := #[]) (mvarsToCheckType : Array MVarId := #[]): SymM Bool := do + DefEqM.run (unify := unify) (zetaDelta := zetaDelta) (mvarsNew := mvarsNew) (mvarsToCheckType := mvarsToCheckType) do isDefEqMain t s def noPending : UnifyM Bool := do @@ -655,7 +757,11 @@ def instantiateLevelParamsS (e : Expr) (paramNames : List Name) (us : List Level -- We do not assume `e` is maximally shared shareCommon (e.instantiateLevelParams paramNames us) -def mkPreResult : UnifyM Unit := do +inductive MkPreResultResult where + | failed + | success (mvarsToCheckType : Array MVarId) + +def mkPreResult : UnifyM MkPreResultResult := do let us ← (← get).uAssignment.toList.mapM fun | some val => pure val | none => mkFreshLevelMVar @@ -663,9 +769,20 @@ def mkPreResult : UnifyM Unit := do let varTypes := pattern.varTypes let isInstance := pattern.isInstance let eAssignment := (← get).eAssignment + let tPending := (← get).tPending let mut args := #[] + let mut mvarsToCheckType := #[] for h : i in *...eAssignment.size do if let .some val := eAssignment[i] then + if tPending.contains i then + let type := varTypes[i]! + let type ← instantiateLevelParamsS type pattern.levelParams us + let type ← instantiateRevBetaS type args + let valType ← inferType val + -- **Note**: we have to use the default `isDefEq` because the type of `val` + -- is not necessarily normalized. + unless (← isDefEqTypes type valType) do + return .failed args := args.push val else let type := varTypes[i]! @@ -677,8 +794,12 @@ def mkPreResult : UnifyM Unit := do continue let mvar ← mkFreshExprMVar type let mvar ← shareCommon mvar + if let some mask := (← read).pattern.checkTypeMask? then + if mask[i]! then + mvarsToCheckType := mvarsToCheckType.push mvar.mvarId! args := args.push mvar modify fun s => { s with args, us } + return .success mvarsToCheckType def processPendingLevel : UnifyM Bool := do let uPending := (← get).uPending @@ -704,7 +825,7 @@ def processPendingInst : UnifyM Bool := do return false return true -def processPendingExpr : UnifyM Bool := do +def processPendingExpr (mvarsToCheckType : Array MVarId) : UnifyM Bool := do let ePending := (← get).ePending if ePending.isEmpty then return true let pattern := (← read).pattern @@ -715,7 +836,7 @@ def processPendingExpr : UnifyM Bool := do let mvarsNew := if unify then #[] else args.filterMap fun | .mvar mvarId => some mvarId | _ => none - DefEqM.run unify zetaDelta mvarsNew do + DefEqM.run unify zetaDelta mvarsNew mvarsToCheckType do for (t, s) in ePending do let t ← instantiateLevelParamsS t pattern.levelParams us let t ← instantiateRevBetaS t args @@ -723,11 +844,11 @@ def processPendingExpr : UnifyM Bool := do return false return true -def processPending : UnifyM Bool := do +def processPending (mvarsToCheckType : Array MVarId) : UnifyM Bool := do if (← noPending) then return true else - processPendingLevel <&&> processPendingInst <&&> processPendingExpr + processPendingLevel <&&> processPendingInst <&&> processPendingExpr mvarsToCheckType abbrev UnifyM.run (pattern : Pattern) (unify : Bool) (zetaDelta : Bool) (k : UnifyM α) : SymM α := do let eAssignment := pattern.varTypes.map fun _ => none @@ -745,9 +866,11 @@ def mkResult : UnifyM MatchUnifyResult := do def main (p : Pattern) (e : Expr) (unify : Bool) (zetaDelta : Bool) : SymM (Option (MatchUnifyResult)) := UnifyM.run p unify zetaDelta do unless (← process p.pattern e) do return none - mkPreResult - unless (← processPending) do return none - return some (← mkResult) + match (← mkPreResult) with + | .failed => return none + | .success mvarsToCheckType => + unless (← processPending mvarsToCheckType) do return none + return some (← mkResult) /-- Attempts to match expression `e` against pattern `p` using purely syntactic matching. diff --git a/src/Lean/Meta/Sym/Simp.lean b/src/Lean/Meta/Sym/Simp.lean index 87fe244778..281ab90381 100644 --- a/src/Lean/Meta/Sym/Simp.lean +++ b/src/Lean/Meta/Sym/Simp.lean @@ -17,3 +17,4 @@ public import Lean.Meta.Sym.Simp.Theorems public import Lean.Meta.Sym.Simp.Have public import Lean.Meta.Sym.Simp.Lambda public import Lean.Meta.Sym.Simp.Forall +public import Lean.Meta.Sym.Simp.Debug diff --git a/src/Lean/Meta/Sym/Simp/Debug.lean b/src/Lean/Meta/Sym/Simp/Debug.lean new file mode 100644 index 0000000000..37a95d14cc --- /dev/null +++ b/src/Lean/Meta/Sym/Simp/Debug.lean @@ -0,0 +1,46 @@ +/- +Copyright (c) 2026 Amazon.com, Inc. or its affiliates. All Rights Reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Leonardo de Moura +-/ +module +prelude +public import Lean.Meta.Sym.Simp.SimpM +import Lean.Meta.Sym.Simp.Theorems +import Lean.Meta.Sym.Simp.Rewrite +import Lean.Meta.Sym.Util +import Lean.Meta.Tactic.Util +import Lean.Meta.AppBuilder +namespace Lean.Meta.Sym +open Simp +/-! +Helper functions for debugging purposes and creating tests. +-/ + +public def mkMethods (declNames : Array Name) : MetaM Methods := do + let mut thms : Theorems := {} + for declName in declNames do + thms := thms.insert (← mkTheoremFromDecl declName) + return { post := thms.rewrite } + +public def simpWith (k : Expr → SymM Result) (mvarId : MVarId) : MetaM (Option MVarId) := SymM.run do + let mvarId ← preprocessMVar mvarId + let decl ← mvarId.getDecl + let target := decl.type + match (← k target) with + | .rfl _ => throwError "`Sym.simp` made no progress " + | .step target' h _ => + let mvarNew ← mkFreshExprSyntheticOpaqueMVar target' decl.userName + let h ← mkAppM ``Eq.mpr #[h, mvarNew] + mvarId.assign h + if target'.isTrue then + mvarNew.mvarId!.assign (mkConst ``True.intro) + return none + else + return some mvarNew.mvarId! + +public def simpGoal (declNames : Array Name) (mvarId : MVarId) : MetaM (Option MVarId) := SymM.run do + let methods ← mkMethods declNames + simpWith (simp · methods) mvarId + +end Lean.Meta.Sym diff --git a/src/Lean/Meta/Sym/Simp/Rewrite.lean b/src/Lean/Meta/Sym/Simp/Rewrite.lean index 58655db04f..dbe9a963cd 100644 --- a/src/Lean/Meta/Sym/Simp/Rewrite.lean +++ b/src/Lean/Meta/Sym/Simp/Rewrite.lean @@ -33,7 +33,10 @@ public def Theorem.rewrite (thm : Theorem) (e : Expr) : SimpM Result := do let rhs := thm.rhs.instantiateLevelParams thm.pattern.levelParams result.us let rhs ← shareCommonInc rhs let expr ← instantiateRevBetaS rhs result.args - return .step expr proof + if isSameExpr e expr then + return .rfl + else + return .step expr proof else return .rfl diff --git a/src/Lean/Meta/Sym/Simp/SimpM.lean b/src/Lean/Meta/Sym/Simp/SimpM.lean index a643ab6fc8..f44b2dfbf9 100644 --- a/src/Lean/Meta/Sym/Simp/SimpM.lean +++ b/src/Lean/Meta/Sym/Simp/SimpM.lean @@ -101,7 +101,7 @@ invalidating the cache and causing O(2^n) behavior on conditional trees. /-- Configuration options for the structural simplifier. -/ structure Config where /-- Maximum number of steps that can be performed by the simplifier. -/ - maxSteps : Nat := 0 + maxSteps : Nat := 1000 -- **TODO**: many are still missing /-- diff --git a/tests/lean/run/sym_simp_1.lean b/tests/lean/run/sym_simp_1.lean new file mode 100644 index 0000000000..82a6ace8e1 --- /dev/null +++ b/tests/lean/run/sym_simp_1.lean @@ -0,0 +1,33 @@ +import Lean +open Lean Meta Elab Tactic + +theorem bv0_eq (x : BitVec 0) : x = 0 := BitVec.of_length_zero + +set_option warn.sorry false + +elab "sym_simp" "[" declNames:ident,* "]" : tactic => do + let declNames ← declNames.getElems.mapM resolveGlobalConstNoOverload + liftMetaTactic1 <| Sym.simpGoal declNames + +theorem heq_self : (x ≍ x) = True := by simp +theorem forall_true {α : Sort u} : (∀ _ : α, True) = True := by simp + +example : x + 0 ≍ x := by + fail_if_success sym_simp [] + sym_simp [Nat.add_zero, heq_self] + +example : 0 + x + 0 = x := by + sym_simp [Nat.add_zero, Nat.zero_add, eq_self] + +example : x = x := by + sym_simp [bv0_eq, eq_self] + +example (x y : BitVec 0) : x = y := by + sym_simp [bv0_eq, eq_self] + +example : ∀ x, 0 + x + 0 = x := by + sym_simp [Nat.add_zero, Nat.zero_add, eq_self] + sym_simp [forall_true] + +example : ∀ x, 0 + x + 0 = x := by + sym_simp [Nat.add_zero, Nat.zero_add, eq_self, forall_true]