diff --git a/library/init/lean/elaborator/basic.lean b/library/init/lean/elaborator/basic.lean index 1c9c4ac8ca..4b5f10ce7a 100644 --- a/library/init/lean/elaborator/basic.lean +++ b/library/init/lean/elaborator/basic.lean @@ -16,6 +16,8 @@ structure ElabContext := (fileMap : FileMap) structure ElabScope := +(cmd : String) +(header : Name) (options : Options := {}) structure ElabState := @@ -38,8 +40,8 @@ end ElabException abbrev Elab := ReaderT ElabContext (EState ElabException ElabState) -abbrev TermElab := Syntax → Elab Expr -abbrev CommandElab := Syntax → Elab Unit +abbrev TermElab := SyntaxNode → Elab Expr +abbrev CommandElab := SyntaxNode → Elab Unit abbrev TermElabTable : Type := SMap SyntaxNodeKind TermElab Name.quickLt abbrev CommandElabTable : Type := SMap SyntaxNodeKind CommandElab Name.quickLt @@ -208,24 +210,26 @@ do logError stx errorMsg; throw (ElabException.other errorMsg) def elabTerm (stx : Syntax) : Elab Expr := -match stx with -| Syntax.node k _ => do - s ← get; - let tables := termElabAttribute.ext.getState s.env; - match tables.find k with - | some elab => elab stx - | none => logErrorAndThrow stx ("term elaborator failed, no support for syntax '" ++ toString k ++ "'") -| _ => throw (ElabException.other "term elaborator failed, unexpected syntax") +stx.ifNode + (fun n => do + s ← get; + let tables := termElabAttribute.ext.getState s.env; + let k := n.getKind; + match tables.find k with + | some elab => elab n + | none => logErrorAndThrow stx ("term elaborator failed, no support for syntax '" ++ toString k ++ "'")) + (fun _ => throw $ ElabException.other "term elaborator failed, unexpected syntax") def elabCommand (stx : Syntax) : Elab Unit := -match stx with -| Syntax.node k _ => do - s ← get; - let tables := commandElabAttribute.ext.getState s.env; - match tables.find k with - | some elab => elab stx - | none => logError stx ("command elaborator failed, no support for syntax '" ++ toString k ++ "'") -| _ => logErrorUsingCmdPos ("command elaborator failed, unexpected syntax") +stx.ifNode + (fun n => do + s ← get; + let tables := commandElabAttribute.ext.getState s.env; + let k := n.getKind; + match tables.find k with + | some elab => elab n + | none => logError stx ("command elaborator failed, no support for syntax '" ++ toString k ++ "'")) + (fun _ => logErrorUsingCmdPos ("command elaborator failed, unexpected syntax")) structure FrontendState := (elabState : ElabState) diff --git a/library/init/lean/elaborator/default.lean b/library/init/lean/elaborator/default.lean index c37318fad5..59514011e1 100644 --- a/library/init/lean/elaborator/default.lean +++ b/library/init/lean/elaborator/default.lean @@ -6,3 +6,4 @@ Authors: Leonardo de Moura prelude import init.lean.elaborator.basic import init.lean.elaborator.elabstrategyattrs +import init.lean.elaborator.command diff --git a/library/init/lean/syntax.lean b/library/init/lean/syntax.lean index bdc82b589b..8cb5cb3f63 100644 --- a/library/init/lean/syntax.lean +++ b/library/init/lean/syntax.lean @@ -47,13 +47,23 @@ def Syntax.isMissing {α} : Syntax α → Bool inductive IsNode {α} : Syntax α → Prop | mk (kind : SyntaxNodeKind) (args : Array (Syntax α)) : IsNode (Syntax.node kind args) -def SyntaxNode (α : Type) : Type := {s : Syntax α // IsNode s } +def SyntaxNode (α : Type := Empty) : Type := {s : Syntax α // IsNode s } def unreachIsNodeMissing {α β} (h : IsNode (@Syntax.missing α)) : β := False.elim (nomatch h) def unreachIsNodeAtom {α β} {info val} (h : IsNode (@Syntax.atom α info val)) : β := False.elim (nomatch h) def unreachIsNodeIdent {α β info rawVal val preresolved} (h : IsNode (@Syntax.ident α info rawVal val preresolved)) : β := False.elim (nomatch h) def unreachIsNodeOther {α β} {a : α} (h : IsNode (Syntax.other a)) : β := False.elim (nomatch h) +namespace SyntaxNode + +@[inline] def getKind {α} (n : SyntaxNode α) : SyntaxNodeKind := +match n with +| ⟨Syntax.node k args, _⟩ => k +| ⟨Syntax.missing, h⟩ => unreachIsNodeMissing h +| ⟨Syntax.atom _ _, h⟩ => unreachIsNodeAtom h +| ⟨Syntax.ident _ _ _ _, h⟩ => unreachIsNodeIdent h +| ⟨Syntax.other _ , h⟩ => unreachIsNodeOther h + @[inline] def withArgs {α β} (n : SyntaxNode α) (fn : Array (Syntax α) → β) : β := match n with | ⟨Syntax.node _ args, _⟩ => fn args @@ -62,6 +72,12 @@ match n with | ⟨Syntax.ident _ _ _ _, h⟩ => unreachIsNodeIdent h | ⟨Syntax.other _ , h⟩ => unreachIsNodeOther h +@[inline] def getNumArgs {α} (n : SyntaxNode α) : Nat := +withArgs n $ fun args => args.size + +@[inline] def getArg {α} (n : SyntaxNode α) (i : Nat) : Syntax α := +withArgs n $ fun args => args.get i + @[inline] def updateArgs {α} (n : SyntaxNode α) (fn : Array (Syntax α) → Array (Syntax α)) : Syntax α := match n with | ⟨Syntax.node kind args, _⟩ => Syntax.node kind (fn args) @@ -70,11 +86,23 @@ match n with | ⟨Syntax.ident _ _ _ _, h⟩ => unreachIsNodeIdent h | ⟨Syntax.other _, h⟩ => unreachIsNodeOther h +end SyntaxNode + namespace Syntax + +@[inline] def ifNode {α β} (stx : Syntax α) (hyes : SyntaxNode α → β) (hno : Unit → β) : β := +match stx with +| Syntax.node k args => hyes ⟨Syntax.node k args, IsNode.mk k args⟩ +| _ => hno () + def isIdent {α} : Syntax α → Bool | (ident _ _ _ _) := true | _ := false +def getIdentVal {α} : Syntax α → Option Name +| (ident _ _ val _) := val +| _ := none + def isOfKind {α} : Syntax α → SyntaxNodeKind → Bool | (node kind _) k := k == kind | _ _ := false