| from sympy.printing.smtlib import smtlib_code | |
| from sympy.assumptions.assume import AppliedPredicate | |
| from sympy.assumptions.cnf import EncodedCNF | |
| from sympy.assumptions.ask import Q | |
| from sympy.core import Add, Mul | |
| from sympy.core.relational import Equality, LessThan, GreaterThan, StrictLessThan, StrictGreaterThan | |
| from sympy.functions.elementary.complexes import Abs | |
| from sympy.functions.elementary.exponential import Pow | |
| from sympy.functions.elementary.miscellaneous import Min, Max | |
| from sympy.logic.boolalg import And, Or, Xor, Implies | |
| from sympy.logic.boolalg import Not, ITE | |
| from sympy.assumptions.relation.equality import StrictGreaterThanPredicate, StrictLessThanPredicate, GreaterThanPredicate, LessThanPredicate, EqualityPredicate | |
| from sympy.external import import_module | |
| def z3_satisfiable(expr, all_models=False): | |
| if not isinstance(expr, EncodedCNF): | |
| exprs = EncodedCNF() | |
| exprs.add_prop(expr) | |
| expr = exprs | |
| z3 = import_module("z3") | |
| if z3 is None: | |
| raise ImportError("z3 is not installed") | |
| s = encoded_cnf_to_z3_solver(expr, z3) | |
| res = str(s.check()) | |
| if res == "unsat": | |
| return False | |
| elif res == "sat": | |
| return z3_model_to_sympy_model(s.model(), expr) | |
| else: | |
| return None | |
| def z3_model_to_sympy_model(z3_model, enc_cnf): | |
| rev_enc = {value : key for key, value in enc_cnf.encoding.items()} | |
| return {rev_enc[int(var.name()[1:])] : bool(z3_model[var]) for var in z3_model} | |
| def clause_to_assertion(clause): | |
| clause_strings = [f"d{abs(lit)}" if lit > 0 else f"(not d{abs(lit)})" for lit in clause] | |
| return "(assert (or " + " ".join(clause_strings) + "))" | |
| def encoded_cnf_to_z3_solver(enc_cnf, z3): | |
| def dummify_bool(pred): | |
| return False | |
| assert isinstance(pred, AppliedPredicate) | |
| if pred.function in [Q.positive, Q.negative, Q.zero]: | |
| return pred | |
| else: | |
| return False | |
| s = z3.Solver() | |
| declarations = [f"(declare-const d{var} Bool)" for var in enc_cnf.variables] | |
| assertions = [clause_to_assertion(clause) for clause in enc_cnf.data] | |
| symbols = set() | |
| for pred, enc in enc_cnf.encoding.items(): | |
| if not isinstance(pred, AppliedPredicate): | |
| continue | |
| if pred.function not in (Q.gt, Q.lt, Q.ge, Q.le, Q.ne, Q.eq, Q.positive, Q.negative, Q.extended_negative, Q.extended_positive, Q.zero, Q.nonzero, Q.nonnegative, Q.nonpositive, Q.extended_nonzero, Q.extended_nonnegative, Q.extended_nonpositive): | |
| continue | |
| pred_str = smtlib_code(pred, auto_declare=False, auto_assert=False, known_functions=known_functions) | |
| symbols |= pred.free_symbols | |
| pred = pred_str | |
| clause = f"(implies d{enc} {pred})" | |
| assertion = "(assert " + clause + ")" | |
| assertions.append(assertion) | |
| for sym in symbols: | |
| declarations.append(f"(declare-const {sym} Real)") | |
| declarations = "\n".join(declarations) | |
| assertions = "\n".join(assertions) | |
| s.from_string(declarations) | |
| s.from_string(assertions) | |
| return s | |
| known_functions = { | |
| Add: '+', | |
| Mul: '*', | |
| Equality: '=', | |
| LessThan: '<=', | |
| GreaterThan: '>=', | |
| StrictLessThan: '<', | |
| StrictGreaterThan: '>', | |
| EqualityPredicate(): '=', | |
| LessThanPredicate(): '<=', | |
| GreaterThanPredicate(): '>=', | |
| StrictLessThanPredicate(): '<', | |
| StrictGreaterThanPredicate(): '>', | |
| Abs: 'abs', | |
| Min: 'min', | |
| Max: 'max', | |
| Pow: '^', | |
| And: 'and', | |
| Or: 'or', | |
| Xor: 'xor', | |
| Not: 'not', | |
| ITE: 'ite', | |
| Implies: '=>', | |
| } | |