feat(frontends/lean/builtin_exprs): add 'else case' for do-match notation
This commit is contained in:
parent
23565ff43c
commit
fbec9053dc
2 changed files with 34 additions and 6 deletions
|
|
@ -533,11 +533,12 @@ static expr fix_do_action_lhs(parser & p, expr const & lhs, expr const & type, p
|
|||
}
|
||||
}
|
||||
|
||||
static std::tuple<optional<expr>, expr, expr> parse_do_action(parser & p, buffer<expr> & new_locals) {
|
||||
static std::tuple<optional<expr>, expr, expr, optional<expr>> parse_do_action(parser & p, buffer<expr> & new_locals) {
|
||||
auto lhs_pos = p.pos();
|
||||
optional<expr> lhs;
|
||||
lhs = parse_match_pattern(p, new_locals);
|
||||
expr type, curr;
|
||||
optional<expr> else_case;
|
||||
if (p.curr_is_token(get_colon_tk())) {
|
||||
p.next();
|
||||
type = p.parse_expr();
|
||||
|
|
@ -554,6 +555,10 @@ static std::tuple<optional<expr>, expr, expr> parse_do_action(parser & p, buffer
|
|||
if (!is_local(*lhs))
|
||||
validate_match_pattern(p, *lhs, new_locals);
|
||||
curr = p.parse_expr();
|
||||
if (p.curr_is_token(get_bar_tk())) {
|
||||
p.next();
|
||||
else_case = p.parse_expr();
|
||||
}
|
||||
} else {
|
||||
if (!new_locals.empty()) {
|
||||
expr undef = new_locals[0];
|
||||
|
|
@ -564,13 +569,14 @@ static std::tuple<optional<expr>, expr, expr> parse_do_action(parser & p, buffer
|
|||
type = mk_expr_placeholder();
|
||||
lhs = none_expr();
|
||||
}
|
||||
return std::make_tuple(lhs, type, curr);
|
||||
return std::make_tuple(lhs, type, curr, else_case);
|
||||
}
|
||||
|
||||
static expr parse_do(parser & p, unsigned, expr const *, pos_info const & pos) {
|
||||
parser::local_scope scope(p);
|
||||
buffer<expr> es;
|
||||
buffer<optional<expr>> lhss;
|
||||
buffer<optional<expr>> else_cases;
|
||||
buffer<list<expr>> lhss_locals;
|
||||
bool has_braces = false;
|
||||
if (p.curr_is_token(get_lcurly_tk())) {
|
||||
|
|
@ -580,9 +586,9 @@ static expr parse_do(parser & p, unsigned, expr const *, pos_info const & pos) {
|
|||
while (true) {
|
||||
auto lhs_pos = p.pos();
|
||||
buffer<expr> new_locals;
|
||||
optional<expr> lhs;
|
||||
optional<expr> lhs, else_case;
|
||||
expr type, curr;
|
||||
std::tie(lhs, type, curr) = parse_do_action(p, new_locals);
|
||||
std::tie(lhs, type, curr, else_case) = parse_do_action(p, new_locals);
|
||||
es.push_back(curr);
|
||||
if (p.curr_is_token(get_comma_tk())) {
|
||||
p.next();
|
||||
|
|
@ -595,6 +601,7 @@ static expr parse_do(parser & p, unsigned, expr const *, pos_info const & pos) {
|
|||
lhss_locals.push_back(list<expr>());
|
||||
}
|
||||
lhss.push_back(lhs);
|
||||
else_cases.push_back(else_case);
|
||||
} else {
|
||||
if (lhs) {
|
||||
throw parser_error("invalid 'do' expression, unnecessary binder", lhs_pos);
|
||||
|
|
@ -623,8 +630,14 @@ static expr parse_do(parser & p, unsigned, expr const *, pos_info const & pos) {
|
|||
buffer<expr> locals;
|
||||
to_buffer(lhss_locals[i], locals);
|
||||
auto pos = p.pos_of(*lhs);
|
||||
expr eq = Fun(fn, Fun(locals, p.save_pos(mk_equation(mk_app(fn, *lhs), r), pos), p));
|
||||
expr eqns = p.save_pos(mk_equations(1, 1, &eq), pos);
|
||||
buffer<expr> eqs;
|
||||
eqs.push_back(Fun(fn, Fun(locals, p.save_pos(mk_equation(mk_app(fn, *lhs), r), pos), p)));
|
||||
if (optional<expr> else_case = else_cases[i]) {
|
||||
// add case
|
||||
// _ := else_case
|
||||
eqs.push_back(Fun(fn, p.save_pos(mk_equation(mk_app(fn, mk_expr_placeholder()), *else_case), pos)));
|
||||
}
|
||||
expr eqns = p.save_pos(mk_equations(1, eqs.size(), eqs.data()), pos);
|
||||
expr local = mk_local("p", mk_expr_placeholder());
|
||||
expr match = p.mk_app(eqns, local, pos);
|
||||
r = mk_app(mk_constant(get_monad_bind_name()), es[i], Fun(local, match));
|
||||
|
|
|
|||
15
tests/lean/run/do_match_else.lean
Normal file
15
tests/lean/run/do_match_else.lean
Normal file
|
|
@ -0,0 +1,15 @@
|
|||
open tactic
|
||||
|
||||
set_option pp.all true
|
||||
|
||||
example (a b c x y : nat) (H : nat.add (nat.add x y) y = 0) : true :=
|
||||
by do
|
||||
a ← get_local "a", b ← get_local "b", c ← get_local "c",
|
||||
nat_add : expr ← mk_const ("nat" <.> "add"),
|
||||
p : pattern ← mk_pattern [] [a, b] (nat_add a b) [nat_add b a, a, b],
|
||||
trace (pattern.output p),
|
||||
H ← get_local "H" >>= infer_type,
|
||||
lhs_rhs ← match_eq H,
|
||||
[v₁, v₂, v₃] ← match_pattern p (prod.pr1 lhs_rhs) | failed,
|
||||
trace v₁,
|
||||
constructor
|
||||
Loading…
Add table
Reference in a new issue