diff --git a/src/frontends/lean/notation_cmd.cpp b/src/frontends/lean/notation_cmd.cpp index dfaa18bcca..b46c0a3739 100644 --- a/src/frontends/lean/notation_cmd.cpp +++ b/src/frontends/lean/notation_cmd.cpp @@ -27,6 +27,7 @@ static name g_infixl("infixl"); static name g_infixr("infixr"); static name g_postfix("postfix"); static name g_notation("notation"); +static name g_call("call"); static std::string parse_symbol(parser & p, char const * msg) { name n; @@ -74,6 +75,7 @@ using notation::mk_binders_action; using notation::mk_exprs_action; using notation::mk_scoped_expr_action; using notation::mk_skip_action; +using notation::mk_ext_lua_action; using notation::transition; using notation::action; @@ -180,6 +182,10 @@ static action parse_action(parser & p, buffer & locals, buffer & locals, buffer #include #include "util/rb_map.h" #include "util/sstream.h" @@ -58,6 +59,12 @@ struct ext_action_cell : public action_cell { action_cell(action_kind::Ext), m_parse_fn(fn) {} }; +struct ext_lua_action_cell : public action_cell { + std::string m_lua_fn; + ext_lua_action_cell(char const * fn): + action_cell(action_kind::LuaExt), m_lua_fn(fn) {} +}; + action::action(action_cell * ptr):m_ptr(ptr) { lean_assert(ptr); } action::action():action(mk_skip_action()) {} action::action(action const & s):m_ptr(s.m_ptr) { if (m_ptr) m_ptr->inc_ref(); } @@ -82,6 +89,10 @@ ext_action_cell * to_ext_action(action_cell * c) { lean_assert(c->m_kind == action_kind::Ext); return static_cast(c); } +ext_lua_action_cell * to_ext_lua_action(action_cell * c) { + lean_assert(c->m_kind == action_kind::LuaExt); + return static_cast(c); +} unsigned action::rbp() const { return to_expr_action(m_ptr)->m_rbp; } name const & action::get_sep() const { return to_exprs_action(m_ptr)->m_token_sep; } expr const & action::get_rec() const { @@ -94,6 +105,7 @@ bool action::use_lambda_abstraction() const { return to_scoped_expr_action(m_ptr expr const & action::get_initial() const { return to_exprs_action(m_ptr)->m_ini; } bool action::is_fold_right() const { return to_exprs_action(m_ptr)->m_fold_right; } parse_fn const & action::get_parse_fn() const { return to_ext_action(m_ptr)->m_parse_fn; } +std::string const & action::get_lua_fn() const { return to_ext_lua_action(m_ptr)->m_lua_fn; } bool action::is_compatible(action const & a) const { if (kind() != a.kind()) return false; @@ -102,6 +114,8 @@ bool action::is_compatible(action const & a) const { return true; case action_kind::Ext: return m_ptr == a.m_ptr; + case action_kind::LuaExt: + return get_lua_fn() == a.get_lua_fn(); case action_kind::Expr: return rbp() == a.rbp(); case action_kind::Exprs: @@ -124,6 +138,7 @@ void action_cell::dealloc() { case action_kind::Exprs: delete(to_exprs_action(this)); break; case action_kind::ScopedExpr: delete(to_scoped_expr_action(this)); break; case action_kind::Ext: delete(to_ext_action(this)); break; + case action_kind::LuaExt: delete(to_ext_lua_action(this)); break; default: delete this; break; } } @@ -154,6 +169,7 @@ action mk_scoped_expr_action(expr const & rec, unsigned rb, bool lambda) { return action(new scoped_expr_action_cell(rec, rb, lambda)); } action mk_ext_action(parse_fn const & fn) { return action(new ext_action_cell(fn)); } +action mk_ext_lua_action(char const * fn) { return action(new ext_lua_action_cell(fn)); } struct parse_table::cell { bool m_nud; @@ -196,7 +212,7 @@ static void validate_transitions(bool nud, unsigned num, transition const * ts, case action_kind::Binder: case action_kind::Binders: found_binder = true; break; - case action_kind::Expr: case action_kind::Exprs: case action_kind::Ext: + case action_kind::Expr: case action_kind::Exprs: case action_kind::Ext: case action_kind::LuaExt: nargs++; break; case action_kind::ScopedExpr: @@ -296,6 +312,14 @@ static int mk_scoped_expr_action(lua_State * L) { bool lambda = (nargs <= 2) || lua_toboolean(L, 3); return push_notation_action(L, mk_scoped_expr_action(to_expr(L, 1), rbp, lambda)); } +static int mk_ext_lua_action(lua_State * L) { + char const * fn = lua_tostring(L, 1); + lua_getglobal(L, fn); + if (lua_isnil(L, -1)) + throw exception("arg #1 is a unknown function name"); + lua_pop(L, 1); + return push_notation_action(L, mk_ext_lua_action(fn)); +} static int is_compatible(lua_State * L) { return push_boolean(L, to_notation_action(L, 1).is_compatible(to_notation_action(L, 2))); } @@ -329,6 +353,10 @@ static int use_lambda_abstraction(lua_State * L) { check_action(L, 1, { action_kind::ScopedExpr }); return push_boolean(L, to_notation_action(L, 1).use_lambda_abstraction()); } +static int fn(lua_State * L) { + check_action(L, 1, { action_kind::LuaExt }); + return push_string(L, to_notation_action(L, 1).get_lua_fn().c_str()); +} static const struct luaL_Reg notation_action_m[] = { {"__gc", notation_action_gc}, @@ -341,6 +369,7 @@ static const struct luaL_Reg notation_action_m[] = { {"initial", safe_function}, {"is_fold_right", safe_function}, {"use_lambda_abstraction", safe_function}, + {"fn", safe_function}, {0, 0} }; @@ -357,6 +386,7 @@ static void open_notation_action(lua_State * L) { SET_GLOBAL_FUN(mk_expr_action, "expr_notation_action"); SET_GLOBAL_FUN(mk_exprs_action, "exprs_notation_action"); SET_GLOBAL_FUN(mk_scoped_expr_action, "scoped_expr_notation_action"); + SET_GLOBAL_FUN(mk_ext_lua_action, "ext_action"); push_notation_action(L, mk_skip_action()); lua_setglobal(L, "Skip"); @@ -373,6 +403,7 @@ static void open_notation_action(lua_State * L) { SET_ENUM("Binders", action_kind::Binders); SET_ENUM("ScopedExpr", action_kind::ScopedExpr); SET_ENUM("Ext", action_kind::Ext); + SET_ENUM("LuaExt", action_kind::LuaExt); lua_setglobal(L, "notation_action_kind"); } diff --git a/src/frontends/lean/parse_table.h b/src/frontends/lean/parse_table.h index f2482483bb..d487a10fdb 100644 --- a/src/frontends/lean/parse_table.h +++ b/src/frontends/lean/parse_table.h @@ -5,6 +5,7 @@ Released under Apache 2.0 license as described in the file LICENSE. Author: Leonardo de Moura */ #pragma once +#include #include #include "util/buffer.h" #include "util/lua.h" @@ -17,7 +18,7 @@ class parser; namespace notation { typedef std::function parse_fn; -enum class action_kind { Skip, Expr, Exprs, Binder, Binders, ScopedExpr, Ext }; +enum class action_kind { Skip, Expr, Exprs, Binder, Binders, ScopedExpr, Ext, LuaExt }; struct action_cell; /** @@ -63,6 +64,7 @@ public: friend action mk_binders_action(); friend action mk_scoped_expr_action(expr const & rec, unsigned rbp, bool lambda); friend action mk_ext_action(parse_fn const & fn); + friend action mk_ext_lua_action(char const * lua_fn); action_kind kind() const; unsigned rbp() const; @@ -72,6 +74,7 @@ public: bool is_fold_right() const; bool use_lambda_abstraction() const; parse_fn const & get_parse_fn() const; + std::string const & get_lua_fn() const; bool is_compatible(action const & a) const; }; @@ -83,6 +86,7 @@ action mk_binder_action(); action mk_binders_action(); action mk_scoped_expr_action(expr const & rec, unsigned rbp = 0, bool lambda = true); action mk_ext_action(parse_fn const & fn); +action mk_ext_lua_action(char const * lua_fn); class transition { name m_token; diff --git a/src/frontends/lean/parser.cpp b/src/frontends/lean/parser.cpp index d6d9df29f8..61b6790f61 100644 --- a/src/frontends/lean/parser.cpp +++ b/src/frontends/lean/parser.cpp @@ -64,6 +64,42 @@ parser::no_undef_id_error_scope::~no_undef_id_error_scope() { m_p.m_no_undef_id_error = m_old; } +static char g_parser_key; +void set_global_parser(lua_State * L, parser * p) { + lua_pushlightuserdata(L, static_cast(&g_parser_key)); + lua_pushlightuserdata(L, static_cast(p)); + lua_settable(L, LUA_REGISTRYINDEX); +} + +parser * get_global_parser_ptr(lua_State * L) { + lua_pushlightuserdata(L, static_cast(&g_parser_key)); + lua_gettable(L, LUA_REGISTRYINDEX); + if (!lua_islightuserdata(L, -1)) + return nullptr; + parser * p = static_cast(const_cast(lua_topointer(L, -1))); + lua_pop(L, 1); + return p; +} + +parser & get_global_parser(lua_State * L) { + parser * p = get_global_parser_ptr(L); + if (p == nullptr) + throw exception("there is no Lean parser on the Lua stack"); + return *p; +} + +struct scoped_set_parser { + lua_State * m_state; + parser * m_old; + scoped_set_parser(lua_State * L, parser & p):m_state(L) { + m_old = get_global_parser_ptr(L); + set_global_parser(L, &p); + } + ~scoped_set_parser() { + set_global_parser(m_state, m_old); + } +}; + parser::parser(environment const & env, io_state const & ios, std::istream & strm, char const * strm_name, script_state * ss, bool use_exceptions, unsigned num_threads, @@ -694,10 +730,32 @@ expr parser::parse_notation(parse_table t, expr * left) { args.push_back(r); break; } + case notation::action_kind::LuaExt: + if (!m_ss) + throw parser_error("failed to use notation implemented in Lua, parser does not contain a Lua state", p); + using_script([&](lua_State * L) { + scoped_set_parser scope(L, *this); + lua_getglobal(L, a.get_lua_fn().c_str()); + if (!lua_isfunction(L, -1)) + throw parser_error(sstream() << "failed to use notation implemented in Lua, Lua state does not contain function '" + << a.get_lua_fn() << "'", p); + lua_pushinteger(L, p.first); + lua_pushinteger(L, p.second); + for (unsigned i = 0; i < args.size(); i++) + push_expr(L, args[i]); + pcall(L, args.size() + 2, 1, 0); + if (!is_expr(L, -1)) + throw parser_error(sstream() << "failed to use notation implemented in Lua, value returned by function '" + << a.get_lua_fn() << "' is not an expression", p); + args.push_back(to_expr(L, -1)); + lua_pop(L, 1); + }); + break; case notation::action_kind::Ext: args.push_back(a.get_parse_fn()(*this, args.size(), args.data(), p)); break; } + t = r->second; } list const & as = t.is_accepting(); @@ -979,7 +1037,6 @@ bool parser::parse_commands() { return !m_found_errors; } - bool parse_commands(environment & env, io_state & ios, std::istream & in, char const * strm_name, script_state * S, bool use_exceptions, unsigned num_threads) { parser p(env, ios, in, strm_name, S, use_exceptions, num_threads); @@ -995,4 +1052,18 @@ bool parse_commands(environment & env, io_state & ios, char const * fname, scrip throw exception(sstream() << "failed to open file '" << fname << "'"); return parse_commands(env, ios, in, fname, S, use_exceptions, num_threads); } + +static int parse_expr(lua_State * L) { + script_state S = to_script_state(L); + int nargs = lua_gettop(L); + expr r; + S.exec_unprotected([&]() { + r = get_global_parser(L).parse_expr(nargs == 0 ? 0 : lua_tointeger(L, 1)); + }); + return push_expr(L, r); +} + +void open_parser(lua_State * L) { + SET_GLOBAL_FUN(parse_expr, "parse_expr"); +} } diff --git a/src/frontends/lean/parser.h b/src/frontends/lean/parser.h index a9ce8ffd98..10ea2a9111 100644 --- a/src/frontends/lean/parser.h +++ b/src/frontends/lean/parser.h @@ -270,4 +270,5 @@ bool parse_commands(environment & env, io_state & ios, std::istream & in, char c script_state * S, bool use_exceptions, unsigned num_threads); bool parse_commands(environment & env, io_state & ios, char const * fname, script_state * S, bool use_exceptions, unsigned num_threads); +void open_parser(lua_State * L); } diff --git a/src/frontends/lean/parser_config.cpp b/src/frontends/lean/parser_config.cpp index 4ba12986d5..bc18aa0310 100644 --- a/src/frontends/lean/parser_config.cpp +++ b/src/frontends/lean/parser_config.cpp @@ -75,6 +75,9 @@ serializer & operator<<(serializer & s, action const & a) { case action_kind::ScopedExpr: s << a.get_rec() << a.rbp() << a.use_lambda_abstraction(); break; + case action_kind::LuaExt: + s << a.get_lua_fn(); + break; case action_kind::Ext: lean_unreachable(); } @@ -104,6 +107,8 @@ action read_action(deserializer & d) { d >> rec >> rbp >> use_lambda_abstraction; return notation::mk_scoped_expr_action(rec, rbp, use_lambda_abstraction); } + case action_kind::LuaExt: + return notation::mk_ext_lua_action(d.read_string().c_str()); case action_kind::Ext: break; } diff --git a/src/frontends/lean/register_module.cpp b/src/frontends/lean/register_module.cpp index 4a94ca7b75..e425f2d29b 100644 --- a/src/frontends/lean/register_module.cpp +++ b/src/frontends/lean/register_module.cpp @@ -9,11 +9,14 @@ Author: Leonardo de Moura #include "util/script_state.h" #include "frontends/lean/token_table.h" #include "frontends/lean/parse_table.h" +#include "frontends/lean/parser.h" namespace lean { + void open_frontend_lean(lua_State * L) { open_token_table(L); open_parse_table(L); + open_parser(L); } void register_frontend_lean_module() { script_state::register_module(open_frontend_lean);