feat: add helper Syntax.node* functions

This commit is contained in:
Leonardo de Moura 2022-10-16 08:20:01 -07:00
parent 729fd63b29
commit c20febff31
4 changed files with 68 additions and 26 deletions

View file

@ -3577,6 +3577,38 @@ inductive Syntax where
-/
| ident (info : SourceInfo) (rawVal : Substring) (val : Name) (preresolved : List Syntax.Preresolved) : Syntax
/-- Create syntax node with 1 child -/
def Syntax.node1 (info : SourceInfo) (kind : SyntaxNodeKind) (a₁ : Syntax) : Syntax :=
Syntax.node info kind (Array.mkArray1 a₁)
/-- Create syntax node with 2 children -/
def Syntax.node2 (info : SourceInfo) (kind : SyntaxNodeKind) (a₁ a₂ : Syntax) : Syntax :=
Syntax.node info kind (Array.mkArray2 a₁ a₂)
/-- Create syntax node with 3 children -/
def Syntax.node3 (info : SourceInfo) (kind : SyntaxNodeKind) (a₁ a₂ a₃ : Syntax) : Syntax :=
Syntax.node info kind (Array.mkArray3 a₁ a₂ a₃)
/-- Create syntax node with 4 children -/
def Syntax.node4 (info : SourceInfo) (kind : SyntaxNodeKind) (a₁ a₂ a₃ a₄ : Syntax) : Syntax :=
Syntax.node info kind (Array.mkArray4 a₁ a₂ a₃ a₄)
/-- Create syntax node with 5 children -/
def Syntax.node5 (info : SourceInfo) (kind : SyntaxNodeKind) (a₁ a₂ a₃ a₄ a₅ : Syntax) : Syntax :=
Syntax.node info kind (Array.mkArray5 a₁ a₂ a₃ a₄ a₅)
/-- Create syntax node with 6 children -/
def Syntax.node6 (info : SourceInfo) (kind : SyntaxNodeKind) (a₁ a₂ a₃ a₄ a₅ a₆ : Syntax) : Syntax :=
Syntax.node info kind (Array.mkArray6 a₁ a₂ a₃ a₄ a₅ a₆)
/-- Create syntax node with 7 children -/
def Syntax.node7 (info : SourceInfo) (kind : SyntaxNodeKind) (a₁ a₂ a₃ a₄ a₅ a₆ a₇ : Syntax) : Syntax :=
Syntax.node info kind (Array.mkArray7 a₁ a₂ a₃ a₄ a₅ a₆ a₇)
/-- Create syntax node with 8 children -/
def Syntax.node8 (info : SourceInfo) (kind : SyntaxNodeKind) (a₁ a₂ a₃ a₄ a₅ a₆ a₇ a₈ : Syntax) : Syntax :=
Syntax.node info kind (Array.mkArray8 a₁ a₂ a₃ a₄ a₅ a₆ a₇ a₈)
/-- `SyntaxNodeKinds` is a set of `SyntaxNodeKind` (implemented as a list). -/
def SyntaxNodeKinds := List SyntaxNodeKind

View file

