feat: optional name for unification hints

This commit is contained in:
Leonardo de Moura 2020-11-29 07:11:59 -08:00
parent f649f24014
commit 05fc1e8bbf
4 changed files with 65 additions and 30 deletions

View file

@ -60,14 +60,21 @@ def expandBrackedBinders (combinatorDeclName : Name) (bracketedExplicitBinders :
syntax unifConstraint := term (" =?= " <|> " ≟ ") term
syntax unifConstraintElem := colGe unifConstraint ", "?
syntax "unif_hint " bracketedBinder* " where " withPosition(unifConstraintElem*) ("|-" <|> "⊢ ") unifConstraint : command
syntax "unif_hint " (ident)? bracketedBinder* " where " withPosition(unifConstraintElem*) ("|-" <|> "⊢ ") unifConstraint : command
private def mkHintBody (cs : Array Syntax) (p : Syntax) : MacroM Syntax := do
let mut body ← `($(p[0]) = $(p[2]))
for c in cs.reverse do
body ← `($(c[0][0]) = $(c[0][2]) → $body)
return body
macro_rules
| `(unif_hint $bs* where $cs* |- $p) => do
let mut body ← `($(p[0]) = $(p[2]))
for c in cs.reverse do
body ← `($(c[0][0]) = $(c[0][2]) → $body)
let body ← mkHintBody cs p
`(@[unificationHint] def hint $bs:explicitBinder* : Sort _ := $body)
| `(unif_hint $n:ident $bs* where $cs* |- $p) => do
let body ← mkHintBody cs p
`(@[unificationHint] def $n:ident $bs:explicitBinder* : Sort _ := $body)
end Lean

View file

@ -255,7 +255,7 @@ def isLevelDefEq (u v : Level) : m Bool := liftMetaM do
isLevelDefEqImp u v
def isExprDefEqImp (t s : Expr) : MetaM Bool :=
traceCtx `Meta.isDefEq $ do
traceCtx `Meta.isDefEq do
let b ← commitWhen (mayPostpone := true) $ Meta.isExprDefEqAux t s
trace[Meta.isDefEq]! "{t} =?= {s} ... {if b then "success" else "failure"}"
pure b

View file

@ -58,14 +58,14 @@ where
private partial def validateHint (declName : Name) (hint : UnificationHint) : MetaM Unit := do
hint.constraints.forM fun c => do
unless (← isDefEq c.lhs c.rhs) do
throwError! "invalid unification hint '{declName}', failed to unify constraint left-hand-side{indentExpr c.lhs}\nwith right-hand-side{indentExpr c.rhs}"
throwError! "invalid unification hint, failed to unify constraint left-hand-side{indentExpr c.lhs}\nwith right-hand-side{indentExpr c.rhs}"
unless (← isDefEq hint.pattern.lhs hint.pattern.rhs) do
throwError! "invalid unification hint '{declName}', failed to unify pattern left-hand-side{indentExpr hint.pattern.lhs}\nwith right-hand-side{indentExpr hint.pattern.rhs}"
throwError! "invalid unification hint, failed to unify pattern left-hand-side{indentExpr hint.pattern.lhs}\nwith right-hand-side{indentExpr hint.pattern.rhs}"
def addUnificationHint (declName : Name) : MetaM Unit := do
let info ← getConstInfo declName
match info.value? with
| none => throwError! "invalid unification hint '{declName}', it must be a definition"
| none => throwError! "invalid unification hint, it must be a definition"
| some val =>
let (_, _, body) ← lambdaMetaTelescope val
match decodeUnificationHint body with
@ -101,27 +101,31 @@ where
isDefEqPattern p e :=
withReducible <| Meta.isExprDefEqAux p e
tryCandidate candidate : MetaM Bool := commitWhen do
trace[Meta.isDefEq.hint]! "trying hint {candidate} at {t} =?= {s}"
let cinfo ← getConstInfo candidate
let hint? ← withConfig (fun cfg => { cfg with unificationHints := false }) do
let us ← cinfo.lparams.mapM fun _ => mkFreshLevelMVar
let val := cinfo.instantiateValueLevelParams us
let (_, _, body) ← lambdaMetaTelescope val
match decodeUnificationHint body with
| Except.error _ => return none
| Except.ok hint =>
if (← isDefEqPattern hint.pattern.lhs t <&&> isDefEqPattern hint.pattern.rhs s) then
return some hint
else
return none
match hint? with
| none => return false
| some hint =>
trace[Meta.isDefEq.hint]! "{candidate} succeeded, applying constraints"
for c in hint.constraints do
unless (← Meta.isExprDefEqAux c.lhs c.rhs) do
return false
return true
tryCandidate candidate : MetaM Bool :=
traceCtx `Meta.isDefEq.hint <| commitWhen do
trace[Meta.isDefEq.hint]! "trying hint {candidate} at {t} =?= {s}"
let cinfo ← getConstInfo candidate
let hint? ← withConfig (fun cfg => { cfg with unificationHints := false }) do
let us ← cinfo.lparams.mapM fun _ => mkFreshLevelMVar
let val := cinfo.instantiateValueLevelParams us
let (_, _, body) ← lambdaMetaTelescope val
match decodeUnificationHint body with
| Except.error _ => return none
| Except.ok hint =>
if (← isDefEqPattern hint.pattern.lhs t <&&> isDefEqPattern hint.pattern.rhs s) then
return some hint
else
return none
match hint? with
| none => return false
| some hint =>
trace[Meta.isDefEq.hint]! "{candidate} succeeded, applying constraints"
for c in hint.constraints do
unless (← Meta.isExprDefEqAux c.lhs c.rhs) do
return false
return true
builtin_initialize
registerTraceClass `Meta.isDefEq.hint
end Lean.Meta

View file

@ -0,0 +1,24 @@
/- The following hints are too expensive, but good enough for small natural numbers -/
unif_hint natAddBase (x y : Nat) where
y =?= 0
|-
Nat.add (Nat.succ x) y =?= Nat.succ x
unif_hint natAddStep (x y z w : Nat) where
y =?= Nat.succ w
z =?= Nat.add (Nat.succ x) w
|-
Nat.add (Nat.succ x) y =?= Nat.succ z
def BV (n : Nat) := { a : Array Bool // a.size = n }
def sext (x : BV s) (n : Nat) : BV (s+n) :=
⟨mkArray (s+n) false, Array.sizeMkArrayEq ..⟩
def bvmul (x y : BV w) : BV w := x
def tst1 (x y : BV 64) : BV 128 :=
bvmul (sext x 64) (sext y _)
def tst2 (x y : BV 16) : BV 32 :=
bvmul (sext x 16) (sext y _)