feat(library/init/lean/elaborator): use SyntaxNode to define TermElab and CommandElab

This commit is contained in:
Leonardo de Moura 2019-07-21 07:29:41 -07:00
parent 35d841e6ea
commit a535d348de
3 changed files with 52 additions and 19 deletions

View file

@ -16,6 +16,8 @@ structure ElabContext :=
(fileMap : FileMap)
structure ElabScope :=
(cmd : String)
(header : Name)
(options : Options := {})
structure ElabState :=
@ -38,8 +40,8 @@ end ElabException
abbrev Elab := ReaderT ElabContext (EState ElabException ElabState)
abbrev TermElab := Syntax → Elab Expr
abbrev CommandElab := Syntax → Elab Unit
abbrev TermElab := SyntaxNode → Elab Expr
abbrev CommandElab := SyntaxNode → Elab Unit
abbrev TermElabTable : Type := SMap SyntaxNodeKind TermElab Name.quickLt
abbrev CommandElabTable : Type := SMap SyntaxNodeKind CommandElab Name.quickLt
@ -208,24 +210,26 @@ do logError stx errorMsg;
throw (ElabException.other errorMsg)
def elabTerm (stx : Syntax) : Elab Expr :=
match stx with
| Syntax.node k _ => do
s ← get;
let tables := termElabAttribute.ext.getState s.env;
match tables.find k with
| some elab => elab stx
| none => logErrorAndThrow stx ("term elaborator failed, no support for syntax '" ++ toString k ++ "'")
| _ => throw (ElabException.other "term elaborator failed, unexpected syntax")
stx.ifNode
(fun n => do
s ← get;
let tables := termElabAttribute.ext.getState s.env;
let k := n.getKind;
match tables.find k with
| some elab => elab n
| none => logErrorAndThrow stx ("term elaborator failed, no support for syntax '" ++ toString k ++ "'"))
(fun _ => throw $ ElabException.other "term elaborator failed, unexpected syntax")
def elabCommand (stx : Syntax) : Elab Unit :=
match stx with
| Syntax.node k _ => do
s ← get;
let tables := commandElabAttribute.ext.getState s.env;
match tables.find k with
| some elab => elab stx
| none => logError stx ("command elaborator failed, no support for syntax '" ++ toString k ++ "'")
| _ => logErrorUsingCmdPos ("command elaborator failed, unexpected syntax")
stx.ifNode
(fun n => do
s ← get;
let tables := commandElabAttribute.ext.getState s.env;
let k := n.getKind;
match tables.find k with
| some elab => elab n
| none => logError stx ("command elaborator failed, no support for syntax '" ++ toString k ++ "'"))
(fun _ => logErrorUsingCmdPos ("command elaborator failed, unexpected syntax"))
structure FrontendState :=
(elabState : ElabState)

View file

@ -6,3 +6,4 @@ Authors: Leonardo de Moura
prelude
import init.lean.elaborator.basic
import init.lean.elaborator.elabstrategyattrs
import init.lean.elaborator.command

View file

@ -47,13 +47,23 @@ def Syntax.isMissing {α} : Syntax α → Bool
inductive IsNode {α} : Syntax α → Prop
| mk (kind : SyntaxNodeKind) (args : Array (Syntax α)) : IsNode (Syntax.node kind args)
def SyntaxNode (α : Type) : Type := {s : Syntax α // IsNode s }
def SyntaxNode (α : Type := Empty) : Type := {s : Syntax α // IsNode s }
def unreachIsNodeMissing {α β} (h : IsNode (@Syntax.missing α)) : β := False.elim (nomatch h)
def unreachIsNodeAtom {α β} {info val} (h : IsNode (@Syntax.atom α info val)) : β := False.elim (nomatch h)
def unreachIsNodeIdent {α β info rawVal val preresolved} (h : IsNode (@Syntax.ident α info rawVal val preresolved)) : β := False.elim (nomatch h)
def unreachIsNodeOther {α β} {a : α} (h : IsNode (Syntax.other a)) : β := False.elim (nomatch h)
namespace SyntaxNode
@[inline] def getKind {α} (n : SyntaxNode α) : SyntaxNodeKind :=
match n with
| ⟨Syntax.node k args, _⟩ => k
| ⟨Syntax.missing, h⟩ => unreachIsNodeMissing h
| ⟨Syntax.atom _ _, h⟩ => unreachIsNodeAtom h
| ⟨Syntax.ident _ _ _ _, h⟩ => unreachIsNodeIdent h
| ⟨Syntax.other _ , h⟩ => unreachIsNodeOther h
@[inline] def withArgs {α β} (n : SyntaxNode α) (fn : Array (Syntax α) → β) : β :=
match n with
| ⟨Syntax.node _ args, _⟩ => fn args
@ -62,6 +72,12 @@ match n with
| ⟨Syntax.ident _ _ _ _, h⟩ => unreachIsNodeIdent h
| ⟨Syntax.other _ , h⟩ => unreachIsNodeOther h
@[inline] def getNumArgs {α} (n : SyntaxNode α) : Nat :=
withArgs n $ fun args => args.size
@[inline] def getArg {α} (n : SyntaxNode α) (i : Nat) : Syntax α :=
withArgs n $ fun args => args.get i
@[inline] def updateArgs {α} (n : SyntaxNode α) (fn : Array (Syntax α) → Array (Syntax α)) : Syntax α :=
match n with
| ⟨Syntax.node kind args, _⟩ => Syntax.node kind (fn args)
@ -70,11 +86,23 @@ match n with
| ⟨Syntax.ident _ _ _ _, h⟩ => unreachIsNodeIdent h
| ⟨Syntax.other _, h⟩ => unreachIsNodeOther h
end SyntaxNode
namespace Syntax
@[inline] def ifNode {α β} (stx : Syntax α) (hyes : SyntaxNode α → β) (hno : Unit → β) : β :=
match stx with
| Syntax.node k args => hyes ⟨Syntax.node k args, IsNode.mk k args⟩
| _ => hno ()
def isIdent {α} : Syntax α → Bool
| (ident _ _ _ _) := true
| _ := false
def getIdentVal {α} : Syntax α → Option Name
| (ident _ _ val _) := val
| _ := none
def isOfKind {α} : Syntax α → SyntaxNodeKind → Bool
| (node kind _) k := k == kind
| _ _ := false