|
from PIL import ImageFile |
|
ImageFile.LOAD_TRUNCATED_IMAGES = True |
|
import os |
|
import ast |
|
import random |
|
import torch |
|
import pandas as pd |
|
from functools import partial |
|
from transformers import TrainingArguments, AutoTokenizer, HfArgumentParser |
|
from utils.my_trainer import CustomTrainer |
|
from utils.utils import my_compute_metrics,seed_everything |
|
from typing import Optional |
|
from dataclasses import dataclass, field |
|
from model.my_model import WPathVLM |
|
from model.my_model_vision import WPathVLM as WPathVLM_Vision |
|
from peft import LoraConfig, get_peft_model |
|
from datasets import load_dataset, concatenate_datasets, load_from_disk |
|
from utils.data_collator import MyDataCollatorForWPathVLM |
|
from utils.formatting_funcs import wsi_formatting_des, wsi_formatting_qa_open, wsi_formatting_qa_close |
|
|
|
@dataclass |
|
class ScriptArguments: |
|
""" |
|
The name of the Casual LM model we wish to fine with SFTTrainer |
|
""" |
|
|
|
|
|
gpu: Optional[str] = field(default="0", metadata={"help": "gpu"}) |
|
load_in_8bit: Optional[bool] = field(default=False, metadata={"help": "load the model in 8 bits precision"}) |
|
load_in_4bit: Optional[bool] = field(default=False, metadata={"help": "load the model in 4 bits precision"}) |
|
trust_remote_code: Optional[bool] = field(default=False, metadata={"help": "Enable `trust_remote_code`"}) |
|
token: Optional[bool] = field(default=True, metadata={"help": "Use HF auth token to access the model"}) |
|
seed: Optional[int] = field(default=42, metadata={"help": "seed"}) |
|
|
|
|
|
llm_name: Optional[str] = field(default="/data_local/pxb/LLM_models/llama3/llama3.1-8b-instruct", metadata={"help": "the model name, mistralai/Mistral-7B-Instruct-v0.2, meta-llama/Meta-Llama-3-8B"}) |
|
vision_adaptor: Optional[bool] = field(default=False, metadata={"help": "True or False (with interaction with text), using for longnet and qformer."}) |
|
hierachical_token: Optional[bool] = field(default=True, metadata={"help": "True or False"}) |
|
hierachical_adaptor: Optional[bool] = field(default=True, metadata={"help": "True or False, only for longnet and qformer"}) |
|
|
|
|
|
select_data_num: Optional[int] = field(default=-1, metadata={"help": "the number of training data, -1 mean use all data"}) |
|
dataset_name_list: Optional[str] = field(default="CNX-PathLLM/TCGA-WSI-Description,CNX-PathLLM/GTEx-WSI-Description") |
|
dataset_text_field: Optional[str] = field(default="text", metadata={"help": "the text field of the dataset"}) |
|
data_cache_dir: Optional[str] = field(default="/data_local/pxb/CNX-PathLLM/.cache", metadata={"help": "the cache dir the dataset and model, /bask/projects/p/phwq4930-gbm/Zeyu/PathVLM/.cache"}) |
|
data_local_dir: Optional[str] = field(default=None, metadata={"help": "if not None, load from local"}) |
|
fea_root: Optional[str] = field(default="/data_local/pxb/CNX-PathLLM/GTEx-TCGA-Embeddings", metadata={"help": "the root path for WSI feature"}) |
|
gmm_root: Optional[str] = field(default="/data_local/pxb/CNX-PathLLM/GMM_PT", metadata={"help": "the root path for WSI feature"}) |
|
ckpt_path: Optional[str] = field(default=None, metadata={"help": "ckpt path"}) |
|
|
|
|
|
|
|
log_with: Optional[str] = field(default="wandb", metadata={"help": "use 'wandb' to log with wandb"}) |
|
output_dir: Optional[str] = field(default="/data_local/pxb/LLM_output/test_merge", metadata={"help": "the output directory"}) |
|
logging_steps: Optional[int] = field(default=5, metadata={"help": "the number of logging steps"}) |
|
max_steps: Optional[int] = field(default=-1, metadata={"help": "the number of training steps"}) |
|
warmup_steps: Optional[int] = field(default=20, metadata={"help": "the number of warmup steps"}) |
|
save_steps: Optional[int] = field(default=120, metadata={"help": "Number of updates steps before two checkpoint saves"}) |
|
save_total_limit: Optional[int] = field(default=10, metadata={"help": "Limits total number of checkpoints."}) |
|
|
|
llm_requires_grad: Optional[bool] = field(default=False, metadata={"help": "True or /output/checkpoint-1400"}) |
|
resume_from_checkpoint: Optional[bool] = field(default=False, metadata={"help": "True or /output/checkpoint-1400"}) |
|
|
|
|
|
learning_rate: Optional[float] = field(default=2.0e-5, metadata={"help": "the learning rate"}) |
|
train_batch_size: Optional[int] = field(default=40, metadata={"help": "the batch size"}) |
|
eval_batch_size: Optional[int] = field(default=48, metadata={"help": "the batch size"}) |
|
max_seq_length: Optional[int] = field(default=512, metadata={"help": "Input sequence length"}) |
|
gradient_accumulation_steps: Optional[int] = field(default=8, metadata={"help": "the number of gradient accumulation steps"}) |
|
num_train_epochs: Optional[int] = field(default=5, metadata={"help": "the number of training epochs"}) |
|
|
|
|
|
n_level: Optional[int] = field(default=3, metadata={"help": "the number of herachical levels for WSI embedding"}) |
|
embed_dim: Optional[int] = field(default=512, metadata={"help": "embedding dimension of each patch, conch: 512, gmm: 2*d+1"}) |
|
agg_strategy: Optional[str] = field(default='gmm,longnet', metadata={"help": "the strategy for WSI aggregation, sample, kmeans, gmm, abmil, qformer, longnet"}) |
|
n_heads: Optional[str] = field(default='32,16,8', metadata={"help": "the number of attention heads for WSI aggregation, for sample and abmil"}) |
|
|
|
|
|
evaluation_strategy: Optional[str] = field(default="steps", metadata={"help": "epoch, step"}) |
|
eval_steps: Optional[int] = field(default=100000, metadata={"help": "eval_steps"}) |
|
|
|
|
|
push_to_hub: Optional[bool] = field(default=False, metadata={"help": "Push the model to HF Hub"}) |
|
use_peft: Optional[bool] = field(default=False, metadata={"help": "Wether to use PEFT or not to train adapters"}) |
|
peft_lora_r: Optional[int] = field(default=64, metadata={"help": "the r parameter of the LoRA adapters"}) |
|
peft_lora_alpha: Optional[int] = field(default=16, metadata={"help": "the alpha parameter of the LoRA adapters"}) |
|
|
|
parser = HfArgumentParser(ScriptArguments) |
|
script_args = parser.parse_args_into_dataclasses()[0] |
|
seed_everything(script_args.seed) |
|
|
|
|
|
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID" |
|
os.environ["CUDA_VISIBLE_DEVICES"] = script_args.gpu |
|
device = 'cuda' |
|
print(script_args) |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(script_args.llm_name) |
|
|
|
tokenizer.pad_token = "<|finetune_right_pad_id|>" |
|
tokenizer.padding_side = 'right' |
|
tokenizer.truncation_side = 'right' |
|
|
|
if script_args.hierachical_token: |
|
new_tokens = ['<|Question|>', '<|Prompt|>', '<|Answer|>', '<|Image|>', '<|High|>', '<|`Mid`|>', '<|Low|>'] |
|
else: |
|
new_tokens = ['<|Question|>', '<|Prompt|>', '<|Answer|>', '<|Image|>'] |
|
|
|
num_added_toks = tokenizer.add_special_tokens({"additional_special_tokens": new_tokens}) |
|
|
|
new_tokens_ids = tokenizer.convert_tokens_to_ids(new_tokens) |
|
print("new_tokens_ids: ", new_tokens_ids) |
|
|
|
if script_args.select_data_num>0: |
|
split_text = "train[:{}]".format(script_args.select_data_num) |
|
else: |
|
split_text = "train" |
|
|
|
|
|
dataset = [] |
|
|
|
for dataset_name in script_args.dataset_name_list.split(","): |
|
columns_to_remove = ['slide_id'] |
|
one_dataset = load_dataset(dataset_name, split=split_text, cache_dir=script_args.data_cache_dir) |
|
if 'project' in one_dataset.column_names: |
|
columns_to_remove.append('project') |
|
elif 'site' in one_dataset.column_names: |
|
columns_to_remove.append('site') |
|
|
|
if 'QA' in dataset_name: |
|
columns_to_remove += ['question', 'answer'] |
|
if 'Open' in dataset_name: |
|
one_dataset = one_dataset.map(wsi_formatting_qa_open, fn_kwargs={'tokenizer': tokenizer}, |
|
num_proc=20, remove_columns=columns_to_remove) |
|
else: |
|
one_dataset = one_dataset.map(wsi_formatting_qa_close, fn_kwargs={'tokenizer': tokenizer}, |
|
num_proc=20, remove_columns=columns_to_remove) |
|
else: |
|
columns_to_remove += ['description'] |
|
one_dataset = one_dataset.map(wsi_formatting_des, fn_kwargs={'tokenizer': tokenizer}, |
|
num_proc=20, remove_columns=columns_to_remove) |
|
dataset.append(one_dataset) |
|
|
|
dataset = concatenate_datasets(dataset) |
|
|
|
|
|
|
|
|
|
train_dataset = dataset |
|
eval_dataset = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print(train_dataset) |
|
print(eval_dataset) |
|
|
|
if script_args.vision_adaptor: |
|
model = WPathVLM_Vision(script_args.llm_requires_grad, |
|
script_args.load_in_8bit, |
|
script_args.load_in_4bit, |
|
script_args.llm_name, |
|
script_args.trust_remote_code, |
|
script_args.token, |
|
tokenizer, |
|
image_token_id = new_tokens_ids[3:], |
|
n_heads = script_args.n_heads, |
|
n_level = script_args.n_level, |
|
embed_dim = script_args.embed_dim, |
|
agg_strategy = script_args.agg_strategy, |
|
hierachical_token = script_args.hierachical_token, |
|
hierachical_adaptor=script_args.hierachical_adaptor, |
|
data_cache_dir = script_args.data_cache_dir, |
|
) |
|
else: |
|
model = WPathVLM(script_args.llm_requires_grad, |
|
script_args.load_in_8bit, |
|
script_args.load_in_4bit, |
|
script_args.llm_name, |
|
script_args.trust_remote_code, |
|
script_args.token, |
|
tokenizer, |
|
image_token_id = new_tokens_ids[3:], |
|
n_heads = script_args.n_heads, |
|
n_level = script_args.n_level, |
|
embed_dim = script_args.embed_dim, |
|
agg_strategy = script_args.agg_strategy, |
|
hierachical_token = script_args.hierachical_token, |
|
hierachical_adaptor=script_args.hierachical_adaptor, |
|
data_cache_dir = script_args.data_cache_dir, |
|
) |
|
|
|
model.print_parameter_counts() |
|
model.print_llm_parameters() |
|
|
|
print("output dir is set to: {}".format(script_args.output_dir)) |
|
|
|
training_args = TrainingArguments( |
|
output_dir=script_args.output_dir, |
|
per_device_train_batch_size=script_args.train_batch_size, |
|
per_device_eval_batch_size=script_args.eval_batch_size, |
|
gradient_accumulation_steps=script_args.gradient_accumulation_steps, |
|
|
|
learning_rate=script_args.learning_rate, |
|
lr_scheduler_type="constant_with_warmup", |
|
logging_steps=script_args.logging_steps, |
|
num_train_epochs=script_args.num_train_epochs, |
|
max_steps=script_args.max_steps, |
|
report_to=script_args.log_with, |
|
save_steps=script_args.save_steps, |
|
save_total_limit=script_args.save_total_limit, |
|
bf16=True, |
|
warmup_steps=script_args.warmup_steps, |
|
evaluation_strategy=script_args.evaluation_strategy, |
|
eval_steps=script_args.eval_steps, |
|
logging_first_step=True, |
|
remove_unused_columns=False, |
|
label_names=["labels"] |
|
) |
|
|
|
if script_args.use_peft: |
|
peft_config = LoraConfig( |
|
r=script_args.peft_lora_r, |
|
lora_alpha=script_args.peft_lora_alpha, |
|
bias="none", |
|
task_type="CAUSAL_LM", |
|
|
|
|
|
|
|
) |
|
model.llm = get_peft_model(model.llm, peft_config) |
|
model.llm.print_trainable_parameters() |
|
else: |
|
peft_config = None |
|
|
|
if script_args.ckpt_path is not None: |
|
model.load_state_dict(torch.load(script_args.ckpt_path, map_location=device), strict=False) |
|
|
|
print("load pre-trained model from: {}".format(script_args.ckpt_path)) |
|
model.print_llm_parameters() |
|
else: |
|
print("no pretrained weights loaded from users!") |
|
|
|
data_collator = MyDataCollatorForWPathVLM(tokenizer=tokenizer, |
|
fea_root=script_args.fea_root, |
|
gmm_root = script_args.gmm_root, |
|
fea_dim=script_args.embed_dim, |
|
n_level=script_args.n_level, |
|
n_heads=list(map(int, script_args.n_heads.split(','))), |
|
agg_strategy=script_args.agg_strategy) |
|
|
|
trainer = CustomTrainer( |
|
model=model, |
|
args=training_args, |
|
max_seq_length=script_args.max_seq_length, |
|
train_dataset=train_dataset, |
|
eval_dataset=eval_dataset, |
|
dataset_text_field=script_args.dataset_text_field, |
|
peft_config=None, |
|
tokenizer=tokenizer, |
|
data_collator=data_collator, |
|
compute_metrics=my_compute_metrics, |
|
) |
|
|
|
trainer.train(resume_from_checkpoint=script_args.resume_from_checkpoint) |