crop-diag-module / prepare_script /image_caption_embeddings.py
Sontranwakumo
init: move from github
88cc76c
import json
import sqlite3
import os
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import torch
from PIL import Image
import clip
import faiss
import numpy as np
import glob
# Đường dẫn lưu trữ
VECTOR_EMBEDDINGS_DB_PATH = 'app/data/vector_embeddings.db'
IMAGE_FAISS_INDEX_PATH = 'app/data/image_faiss_index.index'
TEXT_FAISS_INDEX_PATH = 'app/data/text_faiss_index.index'
# Đường dẫn dữ liệu
DATA_ROOT = '/Users/artteiv/Desktop/Graduated/chore-graduated/Data'
MAIN_DATA_PATH = os.path.join(DATA_ROOT, 'main_data')
CAPTIONS_PATH = os.path.join(DATA_ROOT, 'captions')
# Kết nối SQLite
conn = sqlite3.connect(VECTOR_EMBEDDINGS_DB_PATH)
cursor = conn.cursor()
# Tạo bảng embeddings cho ảnh và văn bản
cursor.execute('''
CREATE TABLE IF NOT EXISTS image_embeddings (
e_index INTEGER PRIMARY KEY,
image_path TEXT NOT NULL,
caption TEXT NOT NULL,
category TEXT NOT NULL,
subcategory TEXT NOT NULL
)
''')
cursor.execute('''
CREATE TABLE IF NOT EXISTS text_embeddings (
e_index INTEGER PRIMARY KEY,
text TEXT NOT NULL,
category TEXT NOT NULL,
subcategory TEXT NOT NULL
)
''')
def insert_image_embedding(e_index, image_path, caption, category, subcategory):
"""Thêm embedding ảnh vào SQLite."""
cursor.execute('''
INSERT INTO image_embeddings (e_index, image_path, caption, category, subcategory)
VALUES (?, ?, ?, ?, ?)
''', (e_index, image_path, caption, category, subcategory))
conn.commit()
print(f"Đã thêm embedding ảnh: {image_path}")
def insert_text_embedding(e_index, text, category, subcategory):
"""Thêm embedding văn bản vào SQLite."""
cursor.execute('''
INSERT INTO text_embeddings (e_index, text, category, subcategory)
VALUES (?, ?, ?, ?)
''', (e_index, text, category, subcategory))
conn.commit()
print(f"Đã thêm embedding văn bản: {text[:50]}...")
def save_faiss_index(index, index_file):
"""Lưu FAISS index vào file."""
faiss.write_index(index, index_file)
print(f"Đã lưu FAISS index vào {index_file}")
def load_faiss_index(index_file):
"""Nạp FAISS index từ file."""
if os.path.exists(index_file):
index = faiss.read_index(index_file)
print(f"Đã nạp FAISS index từ {index_file}")
return index
return None
def compute_embeddings():
"""Tính toán embeddings cho ảnh và văn bản sử dụng CLIP."""
print("Loading CLIP model...")
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)
print("Model loaded")
# Lấy danh sách các thư mục con (categories)
categories = [d for d in os.listdir(MAIN_DATA_PATH) if os.path.isdir(os.path.join(MAIN_DATA_PATH, d))]
image_paths = []
captions = []
texts = []
categories_list = []
subcategories_list = []
# Chuẩn bị dữ liệu
print("Processing data from directories...")
for category in categories:
# Đường dẫn đến thư mục category
category_path = os.path.join(MAIN_DATA_PATH, category)
# Lấy danh sách các subcategories
subcategories = [d for d in os.listdir(category_path) if os.path.isdir(os.path.join(category_path, d))]
for subcategory in subcategories:
# Đường dẫn đến thư mục ảnh và caption của subcategory
subcategory_image_path = os.path.join(category_path, subcategory)
subcategory_caption_path = os.path.join(CAPTIONS_PATH, category, subcategory)
# Lấy danh sách ảnh
image_files = glob.glob(os.path.join(subcategory_image_path, '*.*'))
for img_path in image_files:
# Lấy tên file không có phần mở rộng
base_name = os.path.splitext(os.path.basename(img_path))[0]
caption_file = os.path.join(subcategory_caption_path, f"{base_name}.txt")
if os.path.exists(caption_file):
try:
# Đọc caption
with open(caption_file, 'r', encoding='utf-8') as f:
caption = f.read().strip()
# Thêm vào danh sách
image_paths.append(img_path)
captions.append(caption)
texts.append(caption) # Sử dụng caption làm text
categories_list.append(category)
subcategories_list.append(subcategory)
except Exception as e:
print(f"Error processing {img_path}: {e}")
continue
# Tính toán embeddings cho ảnh
# if image_paths:
# print("Computing image embeddings...")
# image_embeddings = []
# for idx, img_path in enumerate(image_paths):
# try:
# image = preprocess(Image.open(img_path)).unsqueeze(0).to(device)
# with torch.no_grad():
# image_features = model.encode_image(image)
# image_features = image_features.cpu().numpy()
# faiss.normalize_L2(image_features)
# image_embeddings.append(image_features[0])
# insert_image_embedding(idx, img_path, captions[idx], categories_list[idx], subcategories_list[idx])
# except Exception as e:
# print(f"Error processing image {img_path}: {e}")
# continue
# if image_embeddings:
# image_embeddings = np.array(image_embeddings)
# d = image_embeddings.shape[1]
# image_index = faiss.IndexFlatIP(d)
# image_index.add(image_embeddings)
# save_faiss_index(image_index, IMAGE_FAISS_INDEX_PATH)
# Tính toán embeddings cho văn bản
if texts:
print("Computing text embeddings...")
text_tokens = clip.tokenize(texts, truncate=True).to(device)
print("Kích thước của text_tokens:", text_tokens.shape)
with torch.no_grad():
text_features = model.encode_text(text_tokens)
text_features = text_features.cpu().numpy()
faiss.normalize_L2(text_features)
d = text_features.shape[1]
text_index = faiss.IndexFlatIP(d)
text_index.add(text_features)
# Lưu text embeddings vào SQLite
for idx, (text, category, subcategory) in enumerate(zip(texts, categories_list, subcategories_list)):
insert_text_embedding(idx, text, category, subcategory)
save_faiss_index(text_index, TEXT_FAISS_INDEX_PATH)
print("Processing completed")
return image_index if image_paths else None, text_index if texts else None
def predict_image(image_path):
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)
image = preprocess(Image.open(image_path)).unsqueeze(0).to(device)
with torch.no_grad():
image_features = model.encode_image(image)
image_features = image_features.cpu().numpy()
faiss.normalize_L2(image_features)
index = load_faiss_index(IMAGE_FAISS_INDEX_PATH)
distances, indices = index.search(image_features, k=10)
return distances, indices
if __name__ == '__main__':
## Predict
try:
image_index, text_index = compute_embeddings()
if image_index:
print(f"Image index ready with {image_index.ntotal} embeddings")
if text_index:
print(f"Text index ready with {text_index.ntotal} embeddings")
finally:
conn.close()
print("SQLite connection closed")