File size: 2,788 Bytes
b552d82
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
from typing import Dict, List, Any
import numpy as np
import pickle

from sklearn.preprocessing import MultiLabelBinarizer
from transformers import AutoTokenizer
import torch

from eurovoc import EurovocTagger

BERT_MODEL_NAME = "nlpaueb/legal-bert-base-uncased"
MAX_LEN = 512
TEXT_MAX_LEN = MAX_LEN * 50
tokenizer = AutoTokenizer.from_pretrained(BERT_MODEL_NAME)


class EndpointHandler:
    mlb = MultiLabelBinarizer()

    def __init__(self, path=""):
        self.mlb = pickle.load(open(f"{path}/mlb.pickle", "rb"))

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model = EurovocTagger.from_pretrained(path,
                                                   bert_model_name=BERT_MODEL_NAME,
                                                   n_classes=len(self.mlb.classes_),
                                                   map_location=self.device)
        self.model.eval()
        self.model.freeze()

    def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
        """
       data args:
            inputs (:obj: `str` | `PIL.Image` | `np.array`)
            kwargs
      Return:
            A :obj:`list` | `dict`: will be serialized and returned
        """

        text = data.pop("inputs", data)
        topk = data.pop("topk", 5)
        threshold = data.pop("threshold", 0.16)
        debug = data.pop("debug", False)
        prediction = self.get_prediction(text)
        results = [{"label": label, "score": float(score)} for label, score in
                   zip(self.mlb.classes_, prediction[0].tolist())]
        results = sorted(results, key=lambda x: x["score"], reverse=True)
        results = [r for r in results if r["score"] > threshold]
        results = results[:topk]
        if debug:
            return {"results": results, "values": prediction, "input": text}
        else:
            return {"results": results}

    def get_prediction(self, text):
        # split text into chunks of MAX_LEN and get average prediction for each chunk
        chunks = [text[i:i + MAX_LEN] for i in range(0, min(len(text), TEXT_MAX_LEN), MAX_LEN)]
        predictions = [self._get_prediction(chunk) for chunk in chunks]
        predictions = np.array(predictions).mean(axis=0)
        return predictions

    def _get_prediction(self, text):
        item = tokenizer.encode_plus(
            text,
            add_special_tokens=True,
            max_length=MAX_LEN,
            return_token_type_ids=False,
            padding="max_length",
            truncation=True,
            return_attention_mask=True,
            return_tensors='pt')
        _, prediction = self.model(item["input_ids"], item["attention_mask"])
        prediction = prediction.cpu().detach().numpy()
        return prediction