mishig's picture
mishig HF staff
Use smaller version of the model
f89c622
from transformers import AutoTokenizer, AutoModel
import torch
import torch.nn.functional as F
import hnswlib
import numpy as np
import datetime
from fastapi import FastAPI
from pydantic import BaseModel
from typing import List
if torch.cuda.is_available():
print("CUDA is available! Inference on GPU!")
else:
print("CUDA is not available. Inference on CPU.")
seperator = "-HFSEP-"
base_name="intfloat/e5-small-v2"
device="cuda"
max_length=512
max_batch_size = 500
tokenizer = AutoTokenizer.from_pretrained(base_name)
model = AutoModel.from_pretrained(base_name).to(device)
def current_timestamp():
return datetime.datetime.utcnow().timestamp()
def get_embeddings(input_texts):
input_texts = input_texts[:max_batch_size]
batch_dict = tokenizer(
input_texts,
max_length=max_length,
padding=True,
truncation=True,
return_tensors='pt'
).to(device)
with torch.no_grad():
outputs = model(**batch_dict)
embeddings = _average_pool(
outputs.last_hidden_state, batch_dict['attention_mask']
)
embeddings = F.normalize(embeddings, p=2, dim=1)
embeddings_np = embeddings.cpu().numpy()
if device == "cuda":
del embeddings
torch.cuda.empty_cache()
return embeddings_np
def _average_pool(
last_hidden_states,
attention_mask
):
last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
def create_hnsw_index(embeddings_np, space='ip', ef_construction=100, M=16):
index = hnswlib.Index(space=space, dim=len(embeddings_np[0]))
index.init_index(max_elements=len(embeddings_np), ef_construction=ef_construction, M=M)
ids = np.arange(embeddings_np.shape[0])
index.add_items(embeddings_np, ids)
return index
def preprocess_texts(query, paragraphs):
query = f'query: {query}'
paragraphs = [f'passage: {p}' for p in paragraphs]
return [query]+paragraphs
app = FastAPI()
class EmbeddingsSimilarityReq(BaseModel):
paragraphs: List[str]
query: str
top_k: int
@app.post("/")
async def find_similar_paragraphsitem(req: EmbeddingsSimilarityReq):
print("Len of batches", len(req.paragraphs))
print("creating embeddings", current_timestamp())
inputs = preprocess_texts(req.query, req.paragraphs)
embeddings_np = get_embeddings(inputs)
query_embedding, chunks_embeddings = embeddings_np[0], embeddings_np[1:]
print("creating index", current_timestamp())
search_index = create_hnsw_index(chunks_embeddings)
print("searching index", current_timestamp())
labels, _ = search_index.knn_query(query_embedding, k=min(int(req.top_k), len(chunks_embeddings)))
labels = labels[0].tolist()
return labels