File size: 3,587 Bytes
b438028
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
from typing import List

import torch
from datasets import Dataset
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import PerceiverTokenizer


def _map_outputs(predictions):
    """
    Map model outputs to classes.

    :param predictions: model ouptut batch
    :return:
    """

    labels = [
        "admiration",
        "amusement",
        "anger",
        "annoyance",
        "approval",
        "caring",
        "confusion",
        "curiosity",
        "desire",
        "disappointment",
        "disapproval",
        "disgust",
        "embarrassment",
        "excitement",
        "fear",
        "gratitude",
        "grief",
        "joy",
        "love",
        "nervousness",
        "optimism",
        "pride",
        "realization",
        "relief",
        "remorse",
        "sadness",
        "surprise",
        "neutral"
    ]
    classes = []
    for i, example in enumerate(predictions):
        out_batch = []
        for j, category in enumerate(example):
            out_batch.append(labels[j]) if category > 0.5 else None
        classes.append(out_batch)
    return classes


class MultiLabelPipeline:
    """
    Multi label classification pipeline.
    """

    def __init__(self, model_path):
        """
        Init MLC pipeline.

        :param model_path: model to use
        """

        # Init attributes
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        if self.device == 'cuda':
            self.model = torch.load(model_path).eval().to(self.device)
        else:
            self.model = torch.load(model_path, map_location=torch.device('cpu')).eval().to(self.device)
        self.tokenizer = PerceiverTokenizer.from_pretrained('deepmind/language-perceiver')

    def __call__(self, dataset, batch_size: int = 4):
        """
        Processing pipeline.

        :param dataset: dataset
        :return:
        """

        # Tokenize inputs
        dataset = dataset.map(lambda row: self.tokenizer(row['text'], padding="max_length", truncation=True),
                              batched=True, remove_columns=['text'], desc='Tokenizing')
        dataset.set_format('torch', columns=['input_ids', 'attention_mask'])
        dataloader = DataLoader(dataset, batch_size=batch_size)

        # Define output classes
        classes = []
        mem_logs = []

        with tqdm(dataloader, unit='batches') as progression:
            for batch in progression:
                progression.set_description('Inference')
                # Forward
                outputs = self.model(inputs=batch['input_ids'].to(self.device),
                                     attention_mask=batch['attention_mask'].to(self.device), )

                # Outputs
                predictions = outputs.logits.cpu().detach().numpy()

                # Map predictions to classes
                batch_classes = _map_outputs(predictions)

                for row in batch_classes:
                    classes.append(row)

                # Retrieve memory usage
                memory = round(torch.cuda.memory_reserved(self.device) / 1e9, 2)
                mem_logs.append(memory)

                # Update pbar
                progression.set_postfix(memory=f"{round(sum(mem_logs) / len(mem_logs), 2)}Go")

        return classes


def inputs_to_dataset(inputs: List[str]):
    """
    Convert a list of strings to a dataset object.

    :param inputs: list of strings
    :return:
    """

    inputs = {'text': [input for input in inputs]}

    return Dataset.from_dict(inputs)