ImageRetrieval / app.py
locboyf1's picture
Update app.py
7d7d030 verified
import os
import site
import subprocess
import sys
# Tập hợp các gói cần thiết
required_packages = [ "open-clip-torch", "gradio", "huggingface_hub", "matplotlib", "chromadb" ]
# Cài đặt các gói nếu chưa được cài đặt
def install_packages(packages):
for package in packages:
if package not in sys.modules:
subprocess.check_call([sys.executable, "-m", "pip", "install", package])
install_packages(required_packages) # B
from huggingface_hub import hf_hub_download
import shutil
from typing_extensions import Counter
import gradio as gr
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from tqdm import tqdm
import chromadb
from chromadb.utils.embedding_functions import OpenCLIPEmbeddingFunction
embedding_function = OpenCLIPEmbeddingFunction()
# data_path = hf_hub_download(repo_id="locboyf1/ImageRetrieval", filename="data.zip")
# shutil.unpack_archive(data_path, "data")
ROOT = "data"
CLASS_NAME = sorted(list(os.listdir(f'{ROOT}/train')))
def read_image_from_path(path, size):
im = Image.open(path).convert('RGB').resize(size)
return np.array(im)
def folder_to_images(folder, size):
list_dir = [folder + '/' + name for name in os.listdir(folder)]
images_np = np.zeros(shape=(len(list_dir), *size, 3))
images_path = []
for i, path in enumerate(list_dir):
images_np[i] = read_image_from_path(path, size)
images_path.append(path)
images_path = np.array(images_path)
return images_np, images_path
def mean_square_difference(query, data):
axis_batch_size = tuple(range(1, len(data.shape)))
return np.mean((data - query)**2, axis=axis_batch_size)
def get_single_image_embedding(image):
embedding = embedding_function._encode_image(image=image)
return np.array(embedding)
def get_files_path(path):
files_path = []
for label in CLASS_NAME:
label_path = path + "/" + label
filenames = os.listdir(label_path)
for filename in filenames:
filepath = label_path + '/' + filename
files_path.append(filepath)
return files_path
data_path = f'{ROOT}/train'
files_path = get_files_path(path=data_path)
def add_embedding(collection, files_path):
ids = []
embeddings = []
for id_filepath, filepath in tqdm(enumerate(files_path)):
ids.append(f'id_{id_filepath}')
image = Image.open(filepath)
# Chuyển hình ảnh PIL thành mảng Numpy
image_np = np.array(image)
embedding = get_single_image_embedding(image=image_np)
embeddings.append(embedding)
collection.add(
embeddings=embeddings,
ids=ids
)
# Tạo Chroma Client
chroma_client = chromadb.Client()
# Tạo collection
HNSW_SPACE = "hnsw"
l2_collection = chroma_client.get_or_create_collection(name="l2_collection", metadata={HNSW_SPACE: "l2"})
add_embedding(collection=l2_collection, files_path=files_path)
def search(query_image, collection, n_results):
query_embedding = get_single_image_embedding(query_image)
results = collection.query(
query_embeddings=[query_embedding],
n_results=n_results #số lượng kết quả trả về
)
return results
def plot_results(query_image, files_path, results):
# Hiển thị hình ảnh truy vấn
plt.figure(figsize=(20 , 10))
plt.subplot(2, 5, 1)
plt.imshow(query_image)
plt.title('Ảnh gốc')
plt.axis('off')
# ảnh kết quả
for i, result_id in enumerate(results['ids'][0]):
# lấy chuỗi result_id
result_id_str = str(result_id).strip("[]").replace("'", "").split(",")[0].strip()
try:
result_index = int(result_id_str.split('_')[1])
result_image_path = files_path[result_index]
result_image = Image.open(result_image_path)
plt.subplot(2, 5 , i + 2)
plt.imshow(result_image)
plt.title(f'Kết quả {i + 1}')
plt.axis('off')
except (IndexError, ValueError) as e:
print(f"Có lỗi thực thi, mã lỗi: {result_id_str}: {e}")
plt.show()
def display( query_image, files_path, l2_collection):
image_np = np.array(query_image)
results = search(query_image=image_np, collection=l2_collection, n_results=10)
result_image_paths = []
# ảnh kết quả
for i, result_id in enumerate(results['ids'][0]):
# lấy chuỗi result_id
result_id_str = str(result_id).strip("[]").replace("'", "").split(",")[0].strip()
result_index = int(result_id_str.split('_')[1])
result_image_path = files_path[result_index]
result_image_paths.append(result_image_path)
return result_image_paths
# print(display(query_image=query_image, files_path=files_path, l2_collection=l2_collection))
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
query_image = gr.Image(label="Tải lên ảnh", height=600)
with gr.Column():
result_images = gr.Gallery(label="Bộ ảnh kết quả",height=600, columns=2, rows=5)
button = gr.Button("Tìm kiếm")
button.click(fn=lambda query_image: display(query_image, files_path, l2_collection), inputs=query_image, outputs=result_images)
demo.launch()