lean4-htt/library/tools/super/prover_state.lean

454 lines
15 KiB
Text
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

/-
Copyright (c) 2016 Gabriel Ebner. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Gabriel Ebner
-/
import .clause .lpo .cdcl_solver
open tactic monad expr
namespace super
structure score :=
(priority : )
(in_sos : bool)
(cost : )
(age : )
namespace score
def prio.immediate : := 0
def prio.default : := 1
def prio.never : := 2
def sched_default (sc : score) : score := { sc with priority := prio.default }
def sched_now (sc : score) : score := { sc with priority := prio.immediate }
def inc_cost (sc : score) (n : ) : score := { sc with cost := sc^.cost + n }
def min (a b : score) : score :=
{ priority := nat.min a^.priority b^.priority,
in_sos := a^.in_sos && b^.in_sos,
cost := nat.min a^.cost b^.cost,
age := nat.min a^.age b^.age }
def combine (a b : score) : score :=
{ priority := nat.max a^.priority b^.priority,
in_sos := a^.in_sos && b^.in_sos,
cost := a^.cost + b^.cost,
age := nat.max a^.age b^.age }
end score
namespace score
meta instance : has_to_string score :=
⟨λe, "[" ++ to_string e^.priority ++
"," ++ to_string e^.cost ++
"," ++ to_string e^.age ++
",sos=" ++ to_string e^.in_sos ++ "]"⟩
end score
def clause_id :=
namespace clause_id
def to_nat (id : clause_id) : := id
instance : decidable_eq clause_id := nat.decidable_eq
instance : has_ordering clause_id := nat.has_ordering
end clause_id
meta structure derived_clause :=
(id : clause_id)
(c : clause)
(selected : list )
(assertions : list expr)
(sc : score)
namespace derived_clause
meta instance : has_to_tactic_format derived_clause :=
⟨λc, do
prf_fmt ← pp c^.c^.proof,
c_fmt ← pp c^.c,
ass_fmt ← pp (c^.assertions^.map (λa, a^.local_type)),
return $
to_string c^.sc ++ " " ++
prf_fmt ++ " " ++
c_fmt ++ " <- " ++ ass_fmt ++
" (selected: " ++ to_fmt c^.selected ++
")"
meta def clause_with_assertions (ac : derived_clause) : clause :=
ac^.c^.close_constn ac^.assertions
meta def update_proof (dc : derived_clause) (p : expr) : derived_clause :=
{ dc with c := { (dc^.c) with proof := p } }
end derived_clause
meta structure locked_clause :=
(dc : derived_clause)
(reasons : list (list expr))
namespace locked_clause
meta instance : has_to_tactic_format locked_clause :=
⟨λc, do
c_fmt ← pp c^.dc,
reasons_fmt ← pp (c^.reasons^.map (λr, r^.for (λa, a^.local_type))),
return $ c_fmt ++ " (locked in case of: " ++ reasons_fmt ++ ")"
end locked_clause
meta structure prover_state :=
(active : rb_map clause_id derived_clause)
(passive : rb_map clause_id derived_clause)
(newly_derived : list derived_clause)
(prec : list expr)
(locked : list locked_clause)
(local_false : expr)
(sat_solver : cdcl.state)
(current_model : rb_map expr bool)
(sat_hyps : rb_map expr (expr × expr))
(needs_sat_run : bool)
(clause_counter : nat)
open prover_state
private meta def join_with_nl : list format → format :=
list.foldl (λx y, x ++ format.line ++ y) format.nil
private meta def prover_state_tactic_fmt (s : prover_state) : tactic format := do
active_fmts ← mapm pp $ rb_map.values s^.active,
passive_fmts ← mapm pp $ rb_map.values s^.passive,
new_fmts ← mapm pp s^.newly_derived,
locked_fmts ← mapm pp s^.locked,
sat_fmts ← mapm pp s^.sat_solver^.clauses,
sat_model_fmts ← for s^.current_model^.to_list (λx, if x.2 = tt then pp x.1 else pp ```(not %%x.1)),
prec_fmts ← mapm pp s^.prec,
return (join_with_nl
([to_fmt "active:"] ++ map (append (to_fmt " ")) active_fmts ++
[to_fmt "passive:"] ++ map (append (to_fmt " ")) passive_fmts ++
[to_fmt "new:"] ++ map (append (to_fmt " ")) new_fmts ++
[to_fmt "locked:"] ++ map (append (to_fmt " ")) locked_fmts ++
[to_fmt "sat formulas:"] ++ map (append (to_fmt " ")) sat_fmts ++
[to_fmt "sat model:"] ++ map (append (to_fmt " ")) sat_model_fmts ++
[to_fmt "precedence order: " ++ to_fmt prec_fmts]))
meta instance : has_to_tactic_format prover_state :=
⟨prover_state_tactic_fmt⟩
meta def prover := state_t prover_state tactic
namespace prover
meta instance : monad prover := state_t.monad _ _
meta instance : has_monad_lift tactic prover :=
monad.monad_transformer_lift (state_t prover_state) tactic
meta instance (α : Type) : has_coe (tactic α) (prover α) :=
⟨monad.monad_lift⟩
meta def fail {α β : Type} [has_to_format β] (msg : β) : prover α :=
tactic.fail msg
meta def orelse (A : Type) (p1 p2 : prover A) : prover A :=
take state, p1 state <|> p2 state
meta instance : alternative prover :=
{ prover.monad with
failure := λα, fail "failed",
orelse := orelse }
end prover
meta def selection_strategy := derived_clause → prover derived_clause
meta def get_active : prover (rb_map clause_id derived_clause) :=
do state ← state_t.read, return state^.active
meta def add_active (a : derived_clause) : prover unit :=
do state ← state_t.read,
state_t.write { state with active := state^.active^.insert a^.id a }
meta def get_passive : prover (rb_map clause_id derived_clause) :=
lift passive state_t.read
meta def get_precedence : prover (list expr) :=
do state ← state_t.read, return state^.prec
meta def get_term_order : prover (expr → expr → bool) := do
state ← state_t.read,
return $ mk_lpo (map name_of_funsym state^.prec)
private meta def set_precedence (new_prec : list expr) : prover unit :=
do state ← state_t.read, state_t.write { state with prec := new_prec }
meta def register_consts_in_precedence (consts : list expr) := do
p ← get_precedence,
p_set ← return (rb_map.set_of_list (map name_of_funsym p)),
new_syms ← return $ list.filter (λc, ¬p_set^.contains (name_of_funsym c)) consts,
set_precedence (new_syms ++ p)
meta def in_sat_solver {A} (cmd : cdcl.solver A) : prover A := do
state ← state_t.read,
result ← cmd state^.sat_solver,
state_t.write { state with sat_solver := result.2 },
return result.1
meta def collect_ass_hyps (c : clause) : prover (list expr) :=
let lcs := contained_lconsts c^.proof in
do st ← state_t.read,
return (do
hs ← st^.sat_hyps^.values,
h ← [hs.1, hs.2],
guard $ lcs^.contains h^.local_uniq_name,
[h])
meta def get_clause_count : prover :=
do s ← state_t.read, return s^.clause_counter
meta def get_new_cls_id : prover clause_id := do
state ← state_t.read,
state_t.write { state with clause_counter := state^.clause_counter + 1 },
return state^.clause_counter
meta def mk_derived (c : clause) (sc : score) : prover derived_clause := do
ass ← collect_ass_hyps c,
id ← get_new_cls_id,
return { id := id, c := c, selected := [], assertions := ass, sc := sc }
meta def add_inferred (c : derived_clause) : prover unit := do
c' ← c^.c^.normalize, c' ← return { c with c := c' },
register_consts_in_precedence (contained_funsyms c'^.c^.type)^.values,
state ← state_t.read,
state_t.write { state with newly_derived := c' :: state^.newly_derived }
-- FIXME: what if we've seen the variable before, but with a weaker score?
meta def mk_sat_var (v : expr) (suggested_ph : bool) (suggested_ev : score) : prover unit :=
do st ← state_t.read, if st^.sat_hyps^.contains v then return () else do
hpv ← mk_local_def `h v,
hnv ← mk_local_def `hn $ imp v st^.local_false,
state_t.modify $ λst, { st with sat_hyps := st^.sat_hyps^.insert v (hpv, hnv) },
in_sat_solver $ cdcl.mk_var_core v suggested_ph,
match v with
| (pi _ _ _ _) := do
c ← clause.of_proof st^.local_false hpv,
mk_derived c suggested_ev >>= add_inferred
| _ := do cp ← clause.of_proof st^.local_false hpv, mk_derived cp suggested_ev >>= add_inferred,
cn ← clause.of_proof st^.local_false hnv, mk_derived cn suggested_ev >>= add_inferred
end
meta def get_sat_hyp_core (v : expr) (ph : bool) : prover (option expr) :=
flip monad.lift state_t.read $ λst,
match st^.sat_hyps^.find v with
| some (hp, hn) := some $ if ph then hp else hn
| none := none
end
meta def get_sat_hyp (v : expr) (ph : bool) : prover expr :=
do hyp_opt ← get_sat_hyp_core v ph,
match hyp_opt with
| some hyp := return hyp
| none := fail $ "unknown sat variable: " ++ v^.to_string
end
meta def add_sat_clause (c : clause) (suggested_ev : score) : prover unit := do
c ← c^.distinct,
already_added ← flip monad.lift state_t.read $ λst, decidable.to_bool $
c^.type ∈ st^.sat_solver^.clauses^.map (λd, d^.type),
if already_added then return () else do
for c^.get_lits $ λl, mk_sat_var l^.formula l^.is_neg suggested_ev,
in_sat_solver $ cdcl.mk_clause c,
state_t.modify $ λst, { st with needs_sat_run := tt }
meta def sat_eval_lit (v : expr) (pol : bool) : prover bool :=
do v_st ← flip monad.lift state_t.read $ λst, st^.current_model^.find v,
match v_st with
| some ph := return $ if pol then ph else bnot ph
| none := return tt
end
meta def sat_eval_assertion (assertion : expr) : prover bool :=
do lf ← flip monad.lift state_t.read $ λst, st^.local_false,
match is_local_not lf assertion^.local_type with
| some v := sat_eval_lit v ff
| none := sat_eval_lit assertion^.local_type tt
end
meta def sat_eval_assertions : list expr → prover bool
| (a::ass) := do v_a ← sat_eval_assertion a,
if v_a then
sat_eval_assertions ass
else
return ff
| [] := return tt
private meta def intern_clause (c : derived_clause) : prover derived_clause := do
hyp_name ← get_unused_name (mk_simple_name $ "clause_" ++ to_string c^.id^.to_nat) none,
c' ← return $ c^.c^.close_constn c^.assertions,
assertv hyp_name c'^.type c'^.proof,
proof' ← get_local hyp_name,
type ← infer_type proof', -- FIXME: otherwise ""
return $ c^.update_proof $ app_of_list proof' c^.assertions
meta def register_as_passive (c : derived_clause) : prover unit := do
c ← intern_clause c,
ass_v ← sat_eval_assertions c^.assertions,
if c^.c^.num_quants = 0 ∧ c^.c^.num_lits = 0 then
add_sat_clause c^.clause_with_assertions c^.sc
else if ¬ass_v then do
state_t.modify $ λst, { st with locked := ⟨c, []⟩ :: st^.locked }
else do
state_t.modify $ λst, { st with passive := st^.passive^.insert c^.id c }
meta def remove_passive (id : clause_id) : prover unit :=
do state ← state_t.read, state_t.write { state with passive := state^.passive^.erase id }
meta def move_locked_to_passive : prover unit := do
locked ← flip monad.lift state_t.read (λst, st^.locked),
new_locked ← flip filter locked (λlc, do
reason_vals ← mapm sat_eval_assertions lc^.reasons,
c_val ← sat_eval_assertions lc^.dc^.assertions,
if reason_vals^.for_all (λr, r = ff) ∧ c_val then do
state_t.modify $ λst, { st with passive := st^.passive^.insert lc^.dc^.id lc^.dc },
return ff
else
return tt
),
state_t.modify $ λst, { st with locked := new_locked }
meta def move_active_to_locked : prover unit :=
do active ← get_active, for' active^.values $ λac, do
c_val ← sat_eval_assertions ac^.assertions,
if ¬c_val then do
state_t.modify $ λst, { st with
active := st^.active^.erase ac^.id,
locked := ⟨ac, []⟩ :: st^.locked
}
else
return ()
meta def move_passive_to_locked : prover unit :=
do passive ← flip monad.lift state_t.read $ λst, st^.passive, for' passive^.to_list $ λpc, do
c_val ← sat_eval_assertions pc.2^.assertions,
if ¬c_val then do
state_t.modify $ λst, { st with
passive := st^.passive^.erase pc.1,
locked := ⟨pc.2, []⟩ :: st^.locked
}
else
return ()
def super_cc_config : cc_config :=
{ em := ff }
meta def do_sat_run : prover (option expr) :=
do sat_result ← in_sat_solver $ cdcl.run (cdcl.theory_solver_of_tactic $ using_smt $ return ()),
state_t.modify $ λst, { st with needs_sat_run := ff },
old_model ← lift prover_state.current_model state_t.read,
match sat_result with
| (cdcl.result.unsat proof) := return (some proof)
| (cdcl.result.sat new_model) := do
state_t.modify $ λst, { st with current_model := new_model },
move_locked_to_passive,
move_active_to_locked,
move_passive_to_locked,
return none
end
meta def take_newly_derived : prover (list derived_clause) := do
state ← state_t.read,
state_t.write { state with newly_derived := [] },
return state^.newly_derived
meta def remove_redundant (id : clause_id) (parents : list derived_clause) : prover unit := do
when (not $ parents^.for_all $ λp, p^.id ≠ id) (fail "clause is redundant because of itself"),
red ← flip monad.lift state_t.read (λst, st^.active^.find id),
match red with
| none := return ()
| some red := do
let reasons := parents^.map (λp, p^.assertions),
let assertion := red^.assertions,
if reasons^.for_all $ λr, r^.subset_of assertion then do
state_t.modify $ λst, { st with active := st^.active^.erase id }
else do
state_t.modify $ λst, { st with active := st^.active^.erase id,
locked := ⟨red, reasons⟩ :: st^.locked }
end
meta def inference := derived_clause → prover unit
meta structure inf_decl := (prio : ) (inf : inference)
@[user_attribute]
meta def inf_attr : user_attribute :=
⟨ `super.inf, "inference for the super prover" ⟩
meta def seq_inferences : list inference → inference
| [] := λgiven, return ()
| (inf::infs) := λgiven, do
inf given,
now_active ← get_active,
if rb_map.contains now_active given^.id then
seq_inferences infs given
else
return ()
meta def simp_inference (simpl : derived_clause → prover (option clause)) : inference :=
λgiven, do maybe_simpld ← simpl given,
match maybe_simpld with
| some simpld := do
derived_simpld ← mk_derived simpld given^.sc^.sched_now,
add_inferred derived_simpld,
remove_redundant given^.id []
| none := return ()
end
meta def preprocessing_rule (f : list derived_clause → prover (list derived_clause)) : prover unit := do
state ← state_t.read,
newly_derived' ← f state^.newly_derived,
state' ← state_t.read,
state_t.write { state' with newly_derived := newly_derived' }
meta def clause_selection_strategy := → prover clause_id
namespace prover_state
meta def empty (local_false : expr) : prover_state :=
{ active := rb_map.mk _ _, passive := rb_map.mk _ _,
newly_derived := [], prec := [], clause_counter := 0,
local_false := local_false,
locked := [], sat_solver := cdcl.state.initial local_false,
current_model := rb_map.mk _ _, sat_hyps := rb_map.mk _ _, needs_sat_run := ff }
meta def initial (local_false : expr) (clauses : list clause) : tactic prover_state := do
after_setup ← for' clauses (λc,
let in_sos := ((contained_lconsts c^.proof)^.erase local_false^.local_uniq_name)^.size = 0 in
do mk_derived c { priority := score.prio.immediate, in_sos := in_sos,
age := 0, cost := 0 } >>= add_inferred
) $ empty local_false,
return after_setup.2
end prover_state
meta def inf_score (add_cost : ) (scores : list score) : prover score := do
age ← get_clause_count,
return $ list.foldl score.combine { priority := score.prio.default,
in_sos := tt,
age := age,
cost := add_cost
} scores
meta def inf_if_successful (add_cost : ) (parent : derived_clause) (tac : tactic (list clause)) : prover unit :=
(do inferred ← tac,
for' inferred $ λc,
inf_score add_cost [parent^.sc] >>= mk_derived c >>= add_inferred)
<|> return ()
meta def simp_if_successful (parent : derived_clause) (tac : tactic (list clause)) : prover unit :=
(do inferred ← tac,
for' inferred $ λc,
mk_derived c parent^.sc^.sched_now >>= add_inferred,
remove_redundant parent^.id [])
<|> return ()
end super