|
from llava.model.builder import load_pretrained_model |
|
from llava.mm_utils import get_model_name_from_path, process_images, tokenizer_image_token |
|
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, IGNORE_INDEX |
|
from llava.conversation import conv_templates, SeparatorStyle |
|
from peft import LoraConfig, get_peft_model, PeftModel |
|
from PIL import Image |
|
import requests |
|
import copy |
|
import torch |
|
import argparse |
|
from dataset.SurgDataset import SurgDataset |
|
from accelerate import Accelerator |
|
from llava.model.SurgLLaVA import SurgLLaVA |
|
import os |
|
from tqdm import tqdm |
|
import json |
|
os.environ['TORCH_USE_CUDA_DSA'] = '1' |
|
os.environ['TOKENIZERS_PARALLELISM'] = '1' |
|
os.environ['TORCH_DISTRIBUTED_DEBUG'] = 'DETAIL' |
|
|
|
def parse_args(): |
|
parser = argparse.ArgumentParser() |
|
|
|
parser.add_argument('--data_path', type=str, default='/mnt1/lyc/llava_finetune/data_json/instruct_sample_18430_0713_rephrase', help='Data path') |
|
parser.add_argument('--mode', type=str, default='train', choices=['train', 'test']) |
|
parser.add_argument('--wandb', action='store_true') |
|
parser.add_argument('--wandb_project', type=str, default='SurgLlaVA') |
|
parser.add_argument('--wandb_process_name', type=str, default='finetune') |
|
parser.add_argument('--lora_rank', type=int, default=64, help='Rank of the LoRA matrix') |
|
parser.add_argument('--lr', type=float, default=1e-5, help='Learning rate') |
|
parser.add_argument('--batch_size', type=int, default=1, help='Batch size') |
|
parser.add_argument('--log_interval', type=int, default=1) |
|
parser.add_argument('--eval_interval', type=int, default=3) |
|
parser.add_argument('--save_interval', type=int, default=3) |
|
parser.add_argument('--ckpt_dir', type=str, default='model_ckpt', help='Model directory to store checkpoints') |
|
parser.add_argument('--model_name', type=str, default='llava3_mix_instr', help='Model name. This will be used to create a directory in ckpt_dir and show in wandb') |
|
parser.add_argument('--gradient_accumulation_steps', type=int, default=12) |
|
parser.add_argument('--step_size', type=int, default=300) |
|
parser.add_argument('--gamma', type=float, default=0.95, help='gemma value of scheduler') |
|
parser.add_argument('--num_epochs', type=int, default=1000) |
|
parser.add_argument('--lora', action='store_true', help='Use LoRA if True') |
|
parser.add_argument('--test', action='store_true') |
|
parser.add_argument('--lora_ckpt_path', type=str, default=None) |
|
parser.add_argument('--ckpt_path', type=str, default=None) |
|
parser.add_argument('--output_dir', type=str, default='4dor_output', help='output file path, which will store output text.') |
|
return parser.parse_args() |
|
def main(): |
|
args = parse_args() |
|
accelerator = Accelerator(project_dir=os.path.join(args.ckpt_dir, args.model_name), |
|
log_with="wandb" if args.wandb else None, |
|
gradient_accumulation_steps=args.gradient_accumulation_steps) |
|
|
|
if args.wandb: |
|
print(f'[INFO] Using wandb for logging...') |
|
accelerator.init_trackers( |
|
project_name=args.wandb_project, |
|
config=args, |
|
init_kwargs={"wandb": {"name": args.wandb_process_name}} |
|
) |
|
accelerator.print("[Info] Using wandb for logging...") |
|
pretrained = "lmms-lab/llama3-llava-next-8b" |
|
model_name = "llava_llama3" |
|
tokenizer, llm_model, image_processor, max_length = load_pretrained_model(pretrained, None, model_name, device_map='cuda') |
|
|
|
if tokenizer.pad_token is None: |
|
tokenizer.add_special_tokens({'pad_token': '[PAD]'}) |
|
tokenizer.add_special_tokens({'pad_token': tokenizer.eos_token}) |
|
train_dataset = SurgDataset(args, image_processor, llm_model.config, mode='train') |
|
test_dataset = SurgDataset(args, image_processor, llm_model.config, mode='test') |
|
train_dataloader = torch.utils.data.DataLoader(train_dataset, shuffle=True, batch_size=args.batch_size, num_workers=4) |
|
test_dataloader = torch.utils.data.DataLoader(test_dataset, shuffle=True, batch_size=args.batch_size, num_workers=4) |
|
|
|
print(f'[INFO] Freezing llm model') |
|
for param in llm_model.parameters(): |
|
param.requires_grad = False |
|
llm_model.eval() |
|
|
|
if args.lora: |
|
if args.lora_ckpt_path is not None: |
|
print(f'[INFO] Loading LoRA model checkpoint...') |
|
llm_model = PeftModel.from_pretrained(llm_model, './model_ckpt/llama3-llava-next-8b-task-lora') |
|
llm_model = llm_model.merge_and_unload() |
|
else: |
|
print(f'[INFO] Creating LoRA ...') |
|
peft_config = LoraConfig( |
|
lora_alpha=args.lora_rank, |
|
lora_dropout=0.05, |
|
r=args.lora_rank, |
|
bias="none", |
|
task_type="CAUSAL_LM", |
|
target_modules=[ |
|
"q_proj", |
|
"k_proj", |
|
"v_proj", |
|
"o_proj", |
|
"gate_proj", |
|
"up_proj", |
|
"down_proj", |
|
"lm_head", |
|
], |
|
) |
|
lora_llm = get_peft_model(llm_model, peft_config) |
|
llm_model = lora_llm.model |
|
|
|
|
|
train_params = llm_model.parameters() |
|
print(f'[INFO] Creating Model ...') |
|
model = SurgLLaVA(args, llm_model, tokenizer) |
|
model = model.to(torch.bfloat16) |
|
optimizer = torch.optim.AdamW(train_params, lr=args.lr, eps=1e-7) |
|
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=len(train_dataloader) * args.step_size // args.gradient_accumulation_steps, gamma=args.gamma) |
|
|
|
if args.ckpt_path is not None: |
|
print(f'[INFO] Load whole pretrained checkpoint...') |
|
whole_model = torch.load(os.path.join(args.ckpt_path, 'pytorch_model.bin'), map_location='cpu') |
|
model.load_state_dict(whole_model) |
|
|
|
print(f'[INFO] Preparing accelerator...') |
|
model, tokenizer, optimizer, scheduler, train_dataloader, test_dataloader = accelerator.prepare(model, tokenizer, optimizer, scheduler, train_dataloader, test_dataloader) |
|
if args.test: |
|
|
|
accelerator.print(f'[INFO] Start testing...') |
|
model.eval() |
|
with torch.no_grad(): |
|
os.makedirs(args.output_dir, exist_ok=True) |
|
output_list = [] |
|
output_tasks = {} |
|
for i, batch in tqdm(enumerate(test_dataloader)): |
|
raw_data, question, answer, image, image_sizes = batch |
|
image = [img for img in image] |
|
image_sizes = image_sizes[0] |
|
if len(image_sizes) != args.batch_size: |
|
image_sizes = [torch.cat(image_sizes)] |
|
output = model(image, image_sizes, question) |
|
text_output = tokenizer.batch_decode(output, skip_special_tokens=True) |
|
output_data = raw_data |
|
output_data.update({'answer': text_output, 'question': question}) |
|
output_data = [dict(zip(output_data,t)) for t in zip(*output_data.values())] |
|
|
|
with open(f'./temp_{accelerator.process_index}.json', 'w') as f: |
|
json.dump(output_data, f, indent = 4) |
|
accelerator.wait_for_everyone() |
|
|
|
if accelerator.is_main_process: |
|
for j in range(accelerator.num_processes): |
|
with open(f'./temp_{j}.json', 'r') as f: |
|
temp_output = json.load(f) |
|
for t in temp_output: |
|
if t['task'] not in output_tasks.keys(): |
|
output_tasks[t['task']] = [] |
|
output_tasks[t['task']].append(t) |
|
output_list.append(t) |
|
os.remove(f'./temp_{j}.json') |
|
with open(os.path.join(args.output_dir, f'preds.json'), 'w') as f: |
|
json.dump(output_list, f, indent = 4) |
|
for k in output_tasks.keys(): |
|
with open(os.path.join(args.output_dir, f'preds_{k}.json'), 'w') as f: |
|
json.dump(output_tasks[k], f, indent = 4) |
|
accelerator.wait_for_everyone() |
|
else: |
|
|
|
accelerator.print(f'[INFO] Start training...') |
|
for epoch in tqdm(range(args.num_epochs)): |
|
model.train() |
|
total_train_loss = 0 |
|
for i, batch in enumerate(train_dataloader): |
|
optimizer.zero_grad() |
|
img_id, question, answer, image, image_sizes = batch |
|
image = [img for img in image] |
|
image_sizes = image_sizes[0] |
|
if len(image_sizes) != args.batch_size: |
|
image_sizes = [torch.cat(image_sizes)] |
|
output = model(image, image_sizes, question, answer) |
|
loss = output.loss |
|
|
|
for param in model.parameters(): |
|
loss += param.sum() * 0.0 |
|
accelerator.backward(loss) |
|
optimizer.step() |
|
scheduler.step() |
|
total_train_loss += loss.item() |
|
if i % 100 == 0: |
|
accelerator.print(f'[Epoch {epoch} Iter {i}] loss: {loss.item()}') |
|
accelerator.log({ |
|
'train_loss': loss.item(), |
|
'lr': scheduler.get_last_lr()[0], |
|
}) if args.wandb else None |
|
|
|
|
|
|
|
total_train_loss /= len(train_dataloader) |
|
total_test_loss = None |
|
if epoch % args.eval_interval == 0: |
|
total_test_loss = 0 |
|
model.eval() |
|
with torch.no_grad(): |
|
for i, batch in enumerate(test_dataloader): |
|
raw_data, question, answer, image, image_sizes = batch |
|
image = [img for img in image] |
|
image_sizes = image_sizes[0] |
|
if len(image_sizes) != args.batch_size: |
|
image_sizes = [torch.cat(image_sizes)] |
|
output = model(image, image_sizes, question, ) |
|
text_output = tokenizer.batch_decode(output, skip_special_tokens=True) |
|
if i % 100 == 0: |
|
img_id = raw_data[0]['id'] |
|
accelerator.print(f'[Epoch {epoch} ID {img_id} pred text: {text_output[0]}') |
|
accelerator.print(f'[Epoch {epoch} ID {img_id} G T text: {answer[0]}') |
|
accelerator.print() |
|
total_test_loss /= len(test_dataloader) |
|
if epoch % args.save_interval == 0: |
|
accelerator.print(f'[INFO] Save model...') |
|
save_model_dir = os.path.join(args.ckpt_dir, args.model_name, 'checkpoints', f'checkpoint_{epoch:05d}') |
|
lora_save_dir = os.path.join(args.ckpt_dir, args.model_name, 'lora') |
|
accelerator.save_state(save_model_dir, safe_serialization=False, total_limit=5) |
|
unwrapped_model = accelerator.unwrap_model(model) |
|
unwrapped_model.model.save_pretrained( |
|
lora_save_dir, |
|
save_function=accelerator.save, |
|
safe_serialization=False |
|
) |
|
accelerator.print(f"Model saved at {save_model_dir}") |
|
|
|
|
|
accelerator.log({ |
|
'train_loss': total_train_loss, |
|
'eval_loss': total_test_loss if total_test_loss is not None else None, |
|
'lr': scheduler.get_last_lr()[0], |
|
}) if args.wandb else None |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |