diff --git a/tests/playground/lowtech_expander.lean b/tests/playground/lowtech_expander.lean index c707784068..7905258ceb 100644 --- a/tests/playground/lowtech_expander.lean +++ b/tests/playground/lowtech_expander.lean @@ -39,6 +39,9 @@ do m ← nameToKindTable.get, nextUniqId.set (id+1), pure { name := k, id := id } +def mkNullKind : IO SyntaxNodeKind := nextKind `null +@[init mkNullKind] constant nullKind : SyntaxNodeKind := default _ + inductive Syntax | missing | node (kind : SyntaxNodeKind) (args : Array Syntax) (scopes : MacroScopes) @@ -67,7 +70,7 @@ inductive IsNode : Syntax → Prop def SyntaxNode : Type := {s : Syntax // IsNode s } -def notIsNodeMissing (h : IsNode Syntax.missing) : False := match h with end +def notIsNodeMissing (h : IsNode Syntax.missing) : False := match h with end def notIsNodeAtom {info val} (h : IsNode (Syntax.atom info val)) : False := match h with end def notIsNodeIdent {info rawVal val preresolved scopes} (h : IsNode (Syntax.ident info rawVal val preresolved scopes)) : False := match h with end @@ -87,13 +90,46 @@ match s with else base | other := base +@[inline] def mkAtom (val : String) : Syntax := +Syntax.atom none val + +def mkOptionSomeKind : IO SyntaxNodeKind := nextKind `some +@[init mkOptionSomeKind] constant optionSomeKind : SyntaxNodeKind := default _ +def mkOptionNoneKind : IO SyntaxNodeKind := nextKind `none +@[init mkOptionSomeKind] constant optionNoneKind : SyntaxNodeKind := default _ +def mkManyKind : IO SyntaxNodeKind := nextKind `many +@[init mkManyKind] constant manyKind : SyntaxNodeKind := default _ +def mkHoleKind : IO SyntaxNodeKind := nextKind `hole +@[init mkHoleKind] constant holeKind : SyntaxNodeKind := default _ def mkNotKind : IO SyntaxNodeKind := nextKind `not @[init mkNotKind] constant notKind : SyntaxNodeKind := default _ def mkIfKind : IO SyntaxNodeKind := nextKind `if @[init mkIfKind] constant ifKind : SyntaxNodeKind := default _ +def mkLetKind : IO SyntaxNodeKind := nextKind `let +@[init mkLetKind] constant letKind : SyntaxNodeKind := default _ +def mkLetLhsIdKind : IO SyntaxNodeKind := nextKind `letLhsId +@[init mkLetLhsIdKind] constant letLhsIdKind : SyntaxNodeKind := default _ +def mkLetLhsPatternKind : IO SyntaxNodeKind := nextKind `letLhsPattern +@[init mkLetLhsPatternKind] constant letLhsPatternKind : SyntaxNodeKind := default _ -@[inline] def mkAtom (val : String) : Syntax := -Syntax.atom none val +@[inline] def Syntax.getKind (n : Syntax) : SyntaxNodeKind := +match n with +| Syntax.node k _ _ := k +| other := nullKind + +@[inline] def withArgs {α : Type} (n : SyntaxNode) (fn : Array Syntax → α) : α := +match n with +| ⟨Syntax.node _ args _, _⟩ := fn args +| ⟨Syntax.missing, h⟩ := unreachIsNodeMissing h +| ⟨Syntax.atom _ _, h⟩ := unreachIsNodeAtom h +| ⟨Syntax.ident _ _ _ _ _, h⟩ := unreachIsNodeIdent h + +@[inline] def updateArgs (n : SyntaxNode) (fn : Array Syntax → Array Syntax) : Syntax := +match n with +| ⟨Syntax.node kind args scopes, _⟩ := Syntax.node kind (fn args) scopes +| ⟨Syntax.missing, h⟩ := unreachIsNodeMissing h +| ⟨Syntax.atom _ _, h⟩ := unreachIsNodeAtom h +| ⟨Syntax.ident _ _ _ _ _, h⟩ := unreachIsNodeIdent h @[inline] def mkNotAux (tk : Syntax) (c : Syntax) : Syntax := Syntax.node notKind [tk, c].toArray [] @@ -102,25 +138,10 @@ Syntax.node notKind [tk, c].toArray [] mkNotAux (mkAtom "not") c @[inline] def withNot {α : Type} (n : SyntaxNode) (fn : Syntax → α) : α := -match n with -| ⟨Syntax.node _ args _, _⟩ := fn (args.get 1) -| ⟨Syntax.missing, h⟩ := unreachIsNodeMissing h -| ⟨Syntax.atom _ _, h⟩ := unreachIsNodeAtom h -| ⟨Syntax.ident _ _ _ _ _, h⟩ := unreachIsNodeIdent h - -@[inline] def isNot {α : Type} (n : Syntax) (base : α) (fn : Syntax → α) : α := -match n with -| Syntax.node k args _ := if k == notKind then fn (args.get 1) else base -| Syntax.missing := base -| Syntax.atom _ _ := base -| Syntax.ident _ _ _ _ _ := base +withArgs n $ λ args, fn (args.get 1) @[inline] def updateNot (src : SyntaxNode) (c : Syntax) : Syntax := -match src with -| ⟨Syntax.node kind args scopes, _⟩ := Syntax.node kind (args.set 1 c) scopes -| ⟨Syntax.missing, h⟩ := unreachIsNodeMissing h -| ⟨Syntax.atom _ _, h⟩ := unreachIsNodeAtom h -| ⟨Syntax.ident _ _ _ _ _, h⟩ := unreachIsNodeIdent h +updateArgs src $ λ args, args.set 1 c @[inline] def mkIfAux (ifTk : Syntax) (condNode : Syntax) (thenTk : Syntax) (thenNode : Syntax) (elseTk : Syntax) (elseNode: Syntax) : Syntax := Syntax.node ifKind [ifTk, condNode, thenTk, thenNode, elseTk, elseNode].toArray [] @@ -128,38 +149,129 @@ Syntax.node ifKind [ifTk, condNode, thenTk, thenNode, elseTk, elseNode].toArray @[inline] def mkIf (c t e : Syntax) : Syntax := mkIfAux (mkAtom "if") c (mkAtom "then") t (mkAtom "else") e -@[inline] def withIf {α : Type} (src : SyntaxNode) (fn : Syntax → Syntax → Syntax → α) : α := -match src with -| ⟨Syntax.node _ args _, _⟩ := fn (args.get 1) (args.get 3) (args.get 5) -| ⟨Syntax.missing, h⟩ := unreachIsNodeMissing h -| ⟨Syntax.atom _ _, h⟩ := unreachIsNodeAtom h -| ⟨Syntax.ident _ _ _ _ _, h⟩ := unreachIsNodeIdent h +@[inline] def withIf {α : Type} (n : SyntaxNode) (fn : Syntax → Syntax → Syntax → α) : α := +withArgs n $ λ args, fn (args.get 1) (args.get 3) (args.get 5) @[inline] def updateIf (src : SyntaxNode) (c t e : Syntax) : Syntax := -match src with -| ⟨Syntax.node kind args scopes, _⟩ := +updateArgs src $ λ args, let args := args.set 1 c in let args := args.set 3 t in let args := args.set 5 e in - Syntax.node kind args scopes -| ⟨Syntax.missing, h⟩ := unreachIsNodeMissing h -| ⟨Syntax.atom _ _, h⟩ := unreachIsNodeAtom h -| ⟨Syntax.ident _ _ _ _ _, h⟩ := unreachIsNodeIdent h + args + +@[inline] def mkLetAux (letTk : Syntax) (lhs : Syntax) (assignTk : Syntax) (val : Syntax) (inTk : Syntax) (body : Syntax) : Syntax := +Syntax.node letKind [letTk, lhs, assignTk, val, inTk, body].toArray [] + +@[inline] def mkLet (lhs : Syntax) (val : Syntax) (body : Syntax) : Syntax := +mkLetAux (mkAtom "let") lhs (mkAtom ":=") val (mkAtom "in") body + +@[inline] def withLet {α : Type} (n : SyntaxNode) (fn : Syntax → Syntax → Syntax → α) : α := +withArgs n $ λ args, fn (args.get 1) (args.get 3) (args.get 5) + +@[inline] def updateLet (src : SyntaxNode) (lhs val body : Syntax) : Syntax := +updateArgs src $ λ args, + let args := args.set 1 lhs in + let args := args.set 3 val in + let args := args.set 5 body in + args + +@[inline] def mkLetLhsId (id : Syntax) (binders : Syntax) (type : Syntax) : Syntax := +Syntax.node letLhsIdKind [id, binders, type].toArray [] + +@[inline] def withLetLhsId {α : Type} (n : SyntaxNode) (fn : Syntax → Syntax → Syntax → α) : α := +withArgs n $ λ args, fn (args.get 0) (args.get 1) (args.get 2) + +@[inline] def updateLhsId (src : SyntaxNode) (id binders type : Syntax) : Syntax := +updateArgs src $ λ args, + let args := args.set 0 id in + let args := args.set 1 binders in + let args := args.set 2 type in + args + +@[inline] def mkLetLhsPattern (pattern : Syntax) : Syntax := +Syntax.node letLhsPatternKind [pattern].toArray [] + +@[inline] def withLetLhsPattern {α : Type} (n : SyntaxNode) (fn : Syntax → α) : α := +withArgs n $ λ args, fn (args.get 0) + +@[inline] def withOptionSome {α : Type} (n : SyntaxNode) (fn : Syntax → α) : α := +withArgs n $ λ args, fn (args.get 0) + +def Syntax.getNumChildren (n : Syntax) : Nat := +match n with +| Syntax.node _ args _ := args.size +| _ := 0 + +def hole : Syntax := Syntax.node holeKind ∅ [] + +def mkOptionSome (s : Syntax) := Syntax.node optionSomeKind [s].toArray [] abbrev FrontendConfig := Bool -- placeholder abbrev Message := String -- placeholder abbrev TransformM := ReaderT FrontendConfig $ ExceptT Message Id abbrev Transformer := SyntaxNode → TransformM (Option Syntax) +def noExpansion : TransformM (Option Syntax) := pure none + +@[inline] def Syntax.case {α : Type} (n : Syntax) (k : SyntaxNodeKind) (fn : SyntaxNode → TransformM (Option α)) : TransformM (Option α) := +match n with +| Syntax.node k' args s := if k == k' then fn ⟨Syntax.node k' args s, IsNode.mk _ _ _⟩ else pure none +| _ := pure none + +@[inline] def TransformM.orCase {α : Type} (x y : TransformM (Option α)) : TransformM (Option α) := +λ cfg, match x cfg with + | Except.ok none := y cfg + | other := other + +infix ``:2 := TransformM.orCase + set_option pp.implicit true set_option trace.compiler.boxed true def flipIf : Transformer := λ n, withIf n $ λ c t e, - isNot c (pure n.val) $ λ c', - pure (updateIf n c' e t) + c.case notKind $ λ c, withNot c $ λ c', + pure $ updateIf n c' e t -/- -The generated code can be still be improved if we modify ExceptT using the trick described in -our paper. --/ +def letTransformer : Transformer := +λ n, withLet n $ λ lhs val body, + (lhs.case letLhsIdKind $ λ lhs, withLetLhsId lhs $ λ id binders type, + if binders.getNumChildren == 0 then + type.case optionNoneKind $ λ _, + let newLhs := updateLhsId lhs id binders (mkOptionSome hole) in + pure (some (updateLet n newLhs val body)) + else + -- TODO + noExpansion) + + (lhs.case letLhsPatternKind $ λ lhs, + -- TODO + noExpansion) + +@[inline] def Syntax.isNode {α : Type} (n : Syntax) (fn : SyntaxNodeKind → SyntaxNode → TransformM (Option α)) : TransformM (Option α) := +match n with +| Syntax.node k args s := fn k ⟨Syntax.node k args s, IsNode.mk _ _ _⟩ +| other := pure none + +def SyntaxNode.getKind (n : SyntaxNode) : SyntaxNodeKind := +match n with +| ⟨Syntax.node k _ _, _⟩ := k +| ⟨Syntax.missing, h⟩ := unreachIsNodeMissing h +| ⟨Syntax.atom _ _, h⟩ := unreachIsNodeAtom h +| ⟨Syntax.ident _ _ _ _ _, h⟩ := unreachIsNodeIdent h + +/- Version without using the combinator . -/ +def letTransformer' : Transformer := +λ n, withLet n $ λ lhs val body, + lhs.isNode $ λ k lhs, -- lhs is now a SyntaxNode + if k == letLhsIdKind then withLetLhsId lhs $ λ id binders type, + if binders.getNumChildren == 0 then + type.case optionNoneKind $ λ _, + let newLhs := updateLhsId lhs id binders (mkOptionSome hole) in + pure (some (updateLet n newLhs val body)) + else + -- TODO + noExpansion + else withLetLhsPattern lhs $ λ pattern, + -- TODO + noExpansion