| | from .base_agent import BaseAgent |
| | from prompt.constants import modeling_methods |
| | from prompt.template import (TASK_ANALYSIS_PROMPT, TASK_RESULT_PROMPT, TASK_ANSWER_PROMPT, |
| | TASK_FORMULAS_PROMPT, TASK_FORMULAS_CRITIQUE_PROMPT, TASK_FORMULAS_IMPROVEMENT_PROMPT, |
| | TASK_MODELING_PROMPT, TASK_MODELING_CRITIQUE_PROMPT, TASK_MODELING_IMPROVEMENT_PROMPT, |
| | TASK_CODING_PROMPT, TASK_CODING_DEBUG_PROMPT, CODE_STRUCTURE_PROMPT, |
| | TASK_RESULT_WITH_CODE_PROMPT, COO_PROMPT, TASK_CODING_WO_COO_PROMPT) |
| | import sys |
| | import os |
| | import subprocess |
| | import selectors |
| | import tiktoken |
| | import json |
| |
|
| |
|
| | class EnvException(Exception): |
| | def __init__(self, message): |
| | self.message = message |
| | def __str__(self): |
| | return self.message |
| | |
| |
|
| | def execute_script(script_path, work_dir): |
| | try: |
| | device = 0 |
| | python = "python" |
| | cmd = f"CUDA_VISIBLE_DEVICES={device} {python} -u {script_path}" |
| | process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, shell=True, cwd=work_dir) |
| |
|
| | stdout_lines = [] |
| | stderr_lines = [] |
| |
|
| | selector = selectors.DefaultSelector() |
| | selector.register(process.stdout, selectors.EVENT_READ) |
| | selector.register(process.stderr, selectors.EVENT_READ) |
| |
|
| | while process.poll() is None and selector.get_map(): |
| | events = selector.select(timeout=1) |
| |
|
| | for key, _ in events: |
| | line = key.fileobj.readline() |
| | if key.fileobj == process.stdout: |
| | print("STDOUT:", line, end =" ") |
| | stdout_lines.append(line) |
| | else: |
| | print("STDERR:", line, end =" ") |
| | stderr_lines.append(line) |
| |
|
| | for line in process.stdout: |
| | line = line |
| | print("STDOUT:", line, end =" ") |
| | stdout_lines.append(line) |
| | for line in process.stderr: |
| | line = line |
| | print("STDERR:", line, end =" ") |
| | stderr_lines.append(line) |
| |
|
| | return_code = process.returncode |
| |
|
| | if return_code != 0: |
| | observation = "".join(stderr_lines) |
| | else: |
| | observation = "".join(stdout_lines) |
| | if observation == "" and return_code == 0: |
| | |
| | observation = "".join(stderr_lines) |
| | return "The script has been executed. Here is the output:\n" + observation |
| | except Exception as e: |
| | print("++++", "Wrong!") |
| | raise EnvException(f"Something went wrong in executing {script_path}: {e}. Please check if it is ready to be executed.") |
| |
|
| |
|
| | class Task(BaseAgent): |
| | def __init__(self, llm, coo=True, rag=True): |
| | super().__init__(llm) |
| | self.coo = coo |
| | self.rag = rag |
| | if coo: |
| | self.coo_prompt = COO_PROMPT |
| | else: |
| | self.coo_prompt = "" |
| |
|
| | def analysis(self, prompt: str, task_description: str, user_prompt: str = ''): |
| | prompt = TASK_ANALYSIS_PROMPT.format(prompt=prompt, coo_prompt=self.coo_prompt, task_description=task_description, user_prompt=user_prompt).strip() |
| | return self.llm.generate(prompt) |
| | |
| | def formulas_actor(self, prompt: str, data_summary: str, task_description: str, task_analysis: str, modeling_methods: str, user_prompt: str = ''): |
| | prompt = TASK_FORMULAS_PROMPT.format(prompt=prompt, coo_prompt=self.coo_prompt, data_summary=data_summary, task_description=task_description, task_analysis=task_analysis, modeling_methods=modeling_methods, user_prompt=user_prompt).strip() |
| | return self.llm.generate(prompt) |
| |
|
| | def formulas_critic(self, data_summary: str, task_description: str, task_analysis: str, modeling_formulas: str): |
| | prompt = TASK_FORMULAS_CRITIQUE_PROMPT.format(data_summary=data_summary, task_description=task_description, task_analysis=task_analysis, modeling_formulas=modeling_formulas).strip() |
| | return self.llm.generate(prompt) |
| | |
| | def formulas_improvement(self, data_summary: str, task_description: str, task_analysis: str, modeling_formulas: str, modeling_formulas_critique: str, user_prompt: str = ''): |
| | prompt = TASK_FORMULAS_IMPROVEMENT_PROMPT.format(data_summary=data_summary, task_description=task_description, task_analysis=task_analysis, modeling_formulas=modeling_formulas, modeling_formulas_critique=modeling_formulas_critique, user_prompt=user_prompt).strip() |
| | return self.llm.generate(prompt) |
| |
|
| | def formulas(self, prompt: str, data_summary: str, task_description: str, task_analysis: str, modeling_methods: str, round: int = 1, user_prompt: str = ''): |
| | formulas = self.formulas_actor(prompt, data_summary, task_description, task_analysis, modeling_methods, user_prompt) |
| | if self.rag: |
| | for i in range(round): |
| | print(f'FORMULAS Round {i+1}') |
| | formulas_critique = self.formulas_critic(data_summary, task_description, task_analysis, formulas) |
| | formulas = self.formulas_improvement(data_summary, task_description, task_analysis, formulas, formulas_critique, user_prompt) |
| | |
| | return formulas |
| |
|
| | def modeling_actor(self, prompt: str, data_summary: str, task_description: str, task_analysis: str, formulas: str, user_prompt: str = ''): |
| | prompt = TASK_MODELING_PROMPT.format(prompt=prompt, coo_prompt=self.coo_prompt, data_summary=data_summary, task_description=task_description, task_analysis=task_analysis, modeling_formulas=formulas, user_prompt=user_prompt).strip() |
| | return self.llm.generate(prompt) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | def modeling(self, prompt: str, data_summary: str, task_description: str, task_analysis: str, formulas: str, round: int = 1, user_prompt: str = ''): |
| | return self.modeling_actor(prompt, data_summary, task_description, task_analysis, formulas, user_prompt) |
| | |
| | def modeling_actor(self, prompt: str, data_summary: str, task_description: str, task_analysis: str, formulas: str, modeling: str, user_prompt: str = ''): |
| | prompt = TASK_MODELING_PROMPT.format(prompt=prompt, coo_prompt=self.coo_prompt, data_summary=data_summary, task_description=task_description, task_analysis=task_analysis, modeling_formulas=formulas, modeling_methods=modeling, user_prompt=user_prompt).strip() |
| | return self.llm.generate(prompt) |
| | |
| | def coding_actor(self, data_file, data_summary, variable_description, task_description: str, task_analysis: str, formulas: str, modeling: str, dependent_file_prompt: str, code_template: str, script_name: str, work_dir: str, user_prompt: str = ''): |
| | if self.coo: |
| | prompt = TASK_CODING_PROMPT.format(data_file=data_file, data_summary=data_summary, variable_description=variable_description, task_description=task_description, task_analysis=task_analysis, modeling_formulas=formulas, modeling_process=modeling, dependent_file_prompt=dependent_file_prompt, code_template=code_template, user_prompt=user_prompt).strip() |
| | else: |
| | prompt = TASK_CODING_WO_COO_PROMPT.format(data_file=data_file, data_summary=data_summary, variable_description=variable_description, task_description=task_description, task_analysis=task_analysis, modeling_formulas=formulas, modeling_process=modeling, code_template=code_template, user_prompt=user_prompt).strip() |
| | max_retry = 0 |
| | while max_retry < 5: |
| | max_retry += 1 |
| | try: |
| | completion = self.llm.generate(prompt) |
| | new_content = completion.split("```python")[1].split("```")[0].strip() |
| | break |
| | except Exception as e: |
| | |
| | print(f"Retry! The code does not start with ```python") |
| | continue |
| |
|
| | with open(os.path.join(work_dir, script_name), "w") as f: |
| | f.write(new_content) |
| | |
| | |
| | try: |
| | observation = execute_script(script_name, work_dir) |
| | |
| | enc = tiktoken.get_encoding("cl100k_base") |
| | tokens = len(enc.encode(observation)) |
| | if tokens >= 2000: |
| | observation = observation[:2000] |
| | tokens = len(enc.encode(observation)) |
| | except Exception as e: |
| | print(e) |
| | input("Ah oh, Got stuck! Press any key to continue.") |
| |
|
| | return new_content, observation |
| | |
| | def coding_debugger(self, code_template: str, modeling: str, code: str, observation: str, script_name: str, work_dir: str, user_prompt: str = ''): |
| | |
| | prompt = TASK_CODING_DEBUG_PROMPT.format(code_template=code_template, modeling_process=modeling, code=code, observation=observation, user_prompt=user_prompt).strip() |
| | |
| | max_retry = 0 |
| | while max_retry < 5: |
| | max_retry += 1 |
| | try: |
| | completion = self.llm.generate(prompt) |
| | new_content = completion.split("```python")[1].split("```")[0].strip() |
| | break |
| | except Exception as e: |
| | |
| | print(f"Retry! The code does not start with ```python") |
| | continue |
| |
|
| | with open(os.path.join(work_dir, script_name), "w") as f: |
| | f.write(new_content) |
| | |
| | |
| | try: |
| | observation = execute_script(script_name, work_dir) |
| | |
| | enc = tiktoken.get_encoding("cl100k_base") |
| | tokens = len(enc.encode(observation)) |
| | if tokens >= 2000: |
| | observation = observation[:2000] |
| | tokens = len(enc.encode(observation)) |
| | except Exception as e: |
| | print(e) |
| | input("Ah oh, Got stuck! Press any key to continue.") |
| |
|
| | return new_content, observation |
| | |
| | def coding(self, data_file, data_summary, variable_description, task_description: str, task_analysis: str, formulas: str, modeling: str, dependent_file_prompt: str, code_template: str, script_name: str, work_dir: str, try_num: int = 5, round: int = 1, user_prompt: str = ''): |
| | for i in range(try_num): |
| | print("="*10 + f" Try: {i + 1} " + "="*10) |
| | iteration = 0 |
| | max_iteration = 3 |
| | while iteration < max_iteration: |
| | print("="*10 + f" Iteration: {iteration + 1} " + "="*10) |
| | if iteration == 0: |
| | code, observation = self.coding_actor(data_file, data_summary, variable_description, task_description, task_analysis, formulas, modeling, dependent_file_prompt, code_template, script_name, work_dir, user_prompt) |
| | |
| | if "Traceback (most recent call last):" not in observation and "SyntaxError: invalid syntax" not in observation and "IndentationError" not in observation: |
| | return code, True, observation.split("The script has been executed. Here is the output:\n")[1] |
| | else: |
| | code, observation = self.coding_debugger(code_template, modeling, code, observation, script_name, work_dir, user_prompt) |
| | |
| | if "Traceback (most recent call last):" not in observation and "SyntaxError: invalid syntax" not in observation and "IndentationError" not in observation: |
| | return code, True, observation.split("The script has been executed. Here is the output:\n")[1] |
| | iteration += 1 |
| |
|
| | return code, False, None |
| |
|
| | def result(self, task_description: str, task_analysis: str, task_formulas: str, task_modeling: str, user_prompt: str = '', execution_result: str = ''): |
| | if execution_result == '': |
| | prompt = TASK_RESULT_PROMPT.format(task_description=task_description, task_analysis=task_analysis, task_formulas=task_formulas, task_modeling=task_modeling, user_prompt=user_prompt).strip() |
| | else: |
| | prompt = TASK_RESULT_WITH_CODE_PROMPT.format(task_description=task_description, task_analysis=task_analysis, task_formulas=task_formulas, task_modeling=task_modeling, user_prompt=user_prompt, execution_result=execution_result).strip() |
| | return self.llm.generate(prompt) |
| |
|
| | def answer(self, task_description: str, task_analysis: str, task_formulas: str, task_modeling: str, task_result: str, user_prompt: str = ''): |
| | prompt = TASK_ANSWER_PROMPT.format(task_description=task_description, task_analysis=task_analysis, task_formulas=task_formulas, task_modeling=task_modeling, task_result=task_result, user_prompt=user_prompt).strip() |
| | return self.llm.generate(prompt) |
| |
|
| | def extract_code_structure(self, task_id, code: str, save_path: str): |
| | prompt = CODE_STRUCTURE_PROMPT.format(code=code, save_path=save_path) |
| | count = 0 |
| | for i in range(5): |
| | try: |
| | strucutre = self.llm.generate(prompt) |
| | structure_string = strucutre.strip('```json\n').strip('```') |
| | structure_json = json.loads(structure_string) |
| | for i in range(len(structure_json['file_outputs'])): |
| | structure_json['file_outputs'][i]['file_description'] = 'This file is generated by code for Task {}. '.format(task_id) + structure_json['file_outputs'][i]['file_description'] |
| | return structure_json |
| | except: |
| | continue |
| | if count == 5: |
| | sys.exit("Fail at extract_code_structure") |
| |
|