from sentence_transformers import SentenceTransformer from sentence_transformers import util import torch import json import os class PipelineWrapper: """This class is a wrapper for classifying gov datatset titles into the musterdatenkatalog taxonomy. It uses the sentence-transformers library to encode the text into embeddings and then uses semantic search. """ def __init__(self, path=""): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.model = SentenceTransformer(path, device=device, use_auth_token=True) self.taxonomy = os.path.join(path, "taxonomy.json") self.taxonomy_labels = None self.taxonomy_embeddings = None self.load_taxonomy_labels() self.get_taxonomy_embeddings() def __call__(self, queries: list) -> list: return self.predict(queries) def load_taxonomy_labels(self) -> None: with open(self.taxonomy, "r") as f: taxonomy = json.load(f) self.taxonomy_labels = [el["group"] + " - " + el["label"] for el in taxonomy] self.taxonomy_labels.remove("Sonstiges - Sonstiges") def get_taxonomy_embeddings(self) -> None: self.taxonomy_embeddings = self.model.encode( self.taxonomy_labels, convert_to_tensor=True ) def predict(self, queries: list) -> list: """Predicts the taxonomy labels for the given queries. Parameters ---------- queries : list List of queries to predict. Format is a list of dictionaries with the following keys: "id", "title" Returns ------- list List of dictionaries with the following keys: "id", "title", "prediction" """ texts = [el["title"] for el in queries] query_embeddings = self.model.encode(texts, convert_to_tensor=True) predictions = util.semantic_search( query_embeddings=query_embeddings, corpus_embeddings=self.taxonomy_embeddings, top_k=1, ) results = [] for query, prediction in zip(queries, predictions): results.append( { "id": query["id"], "title": query["title"], "prediction": self.taxonomy_labels[prediction[0]["corpus_id"]], } ) return results