feat: add Hashable deriving
add support for the `Hashable` deriving by combining structural hashes over fields
This commit is contained in:
parent
687ce2fe67
commit
fee3390dd1
4 changed files with 206 additions and 0 deletions
|
|
@ -39,3 +39,8 @@ instance : Hashable UInt64 where
|
|||
|
||||
instance : Hashable USize where
|
||||
hash n := n
|
||||
|
||||
instance : Hashable Int where
|
||||
hash
|
||||
| Int.ofNat n => USize.ofNat (2 * n)
|
||||
| Int.negSucc n => USize.ofNat (2 * n + 1)
|
||||
|
|
@ -11,3 +11,4 @@ import Lean.Elab.Deriving.DecEq
|
|||
import Lean.Elab.Deriving.Repr
|
||||
import Lean.Elab.Deriving.FromToJson
|
||||
import Lean.Elab.Deriving.SizeOf
|
||||
import Lean.Elab.Deriving.Hashable
|
||||
|
|
|
|||
111
src/Lean/Elab/Deriving/Hashable.lean
Normal file
111
src/Lean/Elab/Deriving/Hashable.lean
Normal file
|
|
@ -0,0 +1,111 @@
|
|||
/-
|
||||
Copyright (c) 2021 Microsoft Corporation. All rights reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Dany Fabian
|
||||
-/
|
||||
import Lean.Meta.Inductive
|
||||
import Lean.Elab.Deriving.Basic
|
||||
import Lean.Elab.Deriving.Util
|
||||
|
||||
namespace Lean.Elab.Deriving.Hashable
|
||||
open Command
|
||||
open Lean.Parser.Term
|
||||
open Meta
|
||||
|
||||
/--
|
||||
Computes at least the first `n` primes. Usually slightly more
|
||||
-/
|
||||
private def firstNPrimes (n : Nat) : Array Nat := do
|
||||
if n ≥ 6 then
|
||||
let n := Float.ofInt n
|
||||
-- for n ≥ 6, n log n + n log log n is an upper bound for the n-th prime
|
||||
let upperBound := n * (Float.log n + (Float.log $ Float.log n))
|
||||
let nRoot := Float.sqrt upperBound |> Float.toUInt32 |> UInt32.toNat
|
||||
let upperBound := upperBound |> Float.toUInt32 |> UInt32.toNat
|
||||
let mut primes := #[false, false] ++ mkArray (upperBound - 1) true
|
||||
for p in [2:nRoot] do
|
||||
if primes[p] then
|
||||
for i in [p*p:upperBound+1:p] do
|
||||
primes := primes.set! i false
|
||||
return primes.mapIdx (λ i v => if v then some i.1 else none) |> Array.filterMap id
|
||||
else
|
||||
return #[2,3,5,7,11,13]
|
||||
|
||||
def mkHashableHeader (ctx : Context) (indVal : InductiveVal) : TermElabM Header := do
|
||||
mkHeader ctx `Hashable 1 indVal
|
||||
|
||||
def mkMatch (offset : Nat) (primes : Array Nat) (ctx : Context) (header : Header) (indVal : InductiveVal) (auxFunName : Name) : TermElabM Syntax := do
|
||||
let discrs ← mkDiscrs header indVal
|
||||
let alts ← mkAlts
|
||||
`(match $[$discrs],* with $alts:matchAlt*)
|
||||
where
|
||||
|
||||
mkAlts : TermElabM (Array Syntax) := do
|
||||
let mut alts := #[]
|
||||
let mut ctorIdx := 0
|
||||
for ctorName in indVal.ctors do
|
||||
let ctorInfo ← getConstInfoCtor ctorName
|
||||
let alt ← forallTelescopeReducing ctorInfo.type fun xs type => do
|
||||
let type ← Core.betaReduce type -- we 'beta-reduce' to eliminate "artificial" dependencies
|
||||
let mut patterns := #[]
|
||||
-- add `_` pattern for indices
|
||||
for i in [:indVal.numIndices] do
|
||||
patterns := patterns.push (← `(_))
|
||||
let mut ctorArgs := #[]
|
||||
let mut rhs ← `($(quote primes[offset + ctorIdx]))
|
||||
-- add `_` for inductive parameters, they are inaccessible
|
||||
for i in [:indVal.numParams] do
|
||||
ctorArgs := ctorArgs.push (← `(_))
|
||||
for i in [:ctorInfo.numFields] do
|
||||
let x := xs[indVal.numParams + i]
|
||||
let xTy ← inferType x
|
||||
let typeName := xTy.getAppFn.constName!
|
||||
if indVal.all.contains typeName then
|
||||
-- If the value depends of any of the mutually recursive types, ignore it → add `_`.
|
||||
-- We want hash computation to be O(1).
|
||||
ctorArgs := ctorArgs.push (← `(_))
|
||||
else
|
||||
let a := mkIdent (← mkFreshUserName `a)
|
||||
ctorArgs := ctorArgs.push a
|
||||
rhs ← `(mixHash $rhs (hash $a:ident))
|
||||
patterns := patterns.push (← `(@$(mkIdent ctorName):ident $ctorArgs:term*))
|
||||
`(matchAltExpr| | $[$patterns:term],* => $rhs:term)
|
||||
alts := alts.push alt
|
||||
ctorIdx := ctorIdx + 1
|
||||
return alts
|
||||
|
||||
def mkAuxFunction (offset : Nat) (primes : Array Nat) (ctx : Context) (i : Nat) : TermElabM Syntax := do
|
||||
let auxFunName ← ctx.auxFunNames[i]
|
||||
let indVal ← ctx.typeInfos[i]
|
||||
let header ← mkHashableHeader ctx indVal
|
||||
let body ← mkMatch offset primes ctx header indVal auxFunName
|
||||
let binders := header.binders
|
||||
`(private def $(mkIdent auxFunName):ident $binders:explicitBinder* : USize := $body:term)
|
||||
|
||||
def mkHashFuncs (ctx : Context) : TermElabM (Array Syntax) := do
|
||||
let nCtors := ctx.typeInfos.map (·.ctors.length)
|
||||
let primes := nCtors.foldl (· + ·) 0 |> firstNPrimes
|
||||
let mut auxDefs := #[]
|
||||
let mut offset := 0
|
||||
for i in [:ctx.typeInfos.size] do
|
||||
auxDefs := auxDefs.push (← mkAuxFunction offset primes ctx i)
|
||||
offset := nCtors[i]
|
||||
auxDefs
|
||||
|
||||
private def mkHashableInstanceCmds (declNames : Array Name) : TermElabM (Array Syntax) := do
|
||||
let ctx ← mkContext "hash" declNames[0]
|
||||
let cmds := (← mkHashFuncs ctx) ++ (← mkInstanceCmds ctx `Hashable declNames)
|
||||
trace[Elab.Deriving.hashable] "\n{cmds}"
|
||||
return cmds
|
||||
|
||||
def mkHashableHandler (declNames : Array Name) : CommandElabM Bool := do
|
||||
if (← declNames.allM isInductive) && declNames.size > 0 then
|
||||
let cmds ← liftTermElabM none <| mkHashableInstanceCmds declNames
|
||||
cmds.forM elabCommand
|
||||
return true
|
||||
else
|
||||
return false
|
||||
|
||||
builtin_initialize
|
||||
registerBuiltinDerivingHandler ``Hashable mkHashableHandler
|
||||
registerTraceClass `Elab.Deriving.hashable
|
||||
89
tests/playground/hashable.lean
Normal file
89
tests/playground/hashable.lean
Normal file
|
|
@ -0,0 +1,89 @@
|
|||
set_option trace.Elab.Deriving.hashable true
|
||||
|
||||
inductive SimpleInd
|
||||
| A
|
||||
| B
|
||||
deriving Hashable
|
||||
|
||||
theorem «inductive fields have different base hashes» : ∀ x, hash x =
|
||||
match x with
|
||||
| SimpleInd.A => 2
|
||||
| SimpleInd.B => 3 := λ x => rfl
|
||||
|
||||
mutual
|
||||
inductive Foo : Type → Type
|
||||
| A : Int → Foo Prop → String → Foo Int
|
||||
| B : Bar → Foo String
|
||||
deriving Hashable
|
||||
inductive Bar
|
||||
| C
|
||||
| D : Foo String → Bar
|
||||
deriving Hashable
|
||||
end
|
||||
|
||||
theorem «mutually recursive types don't hash recursively» : ∀ x y, (hash x =
|
||||
match x with
|
||||
| Foo.A a _ b => mixHash (mixHash 2 (hash a)) (hash b)
|
||||
| Foo.B _ => 3) ∧ (hash y =
|
||||
match y with
|
||||
| Bar.C => 5
|
||||
| Bar.D _ => 7) := λ x y => ⟨rfl, rfl⟩
|
||||
|
||||
inductive ManyConstructors | A | B | C | D | E | F | G | H | I | J | K | L
|
||||
| M | N | O | P | Q | R | S | T | U | V | W | X | Y | Z
|
||||
deriving Hashable
|
||||
|
||||
theorem «Each constructor is hashed as a different prime to make mixing better» : ∀ x, hash x =
|
||||
match x with
|
||||
| ManyConstructors.A => 2
|
||||
| ManyConstructors.B => 3
|
||||
| ManyConstructors.C => 5
|
||||
| ManyConstructors.D => 7
|
||||
| ManyConstructors.E => 11
|
||||
| ManyConstructors.F => 13
|
||||
| ManyConstructors.G => 17
|
||||
| ManyConstructors.H => 19
|
||||
| ManyConstructors.I => 23
|
||||
| ManyConstructors.J => 29
|
||||
| ManyConstructors.K => 31
|
||||
| ManyConstructors.L => 37
|
||||
| ManyConstructors.M => 41
|
||||
| ManyConstructors.N => 43
|
||||
| ManyConstructors.O => 47
|
||||
| ManyConstructors.P => 53
|
||||
| ManyConstructors.Q => 59
|
||||
| ManyConstructors.R => 61
|
||||
| ManyConstructors.S => 67
|
||||
| ManyConstructors.T => 71
|
||||
| ManyConstructors.U => 73
|
||||
| ManyConstructors.V => 79
|
||||
| ManyConstructors.W => 83
|
||||
| ManyConstructors.X => 89
|
||||
| ManyConstructors.Y => 97
|
||||
| ManyConstructors.Z => 101 := λ x => rfl
|
||||
|
||||
structure Person :=
|
||||
FirstName : String
|
||||
LastName : String
|
||||
Age : Nat
|
||||
deriving Hashable
|
||||
|
||||
structure Company :=
|
||||
Name : String
|
||||
CEO : Person
|
||||
NumberOfEmployees : Nat
|
||||
deriving Hashable
|
||||
|
||||
-- structures hash just fine
|
||||
#eval hash {
|
||||
Name := "Microsoft"
|
||||
CEO := { FirstName := "Satya", LastName := "Nadella", Age := 53 }
|
||||
NumberOfEmployees := 165000 : Company }
|
||||
-- 10875484723257753924
|
||||
|
||||
-- syntax(name := tst) "tst" : command
|
||||
-- @[commandElab «tst»] def elab_tst : CommandElab := fun stx => do
|
||||
-- let declNames := #[`Foo, `Bar]
|
||||
-- let declNames := #[`Foo]
|
||||
-- discard $ mkHashableHandler declNames
|
||||
-- pure ()
|
||||
Loading…
Add table
Reference in a new issue