From 05fc1e8bbfd56288e90bad9f078213895dcb8027 Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Sun, 29 Nov 2020 07:11:59 -0800 Subject: [PATCH] feat: optional name for unification hints --- src/Init/NotationExtra.lean | 15 ++++++--- src/Lean/Meta/LevelDefEq.lean | 2 +- src/Lean/Meta/UnificationHint.lean | 54 ++++++++++++++++-------------- tests/lean/run/unifhint2.lean | 24 +++++++++++++ 4 files changed, 65 insertions(+), 30 deletions(-) create mode 100644 tests/lean/run/unifhint2.lean diff --git a/src/Init/NotationExtra.lean b/src/Init/NotationExtra.lean index c23cfc3358..4e41424674 100644 --- a/src/Init/NotationExtra.lean +++ b/src/Init/NotationExtra.lean @@ -60,14 +60,21 @@ def expandBrackedBinders (combinatorDeclName : Name) (bracketedExplicitBinders : syntax unifConstraint := term (" =?= " <|> " ≟ ") term syntax unifConstraintElem := colGe unifConstraint ", "? -syntax "unif_hint " bracketedBinder* " where " withPosition(unifConstraintElem*) ("|-" <|> "⊢ ") unifConstraint : command +syntax "unif_hint " (ident)? bracketedBinder* " where " withPosition(unifConstraintElem*) ("|-" <|> "⊢ ") unifConstraint : command + +private def mkHintBody (cs : Array Syntax) (p : Syntax) : MacroM Syntax := do + let mut body ← `($(p[0]) = $(p[2])) + for c in cs.reverse do + body ← `($(c[0][0]) = $(c[0][2]) → $body) + return body macro_rules | `(unif_hint $bs* where $cs* |- $p) => do - let mut body ← `($(p[0]) = $(p[2])) - for c in cs.reverse do - body ← `($(c[0][0]) = $(c[0][2]) → $body) + let body ← mkHintBody cs p `(@[unificationHint] def hint $bs:explicitBinder* : Sort _ := $body) + | `(unif_hint $n:ident $bs* where $cs* |- $p) => do + let body ← mkHintBody cs p + `(@[unificationHint] def $n:ident $bs:explicitBinder* : Sort _ := $body) end Lean diff --git a/src/Lean/Meta/LevelDefEq.lean b/src/Lean/Meta/LevelDefEq.lean index ddecf4b8c4..b692ef1aa5 100644 --- a/src/Lean/Meta/LevelDefEq.lean +++ b/src/Lean/Meta/LevelDefEq.lean @@ -255,7 +255,7 @@ def isLevelDefEq (u v : Level) : m Bool := liftMetaM do isLevelDefEqImp u v def isExprDefEqImp (t s : Expr) : MetaM Bool := - traceCtx `Meta.isDefEq $ do + traceCtx `Meta.isDefEq do let b ← commitWhen (mayPostpone := true) $ Meta.isExprDefEqAux t s trace[Meta.isDefEq]! "{t} =?= {s} ... {if b then "success" else "failure"}" pure b diff --git a/src/Lean/Meta/UnificationHint.lean b/src/Lean/Meta/UnificationHint.lean index 7703fd3f43..6ec1ca9a8c 100644 --- a/src/Lean/Meta/UnificationHint.lean +++ b/src/Lean/Meta/UnificationHint.lean @@ -58,14 +58,14 @@ where private partial def validateHint (declName : Name) (hint : UnificationHint) : MetaM Unit := do hint.constraints.forM fun c => do unless (← isDefEq c.lhs c.rhs) do - throwError! "invalid unification hint '{declName}', failed to unify constraint left-hand-side{indentExpr c.lhs}\nwith right-hand-side{indentExpr c.rhs}" + throwError! "invalid unification hint, failed to unify constraint left-hand-side{indentExpr c.lhs}\nwith right-hand-side{indentExpr c.rhs}" unless (← isDefEq hint.pattern.lhs hint.pattern.rhs) do - throwError! "invalid unification hint '{declName}', failed to unify pattern left-hand-side{indentExpr hint.pattern.lhs}\nwith right-hand-side{indentExpr hint.pattern.rhs}" + throwError! "invalid unification hint, failed to unify pattern left-hand-side{indentExpr hint.pattern.lhs}\nwith right-hand-side{indentExpr hint.pattern.rhs}" def addUnificationHint (declName : Name) : MetaM Unit := do let info ← getConstInfo declName match info.value? with - | none => throwError! "invalid unification hint '{declName}', it must be a definition" + | none => throwError! "invalid unification hint, it must be a definition" | some val => let (_, _, body) ← lambdaMetaTelescope val match decodeUnificationHint body with @@ -101,27 +101,31 @@ where isDefEqPattern p e := withReducible <| Meta.isExprDefEqAux p e - tryCandidate candidate : MetaM Bool := commitWhen do - trace[Meta.isDefEq.hint]! "trying hint {candidate} at {t} =?= {s}" - let cinfo ← getConstInfo candidate - let hint? ← withConfig (fun cfg => { cfg with unificationHints := false }) do - let us ← cinfo.lparams.mapM fun _ => mkFreshLevelMVar - let val := cinfo.instantiateValueLevelParams us - let (_, _, body) ← lambdaMetaTelescope val - match decodeUnificationHint body with - | Except.error _ => return none - | Except.ok hint => - if (← isDefEqPattern hint.pattern.lhs t <&&> isDefEqPattern hint.pattern.rhs s) then - return some hint - else - return none - match hint? with - | none => return false - | some hint => - trace[Meta.isDefEq.hint]! "{candidate} succeeded, applying constraints" - for c in hint.constraints do - unless (← Meta.isExprDefEqAux c.lhs c.rhs) do - return false - return true + tryCandidate candidate : MetaM Bool := + traceCtx `Meta.isDefEq.hint <| commitWhen do + trace[Meta.isDefEq.hint]! "trying hint {candidate} at {t} =?= {s}" + let cinfo ← getConstInfo candidate + let hint? ← withConfig (fun cfg => { cfg with unificationHints := false }) do + let us ← cinfo.lparams.mapM fun _ => mkFreshLevelMVar + let val := cinfo.instantiateValueLevelParams us + let (_, _, body) ← lambdaMetaTelescope val + match decodeUnificationHint body with + | Except.error _ => return none + | Except.ok hint => + if (← isDefEqPattern hint.pattern.lhs t <&&> isDefEqPattern hint.pattern.rhs s) then + return some hint + else + return none + match hint? with + | none => return false + | some hint => + trace[Meta.isDefEq.hint]! "{candidate} succeeded, applying constraints" + for c in hint.constraints do + unless (← Meta.isExprDefEqAux c.lhs c.rhs) do + return false + return true + +builtin_initialize + registerTraceClass `Meta.isDefEq.hint end Lean.Meta diff --git a/tests/lean/run/unifhint2.lean b/tests/lean/run/unifhint2.lean new file mode 100644 index 0000000000..38130749d5 --- /dev/null +++ b/tests/lean/run/unifhint2.lean @@ -0,0 +1,24 @@ +/- The following hints are too expensive, but good enough for small natural numbers -/ +unif_hint natAddBase (x y : Nat) where + y =?= 0 + |- + Nat.add (Nat.succ x) y =?= Nat.succ x + +unif_hint natAddStep (x y z w : Nat) where + y =?= Nat.succ w + z =?= Nat.add (Nat.succ x) w + |- + Nat.add (Nat.succ x) y =?= Nat.succ z + +def BV (n : Nat) := { a : Array Bool // a.size = n } + +def sext (x : BV s) (n : Nat) : BV (s+n) := + ⟨mkArray (s+n) false, Array.sizeMkArrayEq ..⟩ + +def bvmul (x y : BV w) : BV w := x + +def tst1 (x y : BV 64) : BV 128 := + bvmul (sext x 64) (sext y _) + +def tst2 (x y : BV 16) : BV 32 := + bvmul (sext x 16) (sext y _)