feat: optional name for unification hints
This commit is contained in:
parent
f649f24014
commit
05fc1e8bbf
4 changed files with 65 additions and 30 deletions
|
|
@ -60,14 +60,21 @@ def expandBrackedBinders (combinatorDeclName : Name) (bracketedExplicitBinders :
|
|||
syntax unifConstraint := term (" =?= " <|> " ≟ ") term
|
||||
syntax unifConstraintElem := colGe unifConstraint ", "?
|
||||
|
||||
syntax "unif_hint " bracketedBinder* " where " withPosition(unifConstraintElem*) ("|-" <|> "⊢ ") unifConstraint : command
|
||||
syntax "unif_hint " (ident)? bracketedBinder* " where " withPosition(unifConstraintElem*) ("|-" <|> "⊢ ") unifConstraint : command
|
||||
|
||||
private def mkHintBody (cs : Array Syntax) (p : Syntax) : MacroM Syntax := do
|
||||
let mut body ← `($(p[0]) = $(p[2]))
|
||||
for c in cs.reverse do
|
||||
body ← `($(c[0][0]) = $(c[0][2]) → $body)
|
||||
return body
|
||||
|
||||
macro_rules
|
||||
| `(unif_hint $bs* where $cs* |- $p) => do
|
||||
let mut body ← `($(p[0]) = $(p[2]))
|
||||
for c in cs.reverse do
|
||||
body ← `($(c[0][0]) = $(c[0][2]) → $body)
|
||||
let body ← mkHintBody cs p
|
||||
`(@[unificationHint] def hint $bs:explicitBinder* : Sort _ := $body)
|
||||
| `(unif_hint $n:ident $bs* where $cs* |- $p) => do
|
||||
let body ← mkHintBody cs p
|
||||
`(@[unificationHint] def $n:ident $bs:explicitBinder* : Sort _ := $body)
|
||||
|
||||
end Lean
|
||||
|
||||
|
|
|
|||
|
|
@ -255,7 +255,7 @@ def isLevelDefEq (u v : Level) : m Bool := liftMetaM do
|
|||
isLevelDefEqImp u v
|
||||
|
||||
def isExprDefEqImp (t s : Expr) : MetaM Bool :=
|
||||
traceCtx `Meta.isDefEq $ do
|
||||
traceCtx `Meta.isDefEq do
|
||||
let b ← commitWhen (mayPostpone := true) $ Meta.isExprDefEqAux t s
|
||||
trace[Meta.isDefEq]! "{t} =?= {s} ... {if b then "success" else "failure"}"
|
||||
pure b
|
||||
|
|
|
|||
|
|
@ -58,14 +58,14 @@ where
|
|||
private partial def validateHint (declName : Name) (hint : UnificationHint) : MetaM Unit := do
|
||||
hint.constraints.forM fun c => do
|
||||
unless (← isDefEq c.lhs c.rhs) do
|
||||
throwError! "invalid unification hint '{declName}', failed to unify constraint left-hand-side{indentExpr c.lhs}\nwith right-hand-side{indentExpr c.rhs}"
|
||||
throwError! "invalid unification hint, failed to unify constraint left-hand-side{indentExpr c.lhs}\nwith right-hand-side{indentExpr c.rhs}"
|
||||
unless (← isDefEq hint.pattern.lhs hint.pattern.rhs) do
|
||||
throwError! "invalid unification hint '{declName}', failed to unify pattern left-hand-side{indentExpr hint.pattern.lhs}\nwith right-hand-side{indentExpr hint.pattern.rhs}"
|
||||
throwError! "invalid unification hint, failed to unify pattern left-hand-side{indentExpr hint.pattern.lhs}\nwith right-hand-side{indentExpr hint.pattern.rhs}"
|
||||
|
||||
def addUnificationHint (declName : Name) : MetaM Unit := do
|
||||
let info ← getConstInfo declName
|
||||
match info.value? with
|
||||
| none => throwError! "invalid unification hint '{declName}', it must be a definition"
|
||||
| none => throwError! "invalid unification hint, it must be a definition"
|
||||
| some val =>
|
||||
let (_, _, body) ← lambdaMetaTelescope val
|
||||
match decodeUnificationHint body with
|
||||
|
|
@ -101,27 +101,31 @@ where
|
|||
isDefEqPattern p e :=
|
||||
withReducible <| Meta.isExprDefEqAux p e
|
||||
|
||||
tryCandidate candidate : MetaM Bool := commitWhen do
|
||||
trace[Meta.isDefEq.hint]! "trying hint {candidate} at {t} =?= {s}"
|
||||
let cinfo ← getConstInfo candidate
|
||||
let hint? ← withConfig (fun cfg => { cfg with unificationHints := false }) do
|
||||
let us ← cinfo.lparams.mapM fun _ => mkFreshLevelMVar
|
||||
let val := cinfo.instantiateValueLevelParams us
|
||||
let (_, _, body) ← lambdaMetaTelescope val
|
||||
match decodeUnificationHint body with
|
||||
| Except.error _ => return none
|
||||
| Except.ok hint =>
|
||||
if (← isDefEqPattern hint.pattern.lhs t <&&> isDefEqPattern hint.pattern.rhs s) then
|
||||
return some hint
|
||||
else
|
||||
return none
|
||||
match hint? with
|
||||
| none => return false
|
||||
| some hint =>
|
||||
trace[Meta.isDefEq.hint]! "{candidate} succeeded, applying constraints"
|
||||
for c in hint.constraints do
|
||||
unless (← Meta.isExprDefEqAux c.lhs c.rhs) do
|
||||
return false
|
||||
return true
|
||||
tryCandidate candidate : MetaM Bool :=
|
||||
traceCtx `Meta.isDefEq.hint <| commitWhen do
|
||||
trace[Meta.isDefEq.hint]! "trying hint {candidate} at {t} =?= {s}"
|
||||
let cinfo ← getConstInfo candidate
|
||||
let hint? ← withConfig (fun cfg => { cfg with unificationHints := false }) do
|
||||
let us ← cinfo.lparams.mapM fun _ => mkFreshLevelMVar
|
||||
let val := cinfo.instantiateValueLevelParams us
|
||||
let (_, _, body) ← lambdaMetaTelescope val
|
||||
match decodeUnificationHint body with
|
||||
| Except.error _ => return none
|
||||
| Except.ok hint =>
|
||||
if (← isDefEqPattern hint.pattern.lhs t <&&> isDefEqPattern hint.pattern.rhs s) then
|
||||
return some hint
|
||||
else
|
||||
return none
|
||||
match hint? with
|
||||
| none => return false
|
||||
| some hint =>
|
||||
trace[Meta.isDefEq.hint]! "{candidate} succeeded, applying constraints"
|
||||
for c in hint.constraints do
|
||||
unless (← Meta.isExprDefEqAux c.lhs c.rhs) do
|
||||
return false
|
||||
return true
|
||||
|
||||
builtin_initialize
|
||||
registerTraceClass `Meta.isDefEq.hint
|
||||
|
||||
end Lean.Meta
|
||||
|
|
|
|||
24
tests/lean/run/unifhint2.lean
Normal file
24
tests/lean/run/unifhint2.lean
Normal file
|
|
@ -0,0 +1,24 @@
|
|||
/- The following hints are too expensive, but good enough for small natural numbers -/
|
||||
unif_hint natAddBase (x y : Nat) where
|
||||
y =?= 0
|
||||
|-
|
||||
Nat.add (Nat.succ x) y =?= Nat.succ x
|
||||
|
||||
unif_hint natAddStep (x y z w : Nat) where
|
||||
y =?= Nat.succ w
|
||||
z =?= Nat.add (Nat.succ x) w
|
||||
|-
|
||||
Nat.add (Nat.succ x) y =?= Nat.succ z
|
||||
|
||||
def BV (n : Nat) := { a : Array Bool // a.size = n }
|
||||
|
||||
def sext (x : BV s) (n : Nat) : BV (s+n) :=
|
||||
⟨mkArray (s+n) false, Array.sizeMkArrayEq ..⟩
|
||||
|
||||
def bvmul (x y : BV w) : BV w := x
|
||||
|
||||
def tst1 (x y : BV 64) : BV 128 :=
|
||||
bvmul (sext x 64) (sext y _)
|
||||
|
||||
def tst2 (x y : BV 16) : BV 32 :=
|
||||
bvmul (sext x 16) (sext y _)
|
||||
Loading…
Add table
Reference in a new issue