Soumic commited on
Commit
7bc0339
·
1 Parent(s): 09cc1f0

:rocket: Code is stable

Browse files
Files changed (1) hide show
  1. app.py +80 -25
app.py CHANGED
@@ -1,4 +1,5 @@
1
  import logging
 
2
  import random
3
  from typing import Any
4
 
@@ -6,13 +7,14 @@ import numpy as np
6
  import pandas as pd
7
  from pytorch_lightning import Trainer, LightningModule, LightningDataModule
8
  from pytorch_lightning.utilities.types import OptimizerLRScheduler, STEP_OUTPUT, EVAL_DATALOADERS, TRAIN_DATALOADERS
 
9
  from torch.utils.data import DataLoader, Dataset
10
  from torchmetrics.classification import BinaryAccuracy, BinaryAUROC, BinaryF1Score, BinaryPrecision, BinaryRecall
11
  from transformers import BertModel, BatchEncoding, BertTokenizer, TrainingArguments
12
  from transformers.modeling_outputs import BaseModelOutputWithPoolingAndCrossAttentions
13
  import torch
14
  from torch import nn
15
- from datasets import load_dataset
16
  from huggingface_hub import PyTorchModelHubMixin
17
 
18
  timber = logging.getLogger()
@@ -136,19 +138,25 @@ def insert_debug_motif_at_random_position(seq, DEBUG_MOTIF):
136
  return output
137
 
138
 
139
- class MQTLDataset(Dataset):
140
- def __init__(self, dataset, check_if_pipeline_is_ok_by_inserting_debug_motif=False):
141
- self.dataset = dataset
142
  self.check_if_pipeline_is_ok_by_inserting_debug_motif = check_if_pipeline_is_ok_by_inserting_debug_motif
143
  self.debug_motif = "ATCGCCTA"
 
144
  pass
145
 
146
- def __len__(self):
147
- return len(self.dataset)
148
-
149
- def __getitem__(self, idx):
150
- seq = self.dataset[idx]['sequence'] # Fetch the 'sequence' column
151
- label = self.dataset[idx]['label'] # Fetch the 'label' column (or whatever target you use)
 
 
 
 
 
152
  if label == 1 and self.check_if_pipeline_is_ok_by_inserting_debug_motif:
153
  seq = insert_debug_motif_at_random_position(seq=seq, DEBUG_MOTIF=self.debug_motif)
154
  seq_rc = reverse_complement_dna_seq(seq)
@@ -162,16 +170,36 @@ class MQTLDataset(Dataset):
162
  return [ohe_seq, ohe_seq_rc], label_np_array
163
 
164
 
 
 
 
 
 
 
 
 
 
 
 
165
  class MqtlDataModule(LightningDataModule):
166
  def __init__(self, train_ds: Dataset, val_ds: Dataset, test_ds: Dataset, batch_size=16):
167
  super().__init__()
168
  self.batch_size = batch_size
169
- self.train_loader = DataLoader(train_ds, batch_size=self.batch_size, shuffle=True, num_workers=15,
170
- persistent_workers=True)
171
- self.validate_loader = DataLoader(val_ds, batch_size=self.batch_size, shuffle=False, num_workers=15,
172
- persistent_workers=True)
173
- self.test_loader = DataLoader(test_ds, batch_size=self.batch_size, shuffle=False, num_workers=15,
174
- persistent_workers=True)
 
 
 
 
 
 
 
 
 
175
  pass
176
 
177
  def prepare_data(self):
