Spaces:
Runtime error
Runtime error
Soumic
commited on
Commit
·
7bc0339
1
Parent(s):
09cc1f0
:rocket: Code is stable
Browse files
app.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
import logging
|
|
|
2 |
import random
|
3 |
from typing import Any
|
4 |
|
@@ -6,13 +7,14 @@ import numpy as np
|
|
6 |
import pandas as pd
|
7 |
from pytorch_lightning import Trainer, LightningModule, LightningDataModule
|
8 |
from pytorch_lightning.utilities.types import OptimizerLRScheduler, STEP_OUTPUT, EVAL_DATALOADERS, TRAIN_DATALOADERS
|
|
|
9 |
from torch.utils.data import DataLoader, Dataset
|
10 |
from torchmetrics.classification import BinaryAccuracy, BinaryAUROC, BinaryF1Score, BinaryPrecision, BinaryRecall
|
11 |
from transformers import BertModel, BatchEncoding, BertTokenizer, TrainingArguments
|
12 |
from transformers.modeling_outputs import BaseModelOutputWithPoolingAndCrossAttentions
|
13 |
import torch
|
14 |
from torch import nn
|
15 |
-
from datasets import load_dataset
|
16 |
from huggingface_hub import PyTorchModelHubMixin
|
17 |
|
18 |
timber = logging.getLogger()
|
@@ -136,19 +138,25 @@ def insert_debug_motif_at_random_position(seq, DEBUG_MOTIF):
|
|
136 |
return output
|
137 |
|
138 |
|
139 |
-
class MQTLDataset(
|
140 |
-
def __init__(self,
|
141 |
-
self.dataset =
|
142 |
self.check_if_pipeline_is_ok_by_inserting_debug_motif = check_if_pipeline_is_ok_by_inserting_debug_motif
|
143 |
self.debug_motif = "ATCGCCTA"
|
|
|
144 |
pass
|
145 |
|
146 |
-
def
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
|
|
|
|
|
|
|
|
|
|
152 |
if label == 1 and self.check_if_pipeline_is_ok_by_inserting_debug_motif:
|
153 |
seq = insert_debug_motif_at_random_position(seq=seq, DEBUG_MOTIF=self.debug_motif)
|
154 |
seq_rc = reverse_complement_dna_seq(seq)
|
@@ -162,16 +170,36 @@ class MQTLDataset(Dataset):
|
|
162 |
return [ohe_seq, ohe_seq_rc], label_np_array
|
163 |
|
164 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
165 |
class MqtlDataModule(LightningDataModule):
|
166 |
def __init__(self, train_ds: Dataset, val_ds: Dataset, test_ds: Dataset, batch_size=16):
|
167 |
super().__init__()
|
168 |
self.batch_size = batch_size
|
169 |
-
self.train_loader = DataLoader(train_ds, batch_size=self.batch_size, shuffle=
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
175 |
pass
|
176 |
|
177 |
def prepare_data(self):
|
@@ -376,11 +404,34 @@ def start(classifier_model, model_save_path, is_attention_model=False, m_optimiz
|
|
376 |
if is_binned:
|
377 |
file_suffix = "_binned"
|
378 |
|
379 |
-
|
380 |
-
|
381 |
-
|
382 |
-
|
383 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
384 |
|
385 |
data_module = MqtlDataModule(train_ds=train_dataset, val_ds=val_dataset, test_ds=test_dataset)
|
386 |
|
@@ -409,7 +460,8 @@ def start(classifier_model, model_save_path, is_attention_model=False, m_optimiz
|
|
409 |
classifier_model.save_pretrained(f"my-awesome-model-{WINDOW}")
|
410 |
|
411 |
# push to the hub
|
412 |
-
classifier_model.push_to_hub(repo_id="fahimfarhan/mqtl-classifier-model",
|
|
|
413 |
|
414 |
# reload
|
415 |
model = classifier_model.from_pretrained(f"my-awesome-model-{WINDOW}")
|
@@ -434,13 +486,16 @@ if __name__ == '__main__':
|
|
434 |
simple_cnn.enable_logging = True
|
435 |
|
436 |
start(classifier_model=simple_cnn, model_save_path=simple_cnn.file_name, WINDOW=WINDOW,
|
437 |
-
dataset_folder_prefix="inputdata/", is_debug=True, max_epochs=
|
438 |
|
439 |
pass
|
440 |
|
441 |
-
|
442 |
"""
|
443 |
lightning_logs/
|
444 |
*.pth
|
445 |
my-awesome-model
|
446 |
-
|
|
|
|
|
|
|
|
|
|
1 |
import logging
|
2 |
+
import os
|
3 |
import random
|
4 |
from typing import Any
|
5 |
|
|
|
7 |
import pandas as pd
|
8 |
from pytorch_lightning import Trainer, LightningModule, LightningDataModule
|
9 |
from pytorch_lightning.utilities.types import OptimizerLRScheduler, STEP_OUTPUT, EVAL_DATALOADERS, TRAIN_DATALOADERS
|
10 |
+
from torch.nn.utils.rnn import pad_sequence
|
11 |
from torch.utils.data import DataLoader, Dataset
|
12 |
from torchmetrics.classification import BinaryAccuracy, BinaryAUROC, BinaryF1Score, BinaryPrecision, BinaryRecall
|
13 |
from transformers import BertModel, BatchEncoding, BertTokenizer, TrainingArguments
|
14 |
from transformers.modeling_outputs import BaseModelOutputWithPoolingAndCrossAttentions
|
15 |
import torch
|
16 |
from torch import nn
|
17 |
+
from datasets import load_dataset, IterableDataset
|
18 |
from huggingface_hub import PyTorchModelHubMixin
|
19 |
|
20 |
timber = logging.getLogger()
|
|
|
138 |
return output
|
139 |
|
140 |
|
141 |
+
class MQTLDataset(IterableDataset):
|
142 |
+
def __init__(self, m_dataset, seq_len, check_if_pipeline_is_ok_by_inserting_debug_motif=False):
|
143 |
+
self.dataset = m_dataset
|
144 |
self.check_if_pipeline_is_ok_by_inserting_debug_motif = check_if_pipeline_is_ok_by_inserting_debug_motif
|
145 |
self.debug_motif = "ATCGCCTA"
|
146 |
+
self.seq_len = seq_len
|
147 |
pass
|
148 |
|
149 |
+
def __iter__(self):
|
150 |
+
for row in self.dataset:
|
151 |
+
processed = self.preprocess(row)
|
152 |
+
if processed is not None:
|
153 |
+
yield processed
|
154 |
+
|
155 |
+
def preprocess(self, row):
|
156 |
+
seq = row['sequence'] # Fetch the 'sequence' column
|
157 |
+
if len(seq) != self.seq_len:
|
158 |
+
return None # skip problematic row!
|
159 |
+
label = row['label'] # Fetch the 'label' column (or whatever target you use)
|
160 |
if label == 1 and self.check_if_pipeline_is_ok_by_inserting_debug_motif:
|
161 |
seq = insert_debug_motif_at_random_position(seq=seq, DEBUG_MOTIF=self.debug_motif)
|
162 |
seq_rc = reverse_complement_dna_seq(seq)
|
|
|
170 |
return [ohe_seq, ohe_seq_rc], label_np_array
|
171 |
|
172 |
|
173 |
+
# def collate_fn(batch):
|
174 |
+
# sequences, labels = zip(*batch)
|
175 |
+
# ohe_seq, ohe_seq_rc = sequences[0], sequences[1]
|
176 |
+
# # Pad sequences to the maximum length in this batch
|
177 |
+
# padded_sequences = pad_sequence(ohe_seq, batch_first=True, padding_value=0)
|
178 |
+
# padded_sequences_rc = pad_sequence(ohe_seq_rc, batch_first=True, padding_value=0)
|
179 |
+
# # Convert labels to a tensor
|
180 |
+
# labels = torch.stack(labels)
|
181 |
+
# return [padded_sequences, padded_sequences_rc], labels
|
182 |
+
|
183 |
+
|
184 |
class MqtlDataModule(LightningDataModule):
|
185 |
def __init__(self, train_ds: Dataset, val_ds: Dataset, test_ds: Dataset, batch_size=16):
|
186 |
super().__init__()
|
187 |
self.batch_size = batch_size
|
188 |
+
self.train_loader = DataLoader(train_ds, batch_size=self.batch_size, shuffle=False,
|
189 |
+
# collate_fn=collate_fn,
|
190 |
+
num_workers=15,
|
191 |
+
# persistent_workers=True
|
192 |
+
)
|
193 |
+
self.validate_loader = DataLoader(val_ds, batch_size=self.batch_size, shuffle=False,
|
194 |
+
# collate_fn=collate_fn,
|
195 |
+
num_workers=15,
|
196 |
+
# persistent_workers=True
|
197 |
+
)
|
198 |
+
self.test_loader = DataLoader(test_ds, batch_size=self.batch_size, shuffle=False,
|
199 |
+
# collate_fn=collate_fn,
|
200 |
+
num_workers=15,
|
201 |
+
# persistent_workers=True
|
202 |
+
)
|
203 |
pass
|
204 |
|
205 |
def prepare_data(self):
|
|
|
404 |
if is_binned:
|
405 |
file_suffix = "_binned"
|
406 |
|
407 |
+
data_files = {
|
408 |
+
# small samples
|
409 |
+
"train_binned_200": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_200_train_binned.csv",
|
410 |
+
"validate_binned_200": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_200_validate_binned.csv",
|
411 |
+
"test_binned_200": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_200_test_binned.csv",
|
412 |
+
# large samples
|
413 |
+
"train_binned_4000": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_4000_train_binned.csv",
|
414 |
+
"validate_binned_4000": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_4000_validate_binned.csv",
|
415 |
+
"test_binned_4000": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_4000_test_binned.csv",
|
416 |
+
}
|
417 |
+
|
418 |
+
dataset_map = None
|
419 |
+
is_my_laptop = os.path.isfile("/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_4000_test_binned.csv")
|
420 |
+
if is_my_laptop:
|
421 |
+
dataset_map = load_dataset("csv", data_files=data_files, streaming=True)
|
422 |
+
else:
|
423 |
+
dataset_map = load_dataset("fahimfarhan/mqtl-classification-datasets", streaming=True)
|
424 |
+
|
425 |
+
train_dataset = MQTLDataset(dataset_map[f"train_binned_{WINDOW}"],
|
426 |
+
check_if_pipeline_is_ok_by_inserting_debug_motif=is_debug,
|
427 |
+
seq_len=WINDOW
|
428 |
+
)
|
429 |
+
val_dataset = MQTLDataset(dataset_map[f"validate_binned_{WINDOW}"],
|
430 |
+
check_if_pipeline_is_ok_by_inserting_debug_motif=is_debug,
|
431 |
+
seq_len=WINDOW)
|
432 |
+
test_dataset = MQTLDataset(dataset_map[f"test_binned_{WINDOW}"],
|
433 |
+
check_if_pipeline_is_ok_by_inserting_debug_motif=is_debug,
|
434 |
+
seq_len=WINDOW)
|
435 |
|
436 |
data_module = MqtlDataModule(train_ds=train_dataset, val_ds=val_dataset, test_ds=test_dataset)
|
437 |
|
|
|
460 |
classifier_model.save_pretrained(f"my-awesome-model-{WINDOW}")
|
461 |
|
462 |
# push to the hub
|
463 |
+
classifier_model.push_to_hub(repo_id="fahimfarhan/mqtl-classifier-model",
|
464 |
+
commit_message=":tada: Push model using huggingface_hub")
|
465 |
|
466 |
# reload
|
467 |
model = classifier_model.from_pretrained(f"my-awesome-model-{WINDOW}")
|
|
|
486 |
simple_cnn.enable_logging = True
|
487 |
|
488 |
start(classifier_model=simple_cnn, model_save_path=simple_cnn.file_name, WINDOW=WINDOW,
|
489 |
+
dataset_folder_prefix="inputdata/", is_debug=True, max_epochs=5)
|
490 |
|
491 |
pass
|
492 |
|
|
|
493 |
"""
|
494 |
lightning_logs/
|
495 |
*.pth
|
496 |
my-awesome-model
|
497 |
+
|
498 |
+
INFO:root:validate_acc = 0.5625, validate_auc = 0.5490195751190186, validate_f1_score = 0.30000001192092896, validate_precision = 0.6000000238418579, validate_recall = 0.20000000298023224
|
499 |
+
/home/soumic/Codes/mqtl-classification/venv/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.
|
500 |
+
|
501 |
+
"""
|