|
|
import emoji |
|
|
import numpy as np |
|
|
import pytorch_lightning as pl |
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
from loguru import logger |
|
|
from torch import nn |
|
|
from torch.optim.lr_scheduler import CosineAnnealingLR |
|
|
from torchmetrics import R2Score |
|
|
|
|
|
from src.utils import get_sentiment |
|
|
from src.utils.neural_networks import set_layer |
|
|
from config import DEVICE |
|
|
|
|
|
torch.set_default_dtype(torch.float32) |
|
|
|
|
|
|
|
|
class DecoderPL(pl.LightningModule): |
|
|
def __init__( |
|
|
self, |
|
|
input_dim: int = 774, |
|
|
layer_norm: bool = True, |
|
|
layer_dict: dict = {}, |
|
|
device=DEVICE, |
|
|
T_max: int = 10, |
|
|
start_lr: float = 5 * 1e-4, |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
|
|
|
self.linear1 = set_layer( |
|
|
layer_dict=layer_dict, |
|
|
name="linear1", |
|
|
alternative=nn.Linear(in_features=input_dim, out_features=512), |
|
|
) |
|
|
|
|
|
self.linear2 = set_layer( |
|
|
layer_dict=layer_dict, |
|
|
name="linear2", |
|
|
alternative=nn.Linear(in_features=512, out_features=264), |
|
|
) |
|
|
|
|
|
self.linear3 = set_layer( |
|
|
layer_dict=layer_dict, |
|
|
name="linear3", |
|
|
alternative=nn.Linear(in_features=264, out_features=64), |
|
|
) |
|
|
|
|
|
self.linear4 = set_layer( |
|
|
layer_dict=layer_dict, |
|
|
name="linear4", |
|
|
alternative=nn.Linear(in_features=64, out_features=1), |
|
|
) |
|
|
|
|
|
self.activation = nn.LeakyReLU(negative_slope=0.01) |
|
|
|
|
|
if not layer_norm: |
|
|
self.layers = [ |
|
|
self.linear1, |
|
|
self.activation, |
|
|
self.linear2, |
|
|
self.activation, |
|
|
self.linear3, |
|
|
self.activation, |
|
|
self.linear4, |
|
|
] |
|
|
else: |
|
|
self.layernorm1 = nn.LayerNorm(normalized_shape=(1, self.linear1.out_features)) |
|
|
self.layernorm2 = nn.LayerNorm(normalized_shape=(1, self.linear2.out_features)) |
|
|
self.layernorm3 = nn.LayerNorm(normalized_shape=(1, self.linear3.out_features)) |
|
|
self.layers = [ |
|
|
self.linear1, |
|
|
self.layernorm1, |
|
|
self.activation, |
|
|
self.linear2, |
|
|
self.layernorm2, |
|
|
self.activation, |
|
|
self.linear3, |
|
|
self.layernorm3, |
|
|
self.activation, |
|
|
self.linear4, |
|
|
] |
|
|
|
|
|
|
|
|
[self.initialize_weights(layer) for layer in self.layers] |
|
|
|
|
|
|
|
|
self.optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, self.parameters()), lr=start_lr) |
|
|
self.scheduler = CosineAnnealingLR(self.optimizer, T_max=T_max) |
|
|
|
|
|
|
|
|
self.save_hyperparameters(ignore=["model"]) |
|
|
self.MSE = nn.MSELoss() |
|
|
self.R2 = R2Score() |
|
|
|
|
|
def initialize_weights(self, module): |
|
|
|
|
|
if isinstance(module, nn.Linear): |
|
|
logger.debug("linear weights initialized") |
|
|
torch.nn.init.xavier_uniform_(module.weight) |
|
|
module.bias.data.fill_(0.01) |
|
|
|
|
|
def forward(self, x: torch.Tensor): |
|
|
|
|
|
if x.dim() == 2: |
|
|
x = x.unsqueeze(dim=1) |
|
|
|
|
|
for layer in self.layers: |
|
|
x = layer(x) |
|
|
|
|
|
x = x.squeeze() |
|
|
|
|
|
if x.dim() == 0: |
|
|
x = x.unsqueeze(dim=0) |
|
|
|
|
|
return x.to(torch.float32) |
|
|
|
|
|
def training_step(self, batch): |
|
|
|
|
|
loss_and_metrics = self._get_loss(batch, get_metrics=True) |
|
|
pred = loss_and_metrics["pred"] |
|
|
act = loss_and_metrics["act"] |
|
|
loss = loss_and_metrics["loss"] |
|
|
|
|
|
self.log("train_loss", loss, on_epoch=True, on_step=False, prog_bar=True, logger=True) |
|
|
|
|
|
return {"loss": loss, "pred": pred, "act": act} |
|
|
|
|
|
def configure_optimizers(self): |
|
|
|
|
|
optimizer = self.optimizer |
|
|
scheduler = self.scheduler |
|
|
return dict(optimizer=optimizer, lr_scheduler=scheduler) |
|
|
|
|
|
def lr_scheduler_step(self, scheduler, optimizer_idx, metric): |
|
|
logger.debug(scheduler) |
|
|
if metric is None: |
|
|
scheduler.step() |
|
|
else: |
|
|
scheduler.step(metric) |
|
|
|
|
|
def validation_step(self, batch, batch_idx): |
|
|
"""used for logging metrics""" |
|
|
loss_and_metrics = self._get_loss(batch, get_metrics=True) |
|
|
loss = loss_and_metrics["loss"] |
|
|
|
|
|
|
|
|
self.log("val_loss", loss, on_epoch=True, prog_bar=True, logger=True) |
|
|
|
|
|
def training_epoch_end(self, training_step_outputs): |
|
|
|
|
|
training_step_outputs = list(training_step_outputs) |
|
|
|
|
|
training_step_outputs.pop() |
|
|
|
|
|
output_dict = {k: [dic[k] for dic in training_step_outputs] for k in training_step_outputs[0]} |
|
|
|
|
|
pred = torch.stack(output_dict["pred"]) |
|
|
act = torch.stack(output_dict["act"]) |
|
|
|
|
|
loss = torch.sub(pred, act) |
|
|
loss_sq = torch.square(loss) |
|
|
|
|
|
TSS = float(torch.var(act, unbiased=False)) |
|
|
RSS = float(torch.mean(loss_sq)) |
|
|
R2 = 1 - RSS / TSS |
|
|
|
|
|
self.log("train_R2", R2, prog_bar=True, logger=True) |
|
|
|
|
|
def _get_loss(self, batch, get_metrics: bool = False): |
|
|
"""convenience function since train/valid/test steps are similar""" |
|
|
pred = self.forward(x=batch["embedding"]).to(torch.float32) |
|
|
|
|
|
act, loss = None, None |
|
|
|
|
|
if "ctr" in batch.keys(): |
|
|
act = batch["ctr"].to(torch.float32) |
|
|
loss = self.MSE(pred, act).to(torch.float32) |
|
|
|
|
|
return {"loss": loss, "pred": pred, "act": act} |
|
|
|