File size: 18,422 Bytes
473c3a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
from __future__ import annotations

import logging
from collections import Counter
from itertools import chain
from tempfile import TemporaryDirectory
from typing import TYPE_CHECKING, TypeVar, cast

import lightning as pl
import numpy as np
import torch
from lightning.pytorch.callbacks import Callback, EarlyStopping
from sklearn.metrics import jaccard_score
from sklearn.model_selection import train_test_split
from sklearn.neural_network import MLPClassifier
from sklearn.pipeline import make_pipeline
from torch import nn
from tqdm import trange

from distiller.model2vec.inference import StaticModelPipeline, evaluate_single_or_multi_label
from distiller.model2vec.train.base import FinetunableStaticModel, TextDataset

if TYPE_CHECKING:
    from lightning.pytorch.utilities.types import OptimizerLRScheduler
    from tokenizers import Tokenizer

logger = logging.getLogger(__name__)
_RANDOM_SEED = 42

LabelType = TypeVar("LabelType", list[str], list[list[str]])


class StaticModelForClassification(FinetunableStaticModel):
    def __init__(
        self,
        *,
        vectors: torch.Tensor,
        tokenizer: Tokenizer,
        n_layers: int = 1,
        hidden_dim: int = 512,
        out_dim: int = 2,
        pad_id: int = 0,
    ) -> None:
        """Initialize a standard classifier model."""
        self.n_layers = n_layers
        self.hidden_dim = hidden_dim
        # Alias: Follows scikit-learn. Set to dummy classes
        self.classes_: list[str] = [str(x) for x in range(out_dim)]
        # multilabel flag will be set based on the type of `y` passed to fit.
        self.multilabel: bool = False
        super().__init__(vectors=vectors, out_dim=out_dim, pad_id=pad_id, tokenizer=tokenizer)

    @property
    def classes(self) -> np.ndarray:
        """Return all clasess in the correct order."""
        return np.array(self.classes_)

    def construct_head(self) -> nn.Sequential:
        """Constructs a simple classifier head."""
        if self.n_layers == 0:
            return nn.Sequential(nn.Linear(self.embed_dim, self.out_dim))
        modules = [
            nn.Linear(self.embed_dim, self.hidden_dim),
            nn.ReLU(),
        ]
        for _ in range(self.n_layers - 1):
            modules.extend([nn.Linear(self.hidden_dim, self.hidden_dim), nn.ReLU()])
        modules.extend([nn.Linear(self.hidden_dim, self.out_dim)])

        for module in modules:
            if isinstance(module, nn.Linear):
                nn.init.kaiming_uniform_(module.weight)
                nn.init.zeros_(module.bias)

        return nn.Sequential(*modules)

    def predict(
        self, X: list[str], show_progress_bar: bool = False, batch_size: int = 1024, threshold: float = 0.5
    ) -> np.ndarray:
        """
        Predict labels for a set of texts.

        In single-label mode, each prediction is a single class.
        In multilabel mode, each prediction is a list of classes.

        :param X: The texts to predict on.
        :param show_progress_bar: Whether to show a progress bar.
        :param batch_size: The batch size.
        :param threshold: The threshold for multilabel classification.
        :return: The predictions.
        """
        pred = []
        for batch in trange(0, len(X), batch_size, disable=not show_progress_bar):
            logits = self._predict_single_batch(X[batch : batch + batch_size])
            if self.multilabel:
                probs = torch.sigmoid(logits)
                mask = (probs > threshold).cpu().numpy()
                pred.extend([self.classes[np.flatnonzero(row)] for row in mask])
            else:
                pred.extend([self.classes[idx] for idx in logits.argmax(dim=1).tolist()])
        if self.multilabel:
            # Return as object array to allow for lists of varying lengths.
            return np.array(pred, dtype=object)
        return np.array(pred)

    @torch.no_grad()
    def _predict_single_batch(self, X: list[str]) -> torch.Tensor:
        input_ids = self.tokenize(X)
        vectors, _ = self.forward(input_ids)
        return vectors

    def predict_proba(self, X: list[str], show_progress_bar: bool = False, batch_size: int = 1024) -> np.ndarray:
        """
        Predict probabilities for each class.

        In single-label mode, returns softmax probabilities.
        In multilabel mode, returns sigmoid probabilities.
        """
        pred = []
        for batch in trange(0, len(X), batch_size, disable=not show_progress_bar):
            logits = self._predict_single_batch(X[batch : batch + batch_size])
            if self.multilabel:
                pred.append(torch.sigmoid(logits).cpu().numpy())
            else:
                pred.append(torch.softmax(logits, dim=1).cpu().numpy())
        return np.concatenate(pred, axis=0)

    def fit(
        self,
        X: list[str],
        y: LabelType,
        learning_rate: float = 1e-3,
        batch_size: int | None = None,
        min_epochs: int | None = None,
        max_epochs: int | None = -1,
        early_stopping_patience: int | None = 5,
        test_size: float = 0.1,
        device: str = "auto",
        X_val: list[str] | None = None,
        y_val: LabelType | None = None,
    ) -> StaticModelForClassification:
        """
        Fit a model.

        This function creates a Lightning Trainer object and fits the model to the data.
        It supports both single-label and multi-label classification.
        We use early stopping. After training, the weights of the best model are loaded back into the model.

        This function seeds everything with a seed of 42, so the results are reproducible.
        It also splits the data into a train and validation set, again with a random seed.

        If `X_val` and `y_val` are not provided, the function will automatically
        split the training data into a train and validation set using `test_size`.

        :param X: The texts to train on.
        :param y: The labels to train on. If the first element is a list, multi-label classification is assumed.
        :param learning_rate: The learning rate.
        :param batch_size: The batch size. If None, a good batch size is chosen automatically.
        :param min_epochs: The minimum number of epochs to train for.
        :param max_epochs: The maximum number of epochs to train for.
            If this is -1, the model trains until early stopping is triggered.
        :param early_stopping_patience: The patience for early stopping.
            If this is None, early stopping is disabled.
        :param test_size: The test size for the train-test split.
        :param device: The device to train on. If this is "auto", the device is chosen automatically.
        :param X_val: The texts to be used for validation.
        :param y_val: The labels to be used for validation.
        :return: The fitted model.
        :raises ValueError: If either X_val or y_val are provided, but not both.
        """
        pl.seed_everything(_RANDOM_SEED)
        logger.info("Re-initializing model.")

        # Determine whether the task is multilabel based on the type of y.

        self._initialize(y)

        if (X_val is not None) != (y_val is not None):
            msg = "Both X_val and y_val must be provided together, or neither."
            raise ValueError(msg)

        if X_val is not None and y_val is not None:
            # Additional check to ensure y_val is of the same type as y
            if type(y_val[0]) != type(y[0]):
                msg = "X_val and y_val must be of the same type as X and y."
                raise ValueError(msg)

            train_texts = X
            train_labels = y
            validation_texts = X_val
            validation_labels = y_val
        else:
            train_texts, validation_texts, train_labels, validation_labels = self._train_test_split(
                X,
                y,
                test_size=test_size,
            )

        if batch_size is None:
            # Set to a multiple of 32
            base_number = int(min(max(1, (len(train_texts) / 30) // 32), 16))
            batch_size = int(base_number * 32)
            logger.info("Batch size automatically set to %d.", batch_size)

        logger.info("Preparing train dataset.")
        train_dataset = self._prepare_dataset(train_texts, train_labels)
        logger.info("Preparing validation dataset.")
        val_dataset = self._prepare_dataset(validation_texts, validation_labels)

        c = _ClassifierLightningModule(self, learning_rate=learning_rate)

        n_train_batches = len(train_dataset) // batch_size
        callbacks: list[Callback] = []
        if early_stopping_patience is not None:
            callback = EarlyStopping(monitor="val_accuracy", mode="max", patience=early_stopping_patience)
            callbacks.append(callback)

        # If the dataset is small, we check the validation set every epoch.
        # If the dataset is large, we check the validation set every 250 batches.
        if n_train_batches < 250:
            val_check_interval = None
            check_val_every_epoch = 1
        else:
            val_check_interval = max(250, 2 * len(val_dataset) // batch_size)
            check_val_every_epoch = None

        with TemporaryDirectory() as tempdir:
            trainer = pl.Trainer(
                min_epochs=min_epochs,
                max_epochs=max_epochs,
                callbacks=callbacks,
                val_check_interval=val_check_interval,
                check_val_every_n_epoch=check_val_every_epoch,
                accelerator=device,
                default_root_dir=tempdir,
            )

            trainer.fit(
                c,
                train_dataloaders=train_dataset.to_dataloader(shuffle=True, batch_size=batch_size),
                val_dataloaders=val_dataset.to_dataloader(shuffle=False, batch_size=batch_size),
            )
            best_model_path = trainer.checkpoint_callback.best_model_path  # type: ignore
            best_model_weights = torch.load(best_model_path, weights_only=True)

        state_dict = {}
        for weight_name, weight in best_model_weights["state_dict"].items():
            state_dict[weight_name.removeprefix("model.")] = weight

        self.load_state_dict(state_dict)
        self.eval()
        return self

    def evaluate(
        self, X: list[str], y: LabelType, batch_size: int = 1024, threshold: float = 0.5, output_dict: bool = False
    ) -> str | dict[str, dict[str, float]]:
        """
        Evaluate the classifier on a given dataset using scikit-learn's classification report.

        :param X: The texts to predict on.
        :param y: The ground truth labels.
        :param batch_size: The batch size.
        :param threshold: The threshold for multilabel classification.
        :param output_dict: Whether to output the classification report as a dictionary.
        :return: A classification report.
        """
        self.eval()
        predictions = self.predict(X, show_progress_bar=True, batch_size=batch_size, threshold=threshold)
        return evaluate_single_or_multi_label(predictions=predictions, y=y, output_dict=output_dict)


    def _initialize(self, y: LabelType) -> None:
        """
        Sets the output dimensionality, the classes, and initializes the head.

        :param y: The labels.
        :raises ValueError: If the labels are inconsistent.
        """
        if isinstance(y[0], (str, int)):
            # Check if all labels are strings or integers.
            if not all(isinstance(label, (str, int)) for label in y):
                msg = "Inconsistent label types in y. All labels must be strings or integers."
                raise ValueError(msg)
            self.multilabel = False
            classes = sorted(set(y))
        else:
            # Check if all labels are lists or tuples.
            if not all(isinstance(label, (list, tuple)) for label in y):
                msg = "Inconsistent label types in y. All labels must be lists or tuples."
                raise ValueError(msg)
            self.multilabel = True
            classes = sorted(set(chain.from_iterable(y)))

        self.classes_ = classes
        self.out_dim = len(self.classes_)  # Update output dimension
        self.head = self.construct_head()
        self.embeddings = nn.Embedding.from_pretrained(self.vectors.clone(), freeze=False, padding_idx=self.pad_id)
        self.w = self.construct_weights()
        self.train()

    def _prepare_dataset(self, X: list[str], y: LabelType, max_length: int = 512) -> TextDataset:
        """
        Prepare a dataset. For multilabel classification, each target is converted into a multi-hot vector.

        :param X: The texts.
        :param y: The labels.
        :param max_length: The maximum length of the input.
        :return: A TextDataset.
        """
        # This is a speed optimization.
        # assumes a mean token length of 10, which is really high, so safe.
        truncate_length = max_length * 10
        X = [x[:truncate_length] for x in X]
        tokenized: list[list[int]] = [
            encoding.ids[:max_length] for encoding in self.tokenizer.encode_batch_fast(X, add_special_tokens=False)
        ]
        if self.multilabel:
            # Convert labels to multi-hot vectors
            num_classes = len(self.classes_)
            labels_tensor = torch.zeros(len(y), num_classes, dtype=torch.float)
            mapping = {label: idx for idx, label in enumerate(self.classes_)}
            for i, sample_labels in enumerate(y):
                indices = [mapping[label] for label in sample_labels]
                labels_tensor[i, indices] = 1.0
        else:
            labels_tensor = torch.tensor([self.classes_.index(label) for label in cast("list[str]", y)], dtype=torch.long)
        return TextDataset(tokenized, labels_tensor)

    def _train_test_split(
        self,
        X: list[str],
        y: list[str] | list[list[str]],
        test_size: float,
    ) -> tuple[list[str], list[str], LabelType, LabelType]:
        """
        Split the data.

        For single-label classification, stratification is attempted (if possible).
        For multilabel classification, a random split is performed.
        """
        if not self.multilabel:
            label_counts = Counter(y)
            if min(label_counts.values()) < 2:
                logger.info("Some classes have less than 2 samples. Stratification is disabled.")
                return train_test_split(X, y, test_size=test_size, random_state=42, shuffle=True)
            return train_test_split(X, y, test_size=test_size, random_state=42, shuffle=True, stratify=y)
        # Multilabel classification does not support stratification.
        return train_test_split(X, y, test_size=test_size, random_state=42, shuffle=True)

    def to_pipeline(self) -> StaticModelPipeline:
        """Convert the model to an sklearn pipeline."""
        static_model = self.to_static_model()

        random_state = np.random.RandomState(_RANDOM_SEED)
        n_items = len(self.classes)
        X = random_state.randn(n_items, static_model.dim)
        y = self.classes

        converted = make_pipeline(MLPClassifier(hidden_layer_sizes=(self.hidden_dim,) * self.n_layers))
        converted.fit(X, y)
        mlp_head: MLPClassifier = converted[-1]

        for index, layer in enumerate([module for module in self.head if isinstance(module, nn.Linear)]):
            mlp_head.coefs_[index] = layer.weight.detach().cpu().numpy().T
            mlp_head.intercepts_[index] = layer.bias.detach().cpu().numpy()
        # Below is necessary to ensure that the converted model works correctly.
        # In scikit-learn, a binary classifier only has a single vector of output coefficients
        # and a single intercept. We use two output vectors.
        # To convert correctly, we need to set the outputs correctly, and fix the activation function.
        # Make sure n_outputs is set to > 1.
        mlp_head.n_outputs_ = self.out_dim
        # Set to softmax or sigmoid
        mlp_head.out_activation_ = "logistic" if self.multilabel else "softmax"

        return StaticModelPipeline(static_model, converted)


class _ClassifierLightningModule(pl.LightningModule):
    def __init__(self, model: StaticModelForClassification, learning_rate: float) -> None:
        """Initialize the LightningModule."""
        super().__init__()
        self.model = model
        self.learning_rate = learning_rate
        self.loss_function = nn.CrossEntropyLoss() if not model.multilabel else nn.BCEWithLogitsLoss()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Simple forward pass."""
        return self.model(x)

    def training_step(self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor:
        """Training step using cross-entropy loss for single-label and binary cross-entropy for multilabel training."""
        x, y = batch
        head_out, _ = self.model(x)
        loss = self.loss_function(head_out, y)
        self.log("train_loss", loss)
        return loss

    def validation_step(self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor:
        """Validation step computing loss and accuracy."""
        x, y = batch
        head_out, _ = self.model(x)
        loss = self.loss_function(head_out, y)
        if self.model.multilabel:
            preds = (torch.sigmoid(head_out) > 0.5).float()
            # Multilabel accuracy is defined as the Jaccard score averaged over samples.
            accuracy = jaccard_score(y.cpu(), preds.cpu(), average="samples")
        else:
            accuracy = (head_out.argmax(dim=1) == y).float().mean()
        self.log("val_loss", loss)
        self.log("val_accuracy", accuracy, prog_bar=True)

        return loss

    def configure_optimizers(self) -> OptimizerLRScheduler:
        """Configure optimizer and learning rate scheduler."""
        optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer,
            mode="min",
            factor=0.5,
            patience=3,
            min_lr=1e-6,
            threshold=0.03,
            threshold_mode="rel",
        )
        return {"optimizer": optimizer, "lr_scheduler": {"scheduler": scheduler, "monitor": "val_loss"}}