|
from transformers import AutoTokenizer, AutoModel |
|
import torch |
|
import torch.nn.functional as F |
|
import hnswlib |
|
import numpy as np |
|
import datetime |
|
from fastapi import FastAPI |
|
from pydantic import BaseModel |
|
from typing import List |
|
|
|
if torch.cuda.is_available(): |
|
print("CUDA is available! Inference on GPU!") |
|
else: |
|
print("CUDA is not available. Inference on CPU.") |
|
|
|
seperator = "-HFSEP-" |
|
base_name="intfloat/e5-small-v2" |
|
device="cuda" |
|
max_length=512 |
|
max_batch_size = 500 |
|
tokenizer = AutoTokenizer.from_pretrained(base_name) |
|
model = AutoModel.from_pretrained(base_name).to(device) |
|
|
|
def current_timestamp(): |
|
return datetime.datetime.utcnow().timestamp() |
|
|
|
def get_embeddings(input_texts): |
|
input_texts = input_texts[:max_batch_size] |
|
batch_dict = tokenizer( |
|
input_texts, |
|
max_length=max_length, |
|
padding=True, |
|
truncation=True, |
|
return_tensors='pt' |
|
).to(device) |
|
|
|
with torch.no_grad(): |
|
outputs = model(**batch_dict) |
|
embeddings = _average_pool( |
|
outputs.last_hidden_state, batch_dict['attention_mask'] |
|
) |
|
embeddings = F.normalize(embeddings, p=2, dim=1) |
|
embeddings_np = embeddings.cpu().numpy() |
|
|
|
if device == "cuda": |
|
del embeddings |
|
torch.cuda.empty_cache() |
|
|
|
return embeddings_np |
|
|
|
def _average_pool( |
|
last_hidden_states, |
|
attention_mask |
|
): |
|
last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0) |
|
return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None] |
|
|
|
def create_hnsw_index(embeddings_np, space='ip', ef_construction=100, M=16): |
|
index = hnswlib.Index(space=space, dim=len(embeddings_np[0])) |
|
index.init_index(max_elements=len(embeddings_np), ef_construction=ef_construction, M=M) |
|
ids = np.arange(embeddings_np.shape[0]) |
|
index.add_items(embeddings_np, ids) |
|
return index |
|
|
|
def preprocess_texts(query, paragraphs): |
|
query = f'query: {query}' |
|
paragraphs = [f'passage: {p}' for p in paragraphs] |
|
return [query]+paragraphs |
|
|
|
|
|
app = FastAPI() |
|
|
|
class EmbeddingsSimilarityReq(BaseModel): |
|
paragraphs: List[str] |
|
query: str |
|
top_k: int |
|
|
|
@app.post("/") |
|
async def find_similar_paragraphsitem(req: EmbeddingsSimilarityReq): |
|
print("Len of batches", len(req.paragraphs)) |
|
|
|
print("creating embeddings", current_timestamp()) |
|
inputs = preprocess_texts(req.query, req.paragraphs) |
|
embeddings_np = get_embeddings(inputs) |
|
query_embedding, chunks_embeddings = embeddings_np[0], embeddings_np[1:] |
|
|
|
print("creating index", current_timestamp()) |
|
search_index = create_hnsw_index(chunks_embeddings) |
|
print("searching index", current_timestamp()) |
|
labels, _ = search_index.knn_query(query_embedding, k=min(int(req.top_k), len(chunks_embeddings))) |
|
labels = labels[0].tolist() |
|
return labels |
|
|