perf: add lean_instantiate_level_mvars (#4910)
The new code is not active yet because of bootstrapping issues. It requires an `update_stage0`.
This commit is contained in:
parent
647a5e9492
commit
1e9d96be22
6 changed files with 132 additions and 3 deletions
|
|
@ -336,6 +336,8 @@ structure MetavarContext where
|
|||
For more information about delayed abstraction, see the docstring for `DelayedMetavarAssignment`. -/
|
||||
dAssignment : PersistentHashMap MVarId DelayedMetavarAssignment := {}
|
||||
|
||||
instance : Inhabited MetavarContext := ⟨{}⟩
|
||||
|
||||
/-- A monad with a stateful metavariable context, defining `getMCtx` and `modifyMCtx`. -/
|
||||
class MonadMCtx (m : Type → Type) where
|
||||
getMCtx : m MetavarContext
|
||||
|
|
@ -358,15 +360,27 @@ abbrev setMCtx [MonadMCtx m] (mctx : MetavarContext) : m Unit :=
|
|||
abbrev getLevelMVarAssignment? [Monad m] [MonadMCtx m] (mvarId : LMVarId) : m (Option Level) :=
|
||||
return (← getMCtx).lAssignment.find? mvarId
|
||||
|
||||
@[export lean_get_lmvar_assignment]
|
||||
def getLevelMVarAssignmentExp (m : MetavarContext) (mvarId : LMVarId) : Option Level :=
|
||||
m.lAssignment.find? mvarId
|
||||
|
||||
def MetavarContext.getExprAssignmentCore? (m : MetavarContext) (mvarId : MVarId) : Option Expr :=
|
||||
m.eAssignment.find? mvarId
|
||||
|
||||
@[export lean_get_mvar_assignment]
|
||||
def MetavarContext.getExprAssignmentExp (m : MetavarContext) (mvarId : MVarId) : Option Expr :=
|
||||
m.eAssignment.find? mvarId
|
||||
|
||||
def getExprMVarAssignment? [Monad m] [MonadMCtx m] (mvarId : MVarId) : m (Option Expr) :=
|
||||
return (← getMCtx).getExprAssignmentCore? mvarId
|
||||
|
||||
def MetavarContext.getDelayedMVarAssignmentCore? (mctx : MetavarContext) (mvarId : MVarId) : Option DelayedMetavarAssignment :=
|
||||
mctx.dAssignment.find? mvarId
|
||||
|
||||
@[export lean_get_delayed_mvar_assignment]
|
||||
def MetavarContext.getDelayedMVarAssignmentExp (mctx : MetavarContext) (mvarId : MVarId) : Option DelayedMetavarAssignment :=
|
||||
mctx.dAssignment.find? mvarId
|
||||
|
||||
def getDelayedMVarAssignment? [Monad m] [MonadMCtx m] (mvarId : MVarId) : m (Option DelayedMetavarAssignment) :=
|
||||
return (← getMCtx).getDelayedMVarAssignmentCore? mvarId
|
||||
|
||||
|
|
@ -478,6 +492,10 @@ def hasAssignableMVar [Monad m] [MonadMCtx m] : Expr → m Bool
|
|||
def assignLevelMVar [MonadMCtx m] (mvarId : LMVarId) (val : Level) : m Unit :=
|
||||
modifyMCtx fun m => { m with lAssignment := m.lAssignment.insert mvarId val }
|
||||
|
||||
@[export lean_assign_lmvar]
|
||||
def assignLevelMVarExp (m : MetavarContext) (mvarId : LMVarId) (val : Level) : MetavarContext :=
|
||||
{ m with lAssignment := m.lAssignment.insert mvarId val }
|
||||
|
||||
/--
|
||||
Add `mvarId := x` to the metavariable assignment.
|
||||
This method does not check whether `mvarId` is already assigned, nor it checks whether
|
||||
|
|
@ -487,6 +505,10 @@ This is a low-level API, and it is safer to use `isDefEq (mkMVar mvarId) x`.
|
|||
def _root_.Lean.MVarId.assign [MonadMCtx m] (mvarId : MVarId) (val : Expr) : m Unit :=
|
||||
modifyMCtx fun m => { m with eAssignment := m.eAssignment.insert mvarId val }
|
||||
|
||||
@[export lean_assign_mvar]
|
||||
def assignExp (m : MetavarContext) (mvarId : MVarId) (val : Expr) : MetavarContext :=
|
||||
{ m with eAssignment := m.eAssignment.insert mvarId val }
|
||||
|
||||
/--
|
||||
Add a delayed assignment for the given metavariable. You must make sure that
|
||||
the metavariable is not already assigned or delayed-assigned.
|
||||
|
|
@ -516,6 +538,9 @@ To avoid this term eta-expanded term, we apply beta-reduction when instantiating
|
|||
This operation is performed at `instantiateExprMVars`, `elimMVarDeps`, and `levelMVarToParam`.
|
||||
-/
|
||||
|
||||
@[extern "lean_instantiate_level_mvars"]
|
||||
opaque instantiateLevelMVarsImp (mctx : MetavarContext) (l : Level) : MetavarContext × Level
|
||||
|
||||
partial def instantiateLevelMVars [Monad m] [MonadMCtx m] : Level → m Level
|
||||
| lvl@(Level.succ lvl₁) => return Level.updateSucc! lvl (← instantiateLevelMVars lvl₁)
|
||||
| lvl@(Level.max lvl₁ lvl₂) => return Level.updateMax! lvl (← instantiateLevelMVars lvl₁) (← instantiateLevelMVars lvl₂)
|
||||
|
|
@ -531,6 +556,9 @@ partial def instantiateLevelMVars [Monad m] [MonadMCtx m] : Level → m Level
|
|||
| none => pure lvl
|
||||
| lvl => pure lvl
|
||||
|
||||
@[extern "lean_instantiate_expr_mvars"]
|
||||
opaque instantiateExprMVarsImp (mctx : MetavarContext) (e : Expr) : MetavarContext × Expr
|
||||
|
||||
/-- instantiateExprMVars main function -/
|
||||
partial def instantiateExprMVars [Monad m] [MonadMCtx m] [STWorld ω m] [MonadLiftT (ST ω) m] (e : Expr) : MonadCacheT ExprStructEq Expr m Expr :=
|
||||
if !e.hasMVar then
|
||||
|
|
@ -792,8 +820,6 @@ def localDeclDependsOnPred [Monad m] [MonadMCtx m] (localDecl : LocalDecl) (pf :
|
|||
|
||||
namespace MetavarContext
|
||||
|
||||
instance : Inhabited MetavarContext := ⟨{}⟩
|
||||
|
||||
@[export lean_mk_metavar_ctx]
|
||||
def mkMetavarContext : Unit → MetavarContext := fun _ => {}
|
||||
|
||||
|
|
|
|||
|
|
@ -2,4 +2,4 @@ add_library(kernel OBJECT level.cpp expr.cpp expr_eq_fn.cpp
|
|||
for_each_fn.cpp replace_fn.cpp abstract.cpp instantiate.cpp
|
||||
local_ctx.cpp declaration.cpp environment.cpp type_checker.cpp
|
||||
init_module.cpp expr_cache.cpp equiv_manager.cpp quot.cpp
|
||||
inductive.cpp trace.cpp)
|
||||
inductive.cpp trace.cpp instantiate_mvars.cpp)
|
||||
|
|
|
|||
95
src/kernel/instantiate_mvars.cpp
Normal file
95
src/kernel/instantiate_mvars.cpp
Normal file
|
|
@ -0,0 +1,95 @@
|
|||
/*
|
||||
Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
|
||||
Authors: Leonardo de Moura
|
||||
*/
|
||||
#include <unordered_map>
|
||||
#include "runtime/option_ref.h"
|
||||
#include "kernel/instantiate.h"
|
||||
#include "kernel/abstract.h"
|
||||
|
||||
/*
|
||||
This module is not used by the kernel. It just provides an efficient implementation of
|
||||
`instantiateExprMVars`
|
||||
*/
|
||||
|
||||
namespace lean {
|
||||
|
||||
extern "C" object * lean_get_lmvar_assignment(obj_arg mctx, obj_arg mid);
|
||||
extern "C" object * lean_assign_lmvar(obj_arg mctx, obj_arg mid, obj_arg val);
|
||||
|
||||
typedef object_ref metavar_ctx;
|
||||
void assign_lmvar(metavar_ctx & mctx, name const & mid, level const & l) {
|
||||
object * r = lean_assign_lmvar(mctx.steal(), mid.to_obj_arg(), l.to_obj_arg());
|
||||
mctx.set_box(r);
|
||||
}
|
||||
|
||||
option_ref<level> get_lmvar_assignment(metavar_ctx & mctx, name const & mid) {
|
||||
return option_ref<level>(lean_get_lmvar_assignment(mctx.to_obj_arg(), mid.to_obj_arg()));
|
||||
}
|
||||
|
||||
class instantiate_lmvar_fn {
|
||||
metavar_ctx & m_mctx;
|
||||
std::unordered_map<lean_object *, lean_object *> m_cache;
|
||||
|
||||
inline level cache(level const & l, level && r, bool shared) {
|
||||
if (shared) {
|
||||
m_cache.insert(mk_pair(l.raw(), r.raw()));
|
||||
}
|
||||
return r;
|
||||
}
|
||||
public:
|
||||
instantiate_lmvar_fn(metavar_ctx & mctx):m_mctx(mctx) {}
|
||||
level visit(level const & l) {
|
||||
if (!has_mvar(l))
|
||||
return l;
|
||||
bool shared = false;
|
||||
if (is_shared(l)) {
|
||||
auto it = m_cache.find(l.raw());
|
||||
if (it != m_cache.end()) {
|
||||
return level(it->second, true);
|
||||
}
|
||||
shared = true;
|
||||
}
|
||||
switch (l.kind()) {
|
||||
case level_kind::Succ:
|
||||
return cache(l, update_succ(l, visit(succ_of(l))), shared);
|
||||
case level_kind::Max: case level_kind::IMax:
|
||||
return cache(l, update_max(l, visit(level_lhs(l)), visit(level_rhs(l))), shared);
|
||||
case level_kind::Zero: case level_kind::Param:
|
||||
lean_unreachable();
|
||||
case level_kind::MVar: {
|
||||
option_ref<level> r = get_lmvar_assignment(m_mctx, mvar_id(l));
|
||||
if (!r) {
|
||||
return l;
|
||||
} else {
|
||||
level a(r.get_val());
|
||||
if (!has_mvar(a)) {
|
||||
return a;
|
||||
} else {
|
||||
level a_new = visit(a);
|
||||
if (!is_eqp(a, a_new)) {
|
||||
assign_lmvar(m_mctx, mvar_id(l), a_new);
|
||||
}
|
||||
return a_new;
|
||||
}
|
||||
}
|
||||
}}
|
||||
}
|
||||
level operator()(level const & l) { return visit(l); }
|
||||
};
|
||||
|
||||
extern "C" LEAN_EXPORT object * lean_instantiate_level_mvars(object * m, object * l) {
|
||||
metavar_ctx mctx(m);
|
||||
level l_new = instantiate_lmvar_fn(mctx)(level(l));
|
||||
object * r = alloc_cnstr(0, 2, 0);
|
||||
cnstr_set(r, 0, mctx.steal());
|
||||
cnstr_set(r, 1, l_new.steal());
|
||||
return r;
|
||||
}
|
||||
|
||||
extern "C" LEAN_EXPORT object * lean_instantiate_expr_mvars(object *, object *) {
|
||||
lean_internal_panic("not implemented yet");
|
||||
}
|
||||
}
|
||||
|
|
@ -82,6 +82,8 @@ inline bool operator!=(level const & l1, level const & l2) { return !operator==(
|
|||
struct level_hash { unsigned operator()(level const & n) const { return n.hash(); } };
|
||||
struct level_eq { bool operator()(level const & n1, level const & n2) const { return n1 == n2; } };
|
||||
|
||||
inline bool is_shared(level const & l) { return !is_exclusive(l.raw()); }
|
||||
|
||||
inline optional<level> none_level() { return optional<level>(); }
|
||||
inline optional<level> some_level(level const & e) { return optional<level>(e); }
|
||||
inline optional<level> some_level(level && e) { return optional<level>(std::forward<level>(e)); }
|
||||
|
|
|
|||
|
|
@ -35,6 +35,10 @@ public:
|
|||
s.m_obj = box(0);
|
||||
return *this;
|
||||
}
|
||||
void set_box(object * o) {
|
||||
lean_assert(is_scalar(m_obj));
|
||||
m_obj = o;
|
||||
}
|
||||
object * raw() const { return m_obj; }
|
||||
object * steal() { object * r = m_obj; m_obj = box(0); return r; }
|
||||
object * to_obj_arg() const { inc(m_obj); return m_obj; }
|
||||
|
|
|
|||
|
|
@ -28,6 +28,8 @@ public:
|
|||
explicit operator bool() const { return !is_scalar(raw()); }
|
||||
optional<T> get() const { return *this ? some(static_cast<T const &>(cnstr_get_ref(*this, 0))) : optional<T>(); }
|
||||
|
||||
T get_val() const { lean_assert(*this); return static_cast<T const &>(cnstr_get_ref(*this, 0)); }
|
||||
|
||||
/** \brief Structural equality. */
|
||||
friend bool operator==(option_ref const & o1, option_ref const & o2) {
|
||||
return o1.get() == o2.get();
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue