import numpy as np # import gradio import torch from transformers import BertTokenizer import argparse import gradio as gr import time from modules.tokenization_clip import SimpleTokenizer as ClipTokenizer from modules.modeling import BirdModel show_num = 9 max_words = 32 video_path_zh = "features/Chinese_batch_visual_output_list.npy" frame_path_zh = "features/Chinese_batch_frame_output_list.npy" video_fea_zh = np.load(video_path_zh) video_fea_zh = torch.from_numpy(video_fea_zh) frame_fea_zh = np.load(frame_path_zh) frame_fea_zh = torch.from_numpy(frame_fea_zh) video_path_en = "features/English_batch_visual_output_list.npy" frame_path_en = "features/English_batch_frame_output_list.npy" video_fea_en = np.load(video_path_en) video_fea_en = torch.from_numpy(video_fea_en) frame_fea_en = np.load(frame_path_en) frame_fea_en = torch.from_numpy(frame_fea_en) test_path = "test_list.txt" # video_dir = "test1500_400_400/" video_dir = "test1500/" with open(test_path, 'r', encoding='utf8') as f_list: lines = f_list.readlines() video_ids = [itm.strip() + ".mp4" for itm in lines] def get_videoname(idx): videoname = [] videopath = [] for i in idx: videoname.append(video_ids[i]) path = video_dir + video_ids[i] videopath.append(path) return videoname, videopath def get_text(caption, tokenizer): # tokenize word words = tokenizer.tokenize(caption) # add cls token words = ["<|startoftext|>"] + words total_length_with_CLS = max_words - 1 if len(words) > total_length_with_CLS: words = words[:total_length_with_CLS] # add end token words = words + ["<|endoftext|>"] # convert token to id according to the vocab input_ids = tokenizer.convert_tokens_to_ids(words) # add zeros for feature of the same length input_mask = [1] * len(input_ids) while len(input_ids) < max_words: input_ids.append(0) input_mask.append(0) # ensure the length of feature to be equal with max words assert len(input_ids) == max_words assert len(input_mask) == max_words pairs_text = np.array(input_ids).reshape(-1, max_words) pairs_text = torch.from_numpy(pairs_text) pairs_mask = np.array(input_mask).reshape(-1, max_words) pairs_mask = torch.from_numpy(pairs_mask) return pairs_text, pairs_mask def get_args(description='Retrieval Task'): parser = argparse.ArgumentParser(description=description) parser.add_argument("--do_pretrain", action='store_true', help="Whether to run training.") parser.add_argument("--do_train", action='store_true', help="Whether to run training.") parser.add_argument("--do_eval", action='store_true', help="Whether to run eval on the dev set.") parser.add_argument("--do_params", action='store_true', help="text the params of the model.") parser.add_argument("--use_frame_fea", action='store_true', help="whether use frame feature matching text") parser.add_argument('--task', type=str, default="retrieval", choices=["retrieval_VT", "retrieval"], help="choose downstream task.") parser.add_argument('--dataset', type=str, default="bird", choices=["bird", "msrvtt", "vatex", "msvd"], help="choose dataset.") parser.add_argument('--num_thread_reader', type=int, default=1, help='') parser.add_argument('--lr', type=float, default=0.0001, help='initial learning rate') parser.add_argument('--text_lr', type=float, default=0.00001, help='text encoder learning rate') parser.add_argument('--epochs', type=int, default=20, help='upper epoch limit') parser.add_argument('--batch_size', type=int, default=256, help='batch size') parser.add_argument('--batch_size_val', type=int, default=3500, help='batch size eval') parser.add_argument('--lr_decay', type=float, default=0.9, help='Learning rate exp epoch decay') parser.add_argument('--weight_decay', type=float, default=0.2, help='Learning rate exp epoch decay') parser.add_argument('--n_display', type=int, default=100, help='Information display frequence') parser.add_argument('--seed', type=int, default=42, help='random seed') parser.add_argument('--max_words', type=int, default=32, help='') parser.add_argument('--max_frames', type=int, default=12, help='') parser.add_argument('--top_frames', type=int, default=3, help='') parser.add_argument('--frame_sample', type=str, default="uniform", choices=["uniform", "random", "uniform_random"], help='frame sample strategy') parser.add_argument('--frame_sample_len', type=str, default="fix", choices=["dynamic", "fix"], help='use dynamic frame length of fix frame length') parser.add_argument('--language', type=str, default="chinese", choices=["chinese", "english"], help='language for text encoder') parser.add_argument('--use_temp', action='store_true', help='whether to use temporal transformer') parser.add_argument("--logdir", default=None, type=str, required=False, help="log dir for tensorboardX writer") parser.add_argument("--cross_model", default="cross-base", type=str, required=False, help="Cross module") parser.add_argument("--pretrained_text", default="hfl/chinese-roberta-wwm-ext", type=str, required=False, help="pretrained_text") parser.add_argument("--init_model", default=None, type=str, required=False, help="Initial model.") parser.add_argument("--warmup_proportion", default=0.1, type=float, help="Proportion of training to perform linear learning rate warmup for. E.g., 0.1 = 10%% of training.") parser.add_argument('--gradient_accumulation_steps', type=int, default=1, help="Number of updates steps to accumulate before performing a backward/update pass.") parser.add_argument('--n_gpu', type=int, default=1, help="Changed in the execute process.") parser.add_argument("--cache_dir", default="", type=str, help="Where do you want to store the pre-trained models downloaded from s3") parser.add_argument('--enable_amp', action='store_true', help="whether to use pytorch amp") parser.add_argument("--world_size", default=0, type=int, help="distribted training") parser.add_argument("--local_rank", default=0, type=int, help="distribted training") parser.add_argument("--rank", default=0, type=int, help="distribted training") parser.add_argument('--coef_lr', type=float, default=1., help='coefficient for bert branch.') args = parser.parse_args() # Check paramenters args.do_eval = True args.use_frame_fea = True args.use_temp = True return args def init_model(language): time1 = time.time() args = get_args() args.language = language if language == "chinese": model_path = "models/Chinese_vatex.bin" tokenizer = BertTokenizer.from_pretrained("hfl/chinese-roberta-wwm-ext") elif language == "english": model_path = "models/English_vatex.bin" tokenizer = ClipTokenizer() else: raise Exception("language should be Chinese or English!") model_state_dict = torch.load(model_path, map_location='cpu') cross_model = "cross-base" model = BirdModel.from_pretrained(cross_model, state_dict=model_state_dict, task_config=args) device = torch.device("cpu") # device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) model.eval() print("language={}".format(language)) print("init model time: {}".format(time.time() - time1)) print("device:{}".format(device)) return model, tokenizer model_zh, tokenizer_zh = init_model(language="chinese") model_en, tokenizer_en = init_model(language="english") def t2v_search_zh(text): with torch.no_grad(): time1 = time.time() text_ids, text_mask = get_text(text, tokenizer_zh) print("get_text time: {}".format(time.time() - time1)) time1 = time.time() text_fea_zh = model_zh.text_encoder(text_ids, text_mask) print("text_encoder time: {}".format(time.time() - time1)) # print("text_fea.shape:{}".format(text_fea.shape)) # print("video_fea.shape:{}".format(video_fea.shape)) # print("frame_fea.shape:{}".format(frame_fea.shape)) time1 = time.time() sim_video = model_zh.loose_similarity(text_fea_zh, video_fea_zh) # print("sim_video.shape:{}".format(sim_video.shape)) sim_frame = model_zh.loose_similarity(text_fea_zh, frame_fea_zh) # print("sim_frame.shape:{}".format(sim_frame.shape)) sim_frame = torch.topk(sim_frame, k=model_zh.top_frames, dim=1)[0] sim_frame = torch.mean(sim_frame, dim=1) sim = sim_video + sim_frame value, index = sim.topk(show_num, dim=0, largest=True, sorted=True) # value, index = sim_video.topk(show_num, dim=0, largest=True, sorted=True) print("calculate_similarity time: {}".format(time.time() - time1)) print("value:{}".format(value)) print("index:{}".format(index)) videoname, videopath = get_videoname(index) print("videoname:{}".format(videoname)) print("videopath:{}".format(videopath)) return videopath def t2v_search_en(text): with torch.no_grad(): time1 = time.time() text_ids, text_mask = get_text(text, tokenizer_en) print("get_text time: {}".format(time.time() - time1)) time1 = time.time() text_fea_en = model_en.text_encoder(text_ids, text_mask) print("text_encoder time: {}".format(time.time() - time1)) # print("text_fea.shape:{}".format(text_fea.shape)) # print("video_fea.shape:{}".format(video_fea.shape)) # print("frame_fea.shape:{}".format(frame_fea.shape)) time1 = time.time() sim_video = model_en.loose_similarity(text_fea_en, video_fea_en) # print("sim_video.shape:{}".format(sim_video.shape)) sim_frame = model_en.loose_similarity(text_fea_en, frame_fea_en) # print("sim_frame.shape:{}".format(sim_frame.shape)) sim_frame = torch.topk(sim_frame, k=model_en.top_frames, dim=1)[0] sim_frame = torch.mean(sim_frame, dim=1) sim = sim_video + sim_frame value, index = sim.topk(show_num, dim=0, largest=True, sorted=True) # value, index = sim_video.topk(show_num, dim=0, largest=True, sorted=True) print("calculate_similarity time: {}".format(time.time() - time1)) print("value:{}".format(value)) print("index:{}".format(index)) videoname, videopath = get_videoname(index) print("videoname:{}".format(videoname)) print("videopath:{}".format(videopath)) return videopath def hello_world(name): return "hello world, my name is " + name + "!" def search_demo(): with gr.Blocks() as demo: gr.Markdown("#
HMMC中英文本-视频检索 \ Github
") demo.title = "HMMC中英文本-视频检索" with gr.Tab("中文"): with gr.Column(variant="panel"): with gr.Row(variant="compact"): input_text = gr.Textbox( label="输入文本", show_label=False, max_lines=1, placeholder="请输入检索文本...", ).style( container=False, ) btn = gr.Button("搜索").style(full_width=False) with gr.Column(variant="panel", scale=2): with gr.Row(variant="compact"): videos_top = [gr.Video( format="mp4", label="视频 "+str(i+1), ).style(height=300, width=300) for i in range(3)] with gr.Column(variant="panel", scale=1): with gr.Row(variant="compact"): videos_rest = [gr.Video( format="mp4", label="视频 "+str(i+1), ).style(height=150, width=150) for i in range(3, show_num)] searched_videos = videos_top + videos_rest btn.click(t2v_search_zh, inputs=input_text, outputs=searched_videos) with gr.Tab("English"): with gr.Column(variant="panel"): with gr.Row(variant="compact"): input_text = gr.Textbox( label="input text", show_label=False, max_lines=1, placeholder="Please input text to search...", ).style( container=False, ) btn = gr.Button("Search").style(full_width=False) with gr.Column(variant="panel", scale=2): with gr.Row(variant="compact"): videos_top = [gr.Video( format="mp4", label="video " + str(i+1), ).style(height=300, width=300) for i in range(3)] with gr.Column(variant="panel", scale=1): with gr.Row(variant="compact"): videos_rest = [gr.Video( format="mp4", label="video " + str(i+1), ).style(height=150, width=150) for i in range(3, show_num)] searched_videos = videos_top + videos_rest btn.click(t2v_search_en, inputs=input_text, outputs=searched_videos) demo.launch() if __name__ == '__main__': search_demo() # text = "两个男人正在随着音乐跳舞,他们正在努力做着macarena舞蹈的动作。" # t2v_search(text)