|
|
|
|
|
import logging |
|
|
import os |
|
|
import pickle |
|
|
|
|
|
import faiss |
|
|
import numpy as np |
|
|
import torch |
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) |
|
|
|
|
|
FAISS_INDEX_DIR = os.environ.get('FAISS_INDEX_DIR', os.path.join(PROJECT_ROOT, 'faiss', 'data')) |
|
|
os.makedirs(FAISS_INDEX_DIR, exist_ok=True) |
|
|
|
|
|
|
|
|
FAISS_INDEX_PATH = os.path.join(FAISS_INDEX_DIR, "index.faiss") |
|
|
ID_MAP_PATH = os.path.join(FAISS_INDEX_DIR, "id_map.pkl") |
|
|
|
|
|
|
|
|
VECTOR_DIM = int(os.environ.get("VECTOR_DIM", 512)) |
|
|
|
|
|
|
|
|
index = None |
|
|
id_map = None |
|
|
|
|
|
def init_vector_store(): |
|
|
"""初始化向量存储""" |
|
|
global index, id_map |
|
|
try: |
|
|
|
|
|
if os.path.exists(FAISS_INDEX_PATH) and os.path.exists(ID_MAP_PATH): |
|
|
index = faiss.read_index(FAISS_INDEX_PATH) |
|
|
with open(ID_MAP_PATH, "rb") as f: |
|
|
id_map = pickle.load(f) |
|
|
logger.info(f"Vector store loaded successfully path={FAISS_INDEX_DIR}, contains {len(id_map)} vectors") |
|
|
else: |
|
|
index = faiss.IndexFlatIP(VECTOR_DIM) |
|
|
id_map = [] |
|
|
logger.info("Initializing new vector store") |
|
|
return True |
|
|
except Exception as e: |
|
|
logger.error(f"Vector store initialization failed: {e}") |
|
|
return False |
|
|
|
|
|
def is_vector_store_available(): |
|
|
"""检查向量存储是否可用""" |
|
|
return index is not None and id_map is not None |
|
|
|
|
|
def check_image_exists(image_path: str) -> bool: |
|
|
""" |
|
|
检查图像是否已经在向量库中存在 |
|
|
Args: |
|
|
image_path: 图像路径/标识 |
|
|
Returns: |
|
|
bool: 如果存在返回True,否则返回False |
|
|
""" |
|
|
try: |
|
|
if not is_vector_store_available(): |
|
|
return False |
|
|
return image_path in id_map |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to check if image exists: {str(e)}") |
|
|
return False |
|
|
|
|
|
def add_image_vector(image_path: str, vector: torch.Tensor): |
|
|
"""添加图片向量到存储""" |
|
|
if not is_vector_store_available(): |
|
|
raise RuntimeError("向量存储未初始化") |
|
|
|
|
|
np_vector = vector.squeeze(0).numpy().astype('float32') |
|
|
index.add(np_vector[np.newaxis, :]) |
|
|
id_map.append(image_path) |
|
|
save_index() |
|
|
logger.info(f"Image vector added: {image_path}") |
|
|
|
|
|
def search_text_vector(vector: torch.Tensor, top_k=5): |
|
|
"""搜索文本向量""" |
|
|
if not is_vector_store_available(): |
|
|
raise RuntimeError("向量存储未初始化") |
|
|
|
|
|
np_vector = vector.squeeze(0).numpy().astype('float32') |
|
|
scores, indices = index.search(np_vector[np.newaxis, :], top_k) |
|
|
|
|
|
if indices is None or len(indices[0]) == 0: |
|
|
return [] |
|
|
|
|
|
results = [ |
|
|
(id_map[i], float(scores[0][j])) |
|
|
for j, i in enumerate(indices[0]) |
|
|
if i < len(id_map) and i != -1 |
|
|
] |
|
|
return results |
|
|
|
|
|
def save_index(): |
|
|
"""保存索引文件""" |
|
|
try: |
|
|
faiss.write_index(index, FAISS_INDEX_PATH) |
|
|
with open(ID_MAP_PATH, "wb") as f: |
|
|
pickle.dump(id_map, f) |
|
|
logger.info("Vector index saved") |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to save vector index: {e}") |
|
|
|
|
|
def get_vector_store_info(): |
|
|
"""获取向量存储信息""" |
|
|
if not is_vector_store_available(): |
|
|
return {"status": "not_initialized", "count": 0} |
|
|
|
|
|
return { |
|
|
"status": "available", |
|
|
"count": len(id_map), |
|
|
"vector_dim": VECTOR_DIM, |
|
|
"index_path": FAISS_INDEX_PATH |
|
|
} |
|
|
|