|
|
|
''' |
|
This script extracts image and text features for evaluation. (with single-GPU) |
|
''' |
|
|
|
import os |
|
import argparse |
|
import logging |
|
from pathlib import Path |
|
import json |
|
|
|
import torch |
|
from tqdm import tqdm |
|
|
|
from clip.model import convert_weights, CLIP |
|
from eval.data import get_eval_img_dataset, get_eval_txt_dataset |
|
|
|
def parse_args(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument( |
|
'--extract-image-feats', |
|
action="store_true", |
|
default=False, |
|
help="Whether to extract image features." |
|
) |
|
parser.add_argument( |
|
'--extract-text-feats', |
|
action="store_true", |
|
default=False, |
|
help="Whether to extract text features." |
|
) |
|
parser.add_argument( |
|
'--image-data', |
|
type=str, |
|
default="../Multimodal_Retrieval/lmdb/test/imgs", |
|
help="If --extract-image-feats is True, specify the path of the LMDB directory storing input image base64 strings." |
|
) |
|
parser.add_argument( |
|
'--text-data', |
|
type=str, |
|
default="../Multimodal_Retrieval/test_texts.jsonl", |
|
help="If --extract-text-feats is True, specify the path of input text Jsonl file." |
|
) |
|
parser.add_argument( |
|
'--image-feat-output-path', |
|
type=str, |
|
default=None, |
|
help="If --extract-image-feats is True, specify the path of output image features." |
|
) |
|
parser.add_argument( |
|
'--text-feat-output-path', |
|
type=str, |
|
default=None, |
|
help="If --extract-image-feats is True, specify the path of output text features." |
|
) |
|
parser.add_argument( |
|
"--img-batch-size", type=int, default=64, help="Image batch size." |
|
) |
|
parser.add_argument( |
|
"--text-batch-size", type=int, default=64, help="Text batch size." |
|
) |
|
parser.add_argument( |
|
"--context-length", type=int, default=64, help="The maximum length of input text (include [CLS] & [SEP] tokens)." |
|
) |
|
parser.add_argument( |
|
"--resume", |
|
default=None, |
|
type=str, |
|
help="path to latest checkpoint (default: none)", |
|
) |
|
parser.add_argument( |
|
"--precision", |
|
choices=["amp", "fp16", "fp32"], |
|
default="amp", |
|
help="Floating point precition." |
|
) |
|
parser.add_argument( |
|
"--vision-model", |
|
choices=["ViT-B-16", "ViT-L-14", "RN50"], |
|
default="ViT-B-16", |
|
help="Name of the vision backbone to use.", |
|
) |
|
parser.add_argument( |
|
"--text-model", |
|
choices=["RoBERTa-wwm-ext-base-chinese", "RoBERTa-wwm-ext-large-chinese", "RBT3-chinese"], |
|
default="RoBERTa-wwm-ext-base-chinese", |
|
help="Name of the text backbone to use.", |
|
) |
|
parser.add_argument( |
|
"--debug", |
|
default=False, |
|
action="store_true", |
|
help="If true, more information is logged." |
|
) |
|
args = parser.parse_args() |
|
|
|
return args |
|
|
|
|
|
|
|
def convert_models_to_fp32(model): |
|
for p in model.parameters(): |
|
p.data = p.data.float() |
|
if p.grad: |
|
p.grad.data = p.grad.data.float() |
|
|
|
|
|
if __name__ == "__main__": |
|
args = parse_args() |
|
|
|
assert args.extract_image_feats or args.extract_text_feats, "--extract-image-feats and --extract-text-feats cannot both be False!" |
|
|
|
|
|
print("Params:") |
|
for name in sorted(vars(args)): |
|
val = getattr(args, name) |
|
print(f" {name}: {val}") |
|
|
|
args.gpu = 0 |
|
torch.cuda.set_device(args.gpu) |
|
|
|
|
|
vision_model_config_file = Path(__file__).parent.parent.parent / f"clip/model_configs/{args.vision_model.replace('/', '-')}.json" |
|
print('Loading vision model config from', vision_model_config_file) |
|
assert os.path.exists(vision_model_config_file) |
|
|
|
text_model_config_file = Path(__file__).parent.parent.parent / f"clip/model_configs/{args.text_model.replace('/', '-')}.json" |
|
print('Loading text model config from', text_model_config_file) |
|
assert os.path.exists(text_model_config_file) |
|
|
|
with open(vision_model_config_file, 'r') as fv, open(text_model_config_file, 'r') as ft: |
|
model_info = json.load(fv) |
|
if isinstance(model_info['vision_layers'], str): |
|
model_info['vision_layers'] = eval(model_info['vision_layers']) |
|
for k, v in json.load(ft).items(): |
|
model_info[k] = v |
|
|
|
model = CLIP(**model_info) |
|
convert_weights(model) |
|
|
|
|
|
if args.precision == "amp" or args.precision == "fp32": |
|
convert_models_to_fp32(model) |
|
model.cuda(args.gpu) |
|
if args.precision == "fp16": |
|
convert_weights(model) |
|
|
|
|
|
if args.extract_image_feats: |
|
print("Preparing image inference dataset.") |
|
img_data = get_eval_img_dataset(args) |
|
if args.extract_text_feats: |
|
print("Preparing text inference dataset.") |
|
text_data = get_eval_txt_dataset(args, max_txt_length=args.context_length) |
|
|
|
|
|
print("Begin to load model checkpoint from {}.".format(args.resume)) |
|
assert os.path.exists(args.resume), "The checkpoint file {} not exists!".format(args.resume) |
|
|
|
loc = "cuda:{}".format(args.gpu) |
|
checkpoint = torch.load(args.resume, map_location='cpu') |
|
start_epoch = checkpoint["epoch"] |
|
sd = checkpoint["state_dict"] |
|
if next(iter(sd.items()))[0].startswith('module'): |
|
sd = {k[len('module.'):]: v for k, v in sd.items() if "bert.pooler" not in k} |
|
model.load_state_dict(sd) |
|
print( |
|
f"=> loaded checkpoint '{args.resume}' (epoch {checkpoint['epoch']} @ {checkpoint['step']} steps)" |
|
) |
|
|
|
|
|
if args.extract_text_feats: |
|
print('Make inference for texts...') |
|
if args.text_feat_output_path is None: |
|
args.text_feat_output_path = "{}.txt_feat.jsonl".format(args.text_data[:-6]) |
|
write_cnt = 0 |
|
with open(args.text_feat_output_path, "w") as fout: |
|
model.eval() |
|
dataloader = text_data.dataloader |
|
with torch.no_grad(): |
|
for batch in tqdm(dataloader): |
|
text_ids, texts = batch |
|
texts = texts.cuda(args.gpu, non_blocking=True) |
|
text_features = model(None, texts) |
|
text_features /= text_features.norm(dim=-1, keepdim=True) |
|
for text_id, text_feature in zip(text_ids.tolist(), text_features.tolist()): |
|
fout.write("{}\n".format(json.dumps({"text_id": text_id, "feature": text_feature}))) |
|
write_cnt += 1 |
|
print('{} text features are stored in {}'.format(write_cnt, args.text_feat_output_path)) |
|
|
|
|
|
if args.extract_image_feats: |
|
print('Make inference for images...') |
|
if args.image_feat_output_path is None: |
|
|
|
args.image_feat_output_path = "{}.img_feat.jsonl".format(args.text_data.replace("_texts.jsonl", "_imgs")) |
|
write_cnt = 0 |
|
with open(args.image_feat_output_path, "w") as fout: |
|
model.eval() |
|
dataloader = img_data.dataloader |
|
with torch.no_grad(): |
|
for batch in tqdm(dataloader): |
|
image_ids, images = batch |
|
images = images.cuda(args.gpu, non_blocking=True) |
|
image_features = model(images, None) |
|
image_features /= image_features.norm(dim=-1, keepdim=True) |
|
for image_id, image_feature in zip(image_ids.tolist(), image_features.tolist()): |
|
fout.write("{}\n".format(json.dumps({"image_id": image_id, "feature": image_feature}))) |
|
write_cnt += 1 |
|
print('{} image features are stored in {}'.format(write_cnt, args.image_feat_output_path)) |
|
|
|
print("Done!") |