import os import numpy as np import onnxruntime as ort from pathlib import Path from transformers import AutoTokenizer, AutoModel from transformers.onnx import export from transformers.onnx.features import FeaturesManager from transformers.utils import logging logging.set_verbosity_error() class ONNXDistilUSEModel: def __init__(self, model_name="sentence-transformers/distiluse-base-multilingual-cased-v2", onnx_path="onnx_model/model.onnx"): self.model_name = model_name self.onnx_path = onnx_path self.tokenizer = AutoTokenizer.from_pretrained(model_name) if not os.path.exists(onnx_path): self.export_to_onnx() self.session = ort.InferenceSession(onnx_path) def export_to_onnx(self): model = AutoModel.from_pretrained(self.model_name) save_dir = Path(self.onnx_path).parent save_dir.mkdir(parents=True, exist_ok=True) _, model_onnx_config = FeaturesManager.check_supported_model_or_raise(model) onnx_config = model_onnx_config(model.config) export(preprocessor=self.tokenizer, model=model, config=onnx_config, opset=14, output=Path(self.onnx_path)) def mean_pooling(self, token_embeddings, attention_mask): input_mask_expanded = np.expand_dims(attention_mask, -1).astype(np.float32) pooled = np.sum(token_embeddings * input_mask_expanded, axis=1) / np.clip(np.sum(input_mask_expanded, axis=1), 1e-9, None) return pooled def encode(self, texts, normalize=True, debug=False): tokens = self.tokenizer(texts, padding=True, truncation=True, return_tensors="np") if debug: print("[DEBUG] Tokens:", self.tokenizer.convert_ids_to_tokens(tokens["input_ids"][0])) # Ensure correct dtype for ONNXRuntime inputs = { "input_ids": tokens["input_ids"].astype(np.int64), "attention_mask": tokens["attention_mask"].astype(np.int64) } outputs = self.session.run(None, inputs) embeddings = self.mean_pooling(outputs[0], tokens["attention_mask"]) # Normalize embeddings (recommended for FlatL2) if normalize: norms = np.linalg.norm(embeddings, axis=1, keepdims=True) embeddings = embeddings / np.clip(norms, 1e-9, None) return embeddings