File size: 4,126 Bytes
d5fb6c3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
from sentence_transformers import models
import torch
from transformers import AutoTokenizer
from optimum.onnxruntime import ORTModelForFeatureExtraction
import numpy as np
import os
import onnxruntime

# ONNX pipeline for Gemma3 embedding model
model_dir = "embeddinggemma-300m"
tokenizer = AutoTokenizer.from_pretrained("google/embeddinggemma-300m-qat-q4_0-unquantized")
device = "cuda" if torch.cuda.is_available() else "cpu"
onnx_model = ORTModelForFeatureExtraction.from_pretrained(
    model_dir,
    file_name="model.onnx"
).to(device)

class ONNXTransformer:
    def __init__(self, onnx_model, tokenizer, max_seq_length=2048):
        self.onnx_model = onnx_model
        self.tokenizer = tokenizer
        self.max_seq_length = max_seq_length
    def encode(self, sentences):
        inputs = self.tokenizer(sentences, return_tensors="pt", padding=True, truncation=True, max_length=self.max_seq_length)
        input_ids = inputs['input_ids']
        sequence_length = input_ids.shape[1]
        position_ids = torch.arange(sequence_length)[None, :].expand(input_ids.shape[0], sequence_length)
        inputs['position_ids'] = position_ids.to(input_ids.device)
        with torch.no_grad():
            outputs = self.onnx_model(**inputs)
        return outputs.last_hidden_state

modules = []
onnx_transformer = ONNXTransformer(onnx_model, tokenizer, max_seq_length=2048)
modules.append(onnx_transformer)
for idx, name in [(1, "Pooling"), (2, "Dense"), (3, "Dense"), (4, "Normalize")]:
    module_path = os.path.join(model_dir, f"{idx}_{name}")
    if name == "Pooling":
        modules.append(models.Pooling(module_path))
    elif name == "Dense":
        # Use ONNXRuntime for Dense layers
        dense_onnx_path = os.path.join(model_dir, "onnx", f"dense{idx-1}.onnx")
        modules.append(onnxruntime.InferenceSession(dense_onnx_path, providers=["CPUExecutionProvider"]))
    elif name == "Normalize":
        modules.append(models.Normalize())

class ONNXSentenceTransformer:
    def __init__(self, modules):
        self.modules = modules
    def encode(self, sentences):
        features = self.modules[0].encode(sentences)
        for module in self.modules[1:]:
            if isinstance(module, models.Pooling):
                features = module({'token_embeddings': features, 'attention_mask': torch.ones(features.shape[:2], device=features.device)})['sentence_embedding']
            elif isinstance(module, onnxruntime.InferenceSession):
                # ONNX Dense layer expects shape [1, in_features], so process each embedding separately
                if isinstance(features, torch.Tensor):
                    features = features.cpu().detach().numpy()
                outputs = []
                for vec in features:
                    ort_inputs = {module.get_inputs()[0].name: vec.reshape(1, -1)}
                    out = module.run(None, ort_inputs)[0]
                    outputs.append(out.squeeze(0))
                features = np.stack(outputs, axis=0)
            elif isinstance(module, models.Normalize):
                # Normalize still uses PyTorch
                if not isinstance(features, torch.Tensor):
                    features = torch.from_numpy(features)
                features = module({'sentence_embedding': features})['sentence_embedding']
        if isinstance(features, torch.Tensor):
            return features.cpu().detach().numpy()
        return features

onnx_st = ONNXSentenceTransformer(modules)

def cosine_similarity(a, b):
    a = a.flatten()
    b = b.flatten()
    return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))

if __name__ == "__main__":
    words = ["apple", "banana", "car"]
    embeddings = onnx_st.encode(words)
    print(embeddings)
    for idx, embedding in enumerate(embeddings):
        print(f"Embedding {idx+1}: {embedding.shape}")

    print("\nCosine similarities:")
    print(f"apple vs banana: {cosine_similarity(embeddings[0], embeddings[1]):.4f}")
    print(f"apple vs car: {cosine_similarity(embeddings[0], embeddings[2]):.4f}")
    print(f"banana vs car: {cosine_similarity(embeddings[1], embeddings[2]):.4f}")