Spaces:
Running
Running
from transformers import AutoTokenizer, AutoModel | |
import torch | |
import numpy as np | |
# Use a model with PyTorch weights available | |
MODEL_NAME = "thenlper/gte-small" | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
model = AutoModel.from_pretrained(MODEL_NAME) | |
def get_embeddings(texts, max_length=512): | |
""" | |
Generate embeddings for long text by chunking and averaging. | |
Args: | |
texts (str or list): One or multiple texts to embed. | |
max_length (int): Maximum tokens per chunk (default is 512). | |
Returns: | |
np.ndarray: Averaged embeddings. | |
""" | |
if isinstance(texts, str): | |
texts = [texts] | |
final_embeddings = [] | |
for text in texts: | |
# Tokenize and split into chunks | |
tokens = tokenizer.tokenize(text) | |
chunks = [tokens[i:i + max_length] for i in range(0, len(tokens), max_length)] | |
chunk_embeddings = [] | |
for chunk in chunks: | |
input_ids = tokenizer.convert_tokens_to_ids(chunk) | |
input_ids = torch.tensor([input_ids]) | |
with torch.no_grad(): | |
output = model(input_ids=input_ids) | |
embedding = output.last_hidden_state.mean(dim=1) # Mean pooling | |
chunk_embeddings.append(embedding) | |
# Average embeddings of all chunks | |
if chunk_embeddings: | |
avg_embedding = torch.stack(chunk_embeddings).mean(dim=0) | |
final_embeddings.append(avg_embedding.squeeze(0).numpy()) | |
else: | |
final_embeddings.append(np.zeros(model.config.hidden_size)) | |
return np.array(final_embeddings) | |