From 38b4062edb5f14185df19d702953756f48ceaba0 Mon Sep 17 00:00:00 2001 From: Joachim Breitner Date: Fri, 19 Sep 2025 16:13:57 +0200 Subject: [PATCH] feat: linear-size Ord instance (#10270) This PR adds an alternative implementation of `Deriving Ord` based on comparing `.ctorIdx` and using a dedicated matcher for comparing same constructors (added in #10152). The new option `deriving.ord.linear_construction_threshold` sets the constructor count threshold (10 by default) for using the new construction. It also (unconditionally) changes the implementation for enumeration types to simply compare the `ctorIdx`. --- src/Lean/DocString/Types.lean | 1 + src/Lean/Elab/Deriving/Ord.lean | 94 ++++++++++++++++++++++++++++++--- tests/lean/run/Ord.lean | 4 +- 3 files changed, 90 insertions(+), 9 deletions(-) diff --git a/src/Lean/DocString/Types.lean b/src/Lean/DocString/Types.lean index a3d4e29369..66a0606270 100644 --- a/src/Lean/DocString/Types.lean +++ b/src/Lean/DocString/Types.lean @@ -10,6 +10,7 @@ prelude public import Init.Data.Repr public import Init.Data.Ord +import Init.Data.Nat.Compare set_option linter.missingDocs true diff --git a/src/Lean/Elab/Deriving/Ord.lean b/src/Lean/Elab/Deriving/Ord.lean index e2234ee6ba..6dae2b89c4 100644 --- a/src/Lean/Elab/Deriving/Ord.lean +++ b/src/Lean/Elab/Deriving/Ord.lean @@ -6,12 +6,21 @@ Authors: Dany Fabian module prelude -public import Lean.Meta.Transform -public import Lean.Elab.Deriving.Basic -public import Lean.Elab.Deriving.Util +public import Lean.Data.Options +import Lean.Meta.Transform +import Lean.Elab.Deriving.Basic +import Lean.Elab.Deriving.Util +import Lean.Meta.Constructions.CtorIdx +import Lean.Meta.Constructions.CasesOnSameCtor import Lean.Meta.SameCtorUtils -public section +register_builtin_option deriving.ord.linear_construction_threshold : Nat := { + defValue := 10 + descr := "If the inductive data type has this many or more constructors, use a different \ + implementation for implementing `Ord` that avoids the quadratic code size produced by the \ + default implementation.\n\n\ + The alternative construction compiles to less efficient code in some cases, so by default \ + it is only used for inductive types with 10 or more constructors." } namespace Lean.Elab.Deriving.Ord open Lean.Parser.Term @@ -20,7 +29,7 @@ open Meta def mkOrdHeader (indVal : InductiveVal) : TermElabM Header := do mkHeader `Ord 2 indVal -def mkMatch (header : Header) (indVal : InductiveVal) : TermElabM Term := do +def mkMatchOld (header : Header) (indVal : InductiveVal) : TermElabM Term := do let discrs ← mkDiscrs header indVal let alts ← mkAlts `(match $[$discrs],* with $alts:matchAlt*) @@ -74,6 +83,59 @@ where alts := alts ++ (alt : Array (TSyntax ``matchAlt)) return alts.pop.pop +def mkMatchNew (header : Header) (indVal : InductiveVal) : TermElabM Term := do + assert! header.targetNames.size == 2 + + let x1 := mkIdent header.targetNames[0]! + let x2 := mkIdent header.targetNames[1]! + let ctorIdxName := mkCtorIdxName indVal.name + -- NB: the getMatcherInfo? assumes all mathcers are called `match_` + let casesOnSameCtorName ← mkFreshUserName (indVal.name ++ `match_on_same_ctor) + mkCasesOnSameCtor casesOnSameCtorName indVal.name + let alts ← Array.ofFnM (n := indVal.numCtors) fun ⟨ctorIdx, _⟩ => do + let ctorName := indVal.ctors[ctorIdx]! + let ctorInfo ← getConstInfoCtor ctorName + forallTelescopeReducing ctorInfo.type fun xs type => do + let type ← Core.betaReduce type -- we 'beta-reduce' to eliminate "artificial" dependencies + let mut ctorArgs1 : Array Term := #[] + let mut ctorArgs2 : Array Term := #[] + + let mut rhsCont : Term → TermElabM Term := fun rhs => pure rhs + for i in *...ctorInfo.numFields do + let x := xs[indVal.numParams + i]! + if occursOrInType (← getLCtx) x type then + -- If resulting type depends on this field, we don't need to compare + -- and the casesOnSameCtor only has a parameter for it once + ctorArgs1 := ctorArgs1.push (← `(_)) + else + let userName ← x.fvarId!.getUserName + let a := mkIdent (← mkFreshUserName userName) + let b := mkIdent (← mkFreshUserName (userName.appendAfter "'")) + ctorArgs1 := ctorArgs1.push a + ctorArgs2 := ctorArgs2.push b + let xType ← inferType x + if (← isProp xType) then + continue + else + rhsCont := fun rhs => `(Ordering.then (compare $a $b) $rhs) >>= rhsCont + let rhs ← rhsCont (← `(Ordering.eq)) + `(@fun $ctorArgs1:term* $ctorArgs2:term* =>$rhs:term) + if indVal.numCtors == 1 then + `( $(mkCIdent casesOnSameCtorName) $x1:term $x2:term rfl $alts:term* ) + else + `( match h : compare ($(mkCIdent ctorIdxName) $x1:ident) ($(mkCIdent ctorIdxName) $x2:ident) with + | Ordering.lt => Ordering.lt + | Ordering.gt => Ordering.gt + | Ordering.eq => + $(mkCIdent casesOnSameCtorName) $x1:term $x2:term (Nat.compare_eq_eq.mp h) $alts:term* + ) + +def mkMatch (header : Header) (indVal : InductiveVal) : TermElabM Term := do + if indVal.numCtors ≥ deriving.ord.linear_construction_threshold.get (← getOptions) then + mkMatchNew header indVal + else + mkMatchOld header indVal + def mkAuxFunction (ctx : Context) (i : Nat) : TermElabM Command := do let auxFunName := ctx.auxFunNames[i]! let indVal := ctx.typeInfos[i]! @@ -105,13 +167,31 @@ private def mkOrdInstanceCmds (declName : Name) : TermElabM (Array Syntax) := do trace[Elab.Deriving.ord] "\n{cmds}" return cmds +private def mkOrdEnumFun (ctx : Context) (name : Name) : TermElabM Syntax := do + let auxFunName := ctx.auxFunNames[0]! + `(def $(mkIdent auxFunName):ident (x y : $(mkCIdent name)) : Ordering := compare x.ctorIdx y.ctorIdx) + +private def mkOrdEnumCmd (name : Name): TermElabM (Array Syntax) := do + let ctx ← mkContext ``Ord "ord" name + let cmds := #[← mkOrdEnumFun ctx name] ++ (← mkInstanceCmds ctx `Ord #[name]) + trace[Elab.Deriving.ord] "\n{cmds}" + return cmds + open Command +def mkOrdInstance (declName : Name) : CommandElabM Unit := do + withoutExposeFromCtors declName do + let cmds ← liftTermElabM <| + if (← isEnumType declName) then + mkOrdEnumCmd declName + else + mkOrdInstanceCmds declName + cmds.forM elabCommand + def mkOrdInstanceHandler (declNames : Array Name) : CommandElabM Bool := do if (← declNames.allM isInductive) then for declName in declNames do - let cmds ← withoutExposeFromCtors declName <| liftTermElabM <| mkOrdInstanceCmds declName - cmds.forM elabCommand + mkOrdInstance declName return true else return false diff --git a/tests/lean/run/Ord.lean b/tests/lean/run/Ord.lean index d2102fce5c..5101a75848 100644 --- a/tests/lean/run/Ord.lean +++ b/tests/lean/run/Ord.lean @@ -19,7 +19,7 @@ inductive ManyConstructors | A | B | C | D | E | F | G | H | I | J | K | L | M | N | O | P | Q | R | S | T | U | V | W | X | Y | Z deriving Ord -structure Person := +structure Person where firstName : String lastName : String age : Nat @@ -27,7 +27,7 @@ deriving Ord example : compare { firstName := "A", lastName := "B", age := 10 : Person } ⟨"B", "A", 9⟩ = Ordering.lt := rfl -structure Company := +structure Company where name : String ceo : Person numberOfEmployees : Nat