diff --git a/src/extra/macros.lua b/src/extra/macros.lua index a2cc8fb213..ead2418cab 100644 --- a/src/extra/macros.lua +++ b/src/extra/macros.lua @@ -26,7 +26,7 @@ function binder_macro(name, f, farity, typepos, lambdapos) local precedence = 0 macro(name, { macro_arg.Bindings, macro_arg.Comma, macro_arg.Expr }, - function (bindings, body) + function (env, bindings, body) local r = body for i = #bindings, 1, -1 do local bname = bindings[i][1] @@ -74,7 +74,7 @@ function nary_macro(name, f, farity) end macro(name, { macro_arg.Expr, macro_arg.Expr, macro_arg.Exprs }, - function (e1, e2, rest) + function (env, e1, e2, rest) local r = bin_app(e1, e2) for i = 1, #rest do r = bin_app(r, rest[i]) diff --git a/src/frontends/lean/parser.cpp b/src/frontends/lean/parser.cpp index c582eb333f..a94518ae92 100644 --- a/src/frontends/lean/parser.cpp +++ b/src/frontends/lean/parser.cpp @@ -133,7 +133,7 @@ static unsigned g_level_cup_prec = 5; // are syntax sugar for (Pi (_ : A), B) static name g_unused = name::mk_internal_unique_name(); -enum class macro_arg_kind { Expr, Exprs, Bindings, Id, Comma, Assign }; +enum class macro_arg_kind { Expr, Exprs, Bindings, Id, Comma, Assign, Tactic }; struct macro { list m_arg_kinds; luaref m_fn; @@ -432,7 +432,7 @@ class parser::imp { break; case script_exception::source::Unknown: display_error_pos(m_last_script_pos); - regular(m_io_state) << " executing script, but could not decode position information, " << ex.what() << endl; + regular(m_io_state) << " executing script, exact error position is not available, " << ex.what() << endl; break; } } @@ -924,11 +924,18 @@ class parser::imp { } typedef buffer> macro_arg_stack; + struct macro_result { + optional m_expr; + optional m_tactic; + macro_result(expr const & e):m_expr(e) {} + macro_result(tactic const & t):m_tactic(t) {} + }; /** \brief Parse a macro implemented in Lua */ - expr parse_macro(list const & arg_kinds, luaref const & fn, unsigned prec, macro_arg_stack & args, pos_info const & p) { + macro_result parse_macro(list const & arg_kinds, luaref const & fn, unsigned prec, macro_arg_stack & args, + pos_info const & p) { if (arg_kinds) { auto k = head(arg_kinds); switch (k) { @@ -962,6 +969,11 @@ class parser::imp { name n = curr_name(); args.emplace_back(k, &n); return parse_macro(tail(arg_kinds), fn, prec, args, p); + } + case macro_arg_kind::Tactic: { + tactic t = parse_tactic_expr(); + args.emplace_back(k, &t); + return parse_macro(tail(arg_kinds), fn, prec, args, p); }} lean_unreachable(); } else { @@ -969,6 +981,7 @@ class parser::imp { m_last_script_pos = p; return m_script_state->apply([&](lua_State * L) { fn.push(); + push_environment(L, m_env); for (auto p : args) { macro_arg_kind k = p.first; void * arg = p.second; @@ -1005,19 +1018,26 @@ class parser::imp { case macro_arg_kind::Id: push_name(L, *static_cast(arg)); break; + case macro_arg_kind::Tactic: + push_tactic(L, *static_cast(arg)); + break; default: lean_unreachable(); } } - pcall(L, args.size(), 1, 0); + pcall(L, args.size() + 1, 1, 0); if (is_expr(L, -1)) { expr r = to_expr(L, -1); lua_pop(L, 1); propagate_position(r, p); - return r; + return macro_result(r); + } else if (is_tactic(L, -1)) { + tactic t = to_tactic(L, -1); + lua_pop(L, 1); + return macro_result(t); } else { lua_pop(L, 1); - throw parser_error("failed to execute macro", p); + throw parser_error("failed to execute macro, unexpected result type", p); } }); } @@ -1027,7 +1047,12 @@ class parser::imp { lean_assert(m_macros && m_macros->find(id) != m_macros->end()); auto m = m_macros->find(id)->second; macro_arg_stack args; - return parse_macro(m.m_arg_kinds, m.m_fn, m.m_precedence, args, p); + auto r = parse_macro(m.m_arg_kinds, m.m_fn, m.m_precedence, args, p); + if (r.m_expr) { + return *(r.m_expr); + } else { + throw parser_error("failed to execute macro, unexpected result type", p); + } } /** @@ -2740,6 +2765,7 @@ void open_macros(lua_State * L) { SET_ENUM("Id", macro_arg_kind::Id); SET_ENUM("Comma", macro_arg_kind::Comma); SET_ENUM("Assign", macro_arg_kind::Assign); + SET_ENUM("Tactic", macro_arg_kind::Tactic); lua_setglobal(L, "macro_arg"); } } diff --git a/tests/lean/lua18.lean b/tests/lean/lua18.lean index e11aa271c5..caf86afec7 100644 --- a/tests/lean/lua18.lean +++ b/tests/lean/lua18.lean @@ -1,10 +1,10 @@ (** macro("MyMacro", { macro_arg.Expr, macro_arg.Comma, macro_arg.Expr }, - function (e1, e2) + function (env, e1, e2) return Const({"Int", "add"})(e1, e2) end) macro("Sum", { macro_arg.Exprs }, - function (es) + function (env, es) if #es == 0 then return iVal(0) end