import pandas as pd import os from tqdm import tqdm import numpy as np from sklearn.metrics.pairwise import cosine_similarity class Imagebase: def __init__(self, parquet_path=None): self.default_parquet_path = 'datas/imagebase.parquet' self.parquet_path = parquet_path or self.default_parquet_path self.datas = None if os.path.exists(self.parquet_path): # self.load_from_parquet(self.parquet_path) pass self.clip_extractor = None def random_sample(self, num_samples=12): if self.datas is not None: return self.datas.sample(num_samples).to_dict(orient='records') else: return [] def load_from_parquet(self, parquet_path): self.datas = pd.read_parquet(parquet_path) def save_to_parquet(self, parquet_path=None): parquet_path = parquet_path or self.default_parquet_path if self.datas is not None: self.datas.to_parquet(parquet_path) def init_clip_extractor(self): if self.clip_extractor is None: try: from CLIPExtractor import CLIPExtractor except: from src.CLIPExtractor import CLIPExtractor cache_dir = "models" self.clip_extractor = CLIPExtractor(model_name="openai/clip-vit-large-patch14", cache_dir=cache_dir) def top_k_search(self, query_feature, top_k=15): if self.datas is None: return [] if 'clip_feature' not in self.datas.columns: raise ValueError("clip_feature column not found in the data.") query_feature = np.array(query_feature).reshape(1, -1) attribute_features = np.stack(self.datas['clip_feature'].dropna().values) similarities = cosine_similarity(query_feature, attribute_features)[0] top_k_indices = np.argsort(similarities)[-top_k:][::-1] top_k_results = self.datas.iloc[top_k_indices].copy() top_k_results['similarity'] = similarities[top_k_indices] # Drop the 'clip_feature' column top_k_results = top_k_results.drop(columns=['clip_feature']) return top_k_results.to_dict(orient='records') def search_with_image_name(self, image_name): self.init_clip_extractor() img_feature = self.clip_extractor.extract_image_from_file(image_name) return self.top_k_search(img_feature) def search_with_image(self, image, if_opencv=False): self.init_clip_extractor() img_feature = self.clip_extractor.extract_image(image, if_opencv=if_opencv) return self.top_k_search(img_feature) def add_image(self, data, if_save = True, image_feature = None): required_fields = ['image_name', 'keyword', 'translated_word'] if not all(field in data for field in required_fields): raise ValueError(f"Data must contain the following fields: {required_fields}") image_name = data['image_name'] if image_feature is None: self.init_clip_extractor() data['clip_feature'] = self.clip_extractor.extract_image_from_file(image_name) else: data['clip_feature'] = image_feature if self.datas is None: self.datas = pd.DataFrame([data]) else: self.datas = pd.concat([self.datas, pd.DataFrame([data])], ignore_index=True) if if_save: self.save_to_parquet() def add_images(self, datas): for data in datas: self.add_image(data, if_save=False) self.save_to_parquet() import os from glob import glob def scan_and_update_imagebase(db, target_folder="temp_images"): # 获取target_folder目录下所有.jpg文件 image_files = glob(os.path.join(target_folder, "*.jpg")) duplicate_count = 0 added_count = 0 for image_path in image_files: # 使用文件名作为keyword keyword = os.path.basename(image_path).rsplit('.', 1)[0] translated_word = keyword # 可以根据需要调整translated_word # 搜索数据库中是否有相似的图片 results = db.search_with_image_name(image_path) if results and results[0]['similarity'] > 0.9: print(f"Image '{image_path}' is considered a duplicate.") duplicate_count += 1 else: new_image_data = { 'image_name': image_path, 'keyword': keyword, 'translated_word': translated_word } db.add_image(new_image_data) print(f"Image '{image_path}' added to the database.") added_count += 1 print(f"Total duplicate images found: {duplicate_count}") print(f"Total new images added to the database: {added_count}") if __name__ == '__main__': img_db = Imagebase() # 目标目录 target_folder = "temp_images" # 扫描并更新数据库 scan_and_update_imagebase(img_db, target_folder) # Usage example # img_db = Imagebase() # new_image_data = { # 'image_name': "datas/老虎.jpg", # 'keyword': 'tiger', # 'translated_word': '老虎' # } # img_db.add_image(new_image_data) # image_path = "datas/老虎.jpg" # results = img_db.search_with_image_name(image_path) # for result in results[:3]: # print(result)