|
import os |
|
from glob import glob |
|
|
|
try: |
|
from src.Database import Database |
|
from src.Captioner import Captioner |
|
from src.ImageBase import Imagebase |
|
from src.get_major_object import get_major_object, verify_keyword_in_base |
|
from src.generate_cultivation import generate_cultivation_with_rag |
|
except: |
|
from Database import Database |
|
from Captioner import Captioner |
|
from ImageBase import Imagebase |
|
from get_major_object import get_major_object, verify_keyword_in_base |
|
from generate_cultivation import generate_cultivation_with_rag |
|
|
|
|
|
class GameMaster: |
|
def __init__( self ): |
|
self.textdb = self.init_textdb() |
|
|
|
self.clip_extractor = self.textdb.clip_extractor |
|
|
|
self.imgdb = self.init_imgdb() |
|
|
|
self.captioner = Captioner() |
|
|
|
self.minimal_image_threshold = 0.9 |
|
|
|
def init_textdb( self ): |
|
text_db = Database() |
|
text_db.init_bge_extractor() |
|
text_db.init_clip_extractor() |
|
return text_db |
|
|
|
def init_imgdb( self ): |
|
img_db = Imagebase() |
|
return img_db |
|
|
|
def random_image_text_data( self, n = 12 ): |
|
random_img_datas = self.imgdb.random_sample(n) |
|
|
|
image_names = [img_data['image_name'] for img_data in random_img_datas] |
|
blank_image_path = "datas/blank_item.jpg" |
|
for i in range(len(image_names)): |
|
if not os.path.exists(image_names[i]): |
|
image_names[i] = blank_image_path |
|
|
|
keywords_zh = [img_data['keyword'] for img_data in random_img_datas] |
|
keywords = [img_data['translated_word'] for img_data in random_img_datas] |
|
descriptions = [] |
|
|
|
for keyword, keyword_zh in zip(keywords, keywords_zh): |
|
result = self.textdb.search_by_en_keyword(keyword) |
|
if result and "description_in_cultivation" in result: |
|
description = result['description_in_cultivation'] |
|
if "name_in_cultivation" in result: |
|
description = result['name_in_cultivation'] + "--" + description |
|
descriptions.append(description) |
|
else: |
|
descriptions.append("") |
|
|
|
|
|
return zip(image_names, descriptions) |
|
|
|
|
|
def search_with_path( self, image_path , threshold = None ): |
|
|
|
image_feature = self.clip_extractor.extract_image_from_file(image_path) |
|
|
|
|
|
image_search_result = self.imgdb.top_k_search(image_feature, top_k=1) |
|
|
|
search_result = None |
|
|
|
if threshold is None: |
|
threshold = self.minimal_image_threshold |
|
|
|
if image_search_result and len(image_search_result)>0 and image_search_result[0]['similarity'] > threshold: |
|
|
|
|
|
result = self.textdb.search_by_en_keyword(image_search_result[0]['translated_word']) |
|
if result and "name_in_cultivation" in result: |
|
search_result = result |
|
search_result['similarity'] = image_search_result[0]['similarity'] |
|
else: |
|
print("Warning! Unfound keyword: ", image_search_result[0]['translated_word']) |
|
|
|
|
|
|
|
|
|
backup_results = self.textdb.top_k_search(image_feature, 'clip_feature', top_k = 5) |
|
|
|
return search_result, backup_results, image_feature |
|
|
|
def generate_cultivation_data( self, image_path , image_feature, text_search_result ): |
|
|
|
|
|
cultivation_data = None |
|
|
|
try: |
|
caption_response = self.captioner.caption(image_path) |
|
except: |
|
print("Error occurred while captioning the image ", image_path) |
|
return cultivation_data |
|
|
|
if text_search_result is None: |
|
|
|
text_search_result = self.textdb.top_k_search(image_feature, 'clip_feature', top_k = 5) |
|
|
|
seen = set() |
|
keywords = [res['translated_word'] for res in text_search_result if not (res['translated_word'] in seen or seen.add(res['translated_word']))] |
|
|
|
try: |
|
json_response = get_major_object(caption_response , keywords) |
|
except: |
|
print("Error occurred while getting major object from caption ", caption_response) |
|
return cultivation_data |
|
|
|
in_base_data , alt_data = verify_keyword_in_base(json_response , self.textdb ) |
|
|
|
if in_base_data is not None: |
|
cultivation_data = in_base_data |
|
|
|
|
|
|
|
image_data = { |
|
'image_name': image_path, |
|
'keyword': in_base_data['keyword'], |
|
'translated_word': in_base_data['translated_word'] |
|
} |
|
self.imgdb.add_image( image_data, True, image_feature ) |
|
elif alt_data is not None: |
|
try: |
|
cultivation_data = generate_cultivation_with_rag(alt_data, text_search_result) |
|
except: |
|
print("Error occurred while generating cultivation data") |
|
return cultivation_data |
|
|
|
new_data = { |
|
"keyword": alt_data['keyword'], |
|
"name_in_cultivation": cultivation_data['new_name'], |
|
"description_in_cultivation": cultivation_data['final_enhanced_description'], |
|
"translated_word": alt_data['translated_word'], |
|
"description": alt_data['description'] |
|
} |
|
self.textdb.add_data(new_data) |
|
print("Added new data to textdb: ", new_data["name_in_cultivation"]) |
|
|
|
image_data = { |
|
'image_name': image_path, |
|
'keyword': new_data['keyword'], |
|
'translated_word': new_data['translated_word'] |
|
} |
|
self.imgdb.add_image( image_data, True, image_feature ) |
|
print("Added new image to imgdb: ", image_data["keyword"]) |
|
|
|
cultivation_data = new_data |
|
|
|
return cultivation_data |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
os.environ['HTTP_PROXY'] = 'http://localhost:8234' |
|
os.environ['HTTPS_PROXY'] = 'http://localhost:8234' |
|
|
|
game_master = GameMaster() |
|
|
|
target_folder="temp_images" |
|
|
|
image_files = glob(os.path.join(target_folder, "*.jpg")) |
|
|
|
for index, image_path in enumerate(image_files): |
|
print("index:" , index ) |
|
|
|
search_result, backup_results, image_feature = game_master.search_with_path(image_path) |
|
|
|
if search_result: |
|
print(search_result) |
|
|
|
break |
|
|
|
test_image_path = "temp_images/向日葵.jpg" |
|
|
|
search_result, backup_results, image_feature = game_master.search_with_path(test_image_path) |
|
cultivation_data = game_master.generate_cultivation_data( \ |
|
test_image_path, image_feature, backup_results ) |
|
print(cultivation_data) |