diff --git a/src/Lean/Meta/ACLt.lean b/src/Lean/Meta/ACLt.lean new file mode 100644 index 0000000000..7a081ea710 --- /dev/null +++ b/src/Lean/Meta/ACLt.lean @@ -0,0 +1,161 @@ +/- +Copyright (c) 2022 Microsoft Corporation. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Leonardo de Moura +-/ +import Lean.Meta.Basic +import Lean.Meta.FunInfo + +namespace Lean + +def Expr.ctorWeight : Expr → UInt8 + | bvar .. => 0 + | fvar .. => 1 + | mvar .. => 2 + | sort .. => 3 + | const .. => 4 + | lit .. => 5 + | mdata .. => 6 + | proj .. => 7 + | app .. => 8 + | lam .. => 9 + | forallE .. => 10 + | letE .. => 11 + +namespace Meta +namespace ACLt + +mutual + +/-- + An AC-compatible ordering. + + Recall that an AC-compatible ordering if it is monotonic, well-founded, and total. + Both KBO and LPO are AC-compatible. KBO is faster, but we do not cache the weight of + each expression in Lean 4. Even if we did, we would need to have a weight where implicit instace arguments are ignored. + So, we use a LPO-like term ordering. + + Remark: this method is used to implement ordered rewriting. We ignore implicit instance + arguments to address an issue reported at issue #972. + + Remark: the order is not really total on terms since + - We instance implicit arguments. + - We ignore metadata. + - We ignore universe parameterst at constants. +-/ +unsafe def lt (a b : Expr) : MetaM Bool := + if ptrAddrUnsafe a == ptrAddrUnsafe b then + false + -- We ignore metadata + else if a.isMData then + lt a.mdataExpr! b + else if b.isMData then + lt a b.mdataExpr! + else + lpo a b + +where + ltPair (a₁ a₂ b₁ b₂ : Expr) : MetaM Bool := do + if (← lt a₁ b₁) then + return true + else if (← lt b₁ a₁) then + return false + else + lt a₂ b₂ + + ltApp (a b : Expr) : MetaM Bool := do + let aFn := a.getAppFn + let bFn := b.getAppFn + if (← lt aFn bFn) then + return true + else if (← lt bFn aFn) then + return false + else + let aArgs := a.getAppArgs + let bArgs := b.getAppArgs + if aArgs.size < bArgs.size then + return true + else if aArgs.size > bArgs.size then + return false + else + let infos := (← getFunInfoNArgs aFn aArgs.size).paramInfo + for i in [:infos.size] do + -- We ignore instance implicit arguments during comparison + if !infos[i].isInstImplicit then + if (← lt aArgs[i] bArgs[i]) then + return true + else if (← lt bArgs[i] aArgs[i]) then + return false + for i in [infos.size:aArgs.size] do + if (← lt aArgs[i] bArgs[i]) then + return true + else if (← lt bArgs[i] aArgs[i]) then + return false + return false + + lexSameCtor (a b : Expr) : MetaM Bool := + match a with + -- Atomic + | Expr.bvar i .. => i < b.bvarIdx! + | Expr.fvar id .. => Name.quickLt id.name b.fvarId!.name + | Expr.mvar id .. => Name.quickLt id.name b.mvarId!.name + | Expr.sort u .. => Level.normLt u b.sortLevel! + | Expr.const n .. => Name.quickLt n b.constName! -- We igore the levels + | Expr.lit v .. => Literal.lt v b.litValue! + -- Composite + | Expr.proj _ i e .. => if i != b.projIdx! then i < b.projIdx! else lt e b.projExpr! + | Expr.app .. => ltApp a b + | Expr.lam _ d e .. => ltPair d e b.bindingDomain! b.bindingBody! + | Expr.forallE _ d e .. => ltPair d e b.bindingDomain! b.bindingBody! + | Expr.letE _ _ v e .. => ltPair v e b.letValue! b.letBody! + -- See main function + | Expr.mdata .. => unreachable! + + lex (a b : Expr) : MetaM Bool := + if a.ctorWeight < b.ctorWeight then + return true + else if a.ctorWeight > b.ctorWeight then + return false + else + lexSameCtor a b + + allChildrenLt (a b : Expr) : MetaM Bool := + match a with + | Expr.proj _ _ e .. => lt e b + | Expr.app .. => + a.withApp fun f args => do + let infos := (← getFunInfoNArgs f args.size).paramInfo + for i in [:infos.size] do + -- We ignore instance implicit arguments during comparison + if !infos[i].isInstImplicit then + if !(← lt args[i] b) then + return false + for i in [infos.size:args.size] do + if !(← lt args[i] b) then + return false + return true + | Expr.lam _ d e .. => lt d b <&&> lt e b + | Expr.forallE _ d e .. => lt d b <&&> lt e b + | Expr.letE _ _ v e .. => lt v b <&&> lt e b + | _ => return true + + someChildGe (a b : Expr) : MetaM Bool := + return !(← allChildrenLt a b) + + -- lpo is only used when `a` and `b` have the same approximate depth, and it is >= 255 + lpo (a b : Expr) : MetaM Bool := do + -- Case 1: `a < b` if for some child `b_i` of `b`, we have `b_i >= a` + someChildGe b a + -- Case 2: `a < b` if `a.ctorWeight < b.ctorWeight` and for all children `a_i` of `a`, `a_i < b` + <||> (a.ctorWeight < b.ctorWeight <&&> allChildrenLt a b) + -- Case 3: `a < b` if `a` & `b` have the same ctor, and `a` is lexicographically smaller + <||> (a.ctorWeight == b.ctorWeight <&&> lexSameCtor a b) + +end + +end ACLt + +@[implementedBy ACLt.lt] +constant Expr.acLt : Expr → Expr → MetaM Bool + +end Lean.Meta diff --git a/src/Lean/Meta/Tactic/Simp/Rewrite.lean b/src/Lean/Meta/Tactic/Simp/Rewrite.lean index 5a55477b38..8f6c55292b 100644 --- a/src/Lean/Meta/Tactic/Simp/Rewrite.lean +++ b/src/Lean/Meta/Tactic/Simp/Rewrite.lean @@ -3,7 +3,7 @@ Copyright (c) 2020 Microsoft Corporation. All rights reserved. Released under Apache 2.0 license as described in the file LICENSE. Authors: Leonardo de Moura -/ -import Lean.Util.ACLt +import Lean.Meta.ACLt import Lean.Meta.AppBuilder import Lean.Meta.SynthInstance import Lean.Meta.Tactic.Simp.Types @@ -55,9 +55,10 @@ private def tryLemmaCore (lhs : Expr) (xs : Array Expr) (bis : Array BinderInfo) let rhs ← instantiateMVars type.appArg! if e == rhs then return none - if lemma.perm && !Expr.acLt rhs e then - trace[Meta.Tactic.simp.rewrite] "{lemma}, perm rejected {e} ==> {rhs}" - return none + if lemma.perm then + if !(← Expr.acLt rhs e) then + trace[Meta.Tactic.simp.rewrite] "{lemma}, perm rejected {e} ==> {rhs}" + return none trace[Meta.Tactic.simp.rewrite] "{lemma}, {e} ==> {rhs}" return some { expr := rhs, proof? := proof } else diff --git a/src/Lean/Util.lean b/src/Lean/Util.lean index 6744c5a5fc..efa27dd826 100644 --- a/src/Lean/Util.lean +++ b/src/Lean/Util.lean @@ -23,4 +23,3 @@ import Lean.Util.FoldConsts import Lean.Util.SCC import Lean.Util.OccursCheck import Lean.Util.Paths -import Lean.Util.ACLt diff --git a/src/Lean/Util/ACLt.lean b/src/Lean/Util/ACLt.lean deleted file mode 100644 index 1b5803219c..0000000000 --- a/src/Lean/Util/ACLt.lean +++ /dev/null @@ -1,110 +0,0 @@ -/- -Copyright (c) 2022 Microsoft Corporation. All rights reserved. -Released under Apache 2.0 license as described in the file LICENSE. -Authors: Leonardo de Moura --/ -import Lean.Expr - -namespace Lean - -def Expr.ctorWeight : Expr → UInt8 - | bvar .. => 0 - | fvar .. => 1 - | mvar .. => 2 - | sort .. => 3 - | const .. => 4 - | lit .. => 5 - | mdata .. => 6 - | proj .. => 7 - | app .. => 8 - | lam .. => 9 - | forallE .. => 10 - | letE .. => 11 - -namespace ACLt - -mutual - -/-- - An AC-compatible ordering. - - Recall that an AC-compatible ordering if it is monotonic, well-founded, and total. - Both KBO and LPO are AC-compatible. KBO is faster, but we do not cache the weight of - each expression in Lean 4, only the approximated depth (it saturates at 255). - Thus, we use a hybrid of KBO and LPO. --/ -unsafe def lt (a b : Expr) : Bool := - if ptrAddrUnsafe a == ptrAddrUnsafe b then - false - -- We ignore metadata - else if a.isMData then - lt a.mdataExpr! b - else if b.isMData then - lt a b.mdataExpr! - else if a.approxDepth < b.approxDepth then - true - else if a.approxDepth > b.approxDepth then - false - else if a.approxDepth < 255 then - lex a b - else - lpo a b -where - ltPair (a₁ a₂ b₁ b₂ : Expr) : Bool := - if a₁ != b₁ then lt a₁ b₁ else lt a₂ b₂ - - lexSameCtor (a b : Expr) : Bool := - match a with - -- Atomic - | Expr.bvar i .. => i < b.bvarIdx! - | Expr.fvar id .. => Name.quickLt id.name b.fvarId!.name - | Expr.mvar id .. => Name.quickLt id.name b.mvarId!.name - | Expr.sort u .. => Level.normLt u b.sortLevel! - | Expr.const n .. => Name.quickLt n b.constName! -- We igore the levels - | Expr.lit v .. => Literal.lt v b.litValue! - -- Composite - | Expr.proj _ i e .. => if i != b.projIdx! then i < b.projIdx! else lt e b.projExpr! - | Expr.app f e .. => ltPair f e b.appFn! b.appArg! - | Expr.lam _ d e .. => ltPair d e b.bindingDomain! b.bindingBody! - | Expr.forallE _ d e .. => ltPair d e b.bindingDomain! b.bindingBody! - | Expr.letE _ _ v e .. => ltPair v e b.letValue! b.letBody! - -- See main function - | Expr.mdata .. => unreachable! - - lex (a b : Expr) : Bool := - if a.ctorWeight < b.ctorWeight then - true - else if a.ctorWeight > b.ctorWeight then - false - else - lexSameCtor a b - - allChildrenLt (a b : Expr) : Bool := - match a with - | Expr.proj _ _ e .. => lt e b - | Expr.app f e .. => lt f b && lt e b - | Expr.lam _ d e .. => lt d b && lt e b - | Expr.forallE _ d e .. => lt d b && lt e b - | Expr.letE _ _ v e .. => lt v b && lt e b - | _ => unreachable! - - someChildGe (a b : Expr) : Bool := - !allChildrenLt a b - - -- lpo is only used when `a` and `b` have the same approximate depth, and it is >= 255 - lpo (a b : Expr) : Bool := - -- Case 1: `a < b` if for some child `b_i` of `b`, we have `b_i >= a` - someChildGe b a - -- Case 2: `a < b` if `a.ctorWeight < b.ctorWeight` and for all children `a_i` of `a`, `a_i < b` - || (a.ctorWeight < b.ctorWeight && allChildrenLt a b) - -- Case 3: `a < b` if `a` & `b` have the same ctor, and `a` is lexicographically smaller - || (a.ctorWeight == b.ctorWeight && lexSameCtor a b) - -end - -end ACLt - -@[implementedBy ACLt.lt] -constant Expr.acLt : Expr → Expr → Bool - -end Lean diff --git a/tests/lean/run/972.lean b/tests/lean/run/972.lean new file mode 100644 index 0000000000..f84f6574c3 --- /dev/null +++ b/tests/lean/run/972.lean @@ -0,0 +1,22 @@ +class Semigroup (M : Type u) extends Mul M where + mul_assoc (a b c : M) : (a * b) * c = a * (b * c) +export Semigroup (mul_assoc) + +class CommSemigroup (M : Type u) extends Semigroup M where + mul_comm (a b : M) : a * b = b * a +export CommSemigroup (mul_comm) + +class Monoid (M : Type u) extends Semigroup M, OfNat M 1 where + mul_one (m : M) : m * 1 = m + one_mul (m : M) : 1 * m = m + +class CommMonoid (M : Type u) extends Monoid M, CommSemigroup M + +theorem mul_left_comm {M} [CommSemigroup M] (a b c : M) : a * (b * c) = b * (a * c) := by + rw [← mul_assoc, mul_comm a b, mul_assoc] + +example {M} [CommMonoid M] (a b c d : M) : a * (b * (c * d)) = (a * c) * (b * d) := by + simp only [mul_left_comm, mul_comm, mul_assoc] + +example {M} [CommMonoid M] (a b c d : M) : (b * (c * d)) = (c) * (b * d) := by + simp only [mul_left_comm, mul_comm, mul_assoc]