diff --git a/src/frontends/lean/parse_table.cpp b/src/frontends/lean/parse_table.cpp index 3ad364a43a..11efb03da0 100644 --- a/src/frontends/lean/parse_table.cpp +++ b/src/frontends/lean/parse_table.cpp @@ -212,18 +212,17 @@ bool action::is_equivalent(action const & a) const { return is_equal(a); } } + +static bool is_compatible_core(action_kind k1, action_kind k2) { + return k1 == action_kind::Skip && (k2 == action_kind::Expr || k2 == action_kind::Exprs || k2 == action_kind::ScopedExpr); +} + bool action::is_compatible(action const & a) const { if (is_equivalent(a)) return true; auto k1 = kind(); auto k2 = a.kind(); - if (k1 == action_kind::Skip && (k2 == action_kind::Expr || k2 == action_kind::Exprs)) - return true; - if (k1 == action_kind::Expr && k2 == action_kind::Skip) - return true; - if (k2 == action_kind::Exprs && (k2 == action_kind::Skip || k2 == action_kind::Exprs)) - return true; - return false; + return is_compatible_core(k1, k2) || is_compatible_core(k2, k1); } void action::display(io_state_stream & out) const { switch (kind()) { diff --git a/src/frontends/lean/parser.cpp b/src/frontends/lean/parser.cpp index 26c69838c2..d881494929 100644 --- a/src/frontends/lean/parser.cpp +++ b/src/frontends/lean/parser.cpp @@ -1174,6 +1174,41 @@ void parser::process_postponed(buffer const & args, bool is_left, } } +// Return true iff the current token is the terminator of some Exprs action, and store the matching pair in r +static bool curr_is_terminator_of_exprs_action(parser const & p, list> const & lst, pair const * & r) { + for (auto const & pr : lst) { + notation::action const & a = pr.first; + if (a.kind() == notation::action_kind::Exprs && + a.get_terminator() && + p.curr_is_token(*a.get_terminator())) { + r = ≺ + return true; + } + } + return false; +} + +// Return true iff \c lst contains a Skip action, and store the matching pair in r. +static bool has_skip(list> const & lst, pair const * & r) { + for (auto const & p : lst) { + notation::action const & a = p.first; + if (a.kind() == notation::action_kind::Skip) { + r = &p; + return true; + } + } + return false; +} + +static pair const * get_non_skip(list> const & lst) { + for (auto const & p : lst) { + notation::action const & a = p.first; + if (a.kind() != notation::action_kind::Skip) + return &p; + } + return nullptr; +} + expr parser::parse_notation_core(parse_table t, expr * left, bool as_tactic) { lean_assert(curr() == scanner::token_kind::Keyword); auto p = pos(); @@ -1199,19 +1234,34 @@ expr parser::parse_notation_core(parse_table t, expr * left, bool as_tactic) { auto r = t.find(get_token_info().value()); if (!r) break; - // TODO(Leo): handle multiple actions - notation::action const & a = head(r).first; + pair const * curr_pair = nullptr; + if (tail(r)) { + // There is more than one possible actions. + // In the current implementation, we support the following possible cases (Skip, Expr), (Skip, Exprs) amd (Skip, ScopedExpr) + next(); + if (curr_is_terminator_of_exprs_action(*this, r, curr_pair)) { + lean_assert(curr_pair->first.kind() == notation::action_kind::Exprs); + } else if (has_skip(r, curr_pair) && !curr_starts_expr()) { + lean_assert(curr_pair->first.kind() == notation::action_kind::Skip); + } else { + curr_pair = get_non_skip(r); + } + } else { + // there is only one possible action + curr_pair = &head(r); + if (curr_pair->first.kind() != notation::action_kind::Ext) + next(); + } + lean_assert(curr_pair); + notation::action const & a = curr_pair->first; switch (a.kind()) { case notation::action_kind::Skip: - next(); break; case notation::action_kind::Expr: - next(); args.push_back(parse_expr_or_tactic(a.rbp(), as_tactic)); kinds.push_back(a.kind()); break; case notation::action_kind::Exprs: { - next(); buffer r_args; auto terminator = a.get_terminator(); if (!terminator || !curr_is_token(*terminator)) { @@ -1234,17 +1284,14 @@ expr parser::parse_notation_core(parse_table t, expr * left, bool as_tactic) { break; } case notation::action_kind::Binder: - next(); binder_pos = pos(); ps.push_back(parse_binder(a.rbp())); break; case notation::action_kind::Binders: - next(); binder_pos = pos(); lenv = parse_binders(ps, a.rbp()); break; case notation::action_kind::ScopedExpr: { - next(); expr r = parse_scoped_expr(ps, lenv, a.rbp()); args.push_back(r); kinds.push_back(a.kind()); @@ -1252,7 +1299,6 @@ expr parser::parse_notation_core(parse_table t, expr * left, bool as_tactic) { break; } case notation::action_kind::LuaExt: - next(); m_last_script_pos = p; using_script([&](lua_State * L) { scoped_set_parser scope(L, *this); @@ -1279,7 +1325,7 @@ expr parser::parse_notation_core(parse_table t, expr * left, bool as_tactic) { kinds.push_back(a.kind()); break; } - t = head(r).second; // TODO(Leo): + t = curr_pair->second; } list const & as = t.is_accepting(); save_overload_notation(as, p); @@ -1578,6 +1624,22 @@ expr parser::parse_nud() { } } +// Return true if the current token can be the beginning of an expression +bool parser::curr_starts_expr() { + switch (curr()) { + case scanner::token_kind::Keyword: + return !is_nil(nud().find(get_token_info().value())); + case scanner::token_kind::Identifier: + case scanner::token_kind::Numeral: + case scanner::token_kind::Decimal: + case scanner::token_kind::String: + case scanner::token_kind::Backtick: + return true; + default: + return false; + } +} + expr parser::parse_led(expr left) { switch (curr()) { case scanner::token_kind::Keyword: return parse_led_notation(left); diff --git a/src/frontends/lean/parser.h b/src/frontends/lean/parser.h index 01c8794cfa..6ce2a571ca 100644 --- a/src/frontends/lean/parser.h +++ b/src/frontends/lean/parser.h @@ -214,6 +214,7 @@ class parser { expr parse_nud_notation(); expr parse_led_notation(expr left); expr parse_nud(); + bool curr_starts_expr(); expr parse_numeral_expr(bool user_notation = true); expr parse_decimal_expr(); expr parse_string_expr(); diff --git a/tests/lean/800.lean b/tests/lean/800.lean new file mode 100644 index 0000000000..b1528e773c --- /dev/null +++ b/tests/lean/800.lean @@ -0,0 +1,17 @@ +import data.matrix data.list +open matrix nat list + +variables {A : Type} {m n : nat} + +definition row_vector [reducible] (A : Type) (n : nat) := matrix A 1 n +definition get_row [reducible] (M : matrix A m n) (row : fin m) : row_vector A n := +λ i j, M row j + +variables (M : matrix A m n) (row : fin m) (col : fin n) + +notation M `[` i `,` j `]` := val M i j +check M[row,col] +notation M `[` i `,` `:` `]` := get_row M i +check M[row,:] +check M[row,col] +check [1, 2, 3] diff --git a/tests/lean/800.lean.expected.out b/tests/lean/800.lean.expected.out new file mode 100644 index 0000000000..0f47b37630 --- /dev/null +++ b/tests/lean/800.lean.expected.out @@ -0,0 +1,4 @@ +M[row,col] : A +M[row,:] : row_vector A n +M[row,col] : A +[1, 2, 3] : list num