141 lines
5.3 KiB
Text
141 lines
5.3 KiB
Text
/-
|
|
Copyright (c) 2020 Microsoft Corporation. All rights reserved.
|
|
Released under Apache 2.0 license as described in the file LICENSE.
|
|
Authors: Leonardo de Moura
|
|
-/
|
|
import Lean.ScopedEnvExtension
|
|
import Lean.Util.Recognizers
|
|
import Lean.Meta.DiscrTree
|
|
import Lean.Meta.LevelDefEq
|
|
import Lean.Meta.SynthInstance
|
|
|
|
namespace Lean.Meta
|
|
|
|
structure UnificationHintEntry where
|
|
keys : Array DiscrTree.Key
|
|
val : Name
|
|
deriving Inhabited
|
|
|
|
structure UnificationHints where
|
|
discrTree : DiscrTree Name := DiscrTree.empty
|
|
deriving Inhabited
|
|
|
|
instance : ToFormat UnificationHints where
|
|
format h := format h.discrTree
|
|
|
|
def UnificationHints.add (hints : UnificationHints) (e : UnificationHintEntry) : UnificationHints :=
|
|
{ hints with discrTree := hints.discrTree.insertCore e.keys e.val }
|
|
|
|
builtin_initialize unificationHintExtension : SimpleScopedEnvExtension UnificationHintEntry UnificationHints ←
|
|
registerSimpleScopedEnvExtension {
|
|
name := `unifHints
|
|
addEntry := UnificationHints.add
|
|
initial := {}
|
|
}
|
|
|
|
structure UnificationConstraint where
|
|
lhs : Expr
|
|
rhs : Expr
|
|
|
|
structure UnificationHint where
|
|
pattern : UnificationConstraint
|
|
constraints : List UnificationConstraint
|
|
|
|
private partial def decodeUnificationHint (e : Expr) : ExceptT MessageData Id UnificationHint := do
|
|
decode e #[]
|
|
where
|
|
decodeConstraint (e : Expr) : ExceptT MessageData Id UnificationConstraint :=
|
|
match e.eq? with
|
|
| some (_, lhs, rhs) => return UnificationConstraint.mk lhs rhs
|
|
| none => throw m!"invalid unification hint constraint, unexpected term{indentExpr e}"
|
|
decode (e : Expr) (cs : Array UnificationConstraint) : ExceptT MessageData Id UnificationHint := do
|
|
match e with
|
|
| Expr.forallE _ d b _ => do
|
|
let c ← decodeConstraint d
|
|
if b.hasLooseBVars then
|
|
throw m!"invalid unification hint constraint, unexpected dependency{indentExpr e}"
|
|
decode b (cs.push c)
|
|
| _ => do
|
|
let p ← decodeConstraint e
|
|
return { pattern := p, constraints := cs.toList }
|
|
|
|
private partial def validateHint (declName : Name) (hint : UnificationHint) : MetaM Unit := do
|
|
hint.constraints.forM fun c => do
|
|
unless (← isDefEq c.lhs c.rhs) do
|
|
throwError "invalid unification hint, failed to unify constraint left-hand-side{indentExpr c.lhs}\nwith right-hand-side{indentExpr c.rhs}"
|
|
unless (← isDefEq hint.pattern.lhs hint.pattern.rhs) do
|
|
throwError "invalid unification hint, failed to unify pattern left-hand-side{indentExpr hint.pattern.lhs}\nwith right-hand-side{indentExpr hint.pattern.rhs}"
|
|
|
|
def addUnificationHint (declName : Name) (kind : AttributeKind) : MetaM Unit :=
|
|
withNewMCtxDepth do
|
|
let info ← getConstInfo declName
|
|
match info.value? with
|
|
| none => throwError "invalid unification hint, it must be a definition"
|
|
| some val =>
|
|
let (_, _, body) ← lambdaMetaTelescope val
|
|
match decodeUnificationHint body with
|
|
| Except.error msg => throwError msg
|
|
| Except.ok hint =>
|
|
let keys ← DiscrTree.mkPath hint.pattern.lhs
|
|
validateHint declName hint
|
|
unificationHintExtension.add { keys := keys, val := declName } kind
|
|
trace[Meta.debug] "addUnificationHint: {unificationHintExtension.getState (← getEnv)}"
|
|
|
|
builtin_initialize
|
|
registerBuiltinAttribute {
|
|
name := `unificationHint
|
|
descr := "unification hint"
|
|
add := fun declName stx kind => do
|
|
Attribute.Builtin.ensureNoArgs stx
|
|
discard <| addUnificationHint declName kind |>.run
|
|
}
|
|
|
|
def tryUnificationHints (t s : Expr) : MetaM Bool := do
|
|
trace[Meta.isDefEq.hint] "{t} =?= {s}"
|
|
unless (← read).config.unificationHints do
|
|
return false
|
|
if t.isMVar then
|
|
return false
|
|
let hints := unificationHintExtension.getState (← getEnv)
|
|
let candidates ← hints.discrTree.getMatch t
|
|
for candidate in candidates do
|
|
if (← tryCandidate candidate) then
|
|
return true
|
|
return false
|
|
where
|
|
isDefEqPattern p e :=
|
|
withReducible <| Meta.isExprDefEqAux p e
|
|
|
|
tryCandidate candidate : MetaM Bool :=
|
|
traceCtx `Meta.isDefEq.hint <| checkpointDefEq do
|
|
trace[Meta.isDefEq.hint] "trying hint {candidate} at {t} =?= {s}"
|
|
let cinfo ← getConstInfo candidate
|
|
let us ← cinfo.levelParams.mapM fun _ => mkFreshLevelMVar
|
|
let val := cinfo.instantiateValueLevelParams us
|
|
let (xs, bis, body) ← lambdaMetaTelescope val
|
|
let hint? ← withConfig (fun cfg => { cfg with unificationHints := false }) do
|
|
match decodeUnificationHint body with
|
|
| Except.error _ => return none
|
|
| Except.ok hint =>
|
|
if (← isDefEqPattern hint.pattern.lhs t <&&> isDefEqPattern hint.pattern.rhs s) then
|
|
return some hint
|
|
else
|
|
return none
|
|
match hint? with
|
|
| none => return false
|
|
| some hint =>
|
|
trace[Meta.isDefEq.hint] "{candidate} succeeded, applying constraints"
|
|
for c in hint.constraints do
|
|
unless (← Meta.isExprDefEqAux c.lhs c.rhs) do
|
|
return false
|
|
for x in xs, bi in bis do
|
|
if bi == BinderInfo.instImplicit then
|
|
match (← trySynthInstance (← inferType x)) with
|
|
| LOption.some val => unless (← isDefEq x val) do return false
|
|
| _ => return false
|
|
return true
|
|
|
|
builtin_initialize
|
|
registerTraceClass `Meta.isDefEq.hint
|
|
|
|
end Lean.Meta
|