# -*- coding: utf-8 -*- ''' This script extracts image and text features for evaluation. (with single-GPU) ''' import os import argparse import logging from pathlib import Path import json import torch from tqdm import tqdm from clip.model import convert_weights, CLIP from eval.data import get_eval_img_dataset, get_eval_txt_dataset def parse_args(): parser = argparse.ArgumentParser() parser.add_argument( '--extract-image-feats', action="store_true", default=False, help="Whether to extract image features." ) parser.add_argument( '--extract-text-feats', action="store_true", default=False, help="Whether to extract text features." ) parser.add_argument( '--image-data', type=str, default="../Multimodal_Retrieval/lmdb/test/imgs", help="If --extract-image-feats is True, specify the path of the LMDB directory storing input image base64 strings." ) parser.add_argument( '--text-data', type=str, default="../Multimodal_Retrieval/test_texts.jsonl", help="If --extract-text-feats is True, specify the path of input text Jsonl file." ) parser.add_argument( '--image-feat-output-path', type=str, default=None, help="If --extract-image-feats is True, specify the path of output image features." ) parser.add_argument( '--text-feat-output-path', type=str, default=None, help="If --extract-image-feats is True, specify the path of output text features." ) parser.add_argument( "--img-batch-size", type=int, default=64, help="Image batch size." ) parser.add_argument( "--text-batch-size", type=int, default=64, help="Text batch size." ) parser.add_argument( "--context-length", type=int, default=64, help="The maximum length of input text (include [CLS] & [SEP] tokens)." ) parser.add_argument( "--resume", default=None, type=str, help="path to latest checkpoint (default: none)", ) parser.add_argument( "--precision", choices=["amp", "fp16", "fp32"], default="amp", help="Floating point precition." ) parser.add_argument( "--vision-model", choices=["ViT-B-16", "ViT-L-14", "RN50"], default="ViT-B-16", help="Name of the vision backbone to use.", ) parser.add_argument( "--text-model", choices=["RoBERTa-wwm-ext-base-chinese", "RoBERTa-wwm-ext-large-chinese", "RBT3-chinese"], default="RoBERTa-wwm-ext-base-chinese", help="Name of the text backbone to use.", ) parser.add_argument( "--debug", default=False, action="store_true", help="If true, more information is logged." ) args = parser.parse_args() return args # Used by https://github.com/openai/CLIP/issues/83 but not below. # Keeping it incase needed. def convert_models_to_fp32(model): for p in model.parameters(): p.data = p.data.float() if p.grad: p.grad.data = p.grad.data.float() if __name__ == "__main__": args = parse_args() assert args.extract_image_feats or args.extract_text_feats, "--extract-image-feats and --extract-text-feats cannot both be False!" # Log params. print("Params:") for name in sorted(vars(args)): val = getattr(args, name) print(f" {name}: {val}") args.gpu = 0 torch.cuda.set_device(args.gpu) # Initialize the model. vision_model_config_file = Path(__file__).parent.parent.parent / f"clip/model_configs/{args.vision_model.replace('/', '-')}.json" print('Loading vision model config from', vision_model_config_file) assert os.path.exists(vision_model_config_file) text_model_config_file = Path(__file__).parent.parent.parent / f"clip/model_configs/{args.text_model.replace('/', '-')}.json" print('Loading text model config from', text_model_config_file) assert os.path.exists(text_model_config_file) with open(vision_model_config_file, 'r') as fv, open(text_model_config_file, 'r') as ft: model_info = json.load(fv) if isinstance(model_info['vision_layers'], str): model_info['vision_layers'] = eval(model_info['vision_layers']) for k, v in json.load(ft).items(): model_info[k] = v model = CLIP(**model_info) convert_weights(model) # See https://discuss.pytorch.org/t/valueerror-attemting-to-unscale-fp16-gradients/81372 if args.precision == "amp" or args.precision == "fp32": convert_models_to_fp32(model) model.cuda(args.gpu) if args.precision == "fp16": convert_weights(model) # Get data. if args.extract_image_feats: print("Preparing image inference dataset.") img_data = get_eval_img_dataset(args) if args.extract_text_feats: print("Preparing text inference dataset.") text_data = get_eval_txt_dataset(args, max_txt_length=args.context_length) # Resume from a checkpoint. print("Begin to load model checkpoint from {}.".format(args.resume)) assert os.path.exists(args.resume), "The checkpoint file {} not exists!".format(args.resume) # Map model to be loaded to specified single gpu. loc = "cuda:{}".format(args.gpu) checkpoint = torch.load(args.resume, map_location='cpu') start_epoch = checkpoint["epoch"] sd = checkpoint["state_dict"] if next(iter(sd.items()))[0].startswith('module'): sd = {k[len('module.'):]: v for k, v in sd.items() if "bert.pooler" not in k} model.load_state_dict(sd) print( f"=> loaded checkpoint '{args.resume}' (epoch {checkpoint['epoch']} @ {checkpoint['step']} steps)" ) # Make inference for texts if args.extract_text_feats: print('Make inference for texts...') if args.text_feat_output_path is None: args.text_feat_output_path = "{}.txt_feat.jsonl".format(args.text_data[:-6]) write_cnt = 0 with open(args.text_feat_output_path, "w") as fout: model.eval() dataloader = text_data.dataloader with torch.no_grad(): for batch in tqdm(dataloader): text_ids, texts = batch texts = texts.cuda(args.gpu, non_blocking=True) text_features = model(None, texts) text_features /= text_features.norm(dim=-1, keepdim=True) for text_id, text_feature in zip(text_ids.tolist(), text_features.tolist()): fout.write("{}\n".format(json.dumps({"text_id": text_id, "feature": text_feature}))) write_cnt += 1 print('{} text features are stored in {}'.format(write_cnt, args.text_feat_output_path)) # Make inference for images if args.extract_image_feats: print('Make inference for images...') if args.image_feat_output_path is None: # by default, we store the image features under the same directory with the text features args.image_feat_output_path = "{}.img_feat.jsonl".format(args.text_data.replace("_texts.jsonl", "_imgs")) write_cnt = 0 with open(args.image_feat_output_path, "w") as fout: model.eval() dataloader = img_data.dataloader with torch.no_grad(): for batch in tqdm(dataloader): image_ids, images = batch images = images.cuda(args.gpu, non_blocking=True) image_features = model(images, None) image_features /= image_features.norm(dim=-1, keepdim=True) for image_id, image_feature in zip(image_ids.tolist(), image_features.tolist()): fout.write("{}\n".format(json.dumps({"image_id": image_id, "feature": image_feature}))) write_cnt += 1 print('{} image features are stored in {}'.format(write_cnt, args.image_feat_output_path)) print("Done!")