feat: add Hashable deriving

add support for the `Hashable` deriving by combining structural
hashes over fields
This commit is contained in:
Daniel Fabian 2021-03-30 00:36:35 +00:00 committed by Leonardo de Moura
parent 687ce2fe67
commit fee3390dd1
4 changed files with 206 additions and 0 deletions

View file

@ -39,3 +39,8 @@ instance : Hashable UInt64 where
instance : Hashable USize where
hash n := n
instance : Hashable Int where
hash
| Int.ofNat n => USize.ofNat (2 * n)
| Int.negSucc n => USize.ofNat (2 * n + 1)

View file

@ -11,3 +11,4 @@ import Lean.Elab.Deriving.DecEq
import Lean.Elab.Deriving.Repr
import Lean.Elab.Deriving.FromToJson
import Lean.Elab.Deriving.SizeOf
import Lean.Elab.Deriving.Hashable

View file

@ -0,0 +1,111 @@
/-
Copyright (c) 2021 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Dany Fabian
-/
import Lean.Meta.Inductive
import Lean.Elab.Deriving.Basic
import Lean.Elab.Deriving.Util
namespace Lean.Elab.Deriving.Hashable
open Command
open Lean.Parser.Term
open Meta
/--
Computes at least the first `n` primes. Usually slightly more
-/
private def firstNPrimes (n : Nat) : Array Nat := do
if n ≥ 6 then
let n := Float.ofInt n
-- for n ≥ 6, n log n + n log log n is an upper bound for the n-th prime
let upperBound := n * (Float.log n + (Float.log $ Float.log n))
let nRoot := Float.sqrt upperBound |> Float.toUInt32 |> UInt32.toNat
let upperBound := upperBound |> Float.toUInt32 |> UInt32.toNat
let mut primes := #[false, false] ++ mkArray (upperBound - 1) true
for p in [2:nRoot] do
if primes[p] then
for i in [p*p:upperBound+1:p] do
primes := primes.set! i false
return primes.mapIdx (λ i v => if v then some i.1 else none) |> Array.filterMap id
else
return #[2,3,5,7,11,13]
def mkHashableHeader (ctx : Context) (indVal : InductiveVal) : TermElabM Header := do
mkHeader ctx `Hashable 1 indVal
def mkMatch (offset : Nat) (primes : Array Nat) (ctx : Context) (header : Header) (indVal : InductiveVal) (auxFunName : Name) : TermElabM Syntax := do
let discrs ← mkDiscrs header indVal
let alts ← mkAlts
`(match $[$discrs],* with $alts:matchAlt*)
where
mkAlts : TermElabM (Array Syntax) := do
let mut alts := #[]
let mut ctorIdx := 0
for ctorName in indVal.ctors do
let ctorInfo ← getConstInfoCtor ctorName
let alt ← forallTelescopeReducing ctorInfo.type fun xs type => do
let type ← Core.betaReduce type -- we 'beta-reduce' to eliminate "artificial" dependencies
let mut patterns := #[]
-- add `_` pattern for indices
for i in [:indVal.numIndices] do
patterns := patterns.push (← `(_))
let mut ctorArgs := #[]
let mut rhs ← `($(quote primes[offset + ctorIdx]))
-- add `_` for inductive parameters, they are inaccessible
for i in [:indVal.numParams] do
ctorArgs := ctorArgs.push (← `(_))
for i in [:ctorInfo.numFields] do
let x := xs[indVal.numParams + i]
let xTy ← inferType x
let typeName := xTy.getAppFn.constName!
if indVal.all.contains typeName then
-- If the value depends of any of the mutually recursive types, ignore it → add `_`.
-- We want hash computation to be O(1).
ctorArgs := ctorArgs.push (← `(_))
else
let a := mkIdent (← mkFreshUserName `a)
ctorArgs := ctorArgs.push a
rhs ← `(mixHash $rhs (hash $a:ident))
patterns := patterns.push (← `(@$(mkIdent ctorName):ident $ctorArgs:term*))
`(matchAltExpr| | $[$patterns:term],* => $rhs:term)
alts := alts.push alt
ctorIdx := ctorIdx + 1
return alts
def mkAuxFunction (offset : Nat) (primes : Array Nat) (ctx : Context) (i : Nat) : TermElabM Syntax := do
let auxFunName ← ctx.auxFunNames[i]
let indVal ← ctx.typeInfos[i]
let header ← mkHashableHeader ctx indVal
let body ← mkMatch offset primes ctx header indVal auxFunName
let binders := header.binders
`(private def $(mkIdent auxFunName):ident $binders:explicitBinder* : USize := $body:term)
def mkHashFuncs (ctx : Context) : TermElabM (Array Syntax) := do
let nCtors := ctx.typeInfos.map (·.ctors.length)
let primes := nCtors.foldl (· + ·) 0 |> firstNPrimes
let mut auxDefs := #[]
let mut offset := 0
for i in [:ctx.typeInfos.size] do
auxDefs := auxDefs.push (← mkAuxFunction offset primes ctx i)
offset := nCtors[i]
auxDefs
private def mkHashableInstanceCmds (declNames : Array Name) : TermElabM (Array Syntax) := do
let ctx ← mkContext "hash" declNames[0]
let cmds := (← mkHashFuncs ctx) ++ (← mkInstanceCmds ctx `Hashable declNames)
trace[Elab.Deriving.hashable] "\n{cmds}"
return cmds
def mkHashableHandler (declNames : Array Name) : CommandElabM Bool := do
if (← declNames.allM isInductive) && declNames.size > 0 then
let cmds ← liftTermElabM none <| mkHashableInstanceCmds declNames
cmds.forM elabCommand
return true
else
return false
builtin_initialize
registerBuiltinDerivingHandler ``Hashable mkHashableHandler
registerTraceClass `Elab.Deriving.hashable

View file

@ -0,0 +1,89 @@
set_option trace.Elab.Deriving.hashable true
inductive SimpleInd
| A
| B
deriving Hashable
theorem «inductive fields have different base hashes» : ∀ x, hash x =
match x with
| SimpleInd.A => 2
| SimpleInd.B => 3 := λ x => rfl
mutual
inductive Foo : Type → Type
| A : Int → Foo Prop → String → Foo Int
| B : Bar → Foo String
deriving Hashable
inductive Bar
| C
| D : Foo String → Bar
deriving Hashable
end
theorem «mutually recursive types don't hash recursively» : ∀ x y, (hash x =
match x with
| Foo.A a _ b => mixHash (mixHash 2 (hash a)) (hash b)
| Foo.B _ => 3) ∧ (hash y =
match y with
| Bar.C => 5
| Bar.D _ => 7) := λ x y => ⟨rfl, rfl⟩
inductive ManyConstructors | A | B | C | D | E | F | G | H | I | J | K | L
| M | N | O | P | Q | R | S | T | U | V | W | X | Y | Z
deriving Hashable
theorem «Each constructor is hashed as a different prime to make mixing better» : ∀ x, hash x =
match x with
| ManyConstructors.A => 2
| ManyConstructors.B => 3
| ManyConstructors.C => 5
| ManyConstructors.D => 7
| ManyConstructors.E => 11
| ManyConstructors.F => 13
| ManyConstructors.G => 17
| ManyConstructors.H => 19
| ManyConstructors.I => 23
| ManyConstructors.J => 29
| ManyConstructors.K => 31
| ManyConstructors.L => 37
| ManyConstructors.M => 41
| ManyConstructors.N => 43
| ManyConstructors.O => 47
| ManyConstructors.P => 53
| ManyConstructors.Q => 59
| ManyConstructors.R => 61
| ManyConstructors.S => 67
| ManyConstructors.T => 71
| ManyConstructors.U => 73
| ManyConstructors.V => 79
| ManyConstructors.W => 83
| ManyConstructors.X => 89
| ManyConstructors.Y => 97
| ManyConstructors.Z => 101 := λ x => rfl
structure Person :=
FirstName : String
LastName : String
Age : Nat
deriving Hashable
structure Company :=
Name : String
CEO : Person
NumberOfEmployees : Nat
deriving Hashable
-- structures hash just fine
#eval hash {
Name := "Microsoft"
CEO := { FirstName := "Satya", LastName := "Nadella", Age := 53 }
NumberOfEmployees := 165000 : Company }
-- 10875484723257753924
-- syntax(name := tst) "tst" : command
-- @[commandElab «tst»] def elab_tst : CommandElab := fun stx => do
-- let declNames := #[`Foo, `Bar]
-- let declNames := #[`Foo]
-- discard $ mkHashableHandler declNames
-- pure ()