| import lightning.pytorch as pl |
| from transformers import ( |
| AdamW, |
| AutoModel, |
| AutoConfig, |
| get_linear_schedule_with_warmup, |
| ) |
| from transformers.models.bert.modeling_bert import BertLMPredictionHead |
| import torch |
| from torch import nn |
| from loss import PCCL |
| import config |
|
|
|
|
| class CL_model(pl.LightningModule): |
| def __init__( |
| self, n_batches=None, n_epochs=None, lr=None, mlm_weight=None, **kwargs |
| ): |
| super().__init__() |
|
|
| |
| self.n_batches = n_batches |
| self.n_epochs = n_epochs |
| self.lr = lr |
| self.mlm_weight = mlm_weight |
| self.config = AutoConfig.from_pretrained("emilyalsentzer/Bio_ClinicalBERT") |
|
|
| |
| self.bert = AutoModel.from_pretrained( |
| "emilyalsentzer/Bio_ClinicalBERT", return_dict=True |
| ) |
| |
| self.bert_layer_num = sum(1 for _ in self.bert.named_parameters()) |
| self.num_unfreeze_layer = self.bert_layer_num |
| self.ratio_unfreeze_layer = 0.0 |
| if kwargs: |
| for key, value in kwargs.items(): |
| if key == "unfreeze" and isinstance(value, float): |
| assert ( |
| value >= 0.0 and value <= 1.0 |
| ), "ValueError: value must be a ratio between 0.0 and 1.0" |
| self.ratio_unfreeze_layer = value |
| if self.ratio_unfreeze_layer > 0.0: |
| self.num_unfreeze_layer = int( |
| self.bert_layer_num * self.ratio_unfreeze_layer |
| ) |
| for param in list(self.bert.parameters())[: -self.num_unfreeze_layer]: |
| param.requires_grad = False |
|
|
| self.lm_head = BertLMPredictionHead(self.config) |
| |
| print("Model Initialized!") |
|
|
| |
| self.cl_loss = PCCL() |
| self.mlm_loss = nn.CrossEntropyLoss() |
|
|
| |
| self.num_batches = 0 |
| self.train_loss, self.val_loss = 0, 0 |
| self.train_loss_cl, self.val_loss_cl = 0, 0 |
| self.train_loss_mlm, self.val_loss_mlm = 0, 0 |
| self.training_step_outputs, self.validation_step_outputs = [], [] |
|
|
| def forward(self, input_ids, attention_mask, masked_indices, eval=False): |
| embs = self.bert(input_ids=input_ids, attention_mask=attention_mask) |
| cls_tokens = embs.pooler_output |
| mask_tokens = [] |
| for idx, value in enumerate(masked_indices): |
| masks = embs.last_hidden_state[idx][value] |
| avg_mask = torch.mean(masks, dim=0) |
| mask_tokens.append(avg_mask) |
| mask_tokens = torch.stack(mask_tokens) |
| cls_concat_mask = torch.cat((cls_tokens, mask_tokens), dim=1) |
| if eval is True: |
| return cls_tokens, mask_tokens, cls_concat_mask |
|
|
| mlm_pred = self.lm_head(embs.last_hidden_state) |
| mlm_pred = mlm_pred.view(-1, self.config.vocab_size) |
| return cls_concat_mask, mlm_pred |
|
|
| def training_step(self, batch, batch_idx): |
| input_ids = batch["input_ids"] |
| attention_mask = batch["attention_mask"] |
| mlm_labels = batch["mlm_labels"] |
| masked_indices = batch["masked_indices"] |
| tags = batch["tags"] |
| scores = batch["scores"] |
| cls_concat_mask, mlm_pred = self(input_ids, attention_mask, masked_indices) |
| loss_cl = self.cl_loss(cls_concat_mask, tags, scores) |
| loss_mlm = self.mlm_loss(mlm_pred, mlm_labels.reshape(-1)) |
| loss = (1 - self.mlm_weight) * loss_cl + self.mlm_weight * loss_mlm |
| logs = {"loss": loss, "loss_cl": loss_cl, "loss_mlm": loss_mlm} |
| self.training_step_outputs.append(logs) |
| self.log("train_loss", loss, prog_bar=True, logger=True, sync_dist=True) |
|
|
| self.num_batches += 1 |
| self.train_loss_cl += loss_cl |
| self.train_loss_mlm += loss_mlm |
| self.train_loss += loss |
|
|
| if self.num_batches % config.log_every_n_steps == 0: |
| avg_loss_cl = self.train_loss_cl / self.num_batches |
| avg_loss_mlm = self.train_loss_mlm / self.num_batches |
| avg_loss = self.train_loss / self.num_batches |
| self.log( |
| "train_avg_cl_loss", |
| avg_loss_cl, |
| prog_bar=True, |
| logger=True, |
| sync_dist=True, |
| ) |
| self.log( |
| "train_avg_mlm_loss", |
| avg_loss_mlm, |
| prog_bar=True, |
| logger=True, |
| sync_dist=True, |
| ) |
| self.log( |
| "train_avg_loss", avg_loss, prog_bar=True, logger=True, sync_dist=True |
| ) |
| self.train_loss_cl = 0 |
| self.train_loss_mlm = 0 |
| self.train_loss = 0 |
| self.num_batches = 0 |
|
|
| return loss |
|
|
| def on_train_epoch_end(self): |
| e_t_avg_loss = ( |
| torch.stack([x["loss"] for x in self.training_step_outputs]) |
| .mean() |
| .detach() |
| .cpu() |
| .numpy() |
| ) |
| self.log( |
| "avg_loss_train_epoch", |
| e_t_avg_loss.item(), |
| on_step=False, |
| on_epoch=True, |
| sync_dist=True, |
| ) |
| e_t_avg_loss_cl = ( |
| torch.stack([x["loss_cl"] for x in self.training_step_outputs]) |
| .mean() |
| .detach() |
| .cpu() |
| .numpy() |
| ) |
| self.log( |
| "avg_loss_cl_train_epoch", |
| e_t_avg_loss_cl.item(), |
| on_step=False, |
| on_epoch=True, |
| sync_dist=True, |
| ) |
| e_t_avg_loss_mlm = ( |
| torch.stack([x["loss_mlm"] for x in self.training_step_outputs]) |
| .mean() |
| .detach() |
| .cpu() |
| .numpy() |
| ) |
| self.log( |
| "avg_loss_mlm_train_epoch", |
| e_t_avg_loss_mlm.item(), |
| on_step=False, |
| on_epoch=True, |
| sync_dist=True, |
| ) |
| print( |
| "train_epoch:", |
| self.current_epoch, |
| "avg_loss:", |
| e_t_avg_loss, |
| "avg_cl_loss:", |
| e_t_avg_loss_cl, |
| "avg_mlm_loss:", |
| e_t_avg_loss_mlm, |
| ) |
| self.training_step_outputs.clear() |
|
|
| def validation_step(self, batch, batch_idx): |
| input_ids = batch["input_ids"] |
| attention_mask = batch["attention_mask"] |
| mlm_labels = batch["mlm_labels"] |
| masked_indices = batch["masked_indices"] |
| tags = batch["tags"] |
| scores = batch["scores"] |
| cls_concat_mask, mlm_pred = self(input_ids, attention_mask, masked_indices) |
| loss_cl = self.cl_loss(cls_concat_mask, tags, scores) |
| loss_mlm = self.mlm_loss(mlm_pred, mlm_labels.reshape(-1)) |
| loss = (1 - self.mlm_weight) * loss_cl + self.mlm_weight * loss_mlm |
| logs = {"loss": loss, "loss_cl": loss_cl, "loss_mlm": loss_mlm} |
| self.validation_step_outputs.append(logs) |
| self.log("val_loss", loss, prog_bar=True, logger=True, sync_dist=True) |
|
|
| self.num_batches += 1 |
| self.val_loss_cl += loss_cl |
| self.val_loss_mlm += loss_mlm |
| self.val_loss += loss |
|
|
| if self.num_batches % config.log_every_n_steps == 0: |
| avg_loss_cl = self.val_loss_cl / self.num_batches |
| avg_loss_mlm = self.val_loss_mlm / self.num_batches |
| avg_loss = self.val_loss / self.num_batches |
| self.log( |
| "val_avg_cl_loss", |
| avg_loss_cl, |
| prog_bar=True, |
| logger=True, |
| sync_dist=True, |
| ) |
| self.log( |
| "val_avg_mlm_loss", |
| avg_loss_mlm, |
| prog_bar=True, |
| logger=True, |
| sync_dist=True, |
| ) |
| self.log( |
| "val_avg_loss", |
| avg_loss, |
| prog_bar=True, |
| logger=True, |
| sync_dist=True, |
| ) |
| self.val_loss_cl = 0 |
| self.val_loss_mlm = 0 |
| self.val_loss = 0 |
| self.num_batches = 0 |
|
|
| return loss |
|
|
| def on_validation_epoch_end(self): |
| e_v_avg_loss = ( |
| torch.stack([x["loss"] for x in self.validation_step_outputs]) |
| .mean() |
| .detach() |
| .cpu() |
| .numpy() |
| ) |
| self.log( |
| "avg_loss_val_epoch", |
| e_v_avg_loss.item(), |
| on_step=False, |
| on_epoch=True, |
| sync_dist=True, |
| ) |
| e_v_avg_loss_cl = ( |
| torch.stack([x["loss_cl"] for x in self.validation_step_outputs]) |
| .mean() |
| .detach() |
| .cpu() |
| .numpy() |
| ) |
| self.log( |
| "avg_loss_cl_val_epoch", |
| e_v_avg_loss_cl.item(), |
| on_step=False, |
| on_epoch=True, |
| sync_dist=True, |
| ) |
| e_v_avg_loss_mlm = ( |
| torch.stack([x["loss_mlm"] for x in self.validation_step_outputs]) |
| .mean() |
| .detach() |
| .cpu() |
| .numpy() |
| ) |
| self.log( |
| "avg_loss_mlm_val_epoch", |
| e_v_avg_loss_mlm.item(), |
| on_step=False, |
| on_epoch=True, |
| sync_dist=True, |
| ) |
| print( |
| "val_epoch:", |
| self.current_epoch, |
| "avg_loss:", |
| e_v_avg_loss, |
| "avg_cl_loss:", |
| e_v_avg_loss_cl, |
| "avg_mlm_loss:", |
| e_v_avg_loss_mlm, |
| ) |
| self.validation_step_outputs.clear() |
|
|
| def configure_optimizers(self): |
| |
| self.trainable_params = [ |
| param for param in self.parameters() if param.requires_grad |
| ] |
| optimizer = AdamW(self.trainable_params, lr=self.lr) |
|
|
| |
| warmup_steps = self.n_batches // 3 |
| total_steps = self.n_batches * self.n_epochs - warmup_steps |
| scheduler = get_linear_schedule_with_warmup( |
| optimizer, warmup_steps, total_steps |
| ) |
| return [optimizer], [scheduler] |
|
|