|
|
|
|
|
|
|
|
import os |
|
|
import re |
|
|
import time |
|
|
import traceback |
|
|
from typing import List |
|
|
from pathlib import Path |
|
|
from ...core.logging import logger |
|
|
from ...prompts.optimizers.aflow_optimizer import ( |
|
|
WORKFLOW_INPUT, |
|
|
WORKFLOW_OPTIMIZE_PROMPT, |
|
|
WORKFLOW_CUSTOM_USE, |
|
|
WORKFLOW_TEMPLATE |
|
|
) |
|
|
from ...models.base_model import BaseLLM |
|
|
from ...workflow.operators import ( |
|
|
Operator, Custom, CustomCodeGenerate, |
|
|
ScEnsemble, Test, AnswerGenerate, QAScEnsemble, Programmer |
|
|
) |
|
|
|
|
|
OPERATOR_MAP = { |
|
|
"Custom": Custom, |
|
|
"CustomCodeGenerate": CustomCodeGenerate, |
|
|
"ScEnsemble": ScEnsemble, |
|
|
"Test": Test, |
|
|
"AnswerGenerate": AnswerGenerate, |
|
|
"QAScEnsemble": QAScEnsemble, |
|
|
"Programmer": Programmer |
|
|
} |
|
|
|
|
|
|
|
|
class GraphUtils: |
|
|
|
|
|
def __init__(self, root_path: str): |
|
|
self.root_path = root_path |
|
|
|
|
|
def create_round_directory(self, graph_path: str, round_number: int) -> str: |
|
|
directory = os.path.join(graph_path, f"round_{round_number}") |
|
|
os.makedirs(directory, exist_ok=True) |
|
|
return directory |
|
|
|
|
|
def load_graph(self, round_number: int, workflows_path: str): |
|
|
workflows_path = workflows_path.replace("\\", ".").replace("/", ".") |
|
|
graph_module_name = f"{workflows_path}.round_{round_number}.graph" |
|
|
try: |
|
|
graph_module = __import__(graph_module_name, fromlist=[""]) |
|
|
graph_class = getattr(graph_module, "Workflow") |
|
|
return graph_class |
|
|
except ImportError as e: |
|
|
logger.info(f"Error loading graph for round {round_number}: {e}") |
|
|
raise |
|
|
|
|
|
def read_graph_files(self, round_number: int, workflows_path: str): |
|
|
prompt_file_path = os.path.join(workflows_path, f"round_{round_number}", "prompt.py") |
|
|
graph_file_path = os.path.join(workflows_path, f"round_{round_number}", "graph.py") |
|
|
|
|
|
try: |
|
|
with open(prompt_file_path, "r", encoding="utf-8") as file: |
|
|
prompt_content = file.read() |
|
|
with open(graph_file_path, "r", encoding="utf-8") as file: |
|
|
graph_content = file.read() |
|
|
except FileNotFoundError as e: |
|
|
logger.info(f"Error: File not found for round {round_number}: {e}") |
|
|
raise |
|
|
except Exception as e: |
|
|
logger.info(f"Error loading prompt for round {round_number}: {e}") |
|
|
raise |
|
|
return prompt_content, graph_content |
|
|
|
|
|
def extract_solve_graph(self, graph_load: str) -> List[str]: |
|
|
pattern = r"class Workflow:.+" |
|
|
return re.findall(pattern, graph_load, re.DOTALL) |
|
|
|
|
|
def load_operators_description(self, operators: List[str], llm: BaseLLM) -> str: |
|
|
|
|
|
operators_description = "" |
|
|
for id, operator in enumerate(operators): |
|
|
operator_description = self._load_operator_description(id + 1, operator, llm) |
|
|
operators_description += f"{operator_description}\n" |
|
|
return operators_description |
|
|
|
|
|
def _load_operator_description(self, id: int, operator_name: str, llm: BaseLLM) -> str: |
|
|
if operator_name not in OPERATOR_MAP: |
|
|
raise ValueError(f"Operator {operator_name} not Found in OPERATOR_MAP! Available operators: {OPERATOR_MAP.keys()}") |
|
|
operator: Operator = OPERATOR_MAP[operator_name](llm=llm) |
|
|
return f"{id}. {operator_name}: {operator.description}, with interface {operator.interface})." |
|
|
|
|
|
def create_graph_optimize_prompt( |
|
|
self, |
|
|
experience: str, |
|
|
score: float, |
|
|
graph: str, |
|
|
prompt: str, |
|
|
operator_description: str, |
|
|
type: str, |
|
|
log_data: str, |
|
|
) -> str: |
|
|
graph_input = WORKFLOW_INPUT.format( |
|
|
experience=experience, |
|
|
score=score, |
|
|
graph=graph, |
|
|
prompt=prompt, |
|
|
operator_description=operator_description, |
|
|
type=type, |
|
|
log=log_data, |
|
|
) |
|
|
graph_system = WORKFLOW_OPTIMIZE_PROMPT.format(type=type) |
|
|
return graph_input + WORKFLOW_CUSTOM_USE + graph_system |
|
|
|
|
|
def get_graph_optimize_response(self, graph_optimize_node): |
|
|
max_retries = 5 |
|
|
retries = 0 |
|
|
|
|
|
while retries < max_retries: |
|
|
try: |
|
|
response = graph_optimize_node.instruct_content.model_dump() |
|
|
return response |
|
|
except Exception as e: |
|
|
retries += 1 |
|
|
logger.info(f"Error generating prediction: {e}. Retrying... ({retries}/{max_retries})") |
|
|
if retries == max_retries: |
|
|
logger.info("Maximum retries reached. Skipping this sample.") |
|
|
break |
|
|
traceback.print_exc() |
|
|
time.sleep(5) |
|
|
return None |
|
|
|
|
|
def write_graph_files(self, directory: str, response: dict): |
|
|
|
|
|
graph = WORKFLOW_TEMPLATE.format(graph=response["graph"]) |
|
|
with open(os.path.join(directory, "graph.py"), "w", encoding="utf-8") as file: |
|
|
file.write(graph) |
|
|
with open(os.path.join(directory, "prompt.py"), "w", encoding="utf-8") as file: |
|
|
prompt = response["prompt"].replace("prompt_custom.", "") |
|
|
file.write(prompt) |
|
|
with open(os.path.join(directory, "__init__.py"), "w", encoding="utf-8") as file: |
|
|
file.write("") |
|
|
self.update_prompt_import(os.path.join(directory, "graph.py"), directory) |
|
|
|
|
|
def update_prompt_import(self, graph_file: str, prompt_folder: str): |
|
|
|
|
|
project_root = Path(os.getcwd()) |
|
|
prompt_folder_path = Path(prompt_folder) |
|
|
|
|
|
if not prompt_folder_path.is_absolute(): |
|
|
prompt_folder_full_path = Path(os.path.join(project_root, prompt_folder)) |
|
|
if not prompt_folder_full_path.exists(): |
|
|
raise ValueError(f"Prompt folder {prompt_folder_full_path} does not exist!") |
|
|
prompt_folder_path = prompt_folder_full_path |
|
|
|
|
|
try: |
|
|
relative_path = prompt_folder_path.relative_to(project_root) |
|
|
except ValueError: |
|
|
raise ValueError(f"Prompt folder {prompt_folder} must be within the project directory") |
|
|
|
|
|
import_path = str(relative_path).replace(os.sep, ".") |
|
|
if import_path.startswith("."): |
|
|
import_path = import_path[1:] |
|
|
|
|
|
with open(graph_file, "r", encoding="utf-8") as file: |
|
|
graph_content = file.read() |
|
|
|
|
|
|
|
|
pattern = r'import .*?\.prompt as prompt_custom' |
|
|
replacement = f'import {import_path}.prompt as prompt_custom' |
|
|
new_content = re.sub(pattern, replacement, graph_content) |
|
|
|
|
|
with open(graph_file, "w", encoding="utf-8") as file: |
|
|
file.write(new_content) |
|
|
|
|
|
|
|
|
|