Spaces:
Running
Running
File size: 5,838 Bytes
f27a827 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 |
import numpy as np
import os,json
import re
import logging
from gpt_dialogue import Dialogue
import openai
from tenacity import (
retry,
before_sleep_log,
stop_after_attempt,
wait_random_exponential,
wait_exponential,
wait_exponential_jitter,
RetryError
) # for exponential backoff
openai.api_key = os.getenv("OPENAI_API_KEY")
logger = logging.getLogger(__name__+'logger')
logger.setLevel(logging.ERROR)
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.ERROR)
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
console_handler.setFormatter(formatter)
logger.addHandler(console_handler)
class ObjectFilter(Dialogue):
def __init__(self, model='gpt-4'):
config = {
# 'model': 'gpt-4',
# 'model': 'gpt-4-1106-preview',
'model': model,
'temperature': 0,
'top_p': 0.0,
'max_tokens': 8192,
# 'load_path': './object_filter_pretext.json',
'load_path': './object_filter_pretext_new.json',
'debug': False
}
super().__init__(**config)
def extract_all_int_lists_from_text(self,text) ->list:
# 匹配方括号内的内容
pattern = r'\[([^\[\]]+)\]'
matches = re.findall(pattern, text)
int_lists = []
for match in matches:
elements = match.split(',')
int_list = []
for element in elements:
element = element.strip()
try:
int_value = int(element)
int_list.append(int_value)
except ValueError:
pass
if len(int_list) == len(elements):
int_lists = int_lists + int_list
return int_lists
def extract_dict_from_text(self,text) ->dict:
# Use regular expression to match the dictionary in the text
match = re.search(r'{\s*(.*?)\s*}', text)
if match:
# Get the matched dictionary content
dict_str = match.group(1)
# Convert the dictionary string to an actual dictionary object
try:
result_dict = eval('{' + dict_str + '}')
return result_dict
except Exception as e:
print(f"Error converting string to dictionary: {e}")
return None
else:
print("No dictionary found in the given text.")
return None
@retry(wait=wait_exponential_jitter(initial=20, max=120, jitter=20), stop=stop_after_attempt(5), before_sleep=before_sleep_log(logger,logging.ERROR)) #20s,40s,80s,120s + random.uniform(0,20)
def filter_objects_by_description(self,description,use_npy_file,objects_info_path=None,object_info_list=None,to_print=True):
# first, create the prompt
print("looking for relevant objects based on description:\n'%s'"%description)
prompt=""
prompt=prompt+"description:\n'%s'\nobject list:\n"%description
# load object info data and add to prompt
if use_npy_file:
data=np.load(objects_info_path,allow_pickle=True)
for obj in data:
if obj['label']=='object':
continue
line="name=%s,id=%d; "%(obj['label'],obj['id'])
prompt=prompt+line
else: # object info list given, used for robot demo
data=object_info_list
for obj in data:
label=obj.get('cls')
if label is None:
label=obj.get('label')
# if obj['cls']=='object':
# continue
if label in ['object','otherfurniture','other','others']:
continue
line="name=%s,id=%d; "%(label,obj['id'])
prompt=prompt+line
# get response from gpt
response,token_usage=self.call_openai(prompt)
response=response['content']
# print("response:",response)
last_line = response.splitlines()[-1] if len(response) > 0 else ''
# exract answer(list/dict) from the last line of response
# answer=self.extract_all_int_lists_from_text(last_line)
answer=self.extract_dict_from_text(last_line)
if to_print:
self.print_pretext()
print("answer:",answer)
print("\n\n")
if len(answer)==0:
answer=None
return answer,token_usage
if __name__ == "__main__":
# scanrefer_path="/share/data/ripl/vincenttann/sr3d/data/scanrefer/ScanRefer_filtered_sampled50.json"
scanrefer_path="/share/data/ripl/vincenttann/sr3d/data/scanrefer/ScanRefer_filtered_train_sampled1000.json"
with open(scanrefer_path, 'r') as json_file:
scanrefer_data=json.load(json_file)
from datetime import datetime
# 记录时间作为文件名
current_time = datetime.now()
formatted_time = current_time.strftime("%Y-%m-%d-%H-%M-%S")
print("formatted_time:",formatted_time)
folder_path="/share/data/ripl/vincenttann/sr3d/object_filter_dialogue/%s/"%formatted_time
os.makedirs(folder_path)
for idx,data in enumerate(scanrefer_data):
print("processing %d/%d..."%(idx+1,len(scanrefer_data)))
description=data['description']
scan_id=data['scene_id']
target_id=data['object_id']
# path="/share/data/ripl/scannet_raw/train/objects_info_gf/objects_info_gf_%s.npy"%scan_id
path="/share/data/ripl/scannet_raw/train/objects_info/objects_info_%s.npy"%scan_id
of=ObjectFilter()
of.filter_objects_by_description(path,description)
object_filter_json_name="%d_%s_%s_object_filter.json"%(idx,scan_id,target_id)
of.save_pretext(folder_path,object_filter_json_name) |