@ -77,19 +77,32 @@ def ArrayStxBuilder := Sum (Array Term) Term
namespace ArrayStxBuilder
def empty : ArrayStxBuilder := Sum.inl #[]
def empty : ArrayStxBuilder := .inl #[]
def build : ArrayStxBuilder → Term
| Sum.inl elems => quote elems
| Sum.inr arr => arr
| .inl elems => quote elems
| .inr arr => arr
def push (b : ArrayStxBuilder) (elem : Syntax) : ArrayStxBuilder :=
match b with
| Sum.inl elems => Sum.inl <| elems.push elem
| Sum.inr arr => Sum.inr <| mkCApp ``Array.push #[arr, elem]
| .inl elems => .inl <| elems.push elem
| .inr arr => .inr <| mkCApp ``Array.push #[arr, elem]
def append (b : ArrayStxBuilder) (arr : Syntax) (appendName := ``Array.append) : ArrayStxBuilder :=
Sum.inr <| mkCApp appendName #[b.build, arr]
.inr <| mkCApp appendName #[b.build, arr]
def mkNode (b : ArrayStxBuilder) (k : SyntaxNodeKind) : TermElabM Term := do
let k := quote k
match b with
| .inl #[a₁] => `(Syntax.node1 info $(k) $(a₁))
| .inl #[a₁, a₂] => `(Syntax.node2 info $(k) $(a₁) $(a₂))
| .inl #[a₁, a₂, a₃] => `(Syntax.node3 info $(k) $(a₁) $(a₂) $(a₃))
| .inl #[a₁, a₂, a₃, a₄] => `(Syntax.node4 info $(k) $(a₁) $(a₂) $(a₃) $(a₄))
| .inl #[a₁, a₂, a₃, a₄, a₅] => `(Syntax.node5 info $(k) $(a₁) $(a₂) $(a₃) $(a₄) $(a₅))
| .inl #[a₁, a₂, a₃, a₄, a₅, a₆] => `(Syntax.node6 info $(k) $(a₁) $(a₂) $(a₃) $(a₄) $(a₅) $(a₆))
| .inl #[a₁, a₂, a₃, a₄, a₅, a₆, a₇] => `(Syntax.node7 info $(k) $(a₁) $(a₂) $(a₃) $(a₄) $(a₅) $(a₆) $(a₇))
| .inl #[a₁, a₂, a₃, a₄, a₅, a₆, a₇, a₈] => `(Syntax.node8 info $(k) $(a₁) $(a₂) $(a₃) $(a₄) $(a₅) $(a₆) $(a₇) $(a₈))
| _ => `(Syntax.node info $(k) $(b.build))
end ArrayStxBuilder
@ -186,7 +199,7 @@ private partial def quoteSyntax : Syntax → TermElabM Term
else do
let arg ← quoteSyntax arg
args := args.push arg
`(Syntax.node info $(quote k) $(args.build))
args.mkNode k
| Syntax.atom _ val =>
`(Syntax.atom info $(quote val))
| Syntax.missing => throwUnsupportedSyntax

View file

@ -16,11 +16,10 @@
let mainModule ← Lean.getMainModule
pure
{ raw :=
Lean.Syntax.node info `Lean.Parser.Term.app
#[Lean.Syntax.ident info (String.toSubstring' "Nat.add")
(Lean.addMacroScope mainModule `Nat.add scp)
[Lean.Syntax.Preresolved.decl `Nat.add [], Lean.Syntax.Preresolved.namespace `Nat.add],
Lean.Syntax.node info `null #[lhs.raw, rhs.raw]] }.raw
Lean.Syntax.node2 info `Lean.Parser.Term.app
(Lean.Syntax.ident info (String.toSubstring' "Nat.add") (Lean.addMacroScope mainModule `Nat.add scp)
[Lean.Syntax.Preresolved.decl `Nat.add [], Lean.Syntax.Preresolved.namespace `Nat.add])
(Lean.Syntax.node2 info `null lhs.raw rhs.raw) }.raw
else
let_fun __discr := x;
throw Lean.Macro.Exception.unsupportedSyntax
@ -41,7 +40,7 @@
let info ← Lean.MonadRef.mkInfoFromRefPos
let _ ← Lean.getCurrMacroScope
let _ ← Lean.getMainModule
pure { raw := Lean.Syntax.node info `term_+++_ #[lhs.raw, Lean.Syntax.atom info "+++", rhs.raw] }.raw
pure { raw := Lean.Syntax.node3 info `term_+++_ lhs.raw (Lean.Syntax.atom info "+++") rhs.raw }.raw
else
let_fun __discr := Lean.Syntax.getArg __discr 1;
throw ()
@ -65,9 +64,9 @@
let _ ← Lean.getMainModule
pure
{ raw :=
Lean.Syntax.node info `Lean.Parser.Term.app
#[Lean.Syntax.node info `term_+++_ #[lhs.raw, Lean.Syntax.atom info "+++", rhs.raw],
Lean.Syntax.node info `null (Array.append #[] (Lean.TSyntaxArray.raw moreArgs))] }.raw
Lean.Syntax.node2 info `Lean.Parser.Term.app
(Lean.Syntax.node3 info `term_+++_ lhs.raw (Lean.Syntax.atom info "+++") rhs.raw)
(Lean.Syntax.node info `null (Array.append #[] (Lean.TSyntaxArray.raw moreArgs))) }.raw
else
let_fun __discr_5 := Lean.Syntax.getArg __discr_2 1;
let_fun __discr := Lean.Syntax.getArg __discr 1;

View file

@ -58,13 +58,12 @@ def foo (x_1 : obj) : obj :=
let x_2 : obj := Nat.repr x_1;
let x_3 : obj := ctor_2[Lean.SourceInfo.none];
let x_4 : obj := Lean.Syntax.mkNumLit x_2 x_3;
let x_5 : obj := foo._closed_10;
let x_6 : obj := foo._closed_12;
let x_7 : obj := Array.mkArray3._rarg x_5 x_6 x_4;
let x_8 : obj := foo._closed_1;
let x_9 : obj := foo._closed_5;
let x_10 : obj := ctor_1[Lean.Syntax.node] x_8 x_9 x_7;
ret x_10[Compiler.result] size: 19
let x_5 : obj := foo._closed_1;
let x_6 : obj := foo._closed_5;
let x_7 : obj := foo._closed_10;
let x_8 : obj := foo._closed_12;
let x_9 : obj := Lean.Syntax.node3 x_5 x_6 x_7 x_8 x_4;
ret x_9[Compiler.result] size: 18
def foo n : Syntax :=
let fst.1 := Syntax.missing
let fst.2 := 1
@ -83,6 +82,5 @@ def foo (x_1 : obj) : obj :=
let _x.15 := "+"
let _x.16 := Syntax.atom fst.4 _x.15
let _x.17 := Lean.instQuoteNatNumLitKind._elambda_0 n
let _x.18 := Array.mkArray3 _ _x.14 _x.16 _x.17
let _x.19 := Syntax.node fst.4 _x.8 _x.18
_x.19
let _x.18 := Syntax.node3 fst.4 _x.8 _x.14 _x.16 _x.17
_x.18