diff --git a/src/Lean/Elab/Tactic/BVDecide/Frontend/Normalize/Structures.lean b/src/Lean/Elab/Tactic/BVDecide/Frontend/Normalize/Structures.lean index d32114d015..e34a707c4a 100644 --- a/src/Lean/Elab/Tactic/BVDecide/Frontend/Normalize/Structures.lean +++ b/src/Lean/Elab/Tactic/BVDecide/Frontend/Normalize/Structures.lean @@ -6,6 +6,8 @@ Authors: Henrik Böving prelude import Lean.Elab.Tactic.BVDecide.Frontend.Normalize.Basic import Lean.Meta.Tactic.Cases +import Lean.Meta.Tactic.Simp +import Lean.Meta.Injective /-! This module contains the implementation of the pre processing pass for automatically splitting up @@ -16,6 +18,8 @@ it is a non recursive structure and at least one of the following conditions hol - it contains something of type `BitVec`/`UIntX`/`Bool` - it is parametrized by an interesting type - it contains another interesting type +Afterwards we also apply relevant `injEq` theorems to support at least equality for these types out +of the box. -/ namespace Lean.Elab.Tactic.BVDecide @@ -28,7 +32,8 @@ Contains a cache for interesting and uninteresting types such that we don't dupl structures pass. -/ structure InterestingStructures where - interesting : Std.HashMap Name Bool := {} + interesting : Std.HashSet Name := {} + uninteresting : Std.HashSet Name := {} private abbrev M := StateRefT InterestingStructures MetaM @@ -37,15 +42,20 @@ namespace M @[inline] def lookup (n : Name) : M (Option Bool) := do let s ← get - return s.interesting.get? n + if s.uninteresting.contains n then + return some false + else if s.interesting.contains n then + return some true + else + return none @[inline] def markInteresting (n : Name) : M Unit := do - modify (fun s => {s with interesting := s.interesting.insert n true}) + modify (fun s => {s with interesting := s.interesting.insert n }) @[inline] def markUninteresting (n : Name) : M Unit := do - modify (fun s => {s with interesting := s.interesting.insert n false}) + modify (fun s => {s with uninteresting := s.uninteresting.insert n }) end M @@ -59,11 +69,31 @@ partial def structuresPass : Pass where return false else let some const := decl.type.getAppFn.constName? | return false - return interesting.getD const false + return interesting.contains const match goals with - | [goal] => return goal + | [goal] => postprocess goal interesting | _ => throwError "structures preprocessor generated more than 1 goal" where + postprocess (goal : MVarId) (interesting : Std.HashSet Name) : PreProcessM (Option MVarId) := do + goal.withContext do + let mut relevantLemmas : SimpTheoremsArray := #[] + for const in interesting do + let constInfo ← getConstInfoInduct const + let ctorName := (← getConstInfoCtor constInfo.ctors.head!).name + let lemmaName := mkInjectiveEqTheoremNameFor ctorName + if (← getEnv).find? lemmaName |>.isSome then + trace[Meta.Tactic.bv] m!"Using injEq lemma: {lemmaName}" + let statement ← mkConstWithLevelParams lemmaName + relevantLemmas ← relevantLemmas.addTheorem (.decl lemmaName) statement + let cfg ← PreProcessM.getConfig + let simpCtx ← Simp.mkContext + (config := { failIfUnchanged := false, maxSteps := cfg.maxSteps }) + (simpTheorems := relevantLemmas) + (congrTheorems := ← getSimpCongrTheorems) + let ⟨result?, _⟩ ← simpGoal goal (ctx := simpCtx) (fvarIdsToSimp := ← getPropHyps) + let some (_, newGoal) := result? | return none + return newGoal + checkContext (goal : MVarId) : M Unit := do goal.withContext do for decl in ← getLCtx do @@ -86,7 +116,7 @@ where let env ← getEnv if !isStructure env n then return false - let constInfo := (← getConstInfoInduct n) + let constInfo ← getConstInfoInduct n if constInfo.isRec then return false diff --git a/tests/lean/run/bv_structures.lean b/tests/lean/run/bv_structures.lean index bb1ad54491..fd566187d6 100644 --- a/tests/lean/run/bv_structures.lean +++ b/tests/lean/run/bv_structures.lean @@ -62,3 +62,50 @@ example (x y : BitVec 32) (p : Param x y) : x + y = 0 := by bv_decide end Ex5 + +namespace Ex6 + +structure Pair where + x : BitVec 32 + y : BitVec 32 + +example (a b : Pair) (h1 : a.x = a.y) (h2 : b.x = b.y) (h3 : a.x = b.y) : a = b := by + bv_decide + +example (a b : Pair) (h : a = b) : a.x = b.x := by + bv_decide + +end Ex6 + +namespace Ex7 + +structure Single where + z : BitVec 32 + +structure Pair where + x : BitVec 32 + y : Single + +example (a b : Pair) (h1 : a.x = a.y.z) (h2 : b.x = b.y.z) (h3 : a.x = b.y.z) : a = b := by + bv_decide + +example (a b : Pair) (h : a = b) : a.x = b.x ∧ a.y.z = b.y.z := by + bv_decide + +end Ex7 + +namespace Ex8 + +structure Single where + z : BitVec 32 + +structure Pair extends Single where + x : BitVec 32 + +example (a b : Pair) (h1 : a.x = a.z) (h2 : b.x = b.z) (h3 : a.x = b.z) : a = b := by + bv_decide + +example (a b : Pair) (h : a = b) : a.x = b.x ∧ a.z = b.z := by + bv_decide + +end Ex8