164 lines
6.8 KiB
Text
164 lines
6.8 KiB
Text
/-
|
|
Copyright (c) 2019 Paul-Nicolas Madelaine. All rights reserved.
|
|
Released under Apache 2.0 license as described in the file LICENSE.
|
|
Authors: Paul-Nicolas Madelaine, Robert Y. Lewis, Mario Carneiro, Gabriel Ebner
|
|
-/
|
|
prelude
|
|
import Lean.Meta.CongrTheorems
|
|
import Lean.Meta.Tactic.Simp.Attr
|
|
import Lean.Meta.CoeAttr
|
|
|
|
namespace Lean.Meta.NormCast
|
|
/--
|
|
`Label` is a type used to classify `norm_cast` lemmas.
|
|
* elim lemma: LHS has 0 head coes and ≥ 1 internal coe
|
|
* move lemma: LHS has 1 head coe and 0 internal coes, RHS has 0 head coes and ≥ 1 internal coes
|
|
* squash lemma: LHS has ≥ 1 head coes and 0 internal coes, RHS has fewer head coes
|
|
-/
|
|
inductive Label
|
|
/-- elim lemma: LHS has 0 head coes and ≥ 1 internal coe -/
|
|
| elim
|
|
/-- move lemma: LHS has 1 head coe and 0 internal coes,
|
|
RHS has 0 head coes and ≥ 1 internal coes -/
|
|
| move
|
|
/-- squash lemma: LHS has ≥ 1 head coes and 0 internal coes, RHS has fewer head coes -/
|
|
| squash
|
|
deriving DecidableEq, Repr, Inhabited
|
|
|
|
/-- Assuming `e` is an application, returns the list of subterms that `simp` will rewrite in. -/
|
|
def getSimpArgs (e : Expr) : MetaM (Array Expr) := do
|
|
match ← mkCongrSimp? e.getAppFn with
|
|
| none => return e.getAppArgs
|
|
| some {argKinds, ..} =>
|
|
let mut args := #[]
|
|
for a in e.getAppArgs, k in argKinds do
|
|
if k matches .eq then
|
|
args := args.push a
|
|
return args
|
|
|
|
/-- Counts how many coercions are at the head of the expression. -/
|
|
partial def countHeadCoes (e : Expr) : MetaM Nat := do
|
|
if let Expr.const fn .. := e.getAppFn then
|
|
if let some info ← getCoeFnInfo? fn then
|
|
if e.getAppNumArgs >= info.numArgs then
|
|
return (← countHeadCoes (e.getArg! info.coercee)) + 1
|
|
return 0
|
|
|
|
/-- Counts how many coercions are inside the expression, including the head ones. -/
|
|
partial def countCoes (e : Expr) : MetaM Nat :=
|
|
lambdaTelescope e fun _ e => do
|
|
if let Expr.const fn .. := e.getAppFn then
|
|
if let some info ← getCoeFnInfo? fn then
|
|
if e.getAppNumArgs >= info.numArgs then
|
|
let mut coes := (← countHeadCoes (e.getArg! info.coercee)) + 1
|
|
for i in [info.numArgs:e.getAppNumArgs] do
|
|
coes := coes + (← countCoes (e.getArg! i))
|
|
return coes
|
|
return (← (← getSimpArgs e).mapM countCoes).foldl (·+·) 0
|
|
|
|
/-- Counts how many coercions are inside the expression, excluding the head ones. -/
|
|
def countInternalCoes (e : Expr) : MetaM Nat :=
|
|
return (← countCoes e) - (← countHeadCoes e)
|
|
|
|
/-- Classifies a declaration of type `ty` as a `norm_cast` rule. -/
|
|
def classifyType (ty : Expr) : MetaM Label :=
|
|
forallTelescopeReducing ty fun _ ty => do
|
|
let ty ← whnf ty
|
|
let (lhs, rhs) ←
|
|
if ty.isAppOfArity ``Eq 3 then pure (ty.getArg! 1, ty.getArg! 2)
|
|
else if ty.isAppOfArity ``Iff 2 then pure (ty.getArg! 0, ty.getArg! 1)
|
|
else throwError "norm_cast: lemma must be = or ↔, but is{indentExpr ty}"
|
|
let lhsCoes ← countCoes lhs
|
|
if lhsCoes = 0 then
|
|
throwError "norm_cast: badly shaped lemma, lhs must contain at least one coe{indentExpr lhs}"
|
|
let lhsHeadCoes ← countHeadCoes lhs
|
|
let rhsHeadCoes ← countHeadCoes rhs
|
|
let rhsInternalCoes ← countInternalCoes rhs
|
|
if lhsHeadCoes = 0 then
|
|
return Label.elim
|
|
else if lhsHeadCoes = 1 then do
|
|
unless rhsHeadCoes = 0 do
|
|
throwError "norm_cast: badly shaped lemma, rhs can't start with coe{indentExpr rhs}"
|
|
if rhsInternalCoes = 0 then
|
|
return Label.squash
|
|
else
|
|
return Label.move
|
|
else if rhsHeadCoes < lhsHeadCoes then do
|
|
return Label.squash
|
|
else do
|
|
throwError "\
|
|
norm_cast: badly shaped shaped squash lemma, \
|
|
rhs must have fewer head coes than lhs{indentExpr ty}"
|
|
|
|
/-- The `push_cast` simp attribute. -/
|
|
builtin_initialize pushCastExt : SimpExtension ←
|
|
registerSimpAttr `push_cast "\
|
|
The `push_cast` simp attribute uses `norm_cast` lemmas \
|
|
to move casts toward the leaf nodes of the expression."
|
|
|
|
/-- The `norm_cast` attribute stores a simp set for each of the three types of `norm_cast` lemma. -/
|
|
structure NormCastExtension where
|
|
/-- A simp set which lifts coercions to the top level. -/
|
|
up : SimpExtension
|
|
/-- A simp set which pushes coercions to the leaves. -/
|
|
down : SimpExtension
|
|
/-- A simp set which simplifies transitive coercions. -/
|
|
squash : SimpExtension
|
|
deriving Inhabited
|
|
|
|
/-- The `norm_cast` extension data. -/
|
|
builtin_initialize normCastExt : NormCastExtension ← pure {
|
|
up := ← mkSimpExt (decl_name% ++ `up)
|
|
down := ← mkSimpExt (decl_name% ++ `down)
|
|
squash := ← mkSimpExt (decl_name% ++ `squash)
|
|
}
|
|
|
|
/-- `addElim decl` adds `decl` as an `elim` lemma to be used by `norm_cast`. -/
|
|
def addElim (decl : Name)
|
|
(kind := AttributeKind.global) (prio := eval_prio default) : MetaM Unit :=
|
|
addSimpTheorem normCastExt.up decl (post := true) (inv := false) kind prio
|
|
|
|
/-- `addMove decl` adds `decl` as a `move` lemma to be used by `norm_cast`. -/
|
|
def addMove (decl : Name)
|
|
(kind := AttributeKind.global) (prio := eval_prio default) : MetaM Unit := do
|
|
addSimpTheorem pushCastExt decl (post := true) (inv := false) kind prio
|
|
addSimpTheorem normCastExt.up decl (post := true) (inv := true) kind prio
|
|
addSimpTheorem normCastExt.down decl (post := true) (inv := false) kind prio
|
|
|
|
/-- `addSquash decl` adds `decl` as a `squash` lemma to be used by `norm_cast`. -/
|
|
def addSquash (decl : Name)
|
|
(kind := AttributeKind.global) (prio := eval_prio default) : MetaM Unit := do
|
|
addSimpTheorem pushCastExt decl (post := true) (inv := false) kind prio
|
|
addSimpTheorem normCastExt.squash decl (post := true) (inv := false) kind prio
|
|
addSimpTheorem normCastExt.down decl (post := true) (inv := false) kind prio
|
|
|
|
/-- `addInfer decl` infers the label of `decl` (`elim`, `move`, or `squash`) and arranges for it to
|
|
be used by `norm_cast`.
|
|
|
|
* elim lemma: LHS has 0 head coes and ≥ 1 internal coe
|
|
* move lemma: LHS has 1 head coe and 0 internal coes, RHS has 0 head coes and ≥ 1 internal coes
|
|
* squash lemma: LHS has ≥ 1 head coes and 0 internal coes, RHS has fewer head coes
|
|
-/
|
|
def addInfer (decl : Name)
|
|
(kind := AttributeKind.global) (prio := eval_prio default) : MetaM Unit := do
|
|
let ty := (← getConstInfo decl).type
|
|
match ← classifyType ty with
|
|
| Label.elim => addElim decl kind prio
|
|
| Label.squash => addSquash decl kind prio
|
|
| Label.move => addMove decl kind prio
|
|
|
|
builtin_initialize registerBuiltinAttribute {
|
|
name := `norm_cast
|
|
descr := "attribute for norm_cast"
|
|
add := fun decl stx kind => MetaM.run' do
|
|
let `(attr| norm_cast $[$label:normCastLabel]? $[$prio]?) := stx | unreachable!
|
|
let prio := (prio.bind (·.1.isNatLit?)).getD (eval_prio default)
|
|
match label.bind (·.1.isStrLit?) with
|
|
| "elim" => addElim decl kind prio
|
|
| "move" => addMove decl kind prio
|
|
| "squash" => addSquash decl kind prio
|
|
| none => addInfer decl kind prio
|
|
| _ => unreachable!
|
|
}
|
|
|
|
end Lean.Meta.NormCast
|