Image-to-Text
Chinese
English
FLIP / FLIP-demo /main.py
OpenFace-CQUPT
Upload 14 files
6e6d6a7 verified
import argparse
import numpy as np
import random
from pathlib import Path
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import torch.distributed as dist
from torch.cuda.amp import GradScaler, autocast
from models.FFLIP import FLIP
from models import utils
from eval.pretrain_eval import evaluation, itm_eval
from data import create_dataset, create_sampler, create_loader
def main(args):
utils.init_distributed_mode(args)
device = torch.device(args.device)
seed = args.seed + utils.get_rank()
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
cudnn.benchmark = True
#### The reference code for creating the dataset ####
print("Creating dataset")
train_dataset, test_dataset = create_dataset(args, 'facecaption')
if args.distributed:
num_tasks = utils.get_world_size()
global_rank = utils.get_rank()
samplers = create_sampler([train_dataset], [True], num_tasks, global_rank) + [None]
else:
samplers = [None, None]
train_loader, test_loader = create_loader([train_dataset, test_dataset], samplers,
batch_size=[80] + [80],
num_workers=[8, 8],
is_trains=[True, False],
collate_fns=[None, None])
#### Model ####
print("Creating model")
model = FLIP(pretrained=args.pretrained, vit='base', queue_size=61440)
model = model.to(device)
model_without_ddp = model
if args.distributed:
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
model_without_ddp = model.module
print("Start evaluation")
score_test_i2t, score_test_t2i = evaluation(args, model_without_ddp, test_loader, device)
if utils.is_main_process():
test_result = itm_eval(score_test_i2t, score_test_t2i, test_loader.dataset.txt2img,
test_loader.dataset.img2txt)
print(test_result)
if args.distributed:
dist.barrier()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--output_dir', default='./outputs')
parser.add_argument('--img_root', default='./FaceCaption/images')
parser.add_argument('--ann_root', default='.FaceCaption/caption')
parser.add_argument('--pretrained', default='./FaceCaption-15M-base.pth')
parser.add_argument('--device', default='cuda')
parser.add_argument('--seed', default=42, type=int)
parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes')
parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
parser.add_argument('--distributed', default=False, type=bool, help='whether to use distributed mode to training')
args = parser.parse_args()
main(args)