musterdatenkatalog_clf / pipeline.py
Rahka's picture
fix taxonomy not found error
67f47da
from sentence_transformers import SentenceTransformer
from sentence_transformers import util
import torch
import json
import os
class PipelineWrapper:
"""This class is a wrapper for classifying gov datatset titles into the musterdatenkatalog taxonomy.
It uses the sentence-transformers library to encode the text into embeddings and then uses semantic search.
"""
def __init__(self, path=""):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model = SentenceTransformer(path, device=device, use_auth_token=True)
self.taxonomy = os.path.join(path, "taxonomy.json")
self.taxonomy_labels = None
self.taxonomy_embeddings = None
self.load_taxonomy_labels()
self.get_taxonomy_embeddings()
def __call__(self, queries: list) -> list:
return self.predict(queries)
def load_taxonomy_labels(self) -> None:
with open(self.taxonomy, "r") as f:
taxonomy = json.load(f)
self.taxonomy_labels = [el["group"] + " - " + el["label"] for el in taxonomy]
self.taxonomy_labels.remove("Sonstiges - Sonstiges")
def get_taxonomy_embeddings(self) -> None:
self.taxonomy_embeddings = self.model.encode(
self.taxonomy_labels, convert_to_tensor=True
)
def predict(self, queries: list) -> list:
"""Predicts the taxonomy labels for the given queries.
Parameters
----------
queries : list
List of queries to predict. Format is a list of dictionaries with the following keys: "id", "title"
Returns
-------
list
List of dictionaries with the following keys: "id", "title", "prediction"
"""
texts = [el["title"] for el in queries]
query_embeddings = self.model.encode(texts, convert_to_tensor=True)
predictions = util.semantic_search(
query_embeddings=query_embeddings,
corpus_embeddings=self.taxonomy_embeddings,
top_k=1,
)
results = []
for query, prediction in zip(queries, predictions):
results.append(
{
"id": query["id"],
"title": query["title"],
"prediction": self.taxonomy_labels[prediction[0]["corpus_id"]],
}
)
return results