Spaces:
Sleeping
Sleeping
| class ConstraintParser: | |
| # 1. NEW: Add target_functions to the initialization | |
| def __init__(self, model, schedule, global_data, target_functions=None): | |
| self.model = model | |
| self.schedule = schedule | |
| self.global_data = global_data | |
| self.target_functions = target_functions or {} | |
| def resolve_arg(self, arg, local_vars): | |
| """Extracts values from local variables based on JSON rules.""" | |
| if isinstance(arg, (int, float, bool)): | |
| return arg | |
| if isinstance(arg, str): | |
| return local_vars.get(arg, arg) | |
| if isinstance(arg, dict): | |
| prop, var_name = list(arg.items())[0] | |
| parent_obj = local_vars[var_name] | |
| if prop.isdigit() and isinstance(parent_obj, (list, tuple)): | |
| return parent_obj[int(prop)] | |
| if isinstance(parent_obj, dict): | |
| return parent_obj[prop] | |
| raise ValueError(f"Could not resolve argument: {arg}") | |
| def apply_assert(self, assert_ast, local_vars): | |
| """Translates the 'assert' block into a CP-SAT model.Add()""" | |
| left_val = self.evaluate_expression(assert_ast["left"], local_vars) | |
| right_val = self.evaluate_expression(assert_ast["right"], local_vars) | |
| operator = assert_ast["operator"] | |
| if isinstance(left_val, (int, float)) and isinstance(right_val, (int, float)): | |
| return | |
| if operator == "==": self.model.Add(left_val == right_val) | |
| elif operator == "!=": self.model.Add(left_val != right_val) | |
| elif operator == "<=": self.model.Add(left_val <= right_val) | |
| elif operator == ">=": self.model.Add(left_val >= right_val) | |
| elif operator == "<": self.model.Add(left_val < right_val) | |
| elif operator == ">": self.model.Add(left_val > right_val) | |
| def evaluate_expression(self, expr, local_vars): | |
| """Recursively evaluates ALL expressions (Math, Logic, and Variables)""" | |
| if not isinstance(expr, dict): | |
| return self.resolve_arg(expr, local_vars) | |
| # 2. NEW: Dynamic Target Lookup | |
| if "target" in expr: | |
| args = tuple(self.resolve_arg(a, local_vars) for a in expr["args"]) | |
| if expr["target"] == "schedule": | |
| return self.schedule.get(args, 0) | |
| # If the JSON asks for a custom function, dynamically run it! | |
| elif expr["target"] in self.target_functions: | |
| func = self.target_functions[expr["target"]] | |
| return func(self.schedule, *args) | |
| else: | |
| raise ValueError(f"Unknown target function: {expr['target']}") | |
| op = expr.get("operator") | |
| if op: | |
| # 3. RESTORED: Intercept 'sum' BEFORE evaluating left/right | |
| if op == "sum": | |
| sum_results = [] | |
| def execute_inner_loops(loop_array, depth, current_vars): | |
| if depth == len(loop_array): | |
| val = self.evaluate_expression(expr["expression"], current_vars) | |
| sum_results.append(val) | |
| return | |
| current_loop = loop_array[depth] | |
| iterator_name, iterator_source = list(current_loop.items())[0] | |
| # --- THE FIX: Add the global_data check here! --- | |
| if isinstance(iterator_source, str) and iterator_source in self.global_data: | |
| iterable = self.global_data[iterator_source] | |
| else: | |
| iterable = self.resolve_arg(iterator_source, current_vars) | |
| # ------------------------------------------------ | |
| for item in iterable: | |
| new_vars = current_vars.copy() | |
| new_vars[iterator_name] = item | |
| if "where" in expr: | |
| if not self.evaluate_expression(expr["where"], new_vars): | |
| continue | |
| execute_inner_loops(loop_array, depth + 1, new_vars) | |
| execute_inner_loops(expr["over"], 0, local_vars) | |
| return sum(sum_results) | |
| # --- BINARY OPERATORS (Require left/right) --- | |
| left_val = self.evaluate_expression(expr.get("left"), local_vars) | |
| right_val = self.evaluate_expression(expr.get("right"), local_vars) | |
| # Math | |
| if op == "+": return left_val + right_val | |
| if op == "-": return left_val - right_val | |
| if op == "*": return left_val * right_val | |
| # Comparison | |
| if op == "<": return left_val < right_val | |
| if op == ">": return left_val > right_val | |
| if op == "<=": return left_val <= right_val | |
| if op == ">=": return left_val >= right_val | |
| if op == "==": return left_val == right_val | |
| if op == "!=": return left_val != right_val | |
| # Boolean | |
| if op == "AND": return left_val and right_val | |
| if op == "OR": return left_val or right_val | |
| if op == "in": return left_val in right_val | |
| if op == "not_in": return left_val not in right_val | |
| return self.resolve_arg(expr, local_vars) | |
| def execute_loops(self, loop_array, current_depth, local_vars, ast): | |
| """Recursively iterates through the 'forall' array.""" | |
| if current_depth == len(loop_array): | |
| if "where" in ast: | |
| if not self.evaluate_expression(ast["where"], local_vars): | |
| return | |
| self.apply_assert(ast["assert"], local_vars) | |
| return | |
| current_loop = loop_array[current_depth] | |
| iterator_name, iterator_source = list(current_loop.items())[0] | |
| iterable = None | |
| if isinstance(iterator_source, str) and iterator_source in self.global_data: | |
| iterable = self.global_data[iterator_source] | |
| else: | |
| iterable = self.resolve_arg(iterator_source, local_vars) | |
| for item in iterable: | |
| new_vars = local_vars.copy() | |
| new_vars[iterator_name] = item | |
| self.execute_loops(loop_array, current_depth + 1, new_vars, ast) | |
| def parse_and_apply(self, ast): | |
| """Main entry point to execute a JSON AST constraint.""" | |
| if ast["type"] == "hard": | |
| self.execute_loops(ast["forall"], 0, {}, ast) |