File size: 4,162 Bytes
8fc2b4e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import numpy as np
import os
import IPython

import traceback
import json
from gensim.utils import (
    save_text,
    add_to_txt,
    extract_dict,
    format_dict_prompt,
    generate_feedback,
)
import copy
import random

class Critic:
    """
    class that reflects and criticizes new task for improvement
    """
    def __init__(self, cfg, memory):
        self.prompt_folder = f"prompts/{cfg['prompt_folder']}"
        self.memory = memory
        self.chat_log = self.memory.chat_log
        self.cfg = cfg
        self.model_output_dir = cfg["model_output_dir"]

    def error_review(self, new_task):
        """ commonly made error review """
        if os.path.exists(f"{self.prompt_folder}/cliport_prompt_common_errors_template.txt") and "task-name" in new_task:
            self.chat_log = add_to_txt(self.chat_log, "================= Error Book Preview!", with_print=True)
            errorbook_prompt_text = open(f'{self.prompt_folder}/cliport_prompt_common_errors_template.txt').read()
            errorbook_prompt_text = errorbook_prompt_text.replace("TASK_NAME_TEMPLATE", new_task["task-name"])
            res = generate_feedback(errorbook_prompt_text, temperature=0., interaction_txt=self.chat_log) # cfg['gpt_temperature']

    def reflection(self, new_task, new_code, current_tasks=None):
        """ reflect on if the new task needs to be added """
        all_add_to_the_task_list_flag = True

        if os.path.exists(f"{self.prompt_folder}/cliport_prompt_task_reflection.txt"):
            # only consider successful task
            self.chat_log = add_to_txt(self.chat_log, "================= Code Reflect!", with_print=True)
            total_tasks = copy.deepcopy(self.memory.online_task_buffer)
            if current_tasks is not None:
                # adding all the tasks in the current run. at least should not overlap with those
                for t in current_tasks:
                    total_tasks[t['task-name']] = t

            # need to load more
            total_tasks = self.memory.online_task_buffer
            MAX_NUM = 40
            if len(total_tasks) > MAX_NUM:
                total_tasks = dict(random.sample(total_tasks.items(), MAX_NUM))

            print("reflection history task num:", len(total_tasks))
            task_descriptions_replacement_str = format_dict_prompt(total_tasks, -1)

            # append current new task
            code_reflection_prompt_text = open(f"{self.prompt_folder}/cliport_prompt_task_reflection.txt").read()
            code_reflection_prompt_text = code_reflection_prompt_text.replace("CURRENT_TASK_NAME_TEMPLATE", str(task_descriptions_replacement_str))
            code_reflection_prompt_text = code_reflection_prompt_text.replace("TASK_STRING_TEMPLATE", str(new_task))
            code_reflection_prompt_text = code_reflection_prompt_text.replace("TASK_CODE_TEMPLATE", str(new_code))
            if len(self.cfg['target_task_name']) > 0:
                code_reflection_prompt_text = code_reflection_prompt_text.replace("TARGET_TASK_NAME", self.cfg['target_task_name'])

            # no matter
            total_tasks[new_task["task-name"].replace("-", "_")] = str(new_task)
            res = generate_feedback(code_reflection_prompt_text, temperature=0.4, interaction_txt=self.chat_log, n=int(self.cfg['reflection_agreement_num'])) # cfg['gpt_temperature']
            all_add_to_the_task_list_flag = True

            for idx, r in enumerate(res):
                # iterate through for agreement
                reflection_def_cmd = extract_dict(r, prefix='task_reflection')
                exec(reflection_def_cmd, globals())
                try:
                    print(f"critic {idx}:", task_reflection)

                    if task_reflection["add_to_the_task_list"] == 'False':
                        all_add_to_the_task_list_flag = False
                        print(f"critic {idx} suggests not adding this task to the buffer! ")
                except:
                    IPython.embed()
            save_text(self.model_output_dir, new_task['task-name'] + "_reflection_output", str(task_reflection))

        return all_add_to_the_task_list_flag