feat: linear-size Ord instance (#10270)

This PR adds an alternative implementation of `Deriving Ord` based on
comparing `.ctorIdx` and using a dedicated matcher for comparing same
constructors (added in #10152). The new option
`deriving.ord.linear_construction_threshold` sets the constructor count
threshold (10 by default) for using the new construction.

It also (unconditionally) changes the implementation for enumeration
types to simply compare the `ctorIdx`.
This commit is contained in:
Joachim Breitner 2025-09-19 16:13:57 +02:00 committed by GitHub
parent ae8dc414c3
commit 38b4062edb
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 90 additions and 9 deletions

View file

@ -10,6 +10,7 @@ prelude
public import Init.Data.Repr
public import Init.Data.Ord
import Init.Data.Nat.Compare
set_option linter.missingDocs true

View file

@ -6,12 +6,21 @@ Authors: Dany Fabian
module
prelude
public import Lean.Meta.Transform
public import Lean.Elab.Deriving.Basic
public import Lean.Elab.Deriving.Util
public import Lean.Data.Options
import Lean.Meta.Transform
import Lean.Elab.Deriving.Basic
import Lean.Elab.Deriving.Util
import Lean.Meta.Constructions.CtorIdx
import Lean.Meta.Constructions.CasesOnSameCtor
import Lean.Meta.SameCtorUtils
public section
register_builtin_option deriving.ord.linear_construction_threshold : Nat := {
defValue := 10
descr := "If the inductive data type has this many or more constructors, use a different \
implementation for implementing `Ord` that avoids the quadratic code size produced by the \
default implementation.\n\n\
The alternative construction compiles to less efficient code in some cases, so by default \
it is only used for inductive types with 10 or more constructors." }
namespace Lean.Elab.Deriving.Ord
open Lean.Parser.Term
@ -20,7 +29,7 @@ open Meta
def mkOrdHeader (indVal : InductiveVal) : TermElabM Header := do
mkHeader `Ord 2 indVal
def mkMatch (header : Header) (indVal : InductiveVal) : TermElabM Term := do
def mkMatchOld (header : Header) (indVal : InductiveVal) : TermElabM Term := do
let discrs ← mkDiscrs header indVal
let alts ← mkAlts
`(match $[$discrs],* with $alts:matchAlt*)
@ -74,6 +83,59 @@ where
alts := alts ++ (alt : Array (TSyntax ``matchAlt))
return alts.pop.pop
def mkMatchNew (header : Header) (indVal : InductiveVal) : TermElabM Term := do
assert! header.targetNames.size == 2
let x1 := mkIdent header.targetNames[0]!
let x2 := mkIdent header.targetNames[1]!
let ctorIdxName := mkCtorIdxName indVal.name
-- NB: the getMatcherInfo? assumes all mathcers are called `match_`
let casesOnSameCtorName ← mkFreshUserName (indVal.name ++ `match_on_same_ctor)
mkCasesOnSameCtor casesOnSameCtorName indVal.name
let alts ← Array.ofFnM (n := indVal.numCtors) fun ⟨ctorIdx, _⟩ => do
let ctorName := indVal.ctors[ctorIdx]!
let ctorInfo ← getConstInfoCtor ctorName
forallTelescopeReducing ctorInfo.type fun xs type => do
let type ← Core.betaReduce type -- we 'beta-reduce' to eliminate "artificial" dependencies
let mut ctorArgs1 : Array Term := #[]
let mut ctorArgs2 : Array Term := #[]
let mut rhsCont : Term → TermElabM Term := fun rhs => pure rhs
for i in *...ctorInfo.numFields do
let x := xs[indVal.numParams + i]!
if occursOrInType (← getLCtx) x type then
-- If resulting type depends on this field, we don't need to compare
-- and the casesOnSameCtor only has a parameter for it once
ctorArgs1 := ctorArgs1.push (← `(_))
else
let userName ← x.fvarId!.getUserName
let a := mkIdent (← mkFreshUserName userName)
let b := mkIdent (← mkFreshUserName (userName.appendAfter "'"))
ctorArgs1 := ctorArgs1.push a
ctorArgs2 := ctorArgs2.push b
let xType ← inferType x
if (← isProp xType) then
continue
else
rhsCont := fun rhs => `(Ordering.then (compare $a $b) $rhs) >>= rhsCont
let rhs ← rhsCont (← `(Ordering.eq))
`(@fun $ctorArgs1:term* $ctorArgs2:term* =>$rhs:term)
if indVal.numCtors == 1 then
`( $(mkCIdent casesOnSameCtorName) $x1:term $x2:term rfl $alts:term* )
else
`( match h : compare ($(mkCIdent ctorIdxName) $x1:ident) ($(mkCIdent ctorIdxName) $x2:ident) with
| Ordering.lt => Ordering.lt
| Ordering.gt => Ordering.gt
| Ordering.eq =>
$(mkCIdent casesOnSameCtorName) $x1:term $x2:term (Nat.compare_eq_eq.mp h) $alts:term*
)
def mkMatch (header : Header) (indVal : InductiveVal) : TermElabM Term := do
if indVal.numCtors ≥ deriving.ord.linear_construction_threshold.get (← getOptions) then
mkMatchNew header indVal
else
mkMatchOld header indVal
def mkAuxFunction (ctx : Context) (i : Nat) : TermElabM Command := do
let auxFunName := ctx.auxFunNames[i]!
let indVal := ctx.typeInfos[i]!
@ -105,13 +167,31 @@ private def mkOrdInstanceCmds (declName : Name) : TermElabM (Array Syntax) := do
trace[Elab.Deriving.ord] "\n{cmds}"
return cmds
private def mkOrdEnumFun (ctx : Context) (name : Name) : TermElabM Syntax := do
let auxFunName := ctx.auxFunNames[0]!
`(def $(mkIdent auxFunName):ident (x y : $(mkCIdent name)) : Ordering := compare x.ctorIdx y.ctorIdx)
private def mkOrdEnumCmd (name : Name): TermElabM (Array Syntax) := do
let ctx ← mkContext ``Ord "ord" name
let cmds := #[← mkOrdEnumFun ctx name] ++ (← mkInstanceCmds ctx `Ord #[name])
trace[Elab.Deriving.ord] "\n{cmds}"
return cmds
open Command
def mkOrdInstance (declName : Name) : CommandElabM Unit := do
withoutExposeFromCtors declName do
let cmds ← liftTermElabM <|
if (← isEnumType declName) then
mkOrdEnumCmd declName
else
mkOrdInstanceCmds declName
cmds.forM elabCommand
def mkOrdInstanceHandler (declNames : Array Name) : CommandElabM Bool := do
if (← declNames.allM isInductive) then
for declName in declNames do
let cmds ← withoutExposeFromCtors declName <| liftTermElabM <| mkOrdInstanceCmds declName
cmds.forM elabCommand
mkOrdInstance declName
return true
else
return false

View file

@ -19,7 +19,7 @@ inductive ManyConstructors | A | B | C | D | E | F | G | H | I | J | K | L
| M | N | O | P | Q | R | S | T | U | V | W | X | Y | Z
deriving Ord
structure Person :=
structure Person where
firstName : String
lastName : String
age : Nat
@ -27,7 +27,7 @@ deriving Ord
example : compare { firstName := "A", lastName := "B", age := 10 : Person } ⟨"B", "A", 9⟩ = Ordering.lt := rfl
structure Company :=
structure Company where
name : String
ceo : Person
numberOfEmployees : Nat