File size: 7,866 Bytes
88cc76c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
import json
import sqlite3
import os
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import torch
from PIL import Image
import clip
import faiss
import numpy as np
import glob

# Đường dẫn lưu trữ
VECTOR_EMBEDDINGS_DB_PATH = 'app/data/vector_embeddings.db'
IMAGE_FAISS_INDEX_PATH = 'app/data/image_faiss_index.index'
TEXT_FAISS_INDEX_PATH = 'app/data/text_faiss_index.index'

# Đường dẫn dữ liệu
DATA_ROOT = '/Users/artteiv/Desktop/Graduated/chore-graduated/Data'
MAIN_DATA_PATH = os.path.join(DATA_ROOT, 'main_data')
CAPTIONS_PATH = os.path.join(DATA_ROOT, 'captions')

# Kết nối SQLite
conn = sqlite3.connect(VECTOR_EMBEDDINGS_DB_PATH)
cursor = conn.cursor()

# Tạo bảng embeddings cho ảnh và văn bản
cursor.execute('''
    CREATE TABLE IF NOT EXISTS image_embeddings (
        e_index INTEGER PRIMARY KEY,
        image_path TEXT NOT NULL,
        caption TEXT NOT NULL,
        category TEXT NOT NULL,
        subcategory TEXT NOT NULL
    )
''')

cursor.execute('''
    CREATE TABLE IF NOT EXISTS text_embeddings (
        e_index INTEGER PRIMARY KEY,
        text TEXT NOT NULL,
        category TEXT NOT NULL,
        subcategory TEXT NOT NULL
    )
''')

def insert_image_embedding(e_index, image_path, caption, category, subcategory):
    """Thêm embedding ảnh vào SQLite."""
    cursor.execute('''
        INSERT INTO image_embeddings (e_index, image_path, caption, category, subcategory)
        VALUES (?, ?, ?, ?, ?)
    ''', (e_index, image_path, caption, category, subcategory))
    conn.commit()
    print(f"Đã thêm embedding ảnh: {image_path}")

def insert_text_embedding(e_index, text, category, subcategory):
    """Thêm embedding văn bản vào SQLite."""
    cursor.execute('''
        INSERT INTO text_embeddings (e_index, text, category, subcategory)
        VALUES (?, ?, ?, ?)
    ''', (e_index, text, category, subcategory))
    conn.commit()
    print(f"Đã thêm embedding văn bản: {text[:50]}...")

def save_faiss_index(index, index_file):
    """Lưu FAISS index vào file."""
    faiss.write_index(index, index_file)
    print(f"Đã lưu FAISS index vào {index_file}")

def load_faiss_index(index_file):
    """Nạp FAISS index từ file."""
    if os.path.exists(index_file):
        index = faiss.read_index(index_file)
        print(f"Đã nạp FAISS index từ {index_file}")
        return index
    return None

def compute_embeddings():
    """Tính toán embeddings cho ảnh và văn bản sử dụng CLIP."""
    print("Loading CLIP model...")
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model, preprocess = clip.load("ViT-B/32", device=device)
    print("Model loaded")

    # Lấy danh sách các thư mục con (categories)
    categories = [d for d in os.listdir(MAIN_DATA_PATH) if os.path.isdir(os.path.join(MAIN_DATA_PATH, d))]

    image_paths = []
    captions = []
    texts = []
    categories_list = []
    subcategories_list = []

    # Chuẩn bị dữ liệu
    print("Processing data from directories...")
    for category in categories:
        # Đường dẫn đến thư mục category
        category_path = os.path.join(MAIN_DATA_PATH, category)

        # Lấy danh sách các subcategories
        subcategories = [d for d in os.listdir(category_path) if os.path.isdir(os.path.join(category_path, d))]

        for subcategory in subcategories:
            # Đường dẫn đến thư mục ảnh và caption của subcategory
            subcategory_image_path = os.path.join(category_path, subcategory)
            subcategory_caption_path = os.path.join(CAPTIONS_PATH, category, subcategory)


            # Lấy danh sách ảnh
            image_files = glob.glob(os.path.join(subcategory_image_path, '*.*'))

            for img_path in image_files:
                # Lấy tên file không có phần mở rộng
                base_name = os.path.splitext(os.path.basename(img_path))[0]
                caption_file = os.path.join(subcategory_caption_path, f"{base_name}.txt")

                if os.path.exists(caption_file):
                    try:
                        # Đọc caption
                        with open(caption_file, 'r', encoding='utf-8') as f:
                            caption = f.read().strip()

                        # Thêm vào danh sách
                        image_paths.append(img_path)
                        captions.append(caption)
                        texts.append(caption)  # Sử dụng caption làm text
                        categories_list.append(category)
                        subcategories_list.append(subcategory)

                    except Exception as e:
                        print(f"Error processing {img_path}: {e}")
                        continue

    # Tính toán embeddings cho ảnh
    # if image_paths:
    #     print("Computing image embeddings...")
    #     image_embeddings = []
    #     for idx, img_path in enumerate(image_paths):
    #         try:
    #             image = preprocess(Image.open(img_path)).unsqueeze(0).to(device)
    #             with torch.no_grad():
    #                 image_features = model.encode_image(image)
    #                 image_features = image_features.cpu().numpy()
    #                 faiss.normalize_L2(image_features)
    #                 image_embeddings.append(image_features[0])
    #                 insert_image_embedding(idx, img_path, captions[idx], categories_list[idx], subcategories_list[idx])
    #         except Exception as e:
    #             print(f"Error processing image {img_path}: {e}")
    #             continue

    #     if image_embeddings:
    #         image_embeddings = np.array(image_embeddings)
    #         d = image_embeddings.shape[1]
    #         image_index = faiss.IndexFlatIP(d)
    #         image_index.add(image_embeddings)
    #         save_faiss_index(image_index, IMAGE_FAISS_INDEX_PATH)

    # Tính toán embeddings cho văn bản
    if texts:
        print("Computing text embeddings...")
        text_tokens = clip.tokenize(texts, truncate=True).to(device)
        print("Kích thước của text_tokens:", text_tokens.shape)
        with torch.no_grad():
            text_features = model.encode_text(text_tokens)
            text_features = text_features.cpu().numpy()
            faiss.normalize_L2(text_features)

            d = text_features.shape[1]
            text_index = faiss.IndexFlatIP(d)
            text_index.add(text_features)

            # Lưu text embeddings vào SQLite
            for idx, (text, category, subcategory) in enumerate(zip(texts, categories_list, subcategories_list)):
                insert_text_embedding(idx, text, category, subcategory)

            save_faiss_index(text_index, TEXT_FAISS_INDEX_PATH)

    print("Processing completed")
    return image_index if image_paths else None, text_index if texts else None

def predict_image(image_path):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model, preprocess = clip.load("ViT-B/32", device=device)

    image = preprocess(Image.open(image_path)).unsqueeze(0).to(device)
    with torch.no_grad():
        image_features = model.encode_image(image)
        image_features = image_features.cpu().numpy()
        faiss.normalize_L2(image_features)

    index = load_faiss_index(IMAGE_FAISS_INDEX_PATH)
    distances, indices = index.search(image_features, k=10)

    return distances, indices

if __name__ == '__main__':
    ## Predict

    try:
        image_index, text_index = compute_embeddings()
        if image_index:
            print(f"Image index ready with {image_index.ntotal} embeddings")
        if text_index:
            print(f"Text index ready with {text_index.ntotal} embeddings")
    finally:
        conn.close()
        print("SQLite connection closed")