431 lines
14 KiB
Text
431 lines
14 KiB
Text
/-
|
||
Copyright (c) 2016 Gabriel Ebner. All rights reserved.
|
||
Released under Apache 2.0 license as described in the file LICENSE.
|
||
Authors: Gabriel Ebner
|
||
-/
|
||
import .clause
|
||
open tactic expr monad super
|
||
|
||
namespace cdcl
|
||
|
||
@[reducible] meta def prop_var := expr
|
||
@[reducible] meta def proof_term := expr
|
||
@[reducible] meta def proof_hyp := expr
|
||
|
||
meta inductive trail_elem
|
||
| dec : prop_var → bool → proof_hyp → trail_elem
|
||
| propg : prop_var → bool → proof_term → proof_hyp → trail_elem
|
||
| dbl_neg_propg : prop_var → bool → proof_term → proof_hyp → trail_elem
|
||
|
||
namespace trail_elem
|
||
|
||
meta def var : trail_elem → prop_var
|
||
| (dec v _ _) := v
|
||
| (propg v _ _ _) := v
|
||
| (dbl_neg_propg v _ _ _) := v
|
||
|
||
meta def phase : trail_elem → bool
|
||
| (dec _ ph _) := ph
|
||
| (propg _ ph _ _) := ph
|
||
| (dbl_neg_propg _ ph _ _) := ph
|
||
|
||
meta def hyp : trail_elem → proof_hyp
|
||
| (dec _ _ h) := h
|
||
| (propg _ _ _ h) := h
|
||
| (dbl_neg_propg _ _ _ h) := h
|
||
|
||
meta def is_decision : trail_elem → bool
|
||
| (dec _ _ _) := tt
|
||
| (propg _ _ _ _) := ff
|
||
| (dbl_neg_propg _ _ _ _) := ff
|
||
|
||
end trail_elem
|
||
|
||
meta structure var_state :=
|
||
(phase : bool)
|
||
(assigned : option proof_hyp)
|
||
|
||
meta structure learned_clause :=
|
||
(c : clause)
|
||
(actual_proof : proof_term)
|
||
|
||
meta inductive prop_lit
|
||
| neg : prop_var → prop_lit
|
||
| pos : prop_var → prop_lit
|
||
|
||
namespace prop_lit
|
||
|
||
meta instance : has_ordering prop_lit :=
|
||
⟨λl₁ l₂, match l₁, l₂ with
|
||
| pos _, neg _ := ordering.gt
|
||
| neg _, pos _ := ordering.lt
|
||
| pos v₁, pos v₂ := has_ordering.cmp v₁ v₂
|
||
| neg v₁, neg v₂ := has_ordering.cmp v₁ v₂
|
||
end⟩
|
||
|
||
meta def of_cls_lit : clause.literal → prop_lit
|
||
| (clause.literal.left v) := neg v
|
||
| (clause.literal.right v) := pos v
|
||
|
||
meta def of_var_and_phase (v : prop_var) : bool → prop_lit
|
||
| tt := pos v
|
||
| ff := neg v
|
||
|
||
end prop_lit
|
||
|
||
meta def watch_map := rb_map name (ℕ × ℕ × clause)
|
||
|
||
meta structure state :=
|
||
(trail : list trail_elem)
|
||
(vars : rb_map prop_var var_state)
|
||
(unassigned : rb_map prop_var prop_var)
|
||
(clauses : list clause)
|
||
(learned : list learned_clause)
|
||
(watches : rb_map prop_lit watch_map)
|
||
(conflict : option proof_term)
|
||
(unitp_queue : list prop_var)
|
||
(local_false : expr)
|
||
|
||
namespace state
|
||
|
||
meta def initial (local_false : expr) : state := {
|
||
trail := [],
|
||
vars := rb_map.mk _ _,
|
||
unassigned := rb_map.mk _ _,
|
||
clauses := [],
|
||
learned := [],
|
||
watches := rb_map.mk _ _,
|
||
conflict := none,
|
||
unitp_queue := [],
|
||
local_false := local_false
|
||
}
|
||
|
||
meta def watches_for (st : state) (pl : prop_lit) : watch_map :=
|
||
(st^.watches^.find pl)^.get_or_else (rb_map.mk _ _)
|
||
|
||
end state
|
||
|
||
meta def solver := state_t state tactic
|
||
meta instance : monad solver := state_t.monad _ _
|
||
meta instance : has_monad_lift tactic solver := monad.monad_transformer_lift (state_t state) tactic
|
||
meta instance (α : Type) : has_coe (tactic α) (solver α) :=
|
||
⟨monad.monad_lift⟩
|
||
|
||
meta def fail {A B} [has_to_format B] (b : B) : solver A :=
|
||
@tactic.fail A B _ b
|
||
|
||
meta def get_local_false : solver expr :=
|
||
do st ← state_t.read, return st^.local_false
|
||
|
||
meta def mk_var_core (v : prop_var) (ph : bool) : solver unit := do
|
||
state_t.modify $ λst, match st^.vars^.find v with
|
||
| (some _) := st
|
||
| none := { st with
|
||
vars := st^.vars^.insert v ⟨ph, none⟩,
|
||
unassigned := st^.unassigned^.insert v v
|
||
}
|
||
end
|
||
|
||
meta def mk_var (v : prop_var) : solver unit := mk_var_core v ff
|
||
|
||
meta def set_conflict (proof : proof_term) : solver unit :=
|
||
state_t.modify $ λst, { st with conflict := some proof }
|
||
|
||
meta def has_conflict : solver bool :=
|
||
do st ← state_t.read, return st^.conflict^.is_some
|
||
|
||
meta def push_trail (elem : trail_elem) : solver unit := do
|
||
st ← state_t.read,
|
||
match st^.vars^.find elem^.var with
|
||
| none := fail $ "unknown variable: " ++ elem^.var^.to_string
|
||
| some ⟨_, some _⟩ := fail $ "adding already assigned variable to trail: " ++ elem^.var^.to_string
|
||
| some ⟨_, none⟩ :=
|
||
state_t.write { st with
|
||
vars := st^.vars^.insert elem^.var ⟨elem^.phase, some elem^.hyp⟩,
|
||
unassigned := st^.unassigned^.erase elem^.var,
|
||
trail := elem :: st^.trail,
|
||
unitp_queue := elem^.var :: st^.unitp_queue
|
||
}
|
||
end
|
||
|
||
meta def pop_trail_core : solver (option trail_elem) := do
|
||
st ← state_t.read,
|
||
match st^.trail with
|
||
| elem :: rest := do
|
||
state_t.write { st with trail := rest,
|
||
vars := st^.vars^.insert elem^.var ⟨elem^.phase, none⟩,
|
||
unassigned := st^.unassigned^.insert elem^.var elem^.var,
|
||
unitp_queue := [] },
|
||
return $ some elem
|
||
| [] := return none
|
||
end
|
||
|
||
meta def is_decision_level_zero : solver bool :=
|
||
do st ← state_t.read, return $ st^.trail^.for_all $ λelem, ¬elem^.is_decision
|
||
|
||
meta def revert_to_decision_level_zero : unit → solver unit | () := do
|
||
is_dl0 ← is_decision_level_zero,
|
||
if is_dl0 then return ()
|
||
else do pop_trail_core, revert_to_decision_level_zero ()
|
||
|
||
meta def formula_of_lit (local_false : expr) (v : prop_var) (ph : bool) :=
|
||
if ph then v else imp v local_false
|
||
|
||
meta def lookup_var (v : prop_var) : solver (option var_state) :=
|
||
do st ← state_t.read, return $ st^.vars^.find v
|
||
|
||
meta def add_propagation (v : prop_var) (ph : bool) (just : proof_term) (just_is_dn : bool) : solver unit :=
|
||
do v_st ← lookup_var v, local_false ← get_local_false, match v_st with
|
||
| none := fail $ "propagating unknown variable: " ++ v^.to_string
|
||
| some ⟨assg_ph, some proof⟩ :=
|
||
if ph = assg_ph then
|
||
return ()
|
||
else if assg_ph ∧ ¬just_is_dn then
|
||
set_conflict (app just proof)
|
||
else
|
||
set_conflict (app proof just)
|
||
| some ⟨_, none⟩ := do
|
||
hyp_name ← mk_fresh_name,
|
||
hyp ← return $ local_const hyp_name hyp_name binder_info.default (formula_of_lit local_false v ph),
|
||
if just_is_dn then do
|
||
push_trail $ trail_elem.dbl_neg_propg v ph just hyp
|
||
else do
|
||
push_trail $ trail_elem.propg v ph just hyp
|
||
end
|
||
|
||
meta def add_decision (v : prop_var) (ph : bool) : solver unit := do
|
||
hyp_name ← mk_fresh_name,
|
||
local_false ← get_local_false,
|
||
hyp ← return $ local_const hyp_name hyp_name binder_info.default (formula_of_lit local_false v ph),
|
||
push_trail $ trail_elem.dec v ph hyp
|
||
|
||
meta def lookup_lit (l : clause.literal) : solver (option (bool × proof_hyp)) :=
|
||
do var_st_opt ← lookup_var l^.formula, match var_st_opt with
|
||
| none := return none
|
||
| some ⟨ph, none⟩ := return none
|
||
| some ⟨ph, some proof⟩ :=
|
||
return $ some (if l^.is_neg then bnot ph else ph, proof)
|
||
end
|
||
|
||
meta def lit_is_false (l : clause.literal) : solver bool :=
|
||
do s ← lookup_lit l, return $ match s with
|
||
| some (ff, _) := tt
|
||
| _ := ff
|
||
end
|
||
|
||
meta def lit_is_not_false (l : clause.literal) : solver bool :=
|
||
do isf ← lit_is_false l, return $ bnot isf
|
||
|
||
meta def cls_is_false (c : clause) : solver bool :=
|
||
lift list.band $ mapm lit_is_false c^.get_lits
|
||
|
||
private meta def unit_propg_cls' : clause → solver (option prop_var) | c :=
|
||
if c^.num_lits = 0 then return (some c^.proof)
|
||
else let hd := c^.get_lit 0 in
|
||
do lit_st ← lookup_lit hd, match lit_st with
|
||
| some (ff, isf_prf) := unit_propg_cls' (c^.inst isf_prf)
|
||
| _ := return none
|
||
end
|
||
|
||
meta def unit_propg_cls : clause → solver unit | c :=
|
||
do has_confl ← has_conflict,
|
||
if has_confl then return () else
|
||
if c^.num_lits = 0 then do set_conflict c^.proof
|
||
else let hd := c^.get_lit 0 in
|
||
do lit_st ← lookup_lit hd, match lit_st with
|
||
| some (ff, isf_prf) := unit_propg_cls (c^.inst isf_prf)
|
||
| some (tt, _) := return ()
|
||
| none :=
|
||
do fls_prf_opt ← unit_propg_cls' (c^.inst (expr.mk_var 0)),
|
||
match fls_prf_opt with
|
||
| some fls_prf := do
|
||
fls_prf' ← return $ lam `H binder_info.default c^.type^.binding_domain fls_prf,
|
||
if hd^.is_neg then
|
||
add_propagation hd^.formula ff fls_prf' ff
|
||
else
|
||
add_propagation hd^.formula tt fls_prf' tt
|
||
| none := return ()
|
||
end
|
||
end
|
||
|
||
private meta def modify_watches_for (pl : prop_lit) (f : watch_map → watch_map) : solver unit :=
|
||
state_t.modify $ λst, { st with
|
||
watches := st^.watches^.insert pl $ f $ st^.watches_for pl
|
||
}
|
||
|
||
private meta def add_watch (n : name) (c : clause) (i j : ℕ) : solver unit :=
|
||
let l := c^.get_lit i, pl := prop_lit.of_cls_lit l in
|
||
modify_watches_for pl $ λw, w^.insert n (i,j,c)
|
||
|
||
private meta def remove_watch (n : name) (c : clause) (i : ℕ) : solver unit :=
|
||
let l := c^.get_lit i, pl := prop_lit.of_cls_lit l in
|
||
modify_watches_for pl $ λw, w^.erase n
|
||
|
||
private meta def set_watches (n : name) (c : clause) : solver unit :=
|
||
if c^.num_lits = 0 then
|
||
set_conflict c^.proof
|
||
else if c^.num_lits = 1 then
|
||
unit_propg_cls c
|
||
else do
|
||
not_false_lits ← filter (λi, lit_is_not_false (c^.get_lit i)) (list.range c^.num_lits),
|
||
match not_false_lits with
|
||
| [] := do
|
||
add_watch n c 0 1,
|
||
add_watch n c 1 0,
|
||
unit_propg_cls c
|
||
| [i] :=
|
||
let j := if i = 0 then 1 else 0 in do
|
||
add_watch n c i j,
|
||
add_watch n c j i,
|
||
unit_propg_cls c
|
||
| (i::j::_) := do
|
||
add_watch n c i j,
|
||
add_watch n c j i
|
||
end
|
||
|
||
meta def update_watches (n : name) (c : clause) (i₁ i₂ : ℕ) : solver unit := do
|
||
remove_watch n c i₁,
|
||
remove_watch n c i₂,
|
||
set_watches n c
|
||
|
||
meta def mk_clause (c : clause) : solver unit := do
|
||
c : clause ← c^.distinct,
|
||
for c^.get_lits (λl, mk_var l^.formula),
|
||
revert_to_decision_level_zero (),
|
||
state_t.modify $ λst, { st with clauses := c :: st^.clauses },
|
||
c_name ← mk_fresh_name,
|
||
set_watches c_name c
|
||
|
||
meta def unit_propg_var (v : prop_var) : solver unit :=
|
||
do st ← state_t.read, if st^.conflict^.is_some then return () else
|
||
match st^.vars^.find v with
|
||
| some ⟨ph, none⟩ := fail $ "propagating unassigned variable: " ++ v^.to_string
|
||
| none := fail $ "unknown variable: " ++ v^.to_string
|
||
| some ⟨ph, some _⟩ :=
|
||
let watches := st^.watches_for $ prop_lit.of_var_and_phase v (bnot ph) in
|
||
for' watches^.to_list $ λw, update_watches w.1 w.2.2.2 w.2.1 w.2.2.1
|
||
end
|
||
|
||
meta def analyze_conflict' (local_false : expr) : proof_term → list trail_elem → clause
|
||
| proof (trail_elem.dec v ph hyp :: es) :=
|
||
let abs_prf := abstract_local proof hyp^.local_uniq_name in
|
||
if has_var abs_prf then
|
||
clause.close_const (analyze_conflict' proof es) hyp
|
||
else
|
||
analyze_conflict' proof es
|
||
| proof (trail_elem.propg v ph l_prf hyp :: es) :=
|
||
let abs_prf := abstract_local proof hyp^.local_uniq_name in
|
||
if has_var abs_prf then
|
||
analyze_conflict' (app (lam hyp^.local_pp_name binder_info.default (formula_of_lit local_false v ph) abs_prf) l_prf) es
|
||
else
|
||
analyze_conflict' proof es
|
||
| proof (trail_elem.dbl_neg_propg v ph l_prf hyp :: es) :=
|
||
let abs_prf := abstract_local proof hyp^.local_uniq_name in
|
||
if has_var abs_prf then
|
||
analyze_conflict' (app l_prf (lambdas [hyp] proof)) es
|
||
else
|
||
analyze_conflict' proof es
|
||
| proof [] := ⟨0, 0, proof, local_false, local_false⟩
|
||
|
||
meta def analyze_conflict (proof : proof_term) : solver clause :=
|
||
do st ← state_t.read, return $ analyze_conflict' st^.local_false proof st^.trail
|
||
|
||
meta def add_learned (c : clause) : solver unit := do
|
||
prf_abbrev_name ← mk_fresh_name,
|
||
c' ← return { c with proof := local_const prf_abbrev_name prf_abbrev_name binder_info.default c^.type },
|
||
state_t.modify $ λst, { st with learned := ⟨c', c^.proof⟩ :: st^.learned },
|
||
c_name ← mk_fresh_name,
|
||
set_watches c_name c'
|
||
|
||
meta def backtrack_with : clause → solver unit | conflict_clause := do
|
||
isf ← cls_is_false conflict_clause,
|
||
if ¬isf then
|
||
state_t.modify (λst, { st with conflict := none })
|
||
else do
|
||
removed_elem ← pop_trail_core,
|
||
if removed_elem^.is_some then
|
||
backtrack_with conflict_clause
|
||
else
|
||
return ()
|
||
|
||
meta def replace_learned_clauses' : proof_term → list learned_clause → proof_term
|
||
| proof [] := proof
|
||
| proof (⟨c, actual_proof⟩ :: lcs) :=
|
||
let abs_prf := abstract_local proof c^.proof^.local_uniq_name in
|
||
if has_var abs_prf then
|
||
replace_learned_clauses' (elet c^.proof^.local_pp_name c^.type actual_proof abs_prf) lcs
|
||
else
|
||
replace_learned_clauses' proof lcs
|
||
|
||
meta def replace_learned_clauses (proof : proof_term) : solver proof_term :=
|
||
do st ← state_t.read, return $ replace_learned_clauses' proof st^.learned
|
||
|
||
meta inductive result
|
||
| unsat : proof_term → result
|
||
| sat : rb_map prop_var bool → result
|
||
|
||
variable theory_solver : solver (option proof_term)
|
||
|
||
meta def unit_propg : unit → solver unit | () := do
|
||
st ← state_t.read,
|
||
if st^.conflict^.is_some then return () else
|
||
match st^.unitp_queue with
|
||
| [] := return ()
|
||
| (v::vs) := do
|
||
state_t.write { st with unitp_queue := vs },
|
||
unit_propg_var v,
|
||
unit_propg ()
|
||
end
|
||
|
||
private meta def run' : unit → solver result | () := do
|
||
unit_propg (),
|
||
st ← state_t.read,
|
||
match st^.conflict with
|
||
| some conflict := do
|
||
conflict_clause ← analyze_conflict conflict,
|
||
if conflict_clause^.num_lits = 0 then do
|
||
proof ← replace_learned_clauses conflict_clause^.proof,
|
||
return (result.unsat proof)
|
||
else do
|
||
backtrack_with conflict_clause,
|
||
add_learned conflict_clause,
|
||
run' ()
|
||
| none :=
|
||
match st^.unassigned^.min with
|
||
| none := do
|
||
theory_conflict ← theory_solver,
|
||
match theory_conflict with
|
||
| some conflict := do set_conflict conflict, run' ()
|
||
| none := return $ result.sat (st^.vars^.for (λvar_st, var_st^.phase))
|
||
end
|
||
| some unassigned :=
|
||
match st^.vars^.find unassigned with
|
||
| some ⟨ph, none⟩ := do add_decision unassigned ph, run' ()
|
||
| _ := fail $ "unassigned variable is assigned: " ++ unassigned^.to_string
|
||
end
|
||
end
|
||
end
|
||
|
||
meta def run : solver result := run' theory_solver ()
|
||
|
||
meta def solve (local_false : expr) (clauses : list clause) : tactic result := do
|
||
res ← (do for clauses mk_clause, run theory_solver) (state.initial local_false),
|
||
return res.1
|
||
|
||
meta def theory_solver_of_tactic (th_solver : tactic unit) : cdcl.solver (option cdcl.proof_term) :=
|
||
do s ← state_t.read, ↑do
|
||
hyps ← return $ s^.trail^.for (λe, e^.hyp),
|
||
subgoal ← mk_meta_var s^.local_false,
|
||
goals ← get_goals,
|
||
set_goals [subgoal],
|
||
hvs ← for hyps (λhyp, assertv hyp^.local_pp_name hyp^.local_type hyp),
|
||
solved ← (do th_solver, now, return tt) <|> return ff,
|
||
set_goals goals,
|
||
if solved then do
|
||
proof ← instantiate_mvars subgoal,
|
||
proof' ← whnf proof, -- gets rid of the unnecessary asserts
|
||
return $ some proof'
|
||
else
|
||
return none
|
||
|
||
end cdcl
|