eP-ALM / float32 /audio_caption.py
mshukor
init
3eb682b
raw
history blame
No virus
14.6 kB
import argparse
import os
import ruamel_yaml as yaml
import numpy as np
import random
import time
import datetime
import json
from pathlib import Path
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import torch.distributed as dist
from models.epalm import ePALM
from models.utils import freeze_whole_model, unfreeze_parameters, print_trainable_params_percentage
from transformers import AutoTokenizer
import utils
from dataset.audio_caption import get_loader
from scheduler import create_scheduler
from optim import create_optimizer
from models.utils import filter_state, filter_msg, exclude_list
def train(model, data_loader, optimizer, tokenizer, epoch, warmup_steps, device, scheduler, config):
model.train()
metric_logger = utils.MetricLogger(delimiter=" ")
metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
metric_logger.add_meter('loss', utils.SmoothedValue(window_size=1, fmt='{value:.4f}'))
header = 'Train Epoch: [{}]'.format(epoch)
print_freq = 50
step_size = 100
warmup_iterations = warmup_steps*step_size
lm_loss_weight = config.get('lm_loss_weight', 1)
append_eos_token = config.get('append_eos_token', False)
eos_token = tokenizer.eos_token
config_optim = utils.AttrDict(config['optimizer'])
prompt_lr = config_optim.prompt_lr if hasattr(config_optim, 'prompt_lr') else None
task_prompt = config.get('task_prompt', None)
if prompt_lr is not None:
metric_logger.add_meter('prompt_lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
for i, batch in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
image = batch["images"].to(device,non_blocking=True)
text = batch["sent"]
if append_eos_token:
text = [t.replace(eos_token, '') + eos_token for t in text]
if task_prompt is not None:
text = [task_prompt + ' ' + t for t in text]
text_input = tokenizer(text, padding='longest', return_tensors="pt").to(device)
targets = text_input.input_ids.masked_fill(text_input.input_ids == tokenizer.pad_token_id, -100)
answer_output = model(image=image,
text=text_input,
labels = targets,
return_dict = True,
mode='train',
reduction='none',
)
loss = answer_output.loss
loss = loss.sum()/image.size(0)
loss = loss*lm_loss_weight
optimizer.zero_grad()
loss.backward()
optimizer.step()
metric_logger.update(loss=loss.item())
metric_logger.update(lr=optimizer.param_groups[0]["lr"])
if prompt_lr is not None:
metric_logger.update(prompt_lr=optimizer.param_groups[1]["lr"])
if epoch==0 and i%step_size==0 and i<=warmup_iterations:
scheduler.step(i//step_size)
# gather the stats from all processes
metric_logger.synchronize_between_processes()
print("Averaged stats:", metric_logger.global_avg())
return {k: "{:.3f}".format(meter.global_avg) for k, meter in metric_logger.meters.items()}
@torch.no_grad()
def evaluation(model, data_loader, tokenizer, device, config) :
model.eval()
metric_logger = utils.MetricLogger(delimiter=" ")
header = 'Generate Caption test result:'
print_freq = 50
predictions = []
targets = []
task_prompt = config.get('task_prompt', None)
pad_token = tokenizer.pad_token
eos_token = tokenizer.eos_token
for n, batch in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
image = batch["images"].to(device,non_blocking=True)
text = ['' for q in image]
if task_prompt is not None:
text = [task_prompt + ' ' + t for t in text]
text_input = tokenizer(text, padding='longest', return_tensors="pt").to(device)
out = model(image=image, text=text_input, mode='generate', return_dict=True, max_length=30, do_sample=True)
out_decode = []
for i, o in enumerate(out):
try:
res = tokenizer.decode(o)
if task_prompt is not None:
res = res.replace(task_prompt, '')
response = res.split('</s>')[1].replace(pad_token, '').replace('</s>', '').replace(eos_token, '') # skip_special_tokens=True
except TypeError:
print(o)
response = ' '
out_decode.append(response)
predictions.extend(out_decode)
if 'targets' in batch:
targets.extend(batch['targets'])
evaluator = data_loader.evaluator
eval_results = evaluator.evaluate(predictions, targets)
wandb_log_dict = {}
for score_name, score in eval_results.items():
wandb_log_dict[f'Valid/{score_name}'] = score
print(wandb_log_dict)
return wandb_log_dict
def main(args, config):
utils.init_distributed_mode(args)
device = torch.device(args.device)
# fix the seed for reproducibility
seed = args.seed + utils.get_rank()
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
cudnn.benchmark = True
start_epoch = 0
max_epoch = config['schedular']['epochs']
warmup_steps = config['schedular']['warmup_epochs']
print(args, config)
tokenizer = AutoTokenizer.from_pretrained(args.text_model, use_fast=False, local_files_only=True)
if args.distributed:
num_tasks = utils.get_world_size()
global_rank = utils.get_rank()
else:
num_tasks = None
global_rank = None
#########
num_workers = config.get('num_workers', 4)
train_topk = config.get('train_topk', -1)
valid_topk = config.get('valid_topk', -1)
data_dir = args.data_dir
args.image_size = config.get('image_res', 224)
args.use_data_augmentation = True
black_image = config.get('black_image', False)
print("black image:", black_image)
# audio
args.melbins = config.get('melbins', 128)
args.target_length = config.get('target_length', 1024)
args.num_tries = config.get('num_tries', 1)
args.skip_norm = config.get('skip_norm', True)
args.norm_mean = config.get('norm_mean', None)
args.norm_std = config.get('norm_std', None)
args.noise = config.get('noise', False)
args.freqm_p = config.get('freqm_p', 48)
args.timem_p = config.get('timem_p', 192)
train_split = config.get('train_split', 'train')
val_split = config.get('val_split', 'val')
test_split = config.get('test_split', 'test')
train_loader = get_loader(
args,
split=train_split, mode='train', batch_size=config['batch_size_train'],
distributed=args.distributed,
workers=num_workers,
topk=train_topk,
data_dir=data_dir,
local_rank=global_rank, world_size=num_tasks, verbose=True, black_image=black_image
)
print('# len train loader:', len(train_loader))
print(f'Building val loader')
val_loader = get_loader(
args,
split=val_split, mode='val', batch_size=config['batch_size_test'],
distributed=False,
workers=4,
topk=valid_topk,data_dir=data_dir,
local_rank=global_rank, world_size=num_tasks, verbose=True, black_image=black_image
)
print('# len val loader:', len(val_loader))
print(f'Building test loader')
test_loader = get_loader(
args,
split=test_split, mode='val', batch_size=config['batch_size_test'],
distributed=False,
workers=4,
topk=valid_topk,data_dir=data_dir,
local_rank=global_rank, world_size=num_tasks, verbose=True
)
print('# len test loader:', len(test_loader))
#### Model ####
print("Creating model")
start_layer_idx = config.get('start_layer_idx', 0)
end_layer_idx = config.get('end_layer_idx', 0)
vision_model_name = config.get('vision_model_name', args.vision_model)
model = ePALM(opt_model_name = args.text_model,
vision_model_name = vision_model_name,
use_vis_prefix = True,
start_layer_idx = start_layer_idx,
end_layer_idx = end_layer_idx,
return_hidden_state_vision = True,
config=config,
)
model = model.to(device)
arg_opt = utils.AttrDict(config['optimizer'])
optimizer = create_optimizer(arg_opt, model, config=config)
if hasattr(arg_opt, 'prompt_lr') and arg_opt.prompt_lr is not None:
print('\tInitial other params params lr: %f' % optimizer.param_groups[0]['lr'])
print('\tInitial prompt params lr: %f' % optimizer.param_groups[1]['lr'])
arg_sche = utils.AttrDict(config['schedular'])
lr_scheduler, _ = create_scheduler(arg_sche, optimizer)
best_epoch = 0
best_valid = 0
if args.checkpoint:
checkpoint = torch.load(args.checkpoint, map_location='cpu')
state_dict = checkpoint['model']
msg = model.load_state_dict(state_dict,strict=False)
msg = filter_msg(msg, exclude_list)
print('load checkpoint from %s'%args.checkpoint)
print(msg)
if args.resume:
model = model.to(device)
optimizer.load_state_dict(checkpoint['optimizer'])
lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
start_epoch = checkpoint['epoch']+1
print(checkpoint.keys())
if 'best_valid' in checkpoint:
best_valid = checkpoint['best_valid']
best_epoch = checkpoint['best_epoch']
print("load best valid {} at epoch {}".format(best_valid, best_epoch))
freeze_whole_model(model)
unfreeze_parameters(model, config)
print_trainable_params_percentage(model)
model_without_ddp = model
if args.distributed:
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
model_without_ddp = model.module
print("Start training")
start_time = time.time()
for epoch in range(start_epoch, max_epoch):
if epoch>0:
lr_scheduler.step(epoch+warmup_steps)
if not args.evaluate:
if args.distributed:
train_loader.sampler.set_epoch(epoch)
train_stats = train(model, train_loader, optimizer, tokenizer, epoch, warmup_steps, device, lr_scheduler, config)
if args.evaluate:
break
valid_results = evaluation(model, val_loader, tokenizer, device, config)
if utils.is_main_process():
log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
'epoch': epoch,
}
with open(os.path.join(args.output_dir, "log.txt"),"a") as f:
f.write(json.dumps(log_stats) + "\n")
save_obj = {
'model': filter_state(model_without_ddp.state_dict(), exclude_list),
'optimizer': optimizer.state_dict(),
'lr_scheduler': lr_scheduler.state_dict(),
'config': config,
'epoch': epoch,
'best_valid': best_valid,
'best_epoch': best_epoch,
}
if args.save_best:
valid_score = valid_results['Valid/CIDEr']
if valid_score > best_valid or epoch == 0:
best_valid = valid_score
best_epoch = epoch
print("Save best epoch:", best_epoch)
save_obj['best_valid'] = best_valid
save_obj['best_epoch'] = best_epoch
torch.save(save_obj, os.path.join(args.output_dir, 'checkpoint_best.pth'))
# else:
torch.save(save_obj, os.path.join(args.output_dir, 'checkpoint_last.pth'))
dist.barrier()
if not args.evaluate:
checkpoint = torch.load(os.path.join(args.output_dir, 'checkpoint_best.pth'), map_location='cpu')
state_dict = checkpoint['model']
msg = model.module.load_state_dict(state_dict,strict=False)
msg = filter_msg(msg, exclude_list)
print('load checkpoint for test from %s'%os.path.join(args.output_dir, 'checkpoint_best.pth'))
print(msg)
vqa_result = evaluation(model, test_loader, tokenizer, device, config)
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print('Training time {}'.format(total_time_str))
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--config', default='./configs/VQA.yaml')
parser.add_argument('--checkpoint', default='')
parser.add_argument('--output_dir', default='output/vqa')
parser.add_argument('--evaluate', action='store_true')
parser.add_argument('--text_model', default='facebook/opt-350m')
parser.add_argument('--vision_model', default='vit_base_patch16_224')
parser.add_argument('--device', default='cuda')
parser.add_argument('--seed', default=42, type=int)
parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes')
parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
parser.add_argument('--distributed', default=True, type=bool)
parser.add_argument('--data_dir', default='/data/mshukor/data')
parser.add_argument('--resume', action='store_true')
parser.add_argument('--save_best', action='store_true')
parser.add_argument('--image_dir', default='/data/mshukor/data')
args = parser.parse_args()
config = yaml.load(open(args.config, 'r'), Loader=yaml.Loader)
args.result_dir = os.path.join(args.output_dir, 'result')
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
Path(args.result_dir).mkdir(parents=True, exist_ok=True)
yaml.dump(config, open(os.path.join(args.output_dir, 'config.yaml'), 'w'))
main(args, config)