feat: elaborate binrel! macro
This commit is contained in:
parent
fcd155931b
commit
51e2db9850
4 changed files with 71 additions and 17 deletions
|
|
@ -925,6 +925,30 @@ private def elabAtom : TermElab := fun stx expectedType? =>
|
|||
@[builtinTermElab proj] def elabProj : TermElab := elabAtom
|
||||
@[builtinTermElab arrayRef] def elabArrayRef : TermElab := elabAtom
|
||||
|
||||
@[builtinTermElab binrel] def elabBinRel : TermElab := fun stx expectedType? => do
|
||||
match (← resolveId? stx[1]) with
|
||||
| some f =>
|
||||
let (lhs, rhs) ← withSynthesize (mayPostpone := true) do
|
||||
let mut lhs ← elabTerm stx[2] none
|
||||
let mut rhs ← elabTerm stx[3] none
|
||||
if lhs.isAppOfArity `OfNat.ofNat 3 then
|
||||
lhs ← ensureHasType (← inferType rhs) lhs
|
||||
else if rhs.isAppOfArity `OfNat.ofNat 3 then
|
||||
rhs ← ensureHasType (← inferType lhs) rhs
|
||||
return (lhs, rhs)
|
||||
let lhsType ← inferType lhs
|
||||
let rhsType ← inferType rhs
|
||||
let (lhs, rhs) ←
|
||||
try
|
||||
pure (lhs, ← withRef stx[3] do ensureHasType lhsType rhs)
|
||||
catch ex =>
|
||||
try
|
||||
pure (← withRef stx[2] do ensureHasType rhsType lhs, rhs)
|
||||
catch _ =>
|
||||
throw ex
|
||||
elabAppArgs f #[] #[Arg.expr lhs, Arg.expr rhs] expectedType? (explicit := false) (ellipsis := false)
|
||||
| none => throwUnknownConstant stx[1].getId
|
||||
|
||||
builtin_initialize
|
||||
registerTraceClass `Elab.app
|
||||
|
||||
|
|
|
|||
|
|
@ -210,21 +210,6 @@ private def getNumExplicitCtorParams (ctorVal : ConstructorVal) : TermElabM Nat
|
|||
result := result+1
|
||||
pure result
|
||||
|
||||
private def throwAmbiguous {α} (fs : List Expr) : M α :=
|
||||
throwError! "ambiguous pattern, use fully qualified name, possible interpretations {fs}"
|
||||
|
||||
def resolveId? (stx : Syntax) : M (Option Expr) :=
|
||||
match stx with
|
||||
| Syntax.ident _ _ val preresolved => do
|
||||
let rs ← try resolveName val preresolved [] catch _ => pure []
|
||||
let rs := rs.filter fun ⟨f, projs⟩ => projs.isEmpty
|
||||
let fs := rs.map fun (f, _) => f
|
||||
match fs with
|
||||
| [] => pure none
|
||||
| [f] => pure (some f)
|
||||
| _ => throwAmbiguous fs
|
||||
| _ => throwError "identifier expected"
|
||||
|
||||
private def throwInvalidPattern {α} : M α :=
|
||||
throwError "invalid pattern"
|
||||
|
||||
|
|
@ -330,7 +315,7 @@ def processCtorApp (collect : Syntax → M Syntax) (f : Syntax) (namedArgs : Arr
|
|||
| `($fId:ident) => pure (fId, false)
|
||||
| `(@$fId:ident) => pure (fId, true)
|
||||
| _ => throwError "identifier expected"
|
||||
let some (Expr.const fName _ _) ← resolveId? fId | throwCtorExpected
|
||||
let some (Expr.const fName _ _) ← resolveId? fId "pattern" | throwCtorExpected
|
||||
let fInfo ← getConstInfo fName
|
||||
forallTelescopeReducing fInfo.type fun xs _ => do
|
||||
let paramDecls ← xs.mapM (getFVarLocalDecl ·)
|
||||
|
|
@ -368,7 +353,7 @@ private def processVar (idStx : Syntax) : M Syntax := do
|
|||
/- Check whether `stx` is a pattern variable or constructor-like (i.e., constructor or constant tagged with `[matchPattern]` attribute) -/
|
||||
private def processId (collect : Syntax → M Syntax) (stx : Syntax) : M Syntax := do
|
||||
let env ← getEnv
|
||||
match (← resolveId? stx) with
|
||||
match (← resolveId? stx "pattern") with
|
||||
| none => processVar stx
|
||||
| some f => match f with
|
||||
| Expr.const fName _ _ =>
|
||||
|
|
|
|||
|
|
@ -1243,6 +1243,18 @@ def resolveName (n : Name) (preresolved : List (Name × List String)) (explicitL
|
|||
else
|
||||
process preresolved
|
||||
|
||||
def resolveId? (stx : Syntax) (kind := "term") : TermElabM (Option Expr) :=
|
||||
match stx with
|
||||
| Syntax.ident _ _ val preresolved => do
|
||||
let rs ← try resolveName val preresolved [] catch _ => pure []
|
||||
let rs := rs.filter fun ⟨f, projs⟩ => projs.isEmpty
|
||||
let fs := rs.map fun (f, _) => f
|
||||
match fs with
|
||||
| [] => pure none
|
||||
| [f] => pure (some f)
|
||||
| _ => throwError! "ambiguous {kind}, use fully qualified name, possible interpretations {fs}"
|
||||
| _ => throwError "identifier expected"
|
||||
|
||||
@[builtinTermElab cdot] def elabBadCDot : TermElab := fun stx _ =>
|
||||
throwError "invalid occurrence of `·` notation, it must be surrounded by parentheses (e.g. `(· + 1)`)"
|
||||
|
||||
|
|
|
|||
33
tests/lean/run/binrel.lean
Normal file
33
tests/lean/run/binrel.lean
Normal file
|
|
@ -0,0 +1,33 @@
|
|||
def ex1 (x y : Nat) (i j : Int) :=
|
||||
binrel! Less x i
|
||||
|
||||
def ex2 (x y : Nat) (i j : Int) :=
|
||||
binrel! Less i x
|
||||
|
||||
def ex3 (x y : Nat) (i j : Int) :=
|
||||
binrel! Less (i + 1) x
|
||||
|
||||
def ex4 (x y : Nat) (i j : Int) :=
|
||||
binrel! Less i (x + 1)
|
||||
|
||||
def ex5 (x y : Nat) (i j : Int) :=
|
||||
binrel! Less i (x + y)
|
||||
|
||||
def ex6 (x y : Nat) (i j : Int) :=
|
||||
binrel! Less (i + j) (x + 0)
|
||||
|
||||
def ex7 (x y : Nat) (i j : Int) :=
|
||||
binrel! Less (i + j) (x + i)
|
||||
|
||||
def ex8 (x y : Nat) (i j : Int) :=
|
||||
binrel! Less (i + 0) (x + i)
|
||||
|
||||
def ex9 (n : UInt32) :=
|
||||
binrel! Less n 0xd800
|
||||
|
||||
def ex10 (x : Lean.Syntax) : Bool :=
|
||||
x.getArgs.all (binrel! BEq.beq ·.getKind `foo)
|
||||
|
||||
def ex11 (xs : Array (Nat × Nat)) :=
|
||||
let f a b := binrel! Less a.1 b.1
|
||||
f xs[1] xs[2]
|
||||
Loading…
Add table
Reference in a new issue