DoAn / core /rag /embedding_model.py
hungnha's picture
change commit
b91b0a5
from __future__ import annotations
import os
import logging
import time
from dataclasses import dataclass
from typing import List, Sequence
import numpy as np
from openai import OpenAI
from langchain_core.embeddings import Embeddings
logger = logging.getLogger(__name__)
@dataclass
class EmbeddingConfig:
"""Cấu hình cho embedding model."""
api_base_url: str = "https://api.siliconflow.com/v1" # SiliconFlow API
model: str = "Qwen/Qwen3-Embedding-4B" # Model embedding
dimension: int = 2048 # Số chiều vector
batch_size: int = 16 # Số text mỗi batch
_embed_config: EmbeddingConfig | None = None
def get_embedding_config() -> EmbeddingConfig:
"""Lấy cấu hình embedding (singleton pattern)."""
global _embed_config
if _embed_config is None:
_embed_config = EmbeddingConfig()
return _embed_config
class QwenEmbeddings(Embeddings):
"""Wrapper embedding model Qwen qua SiliconFlow API"""
def __init__(self, config: EmbeddingConfig | None = None):
"""Khởi tạo embedding client."""
self.config = config or get_embedding_config()
api_key = os.getenv("SILICONFLOW_API_KEY", "").strip()
if not api_key:
raise ValueError("Chưa đặt biến môi trường SILICONFLOW_API_KEY")
self._client = OpenAI(
api_key=api_key,
base_url=self.config.api_base_url,
)
logger.info(f"Đã khởi tạo QwenEmbeddings: {self.config.model}")
def embed_query(self, text: str) -> List[float]:
"""Embed một câu query (dùng cho search)."""
return self._embed_texts([text])[0]
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Embed nhiều documents (dùng khi index)."""
return self._embed_texts(texts)
def _embed_texts(self, texts: Sequence[str]) -> List[List[float]]:
"""Embed danh sách texts theo batch với retry logic."""
if not texts:
return []
all_embeddings: List[List[float]] = []
batch_size = self.config.batch_size
max_retries = 3
# Xử lý theo batch
for i in range(0, len(texts), batch_size):
batch = list(texts[i:i + batch_size])
# Retry logic cho rate limit
for attempt in range(max_retries):
try:
response = self._client.embeddings.create(
model=self.config.model,
input=batch,
)
for item in response.data:
all_embeddings.append(item.embedding)
break
except Exception as e:
# Nếu bị rate limit -> đợi rồi thử lại
if "rate" in str(e).lower() and attempt < max_retries - 1:
wait_time = 2 ** attempt # Exponential backoff: 1s, 2s, 4s
logger.warning(f"Bị rate limit, đợi {wait_time}s...")
time.sleep(wait_time)
else:
raise
return all_embeddings
def embed_texts_np(self, texts: Sequence[str]) -> np.ndarray:
"""Embed texts và trả về numpy array (tiện cho tính toán)."""
return np.asarray(self._embed_texts(list(texts)), dtype=np.float32)
# Alias để tương thích ngược
SiliconFlowConfig = EmbeddingConfig
get_config = get_embedding_config