@@ -376,11 +404,34 @@ def start(classifier_model, model_save_path, is_attention_model=False, m_optimiz
376
  if is_binned:
377
  file_suffix = "_binned"
378
 
379
- dataset_map = load_dataset("fahimfarhan/mqtl-classification-datasets")
380
-
381
- train_dataset = MQTLDataset(dataset_map[f"train_binned_{WINDOW}"], check_if_pipeline_is_ok_by_inserting_debug_motif=is_debug)
382
- val_dataset = MQTLDataset(dataset_map[f"validate_binned_{WINDOW}"], check_if_pipeline_is_ok_by_inserting_debug_motif=is_debug)
383
- test_dataset = MQTLDataset(dataset_map[f"test_binned_{WINDOW}"], check_if_pipeline_is_ok_by_inserting_debug_motif=is_debug)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
384
 
385
  data_module = MqtlDataModule(train_ds=train_dataset, val_ds=val_dataset, test_ds=test_dataset)
386
 
@@ -409,7 +460,8 @@ def start(classifier_model, model_save_path, is_attention_model=False, m_optimiz
409
  classifier_model.save_pretrained(f"my-awesome-model-{WINDOW}")
410
 
411
  # push to the hub
412
- classifier_model.push_to_hub(repo_id="fahimfarhan/mqtl-classifier-model", commit_message=":tada: Push model using huggingface_hub")
 
413
 
414
  # reload
415
  model = classifier_model.from_pretrained(f"my-awesome-model-{WINDOW}")
@@ -434,13 +486,16 @@ if __name__ == '__main__':
434
  simple_cnn.enable_logging = True
435
 
436
  start(classifier_model=simple_cnn, model_save_path=simple_cnn.file_name, WINDOW=WINDOW,
437
- dataset_folder_prefix="inputdata/", is_debug=True, max_epochs=3)
438
 
439
  pass
440
 
441
-
442
  """
443
  lightning_logs/
444
  *.pth
445
  my-awesome-model
446
- """
 
 
 
 
 
1
  import logging
2
+ import os
3
  import random
4
  from typing import Any
5
 
 
7
  import pandas as pd
8
  from pytorch_lightning import Trainer, LightningModule, LightningDataModule
9
  from pytorch_lightning.utilities.types import OptimizerLRScheduler, STEP_OUTPUT, EVAL_DATALOADERS, TRAIN_DATALOADERS
10
+ from torch.nn.utils.rnn import pad_sequence
11
  from torch.utils.data import DataLoader, Dataset
12
  from torchmetrics.classification import BinaryAccuracy, BinaryAUROC, BinaryF1Score, BinaryPrecision, BinaryRecall
13
  from transformers import BertModel, BatchEncoding, BertTokenizer, TrainingArguments
14
  from transformers.modeling_outputs import BaseModelOutputWithPoolingAndCrossAttentions
15
  import torch
16
  from torch import nn
17
+ from datasets import load_dataset, IterableDataset
18
  from huggingface_hub import PyTorchModelHubMixin
19
 
20
  timber = logging.getLogger()
 
138
  return output
139
 
140
 
141
+ class MQTLDataset(IterableDataset):
142
+ def __init__(self, m_dataset, seq_len, check_if_pipeline_is_ok_by_inserting_debug_motif=False):
143
+ self.dataset = m_dataset
144
  self.check_if_pipeline_is_ok_by_inserting_debug_motif = check_if_pipeline_is_ok_by_inserting_debug_motif
145
  self.debug_motif = "ATCGCCTA"
146
+ self.seq_len = seq_len
147
  pass
148
 
149
+ def __iter__(self):
150
+ for row in self.dataset:
151
+ processed = self.preprocess(row)
152
+ if processed is not None:
153
+ yield processed
154
+
155
+ def preprocess(self, row):
156
+ seq = row['sequence'] # Fetch the 'sequence' column
157
+ if len(seq) != self.seq_len:
158
+ return None # skip problematic row!
159
+ label = row['label'] # Fetch the 'label' column (or whatever target you use)
160
  if label == 1 and self.check_if_pipeline_is_ok_by_inserting_debug_motif:
161
  seq = insert_debug_motif_at_random_position(seq=seq, DEBUG_MOTIF=self.debug_motif)
162
  seq_rc = reverse_complement_dna_seq(seq)
 
170
  return [ohe_seq, ohe_seq_rc], label_np_array
171
 
172
 
173
+ # def collate_fn(batch):
174
+ # sequences, labels = zip(*batch)
175
+ # ohe_seq, ohe_seq_rc = sequences[0], sequences[1]
176
+ # # Pad sequences to the maximum length in this batch
177
+ # padded_sequences = pad_sequence(ohe_seq, batch_first=True, padding_value=0)
178
+ # padded_sequences_rc = pad_sequence(ohe_seq_rc, batch_first=True, padding_value=0)
179
+ # # Convert labels to a tensor
180
+ # labels = torch.stack(labels)
181
+ # return [padded_sequences, padded_sequences_rc], labels
182
+
183
+
184
  class MqtlDataModule(LightningDataModule):
185
  def __init__(self, train_ds: Dataset, val_ds: Dataset, test_ds: Dataset, batch_size=16):
186
  super().__init__()
187
  self.batch_size = batch_size
188
+ self.train_loader = DataLoader(train_ds, batch_size=self.batch_size, shuffle=False,
189
+ # collate_fn=collate_fn,
190
+ num_workers=15,
191
+ # persistent_workers=True
192
+ )
193
+ self.validate_loader = DataLoader(val_ds, batch_size=self.batch_size, shuffle=False,
194
+ # collate_fn=collate_fn,
195
+ num_workers=15,
196
+ # persistent_workers=True
197
+ )
198
+ self.test_loader = DataLoader(test_ds, batch_size=self.batch_size, shuffle=False,
199
+ # collate_fn=collate_fn,
200
+ num_workers=15,
201
+ # persistent_workers=True
202
+ )
203
  pass
204
 
205
  def prepare_data(self):
 
404
  if is_binned:
405
  file_suffix = "_binned"
406
 
407
+ data_files = {
408
+ # small samples
409
+ "train_binned_200": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_200_train_binned.csv",
410
+ "validate_binned_200": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_200_validate_binned.csv",
411
+ "test_binned_200": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_200_test_binned.csv",
412
+ # large samples
413
+ "train_binned_4000": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_4000_train_binned.csv",
414
+ "validate_binned_4000": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_4000_validate_binned.csv",
415
+ "test_binned_4000": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_4000_test_binned.csv",
416
+ }
417
+
418
+ dataset_map = None
419
+ is_my_laptop = os.path.isfile("/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_4000_test_binned.csv")
420
+ if is_my_laptop:
421
+ dataset_map = load_dataset("csv", data_files=data_files, streaming=True)
422
+ else:
423
+ dataset_map = load_dataset("fahimfarhan/mqtl-classification-datasets", streaming=True)
424
+
425
+ train_dataset = MQTLDataset(dataset_map[f"train_binned_{WINDOW}"],
426
+ check_if_pipeline_is_ok_by_inserting_debug_motif=is_debug,
427
+ seq_len=WINDOW
428
+ )
429
+ val_dataset = MQTLDataset(dataset_map[f"validate_binned_{WINDOW}"],
430
+ check_if_pipeline_is_ok_by_inserting_debug_motif=is_debug,
431
+ seq_len=WINDOW)
432
+ test_dataset = MQTLDataset(dataset_map[f"test_binned_{WINDOW}"],
433
+ check_if_pipeline_is_ok_by_inserting_debug_motif=is_debug,
434
+ seq_len=WINDOW)
435
 
436
  data_module = MqtlDataModule(train_ds=train_dataset, val_ds=val_dataset, test_ds=test_dataset)
437
 
 
460
  classifier_model.save_pretrained(f"my-awesome-model-{WINDOW}")
461
 
462
  # push to the hub
463
+ classifier_model.push_to_hub(repo_id="fahimfarhan/mqtl-classifier-model",
464
+ commit_message=":tada: Push model using huggingface_hub")
465
 
466
  # reload
467
  model = classifier_model.from_pretrained(f"my-awesome-model-{WINDOW}")
 
486
  simple_cnn.enable_logging = True
487
 
488
  start(classifier_model=simple_cnn, model_save_path=simple_cnn.file_name, WINDOW=WINDOW,
489
+ dataset_folder_prefix="inputdata/", is_debug=True, max_epochs=5)
490
 
491
  pass
492
 
 
493
  """
494
  lightning_logs/
495
  *.pth
496
  my-awesome-model
497
+
498
+ INFO:root:validate_acc = 0.5625, validate_auc = 0.5490195751190186, validate_f1_score = 0.30000001192092896, validate_precision = 0.6000000238418579, validate_recall = 0.20000000298023224
499
+ /home/soumic/Codes/mqtl-classification/venv/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.
500
+
501
+ """