feat: more TSyntax API & coercions

This commit is contained in:
Sebastian Ullrich 2022-06-15 15:23:59 +02:00
parent a12cde41e1
commit f90e4ae30c
9 changed files with 138 additions and 30 deletions

View file

@ -44,4 +44,7 @@ def formatStx (stx : Syntax) (maxDepth : Option Nat := none) (showInfo := false)
instance : ToFormat (Syntax) := ⟨formatStx⟩
instance : ToString (Syntax) := ⟨@toString Format _ ∘ format⟩
instance : ToFormat (TSyntax k) := ⟨(format ·.raw)⟩
instance : ToString (TSyntax k) := ⟨(toString ·.raw)⟩
end Lean.Syntax

View file

@ -34,9 +34,6 @@ def toMonad [Monad m] [Alternative m] : Option α → m α
| none, _ => none
| some a, b => b a
@[inline] protected def map (f : α → β) (o : Option α) : Option β :=
Option.bind o (some ∘ f)
@[inline] protected def mapM [Monad m] (f : α → m β) (o : Option α) : m (Option β) := do
if let some a := o then
return some (← f a)

View file

@ -7,6 +7,7 @@ Additional goodies for writing macros
-/
prelude
import Init.Data.Array.Basic
import Init.Data.Option.BasicAux
namespace Lean
@ -243,6 +244,47 @@ instance monadNameGeneratorLift (m n : Type → Type) [MonadLift m n] [MonadName
setNGen := fun ngen => liftM (setNGen ngen : m _)
}
namespace TSyntax
instance : Coe (TSyntax [k]) (TSyntax (k :: ks)) where
coe stx := ⟨stx⟩
instance [Coe (TSyntax [k]) (TSyntax ks)] : Coe (TSyntax [k]) (TSyntax (k' :: ks)) where
coe stx := ⟨stx⟩
instance : Coe (TSyntax identKind) (TSyntax `term) where
coe s := ⟨s.raw⟩
instance : CoeDep (TSyntax `term) ⟨Syntax.ident info ss n res⟩ (TSyntax `ident) where
coe := ⟨Syntax.ident info ss n res⟩
instance : Coe (TSyntax strLitKind) (TSyntax `term) where
coe s := ⟨s.raw⟩
instance : Coe (TSyntax nameLitKind) (TSyntax `term) where
coe s := ⟨s.raw⟩
instance : Coe (TSyntax numLitKind) (TSyntax `term) where
coe s := ⟨s.raw⟩
instance : Coe (TSyntax charLitKind) (TSyntax `term) where
coe s := ⟨s.raw⟩
instance : Coe (TSyntax numLitKind) (TSyntax `prec) where
coe s := ⟨s.raw⟩
namespace Compat
scoped instance : CoeTail Syntax (TSyntax k) where
coe s := ⟨s⟩
scoped instance : CoeTail (Array Syntax) (TSyntaxArray k) where
coe := .mk
end Compat
end TSyntax
namespace Syntax
partial def structEq : Syntax → Syntax → Bool
@ -253,6 +295,7 @@ partial def structEq : Syntax → Syntax → Bool
| _, _ => false
instance : BEq Lean.Syntax := ⟨structEq⟩
instance : BEq (Lean.TSyntax k) := ⟨(·.raw == ·.raw)⟩
partial def getTailInfo? : Syntax → Option SourceInfo
| atom info _ => info
@ -449,9 +492,12 @@ def SepArray.ofElemsUsingRef [Monad m] [MonadRef m] {sep} (elems : Array Syntax)
let ref ← getRef;
return ⟨mkSepArray elems (if sep.isEmpty then mkNullNode else mkAtomFrom ref sep)⟩
instance (sep) : Coe (Array Syntax) (SepArray sep) where
instance : Coe (Array Syntax) (SepArray sep) where
coe := SepArray.ofElems
instance : Coe (TSyntaxArray k) (TSepArray k sep) where
coe a := ⟨mkSepArray a.raw (mkAtom sep)⟩
/-- Create syntax representing a Lean term application, but avoid degenerate empty applications. -/
def mkApp (fn : Syntax) : (args : Array Syntax) → Syntax
| #[] => fn
@ -759,11 +805,6 @@ def isNone (stx : Syntax) : Bool :=
| Syntax.missing => true
| _ => false
def getOptional? (stx : Syntax) : Option Syntax :=
match stx with
| Syntax.node _ k args => if k == nullKind && args.size == 1 then some (args.get! 0) else none
| _ => none
def getOptionalIdent? (stx : Syntax) : Option Name :=
match stx.getOptional? with
| some stx => some stx.getId
@ -778,6 +819,26 @@ def find? (stx : Syntax) (p : Syntax → Bool) : Option Syntax :=
end Syntax
namespace TSyntax
def getNat (s : TSyntax numLitKind) : Nat :=
s.raw.isNatLit?.get!
def getId (s : TSyntax identKind) : Name :=
s.raw.getId
def getString (s : TSyntax strLitKind) : String :=
s.raw.isStrLit?.get!
namespace Compat
scoped instance : CoeTail (Array Syntax) (Syntax.TSepArray k sep) where
coe a := (a : TSyntaxArray k)
end Compat
end TSyntax
/-- Reflect a runtime datum back to surface syntax (best-effort). -/
class Quote (α : Type) where
quote : α → Syntax
@ -915,20 +976,38 @@ def mapSepElems (a : Array Syntax) (f : Syntax → Syntax) : Array Syntax :=
end Array
namespace Lean.Syntax.SepArray
namespace Lean.Syntax
def getElems {sep} (sa : SepArray sep) : Array Syntax :=
def SepArray.getElems (sa : SepArray sep) : Array Syntax :=
sa.elemsAndSeps.getSepElems
def TSepArray.getElems (sa : TSepArray k sep) : TSyntaxArray k :=
.mk sa.elemsAndSeps.getSepElems
/-
We use `CoeTail` here instead of `Coe` to avoid a "loop" when computing `CoeTC`.
The "loop" is interrupted using the maximum instance size threshold, but it is a performance bottleneck.
The loop occurs because the predicate `isNewAnswer` is too imprecise.
-/
instance (sep) : CoeTail (SepArray sep) (Array Syntax) where
coe := getElems
instance : CoeTail (SepArray sep) (Array Syntax) where
coe := SepArray.getElems
end Lean.Syntax.SepArray
instance : Coe (TSepArray k sep) (TSyntaxArray k) where
coe := TSepArray.getElems
instance [Coe (TSyntax k) (TSyntax k')] : Coe (TSyntaxArray k) (TSyntaxArray k') where
coe a := .mk a.raw
instance : Coe (TSyntaxArray k) (Array Syntax) where
coe a := a.raw
instance : Coe (TSyntax identKind) (TSyntax `Lean.Parser.Command.declId) where
coe id := mkNode _ #[id, mkNullNode #[]]
instance : Coe (Lean.TSyntax `term) (Lean.TSyntax `Lean.Parser.Term.funBinder) where
coe stx := ⟨stx⟩
end Lean.Syntax
set_option linter.unusedVariables.funArgs false in
/--

View file

@ -22,8 +22,13 @@ syntax:65 (name := subPrio) prio " - " prio:66 : prio
end Lean.Parser.Syntax
namespace Lean
instance : Coe (TSyntax k) Syntax where
instance : Coe (TSyntax ks) Syntax where
coe stx := stx.raw
instance : Coe SyntaxNodeKind SyntaxNodeKinds where
coe k := List.cons k List.nil
end Lean
macro "max" : prec => `(1024) -- maximum precedence used in term parsers, in particular for terms in function position (`ident`, `paren`, ...)
@ -228,3 +233,6 @@ macro tk:"this" : term => return Syntax.ident tk.getHeadInfo "this".toSubstring
Category for carrying raw syntax trees between macros; any content is printed as is by the pretty printer.
The only accepted parser for this category is an antiquotation. -/
declare_syntax_cat rawStx
instance : Coe Syntax (TSyntax `rawStx) where
coe stx := ⟨stx⟩

View file

@ -1065,6 +1065,10 @@ instance {α} : Inhabited (Option α) where
| some x, _ => x
| none, e => e
@[inline] protected def Option.map (f : α → β) : Option α → Option β
| some x => some (f x)
| none => none
inductive List (α : Type u) where
| nil : List α
| cons (head : α) (tail : List α) : List α
@ -1353,7 +1357,7 @@ def Array.sequenceMap {α : Type u} {β : Type v} {m : Type v → Type w} [Monad
| 0 => pure bs
| Nat.succ i' => Bind.bind (f (as.get ⟨j, hlt⟩)) fun b => loop i' (hAdd j 1) (bs.push b))
(fun _ => pure bs)
loop as.size 0 Array.empty
loop as.size 0 (Array.mkEmpty as.size)
/-- A Function for lifting a computation from an inner Monad to an outer Monad.
Like [MonadTrans](https://hackage.haskell.org/package/transformers-0.5.5.0/docs/Control-Monad-Trans-Class.html),
@ -1849,19 +1853,22 @@ structure TSyntax (ks : SyntaxNodeKinds) where
instance : Inhabited Syntax where
default := Syntax.missing
instance : Inhabited (TSyntax ks) where
default := ⟨default⟩
/- Builtin kinds -/
def choiceKind : SyntaxNodeKind := `choice
def nullKind : SyntaxNodeKind := `null
def groupKind : SyntaxNodeKind := `group
def identKind : SyntaxNodeKind := `ident
def strLitKind : SyntaxNodeKind := `str
def charLitKind : SyntaxNodeKind := `char
def numLitKind : SyntaxNodeKind := `num
def scientificLitKind : SyntaxNodeKind := `scientific
def nameLitKind : SyntaxNodeKind := `name
def fieldIdxKind : SyntaxNodeKind := `fieldIdx
def interpolatedStrLitKind : SyntaxNodeKind := `interpolatedStrLitKind
def interpolatedStrKind : SyntaxNodeKind := `interpolatedStrKind
abbrev choiceKind : SyntaxNodeKind := `choice
abbrev nullKind : SyntaxNodeKind := `null
abbrev groupKind : SyntaxNodeKind := `group
abbrev identKind : SyntaxNodeKind := `ident
abbrev strLitKind : SyntaxNodeKind := `str
abbrev charLitKind : SyntaxNodeKind := `char
abbrev numLitKind : SyntaxNodeKind := `num
abbrev scientificLitKind : SyntaxNodeKind := `scientific
abbrev nameLitKind : SyntaxNodeKind := `name
abbrev fieldIdxKind : SyntaxNodeKind := `fieldIdx
abbrev interpolatedStrLitKind : SyntaxNodeKind := `interpolatedStrLitKind
abbrev interpolatedStrKind : SyntaxNodeKind := `interpolatedStrKind
namespace Syntax
@ -1902,6 +1909,13 @@ def getNumArgs (stx : Syntax) : Nat :=
| Syntax.node _ _ args => args.size
| _ => 0
def getOptional? (stx : Syntax) : Option Syntax :=
match stx with
| Syntax.node _ k args => match and (beq k nullKind) (beq args.size 1) with
| true => some (args.get! 0)
| false => none
| _ => none
def isMissing : Syntax → Bool
| Syntax.missing => true
| _ => false

View file

@ -83,7 +83,7 @@ private def printId (id : Syntax) : CommandElabM Unit := do
@[builtinCommandElab «print»] def elabPrint : CommandElab
| `(#print%$tk $id:ident) => withRef tk <| printId id
| `(#print%$tk $s:str) => logInfoAt tk s.isStrLit?.get!
| `(#print%$tk $s:str) => logInfoAt tk s.getString
| _ => throwError "invalid #print command"
namespace CollectAxioms

View file

@ -98,7 +98,7 @@ private def selectIdx (tacticName : String) (mvarIds : List (Option MVarId)) (i
@[builtinTactic Lean.Parser.Tactic.Conv.arg] def evalArg : Tactic := fun stx => do
match stx with
| `(conv| arg $[@%$tk?]? $i:num) =>
let i := i.isNatLit?.getD 0
let i := i.getNat
if i == 0 then
throwError "invalid 'arg' conv tactic, index must be greater than 0"
let i := i - 1

View file

@ -294,6 +294,7 @@ instance : ToMessageData Level := ⟨MessageData.ofLevel⟩
instance : ToMessageData Name := ⟨MessageData.ofName⟩
instance : ToMessageData String := ⟨stringToMessageData⟩
instance : ToMessageData Syntax := ⟨MessageData.ofSyntax⟩
instance : ToMessageData (TSyntax k) := ⟨(MessageData.ofSyntax ·)⟩
instance : ToMessageData Format := ⟨MessageData.ofFormat⟩
instance : ToMessageData MVarId := ⟨MessageData.ofGoal⟩
instance : ToMessageData MessageData := ⟨id⟩

View file

@ -143,6 +143,9 @@ def matchAlt (rhsParser : Parser := termParser) : Parser :=
work with other `rhsParser`s (of arity 1). -/
def matchAltExpr := matchAlt
instance : Coe (TSyntax ``matchAltExpr) (TSyntax ``matchAlt) where
coe stx := ⟨stx.raw⟩
def matchAlts (rhsParser : Parser := termParser) : Parser :=
leading_parser withPosition $ many1Indent (ppLine >> matchAlt rhsParser)
@ -206,6 +209,9 @@ def letDecl := leading_parser (withAnonymousAntiquot := false) notFollowedBy
-- `let`-declaration that is only included in the elaborated term if variable is still there
@[builtinTermParser] def «let_tmp» := leading_parser:leadPrec withPosition ("let_tmp " >> letDecl) >> optSemicolon termParser
instance : Coe (TSyntax ``letIdBinder) (TSyntax ``funBinder) where
coe stx := ⟨stx⟩ -- `simpleBinderWithoutType` prevents using a proper quotation for this
-- like `let_fun` but with optional name
def haveIdLhs := optional (ident >> many (ppSpace >> letIdBinder)) >> optType
def haveIdDecl := leading_parser (withAnonymousAntiquot := false) atomic (haveIdLhs >> " := ") >> termParser