Spaces:
Runtime error
Runtime error
import numpy as np | |
# import gradio | |
import torch | |
from transformers import BertTokenizer | |
import argparse | |
import gradio as gr | |
import time | |
from modules.tokenization_clip import SimpleTokenizer as ClipTokenizer | |
from modules.modeling import BirdModel | |
show_num = 9 | |
max_words = 32 | |
video_path_zh = "features/Chinese_batch_visual_output_list.npy" | |
frame_path_zh = "features/Chinese_batch_frame_output_list.npy" | |
video_fea_zh = np.load(video_path_zh) | |
video_fea_zh = torch.from_numpy(video_fea_zh) | |
frame_fea_zh = np.load(frame_path_zh) | |
frame_fea_zh = torch.from_numpy(frame_fea_zh) | |
video_path_en = "features/English_batch_visual_output_list.npy" | |
frame_path_en = "features/English_batch_frame_output_list.npy" | |
video_fea_en = np.load(video_path_en) | |
video_fea_en = torch.from_numpy(video_fea_en) | |
frame_fea_en = np.load(frame_path_en) | |
frame_fea_en = torch.from_numpy(frame_fea_en) | |
test_path = "test_list.txt" | |
# video_dir = "test1500_400_400/" | |
video_dir = "test1500/" | |
with open(test_path, 'r', encoding='utf8') as f_list: | |
lines = f_list.readlines() | |
video_ids = [itm.strip() + ".mp4" for itm in lines] | |
def get_videoname(idx): | |
videoname = [] | |
videopath = [] | |
for i in idx: | |
videoname.append(video_ids[i]) | |
path = video_dir + video_ids[i] | |
videopath.append(path) | |
return videoname, videopath | |
def get_text(caption, tokenizer): | |
# tokenize word | |
words = tokenizer.tokenize(caption) | |
# add cls token | |
words = ["<|startoftext|>"] + words | |
total_length_with_CLS = max_words - 1 | |
if len(words) > total_length_with_CLS: | |
words = words[:total_length_with_CLS] | |
# add end token | |
words = words + ["<|endoftext|>"] | |
# convert token to id according to the vocab | |
input_ids = tokenizer.convert_tokens_to_ids(words) | |
# add zeros for feature of the same length | |
input_mask = [1] * len(input_ids) | |
while len(input_ids) < max_words: | |
input_ids.append(0) | |
input_mask.append(0) | |
# ensure the length of feature to be equal with max words | |
assert len(input_ids) == max_words | |
assert len(input_mask) == max_words | |
pairs_text = np.array(input_ids).reshape(-1, max_words) | |
pairs_text = torch.from_numpy(pairs_text) | |
pairs_mask = np.array(input_mask).reshape(-1, max_words) | |
pairs_mask = torch.from_numpy(pairs_mask) | |
return pairs_text, pairs_mask | |
def get_args(description='Retrieval Task'): | |
parser = argparse.ArgumentParser(description=description) | |
parser.add_argument("--do_pretrain", action='store_true', help="Whether to run training.") | |
parser.add_argument("--do_train", action='store_true', help="Whether to run training.") | |
parser.add_argument("--do_eval", action='store_true', help="Whether to run eval on the dev set.") | |
parser.add_argument("--do_params", action='store_true', help="text the params of the model.") | |
parser.add_argument("--use_frame_fea", action='store_true', help="whether use frame feature matching text") | |
parser.add_argument('--task', type=str, default="retrieval", choices=["retrieval_VT", "retrieval"], | |
help="choose downstream task.") | |
parser.add_argument('--dataset', type=str, default="bird", choices=["bird", "msrvtt", "vatex", "msvd"], | |
help="choose dataset.") | |
parser.add_argument('--num_thread_reader', type=int, default=1, help='') | |
parser.add_argument('--lr', type=float, default=0.0001, help='initial learning rate') | |
parser.add_argument('--text_lr', type=float, default=0.00001, help='text encoder learning rate') | |
parser.add_argument('--epochs', type=int, default=20, help='upper epoch limit') | |
parser.add_argument('--batch_size', type=int, default=256, help='batch size') | |
parser.add_argument('--batch_size_val', type=int, default=3500, help='batch size eval') | |
parser.add_argument('--lr_decay', type=float, default=0.9, help='Learning rate exp epoch decay') | |
parser.add_argument('--weight_decay', type=float, default=0.2, help='Learning rate exp epoch decay') | |
parser.add_argument('--n_display', type=int, default=100, help='Information display frequence') | |
parser.add_argument('--seed', type=int, default=42, help='random seed') | |
parser.add_argument('--max_words', type=int, default=32, help='') | |
parser.add_argument('--max_frames', type=int, default=12, help='') | |
parser.add_argument('--top_frames', type=int, default=3, help='') | |
parser.add_argument('--frame_sample', type=str, default="uniform", choices=["uniform", "random", "uniform_random"], | |
help='frame sample strategy') | |
parser.add_argument('--frame_sample_len', type=str, default="fix", choices=["dynamic", "fix"], | |
help='use dynamic frame length of fix frame length') | |
parser.add_argument('--language', type=str, default="chinese", choices=["chinese", "english"], | |
help='language for text encoder') | |
parser.add_argument('--use_temp', action='store_true', help='whether to use temporal transformer') | |
parser.add_argument("--logdir", default=None, type=str, required=False, help="log dir for tensorboardX writer") | |
parser.add_argument("--cross_model", default="cross-base", type=str, required=False, help="Cross module") | |
parser.add_argument("--pretrained_text", default="hfl/chinese-roberta-wwm-ext", type=str, required=False, help="pretrained_text") | |
parser.add_argument("--init_model", default=None, type=str, required=False, help="Initial model.") | |
parser.add_argument("--warmup_proportion", default=0.1, type=float, | |
help="Proportion of training to perform linear learning rate warmup for. E.g., 0.1 = 10%% of training.") | |
parser.add_argument('--gradient_accumulation_steps', type=int, default=1, | |
help="Number of updates steps to accumulate before performing a backward/update pass.") | |
parser.add_argument('--n_gpu', type=int, default=1, help="Changed in the execute process.") | |
parser.add_argument("--cache_dir", default="", type=str, | |
help="Where do you want to store the pre-trained models downloaded from s3") | |
parser.add_argument('--enable_amp', action='store_true', help="whether to use pytorch amp") | |
parser.add_argument("--world_size", default=0, type=int, help="distribted training") | |
parser.add_argument("--local_rank", default=0, type=int, help="distribted training") | |
parser.add_argument("--rank", default=0, type=int, help="distribted training") | |
parser.add_argument('--coef_lr', type=float, default=1., help='coefficient for bert branch.') | |
args = parser.parse_args() | |
# Check paramenters | |
args.do_eval = True | |
args.use_frame_fea = True | |
args.use_temp = True | |
return args | |
def init_model(language): | |
time1 = time.time() | |
args = get_args() | |
args.language = language | |
if language == "chinese": | |
model_path = "models/Chinese_vatex.bin" | |
tokenizer = BertTokenizer.from_pretrained("hfl/chinese-roberta-wwm-ext") | |
elif language == "english": | |
model_path = "models/English_vatex.bin" | |
tokenizer = ClipTokenizer() | |
else: | |
raise Exception("language should be Chinese or English!") | |
model_state_dict = torch.load(model_path, map_location='cpu') | |
cross_model = "cross-base" | |
model = BirdModel.from_pretrained(cross_model, state_dict=model_state_dict, task_config=args) | |
device = torch.device("cpu") | |
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model.to(device) | |
model.eval() | |
print("language={}".format(language)) | |
print("init model time: {}".format(time.time() - time1)) | |
print("device:{}".format(device)) | |
return model, tokenizer | |
model_zh, tokenizer_zh = init_model(language="chinese") | |
model_en, tokenizer_en = init_model(language="english") | |
def t2v_search_zh(text): | |
with torch.no_grad(): | |
time1 = time.time() | |
text_ids, text_mask = get_text(text, tokenizer_zh) | |
print("get_text time: {}".format(time.time() - time1)) | |
time1 = time.time() | |
text_fea_zh = model_zh.text_encoder(text_ids, text_mask) | |
print("text_encoder time: {}".format(time.time() - time1)) | |
# print("text_fea.shape:{}".format(text_fea.shape)) | |
# print("video_fea.shape:{}".format(video_fea.shape)) | |
# print("frame_fea.shape:{}".format(frame_fea.shape)) | |
time1 = time.time() | |
sim_video = model_zh.loose_similarity(text_fea_zh, video_fea_zh) | |
# print("sim_video.shape:{}".format(sim_video.shape)) | |
sim_frame = model_zh.loose_similarity(text_fea_zh, frame_fea_zh) | |
# print("sim_frame.shape:{}".format(sim_frame.shape)) | |
sim_frame = torch.topk(sim_frame, k=model_zh.top_frames, dim=1)[0] | |
sim_frame = torch.mean(sim_frame, dim=1) | |
sim = sim_video + sim_frame | |
value, index = sim.topk(show_num, dim=0, largest=True, sorted=True) | |
# value, index = sim_video.topk(show_num, dim=0, largest=True, sorted=True) | |
print("calculate_similarity time: {}".format(time.time() - time1)) | |
print("value:{}".format(value)) | |
print("index:{}".format(index)) | |
videoname, videopath = get_videoname(index) | |
print("videoname:{}".format(videoname)) | |
print("videopath:{}".format(videopath)) | |
return videopath | |
def t2v_search_en(text): | |
with torch.no_grad(): | |
time1 = time.time() | |
text_ids, text_mask = get_text(text, tokenizer_en) | |
print("get_text time: {}".format(time.time() - time1)) | |
time1 = time.time() | |
text_fea_en = model_en.text_encoder(text_ids, text_mask) | |
print("text_encoder time: {}".format(time.time() - time1)) | |
# print("text_fea.shape:{}".format(text_fea.shape)) | |
# print("video_fea.shape:{}".format(video_fea.shape)) | |
# print("frame_fea.shape:{}".format(frame_fea.shape)) | |
time1 = time.time() | |
sim_video = model_en.loose_similarity(text_fea_en, video_fea_en) | |
# print("sim_video.shape:{}".format(sim_video.shape)) | |
sim_frame = model_en.loose_similarity(text_fea_en, frame_fea_en) | |
# print("sim_frame.shape:{}".format(sim_frame.shape)) | |
sim_frame = torch.topk(sim_frame, k=model_en.top_frames, dim=1)[0] | |
sim_frame = torch.mean(sim_frame, dim=1) | |
sim = sim_video + sim_frame | |
value, index = sim.topk(show_num, dim=0, largest=True, sorted=True) | |
# value, index = sim_video.topk(show_num, dim=0, largest=True, sorted=True) | |
print("calculate_similarity time: {}".format(time.time() - time1)) | |
print("value:{}".format(value)) | |
print("index:{}".format(index)) | |
videoname, videopath = get_videoname(index) | |
print("videoname:{}".format(videoname)) | |
print("videopath:{}".format(videopath)) | |
return videopath | |
def hello_world(name): | |
return "hello world, my name is " + name + "!" | |
def search_demo(): | |
with gr.Blocks() as demo: | |
gr.Markdown("# <div align='center'>HMMC中英文本-视频检索 \ | |
<a style='font-size:18px;color: #000000' href='https://github.com/cheetah003/HMMC'> Github </div>") | |
demo.title = "HMMC中英文本-视频检索" | |
with gr.Tab("中文"): | |
with gr.Column(variant="panel"): | |
with gr.Row(variant="compact"): | |
input_text = gr.Textbox( | |
label="输入文本", | |
show_label=False, | |
max_lines=1, | |
placeholder="请输入检索文本...", | |
).style( | |
container=False, | |
) | |
btn = gr.Button("搜索").style(full_width=False) | |
with gr.Column(variant="panel", scale=2): | |
with gr.Row(variant="compact"): | |
videos_top = [gr.Video( | |
format="mp4", label="视频 "+str(i+1), | |
).style(height=300, width=300) for i in range(3)] | |
with gr.Column(variant="panel", scale=1): | |
with gr.Row(variant="compact"): | |
videos_rest = [gr.Video( | |
format="mp4", label="视频 "+str(i), | |
).style(height=150, width=150) for i in range(3, show_num)] | |
searched_videos = videos_top + videos_rest | |
btn.click(t2v_search_zh, inputs=input_text, outputs=searched_videos) | |
with gr.Tab("English"): | |
with gr.Column(variant="panel"): | |
with gr.Row(variant="compact"): | |
input_text = gr.Textbox( | |
label="input text", | |
show_label=False, | |
max_lines=1, | |
placeholder="Please input text to search...", | |
).style( | |
container=False, | |
) | |
btn = gr.Button("Search").style(full_width=False) | |
with gr.Column(variant="panel", scale=2): | |
with gr.Row(variant="compact"): | |
videos_top = [gr.Video( | |
format="mp4", label="video " + str(i+1), | |
).style(height=300, width=300) for i in range(3)] | |
with gr.Column(variant="panel", scale=1): | |
with gr.Row(variant="compact"): | |
videos_rest = [gr.Video( | |
format="mp4", label="video " + str(i+1), | |
).style(height=150, width=150) for i in range(3, show_num)] | |
searched_videos = videos_top + videos_rest | |
btn.click(t2v_search_en, inputs=input_text, outputs=searched_videos) | |
demo.launch() | |
if __name__ == '__main__': | |
search_demo() | |
# text = "两个男人正在随着音乐跳舞,他们正在努力做着macarena舞蹈的动作。" | |
# t2v_search(text) | |