lean4-htt/src/Lean/Meta/UnificationHint.lean
2021-08-01 09:58:44 +02:00

141 lines
5.3 KiB
Text

/-
Copyright (c) 2020 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
import Lean.ScopedEnvExtension
import Lean.Util.Recognizers
import Lean.Meta.DiscrTree
import Lean.Meta.LevelDefEq
import Lean.Meta.SynthInstance
namespace Lean.Meta
structure UnificationHintEntry where
keys : Array DiscrTree.Key
val : Name
deriving Inhabited
structure UnificationHints where
discrTree : DiscrTree Name := DiscrTree.empty
deriving Inhabited
instance : ToFormat UnificationHints where
format h := format h.discrTree
def UnificationHints.add (hints : UnificationHints) (e : UnificationHintEntry) : UnificationHints :=
{ hints with discrTree := hints.discrTree.insertCore e.keys e.val }
builtin_initialize unificationHintExtension : SimpleScopedEnvExtension UnificationHintEntry UnificationHints ←
registerSimpleScopedEnvExtension {
name := `unifHints
addEntry := UnificationHints.add
initial := {}
}
structure UnificationConstraint where
lhs : Expr
rhs : Expr
structure UnificationHint where
pattern : UnificationConstraint
constraints : List UnificationConstraint
private partial def decodeUnificationHint (e : Expr) : ExceptT MessageData Id UnificationHint := do
decode e #[]
where
decodeConstraint (e : Expr) : ExceptT MessageData Id UnificationConstraint :=
match e.eq? with
| some (_, lhs, rhs) => return UnificationConstraint.mk lhs rhs
| none => throw m!"invalid unification hint constraint, unexpected term{indentExpr e}"
decode (e : Expr) (cs : Array UnificationConstraint) : ExceptT MessageData Id UnificationHint := do
match e with
| Expr.forallE _ d b _ => do
let c ← decodeConstraint d
if b.hasLooseBVars then
throw m!"invalid unification hint constraint, unexpected dependency{indentExpr e}"
decode b (cs.push c)
| _ => do
let p ← decodeConstraint e
return { pattern := p, constraints := cs.toList }
private partial def validateHint (declName : Name) (hint : UnificationHint) : MetaM Unit := do
hint.constraints.forM fun c => do
unless (← isDefEq c.lhs c.rhs) do
throwError "invalid unification hint, failed to unify constraint left-hand-side{indentExpr c.lhs}\nwith right-hand-side{indentExpr c.rhs}"
unless (← isDefEq hint.pattern.lhs hint.pattern.rhs) do
throwError "invalid unification hint, failed to unify pattern left-hand-side{indentExpr hint.pattern.lhs}\nwith right-hand-side{indentExpr hint.pattern.rhs}"
def addUnificationHint (declName : Name) (kind : AttributeKind) : MetaM Unit :=
withNewMCtxDepth do
let info ← getConstInfo declName
match info.value? with
| none => throwError "invalid unification hint, it must be a definition"
| some val =>
let (_, _, body) ← lambdaMetaTelescope val
match decodeUnificationHint body with
| Except.error msg => throwError msg
| Except.ok hint =>
let keys ← DiscrTree.mkPath hint.pattern.lhs
validateHint declName hint
unificationHintExtension.add { keys := keys, val := declName } kind
trace[Meta.debug] "addUnificationHint: {unificationHintExtension.getState (← getEnv)}"
builtin_initialize
registerBuiltinAttribute {
name := `unificationHint
descr := "unification hint"
add := fun declName stx kind => do
Attribute.Builtin.ensureNoArgs stx
discard <| addUnificationHint declName kind |>.run
}
def tryUnificationHints (t s : Expr) : MetaM Bool := do
trace[Meta.isDefEq.hint] "{t} =?= {s}"
unless (← read).config.unificationHints do
return false
if t.isMVar then
return false
let hints := unificationHintExtension.getState (← getEnv)
let candidates ← hints.discrTree.getMatch t
for candidate in candidates do
if (← tryCandidate candidate) then
return true
return false
where
isDefEqPattern p e :=
withReducible <| Meta.isExprDefEqAux p e
tryCandidate candidate : MetaM Bool :=
traceCtx `Meta.isDefEq.hint <| checkpointDefEq do
trace[Meta.isDefEq.hint] "trying hint {candidate} at {t} =?= {s}"
let cinfo ← getConstInfo candidate
let us ← cinfo.levelParams.mapM fun _ => mkFreshLevelMVar
let val := cinfo.instantiateValueLevelParams us
let (xs, bis, body) ← lambdaMetaTelescope val
let hint? ← withConfig (fun cfg => { cfg with unificationHints := false }) do
match decodeUnificationHint body with
| Except.error _ => return none
| Except.ok hint =>
if (← isDefEqPattern hint.pattern.lhs t <&&> isDefEqPattern hint.pattern.rhs s) then
return some hint
else
return none
match hint? with
| none => return false
| some hint =>
trace[Meta.isDefEq.hint] "{candidate} succeeded, applying constraints"
for c in hint.constraints do
unless (← Meta.isExprDefEqAux c.lhs c.rhs) do
return false
for x in xs, bi in bis do
if bi == BinderInfo.instImplicit then
match (← trySynthInstance (← inferType x)) with
| LOption.some val => unless (← isDefEq x val) do return false
| _ => return false
return true
builtin_initialize
registerTraceClass `Meta.isDefEq.hint
end Lean.Meta