Spaces:
Runtime error
Runtime error
| import logging | |
| import os | |
| import time | |
| from typing import Any | |
| from huggingface_hub import PyTorchModelHubMixin | |
| from pytorch_lightning import Trainer, LightningModule, LightningDataModule | |
| from pytorch_lightning.utilities.types import OptimizerLRScheduler, STEP_OUTPUT, EVAL_DATALOADERS | |
| from torch.utils.data import DataLoader, Dataset, IterableDataset | |
| from torcheval.metrics import BinaryAccuracy, BinaryAUROC, BinaryF1Score, BinaryPrecision, BinaryRecall | |
| from transformers import BertModel, BatchEncoding, BertTokenizer, TrainingArguments | |
| from transformers.modeling_outputs import BaseModelOutputWithPoolingAndCrossAttentions | |
| import torch | |
| from torch import nn | |
| from datasets import load_dataset | |
| timber = logging.getLogger() | |
| # logging.basicConfig(level=logging.DEBUG) | |
| logging.basicConfig(level=logging.INFO) # change to level=logging.DEBUG to print more logs... | |
| NO_REGULARIZATION = 0 | |
| L1_REGULARIZATION_CODE = 1 | |
| L2_REGULARIZATION_CODE = 2 | |
| L1_AND_L2_REGULARIZATION_CODE = 3 | |
| black = "\u001b[30m" | |
| red = "\u001b[31m" | |
| green = "\u001b[32m" | |
| yellow = "\u001b[33m" | |
| blue = "\u001b[34m" | |
| magenta = "\u001b[35m" | |
| cyan = "\u001b[36m" | |
| white = "\u001b[37m" | |
| FORWARD = "FORWARD_INPUT" | |
| BACKWARD = "BACKWARD_INPUT" | |
| DNA_BERT_6 = "zhihan1996/DNA_bert_6" | |
| class CommonAttentionLayer(nn.Module): | |
| def __init__(self, hidden_size, *args, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| self.attention_linear = nn.Linear(hidden_size, 1) | |
| pass | |
| def forward(self, hidden_states): | |
| # Apply linear layer | |
| attn_weights = self.attention_linear(hidden_states) | |
| # Apply softmax to get attention scores | |
| attn_weights = torch.softmax(attn_weights, dim=1) | |
| # Apply attention weights to hidden states | |
| context_vector = torch.sum(attn_weights * hidden_states, dim=1) | |
| return context_vector, attn_weights | |
| class DNABert6MqtlClassifier(nn.Module, PyTorchModelHubMixin): | |
| def __init__(self, | |
| bert_model=BertModel.from_pretrained(pretrained_model_name_or_path=DNA_BERT_6), | |
| hidden_size=768, # I got mat-mul error, looks like this will be 12 times :/ | |
| num_classes=1, | |
| *args, | |
| **kwargs): | |
| super().__init__(*args, **kwargs) | |
| self.model_name = "DNABert6MqtlClassifier" | |
| self.bert_model = bert_model | |
| self.attention = CommonAttentionLayer(hidden_size) # Optional if you want to use attention | |
| classifier_input_size = 8 # cz mat-mul error | |
| self.classifier = nn.Linear(classifier_input_size, num_classes) | |
| def forward(self, input_ids, attention_mask, token_type_ids): | |
| # Run BERT on each sub-sequence and collect the embeddings | |
| embeddings = [] | |
| for i in range(input_ids.size(0)): # Iterate over sub-sequences | |
| outputs = self.bert_model( | |
| input_ids=input_ids[i], | |
| attention_mask=attention_mask[i], | |
| token_type_ids=token_type_ids[i] if token_type_ids is not None else None | |
| ) | |
| last_hidden_state = outputs.last_hidden_state | |
| embedding = last_hidden_state.mean(dim=1) # Example: taking the mean of hidden states | |
| embeddings.append(embedding) | |
| # Concatenate embeddings from all sub-sequences | |
| concatenated_embedding = torch.cat(embeddings, dim=1) | |
| # apply attention here | |
| context_vector, _ = self.attention(concatenated_embedding) | |
| # Classify | |
| y_probability = self.classifier(context_vector) | |
| return y_probability # float / double | |
| class TorchMetrics: | |
| def __init__(self): | |
| self.binary_accuracy = BinaryAccuracy() #.to(device) | |
| self.binary_auc = BinaryAUROC() # .to(device) | |
| self.binary_f1_score = BinaryF1Score() # .to(device) | |
| self.binary_precision = BinaryPrecision() # .to(device) | |
| self.binary_recall = BinaryRecall() # .to(device) | |
| pass | |
| def update_on_each_step(self, batch_predicted_labels, batch_actual_labels): # todo: Add log if needed | |
| # it looks like the library maintainers changed preds to input, ie, before: preds, now: input | |
| self.binary_accuracy.update(input=batch_predicted_labels, target=batch_actual_labels) | |
| self.binary_auc.update(input=batch_predicted_labels, target=batch_actual_labels) | |
| self.binary_f1_score.update(input=batch_predicted_labels, target=batch_actual_labels) | |
| self.binary_precision.update(input=batch_predicted_labels, target=batch_actual_labels) | |
| self.binary_recall.update(input=batch_predicted_labels, target=batch_actual_labels) | |
| pass | |
| def compute_metrics_and_log(self, log, log_prefix: str, log_color: str = green): | |
| b_accuracy = self.binary_accuracy.compute() | |
| b_auc = self.binary_auc.compute() | |
| b_f1_score = self.binary_f1_score.compute() | |
| b_precision = self.binary_precision.compute() | |
| b_recall = self.binary_recall.compute() | |
| timber.info( | |
| log_color + f"{log_prefix}_acc = {b_accuracy}, {log_prefix}_auc = {b_auc}, {log_prefix}_f1_score = {b_f1_score}, {log_prefix}_precision = {b_precision}, {log_prefix}_recall = {b_recall}") | |
| log(f"{log_prefix}_accuracy", b_accuracy) | |
| log(f"{log_prefix}_auc", b_auc) | |
| log(f"{log_prefix}_f1_score", b_f1_score) | |
| log(f"{log_prefix}_precision", b_precision) | |
| log(f"{log_prefix}_recall", b_recall) | |
| pass | |
| def reset_on_epoch_end(self): | |
| self.binary_accuracy.reset() | |
| self.binary_auc.reset() | |
| self.binary_f1_score.reset() | |
| self.binary_precision.reset() | |
| self.binary_recall.reset() | |
| class MQtlBertClassifierLightningModule(LightningModule): | |
| def __init__(self, | |
| classifier: nn.Module, | |
| criterion=nn.BCEWithLogitsLoss(), | |
| regularization: int = L2_REGULARIZATION_CODE, | |
| # 1 == L1, 2 == L2, 3 (== 1 | 2) == both l1 and l2, else ignore / don't care | |
| l1_lambda=0.0001, | |
| l2_wright_decay=0.0001, | |
| *args: Any, | |
| **kwargs: Any): | |
| super().__init__(*args, **kwargs) | |
| self.classifier = classifier | |
| self.criterion = criterion | |
| self.train_metrics = TorchMetrics() | |
| self.validate_metrics = TorchMetrics() | |
| self.test_metrics = TorchMetrics() | |
| self.regularization = regularization | |
| self.l1_lambda = l1_lambda | |
| self.l2_weight_decay = l2_wright_decay | |
| pass | |
| def forward(self, input_ids, attention_mask, token_type_ids, *args: Any, **kwargs: Any) -> Any: | |
| # print(f"\n{ type(input_ids) = }, {input_ids = }") | |
| # print(f"{ type(attention_mask) = }, { attention_mask = }") | |
| # print(f"{ type(token_type_ids) = }, { token_type_ids = }") | |
| return self.classifier.forward(input_ids, attention_mask, token_type_ids) | |
| def configure_optimizers(self) -> OptimizerLRScheduler: | |
| # Here we add weight decay (L2 regularization) to the optimizer | |
| weight_decay = 0.0 | |
| if self.regularization == L2_REGULARIZATION_CODE or self.regularization == L1_AND_L2_REGULARIZATION_CODE: | |
| weight_decay = self.l2_weight_decay | |
| return torch.optim.Adam(self.parameters(), lr=1e-5, weight_decay=weight_decay) # , weight_decay=0.005) | |
| def training_step(self, batch, batch_idx, *args: Any, **kwargs: Any) -> STEP_OUTPUT: | |
| # Accuracy on training batch data | |
| input_ids, attention_mask, token_type_ids, y = batch | |
| probability = self.forward(input_ids, attention_mask, token_type_ids) | |
| # prediction | |
| predicted_class = (probability >= 0.5).int() # Convert to binary and cast to int | |
| loss = self.criterion(probability, y.float()) | |
| if self.regularization == L1_REGULARIZATION_CODE or self.regularization == L1_AND_L2_REGULARIZATION_CODE: # apply l1 regularization | |
| l1_norm = sum(p.abs().sum() for p in self.parameters()) | |
| loss += self.l1_lambda * l1_norm | |
| self.log("train_loss", loss) | |
| # calculate the scores start | |
| self.train_metrics.update_on_each_step(batch_predicted_labels=predicted_class, batch_actual_labels=y) | |
| self.train_metrics.compute_metrics_and_log(log=self.log, log_prefix="train") | |
| # self.train_metrics.compute_and_log_on_each_step(log=self.log, log_prefix="train") | |
| # calculate the scores end | |
| return loss | |
| def on_train_epoch_end(self) -> None: | |
| self.train_metrics.compute_metrics_and_log(log=self.log, log_prefix="train") | |
| self.train_metrics.reset_on_epoch_end() | |
| pass | |
| def validation_step(self, batch, batch_idx, *args: Any, **kwargs: Any) -> STEP_OUTPUT: | |
| # Accuracy on validation batch data | |
| # print(f"debug { batch = }") | |
| input_ids, attention_mask, token_type_ids, y = batch | |
| probability = self.forward(input_ids, attention_mask, token_type_ids) | |
| # prediction | |
| predicted_class = (probability >= 0.5).int() # Convert to binary and cast to int | |
| # print(blue+f"{x.shape = }") | |
| # x should have [32, sth...] | |
| loss = self.criterion(probability, y.float()) | |
| """ loss = 0 # <------------------------- maybe the loss calculation is problematic """ | |
| self.log("valid_loss", loss) | |
| # calculate the scores start | |
| self.validate_metrics.update_on_each_step(batch_predicted_labels=predicted_class, batch_actual_labels=y) | |
| self.validate_metrics.compute_metrics_and_log(log=self.log, log_prefix="validate", log_color=blue) | |
| # self.validate_metrics.compute_and_log_on_each_step(log=self.log, log_prefix="validate", log_color=blue) | |
| # calculate the scores end | |
| return loss | |
| def on_validation_epoch_end(self) -> None: | |
| self.validate_metrics.compute_metrics_and_log(log=self.log, log_prefix="validate", log_color=blue) | |
| self.validate_metrics.reset_on_epoch_end() | |
| return None | |
| def test_step(self, batch, batch_idx, *args: Any, **kwargs: Any) -> STEP_OUTPUT: | |
| # Accuracy on validation batch data | |
| input_ids, attention_mask, token_type_ids, y = batch | |
| probability = self.forward(input_ids, attention_mask, token_type_ids) | |
| # prediction | |
| predicted_class = (probability >= 0.5).int() # Convert to binary and cast to int | |
| loss = self.criterion(probability, y.float()) | |
| self.log("test_loss", loss) # do we need this? | |
| # calculate the scores start | |
| self.test_metrics.update_on_each_step(batch_predicted_labels=predicted_class, batch_actual_labels=y) | |
| self.test_metrics.compute_metrics_and_log(log=self.log, log_prefix="test", log_color=magenta) | |
| # self.test_metrics.compute_and_log_on_each_step(log=self.log, log_prefix="test", log_color=magenta) | |
| # calculate the scores end | |
| return loss | |
| def on_test_epoch_end(self) -> None: | |
| self.test_metrics.compute_metrics_and_log(log=self.log, log_prefix="test", log_color=magenta) | |
| self.test_metrics.reset_on_epoch_end() | |
| return None | |
| pass | |
| class PagingMQTLDnaBertDataset(IterableDataset): | |
| def __init__(self, dataset, tokenizer, max_length=512): | |
| self.dataset = dataset | |
| self.bert_tokenizer = tokenizer | |
| self.max_length = max_length | |
| def __iter__(self): | |
| for row in self.dataset: | |
| processed = self.preprocess(row) | |
| if processed is not None: | |
| yield processed | |
| def preprocess(self, row): | |
| sequence = row['sequence'] | |
| label = row['label'] | |
| # Split the sequence into chunks of size max_length (512) | |
| chunks = [sequence[i:i + self.max_length] for i in range(0, len(sequence), self.max_length)] | |
| # Tokenize each chunk and return the tokenized inputs | |
| tokenized_inputs = { | |
| 'input_ids': [], | |
| 'attention_mask': [], | |
| 'token_type_ids': [] # If needed for DNABERT | |
| } | |
| for chunk in chunks: | |
| encoded_chunk = self.bert_tokenizer( | |
| chunk, | |
| truncation=True, | |
| padding='max_length', | |
| max_length=self.max_length, | |
| return_tensors='pt' | |
| ) | |
| tokenized_inputs['input_ids'].append(encoded_chunk['input_ids'].squeeze(0)) | |
| tokenized_inputs['attention_mask'].append(encoded_chunk['attention_mask'].squeeze(0)) | |
| tokenized_inputs['token_type_ids'].append( | |
| encoded_chunk['token_type_ids'].squeeze(0) if 'token_type_ids' in encoded_chunk else None) | |
| # Convert list of tensors to tensors with an extra batch dimension | |
| tokenized_inputs = {k: torch.stack(v) for k, v in tokenized_inputs.items() if v[0] is not None} | |
| input_ids = tokenized_inputs['input_ids'] | |
| attention_mask = tokenized_inputs['attention_mask'] | |
| token_type_ids = tokenized_inputs['token_type_ids'] | |
| # print(f"{type(input_ids) }") | |
| # print(f"{type(attention_mask) }") | |
| # print(f"{type(token_type_ids) }") | |
| # Concatenate these tensors along a new dimension | |
| # Result will be shape [3, num_chunks, 512] | |
| # stacked_inputs = torch.stack([input_ids, attention_mask, token_type_ids], dim=0) | |
| # return stacked_inputs, torch.tensor(label) | |
| return input_ids, attention_mask, token_type_ids, torch.tensor(label).int() | |
| class DNABERTDataModule(LightningDataModule): | |
| def __init__(self, model_name=DNA_BERT_6, batch_size=8, WINDOW=-1, is_local=False): | |
| super().__init__() | |
| self.tokenized_dataset = None | |
| self.dataset = None | |
| self.train_dataset: PagingMQTLDnaBertDataset = None | |
| self.validate_dataset: PagingMQTLDnaBertDataset = None | |
| self.test_dataset: PagingMQTLDnaBertDataset = None | |
| self.tokenizer = BertTokenizer.from_pretrained(pretrained_model_name_or_path=model_name) | |
| self.batch_size = batch_size | |
| self.is_local = is_local | |
| self.window = WINDOW | |
| def prepare_data(self): | |
| # Download and prepare dataset | |
| data_files = { | |
| # small samples | |
| "train_binned_200": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_200_train_binned.csv", | |
| "validate_binned_200": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_200_validate_binned.csv", | |
| "test_binned_200": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_200_test_binned.csv", | |
| # medium samples | |
| "train_binned_1000": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_1000_train_binned.csv", | |
| "validate_binned_1000": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_1000_validate_binned.csv", | |
| "test_binned_1000": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_1000_test_binned.csv", | |
| # large samples | |
| "train_binned_4000": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_4000_train_binned.csv", | |
| "validate_binned_4000": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_4000_validate_binned.csv", | |
| "test_binned_4000": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_4000_test_binned.csv", | |
| # really tiny | |
| # "tiny_train": "/home/soumic/Codes/mqtl-classification/src/inputdata/tiny_dataset_4000_train_binned.csv", | |
| # "tiny_validate": "/home/soumic/Codes/mqtl-classification/src/inputdata/tiny_dataset_4000_validate_binned.csv", | |
| # "tiny_test": "/home/soumic/Codes/mqtl-classification/src/inputdata/tiny_dataset_4000_test_binned.csv", | |
| "tiny_train": "/home/soumic/Codes/mqtl-classification/src/inputdata/medium_dataset_4000_train_binned.csv", | |
| "tiny_validate": "/home/soumic/Codes/mqtl-classification/src/inputdata/medium_dataset_4000_validate_binned.csv", | |
| "tiny_test": "/home/soumic/Codes/mqtl-classification/src/inputdata/medium_dataset_4000_test_binned.csv", | |
| } | |
| if self.is_local: | |
| self.dataset = load_dataset("csv", data_files=data_files, streaming=True) | |
| else: | |
| self.dataset = load_dataset("fahimfarhan/mqtl-classification-datasets") | |
| def setup(self, stage=None): | |
| self.train_dataset = PagingMQTLDnaBertDataset(self.dataset['tiny_test'], self.tokenizer) | |
| self.validate_dataset = PagingMQTLDnaBertDataset(self.dataset['tiny_validate'], self.tokenizer) | |
| self.test_dataset = PagingMQTLDnaBertDataset(self.dataset['tiny_test'], self.tokenizer) | |
| def train_dataloader(self): | |
| return DataLoader(self.train_dataset, batch_size=self.batch_size, num_workers=1) | |
| def val_dataloader(self): | |
| return DataLoader(self.validate_dataset, batch_size=self.batch_size, num_workers=1) | |
| def test_dataloader(self) -> EVAL_DATALOADERS: | |
| return DataLoader(self.test_dataset, batch_size=self.batch_size, num_workers=1) | |
| def start_bert(classifier_model, model_save_path, criterion, WINDOW, batch_size=4, | |
| is_binned=True, is_debug=False, max_epochs=10, regularization_code=L2_REGULARIZATION_CODE): | |
| is_my_laptop = os.path.isfile("/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_4000_train_binned.csv") | |
| model_local_directory = f"my-awesome-model-{WINDOW}" | |
| model_remote_repository = f"fahimfarhan/dnabert-6-mqtl-classifier-{WINDOW}" | |
| file_suffix = "" | |
| if is_binned: | |
| file_suffix = "_binned" | |
| data_module = DNABERTDataModule(batch_size=batch_size, WINDOW=WINDOW, is_local=is_my_laptop) | |
| # classifier_model = classifier_model.to(DEVICE) | |
| classifier_module = MQtlBertClassifierLightningModule( | |
| classifier=classifier_model, | |
| regularization=regularization_code, criterion=criterion) | |
| # if os.path.exists(model_save_path): | |
| # classifier_module.load_state_dict(torch.load(model_save_path)) | |
| classifier_module = classifier_module # .double() | |
| # Prepare data using the DataModule | |
| data_module.prepare_data() | |
| data_module.setup() | |
| trainer = Trainer(max_epochs=max_epochs, precision="32") | |
| # Train the model | |
| trainer.fit(model=classifier_module, datamodule=data_module) | |
| trainer.test(model=classifier_module, datamodule=data_module) | |
| torch.save(classifier_module.state_dict(), model_save_path) | |
| # classifier_module.push_to_hub("fahimfarhan/mqtl-classifier-model") | |
| classifier_model.save_pretrained(save_directory=model_local_directory, safe_serialization=False) | |
| # push to the hub | |
| commit_message = f":tada: Push model for window size {WINDOW} from huggingface space" | |
| if is_my_laptop: | |
| commit_message = f":tada: Push model for window size {WINDOW} from zephyrus" | |
| classifier_model.push_to_hub( | |
| repo_id=model_remote_repository, | |
| # subfolder=f"my-awesome-model-{WINDOW}", subfolder didn't work :/ | |
| commit_message=commit_message, # f":tada: Push model for window size {WINDOW}" | |
| # safe_serialization=False | |
| ) | |
| pass | |
| if __name__ == "__main__": | |
| start_time = time.time() | |
| dataset_folder_prefix = "inputdata/" | |
| pytorch_model = DNABert6MqtlClassifier() | |
| start_bert(classifier_model=pytorch_model, model_save_path=f"weights_{pytorch_model.model_name}.pth", | |
| criterion=nn.BCEWithLogitsLoss(), WINDOW=4000, batch_size=1, # 12, # max 14 on my laptop... | |
| max_epochs=1, regularization_code=L2_REGULARIZATION_CODE) | |
| # Record the end time | |
| end_time = time.time() | |
| # Calculate the duration | |
| duration = end_time - start_time | |
| # Print the runtime | |
| print(f"Runtime: {duration:.2f} seconds") | |
| pass | |