Python ONNX code

#2
by emraza110 - opened

[GENERATED FROM gemini-pro-dev-api. Not tested by running it. Looks fine to me.]

import onnxruntime
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity

# Load the quantized ONNX model
model_path = "bge-m3-model-files/sentence_transformers_quantized.onnx"
onnx_session = onnxruntime.InferenceSession(model_path)

# Load the quantized ONNX data
data_path = "bge-m3-model-files/sentence_transformers_quantized.onnx_data"
with open(data_path, "rb") as f:
    data = f.read()

# Create a feature extractor function
def extract_embeddings(texts):
    # Preprocess the texts
    texts = [text.lower() for text in texts]

    # Convert texts to ONNX compatible format
    inputs = {"input_ids": np.array([text.encode("utf-8") for text in texts])}

    # Run the ONNX model
    outputs = onnx_session.run(data, inputs)

    # Extract the embeddings
    embeddings = outputs[0]

    return embeddings

# Define a query to use for retrieval
query = "What is BGE M3?"

# List of documents you want to embed
texts = [
    "BGE M3 is an embedding model supporting dense retrieval, lexical matching and multi-vector interaction.",
    "BM25 is a bag-of-words retrieval function that ranks a set of documents based on the query terms appearing in each document",
]

# Compute sentence embeddings
embeddings = extract_embeddings(texts)

# Compute query embeddings
query_embeddings = extract_embeddings([query])

# Compute cosine similarity scores
scores = cosine_similarity(query_embeddings, embeddings)

# Sort by cosine similarity score
sorted_scores = np.argsort(scores.flatten())[::-1]

# Print the top 5 most similar documents
print("Top 5 most similar documents:")
for i in sorted_scores[:5]:
    print(f"Document {i}: {texts[i]}")

@emraza110 Thanks for the contribution. Could you please try this and let me know if it works?

@emraza110 where to get file sentence_transformers_quantized.onnx_data ? I tried using sentence_transformers.onnx_data and not working
Thanks

Edit : when I tried using sentence_transformers.onnx_data
I get this error
ValueError: Required inputs (['attention_mask']) are missing from input feed (['input_ids']).

Xenova changed discussion status to closed

Sign up or log in to comment