Spaces:
Runtime error
Runtime error
import pandas as pd | |
import os | |
from tqdm import tqdm | |
import numpy as np | |
from sklearn.metrics.pairwise import cosine_similarity | |
class Imagebase: | |
def __init__(self, parquet_path=None): | |
self.default_parquet_path = 'datas/imagebase.parquet' | |
self.parquet_path = parquet_path or self.default_parquet_path | |
self.datas = None | |
if os.path.exists(self.parquet_path): | |
# self.load_from_parquet(self.parquet_path) | |
pass | |
self.clip_extractor = None | |
def random_sample(self, num_samples=12): | |
if self.datas is not None: | |
return self.datas.sample(num_samples).to_dict(orient='records') | |
else: | |
return [] | |
def load_from_parquet(self, parquet_path): | |
self.datas = pd.read_parquet(parquet_path) | |
def save_to_parquet(self, parquet_path=None): | |
parquet_path = parquet_path or self.default_parquet_path | |
if self.datas is not None: | |
self.datas.to_parquet(parquet_path) | |
def init_clip_extractor(self): | |
if self.clip_extractor is None: | |
try: | |
from CLIPExtractor import CLIPExtractor | |
except: | |
from src.CLIPExtractor import CLIPExtractor | |
cache_dir = "models" | |
self.clip_extractor = CLIPExtractor(model_name="openai/clip-vit-large-patch14", cache_dir=cache_dir) | |
def top_k_search(self, query_feature, top_k=15): | |
if self.datas is None: | |
return [] | |
if 'clip_feature' not in self.datas.columns: | |
raise ValueError("clip_feature column not found in the data.") | |
query_feature = np.array(query_feature).reshape(1, -1) | |
attribute_features = np.stack(self.datas['clip_feature'].dropna().values) | |
similarities = cosine_similarity(query_feature, attribute_features)[0] | |
top_k_indices = np.argsort(similarities)[-top_k:][::-1] | |
top_k_results = self.datas.iloc[top_k_indices].copy() | |
top_k_results['similarity'] = similarities[top_k_indices] | |
# Drop the 'clip_feature' column | |
top_k_results = top_k_results.drop(columns=['clip_feature']) | |
return top_k_results.to_dict(orient='records') | |
def search_with_image_name(self, image_name): | |
self.init_clip_extractor() | |
img_feature = self.clip_extractor.extract_image_from_file(image_name) | |
return self.top_k_search(img_feature) | |
def search_with_image(self, image, if_opencv=False): | |
self.init_clip_extractor() | |
img_feature = self.clip_extractor.extract_image(image, if_opencv=if_opencv) | |
return self.top_k_search(img_feature) | |
def add_image(self, data, if_save = True, image_feature = None): | |
required_fields = ['image_name', 'keyword', 'translated_word'] | |
if not all(field in data for field in required_fields): | |
raise ValueError(f"Data must contain the following fields: {required_fields}") | |
image_name = data['image_name'] | |
if image_feature is None: | |
self.init_clip_extractor() | |
data['clip_feature'] = self.clip_extractor.extract_image_from_file(image_name) | |
else: | |
data['clip_feature'] = image_feature | |
if self.datas is None: | |
self.datas = pd.DataFrame([data]) | |
else: | |
self.datas = pd.concat([self.datas, pd.DataFrame([data])], ignore_index=True) | |
if if_save: | |
self.save_to_parquet() | |
def add_images(self, datas): | |
for data in datas: | |
self.add_image(data, if_save=False) | |
self.save_to_parquet() | |
import os | |
from glob import glob | |
def scan_and_update_imagebase(db, target_folder="temp_images"): | |
# 获取target_folder目录下所有.jpg文件 | |
image_files = glob(os.path.join(target_folder, "*.jpg")) | |
duplicate_count = 0 | |
added_count = 0 | |
for image_path in image_files: | |
# 使用文件名作为keyword | |
keyword = os.path.basename(image_path).rsplit('.', 1)[0] | |
translated_word = keyword # 可以根据需要调整translated_word | |
# 搜索数据库中是否有相似的图片 | |
results = db.search_with_image_name(image_path) | |
if results and results[0]['similarity'] > 0.9: | |
print(f"Image '{image_path}' is considered a duplicate.") | |
duplicate_count += 1 | |
else: | |
new_image_data = { | |
'image_name': image_path, | |
'keyword': keyword, | |
'translated_word': translated_word | |
} | |
db.add_image(new_image_data) | |
print(f"Image '{image_path}' added to the database.") | |
added_count += 1 | |
print(f"Total duplicate images found: {duplicate_count}") | |
print(f"Total new images added to the database: {added_count}") | |
if __name__ == '__main__': | |
img_db = Imagebase() | |
# 目标目录 | |
target_folder = "temp_images" | |
# 扫描并更新数据库 | |
scan_and_update_imagebase(img_db, target_folder) | |
# Usage example | |
# img_db = Imagebase() | |
# new_image_data = { | |
# 'image_name': "datas/老虎.jpg", | |
# 'keyword': 'tiger', | |
# 'translated_word': '老虎' | |
# } | |
# img_db.add_image(new_image_data) | |
# image_path = "datas/老虎.jpg" | |
# results = img_db.search_with_image_name(image_path) | |
# for result in results[:3]: | |
# print(result) | |