feat: elaborate binrel! macro

This commit is contained in:
Leonardo de Moura 2020-12-29 15:31:29 -08:00
parent fcd155931b
commit 51e2db9850
4 changed files with 71 additions and 17 deletions

View file

@ -925,6 +925,30 @@ private def elabAtom : TermElab := fun stx expectedType? =>
@[builtinTermElab proj] def elabProj : TermElab := elabAtom
@[builtinTermElab arrayRef] def elabArrayRef : TermElab := elabAtom
@[builtinTermElab binrel] def elabBinRel : TermElab := fun stx expectedType? => do
match (← resolveId? stx[1]) with
| some f =>
let (lhs, rhs) ← withSynthesize (mayPostpone := true) do
let mut lhs ← elabTerm stx[2] none
let mut rhs ← elabTerm stx[3] none
if lhs.isAppOfArity `OfNat.ofNat 3 then
lhs ← ensureHasType (← inferType rhs) lhs
else if rhs.isAppOfArity `OfNat.ofNat 3 then
rhs ← ensureHasType (← inferType lhs) rhs
return (lhs, rhs)
let lhsType ← inferType lhs
let rhsType ← inferType rhs
let (lhs, rhs) ←
try
pure (lhs, ← withRef stx[3] do ensureHasType lhsType rhs)
catch ex =>
try
pure (← withRef stx[2] do ensureHasType rhsType lhs, rhs)
catch _ =>
throw ex
elabAppArgs f #[] #[Arg.expr lhs, Arg.expr rhs] expectedType? (explicit := false) (ellipsis := false)
| none => throwUnknownConstant stx[1].getId
builtin_initialize
registerTraceClass `Elab.app

View file

@ -210,21 +210,6 @@ private def getNumExplicitCtorParams (ctorVal : ConstructorVal) : TermElabM Nat
result := result+1
pure result
private def throwAmbiguous {α} (fs : List Expr) : M α :=
throwError! "ambiguous pattern, use fully qualified name, possible interpretations {fs}"
def resolveId? (stx : Syntax) : M (Option Expr) :=
match stx with
| Syntax.ident _ _ val preresolved => do
let rs ← try resolveName val preresolved [] catch _ => pure []
let rs := rs.filter fun ⟨f, projs⟩ => projs.isEmpty
let fs := rs.map fun (f, _) => f
match fs with
| [] => pure none
| [f] => pure (some f)
| _ => throwAmbiguous fs
| _ => throwError "identifier expected"
private def throwInvalidPattern {α} : M α :=
throwError "invalid pattern"
@ -330,7 +315,7 @@ def processCtorApp (collect : Syntax → M Syntax) (f : Syntax) (namedArgs : Arr
| `($fId:ident) => pure (fId, false)
| `(@$fId:ident) => pure (fId, true)
| _ => throwError "identifier expected"
let some (Expr.const fName _ _) ← resolveId? fId | throwCtorExpected
let some (Expr.const fName _ _) ← resolveId? fId "pattern" | throwCtorExpected
let fInfo ← getConstInfo fName
forallTelescopeReducing fInfo.type fun xs _ => do
let paramDecls ← xs.mapM (getFVarLocalDecl ·)
@ -368,7 +353,7 @@ private def processVar (idStx : Syntax) : M Syntax := do
/- Check whether `stx` is a pattern variable or constructor-like (i.e., constructor or constant tagged with `[matchPattern]` attribute) -/
private def processId (collect : Syntax → M Syntax) (stx : Syntax) : M Syntax := do
let env ← getEnv
match (← resolveId? stx) with
match (← resolveId? stx "pattern") with
| none => processVar stx
| some f => match f with
| Expr.const fName _ _ =>

View file

@ -1243,6 +1243,18 @@ def resolveName (n : Name) (preresolved : List (Name × List String)) (explicitL
else
process preresolved
def resolveId? (stx : Syntax) (kind := "term") : TermElabM (Option Expr) :=
match stx with
| Syntax.ident _ _ val preresolved => do
let rs ← try resolveName val preresolved [] catch _ => pure []
let rs := rs.filter fun ⟨f, projs⟩ => projs.isEmpty
let fs := rs.map fun (f, _) => f
match fs with
| [] => pure none
| [f] => pure (some f)
| _ => throwError! "ambiguous {kind}, use fully qualified name, possible interpretations {fs}"
| _ => throwError "identifier expected"
@[builtinTermElab cdot] def elabBadCDot : TermElab := fun stx _ =>
throwError "invalid occurrence of `·` notation, it must be surrounded by parentheses (e.g. `(· + 1)`)"

View file

@ -0,0 +1,33 @@
def ex1 (x y : Nat) (i j : Int) :=
binrel! Less x i
def ex2 (x y : Nat) (i j : Int) :=
binrel! Less i x
def ex3 (x y : Nat) (i j : Int) :=
binrel! Less (i + 1) x
def ex4 (x y : Nat) (i j : Int) :=
binrel! Less i (x + 1)
def ex5 (x y : Nat) (i j : Int) :=
binrel! Less i (x + y)
def ex6 (x y : Nat) (i j : Int) :=
binrel! Less (i + j) (x + 0)
def ex7 (x y : Nat) (i j : Int) :=
binrel! Less (i + j) (x + i)
def ex8 (x y : Nat) (i j : Int) :=
binrel! Less (i + 0) (x + i)
def ex9 (n : UInt32) :=
binrel! Less n 0xd800
def ex10 (x : Lean.Syntax) : Bool :=
x.getArgs.all (binrel! BEq.beq ·.getKind `foo)
def ex11 (xs : Array (Nat × Nat)) :=
let f a b := binrel! Less a.1 b.1
f xs[1] xs[2]