Spaces:
Sleeping
Sleeping
import os | |
import site | |
import subprocess | |
import sys | |
# Tập hợp các gói cần thiết | |
required_packages = [ "open-clip-torch", "gradio", "huggingface_hub", "matplotlib", "chromadb" ] | |
# Cài đặt các gói nếu chưa được cài đặt | |
def install_packages(packages): | |
for package in packages: | |
if package not in sys.modules: | |
subprocess.check_call([sys.executable, "-m", "pip", "install", package]) | |
install_packages(required_packages) # B | |
from huggingface_hub import hf_hub_download | |
import shutil | |
from typing_extensions import Counter | |
import gradio as gr | |
import numpy as np | |
from PIL import Image | |
import matplotlib.pyplot as plt | |
from tqdm import tqdm | |
import chromadb | |
from chromadb.utils.embedding_functions import OpenCLIPEmbeddingFunction | |
embedding_function = OpenCLIPEmbeddingFunction() | |
# data_path = hf_hub_download(repo_id="locboyf1/ImageRetrieval", filename="data.zip") | |
# shutil.unpack_archive(data_path, "data") | |
ROOT = "data" | |
CLASS_NAME = sorted(list(os.listdir(f'{ROOT}/train'))) | |
def read_image_from_path(path, size): | |
im = Image.open(path).convert('RGB').resize(size) | |
return np.array(im) | |
def folder_to_images(folder, size): | |
list_dir = [folder + '/' + name for name in os.listdir(folder)] | |
images_np = np.zeros(shape=(len(list_dir), *size, 3)) | |
images_path = [] | |
for i, path in enumerate(list_dir): | |
images_np[i] = read_image_from_path(path, size) | |
images_path.append(path) | |
images_path = np.array(images_path) | |
return images_np, images_path | |
def mean_square_difference(query, data): | |
axis_batch_size = tuple(range(1, len(data.shape))) | |
return np.mean((data - query)**2, axis=axis_batch_size) | |
def get_single_image_embedding(image): | |
embedding = embedding_function._encode_image(image=image) | |
return np.array(embedding) | |
def get_files_path(path): | |
files_path = [] | |
for label in CLASS_NAME: | |
label_path = path + "/" + label | |
filenames = os.listdir(label_path) | |
for filename in filenames: | |
filepath = label_path + '/' + filename | |
files_path.append(filepath) | |
return files_path | |
data_path = f'{ROOT}/train' | |
files_path = get_files_path(path=data_path) | |
def add_embedding(collection, files_path): | |
ids = [] | |
embeddings = [] | |
for id_filepath, filepath in tqdm(enumerate(files_path)): | |
ids.append(f'id_{id_filepath}') | |
image = Image.open(filepath) | |
# Chuyển hình ảnh PIL thành mảng Numpy | |
image_np = np.array(image) | |
embedding = get_single_image_embedding(image=image_np) | |
embeddings.append(embedding) | |
collection.add( | |
embeddings=embeddings, | |
ids=ids | |
) | |
# Tạo Chroma Client | |
chroma_client = chromadb.Client() | |
# Tạo collection | |
HNSW_SPACE = "hnsw" | |
l2_collection = chroma_client.get_or_create_collection(name="l2_collection", metadata={HNSW_SPACE: "l2"}) | |
add_embedding(collection=l2_collection, files_path=files_path) | |
def search(query_image, collection, n_results): | |
query_embedding = get_single_image_embedding(query_image) | |
results = collection.query( | |
query_embeddings=[query_embedding], | |
n_results=n_results #số lượng kết quả trả về | |
) | |
return results | |
def plot_results(query_image, files_path, results): | |
# Hiển thị hình ảnh truy vấn | |
plt.figure(figsize=(20 , 10)) | |
plt.subplot(2, 5, 1) | |
plt.imshow(query_image) | |
plt.title('Ảnh gốc') | |
plt.axis('off') | |
# ảnh kết quả | |
for i, result_id in enumerate(results['ids'][0]): | |
# lấy chuỗi result_id | |
result_id_str = str(result_id).strip("[]").replace("'", "").split(",")[0].strip() | |
try: | |
result_index = int(result_id_str.split('_')[1]) | |
result_image_path = files_path[result_index] | |
result_image = Image.open(result_image_path) | |
plt.subplot(2, 5 , i + 2) | |
plt.imshow(result_image) | |
plt.title(f'Kết quả {i + 1}') | |
plt.axis('off') | |
except (IndexError, ValueError) as e: | |
print(f"Có lỗi thực thi, mã lỗi: {result_id_str}: {e}") | |
plt.show() | |
def display( query_image, files_path, l2_collection): | |
image_np = np.array(query_image) | |
results = search(query_image=image_np, collection=l2_collection, n_results=10) | |
result_image_paths = [] | |
# ảnh kết quả | |
for i, result_id in enumerate(results['ids'][0]): | |
# lấy chuỗi result_id | |
result_id_str = str(result_id).strip("[]").replace("'", "").split(",")[0].strip() | |
result_index = int(result_id_str.split('_')[1]) | |
result_image_path = files_path[result_index] | |
result_image_paths.append(result_image_path) | |
return result_image_paths | |
# print(display(query_image=query_image, files_path=files_path, l2_collection=l2_collection)) | |
with gr.Blocks() as demo: | |
with gr.Row(): | |
with gr.Column(): | |
query_image = gr.Image(label="Tải lên ảnh", height=600) | |
with gr.Column(): | |
result_images = gr.Gallery(label="Bộ ảnh kết quả",height=600, columns=2, rows=5) | |
button = gr.Button("Tìm kiếm") | |
button.click(fn=lambda query_image: display(query_image, files_path, l2_collection), inputs=query_image, outputs=result_images) | |
demo.launch() |