|
import pandas as pd |
|
import os |
|
from tqdm import tqdm |
|
|
|
import numpy as np |
|
from sklearn.metrics.pairwise import cosine_similarity |
|
|
|
|
|
class Database: |
|
def __init__(self, parquet_path=None, customized_parquets = None): |
|
self.default_parquet_path = 'datas/database_4000.parquet' |
|
self.parquet_path = parquet_path or self.default_parquet_path |
|
|
|
self.default_customized_parquets = ["datas/customized_database_0.parquet"] |
|
self.customized_parquets = customized_parquets or self.default_customized_parquets |
|
|
|
self.datas = None |
|
self.last_save_table = None |
|
|
|
if os.path.exists(self.parquet_path): |
|
self.load_from_parquet(self.parquet_path) |
|
|
|
self.load_from_customized(self.customized_parquets) |
|
|
|
self.clip_extractor = None |
|
self.bge_extractor = None |
|
|
|
self.en_keyword2data = {} |
|
|
|
def build_en_keyword2index(self): |
|
|
|
self.en_keyword2data = {row['translated_word'].lower(): row for i, row in self.datas.iterrows()} |
|
|
|
def search_by_en_keyword(self, keyword): |
|
if len(self.en_keyword2data) == 0: |
|
self.build_en_keyword2index() |
|
|
|
keyword = keyword.lower() |
|
if keyword in self.en_keyword2data: |
|
ans = self.en_keyword2data[keyword].to_dict() |
|
del ans["clip_feature"] |
|
del ans["bge_feature"] |
|
return ans |
|
else: |
|
return None |
|
|
|
def load_from_parquet(self, parquet_path): |
|
self.datas = pd.read_parquet(parquet_path) |
|
|
|
def load_from_customized(self, customized_parquets=None): |
|
customized_parquets = customized_parquets or self.customized_parquets |
|
|
|
|
|
for index, parquet_file in enumerate(customized_parquets): |
|
if os.path.exists(parquet_file): |
|
temp_df = pd.read_parquet(parquet_file) |
|
if self.datas is None: |
|
self.datas = temp_df |
|
else: |
|
self.datas = pd.concat([self.datas, temp_df], ignore_index=True) |
|
|
|
|
|
if index == len(customized_parquets) - 1: |
|
self.last_save_table = temp_df |
|
|
|
|
|
|
|
|
|
|
|
def add_data(self, data, if_save=True): |
|
required_columns = ['keyword', 'name_in_cultivation', 'description_in_cultivation', 'translated_word', 'description'] |
|
for column in required_columns: |
|
if column not in data: |
|
raise ValueError(f"Missing required field: {column}") |
|
|
|
|
|
if 'founder' not in data: |
|
data['founder'] = "" |
|
|
|
|
|
if self.clip_extractor is None: |
|
self.init_clip_extractor() |
|
if self.bge_extractor is None: |
|
self.init_bge_extractor() |
|
|
|
data['clip_feature'] = self.clip_extractor.extract_text(data['translated_word'] + '.' + data['description']) |
|
data['bge_feature'] = self.bge_extractor.extract([data['keyword']])[0].tolist() |
|
|
|
|
|
data_df = pd.DataFrame([data]) |
|
if self.datas is None: |
|
self.datas = data_df |
|
else: |
|
self.datas = pd.concat([self.datas, data_df], ignore_index=True) |
|
|
|
|
|
self.en_keyword2data[data['translated_word'].lower()] = self.datas.iloc[-1] |
|
|
|
|
|
if self.last_save_table is None: |
|
|
|
|
|
self.last_save_table = pd.DataFrame(columns=self.datas.columns) |
|
|
|
self.last_save_table = pd.concat([self.last_save_table, data_df], ignore_index=True) |
|
|
|
if if_save: |
|
self.save_to_parquet(self.customized_parquets[-1], self.last_save_table ) |
|
|
|
def add_datas(self, datas, if_save=True): |
|
for data in datas: |
|
self.add_data(data, if_save=False) |
|
if if_save: |
|
self.save_to_parquet(self.customized_parquets[-1], self.last_save_table) |
|
|
|
def init_from_excel(self, excel_path): |
|
df = pd.read_excel(excel_path) |
|
|
|
|
|
df.dropna(subset=['keyword', 'name_in_cultivation', 'description_in_cultivation', 'translated_word', 'description'], inplace=True) |
|
|
|
|
|
df['clip_feature'] = None |
|
df['bge_feature'] = None |
|
|
|
self.datas = df |
|
|
|
self.extract_clip() |
|
self.extract_bge() |
|
|
|
def save_to_parquet(self, parquet_path=None, df = None): |
|
|
|
parquet_path = parquet_path or self.default_parquet_path |
|
if df is None: |
|
if self.datas is not None: |
|
self.datas.to_parquet(parquet_path) |
|
else: |
|
df.to_parquet(parquet_path) |
|
|
|
def init_clip_extractor(self): |
|
if self.clip_extractor is None: |
|
try: |
|
from CLIPExtractor import CLIPExtractor |
|
except: |
|
from src.CLIPExtractor import CLIPExtractor |
|
|
|
cache_dir = "models" |
|
|
|
self.clip_extractor = CLIPExtractor(model_name = "openai/clip-vit-large-patch14",cache_dir = cache_dir) |
|
|
|
|
|
def extract_clip(self): |
|
if self.clip_extractor is None: |
|
self.init_clip_extractor() |
|
|
|
clip_features = [] |
|
|
|
for index, row in tqdm(self.datas.iterrows(), desc='Extracting CLIP features', total=len(self.datas)): |
|
text = row['translated_word'] + '.' + row['description'] |
|
if text: |
|
feature = self.clip_extractor.extract_text(text) |
|
else: |
|
feature = None |
|
clip_features.append(feature) |
|
|
|
self.datas['clip_feature'] = clip_features |
|
|
|
def init_bge_extractor(self): |
|
if self.bge_extractor is None: |
|
try: |
|
from text_embedding import TextExtractor |
|
except: |
|
from src.text_embedding import TextExtractor |
|
|
|
self.bge_extractor = TextExtractor('BAAI/bge-small-zh-v1.5') |
|
|
|
def top_k_search(self, query_feature, attribute, top_k=15): |
|
|
|
if attribute not in self.datas.columns: |
|
raise ValueError(f"Attribute {attribute} not found in the data.") |
|
|
|
|
|
query_feature = np.array(query_feature).reshape(1, -1) |
|
attribute_features = np.stack(self.datas[attribute].dropna().values) |
|
|
|
|
|
similarities = cosine_similarity(query_feature, attribute_features)[0] |
|
|
|
|
|
top_k_indices = np.argsort(similarities)[-top_k:][::-1] |
|
|
|
|
|
top_k_results = self.datas.iloc[top_k_indices].copy() |
|
|
|
top_k_results = top_k_results.drop(columns=['clip_feature', 'bge_feature']) |
|
|
|
top_k_results['similarity'] = similarities[top_k_indices] |
|
|
|
return top_k_results.to_dict(orient='records') |
|
|
|
def search_with_image_name(self, image_name): |
|
self.init_clip_extractor() |
|
|
|
img_feature = self.clip_extractor.extract_image_from_file(image_name) |
|
|
|
return self.top_k_search(img_feature, 'clip_feature') |
|
|
|
def search_with_image(self, image, if_opencv = False ): |
|
if self.clip_extractor is None: |
|
self.init_clip_extractor() |
|
|
|
img_feature = self.clip_extractor.extract_image(image, if_opencv = if_opencv) |
|
|
|
return self.top_k_search(img_feature, 'clip_feature') |
|
|
|
def search_with_chinese(self, text): |
|
if self.bge_extractor is None: |
|
self.init_bge_extractor() |
|
|
|
text_feature = self.bge_extractor.extract([text])[0].tolist() |
|
|
|
return self.top_k_search(text_feature, 'bge_feature') |
|
|
|
|
|
|
|
def extract_bge(self): |
|
if self.bge_extractor is None: |
|
self.init_bge_extractor() |
|
|
|
|
|
bge_features = [] |
|
for text in tqdm(self.datas['keyword'], desc='Extracting BGE features'): |
|
if text: |
|
feature = self.bge_extractor.extract([text])[0].tolist() |
|
else: |
|
feature = None |
|
bge_features.append(feature) |
|
|
|
self.datas['bge_feature'] = bge_features |
|
|
|
if __name__ == '__main__': |
|
|
|
db = Database() |
|
re_extract = False |
|
if db.datas is None or re_extract: |
|
print("Rebuilding database from excel file") |
|
db.init_from_excel('datas/database_4000.xlsx') |
|
db.save_to_parquet() |
|
|
|
|
|
|
|
query_text = "钢琴" |
|
|
|
results = db.search_with_chinese(query_text) |
|
|
|
print(results[0].keys()) |
|
|
|
for result in results[:3]: |
|
print(result) |
|
|
|
image_path = "datas/老虎.jpg" |
|
|
|
results = db.search_with_image_name(image_path) |
|
|
|
for result in results[:3]: |
|
print(result) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|