feat(library/init/lean/expander): command-level notations
This commit is contained in:
parent
fd121f03bd
commit
7003fb6447
3 changed files with 89 additions and 12 deletions
|
|
@ -16,6 +16,9 @@ def state_t (σ : Type u) (m : Type u → Type v) (α : Type u) : Type (max u v)
|
|||
@[inline] def state_t.run {σ : Type u} {m : Type u → Type v} {α : Type u} (x : state_t σ m α) (s : σ) : m (α × σ) :=
|
||||
x s
|
||||
|
||||
@[inline] def state_t.run' {σ : Type u} {m : Type u → Type v} [functor m] {α : Type u} (x : state_t σ m α) (s : σ) : m α :=
|
||||
prod.fst <$> x s
|
||||
|
||||
@[reducible] def state (σ α : Type u) : Type u := state_t σ id α
|
||||
|
||||
namespace state_t
|
||||
|
|
@ -160,7 +163,7 @@ instance [monad m] : monad_state_adapter σ σ' (state_t σ m) (state_t σ' m) :
|
|||
end
|
||||
|
||||
instance (σ : Type u) (m out : Type u → Type v) [functor m] [monad_run out m] : monad_run (λ α, σ → out α) (state_t σ m) :=
|
||||
⟨λ α x, run ∘ (λ σ, prod.fst <$> (x σ))⟩
|
||||
⟨λ α x, run ∘ state_t.run' x⟩
|
||||
|
||||
class monad_state_runner (σ : Type u) (m m' : Type u → Type u) :=
|
||||
(run_state {} {α : Type u} : m α → σ → m' α)
|
||||
|
|
|
|||
|
|
@ -16,17 +16,19 @@ open parser
|
|||
open parser.term
|
||||
open parser.command
|
||||
open parser.command.notation_spec
|
||||
open expander
|
||||
|
||||
structure elaborator_config :=
|
||||
(filename : string)
|
||||
(local_notations : list notation.view := [])
|
||||
(local_notations : list notation_macro := [])
|
||||
(initial_parser_cfg : module_parser_config)
|
||||
|
||||
structure elaborator_state :=
|
||||
-- TODO(Sebastian): retrieve from environment
|
||||
(reserved_notations : list reserve_notation.view := [])
|
||||
-- TODO(Sebastian): retrieve from environment
|
||||
(nonlocal_notations : list notation.view := [])
|
||||
(nonlocal_notations : list notation_macro := [])
|
||||
(notation_counter := 0)
|
||||
|
||||
(messages : message_log := message_log.empty)
|
||||
(parser_cfg : module_parser_config)
|
||||
(expander_cfg : expander.expander_config)
|
||||
|
|
@ -162,10 +164,9 @@ do spec.rules.mfoldl (λ (cfg : command_parser_config) r, match r.symbol with
|
|||
pure {cfg with tokens := cfg.tokens.insert a.val.trim {«prefix» := a.val.trim, lbp := prec_to_nat prec}}
|
||||
| _ := throw "register_notation_tokens: unreachable") cfg
|
||||
|
||||
def command_parser_config.register_notation_parser (spec : notation_spec.view) (cfg : command_parser_config) :
|
||||
except string command_parser_config :=
|
||||
def command_parser_config.register_notation_parser (k : syntax_node_kind) (spec : notation_spec.view)
|
||||
(cfg : command_parser_config) : except string command_parser_config :=
|
||||
do -- build and register parser
|
||||
let k : syntax_node_kind := {name := "notation<TODO>"},
|
||||
ps ← spec.rules.mmap (λ r : rule.view, do
|
||||
psym ← match r.symbol with
|
||||
| notation_symbol.view.quoted {symbol := some a ..} :=
|
||||
|
|
@ -208,10 +209,10 @@ do st ← get,
|
|||
| except.ok ccfg := pure ccfg
|
||||
| except.error e := error (review reserve_notation rnota) e) ccfg,
|
||||
ccfg ← (st.nonlocal_notations ++ cfg.local_notations).mfoldl (λ ccfg nota,
|
||||
match command_parser_config.register_notation_tokens nota.spec ccfg >>=
|
||||
command_parser_config.register_notation_parser nota.spec with
|
||||
match command_parser_config.register_notation_tokens nota.nota.spec ccfg >>=
|
||||
command_parser_config.register_notation_parser nota.kind nota.nota.spec with
|
||||
| except.ok ccfg := pure ccfg
|
||||
| except.error e := error (review «notation» nota) e) ccfg,
|
||||
| except.error e := error (review «notation» nota.nota) e) ccfg,
|
||||
put {st with parser_cfg := {cfg.initial_parser_cfg with to_command_parser_config := ccfg}}
|
||||
|
||||
def yield_to_outside : coelaborator_m unit :=
|
||||
|
|
@ -288,6 +289,18 @@ def notation.elaborate_aux : notation.view → elaborator_m notation.view :=
|
|||
-- TODO: sanity checks
|
||||
pure {nota with spec := postprocess_notation_spec nota.spec}
|
||||
|
||||
def mk_notation_kind : elaborator_m syntax_node_kind :=
|
||||
do st ← get,
|
||||
put {st with notation_counter := st.notation_counter + 1},
|
||||
pure {name := (`_notation).mk_numeral st.notation_counter}
|
||||
|
||||
def register_notation_macro (nota : notation.view) : elaborator_m notation_macro :=
|
||||
do k ← mk_notation_kind,
|
||||
let m : notation_macro := ⟨k, nota⟩,
|
||||
let transf := mk_notation_transformer m,
|
||||
modify $ λ st, {st with expander_cfg := {st.expander_cfg with transformers := st.expander_cfg.transformers.insert k.name transf}},
|
||||
pure m
|
||||
|
||||
def notation.elaborate : elaborator :=
|
||||
λ stx, do
|
||||
let nota := view «notation» stx,
|
||||
|
|
@ -303,7 +316,8 @@ def notation.elaborate : elaborator :=
|
|||
severity := message_severity.warning, text := "ignoring notation using 'fold' action"}}
|
||||
} else do {
|
||||
nota ← notation.elaborate_aux nota,
|
||||
modify $ λ st, {st with nonlocal_notations := nota::st.nonlocal_notations},
|
||||
m ← register_notation_macro nota,
|
||||
modify $ λ st, {st with nonlocal_notations := m::st.nonlocal_notations},
|
||||
update_parser_config
|
||||
}
|
||||
|
||||
|
|
@ -332,8 +346,9 @@ def commands.elaborate (stop_on_end_cmd : bool) : ℕ → coelaborator
|
|||
let nota := view «notation» cmd,
|
||||
if nota.local.is_some then do {
|
||||
nota ← notation.elaborate_aux nota,
|
||||
m ← register_notation_macro nota,
|
||||
-- add local notation scoped to the remaining commands
|
||||
adapt_reader (λ cfg : elaborator_config, {cfg with local_notations := nota::cfg.local_notations}) $ do {
|
||||
adapt_reader (λ cfg : elaborator_config, {cfg with local_notations := m::cfg.local_notations}) $ do {
|
||||
(update_parser_config : coelaborator),
|
||||
yield_to_outside,
|
||||
commands.elaborate n
|
||||
|
|
|
|||
|
|
@ -45,6 +45,65 @@ instance coe_ident_binder_id : has_coe syntax_ident binder_ident.view :=
|
|||
instance coe_binders {α : Type} [has_coe_t α binder_ident.view] : has_coe (list α) term.binders.view :=
|
||||
⟨λ ids, {leading_ids := ids.map coe}⟩
|
||||
|
||||
instance coe_binders_binders' : has_coe term.binders.view term.binders'.view :=
|
||||
⟨term.binders'.view.extended⟩
|
||||
|
||||
/-- A notation together with a unique node kind. -/
|
||||
structure notation_macro :=
|
||||
(kind : syntax_node_kind)
|
||||
(nota : notation.view)
|
||||
|
||||
structure notation_transformer_state :=
|
||||
(stx : syntax)
|
||||
-- children of `stx` that have not been consumed yet
|
||||
(stx_args : list syntax := [])
|
||||
-- substitutions for notation variables
|
||||
(substs : list (syntax_ident × syntax) := [])
|
||||
-- filled by `binders` transitions, consumed by `scoped` actions
|
||||
(scoped : option $ term.binders.view := none)
|
||||
|
||||
private def pop_stx_arg : state_t notation_transformer_state transform_m syntax :=
|
||||
do st ← get,
|
||||
match st.stx_args with
|
||||
| arg::args := put {st with stx_args := args} *> pure arg
|
||||
| _ := error st.stx "mk_notation_transformer: unreachable"
|
||||
|
||||
def mk_notation_transformer (nota : notation_macro) : expander.transformer :=
|
||||
λ stx, do
|
||||
some {args := stx_args, ..} ← pure stx.as_node
|
||||
| error stx "mk_notation_transformer: unreachable",
|
||||
flip state_t.run' {notation_transformer_state . stx := stx, stx_args := stx_args} $ do
|
||||
let spec := nota.nota.spec,
|
||||
match spec.prefix_arg with
|
||||
| none := pure ()
|
||||
| some arg := do { stx_arg ← pop_stx_arg, modify $ λ st, {st with substs := (arg, stx_arg)::st.substs} },
|
||||
nota.nota.spec.rules.mfor (λ r : rule.view, do
|
||||
match r.symbol with
|
||||
| notation_symbol.view.quoted {symbol := some a ..} := pop_stx_arg
|
||||
| _ := error stx "mk_notation_transformer: unreachable",
|
||||
match r.transition with
|
||||
| some (transition.view.binders b) :=
|
||||
do { stx_arg ← pop_stx_arg, modify $ λ st, {st with scoped := some $ view term.binders.parser stx_arg} }
|
||||
| some (transition.view.arg {action := none, id := id}) :=
|
||||
do { stx_arg ← pop_stx_arg, modify $ λ st, {st with substs := (id, stx_arg)::st.substs} }
|
||||
| some (transition.view.arg {action := some {kind := action_kind.view.prec _}, id := id}) :=
|
||||
do { stx_arg ← pop_stx_arg, modify $ λ st, {st with substs := (id, stx_arg)::st.substs} }
|
||||
| some (transition.view.arg {action := some {kind := action_kind.view.scoped sc}, id := id}) := do
|
||||
stx_arg ← pop_stx_arg,
|
||||
{scoped := some bnders, ..} ← get
|
||||
| error stx "mk_notation_transformer: unreachable",
|
||||
-- TODO(Sebastian): not correct with multiple binders
|
||||
let sc_lam := review lambda {binders := [sc.id], body := sc.term},
|
||||
let lam := review lambda {binders := binders'.view.extended bnders, body := stx_arg},
|
||||
let arg := review app {fn := sc_lam, arg := lam},
|
||||
modify $ λ st, {st with substs := (id, arg)::st.substs}
|
||||
| none := pure ()
|
||||
| _ := error stx "mk_notation_transformer: unimplemented"),
|
||||
st ← get,
|
||||
-- apply substitutions [(x1, e1), ...] via `(λ x1 ..., nota.nota.term) e1 ...`
|
||||
let lam := review lambda {binders := st.substs.map prod.fst, body := nota.nota.term},
|
||||
pure $ some $ st.substs.foldl (λ fn subst, review app {fn := fn, arg := subst.2}) lam
|
||||
|
||||
def mk_simple_binder (id : syntax_ident) (bi : binder_info) (type : syntax) : binders'.view :=
|
||||
let bc : binder_content.view := {ids := [id], type := some {type := type}} in
|
||||
binders'.view.simple $ match bi with
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue