diff --git a/src/Lean/Meta/Sym/Pattern.lean b/src/Lean/Meta/Sym/Pattern.lean index 7ef625ca7a..ee47849805 100644 --- a/src/Lean/Meta/Sym/Pattern.lean +++ b/src/Lean/Meta/Sym/Pattern.lean @@ -42,6 +42,10 @@ framework (`Sym`). The design prioritizes performance by using a two-phase appro - `instantiateRevS` ensures maximal sharing of result expressions -/ +/-- Helper function for checking whether types `α` and `β` are definitionally equal during unification/matching. -/ +def isDefEqTypes (α β : Expr) : MetaM Bool := do + withReducible <| isDefEq α β + /-- Collects `ProofInstInfo` for all function symbols occurring in `pattern`. @@ -56,11 +60,18 @@ def mkProofInstInfoMapFor (pattern : Expr) : MetaM (AssocList Name ProofInstInfo return fnInfos public structure Pattern where - levelParams : List Name - varTypes : Array Expr - isInstance : Array Bool - pattern : Expr - fnInfos : AssocList Name ProofInstInfo + levelParams : List Name + varTypes : Array Expr + isInstance : Array Bool + pattern : Expr + fnInfos : AssocList Name ProofInstInfo + /-- + If `checkTypeMask? = some mask`, then we must check the type of pattern variable `i` + if `mask[i]` is true. + Moreover `mask.size == varTypes.size`. + See `mkCheckTypeMask` + -/ + checkTypeMask? : Option (Array Bool) deriving Inhabited def uvarPrefix : Name := `_uvar @@ -79,6 +90,65 @@ def preprocessPattern (declName : Name) : MetaM (List Name × Expr) := do let type ← preprocessType type return (levelParams, type) +/-- +Creates a mask indicating which pattern variables require type checking during matching. + +When matching a pattern against a target expression, we must ensure that pattern variable +assignments are type-correct. However, checking types for every variable is expensive. +This function identifies which variables actually need type checking. + +**Key insight**: A pattern variable appearing as an argument to a function application +does not need its type checked separately, because the type information is already +encoded in the application structure, and we assume the input is type correct. + +**Variables that need type checking**: +- Variables in function position: `f x` where `f` is a pattern variable +- Variables in binder domains or bodies: `∀ x : α, β` or `fun x : α => b` +- Variables appearing alone (not as part of any application) + +**Variables that skip type checking**: +- Variables appearing only as arguments to applications: in `f x`, the variable `x` + does not need checking because the type of `f` constrains the type of `x` + +**Examples**: +- `bv0_eq (x : BitVec 0) : x = 0`: pattern is just `x`, must check type to ensure `BitVec 0` +- `forall_true : (∀ _ : α, True) = True`: `α` appears in binder domain, must check +- `Nat.add_zero (x : Nat) : x + 0 = x`: `x` is argument to `HAdd.hAdd`, no check needed + +**Note**: This analysis is conservative. It may mark some variables for checking even when +the type information is redundant (already determined by other constraints). This is +harmless—just extra work, not incorrect behavior. + +Returns an array of booleans parallel to the pattern's `varTypes`, where `true` indicates +the variable's type must be checked against the matched subterm's type. +-/ +def mkCheckTypeMask (pattern : Expr) (numPatternVars : Nat) : Array Bool := + let mask := Array.replicate numPatternVars false + go pattern 0 false mask +where + go (e : Expr) (offset : Nat) (isArg : Bool) : Array Bool → Array Bool := + match e with + | .app f a => go f offset isArg ∘ go a offset true + | .letE .. => unreachable! -- We zeta-reduce at `preprocessType` + | .const .. | .fvar _ | .sort _ | .mvar _ | .lit _ => id + | .mdata _ b => go b offset isArg + | .proj .. => id -- Should not occur in patterns + | .forallE _ d b _ + | .lam _ d b _ => go d offset false ∘ go b (offset+1) false + | .bvar idx => fun mask => + if idx >= offset && !isArg then + let idx := idx - offset + mask.set! (mask.size - idx - 1) true + else + mask + +def mkPatternCore (levelParams : List Name) (varTypes : Array Expr) (isInstance : Array Bool) + (pattern : Expr) : MetaM Pattern := do + let fnInfos ← mkProofInstInfoMapFor pattern + let checkTypeMask := mkCheckTypeMask pattern varTypes.size + let checkTypeMask? := if checkTypeMask.all (· == false) then none else some checkTypeMask + return { levelParams, varTypes, isInstance, pattern, fnInfos, checkTypeMask? } + /-- Creates a `Pattern` from the type of a theorem. @@ -100,9 +170,7 @@ public def mkPatternFromDecl (declName : Name) (num? : Option Nat := none) : Met if i < num then if let .forallE _ d b _ := type then return (← go (i+1) b (varTypes.push d) (isInstance.push (isClass? (← getEnv) d).isSome)) - let pattern := type - let fnInfos ← mkProofInstInfoMapFor pattern - return { levelParams, varTypes, isInstance, pattern, fnInfos } + mkPatternCore levelParams varTypes isInstance type go 0 type #[] #[] /-- @@ -123,9 +191,8 @@ public def mkEqPatternFromDecl (declName : Name) : MetaM (Pattern × Expr) := do return (← go b (varTypes.push d) (isInstance.push (isClass? (← getEnv) d).isSome)) else let_expr Eq _ lhs rhs := type | throwError "resulting type for `{.ofConstName declName}` is not an equality" - let pattern := lhs - let fnInfos ← mkProofInstInfoMapFor pattern - return ({ levelParams, varTypes, isInstance, pattern, fnInfos }, rhs) + let pattern ← mkPatternCore levelParams varTypes isInstance lhs + return (pattern, rhs) go type #[] #[] structure UnifyM.Context where @@ -139,6 +206,11 @@ structure UnifyM.State where ePending : Array (Expr × Expr) := #[] uPending : Array (Level × Level) := #[] iPending : Array (Expr × Expr) := #[] + /-- + Contains the index of the pattern variables that we must check whether its type + matches the type of the value assigned to it. + -/ + tPending : Array Nat := #[] us : List Level := [] args : Array Expr := #[] @@ -153,6 +225,14 @@ def pushLevelPending (u : Level) (v : Level) : UnifyM Unit := def pushInstPending (p : Expr) (e : Expr) : UnifyM Unit := modify fun s => { s with iPending := s.iPending.push (p, e) } +/-- +Mark pattern variable `i` for type checking. That is, at the end of phase 1 +we must check whether the type of this pattern variable is compatible with the type of +the value assigned to it. +-/ +def pushCheckTypePending (i : Nat) : UnifyM Unit := + modify fun s => { s with tPending := s.tPending.push i } + def assignExprIfUnassigned (bidx : Nat) (e : Expr) : UnifyM Unit := do let s ← get let i := s.eAssignment.size - bidx - 1 @@ -169,6 +249,8 @@ def assignExpr (bidx : Nat) (e : Expr) : UnifyM Bool := do return true else modify fun s => { s with eAssignment := s.eAssignment.set! i (some e) } + if (← read).pattern.checkTypeMask?.isSome then + pushCheckTypePending i return true def assignLevel (uidx : Nat) (u : Level) : UnifyM Bool := do @@ -369,6 +451,11 @@ structure DefEqM.Context where If `unify` is `false`, it contains which variables can be assigned. -/ mvarsNew : Array MVarId := #[] + /-- + If a metavariable is in this collection, when we perform the assignment `?m := v`, + we must check whether their types are compatible. + -/ + mvarsToCheckType : Array MVarId := #[] abbrev DefEqM := ReaderT DefEqM.Context SymM @@ -481,6 +568,12 @@ def mayAssign (t s : Expr) : SymM Bool := do let tMaxFVarDecl ← tMaxFVarId.getDecl return tMaxFVarDecl.index ≥ sMaxFVarDecl.index +@[inline] def whenUndefDo (x : DefEqM LBool) (k : DefEqM Bool) : DefEqM Bool := do + match (← x) with + | .true => return true + | .false => return false + | .undef => k + /-- Attempts to solve a unification constraint `t =?= s` where `t` has the form `?m a₁ ... aₙ` and satisfies the Miller pattern condition (all `aᵢ` are distinct, newly-introduced free variables). @@ -495,17 +588,20 @@ The `tFn` parameter must equal `t.getAppFn` (enforced by the proof argument). Remark: `t` may be of the form `?m`. -/ -def tryAssignMillerPattern (tFn : Expr) (t : Expr) (s : Expr) (_ : tFn = t.getAppFn) : DefEqM Bool := do - let .mvar mvarId := tFn | return false - if !(← isAssignableMVar mvarId) then return false - if !(← isMillerPatternArgs t) then return false +def tryAssignMillerPattern (tFn : Expr) (t : Expr) (s : Expr) (_ : tFn = t.getAppFn) : DefEqM LBool := do + let .mvar mvarId := tFn | return .undef + if !(← isAssignableMVar mvarId) then return .undef + if !(← isMillerPatternArgs t) then return .undef let s ← if t.isApp then mkLambdaFVarsS t.getAppArgs s else pure s - if !(← mayAssign tFn s) then return false + if !(← mayAssign tFn s) then return .undef + if (← read).mvarsToCheckType.contains mvarId then + unless (← Sym.isDefEqTypes (← mvarId.getDecl).type (← inferType s)) do + return .false mvarId.assign s - return true + return .true /-- Structural definitional equality for applications without `ProofInstInfo`. @@ -531,6 +627,11 @@ where if (← mvarId.isAssigned) then return false if !(← isAssignableMVar mvarId) then return false if !(← mayAssign t s) then return false + /- + **Note**: we don't need to check the type of `mvarId` here even if the variable is marked for + checking. This is the case because `tryAssignUnassigned` is invoked only from a context where `t` and `s` are the arguments + of function applications. + -/ mvarId.assign s return true @@ -619,11 +720,10 @@ def isDefEqMainImpl (t : Expr) (s : Expr) : DefEqM Bool := do isDefEqMain (← instantiateMVarsS t) s else if (← isAssignedMVar sFn) then isDefEqMain t (← instantiateMVarsS s) - else if (← tryAssignMillerPattern tFn t s rfl) then - return true - else if (← tryAssignMillerPattern sFn s t rfl) then - return true - else if let .fvar fvarId₁ := t then + else + whenUndefDo (tryAssignMillerPattern tFn t s rfl) do + whenUndefDo (tryAssignMillerPattern sFn s t rfl) do + if let .fvar fvarId₁ := t then unless (← read).zetaDelta do return false let some val₁ ← fvarId₁.getValue? | return false isDefEqMain val₁ s @@ -634,17 +734,19 @@ def isDefEqMainImpl (t : Expr) (s : Expr) : DefEqM Bool := do else isDefEqApp tFn t s rfl -abbrev DefEqM.run (unify := true) (zetaDelta := true) (mvarsNew : Array MVarId := #[]) (x : DefEqM α) : SymM α := do +abbrev DefEqM.run (unify := true) (zetaDelta := true) (mvarsNew : Array MVarId := #[]) + (mvarsToCheckType : Array MVarId := #[]) (x : DefEqM α) : SymM α := do let lctx ← getLCtx let lctxInitialNextIndex := lctx.decls.size - x { zetaDelta, lctxInitialNextIndex, unify, mvarsNew } + x { zetaDelta, lctxInitialNextIndex, unify, mvarsNew, mvarsToCheckType } /-- A lightweight structural definitional equality for the symbolic simulation framework. Unlike the full `isDefEq`, it avoids expensive operations while still supporting Miller pattern unification. -/ -public def isDefEqS (t : Expr) (s : Expr) (unify := true) (zetaDelta := true) (mvarsNew : Array MVarId := #[]) : SymM Bool := do - DefEqM.run (unify := unify) (zetaDelta := zetaDelta) (mvarsNew := mvarsNew) do +public def isDefEqS (t : Expr) (s : Expr) (unify := true) (zetaDelta := true) + (mvarsNew : Array MVarId := #[]) (mvarsToCheckType : Array MVarId := #[]): SymM Bool := do + DefEqM.run (unify := unify) (zetaDelta := zetaDelta) (mvarsNew := mvarsNew) (mvarsToCheckType := mvarsToCheckType) do isDefEqMain t s def noPending : UnifyM Bool := do @@ -655,7 +757,11 @@ def instantiateLevelParamsS (e : Expr) (paramNames : List Name) (us : List Level -- We do not assume `e` is maximally shared shareCommon (e.instantiateLevelParams paramNames us) -def mkPreResult : UnifyM Unit := do +inductive MkPreResultResult where + | failed + | success (mvarsToCheckType : Array MVarId) + +def mkPreResult : UnifyM MkPreResultResult := do let us ← (← get).uAssignment.toList.mapM fun | some val => pure val | none => mkFreshLevelMVar @@ -663,9 +769,20 @@ def mkPreResult : UnifyM Unit := do let varTypes := pattern.varTypes let isInstance := pattern.isInstance let eAssignment := (← get).eAssignment + let tPending := (← get).tPending let mut args := #[] + let mut mvarsToCheckType := #[] for h : i in *...eAssignment.size do if let .some val := eAssignment[i] then + if tPending.contains i then + let type := varTypes[i]! + let type ← instantiateLevelParamsS type pattern.levelParams us + let type ← instantiateRevBetaS type args + let valType ← inferType val + -- **Note**: we have to use the default `isDefEq` because the type of `val` + -- is not necessarily normalized. + unless (← isDefEqTypes type valType) do + return .failed args := args.push val else let type := varTypes[i]! @@ -677,8 +794,12 @@ def mkPreResult : UnifyM Unit := do continue let mvar ← mkFreshExprMVar type let mvar ← shareCommon mvar + if let some mask := (← read).pattern.checkTypeMask? then + if mask[i]! then + mvarsToCheckType := mvarsToCheckType.push mvar.mvarId! args := args.push mvar modify fun s => { s with args, us } + return .success mvarsToCheckType def processPendingLevel : UnifyM Bool := do let uPending := (← get).uPending @@ -704,7 +825,7 @@ def processPendingInst : UnifyM Bool := do return false return true -def processPendingExpr : UnifyM Bool := do +def processPendingExpr (mvarsToCheckType : Array MVarId) : UnifyM Bool := do let ePending := (← get).ePending if ePending.isEmpty then return true let pattern := (← read).pattern @@ -715,7 +836,7 @@ def processPendingExpr : UnifyM Bool := do let mvarsNew := if unify then #[] else args.filterMap fun | .mvar mvarId => some mvarId | _ => none - DefEqM.run unify zetaDelta mvarsNew do + DefEqM.run unify zetaDelta mvarsNew mvarsToCheckType do for (t, s) in ePending do let t ← instantiateLevelParamsS t pattern.levelParams us let t ← instantiateRevBetaS t args @@ -723,11 +844,11 @@ def processPendingExpr : UnifyM Bool := do return false return true -def processPending : UnifyM Bool := do +def processPending (mvarsToCheckType : Array MVarId) : UnifyM Bool := do if (← noPending) then return true else - processPendingLevel <&&> processPendingInst <&&> processPendingExpr + processPendingLevel <&&> processPendingInst <&&> processPendingExpr mvarsToCheckType abbrev UnifyM.run (pattern : Pattern) (unify : Bool) (zetaDelta : Bool) (k : UnifyM α) : SymM α := do let eAssignment := pattern.varTypes.map fun _ => none @@ -745,9 +866,11 @@ def mkResult : UnifyM MatchUnifyResult := do def main (p : Pattern) (e : Expr) (unify : Bool) (zetaDelta : Bool) : SymM (Option (MatchUnifyResult)) := UnifyM.run p unify zetaDelta do unless (← process p.pattern e) do return none - mkPreResult - unless (← processPending) do return none - return some (← mkResult) + match (← mkPreResult) with + | .failed => return none + | .success mvarsToCheckType => + unless (← processPending mvarsToCheckType) do return none + return some (← mkResult) /-- Attempts to match expression `e` against pattern `p` using purely syntactic matching. diff --git a/src/Lean/Meta/Sym/Simp.lean b/src/Lean/Meta/Sym/Simp.lean index 87fe244778..281ab90381 100644 --- a/src/Lean/Meta/Sym/Simp.lean +++ b/src/Lean/Meta/Sym/Simp.lean @@ -17,3 +17,4 @@ public import Lean.Meta.Sym.Simp.Theorems public import Lean.Meta.Sym.Simp.Have public import Lean.Meta.Sym.Simp.Lambda public import Lean.Meta.Sym.Simp.Forall +public import Lean.Meta.Sym.Simp.Debug diff --git a/src/Lean/Meta/Sym/Simp/Debug.lean b/src/Lean/Meta/Sym/Simp/Debug.lean new file mode 100644 index 0000000000..37a95d14cc --- /dev/null +++ b/src/Lean/Meta/Sym/Simp/Debug.lean @@ -0,0 +1,46 @@ +/- +Copyright (c) 2026 Amazon.com, Inc. or its affiliates. All Rights Reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Leonardo de Moura +-/ +module +prelude +public import Lean.Meta.Sym.Simp.SimpM +import Lean.Meta.Sym.Simp.Theorems +import Lean.Meta.Sym.Simp.Rewrite +import Lean.Meta.Sym.Util +import Lean.Meta.Tactic.Util +import Lean.Meta.AppBuilder +namespace Lean.Meta.Sym +open Simp +/-! +Helper functions for debugging purposes and creating tests. +-/ + +public def mkMethods (declNames : Array Name) : MetaM Methods := do + let mut thms : Theorems := {} + for declName in declNames do + thms := thms.insert (← mkTheoremFromDecl declName) + return { post := thms.rewrite } + +public def simpWith (k : Expr → SymM Result) (mvarId : MVarId) : MetaM (Option MVarId) := SymM.run do + let mvarId ← preprocessMVar mvarId + let decl ← mvarId.getDecl + let target := decl.type + match (← k target) with + | .rfl _ => throwError "`Sym.simp` made no progress " + | .step target' h _ => + let mvarNew ← mkFreshExprSyntheticOpaqueMVar target' decl.userName + let h ← mkAppM ``Eq.mpr #[h, mvarNew] + mvarId.assign h + if target'.isTrue then + mvarNew.mvarId!.assign (mkConst ``True.intro) + return none + else + return some mvarNew.mvarId! + +public def simpGoal (declNames : Array Name) (mvarId : MVarId) : MetaM (Option MVarId) := SymM.run do + let methods ← mkMethods declNames + simpWith (simp · methods) mvarId + +end Lean.Meta.Sym diff --git a/src/Lean/Meta/Sym/Simp/Rewrite.lean b/src/Lean/Meta/Sym/Simp/Rewrite.lean index 58655db04f..dbe9a963cd 100644 --- a/src/Lean/Meta/Sym/Simp/Rewrite.lean +++ b/src/Lean/Meta/Sym/Simp/Rewrite.lean @@ -33,7 +33,10 @@ public def Theorem.rewrite (thm : Theorem) (e : Expr) : SimpM Result := do let rhs := thm.rhs.instantiateLevelParams thm.pattern.levelParams result.us let rhs ← shareCommonInc rhs let expr ← instantiateRevBetaS rhs result.args - return .step expr proof + if isSameExpr e expr then + return .rfl + else + return .step expr proof else return .rfl diff --git a/src/Lean/Meta/Sym/Simp/SimpM.lean b/src/Lean/Meta/Sym/Simp/SimpM.lean index a643ab6fc8..f44b2dfbf9 100644 --- a/src/Lean/Meta/Sym/Simp/SimpM.lean +++ b/src/Lean/Meta/Sym/Simp/SimpM.lean @@ -101,7 +101,7 @@ invalidating the cache and causing O(2^n) behavior on conditional trees. /-- Configuration options for the structural simplifier. -/ structure Config where /-- Maximum number of steps that can be performed by the simplifier. -/ - maxSteps : Nat := 0 + maxSteps : Nat := 1000 -- **TODO**: many are still missing /-- diff --git a/tests/lean/run/sym_simp_1.lean b/tests/lean/run/sym_simp_1.lean new file mode 100644 index 0000000000..82a6ace8e1 --- /dev/null +++ b/tests/lean/run/sym_simp_1.lean @@ -0,0 +1,33 @@ +import Lean +open Lean Meta Elab Tactic + +theorem bv0_eq (x : BitVec 0) : x = 0 := BitVec.of_length_zero + +set_option warn.sorry false + +elab "sym_simp" "[" declNames:ident,* "]" : tactic => do + let declNames ← declNames.getElems.mapM resolveGlobalConstNoOverload + liftMetaTactic1 <| Sym.simpGoal declNames + +theorem heq_self : (x ≍ x) = True := by simp +theorem forall_true {α : Sort u} : (∀ _ : α, True) = True := by simp + +example : x + 0 ≍ x := by + fail_if_success sym_simp [] + sym_simp [Nat.add_zero, heq_self] + +example : 0 + x + 0 = x := by + sym_simp [Nat.add_zero, Nat.zero_add, eq_self] + +example : x = x := by + sym_simp [bv0_eq, eq_self] + +example (x y : BitVec 0) : x = y := by + sym_simp [bv0_eq, eq_self] + +example : ∀ x, 0 + x + 0 = x := by + sym_simp [Nat.add_zero, Nat.zero_add, eq_self] + sym_simp [forall_true] + +example : ∀ x, 0 + x + 0 = x := by + sym_simp [Nat.add_zero, Nat.zero_add, eq_self, forall_true]