|
|
from sentence_transformers import models
|
|
|
import torch
|
|
|
from transformers import AutoTokenizer
|
|
|
from optimum.onnxruntime import ORTModelForFeatureExtraction
|
|
|
import numpy as np
|
|
|
|
|
|
model_path = "./embeddinggemma-300m"
|
|
|
tokenizer = AutoTokenizer.from_pretrained("google/embeddinggemma-300m-qat-q4_0-unquantized")
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
onnx_model = ORTModelForFeatureExtraction.from_pretrained(model_path).to(device)
|
|
|
|
|
|
class ONNXSentenceTransformer:
|
|
|
def __init__(self, model, tokenizer):
|
|
|
self.model = model
|
|
|
self.tokenizer = tokenizer
|
|
|
self.word_embedding_dimension = 768
|
|
|
self.pooling = models.Pooling(word_embedding_dimension=self.word_embedding_dimension, pooling_mode_mean_tokens=True)
|
|
|
|
|
|
def encode(self, sentences, batch_size=32):
|
|
|
if isinstance(sentences, str):
|
|
|
sentences = [sentences]
|
|
|
embeddings = []
|
|
|
for i in range(0, len(sentences), batch_size):
|
|
|
batch = sentences[i:i+batch_size]
|
|
|
inputs = self.tokenizer(batch, return_tensors="pt", padding=True, truncation=True)
|
|
|
input_ids = inputs['input_ids']
|
|
|
sequence_length = input_ids.shape[1]
|
|
|
position_ids = torch.arange(sequence_length)[None, :].expand(input_ids.shape[0], sequence_length)
|
|
|
inputs['position_ids'] = position_ids
|
|
|
with torch.no_grad():
|
|
|
outputs = self.model(**inputs)
|
|
|
last_hidden = outputs.last_hidden_state
|
|
|
attention_mask = inputs['attention_mask'].to(last_hidden.device)
|
|
|
features = {'token_embeddings': last_hidden, 'attention_mask': attention_mask}
|
|
|
pooled = self.pooling(features)['sentence_embedding']
|
|
|
embeddings.append(pooled)
|
|
|
return torch.cat(embeddings, dim=0).cpu().detach().numpy()
|
|
|
|
|
|
|
|
|
|
|
|
onnx_st = ONNXSentenceTransformer(onnx_model, tokenizer)
|
|
|
|
|
|
words = ["apple", "banana", "car"]
|
|
|
embeddings = onnx_st.encode(words)
|
|
|
print(embeddings)
|
|
|
for idx, embedding in enumerate(embeddings):
|
|
|
print(f"Embedding {idx+1}: {embedding.shape}")
|
|
|
|
|
|
|
|
|
def cosine_similarity(a, b):
|
|
|
a = a.flatten()
|
|
|
b = b.flatten()
|
|
|
return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))
|
|
|
|
|
|
print("\nCosine similarities:")
|
|
|
print(f"apple vs banana: {cosine_similarity(embeddings[0], embeddings[1]):.4f}")
|
|
|
print(f"apple vs car: {cosine_similarity(embeddings[0], embeddings[2]):.4f}")
|
|
|
print(f"banana vs car: {cosine_similarity(embeddings[1], embeddings[2]):.4f}")
|
|
|
|