lean4-htt/src/Lean/Elab/Deriving/DecEq.lean
Jason Yuen 3770b3dcb8
chore: fix spelling errors (#13274)
This PR fixed typos:

```
pip install codespell --upgrade
codespell --summary --ignore-words-list enew,forin,fro,happend,hge,ihs,iterm,spred --skip stage0 --check-filenames
codespell --summary --ignore-words-list enew,forin,fro,happend,hge,ihs,iterm,spred --skip stage0 --check-filenames --regex '[A-Z][a-z]*'
codespell --summary --ignore-words-list enew,forin,fro,happend,hge,ihs,iterm,spred --skip stage0 --check-filenames --regex "\b[a-z']*"
```
2026-04-04 07:34:34 +00:00

295 lines
13 KiB
Text
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

/-
Copyright (c) 2020 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
module
prelude
public import Lean.Data.Options
import Lean.Meta.Inductive
import Lean.Elab.Deriving.Basic
import Lean.Elab.Deriving.Util
import Lean.Meta.NatTable
import Lean.Meta.Constructions.CtorIdx
import Lean.Meta.Constructions.CasesOnSameCtor
import Lean.Meta.SameCtorUtils
import Init.Data.Array.OfFn
namespace Lean.Elab.Deriving.DecEq
open Lean.Parser.Term
open Meta
register_builtin_option deriving.decEq.linear_construction_threshold : Nat := {
defValue := 10
descr := "If the inductive data type has this many or more constructors, use a different \
implementation for deciding equality that avoids the quadratic code size produced by the \
default implementation.\n\n\
The alternative construction compiles to less efficient code in some cases, so by default \
it is only used for inductive types with 10 or more constructors." }
def mkDecEqHeader (indVal : InductiveVal) : TermElabM Header := do
mkHeader `DecidableEq 2 indVal
def mkMatchOld (ctx : Context) (header : Header) (indVal : InductiveVal) : TermElabM Term := do
let discrs ← mkDiscrs header indVal
let alts ← mkAlts
`(match $[$discrs],* with $alts:matchAlt*)
where
mkSameCtorRhs : List (Ident × Ident × Option Name × Bool) → TermElabM Term
| [] => ``(isTrue rfl)
| (a, b, recField, isProof) :: todo => withFreshMacroScope do
let rhs ← if isProof then
`(have h : @$a = @$b := rfl; by subst h; exact $(← mkSameCtorRhs todo):term)
else
let sameCtor ← mkSameCtorRhs todo
`(if h : @$a = @$b then
by subst h; exact $sameCtor:term
else
isFalse (by intro n; injection n; apply h _; assumption))
if let some auxFunName := recField then
-- add local instance for `a = b` using the function being defined `auxFunName`
`(let inst := $(mkIdent auxFunName) @$a @$b; $rhs)
else
return rhs
mkAlts : TermElabM (Array (TSyntax ``matchAlt)) := do
let mut alts := #[]
for ctorName₁ in indVal.ctors do
let ctorInfo ← getConstInfoCtor ctorName₁
for ctorName₂ in indVal.ctors do
let mut patterns := #[]
-- add `_` pattern for indices
for _ in *...indVal.numIndices do
patterns := patterns.push (← `(_))
if ctorName₁ == ctorName₂ then
let alt ← forallTelescopeReducing ctorInfo.type fun xs type => do
let type ← Core.betaReduce type -- we 'beta-reduce' to eliminate "artificial" dependencies
let mut patterns := patterns
let mut ctorArgs1 := #[]
let mut ctorArgs2 := #[]
-- add `_` for inductive parameters, they are inaccessible
for _ in *...indVal.numParams do
ctorArgs1 := ctorArgs1.push (← `(_))
ctorArgs2 := ctorArgs2.push (← `(_))
let mut todo := #[]
for i in *...ctorInfo.numFields do
let x := xs[indVal.numParams + i]!
if occursOrInType (← getLCtx) x type then
-- If resulting type depends on this field, we don't need to compare
-- but use inaccessible patterns fail during pattern match compilation if their
-- equality does not actually follow from the equality between their types
let a := mkIdent (← mkFreshUserName `a)
ctorArgs1 := ctorArgs1.push a
ctorArgs2 := ctorArgs2.push (← `(term|.( $a:ident )))
else
let a := mkIdent (← mkFreshUserName `a)
let b := mkIdent (← mkFreshUserName `b)
ctorArgs1 := ctorArgs1.push a
ctorArgs2 := ctorArgs2.push b
let xType ← inferType x
let indValNum :=
ctx.typeInfos.findIdx?
(xType.isAppOf ∘ ConstantVal.name ∘ InductiveVal.toConstantVal)
let recField := indValNum.map (ctx.auxFunNames[·]!)
let isProof ← isProp xType
todo := todo.push (a, b, recField, isProof)
patterns := patterns.push (← `(@$(mkIdent ctorName₁):ident $ctorArgs1:term*))
patterns := patterns.push (← `(@$(mkIdent ctorName₁):ident $ctorArgs2:term*))
let rhs ← mkSameCtorRhs todo.toList
`(matchAltExpr| | $[$patterns:term],* => $rhs:term)
alts := alts.push alt
else if (← compatibleCtors ctorName₁ ctorName₂) then
patterns := patterns ++ #[(← `($(mkIdent ctorName₁) ..)), (← `($(mkIdent ctorName₂) ..))]
let rhs ← `(isFalse (by intro h; injection h))
alts := alts.push (← `(matchAltExpr| | $[$patterns:term],* => $rhs:term))
return alts
def mkMatchNew (ctx : Context) (header : Header) (indVal : InductiveVal) : TermElabM Term := do
assert! header.targetNames.size == 2
let x1 := mkIdent header.targetNames[0]!
let x2 := mkIdent header.targetNames[1]!
let ctorIdxName := mkCtorIdxName indVal.name
-- NB: the getMatcherInfo? assumes all matchers are called `match_`
let casesOnSameCtorName ← mkFreshUserName (indVal.name ++ `match_on_same_ctor)
mkCasesOnSameCtor casesOnSameCtorName indVal.name
let alts ← Array.ofFnM (n := indVal.numCtors) fun ⟨ctorIdx, _⟩ => do
let ctorName := indVal.ctors[ctorIdx]!
let ctorInfo ← getConstInfoCtor ctorName
forallTelescopeReducing ctorInfo.type fun xs type => do
let type ← Core.betaReduce type -- we 'beta-reduce' to eliminate "artificial" dependencies
let mut ctorArgs1 : Array Term := #[]
let mut ctorArgs2 : Array Term := #[]
let mut todo := #[]
for i in *...ctorInfo.numFields do
let x := xs[indVal.numParams + i]!
if type.containsFVar x.fvarId! then
-- If resulting type depends on this field, we don't need to bring it into
-- scope nor compare it
ctorArgs1 := ctorArgs1.push (← `(_))
else
let a := mkIdent (← mkFreshUserName `a)
let b := mkIdent (← mkFreshUserName `b)
ctorArgs1 := ctorArgs1.push a
ctorArgs2 := ctorArgs2.push b
let xType ← inferType x
let indValNum :=
ctx.typeInfos.findIdx?
(xType.isAppOf ∘ ConstantVal.name ∘ InductiveVal.toConstantVal)
let recField := indValNum.map (ctx.auxFunNames[·]!)
let isProof ← isProp xType
todo := todo.push (a, b, recField, isProof)
if ctorArgs1.isEmpty then
-- Unit thunking argument
ctorArgs1 := ctorArgs1.push (← `(()))
let rhs ← mkSameCtorRhs todo.toList
`(@fun $ctorArgs1:term* $ctorArgs2:term* =>$rhs:term)
if indVal.numCtors == 1 then
`( $(mkCIdent casesOnSameCtorName) $x1:term $x2:term rfl $alts:term* )
else
`( match decEq ($(mkCIdent ctorIdxName) $x1:ident) ($(mkCIdent ctorIdxName) $x2:ident) with
| .isTrue h => $(mkCIdent casesOnSameCtorName) $x1:term $x2:term h $alts:term*
| .isFalse h => isFalse (fun h' => h (congrArg $(mkCIdent ctorIdxName) h')))
where
mkSameCtorRhs : List (Ident × Ident × Option Name × Bool) → TermElabM Term
| [] => ``(isTrue rfl)
| (a, b, recField, isProof) :: todo => withFreshMacroScope do
let rhs ← if isProof then
`(have h : @$a = @$b := rfl; by subst h; exact $(← mkSameCtorRhs todo):term)
else
let sameCtor ← mkSameCtorRhs todo
`(if h : @$a = @$b then
by subst h; exact $sameCtor:term
else
isFalse (by intro n; injection n; apply h _; assumption))
if let some auxFunName := recField then
-- add local instance for `a = b` using the function being defined `auxFunName`
`(let inst := $(mkIdent auxFunName) @$a @$b; $rhs)
else
return rhs
def mkMatch (ctx : Context) (header : Header) (indVal : InductiveVal) : TermElabM Term := do
if indVal.numCtors ≥ deriving.decEq.linear_construction_threshold.get (← getOptions) then
mkMatchNew ctx header indVal
else
mkMatchOld ctx header indVal
def mkAuxFunction (ctx : Context) (auxFunName : Name) (indVal : InductiveVal): TermElabM (TSyntax `command) := do
let header ← mkDecEqHeader indVal
let body ← mkMatch ctx header indVal
let binders := header.binders
let target₁ := mkIdent header.targetNames[0]!
let target₂ := mkIdent header.targetNames[1]!
let termSuffix ← if indVal.isRec
then `(Parser.Termination.suffix|termination_by structural $target₁)
else `(Parser.Termination.suffix|)
let type ← `(Decidable ($target₁ = $target₂))
`(def $(mkIdent auxFunName):ident $binders:bracketedBinder* : $type:term := $body:term
$termSuffix:suffix)
def mkAuxFunctions (ctx : Context) : TermElabM (TSyntax `command) := do
let mut res : Array (TSyntax `command) := #[]
for i in *...ctx.auxFunNames.size do
let auxFunName := ctx.auxFunNames[i]!
let indVal := ctx.typeInfos[i]!
res := res.push (← mkAuxFunction ctx auxFunName indVal)
`(command| mutual $[$res:command]* end)
def mkDecEqCmds (indVal : InductiveVal) : TermElabM (Array Syntax) := do
let ctx ← mkContext ``DecidableEq "decEq" indVal.name
let cmds := #[← mkAuxFunctions ctx] ++ (← mkInstanceCmds ctx `DecidableEq #[indVal.name] (useAnonCtor := false))
trace[Elab.Deriving.decEq] "\n{cmds}"
return cmds
open Command
def mkDecEq (declName : Name) : CommandElabM Bool := do
let indVal ← getConstInfoInduct declName
if indVal.isNested then
return false -- nested inductive types are not supported yet
else
let cmds ← liftTermElabM <| mkDecEqCmds indVal
-- `cmds` can have a number of syntax nodes quadratic in the number of constructors
-- and thus create as many info tree nodes, which we never make use of but which can
-- significantly slow down e.g. the unused variables linter; avoid creating them
withEnableInfoTree false do
cmds.forM elabCommand
return true
partial def mkEnumOfNat (declName : Name) : MetaM Unit := do
let indVal ← getConstInfoInduct declName
let levels := indVal.levelParams.map Level.param
let enumType := mkConst declName levels
let ctors := indVal.ctors.toArray.map (mkConst · levels)
withLocalDeclD `n (mkConst ``Nat) fun n => do
let value ← mkNatLookupTable n enumType ctors
let value ← mkLambdaFVars #[n] value
let type ← mkArrow (mkConst ``Nat) enumType
addAndCompile <| Declaration.defnDecl {
name := Name.mkStr declName "ofNat"
levelParams := indVal.levelParams
safety := DefinitionSafety.safe
hints := ReducibilityHints.abbrev
value, type
}
def mkEnumOfNatThm (declName : Name) : MetaM Unit := do
let indVal ← getConstInfoInduct declName
let levels := indVal.levelParams.map Level.param
let ctorIdx := mkConst (mkCtorIdxName declName) levels
let ofNat := mkConst (Name.mkStr declName "ofNat") levels
let enumType := mkConst declName levels
let u ← getLevel enumType
let eqEnum := mkApp (mkConst ``Eq [u]) enumType
let rflEnum := mkApp (mkConst ``Eq.refl [u]) enumType
let ctors := indVal.ctors
withLocalDeclD `x enumType fun x => do
let resultType := mkApp2 eqEnum (mkApp ofNat (mkApp ctorIdx x)) x
let motive ← mkLambdaFVars #[x] resultType
let casesOn := mkConst (mkCasesOnName declName) (Level.zero :: levels)
let mut value := mkApp2 casesOn motive x
for ctor in ctors do
value := mkApp value (mkApp rflEnum (mkConst ctor levels))
value ← mkLambdaFVars #[x] value
let type ← mkForallFVars #[x] resultType
addAndCompile <| Declaration.thmDecl {
name := Name.mkStr declName "ofNat_ctorIdx"
levelParams := indVal.levelParams
value, type
}
def mkDecEqEnum (declName : Name) : CommandElabM Unit := do
let cmd ← liftTermElabM do
mkEnumOfNat declName
mkEnumOfNatThm declName
let ofNatIdent := mkIdent (Name.mkStr declName "ofNat")
let auxThmIdent := mkIdent (Name.mkStr declName "ofNat_ctorIdx")
`(instance : DecidableEq $(mkCIdent declName) :=
fun x y =>
if h : x.ctorIdx = y.ctorIdx then
-- We use `rfl` in the following proof because the first script fails for unit-like datatypes due to etaStruct.
isTrue (by first | have aux := congrArg $ofNatIdent h; rw [$auxThmIdent:ident, $auxThmIdent:ident] at aux; assumption | rfl)
else
isFalse fun h => by subst h; contradiction)
trace[Elab.Deriving.decEq] "\n{cmd}"
elabCommand cmd
def mkDecEqInstance (declName : Name) : CommandElabM Bool := do
withoutExposeFromCtors declName do
if (← isEnumType declName) then
mkDecEqEnum declName
return true
else
mkDecEq declName
def mkDecEqInstanceHandler (declNames : Array Name) : CommandElabM Bool := do
declNames.foldlM (fun b n => andM (pure b) (mkDecEqInstance n)) true
builtin_initialize
registerDerivingHandler `DecidableEq mkDecEqInstanceHandler
registerTraceClass `Elab.Deriving.decEq
end Lean.Elab.Deriving.DecEq