lixiangchun commited on
Commit
25f71fa
1 Parent(s): 32d3c5a

first commit

Browse files
README.md ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # DECIDIA-code
2
+ Source code for weakly supervised classification of cancer versus control exclusively from bisulfite sequencing reads.
3
+
4
+
5
+ ## Training
6
+ ```bash
7
+ bash train.sh
8
+ ```
9
+
data/log-reads-200-patients-trn200-val200-test622-tiny.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ epoch train_loss train_acc val_loss val_acc eval_loss eval_acc
data/test.csv.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:43aacd0406822240dd50dbc6c24123fe188f5672c1243574eec0643b38c44d6b
3
+ size 15856551
data/trn.csv.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:04d798b3dd0735b3a777abcfacd1fe71600a49a8e11666bf410ee307a6f8c176
3
+ size 25539107
data/val.csv.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6dc2370881027c81e206cabbf3d361ba2f28ccb94ed7b1e801762bea579eed63
3
+ size 5028508
model.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ def exists(val):
6
+ return val is not None
7
+
8
+ def initialize_weights(module):
9
+ for m in module.modules():
10
+ if isinstance(m, nn.Linear):
11
+ nn.init.xavier_normal_(m.weight)
12
+ if m.bias is not None:
13
+ m.bias.data.zero_()
14
+
15
+ """
16
+ Attention Network with Sigmoid Gating (3 fc layers)
17
+ args:
18
+ L: input feature dimension
19
+ D: hidden layer dimension
20
+ dropout: whether to use dropout (p = 0.25)
21
+ n_classes: number of classes
22
+ """
23
+ class Attn_Net_Gated(nn.Module):
24
+
25
+ def __init__(self, L = 1024, D = 256, n_tasks = 1):
26
+ super(Attn_Net_Gated, self).__init__()
27
+ self.attention_a = nn.Sequential(nn.Linear(L, D), nn.Tanh(), nn.Dropout(0.25))
28
+ self.attention_b = nn.Sequential(nn.Linear(L, D), nn.Sigmoid(), nn.Dropout(0.25))
29
+ self.attention_c = nn.Linear(D, n_tasks)
30
+
31
+ def forward(self, x):
32
+ a = self.attention_a(x)
33
+ b = self.attention_b(x)
34
+ A = a.mul(b)
35
+ A = self.attention_c(A) # N x n_classes
36
+ return A, x
37
+
38
+
39
+ """
40
+ Code borrow from: https://github.com/mahmoodlab/TOAD
41
+
42
+ args:
43
+ gate: whether to use gating in attention network
44
+ size_args: size config of attention network
45
+ dropout: whether to use dropout in attention network
46
+ n_classes: number of classes
47
+ """
48
+
49
+ class DeepAttnMIL(nn.Module):
50
+
51
+ def __init__(self, input_dim = 1024, size_arg = "big", n_classes = 2):
52
+ super(DeepAttnMIL, self).__init__()
53
+ self.size_dict = {"small": [input_dim, 512, 256], "big": [input_dim, 512, 384]}
54
+ size = self.size_dict[size_arg]
55
+
56
+ self.attention_net = nn.Sequential(
57
+ nn.Linear(size[0], size[1]),
58
+ nn.ReLU(),
59
+ nn.Dropout(0.25),
60
+ Attn_Net_Gated(L = size[1], D = size[2], n_tasks = 1))
61
+
62
+ self.classifier = nn.Linear(size[1], n_classes)
63
+
64
+ initialize_weights(self)
65
+
66
+ def forward(self, h, return_features=False, attention_only=False):
67
+ A, h = self.attention_net(h)
68
+ A = torch.transpose(A, 1, 0)
69
+ if attention_only:
70
+ return A[0]
71
+
72
+ A = F.softmax(A, dim=1)
73
+ M = torch.mm(A, h)
74
+
75
+ if return_features:
76
+ return M
77
+
78
+ logits = self.classifier(M)
79
+
80
+ return logits
81
+
sequence_embedding/config.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "facebook/opt",
3
+ "_remove_final_layer_norm": false,
4
+ "activation_dropout": 0.0,
5
+ "activation_function": "relu",
6
+ "architectures": [
7
+ "OPTForCausalLM"
8
+ ],
9
+ "attention_dropout": 0.0,
10
+ "bos_token_id": 2,
11
+ "do_layer_norm_before": true,
12
+ "dropout": 0.1,
13
+ "eos_token_id": 2,
14
+ "ffn_dim": 1536,
15
+ "hidden_size": 384,
16
+ "init_std": 0.02,
17
+ "layerdrop": 0.0,
18
+ "max_position_embeddings": 512,
19
+ "model_type": "opt",
20
+ "num_attention_heads": 12,
21
+ "num_hidden_layers": 1,
22
+ "pad_token_id": 1,
23
+ "prefix": "</s>",
24
+ "torch_dtype": "float32",
25
+ "transformers_version": "4.21.1",
26
+ "use_cache": true,
27
+ "vocab_size": 30,
28
+ "word_embed_proj_dim": 384
29
+ }
sequence_embedding/pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fd73e58e7cfeb3faf6b3b3062ad16295066c115a9f1a57ed08d3aa5ce495fe6d
3
+ size 7942569
sequence_embedding/special_tokens_map.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": "</s>",
3
+ "eos_token": "</s>",
4
+ "pad_token": "<pad>",
5
+ "unk_token": "</s>"
6
+ }
sequence_embedding/tokenizer.json ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "version": "1.0",
3
+ "truncation": null,
4
+ "padding": null,
5
+ "added_tokens": [
6
+ {
7
+ "id": 0,
8
+ "content": "<s>",
9
+ "single_word": false,
10
+ "lstrip": false,
11
+ "rstrip": false,
12
+ "normalized": false,
13
+ "special": true
14
+ },
15
+ {
16
+ "id": 1,
17
+ "content": "<pad>",
18
+ "single_word": false,
19
+ "lstrip": false,
20
+ "rstrip": false,
21
+ "normalized": false,
22
+ "special": true
23
+ },
24
+ {
25
+ "id": 2,
26
+ "content": "</s>",
27
+ "single_word": false,
28
+ "lstrip": false,
29
+ "rstrip": false,
30
+ "normalized": false,
31
+ "special": true
32
+ },
33
+ {
34
+ "id": 3,
35
+ "content": "<unk>",
36
+ "single_word": false,
37
+ "lstrip": false,
38
+ "rstrip": false,
39
+ "normalized": false,
40
+ "special": true
41
+ }
42
+ ],
43
+ "normalizer": null,
44
+ "pre_tokenizer": {
45
+ "type": "Whitespace"
46
+ },
47
+ "post_processor": {
48
+ "type": "TemplateProcessing",
49
+ "single": [
50
+ {
51
+ "SpecialToken": {
52
+ "id": "</s>",
53
+ "type_id": 0
54
+ }
55
+ },
56
+ {
57
+ "Sequence": {
58
+ "id": "A",
59
+ "type_id": 0
60
+ }
61
+ }
62
+ ],
63
+ "pair": [
64
+ {
65
+ "Sequence": {
66
+ "id": "A",
67
+ "type_id": 0
68
+ }
69
+ },
70
+ {
71
+ "Sequence": {
72
+ "id": "B",
73
+ "type_id": 1
74
+ }
75
+ }
76
+ ],
77
+ "special_tokens": {
78
+ "</s>": {
79
+ "id": "</s>",
80
+ "ids": [
81
+ 2
82
+ ],
83
+ "tokens": [
84
+ "</s>"
85
+ ]
86
+ },
87
+ "<pad>": {
88
+ "id": "<pad>",
89
+ "ids": [
90
+ 1
91
+ ],
92
+ "tokens": [
93
+ "<pad>"
94
+ ]
95
+ },
96
+ "<s>": {
97
+ "id": "<s>",
98
+ "ids": [
99
+ 0
100
+ ],
101
+ "tokens": [
102
+ "<s>"
103
+ ]
104
+ },
105
+ "<unk>": {
106
+ "id": "<unk>",
107
+ "ids": [
108
+ 3
109
+ ],
110
+ "tokens": [
111
+ "<unk>"
112
+ ]
113
+ }
114
+ }
115
+ },
116
+ "decoder": null,
117
+ "model": {
118
+ "type": "WordLevel",
119
+ "vocab": {
120
+ "<s>": 0,
121
+ "<pad>": 1,
122
+ "</s>": 2,
123
+ "<unk>": 3,
124
+ "A": 4,
125
+ "B": 5,
126
+ "C": 6,
127
+ "D": 7,
128
+ "E": 8,
129
+ "F": 9,
130
+ "G": 10,
131
+ "H": 11,
132
+ "I": 12,
133
+ "J": 13,
134
+ "K": 14,
135
+ "L": 15,
136
+ "M": 16,
137
+ "N": 17,
138
+ "O": 18,
139
+ "P": 19,
140
+ "Q": 20,
141
+ "R": 21,
142
+ "S": 22,
143
+ "T": 23,
144
+ "U": 24,
145
+ "V": 25,
146
+ "W": 26,
147
+ "X": 27,
148
+ "Y": 28,
149
+ "Z": 29
150
+ },
151
+ "unk_token": "<unk>"
152
+ }
153
+ }
sequence_embedding/tokenizer_config.json ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": "</s>",
3
+ "eos_token": "</s>",
4
+ "model_max_length": 2048,
5
+ "pad_token": "<pad>",
6
+ "tokenizer_class": "PreTrainedTokenizerFast",
7
+ "unk_token": "</s>"
8
+ }
train.py ADDED
@@ -0,0 +1,258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/opt/software/install/miniconda38/bin/python
2
+ import argparse
3
+ parser = argparse.ArgumentParser(description='DECIDIA training program')
4
+ parser.add_argument('--input_dir', type=str, help='input directory')
5
+ parser.add_argument('--sequence_embedding', type=str, help='sequence embedding directory')
6
+ parser.add_argument('--num_hidden_layers', type=int, default=1, help='num_hidden_layers [1]')
7
+ parser.add_argument('--train_file', type=str, help='training file')
8
+ parser.add_argument('--val_file', type=str, help='validation file')
9
+ parser.add_argument('--device', type=str, help='device', default='cuda:1')
10
+ parser.add_argument('--num_classes', type=int, help='num_classes [32]', default=32)
11
+ parser.add_argument('--diseases', type=str, default=None, help='diseases included, e.g "LUAD,LUSC"')
12
+ parser.add_argument('--weight_decay', type=float, help='weight_decay [1e-5]', default=1e-5)
13
+ parser.add_argument('--modeling_context', action='store_true', help='whether use OPT to model context dependency')
14
+ parser.add_argument("--lr_scheduler_type", type=str,
15
+ choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"],
16
+ default="constant", help="The scheduler type to use.")
17
+ parser.add_argument('--pretrained_weight', type=str, help='pretrained weight')
18
+ parser.add_argument('--pretrained_cls_token', type=str, help='pretrained cls token')
19
+ parser.add_argument('--epochs', type=int, default=100, help='epochs (default: 100)')
20
+ parser.add_argument('--num_sequences', type=int, default=None, help='num of sequences to sample from training set')
21
+ parser.add_argument('--num_train_patients', type=int, default=None, help='num of patients data to sample from training set')
22
+
23
+ args = parser.parse_args()
24
+
25
+ import os
26
+ os.environ['TOKENIZERS_PARALLELISM'] = 'false'
27
+ import sys
28
+ import glob
29
+ import torch
30
+ import torch.nn as nn
31
+ from tqdm import tqdm
32
+ from torch.optim import AdamW, Adam, SGD, Adagrad
33
+ from sklearn.utils import resample
34
+ from transformers import get_scheduler
35
+ import numpy as np
36
+ import pandas as pd
37
+ import random
38
+ import time
39
+ from transformers import (
40
+ PreTrainedTokenizerFast,
41
+ OPTForCausalLM
42
+ )
43
+ from model import DeepAttnMIL
44
+
45
+ torch.set_num_threads(2)
46
+ device = args.device
47
+ random.seed(123)
48
+
49
+ tokenizer = PreTrainedTokenizerFast.from_pretrained(args.sequence_embedding)
50
+ net = OPTForCausalLM.from_pretrained(args.sequence_embedding)
51
+ net = net.to(device)
52
+ net.eval()
53
+
54
+ feature_dim = net.config.hidden_size
55
+
56
+ trn_df = pd.read_csv(f'{args.input_dir}/trn.csv.gz')
57
+ reads_per_patient = trn_df.patient.value_counts().unique()
58
+ assert len(reads_per_patient) == 1
59
+ reads_per_patient = reads_per_patient[0]
60
+ if args.num_sequences < reads_per_patient:
61
+ trn_df = pd.concat([df.sample(args.num_sequences, random_state=123) for patient, df in trn_df.groupby('patient')])
62
+
63
+ num_train_samples = len(trn_df.patient.unique())
64
+ if args.num_train_patients is None:
65
+ args.num_train_patients = num_train_samples
66
+ if args.num_train_patients < num_train_samples:
67
+ trn_df = trn_df[trn_df.patient.isin(random.sample(trn_df.patient.unique().tolist(), args.num_train_patients))]
68
+
69
+ trn_x = torch.zeros(args.num_train_patients, args.num_sequences, feature_dim)
70
+ trn_y = torch.as_tensor([-1] * args.num_train_patients)
71
+
72
+ test_df = pd.read_csv(f'{args.input_dir}/test.csv.gz')
73
+ num_test_samples = len(test_df.patient.unique())
74
+ test_x = torch.zeros(num_test_samples, reads_per_patient, feature_dim)
75
+ test_y = torch.as_tensor([-1] * num_test_samples)
76
+ test_patients = []
77
+
78
+ val_df = pd.read_csv(f'{args.input_dir}/val.csv.gz')
79
+ num_val_samples = len(val_df.patient.unique())
80
+ val_x = torch.zeros(num_val_samples, reads_per_patient, feature_dim)
81
+ val_y = torch.as_tensor([-1] * num_val_samples)
82
+ val_patients = []
83
+
84
+
85
+ pad_token_id = net.config.pad_token_id
86
+
87
+
88
+ for i, (patient, e) in tqdm(enumerate(trn_df.groupby('patient')), total=args.num_train_patients):
89
+ a = [' '.join(list(s)) for s in e.seq]
90
+ inputs = tokenizer(a, max_length=100, padding='max_length', truncation=True, return_tensors='pt', return_token_type_ids=False)
91
+ for k, v in inputs.items():inputs[k] = v.to(device)
92
+ with torch.inference_mode():
93
+ out = net.model(**inputs)
94
+ features = out.last_hidden_state.mean(1).cpu()
95
+ trn_x[i] = features
96
+ trn_y[i] = e.label.tolist()[0]
97
+
98
+
99
+ for i, (patient, e) in tqdm(enumerate(test_df.groupby('patient')), total=num_test_samples):
100
+ a = [' '.join(list(s)) for s in e.seq]
101
+ inputs = tokenizer(a, max_length=100, padding='max_length', truncation=True, return_tensors='pt', return_token_type_ids=False)
102
+ for k, v in inputs.items():inputs[k] = v.to(device)
103
+ with torch.inference_mode():
104
+ out = net.model(**inputs)
105
+ features = out.last_hidden_state.mean(1).cpu()
106
+ test_x[i] = features
107
+ test_y[i] = e.label.tolist()[0]
108
+ test_patients.append(patient)
109
+
110
+ for i, (patient, e) in tqdm(enumerate(val_df.groupby('patient')), total=num_val_samples):
111
+ a = [' '.join(list(s)) for s in e.seq]
112
+ inputs = tokenizer(a, max_length=100, padding='max_length', truncation=True, return_tensors='pt', return_token_type_ids=False)
113
+ for k, v in inputs.items():inputs[k] = v.to(device)
114
+ with torch.inference_mode():
115
+ out = net.model(**inputs)
116
+ features = out.last_hidden_state.mean(1).cpu()
117
+ val_x[i] = features
118
+ val_y[i] = e.label.tolist()[0]
119
+ val_patients.append(patient)
120
+
121
+
122
+
123
+ fout = open(f'{args.input_dir}/log-reads-{args.num_sequences}-patients-trn{args.num_train_patients}-val{num_val_samples}-test{num_test_samples}-tiny.txt', 'w')
124
+ print("epoch\ttrain_loss\ttrain_acc\tval_loss\tval_acc\teval_loss\teval_acc", file=fout)
125
+
126
+ model = DeepAttnMIL(input_dim=feature_dim, n_classes=args.num_classes, size_arg='big')
127
+
128
+
129
+ if args.pretrained_weight:
130
+ state_dict = torch.load(args.pretrained_weight, map_location='cpu')
131
+ if state_dict['classifier.weight'].size(0) != args.num_classes:
132
+ del state_dict['classifier.weight']
133
+ del state_dict['classifier.bias']
134
+
135
+ msg = model.load_state_dict(state_dict, strict=False)
136
+ print(msg)#, file=fout)
137
+
138
+ model = model.to(device)
139
+
140
+ print(model)#, file=fout)
141
+
142
+
143
+ criterion = nn.CrossEntropyLoss()
144
+
145
+ no_decay = ["bias", "LayerNorm.weight"]
146
+ optimizer_grouped_parameters = [
147
+ {
148
+ "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
149
+ "weight_decay": 1e-5,
150
+ },
151
+ {
152
+ "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
153
+ "weight_decay": 0.0,
154
+ },
155
+ ]
156
+ opt = AdamW(optimizer_grouped_parameters, lr=2e-5)
157
+
158
+
159
+ num_update_steps_per_epoch = len(trn_df)
160
+ max_train_steps = args.epochs * num_update_steps_per_epoch
161
+ lr_scheduler = get_scheduler(name=args.lr_scheduler_type, optimizer=opt, num_warmup_steps=num_update_steps_per_epoch*1, num_training_steps=max_train_steps)
162
+
163
+
164
+ best_eval_acc = 0.0
165
+ best_eval_loss = 100000.0
166
+ best_val_loss = 100000.0
167
+ for epoch in range(args.epochs):
168
+ model.train()
169
+ total_loss, total_batch, total_num, correct_k = 0, 0, 0, 0
170
+ idxs = random.sample(range(len(trn_y)), len(trn_y))
171
+ for idx in idxs:
172
+ x = trn_x[idx]
173
+ y = trn_y[idx].unsqueeze(0)
174
+ x = x.to(device)
175
+ y = y.to(device)
176
+
177
+ logit = model(x)
178
+ loss = criterion(logit, y)
179
+
180
+ opt.zero_grad()
181
+ loss.backward()
182
+ opt.step()
183
+ lr_scheduler.step()
184
+
185
+ total_loss += loss.item()
186
+ total_batch += 1
187
+ total_num += len(y)
188
+ correct_k += logit.argmax(1).eq(y).sum()
189
+
190
+ train_acc = correct_k / total_num
191
+ train_loss = total_loss / total_batch
192
+
193
+ #######Evalutate on test set ################
194
+ model.eval()
195
+ total_loss, total_batch, total_num, correct_k = 0, 0, 0, 0
196
+ eval_probs = []
197
+ for x, y, pid in zip(test_x, test_y, test_patients):
198
+ y = y.unsqueeze(0).to(device)
199
+ x = x.to(device)
200
+
201
+ with torch.inference_mode():
202
+ logit = model(x)
203
+ loss = criterion(logit, y)
204
+
205
+ eval_probs.append(logit.flatten().softmax(0).tolist())
206
+
207
+ total_loss += loss.item()
208
+ total_batch += 1
209
+ total_num += len(y)
210
+ correct_k += logit.argmax(1).eq(y).sum()
211
+
212
+ eval_acc = correct_k / total_num
213
+ eval_loss = total_loss / total_batch
214
+
215
+ #######Evalutate on val set ################
216
+ model.eval()
217
+ total_loss, total_batch, total_num, correct_k = 0, 0, 0, 0
218
+ val_probs = []
219
+ for x, y, pid in zip(val_x, val_y, val_patients):
220
+ y = y.unsqueeze(0).to(device)
221
+ x = x.to(device)
222
+
223
+ with torch.inference_mode():
224
+ logit = model(x)
225
+ loss = criterion(logit, y)
226
+
227
+ val_probs.append(logit.flatten().softmax(0).tolist())
228
+
229
+ total_loss += loss.item()
230
+ total_batch += 1
231
+ total_num += len(y)
232
+ correct_k += logit.argmax(1).eq(y).sum()
233
+
234
+ val_acc = correct_k / total_num
235
+ val_loss = total_loss / total_batch
236
+
237
+
238
+ print(f"{epoch+1}\t{train_loss}\t{train_acc}\t{val_loss}\t{val_acc}\t{eval_loss}\t{eval_acc}", file=fout)
239
+ fout.flush()
240
+
241
+ if val_loss < best_val_loss:
242
+ torch.save(model.state_dict(), f'{args.input_dir}/model-reads-{args.num_sequences}-patients-trn{args.num_train_patients}-val{num_val_samples}-test{num_test_samples}-tiny.pt')
243
+ best_val_loss = val_loss
244
+
245
+ eval_probs = pd.DataFrame(eval_probs, columns=['p_normal', 'p_cancer'])
246
+ info = pd.DataFrame({'patient':test_patients, 'label':test_y.tolist()})
247
+ info = pd.concat([info, eval_probs], axis=1)
248
+ info.to_csv(f'{args.input_dir}/test_prediction-reads-{args.num_sequences}-patients-trn{args.num_train_patients}-val{num_val_samples}-test{num_test_samples}-tiny.csv', index=False)
249
+
250
+ val_probs = pd.DataFrame(val_probs, columns=['p_normal', 'p_cancer'])
251
+ info = pd.DataFrame({'patient':val_patients, 'label':val_y.tolist()})
252
+ info = pd.concat([info, val_probs], axis=1)
253
+ info.to_csv(f'{args.input_dir}/val_prediction-reads-{args.num_sequences}-patients-trn{args.num_train_patients}-val{num_val_samples}-test{num_test_samples}-tiny.csv', index=False)
254
+
255
+ fout.close()
256
+
257
+
258
+
train.sh ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ python train.py \
4
+ --sequence_embedding sequence_embedding \
5
+ --input_dir data \
6
+ --num_classes 2 \
7
+ --device 'cuda:0' \
8
+ --num_train_patients 200 \
9
+ --num_sequences 200
10
+