| | import numpy as np |
| | from datasets import load_from_disk |
| | import torch |
| | from transformers import BertForMaskedLM |
| | import os |
| | import sys |
| | from tqdm.notebook import tqdm |
| | import seaborn as sns |
| | import matplotlib.pyplot as plt |
| | |
| | from geneformer.pretrainer import token_dictionary |
| | import datetime |
| | import time |
| | import pickle |
| | import random |
| | import subprocess |
| | import numpy as np |
| | import pytz |
| | import torch |
| | from datasets import load_from_disk, Dataset |
| | from transformers import BertConfig, BertForMaskedLM, TrainingArguments, TrainerCallback, Trainer, BertModel, BertPreTrainedModel |
| | from geneformer import GeneformerPretrainer |
| | from typing import Tuple |
| | from torch import Tensor |
| | from transformers.modeling_outputs import MaskedLMOutput |
| | from transformers.models.bert.modeling_bert import BertLMPredictionHead, BertOnlyMLMHead, BertPredictionHeadTransform |
| | from transformers.activations import ACT2FN |
| | from typing import List, Optional, Tuple, Union |
| | import torch.nn.functional as F |
| |
|
| | class CustomBertForMaskedLM(BertPreTrainedModel): |
| | _keys_to_ignore_on_load_missing = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"] |
| | _tied_weights_keys = ["decoder.weight", "bert.embeddings.word_embeddings.weight"] |
| |
|
| | def __init__(self, config): |
| | super().__init__(config) |
| | self.bert = BertModel(config, add_pooling_layer=False) |
| | self.transform = BertPredictionHeadTransform(config) |
| |
|
| | self.decoder = torch.nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
| |
|
| | self.bias = torch.nn.Parameter(torch.zeros(config.vocab_size)) |
| |
|
| | |
| | self.init_weights() |
| |
|
| | |
| | self.tie_weights() |
| |
|
| | |
| |
|
| | def tie_weights(self): |
| | """ |
| | Ties the weights between the input embeddings and output decoder weights. |
| | """ |
| | self.decoder.weight = self.bert.embeddings.word_embeddings.weight |
| |
|
| | def probability_convert(self, probs: Tensor, input_ids: Tensor, labels: Tensor) -> Tensor: |
| | device = probs.device |
| | batch_size, seq_length, vocab_size = probs.size() |
| | _, input_seq_length = input_ids.size() |
| |
|
| | |
| | |
| | non_mask = labels == -100 |
| | non_mask_indices = non_mask.nonzero(as_tuple=True) |
| | known_gene_indices = input_ids[non_mask] |
| |
|
| | |
| | zeros = torch.zeros((batch_size, 1, vocab_size), device=device) |
| | zeros[non_mask_indices[0], 0, known_gene_indices] = 1.0 |
| | probs_shifted = torch.cat((zeros, probs[:, :-1, :]), dim=1) |
| | inv_probs_shifted = 1 - probs_shifted |
| | |
| | |
| | cumprod_inv_probs = torch.cumprod(inv_probs_shifted, dim=1) |
| | modified_probs = probs * cumprod_inv_probs |
| |
|
| | |
| | |
| | |
| | |
| | |
| | normalized_probs = modified_probs.sum(dim=-1, keepdim=True).clamp(min=1e-18) |
| | modified_probs = modified_probs / normalized_probs |
| | |
| | return modified_probs |
| | |
| | def assign_known_gene_probs(self, probs: Tensor, input_ids: Tensor, labels: Tensor) -> Tensor: |
| |
|
| | device = probs.device |
| | batch_size, seq_length, vocab_size = probs.size() |
| | _, input_seq_length = input_ids.size() |
| |
|
| | |
| | truncated_labels = labels[:, :input_seq_length] |
| |
|
| | non_mask = truncated_labels == -100 |
| | non_mask_indices = non_mask.nonzero(as_tuple=True) |
| |
|
| | ones = torch.ones((batch_size, seq_length, vocab_size), device=device) |
| | zeros = torch.zeros((batch_size, seq_length, vocab_size), device=device) |
| | |
| | known_gene_indices = input_ids[non_mask] |
| |
|
| | ones[non_mask_indices[0], non_mask_indices[1], :] = 0.0 |
| | zeros[non_mask_indices[0], non_mask_indices[1], known_gene_indices] = 1.0 |
| |
|
| | |
| | modified_probs = probs * ones |
| | modified_probs = modified_probs + zeros |
| |
|
| | |
| | modified_probs = modified_probs / modified_probs.sum(dim=-1, keepdim=True).clamp(min=1e-18) |
| |
|
| | return modified_probs |
| |
|
| | def compute_similarity_on_probs(self, probs: Tensor) -> Tensor: |
| | """ |
| | Optimized computation of average cosine similarity across all positions in each sequence and batch. |
| | |
| | Args: |
| | probs (torch.Tensor): Probability tensor of shape (batch_size, seq_length, vocab_size). |
| | |
| | Returns: |
| | torch.Tensor: Average similarity term for loss computation. |
| | """ |
| | batch_size, seq_length, vocab_size = probs.size() |
| |
|
| | |
| | probs_norm = F.normalize(probs, dim=-1) |
| | |
| | |
| | similarities = torch.einsum("biv,bjv->bij", probs_norm, probs_norm) |
| |
|
| | |
| | mask_sim = torch.triu(torch.ones(seq_length, seq_length, device=probs.device), diagonal=1) |
| | valid_similarities = similarities * mask_sim |
| |
|
| | |
| | total_similarity = valid_similarities.sum() |
| | total_comparisons = mask_sim.sum().item() * batch_size |
| |
|
| | return total_similarity / total_comparisons |
| |
|
| |
|
| | def forward( |
| | self, |
| | input_ids: Tensor | None = None, |
| | attention_mask: Tensor | None = None, |
| | token_type_ids: Tensor | None = None, |
| | position_ids: Tensor | None = None, |
| | head_mask: Tensor | None = None, |
| | inputs_embeds: Tensor | None = None, |
| | encoder_hidden_states: Tensor | None = None, |
| | encoder_attention_mask: Tensor | None = None, |
| | labels: Tensor | None = None, |
| | output_attentions: bool | None = None, |
| | output_hidden_states: bool | None = None, |
| | return_dict: bool | None = None) -> Union[Tuple[torch.Tensor], MaskedLMOutput]: |
| |
|
| | return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
| | |
| | outputs = self.bert( |
| | input_ids, |
| | attention_mask=attention_mask, |
| | token_type_ids=token_type_ids, |
| | position_ids=position_ids, |
| | head_mask=head_mask, |
| | inputs_embeds=inputs_embeds, |
| | output_attentions=output_attentions, |
| | output_hidden_states=output_hidden_states, |
| | return_dict=return_dict, |
| | ) |
| | |
| | hidden_states = outputs[0] |
| | hidden_transform = self.transform(hidden_states) |
| | logits = self.decoder(hidden_transform) + self.bias |
| |
|
| | |
| | |
| |
|
| | probs = F.softmax(logits, dim=-1) |
| | |
| | |
| | |
| | |
| | probs = self.assign_known_gene_probs(probs, input_ids, labels) |
| | convert_probs = self.probability_convert(probs, input_ids, labels) |
| | assigned_probs = self.assign_known_gene_probs(convert_probs, input_ids, labels) |
| |
|
| | masked_lm_loss = None |
| | if labels is not None: |
| | |
| | probs_flat = probs.view(-1, self.config.vocab_size) |
| | labels_flat = labels.view(-1) |
| | mask = (labels != -100).float().view(-1) |
| |
|
| | |
| | masked_lm_loss = -torch.log(torch.clamp(probs_flat[torch.arange(len(labels_flat)), labels_flat], min=1e-18)) * mask |
| | masked_lm_loss = masked_lm_loss.sum() / mask.sum() |
| |
|
| | similarity_loss = self.compute_similarity_on_probs(assigned_probs) |
| | lambda_similarity = 200.0 |
| | masked_lm_loss = masked_lm_loss + lambda_similarity * similarity_loss |
| |
|
| | |
| | else: |
| | loss = None |
| |
|
| | if not return_dict: |
| | output = (assigned_probs,) + outputs[2:] |
| | return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output |
| |
|
| | return MaskedLMOutput( |
| | loss=masked_lm_loss, |
| | |
| | logits=probs, |
| | hidden_states=outputs.hidden_states, |
| | attentions=outputs.attentions, |
| | ) |
| | |
| | def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs): |
| | input_shape = input_ids.shape |
| | effective_batch_size = input_shape[0] |
| |
|
| | |
| | if self.config.pad_token_id is None: |
| | raise ValueError("The PAD token should be defined for generation") |
| |
|
| | attention_mask = torch.cat([attention_mask, attention_mask.new_zeros((attention_mask.shape[0], 1))], dim=-1) |
| | dummy_token = torch.full( |
| | (effective_batch_size, 1), self.config.pad_token_id, dtype=torch.long, device=input_ids.device |
| | ) |
| | input_ids = torch.cat([input_ids, dummy_token], dim=1) |
| |
|
| | return {"input_ids": input_ids, "attention_mask": attention_mask} |
| |
|