File size: 2,393 Bytes
fb00699
 
 
 
67f47da
fb00699
 
 
 
 
 
 
 
 
 
67f47da
fb00699
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
from sentence_transformers import SentenceTransformer
from sentence_transformers import util
import torch
import json
import os

class PipelineWrapper:

    """This class is a wrapper for classifying gov datatset titles into the musterdatenkatalog taxonomy.
    It uses the sentence-transformers library to encode the text into embeddings and then uses semantic search.
    """

    def __init__(self, path=""):
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model = SentenceTransformer(path, device=device, use_auth_token=True)
        self.taxonomy = os.path.join(path, "taxonomy.json")
        self.taxonomy_labels = None
        self.taxonomy_embeddings = None
        self.load_taxonomy_labels()
        self.get_taxonomy_embeddings()

    def __call__(self, queries: list) -> list:
        return self.predict(queries)

    def load_taxonomy_labels(self) -> None:
        with open(self.taxonomy, "r") as f:
            taxonomy = json.load(f)
        self.taxonomy_labels = [el["group"] + " - " + el["label"] for el in taxonomy]
        self.taxonomy_labels.remove("Sonstiges - Sonstiges")

    def get_taxonomy_embeddings(self) -> None:
        self.taxonomy_embeddings = self.model.encode(
            self.taxonomy_labels, convert_to_tensor=True
        )

    def predict(self, queries: list) -> list:
        """Predicts the taxonomy labels for the given queries.

        Parameters
        ----------
        queries : list
            List of queries to predict. Format is a list of dictionaries with the following keys: "id", "title"

        Returns
        -------
        list
            List of dictionaries with the following keys: "id", "title", "prediction"
        """
        texts = [el["title"] for el in queries]
        query_embeddings = self.model.encode(texts, convert_to_tensor=True)
        predictions = util.semantic_search(
            query_embeddings=query_embeddings,
            corpus_embeddings=self.taxonomy_embeddings,
            top_k=1,
        )

        results = []

        for query, prediction in zip(queries, predictions):
            results.append(
                {
                    "id": query["id"],
                    "title": query["title"],
                    "prediction": self.taxonomy_labels[prediction[0]["corpus_id"]],
                }
            )

        return results