from transformers import AutoTokenizer, AutoModel import torch import torch.nn.functional as F import hnswlib import numpy as np import datetime from fastapi import FastAPI from pydantic import BaseModel from typing import List if torch.cuda.is_available(): print("CUDA is available! Inference on GPU!") else: print("CUDA is not available. Inference on CPU.") seperator = "-HFSEP-" base_name="intfloat/e5-small-v2" device="cuda" max_length=512 max_batch_size = 500 tokenizer = AutoTokenizer.from_pretrained(base_name) model = AutoModel.from_pretrained(base_name).to(device) def current_timestamp(): return datetime.datetime.utcnow().timestamp() def get_embeddings(input_texts): input_texts = input_texts[:max_batch_size] batch_dict = tokenizer( input_texts, max_length=max_length, padding=True, truncation=True, return_tensors='pt' ).to(device) with torch.no_grad(): outputs = model(**batch_dict) embeddings = _average_pool( outputs.last_hidden_state, batch_dict['attention_mask'] ) embeddings = F.normalize(embeddings, p=2, dim=1) embeddings_np = embeddings.cpu().numpy() if device == "cuda": del embeddings torch.cuda.empty_cache() return embeddings_np def _average_pool( last_hidden_states, attention_mask ): last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0) return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None] def create_hnsw_index(embeddings_np, space='ip', ef_construction=100, M=16): index = hnswlib.Index(space=space, dim=len(embeddings_np[0])) index.init_index(max_elements=len(embeddings_np), ef_construction=ef_construction, M=M) ids = np.arange(embeddings_np.shape[0]) index.add_items(embeddings_np, ids) return index def preprocess_texts(query, paragraphs): query = f'query: {query}' paragraphs = [f'passage: {p}' for p in paragraphs] return [query]+paragraphs app = FastAPI() class EmbeddingsSimilarityReq(BaseModel): paragraphs: List[str] query: str top_k: int @app.post("/") async def find_similar_paragraphsitem(req: EmbeddingsSimilarityReq): print("Len of batches", len(req.paragraphs)) print("creating embeddings", current_timestamp()) inputs = preprocess_texts(req.query, req.paragraphs) embeddings_np = get_embeddings(inputs) query_embedding, chunks_embeddings = embeddings_np[0], embeddings_np[1:] print("creating index", current_timestamp()) search_index = create_hnsw_index(chunks_embeddings) print("searching index", current_timestamp()) labels, _ = search_index.knn_query(query_embedding, k=min(int(req.top_k), len(chunks_embeddings))) labels = labels[0].tolist() return labels