QA-CLIP / eval /extract_features.py
kunyi
Upload 30 files
f76d30f
raw
history blame
8.17 kB
# -*- coding: utf-8 -*-
'''
This script extracts image and text features for evaluation. (with single-GPU)
'''
import os
import argparse
import logging
from pathlib import Path
import json
import torch
from tqdm import tqdm
from clip.model import convert_weights, CLIP
from eval.data import get_eval_img_dataset, get_eval_txt_dataset
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
'--extract-image-feats',
action="store_true",
default=False,
help="Whether to extract image features."
)
parser.add_argument(
'--extract-text-feats',
action="store_true",
default=False,
help="Whether to extract text features."
)
parser.add_argument(
'--image-data',
type=str,
default="../Multimodal_Retrieval/lmdb/test/imgs",
help="If --extract-image-feats is True, specify the path of the LMDB directory storing input image base64 strings."
)
parser.add_argument(
'--text-data',
type=str,
default="../Multimodal_Retrieval/test_texts.jsonl",
help="If --extract-text-feats is True, specify the path of input text Jsonl file."
)
parser.add_argument(
'--image-feat-output-path',
type=str,
default=None,
help="If --extract-image-feats is True, specify the path of output image features."
)
parser.add_argument(
'--text-feat-output-path',
type=str,
default=None,
help="If --extract-image-feats is True, specify the path of output text features."
)
parser.add_argument(
"--img-batch-size", type=int, default=64, help="Image batch size."
)
parser.add_argument(
"--text-batch-size", type=int, default=64, help="Text batch size."
)
parser.add_argument(
"--context-length", type=int, default=64, help="The maximum length of input text (include [CLS] & [SEP] tokens)."
)
parser.add_argument(
"--resume",
default=None,
type=str,
help="path to latest checkpoint (default: none)",
)
parser.add_argument(
"--precision",
choices=["amp", "fp16", "fp32"],
default="amp",
help="Floating point precition."
)
parser.add_argument(
"--vision-model",
choices=["ViT-B-16", "ViT-L-14", "RN50"],
default="ViT-B-16",
help="Name of the vision backbone to use.",
)
parser.add_argument(
"--text-model",
choices=["RoBERTa-wwm-ext-base-chinese", "RoBERTa-wwm-ext-large-chinese", "RBT3-chinese"],
default="RoBERTa-wwm-ext-base-chinese",
help="Name of the text backbone to use.",
)
parser.add_argument(
"--debug",
default=False,
action="store_true",
help="If true, more information is logged."
)
args = parser.parse_args()
return args
# Used by https://github.com/openai/CLIP/issues/83 but not below.
# Keeping it incase needed.
def convert_models_to_fp32(model):
for p in model.parameters():
p.data = p.data.float()
if p.grad:
p.grad.data = p.grad.data.float()
if __name__ == "__main__":
args = parse_args()
assert args.extract_image_feats or args.extract_text_feats, "--extract-image-feats and --extract-text-feats cannot both be False!"
# Log params.
print("Params:")
for name in sorted(vars(args)):
val = getattr(args, name)
print(f" {name}: {val}")
args.gpu = 0
torch.cuda.set_device(args.gpu)
# Initialize the model.
vision_model_config_file = Path(__file__).parent.parent.parent / f"clip/model_configs/{args.vision_model.replace('/', '-')}.json"
print('Loading vision model config from', vision_model_config_file)
assert os.path.exists(vision_model_config_file)
text_model_config_file = Path(__file__).parent.parent.parent / f"clip/model_configs/{args.text_model.replace('/', '-')}.json"
print('Loading text model config from', text_model_config_file)
assert os.path.exists(text_model_config_file)
with open(vision_model_config_file, 'r') as fv, open(text_model_config_file, 'r') as ft:
model_info = json.load(fv)
if isinstance(model_info['vision_layers'], str):
model_info['vision_layers'] = eval(model_info['vision_layers'])
for k, v in json.load(ft).items():
model_info[k] = v
model = CLIP(**model_info)
convert_weights(model)
# See https://discuss.pytorch.org/t/valueerror-attemting-to-unscale-fp16-gradients/81372
if args.precision == "amp" or args.precision == "fp32":
convert_models_to_fp32(model)
model.cuda(args.gpu)
if args.precision == "fp16":
convert_weights(model)
# Get data.
if args.extract_image_feats:
print("Preparing image inference dataset.")
img_data = get_eval_img_dataset(args)
if args.extract_text_feats:
print("Preparing text inference dataset.")
text_data = get_eval_txt_dataset(args, max_txt_length=args.context_length)
# Resume from a checkpoint.
print("Begin to load model checkpoint from {}.".format(args.resume))
assert os.path.exists(args.resume), "The checkpoint file {} not exists!".format(args.resume)
# Map model to be loaded to specified single gpu.
loc = "cuda:{}".format(args.gpu)
checkpoint = torch.load(args.resume, map_location='cpu')
start_epoch = checkpoint["epoch"]
sd = checkpoint["state_dict"]
if next(iter(sd.items()))[0].startswith('module'):
sd = {k[len('module.'):]: v for k, v in sd.items() if "bert.pooler" not in k}
model.load_state_dict(sd)
print(
f"=> loaded checkpoint '{args.resume}' (epoch {checkpoint['epoch']} @ {checkpoint['step']} steps)"
)
# Make inference for texts
if args.extract_text_feats:
print('Make inference for texts...')
if args.text_feat_output_path is None:
args.text_feat_output_path = "{}.txt_feat.jsonl".format(args.text_data[:-6])
write_cnt = 0
with open(args.text_feat_output_path, "w") as fout:
model.eval()
dataloader = text_data.dataloader
with torch.no_grad():
for batch in tqdm(dataloader):
text_ids, texts = batch
texts = texts.cuda(args.gpu, non_blocking=True)
text_features = model(None, texts)
text_features /= text_features.norm(dim=-1, keepdim=True)
for text_id, text_feature in zip(text_ids.tolist(), text_features.tolist()):
fout.write("{}\n".format(json.dumps({"text_id": text_id, "feature": text_feature})))
write_cnt += 1
print('{} text features are stored in {}'.format(write_cnt, args.text_feat_output_path))
# Make inference for images
if args.extract_image_feats:
print('Make inference for images...')
if args.image_feat_output_path is None:
# by default, we store the image features under the same directory with the text features
args.image_feat_output_path = "{}.img_feat.jsonl".format(args.text_data.replace("_texts.jsonl", "_imgs"))
write_cnt = 0
with open(args.image_feat_output_path, "w") as fout:
model.eval()
dataloader = img_data.dataloader
with torch.no_grad():
for batch in tqdm(dataloader):
image_ids, images = batch
images = images.cuda(args.gpu, non_blocking=True)
image_features = model(images, None)
image_features /= image_features.norm(dim=-1, keepdim=True)
for image_id, image_feature in zip(image_ids.tolist(), image_features.tolist()):
fout.write("{}\n".format(json.dumps({"image_id": image_id, "feature": image_feature})))
write_cnt += 1
print('{} image features are stored in {}'.format(write_cnt, args.image_feat_output_path))
print("Done!")