HMMC_t2v_search / app.py
cheetah003's picture
update video index
61f2d16
import numpy as np
# import gradio
import torch
from transformers import BertTokenizer
import argparse
import gradio as gr
import time
from modules.tokenization_clip import SimpleTokenizer as ClipTokenizer
from modules.modeling import BirdModel
show_num = 9
max_words = 32
video_path_zh = "features/Chinese_batch_visual_output_list.npy"
frame_path_zh = "features/Chinese_batch_frame_output_list.npy"
video_fea_zh = np.load(video_path_zh)
video_fea_zh = torch.from_numpy(video_fea_zh)
frame_fea_zh = np.load(frame_path_zh)
frame_fea_zh = torch.from_numpy(frame_fea_zh)
video_path_en = "features/English_batch_visual_output_list.npy"
frame_path_en = "features/English_batch_frame_output_list.npy"
video_fea_en = np.load(video_path_en)
video_fea_en = torch.from_numpy(video_fea_en)
frame_fea_en = np.load(frame_path_en)
frame_fea_en = torch.from_numpy(frame_fea_en)
test_path = "test_list.txt"
# video_dir = "test1500_400_400/"
video_dir = "test1500/"
with open(test_path, 'r', encoding='utf8') as f_list:
lines = f_list.readlines()
video_ids = [itm.strip() + ".mp4" for itm in lines]
def get_videoname(idx):
videoname = []
videopath = []
for i in idx:
videoname.append(video_ids[i])
path = video_dir + video_ids[i]
videopath.append(path)
return videoname, videopath
def get_text(caption, tokenizer):
# tokenize word
words = tokenizer.tokenize(caption)
# add cls token
words = ["<|startoftext|>"] + words
total_length_with_CLS = max_words - 1
if len(words) > total_length_with_CLS:
words = words[:total_length_with_CLS]
# add end token
words = words + ["<|endoftext|>"]
# convert token to id according to the vocab
input_ids = tokenizer.convert_tokens_to_ids(words)
# add zeros for feature of the same length
input_mask = [1] * len(input_ids)
while len(input_ids) < max_words:
input_ids.append(0)
input_mask.append(0)
# ensure the length of feature to be equal with max words
assert len(input_ids) == max_words
assert len(input_mask) == max_words
pairs_text = np.array(input_ids).reshape(-1, max_words)
pairs_text = torch.from_numpy(pairs_text)
pairs_mask = np.array(input_mask).reshape(-1, max_words)
pairs_mask = torch.from_numpy(pairs_mask)
return pairs_text, pairs_mask
def get_args(description='Retrieval Task'):
parser = argparse.ArgumentParser(description=description)
parser.add_argument("--do_pretrain", action='store_true', help="Whether to run training.")
parser.add_argument("--do_train", action='store_true', help="Whether to run training.")
parser.add_argument("--do_eval", action='store_true', help="Whether to run eval on the dev set.")
parser.add_argument("--do_params", action='store_true', help="text the params of the model.")
parser.add_argument("--use_frame_fea", action='store_true', help="whether use frame feature matching text")
parser.add_argument('--task', type=str, default="retrieval", choices=["retrieval_VT", "retrieval"],
help="choose downstream task.")
parser.add_argument('--dataset', type=str, default="bird", choices=["bird", "msrvtt", "vatex", "msvd"],
help="choose dataset.")
parser.add_argument('--num_thread_reader', type=int, default=1, help='')
parser.add_argument('--lr', type=float, default=0.0001, help='initial learning rate')
parser.add_argument('--text_lr', type=float, default=0.00001, help='text encoder learning rate')
parser.add_argument('--epochs', type=int, default=20, help='upper epoch limit')
parser.add_argument('--batch_size', type=int, default=256, help='batch size')
parser.add_argument('--batch_size_val', type=int, default=3500, help='batch size eval')
parser.add_argument('--lr_decay', type=float, default=0.9, help='Learning rate exp epoch decay')
parser.add_argument('--weight_decay', type=float, default=0.2, help='Learning rate exp epoch decay')
parser.add_argument('--n_display', type=int, default=100, help='Information display frequence')
parser.add_argument('--seed', type=int, default=42, help='random seed')
parser.add_argument('--max_words', type=int, default=32, help='')
parser.add_argument('--max_frames', type=int, default=12, help='')
parser.add_argument('--top_frames', type=int, default=3, help='')
parser.add_argument('--frame_sample', type=str, default="uniform", choices=["uniform", "random", "uniform_random"],
help='frame sample strategy')
parser.add_argument('--frame_sample_len', type=str, default="fix", choices=["dynamic", "fix"],
help='use dynamic frame length of fix frame length')
parser.add_argument('--language', type=str, default="chinese", choices=["chinese", "english"],
help='language for text encoder')
parser.add_argument('--use_temp', action='store_true', help='whether to use temporal transformer')
parser.add_argument("--logdir", default=None, type=str, required=False, help="log dir for tensorboardX writer")
parser.add_argument("--cross_model", default="cross-base", type=str, required=False, help="Cross module")
parser.add_argument("--pretrained_text", default="hfl/chinese-roberta-wwm-ext", type=str, required=False, help="pretrained_text")
parser.add_argument("--init_model", default=None, type=str, required=False, help="Initial model.")
parser.add_argument("--warmup_proportion", default=0.1, type=float,
help="Proportion of training to perform linear learning rate warmup for. E.g., 0.1 = 10%% of training.")
parser.add_argument('--gradient_accumulation_steps', type=int, default=1,
help="Number of updates steps to accumulate before performing a backward/update pass.")
parser.add_argument('--n_gpu', type=int, default=1, help="Changed in the execute process.")
parser.add_argument("--cache_dir", default="", type=str,
help="Where do you want to store the pre-trained models downloaded from s3")
parser.add_argument('--enable_amp', action='store_true', help="whether to use pytorch amp")
parser.add_argument("--world_size", default=0, type=int, help="distribted training")
parser.add_argument("--local_rank", default=0, type=int, help="distribted training")
parser.add_argument("--rank", default=0, type=int, help="distribted training")
parser.add_argument('--coef_lr', type=float, default=1., help='coefficient for bert branch.')
args = parser.parse_args()
# Check paramenters
args.do_eval = True
args.use_frame_fea = True
args.use_temp = True
return args
def init_model(language):
time1 = time.time()
args = get_args()
args.language = language
if language == "chinese":
model_path = "models/Chinese_vatex.bin"
tokenizer = BertTokenizer.from_pretrained("hfl/chinese-roberta-wwm-ext")
elif language == "english":
model_path = "models/English_vatex.bin"
tokenizer = ClipTokenizer()
else:
raise Exception("language should be Chinese or English!")
model_state_dict = torch.load(model_path, map_location='cpu')
cross_model = "cross-base"
model = BirdModel.from_pretrained(cross_model, state_dict=model_state_dict, task_config=args)
device = torch.device("cpu")
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()
print("language={}".format(language))
print("init model time: {}".format(time.time() - time1))
print("device:{}".format(device))
return model, tokenizer
model_zh, tokenizer_zh = init_model(language="chinese")
model_en, tokenizer_en = init_model(language="english")
def t2v_search_zh(text):
with torch.no_grad():
time1 = time.time()
text_ids, text_mask = get_text(text, tokenizer_zh)
print("get_text time: {}".format(time.time() - time1))
time1 = time.time()
text_fea_zh = model_zh.text_encoder(text_ids, text_mask)
print("text_encoder time: {}".format(time.time() - time1))
# print("text_fea.shape:{}".format(text_fea.shape))
# print("video_fea.shape:{}".format(video_fea.shape))
# print("frame_fea.shape:{}".format(frame_fea.shape))
time1 = time.time()
sim_video = model_zh.loose_similarity(text_fea_zh, video_fea_zh)
# print("sim_video.shape:{}".format(sim_video.shape))
sim_frame = model_zh.loose_similarity(text_fea_zh, frame_fea_zh)
# print("sim_frame.shape:{}".format(sim_frame.shape))
sim_frame = torch.topk(sim_frame, k=model_zh.top_frames, dim=1)[0]
sim_frame = torch.mean(sim_frame, dim=1)
sim = sim_video + sim_frame
value, index = sim.topk(show_num, dim=0, largest=True, sorted=True)
# value, index = sim_video.topk(show_num, dim=0, largest=True, sorted=True)
print("calculate_similarity time: {}".format(time.time() - time1))
print("value:{}".format(value))
print("index:{}".format(index))
videoname, videopath = get_videoname(index)
print("videoname:{}".format(videoname))
print("videopath:{}".format(videopath))
return videopath
def t2v_search_en(text):
with torch.no_grad():
time1 = time.time()
text_ids, text_mask = get_text(text, tokenizer_en)
print("get_text time: {}".format(time.time() - time1))
time1 = time.time()
text_fea_en = model_en.text_encoder(text_ids, text_mask)
print("text_encoder time: {}".format(time.time() - time1))
# print("text_fea.shape:{}".format(text_fea.shape))
# print("video_fea.shape:{}".format(video_fea.shape))
# print("frame_fea.shape:{}".format(frame_fea.shape))
time1 = time.time()
sim_video = model_en.loose_similarity(text_fea_en, video_fea_en)
# print("sim_video.shape:{}".format(sim_video.shape))
sim_frame = model_en.loose_similarity(text_fea_en, frame_fea_en)
# print("sim_frame.shape:{}".format(sim_frame.shape))
sim_frame = torch.topk(sim_frame, k=model_en.top_frames, dim=1)[0]
sim_frame = torch.mean(sim_frame, dim=1)
sim = sim_video + sim_frame
value, index = sim.topk(show_num, dim=0, largest=True, sorted=True)
# value, index = sim_video.topk(show_num, dim=0, largest=True, sorted=True)
print("calculate_similarity time: {}".format(time.time() - time1))
print("value:{}".format(value))
print("index:{}".format(index))
videoname, videopath = get_videoname(index)
print("videoname:{}".format(videoname))
print("videopath:{}".format(videopath))
return videopath
def hello_world(name):
return "hello world, my name is " + name + "!"
def search_demo():
with gr.Blocks() as demo:
gr.Markdown("# <div align='center'>HMMC中英文本-视频检索 \
<a style='font-size:18px;color: #000000' href='https://github.com/cheetah003/HMMC'> Github </div>")
demo.title = "HMMC中英文本-视频检索"
with gr.Tab("中文"):
with gr.Column(variant="panel"):
with gr.Row(variant="compact"):
input_text = gr.Textbox(
label="输入文本",
show_label=False,
max_lines=1,
placeholder="请输入检索文本...",
).style(
container=False,
)
btn = gr.Button("搜索").style(full_width=False)
with gr.Column(variant="panel", scale=2):
with gr.Row(variant="compact"):
videos_top = [gr.Video(
format="mp4", label="视频 "+str(i+1),
).style(height=300, width=300) for i in range(3)]
with gr.Column(variant="panel", scale=1):
with gr.Row(variant="compact"):
videos_rest = [gr.Video(
format="mp4", label="视频 "+str(i+1),
).style(height=150, width=150) for i in range(3, show_num)]
searched_videos = videos_top + videos_rest
btn.click(t2v_search_zh, inputs=input_text, outputs=searched_videos)
with gr.Tab("English"):
with gr.Column(variant="panel"):
with gr.Row(variant="compact"):
input_text = gr.Textbox(
label="input text",
show_label=False,
max_lines=1,
placeholder="Please input text to search...",
).style(
container=False,
)
btn = gr.Button("Search").style(full_width=False)
with gr.Column(variant="panel", scale=2):
with gr.Row(variant="compact"):
videos_top = [gr.Video(
format="mp4", label="video " + str(i+1),
).style(height=300, width=300) for i in range(3)]
with gr.Column(variant="panel", scale=1):
with gr.Row(variant="compact"):
videos_rest = [gr.Video(
format="mp4", label="video " + str(i+1),
).style(height=150, width=150) for i in range(3, show_num)]
searched_videos = videos_top + videos_rest
btn.click(t2v_search_en, inputs=input_text, outputs=searched_videos)
demo.launch()
if __name__ == '__main__':
search_demo()
# text = "两个男人正在随着音乐跳舞,他们正在努力做着macarena舞蹈的动作。"
# t2v_search(text)