sxtforreal commited on
Commit
3957f36
1 Parent(s): 8647d89

Create train.py

Browse files

Run this file to train models.

Files changed (1) hide show
  1. train.py +257 -0
train.py ADDED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from lightning.pytorch import seed_everything
2
+ from lightning.pytorch.callbacks import ModelCheckpoint
3
+ from lightning.pytorch.callbacks.early_stopping import EarlyStopping
4
+ import lightning.pytorch as pl
5
+ from pytorch_lightning.loggers import TensorBoardLogger
6
+ import pandas as pd
7
+ from sklearn.model_selection import train_test_split
8
+ from transformers import AutoTokenizer
9
+ from ast import literal_eval
10
+
11
+ # imports from our own modules
12
+ import config
13
+ from model import (
14
+ BERTContrastiveLearning_simcse,
15
+ BERTContrastiveLearning_simcse_w,
16
+ BERTContrastiveLearning_samp,
17
+ BERTContrastiveLearning_samp_w,
18
+ )
19
+ from dataset import (
20
+ ContrastiveLearningDataModule_simcse,
21
+ ContrastiveLearningDataModule_simcse_w,
22
+ ContrastiveLearningDataModule_samp,
23
+ ContrastiveLearningDataModule_samp_w,
24
+ )
25
+
26
+ if __name__ == "__main__":
27
+ seed_everything(0, workers=True)
28
+
29
+ # Initialize tensorboard logger
30
+ logger = TensorBoardLogger("logs", name="MIMIC-tr")
31
+
32
+ query_df = pd.read_csv(
33
+ "/home/sunx/data/aiiih/projects/sunx/ccf_fuzzy_diag/data_proc/mimic_data/processed_train/processed.csv"
34
+ )
35
+ # query_df = query_df.head(1000)
36
+ query_df["concepts"] = query_df["concepts"].apply(literal_eval)
37
+ query_df["codes"] = query_df["codes"].apply(literal_eval)
38
+ query_df["codes"] = query_df["codes"].apply(
39
+ lambda x: [val for val in x if val is not None]
40
+ ) # remove None in lists
41
+ query_df = query_df.drop(columns=["one_hot"])
42
+ train_df, val_df = train_test_split(query_df, test_size=config.split_ratio)
43
+
44
+ tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
45
+
46
+ sim_df = pd.read_csv(
47
+ "/home/sunx/data/aiiih/projects/sunx/ccf_fuzzy_diag/data_proc/pairwise_scores.csv"
48
+ )
49
+
50
+ all_d = pd.read_csv(
51
+ "/home/sunx/data/aiiih/projects/sunx/ccf_fuzzy_diag/data_proc/all_d_full.csv"
52
+ )
53
+ all_d["synonyms"] = all_d["synonyms"].apply(literal_eval)
54
+ all_d["ancestors"] = all_d["ancestors"].apply(literal_eval)
55
+ dictionary = dict(zip(all_d["concept"], all_d["synonyms"]))
56
+
57
+ # SimCSE
58
+ data_module1 = ContrastiveLearningDataModule_simcse(
59
+ train_df,
60
+ val_df,
61
+ tokenizer,
62
+ )
63
+ data_module1.setup()
64
+
65
+ print("Number of training data:", len(data_module1.train_dataset))
66
+ print("Number of validation data:", len(data_module1.val_dataset))
67
+
68
+ model1 = BERTContrastiveLearning_simcse(
69
+ n_batches=len(data_module1.train_dataset) / config.batch_size,
70
+ n_epochs=config.max_epochs,
71
+ lr=config.learning_rate,
72
+ unfreeze=config.unfreeze_ratio,
73
+ )
74
+
75
+ checkpoint1 = ModelCheckpoint(
76
+ dirpath="/home/sunx/data/aiiih/projects/sunx/ccf_fuzzy_diag/train/ckpt/simcse/v1",
77
+ filename="{epoch}-{step}",
78
+ # save_weights_only=True,
79
+ save_last=True,
80
+ every_n_train_steps=config.log_every_n_steps,
81
+ monitor=None,
82
+ save_top_k=-1,
83
+ )
84
+
85
+ trainer1 = pl.Trainer(
86
+ accelerator=config.accelerator,
87
+ devices=config.devices,
88
+ strategy="ddp",
89
+ logger=logger,
90
+ max_epochs=config.max_epochs,
91
+ min_epochs=config.min_epochs,
92
+ precision=config.precision,
93
+ callbacks=[
94
+ EarlyStopping(
95
+ monitor="validation_loss", min_delta=1e-3, patience=3, mode="min"
96
+ ),
97
+ checkpoint1,
98
+ ],
99
+ profiler="simple",
100
+ log_every_n_steps=config.log_every_n_steps,
101
+ )
102
+
103
+ trainer1.fit(model1, data_module1)
104
+
105
+ # SimCSE_w
106
+ data_module2 = ContrastiveLearningDataModule_simcse_w(
107
+ train_df,
108
+ val_df,
109
+ query_df,
110
+ tokenizer,
111
+ sim_df,
112
+ all_d,
113
+ )
114
+ data_module2.setup()
115
+
116
+ print("Number of training data:", len(data_module2.train_dataset))
117
+ print("Number of validation data:", len(data_module2.val_dataset))
118
+
119
+ model2 = BERTContrastiveLearning_simcse_w(
120
+ n_batches=len(data_module2.train_dataset) / config.batch_size,
121
+ n_epochs=config.max_epochs,
122
+ lr=config.learning_rate,
123
+ unfreeze=config.unfreeze_ratio,
124
+ )
125
+
126
+ checkpoint2 = ModelCheckpoint(
127
+ dirpath="/home/sunx/data/aiiih/projects/sunx/ccf_fuzzy_diag/train/ckpt/simcse_w/v1",
128
+ filename="{epoch}-{step}",
129
+ # save_weights_only=True,
130
+ save_last=True,
131
+ every_n_train_steps=config.log_every_n_steps,
132
+ monitor=None,
133
+ save_top_k=-1,
134
+ )
135
+
136
+ trainer2 = pl.Trainer(
137
+ accelerator=config.accelerator,
138
+ devices=config.devices,
139
+ strategy="ddp",
140
+ logger=logger,
141
+ max_epochs=config.max_epochs,
142
+ min_epochs=config.min_epochs,
143
+ precision=config.precision,
144
+ callbacks=[
145
+ EarlyStopping(
146
+ monitor="validation_loss", min_delta=1e-3, patience=3, mode="min"
147
+ ),
148
+ checkpoint2,
149
+ ],
150
+ profiler="simple",
151
+ log_every_n_steps=config.log_every_n_steps,
152
+ )
153
+
154
+ trainer2.fit(model2, data_module2)
155
+
156
+ # Samp
157
+ data_module3 = ContrastiveLearningDataModule_samp(
158
+ train_df,
159
+ val_df,
160
+ query_df,
161
+ tokenizer,
162
+ dictionary,
163
+ sim_df,
164
+ )
165
+ data_module3.setup()
166
+
167
+ print("Number of training data:", len(data_module3.train_dataset))
168
+ print("Number of validation data:", len(data_module3.val_dataset))
169
+
170
+ model3 = BERTContrastiveLearning_samp(
171
+ n_batches=len(data_module3.train_dataset) / config.batch_size,
172
+ n_epochs=config.max_epochs,
173
+ lr=config.learning_rate,
174
+ unfreeze=config.unfreeze_ratio,
175
+ )
176
+
177
+ checkpoint3 = ModelCheckpoint(
178
+ dirpath="/home/sunx/data/aiiih/projects/sunx/ccf_fuzzy_diag/train/ckpt/samp/v1",
179
+ filename="{epoch}-{step}",
180
+ # save_weights_only=True,
181
+ save_last=True,
182
+ every_n_train_steps=config.log_every_n_steps,
183
+ monitor=None,
184
+ save_top_k=-1,
185
+ )
186
+
187
+ trainer3 = pl.Trainer(
188
+ accelerator=config.accelerator,
189
+ devices=config.devices,
190
+ strategy="ddp",
191
+ logger=logger,
192
+ max_epochs=config.max_epochs,
193
+ min_epochs=config.min_epochs,
194
+ precision=config.precision,
195
+ callbacks=[
196
+ EarlyStopping(
197
+ monitor="validation_loss", min_delta=1e-3, patience=3, mode="min"
198
+ ),
199
+ checkpoint3,
200
+ ],
201
+ profiler="simple",
202
+ log_every_n_steps=config.log_every_n_steps,
203
+ )
204
+
205
+ trainer3.fit(model3, data_module3)
206
+
207
+ # Samp_w
208
+ data_module4 = ContrastiveLearningDataModule_samp_w(
209
+ train_df,
210
+ val_df,
211
+ query_df,
212
+ tokenizer,
213
+ dictionary,
214
+ sim_df,
215
+ all_d,
216
+ )
217
+ data_module4.setup()
218
+
219
+ print("Number of training data:", len(data_module4.train_dataset))
220
+ print("Number of validation data:", len(data_module4.val_dataset))
221
+
222
+ model4 = BERTContrastiveLearning_samp_w(
223
+ n_batches=len(data_module4.train_dataset) / config.batch_size,
224
+ n_epochs=config.max_epochs,
225
+ lr=config.learning_rate,
226
+ unfreeze=config.unfreeze_ratio,
227
+ )
228
+
229
+ checkpoint4 = ModelCheckpoint(
230
+ dirpath="/home/sunx/data/aiiih/projects/sunx/ccf_fuzzy_diag/train/ckpt/samp_w/v1",
231
+ filename="{epoch}-{step}",
232
+ # save_weights_only=True,
233
+ save_last=True,
234
+ every_n_train_steps=config.log_every_n_steps,
235
+ monitor=None,
236
+ save_top_k=-1,
237
+ )
238
+
239
+ trainer4 = pl.Trainer(
240
+ accelerator=config.accelerator,
241
+ devices=config.devices,
242
+ strategy="ddp",
243
+ logger=logger,
244
+ max_epochs=config.max_epochs,
245
+ min_epochs=config.min_epochs,
246
+ precision=config.precision,
247
+ callbacks=[
248
+ EarlyStopping(
249
+ monitor="validation_loss", min_delta=1e-3, patience=3, mode="min"
250
+ ),
251
+ checkpoint4,
252
+ ],
253
+ profiler="simple",
254
+ log_every_n_steps=config.log_every_n_steps,
255
+ )
256
+
257
+ trainer4.fit(model4, data_module4)