refactor(frontends/lean/elaborator): mark thunk as opaque, and thunk A to A is now a coercion
@kha I was working in the new declaration type and using tasks there.
Since we don't have tasks yet in Lean, I decided to start refactoring
the `thunk` type. I defined it as:
```
-- TODO(Leo): mark as opaque, it is implemented by the new runtime
structure thunk (α : Type u) : Type u :=
(fn : unit → α)
def thunk.pure {α : Type u} (a : α) : thunk α :=
⟨λ _, a⟩
def thunk.get {α : Type u} (t : thunk α) : α :=
t.fn ()
```
The idea is to use the runtime primitives to implement them.
Then, I realized the support for `thunk`s in the elaborator are quite
hacky. Given `f x`, if `f`'s domain has type `thunk A`, we elaborate
`f x` as `f (fun _, x)` even if `x` has type `thunk A`.
This is quite bad, for example, suppose we have
```
def f (x : thunk A) := ...
```
Then, the following definition is type incorrect.
```
def g (x : thunk A) := f x
```
and we are forced to write
```
def g (x : thunk A) := f (x ())
```
The term `f (x ())` will be elaborated as `f (fun _, x ())` and an
unnecessary closure is created at runtime.
This mechanism inherited from Lean 3 is also incompatible with the
new thunk definition. Given `x : thunk A`, I want to write `x.get`
to retrieve the value instead of `x ()` as in Lean 3.
However, `x.get` expands into the nonsensical `(fun _, x).get`.
So, I decided to view the mapping `A` to `thunk A` as a "coercion".
I used double quotes, because it is a macro instead of a function.
If it were a coercion, then we would be using `thunk.pure` to coerce
values but this is not we want most of the time.
For example, given `f : thunk A -> B` and a term `t : A`, when we write
`f t`, we want it to be converted into `f (fun _, t)` instead of
`f (thunk.pure t)` which would eagerly compute `t`. The transformation
`t` into `fun _, t` is syntactic.
We cannot implement it using type classes. I implemented it as
a hard-coded extra case like the one from `Prop` to `bool`.
We can also add a coercion from `thunk A` to `A` to avoid the `.get`.
That being said, I had a few breakages in the code base since we only
use coercions when the given and expected type do not contain
metavariables.
This commit is contained in:
parent
dc477db71e
commit
261dc999d0
7 changed files with 41 additions and 62 deletions
|
|
@ -130,17 +130,15 @@ abbreviation unit : Type := punit
|
|||
|
||||
@[pattern] abbreviation unit.star : unit := punit.star
|
||||
|
||||
/--
|
||||
Gadget for defining thunks, thunk parameters have special treatment.
|
||||
Example: given
|
||||
def f (s : string) (t : thunk nat) : nat
|
||||
an application
|
||||
f "hello" 10
|
||||
is converted into
|
||||
f "hello" (λ _, 10)
|
||||
-/
|
||||
@[reducible] def thunk (α : Type u) : Type u :=
|
||||
unit → α
|
||||
-- TODO(Leo): mark as opaque, it is implemented by the new runtime
|
||||
structure thunk (α : Type u) : Type u :=
|
||||
(fn : unit → α)
|
||||
|
||||
def thunk.pure {α : Type u} (a : α) : thunk α :=
|
||||
⟨λ _, a⟩
|
||||
|
||||
def thunk.get {α : Type u} (t : thunk α) : α :=
|
||||
t.fn ()
|
||||
|
||||
inductive true : Prop
|
||||
| intro : true
|
||||
|
|
|
|||
|
|
@ -49,7 +49,7 @@ TODO: mark as `@[inline]` as soon as we fix the code inliner.
|
|||
-/
|
||||
private def merge (w : nat) (r₁ : space_result) (r₂ : thunk space_result) : space_result :=
|
||||
if r₁.exceeded || r₁.found then r₁
|
||||
else let y := r₂ () in
|
||||
else let y := r₂.get in
|
||||
if y.exceeded || y.found then y
|
||||
else let new_space := r₁.space + y.space in
|
||||
{ space := new_space, exceeded := new_space > w }
|
||||
|
|
|
|||
|
|
@ -194,11 +194,12 @@ def not_followed_by (p : parsec' α) (msg : string := "input") : parsec' unit :=
|
|||
|
||||
def parsec.dbg (label : string) (p : parsec μ α) : parsec μ α :=
|
||||
λ it, trace ("DBG " ++ label ++ ": '" ++ (it.extract (it.nextn 40)).get_or_else "" ++ "'") $
|
||||
match p it with
|
||||
| ok a it' := trace ("consumed ok : '" ++ (it.extract it').get_or_else "" ++ "'") $ ok a it'
|
||||
| ok_eps a it' ex := trace ("empty ok : '" ++ (it.extract it').get_or_else "" ++ "'") $ ok_eps a it' ex
|
||||
| error msg tt := trace ("consumed error : '" ++ (it.extract msg.it).get_or_else "" ++ "'\n" ++ to_string msg) $ error msg tt
|
||||
| error msg ff := trace ("empty error : '" ++ (it.extract msg.it).get_or_else "" ++ "'\n" ++ to_string msg) $ error msg ff
|
||||
match p it : _ → result μ α with
|
||||
| ok a it' := trace ("consumed ok : '" ++ (it.extract it').get_or_else "" ++ "'") $ @ok μ α a it'
|
||||
| ok_eps a it' ex := trace ("empty ok : '" ++ (it.extract it').get_or_else "" ++ "'") $ @ok_eps μ α a it' ex
|
||||
| error msg tt := trace ("consumed error : '" ++ (it.extract msg.it).get_or_else "" ++ "'\n" ++ to_string msg) $ @error μ α msg tt
|
||||
| error msg ff := trace ("empty error : '" ++ (it.extract msg.it).get_or_else "" ++ "'\n" ++ to_string msg) $ @error μ α msg ff
|
||||
|
||||
end parsec
|
||||
|
||||
/- Type class for abstracting from concrete monad stacks containing a `parsec` somewhere. -/
|
||||
|
|
|
|||
|
|
@ -45,26 +45,26 @@ class monad_tracer (m : Type → Type u) :=
|
|||
export monad_tracer (trace_root trace_ctx)
|
||||
|
||||
def trace {m} [monad m] [monad_tracer m] (cls : name) (msg : message) : m unit :=
|
||||
trace_ctx cls msg (pure ())
|
||||
trace_ctx cls msg (pure () : m unit)
|
||||
|
||||
instance (m) [monad m] : monad_tracer (trace_t m) :=
|
||||
{ trace_root := λ α pos cls msg ctx, do {
|
||||
st ← get,
|
||||
if st.opts.get_bool cls = some tt then do {
|
||||
modify $ λ st, {cur_pos := pos, cur_traces := [], ..st},
|
||||
a ← ctx (),
|
||||
a ← ctx.get,
|
||||
modify $ λ (st : trace_state), {roots := st.roots.insert pos ⟨msg, st.cur_traces⟩, ..st},
|
||||
pure a
|
||||
} else ctx ()
|
||||
} else ctx.get
|
||||
},
|
||||
trace_ctx := λ α cls msg ctx, do {
|
||||
st ← get,
|
||||
-- tracing enabled?
|
||||
some _ ← pure st.cur_pos | ctx (),
|
||||
some _ ← pure st.cur_pos | ctx.get,
|
||||
-- trace class enabled?
|
||||
if st.opts.get_bool cls = some tt then do {
|
||||
put {cur_traces := [], ..st},
|
||||
a ← ctx (),
|
||||
a ← ctx.get,
|
||||
modify $ λ (st' : trace_state), {cur_traces := st.cur_traces ++ [⟨msg, st'.cur_traces⟩], ..st'},
|
||||
pure a
|
||||
} else
|
||||
|
|
@ -72,7 +72,7 @@ instance (m) [monad m] : monad_tracer (trace_t m) :=
|
|||
adapt_state'
|
||||
(λ _, {cur_pos := none, ..st})
|
||||
(λ st', {cur_pos := st.cur_pos, ..st'})
|
||||
(ctx ())
|
||||
ctx.get
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -221,7 +221,7 @@ meta def trace_call_stack : tactic unit :=
|
|||
assume state, _root_.trace_call_stack (success () state)
|
||||
|
||||
meta def timetac {α : Type u} (desc : string) (t : thunk (tactic α)) : tactic α :=
|
||||
λ s, timeit desc (t () s)
|
||||
λ s, timeit desc (t.get s)
|
||||
|
||||
meta def trace_state : tactic unit :=
|
||||
do s ← read,
|
||||
|
|
@ -501,7 +501,7 @@ meta def step {α : Type u} (t : tactic α) : tactic unit :=
|
|||
t >>[tactic] cleanup
|
||||
|
||||
meta def istep {α : Type u} (line0 col0 : ℕ) (line col : ℕ) (t : tactic α) : tactic unit :=
|
||||
λ s, (@scope_trace _ line col (λ _, step t s)).clamp_pos line0 line col
|
||||
λ s, (@scope_trace _ line col ⟨λ _, step t s⟩).clamp_pos line0 line col
|
||||
|
||||
meta def is_prop (e : expr) : tactic bool :=
|
||||
do t ← infer_type e,
|
||||
|
|
@ -1039,7 +1039,7 @@ do ns ← open_namespaces,
|
|||
long running tactics. -/
|
||||
meta def try_for {α} (max : nat) (tac : tactic α) : tactic α :=
|
||||
λ s,
|
||||
match _root_.try_for max (tac s) with
|
||||
match _root_.try_for max ⟨λ _, tac s⟩ with
|
||||
| some r := r
|
||||
| none := mk_exception "try_for tactic failed, timeout" none s
|
||||
|
||||
|
|
|
|||
|
|
@ -9,23 +9,23 @@ universes u
|
|||
|
||||
/-- This function has a native implementation that tracks time. -/
|
||||
def timeit {α : Type u} (s : string) (f : thunk α) : α :=
|
||||
f ()
|
||||
f.get
|
||||
|
||||
/-- This function has a native implementation that displays the given string in the regular output stream. -/
|
||||
def trace {α : Type u} (s : string) (f : thunk α) : α :=
|
||||
f ()
|
||||
f.get
|
||||
|
||||
meta def trace_val {α : Type u} [has_to_format α] (f : α) : α :=
|
||||
trace (to_fmt f).to_string f
|
||||
|
||||
/-- This function has a native implementation that shows the VM call stack. -/
|
||||
def trace_call_stack {α : Type u} (f : thunk α) : α :=
|
||||
f ()
|
||||
f.get
|
||||
|
||||
/-- This function has a native implementation that displays in the given position all trace messages used in f.
|
||||
The arguments line and col are filled by the elaborator. -/
|
||||
def scope_trace {α : Type u} {line col: nat} (f : thunk α) : α :=
|
||||
f ()
|
||||
f.get
|
||||
|
||||
/--
|
||||
This function has a native implementation where
|
||||
|
|
@ -33,7 +33,7 @@ f ()
|
|||
The heartbeat is approx. the maximum number of memory allocations (in thousands) performed by 'f ()'.
|
||||
This is a deterministic way of interrupting long running tasks. -/
|
||||
meta def try_for {α : Type u} (max : nat) (f : thunk α) : option α :=
|
||||
some (f ())
|
||||
some f.get
|
||||
|
||||
meta constant undefined_core {α : Sort u} (message : string) : α
|
||||
|
||||
|
|
|
|||
|
|
@ -555,9 +555,17 @@ optional<expr> elaborator::mk_Prop_to_bool_coercion(expr const & e, expr const &
|
|||
return some_expr(r);
|
||||
}
|
||||
|
||||
static bool is_thunk(expr const & e) {
|
||||
return is_app_of(e, get_thunk_name(), 1);
|
||||
}
|
||||
|
||||
optional<expr> elaborator::mk_coercion_core(expr const & e, expr const & e_type, expr const & type, expr const & ref) {
|
||||
if (e_type == mk_Prop() && m_ctx.is_def_eq(type, mk_bool())) {
|
||||
return mk_Prop_to_bool_coercion(e, ref);
|
||||
} else if (is_thunk(type) && m_ctx.is_def_eq(e_type, app_arg(type))) {
|
||||
return some_expr(::lean::mk_app(mk_constant(name{"thunk", "mk"}, const_levels(app_fn(type))),
|
||||
app_arg(type),
|
||||
::lean::mk_lambda("_", mk_constant(get_unit_name()), e)));
|
||||
} else if (!has_expr_metavar(e_type) && !has_expr_metavar(type)) {
|
||||
expr has_coe_t;
|
||||
try {
|
||||
|
|
@ -1189,21 +1197,6 @@ struct elaborator::first_pass_info {
|
|||
buffer<expr> eta_args;
|
||||
};
|
||||
|
||||
static optional<expr> is_thunk(expr const & e) {
|
||||
if (is_app_of(e, get_thunk_name(), 1)) {
|
||||
return some_expr(app_arg(e));
|
||||
} else {
|
||||
return none_expr();
|
||||
}
|
||||
}
|
||||
|
||||
static expr mk_thunk_if_needed(expr const & e, optional<expr> const & is_thunk) {
|
||||
if (is_thunk)
|
||||
return mk_lambda("_", mk_constant(get_unit_name()), e);
|
||||
else
|
||||
return e;
|
||||
}
|
||||
|
||||
expr elaborator::mk_auto_param(expr const & name_lit, expr const & expected_type, expr const & ref) {
|
||||
auto c = name_lit_to_name(name_lit);
|
||||
if (!c)
|
||||
|
|
@ -1337,14 +1330,8 @@ void elaborator::first_pass(expr const & fn, buffer<expr> const & args,
|
|||
See discussion at #1403
|
||||
*/
|
||||
new_arg = get_as_is_arg(args[i]);
|
||||
optional<expr> thunk_of;
|
||||
if (!m_in_pattern) thunk_of = is_thunk(d);
|
||||
expr arg_expected_type;
|
||||
if (thunk_of)
|
||||
arg_expected_type = *thunk_of;
|
||||
else
|
||||
arg_expected_type = d;
|
||||
new_arg = mk_thunk_if_needed(new_arg, thunk_of);
|
||||
arg_expected_type = d;
|
||||
expr new_arg_type = infer_type(new_arg);
|
||||
optional<expr> new_new_arg = ensure_has_type(new_arg, new_arg_type, arg_expected_type, arg_ref);
|
||||
if (!new_new_arg) {
|
||||
|
|
@ -1401,11 +1388,8 @@ void elaborator::first_pass(expr const & fn, buffer<expr> const & args,
|
|||
}
|
||||
|
||||
std::tuple<expr, expr, optional<expr>> elaborator::elaborate_arg(expr const & arg, expr const & expected_type, expr const & ref) {
|
||||
optional<expr> thunk_of;
|
||||
if (!m_in_pattern) thunk_of = is_thunk(expected_type);
|
||||
expr aux_expected_type = thunk_of ? *thunk_of : expected_type;
|
||||
expr aux_expected_type = expected_type;
|
||||
expr new_arg = visit(arg, some_expr(aux_expected_type));
|
||||
new_arg = mk_thunk_if_needed(new_arg, thunk_of);
|
||||
expr new_arg_type = infer_type(new_arg);
|
||||
return std::make_tuple(new_arg, new_arg_type, ensure_has_type(new_arg, new_arg_type, expected_type, ref));
|
||||
}
|
||||
|
|
@ -1509,13 +1493,10 @@ expr elaborator::visit_base_app_simple(expr const & _fn, arg_mask amask, buffer<
|
|||
new_arg = post_process_implicit_arg(new_arg, ref);
|
||||
} else if (i < args.size()) {
|
||||
expr expected_type = d;
|
||||
optional<expr> thunk_of;
|
||||
if (!m_in_pattern && amask == arg_mask::Default) thunk_of = is_thunk(d);
|
||||
if (thunk_of) expected_type = *thunk_of;
|
||||
// explicit argument
|
||||
expr ref_arg = get_ref_for_child(args[i], ref);
|
||||
if (args_already_visited) {
|
||||
new_arg = mk_thunk_if_needed(args[i], thunk_of);
|
||||
new_arg = args[i];
|
||||
} else if (is_inst_implicit(bi) && is_placeholder(args[i])) {
|
||||
lean_assert(amask != arg_mask::Default);
|
||||
/* If '@' or '@@' have been used, and the argument is '_', then
|
||||
|
|
@ -1523,7 +1504,6 @@ expr elaborator::visit_base_app_simple(expr const & _fn, arg_mask amask, buffer<
|
|||
new_arg = mk_instance(d, ref);
|
||||
} else {
|
||||
new_arg = visit(args[i], some_expr(expected_type));
|
||||
new_arg = mk_thunk_if_needed(new_arg, thunk_of);
|
||||
}
|
||||
expr new_arg_type = infer_type(new_arg);
|
||||
if (optional<expr> new_new_arg = ensure_has_type(new_arg, new_arg_type, d, ref_arg)) {
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue