feat: linear-size Ord instance (#10270)
This PR adds an alternative implementation of `Deriving Ord` based on comparing `.ctorIdx` and using a dedicated matcher for comparing same constructors (added in #10152). The new option `deriving.ord.linear_construction_threshold` sets the constructor count threshold (10 by default) for using the new construction. It also (unconditionally) changes the implementation for enumeration types to simply compare the `ctorIdx`.
This commit is contained in:
parent
ae8dc414c3
commit
38b4062edb
3 changed files with 90 additions and 9 deletions
|
|
@ -10,6 +10,7 @@ prelude
|
|||
|
||||
public import Init.Data.Repr
|
||||
public import Init.Data.Ord
|
||||
import Init.Data.Nat.Compare
|
||||
|
||||
set_option linter.missingDocs true
|
||||
|
||||
|
|
|
|||
|
|
@ -6,12 +6,21 @@ Authors: Dany Fabian
|
|||
module
|
||||
|
||||
prelude
|
||||
public import Lean.Meta.Transform
|
||||
public import Lean.Elab.Deriving.Basic
|
||||
public import Lean.Elab.Deriving.Util
|
||||
public import Lean.Data.Options
|
||||
import Lean.Meta.Transform
|
||||
import Lean.Elab.Deriving.Basic
|
||||
import Lean.Elab.Deriving.Util
|
||||
import Lean.Meta.Constructions.CtorIdx
|
||||
import Lean.Meta.Constructions.CasesOnSameCtor
|
||||
import Lean.Meta.SameCtorUtils
|
||||
|
||||
public section
|
||||
register_builtin_option deriving.ord.linear_construction_threshold : Nat := {
|
||||
defValue := 10
|
||||
descr := "If the inductive data type has this many or more constructors, use a different \
|
||||
implementation for implementing `Ord` that avoids the quadratic code size produced by the \
|
||||
default implementation.\n\n\
|
||||
The alternative construction compiles to less efficient code in some cases, so by default \
|
||||
it is only used for inductive types with 10 or more constructors." }
|
||||
|
||||
namespace Lean.Elab.Deriving.Ord
|
||||
open Lean.Parser.Term
|
||||
|
|
@ -20,7 +29,7 @@ open Meta
|
|||
def mkOrdHeader (indVal : InductiveVal) : TermElabM Header := do
|
||||
mkHeader `Ord 2 indVal
|
||||
|
||||
def mkMatch (header : Header) (indVal : InductiveVal) : TermElabM Term := do
|
||||
def mkMatchOld (header : Header) (indVal : InductiveVal) : TermElabM Term := do
|
||||
let discrs ← mkDiscrs header indVal
|
||||
let alts ← mkAlts
|
||||
`(match $[$discrs],* with $alts:matchAlt*)
|
||||
|
|
@ -74,6 +83,59 @@ where
|
|||
alts := alts ++ (alt : Array (TSyntax ``matchAlt))
|
||||
return alts.pop.pop
|
||||
|
||||
def mkMatchNew (header : Header) (indVal : InductiveVal) : TermElabM Term := do
|
||||
assert! header.targetNames.size == 2
|
||||
|
||||
let x1 := mkIdent header.targetNames[0]!
|
||||
let x2 := mkIdent header.targetNames[1]!
|
||||
let ctorIdxName := mkCtorIdxName indVal.name
|
||||
-- NB: the getMatcherInfo? assumes all mathcers are called `match_`
|
||||
let casesOnSameCtorName ← mkFreshUserName (indVal.name ++ `match_on_same_ctor)
|
||||
mkCasesOnSameCtor casesOnSameCtorName indVal.name
|
||||
let alts ← Array.ofFnM (n := indVal.numCtors) fun ⟨ctorIdx, _⟩ => do
|
||||
let ctorName := indVal.ctors[ctorIdx]!
|
||||
let ctorInfo ← getConstInfoCtor ctorName
|
||||
forallTelescopeReducing ctorInfo.type fun xs type => do
|
||||
let type ← Core.betaReduce type -- we 'beta-reduce' to eliminate "artificial" dependencies
|
||||
let mut ctorArgs1 : Array Term := #[]
|
||||
let mut ctorArgs2 : Array Term := #[]
|
||||
|
||||
let mut rhsCont : Term → TermElabM Term := fun rhs => pure rhs
|
||||
for i in *...ctorInfo.numFields do
|
||||
let x := xs[indVal.numParams + i]!
|
||||
if occursOrInType (← getLCtx) x type then
|
||||
-- If resulting type depends on this field, we don't need to compare
|
||||
-- and the casesOnSameCtor only has a parameter for it once
|
||||
ctorArgs1 := ctorArgs1.push (← `(_))
|
||||
else
|
||||
let userName ← x.fvarId!.getUserName
|
||||
let a := mkIdent (← mkFreshUserName userName)
|
||||
let b := mkIdent (← mkFreshUserName (userName.appendAfter "'"))
|
||||
ctorArgs1 := ctorArgs1.push a
|
||||
ctorArgs2 := ctorArgs2.push b
|
||||
let xType ← inferType x
|
||||
if (← isProp xType) then
|
||||
continue
|
||||
else
|
||||
rhsCont := fun rhs => `(Ordering.then (compare $a $b) $rhs) >>= rhsCont
|
||||
let rhs ← rhsCont (← `(Ordering.eq))
|
||||
`(@fun $ctorArgs1:term* $ctorArgs2:term* =>$rhs:term)
|
||||
if indVal.numCtors == 1 then
|
||||
`( $(mkCIdent casesOnSameCtorName) $x1:term $x2:term rfl $alts:term* )
|
||||
else
|
||||
`( match h : compare ($(mkCIdent ctorIdxName) $x1:ident) ($(mkCIdent ctorIdxName) $x2:ident) with
|
||||
| Ordering.lt => Ordering.lt
|
||||
| Ordering.gt => Ordering.gt
|
||||
| Ordering.eq =>
|
||||
$(mkCIdent casesOnSameCtorName) $x1:term $x2:term (Nat.compare_eq_eq.mp h) $alts:term*
|
||||
)
|
||||
|
||||
def mkMatch (header : Header) (indVal : InductiveVal) : TermElabM Term := do
|
||||
if indVal.numCtors ≥ deriving.ord.linear_construction_threshold.get (← getOptions) then
|
||||
mkMatchNew header indVal
|
||||
else
|
||||
mkMatchOld header indVal
|
||||
|
||||
def mkAuxFunction (ctx : Context) (i : Nat) : TermElabM Command := do
|
||||
let auxFunName := ctx.auxFunNames[i]!
|
||||
let indVal := ctx.typeInfos[i]!
|
||||
|
|
@ -105,13 +167,31 @@ private def mkOrdInstanceCmds (declName : Name) : TermElabM (Array Syntax) := do
|
|||
trace[Elab.Deriving.ord] "\n{cmds}"
|
||||
return cmds
|
||||
|
||||
private def mkOrdEnumFun (ctx : Context) (name : Name) : TermElabM Syntax := do
|
||||
let auxFunName := ctx.auxFunNames[0]!
|
||||
`(def $(mkIdent auxFunName):ident (x y : $(mkCIdent name)) : Ordering := compare x.ctorIdx y.ctorIdx)
|
||||
|
||||
private def mkOrdEnumCmd (name : Name): TermElabM (Array Syntax) := do
|
||||
let ctx ← mkContext ``Ord "ord" name
|
||||
let cmds := #[← mkOrdEnumFun ctx name] ++ (← mkInstanceCmds ctx `Ord #[name])
|
||||
trace[Elab.Deriving.ord] "\n{cmds}"
|
||||
return cmds
|
||||
|
||||
open Command
|
||||
|
||||
def mkOrdInstance (declName : Name) : CommandElabM Unit := do
|
||||
withoutExposeFromCtors declName do
|
||||
let cmds ← liftTermElabM <|
|
||||
if (← isEnumType declName) then
|
||||
mkOrdEnumCmd declName
|
||||
else
|
||||
mkOrdInstanceCmds declName
|
||||
cmds.forM elabCommand
|
||||
|
||||
def mkOrdInstanceHandler (declNames : Array Name) : CommandElabM Bool := do
|
||||
if (← declNames.allM isInductive) then
|
||||
for declName in declNames do
|
||||
let cmds ← withoutExposeFromCtors declName <| liftTermElabM <| mkOrdInstanceCmds declName
|
||||
cmds.forM elabCommand
|
||||
mkOrdInstance declName
|
||||
return true
|
||||
else
|
||||
return false
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ inductive ManyConstructors | A | B | C | D | E | F | G | H | I | J | K | L
|
|||
| M | N | O | P | Q | R | S | T | U | V | W | X | Y | Z
|
||||
deriving Ord
|
||||
|
||||
structure Person :=
|
||||
structure Person where
|
||||
firstName : String
|
||||
lastName : String
|
||||
age : Nat
|
||||
|
|
@ -27,7 +27,7 @@ deriving Ord
|
|||
|
||||
example : compare { firstName := "A", lastName := "B", age := 10 : Person } ⟨"B", "A", 9⟩ = Ordering.lt := rfl
|
||||
|
||||
structure Company :=
|
||||
structure Company where
|
||||
name : String
|
||||
ceo : Person
|
||||
numberOfEmployees : Nat
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue