From f90e4ae30ceeee35cc405ca4c809da93e16bbc3c Mon Sep 17 00:00:00 2001 From: Sebastian Ullrich Date: Wed, 15 Jun 2022 15:23:59 +0200 Subject: [PATCH] feat: more TSyntax API & coercions --- src/Init/Data/Format/Syntax.lean | 3 + src/Init/Data/Option/Basic.lean | 3 - src/Init/Meta.lean | 101 ++++++++++++++++++++++++--- src/Init/Notation.lean | 10 ++- src/Init/Prelude.lean | 40 +++++++---- src/Lean/Elab/Print.lean | 2 +- src/Lean/Elab/Tactic/Conv/Congr.lean | 2 +- src/Lean/Message.lean | 1 + src/Lean/Parser/Term.lean | 6 ++ 9 files changed, 138 insertions(+), 30 deletions(-) diff --git a/src/Init/Data/Format/Syntax.lean b/src/Init/Data/Format/Syntax.lean index a970c91294..4a551e7185 100644 --- a/src/Init/Data/Format/Syntax.lean +++ b/src/Init/Data/Format/Syntax.lean @@ -44,4 +44,7 @@ def formatStx (stx : Syntax) (maxDepth : Option Nat := none) (showInfo := false) instance : ToFormat (Syntax) := ⟨formatStx⟩ instance : ToString (Syntax) := ⟨@toString Format _ ∘ format⟩ +instance : ToFormat (TSyntax k) := ⟨(format ·.raw)⟩ +instance : ToString (TSyntax k) := ⟨(toString ·.raw)⟩ + end Lean.Syntax diff --git a/src/Init/Data/Option/Basic.lean b/src/Init/Data/Option/Basic.lean index 8792e47830..895628da58 100644 --- a/src/Init/Data/Option/Basic.lean +++ b/src/Init/Data/Option/Basic.lean @@ -34,9 +34,6 @@ def toMonad [Monad m] [Alternative m] : Option α → m α | none, _ => none | some a, b => b a -@[inline] protected def map (f : α → β) (o : Option α) : Option β := - Option.bind o (some ∘ f) - @[inline] protected def mapM [Monad m] (f : α → m β) (o : Option α) : m (Option β) := do if let some a := o then return some (← f a) diff --git a/src/Init/Meta.lean b/src/Init/Meta.lean index 5d17b19af9..6e09307597 100644 --- a/src/Init/Meta.lean +++ b/src/Init/Meta.lean @@ -7,6 +7,7 @@ Additional goodies for writing macros -/ prelude import Init.Data.Array.Basic +import Init.Data.Option.BasicAux namespace Lean @@ -243,6 +244,47 @@ instance monadNameGeneratorLift (m n : Type → Type) [MonadLift m n] [MonadName setNGen := fun ngen => liftM (setNGen ngen : m _) } +namespace TSyntax + +instance : Coe (TSyntax [k]) (TSyntax (k :: ks)) where + coe stx := ⟨stx⟩ + +instance [Coe (TSyntax [k]) (TSyntax ks)] : Coe (TSyntax [k]) (TSyntax (k' :: ks)) where + coe stx := ⟨stx⟩ + +instance : Coe (TSyntax identKind) (TSyntax `term) where + coe s := ⟨s.raw⟩ + +instance : CoeDep (TSyntax `term) ⟨Syntax.ident info ss n res⟩ (TSyntax `ident) where + coe := ⟨Syntax.ident info ss n res⟩ + +instance : Coe (TSyntax strLitKind) (TSyntax `term) where + coe s := ⟨s.raw⟩ + +instance : Coe (TSyntax nameLitKind) (TSyntax `term) where + coe s := ⟨s.raw⟩ + +instance : Coe (TSyntax numLitKind) (TSyntax `term) where + coe s := ⟨s.raw⟩ + +instance : Coe (TSyntax charLitKind) (TSyntax `term) where + coe s := ⟨s.raw⟩ + +instance : Coe (TSyntax numLitKind) (TSyntax `prec) where + coe s := ⟨s.raw⟩ + +namespace Compat + +scoped instance : CoeTail Syntax (TSyntax k) where + coe s := ⟨s⟩ + +scoped instance : CoeTail (Array Syntax) (TSyntaxArray k) where + coe := .mk + +end Compat + +end TSyntax + namespace Syntax partial def structEq : Syntax → Syntax → Bool @@ -253,6 +295,7 @@ partial def structEq : Syntax → Syntax → Bool | _, _ => false instance : BEq Lean.Syntax := ⟨structEq⟩ +instance : BEq (Lean.TSyntax k) := ⟨(·.raw == ·.raw)⟩ partial def getTailInfo? : Syntax → Option SourceInfo | atom info _ => info @@ -449,9 +492,12 @@ def SepArray.ofElemsUsingRef [Monad m] [MonadRef m] {sep} (elems : Array Syntax) let ref ← getRef; return ⟨mkSepArray elems (if sep.isEmpty then mkNullNode else mkAtomFrom ref sep)⟩ -instance (sep) : Coe (Array Syntax) (SepArray sep) where +instance : Coe (Array Syntax) (SepArray sep) where coe := SepArray.ofElems +instance : Coe (TSyntaxArray k) (TSepArray k sep) where + coe a := ⟨mkSepArray a.raw (mkAtom sep)⟩ + /-- Create syntax representing a Lean term application, but avoid degenerate empty applications. -/ def mkApp (fn : Syntax) : (args : Array Syntax) → Syntax | #[] => fn @@ -759,11 +805,6 @@ def isNone (stx : Syntax) : Bool := | Syntax.missing => true | _ => false -def getOptional? (stx : Syntax) : Option Syntax := - match stx with - | Syntax.node _ k args => if k == nullKind && args.size == 1 then some (args.get! 0) else none - | _ => none - def getOptionalIdent? (stx : Syntax) : Option Name := match stx.getOptional? with | some stx => some stx.getId @@ -778,6 +819,26 @@ def find? (stx : Syntax) (p : Syntax → Bool) : Option Syntax := end Syntax +namespace TSyntax + +def getNat (s : TSyntax numLitKind) : Nat := + s.raw.isNatLit?.get! + +def getId (s : TSyntax identKind) : Name := + s.raw.getId + +def getString (s : TSyntax strLitKind) : String := + s.raw.isStrLit?.get! + +namespace Compat + +scoped instance : CoeTail (Array Syntax) (Syntax.TSepArray k sep) where + coe a := (a : TSyntaxArray k) + +end Compat + +end TSyntax + /-- Reflect a runtime datum back to surface syntax (best-effort). -/ class Quote (α : Type) where quote : α → Syntax @@ -915,20 +976,38 @@ def mapSepElems (a : Array Syntax) (f : Syntax → Syntax) : Array Syntax := end Array -namespace Lean.Syntax.SepArray +namespace Lean.Syntax -def getElems {sep} (sa : SepArray sep) : Array Syntax := +def SepArray.getElems (sa : SepArray sep) : Array Syntax := sa.elemsAndSeps.getSepElems +def TSepArray.getElems (sa : TSepArray k sep) : TSyntaxArray k := + .mk sa.elemsAndSeps.getSepElems + /- We use `CoeTail` here instead of `Coe` to avoid a "loop" when computing `CoeTC`. The "loop" is interrupted using the maximum instance size threshold, but it is a performance bottleneck. The loop occurs because the predicate `isNewAnswer` is too imprecise. -/ -instance (sep) : CoeTail (SepArray sep) (Array Syntax) where - coe := getElems +instance : CoeTail (SepArray sep) (Array Syntax) where + coe := SepArray.getElems -end Lean.Syntax.SepArray +instance : Coe (TSepArray k sep) (TSyntaxArray k) where + coe := TSepArray.getElems + +instance [Coe (TSyntax k) (TSyntax k')] : Coe (TSyntaxArray k) (TSyntaxArray k') where + coe a := .mk a.raw + +instance : Coe (TSyntaxArray k) (Array Syntax) where + coe a := a.raw + +instance : Coe (TSyntax identKind) (TSyntax `Lean.Parser.Command.declId) where + coe id := mkNode _ #[id, mkNullNode #[]] + +instance : Coe (Lean.TSyntax `term) (Lean.TSyntax `Lean.Parser.Term.funBinder) where + coe stx := ⟨stx⟩ + +end Lean.Syntax set_option linter.unusedVariables.funArgs false in /-- diff --git a/src/Init/Notation.lean b/src/Init/Notation.lean index b9c4282f5e..13933fc082 100644 --- a/src/Init/Notation.lean +++ b/src/Init/Notation.lean @@ -22,8 +22,13 @@ syntax:65 (name := subPrio) prio " - " prio:66 : prio end Lean.Parser.Syntax namespace Lean -instance : Coe (TSyntax k) Syntax where + +instance : Coe (TSyntax ks) Syntax where coe stx := stx.raw + +instance : Coe SyntaxNodeKind SyntaxNodeKinds where + coe k := List.cons k List.nil + end Lean macro "max" : prec => `(1024) -- maximum precedence used in term parsers, in particular for terms in function position (`ident`, `paren`, ...) @@ -228,3 +233,6 @@ macro tk:"this" : term => return Syntax.ident tk.getHeadInfo "this".toSubstring Category for carrying raw syntax trees between macros; any content is printed as is by the pretty printer. The only accepted parser for this category is an antiquotation. -/ declare_syntax_cat rawStx + +instance : Coe Syntax (TSyntax `rawStx) where + coe stx := ⟨stx⟩ diff --git a/src/Init/Prelude.lean b/src/Init/Prelude.lean index 9f6b03973c..520306785a 100644 --- a/src/Init/Prelude.lean +++ b/src/Init/Prelude.lean @@ -1065,6 +1065,10 @@ instance {α} : Inhabited (Option α) where | some x, _ => x | none, e => e +@[inline] protected def Option.map (f : α → β) : Option α → Option β + | some x => some (f x) + | none => none + inductive List (α : Type u) where | nil : List α | cons (head : α) (tail : List α) : List α @@ -1353,7 +1357,7 @@ def Array.sequenceMap {α : Type u} {β : Type v} {m : Type v → Type w} [Monad | 0 => pure bs | Nat.succ i' => Bind.bind (f (as.get ⟨j, hlt⟩)) fun b => loop i' (hAdd j 1) (bs.push b)) (fun _ => pure bs) - loop as.size 0 Array.empty + loop as.size 0 (Array.mkEmpty as.size) /-- A Function for lifting a computation from an inner Monad to an outer Monad. Like [MonadTrans](https://hackage.haskell.org/package/transformers-0.5.5.0/docs/Control-Monad-Trans-Class.html), @@ -1849,19 +1853,22 @@ structure TSyntax (ks : SyntaxNodeKinds) where instance : Inhabited Syntax where default := Syntax.missing +instance : Inhabited (TSyntax ks) where + default := ⟨default⟩ + /- Builtin kinds -/ -def choiceKind : SyntaxNodeKind := `choice -def nullKind : SyntaxNodeKind := `null -def groupKind : SyntaxNodeKind := `group -def identKind : SyntaxNodeKind := `ident -def strLitKind : SyntaxNodeKind := `str -def charLitKind : SyntaxNodeKind := `char -def numLitKind : SyntaxNodeKind := `num -def scientificLitKind : SyntaxNodeKind := `scientific -def nameLitKind : SyntaxNodeKind := `name -def fieldIdxKind : SyntaxNodeKind := `fieldIdx -def interpolatedStrLitKind : SyntaxNodeKind := `interpolatedStrLitKind -def interpolatedStrKind : SyntaxNodeKind := `interpolatedStrKind +abbrev choiceKind : SyntaxNodeKind := `choice +abbrev nullKind : SyntaxNodeKind := `null +abbrev groupKind : SyntaxNodeKind := `group +abbrev identKind : SyntaxNodeKind := `ident +abbrev strLitKind : SyntaxNodeKind := `str +abbrev charLitKind : SyntaxNodeKind := `char +abbrev numLitKind : SyntaxNodeKind := `num +abbrev scientificLitKind : SyntaxNodeKind := `scientific +abbrev nameLitKind : SyntaxNodeKind := `name +abbrev fieldIdxKind : SyntaxNodeKind := `fieldIdx +abbrev interpolatedStrLitKind : SyntaxNodeKind := `interpolatedStrLitKind +abbrev interpolatedStrKind : SyntaxNodeKind := `interpolatedStrKind namespace Syntax @@ -1902,6 +1909,13 @@ def getNumArgs (stx : Syntax) : Nat := | Syntax.node _ _ args => args.size | _ => 0 +def getOptional? (stx : Syntax) : Option Syntax := + match stx with + | Syntax.node _ k args => match and (beq k nullKind) (beq args.size 1) with + | true => some (args.get! 0) + | false => none + | _ => none + def isMissing : Syntax → Bool | Syntax.missing => true | _ => false diff --git a/src/Lean/Elab/Print.lean b/src/Lean/Elab/Print.lean index 1fb84deb74..addcb93559 100644 --- a/src/Lean/Elab/Print.lean +++ b/src/Lean/Elab/Print.lean @@ -83,7 +83,7 @@ private def printId (id : Syntax) : CommandElabM Unit := do @[builtinCommandElab «print»] def elabPrint : CommandElab | `(#print%$tk $id:ident) => withRef tk <| printId id - | `(#print%$tk $s:str) => logInfoAt tk s.isStrLit?.get! + | `(#print%$tk $s:str) => logInfoAt tk s.getString | _ => throwError "invalid #print command" namespace CollectAxioms diff --git a/src/Lean/Elab/Tactic/Conv/Congr.lean b/src/Lean/Elab/Tactic/Conv/Congr.lean index 4714ee0a81..4c920dec19 100644 --- a/src/Lean/Elab/Tactic/Conv/Congr.lean +++ b/src/Lean/Elab/Tactic/Conv/Congr.lean @@ -98,7 +98,7 @@ private def selectIdx (tacticName : String) (mvarIds : List (Option MVarId)) (i @[builtinTactic Lean.Parser.Tactic.Conv.arg] def evalArg : Tactic := fun stx => do match stx with | `(conv| arg $[@%$tk?]? $i:num) => - let i := i.isNatLit?.getD 0 + let i := i.getNat if i == 0 then throwError "invalid 'arg' conv tactic, index must be greater than 0" let i := i - 1 diff --git a/src/Lean/Message.lean b/src/Lean/Message.lean index c530afaf9e..45d1542084 100644 --- a/src/Lean/Message.lean +++ b/src/Lean/Message.lean @@ -294,6 +294,7 @@ instance : ToMessageData Level := ⟨MessageData.ofLevel⟩ instance : ToMessageData Name := ⟨MessageData.ofName⟩ instance : ToMessageData String := ⟨stringToMessageData⟩ instance : ToMessageData Syntax := ⟨MessageData.ofSyntax⟩ +instance : ToMessageData (TSyntax k) := ⟨(MessageData.ofSyntax ·)⟩ instance : ToMessageData Format := ⟨MessageData.ofFormat⟩ instance : ToMessageData MVarId := ⟨MessageData.ofGoal⟩ instance : ToMessageData MessageData := ⟨id⟩ diff --git a/src/Lean/Parser/Term.lean b/src/Lean/Parser/Term.lean index 9b4a23eb4b..0fc37ce678 100644 --- a/src/Lean/Parser/Term.lean +++ b/src/Lean/Parser/Term.lean @@ -143,6 +143,9 @@ def matchAlt (rhsParser : Parser := termParser) : Parser := work with other `rhsParser`s (of arity 1). -/ def matchAltExpr := matchAlt +instance : Coe (TSyntax ``matchAltExpr) (TSyntax ``matchAlt) where + coe stx := ⟨stx.raw⟩ + def matchAlts (rhsParser : Parser := termParser) : Parser := leading_parser withPosition $ many1Indent (ppLine >> matchAlt rhsParser) @@ -206,6 +209,9 @@ def letDecl := leading_parser (withAnonymousAntiquot := false) notFollowedBy -- `let`-declaration that is only included in the elaborated term if variable is still there @[builtinTermParser] def «let_tmp» := leading_parser:leadPrec withPosition ("let_tmp " >> letDecl) >> optSemicolon termParser +instance : Coe (TSyntax ``letIdBinder) (TSyntax ``funBinder) where + coe stx := ⟨stx⟩ -- `simpleBinderWithoutType` prevents using a proper quotation for this + -- like `let_fun` but with optional name def haveIdLhs := optional (ident >> many (ppSpace >> letIdBinder)) >> optType def haveIdDecl := leading_parser (withAnonymousAntiquot := false) atomic (haveIdLhs >> " := ") >> termParser