llava_finetune / run_finetune_llava.py
lyclyc52's picture
Update: integrate llama3 into finetuning code
157f5b2
raw
history blame
12.4 kB
from llava.model.builder import load_pretrained_model
from llava.mm_utils import get_model_name_from_path, process_images, tokenizer_image_token
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, IGNORE_INDEX
from llava.conversation import conv_templates, SeparatorStyle
from peft import LoraConfig, get_peft_model, PeftModel
from PIL import Image
import requests
import copy
import torch
import argparse
from dataset.SurgDataset import SurgDataset
from accelerate import Accelerator
from llava.model.SurgLLaVA import SurgLLaVA
import os
from tqdm import tqdm
import json
os.environ['TORCH_USE_CUDA_DSA'] = '1'
os.environ['TOKENIZERS_PARALLELISM'] = '1'
os.environ['TORCH_DISTRIBUTED_DEBUG'] = 'DETAIL'
def parse_args():
parser = argparse.ArgumentParser()
# General arguments
parser.add_argument('--data_path', type=str, default='/mnt1/lyc/llava_finetune/data_json/instruct_sample_18430_0713_rephrase', help='Data path')
parser.add_argument('--mode', type=str, default='train', choices=['train', 'test'])
parser.add_argument('--wandb', action='store_true')
parser.add_argument('--wandb_project', type=str, default='SurgLlaVA')
parser.add_argument('--wandb_process_name', type=str, default='finetune')
parser.add_argument('--lora_rank', type=int, default=64, help='Rank of the LoRA matrix')
parser.add_argument('--lr', type=float, default=1e-5, help='Learning rate')
parser.add_argument('--batch_size', type=int, default=1, help='Batch size')
parser.add_argument('--log_interval', type=int, default=1)
parser.add_argument('--eval_interval', type=int, default=3)
parser.add_argument('--save_interval', type=int, default=3)
parser.add_argument('--ckpt_dir', type=str, default='model_ckpt', help='Model directory to store checkpoints')
parser.add_argument('--model_name', type=str, default='llava3_mix_instr', help='Model name. This will be used to create a directory in ckpt_dir and show in wandb')
parser.add_argument('--gradient_accumulation_steps', type=int, default=12)
parser.add_argument('--step_size', type=int, default=300)
parser.add_argument('--gamma', type=float, default=0.95, help='gemma value of scheduler')
parser.add_argument('--num_epochs', type=int, default=1000)
parser.add_argument('--lora', action='store_true', help='Use LoRA if True')
parser.add_argument('--test', action='store_true')
parser.add_argument('--lora_ckpt_path', type=str, default=None)
parser.add_argument('--ckpt_path', type=str, default=None)
parser.add_argument('--output_dir', type=str, default='4dor_output', help='output file path, which will store output text.')
return parser.parse_args()
def main():
args = parse_args()
accelerator = Accelerator(project_dir=os.path.join(args.ckpt_dir, args.model_name),
log_with="wandb" if args.wandb else None,
gradient_accumulation_steps=args.gradient_accumulation_steps)
if args.wandb:
print(f'[INFO] Using wandb for logging...')
accelerator.init_trackers(
project_name=args.wandb_project,
config=args,
init_kwargs={"wandb": {"name": args.wandb_process_name}}
)
accelerator.print("[Info] Using wandb for logging...")
pretrained = "lmms-lab/llama3-llava-next-8b"
model_name = "llava_llama3"
tokenizer, llm_model, image_processor, max_length = load_pretrained_model(pretrained, None, model_name, device_map='cuda') # Add any other thing you want to pass in llava_model_args
# tokenizer.pad_token_id = tokenizer.eos_token_id
if tokenizer.pad_token is None:
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
tokenizer.add_special_tokens({'pad_token': tokenizer.eos_token})
train_dataset = SurgDataset(args, image_processor, llm_model.config, mode='train')
test_dataset = SurgDataset(args, image_processor, llm_model.config, mode='test')
train_dataloader = torch.utils.data.DataLoader(train_dataset, shuffle=True, batch_size=args.batch_size, num_workers=4)
test_dataloader = torch.utils.data.DataLoader(test_dataset, shuffle=True, batch_size=args.batch_size, num_workers=4)
print(f'[INFO] Freezing llm model')
for param in llm_model.parameters():
param.requires_grad = False
llm_model.eval()
if args.lora:
if args.lora_ckpt_path is not None:
print(f'[INFO] Loading LoRA model checkpoint...')
llm_model = PeftModel.from_pretrained(llm_model, './model_ckpt/llama3-llava-next-8b-task-lora')
llm_model = llm_model.merge_and_unload()
else:
print(f'[INFO] Creating LoRA ...')
peft_config = LoraConfig(
lora_alpha=args.lora_rank,
lora_dropout=0.05,
r=args.lora_rank,
bias="none",
task_type="CAUSAL_LM",
target_modules=[
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"gate_proj",
"up_proj",
"down_proj",
"lm_head",
],
)
lora_llm = get_peft_model(llm_model, peft_config)
llm_model = lora_llm.model
train_params = llm_model.parameters()
print(f'[INFO] Creating Model ...')
model = SurgLLaVA(args, llm_model, tokenizer)
model = model.to(torch.bfloat16)
optimizer = torch.optim.AdamW(train_params, lr=args.lr, eps=1e-7)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=len(train_dataloader) * args.step_size // args.gradient_accumulation_steps, gamma=args.gamma)
if args.ckpt_path is not None:
print(f'[INFO] Load whole pretrained checkpoint...')
whole_model = torch.load(os.path.join(args.ckpt_path, 'pytorch_model.bin'), map_location='cpu')
model.load_state_dict(whole_model)
print(f'[INFO] Preparing accelerator...')
model, tokenizer, optimizer, scheduler, train_dataloader, test_dataloader = accelerator.prepare(model, tokenizer, optimizer, scheduler, train_dataloader, test_dataloader)
if args.test:
# testing code
accelerator.print(f'[INFO] Start testing...')
model.eval()
with torch.no_grad():
os.makedirs(args.output_dir, exist_ok=True)
output_list = []
output_tasks = {}
for i, batch in tqdm(enumerate(test_dataloader)):
raw_data, question, answer, image, image_sizes = batch
image = [img for img in image]
image_sizes = image_sizes[0]
if len(image_sizes) != args.batch_size:
image_sizes = [torch.cat(image_sizes)]
output = model(image, image_sizes, question)
text_output = tokenizer.batch_decode(output, skip_special_tokens=True)
output_data = raw_data
output_data.update({'answer': text_output, 'question': question})
output_data = [dict(zip(output_data,t)) for t in zip(*output_data.values())]
# Need to save the results to avoid conflict between processes
with open(f'./temp_{accelerator.process_index}.json', 'w') as f:
json.dump(output_data, f, indent = 4)
accelerator.wait_for_everyone()
# The main process are used to merge all the results
if accelerator.is_main_process:
for j in range(accelerator.num_processes):
with open(f'./temp_{j}.json', 'r') as f:
temp_output = json.load(f)
for t in temp_output:
if t['task'] not in output_tasks.keys():
output_tasks[t['task']] = []
output_tasks[t['task']].append(t)
output_list.append(t)
os.remove(f'./temp_{j}.json')
with open(os.path.join(args.output_dir, f'preds.json'), 'w') as f:
json.dump(output_list, f, indent = 4)
for k in output_tasks.keys():
with open(os.path.join(args.output_dir, f'preds_{k}.json'), 'w') as f:
json.dump(output_tasks[k], f, indent = 4)
accelerator.wait_for_everyone()
else:
# initialize epoch-level metrics
accelerator.print(f'[INFO] Start training...')
for epoch in tqdm(range(args.num_epochs)):
model.train()
total_train_loss = 0
for i, batch in enumerate(train_dataloader):
optimizer.zero_grad()
img_id, question, answer, image, image_sizes = batch
image = [img for img in image]
image_sizes = image_sizes[0]
if len(image_sizes) != args.batch_size:
image_sizes = [torch.cat(image_sizes)]
output = model(image, image_sizes, question, answer)
loss = output.loss
# Accelerator requires all params to involve gradient descend. This 'dummy loss' can avoid this issue.
for param in model.parameters():
loss += param.sum() * 0.0
accelerator.backward(loss)
optimizer.step()
scheduler.step()
total_train_loss += loss.item()
if i % 100 == 0:
accelerator.print(f'[Epoch {epoch} Iter {i}] loss: {loss.item()}')
accelerator.log({
'train_loss': loss.item(),
'lr': scheduler.get_last_lr()[0],
}) if args.wandb else None
# except:
# accelerator.print(f"Error: {img_id}, {answer}")
total_train_loss /= len(train_dataloader)
total_test_loss = None
if epoch % args.eval_interval == 0:
total_test_loss = 0
model.eval()
with torch.no_grad():
for i, batch in enumerate(test_dataloader):
raw_data, question, answer, image, image_sizes = batch
image = [img for img in image]
image_sizes = image_sizes[0]
if len(image_sizes) != args.batch_size:
image_sizes = [torch.cat(image_sizes)]
output = model(image, image_sizes, question, )
text_output = tokenizer.batch_decode(output, skip_special_tokens=True)
if i % 100 == 0:
img_id = raw_data[0]['id']
accelerator.print(f'[Epoch {epoch} ID {img_id} pred text: {text_output[0]}')
accelerator.print(f'[Epoch {epoch} ID {img_id} G T text: {answer[0]}')
accelerator.print()
total_test_loss /= len(test_dataloader)
if epoch % args.save_interval == 0:
accelerator.print(f'[INFO] Save model...')
save_model_dir = os.path.join(args.ckpt_dir, args.model_name, 'checkpoints', f'checkpoint_{epoch:05d}')
lora_save_dir = os.path.join(args.ckpt_dir, args.model_name, 'lora')
accelerator.save_state(save_model_dir, safe_serialization=False, total_limit=5)
unwrapped_model = accelerator.unwrap_model(model)
unwrapped_model.model.save_pretrained(
lora_save_dir,
save_function=accelerator.save,
safe_serialization=False
)
accelerator.print(f"Model saved at {save_model_dir}")
accelerator.log({
'train_loss': total_train_loss,
'eval_loss': total_test_loss if total_test_loss is not None else None,
'lr': scheduler.get_last_lr()[0],
}) if args.wandb else None
if __name__ == '__main__':
main()