From 7003fb6447d6a16ed106f5bbdb40271c8f4a3f31 Mon Sep 17 00:00:00 2001 From: Sebastian Ullrich Date: Wed, 21 Nov 2018 18:11:17 +0100 Subject: [PATCH] feat(library/init/lean/expander): command-level notations --- library/init/control/state.lean | 5 ++- library/init/lean/elaborator.lean | 37 +++++++++++++------ library/init/lean/expander.lean | 59 +++++++++++++++++++++++++++++++ 3 files changed, 89 insertions(+), 12 deletions(-) diff --git a/library/init/control/state.lean b/library/init/control/state.lean index 4074bd7c48..e55b9797dc 100644 --- a/library/init/control/state.lean +++ b/library/init/control/state.lean @@ -16,6 +16,9 @@ def state_t (σ : Type u) (m : Type u → Type v) (α : Type u) : Type (max u v) @[inline] def state_t.run {σ : Type u} {m : Type u → Type v} {α : Type u} (x : state_t σ m α) (s : σ) : m (α × σ) := x s +@[inline] def state_t.run' {σ : Type u} {m : Type u → Type v} [functor m] {α : Type u} (x : state_t σ m α) (s : σ) : m α := +prod.fst <$> x s + @[reducible] def state (σ α : Type u) : Type u := state_t σ id α namespace state_t @@ -160,7 +163,7 @@ instance [monad m] : monad_state_adapter σ σ' (state_t σ m) (state_t σ' m) : end instance (σ : Type u) (m out : Type u → Type v) [functor m] [monad_run out m] : monad_run (λ α, σ → out α) (state_t σ m) := -⟨λ α x, run ∘ (λ σ, prod.fst <$> (x σ))⟩ +⟨λ α x, run ∘ state_t.run' x⟩ class monad_state_runner (σ : Type u) (m m' : Type u → Type u) := (run_state {} {α : Type u} : m α → σ → m' α) diff --git a/library/init/lean/elaborator.lean b/library/init/lean/elaborator.lean index 3c5a9e0c67..76da757ad7 100644 --- a/library/init/lean/elaborator.lean +++ b/library/init/lean/elaborator.lean @@ -16,17 +16,19 @@ open parser open parser.term open parser.command open parser.command.notation_spec +open expander structure elaborator_config := (filename : string) -(local_notations : list notation.view := []) +(local_notations : list notation_macro := []) (initial_parser_cfg : module_parser_config) structure elaborator_state := -- TODO(Sebastian): retrieve from environment (reserved_notations : list reserve_notation.view := []) --- TODO(Sebastian): retrieve from environment -(nonlocal_notations : list notation.view := []) +(nonlocal_notations : list notation_macro := []) +(notation_counter := 0) + (messages : message_log := message_log.empty) (parser_cfg : module_parser_config) (expander_cfg : expander.expander_config) @@ -162,10 +164,9 @@ do spec.rules.mfoldl (λ (cfg : command_parser_config) r, match r.symbol with pure {cfg with tokens := cfg.tokens.insert a.val.trim {«prefix» := a.val.trim, lbp := prec_to_nat prec}} | _ := throw "register_notation_tokens: unreachable") cfg -def command_parser_config.register_notation_parser (spec : notation_spec.view) (cfg : command_parser_config) : - except string command_parser_config := +def command_parser_config.register_notation_parser (k : syntax_node_kind) (spec : notation_spec.view) + (cfg : command_parser_config) : except string command_parser_config := do -- build and register parser - let k : syntax_node_kind := {name := "notation"}, ps ← spec.rules.mmap (λ r : rule.view, do psym ← match r.symbol with | notation_symbol.view.quoted {symbol := some a ..} := @@ -208,10 +209,10 @@ do st ← get, | except.ok ccfg := pure ccfg | except.error e := error (review reserve_notation rnota) e) ccfg, ccfg ← (st.nonlocal_notations ++ cfg.local_notations).mfoldl (λ ccfg nota, - match command_parser_config.register_notation_tokens nota.spec ccfg >>= - command_parser_config.register_notation_parser nota.spec with + match command_parser_config.register_notation_tokens nota.nota.spec ccfg >>= + command_parser_config.register_notation_parser nota.kind nota.nota.spec with | except.ok ccfg := pure ccfg - | except.error e := error (review «notation» nota) e) ccfg, + | except.error e := error (review «notation» nota.nota) e) ccfg, put {st with parser_cfg := {cfg.initial_parser_cfg with to_command_parser_config := ccfg}} def yield_to_outside : coelaborator_m unit := @@ -288,6 +289,18 @@ def notation.elaborate_aux : notation.view → elaborator_m notation.view := -- TODO: sanity checks pure {nota with spec := postprocess_notation_spec nota.spec} +def mk_notation_kind : elaborator_m syntax_node_kind := +do st ← get, + put {st with notation_counter := st.notation_counter + 1}, + pure {name := (`_notation).mk_numeral st.notation_counter} + +def register_notation_macro (nota : notation.view) : elaborator_m notation_macro := +do k ← mk_notation_kind, + let m : notation_macro := ⟨k, nota⟩, + let transf := mk_notation_transformer m, + modify $ λ st, {st with expander_cfg := {st.expander_cfg with transformers := st.expander_cfg.transformers.insert k.name transf}}, + pure m + def notation.elaborate : elaborator := λ stx, do let nota := view «notation» stx, @@ -303,7 +316,8 @@ def notation.elaborate : elaborator := severity := message_severity.warning, text := "ignoring notation using 'fold' action"}} } else do { nota ← notation.elaborate_aux nota, - modify $ λ st, {st with nonlocal_notations := nota::st.nonlocal_notations}, + m ← register_notation_macro nota, + modify $ λ st, {st with nonlocal_notations := m::st.nonlocal_notations}, update_parser_config } @@ -332,8 +346,9 @@ def commands.elaborate (stop_on_end_cmd : bool) : ℕ → coelaborator let nota := view «notation» cmd, if nota.local.is_some then do { nota ← notation.elaborate_aux nota, + m ← register_notation_macro nota, -- add local notation scoped to the remaining commands - adapt_reader (λ cfg : elaborator_config, {cfg with local_notations := nota::cfg.local_notations}) $ do { + adapt_reader (λ cfg : elaborator_config, {cfg with local_notations := m::cfg.local_notations}) $ do { (update_parser_config : coelaborator), yield_to_outside, commands.elaborate n diff --git a/library/init/lean/expander.lean b/library/init/lean/expander.lean index 5f0f8a2ad8..11f16f9c02 100644 --- a/library/init/lean/expander.lean +++ b/library/init/lean/expander.lean @@ -45,6 +45,65 @@ instance coe_ident_binder_id : has_coe syntax_ident binder_ident.view := instance coe_binders {α : Type} [has_coe_t α binder_ident.view] : has_coe (list α) term.binders.view := ⟨λ ids, {leading_ids := ids.map coe}⟩ +instance coe_binders_binders' : has_coe term.binders.view term.binders'.view := +⟨term.binders'.view.extended⟩ + +/-- A notation together with a unique node kind. -/ +structure notation_macro := +(kind : syntax_node_kind) +(nota : notation.view) + +structure notation_transformer_state := +(stx : syntax) +-- children of `stx` that have not been consumed yet +(stx_args : list syntax := []) +-- substitutions for notation variables +(substs : list (syntax_ident × syntax) := []) +-- filled by `binders` transitions, consumed by `scoped` actions +(scoped : option $ term.binders.view := none) + +private def pop_stx_arg : state_t notation_transformer_state transform_m syntax := +do st ← get, + match st.stx_args with + | arg::args := put {st with stx_args := args} *> pure arg + | _ := error st.stx "mk_notation_transformer: unreachable" + +def mk_notation_transformer (nota : notation_macro) : expander.transformer := +λ stx, do + some {args := stx_args, ..} ← pure stx.as_node + | error stx "mk_notation_transformer: unreachable", + flip state_t.run' {notation_transformer_state . stx := stx, stx_args := stx_args} $ do + let spec := nota.nota.spec, + match spec.prefix_arg with + | none := pure () + | some arg := do { stx_arg ← pop_stx_arg, modify $ λ st, {st with substs := (arg, stx_arg)::st.substs} }, + nota.nota.spec.rules.mfor (λ r : rule.view, do + match r.symbol with + | notation_symbol.view.quoted {symbol := some a ..} := pop_stx_arg + | _ := error stx "mk_notation_transformer: unreachable", + match r.transition with + | some (transition.view.binders b) := + do { stx_arg ← pop_stx_arg, modify $ λ st, {st with scoped := some $ view term.binders.parser stx_arg} } + | some (transition.view.arg {action := none, id := id}) := + do { stx_arg ← pop_stx_arg, modify $ λ st, {st with substs := (id, stx_arg)::st.substs} } + | some (transition.view.arg {action := some {kind := action_kind.view.prec _}, id := id}) := + do { stx_arg ← pop_stx_arg, modify $ λ st, {st with substs := (id, stx_arg)::st.substs} } + | some (transition.view.arg {action := some {kind := action_kind.view.scoped sc}, id := id}) := do + stx_arg ← pop_stx_arg, + {scoped := some bnders, ..} ← get + | error stx "mk_notation_transformer: unreachable", + -- TODO(Sebastian): not correct with multiple binders + let sc_lam := review lambda {binders := [sc.id], body := sc.term}, + let lam := review lambda {binders := binders'.view.extended bnders, body := stx_arg}, + let arg := review app {fn := sc_lam, arg := lam}, + modify $ λ st, {st with substs := (id, arg)::st.substs} + | none := pure () + | _ := error stx "mk_notation_transformer: unimplemented"), + st ← get, + -- apply substitutions [(x1, e1), ...] via `(λ x1 ..., nota.nota.term) e1 ...` + let lam := review lambda {binders := st.substs.map prod.fst, body := nota.nota.term}, + pure $ some $ st.substs.foldl (λ fn subst, review app {fn := fn, arg := subst.2}) lam + def mk_simple_binder (id : syntax_ident) (bi : binder_info) (type : syntax) : binders'.view := let bc : binder_content.view := {ids := [id], type := some {type := type}} in binders'.view.simple $ match bi with