Spaces:
Runtime error
Runtime error
Create src/model.py
Browse files- src/model.py +134 -0
src/model.py
ADDED
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copy pasted model from https://www.kaggle.com/code/yasufuminakama/fb3-deberta-v3-base-baseline-train/notebook
|
3 |
+
"""
|
4 |
+
import pytorch_lightning as pl
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
from torch.optim import AdamW
|
8 |
+
from torch.optim.lr_scheduler import CosineAnnealingLR
|
9 |
+
from transformers import AutoConfig, AutoModel, AutoTokenizer
|
10 |
+
|
11 |
+
from src.data_reader import load_train_test_df
|
12 |
+
from src.losses import MCRMSELoss
|
13 |
+
|
14 |
+
|
15 |
+
def num_train_samples():
|
16 |
+
train_df, _ = load_train_test_df()
|
17 |
+
return len(train_df)
|
18 |
+
|
19 |
+
|
20 |
+
class MeanPooling(nn.Module):
|
21 |
+
# taking mean of last hidden state with mask
|
22 |
+
|
23 |
+
def forward(self, last_hidden_state, attention_mask):
|
24 |
+
input_mask_expanded = attention_mask.unsqueeze(-1).expand(last_hidden_state.size()).float()
|
25 |
+
sum_embeddings = torch.sum(last_hidden_state * input_mask_expanded, 1)
|
26 |
+
sum_mask = input_mask_expanded.sum(1)
|
27 |
+
sum_mask = torch.clamp(sum_mask, min=1e-9)
|
28 |
+
mean_embeddings = sum_embeddings / sum_mask
|
29 |
+
return mean_embeddings
|
30 |
+
|
31 |
+
|
32 |
+
class BertLightningModel(pl.LightningModule):
|
33 |
+
|
34 |
+
def __init__(self, config: dict):
|
35 |
+
super(BertLightningModel, self).__init__()
|
36 |
+
|
37 |
+
self.config = config
|
38 |
+
|
39 |
+
huggingface_config = AutoConfig.from_pretrained(self.config['model_name'], output_hidden_states=True)
|
40 |
+
huggingface_config.hidden_dropout = 0.
|
41 |
+
huggingface_config.hidden_dropout_prob = 0.
|
42 |
+
huggingface_config.attention_dropout = 0.
|
43 |
+
huggingface_config.attention_probs_dropout_prob = 0.
|
44 |
+
|
45 |
+
self.tokenizer = AutoTokenizer.from_pretrained(self.config['model_name'])
|
46 |
+
self.model = AutoModel.from_pretrained(self.config['model_name'], config=huggingface_config)
|
47 |
+
|
48 |
+
self.pool = MeanPooling()
|
49 |
+
|
50 |
+
self.fc = nn.Linear(in_features=1024, out_features=6)
|
51 |
+
|
52 |
+
self.loss = MCRMSELoss()
|
53 |
+
|
54 |
+
# freezing first 20 layers of DeBERTa from 24
|
55 |
+
modules = [self.model.embeddings, self.model.encoder.layer[:self.config['num_frozen_layers']]]
|
56 |
+
for module in modules:
|
57 |
+
for param in module.parameters():
|
58 |
+
param.requires_grad = False
|
59 |
+
|
60 |
+
self.class_metric = None
|
61 |
+
self.best_metric = None
|
62 |
+
|
63 |
+
def forward(self, inputs):
|
64 |
+
outputs = self.model(**inputs)
|
65 |
+
last_hidden_state = outputs.last_hidden_state
|
66 |
+
|
67 |
+
bert_features = self.pool(last_hidden_state, inputs['attention_mask'])
|
68 |
+
|
69 |
+
logits = self.fc(bert_features)
|
70 |
+
|
71 |
+
return logits
|
72 |
+
|
73 |
+
def training_step(self, batch, batch_idx):
|
74 |
+
inputs = batch
|
75 |
+
labels = inputs.pop("labels", None)
|
76 |
+
logits = self(inputs)
|
77 |
+
loss = self.loss(logits, labels)
|
78 |
+
|
79 |
+
self.log('train/loss', loss)
|
80 |
+
|
81 |
+
return {
|
82 |
+
'loss': loss,
|
83 |
+
'mc_rmse': loss
|
84 |
+
}
|
85 |
+
|
86 |
+
def training_epoch_end(self, outputs):
|
87 |
+
mean_mc_rmse = sum(output['mc_rmse'].item() for output in outputs) / len(outputs)
|
88 |
+
self.log("train/epoch_loss", mean_mc_rmse)
|
89 |
+
|
90 |
+
def validation_step(self, batch, batch_idx):
|
91 |
+
inputs = batch
|
92 |
+
labels = inputs.pop("labels", None)
|
93 |
+
logits = self(inputs)
|
94 |
+
loss = self.loss(logits, labels)
|
95 |
+
class_rmse = self.loss.class_mcrmse(logits, labels)
|
96 |
+
|
97 |
+
self.log('val/loss', loss)
|
98 |
+
|
99 |
+
return {
|
100 |
+
'loss': loss,
|
101 |
+
'mc_rmse': loss,
|
102 |
+
'class_mc_rmse': class_rmse
|
103 |
+
}
|
104 |
+
|
105 |
+
def validation_epoch_end(self, outputs):
|
106 |
+
mean_mc_rmse = sum(output['mc_rmse'].item() for output in outputs) / len(outputs)
|
107 |
+
class_metrics = torch.stack([output['class_mc_rmse'] for output in outputs]).mean(0).tolist()
|
108 |
+
class_metrics = [round(item, 4) for item in class_metrics]
|
109 |
+
self.log('val/epoch_loss', mean_mc_rmse)
|
110 |
+
|
111 |
+
if self.best_metric is None or mean_mc_rmse < self.best_metric:
|
112 |
+
self.best_metric = mean_mc_rmse
|
113 |
+
self.class_metric = class_metrics
|
114 |
+
|
115 |
+
def configure_optimizers(self):
|
116 |
+
# weight_decay = self.config['weight_decay']
|
117 |
+
lr = self.config['lr']
|
118 |
+
|
119 |
+
# In original solution authors add weight decaying to some parameters
|
120 |
+
|
121 |
+
optimizer = AdamW(self.parameters(), lr=lr, weight_decay=0.0, eps=1e-6, betas=(0.9, 0.999))
|
122 |
+
|
123 |
+
scheduler = CosineAnnealingLR(
|
124 |
+
optimizer,
|
125 |
+
T_max=self.config['max_epochs'],
|
126 |
+
)
|
127 |
+
return [optimizer], [scheduler]
|
128 |
+
|
129 |
+
def predict_step(self, batch, batch_idx: int, dataloader_idx: int = 0):
|
130 |
+
inputs = batch
|
131 |
+
inputs.pop("labels", None)
|
132 |
+
logits = self(inputs)
|
133 |
+
|
134 |
+
return logits
|