|
import numpy as np |
|
import os |
|
import IPython |
|
|
|
import traceback |
|
import json |
|
from gensim.utils import ( |
|
save_text, |
|
add_to_txt, |
|
extract_dict, |
|
format_dict_prompt, |
|
generate_feedback, |
|
) |
|
import copy |
|
import random |
|
|
|
class Critic: |
|
""" |
|
class that reflects and criticizes new task for improvement |
|
""" |
|
def __init__(self, cfg, memory): |
|
self.prompt_folder = f"prompts/{cfg['prompt_folder']}" |
|
self.memory = memory |
|
self.chat_log = self.memory.chat_log |
|
self.cfg = cfg |
|
self.model_output_dir = cfg["model_output_dir"] |
|
|
|
def error_review(self, new_task): |
|
""" commonly made error review """ |
|
if os.path.exists(f"{self.prompt_folder}/cliport_prompt_common_errors_template.txt") and "task-name" in new_task: |
|
self.chat_log = add_to_txt(self.chat_log, "================= Error Book Preview!", with_print=True) |
|
errorbook_prompt_text = open(f'{self.prompt_folder}/cliport_prompt_common_errors_template.txt').read() |
|
errorbook_prompt_text = errorbook_prompt_text.replace("TASK_NAME_TEMPLATE", new_task["task-name"]) |
|
res = generate_feedback(errorbook_prompt_text, temperature=0., interaction_txt=self.chat_log) |
|
|
|
def reflection(self, new_task, new_code, current_tasks=None): |
|
""" reflect on if the new task needs to be added """ |
|
all_add_to_the_task_list_flag = True |
|
|
|
if os.path.exists(f"{self.prompt_folder}/cliport_prompt_task_reflection.txt"): |
|
|
|
self.chat_log = add_to_txt(self.chat_log, "================= Code Reflect!", with_print=True) |
|
total_tasks = copy.deepcopy(self.memory.online_task_buffer) |
|
if current_tasks is not None: |
|
|
|
for t in current_tasks: |
|
total_tasks[t['task-name']] = t |
|
|
|
|
|
total_tasks = self.memory.online_task_buffer |
|
MAX_NUM = 40 |
|
if len(total_tasks) > MAX_NUM: |
|
total_tasks = dict(random.sample(total_tasks.items(), MAX_NUM)) |
|
|
|
print("reflection history task num:", len(total_tasks)) |
|
task_descriptions_replacement_str = format_dict_prompt(total_tasks, -1) |
|
|
|
|
|
code_reflection_prompt_text = open(f"{self.prompt_folder}/cliport_prompt_task_reflection.txt").read() |
|
code_reflection_prompt_text = code_reflection_prompt_text.replace("CURRENT_TASK_NAME_TEMPLATE", str(task_descriptions_replacement_str)) |
|
code_reflection_prompt_text = code_reflection_prompt_text.replace("TASK_STRING_TEMPLATE", str(new_task)) |
|
code_reflection_prompt_text = code_reflection_prompt_text.replace("TASK_CODE_TEMPLATE", str(new_code)) |
|
if len(self.cfg['target_task_name']) > 0: |
|
code_reflection_prompt_text = code_reflection_prompt_text.replace("TARGET_TASK_NAME", self.cfg['target_task_name']) |
|
|
|
|
|
total_tasks[new_task["task-name"].replace("-", "_")] = str(new_task) |
|
res = generate_feedback(code_reflection_prompt_text, temperature=0.4, interaction_txt=self.chat_log, n=int(self.cfg['reflection_agreement_num'])) |
|
all_add_to_the_task_list_flag = True |
|
|
|
for idx, r in enumerate(res): |
|
|
|
reflection_def_cmd = extract_dict(r, prefix='task_reflection') |
|
exec(reflection_def_cmd, globals()) |
|
try: |
|
print(f"critic {idx}:", task_reflection) |
|
|
|
if task_reflection["add_to_the_task_list"] == 'False': |
|
all_add_to_the_task_list_flag = False |
|
print(f"critic {idx} suggests not adding this task to the buffer! ") |
|
except: |
|
IPython.embed() |
|
save_text(self.model_output_dir, new_task['task-name'] + "_reflection_output", str(task_reflection)) |
|
|
|
return all_add_to_the_task_list_flag |