lean4-htt/library/tools/super/cdcl_solver.lean
2017-01-10 09:07:37 -08:00

431 lines
14 KiB
Text
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

/-
Copyright (c) 2016 Gabriel Ebner. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Gabriel Ebner
-/
import .clause
open tactic expr monad super
namespace cdcl
@[reducible] meta def prop_var := expr
@[reducible] meta def proof_term := expr
@[reducible] meta def proof_hyp := expr
meta inductive trail_elem
| dec : prop_var → bool → proof_hyp → trail_elem
| propg : prop_var → bool → proof_term → proof_hyp → trail_elem
| dbl_neg_propg : prop_var → bool → proof_term → proof_hyp → trail_elem
namespace trail_elem
meta def var : trail_elem → prop_var
| (dec v _ _) := v
| (propg v _ _ _) := v
| (dbl_neg_propg v _ _ _) := v
meta def phase : trail_elem → bool
| (dec _ ph _) := ph
| (propg _ ph _ _) := ph
| (dbl_neg_propg _ ph _ _) := ph
meta def hyp : trail_elem → proof_hyp
| (dec _ _ h) := h
| (propg _ _ _ h) := h
| (dbl_neg_propg _ _ _ h) := h
meta def is_decision : trail_elem → bool
| (dec _ _ _) := tt
| (propg _ _ _ _) := ff
| (dbl_neg_propg _ _ _ _) := ff
end trail_elem
meta structure var_state :=
(phase : bool)
(assigned : option proof_hyp)
meta structure learned_clause :=
(c : clause)
(actual_proof : proof_term)
meta inductive prop_lit
| neg : prop_var → prop_lit
| pos : prop_var → prop_lit
namespace prop_lit
meta instance : has_ordering prop_lit :=
⟨λl₁ l₂, match l₁, l₂ with
| pos _, neg _ := ordering.gt
| neg _, pos _ := ordering.lt
| pos v₁, pos v₂ := has_ordering.cmp v₁ v₂
| neg v₁, neg v₂ := has_ordering.cmp v₁ v₂
end⟩
meta def of_cls_lit : clause.literal → prop_lit
| (clause.literal.left v) := neg v
| (clause.literal.right v) := pos v
meta def of_var_and_phase (v : prop_var) : bool → prop_lit
| tt := pos v
| ff := neg v
end prop_lit
meta def watch_map := rb_map name ( × × clause)
meta structure state :=
(trail : list trail_elem)
(vars : rb_map prop_var var_state)
(unassigned : rb_map prop_var prop_var)
(clauses : list clause)
(learned : list learned_clause)
(watches : rb_map prop_lit watch_map)
(conflict : option proof_term)
(unitp_queue : list prop_var)
(local_false : expr)
namespace state
meta def initial (local_false : expr) : state := {
trail := [],
vars := rb_map.mk _ _,
unassigned := rb_map.mk _ _,
clauses := [],
learned := [],
watches := rb_map.mk _ _,
conflict := none,
unitp_queue := [],
local_false := local_false
}
meta def watches_for (st : state) (pl : prop_lit) : watch_map :=
(st^.watches^.find pl)^.get_or_else (rb_map.mk _ _)
end state
meta def solver := state_t state tactic
meta instance : monad solver := state_t.monad _ _
meta instance : has_monad_lift tactic solver := monad.monad_transformer_lift (state_t state) tactic
meta instance (α : Type) : has_coe (tactic α) (solver α) :=
⟨monad.monad_lift⟩
meta def fail {A B} [has_to_format B] (b : B) : solver A :=
@tactic.fail A B _ b
meta def get_local_false : solver expr :=
do st ← state_t.read, return st^.local_false
meta def mk_var_core (v : prop_var) (ph : bool) : solver unit := do
state_t.modify $ λst, match st^.vars^.find v with
| (some _) := st
| none := { st with
vars := st^.vars^.insert v ⟨ph, none⟩,
unassigned := st^.unassigned^.insert v v
}
end
meta def mk_var (v : prop_var) : solver unit := mk_var_core v ff
meta def set_conflict (proof : proof_term) : solver unit :=
state_t.modify $ λst, { st with conflict := some proof }
meta def has_conflict : solver bool :=
do st ← state_t.read, return st^.conflict^.is_some
meta def push_trail (elem : trail_elem) : solver unit := do
st ← state_t.read,
match st^.vars^.find elem^.var with
| none := fail $ "unknown variable: " ++ elem^.var^.to_string
| some ⟨_, some _⟩ := fail $ "adding already assigned variable to trail: " ++ elem^.var^.to_string
| some ⟨_, none⟩ :=
state_t.write { st with
vars := st^.vars^.insert elem^.var ⟨elem^.phase, some elem^.hyp⟩,
unassigned := st^.unassigned^.erase elem^.var,
trail := elem :: st^.trail,
unitp_queue := elem^.var :: st^.unitp_queue
}
end
meta def pop_trail_core : solver (option trail_elem) := do
st ← state_t.read,
match st^.trail with
| elem :: rest := do
state_t.write { st with trail := rest,
vars := st^.vars^.insert elem^.var ⟨elem^.phase, none⟩,
unassigned := st^.unassigned^.insert elem^.var elem^.var,
unitp_queue := [] },
return $ some elem
| [] := return none
end
meta def is_decision_level_zero : solver bool :=
do st ← state_t.read, return $ st^.trail^.for_all $ λelem, ¬elem^.is_decision
meta def revert_to_decision_level_zero : unit → solver unit | () := do
is_dl0 ← is_decision_level_zero,
if is_dl0 then return ()
else do pop_trail_core, revert_to_decision_level_zero ()
meta def formula_of_lit (local_false : expr) (v : prop_var) (ph : bool) :=
if ph then v else imp v local_false
meta def lookup_var (v : prop_var) : solver (option var_state) :=
do st ← state_t.read, return $ st^.vars^.find v
meta def add_propagation (v : prop_var) (ph : bool) (just : proof_term) (just_is_dn : bool) : solver unit :=
do v_st ← lookup_var v, local_false ← get_local_false, match v_st with
| none := fail $ "propagating unknown variable: " ++ v^.to_string
| some ⟨assg_ph, some proof⟩ :=
if ph = assg_ph then
return ()
else if assg_ph ∧ ¬just_is_dn then
set_conflict (app just proof)
else
set_conflict (app proof just)
| some ⟨_, none⟩ := do
hyp_name ← mk_fresh_name,
hyp ← return $ local_const hyp_name hyp_name binder_info.default (formula_of_lit local_false v ph),
if just_is_dn then do
push_trail $ trail_elem.dbl_neg_propg v ph just hyp
else do
push_trail $ trail_elem.propg v ph just hyp
end
meta def add_decision (v : prop_var) (ph : bool) : solver unit := do
hyp_name ← mk_fresh_name,
local_false ← get_local_false,
hyp ← return $ local_const hyp_name hyp_name binder_info.default (formula_of_lit local_false v ph),
push_trail $ trail_elem.dec v ph hyp
meta def lookup_lit (l : clause.literal) : solver (option (bool × proof_hyp)) :=
do var_st_opt ← lookup_var l^.formula, match var_st_opt with
| none := return none
| some ⟨ph, none⟩ := return none
| some ⟨ph, some proof⟩ :=
return $ some (if l^.is_neg then bnot ph else ph, proof)
end
meta def lit_is_false (l : clause.literal) : solver bool :=
do s ← lookup_lit l, return $ match s with
| some (ff, _) := tt
| _ := ff
end
meta def lit_is_not_false (l : clause.literal) : solver bool :=
do isf ← lit_is_false l, return $ bnot isf
meta def cls_is_false (c : clause) : solver bool :=
lift list.band $ mapm lit_is_false c^.get_lits
private meta def unit_propg_cls' : clause → solver (option prop_var) | c :=
if c^.num_lits = 0 then return (some c^.proof)
else let hd := c^.get_lit 0 in
do lit_st ← lookup_lit hd, match lit_st with
| some (ff, isf_prf) := unit_propg_cls' (c^.inst isf_prf)
| _ := return none
end
meta def unit_propg_cls : clause → solver unit | c :=
do has_confl ← has_conflict,
if has_confl then return () else
if c^.num_lits = 0 then do set_conflict c^.proof
else let hd := c^.get_lit 0 in
do lit_st ← lookup_lit hd, match lit_st with
| some (ff, isf_prf) := unit_propg_cls (c^.inst isf_prf)
| some (tt, _) := return ()
| none :=
do fls_prf_opt ← unit_propg_cls' (c^.inst (expr.mk_var 0)),
match fls_prf_opt with
| some fls_prf := do
fls_prf' ← return $ lam `H binder_info.default c^.type^.binding_domain fls_prf,
if hd^.is_neg then
add_propagation hd^.formula ff fls_prf' ff
else
add_propagation hd^.formula tt fls_prf' tt
| none := return ()
end
end
private meta def modify_watches_for (pl : prop_lit) (f : watch_map → watch_map) : solver unit :=
state_t.modify $ λst, { st with
watches := st^.watches^.insert pl $ f $ st^.watches_for pl
}
private meta def add_watch (n : name) (c : clause) (i j : ) : solver unit :=
let l := c^.get_lit i, pl := prop_lit.of_cls_lit l in
modify_watches_for pl $ λw, w^.insert n (i,j,c)
private meta def remove_watch (n : name) (c : clause) (i : ) : solver unit :=
let l := c^.get_lit i, pl := prop_lit.of_cls_lit l in
modify_watches_for pl $ λw, w^.erase n
private meta def set_watches (n : name) (c : clause) : solver unit :=
if c^.num_lits = 0 then
set_conflict c^.proof
else if c^.num_lits = 1 then
unit_propg_cls c
else do
not_false_lits ← filter (λi, lit_is_not_false (c^.get_lit i)) (list.range c^.num_lits),
match not_false_lits with
| [] := do
add_watch n c 0 1,
add_watch n c 1 0,
unit_propg_cls c
| [i] :=
let j := if i = 0 then 1 else 0 in do
add_watch n c i j,
add_watch n c j i,
unit_propg_cls c
| (i::j::_) := do
add_watch n c i j,
add_watch n c j i
end
meta def update_watches (n : name) (c : clause) (i₁ i₂ : ) : solver unit := do
remove_watch n c i₁,
remove_watch n c i₂,
set_watches n c
meta def mk_clause (c : clause) : solver unit := do
c : clause ← c^.distinct,
for c^.get_lits (λl, mk_var l^.formula),
revert_to_decision_level_zero (),
state_t.modify $ λst, { st with clauses := c :: st^.clauses },
c_name ← mk_fresh_name,
set_watches c_name c
meta def unit_propg_var (v : prop_var) : solver unit :=
do st ← state_t.read, if st^.conflict^.is_some then return () else
match st^.vars^.find v with
| some ⟨ph, none⟩ := fail $ "propagating unassigned variable: " ++ v^.to_string
| none := fail $ "unknown variable: " ++ v^.to_string
| some ⟨ph, some _⟩ :=
let watches := st^.watches_for $ prop_lit.of_var_and_phase v (bnot ph) in
for' watches^.to_list $ λw, update_watches w.1 w.2.2.2 w.2.1 w.2.2.1
end
meta def analyze_conflict' (local_false : expr) : proof_term → list trail_elem → clause
| proof (trail_elem.dec v ph hyp :: es) :=
let abs_prf := abstract_local proof hyp^.local_uniq_name in
if has_var abs_prf then
clause.close_const (analyze_conflict' proof es) hyp
else
analyze_conflict' proof es
| proof (trail_elem.propg v ph l_prf hyp :: es) :=
let abs_prf := abstract_local proof hyp^.local_uniq_name in
if has_var abs_prf then
analyze_conflict' (app (lam hyp^.local_pp_name binder_info.default (formula_of_lit local_false v ph) abs_prf) l_prf) es
else
analyze_conflict' proof es
| proof (trail_elem.dbl_neg_propg v ph l_prf hyp :: es) :=
let abs_prf := abstract_local proof hyp^.local_uniq_name in
if has_var abs_prf then
analyze_conflict' (app l_prf (lambdas [hyp] proof)) es
else
analyze_conflict' proof es
| proof [] := ⟨0, 0, proof, local_false, local_false⟩
meta def analyze_conflict (proof : proof_term) : solver clause :=
do st ← state_t.read, return $ analyze_conflict' st^.local_false proof st^.trail
meta def add_learned (c : clause) : solver unit := do
prf_abbrev_name ← mk_fresh_name,
c' ← return { c with proof := local_const prf_abbrev_name prf_abbrev_name binder_info.default c^.type },
state_t.modify $ λst, { st with learned := ⟨c', c^.proof⟩ :: st^.learned },
c_name ← mk_fresh_name,
set_watches c_name c'
meta def backtrack_with : clause → solver unit | conflict_clause := do
isf ← cls_is_false conflict_clause,
if ¬isf then
state_t.modify (λst, { st with conflict := none })
else do
removed_elem ← pop_trail_core,
if removed_elem^.is_some then
backtrack_with conflict_clause
else
return ()
meta def replace_learned_clauses' : proof_term → list learned_clause → proof_term
| proof [] := proof
| proof (⟨c, actual_proof⟩ :: lcs) :=
let abs_prf := abstract_local proof c^.proof^.local_uniq_name in
if has_var abs_prf then
replace_learned_clauses' (elet c^.proof^.local_pp_name c^.type actual_proof abs_prf) lcs
else
replace_learned_clauses' proof lcs
meta def replace_learned_clauses (proof : proof_term) : solver proof_term :=
do st ← state_t.read, return $ replace_learned_clauses' proof st^.learned
meta inductive result
| unsat : proof_term → result
| sat : rb_map prop_var bool → result
variable theory_solver : solver (option proof_term)
meta def unit_propg : unit → solver unit | () := do
st ← state_t.read,
if st^.conflict^.is_some then return () else
match st^.unitp_queue with
| [] := return ()
| (v::vs) := do
state_t.write { st with unitp_queue := vs },
unit_propg_var v,
unit_propg ()
end
private meta def run' : unit → solver result | () := do
unit_propg (),
st ← state_t.read,
match st^.conflict with
| some conflict := do
conflict_clause ← analyze_conflict conflict,
if conflict_clause^.num_lits = 0 then do
proof ← replace_learned_clauses conflict_clause^.proof,
return (result.unsat proof)
else do
backtrack_with conflict_clause,
add_learned conflict_clause,
run' ()
| none :=
match st^.unassigned^.min with
| none := do
theory_conflict ← theory_solver,
match theory_conflict with
| some conflict := do set_conflict conflict, run' ()
| none := return $ result.sat (st^.vars^.for (λvar_st, var_st^.phase))
end
| some unassigned :=
match st^.vars^.find unassigned with
| some ⟨ph, none⟩ := do add_decision unassigned ph, run' ()
| _ := fail $ "unassigned variable is assigned: " ++ unassigned^.to_string
end
end
end
meta def run : solver result := run' theory_solver ()
meta def solve (local_false : expr) (clauses : list clause) : tactic result := do
res ← (do for clauses mk_clause, run theory_solver) (state.initial local_false),
return res.1
meta def theory_solver_of_tactic (th_solver : tactic unit) : cdcl.solver (option cdcl.proof_term) :=
do s ← state_t.read, ↑do
hyps ← return $ s^.trail^.for (λe, e^.hyp),
subgoal ← mk_meta_var s^.local_false,
goals ← get_goals,
set_goals [subgoal],
hvs ← for hyps (λhyp, assertv hyp^.local_pp_name hyp^.local_type hyp),
solved ← (do th_solver, now, return tt) <|> return ff,
set_goals goals,
if solved then do
proof ← instantiate_mvars subgoal,
proof' ← whnf proof, -- gets rid of the unnecessary asserts
return $ some proof'
else
return none
end cdcl