Spaces:
Runtime error
Runtime error
BERT_MODEL_NAME = 'bert-base-cased' | |
LABEL_COLUMNS = ['anger','joy','fear','surprise','sadness', 'neutral'] | |
class EmotionTagger(pl.LightningModule): | |
def __init__(self, n_classes: int, n_training_steps=None, n_warmup_steps=None): | |
super().__init__() | |
self.bert = BertModel.from_pretrained(BERT_MODEL_NAME, return_dict=True) | |
self.classifier = nn.Linear(self.bert.config.hidden_size, n_classes) | |
self.n_training_steps = n_training_steps | |
self.n_warmup_steps = n_warmup_steps | |
self.criterion = nn.BCELoss() | |
def forward(self, input_ids, attention_mask, labels=None): | |
output = self.bert(input_ids, attention_mask=attention_mask) | |
output = self.classifier(output.pooler_output) | |
output = torch.sigmoid(output) | |
loss = 0 | |
if labels is not None: | |
loss = self.criterion(output, labels) | |
return loss, output | |
def training_step(self, batch, batch_idx): | |
input_ids = batch["input_ids"] | |
attention_mask = batch["attention_mask"] | |
labels = batch["labels"] | |
loss, outputs = self(input_ids, attention_mask, labels) | |
self.log("train_loss", loss, prog_bar=True, logger=True) | |
return {"loss": loss, "predictions": outputs, "labels": labels} | |
def validation_step(self, batch, batch_idx): | |
input_ids = batch["input_ids"] | |
attention_mask = batch["attention_mask"] | |
labels = batch["labels"] | |
loss, outputs = self(input_ids, attention_mask, labels) | |
self.log("val_loss", loss, prog_bar=True, logger=True) | |
return loss | |
def test_step(self, batch, batch_idx): | |
input_ids = batch["input_ids"] | |
attention_mask = batch["attention_mask"] | |
labels = batch["labels"] | |
loss, outputs = self(input_ids, attention_mask, labels) | |
self.log("test_loss", loss, prog_bar=True, logger=True) | |
return loss | |
for i, name in enumerate(LABEL_COLUMNS): | |
class_roc_auc = pytorch_lightning.metrics.functional.auroc(predictions[:, i], labels[:, i]) | |
self.logger.experiment.add_scalar(f"{name}_roc_auc/Train", class_roc_auc, self.current_epoch) | |
def configure_optimizers(self): | |
optimizer = AdamW(self.parameters(), lr=2e-5) | |
scheduler = get_linear_schedule_with_warmup( | |
optimizer, | |
num_warmup_steps=self.n_warmup_steps, | |
num_training_steps=self.n_training_steps | |
) | |
return dict( | |
optimizer=optimizer, | |
lr_scheduler=dict( | |
scheduler=scheduler, | |
interval='step' | |
) | |
) |