diff --git a/src/library/compiler/util.cpp b/src/library/compiler/util.cpp index 4d46bf7904..5f77bdc781 100644 --- a/src/library/compiler/util.cpp +++ b/src/library/compiler/util.cpp @@ -508,32 +508,30 @@ expr mk_runtime_type(type_checker::state & st, local_ctx const & lctx, expr e) { /* If `e` is a trivial structure such as `Subtype`, then convert the only relevant field to a runtime type. */ - if (is_app(e)) { - expr const & fn = get_app_fn(e); - if (is_constant(fn) && is_inductive(st.env(), const_name(fn))) { - name const & I_name = const_name(fn); - environment const & env = st.env(); - if (optional fidx = has_trivial_structure(env, I_name)) { - /* Retrieve field `fidx` type */ - inductive_val I_val = env.get(I_name).to_inductive_val(); - name K = head(I_val.get_cnstrs()); - unsigned nparams = I_val.get_nparams(); - buffer e_args; - get_app_args(e, e_args); - lean_assert(nparams <= e_args.size()); - expr k_app = mk_app(mk_constant(K, const_levels(fn)), nparams, e_args.data()); - expr type = tc.whnf(tc.infer(k_app)); - local_ctx aux_lctx = lctx; - unsigned idx = 0; - while (is_pi(type)) { - if (idx == *fidx) { - return mk_runtime_type(st, aux_lctx, binding_domain(type)); - } - expr local = aux_lctx.mk_local_decl(st.ngen(), binding_name(type), binding_domain(type), binding_info(type)); - type = instantiate(binding_body(type), local); - type = type_checker(st, aux_lctx).whnf(type); - idx++; + expr const & fn = get_app_fn(e); + if (is_constant(fn) && is_inductive(st.env(), const_name(fn))) { + name const & I_name = const_name(fn); + environment const & env = st.env(); + if (optional fidx = has_trivial_structure(env, I_name)) { + /* Retrieve field `fidx` type */ + inductive_val I_val = env.get(I_name).to_inductive_val(); + name K = head(I_val.get_cnstrs()); + unsigned nparams = I_val.get_nparams(); + buffer e_args; + get_app_args(e, e_args); + lean_assert(nparams <= e_args.size()); + expr k_app = mk_app(mk_constant(K, const_levels(fn)), nparams, e_args.data()); + expr type = tc.whnf(tc.infer(k_app)); + local_ctx aux_lctx = lctx; + unsigned idx = 0; + while (is_pi(type)) { + if (idx == *fidx) { + return mk_runtime_type(st, aux_lctx, binding_domain(type)); } + expr local = aux_lctx.mk_local_decl(st.ngen(), binding_name(type), binding_domain(type), binding_info(type)); + type = instantiate(binding_body(type), local); + type = type_checker(st, aux_lctx).whnf(type); + idx++; } } } diff --git a/tests/lean/unboxStruct.lean b/tests/lean/unboxStruct.lean new file mode 100644 index 0000000000..7b6fccb237 --- /dev/null +++ b/tests/lean/unboxStruct.lean @@ -0,0 +1,10 @@ + +structure AddrSpace where + index : UInt32 + +@[extern "foo"] +constant foo (addrSpace : AddrSpace) : IO PUnit + +set_option trace.compiler.ir.result true in +-- should accept and pass an unboxed `uint32` +def test2 : AddrSpace → IO PUnit := foo diff --git a/tests/lean/unboxStruct.lean.expected.out b/tests/lean/unboxStruct.lean.expected.out new file mode 100644 index 0000000000..feffd37563 --- /dev/null +++ b/tests/lean/unboxStruct.lean.expected.out @@ -0,0 +1,10 @@ + +[result] +def test2 (x_1 : u32) (x_2 : obj) : obj := + let x_3 : obj := foo x_1 x_2; + ret x_3 +def test2._boxed (x_1 : obj) (x_2 : obj) : obj := + let x_3 : u32 := unbox x_1; + dec x_1; + let x_4 : obj := test2 x_3 x_2; + ret x_4