File size: 5,838 Bytes
f27a827
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import numpy as np
import os,json
import re
import logging
from gpt_dialogue import Dialogue
import openai
from tenacity import (
    retry,
    before_sleep_log,
    stop_after_attempt,
    wait_random_exponential,
    wait_exponential,
    wait_exponential_jitter,
    RetryError
)  # for exponential backoff

openai.api_key = os.getenv("OPENAI_API_KEY")

logger = logging.getLogger(__name__+'logger')
logger.setLevel(logging.ERROR)
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.ERROR)
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
console_handler.setFormatter(formatter)
logger.addHandler(console_handler)

class ObjectFilter(Dialogue):
    def __init__(self, model='gpt-4'):
        config = {
        # 'model': 'gpt-4',
        # 'model': 'gpt-4-1106-preview',
        'model': model,
        'temperature': 0,
        'top_p': 0.0,
        'max_tokens': 8192,
        # 'load_path': './object_filter_pretext.json',
        'load_path': './object_filter_pretext_new.json',
        'debug': False
        }
        super().__init__(**config)
    
    def extract_all_int_lists_from_text(self,text) ->list:
        # 匹配方括号内的内容
        pattern = r'\[([^\[\]]+)\]'
        matches = re.findall(pattern, text)

        int_lists = []

        for match in matches:
            elements = match.split(',')
            int_list = []

            for element in elements:
                element = element.strip()
                try:
                    int_value = int(element)
                    int_list.append(int_value)
                except ValueError:
                    pass
                
            if len(int_list) == len(elements):
                int_lists = int_lists + int_list

        return int_lists

    def extract_dict_from_text(self,text) ->dict:
        # Use regular expression to match the dictionary in the text
        match = re.search(r'{\s*(.*?)\s*}', text)
        if match:
            # Get the matched dictionary content
            dict_str = match.group(1)
            # Convert the dictionary string to an actual dictionary object
            try:
                result_dict = eval('{' + dict_str + '}')
                return result_dict
            except Exception as e:
                print(f"Error converting string to dictionary: {e}")
                return None
        else:
            print("No dictionary found in the given text.")
            return None

    @retry(wait=wait_exponential_jitter(initial=20, max=120, jitter=20), stop=stop_after_attempt(5), before_sleep=before_sleep_log(logger,logging.ERROR)) #20s,40s,80s,120s + random.uniform(0,20)
    def filter_objects_by_description(self,description,use_npy_file,objects_info_path=None,object_info_list=None,to_print=True):
        # first, create the prompt
        print("looking for relevant objects based on description:\n'%s'"%description)
        prompt=""
        prompt=prompt+"description:\n'%s'\nobject list:\n"%description
        # load object info data and add to prompt
        if use_npy_file:
            data=np.load(objects_info_path,allow_pickle=True)
            for obj in data:
                if obj['label']=='object':
                    continue
                line="name=%s,id=%d; "%(obj['label'],obj['id'])
                prompt=prompt+line
        else: # object info list given, used for robot demo
            data=object_info_list
            for obj in data:
                label=obj.get('cls')
                if label is None:
                    label=obj.get('label')
                # if obj['cls']=='object':
                #     continue
                if label in ['object','otherfurniture','other','others']:
                    continue
                line="name=%s,id=%d; "%(label,obj['id'])
                prompt=prompt+line
        
        
        # get response from gpt
        response,token_usage=self.call_openai(prompt)
        response=response['content']
        # print("response:",response)
        last_line = response.splitlines()[-1] if len(response) > 0 else ''

        # exract answer(list/dict) from the last line of response
        # answer=self.extract_all_int_lists_from_text(last_line)
        answer=self.extract_dict_from_text(last_line)
        if to_print:
            self.print_pretext()
            print("answer:",answer)
            print("\n\n")
        if len(answer)==0:
            answer=None
        return answer,token_usage
    

    
if __name__ == "__main__":
    # scanrefer_path="/share/data/ripl/vincenttann/sr3d/data/scanrefer/ScanRefer_filtered_sampled50.json"
    scanrefer_path="/share/data/ripl/vincenttann/sr3d/data/scanrefer/ScanRefer_filtered_train_sampled1000.json"
    with open(scanrefer_path, 'r') as json_file:
        scanrefer_data=json.load(json_file)
    
    from datetime import datetime
    # 记录时间作为文件名
    current_time = datetime.now()
    formatted_time = current_time.strftime("%Y-%m-%d-%H-%M-%S")
    print("formatted_time:",formatted_time)
    folder_path="/share/data/ripl/vincenttann/sr3d/object_filter_dialogue/%s/"%formatted_time
    os.makedirs(folder_path)

    for idx,data in enumerate(scanrefer_data):
        print("processing %d/%d..."%(idx+1,len(scanrefer_data)))
        description=data['description']
        scan_id=data['scene_id']
        target_id=data['object_id']
        # path="/share/data/ripl/scannet_raw/train/objects_info_gf/objects_info_gf_%s.npy"%scan_id
        path="/share/data/ripl/scannet_raw/train/objects_info/objects_info_%s.npy"%scan_id
        of=ObjectFilter()
        of.filter_objects_by_description(path,description)
        object_filter_json_name="%d_%s_%s_object_filter.json"%(idx,scan_id,target_id)
        of.save_pretext(folder_path,object_filter_json_name)