File size: 7,197 Bytes
0319a9a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 |
import os
from glob import glob
try:
from src.Database import Database
from src.Captioner import Captioner
from src.ImageBase import Imagebase
from src.get_major_object import get_major_object, verify_keyword_in_base
from src.generate_cultivation import generate_cultivation_with_rag
except:
from Database import Database
from Captioner import Captioner
from ImageBase import Imagebase
from get_major_object import get_major_object, verify_keyword_in_base
from generate_cultivation import generate_cultivation_with_rag
class GameMaster:
def __init__( self ):
self.textdb = self.init_textdb()
self.clip_extractor = self.textdb.clip_extractor
self.imgdb = self.init_imgdb()
self.captioner = Captioner()
self.minimal_image_threshold = 0.9
def init_textdb( self ):
text_db = Database()
text_db.init_bge_extractor()
text_db.init_clip_extractor()
return text_db
def init_imgdb( self ):
img_db = Imagebase()
return img_db
def random_image_text_data( self, n = 12 ):
random_img_datas = self.imgdb.random_sample(n)
# keep image_name and keywords only
image_names = [img_data['image_name'] for img_data in random_img_datas]
blank_image_path = "datas/blank_item.jpg"
for i in range(len(image_names)):
if not os.path.exists(image_names[i]):
image_names[i] = blank_image_path
keywords_zh = [img_data['keyword'] for img_data in random_img_datas]
keywords = [img_data['translated_word'] for img_data in random_img_datas]
descriptions = []
for keyword, keyword_zh in zip(keywords, keywords_zh):
result = self.textdb.search_by_en_keyword(keyword)
if result and "description_in_cultivation" in result:
description = result['description_in_cultivation']
if "name_in_cultivation" in result:
description = result['name_in_cultivation'] + "--" + description
descriptions.append(description)
else:
descriptions.append("")
#return tuple of imapge path and description
return zip(image_names, descriptions)
def search_with_path( self, image_path , threshold = None ):
# this is a relatively light weight search
image_feature = self.clip_extractor.extract_image_from_file(image_path)
# image_search_result = img_db.search_with_image_name(image_path)
image_search_result = self.imgdb.top_k_search(image_feature, top_k=1)
search_result = None
if threshold is None:
threshold = self.minimal_image_threshold
if image_search_result and len(image_search_result)>0 and image_search_result[0]['similarity'] > threshold:
# try find data with translated_word
result = self.textdb.search_by_en_keyword(image_search_result[0]['translated_word'])
if result and "name_in_cultivation" in result:
search_result = result
search_result['similarity'] = image_search_result[0]['similarity']
else:
print("Warning! Unfound keyword: ", image_search_result[0]['translated_word'])
# backup_results = None
# if search_result is None:
# try search with textdb
backup_results = self.textdb.top_k_search(image_feature, 'clip_feature', top_k = 5)
return search_result, backup_results, image_feature
def generate_cultivation_data( self, image_path , image_feature, text_search_result ):
# this is very expensive
cultivation_data = None
try:
caption_response = self.captioner.caption(image_path)
except:
print("Error occurred while captioning the image ", image_path)
return cultivation_data
if text_search_result is None:
# complete text search
text_search_result = self.textdb.top_k_search(image_feature, 'clip_feature', top_k = 5)
seen = set()
keywords = [res['translated_word'] for res in text_search_result if not (res['translated_word'] in seen or seen.add(res['translated_word']))]
try:
json_response = get_major_object(caption_response , keywords)
except:
print("Error occurred while getting major object from caption ", caption_response)
return cultivation_data
in_base_data , alt_data = verify_keyword_in_base(json_response , self.textdb )
if in_base_data is not None:
cultivation_data = in_base_data
# 这意味着找到了一张新的图片,不需要生成额外的词条
# required_fields = ['image_name', 'keyword', 'translated_word']
image_data = {
'image_name': image_path,
'keyword': in_base_data['keyword'],
'translated_word': in_base_data['translated_word']
}
self.imgdb.add_image( image_data, True, image_feature )
elif alt_data is not None:
try:
cultivation_data = generate_cultivation_with_rag(alt_data, text_search_result)
except:
print("Error occurred while generating cultivation data")
return cultivation_data
new_data = {
"keyword": alt_data['keyword'],
"name_in_cultivation": cultivation_data['new_name'],
"description_in_cultivation": cultivation_data['final_enhanced_description'],
"translated_word": alt_data['translated_word'],
"description": alt_data['description']
}
self.textdb.add_data(new_data)
print("Added new data to textdb: ", new_data["name_in_cultivation"])
image_data = {
'image_name': image_path,
'keyword': new_data['keyword'],
'translated_word': new_data['translated_word']
}
self.imgdb.add_image( image_data, True, image_feature )
print("Added new image to imgdb: ", image_data["keyword"])
cultivation_data = new_data
return cultivation_data
if __name__ == "__main__":
os.environ['HTTP_PROXY'] = 'http://localhost:8234'
os.environ['HTTPS_PROXY'] = 'http://localhost:8234'
game_master = GameMaster()
target_folder="temp_images"
image_files = glob(os.path.join(target_folder, "*.jpg"))
for index, image_path in enumerate(image_files):
print("index:" , index )
search_result, backup_results, image_feature = game_master.search_with_path(image_path)
if search_result:
print(search_result)
break
test_image_path = "temp_images/向日葵.jpg"
search_result, backup_results, image_feature = game_master.search_with_path(test_image_path)
cultivation_data = game_master.generate_cultivation_data( \
test_image_path, image_feature, backup_results )
print(cultivation_data) |