|
import pandas as pd |
|
import os |
|
from tqdm import tqdm |
|
import numpy as np |
|
from sklearn.metrics.pairwise import cosine_similarity |
|
|
|
class Imagebase: |
|
def __init__(self, parquet_path=None): |
|
self.default_parquet_path = 'datas/imagebase.parquet' |
|
self.parquet_path = parquet_path or self.default_parquet_path |
|
self.datas = None |
|
|
|
if os.path.exists(self.parquet_path): |
|
self.load_from_parquet(self.parquet_path) |
|
|
|
self.clip_extractor = None |
|
|
|
def random_sample(self, num_samples=12): |
|
if self.datas is not None: |
|
return self.datas.sample(num_samples).to_dict(orient='records') |
|
else: |
|
return [] |
|
|
|
def load_from_parquet(self, parquet_path): |
|
self.datas = pd.read_parquet(parquet_path) |
|
|
|
def save_to_parquet(self, parquet_path=None): |
|
parquet_path = parquet_path or self.default_parquet_path |
|
if self.datas is not None: |
|
self.datas.to_parquet(parquet_path) |
|
|
|
def init_clip_extractor(self): |
|
if self.clip_extractor is None: |
|
try: |
|
from CLIPExtractor import CLIPExtractor |
|
except: |
|
from src.CLIPExtractor import CLIPExtractor |
|
|
|
cache_dir = "D:\\aistudio\\LubaoGithub\\models" |
|
self.clip_extractor = CLIPExtractor(model_name="openai/clip-vit-large-patch14", cache_dir=cache_dir) |
|
|
|
def top_k_search(self, query_feature, top_k=15): |
|
if self.datas is None: |
|
return [] |
|
if 'clip_feature' not in self.datas.columns: |
|
raise ValueError("clip_feature column not found in the data.") |
|
|
|
query_feature = np.array(query_feature).reshape(1, -1) |
|
attribute_features = np.stack(self.datas['clip_feature'].dropna().values) |
|
|
|
similarities = cosine_similarity(query_feature, attribute_features)[0] |
|
|
|
top_k_indices = np.argsort(similarities)[-top_k:][::-1] |
|
|
|
top_k_results = self.datas.iloc[top_k_indices].copy() |
|
|
|
top_k_results['similarity'] = similarities[top_k_indices] |
|
|
|
|
|
top_k_results = top_k_results.drop(columns=['clip_feature']) |
|
|
|
return top_k_results.to_dict(orient='records') |
|
|
|
|
|
def search_with_image_name(self, image_name): |
|
self.init_clip_extractor() |
|
|
|
img_feature = self.clip_extractor.extract_image_from_file(image_name) |
|
|
|
return self.top_k_search(img_feature) |
|
|
|
def search_with_image(self, image, if_opencv=False): |
|
self.init_clip_extractor() |
|
|
|
img_feature = self.clip_extractor.extract_image(image, if_opencv=if_opencv) |
|
|
|
return self.top_k_search(img_feature) |
|
|
|
def add_image(self, data, if_save = True, image_feature = None): |
|
required_fields = ['image_name', 'keyword', 'translated_word'] |
|
if not all(field in data for field in required_fields): |
|
raise ValueError(f"Data must contain the following fields: {required_fields}") |
|
|
|
|
|
|
|
image_name = data['image_name'] |
|
if image_feature is None: |
|
self.init_clip_extractor() |
|
data['clip_feature'] = self.clip_extractor.extract_image_from_file(image_name) |
|
else: |
|
data['clip_feature'] = image_feature |
|
|
|
if self.datas is None: |
|
self.datas = pd.DataFrame([data]) |
|
else: |
|
self.datas = pd.concat([self.datas, pd.DataFrame([data])], ignore_index=True) |
|
if if_save: |
|
self.save_to_parquet() |
|
|
|
def add_images(self, datas): |
|
for data in datas: |
|
self.add_image(data, if_save=False) |
|
self.save_to_parquet() |
|
|
|
import os |
|
from glob import glob |
|
|
|
def scan_and_update_imagebase(db, target_folder="temp_images"): |
|
|
|
image_files = glob(os.path.join(target_folder, "*.jpg")) |
|
|
|
duplicate_count = 0 |
|
added_count = 0 |
|
|
|
for image_path in image_files: |
|
|
|
keyword = os.path.basename(image_path).rsplit('.', 1)[0] |
|
translated_word = keyword |
|
|
|
|
|
results = db.search_with_image_name(image_path) |
|
|
|
if results and results[0]['similarity'] > 0.9: |
|
print(f"Image '{image_path}' is considered a duplicate.") |
|
duplicate_count += 1 |
|
else: |
|
new_image_data = { |
|
'image_name': image_path, |
|
'keyword': keyword, |
|
'translated_word': translated_word |
|
} |
|
db.add_image(new_image_data) |
|
print(f"Image '{image_path}' added to the database.") |
|
added_count += 1 |
|
|
|
print(f"Total duplicate images found: {duplicate_count}") |
|
print(f"Total new images added to the database: {added_count}") |
|
|
|
if __name__ == '__main__': |
|
img_db = Imagebase() |
|
|
|
|
|
target_folder = "temp_images" |
|
|
|
|
|
scan_and_update_imagebase(img_db, target_folder) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|