|
from typing import Any, Dict, List |
|
from sentence_transformers import SentenceTransformer |
|
from sentence_transformers import util |
|
import torch |
|
import json |
|
|
|
|
|
class PipelineWrapper: |
|
|
|
"""This class is a wrapper for classifying gov datatset titles into the musterdatenkatalog taxonomy. |
|
It uses the sentence-transformers library to encode the text into embeddings and then uses semantic search. |
|
""" |
|
|
|
def __init__(self, path=""): |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
self.model = SentenceTransformer(path, device=device, use_auth_token=True) |
|
self.taxonomy = "taxonomy.json" |
|
self.taxonomy_labels = None |
|
self.taxonomy_embeddings = None |
|
self.load_taxonomy_labels() |
|
self.get_taxonomy_embeddings() |
|
|
|
def __call__(self, queries: list) -> list: |
|
return self.predict(queries) |
|
|
|
def load_taxonomy_labels(self) -> None: |
|
with open(self.taxonomy, "r") as f: |
|
taxonomy = json.load(f) |
|
self.taxonomy_labels = [el["group"] + " - " + el["label"] for el in taxonomy] |
|
self.taxonomy_labels.remove("Sonstiges - Sonstiges") |
|
|
|
def get_taxonomy_embeddings(self) -> None: |
|
self.taxonomy_embeddings = self.model.encode( |
|
self.taxonomy_labels, convert_to_tensor=True |
|
) |
|
|
|
def predict(self, queries: list) -> list: |
|
"""Predicts the taxonomy labels for the given queries. |
|
|
|
Parameters |
|
---------- |
|
queries : list |
|
List of queries to predict. Format is a list of dictionaries with the following keys: "id", "title" |
|
|
|
Returns |
|
------- |
|
list |
|
List of dictionaries with the following keys: "id", "title", "prediction" |
|
""" |
|
texts = [el["title"] for el in queries] |
|
query_embeddings = self.model.encode(texts, convert_to_tensor=True) |
|
predictions = util.semantic_search( |
|
query_embeddings=query_embeddings, |
|
corpus_embeddings=self.taxonomy_embeddings, |
|
top_k=1, |
|
) |
|
|
|
results = [] |
|
|
|
for query, prediction in zip(queries, predictions): |
|
results.append( |
|
{ |
|
"id": query["id"], |
|
"title": query["title"], |
|
"prediction": self.taxonomy_labels[prediction[0]["corpus_id"]], |
|
} |
|
) |
|
|
|
return results |
|
|