| from lightning.pytorch import seed_everything |
| from lightning.pytorch.callbacks import ModelCheckpoint |
| from lightning.pytorch.callbacks.early_stopping import EarlyStopping |
| import lightning.pytorch as pl |
| import pandas as pd |
| from sklearn.model_selection import train_test_split |
| from transformers import AutoTokenizer |
| from ast import literal_eval |
| from pytorch_lightning.loggers import TensorBoardLogger |
|
|
| |
| import config |
| from model import CL_model |
| from dataset import CLDataModule |
|
|
| if __name__ == "__main__": |
| seed_everything(0, workers=True) |
| logger = TensorBoardLogger( |
| "/data/aiiih/projects/sunx/ccf_fuzzy_diag/train/prompt/logs", name="CL" |
| ) |
|
|
| query_df = pd.read_csv( |
| "/home/sunx/data/aiiih/projects/sunx/ccf_fuzzy_diag/data_proc/query_df.csv" |
| ) |
| query_df["concepts"] = query_df["concepts"].apply(literal_eval) |
| query_df["codes"] = query_df["codes"].apply(literal_eval) |
| query_df["codes"] = query_df["codes"].apply( |
| lambda x: [val for val in x if val is not None] |
| ) |
| train_df, val_df = train_test_split(query_df, test_size=config.split_ratio) |
|
|
| all_d = pd.read_csv( |
| "/home/sunx/data/aiiih/projects/sunx/ccf_fuzzy_diag/data_proc/query_all_d.csv" |
| ) |
| all_d.drop(columns=["finding_sites", "morphology"], inplace=True) |
| all_d["synonyms"] = all_d["synonyms"].apply(literal_eval) |
| all_d["ancestors"] = all_d["ancestors"].apply(literal_eval) |
| dictionary = dict(zip(all_d["concept"], all_d["synonyms"])) |
|
|
| pairs = pd.read_csv("/data/aiiih/projects/sunx/ccf_fuzzy_diag/data_proc/pairs.csv") |
|
|
| tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT") |
|
|
| data_module = CLDataModule(train_df, val_df, tokenizer, dictionary, all_d, pairs) |
| data_module.setup() |
|
|
| model = CL_model( |
| n_batches=len(data_module.train_dataset) / config.batch_size, |
| n_epochs=config.max_epochs, |
| lr=config.learning_rate, |
| mlm_weight=config.mlm_weight, |
| unfreeze=config.unfreeze_ratio, |
| ) |
|
|
| checkpoint = ModelCheckpoint( |
| dirpath="/home/sunx/data/aiiih/projects/sunx/ccf_fuzzy_diag/train/ckpt/v2", |
| filename="{epoch}-{step}", |
| save_weights_only=True, |
| save_last=True, |
| every_n_train_steps=config.ckcpt_every_n_steps, |
| monitor=None, |
| save_top_k=-1, |
| ) |
|
|
| trainer = pl.Trainer( |
| accelerator=config.accelerator, |
| devices=config.devices, |
| strategy="ddp", |
| logger=logger, |
| max_epochs=config.max_epochs, |
| min_epochs=config.min_epochs, |
| precision=config.precision, |
| callbacks=[ |
| EarlyStopping(monitor="val_loss", min_delta=1e-3, patience=2, mode="min"), |
| checkpoint, |
| ], |
| profiler="simple", |
| log_every_n_steps=config.log_every_n_steps, |
| ) |
|
|
| trainer.fit(model, data_module) |
|
|