| | from typing import List |
| | import numpy as np |
| | import torch |
| | import torch.nn.functional as F |
| | from transformers import AutoModel, AutoTokenizer |
| |
|
| | class EmbeddingScorer: |
| | """ |
| | A class for performing semantic search using embeddings. |
| | Uses the gte-multilingual-base model from Alibaba-NLP. |
| | """ |
| | |
| | def __init__(self, model_name='Alibaba-NLP/gte-multilingual-base'): |
| | """ |
| | Initialize the EmbeddingScorer with the specified model. |
| | |
| | Args: |
| | model_name (str): Name of the model to use. |
| | """ |
| | |
| | self.tokenizer = AutoTokenizer.from_pretrained(model_name) |
| | self.model = AutoModel.from_pretrained(model_name, trust_remote_code=True) |
| | self.dimension = 768 |
| | |
| | def score_method(self, query: str, methods: List[dict]) -> List[dict]: |
| | """ |
| | Calculate similarity between a query and a list of methods. |
| | |
| | Args: |
| | query (str): The query sentence. |
| | methods (list): List of method dictionaries to compare against the query. |
| | |
| | Returns: |
| | list: List of similarity scores between the query and each method. |
| | """ |
| | |
| | sentences = [f"{method['method']}: {method.get('description', '')}" for method in methods] |
| | texts = [query] + sentences |
| | |
| | |
| | batch_dict = self.tokenizer(texts, max_length=8192, padding=True, truncation=True, return_tensors='pt') |
| | |
| | |
| | with torch.no_grad(): |
| | outputs = self.model(**batch_dict) |
| | |
| | |
| | embeddings = outputs.last_hidden_state[:, 0][:self.dimension] |
| | |
| | |
| | embeddings = F.normalize(embeddings, p=2, dim=1) |
| | |
| | |
| | query_embedding = embeddings[0].unsqueeze(0) |
| | method_embeddings = embeddings[1:] |
| | |
| | |
| | similarities = (query_embedding @ method_embeddings.T) * 100 |
| | similarities = similarities.squeeze().tolist() |
| | |
| | |
| | if not isinstance(similarities, list): |
| | similarities = [similarities] |
| | |
| | |
| | result = [] |
| | for i, similarity in enumerate(similarities, start=1): |
| | result.append({ |
| | "method_index": i, |
| | "score": float(similarity) |
| | }) |
| | |
| | return result |
| |
|
| | if __name__ == "__main__": |
| | es = EmbeddingScorer() |
| | print(es.score_method("How to solve the problem of the user", [{"method": "Method 1", "description": "Description 1"}, {"method": "Method 2", "description": "Description 2"}])) |
| |
|