Spaces:
Runtime error
Runtime error
KevinQHLin
commited on
Commit
•
9d0a4ae
1
Parent(s):
eab7b75
Upload 60 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +3 -0
- app.py +236 -0
- examples/charades.mp4 +3 -0
- examples/ego4d.mp4 +3 -0
- examples/youtube.mp4 +3 -0
- main/__init__.py +0 -0
- main/_train_qfvs.py +293 -0
- main/config.py +378 -0
- main/config_hl.py +190 -0
- main/config_qfvs.json +14 -0
- main/dataset.py +1261 -0
- main/dataset_qfvs.py +284 -0
- main/inference_demo.py +81 -0
- main/inference_hl.py +229 -0
- main/inference_mr.py +273 -0
- main/inference_qfvs.py +342 -0
- main/train_hl.py +229 -0
- main/train_mr.py +266 -0
- main/train_qfvs.py +325 -0
- main/train_vlp.py +278 -0
- main/train_vlp_ddp.py +288 -0
- model/base.py +449 -0
- model/base_albef.py +478 -0
- model/base_droppath.py +449 -0
- model/base_droppath_ablation.py +474 -0
- model/base_droppath_qfvs.py +476 -0
- model/base_prompt.py +460 -0
- model/base_qfvs.py +476 -0
- model/matcher.py +107 -0
- model/moment_detr.py +462 -0
- model/position_encoding.py +126 -0
- model/transformer.py +471 -0
- model/transformer_encoder.py +159 -0
- model/transformer_encoder_droppath.py +194 -0
- model/univtg.py +450 -0
- model/univtg_ablation.py +474 -0
- model/univtg_qfvs.py +476 -0
- requirements.txt +355 -0
- results/omni/opt.json +111 -0
- run_on_video/__init__.py +1 -0
- run_on_video/clip/__init__.py +1 -0
- run_on_video/clip/bpe_simple_vocab_16e6.txt.gz +3 -0
- run_on_video/clip/clip.py +195 -0
- run_on_video/clip/model.py +432 -0
- run_on_video/clip/simple_tokenizer.py +132 -0
- run_on_video/clip_feature_extractor.py +101 -0
- run_on_video/data_utils.py +170 -0
- run_on_video/preprocessing.py +25 -0
- run_on_video/text_extractor.py +36 -0
- run_on_video/video_extractor.py +94 -0
.gitattributes
CHANGED
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
examples/charades.mp4 filter=lfs diff=lfs merge=lfs -text
|
37 |
+
examples/ego4d.mp4 filter=lfs diff=lfs merge=lfs -text
|
38 |
+
examples/youtube.mp4 filter=lfs diff=lfs merge=lfs -text
|
app.py
ADDED
@@ -0,0 +1,236 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import pdb
|
3 |
+
import time
|
4 |
+
import torch
|
5 |
+
import gradio as gr
|
6 |
+
import numpy as np
|
7 |
+
import argparse
|
8 |
+
import subprocess
|
9 |
+
from run_on_video import clip, vid2clip, txt2clip
|
10 |
+
|
11 |
+
parser = argparse.ArgumentParser(description='')
|
12 |
+
parser.add_argument('--save_dir', type=str, default='./tmp')
|
13 |
+
parser.add_argument('--resume', type=str, default='./results/omni/model_best.ckpt')
|
14 |
+
parser.add_argument("--gpu_id", type=int, default=2)
|
15 |
+
args = parser.parse_args()
|
16 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu_id)
|
17 |
+
|
18 |
+
#################################
|
19 |
+
model_version = "ViT-B/32"
|
20 |
+
output_feat_size = 512
|
21 |
+
clip_len = 2
|
22 |
+
overwrite = True
|
23 |
+
num_decoding_thread = 4
|
24 |
+
half_precision = False
|
25 |
+
|
26 |
+
clip_model, _ = clip.load(model_version, device=args.gpu_id, jit=False)
|
27 |
+
|
28 |
+
import logging
|
29 |
+
import torch.backends.cudnn as cudnn
|
30 |
+
from main.config import TestOptions, setup_model
|
31 |
+
from utils.basic_utils import l2_normalize_np_array
|
32 |
+
|
33 |
+
logger = logging.getLogger(__name__)
|
34 |
+
logging.basicConfig(format="%(asctime)s.%(msecs)03d:%(levelname)s:%(name)s - %(message)s",
|
35 |
+
datefmt="%Y-%m-%d %H:%M:%S",
|
36 |
+
level=logging.INFO)
|
37 |
+
|
38 |
+
def load_model():
|
39 |
+
logger.info("Setup config, data and model...")
|
40 |
+
opt = TestOptions().parse(args)
|
41 |
+
# pdb.set_trace()
|
42 |
+
cudnn.benchmark = True
|
43 |
+
cudnn.deterministic = False
|
44 |
+
|
45 |
+
if opt.lr_warmup > 0:
|
46 |
+
total_steps = opt.n_epoch
|
47 |
+
warmup_steps = opt.lr_warmup if opt.lr_warmup > 1 else int(opt.lr_warmup * total_steps)
|
48 |
+
opt.lr_warmup = [warmup_steps, total_steps]
|
49 |
+
|
50 |
+
model, criterion, _, _ = setup_model(opt)
|
51 |
+
return model
|
52 |
+
|
53 |
+
vtg_model = load_model()
|
54 |
+
|
55 |
+
def convert_to_hms(seconds):
|
56 |
+
return time.strftime('%H:%M:%S', time.gmtime(seconds))
|
57 |
+
|
58 |
+
def load_data(save_dir):
|
59 |
+
vid = np.load(os.path.join(save_dir, 'vid.npz'))['features'].astype(np.float32)
|
60 |
+
txt = np.load(os.path.join(save_dir, 'txt.npz'))['features'].astype(np.float32)
|
61 |
+
|
62 |
+
vid = torch.from_numpy(l2_normalize_np_array(vid))
|
63 |
+
txt = torch.from_numpy(l2_normalize_np_array(txt))
|
64 |
+
clip_len = 2
|
65 |
+
ctx_l = vid.shape[0]
|
66 |
+
|
67 |
+
timestamp = ( (torch.arange(0, ctx_l) + clip_len / 2) / ctx_l).unsqueeze(1).repeat(1, 2)
|
68 |
+
|
69 |
+
if True:
|
70 |
+
tef_st = torch.arange(0, ctx_l, 1.0) / ctx_l
|
71 |
+
tef_ed = tef_st + 1.0 / ctx_l
|
72 |
+
tef = torch.stack([tef_st, tef_ed], dim=1) # (Lv, 2)
|
73 |
+
vid = torch.cat([vid, tef], dim=1) # (Lv, Dv+2)
|
74 |
+
|
75 |
+
src_vid = vid.unsqueeze(0).cuda()
|
76 |
+
src_txt = txt.unsqueeze(0).cuda()
|
77 |
+
src_vid_mask = torch.ones(src_vid.shape[0], src_vid.shape[1]).cuda()
|
78 |
+
src_txt_mask = torch.ones(src_txt.shape[0], src_txt.shape[1]).cuda()
|
79 |
+
|
80 |
+
return src_vid, src_txt, src_vid_mask, src_txt_mask, timestamp, ctx_l
|
81 |
+
|
82 |
+
def forward(model, save_dir, query):
|
83 |
+
src_vid, src_txt, src_vid_mask, src_txt_mask, timestamp, ctx_l = load_data(save_dir)
|
84 |
+
src_vid = src_vid.cuda(args.gpu_id)
|
85 |
+
src_txt = src_txt.cuda(args.gpu_id)
|
86 |
+
src_vid_mask = src_vid_mask.cuda(args.gpu_id)
|
87 |
+
src_txt_mask = src_txt_mask.cuda(args.gpu_id)
|
88 |
+
|
89 |
+
with torch.no_grad():
|
90 |
+
output = model(src_vid=src_vid, src_txt=src_txt, src_vid_mask=src_vid_mask, src_txt_mask=src_txt_mask)
|
91 |
+
|
92 |
+
# prepare the model prediction
|
93 |
+
pred_logits = output['pred_logits'][0].cpu()
|
94 |
+
pred_spans = output['pred_spans'][0].cpu()
|
95 |
+
pred_saliency = output['saliency_scores'].cpu()
|
96 |
+
|
97 |
+
# prepare the model prediction
|
98 |
+
pred_windows = (pred_spans + timestamp) * ctx_l * clip_len
|
99 |
+
pred_confidence = pred_logits
|
100 |
+
|
101 |
+
# grounding
|
102 |
+
top1_window = pred_windows[torch.argmax(pred_confidence)].tolist()
|
103 |
+
top5_values, top5_indices = torch.topk(pred_confidence.flatten(), k=5)
|
104 |
+
top5_windows = pred_windows[top5_indices].tolist()
|
105 |
+
|
106 |
+
# print(f"The video duration is {convert_to_hms(src_vid.shape[1]*clip_len)}.")
|
107 |
+
q_response = f"For query: {query}"
|
108 |
+
|
109 |
+
mr_res = " - ".join([convert_to_hms(int(i)) for i in top1_window])
|
110 |
+
mr_response = f"The Top-1 interval is: {mr_res}"
|
111 |
+
|
112 |
+
hl_res = convert_to_hms(torch.argmax(pred_saliency) * clip_len)
|
113 |
+
hl_response = f"The Top-1 highlight is: {hl_res}"
|
114 |
+
return '\n'.join([q_response, mr_response, hl_response])
|
115 |
+
|
116 |
+
def extract_vid(vid_path, state):
|
117 |
+
history = state['messages']
|
118 |
+
vid_features = vid2clip(clip_model, vid_path, args.save_dir)
|
119 |
+
history.append({"role": "user", "content": "Finish extracting video features."})
|
120 |
+
history.append({"role": "system", "content": "Please Enter the text query."})
|
121 |
+
chat_messages = [(history[i]['content'], history[i+1]['content']) for i in range(0, len(history),2)]
|
122 |
+
return '', chat_messages, state
|
123 |
+
|
124 |
+
def extract_txt(txt):
|
125 |
+
txt_features = txt2clip(clip_model, txt, args.save_dir)
|
126 |
+
return
|
127 |
+
|
128 |
+
def download_video(url, save_dir='./examples', size=768):
|
129 |
+
save_path = f'{save_dir}/{url}.mp4'
|
130 |
+
cmd = f'yt-dlp -S ext:mp4:m4a --throttled-rate 5M -f "best[width<={size}][height<={size}]" --output {save_path} --merge-output-format mp4 https://www.youtube.com/embed/{url}'
|
131 |
+
if not os.path.exists(save_path):
|
132 |
+
try:
|
133 |
+
subprocess.call(cmd, shell=True)
|
134 |
+
except:
|
135 |
+
return None
|
136 |
+
return save_path
|
137 |
+
|
138 |
+
def get_empty_state():
|
139 |
+
return {"total_tokens": 0, "messages": []}
|
140 |
+
|
141 |
+
def submit_message(prompt, state):
|
142 |
+
history = state['messages']
|
143 |
+
|
144 |
+
if not prompt:
|
145 |
+
return gr.update(value=''), [(history[i]['content'], history[i+1]['content']) for i in range(0, len(history)-1, 2)], state
|
146 |
+
|
147 |
+
prompt_msg = { "role": "user", "content": prompt }
|
148 |
+
|
149 |
+
try:
|
150 |
+
history.append(prompt_msg)
|
151 |
+
# answer = vlogger.chat2video(prompt)
|
152 |
+
# answer = prompt
|
153 |
+
extract_txt(prompt)
|
154 |
+
answer = forward(vtg_model, args.save_dir, prompt)
|
155 |
+
history.append({"role": "system", "content": answer})
|
156 |
+
|
157 |
+
except Exception as e:
|
158 |
+
history.append(prompt_msg)
|
159 |
+
history.append({
|
160 |
+
"role": "system",
|
161 |
+
"content": f"Error: {e}"
|
162 |
+
})
|
163 |
+
|
164 |
+
chat_messages = [(history[i]['content'], history[i+1]['content']) for i in range(0, len(history)-1, 2)]
|
165 |
+
return '', chat_messages, state
|
166 |
+
|
167 |
+
|
168 |
+
def clear_conversation():
|
169 |
+
return gr.update(value=None, visible=True), gr.update(value=None, interactive=True), None, gr.update(value=None, visible=True), get_empty_state()
|
170 |
+
|
171 |
+
|
172 |
+
def subvid_fn(vid):
|
173 |
+
save_path = download_video(vid)
|
174 |
+
return gr.update(value=save_path)
|
175 |
+
|
176 |
+
|
177 |
+
css = """
|
178 |
+
#col-container {max-width: 80%; margin-left: auto; margin-right: auto;}
|
179 |
+
#video_inp {min-height: 100px}
|
180 |
+
#chatbox {min-height: 100px;}
|
181 |
+
#header {text-align: center;}
|
182 |
+
#hint {font-size: 1.0em; padding: 0.5em; margin: 0;}
|
183 |
+
.message { font-size: 1.2em; }
|
184 |
+
"""
|
185 |
+
|
186 |
+
with gr.Blocks(css=css) as demo:
|
187 |
+
|
188 |
+
state = gr.State(get_empty_state())
|
189 |
+
|
190 |
+
|
191 |
+
with gr.Column(elem_id="col-container"):
|
192 |
+
gr.Markdown("""## 🤖️ UniVTG: Towards Unified Video-Language Temporal Grounding
|
193 |
+
Given a video and text query, return relevant window and highlight.""",
|
194 |
+
elem_id="header")
|
195 |
+
|
196 |
+
with gr.Row():
|
197 |
+
with gr.Column():
|
198 |
+
video_inp = gr.Video(label="video_input")
|
199 |
+
gr.Markdown("👋 **Step1**: Select a video in Examples (bottom) or input youtube video_id in this textbox, *e.g.* *G7zJK6lcbyU* for https://www.youtube.com/watch?v=G7zJK6lcbyU", elem_id="hint")
|
200 |
+
with gr.Row():
|
201 |
+
video_id = gr.Textbox(value="", placeholder="Youtube video url", show_label=False)
|
202 |
+
vidsub_btn = gr.Button("(Optional) Submit Youtube id")
|
203 |
+
|
204 |
+
with gr.Column():
|
205 |
+
vid_ext = gr.Button("Step2: Extract video feature, may takes a while")
|
206 |
+
# vlog_outp = gr.Textbox(label="Document output", lines=40)
|
207 |
+
total_tokens_str = gr.Markdown(elem_id="total_tokens_str")
|
208 |
+
|
209 |
+
chatbot = gr.Chatbot(elem_id="chatbox")
|
210 |
+
input_message = gr.Textbox(show_label=False, placeholder="Enter text query and press enter", visible=True).style(container=False)
|
211 |
+
btn_submit = gr.Button("Step3: Enter your text query")
|
212 |
+
btn_clear_conversation = gr.Button("🔃 Clear")
|
213 |
+
|
214 |
+
examples = gr.Examples(
|
215 |
+
examples=[
|
216 |
+
["./examples/youtube.mp4"],
|
217 |
+
["./examples/charades.mp4"],
|
218 |
+
["./examples/ego4d.mp4"],
|
219 |
+
],
|
220 |
+
inputs=[video_inp],
|
221 |
+
)
|
222 |
+
|
223 |
+
gr.HTML('''<br><br><br><center>You can duplicate this Space to skip the queue:<a href="https://huggingface.co/spaces/anzorq/chatgpt-demo?duplicate=true"><img src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a><br></center>''')
|
224 |
+
|
225 |
+
btn_submit.click(submit_message, [input_message, state], [input_message, chatbot])
|
226 |
+
input_message.submit(submit_message, [input_message, state], [input_message, chatbot])
|
227 |
+
# btn_clear_conversation.click(clear_conversation, [], [input_message, video_inp, chatbot, vlog_outp, state])
|
228 |
+
btn_clear_conversation.click(clear_conversation, [], [input_message, video_inp, chatbot, state])
|
229 |
+
vid_ext.click(extract_vid, [video_inp, state], [input_message, chatbot])
|
230 |
+
vidsub_btn.click(subvid_fn, [video_id], [video_inp])
|
231 |
+
|
232 |
+
demo.load(queur=False)
|
233 |
+
|
234 |
+
|
235 |
+
demo.queue(concurrency_count=10)
|
236 |
+
demo.launch(height='800px', server_port=2253, debug=True, share=True)
|
examples/charades.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:fa3d1ba99bf28103844e1313cc5543b7c626d87c42a1c18108c2a69479a6d679
|
3 |
+
size 1301669
|
examples/ego4d.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:cf1271d42415c793e659bebbd48394326cc50e970d44e6fdd0af5dfb4cb4ede4
|
3 |
+
size 28306388
|
examples/youtube.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1dd6b483e5346a777b5d6448460c5e30b8fe46aa1133cf6bba94c84dd7262b49
|
3 |
+
size 47353846
|
main/__init__.py
ADDED
File without changes
|
main/_train_qfvs.py
ADDED
@@ -0,0 +1,293 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import pdb
|
3 |
+
import time
|
4 |
+
import json
|
5 |
+
import pprint
|
6 |
+
import random
|
7 |
+
import importlib
|
8 |
+
import numpy as np
|
9 |
+
from tqdm import tqdm, trange
|
10 |
+
from collections import defaultdict
|
11 |
+
|
12 |
+
import h5py
|
13 |
+
import torch
|
14 |
+
import torch.nn as nn
|
15 |
+
import torch.backends.cudnn as cudnn
|
16 |
+
from torch.utils.data import DataLoader
|
17 |
+
from torch.utils.tensorboard import SummaryWriter
|
18 |
+
|
19 |
+
import sys
|
20 |
+
sys.path.append('/data/home/qinghonglin/univtg')
|
21 |
+
from main.config import BaseOptions, setup_model
|
22 |
+
from main.dataset import DatasetQFVS, prepare_batch_inputs_qfvs, start_end_collate_qfvs
|
23 |
+
from utils.basic_utils import set_seed, AverageMeter, dict_to_markdown, save_json, save_jsonl, load_json, load_pickle
|
24 |
+
from utils.model_utils import count_parameters
|
25 |
+
from eval.qfvs import calculate_semantic_matching, load_videos_tag
|
26 |
+
|
27 |
+
import logging
|
28 |
+
logger = logging.getLogger(__name__)
|
29 |
+
logging.basicConfig(format="%(asctime)s.%(msecs)03d:%(levelname)s:%(name)s - %(message)s",
|
30 |
+
datefmt="%Y-%m-%d %H:%M:%S",
|
31 |
+
level=logging.INFO)
|
32 |
+
|
33 |
+
def eval_epoch(model, config, opt):
|
34 |
+
model.eval()
|
35 |
+
f1_sum = 0; p_sum = 0; r_sum = 0
|
36 |
+
|
37 |
+
assert len(config['test_videos']) == 1
|
38 |
+
video_id = config['test_videos'][0]
|
39 |
+
embedding = load_pickle(f"./data/qfvs/txt_clip/{config['txt_feature']}.pkl")
|
40 |
+
|
41 |
+
feat_type = config['vid_feature']
|
42 |
+
feat = h5py.File(f'./data/qfvs/processed/P0{video_id}_{feat_type}.h5', 'r')
|
43 |
+
features = torch.tensor(feat['feature'][()]).unsqueeze(0).cuda()
|
44 |
+
# pdb.set_trace()
|
45 |
+
# seg_len = torch.tensor(feat['seg_len'][()]).unsqueeze(0).cuda()
|
46 |
+
|
47 |
+
# dim = features.shape[-1]
|
48 |
+
# ctx_l = seg_len.sum().cpu()
|
49 |
+
|
50 |
+
dim = features.shape[-1]
|
51 |
+
ctx_l = features.shape[1]
|
52 |
+
seg_len = torch.ones(ctx_l)
|
53 |
+
features = features.reshape(-1, dim)[:ctx_l]
|
54 |
+
|
55 |
+
tef_st = torch.arange(0, ctx_l, 1.0) / ctx_l
|
56 |
+
tef_ed = tef_st + 1.0 / ctx_l
|
57 |
+
tef = torch.stack([tef_st, tef_ed], dim=1).cuda() # (Lv, 2)
|
58 |
+
features = torch.cat([features, tef], dim=1) # (Lv, Dv+2)
|
59 |
+
|
60 |
+
transfer = {"Cupglass": "Glass",
|
61 |
+
"Musicalinstrument": "Instrument",
|
62 |
+
"Petsanimal": "Animal"}
|
63 |
+
|
64 |
+
for _,_,files in os.walk("./data/qfvs/metadata/origin_data/Query-Focused_Summaries/Oracle_Summaries/P0"+str(video_id)):
|
65 |
+
evaluation_num=len(files)
|
66 |
+
for file in files:
|
67 |
+
summaries_GT=[]
|
68 |
+
with open("./data/qfvs/metadata/origin_data/Query-Focused_Summaries/Oracle_Summaries/P0"+str(video_id)+"/"+file,"r") as f:
|
69 |
+
for line in f.readlines():
|
70 |
+
summaries_GT.append(int(line.strip()))
|
71 |
+
|
72 |
+
concept1, concept2 = file.split('_')[0:2]
|
73 |
+
|
74 |
+
##############
|
75 |
+
if concept1 in transfer:
|
76 |
+
concept1 = transfer[concept1]
|
77 |
+
if concept2 in transfer:
|
78 |
+
concept2 = transfer[concept2]
|
79 |
+
concept1 = embedding[concept1]
|
80 |
+
concept2 = embedding[concept2]
|
81 |
+
|
82 |
+
data = {
|
83 |
+
'features':features,
|
84 |
+
'seg_len': seg_len,
|
85 |
+
'tokens_pad1':torch.from_numpy(concept1),
|
86 |
+
'tokens_pad2':torch.from_numpy(concept2),
|
87 |
+
}
|
88 |
+
|
89 |
+
input1, input2, input_oracle, mask = prepare_batch_inputs_qfvs(start_end_collate_qfvs([data]), config, eval=True)
|
90 |
+
|
91 |
+
summaries_GT = [x - 1 for x in summaries_GT]
|
92 |
+
video_shots_tag = load_videos_tag(mat_path="./eval/Tags.mat")
|
93 |
+
|
94 |
+
|
95 |
+
output_type = 'pred_logits' # only saliency.
|
96 |
+
# if opt.f_loss_coef == 0:
|
97 |
+
# output_type = 'saliency_scores' # only saliency.
|
98 |
+
# elif opt.s_loss_intra_coef == 0:
|
99 |
+
# output_type = 'pred_logits' # cls is default.
|
100 |
+
# else:
|
101 |
+
# output_type = ['pred_logits', 'saliency_scores']
|
102 |
+
|
103 |
+
# if opt.qfvs_score_multiple > 0:
|
104 |
+
# output_type = ['pred_logits', 'saliency_scores']
|
105 |
+
|
106 |
+
with torch.no_grad():
|
107 |
+
if not isinstance(output_type, list):
|
108 |
+
score1 = model(**input1)[output_type].squeeze()
|
109 |
+
# score1 = score1.masked_select(mask)
|
110 |
+
score2 = model(**input2)[output_type].squeeze()
|
111 |
+
# score2 = score2.masked_select(mask)
|
112 |
+
|
113 |
+
score = model(**input_oracle)[output_type].squeeze()
|
114 |
+
# score = score.masked_select(mask)
|
115 |
+
else:
|
116 |
+
score1, score2, score = torch.zeros((int(mask.sum().item()))).cuda(), torch.zeros((int(mask.sum().item()))).cuda(), torch.zeros((int(mask.sum().item()))).cuda()
|
117 |
+
for output_t in output_type:
|
118 |
+
# score1 *= model(**input1)[output_t].squeeze() #.masked_select(mask)
|
119 |
+
# score2 *= model(**input2)[output_t].squeeze() #.masked_select(mask)
|
120 |
+
# score *= model(**input_oracle)[output_t].squeeze() #.masked_select(mask)
|
121 |
+
score1 += model(**input1)[output_t].squeeze() #.masked_select(mask)
|
122 |
+
score2 += model(**input2)[output_t].squeeze() #.masked_select(mask)
|
123 |
+
score += model(**input_oracle)[output_t].squeeze() #.masked_select(mask)
|
124 |
+
|
125 |
+
score = score
|
126 |
+
# score = score + score1 + score2
|
127 |
+
|
128 |
+
# since video4 features dim is greater than video_shots_tag.
|
129 |
+
score = score[:min(score.shape[0], video_shots_tag[video_id-1].shape[0])]
|
130 |
+
_, top_index = score.topk(int(score.shape[0] * config["top_percent"]))
|
131 |
+
p, r, f1 = calculate_semantic_matching(list(top_index.cpu().numpy()), summaries_GT, video_shots_tag, video_id=video_id-1)
|
132 |
+
f1_sum+=f1; r_sum+=r; p_sum+=p
|
133 |
+
|
134 |
+
return {'F': round(100* f1_sum/evaluation_num,2) ,
|
135 |
+
'R': round(100* r_sum/evaluation_num,2) ,
|
136 |
+
'P': round(100* p_sum/evaluation_num,2) }
|
137 |
+
|
138 |
+
def train_epoch(model, criterion, train_loader, optimizer, opt, config, epoch_i, tb_writer):
|
139 |
+
model.train()
|
140 |
+
criterion.train()
|
141 |
+
|
142 |
+
# init meters
|
143 |
+
time_meters = defaultdict(AverageMeter)
|
144 |
+
loss_meters = defaultdict(AverageMeter)
|
145 |
+
|
146 |
+
timer_dataloading = time.time()
|
147 |
+
loss_total = 0
|
148 |
+
|
149 |
+
# optimizer.zero_grad()
|
150 |
+
for batch_idx, batch in enumerate(tqdm(train_loader)):
|
151 |
+
time_meters["dataloading_time"].update(time.time() - timer_dataloading)
|
152 |
+
timer_start = time.time()
|
153 |
+
model_input1, model_input2, model_input_oracle, \
|
154 |
+
model_gt1, model_gt2, model_gt_oracle, \
|
155 |
+
mask_GT = prepare_batch_inputs_qfvs(batch, config)
|
156 |
+
time_meters["prepare_inputs_time"].update(time.time() - timer_start)
|
157 |
+
|
158 |
+
timer_start = time.time()
|
159 |
+
output1 = model(**model_input1)
|
160 |
+
output2 = model(**model_input2)
|
161 |
+
output_oracle = model(**model_input_oracle)
|
162 |
+
|
163 |
+
loss_dict = {}
|
164 |
+
loss_dict1 = criterion(output1, model_gt1)
|
165 |
+
loss_dict2 = criterion(output2, model_gt2)
|
166 |
+
loss_dict3 = criterion(output_oracle, model_gt_oracle)
|
167 |
+
|
168 |
+
weight_dict = criterion.weight_dict
|
169 |
+
for k in loss_dict1.keys():
|
170 |
+
loss_dict[k] = loss_dict1[k] + loss_dict2[k] + loss_dict3[k]
|
171 |
+
|
172 |
+
# print(loss_dict)
|
173 |
+
losses = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict)
|
174 |
+
loss_total += losses.item()
|
175 |
+
|
176 |
+
time_meters["model_forward_time"].update(time.time() - timer_start)
|
177 |
+
timer_start = time.time()
|
178 |
+
# optimizer.zero_grad()
|
179 |
+
optimizer.zero_grad()
|
180 |
+
losses.backward()
|
181 |
+
if opt.grad_clip > 0:
|
182 |
+
nn.utils.clip_grad_norm_(model.parameters(), opt.grad_clip)
|
183 |
+
# if ((batch_idx + 1) % opt.bsz==0) or (batch_idx == len(train_loader)-1):
|
184 |
+
# pdb.set_trace()
|
185 |
+
# optimizer.step()
|
186 |
+
# optimizer.zero_grad()
|
187 |
+
optimizer.step()
|
188 |
+
time_meters["model_backward_time"].update(time.time() - timer_start)
|
189 |
+
|
190 |
+
timer_dataloading = time.time()
|
191 |
+
return round(loss_total / len(train_loader), 2)
|
192 |
+
|
193 |
+
# train in single domain.
|
194 |
+
def train(model, criterion, optimizer, lr_scheduler, train_loader, opt, config):
|
195 |
+
if opt.device.type == "cuda":
|
196 |
+
logger.info("CUDA enabled.")
|
197 |
+
model.to(opt.device)
|
198 |
+
|
199 |
+
tb_writer = SummaryWriter(opt.tensorboard_log_dir)
|
200 |
+
tb_writer.add_text("hyperparameters", dict_to_markdown(vars(opt), max_str_len=None))
|
201 |
+
opt.train_log_txt_formatter = "{time_str} [Epoch] {epoch:03d} [Loss] {loss_str}\n"
|
202 |
+
opt.eval_log_txt_formatter = "{time_str} [Epoch] {epoch:03d} [Loss] {loss_str} [Metrics] {eval_metrics_str}\n"
|
203 |
+
|
204 |
+
prev_best_score = {'Fscore':0, 'Precision':0, 'Recall':0}
|
205 |
+
if opt.start_epoch is None:
|
206 |
+
start_epoch = -1 if opt.eval_init else 0
|
207 |
+
else:
|
208 |
+
start_epoch = opt.start_epoch
|
209 |
+
|
210 |
+
val_score = eval_epoch(model, config, opt)
|
211 |
+
tb_writer.add_scalar(f"Eval/QFVS-V{config['test_videos'][0]}-fscore", float(val_score['F']), 0)
|
212 |
+
logger.info(f"[Epoch {0}] [Fscore: {val_score['F']} / {prev_best_score['Fscore']}]"
|
213 |
+
f" [Precision: {val_score['P']} / {prev_best_score['Precision']}]"
|
214 |
+
f" [Recall: {val_score['R']} / {prev_best_score['Recall']}]")
|
215 |
+
for epoch_i in trange(start_epoch, opt.n_epoch, desc="Epoch"):
|
216 |
+
if epoch_i > -1:
|
217 |
+
loss_epoch = train_epoch(model, criterion, train_loader, optimizer, opt, config, epoch_i, tb_writer)
|
218 |
+
lr_scheduler.step()
|
219 |
+
eval_epoch_interval = opt.eval_epoch
|
220 |
+
if opt.eval_path is not None and (epoch_i + 1) % eval_epoch_interval == 0:
|
221 |
+
with torch.no_grad():
|
222 |
+
val_score = eval_epoch(model, config, opt)
|
223 |
+
tb_writer.add_scalar(f"Eval/QFVS-V{config['test_videos'][0]}-fscore", float(val_score['F']), epoch_i+1)
|
224 |
+
logger.info(f"[Epoch {epoch_i + 1}, Loss {loss_epoch}] [Fscore: {val_score['F']} / {prev_best_score['Fscore']}]"
|
225 |
+
f" [Precision: {val_score['P']} / {prev_best_score['Precision']}]"
|
226 |
+
f" [Recall: {val_score['R']} / {prev_best_score['Recall']}]")
|
227 |
+
|
228 |
+
if prev_best_score['Fscore'] < val_score['F']:
|
229 |
+
prev_best_score['Fscore'] = val_score['F']
|
230 |
+
prev_best_score['Precision'] = val_score['P']
|
231 |
+
prev_best_score['Recall'] = val_score['R']
|
232 |
+
|
233 |
+
checkpoint = {
|
234 |
+
"model": model.state_dict(),
|
235 |
+
"optimizer": optimizer.state_dict(),
|
236 |
+
"epoch": epoch_i,
|
237 |
+
"opt": opt
|
238 |
+
}
|
239 |
+
torch.save(checkpoint, opt.ckpt_filepath.replace(".ckpt", f"_V{config['test_videos'][0]}_best.ckpt"))
|
240 |
+
tb_writer.close()
|
241 |
+
return prev_best_score
|
242 |
+
|
243 |
+
def start_training():
|
244 |
+
logger.info("Setup config, data and model...")
|
245 |
+
opt = BaseOptions().parse()
|
246 |
+
set_seed(opt.seed)
|
247 |
+
|
248 |
+
config = load_json("./main/config_qfvs.json")
|
249 |
+
|
250 |
+
tb_writer = SummaryWriter(opt.tensorboard_log_dir)
|
251 |
+
|
252 |
+
# key -> test video; value -> training videos.
|
253 |
+
qfvs_split = {1: [2, 3, 4],
|
254 |
+
2: [1, 3, 4],
|
255 |
+
3: [1, 2, 4],
|
256 |
+
4: [1, 2, 3]}
|
257 |
+
# qfvs_split = {
|
258 |
+
# 2: [1, 3, 4],
|
259 |
+
# 3: [1, 2, 4],
|
260 |
+
# }
|
261 |
+
|
262 |
+
scores_videos = {}
|
263 |
+
for test_id, splits in qfvs_split.items():
|
264 |
+
logger.info(f"Start Training {opt.dset_name}: {test_id}")
|
265 |
+
config['train_videos'] = qfvs_split[test_id]
|
266 |
+
config['test_videos'] = [test_id]
|
267 |
+
train_dataset = DatasetQFVS(config)
|
268 |
+
train_loader = DataLoader(train_dataset, batch_size=opt.bsz, collate_fn=start_end_collate_qfvs, shuffle=True, num_workers=opt.num_workers)
|
269 |
+
|
270 |
+
model, criterion, optimizer, lr_scheduler = setup_model(opt)
|
271 |
+
count_parameters(model)
|
272 |
+
best_score = train(model, criterion, optimizer, lr_scheduler, train_loader, opt, config)
|
273 |
+
scores_videos['V'+str(test_id)] = best_score
|
274 |
+
|
275 |
+
# save the final results.
|
276 |
+
avg_fscore = sum([v['Fscore'] for k, v in scores_videos.items()]) / len(scores_videos)
|
277 |
+
avg_precision = sum([v['Precision'] for k, v in scores_videos.items()]) / len(scores_videos)
|
278 |
+
avg_recall = sum([v['Recall'] for k, v in scores_videos.items()]) / len(scores_videos)
|
279 |
+
scores_videos['avg'] = {'Fscore':avg_fscore, 'Precision':avg_precision, 'Recall':avg_recall}
|
280 |
+
|
281 |
+
save_metrics_path = os.path.join(opt.results_dir, f"best_{opt.dset_name}_{opt.eval_split_name}_preds_metrics.json")
|
282 |
+
save_json( scores_videos, save_metrics_path, save_pretty=True, sort_keys=False)
|
283 |
+
|
284 |
+
tb_writer.add_scalar(f"Eval/QFVS-avg-fscore", round(avg_fscore, 2), 1)
|
285 |
+
tb_writer.add_text(f"Eval/QFVS-{opt.dset_name}", dict_to_markdown(scores_videos, max_str_len=None))
|
286 |
+
tb_writer.close()
|
287 |
+
|
288 |
+
print(scores_videos)
|
289 |
+
return
|
290 |
+
|
291 |
+
if __name__ == '__main__':
|
292 |
+
start_training()
|
293 |
+
results = logger.info("\n\n\nFINISHED TRAINING!!!")
|
main/config.py
ADDED
@@ -0,0 +1,378 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import pdb
|
3 |
+
import time
|
4 |
+
import torch
|
5 |
+
import logging
|
6 |
+
import argparse
|
7 |
+
import importlib
|
8 |
+
from utils.basic_utils import mkdirp, remkdirp, \
|
9 |
+
load_json, save_json, make_zipfile, dict_to_markdown
|
10 |
+
|
11 |
+
logger = logging.getLogger(__name__)
|
12 |
+
logging.basicConfig(format="%(asctime)s.%(msecs)03d:%(levelname)s:%(name)s - %(message)s",
|
13 |
+
datefmt="%Y-%m-%d %H:%M:%S",
|
14 |
+
level=logging.INFO)
|
15 |
+
|
16 |
+
class BaseOptions(object):
|
17 |
+
saved_option_filename = "opt.json"
|
18 |
+
ckpt_filename = "model.ckpt"
|
19 |
+
tensorboard_log_dir = "tensorboard_log"
|
20 |
+
train_log_filename = "train.log.txt"
|
21 |
+
eval_log_filename = "eval.log.txt"
|
22 |
+
|
23 |
+
def __init__(self):
|
24 |
+
self.parser = None
|
25 |
+
self.initialized = False
|
26 |
+
self.opt = None
|
27 |
+
|
28 |
+
def initialize(self):
|
29 |
+
self.initialized = True
|
30 |
+
parser = argparse.ArgumentParser()
|
31 |
+
# * Running configs
|
32 |
+
parser.add_argument("--dset_type", type=str, choices=["mr", "hl", "vs", "vlp"]) # moment retrieval, highlight detection, and video summarization
|
33 |
+
parser.add_argument("--dset_name", type=str, choices=["qvhighlights", "charades", "anet", "tvsum", "youtube", "summe", "ego4d", "qfvs", "video2gif", "coin", "hacs", "vlp", "videocc", "tacos"])
|
34 |
+
parser.add_argument("--domain_name", type=str, default=None)
|
35 |
+
parser.add_argument("--model_id", type=str, default="moment_detr")
|
36 |
+
parser.add_argument("--exp_id", type=str, default="debug", help="id of this run, required at training")
|
37 |
+
parser.add_argument("--device", type=int, default=0, help="0 cuda, -1 cpu")
|
38 |
+
parser.add_argument("--gpu_id", type=int, default=0)
|
39 |
+
parser.add_argument("--debug", action="store_true",
|
40 |
+
help="debug (fast) mode, break all loops, do not load all data into memory.")
|
41 |
+
parser.add_argument("--seed", type=int, default=2018, help="random seed")
|
42 |
+
|
43 |
+
# * DDP
|
44 |
+
parser.add_argument('--local_rank', default=-1, type=int, help='node rank for distributed training')
|
45 |
+
|
46 |
+
|
47 |
+
parser.add_argument("--eval_split_name", type=str, default="val",
|
48 |
+
help="should match keys in video_duration_idx_path, must set for VCMR")
|
49 |
+
parser.add_argument("--data_ratio", type=float, default=1.0,
|
50 |
+
help="how many training and eval data to use. 1.0: use all, 0.1: use 10%."
|
51 |
+
"Use small portion for debug purposes. Note this is different from --debug, "
|
52 |
+
"which works by breaking the loops, typically they are not used together.")
|
53 |
+
parser.add_argument("--results_root", type=str, default="results")
|
54 |
+
parser.add_argument("--num_workers", type=int, default=0,
|
55 |
+
help="num subprocesses used to load the data, 0: use main process")
|
56 |
+
parser.add_argument("--no_pin_memory", action="store_true",
|
57 |
+
help="Don't use pin_memory=True for dataloader. "
|
58 |
+
"ref: https://discuss.pytorch.org/t/should-we-set-non-blocking-to-true/38234/4")
|
59 |
+
|
60 |
+
# * Training configs
|
61 |
+
parser.add_argument("--bsz", type=int, default=32, help="mini-batch size")
|
62 |
+
parser.add_argument("--n_epoch", type=int, default=200, help="number of epochs to run")
|
63 |
+
parser.add_argument("--max_es_cnt", type=int, default=200,
|
64 |
+
help="number of epochs to early stop, use -1 to disable early stop")
|
65 |
+
parser.add_argument("--lr", type=float, default=1e-4, help="learning rate")
|
66 |
+
parser.add_argument("--lr_drop", type=int, default=400, help="drop learning rate to 1/10 every lr_drop epochs")
|
67 |
+
parser.add_argument("--lr_gamma", type=float, default=0.1, help="lr reduces the gamma times after the `drop' epoch")
|
68 |
+
parser.add_argument("--lr_warmup", type=float, default=-1, help="linear warmup scheme")
|
69 |
+
parser.add_argument("--wd", type=float, default=1e-4, help="weight decay")
|
70 |
+
parser.add_argument("--grad_clip", type=float, default=0.1, help="perform gradient clip, -1: disable")
|
71 |
+
|
72 |
+
# ** Loss coefficients
|
73 |
+
# *** boundary branch
|
74 |
+
parser.add_argument("--span_loss_type", default="l1", type=str, choices=['l1', 'ce'],
|
75 |
+
help="l1: (center-x, width) regression. ce: (st_idx, ed_idx) classification.")
|
76 |
+
parser.add_argument('--b_loss_coef', default=10, type=float) # boundary regression e.g., l1
|
77 |
+
parser.add_argument('--g_loss_coef', default=1, type=float) # giou loss
|
78 |
+
# *** foreground branch
|
79 |
+
parser.add_argument('--eos_coef', default=0.1, type=float, help="relative classification weight of the no-object class")
|
80 |
+
parser.add_argument('--f_loss_coef', default=4, type=float) # cls loss for foreground
|
81 |
+
# *** saliency branch
|
82 |
+
parser.add_argument("--s_loss_intra_coef", type=float, default=1., help="inter-video (frame-level) saliency loss e.g. momentdetr saliency loss")
|
83 |
+
parser.add_argument("--s_loss_inter_coef", type=float, default=0., help="intra-video (sample-level) saliency loss,")
|
84 |
+
|
85 |
+
# * Eval configs
|
86 |
+
parser.add_argument("--main_metric", type=str, default="MR-full-mAP")
|
87 |
+
parser.add_argument('--eval_mode', default=None, type=str,
|
88 |
+
help="how to integrate foreground and saliency for better prediction")
|
89 |
+
parser.add_argument("--eval_bsz", type=int, default=100,
|
90 |
+
help="mini-batch size at inference, for query")
|
91 |
+
parser.add_argument("--eval_epoch", type=int, default=5,
|
92 |
+
help="number of epochs for once inference")
|
93 |
+
parser.add_argument("--eval_init", action="store_true", help="evaluate model before training i.e. `epoch=-1'")
|
94 |
+
parser.add_argument("--save_interval", type=int, default=50)
|
95 |
+
|
96 |
+
parser.add_argument("--resume", type=str, default=None,
|
97 |
+
help="checkpoint path to resume or evaluate, without --resume_all this only load weights")
|
98 |
+
parser.add_argument("--resume_dir", type=str, default=None,
|
99 |
+
help="checkpoint path to resume or evaluate, without --resume_all this only load weights")
|
100 |
+
parser.add_argument("--resume_all", action="store_true",
|
101 |
+
help="if --resume_all, load optimizer/scheduler/epoch as well")
|
102 |
+
parser.add_argument("--start_epoch", type=int, default=None,
|
103 |
+
help="if None, will be set automatically when using --resume_all")
|
104 |
+
|
105 |
+
# ** NMS configs
|
106 |
+
parser.add_argument("--no_sort_results", action="store_true",
|
107 |
+
help="do not sort results, use this for moment query visualization")
|
108 |
+
parser.add_argument("--max_before_nms", type=int, default=10)
|
109 |
+
parser.add_argument("--max_after_nms", type=int, default=10)
|
110 |
+
parser.add_argument("--conf_thd", type=float, default=0.0, help="only keep windows with conf >= conf_thd")
|
111 |
+
parser.add_argument("--nms_thd", type=float, default=-1,
|
112 |
+
help="additionally use non-maximum suppression "
|
113 |
+
"(or non-minimum suppression for distance)"
|
114 |
+
"to post-processing the predictions. "
|
115 |
+
"-1: do not use nms. [0, 1]")
|
116 |
+
|
117 |
+
# * Dataset configs
|
118 |
+
parser.add_argument("--use_cache", type=int, default=-1, help="Preload features into cache for fast IO")
|
119 |
+
parser.add_argument("--max_q_l", type=int, default=75)
|
120 |
+
parser.add_argument("--max_v_l", type=int, default=75)
|
121 |
+
parser.add_argument("--clip_length", type=float, default=1.0)
|
122 |
+
parser.add_argument("--clip_len_list", type=int, nargs='+')
|
123 |
+
parser.add_argument("--max_windows", type=int, default=5)
|
124 |
+
|
125 |
+
parser.add_argument("--add_easy_negative", type=int, default=1)
|
126 |
+
parser.add_argument("--easy_negative_only", type=int, default=1)
|
127 |
+
parser.add_argument("--round_multiple", type=int, default=1)
|
128 |
+
|
129 |
+
parser.add_argument("--train_path", type=str, default=None, nargs='+')
|
130 |
+
parser.add_argument("--eval_path", type=str, default=None,
|
131 |
+
help="Evaluating during training, for Dev set. If None, will only do training, ")
|
132 |
+
parser.add_argument("--train_path_list", type=str, nargs='+')
|
133 |
+
parser.add_argument("--eval_path_list", type=str, nargs='+')
|
134 |
+
parser.add_argument("--feat_root_list", type=str, nargs='+')
|
135 |
+
|
136 |
+
parser.add_argument("--no_norm_vfeat", action="store_true", help="Do not do normalize video feat")
|
137 |
+
parser.add_argument("--no_norm_tfeat", action="store_true", help="Do not do normalize text feat")
|
138 |
+
parser.add_argument("--v_feat_dirs", type=str, nargs="+",
|
139 |
+
help="video feature dirs. If more than one, will concat their features. "
|
140 |
+
"Note that sub ctx features are also accepted here.")
|
141 |
+
parser.add_argument("--t_feat_dir", type=str, help="text/query feature dir")
|
142 |
+
parser.add_argument("--v_feat_dim", type=int, help="video feature dim")
|
143 |
+
parser.add_argument("--t_feat_dim", type=int, help="text/query feature dim")
|
144 |
+
parser.add_argument("--ctx_mode", type=str, default="video_tef")
|
145 |
+
parser.add_argument("--v_feat_types", type=str)
|
146 |
+
parser.add_argument("--t_feat_type", type=str)
|
147 |
+
|
148 |
+
# * Model configs
|
149 |
+
parser.add_argument('--position_embedding', default='sine', type=str, choices=('sine', 'learned'),
|
150 |
+
help="Type of positional embedding to use on top of the image features")
|
151 |
+
parser.add_argument("--n_input_proj", type=int, default=2, help="#layers to vid/txt projector")
|
152 |
+
parser.add_argument("--temperature", type=float, default=0.07, help="temperature nce contrastive_align_loss")
|
153 |
+
|
154 |
+
# ** Transformer
|
155 |
+
parser.add_argument('--enc_layers', default=4, type=int,
|
156 |
+
help="Number of encoding layers in the transformer")
|
157 |
+
parser.add_argument('--sub_enc_layers', default=2, type=int,
|
158 |
+
help="Number of encoding layers in the video / text transformer in albef-style.")
|
159 |
+
parser.add_argument('--dec_layers', default=2, type=int,
|
160 |
+
help="Number of decoding layers in the transformer, N/A for UniVTG")
|
161 |
+
parser.add_argument('--dim_feedforward', default=1024, type=int,
|
162 |
+
help="Intermediate size of the feedforward layers in the transformer blocks")
|
163 |
+
parser.add_argument('--hidden_dim', default=256, type=int,
|
164 |
+
help="Size of the embeddings (dimension of the transformer)")
|
165 |
+
parser.add_argument('--input_dropout', default=0.5, type=float,
|
166 |
+
help="Dropout applied in input")
|
167 |
+
parser.add_argument('--dropout', default=0.1, type=float,
|
168 |
+
help="Dropout applied in the transformer")
|
169 |
+
parser.add_argument('--droppath', default=0.1, type=float,
|
170 |
+
help="Droppath applied in the transformer")
|
171 |
+
parser.add_argument("--txt_drop_ratio", default=0, type=float,
|
172 |
+
help="drop txt_drop_ratio tokens from text input. 0.1=10%")
|
173 |
+
parser.add_argument("--use_txt_pos", action="store_true", help="use position_embedding for text as well.")
|
174 |
+
parser.add_argument('--nheads', default=8, type=int,
|
175 |
+
help="Number of attention heads inside the transformer's attentions")
|
176 |
+
parser.add_argument('--num_queries', default=10, type=int,
|
177 |
+
help="Number of query slots")
|
178 |
+
parser.add_argument('--pre_norm', action='store_true')
|
179 |
+
|
180 |
+
# ** momentdetr configs e.g. Matcher, saliency margin
|
181 |
+
parser.add_argument('--set_cost_span', default=10, type=float,
|
182 |
+
help="L1 span coefficient in the matching cost")
|
183 |
+
parser.add_argument('--set_cost_giou', default=1, type=float,
|
184 |
+
help="giou span coefficient in the matching cost")
|
185 |
+
parser.add_argument('--set_cost_class', default=4, type=float,
|
186 |
+
help="Class coefficient in the matching cost")
|
187 |
+
parser.add_argument("--saliency_margin", type=float, default=0.2)
|
188 |
+
parser.add_argument('--no_aux_loss', dest='aux_loss', action='store_true',
|
189 |
+
help="Disables auxiliary decoding losses (loss at each layer)")
|
190 |
+
|
191 |
+
# * Query-Force Video Summarization
|
192 |
+
parser.add_argument("--max_segment_num", type=int, default=20)
|
193 |
+
parser.add_argument("--max_frame_num", type=int, default=200)
|
194 |
+
parser.add_argument("--top_percent", type=float, default=0.02)
|
195 |
+
|
196 |
+
parser.add_argument("--qfvs_vid_feature", type=str, default='fps1')
|
197 |
+
parser.add_argument("--qfvs_txt_feature", type=str, default='query')
|
198 |
+
parser.add_argument("--qfvs_split", type=int, default=-1)
|
199 |
+
|
200 |
+
parser.add_argument("--qfvs_dense_shot", type=int, default=-1)
|
201 |
+
parser.add_argument("--qfvs_score_ensemble", type=int, default=-1)
|
202 |
+
parser.add_argument("--qfvs_score_gather", type=int, default=-1)
|
203 |
+
parser.add_argument("--qfvs_loss_gather", type=int, default=-1)
|
204 |
+
self.parser = parser
|
205 |
+
|
206 |
+
def display_save(self, opt):
|
207 |
+
args = vars(opt)
|
208 |
+
# Display settings
|
209 |
+
print(dict_to_markdown(vars(opt), max_str_len=120))
|
210 |
+
# Save settings
|
211 |
+
if not isinstance(self, TestOptions):
|
212 |
+
option_file_path = os.path.join(opt.results_dir, self.saved_option_filename) # not yaml file indeed
|
213 |
+
save_json(args, option_file_path, save_pretty=True)
|
214 |
+
|
215 |
+
def parse(self, args=None):
|
216 |
+
if not self.initialized:
|
217 |
+
self.initialize()
|
218 |
+
opt = self.parser.parse_args()
|
219 |
+
|
220 |
+
if args is not None:
|
221 |
+
args_dict = vars(args)
|
222 |
+
opt_dict = vars(opt)
|
223 |
+
for key, value in args_dict.items():
|
224 |
+
opt_dict[key] = value
|
225 |
+
opt = argparse.Namespace(**opt_dict)
|
226 |
+
opt.model_dir = os.path.dirname(opt.resume)
|
227 |
+
torch.cuda.set_device(opt.gpu_id)
|
228 |
+
|
229 |
+
if opt.debug:
|
230 |
+
opt.results_root = os.path.sep.join(opt.results_root.split(os.path.sep)[:-1] + ["debug_results", ])
|
231 |
+
opt.num_workers = 0
|
232 |
+
|
233 |
+
if isinstance(self, TestOptions):
|
234 |
+
# modify model_dir to absolute path
|
235 |
+
# opt.model_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "results", opt.model_dir)
|
236 |
+
opt.model_dir = os.path.dirname(opt.resume)
|
237 |
+
saved_options = load_json(os.path.join(opt.model_dir, self.saved_option_filename))
|
238 |
+
for arg in saved_options: # use saved options to overwrite all BaseOptions args.
|
239 |
+
if arg not in ["results_root", "num_workers", "nms_thd", "debug", "max_before_nms", "max_after_nms"
|
240 |
+
"max_pred_l", "min_pred_l", "gpu_id",
|
241 |
+
"resume", "resume_all", "no_sort_results",
|
242 |
+
"eval_path", "eval_split_name"]:
|
243 |
+
# "dset_name", "v_feat_dirs", "t_feat_dir"]:
|
244 |
+
setattr(opt, arg, saved_options[arg])
|
245 |
+
# opt.no_core_driver = True
|
246 |
+
if opt.eval_results_dir is not None:
|
247 |
+
opt.results_dir = opt.eval_results_dir
|
248 |
+
else:
|
249 |
+
if opt.exp_id is None:
|
250 |
+
raise ValueError("--exp_id is required for at a training option!")
|
251 |
+
|
252 |
+
# ctx_str = opt.ctx_mode + "_sub" if any(["sub_ctx" in p for p in opt.v_feat_dirs]) else opt.ctx_mode
|
253 |
+
|
254 |
+
if 'debug' not in opt.exp_id:
|
255 |
+
opt.results_dir = os.path.join(opt.results_root, "-".join([opt.dset_type, opt.dset_name]), "-".join([opt.exp_id, opt.v_feat_types, opt.t_feat_type, time.strftime("%Y_%m_%d_%H")]))
|
256 |
+
else:
|
257 |
+
opt.results_dir = os.path.join(opt.results_root, "-".join([opt.dset_type, opt.dset_name]), opt.exp_id) # debug mode.
|
258 |
+
|
259 |
+
if int(opt.local_rank) in [0, -1]:
|
260 |
+
# mkdirp(opt.results_dir)
|
261 |
+
remkdirp(opt.results_dir) # remove dir and remkdir it.
|
262 |
+
|
263 |
+
# save a copy of current code
|
264 |
+
code_dir = os.path.dirname(os.path.realpath(__file__))
|
265 |
+
code_zip_filename = os.path.join(opt.results_dir, "code.zip")
|
266 |
+
make_zipfile(code_dir, code_zip_filename,
|
267 |
+
enclosing_dir="code",
|
268 |
+
exclude_dirs_substring="results",
|
269 |
+
exclude_dirs=["results", "debug_results", "__pycache__"],
|
270 |
+
exclude_extensions=[".pyc", ".ipynb", ".swap"], )
|
271 |
+
|
272 |
+
if int(opt.local_rank) in [0, -1]:
|
273 |
+
self.display_save(opt)
|
274 |
+
opt.ckpt_filepath = os.path.join(opt.results_dir, self.ckpt_filename)
|
275 |
+
opt.train_log_filepath = os.path.join(opt.results_dir, self.train_log_filename)
|
276 |
+
opt.eval_log_filepath = os.path.join(opt.results_dir, self.eval_log_filename)
|
277 |
+
opt.tensorboard_log_dir = os.path.join(opt.results_dir, self.tensorboard_log_dir)
|
278 |
+
# opt.device = torch.device("cuda" if opt.device >= 0 else "cpu")
|
279 |
+
|
280 |
+
if int(opt.local_rank) in [-1]:
|
281 |
+
torch.cuda.set_device(opt.gpu_id)
|
282 |
+
opt.pin_memory = not opt.no_pin_memory
|
283 |
+
|
284 |
+
if opt.local_rank == -1:
|
285 |
+
torch.cuda.set_device(opt.gpu_id)
|
286 |
+
|
287 |
+
opt.use_tef = "tef" in opt.ctx_mode
|
288 |
+
opt.use_video = "video" in opt.ctx_mode
|
289 |
+
if not opt.use_video:
|
290 |
+
opt.v_feat_dim = 0
|
291 |
+
if opt.use_tef:
|
292 |
+
opt.v_feat_dim += 2
|
293 |
+
|
294 |
+
self.opt = opt
|
295 |
+
return opt
|
296 |
+
|
297 |
+
class TestOptions(BaseOptions):
|
298 |
+
"""add additional options for evaluating"""
|
299 |
+
|
300 |
+
def initialize(self):
|
301 |
+
BaseOptions.initialize(self)
|
302 |
+
# also need to specify --eval_split_name
|
303 |
+
self.parser.add_argument("--eval_id", type=str, help="evaluation id")
|
304 |
+
self.parser.add_argument("--eval_results_dir", type=str, default=None,
|
305 |
+
help="dir to save results, if not set, fall back to training results_dir")
|
306 |
+
self.parser.add_argument("--model_dir", type=str,
|
307 |
+
help="dir contains the model file, will be converted to absolute path afterwards")
|
308 |
+
|
309 |
+
class WarmupStepLR(torch.optim.lr_scheduler.StepLR):
|
310 |
+
def __init__(self, optimizer, warmup_steps, step_size, gamma=0.1, last_epoch=-1):
|
311 |
+
self.warmup_steps = warmup_steps
|
312 |
+
self.step_size = step_size
|
313 |
+
self.gamma = gamma
|
314 |
+
super(WarmupStepLR, self).__init__(optimizer, step_size, gamma=self.gamma, last_epoch=last_epoch)
|
315 |
+
def get_lr(self):
|
316 |
+
if not self._get_lr_called_within_step:
|
317 |
+
import warnings
|
318 |
+
warnings.warn("To get the last learning rate computed by the scheduler, "
|
319 |
+
"please use `get_last_lr()`.", DeprecationWarning)
|
320 |
+
# e.g. warmup_steps = 10, case: 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 21...
|
321 |
+
if self.last_epoch == self.warmup_steps or(self.last_epoch % self.step_size != 0 and self.last_epoch > self.warmup_steps):
|
322 |
+
return [group['lr'] for group in self.optimizer.param_groups]
|
323 |
+
# e.g. warmup_steps = 10, case: 0, 1, 2, 3, 4, 5, 6, 7, 8, 9
|
324 |
+
elif self.last_epoch < self.warmup_steps:
|
325 |
+
return [group['initial_lr'] * float(self.last_epoch + 1) / float(self.warmup_steps) for group in self.optimizer.param_groups]
|
326 |
+
|
327 |
+
|
328 |
+
# e.g. warmup_steps = 10, case: 10, 20, 30, 40...
|
329 |
+
return [group['lr'] * self.gamma
|
330 |
+
for group in self.optimizer.param_groups]
|
331 |
+
def _get_closed_form_lr(self):
|
332 |
+
if self.last_epoch <= self.warmup_steps:
|
333 |
+
return [base_lr * float(self.last_epoch) / (self.warmup_steps) for base_lr in self.base_lrs]
|
334 |
+
else:
|
335 |
+
return [base_lr * self.gamma ** ((self.last_epoch - self.warmup_steps)// self.step_size) for base_lr in self.base_lrs]
|
336 |
+
|
337 |
+
def setup_model(opt):
|
338 |
+
"""setup model/optimizer/scheduler and load checkpoints when needed"""
|
339 |
+
logger.info("setup model/optimizer/scheduler")
|
340 |
+
|
341 |
+
importer = importlib.import_module('.'.join(['model', opt.model_id]))
|
342 |
+
model, criterion = importer.build_model(opt)
|
343 |
+
|
344 |
+
if int(opt.device) >= 0:
|
345 |
+
logger.info("CUDA enabled.")
|
346 |
+
model.to(opt.gpu_id)
|
347 |
+
criterion.to(opt.gpu_id)
|
348 |
+
|
349 |
+
param_dicts = [{"params": [p for n, p in model.named_parameters() if p.requires_grad]}]
|
350 |
+
optimizer = torch.optim.AdamW(param_dicts, lr=opt.lr, weight_decay=opt.wd)
|
351 |
+
|
352 |
+
if opt.lr_warmup != -1 and opt.lr_drop > 0:
|
353 |
+
lr_scheduler = WarmupStepLR(optimizer, warmup_steps=opt.lr_warmup[0], step_size=opt.lr_drop, gamma=opt.lr_gamma)
|
354 |
+
|
355 |
+
elif opt.lr_warmup != -1:
|
356 |
+
from transformers import get_constant_schedule_with_warmup
|
357 |
+
lr_scheduler = get_constant_schedule_with_warmup(optimizer, opt.lr_warmup[0])
|
358 |
+
|
359 |
+
elif opt.lr_drop > 0:
|
360 |
+
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, opt.lr_drop, gamma=opt.lr_gamma)
|
361 |
+
|
362 |
+
if opt.resume is not None:
|
363 |
+
logger.info(f"Load checkpoint from {opt.resume}")
|
364 |
+
checkpoint = torch.load(opt.resume, map_location="cpu")
|
365 |
+
|
366 |
+
for key in list(checkpoint["model"].keys()):
|
367 |
+
checkpoint["model"][key.replace('module.', '')] = checkpoint["model"].pop(key)
|
368 |
+
model.load_state_dict(checkpoint["model"])
|
369 |
+
|
370 |
+
if opt.resume_all:
|
371 |
+
optimizer.load_state_dict(checkpoint['optimizer'])
|
372 |
+
lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
|
373 |
+
opt.start_epoch = checkpoint['epoch'] + 1
|
374 |
+
logger.info(f"Loaded model saved at epoch {checkpoint['epoch']} from checkpoint: {opt.resume}")
|
375 |
+
else:
|
376 |
+
logger.warning("If you intend to evaluate the model, please specify --resume with ckpt path")
|
377 |
+
|
378 |
+
return model, criterion, optimizer, lr_scheduler
|
main/config_hl.py
ADDED
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) THL A29 Limited, a Tencent company. All rights reserved.
|
2 |
+
|
3 |
+
YOUTUBE_SPLITS = {
|
4 |
+
'dog': {
|
5 |
+
'train': [
|
6 |
+
'BsjTtq337mM', 'eGCD1F74iy8', 'x2Za-t1yHtI', 'iyYiqa0QZXM',
|
7 |
+
'azy9ijU6f9I', 'NNtSZ6cPiwA', 'U9CBalvFfbM', 'AZDkqJaOgJU',
|
8 |
+
'-olTgMPAyMI', 'i35F1Ec3Ats', '6bS6-GVLBeM', 'ZGszTEn28v8',
|
9 |
+
'EEb8iSMqwj4', 'p2hYGNkRMCw', '3kbptPDIz4U', 'iLHRqR-M9HQ',
|
10 |
+
'zyooMDuAgCA', 'dOVsQ63N0gg', '7H_qqQvPUzY', 'Z5BEFsaYIS4',
|
11 |
+
'iWO6io44-Fs', 'vVmGisWK0QI', 'L10kN7Btk90', '2yql1mvWbDs',
|
12 |
+
'Iu2nbtr_Uuk', 'NSmOKAauZpM', 'PAhQGoURAro', 'uJ81Us4mBOc',
|
13 |
+
'1krGVyfIaOw', 'p9yW6FxsrJ4', 'DLGRJfpGmCQ', '0XTXKe2TOAg',
|
14 |
+
'qpc4OSqeV7I', 'q_PJFuBOk7k', '0Uu53hCnKQ4', '-szRD9kyNug',
|
15 |
+
'rUPxwWmJYpg', 'hseONiKKx_8', 'BLaQcOcDfjo', 'nW5JulWYEc8',
|
16 |
+
'rMvH1SMGwwI', 'l6KlvTJkTgk', 'O8j4U3NjNvs', '8AJTZeEeStk'
|
17 |
+
],
|
18 |
+
'val': [
|
19 |
+
'a2nj7XCo2Rk', '9rP5yF9EC3Y', 'OxSsRZqPfyk', 'bZzP2MieC1c',
|
20 |
+
'PcvdX5OVgfQ', 'p0oxRJD1GUk', 'msjK8nHZHZ0', 'hSRyclcZyGM',
|
21 |
+
'dlH2K9N_jSM', 'OCVXhRG2fEA', 'MkBdHvXPocc', 'yN7h90Y-04g',
|
22 |
+
'PWqLJKZeBC8', '9D_Q8l_ruQk', 'Mp8Pz86J660', '1gjntnYm8NA',
|
23 |
+
'O3XxuutEvoo', 'wf_qlAizlSM', 'fXx44D1sqUw', 'P0MnXh6bnKk',
|
24 |
+
'sTd06idFa0E', 'ppNjl3I3iJs', 'Om5mczkpcVg', 'xZIN_s-qhbU'
|
25 |
+
]
|
26 |
+
},
|
27 |
+
'gymnastics': {
|
28 |
+
'train': [
|
29 |
+
'Wfv90YJ2YtA', 'MbD5OIR9yWc', 'fZwCJWkC_Qw', 'AyRI1CioQfY',
|
30 |
+
'xV_5YCdVqSM', '19UO7T32DJI', 'o2gAP2Clg_s', 'ewyfAOrBzjQ',
|
31 |
+
'CMTKpA683Ig', 'aNjphhjTgqs', 'dmJ0Nq4DF2w', '57IQ6EudvGU',
|
32 |
+
'BAlUYtPUsVI', '_UU4XqYVDqE', 'Kq4OhBiQk_E', 'D6nyvx9kEac',
|
33 |
+
'g-m4-zeCisU', '_45vTFtcduE', '9L-Pocc_u70', '0636XaURL-A',
|
34 |
+
'GCabQyaHSMg', 'vUi1Scb35fQ', 'eK-Yuoou_1I', 'kkS7TgNZwJI',
|
35 |
+
'2EFkINKg3nA', 'eKvALYDh7RU', 'Hyp3Hpk6dyA', '9rpzf3sgQkw',
|
36 |
+
'kHNAnpewyeo', 'ydQij10qrZM', '41u2V_ZAKto', '6NSWsMKAgEU',
|
37 |
+
'kUs_yUR-C2k', 'bs3ZBcfhvKA'
|
38 |
+
],
|
39 |
+
'val': [
|
40 |
+
'2AuigNFEsTM', 'rPsKpHKzUso', 'tzq5cJQ9NQA', 'DyZ0gZ5xmxI',
|
41 |
+
'PEKRfJYYEgU', 'affAIVH9uRA', 'FT7yIi3-tG0', 'T_zWyrVzyvw',
|
42 |
+
'RoiLzMA_ilA', 'nBZiGSccsTg', 'z3cNtOMKK7A', 'EwQ-aMK2sKg',
|
43 |
+
'Rq0BpciuvBM', 's6LNwTThBgs', '-hE9v3izo4c', 'KldEfRhv7H0',
|
44 |
+
'eUyuw2J5FaE', 'E0aRE1_ea8E', 'BU7YlQAOBkM', 'iDJM9j11U-c',
|
45 |
+
'zr5LSPMBpiI', 'NAfBa7lqg2Q', 'eB4Toq9dUWs', 'YPd7RDN5CkE',
|
46 |
+
'86YLsw7efDM', 'iQRMMFiYAUw', 'lzEhLAPxZyQ', 'PAjJbT1DRnY'
|
47 |
+
]
|
48 |
+
},
|
49 |
+
'parkour': {
|
50 |
+
'train': [
|
51 |
+
'qz1UnnxlWhI', 'MzODICzycHs', '0swXWs9yWA4', 'Nnv22OW_PaI',
|
52 |
+
'LUhZJLY2uKc', 'yZz8z1l3XJU', '3dvjtdMC2ls', 'e27ppPer9XY',
|
53 |
+
'HJNn2WlKFhM', 'j4OxlxnapNI', 'rhABvn7VjSQ', '3PCwXpwYqLs',
|
54 |
+
'LECL1bIpi5w', 'w0ouP79iZWc', 'z6aKQPMJUC0', 'kATlFTwxBVY',
|
55 |
+
'3SM6a8eyuVA', 'v-Sfc4COqRQ', '64eu8pwuIUE', '7WKm0XDk3og',
|
56 |
+
'2F5Sc0Jgk4g'
|
57 |
+
],
|
58 |
+
'val': [
|
59 |
+
'TFdbCRkVeIA', 'uGLs9atTvNc', 'qlGPuopK3CI', 'ucTkpjZO_o4',
|
60 |
+
'4-4BgyGphLQ', '08k4ysX_XJE', '6sMNnWqa_as', 'oT6g0I2Ok9o',
|
61 |
+
'Be4IlnKeBOo', 'yUjJq0kvxcw', 'fLek7GRIxjE'
|
62 |
+
]
|
63 |
+
},
|
64 |
+
'skating': {
|
65 |
+
'train': [
|
66 |
+
'7owXLUkpoNY', '1OLM0_Jzt5M', 'b1LXb0Sbiy0', '3fGux6-ttlA',
|
67 |
+
'HQvRun80GyA', 'a8M-5nTrll8', 'bA3CxZllhsI', 'AUAsfZtcB4E',
|
68 |
+
'FG57uCJvQLw', 'jXIuv5uFPTI', 'eG-hdYLoS98', '2SdJBl251PU',
|
69 |
+
'2PHJqqrGC80', 'EtZkkFhniRw', 'jUiwyguxzIw', 'FL6mXlaF78Q',
|
70 |
+
'BdemklZtYWI', 'ATk_ncI1-BA', '4wiKDfq3X8U', 'BN7GBjVlFTo',
|
71 |
+
'JiMZvMkkbRo', '2DIXYkSnRf4', 'dZ3i-HuhQXM', '7jZydh62m8M'
|
72 |
+
],
|
73 |
+
'val': [
|
74 |
+
'2oOe2_Ew6Ao', 'DGcO0QgcXtw', 'ixsKaNplm6o', '7TQbqKWjLcI',
|
75 |
+
'CQZNrEstSag', 'g1WbAIzkw80', '4cyx1VpDjc4', 'BGZaaqFjoRY',
|
76 |
+
'AJ98A2y1dVw', '1n7Afe5AZCM', '8x8ESK5MnR0'
|
77 |
+
]
|
78 |
+
},
|
79 |
+
'skiing': {
|
80 |
+
'train': [
|
81 |
+
'6Usy87KaF-A', 'DtjKkp_4KDQ', '4Wt7TM2wDxI', 'iKnzSGFwdbc',
|
82 |
+
'nALCc6HPQNs', 'WL4TA--CVcA', 'dFrfsgW1M98', 'x6qmrVojcYc',
|
83 |
+
'pvcmQ9J_BYw', 'S3VEYFAP_pk', 'pU57a3jYMEk', '33TrLdo3ook',
|
84 |
+
'xLhHU8uo2aY', 'fAHBmka6Psc', '9HYzZk5kiJA', 'T0gjqYbeU1g',
|
85 |
+
'7o628W-bFy0', 'YKDm_PCa-HM', 'R3DV2zDnNqg', 'NCe9YeXTvHo',
|
86 |
+
'5tXxvscmZ-Y', 'thNiPQLbi5w', '1TtJy8cSzqA', 'zDRzOsmwa08',
|
87 |
+
'gCI4gArPjNA', 'uw0i26NHucs', '1giAsZC_ywQ', 'OvgaPTfEnqo',
|
88 |
+
'bFD_p5znoq4', 'uKmqaAvjKgw', '5ivw_sdCTCU', 'iwCSAYGwPq4',
|
89 |
+
'HmmOPntPlRA', 'FHCEyiM-NoY', 'EUSFMmoE_jI', 'igvSxtdsT8w',
|
90 |
+
'zEgMYFiEaX4', '0K2FKccDp9A', 'tdyz6h4ZtYs', 'PO7GEbi2z3c',
|
91 |
+
'mmiu7rRmSAU', 'qL6Kic-CdTo', '0fNCsOY1WGk', 'V3J26hr1ZSE',
|
92 |
+
'GS-qBunN3B4', 'ZLNvg8025Nw', 'puAxGH6aWMY', 'h-SlvHubhs8',
|
93 |
+
'AdovZ4OAS8I', 'UDvA1XMa1m4', 'qdo3d7mR_9s', 'qAinbyORWIw',
|
94 |
+
'v1JpJueAElY', 'TjH29fdjcqI', 'f76B1uucoyo', 'DNPPDcOd5eQ',
|
95 |
+
'-GX95udKKm8', 'YRO_RQ3aBgg', '1ptV2E7lm9U', 'qa7dtf1Qcew',
|
96 |
+
'_UJTkqYNrpA', 'md14DNKq2_o', 'tpewrb9dDyo', 'yGoWYi_dHLY',
|
97 |
+
'DZ3NRjDHwy8', 'aMFcEuJUqpk', '6fT9KLuE7no', 'lPdQMMAuOZo'
|
98 |
+
],
|
99 |
+
'val': [
|
100 |
+
'SSlv7qJK5zA', '_BYqZjuKpKA', 'ZueaKXReGjU', 'mGST8ZekCZc',
|
101 |
+
'JJSu7Lh9rvs', 'IyoD3G5igY0', 'MXyv-Ut9HRg', 'Z8X9WIojH1U',
|
102 |
+
'vT33-8KUb2Q', 'HW6_sPym938', '9wtXO2lF6hM', 'mRdthCqe6Nk',
|
103 |
+
'RGxiOb9hlS0', 'ruySf5zL7Kw', 'I7wFmP6P7p0', '0AHkDElk3ws',
|
104 |
+
'zqXd4EgUFhE', '91lDbBHUx0w', 'iaHbK6ogafc', 'jRbst8kjWW8',
|
105 |
+
'drHPy6wSZGs', '5VaY6LgIqDs', 'bXq9rRSbI3c', 'hjZLa2DTuqs',
|
106 |
+
'Ka2qcp3jmWo', 'ZnA4-ggkFu8', 'iXdt4v42mbs', '8aWN-0NZErI',
|
107 |
+
'09v0HNf81J0', 'YJCR2q-WRhQ', 'RjagI4pAUpw', '_10CbYdTG5M',
|
108 |
+
'lhgmIgzBQxs', '2pstGBM4p0w', 'b53-VPsWom4', 'x-G4r153n6o',
|
109 |
+
'qBbqK5qlVSM', 'XamrS9XyHuQ', 'u_n7jMS1vlw', 'AO6p0jlOd6U',
|
110 |
+
'm-W-lcTkBQ0', 'bMuyPVIlXW8', 'kAAvTAKkIy4', 'U6vnbCurZQA',
|
111 |
+
'dHE8q7sZ70U', 'w7fzLVRPSUc', 'FLYkD7zHuHQ', 'nhOhI24P7dM',
|
112 |
+
'n5q2KhfoiWw', '7Hcyse0h9HE', '6_BPy_VaPSY'
|
113 |
+
]
|
114 |
+
},
|
115 |
+
'surfing': {
|
116 |
+
'train': [
|
117 |
+
'Ai9FwQGn5ds', 'hBl0Sm3_auw', 'LMxMeg407Vg', 'D3fk8doVui4',
|
118 |
+
'Y9pxmLg6ti8', 'p_JsivYdbgQ', 'UokX-hcXQeo', 'VYe5QfM5ecE',
|
119 |
+
'I48VJ92ouTQ', 'Tn-ebtUnq6E', 'eWae-nWocPU', '-Yamat_0tbw',
|
120 |
+
'c2Fy-rdXJy4', 'xQ4NAp4vWbI', 'g9kXCIjIjoE', 'A96Jx6gv6_4',
|
121 |
+
'e427qElqqN0', 'tTcA5hiViPo', 'wMdXzj_3aA0', 'fqNzMz1n6uA',
|
122 |
+
'jKVOA7RFCUo', 'TJBJrk9iPPA', '_C8EjMxrS2s', 'yj7abHfZTQQ',
|
123 |
+
'NDcqgpsyWaU', 'UJjwoivaGNo', 'GZ_XS8EnnWo', 'kJUBIcBjUZ0',
|
124 |
+
'lWoLyR7lDAU', 'FilbyF_PGjI', 'fapRkcOe4vE', 't05r50PQqww',
|
125 |
+
'QgStLppe610', '2TY8Q2WXUyk', '9y_ED3DyNhE', 'CGwtinVGkVU',
|
126 |
+
'nOuRhrAMaIw', 'UN4TwjDajtQ', '-FHmVZWWgcE', 'ksx0_BfpsLg',
|
127 |
+
'agOBPDsQrTM', 'XqggBwFOmFU', 'orNzj1J8i-4', '6ZbTCHwt1gk',
|
128 |
+
'0un3wh_pQAc', '4u6OURBLZDs', 'us0agAKuvEM', 'mVQYl7Q-TQs',
|
129 |
+
'cB2SdlGHLMQ', 'WK5t4To0zlA', 'NNEuH_juUHI', 'KTU7xfVOat0',
|
130 |
+
'Y1nhbNaY1ZY', 'YlXJnZe575s', 'SH7Ns0ANzJU', '3TbZfeokCkE'
|
131 |
+
],
|
132 |
+
'val': [
|
133 |
+
'o0on6yIXJQE', '4RsZz_8d8Ro', 'p8VUjcZyK70', '0P2PZXUa0Bg',
|
134 |
+
'p2eU5z647Mw', 'mSVxaAJcNJQ', 'bcmXVyFbsRg', 'Eiq8GHi4kEo',
|
135 |
+
'H5FEdJYokO4', 'Mkyp0z_Cgig', 'NB5Ez5kJfMU', 'Xa0y6b6Vm6U',
|
136 |
+
'gVcCGUtpA90', '0-fstXuo_Pw', '-d72e4v9skA', 'lbp6_wCXqvw',
|
137 |
+
'9GpZHq1n8ps', 'CefGXyYu_zU', 'SI2JbS48Upg', 'hdklRTNrq0I',
|
138 |
+
'J-P-t6g19SM', 'K0f_DpVOjfA', 'lw_1fEY9QTo', 'uUuYnKLETLw',
|
139 |
+
'HwKv3Xc5MAE', 'wvQ0h5Nwsxc', 'l8ME6z_EWKE', 's9dTu2fcbNg',
|
140 |
+
'GS09SevPYT4', 'YbwdDCzVczU', 'jaCOI_VwIjc', '3Y1Jp1_fFLQ',
|
141 |
+
'82OzgxT2tH8', 'IjQhHPlTfdE', 'KzQcJrT91jU', 't05AD0c08zE',
|
142 |
+
'rGxWxX6nYO4', 'QGp0kRzKiAc', 'pK9gDWoOyko', 'Srjd4pe6vck',
|
143 |
+
'twGcxuhCXoU', 'AshLUHPEb8M', '8En3M5CUc2E', '8sTJfTUk1d0',
|
144 |
+
'o-bubyWTw60', 'NctbssxGCtU', 'L09Qo1ql0nM'
|
145 |
+
]
|
146 |
+
}
|
147 |
+
}
|
148 |
+
|
149 |
+
TVSUM_SPLITS = {
|
150 |
+
'BK': {
|
151 |
+
'train': ['WxtbjNsCQ8A', 'EE-bNr36nyA', 'oDXZc0tZe04', 'uGu_10sucQo'],
|
152 |
+
'val': ['Se3oxnaPsz0']
|
153 |
+
},
|
154 |
+
'BT': {
|
155 |
+
'train': ['eQu1rNs0an0', 'qqR6AEXwxoQ', 'EYqVtI9YWJA', 'iVt07TCkFM0'],
|
156 |
+
'val': ['JgHubY5Vw3Y']
|
157 |
+
},
|
158 |
+
'DS': {
|
159 |
+
'train': ['kLxoNp-UchI', 'NyBmCxDoHJU', 'jcoYJXDG9sw', '-esJrBWj2d8'],
|
160 |
+
'val': ['E11zDS9XGzg']
|
161 |
+
},
|
162 |
+
'FM': {
|
163 |
+
'train': ['_xMr-HKMfVA', 'byxOvuiIJV0', 'VuWGsYPqAX8', 'xmEERLqJ2kU'],
|
164 |
+
'val': ['JKpqYvAdIsw']
|
165 |
+
},
|
166 |
+
'GA': {
|
167 |
+
'train': ['xxdtq8mxegs', 'i3wAGJaaktw', '0tmA_C6XwfM', '3eYKfiOEJNs'],
|
168 |
+
'val': ['Bhxk-O1Y7Ho']
|
169 |
+
},
|
170 |
+
'MS': {
|
171 |
+
'train': ['Hl-__g2gn_A', 'WG0MBPpPC6I', 'LRw_obCPUt0', '37rzWOQsNIw'],
|
172 |
+
'val': ['Yi4Ij2NM7U4']
|
173 |
+
},
|
174 |
+
'PK': {
|
175 |
+
'train': ['GsAD1KT1xo8', 'XkqCExn6_Us', 'b626MiF1ew4', 'PJrm840pAUI'],
|
176 |
+
'val': ['cjibtmSLxQ4']
|
177 |
+
},
|
178 |
+
'PR': {
|
179 |
+
'train': ['RBCABdttQmI', 'z_6gVvQb2d0', '4wU_LUjG5Ic', '91IHQYk1IQM'],
|
180 |
+
'val': ['fWutDQy1nnY']
|
181 |
+
},
|
182 |
+
'VT': {
|
183 |
+
'train': ['gzDbaEs1Rlg', 'XzYM3PfTM4w', '98MoyGZKHXc', 'AwmHb44_ouw'],
|
184 |
+
'val': ['J0nA4VgnoCo']
|
185 |
+
},
|
186 |
+
'VU': {
|
187 |
+
'train': ['akI8YFjEmUw', 'HT5vyqe0Xaw', 'vdmoEJ5YbrQ', 'xwqBXPGE9pQ'],
|
188 |
+
'val': ['sTEELN-vY30']
|
189 |
+
}
|
190 |
+
}
|
main/config_qfvs.json
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
// "max_segment_num": 20,
|
3 |
+
// "max_frame_num": 200,
|
4 |
+
|
5 |
+
// "train_videos": null,
|
6 |
+
// "test_videos": null,
|
7 |
+
// "top_percent": 0.02,
|
8 |
+
|
9 |
+
// "vid_feature": "fps1",
|
10 |
+
// "txt_feature": "query",
|
11 |
+
// "txt_max_len": 5,
|
12 |
+
|
13 |
+
// "factor": null
|
14 |
+
}
|
main/dataset.py
ADDED
@@ -0,0 +1,1261 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import pdb
|
3 |
+
import h5py
|
4 |
+
import nncore
|
5 |
+
import torch
|
6 |
+
from torch.utils.data import Dataset
|
7 |
+
import numpy as np
|
8 |
+
from tqdm import tqdm
|
9 |
+
import random
|
10 |
+
import logging
|
11 |
+
from os.path import join, exists
|
12 |
+
from nncore.dataset import DATASETS
|
13 |
+
from nncore.parallel import DataContainer
|
14 |
+
from main.config_hl import TVSUM_SPLITS, YOUTUBE_SPLITS
|
15 |
+
from utils.basic_utils import load_jsonl, load_pickle, l2_normalize_np_array
|
16 |
+
from utils.tensor_utils import pad_sequences_1d
|
17 |
+
from utils.span_utils import span_xx_to_cxw
|
18 |
+
from random import shuffle
|
19 |
+
|
20 |
+
logger = logging.getLogger(__name__)
|
21 |
+
|
22 |
+
class DatasetVLP(Dataset):
|
23 |
+
Q_FEAT_TYPES = ["pooler_output", "last_hidden_state"]
|
24 |
+
"""One line in data loaded from data_path."
|
25 |
+
{
|
26 |
+
"qid": 7803,
|
27 |
+
"query": "Man in gray top walks from outside to inside.",
|
28 |
+
"duration": 150,
|
29 |
+
"vid": "RoripwjYFp8_360.0_510.0",
|
30 |
+
"relevant_clip_ids": [13, 14, 15, 16, 17],
|
31 |
+
"relevant_windows": [[26, 36]]
|
32 |
+
}
|
33 |
+
"""
|
34 |
+
def __init__(self, dset_name, data_path, v_feat_dirs, q_feat_dir, v_feat_dim, q_feat_dim,
|
35 |
+
q_feat_type="last_hidden_state",
|
36 |
+
max_q_l=32, max_v_l=75, data_ratio=1.0, ctx_mode="video",
|
37 |
+
normalize_v=True, normalize_t=True, load_labels=True,
|
38 |
+
clip_len=2, max_windows=5, span_loss_type="l1", txt_drop_ratio=0,
|
39 |
+
use_cache=-1, fix_len=-1, add_easy_negative=1, easy_negative_only=-1):
|
40 |
+
self.dset_name = dset_name
|
41 |
+
self.data_path = data_path
|
42 |
+
self.data_ratio = data_ratio
|
43 |
+
self.v_feat_dirs = v_feat_dirs \
|
44 |
+
if isinstance(v_feat_dirs, list) else [v_feat_dirs]
|
45 |
+
self.q_feat_dir = q_feat_dir
|
46 |
+
self.q_feat_type = q_feat_type
|
47 |
+
self.v_feat_dim = v_feat_dim
|
48 |
+
self.q_feat_dim = q_feat_dim
|
49 |
+
self.max_q_l = max_q_l
|
50 |
+
self.max_v_l = max_v_l
|
51 |
+
self.ctx_mode = ctx_mode
|
52 |
+
self.use_tef = "tef" in ctx_mode
|
53 |
+
self.use_video = "video" in ctx_mode
|
54 |
+
self.normalize_t = normalize_t
|
55 |
+
self.normalize_v = normalize_v
|
56 |
+
self.load_labels = load_labels
|
57 |
+
self.clip_len = clip_len
|
58 |
+
self.fix_len = fix_len
|
59 |
+
self.max_windows = max_windows # maximum number of windows to use as labels
|
60 |
+
self.span_loss_type = span_loss_type
|
61 |
+
self.txt_drop_ratio = txt_drop_ratio
|
62 |
+
self.use_cache = use_cache
|
63 |
+
self.add_easy_negative = add_easy_negative
|
64 |
+
self.easy_negative_only = easy_negative_only
|
65 |
+
|
66 |
+
self.vlp_mapping = {
|
67 |
+
# 'data/qvhighlights/metadata/qvhighlights_asr.jsonl': {
|
68 |
+
# 'dset_name': 'qvhighlights', 'v_feat_suffix': '', 'q_feat_suffix': '_asr', 'type': 'interval',
|
69 |
+
# },
|
70 |
+
# 'data/ego4d/metadata/point_train_1m.jsonl': {
|
71 |
+
# 'dset_name': 'ego4d', 'v_feat_suffix': '_point', 'q_feat_suffix': '_point', 'type': 'point',
|
72 |
+
# },
|
73 |
+
# 'data/ego4d/metadata/point_train_1m_0.1p.jsonl': {
|
74 |
+
# 'dset_name': 'ego4d', 'v_feat_suffix': '_point', 'q_feat_suffix': '_point', 'type': 'point',
|
75 |
+
# },
|
76 |
+
# 'data/ego4d/metadata/point_train_1m_0.2p.jsonl': {
|
77 |
+
# 'dset_name': 'ego4d', 'v_feat_suffix': '_point', 'q_feat_suffix': '_point', 'type': 'point',
|
78 |
+
# },
|
79 |
+
# 'data/ego4d/metadata/point_train_1m_0.5p.jsonl': {
|
80 |
+
# 'dset_name': 'ego4d', 'v_feat_suffix': '_point', 'q_feat_suffix': '_point', 'type': 'point',
|
81 |
+
# },
|
82 |
+
# 'data/ego4d/metadata/point_train_1m_0.75p.jsonl': {
|
83 |
+
# 'dset_name': 'ego4d', 'v_feat_suffix': '_point', 'q_feat_suffix': '_point', 'type': 'point',
|
84 |
+
# },
|
85 |
+
# 'data/ego4d/metadata/point_train_2m.jsonl': {
|
86 |
+
# 'dset_name': 'ego4d', 'v_feat_suffix': '_point', 'q_feat_suffix': '_point', 'type': 'point',
|
87 |
+
# },
|
88 |
+
# 'data/ego4d/metadata/point_train_1m_egoclip.jsonl': {
|
89 |
+
# 'dset_name': 'ego4d', 'v_feat_suffix': '_point', 'q_feat_suffix': '_point', 'type': 'point',
|
90 |
+
# },
|
91 |
+
# 'data/hacs/metadata/hacs_train_cs.jsonl': {
|
92 |
+
# 'dset_name': 'hacs', 'v_feat_suffix': '', 'q_feat_suffix': '_cs', 'type': 'curve',
|
93 |
+
# },
|
94 |
+
# 'data/hacs/metadata/hacs_train.jsonl': {
|
95 |
+
# 'dset_name': 'hacs', 'v_feat_suffix': '', 'q_feat_suffix': '', 'type': 'curve',
|
96 |
+
# },
|
97 |
+
# 'data/videocc/metadata/train_300k.jsonl': {
|
98 |
+
# 'dset_name': 'videocc', 'v_feat_suffix': '', 'q_feat_suffix': '', 'type': 'interval',
|
99 |
+
# },
|
100 |
+
# 'data/videocc/metadata/train_600k.jsonl': {
|
101 |
+
# 'dset_name': 'videocc', 'v_feat_suffix': '', 'q_feat_suffix': '', 'type': 'interval',
|
102 |
+
# },
|
103 |
+
# 'data/videocc/metadata/train_600k_0.1p.jsonl': {
|
104 |
+
# 'dset_name': 'videocc', 'v_feat_suffix': '', 'q_feat_suffix': '', 'type': 'interval',
|
105 |
+
# },
|
106 |
+
# 'data/videocc/metadata/train_600k_0.2p.jsonl': {
|
107 |
+
# 'dset_name': 'videocc', 'v_feat_suffix': '', 'q_feat_suffix': '', 'type': 'interval',
|
108 |
+
# },
|
109 |
+
# 'data/videocc/metadata/train_600k_0.5p.jsonl': {
|
110 |
+
# 'dset_name': 'videocc', 'v_feat_suffix': '', 'q_feat_suffix': '', 'type': 'interval',
|
111 |
+
# },
|
112 |
+
# 'data/videocc/metadata/train_600k_0.75p.jsonl': {
|
113 |
+
# 'dset_name': 'videocc', 'v_feat_suffix': '', 'q_feat_suffix': '', 'type': 'interval',
|
114 |
+
# },
|
115 |
+
# 'data/videocc/metadata/train_900k.jsonl': {
|
116 |
+
# 'dset_name': 'videocc', 'v_feat_suffix': '', 'q_feat_suffix': '', 'type': 'interval',
|
117 |
+
# },
|
118 |
+
# 'data/ego4d/metadata/concept_train_top10_window.jsonl': {
|
119 |
+
# 'dset_name': 'ego4d', 'v_feat_suffix': '_point', 'q_feat_suffix': '_concept', 'type': 'curve',
|
120 |
+
# },
|
121 |
+
# 'data/ego4d/metadata/concept_train_top5_window.jsonl': {
|
122 |
+
# 'dset_name': 'ego4d', 'v_feat_suffix': '_point', 'q_feat_suffix': '_concept', 'type': 'curve',
|
123 |
+
# },
|
124 |
+
# 'data/ego4d/metadata/concept_train_top5_window_0.1p.jsonl': {
|
125 |
+
# 'dset_name': 'ego4d', 'v_feat_suffix': '_point', 'q_feat_suffix': '_concept', 'type': 'curve',
|
126 |
+
# },
|
127 |
+
# 'data/ego4d/metadata/concept_train_top5_window_0.2p.jsonl': {
|
128 |
+
# 'dset_name': 'ego4d', 'v_feat_suffix': '_point', 'q_feat_suffix': '_concept', 'type': 'curve',
|
129 |
+
# },
|
130 |
+
# 'data/ego4d/metadata/concept_train_top5_window_0.5p.jsonl': {
|
131 |
+
# 'dset_name': 'ego4d', 'v_feat_suffix': '_point', 'q_feat_suffix': '_concept', 'type': 'curve',
|
132 |
+
# },
|
133 |
+
# 'data/ego4d/metadata/concept_train_top5_window_0.75p.jsonl': {
|
134 |
+
# 'dset_name': 'ego4d', 'v_feat_suffix': '_point', 'q_feat_suffix': '_concept', 'type': 'curve',
|
135 |
+
# },
|
136 |
+
# 'data/videocc/metadata/concept_train_top10_window.jsonl': {
|
137 |
+
# 'dset_name': 'videocc', 'v_feat_suffix': '', 'q_feat_suffix': '_concept', 'type': 'curve',
|
138 |
+
# },
|
139 |
+
# 'data/videocc/metadata/concept_train_top5_window.jsonl': {
|
140 |
+
# 'dset_name': 'videocc', 'v_feat_suffix': '', 'q_feat_suffix': '_concept', 'type': 'curve',
|
141 |
+
# },
|
142 |
+
# 'data/videocc/metadata/concept_train_top5_window_0.1p.jsonl': {
|
143 |
+
# 'dset_name': 'videocc', 'v_feat_suffix': '', 'q_feat_suffix': '_concept', 'type': 'curve',
|
144 |
+
# },
|
145 |
+
# 'data/videocc/metadata/concept_train_top5_window_0.2p.jsonl': {
|
146 |
+
# 'dset_name': 'videocc', 'v_feat_suffix': '', 'q_feat_suffix': '_concept', 'type': 'curve',
|
147 |
+
# },
|
148 |
+
# 'data/videocc/metadata/concept_train_top5_window_0.5p.jsonl': {
|
149 |
+
# 'dset_name': 'videocc', 'v_feat_suffix': '', 'q_feat_suffix': '_concept', 'type': 'curve',
|
150 |
+
# },
|
151 |
+
# 'data/videocc/metadata/concept_train_top5_window_0.75p.jsonl': {
|
152 |
+
# 'dset_name': 'videocc', 'v_feat_suffix': '', 'q_feat_suffix': '_concept', 'type': 'curve',
|
153 |
+
# },
|
154 |
+
#
|
155 |
+
# pre-training
|
156 |
+
'data/ego4d/metadata/point_egoclip_wo_val.jsonl': {
|
157 |
+
'dset_name': 'ego4d', 'v_feat_suffix': '_point', 'q_feat_suffix': '_point', 'type': 'point',
|
158 |
+
},
|
159 |
+
'data/videocc/metadata/interval_900k.jsonl': {
|
160 |
+
'dset_name': 'videocc', 'v_feat_suffix': '', 'q_feat_suffix': '', 'type': 'interval',
|
161 |
+
},
|
162 |
+
'data/videocc/metadata/curve_5_window.jsonl': {
|
163 |
+
'dset_name': 'videocc', 'v_feat_suffix': '', 'q_feat_suffix': '_concept', 'type': 'curve',
|
164 |
+
},
|
165 |
+
# downstream
|
166 |
+
'data/qvhighlights/metadata/qvhighlights_train.jsonl': {
|
167 |
+
'dset_name': 'qvhighlights', 'v_feat_suffix': '', 'q_feat_suffix': '', 'type': 'curve',
|
168 |
+
},
|
169 |
+
'data/charades/metadata/charades_train.jsonl': {
|
170 |
+
'dset_name': 'charades', 'v_feat_suffix': '_2', 'q_feat_suffix': '', 'type': 'interval',
|
171 |
+
},
|
172 |
+
'data/ego4d/metadata/nlq_train.jsonl': {
|
173 |
+
'dset_name': 'ego4d', 'v_feat_suffix': '', 'q_feat_suffix': '', 'type': 'interval',
|
174 |
+
},
|
175 |
+
'data/tacos/metadata/train.jsonl': {
|
176 |
+
'dset_name': 'tacos', 'v_feat_suffix': '', 'q_feat_suffix': '', 'type': 'interval',
|
177 |
+
},
|
178 |
+
'data/anet/metadata/train.jsonl': {
|
179 |
+
'dset_name': 'anet', 'v_feat_suffix': '', 'q_feat_suffix': '', 'type': 'interval',
|
180 |
+
},
|
181 |
+
'data/didemo/metadata/train.jsonl': {
|
182 |
+
'dset_name': 'didemo', 'v_feat_suffix': '', 'q_feat_suffix': '', 'type': 'interval',
|
183 |
+
},
|
184 |
+
}
|
185 |
+
|
186 |
+
if "val" in data_path or "test" in data_path:
|
187 |
+
assert txt_drop_ratio == 0
|
188 |
+
|
189 |
+
# checks
|
190 |
+
assert q_feat_type in self.Q_FEAT_TYPES
|
191 |
+
|
192 |
+
# data
|
193 |
+
self.data = self.load_data()
|
194 |
+
|
195 |
+
self.v_feat_types = [feat_dir.split('/')[-1] for feat_dir in self.v_feat_dirs]
|
196 |
+
t_feat_type = q_feat_dir.split('/')[-1]
|
197 |
+
|
198 |
+
if self.use_cache > 0:
|
199 |
+
print('Loading the off-line features...')
|
200 |
+
dset_dir = os.path.join('data', self.dset_name)
|
201 |
+
vid_keys = [meta['vid'] for meta in self.data]
|
202 |
+
qid_keys = [meta['qid'] for meta in self.data]
|
203 |
+
|
204 |
+
self.vid_cache = {}
|
205 |
+
for v_feat_type in self.v_feat_types:
|
206 |
+
assert 'vid' in v_feat_type
|
207 |
+
with h5py.File(os.path.join(dset_dir, 'h5py', v_feat_type + '.hdf5'), 'r') as f:
|
208 |
+
self.vid_cache[v_feat_type] = {key: f[str(key)][:] for key in tqdm(vid_keys)}
|
209 |
+
|
210 |
+
assert 'txt' in t_feat_type
|
211 |
+
self.txt_cache = {}
|
212 |
+
with h5py.File(os.path.join(dset_dir, 'h5py', t_feat_type + '.hdf5'), 'r') as f:
|
213 |
+
for key in tqdm(qid_keys):
|
214 |
+
try:
|
215 |
+
self.txt_cache[key] = f[str(key)][:]
|
216 |
+
except:
|
217 |
+
logger.info(f"text {key} is not in the cache.")
|
218 |
+
|
219 |
+
def load_data(self):
|
220 |
+
# datalist = load_jsonl(self.data_path[0])
|
221 |
+
datalist = []
|
222 |
+
for dset_path in self.data_path:
|
223 |
+
dset_info = self.vlp_mapping[dset_path]
|
224 |
+
dset_list = load_jsonl(dset_path)
|
225 |
+
for x in dset_list: x.update(dset_info)
|
226 |
+
datalist += dset_list
|
227 |
+
n_examples = int(len(datalist))
|
228 |
+
if self.data_ratio != 1:
|
229 |
+
n_examples = int(len(datalist) * self.data_ratio)
|
230 |
+
shuffle(datalist)
|
231 |
+
datalist = datalist[:n_examples]
|
232 |
+
logger.info("Using {}% of the data: {} examples"
|
233 |
+
.format(self.data_ratio * 100, n_examples))
|
234 |
+
return datalist
|
235 |
+
|
236 |
+
def __len__(self):
|
237 |
+
return len(self.data)
|
238 |
+
|
239 |
+
def __getitem__(self, index):
|
240 |
+
meta = self.data[index]
|
241 |
+
|
242 |
+
model_inputs = dict()
|
243 |
+
model_inputs["query_feat"] = self._get_query_feat_by_qid(meta) # (Dq, ) or (Lq, Dq)
|
244 |
+
|
245 |
+
if self.use_video:
|
246 |
+
model_inputs["video_feat"] = self._get_video_feat_by_vid(meta) # (Lv, Dv)
|
247 |
+
ctx_l = len(model_inputs["video_feat"])
|
248 |
+
else:
|
249 |
+
ctx_l = self.max_v_l
|
250 |
+
|
251 |
+
if meta['dset_name'] in ['hacs', 'ego4d', 'activitynet']:
|
252 |
+
for i, window_i in enumerate(meta["relevant_windows"]):
|
253 |
+
if window_i[1] - window_i[0] < self.clip_len:
|
254 |
+
center = (window_i[1] + window_i[0]) / 2
|
255 |
+
window_i[0] = max(0, center - 0.5 * self.clip_len)
|
256 |
+
window_i[1] = min(float(meta['duration']), center + 0.5 * self.clip_len)
|
257 |
+
window_i[1] = max(self.clip_len, window_i[1])
|
258 |
+
|
259 |
+
model_inputs["timestamp"] = ( (torch.arange(0, ctx_l) + self.clip_len / 2) / ctx_l).unsqueeze(1).repeat(1, 2)
|
260 |
+
|
261 |
+
if 'test' in self.data_path and 'qvhighlights' in self.dset_name:
|
262 |
+
meta["relevant_windows"] = [[0, 150]]
|
263 |
+
relevant_windows = torch.Tensor(meta["relevant_windows"])
|
264 |
+
|
265 |
+
# assign the nearest window for each timestamp i.e., qvhighlights.
|
266 |
+
num_vid_seq = model_inputs["timestamp"].shape[0]
|
267 |
+
num_windows = relevant_windows.shape[0]
|
268 |
+
|
269 |
+
relevant_windows_ts = relevant_windows / (ctx_l * self.clip_len)
|
270 |
+
relevant_windows_ts = relevant_windows_ts.unsqueeze(0).repeat(num_vid_seq, 1, 1)
|
271 |
+
model_inputs_ts = model_inputs["timestamp"].unsqueeze(1).repeat(1, num_windows, 1)
|
272 |
+
|
273 |
+
if meta['qid'] is not None:
|
274 |
+
nn_window_ts = torch.zeros_like(model_inputs["timestamp"])
|
275 |
+
diff_left = model_inputs_ts[..., 0] - relevant_windows_ts[..., 0]
|
276 |
+
diff_right = relevant_windows_ts[..., 1] - model_inputs_ts[..., 1]
|
277 |
+
assign_idx = torch.where((diff_left >= 0) * (diff_right >= 0))
|
278 |
+
if min(assign_idx[0].shape) == 0: # not assigned, happened in activitynet.
|
279 |
+
nn_window_ts = relevant_windows_ts.squeeze(1)
|
280 |
+
else:
|
281 |
+
nn_window_ts[assign_idx[0]] = relevant_windows_ts[assign_idx[0], assign_idx[1]]
|
282 |
+
|
283 |
+
model_inputs["span_labels_nn"] = nn_window_ts
|
284 |
+
model_inputs["timestamp_window"] = 1 * (model_inputs["timestamp"][:,0] >= nn_window_ts[:,0]) & (model_inputs["timestamp"][:,1] <= nn_window_ts[:,1])
|
285 |
+
|
286 |
+
# for activitynet.
|
287 |
+
if model_inputs["timestamp_window"].sum() < 1:
|
288 |
+
idx = int(meta['relevant_windows'][0][0] / self.clip_len)
|
289 |
+
idx = max(0, min(idx, ctx_l-1))
|
290 |
+
model_inputs["timestamp_window"][idx] = 1
|
291 |
+
|
292 |
+
if self.use_tef:
|
293 |
+
tef_st = torch.arange(0, ctx_l, 1.0) / ctx_l
|
294 |
+
tef_ed = tef_st + 1.0 / ctx_l
|
295 |
+
tef = torch.stack([tef_st, tef_ed], dim=1) # (Lv, 2)
|
296 |
+
if self.use_video:
|
297 |
+
model_inputs["video_feat"] = torch.cat(
|
298 |
+
[model_inputs["video_feat"], tef], dim=1) # (Lv, Dv+2)
|
299 |
+
else:
|
300 |
+
model_inputs["video_feat"] = tef
|
301 |
+
|
302 |
+
if self.load_labels:
|
303 |
+
model_inputs["span_labels"] = self.get_span_labels(meta["relevant_windows"], ctx_l) # (#windows, 2)
|
304 |
+
if 'saliency_scores' in meta.keys():
|
305 |
+
# this is for highlight-only task
|
306 |
+
model_inputs["saliency_scores"] = torch.zeros(ctx_l).double()
|
307 |
+
limit = meta["relevant_clip_ids"].index(ctx_l) if (np.array(meta["relevant_clip_ids"]) >= ctx_l).any() else None
|
308 |
+
model_inputs["saliency_scores"][meta["relevant_clip_ids"][:limit]] = torch.tensor(np.mean(np.array(meta["saliency_scores"][:limit]), -1))
|
309 |
+
model_inputs["saliency_pos_labels"], model_inputs["saliency_neg_labels"] = \
|
310 |
+
self.get_saliency_labels(meta["relevant_clip_ids"], meta["saliency_scores"], ctx_l)
|
311 |
+
# pdb.set_trace()
|
312 |
+
else:
|
313 |
+
model_inputs["saliency_scores"] = model_inputs["timestamp_window"]
|
314 |
+
model_inputs["saliency_pos_labels"], model_inputs["saliency_neg_labels"] = \
|
315 |
+
self.get_saliency_labels_sub_as_query(meta["relevant_windows"][0], ctx_l) # only one gt
|
316 |
+
model_inputs["saliency_pos_labels"] = [ random.choice(torch.where(model_inputs['saliency_scores'])[0].tolist()) ]
|
317 |
+
|
318 |
+
if 'type' in meta.keys():
|
319 |
+
if meta['type'] == 'point':
|
320 |
+
model_inputs['weight_ablation'] = torch.tensor([0, 0, 1, 0, 0])
|
321 |
+
if meta['type'] == 'interval':
|
322 |
+
model_inputs['weight_ablation'] = torch.tensor([1, 1, 0, 0, 0])
|
323 |
+
if meta['type'] == 'curve':
|
324 |
+
model_inputs['weight_ablation'] = torch.tensor([0, 0, 0, 1, 1])
|
325 |
+
|
326 |
+
return dict(meta=meta, model_inputs=model_inputs)
|
327 |
+
|
328 |
+
def get_saliency_labels_sub_as_query(self, gt_window, ctx_l, max_n=1):
|
329 |
+
gt_st = int(gt_window[0] / self.clip_len)
|
330 |
+
gt_st = min(gt_st, ctx_l-1)
|
331 |
+
gt_ed = max(0, min(int(gt_window[1] / self.clip_len), ctx_l) - 1)
|
332 |
+
if gt_st > gt_ed:
|
333 |
+
# gt_st = gt_ed
|
334 |
+
gt_ed = gt_st
|
335 |
+
|
336 |
+
if gt_st != gt_ed:
|
337 |
+
pos_clip_indices = random.sample(range(gt_st, gt_ed+1), k=max_n)
|
338 |
+
else:
|
339 |
+
pos_clip_indices = [gt_st] * max_n #[gt_st, gt_st]
|
340 |
+
|
341 |
+
neg_pool = list(range(0, gt_st)) + list(range(gt_ed+1, ctx_l))
|
342 |
+
# neg_clip_indices = random.sample(neg_pool, k=max_n)
|
343 |
+
|
344 |
+
try:
|
345 |
+
neg_clip_indices = random.sample(neg_pool, k=max_n)
|
346 |
+
except:
|
347 |
+
neg_clip_indices = pos_clip_indices
|
348 |
+
|
349 |
+
return pos_clip_indices, neg_clip_indices
|
350 |
+
|
351 |
+
def get_saliency_labels(self, rel_clip_ids, scores, ctx_l, max_n=1):
|
352 |
+
"""Sum the scores from the three annotations, then take the two clips with the
|
353 |
+
maximum scores as positive, and two with the minimum scores as negative.
|
354 |
+
Args:
|
355 |
+
rel_clip_ids: list(int), list of relevant clip ids
|
356 |
+
scores: list([anno1_score, anno2_score, anno3_score]),
|
357 |
+
ctx_l: int
|
358 |
+
max_n: int, #clips to use as positive and negative, for easy and hard negative, respectively.
|
359 |
+
add_easy_negative: bool, if True, sample eay negative outside the relevant_clip_ids.
|
360 |
+
"""
|
361 |
+
# indices inside rel_clip_ids
|
362 |
+
scores = np.array(scores) # (#rel_clips, 3)
|
363 |
+
agg_scores = np.sum(scores, 1) # (#rel_clips, )
|
364 |
+
sort_indices = np.argsort(agg_scores) # increasing
|
365 |
+
|
366 |
+
# indices in the whole video
|
367 |
+
# the min(_, ctx_l-1) here is incorrect, but should not cause
|
368 |
+
# much troubles since this should be rarely used.
|
369 |
+
hard_pos_clip_indices = [min(rel_clip_ids[idx], ctx_l-1) for idx in sort_indices[-max_n:]]
|
370 |
+
hard_neg_clip_indices = [min(rel_clip_ids[idx], ctx_l-1) for idx in sort_indices[:max_n]]
|
371 |
+
|
372 |
+
if agg_scores[sort_indices[-1]] == agg_scores[sort_indices[0]]:
|
373 |
+
hard_neg_clip_indices = hard_pos_clip_indices
|
374 |
+
|
375 |
+
easy_pos_clip_indices = []
|
376 |
+
easy_neg_clip_indices = []
|
377 |
+
# pdb.set_trace()
|
378 |
+
if self.add_easy_negative > 0:
|
379 |
+
easy_neg_pool = list(set(range(ctx_l)) - set(rel_clip_ids))
|
380 |
+
if len(easy_neg_pool) >= max_n:
|
381 |
+
easy_pos_clip_indices = random.sample(rel_clip_ids, k=max_n)
|
382 |
+
easy_neg_clip_indices = random.sample(easy_neg_pool, k=max_n)
|
383 |
+
else: # copy the hard ones
|
384 |
+
easy_pos_clip_indices = hard_pos_clip_indices
|
385 |
+
easy_neg_clip_indices = hard_neg_clip_indices
|
386 |
+
|
387 |
+
if self.easy_negative_only > 0:
|
388 |
+
return easy_pos_clip_indices, easy_neg_clip_indices
|
389 |
+
|
390 |
+
pos_clip_indices = hard_pos_clip_indices + easy_pos_clip_indices
|
391 |
+
neg_clip_indices = hard_neg_clip_indices + easy_neg_clip_indices
|
392 |
+
|
393 |
+
return pos_clip_indices, neg_clip_indices
|
394 |
+
|
395 |
+
def get_span_labels(self, windows, ctx_l):
|
396 |
+
"""
|
397 |
+
windows: list([st, ed]) in seconds. E.g. [[26, 36]], corresponding st_ed clip_indices [[13, 17]] (inclusive)
|
398 |
+
Note a maximum of `self.max_windows` windows are used.
|
399 |
+
returns Tensor of shape (#windows, 2), each row is [center, width] normalized by video length
|
400 |
+
"""
|
401 |
+
if len(windows) > self.max_windows:
|
402 |
+
random.shuffle(windows)
|
403 |
+
windows = windows[:self.max_windows]
|
404 |
+
if self.span_loss_type == "l1":
|
405 |
+
windows = torch.Tensor(windows) / (ctx_l * self.clip_len) # normalized windows in xx
|
406 |
+
windows = span_xx_to_cxw(windows) # normalized windows in cxw
|
407 |
+
elif self.span_loss_type == "ce":
|
408 |
+
windows = torch.Tensor([
|
409 |
+
[int(w[0] / self.clip_len), min(int(w[1] / self.clip_len), ctx_l) - 1]
|
410 |
+
for w in windows]).long() # inclusive
|
411 |
+
else:
|
412 |
+
raise NotImplementedError
|
413 |
+
return windows
|
414 |
+
|
415 |
+
def _get_query_feat_by_qid(self, meta):
|
416 |
+
qid = meta['qid']
|
417 |
+
dset_name = meta['dset_name']
|
418 |
+
q_feat_suffix = meta['q_feat_suffix']
|
419 |
+
q_feat_dir = self.q_feat_dir + q_feat_suffix
|
420 |
+
|
421 |
+
if self.use_cache > 0:
|
422 |
+
try:
|
423 |
+
q_feat = self.txt_cache[qid]
|
424 |
+
except:
|
425 |
+
q_feat = np.zeros((10, self.q_feat_dim)).astype(np.float32)
|
426 |
+
return torch.from_numpy(q_feat)
|
427 |
+
|
428 |
+
q_feat_path = os.path.join('data', dset_name, q_feat_dir, f"{qid}.npz")
|
429 |
+
try:
|
430 |
+
q_feat = np.load(q_feat_path)[self.q_feat_type].astype(np.float32)
|
431 |
+
except:
|
432 |
+
q_feat = np.zeros((10, self.q_feat_dim)).astype(np.float32)
|
433 |
+
logger.info(f"Something wrong when loading the query feature {q_feat_path}.")
|
434 |
+
|
435 |
+
if self.q_feat_type == "last_hidden_state":
|
436 |
+
# q_feat = q_feat[:self.max_q_l]
|
437 |
+
q_feat = q_feat
|
438 |
+
if self.normalize_t:
|
439 |
+
q_feat = l2_normalize_np_array(q_feat)
|
440 |
+
if self.txt_drop_ratio > 0:
|
441 |
+
q_feat = self.random_drop_rows(q_feat)
|
442 |
+
return torch.from_numpy(q_feat) # (D, ) or (Lq, D)
|
443 |
+
|
444 |
+
def random_drop_rows(self, embeddings):
|
445 |
+
"""randomly mask num_drop rows in embeddings to be zero.
|
446 |
+
Args:
|
447 |
+
embeddings: np.ndarray (L, D)
|
448 |
+
"""
|
449 |
+
num_drop_rows = round(len(embeddings) * self.txt_drop_ratio)
|
450 |
+
if num_drop_rows > 0:
|
451 |
+
row_indices = np.random.choice(
|
452 |
+
len(embeddings), size=num_drop_rows, replace=False)
|
453 |
+
embeddings[row_indices] = 0
|
454 |
+
return embeddings
|
455 |
+
|
456 |
+
def _get_video_feat_by_vid(self, meta):
|
457 |
+
dset_name = meta['dset_name']
|
458 |
+
v_feat_suffix = meta['v_feat_suffix']
|
459 |
+
vid = meta['vid']
|
460 |
+
|
461 |
+
v_feat_list = []
|
462 |
+
for feat_type, _feat_dir in zip(self.v_feat_types, self.v_feat_dirs):
|
463 |
+
v_feat_dir = _feat_dir + v_feat_suffix
|
464 |
+
if self.use_cache > 0:
|
465 |
+
_feat = self.vid_cache[feat_type][vid]
|
466 |
+
else:
|
467 |
+
_feat_path = os.path.join('data', dset_name, v_feat_dir, f"{vid}.npz")
|
468 |
+
_feat = np.load(_feat_path)["features"].astype(np.float32)
|
469 |
+
if self.normalize_v:
|
470 |
+
_feat = l2_normalize_np_array(_feat)
|
471 |
+
v_feat_list.append(_feat)
|
472 |
+
# some features are slightly longer than the others
|
473 |
+
min_len = min([len(e) for e in v_feat_list])
|
474 |
+
v_feat_list = [e[:min_len] for e in v_feat_list]
|
475 |
+
v_feat = np.concatenate(v_feat_list, axis=1)
|
476 |
+
return torch.from_numpy(v_feat) # (Lv, D)
|
477 |
+
|
478 |
+
class DatasetMR(Dataset):
|
479 |
+
Q_FEAT_TYPES = ["pooler_output", "last_hidden_state"]
|
480 |
+
"""One line in data loaded from data_path."
|
481 |
+
{
|
482 |
+
"qid": 7803,
|
483 |
+
"query": "Man in gray top walks from outside to inside.",
|
484 |
+
"duration": 150,
|
485 |
+
"vid": "RoripwjYFp8_360.0_510.0",
|
486 |
+
"relevant_clip_ids": [13, 14, 15, 16, 17],
|
487 |
+
"relevant_windows": [[26, 36]]
|
488 |
+
}
|
489 |
+
"""
|
490 |
+
def __init__(self, dset_name, data_path, v_feat_dirs, q_feat_dir, v_feat_dim, q_feat_dim,
|
491 |
+
q_feat_type="last_hidden_state",
|
492 |
+
max_q_l=32, max_v_l=75, data_ratio=1.0, ctx_mode="video",
|
493 |
+
normalize_v=True, normalize_t=True, load_labels=True,
|
494 |
+
clip_len=2, max_windows=5, span_loss_type="l1", txt_drop_ratio=0,
|
495 |
+
use_cache=-1, fix_len=-1, add_easy_negative=1, easy_negative_only=-1):
|
496 |
+
self.dset_name = dset_name
|
497 |
+
self.data_path = data_path[0] if isinstance(data_path, list) else data_path
|
498 |
+
self.data_ratio = data_ratio
|
499 |
+
self.v_feat_dirs = v_feat_dirs \
|
500 |
+
if isinstance(v_feat_dirs, list) else [v_feat_dirs]
|
501 |
+
self.q_feat_dir = q_feat_dir
|
502 |
+
self.q_feat_type = q_feat_type
|
503 |
+
self.v_feat_dim = v_feat_dim
|
504 |
+
self.q_feat_dim = q_feat_dim
|
505 |
+
self.max_q_l = max_q_l
|
506 |
+
self.max_v_l = max_v_l
|
507 |
+
self.ctx_mode = ctx_mode
|
508 |
+
self.use_tef = "tef" in ctx_mode
|
509 |
+
self.use_video = "video" in ctx_mode
|
510 |
+
self.normalize_t = normalize_t
|
511 |
+
self.normalize_v = normalize_v
|
512 |
+
self.load_labels = load_labels
|
513 |
+
self.clip_len = clip_len
|
514 |
+
self.fix_len = fix_len
|
515 |
+
self.max_windows = max_windows # maximum number of windows to use as labels
|
516 |
+
self.span_loss_type = span_loss_type
|
517 |
+
self.txt_drop_ratio = txt_drop_ratio
|
518 |
+
self.use_cache = use_cache
|
519 |
+
self.add_easy_negative = add_easy_negative
|
520 |
+
self.easy_negative_only = easy_negative_only
|
521 |
+
|
522 |
+
if "val" in data_path or "test" in data_path:
|
523 |
+
assert txt_drop_ratio == 0
|
524 |
+
|
525 |
+
# checks
|
526 |
+
assert q_feat_type in self.Q_FEAT_TYPES
|
527 |
+
|
528 |
+
# data
|
529 |
+
self.data = self.load_data()
|
530 |
+
|
531 |
+
self.v_feat_types = [feat_dir.split('/')[-1] for feat_dir in self.v_feat_dirs]
|
532 |
+
t_feat_type = q_feat_dir.split('/')[-1]
|
533 |
+
|
534 |
+
if self.use_cache > 0:
|
535 |
+
print('Loading the off-line features...')
|
536 |
+
dset_dir = os.path.join('data', self.dset_name)
|
537 |
+
vid_keys = [meta['vid'] for meta in self.data]
|
538 |
+
qid_keys = [meta['qid'] for meta in self.data]
|
539 |
+
|
540 |
+
self.vid_cache = {}
|
541 |
+
for v_feat_type in self.v_feat_types:
|
542 |
+
assert 'vid' in v_feat_type
|
543 |
+
with h5py.File(os.path.join(dset_dir, 'h5py', v_feat_type + '.hdf5'), 'r') as f:
|
544 |
+
self.vid_cache[v_feat_type] = {key: f[str(key)][:] for key in tqdm(vid_keys)}
|
545 |
+
|
546 |
+
assert 'txt' in t_feat_type
|
547 |
+
self.txt_cache = {}
|
548 |
+
with h5py.File(os.path.join(dset_dir, 'h5py', t_feat_type + '.hdf5'), 'r') as f:
|
549 |
+
for key in tqdm(qid_keys):
|
550 |
+
try:
|
551 |
+
self.txt_cache[key] = f[str(key)][:]
|
552 |
+
except:
|
553 |
+
logger.info(f"text {key} is not in the cache.")
|
554 |
+
|
555 |
+
def load_data(self):
|
556 |
+
datalist = load_jsonl(self.data_path)
|
557 |
+
if self.data_ratio != 1:
|
558 |
+
n_examples = int(len(datalist) * self.data_ratio)
|
559 |
+
datalist = datalist[:n_examples]
|
560 |
+
logger.info("Using {}% of the data: {} examples"
|
561 |
+
.format(self.data_ratio * 100, n_examples))
|
562 |
+
return datalist
|
563 |
+
|
564 |
+
def __len__(self):
|
565 |
+
return len(self.data)
|
566 |
+
|
567 |
+
def __getitem__(self, index):
|
568 |
+
meta = self.data[index]
|
569 |
+
|
570 |
+
model_inputs = dict()
|
571 |
+
model_inputs["query_feat"] = self._get_query_feat_by_qid(meta["qid"]) # (Dq, ) or (Lq, Dq)
|
572 |
+
|
573 |
+
if self.use_video:
|
574 |
+
model_inputs["video_feat"] = self._get_video_feat_by_vid(meta["vid"]) # (Lv, Dv)
|
575 |
+
ctx_l = len(model_inputs["video_feat"])
|
576 |
+
else:
|
577 |
+
ctx_l = self.max_v_l
|
578 |
+
|
579 |
+
if self.dset_name in ['hacs', 'ego4d', 'videocc', 'activitynet']:
|
580 |
+
for i, window_i in enumerate(meta["relevant_windows"]):
|
581 |
+
if window_i[1] - window_i[0] < self.clip_len:
|
582 |
+
center = (window_i[1] + window_i[0]) / 2
|
583 |
+
window_i[0] = max(0, center - 0.5 * self.clip_len)
|
584 |
+
window_i[1] = min(float(meta['duration']), center + 0.5 * self.clip_len)
|
585 |
+
window_i[1] = max(self.clip_len, window_i[1])
|
586 |
+
|
587 |
+
model_inputs["timestamp"] = ( (torch.arange(0, ctx_l) + self.clip_len / 2) / ctx_l).unsqueeze(1).repeat(1, 2)
|
588 |
+
|
589 |
+
if 'test' in self.data_path and 'qvhighlights' in self.dset_name:
|
590 |
+
meta["relevant_windows"] = [[0, 150]]
|
591 |
+
relevant_windows = torch.Tensor(meta["relevant_windows"])
|
592 |
+
|
593 |
+
# assign the nearest window for each timestamp i.e., qvhighlights.
|
594 |
+
num_vid_seq = model_inputs["timestamp"].shape[0]
|
595 |
+
num_windows = relevant_windows.shape[0]
|
596 |
+
|
597 |
+
relevant_windows_ts = relevant_windows / (ctx_l * self.clip_len)
|
598 |
+
relevant_windows_ts = relevant_windows_ts.unsqueeze(0).repeat(num_vid_seq, 1, 1)
|
599 |
+
model_inputs_ts = model_inputs["timestamp"].unsqueeze(1).repeat(1, num_windows, 1)
|
600 |
+
|
601 |
+
if meta['qid'] is not None:
|
602 |
+
nn_window_ts = torch.zeros_like(model_inputs["timestamp"])
|
603 |
+
diff_left = model_inputs_ts[..., 0] - relevant_windows_ts[..., 0]
|
604 |
+
diff_right = relevant_windows_ts[..., 1] - model_inputs_ts[..., 1]
|
605 |
+
assign_idx = torch.where((diff_left >= 0) * (diff_right >= 0))
|
606 |
+
if min(assign_idx[0].shape) == 0: # not assigned, happened in activitynet.
|
607 |
+
nn_window_ts = relevant_windows_ts.squeeze(1)
|
608 |
+
else:
|
609 |
+
nn_window_ts[assign_idx[0]] = relevant_windows_ts[assign_idx[0], assign_idx[1]]
|
610 |
+
|
611 |
+
model_inputs["span_labels_nn"] = nn_window_ts
|
612 |
+
model_inputs["timestamp_window"] = 1 * (model_inputs["timestamp"][:,0] >= nn_window_ts[:,0]) & (model_inputs["timestamp"][:,1] <= nn_window_ts[:,1])
|
613 |
+
|
614 |
+
# for activitynet.
|
615 |
+
if model_inputs["timestamp_window"].sum() < 1:
|
616 |
+
idx = int(meta['relevant_windows'][0][0] / self.clip_len)
|
617 |
+
idx = max(0, min(idx, ctx_l-1))
|
618 |
+
model_inputs["timestamp_window"][idx] = 1
|
619 |
+
|
620 |
+
if self.use_tef:
|
621 |
+
tef_st = torch.arange(0, ctx_l, 1.0) / ctx_l
|
622 |
+
tef_ed = tef_st + 1.0 / ctx_l
|
623 |
+
tef = torch.stack([tef_st, tef_ed], dim=1) # (Lv, 2)
|
624 |
+
if self.use_video:
|
625 |
+
model_inputs["video_feat"] = torch.cat(
|
626 |
+
[model_inputs["video_feat"], tef], dim=1) # (Lv, Dv+2)
|
627 |
+
else:
|
628 |
+
model_inputs["video_feat"] = tef
|
629 |
+
|
630 |
+
if self.load_labels:
|
631 |
+
model_inputs["span_labels"] = self.get_span_labels(meta["relevant_windows"], ctx_l) # (#windows, 2)
|
632 |
+
if 'saliency_scores' in meta.keys():
|
633 |
+
model_inputs["saliency_scores"] = torch.zeros(ctx_l).double()
|
634 |
+
limit = meta["relevant_clip_ids"].index(ctx_l) if (np.array(meta["relevant_clip_ids"]) >= ctx_l).any() else None
|
635 |
+
model_inputs["saliency_scores"][meta["relevant_clip_ids"][:limit]] = torch.tensor(np.mean(np.array(meta["saliency_scores"][:limit]), -1))
|
636 |
+
model_inputs["saliency_pos_labels"], model_inputs["saliency_neg_labels"] = \
|
637 |
+
self.get_saliency_labels(meta["relevant_clip_ids"], meta["saliency_scores"], ctx_l)
|
638 |
+
else:
|
639 |
+
model_inputs["saliency_scores"] = model_inputs["timestamp_window"]
|
640 |
+
model_inputs["saliency_pos_labels"], model_inputs["saliency_neg_labels"] = \
|
641 |
+
self.get_saliency_labels_sub_as_query(meta["relevant_windows"][0], ctx_l) # only one gt
|
642 |
+
model_inputs["saliency_pos_labels"] = [ random.choice(torch.where(model_inputs['saliency_scores'])[0].tolist()) ]
|
643 |
+
|
644 |
+
return dict(meta=meta, model_inputs=model_inputs)
|
645 |
+
|
646 |
+
def get_saliency_labels_sub_as_query(self, gt_window, ctx_l, max_n=1):
|
647 |
+
gt_st = int(gt_window[0] / self.clip_len)
|
648 |
+
gt_st = min(gt_st, ctx_l-1)
|
649 |
+
gt_ed = max(0, min(int(gt_window[1] / self.clip_len), ctx_l) - 1)
|
650 |
+
if gt_st > gt_ed:
|
651 |
+
gt_ed = gt_st
|
652 |
+
|
653 |
+
if gt_st != gt_ed:
|
654 |
+
pos_clip_indices = random.sample(range(gt_st, gt_ed+1), k=max_n)
|
655 |
+
else:
|
656 |
+
pos_clip_indices = [gt_st] * max_n #[gt_st, gt_st]
|
657 |
+
|
658 |
+
neg_pool = list(range(0, gt_st)) + list(range(gt_ed+1, ctx_l))
|
659 |
+
|
660 |
+
try:
|
661 |
+
neg_clip_indices = random.sample(neg_pool, k=max_n)
|
662 |
+
except:
|
663 |
+
neg_clip_indices = pos_clip_indices
|
664 |
+
|
665 |
+
return pos_clip_indices, neg_clip_indices
|
666 |
+
|
667 |
+
def get_saliency_labels(self, rel_clip_ids, scores, ctx_l, max_n=1):
|
668 |
+
"""Sum the scores from the three annotations, then take the two clips with the
|
669 |
+
maximum scores as positive, and two with the minimum scores as negative.
|
670 |
+
Args:
|
671 |
+
rel_clip_ids: list(int), list of relevant clip ids
|
672 |
+
scores: list([anno1_score, anno2_score, anno3_score]),
|
673 |
+
ctx_l: int
|
674 |
+
max_n: int, #clips to use as positive and negative, for easy and hard negative, respectively.
|
675 |
+
add_easy_negative: bool, if True, sample eay negative outside the relevant_clip_ids.
|
676 |
+
"""
|
677 |
+
# indices inside rel_clip_ids
|
678 |
+
scores = np.array(scores) # (#rel_clips, 3)
|
679 |
+
agg_scores = np.sum(scores, 1) # (#rel_clips, )
|
680 |
+
sort_indices = np.argsort(agg_scores) # increasing
|
681 |
+
|
682 |
+
# indices in the whole video
|
683 |
+
# the min(_, ctx_l-1) here is incorrect, but should not cause
|
684 |
+
# much troubles since this should be rarely used.
|
685 |
+
hard_pos_clip_indices = [min(rel_clip_ids[idx], ctx_l-1) for idx in sort_indices[-max_n:]]
|
686 |
+
hard_neg_clip_indices = [min(rel_clip_ids[idx], ctx_l-1) for idx in sort_indices[:max_n]]
|
687 |
+
|
688 |
+
if agg_scores[sort_indices[-1]] == agg_scores[sort_indices[0]]:
|
689 |
+
hard_neg_clip_indices = hard_pos_clip_indices
|
690 |
+
|
691 |
+
easy_pos_clip_indices = []
|
692 |
+
easy_neg_clip_indices = []
|
693 |
+
|
694 |
+
if self.add_easy_negative > 0:
|
695 |
+
easy_neg_pool = list(set(range(ctx_l)) - set(rel_clip_ids))
|
696 |
+
if len(easy_neg_pool) >= max_n:
|
697 |
+
easy_pos_clip_indices = random.sample(rel_clip_ids, k=max_n)
|
698 |
+
easy_neg_clip_indices = random.sample(easy_neg_pool, k=max_n)
|
699 |
+
else: # copy the hard ones
|
700 |
+
easy_pos_clip_indices = hard_pos_clip_indices
|
701 |
+
easy_neg_clip_indices = hard_neg_clip_indices
|
702 |
+
|
703 |
+
if self.easy_negative_only > 0:
|
704 |
+
return easy_pos_clip_indices, easy_neg_clip_indices
|
705 |
+
|
706 |
+
pos_clip_indices = hard_pos_clip_indices + easy_pos_clip_indices
|
707 |
+
neg_clip_indices = hard_neg_clip_indices + easy_neg_clip_indices
|
708 |
+
return pos_clip_indices, neg_clip_indices
|
709 |
+
|
710 |
+
def get_span_labels(self, windows, ctx_l):
|
711 |
+
"""
|
712 |
+
windows: list([st, ed]) in seconds. E.g. [[26, 36]], corresponding st_ed clip_indices [[13, 17]] (inclusive)
|
713 |
+
Note a maximum of `self.max_windows` windows are used.
|
714 |
+
returns Tensor of shape (#windows, 2), each row is [center, width] normalized by video length
|
715 |
+
"""
|
716 |
+
if len(windows) > self.max_windows:
|
717 |
+
random.shuffle(windows)
|
718 |
+
windows = windows[:self.max_windows]
|
719 |
+
if self.span_loss_type == "l1":
|
720 |
+
windows = torch.Tensor(windows) / (ctx_l * self.clip_len) # normalized windows in xx
|
721 |
+
windows = span_xx_to_cxw(windows) # normalized windows in cxw
|
722 |
+
elif self.span_loss_type == "ce":
|
723 |
+
windows = torch.Tensor([
|
724 |
+
[int(w[0] / self.clip_len), min(int(w[1] / self.clip_len), ctx_l) - 1]
|
725 |
+
for w in windows]).long() # inclusive
|
726 |
+
else:
|
727 |
+
raise NotImplementedError
|
728 |
+
return windows
|
729 |
+
|
730 |
+
def _get_query_feat_by_qid(self, qid):
|
731 |
+
if self.use_cache > 0:
|
732 |
+
try:
|
733 |
+
q_feat = self.txt_cache[qid]
|
734 |
+
except:
|
735 |
+
q_feat = np.zeros((10, self.q_feat_dim)).astype(np.float32)
|
736 |
+
return torch.from_numpy(q_feat)
|
737 |
+
|
738 |
+
q_feat_path = join(self.q_feat_dir, f"{qid}.npz")
|
739 |
+
try:
|
740 |
+
q_feat = np.load(q_feat_path)[self.q_feat_type].astype(np.float32)
|
741 |
+
except:
|
742 |
+
q_feat = np.zeros((10, self.q_feat_dim)).astype(np.float32)
|
743 |
+
logger.info(f"Something wrong when loading the query feature {q_feat_path}.")
|
744 |
+
|
745 |
+
if self.q_feat_type == "last_hidden_state":
|
746 |
+
# q_feat = q_feat[:self.max_q_l]
|
747 |
+
q_feat = q_feat
|
748 |
+
if self.normalize_t:
|
749 |
+
q_feat = l2_normalize_np_array(q_feat)
|
750 |
+
if self.txt_drop_ratio > 0:
|
751 |
+
q_feat = self.random_drop_rows(q_feat)
|
752 |
+
return torch.from_numpy(q_feat) # (D, ) or (Lq, D)
|
753 |
+
|
754 |
+
def random_drop_rows(self, embeddings):
|
755 |
+
"""randomly mask num_drop rows in embeddings to be zero.
|
756 |
+
Args:
|
757 |
+
embeddings: np.ndarray (L, D)
|
758 |
+
"""
|
759 |
+
num_drop_rows = round(len(embeddings) * self.txt_drop_ratio)
|
760 |
+
if num_drop_rows > 0:
|
761 |
+
row_indices = np.random.choice(
|
762 |
+
len(embeddings), size=num_drop_rows, replace=False)
|
763 |
+
embeddings[row_indices] = 0
|
764 |
+
return embeddings
|
765 |
+
|
766 |
+
def _get_video_feat_by_vid(self, vid):
|
767 |
+
v_feat_list = []
|
768 |
+
for feat_type, _feat_dir in zip(self.v_feat_types, self.v_feat_dirs):
|
769 |
+
if self.use_cache > 0:
|
770 |
+
_feat = self.vid_cache[feat_type][vid]
|
771 |
+
else:
|
772 |
+
_feat_path = join(_feat_dir, f"{vid}.npz")
|
773 |
+
_feat = np.load(_feat_path)["features"].astype(np.float32)
|
774 |
+
# _feat = np.load(_feat_path)["features"][:self.max_v_l].astype(np.float32)
|
775 |
+
if self.normalize_v:
|
776 |
+
_feat = l2_normalize_np_array(_feat)
|
777 |
+
v_feat_list.append(_feat)
|
778 |
+
# some features are slightly longer than the others
|
779 |
+
min_len = min([len(e) for e in v_feat_list])
|
780 |
+
v_feat_list = [e[:min_len] for e in v_feat_list]
|
781 |
+
v_feat = np.concatenate(v_feat_list, axis=1)
|
782 |
+
return torch.from_numpy(v_feat) # (Lv, D)
|
783 |
+
|
784 |
+
class DatasetHL(Dataset):
|
785 |
+
def __init__(self,
|
786 |
+
dset_name,
|
787 |
+
domain,
|
788 |
+
data_path,
|
789 |
+
v_feat_types,
|
790 |
+
v_feat_dirs,
|
791 |
+
t_feat_dir,
|
792 |
+
use_tef=False
|
793 |
+
):
|
794 |
+
assert dset_name in ['tvsum', 'youtube']
|
795 |
+
self.dset_name = dset_name
|
796 |
+
dset_domain = {'tvsum': TVSUM_SPLITS,
|
797 |
+
'youtube': YOUTUBE_SPLITS}
|
798 |
+
self.splits = dset_domain[dset_name]
|
799 |
+
assert domain in self.splits.keys()
|
800 |
+
|
801 |
+
self.domain = domain
|
802 |
+
assert len(data_path) == 1
|
803 |
+
self.data_path = data_path[0] if isinstance(data_path, list) else data_path
|
804 |
+
self.v_feat_types = v_feat_types.split('_')
|
805 |
+
self.v_feat_dirs = v_feat_dirs
|
806 |
+
self.q_feat_type = "last_hidden_state"
|
807 |
+
self.q_feat_dir = t_feat_dir
|
808 |
+
|
809 |
+
self.txt_drop_ratio = 0
|
810 |
+
self.normalize_t = True
|
811 |
+
self.normalize_v = True
|
812 |
+
|
813 |
+
self.label = nncore.load(self.data_path)
|
814 |
+
self.use_tef = use_tef
|
815 |
+
|
816 |
+
self.video_id = {
|
817 |
+
k: [s for s in self.splits[domain][k] if s in self.label]
|
818 |
+
for k in ('train', 'val')
|
819 |
+
}
|
820 |
+
self.set_state('train')
|
821 |
+
|
822 |
+
def __len__(self):
|
823 |
+
return len(self.video_id[self.state])
|
824 |
+
|
825 |
+
def __getitem__(self, idx):
|
826 |
+
vid = self.get_video_id(idx)
|
827 |
+
video = self._get_video_feat_by_vid(vid)
|
828 |
+
saliency = self.get_saliency(idx)
|
829 |
+
|
830 |
+
if self.dset_name == 'youtube':
|
831 |
+
saliency_pos_labels = torch.Tensor([random.choice(torch.where(saliency > 0)[0].tolist())])
|
832 |
+
elif self.dset_name == 'tvsum':
|
833 |
+
saliency_pos_labels = torch.Tensor([random.choice(torch.where(saliency > 0)[0].tolist())])
|
834 |
+
# saliency_pos_labels = torch.Tensor([random.choice(torch.where(saliency != min(saliency))[0].tolist())])
|
835 |
+
else:
|
836 |
+
raise NotImplementedError
|
837 |
+
|
838 |
+
num_clips = min(c.size(0) for c in (video, saliency))
|
839 |
+
|
840 |
+
video = video[:num_clips]
|
841 |
+
saliency = saliency[:num_clips]
|
842 |
+
|
843 |
+
if self.use_tef:
|
844 |
+
ctx_l = video.shape[0]
|
845 |
+
tef_st = torch.arange(0, ctx_l, 1.0) / ctx_l
|
846 |
+
tef_ed = tef_st + 1.0 / ctx_l
|
847 |
+
tef = torch.stack([tef_st, tef_ed], dim=1) # (Lv, 2)
|
848 |
+
video = torch.cat([video, tef], dim=1) # (Lv, Dv+2)
|
849 |
+
|
850 |
+
data = dict(
|
851 |
+
video=DataContainer(video),
|
852 |
+
saliency=DataContainer(saliency, pad_value=-1),
|
853 |
+
saliency_pos_labels=saliency_pos_labels)
|
854 |
+
|
855 |
+
if self.q_feat_dir is not None:
|
856 |
+
query = self._get_query_feat_by_qid(vid)
|
857 |
+
data['query'] = DataContainer(query, pad_value=float('inf'))
|
858 |
+
return data
|
859 |
+
|
860 |
+
def set_state(self, state):
|
861 |
+
self.state = 'train' if state == 'train' else 'val'
|
862 |
+
|
863 |
+
def get_video_id(self, idx):
|
864 |
+
return self.video_id[self.state][idx]
|
865 |
+
|
866 |
+
def get_video(self, idx):
|
867 |
+
video_id = self.get_video_id(idx)
|
868 |
+
video = torch.from_numpy(self.video[video_id]).float()
|
869 |
+
optic = torch.from_numpy(self.optic[video_id]).float()
|
870 |
+
return torch.cat((video, optic), dim=1)
|
871 |
+
|
872 |
+
def _get_video_feat_by_vid(self, vid):
|
873 |
+
v_feat_list = []
|
874 |
+
for feat_type, _feat_dir in zip(self.v_feat_types, self.v_feat_dirs):
|
875 |
+
# if self.use_cache > 0:
|
876 |
+
# _feat = self.vid_cache[feat_type][vid]
|
877 |
+
# else:
|
878 |
+
if True:
|
879 |
+
_feat_path = join(_feat_dir, f"{vid}.npz")
|
880 |
+
_feat = np.load(_feat_path)["features"].astype(np.float32)
|
881 |
+
if self.normalize_v:
|
882 |
+
_feat = l2_normalize_np_array(_feat)
|
883 |
+
v_feat_list.append(_feat)
|
884 |
+
# some features are slightly longer than the others
|
885 |
+
min_len = min([len(e) for e in v_feat_list])
|
886 |
+
v_feat_list = [e[:min_len] for e in v_feat_list]
|
887 |
+
v_feat = np.concatenate(v_feat_list, axis=1)
|
888 |
+
return torch.from_numpy(v_feat) # (Lv, D)
|
889 |
+
|
890 |
+
def _get_query_feat_by_qid(self, qid):
|
891 |
+
# if self.use_cache > 0:
|
892 |
+
# try:
|
893 |
+
# q_feat = self.txt_cache[qid]
|
894 |
+
# except:
|
895 |
+
# q_feat = np.zeros((10, self.q_feat_dim)).astype(np.float32)
|
896 |
+
# return torch.from_numpy(q_feat)
|
897 |
+
|
898 |
+
q_feat_path = join(self.q_feat_dir, f"{qid}.npz")
|
899 |
+
try:
|
900 |
+
q_feat = np.load(q_feat_path)[self.q_feat_type].astype(np.float32)
|
901 |
+
except:
|
902 |
+
q_feat = np.zeros((10, self.q_feat_dim)).astype(np.float32)
|
903 |
+
logger.info(f"Something wrong when loading the query feature {q_feat_path}.")
|
904 |
+
|
905 |
+
if self.q_feat_type == "last_hidden_state":
|
906 |
+
# q_feat = q_feat[:self.max_q_l]
|
907 |
+
q_feat = q_feat
|
908 |
+
if self.normalize_t:
|
909 |
+
q_feat = l2_normalize_np_array(q_feat)
|
910 |
+
if self.txt_drop_ratio > 0:
|
911 |
+
q_feat = self.random_drop_rows(q_feat)
|
912 |
+
return torch.from_numpy(q_feat) # (D, ) or (Lq, D)
|
913 |
+
|
914 |
+
def get_saliency(self, idx):
|
915 |
+
if self.dset_name == 'tvsum':
|
916 |
+
video_id = self.get_video_id(idx)
|
917 |
+
saliency = torch.Tensor(self.label[video_id]['anno'])
|
918 |
+
|
919 |
+
# top-5 saliency scores as a threshold.
|
920 |
+
# saliency_tmp = saliency.mean(1)
|
921 |
+
# topk = int(saliency_tmp.shape[0] * 0.1)
|
922 |
+
# th = saliency_tmp[torch.sort(saliency_tmp)[1][-topk]] # v4
|
923 |
+
# saliency = saliency_tmp - th
|
924 |
+
|
925 |
+
# saliency_tmp = saliency.mean(1) # med
|
926 |
+
# th = saliency_tmp.median()
|
927 |
+
# saliency = saliency_tmp - th
|
928 |
+
|
929 |
+
saliency = (saliency - saliency.mean()).mean(dim=1)
|
930 |
+
# saliency = (saliency.sum(dim=1) - 20) / 80 # v2
|
931 |
+
|
932 |
+
elif self.dset_name == 'youtube':
|
933 |
+
video_id = self.get_video_id(idx)
|
934 |
+
saliency = [1 if s > 0 else 0 for s in self.label[video_id]['match']]
|
935 |
+
else:
|
936 |
+
raise NotImplementedError
|
937 |
+
return torch.Tensor(saliency)
|
938 |
+
|
939 |
+
def evaluate(self, blob, k=5, save_dir=None, **kwargs):
|
940 |
+
# blob = nncore.to_dict_of_list(blob)
|
941 |
+
collected = []
|
942 |
+
|
943 |
+
if save_dir is not None:
|
944 |
+
import json
|
945 |
+
with open(os.path.join(save_dir, self.dset_name, self.domain +'.jsonl'), 'w') as f:
|
946 |
+
for idx, score in enumerate(blob):
|
947 |
+
video_id = self.get_video_id(idx)
|
948 |
+
entry = {'vid':video_id, 'pred': score[0].tolist(), 'gt': self.get_saliency(idx).tolist(),
|
949 |
+
'duration': int(self.label[video_id]['frames']) / int(self.label[video_id]['fps']),
|
950 |
+
'domain': self.label[video_id]['domain'], 'fps': self.label[video_id]['fps']}
|
951 |
+
if self.dset_name == 'tvsum':
|
952 |
+
entry.update({'title':self.label[video_id]['title']})
|
953 |
+
if self.dset_name == 'youtube':
|
954 |
+
entry.update({'clip':self.label[video_id]['clip']})
|
955 |
+
f.write(json.dumps(entry) + '\n')
|
956 |
+
|
957 |
+
if self.dset_name == 'tvsum':
|
958 |
+
for i in range(20):
|
959 |
+
video_ap = []
|
960 |
+
for idx, score in enumerate(blob):
|
961 |
+
inds = torch.argsort(score[0], descending=True)
|
962 |
+
video_id = self.get_video_id(idx)
|
963 |
+
label = torch.Tensor(self.label[video_id]['anno'])[:, i]
|
964 |
+
label = torch.where(label > label.median(), 1.0, .0)
|
965 |
+
label = label[inds].tolist()[:k]
|
966 |
+
|
967 |
+
if (num_gt := sum(label)) == 0:
|
968 |
+
video_ap.append(0)
|
969 |
+
continue
|
970 |
+
|
971 |
+
hits = ap = rec = 0
|
972 |
+
prc = 1
|
973 |
+
|
974 |
+
for j, gt in enumerate(label):
|
975 |
+
hits += gt
|
976 |
+
_rec = hits / num_gt
|
977 |
+
_prc = hits / (j + 1)
|
978 |
+
ap += (_rec - rec) * (prc + _prc) / 2
|
979 |
+
rec, prc = _rec, _prc
|
980 |
+
video_ap.append(ap)
|
981 |
+
collected.append(sum(video_ap) / len(video_ap))
|
982 |
+
|
983 |
+
elif self.dset_name == 'youtube':
|
984 |
+
for idx, score in enumerate(blob):
|
985 |
+
inds = torch.argsort(score[0], descending=True)
|
986 |
+
label = self.get_saliency(idx)[inds].tolist()
|
987 |
+
|
988 |
+
if (num_gt := sum(label)) == 0:
|
989 |
+
collected.append(0)
|
990 |
+
continue
|
991 |
+
|
992 |
+
hits = ap = rec = 0
|
993 |
+
prc = 1
|
994 |
+
|
995 |
+
for i, gt in enumerate(label):
|
996 |
+
hits += gt
|
997 |
+
_rec = hits / num_gt
|
998 |
+
_prc = hits / (i + 1)
|
999 |
+
ap += (_rec - rec) * (prc + _prc) / 2
|
1000 |
+
rec, prc = _rec, _prc
|
1001 |
+
collected.append(ap)
|
1002 |
+
else:
|
1003 |
+
raise NotImplementedError
|
1004 |
+
|
1005 |
+
mean_ap = sum(collected) / len(collected)
|
1006 |
+
results = dict(mAP=round(mean_ap, 5))
|
1007 |
+
return results
|
1008 |
+
|
1009 |
+
class DatasetQFVS(Dataset):
|
1010 |
+
def __init__(self,config, use_tef=True):
|
1011 |
+
# pdb.set_trace()
|
1012 |
+
self.config=config
|
1013 |
+
self.dataset=[]
|
1014 |
+
self.use_tef=use_tef
|
1015 |
+
|
1016 |
+
self.embedding=load_pickle(f"./data/qfvs/txt_clip/{self.config['txt_feature']}.pkl")
|
1017 |
+
|
1018 |
+
for video_id in self.config["train_videos"]:
|
1019 |
+
for _ , _, files in os.walk("./data/qfvs/metadata/origin_data/Query-Focused_Summaries/Oracle_Summaries/P0"+str(video_id)):
|
1020 |
+
for file in files:
|
1021 |
+
self.dataset.append(file[:file.find("_oracle.txt")]+"_"+str(video_id))
|
1022 |
+
|
1023 |
+
def __getitem__(self,index):
|
1024 |
+
video_id=self.dataset[index].split('_')[2]
|
1025 |
+
feat_type = self.config['vid_feature']
|
1026 |
+
# pdb.set_trace()
|
1027 |
+
feat_type = self.config['vid_feature']
|
1028 |
+
f=h5py.File(f'./data/qfvs/processed/P0{video_id}_{feat_type}.h5','r')
|
1029 |
+
features=f['feature'][()]
|
1030 |
+
# dim=features.shape[-1]
|
1031 |
+
# features=features.reshape(-1, dim)
|
1032 |
+
# seg_len=f['seg_len'][()]
|
1033 |
+
dim = features.shape[-1]
|
1034 |
+
ctx_l = features.shape[0]
|
1035 |
+
seg_len = np.ones(ctx_l)
|
1036 |
+
|
1037 |
+
# mask = torch.zeros(self.config["max_segment_num"], self.config["max_frame_num"], dtype=torch.bool)
|
1038 |
+
# for j in range(len(seg_len)):
|
1039 |
+
# for k in range(seg_len[j]):
|
1040 |
+
# mask[j][k] = 1
|
1041 |
+
|
1042 |
+
# ctx_l = seg_len.sum()
|
1043 |
+
features = torch.from_numpy(features)
|
1044 |
+
# features = features[mask, :]
|
1045 |
+
|
1046 |
+
if self.use_tef:
|
1047 |
+
tef_st = torch.arange(0, ctx_l, 1.0) / ctx_l
|
1048 |
+
tef_ed = tef_st + 1.0 / ctx_l
|
1049 |
+
tef = torch.stack([tef_st, tef_ed], dim=1) # (Lv, 2)
|
1050 |
+
features = torch.cat([features, tef], dim=1) # (Lv, Dv+2)
|
1051 |
+
|
1052 |
+
transfer={"Cupglass":"Glass",
|
1053 |
+
"Musicalinstrument":"Instrument",
|
1054 |
+
"Petsanimal":"Animal"}
|
1055 |
+
|
1056 |
+
concept1,concept2=self.dataset[index].split('_')[0:2]
|
1057 |
+
|
1058 |
+
concept1_GT=torch.zeros(ctx_l)
|
1059 |
+
concept2_GT=torch.zeros(ctx_l)
|
1060 |
+
with open("./data/qfvs/metadata/origin_data/Dense_per_shot_tags/P0"+video_id+"/P0"+video_id+".txt","r") as f:
|
1061 |
+
lines=f.readlines()
|
1062 |
+
for index,line in enumerate(lines):
|
1063 |
+
concepts=line.strip().split(',')
|
1064 |
+
if concept1 in concepts:
|
1065 |
+
concept1_GT[index]=1
|
1066 |
+
if concept2 in concepts:
|
1067 |
+
concept2_GT[index]=1
|
1068 |
+
|
1069 |
+
# shot_num=seg_len.sum()
|
1070 |
+
# mask_GT=torch.zeros(ctx_l)
|
1071 |
+
# for i in range(shot_num):
|
1072 |
+
# mask_GT[i]=1
|
1073 |
+
mask_GT=torch.ones(ctx_l)
|
1074 |
+
|
1075 |
+
oracle_summary = torch.zeros(ctx_l)
|
1076 |
+
GT_summary_shots = []
|
1077 |
+
with open("./data/qfvs/metadata/origin_data/Query-Focused_Summaries/Oracle_Summaries/P0"+str(video_id)+"/"+str(concept1)+"_"+str(concept2)+"_"+"oracle.txt","r") as f:
|
1078 |
+
for line in f.readlines():
|
1079 |
+
GT_summary_shots.append(int(line.strip()))
|
1080 |
+
GT_summary_shots = [x - 1 for x in GT_summary_shots]
|
1081 |
+
for element in GT_summary_shots:
|
1082 |
+
oracle_summary[element] = 1
|
1083 |
+
|
1084 |
+
if concept1 in transfer:
|
1085 |
+
concept1=transfer[concept1]
|
1086 |
+
if concept2 in transfer:
|
1087 |
+
concept2=transfer[concept2]
|
1088 |
+
concept1=self.embedding[concept1]
|
1089 |
+
concept2=self.embedding[concept2]
|
1090 |
+
|
1091 |
+
try:
|
1092 |
+
saliency_pos_labels_1 = torch.Tensor([random.choice(torch.where(concept1_GT> 0)[0].tolist())])
|
1093 |
+
except:
|
1094 |
+
saliency_pos_labels_1 = torch.Tensor(0)
|
1095 |
+
|
1096 |
+
try:
|
1097 |
+
saliency_pos_labels_2 = torch.Tensor([random.choice(torch.where(concept2_GT> 0)[0].tolist())])
|
1098 |
+
except:
|
1099 |
+
saliency_pos_labels_2 = torch.Tensor(0)
|
1100 |
+
|
1101 |
+
try:
|
1102 |
+
saliency_pos_labels_oracle = torch.Tensor([random.choice(torch.where(oracle_summary> 0)[0].tolist())])
|
1103 |
+
except:
|
1104 |
+
saliency_pos_labels_oracle = torch.Tensor(0)
|
1105 |
+
|
1106 |
+
return {
|
1107 |
+
'features':features,
|
1108 |
+
'seg_len':torch.from_numpy(seg_len),
|
1109 |
+
'concept1_GT':concept1_GT,
|
1110 |
+
'concept2_GT':concept2_GT,
|
1111 |
+
'mask_GT':mask_GT,
|
1112 |
+
'oracle_summary':oracle_summary,
|
1113 |
+
'tokens_pad1':torch.from_numpy(concept1),
|
1114 |
+
'tokens_pad2':torch.from_numpy(concept2),
|
1115 |
+
'saliency_pos_labels_1': saliency_pos_labels_1,
|
1116 |
+
'saliency_pos_labels_2': saliency_pos_labels_2,
|
1117 |
+
'saliency_pos_labels_oracle': saliency_pos_labels_oracle,
|
1118 |
+
}
|
1119 |
+
|
1120 |
+
def __len__(self):
|
1121 |
+
return len(self.dataset)
|
1122 |
+
|
1123 |
+
def start_end_collate_mr(batch):
|
1124 |
+
batch_meta = [e["meta"] for e in batch] # seems no need to collate ?
|
1125 |
+
|
1126 |
+
model_inputs_keys = batch[0]["model_inputs"].keys()
|
1127 |
+
batched_data = dict()
|
1128 |
+
for k in model_inputs_keys:
|
1129 |
+
if k == "span_labels":
|
1130 |
+
batched_data[k] = [dict(spans=e["model_inputs"]["span_labels"]) for e in batch]
|
1131 |
+
continue
|
1132 |
+
if k in ["saliency_pos_labels", "saliency_neg_labels"]:
|
1133 |
+
batched_data[k] = torch.LongTensor([e["model_inputs"][k] for e in batch])
|
1134 |
+
continue
|
1135 |
+
|
1136 |
+
batched_data[k] = pad_sequences_1d(
|
1137 |
+
[e["model_inputs"][k] for e in batch], dtype=torch.float32, fixed_length=None)
|
1138 |
+
return batch_meta, batched_data
|
1139 |
+
|
1140 |
+
def start_end_collate_hl(batch):
|
1141 |
+
model_inputs_keys = batch[0].keys()
|
1142 |
+
|
1143 |
+
batched_data = dict()
|
1144 |
+
for k in model_inputs_keys:
|
1145 |
+
batched_data[k] = pad_sequences_1d([e[k].data for e in batch], dtype=torch.float32, fixed_length=None)
|
1146 |
+
return batched_data
|
1147 |
+
|
1148 |
+
def start_end_collate_qfvs(batch):
|
1149 |
+
model_inputs_keys = batch[0].keys()
|
1150 |
+
|
1151 |
+
batched_data = dict()
|
1152 |
+
for k in model_inputs_keys:
|
1153 |
+
batched_data[k] = pad_sequences_1d([e[k].data for e in batch], dtype=torch.float32, fixed_length=None)
|
1154 |
+
|
1155 |
+
return batched_data
|
1156 |
+
|
1157 |
+
def prepare_batch_inputs_mr(batched_model_inputs, device, non_blocking=False):
|
1158 |
+
model_inputs = dict(
|
1159 |
+
src_txt=batched_model_inputs["query_feat"][0].to(device, non_blocking=non_blocking),
|
1160 |
+
src_txt_mask=batched_model_inputs["query_feat"][1].to(device, non_blocking=non_blocking),
|
1161 |
+
src_vid=batched_model_inputs["video_feat"][0].to(device, non_blocking=non_blocking),
|
1162 |
+
src_vid_mask=batched_model_inputs["video_feat"][1].to(device, non_blocking=non_blocking),
|
1163 |
+
)
|
1164 |
+
targets = {}
|
1165 |
+
targets['timestamp'] = batched_model_inputs["timestamp"][0].to(device, non_blocking=non_blocking)
|
1166 |
+
targets['timestamp_mask'] = batched_model_inputs["timestamp"][1].to(device, non_blocking=non_blocking)
|
1167 |
+
targets['timestamp_window'] = batched_model_inputs["timestamp_window"][0].to(device, non_blocking=non_blocking)
|
1168 |
+
targets['span_labels_nn'] = batched_model_inputs["span_labels_nn"][0].to(device, non_blocking=non_blocking)
|
1169 |
+
|
1170 |
+
if 'saliency_scores' in batched_model_inputs.keys():
|
1171 |
+
targets['saliency_scores'] = batched_model_inputs["saliency_scores"][0].to(device, non_blocking=non_blocking)
|
1172 |
+
|
1173 |
+
if "span_labels" in batched_model_inputs:
|
1174 |
+
targets["span_labels"] = [
|
1175 |
+
dict(spans=e["spans"].to(device, non_blocking=non_blocking))
|
1176 |
+
for e in batched_model_inputs["span_labels"]
|
1177 |
+
]
|
1178 |
+
if "saliency_pos_labels" in batched_model_inputs:
|
1179 |
+
for name in ["saliency_pos_labels", "saliency_neg_labels"]:
|
1180 |
+
targets[name] = batched_model_inputs[name].to(device, non_blocking=non_blocking)
|
1181 |
+
|
1182 |
+
if "weight_ablation" in batched_model_inputs:
|
1183 |
+
targets["weight_ablation"] = batched_model_inputs["weight_ablation"][0].to(device, non_blocking=non_blocking)
|
1184 |
+
|
1185 |
+
targets = None if len(targets) == 0 else targets
|
1186 |
+
return model_inputs, targets
|
1187 |
+
|
1188 |
+
def prepare_batch_inputs_hl(batched_model_inputs, device='cuda', non_blocking=False):
|
1189 |
+
src_vid = batched_model_inputs['video'][0].to(device, non_blocking=non_blocking)
|
1190 |
+
src_vid_mask = batched_model_inputs['video'][1].bool().to(device, non_blocking=non_blocking)
|
1191 |
+
src_txt = batched_model_inputs['query'][0].to(device, non_blocking=non_blocking) \
|
1192 |
+
if 'query' in batched_model_inputs.keys() else None
|
1193 |
+
src_txt_mask = batched_model_inputs['query'][1].bool().to(device, non_blocking=non_blocking) \
|
1194 |
+
if 'query' in batched_model_inputs.keys() else None
|
1195 |
+
|
1196 |
+
model_inputs = dict(
|
1197 |
+
src_vid=src_vid, src_vid_mask=src_vid_mask,
|
1198 |
+
src_txt=src_txt, src_txt_mask=src_txt_mask)
|
1199 |
+
|
1200 |
+
# if 'audio' in batched_model_inputs.keys():
|
1201 |
+
# src_aud = batched_model_inputs['audio'][0].bool().to(device, non_blocking=non_blocking)
|
1202 |
+
# src_aud_mask = batched_model_inputs['audio'][1].bool().to(device, non_blocking=non_blocking)
|
1203 |
+
# model_inputs['src_aud']=src_aud; model_inputs['src_aud_mask']=src_aud_mask;
|
1204 |
+
|
1205 |
+
targets = {}
|
1206 |
+
saliency = batched_model_inputs['saliency'][0].to(device, non_blocking=non_blocking)
|
1207 |
+
saliency_pos_labels = batched_model_inputs['saliency_pos_labels'][0].to(device, non_blocking=non_blocking)
|
1208 |
+
|
1209 |
+
targets['saliency_scores'] = saliency
|
1210 |
+
targets['saliency_pos_labels'] = saliency_pos_labels.long()
|
1211 |
+
targets['timestamp_mask'] = batched_model_inputs["video"][1].to(device, non_blocking=non_blocking)
|
1212 |
+
targets['timestamp_window'] = 1 * (saliency > 0)
|
1213 |
+
|
1214 |
+
return model_inputs, targets
|
1215 |
+
|
1216 |
+
def prepare_batch_inputs_qfvs(data, config, eval=False):
|
1217 |
+
if not eval:
|
1218 |
+
features, mask, seg_len, \
|
1219 |
+
concept1_GT, concept2_GT, mask_GT, oracle_summary_GT, \
|
1220 |
+
src_txt_1, src_txt_2, src_txt_mask_1, src_txt_mask_2,\
|
1221 |
+
saliency_pos_labels_1, saliency_pos_labels_2, saliency_pos_labels_oracle = \
|
1222 |
+
data['features'][0], data['features'][1], data['seg_len'][0],\
|
1223 |
+
data['concept1_GT'][0], data['concept2_GT'][0], data['mask_GT'][0], data['oracle_summary'][0],\
|
1224 |
+
data['tokens_pad1'][0], data['tokens_pad2'][0], data['tokens_pad1'][1], data['tokens_pad2'][1], \
|
1225 |
+
data['saliency_pos_labels_1'][0], data['saliency_pos_labels_2'][0], data['saliency_pos_labels_oracle'][0],
|
1226 |
+
else:
|
1227 |
+
features, mask, seg_len, \
|
1228 |
+
src_txt_1, src_txt_2, src_txt_mask_1, src_txt_mask_2 = \
|
1229 |
+
data['features'][0], data['features'][1], data['seg_len'][0],\
|
1230 |
+
data['tokens_pad1'][0], data['tokens_pad2'][0], data['tokens_pad1'][1], data['tokens_pad2'][1]
|
1231 |
+
|
1232 |
+
# preprocess for vid input.
|
1233 |
+
seq = features.to('cuda')
|
1234 |
+
mask = mask.to('cuda')
|
1235 |
+
|
1236 |
+
# for txt input.
|
1237 |
+
src_txt_1 = src_txt_1.to(torch.float32).to('cuda')
|
1238 |
+
src_txt_2 = src_txt_2.to(torch.float32).to('cuda')
|
1239 |
+
src_txt_mask_1 = src_txt_mask_1.to('cuda')
|
1240 |
+
src_txt_mask_2 = src_txt_mask_2.to('cuda')
|
1241 |
+
|
1242 |
+
src_txt_oracle = torch.cat((src_txt_1, src_txt_2), dim=1).to('cuda')
|
1243 |
+
src_txt_mask_oracle = torch.cat((src_txt_mask_1, src_txt_mask_2), dim=1).to('cuda')
|
1244 |
+
|
1245 |
+
model_inputs_1 = dict(src_vid=seq, src_vid_mask=mask, src_txt=src_txt_1, src_txt_mask=src_txt_mask_1)
|
1246 |
+
model_inputs_2 = dict(src_vid=seq, src_vid_mask=mask, src_txt=src_txt_2, src_txt_mask=src_txt_mask_2)
|
1247 |
+
model_inputs_oracle = dict(src_vid=seq, src_vid_mask=mask, src_txt=src_txt_oracle, src_txt_mask=src_txt_mask_oracle)
|
1248 |
+
|
1249 |
+
if not eval:
|
1250 |
+
targets_1 = dict(saliency_scores=concept1_GT.to('cuda'), saliency_pos_labels=saliency_pos_labels_1.to('cuda'))
|
1251 |
+
targets_2 = dict(saliency_scores=concept2_GT.to('cuda'), saliency_pos_labels=saliency_pos_labels_2.to('cuda'))
|
1252 |
+
targets_oracle = dict(saliency_scores=oracle_summary_GT.to('cuda'), saliency_pos_labels=saliency_pos_labels_oracle.to('cuda'))
|
1253 |
+
|
1254 |
+
targets_1['timestamp_mask'] = mask; targets_1['timestamp_window'] = concept1_GT.to('cuda')
|
1255 |
+
targets_2['timestamp_mask'] = mask; targets_2['timestamp_window'] = concept2_GT.to('cuda')
|
1256 |
+
targets_oracle['timestamp_mask'] = mask; targets_oracle['timestamp_window'] = oracle_summary_GT.to('cuda')
|
1257 |
+
|
1258 |
+
return model_inputs_1, model_inputs_2, model_inputs_oracle, \
|
1259 |
+
targets_1, targets_2, targets_oracle, mask_GT
|
1260 |
+
else:
|
1261 |
+
return model_inputs_1, model_inputs_2, model_inputs_oracle, mask
|
main/dataset_qfvs.py
ADDED
@@ -0,0 +1,284 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import pdb
|
3 |
+
import h5py
|
4 |
+
import nncore
|
5 |
+
import torch
|
6 |
+
from torch.utils.data import Dataset
|
7 |
+
import numpy as np
|
8 |
+
from tqdm import tqdm
|
9 |
+
import random
|
10 |
+
import logging
|
11 |
+
from os.path import join, exists
|
12 |
+
from nncore.dataset import DATASETS
|
13 |
+
from nncore.parallel import DataContainer
|
14 |
+
from main.config_hl import TVSUM_SPLITS, YOUTUBE_SPLITS
|
15 |
+
from utils.basic_utils import load_jsonl, load_pickle, l2_normalize_np_array
|
16 |
+
from utils.tensor_utils import pad_sequences_1d
|
17 |
+
from utils.span_utils import span_xx_to_cxw
|
18 |
+
|
19 |
+
logger = logging.getLogger(__name__)
|
20 |
+
|
21 |
+
class DatasetQFVS(Dataset):
|
22 |
+
def __init__(self,config, use_tef=True):
|
23 |
+
# pdb.set_trace()
|
24 |
+
self.config=config
|
25 |
+
self.dataset=[]
|
26 |
+
self.use_tef=use_tef
|
27 |
+
|
28 |
+
self.embedding=load_pickle(f"./data/qfvs/txt_clip/{self.config['txt_feature']}.pkl")
|
29 |
+
|
30 |
+
self.transfer={"Cupglass":"Glass",
|
31 |
+
"Musicalinstrument":"Instrument",
|
32 |
+
"Petsanimal":"Animal"}
|
33 |
+
|
34 |
+
self.f_dict = {}
|
35 |
+
feat_type = self.config['vid_feature']
|
36 |
+
|
37 |
+
for video_id in self.config["train_videos"]:
|
38 |
+
self.f_dict[str(video_id)] = h5py.File(f'./data/qfvs/processed/P0{video_id}_{feat_type}.h5','r')
|
39 |
+
for _ , _, files in os.walk("./data/qfvs/metadata/origin_data/Query-Focused_Summaries/Oracle_Summaries/P0"+str(video_id)):
|
40 |
+
for file in files:
|
41 |
+
self.dataset.append(['Oracle', file[:file.find("_oracle.txt")]+"_"+str(video_id)])
|
42 |
+
|
43 |
+
if self.config['qfvs_dense_shot'] > 0:
|
44 |
+
dense_concept = {}
|
45 |
+
feat_type = self.config['vid_feature']
|
46 |
+
feat=h5py.File(f'./data/qfvs/processed/P0{video_id}_{feat_type}.h5','r')
|
47 |
+
features=feat['features'][()]
|
48 |
+
seg_len=feat['seg_len'][()]
|
49 |
+
with open("./data/qfvs/metadata/origin_data/Dense_per_shot_tags/P0"+str(video_id)+"/P0"+str(video_id)+".txt","r") as f:
|
50 |
+
lines=f.readlines()
|
51 |
+
for index,line in enumerate(lines):
|
52 |
+
concepts=line.strip().split(',')
|
53 |
+
for concept in concepts:
|
54 |
+
if concept in self.transfer:
|
55 |
+
concept= self.transfer[concept]
|
56 |
+
if concept not in dense_concept:
|
57 |
+
# dense_concept[concept] = torch.zeros(seg_len.sum())
|
58 |
+
dense_concept[concept] = torch.zeros(self.config["max_segment_num"]*self.config["max_frame_num"])
|
59 |
+
else:
|
60 |
+
dense_concept[concept][index] = 1
|
61 |
+
|
62 |
+
for key, value in dense_concept.items():
|
63 |
+
if value.sum().item() > 0:
|
64 |
+
self.dataset.append([video_id, key, value])
|
65 |
+
|
66 |
+
def __getitem__(self, index):
|
67 |
+
if self.dataset[index][0] == 'Oracle':
|
68 |
+
return self.get_oracle(index)
|
69 |
+
else:
|
70 |
+
return self.get_dense(index)
|
71 |
+
|
72 |
+
def get_dense(self,index):
|
73 |
+
video_id=str(self.dataset[index][0])
|
74 |
+
f = self.f_dict[video_id]
|
75 |
+
# feat_type = self.config['vid_feature']
|
76 |
+
# f=h5py.File(f'./data/qfvs/processed/P0{video_id}_{feat_type}.h5','r')
|
77 |
+
features=f['features'][()]
|
78 |
+
seg_len=f['seg_len'][()]
|
79 |
+
|
80 |
+
dim = features.shape[-1]
|
81 |
+
|
82 |
+
mask_GT = torch.zeros(self.config["max_segment_num"], self.config["max_frame_num"], dtype=torch.bool)
|
83 |
+
for j in range(len(seg_len)):
|
84 |
+
for k in range(seg_len[j]):
|
85 |
+
mask_GT[j][k] = 1
|
86 |
+
|
87 |
+
features = torch.from_numpy(features)
|
88 |
+
|
89 |
+
concept1 = concept2 = self.dataset[index][1]
|
90 |
+
concept1_GT = concept2_GT = oracle_summary = self.dataset[index][2]
|
91 |
+
|
92 |
+
if concept1 in self.transfer:
|
93 |
+
concept1=self.transfer[concept1]
|
94 |
+
if concept2 in self.transfer:
|
95 |
+
concept2=self.transfer[concept2]
|
96 |
+
concept1=self.embedding[concept1]
|
97 |
+
concept2=self.embedding[concept2]
|
98 |
+
|
99 |
+
concept1 = l2_normalize_np_array(concept1)
|
100 |
+
concept2 = l2_normalize_np_array(concept2)
|
101 |
+
|
102 |
+
try:
|
103 |
+
saliency_pos_labels_1 = torch.Tensor([random.choice(torch.where(concept1_GT> 0)[0].tolist())])
|
104 |
+
except:
|
105 |
+
saliency_pos_labels_1 = torch.Tensor(0)
|
106 |
+
|
107 |
+
try:
|
108 |
+
saliency_pos_labels_2 = torch.Tensor([random.choice(torch.where(concept2_GT> 0)[0].tolist())])
|
109 |
+
except:
|
110 |
+
saliency_pos_labels_2 = torch.Tensor(0)
|
111 |
+
|
112 |
+
try:
|
113 |
+
saliency_pos_labels_oracle = torch.Tensor([random.choice(torch.where(oracle_summary> 0)[0].tolist())])
|
114 |
+
except:
|
115 |
+
saliency_pos_labels_oracle = torch.Tensor(0)
|
116 |
+
|
117 |
+
return {
|
118 |
+
'features':features,
|
119 |
+
'seg_len':torch.from_numpy(seg_len),
|
120 |
+
'concept1_GT':concept1_GT,
|
121 |
+
'concept2_GT':concept2_GT,
|
122 |
+
'mask_GT':mask_GT,
|
123 |
+
'oracle_summary':oracle_summary,
|
124 |
+
'tokens_pad1':torch.from_numpy(concept1),
|
125 |
+
'tokens_pad2':torch.from_numpy(concept2),
|
126 |
+
'saliency_pos_labels_1': saliency_pos_labels_1,
|
127 |
+
'saliency_pos_labels_2': saliency_pos_labels_2,
|
128 |
+
'saliency_pos_labels_oracle': saliency_pos_labels_oracle,
|
129 |
+
}
|
130 |
+
|
131 |
+
def get_oracle(self,index):
|
132 |
+
video_id=self.dataset[index][1].split('_')[2]
|
133 |
+
f = self.f_dict[video_id]
|
134 |
+
# video_id=self.dataset[index][1].split('_')[2]
|
135 |
+
# feat_type = self.config['vid_feature']
|
136 |
+
# f=h5py.File(f'./data/qfvs/processed/P0{video_id}_{feat_type}.h5','r')
|
137 |
+
features=f['features'][()]
|
138 |
+
seg_len=f['seg_len'][()]
|
139 |
+
|
140 |
+
dim = features.shape[-1]
|
141 |
+
|
142 |
+
mask_GT = torch.zeros(self.config["max_segment_num"], self.config["max_frame_num"], dtype=torch.bool)
|
143 |
+
for j in range(len(seg_len)):
|
144 |
+
for k in range(seg_len[j]):
|
145 |
+
mask_GT[j][k] = 1
|
146 |
+
|
147 |
+
features = torch.from_numpy(features)
|
148 |
+
|
149 |
+
concept1,concept2=self.dataset[index][1].split('_')[0:2]
|
150 |
+
|
151 |
+
concept1_GT=torch.zeros(self.config["max_segment_num"]*self.config["max_frame_num"])
|
152 |
+
concept2_GT=torch.zeros(self.config["max_segment_num"]*self.config["max_frame_num"])
|
153 |
+
# concept1_GT=torch.zeros(seg_len.sum())
|
154 |
+
# concept2_GT= torch.zeros(seg_len.sum())
|
155 |
+
with open("./data/qfvs/metadata/origin_data/Dense_per_shot_tags/P0"+video_id+"/P0"+video_id+".txt","r") as f:
|
156 |
+
lines=f.readlines()
|
157 |
+
for index,line in enumerate(lines):
|
158 |
+
concepts=line.strip().split(',')
|
159 |
+
if concept1 in concepts:
|
160 |
+
concept1_GT[index]=1
|
161 |
+
if concept2 in concepts:
|
162 |
+
concept2_GT[index]=1
|
163 |
+
|
164 |
+
# oracle_summary =torch.zeros(seg_len.sum())
|
165 |
+
oracle_summary = torch.zeros(self.config["max_segment_num"]*self.config["max_frame_num"])
|
166 |
+
GT_summary_shots = []
|
167 |
+
with open("./data/qfvs/metadata/origin_data/Query-Focused_Summaries/Oracle_Summaries/P0"+str(video_id)+"/"+str(concept1)+"_"+str(concept2)+"_"+"oracle.txt","r") as f:
|
168 |
+
for line in f.readlines():
|
169 |
+
GT_summary_shots.append(int(line.strip()))
|
170 |
+
GT_summary_shots = [x - 1 for x in GT_summary_shots]
|
171 |
+
for element in GT_summary_shots:
|
172 |
+
oracle_summary[element] = 1
|
173 |
+
|
174 |
+
if concept1 in self.transfer:
|
175 |
+
concept1=self.transfer[concept1]
|
176 |
+
if concept2 in self.transfer:
|
177 |
+
concept2=self.transfer[concept2]
|
178 |
+
concept1=self.embedding[concept1]
|
179 |
+
concept2=self.embedding[concept2]
|
180 |
+
|
181 |
+
concept1 = l2_normalize_np_array(concept1)
|
182 |
+
concept2 = l2_normalize_np_array(concept2)
|
183 |
+
|
184 |
+
try:
|
185 |
+
saliency_pos_labels_1 = torch.Tensor([random.choice(torch.where(concept1_GT> 0)[0].tolist())])
|
186 |
+
except:
|
187 |
+
saliency_pos_labels_1 = torch.Tensor(0)
|
188 |
+
|
189 |
+
try:
|
190 |
+
saliency_pos_labels_2 = torch.Tensor([random.choice(torch.where(concept2_GT> 0)[0].tolist())])
|
191 |
+
except:
|
192 |
+
saliency_pos_labels_2 = torch.Tensor(0)
|
193 |
+
|
194 |
+
try:
|
195 |
+
saliency_pos_labels_oracle = torch.Tensor([random.choice(torch.where(oracle_summary> 0)[0].tolist())])
|
196 |
+
except:
|
197 |
+
saliency_pos_labels_oracle = torch.Tensor(0)
|
198 |
+
|
199 |
+
return {
|
200 |
+
'features':features,
|
201 |
+
'seg_len':torch.from_numpy(seg_len),
|
202 |
+
'concept1_GT':concept1_GT,
|
203 |
+
'concept2_GT':concept2_GT,
|
204 |
+
'mask_GT':mask_GT,
|
205 |
+
'oracle_summary':oracle_summary,
|
206 |
+
'tokens_pad1':torch.from_numpy(concept1),
|
207 |
+
'tokens_pad2':torch.from_numpy(concept2),
|
208 |
+
'saliency_pos_labels_1': saliency_pos_labels_1,
|
209 |
+
'saliency_pos_labels_2': saliency_pos_labels_2,
|
210 |
+
'saliency_pos_labels_oracle': saliency_pos_labels_oracle,
|
211 |
+
}
|
212 |
+
|
213 |
+
def __len__(self):
|
214 |
+
return len(self.dataset)
|
215 |
+
|
216 |
+
def start_end_collate_qfvs(batch):
|
217 |
+
model_inputs_keys = batch[0].keys()
|
218 |
+
|
219 |
+
batched_data = dict()
|
220 |
+
for k in model_inputs_keys:
|
221 |
+
batched_data[k] = pad_sequences_1d([e[k].data for e in batch], dtype=torch.float32, fixed_length=None)
|
222 |
+
|
223 |
+
return batched_data
|
224 |
+
|
225 |
+
def prepare_batch_inputs_qfvs(data, config, eval=False):
|
226 |
+
if not eval:
|
227 |
+
features, mask, seg_len, \
|
228 |
+
concept1_GT, concept2_GT, mask_GT, oracle_summary_GT, \
|
229 |
+
src_txt_1, src_txt_2, src_txt_mask_1, src_txt_mask_2,\
|
230 |
+
saliency_pos_labels_1, saliency_pos_labels_2, saliency_pos_labels_oracle = \
|
231 |
+
data['features'][0], data['mask_GT'][0], data['seg_len'][0],\
|
232 |
+
data['concept1_GT'][0], data['concept2_GT'][0], data['mask_GT'][0], data['oracle_summary'][0],\
|
233 |
+
data['tokens_pad1'][0], data['tokens_pad2'][0], data['tokens_pad1'][1], data['tokens_pad2'][1], \
|
234 |
+
data['saliency_pos_labels_1'][0], data['saliency_pos_labels_2'][0], data['saliency_pos_labels_oracle'][0],
|
235 |
+
else:
|
236 |
+
features, mask, seg_len, \
|
237 |
+
src_txt_1, src_txt_2, src_txt_mask_1, src_txt_mask_2 = \
|
238 |
+
data['features'][0], data['mask_GT'][0], data['seg_len'][0],\
|
239 |
+
data['tokens_pad1'][0], data['tokens_pad2'][0], data['tokens_pad1'][1], data['tokens_pad2'][1]
|
240 |
+
|
241 |
+
# preprocess for vid input.
|
242 |
+
mask_GT = mask.to('cuda').reshape(1, -1).bool()
|
243 |
+
seq = features.to('cuda').squeeze(0)
|
244 |
+
mask = mask.to('cuda').squeeze(0)
|
245 |
+
num_seg = seq.shape[0]
|
246 |
+
|
247 |
+
ctx_l = seq.shape[1]
|
248 |
+
tef_st = torch.arange(0, ctx_l, 1.0) / ctx_l
|
249 |
+
tef_ed = tef_st + 1.0 / ctx_l
|
250 |
+
tef = torch.stack([tef_st, tef_ed], dim=1).to('cuda') # (Lv, 2)
|
251 |
+
|
252 |
+
tef = tef.squeeze(0).repeat(seq.shape[0], 1, 1)
|
253 |
+
seq = torch.cat([seq, tef], dim=-1)
|
254 |
+
|
255 |
+
# for txt input.
|
256 |
+
src_txt_1 = src_txt_1.to(torch.float32).to('cuda').repeat(num_seg, 1, 1)
|
257 |
+
src_txt_2 = src_txt_2.to(torch.float32).to('cuda').repeat(num_seg, 1, 1)
|
258 |
+
src_txt_mask_1 = src_txt_mask_1.to('cuda').repeat(num_seg, 1)
|
259 |
+
src_txt_mask_2 = src_txt_mask_2.to('cuda').repeat(num_seg, 1)
|
260 |
+
|
261 |
+
src_txt_oracle = torch.cat((src_txt_1, src_txt_2), dim=1).to('cuda')
|
262 |
+
src_txt_mask_oracle = torch.cat((src_txt_mask_1, src_txt_mask_2), dim=1).to('cuda')
|
263 |
+
|
264 |
+
model_inputs_1 = dict(src_vid=seq, src_vid_mask=mask, src_txt=src_txt_1, src_txt_mask=src_txt_mask_1)
|
265 |
+
model_inputs_2 = dict(src_vid=seq, src_vid_mask=mask, src_txt=src_txt_2, src_txt_mask=src_txt_mask_2)
|
266 |
+
model_inputs_oracle = dict(src_vid=seq, src_vid_mask=mask, src_txt=src_txt_oracle, src_txt_mask=src_txt_mask_oracle)
|
267 |
+
|
268 |
+
# concept1_GT = concept1_GT.squeeze().reshape(config['max_segment_num'], config['max_frame_num'])
|
269 |
+
# concept2_GT = concept2_GT.squeeze().reshape(config['max_segment_num'], config['max_frame_num'])
|
270 |
+
# oracle_summary_GT = oracle_summary_GT.squeeze().reshape(config['max_segment_num'], config['max_frame_num'])
|
271 |
+
|
272 |
+
if not eval:
|
273 |
+
targets_1 = dict(saliency_scores=concept1_GT.to('cuda'), saliency_pos_labels=saliency_pos_labels_1.to('cuda'))
|
274 |
+
targets_2 = dict(saliency_scores=concept2_GT.to('cuda'), saliency_pos_labels=saliency_pos_labels_2.to('cuda'))
|
275 |
+
targets_oracle = dict(saliency_scores=oracle_summary_GT.to('cuda'), saliency_pos_labels=saliency_pos_labels_oracle.to('cuda'))
|
276 |
+
|
277 |
+
targets_1['timestamp_mask'] = mask; targets_1['timestamp_window'] = concept1_GT.to('cuda')
|
278 |
+
targets_2['timestamp_mask'] = mask; targets_2['timestamp_window'] = concept2_GT.to('cuda')
|
279 |
+
targets_oracle['timestamp_mask'] = mask; targets_oracle['timestamp_window'] = oracle_summary_GT.to('cuda')
|
280 |
+
|
281 |
+
return model_inputs_1, model_inputs_2, model_inputs_oracle, \
|
282 |
+
targets_1, targets_2, targets_oracle, mask_GT
|
283 |
+
else:
|
284 |
+
return model_inputs_1, model_inputs_2, model_inputs_oracle, mask_GT
|
main/inference_demo.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pdb
|
2 |
+
import pprint
|
3 |
+
from tqdm import tqdm, trange
|
4 |
+
import numpy as np
|
5 |
+
import os
|
6 |
+
from collections import OrderedDict, defaultdict
|
7 |
+
from utils.basic_utils import AverageMeter
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.nn.functional as F
|
11 |
+
import torch.backends.cudnn as cudnn
|
12 |
+
from torch.utils.data import DataLoader
|
13 |
+
|
14 |
+
from main.config import TestOptions, setup_model
|
15 |
+
from main.dataset import DatasetMR, start_end_collate_mr, prepare_batch_inputs_mr
|
16 |
+
from eval.eval import eval_submission
|
17 |
+
from eval.postprocessing import PostProcessorDETR
|
18 |
+
from utils.basic_utils import save_jsonl, save_json
|
19 |
+
from utils.temporal_nms import temporal_nms
|
20 |
+
from utils.span_utils import span_cxw_to_xx
|
21 |
+
from utils.basic_utils import load_jsonl, load_pickle, l2_normalize_np_array
|
22 |
+
|
23 |
+
import logging
|
24 |
+
import importlib
|
25 |
+
|
26 |
+
logger = logging.getLogger(__name__)
|
27 |
+
logging.basicConfig(format="%(asctime)s.%(msecs)03d:%(levelname)s:%(name)s - %(message)s",
|
28 |
+
datefmt="%Y-%m-%d %H:%M:%S",
|
29 |
+
level=logging.INFO)
|
30 |
+
|
31 |
+
def load_model():
|
32 |
+
logger.info("Setup config, data and model...")
|
33 |
+
opt = TestOptions().parse()
|
34 |
+
# pdb.set_trace()
|
35 |
+
cudnn.benchmark = True
|
36 |
+
cudnn.deterministic = False
|
37 |
+
|
38 |
+
model, criterion, _, _ = setup_model(opt)
|
39 |
+
return model
|
40 |
+
|
41 |
+
def load_data(save_dir):
|
42 |
+
vid = np.load(os.path.join(save_dir, 'vid.npz'))['features'].astype(np.float32)
|
43 |
+
txt = np.load(os.path.join(save_dir, 'txt.npz'))['features'].astype(np.float32)
|
44 |
+
|
45 |
+
vid = torch.from_numpy(l2_normalize_np_array(vid))
|
46 |
+
txt = torch.from_numpy(l2_normalize_np_array(txt))
|
47 |
+
clip_len = 2
|
48 |
+
ctx_l = vid.shape[0]
|
49 |
+
|
50 |
+
timestamp = ( (torch.arange(0, ctx_l) + clip_len / 2) / ctx_l).unsqueeze(1).repeat(1, 2)
|
51 |
+
|
52 |
+
if True:
|
53 |
+
tef_st = torch.arange(0, ctx_l, 1.0) / ctx_l
|
54 |
+
tef_ed = tef_st + 1.0 / ctx_l
|
55 |
+
tef = torch.stack([tef_st, tef_ed], dim=1) # (Lv, 2)
|
56 |
+
vid = torch.cat([vid, tef], dim=1) # (Lv, Dv+2)
|
57 |
+
|
58 |
+
src_vid = vid.unsqueeze(0).cuda()
|
59 |
+
src_txt = txt.unsqueeze(0).cuda()
|
60 |
+
src_vid_mask = torch.ones(src_vid.shape[0], src_vid.shape[1]).cuda()
|
61 |
+
src_txt_mask = torch.ones(src_txt.shape[0], src_txt.shape[1]).cuda()
|
62 |
+
|
63 |
+
return src_vid, src_txt, src_vid_mask, src_txt_mask, timestamp, ctx_l
|
64 |
+
|
65 |
+
if __name__ == '__main__':
|
66 |
+
clip_len = 2
|
67 |
+
save_dir = '/data/home/qinghonglin/univtg/demo/tmp'
|
68 |
+
|
69 |
+
model = load_model()
|
70 |
+
src_vid, src_txt, src_vid_mask, src_txt_mask, timestamp, ctx_l = load_data(save_dir)
|
71 |
+
with torch.no_grad():
|
72 |
+
output = model(src_vid=src_vid, src_txt=src_txt, src_vid_mask=src_vid_mask, src_txt_mask=src_txt_mask)
|
73 |
+
|
74 |
+
pred_logits = output['pred_logits'][0].cpu()
|
75 |
+
pred_spans = output['pred_spans'][0].cpu()
|
76 |
+
pred_saliency = output['saliency_scores'].cpu()
|
77 |
+
|
78 |
+
pdb.set_trace()
|
79 |
+
top1 = (pred_spans + timestamp)[torch.argmax(pred_logits)] * ctx_l * clip_len
|
80 |
+
print(top1)
|
81 |
+
print(pred_saliency.argmax()*clip_len)
|
main/inference_hl.py
ADDED
@@ -0,0 +1,229 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import pdb
|
3 |
+
import time
|
4 |
+
import json
|
5 |
+
import pprint
|
6 |
+
import random
|
7 |
+
import importlib
|
8 |
+
import numpy as np
|
9 |
+
from tqdm import tqdm, trange
|
10 |
+
from collections import defaultdict
|
11 |
+
|
12 |
+
import torch
|
13 |
+
import torch.nn as nn
|
14 |
+
import torch.backends.cudnn as cudnn
|
15 |
+
from torch.utils.data import DataLoader
|
16 |
+
from torch.utils.tensorboard import SummaryWriter
|
17 |
+
|
18 |
+
import sys
|
19 |
+
sys.path.append('/Users/kevin/univtg')
|
20 |
+
from main.config import BaseOptions, setup_model
|
21 |
+
from main.dataset import DatasetHL, prepare_batch_inputs_hl, start_end_collate_hl
|
22 |
+
from utils.basic_utils import set_seed, AverageMeter, dict_to_markdown, save_json, save_jsonl
|
23 |
+
from utils.model_utils import count_parameters
|
24 |
+
|
25 |
+
import logging
|
26 |
+
logger = logging.getLogger(__name__)
|
27 |
+
logging.basicConfig(format="%(asctime)s.%(msecs)03d:%(levelname)s:%(name)s - %(message)s",
|
28 |
+
datefmt="%Y-%m-%d %H:%M:%S",
|
29 |
+
level=logging.INFO)
|
30 |
+
|
31 |
+
def eval_epoch(model, train_val_dataset, opt): #, nms_thresh, device):
|
32 |
+
model.eval()
|
33 |
+
|
34 |
+
scores = []
|
35 |
+
train_val_dataset.set_state('val')
|
36 |
+
val_loader = DataLoader(
|
37 |
+
train_val_dataset,
|
38 |
+
collate_fn=start_end_collate_hl,
|
39 |
+
batch_size=opt.eval_bsz,
|
40 |
+
num_workers=opt.num_workers,
|
41 |
+
shuffle=False,
|
42 |
+
pin_memory=opt.pin_memory
|
43 |
+
)
|
44 |
+
|
45 |
+
with torch.no_grad():
|
46 |
+
for data in val_loader:
|
47 |
+
model_inputs, targets = prepare_batch_inputs_hl(data)
|
48 |
+
outputs = model(**model_inputs)
|
49 |
+
# pred_cls = outputs['pred_logits'].squeeze(-1)
|
50 |
+
# pred_cls = outputs['saliency_scores']
|
51 |
+
# pred_cls = outputs['saliency_scores'] + outputs['pred_logits'].squeeze(-1)
|
52 |
+
|
53 |
+
# pdb.set_trace()
|
54 |
+
if opt.f_loss_coef == 0:
|
55 |
+
pred_cls = outputs['saliency_scores']
|
56 |
+
elif opt.s_loss_intra_coef == 0:
|
57 |
+
pred_cls = outputs['pred_logits'].squeeze(-1)
|
58 |
+
else:
|
59 |
+
if opt.eval_mode == 'add':
|
60 |
+
pred_cls = outputs['saliency_scores'] + outputs['pred_logits'].squeeze(-1)
|
61 |
+
else:
|
62 |
+
pred_cls = outputs['pred_logits'].squeeze(-1)
|
63 |
+
|
64 |
+
pred_cls = pred_cls.detach().cpu()
|
65 |
+
scores.append(pred_cls)
|
66 |
+
map = round(train_val_dataset.evaluate(scores, save_dir='./plot')['mAP'] * 100, 4)
|
67 |
+
return map
|
68 |
+
|
69 |
+
def train_epoch(model, criterion, train_val_dataset, optimizer, opt, epoch_i, tb_writer):
|
70 |
+
logger.info(f"[Epoch {epoch_i+1}]")
|
71 |
+
model.train()
|
72 |
+
criterion.train()
|
73 |
+
|
74 |
+
train_val_dataset.set_state('train')
|
75 |
+
train_loader = DataLoader(
|
76 |
+
train_val_dataset,
|
77 |
+
collate_fn=start_end_collate_hl,
|
78 |
+
batch_size=opt.bsz,
|
79 |
+
num_workers=opt.num_workers,
|
80 |
+
shuffle=True,
|
81 |
+
pin_memory=opt.pin_memory
|
82 |
+
)
|
83 |
+
|
84 |
+
# init meters
|
85 |
+
time_meters = defaultdict(AverageMeter)
|
86 |
+
loss_meters = defaultdict(AverageMeter)
|
87 |
+
|
88 |
+
num_training_examples = len(train_loader)
|
89 |
+
timer_dataloading = time.time()
|
90 |
+
for batch_idx, batch in enumerate(train_loader):
|
91 |
+
time_meters["dataloading_time"].update(time.time() - timer_dataloading)
|
92 |
+
timer_start = time.time()
|
93 |
+
model_inputs, targets = prepare_batch_inputs_hl(batch)
|
94 |
+
time_meters["prepare_inputs_time"].update(time.time() - timer_start)
|
95 |
+
|
96 |
+
timer_start = time.time()
|
97 |
+
outputs = model(**model_inputs)
|
98 |
+
loss_dict = criterion(outputs, targets)
|
99 |
+
weight_dict = criterion.weight_dict
|
100 |
+
losses = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict)
|
101 |
+
time_meters["model_forward_time"].update(time.time() - timer_start)
|
102 |
+
|
103 |
+
timer_start = time.time()
|
104 |
+
optimizer.zero_grad()
|
105 |
+
losses.backward()
|
106 |
+
if opt.grad_clip > 0:
|
107 |
+
nn.utils.clip_grad_norm_(model.parameters(), opt.grad_clip)
|
108 |
+
optimizer.step()
|
109 |
+
time_meters["model_backward_time"].update(time.time() - timer_start)
|
110 |
+
|
111 |
+
loss_dict["loss_overall"] = float(losses)
|
112 |
+
for k, v in loss_dict.items():
|
113 |
+
loss_meters[k].update(float(v) * weight_dict[k] if k in weight_dict else float(v))
|
114 |
+
|
115 |
+
timer_dataloading = time.time()
|
116 |
+
if opt.debug and batch_idx == 3:
|
117 |
+
break
|
118 |
+
|
119 |
+
# print/add logs
|
120 |
+
tb_writer.add_scalar("Train/lr", float(optimizer.param_groups[0]["lr"]), epoch_i+1)
|
121 |
+
for k, v in loss_meters.items():
|
122 |
+
tb_writer.add_scalar("Train/{}".format(k), v.avg, epoch_i+1)
|
123 |
+
|
124 |
+
to_write = opt.train_log_txt_formatter.format(
|
125 |
+
time_str=time.strftime("%Y_%m_%d_%H_%M_%S"),
|
126 |
+
epoch=epoch_i+1,
|
127 |
+
loss_str=" ".join(["{} {:.4f}".format(k, v.avg) for k, v in loss_meters.items()]))
|
128 |
+
with open(opt.train_log_filepath, "a") as f:
|
129 |
+
f.write(to_write)
|
130 |
+
|
131 |
+
logger.info("Epoch time stats:")
|
132 |
+
for name, meter in time_meters.items():
|
133 |
+
d = {k: f"{getattr(meter, k):.4f}" for k in ["max", "min", "avg"]}
|
134 |
+
logger.info(f"{name} ==> {d}")
|
135 |
+
|
136 |
+
# train in single domain.
|
137 |
+
def train(model, criterion, optimizer, lr_scheduler, train_val_dataset, opt):
|
138 |
+
# if opt.device.type == "cuda":
|
139 |
+
# logger.info("CUDA enabled.")
|
140 |
+
# model.to(opt.device)
|
141 |
+
|
142 |
+
tb_writer = SummaryWriter(opt.tensorboard_log_dir)
|
143 |
+
tb_writer.add_text("hyperparameters", dict_to_markdown(vars(opt), max_str_len=None))
|
144 |
+
opt.train_log_txt_formatter = "{time_str} [Epoch] {epoch:03d} [Loss] {loss_str}\n"
|
145 |
+
opt.eval_log_txt_formatter = "{time_str} [Epoch] {epoch:03d} [Loss] {loss_str} [Metrics] {eval_metrics_str}\n"
|
146 |
+
|
147 |
+
prev_best_score = 0.
|
148 |
+
if opt.start_epoch is None:
|
149 |
+
start_epoch = -1 if opt.eval_init else 0
|
150 |
+
else:
|
151 |
+
start_epoch = opt.start_epoch
|
152 |
+
|
153 |
+
for epoch_i in trange(start_epoch, opt.n_epoch, desc="Epoch"):
|
154 |
+
if epoch_i > -1:
|
155 |
+
train_epoch(model, criterion, train_val_dataset, optimizer, opt, epoch_i, tb_writer)
|
156 |
+
lr_scheduler.step()
|
157 |
+
eval_epoch_interval = opt.eval_epoch
|
158 |
+
if opt.eval_path is not None and (epoch_i + 1) % eval_epoch_interval == 0:
|
159 |
+
with torch.no_grad():
|
160 |
+
scores = eval_epoch(model, train_val_dataset, opt)
|
161 |
+
tb_writer.add_scalar(f"Eval/HL-{opt.dset_name}-{train_val_dataset.domain}-mAP", float(scores), epoch_i+1)
|
162 |
+
if prev_best_score < scores:
|
163 |
+
prev_best_score = scores
|
164 |
+
checkpoint = {
|
165 |
+
"model": model.state_dict(),
|
166 |
+
"optimizer": optimizer.state_dict(),
|
167 |
+
"epoch": epoch_i,
|
168 |
+
"opt": opt
|
169 |
+
}
|
170 |
+
torch.save(checkpoint, opt.ckpt_filepath.replace(".ckpt", f"_{train_val_dataset.domain}_best.ckpt"))
|
171 |
+
tb_writer.close()
|
172 |
+
return prev_best_score
|
173 |
+
|
174 |
+
def start_training():
|
175 |
+
logger.info("Setup config, data and model...")
|
176 |
+
opt = BaseOptions().parse()
|
177 |
+
set_seed(opt.seed)
|
178 |
+
|
179 |
+
from main.config_hl import TVSUM_SPLITS, YOUTUBE_SPLITS
|
180 |
+
if opt.dset_name == "tvsum":
|
181 |
+
domain_splits = TVSUM_SPLITS.keys()
|
182 |
+
if opt.dset_name == "youtube":
|
183 |
+
domain_splits = YOUTUBE_SPLITS.keys()
|
184 |
+
|
185 |
+
scores = {}
|
186 |
+
if opt.lr_warmup > 0:
|
187 |
+
# total_steps = opt.n_epoch * len(train_dataset) // opt.bsz
|
188 |
+
total_steps = opt.n_epoch
|
189 |
+
warmup_steps = opt.lr_warmup if opt.lr_warmup > 1 else int(opt.lr_warmup * total_steps)
|
190 |
+
opt.lr_warmup = [warmup_steps, total_steps]
|
191 |
+
|
192 |
+
domain_splits = domain_splits if not opt.domain_name else [opt.domain_name]
|
193 |
+
|
194 |
+
for domain in domain_splits:
|
195 |
+
dataset_config = dict(
|
196 |
+
dset_name=opt.dset_name,
|
197 |
+
domain=domain,
|
198 |
+
data_path=opt.train_path,
|
199 |
+
v_feat_types=opt.v_feat_types,
|
200 |
+
v_feat_dirs=opt.v_feat_dirs,
|
201 |
+
t_feat_dir=opt.t_feat_dir,
|
202 |
+
use_tef=True
|
203 |
+
)
|
204 |
+
dataloader = DatasetHL(**dataset_config)
|
205 |
+
|
206 |
+
model, criterion, optimizer, lr_scheduler = setup_model(opt)
|
207 |
+
count_parameters(model)
|
208 |
+
logger.info(f"Start Training {domain}")
|
209 |
+
best_score = train(model, criterion, optimizer, lr_scheduler, dataloader, opt)
|
210 |
+
scores[domain] = best_score
|
211 |
+
scores['AVG'] = sum(scores.values()) / len(scores)
|
212 |
+
|
213 |
+
# save the final results.
|
214 |
+
save_metrics_path = os.path.join(opt.results_dir, f"best_{opt.dset_name}_{opt.eval_split_name}_preds_metrics.json")
|
215 |
+
save_json(scores, save_metrics_path, save_pretty=True, sort_keys=False)
|
216 |
+
|
217 |
+
tb_writer = SummaryWriter(opt.tensorboard_log_dir)
|
218 |
+
tb_writer.add_text(f"HL-{opt.dset_name}", dict_to_markdown(scores, max_str_len=None))
|
219 |
+
tb_writer.add_scalar(f"Eval/HL-{opt.dset_name}-avg-mAP-key", float(scores['AVG']), 1)
|
220 |
+
tb_writer.close()
|
221 |
+
# return opt.ckpt_filepath.replace(".ckpt", "_best.ckpt"), opt.eval_split_name, opt.eval_path, opt.debug
|
222 |
+
|
223 |
+
print(opt.dset_name)
|
224 |
+
print(scores)
|
225 |
+
return
|
226 |
+
|
227 |
+
if __name__ == '__main__':
|
228 |
+
start_training()
|
229 |
+
results = logger.info("\n\n\nFINISHED TRAINING!!!")
|
main/inference_mr.py
ADDED
@@ -0,0 +1,273 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pdb
|
2 |
+
import pprint
|
3 |
+
from tqdm import tqdm, trange
|
4 |
+
import numpy as np
|
5 |
+
import os
|
6 |
+
from collections import OrderedDict, defaultdict
|
7 |
+
from utils.basic_utils import AverageMeter
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.nn.functional as F
|
11 |
+
import torch.backends.cudnn as cudnn
|
12 |
+
from torch.utils.data import DataLoader
|
13 |
+
|
14 |
+
from main.config import TestOptions, setup_model
|
15 |
+
from main.dataset import DatasetMR, start_end_collate_mr, prepare_batch_inputs_mr
|
16 |
+
from eval.eval import eval_submission
|
17 |
+
from eval.postprocessing import PostProcessorDETR
|
18 |
+
from utils.basic_utils import save_jsonl, save_json
|
19 |
+
from utils.temporal_nms import temporal_nms
|
20 |
+
from utils.span_utils import span_cxw_to_xx
|
21 |
+
|
22 |
+
import logging
|
23 |
+
import importlib
|
24 |
+
|
25 |
+
logger = logging.getLogger(__name__)
|
26 |
+
logging.basicConfig(format="%(asctime)s.%(msecs)03d:%(levelname)s:%(name)s - %(message)s",
|
27 |
+
datefmt="%Y-%m-%d %H:%M:%S",
|
28 |
+
level=logging.INFO)
|
29 |
+
|
30 |
+
|
31 |
+
def post_processing_mr_nms(mr_res, nms_thd, max_before_nms, max_after_nms):
|
32 |
+
mr_res_after_nms = []
|
33 |
+
for e in mr_res:
|
34 |
+
e["pred_relevant_windows"] = temporal_nms(
|
35 |
+
e["pred_relevant_windows"][:max_before_nms],
|
36 |
+
nms_thd=nms_thd,
|
37 |
+
max_after_nms=max_after_nms
|
38 |
+
)
|
39 |
+
mr_res_after_nms.append(e)
|
40 |
+
return mr_res_after_nms
|
41 |
+
|
42 |
+
|
43 |
+
def eval_epoch_post_processing(submission, opt, gt_data, save_submission_filename):
|
44 |
+
# IOU_THDS = (0.5, 0.7)
|
45 |
+
logger.info("Saving/Evaluating before nms results")
|
46 |
+
submission_path = os.path.join(opt.results_dir, save_submission_filename)
|
47 |
+
save_jsonl(submission, submission_path)
|
48 |
+
|
49 |
+
if opt.eval_split_name in ["val", "test"]: # since test_public has no GT
|
50 |
+
metrics = eval_submission(
|
51 |
+
submission, gt_data,
|
52 |
+
verbose=opt.debug, match_number=not opt.debug,
|
53 |
+
)
|
54 |
+
save_metrics_path = submission_path.replace(".jsonl", "_metrics.json")
|
55 |
+
save_json(metrics, save_metrics_path, save_pretty=True, sort_keys=False)
|
56 |
+
latest_file_paths = [submission_path, save_metrics_path]
|
57 |
+
else:
|
58 |
+
metrics = None
|
59 |
+
latest_file_paths = [submission_path, ]
|
60 |
+
|
61 |
+
if opt.nms_thd != -1:
|
62 |
+
logger.info("[MR] Performing nms with nms_thd {}".format(opt.nms_thd))
|
63 |
+
submission_after_nms = post_processing_mr_nms(
|
64 |
+
submission, nms_thd=opt.nms_thd,
|
65 |
+
max_before_nms=opt.max_before_nms, max_after_nms=opt.max_after_nms
|
66 |
+
)
|
67 |
+
|
68 |
+
logger.info("Saving/Evaluating nms results")
|
69 |
+
submission_nms_path = submission_path.replace(".jsonl", "_nms_thd_{}.jsonl".format(opt.nms_thd))
|
70 |
+
save_jsonl(submission_after_nms, submission_nms_path)
|
71 |
+
if opt.eval_split_name == "val":
|
72 |
+
metrics_nms = eval_submission(
|
73 |
+
submission_after_nms, gt_data,
|
74 |
+
verbose=opt.debug, match_number=not opt.debug
|
75 |
+
)
|
76 |
+
save_metrics_nms_path = submission_nms_path.replace(".jsonl", "_metrics.json")
|
77 |
+
save_json(metrics_nms, save_metrics_nms_path, save_pretty=True, sort_keys=False)
|
78 |
+
latest_file_paths += [submission_nms_path, save_metrics_nms_path]
|
79 |
+
else:
|
80 |
+
metrics_nms = None
|
81 |
+
latest_file_paths = [submission_nms_path, ]
|
82 |
+
else:
|
83 |
+
metrics_nms = None
|
84 |
+
return metrics, metrics_nms, latest_file_paths
|
85 |
+
|
86 |
+
|
87 |
+
@torch.no_grad()
|
88 |
+
def compute_mr_results(model, eval_loader, opt, epoch_i=None, criterion=None, tb_writer=None):
|
89 |
+
model.eval()
|
90 |
+
if criterion:
|
91 |
+
assert eval_loader.dataset.load_labels
|
92 |
+
criterion.eval()
|
93 |
+
|
94 |
+
loss_meters = defaultdict(AverageMeter)
|
95 |
+
write_tb = tb_writer is not None and epoch_i is not None
|
96 |
+
|
97 |
+
mr_res = []
|
98 |
+
for batch in tqdm(eval_loader, desc="compute st ed scores"):
|
99 |
+
query_meta = batch[0]
|
100 |
+
model_inputs, targets = prepare_batch_inputs_mr(batch[1], opt.device, non_blocking=opt.pin_memory)
|
101 |
+
outputs = model(**model_inputs)
|
102 |
+
prob = outputs["pred_logits"] # the last channel may be 1 or 2.
|
103 |
+
# if opt.eval_mode == 'v1':
|
104 |
+
# prob = prob * outputs["saliency_scores"].unsqueeze(-1) # v1
|
105 |
+
# if opt.eval_mode == 'v2':
|
106 |
+
# prob = F.softmax(prob, dim=1) * outputs["saliency_scores"].unsqueeze(-1) # v2
|
107 |
+
# if opt.eval_mode == 'v3':
|
108 |
+
# prob = outputs["saliency_scores"].unsqueeze(-1)
|
109 |
+
if outputs["pred_logits"].shape[-1] > 1:
|
110 |
+
prob = F.softmax(outputs["pred_logits"], -1) # (batch_size, #queries, #classes=2)
|
111 |
+
if opt.span_loss_type == "l1":
|
112 |
+
scores = prob[..., 0] # * (batch_size, #queries) foreground label is 0, we directly take it
|
113 |
+
pred_spans = outputs["pred_spans"] # (bsz, #queries, 2)
|
114 |
+
|
115 |
+
if opt.model_id not in ['moment_detr']: # dense regression.
|
116 |
+
start_spans = targets['timestamp']
|
117 |
+
pred_spans = start_spans + pred_spans
|
118 |
+
mask = targets['timestamp_mask'].bool()
|
119 |
+
scores[~mask] = 0
|
120 |
+
# if opt.eval_mode == 'v4':
|
121 |
+
# _mask = targets['timestamp_window'].bool()
|
122 |
+
# scores[~_mask] = 0
|
123 |
+
|
124 |
+
if opt.eval_mode == 'add':
|
125 |
+
# pdb.set_trace()
|
126 |
+
_saliency_scores = outputs["saliency_scores"].half() + prob.squeeze(-1)
|
127 |
+
else:
|
128 |
+
_saliency_scores = outputs["saliency_scores"].half() # (bsz, L)
|
129 |
+
|
130 |
+
if opt.eval_mode == 'add_mr':
|
131 |
+
prob = outputs["saliency_scores"].half().unsqueeze(-1) + prob
|
132 |
+
|
133 |
+
saliency_scores = []
|
134 |
+
valid_vid_lengths = model_inputs["src_vid_mask"].sum(1).cpu().tolist()
|
135 |
+
for j in range(len(valid_vid_lengths)):
|
136 |
+
saliency_scores.append(_saliency_scores[j, :int(valid_vid_lengths[j])].tolist())
|
137 |
+
else:
|
138 |
+
bsz, n_queries = outputs["pred_spans"].shape[:2] # # (bsz, #queries, max_v_l *2)
|
139 |
+
pred_spans_logits = outputs["pred_spans"].view(bsz, n_queries, 2, opt.max_v_l)
|
140 |
+
# TODO use more advanced decoding method with st_ed product
|
141 |
+
pred_span_scores, pred_spans = F.softmax(pred_spans_logits, dim=-1).max(-1) # 2 * (bsz, #queries, 2)
|
142 |
+
scores = torch.prod(pred_span_scores, 2) # (bsz, #queries)
|
143 |
+
pred_spans[:, 1] += 1
|
144 |
+
pred_spans *= opt.clip_length
|
145 |
+
|
146 |
+
# compose predictions
|
147 |
+
for idx, (meta, spans, score) in enumerate(zip(query_meta, pred_spans.cpu(), scores.cpu())):
|
148 |
+
if opt.span_loss_type == "l1":
|
149 |
+
if opt.model_id in ['moment_detr']:
|
150 |
+
spans = span_cxw_to_xx(spans) * meta["duration"]
|
151 |
+
else:
|
152 |
+
spans = spans * meta["duration"]
|
153 |
+
spans = torch.clamp(spans, 0, meta["duration"]) # added by Kevin, since window cannot be longer than video duration.
|
154 |
+
|
155 |
+
# (#queries, 3), [st(float), ed(float), score(float)]
|
156 |
+
cur_ranked_preds = torch.cat([spans, score[:, None]], dim=1).tolist()
|
157 |
+
if not opt.no_sort_results:
|
158 |
+
cur_ranked_preds = sorted(cur_ranked_preds, key=lambda x: x[2], reverse=True)
|
159 |
+
cur_ranked_preds = [[float(f"{e:.4f}") for e in row] for row in cur_ranked_preds]
|
160 |
+
cur_query_pred = dict(
|
161 |
+
qid=meta["qid"],
|
162 |
+
query=meta["query"],
|
163 |
+
vid=meta["vid"],
|
164 |
+
pred_relevant_windows=cur_ranked_preds,
|
165 |
+
pred_saliency_scores=saliency_scores[idx]
|
166 |
+
)
|
167 |
+
mr_res.append(cur_query_pred)
|
168 |
+
|
169 |
+
if criterion:
|
170 |
+
loss_dict = criterion(outputs, targets)
|
171 |
+
weight_dict = criterion.weight_dict
|
172 |
+
losses = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict)
|
173 |
+
loss_dict["loss_overall"] = float(losses) # for logging only
|
174 |
+
for k, v in loss_dict.items():
|
175 |
+
loss_meters[k].update(float(v) * weight_dict[k] if k in weight_dict else float(v))
|
176 |
+
|
177 |
+
if opt.debug:
|
178 |
+
break
|
179 |
+
|
180 |
+
if write_tb and criterion:
|
181 |
+
for k, v in loss_meters.items():
|
182 |
+
tb_writer.add_scalar("Eval/{}".format(k), v.avg, epoch_i + 1)
|
183 |
+
|
184 |
+
post_processor = PostProcessorDETR(
|
185 |
+
clip_length=opt.clip_length, min_ts_val=0, max_ts_val=150,
|
186 |
+
min_w_l=2, max_w_l=150, move_window_method="left",
|
187 |
+
# process_func_names=("clip_ts", "round_multiple")
|
188 |
+
process_func_names=["round_multiple"] # have added `clamp' op on line 147, thus we do not need `clip_ts' again;
|
189 |
+
)
|
190 |
+
# todo: are we need round_multiple?
|
191 |
+
if opt.round_multiple > 0:
|
192 |
+
mr_res = post_processor(mr_res)
|
193 |
+
return mr_res, loss_meters
|
194 |
+
|
195 |
+
def get_eval_res(model, eval_loader, opt, epoch_i, criterion, tb_writer):
|
196 |
+
"""compute and save query and video proposal embeddings"""
|
197 |
+
eval_res, eval_loss_meters = compute_mr_results(model, eval_loader, opt, epoch_i, criterion, tb_writer) # list(dict)
|
198 |
+
return eval_res, eval_loss_meters
|
199 |
+
|
200 |
+
def eval_epoch(model, eval_dataset, opt, save_submission_filename, epoch_i=None, criterion=None, tb_writer=None):
|
201 |
+
logger.info("Generate submissions")
|
202 |
+
model.eval()
|
203 |
+
if criterion is not None and eval_dataset.load_labels:
|
204 |
+
criterion.eval()
|
205 |
+
else:
|
206 |
+
criterion = None
|
207 |
+
|
208 |
+
eval_loader = DataLoader(
|
209 |
+
eval_dataset,
|
210 |
+
collate_fn=start_end_collate_mr,
|
211 |
+
batch_size=opt.eval_bsz,
|
212 |
+
num_workers=opt.num_workers,
|
213 |
+
shuffle=False,
|
214 |
+
pin_memory=opt.pin_memory
|
215 |
+
)
|
216 |
+
|
217 |
+
submission, eval_loss_meters = get_eval_res(model, eval_loader, opt, epoch_i, criterion, tb_writer)
|
218 |
+
if opt.no_sort_results:
|
219 |
+
save_submission_filename = save_submission_filename.replace(".jsonl", "_unsorted.jsonl")
|
220 |
+
metrics, metrics_nms, latest_file_paths = eval_epoch_post_processing(
|
221 |
+
submission, opt, eval_dataset.data, save_submission_filename)
|
222 |
+
return metrics, metrics_nms, eval_loss_meters, latest_file_paths
|
223 |
+
|
224 |
+
def start_inference():
|
225 |
+
logger.info("Setup config, data and model...")
|
226 |
+
opt = TestOptions().parse()
|
227 |
+
# pdb.set_trace()
|
228 |
+
cudnn.benchmark = True
|
229 |
+
cudnn.deterministic = False
|
230 |
+
|
231 |
+
assert opt.eval_path is not None
|
232 |
+
eval_dataset = DatasetMR(
|
233 |
+
dset_name=opt.dset_name,
|
234 |
+
data_path=opt.eval_path,
|
235 |
+
v_feat_dirs=opt.v_feat_dirs,
|
236 |
+
q_feat_dir=opt.t_feat_dir,
|
237 |
+
v_feat_dim=opt.v_feat_dim,
|
238 |
+
q_feat_dim=opt.t_feat_dim,
|
239 |
+
q_feat_type="last_hidden_state",
|
240 |
+
max_q_l=opt.max_q_l,
|
241 |
+
max_v_l=opt.max_v_l,
|
242 |
+
ctx_mode=opt.ctx_mode,
|
243 |
+
data_ratio=opt.data_ratio,
|
244 |
+
normalize_v=not opt.no_norm_vfeat,
|
245 |
+
normalize_t=not opt.no_norm_tfeat,
|
246 |
+
clip_len=opt.clip_length,
|
247 |
+
max_windows=opt.max_windows,
|
248 |
+
load_labels=True, # opt.eval_split_name == "val",
|
249 |
+
span_loss_type=opt.span_loss_type,
|
250 |
+
txt_drop_ratio=0,
|
251 |
+
use_cache=opt.use_cache,
|
252 |
+
)
|
253 |
+
|
254 |
+
if opt.lr_warmup > 0:
|
255 |
+
# total_steps = opt.n_epoch * len(train_dataset) // opt.bsz
|
256 |
+
total_steps = opt.n_epoch
|
257 |
+
warmup_steps = opt.lr_warmup if opt.lr_warmup > 1 else int(opt.lr_warmup * total_steps)
|
258 |
+
opt.lr_warmup = [warmup_steps, total_steps]
|
259 |
+
|
260 |
+
model, criterion, _, _ = setup_model(opt)
|
261 |
+
save_submission_filename = "inference_{}_{}_{}_preds.jsonl".format(
|
262 |
+
opt.dset_name, opt.eval_split_name, opt.eval_id)
|
263 |
+
logger.info("Starting inference...")
|
264 |
+
with torch.no_grad():
|
265 |
+
metrics_no_nms, metrics_nms, eval_loss_meters, latest_file_paths = \
|
266 |
+
eval_epoch(model, eval_dataset, opt, save_submission_filename, criterion=criterion)
|
267 |
+
logger.info("metrics_no_nms {}".format(pprint.pformat(metrics_no_nms["brief"], indent=4)))
|
268 |
+
if metrics_nms is not None:
|
269 |
+
logger.info("metrics_nms {}".format(pprint.pformat(metrics_nms["brief"], indent=4)))
|
270 |
+
|
271 |
+
|
272 |
+
if __name__ == '__main__':
|
273 |
+
start_inference()
|
main/inference_qfvs.py
ADDED
@@ -0,0 +1,342 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import pdb
|
3 |
+
import time
|
4 |
+
import json
|
5 |
+
import pprint
|
6 |
+
import random
|
7 |
+
import importlib
|
8 |
+
import numpy as np
|
9 |
+
from tqdm import tqdm, trange
|
10 |
+
from collections import defaultdict
|
11 |
+
|
12 |
+
import h5py
|
13 |
+
import torch
|
14 |
+
import torch.nn as nn
|
15 |
+
import torch.backends.cudnn as cudnn
|
16 |
+
from torch.utils.data import DataLoader
|
17 |
+
from torch.utils.tensorboard import SummaryWriter
|
18 |
+
|
19 |
+
import sys
|
20 |
+
sys.path.append('/Users/kevin/univtg')
|
21 |
+
from main.config import BaseOptions, setup_model
|
22 |
+
from main.dataset_qfvs import DatasetQFVS, prepare_batch_inputs_qfvs, start_end_collate_qfvs
|
23 |
+
from utils.basic_utils import set_seed, AverageMeter, dict_to_markdown, save_json, save_jsonl, load_json, load_pickle, l2_normalize_np_array
|
24 |
+
from utils.model_utils import count_parameters
|
25 |
+
from eval.qfvs import calculate_semantic_matching, load_videos_tag
|
26 |
+
|
27 |
+
import logging
|
28 |
+
logger = logging.getLogger(__name__)
|
29 |
+
logging.basicConfig(format="%(asctime)s.%(msecs)03d:%(levelname)s:%(name)s - %(message)s",
|
30 |
+
datefmt="%Y-%m-%d %H:%M:%S",
|
31 |
+
level=logging.INFO)
|
32 |
+
|
33 |
+
def eval_epoch(model, config, opt):
|
34 |
+
model.eval()
|
35 |
+
f1_sum = 0; p_sum = 0; r_sum = 0
|
36 |
+
|
37 |
+
assert len(config['test_videos']) == 1
|
38 |
+
video_id = config['test_videos'][0]
|
39 |
+
embedding = load_pickle(f"./data/qfvs/txt_clip/{config['txt_feature']}.pkl")
|
40 |
+
|
41 |
+
feat_type = config['vid_feature']
|
42 |
+
feat = h5py.File(f'./data/qfvs/processed/P0{video_id}_{feat_type}.h5', 'r')
|
43 |
+
features = torch.from_numpy(feat['features'][()])
|
44 |
+
seg_len = torch.from_numpy(feat['seg_len'][()])
|
45 |
+
# seg_len = torch.tensor(feat['seg_len'][()]).unsqueeze(0).cuda()
|
46 |
+
|
47 |
+
# dim = features.shape[-1]
|
48 |
+
# ctx_l = seg_len.sum().cpu()
|
49 |
+
|
50 |
+
# dim = features.shape[-1]
|
51 |
+
# ctx_l = features.shape[1]
|
52 |
+
# seg_len = torch.ones(ctx_l)
|
53 |
+
# features = features.reshape(-1, dim)[:ctx_l]
|
54 |
+
|
55 |
+
# tef_st = torch.arange(0, ctx_l, 1.0) / ctx_l
|
56 |
+
# tef_ed = tef_st + 1.0 / ctx_l
|
57 |
+
# tef = torch.stack([tef_st, tef_ed], dim=1).cuda() # (Lv, 2)
|
58 |
+
# features = torch.cat([features, tef], dim=1) # (Lv, Dv+2)
|
59 |
+
|
60 |
+
transfer = {"Cupglass": "Glass",
|
61 |
+
"Musicalinstrument": "Instrument",
|
62 |
+
"Petsanimal": "Animal"}
|
63 |
+
|
64 |
+
with open(os.path.join('./plot', opt.dset_name, str(opt.qfvs_split) +'.jsonl'), 'w') as f_write:
|
65 |
+
for _,_,files in os.walk("./data/qfvs/metadata/origin_data/Query-Focused_Summaries/Oracle_Summaries/P0"+str(video_id)):
|
66 |
+
evaluation_num=len(files)
|
67 |
+
|
68 |
+
mask_GT = torch.zeros(config["max_segment_num"], config["max_frame_num"], dtype=torch.bool).cuda()
|
69 |
+
for j in range(len(seg_len)):
|
70 |
+
for k in range(seg_len[j]):
|
71 |
+
mask_GT[j][k] = 1
|
72 |
+
|
73 |
+
for file in files:
|
74 |
+
summaries_GT=[]
|
75 |
+
with open("./data/qfvs/metadata/origin_data/Query-Focused_Summaries/Oracle_Summaries/P0"+str(video_id)+"/"+file,"r") as f:
|
76 |
+
for line in f.readlines():
|
77 |
+
summaries_GT.append(int(line.strip()))
|
78 |
+
|
79 |
+
concept1, concept2 = file.split('_')[0:2]
|
80 |
+
|
81 |
+
##############
|
82 |
+
if concept1 in transfer:
|
83 |
+
concept1 = transfer[concept1]
|
84 |
+
if concept2 in transfer:
|
85 |
+
concept2 = transfer[concept2]
|
86 |
+
concept1 = embedding[concept1]
|
87 |
+
concept2 = embedding[concept2]
|
88 |
+
|
89 |
+
concept1 = l2_normalize_np_array(concept1)
|
90 |
+
concept2 = l2_normalize_np_array(concept2)
|
91 |
+
|
92 |
+
data = {
|
93 |
+
'features':features,
|
94 |
+
'seg_len': seg_len,
|
95 |
+
'tokens_pad1':torch.from_numpy(concept1),
|
96 |
+
'tokens_pad2':torch.from_numpy(concept2),
|
97 |
+
'mask_GT': mask_GT
|
98 |
+
}
|
99 |
+
|
100 |
+
input1, input2, input_oracle, mask = prepare_batch_inputs_qfvs(start_end_collate_qfvs([data]), config, eval=True)
|
101 |
+
|
102 |
+
summaries_GT = [x - 1 for x in summaries_GT]
|
103 |
+
video_shots_tag = load_videos_tag(mat_path="./eval/Tags.mat")
|
104 |
+
|
105 |
+
if opt.f_loss_coef == 0:
|
106 |
+
output_type = 'saliency_scores'
|
107 |
+
elif opt.s_loss_intra_coef == 0:
|
108 |
+
output_type = 'pred_logits'
|
109 |
+
else:
|
110 |
+
if config['qfvs_score_ensemble'] > 0:
|
111 |
+
output_type = ['pred_logits', 'saliency_scores']
|
112 |
+
else:
|
113 |
+
output_type = 'pred_logits'
|
114 |
+
|
115 |
+
with torch.no_grad():
|
116 |
+
if not isinstance(output_type, list):
|
117 |
+
score1 = model(**input1)[output_type].squeeze()
|
118 |
+
score1 = score1.masked_select(mask_GT)
|
119 |
+
|
120 |
+
score2 = model(**input2)[output_type].squeeze()
|
121 |
+
score2 = score2.masked_select(mask_GT)
|
122 |
+
|
123 |
+
score = model(**input_oracle)[output_type].squeeze()
|
124 |
+
score = score.masked_select(mask_GT)
|
125 |
+
else:
|
126 |
+
score1, score2, score = torch.zeros((int(mask.sum().item()))).cuda(), torch.zeros((int(mask.sum().item()))).cuda(), torch.zeros((int(mask.sum().item()))).cuda()
|
127 |
+
for output_t in output_type:
|
128 |
+
score1 += model(**input1)[output_t].squeeze().masked_select(mask_GT)
|
129 |
+
score2 += model(**input2)[output_t].squeeze().masked_select(mask_GT)
|
130 |
+
score += model(**input_oracle)[output_t].squeeze().masked_select(mask_GT)
|
131 |
+
|
132 |
+
if config['qfvs_score_gather'] > 0:
|
133 |
+
score = score + score1 + score2
|
134 |
+
else:
|
135 |
+
score = score
|
136 |
+
|
137 |
+
# since video4 features dim is greater than video_shots_tag.
|
138 |
+
score = score[:min(score.shape[0], video_shots_tag[video_id-1].shape[0])]
|
139 |
+
_, top_index = score.topk(int(score.shape[0] * config["top_percent"]))
|
140 |
+
|
141 |
+
c1, c2 = file.split('_')[0:2]
|
142 |
+
if c1 in transfer:
|
143 |
+
c1 = transfer[c1]
|
144 |
+
if c2 in transfer:
|
145 |
+
c2 = transfer[c2]
|
146 |
+
|
147 |
+
p, r, f1 = calculate_semantic_matching(list(top_index.cpu().numpy()), summaries_GT, video_shots_tag, video_id=video_id-1)
|
148 |
+
entry = {'concept1': c1, 'concept2': c2,
|
149 |
+
'score':score.tolist(),
|
150 |
+
'top_percent': config["top_percent"],
|
151 |
+
'top_pred':top_index.tolist(),
|
152 |
+
'gt':summaries_GT,
|
153 |
+
'p': p, 'r': r, 'f1': f1,
|
154 |
+
'shots': video_shots_tag[video_id-1].shape[0]}
|
155 |
+
f_write.write(json.dumps(entry) + '\n')
|
156 |
+
f1_sum+=f1; r_sum+=r; p_sum+=p
|
157 |
+
return {'F': round(100* f1_sum/evaluation_num,2) ,
|
158 |
+
'R': round(100* r_sum/evaluation_num,2) ,
|
159 |
+
'P': round(100* p_sum/evaluation_num,2) }
|
160 |
+
|
161 |
+
def idx2time(idx):
|
162 |
+
sec1, sec2 = idx*5, (idx+1)*5
|
163 |
+
|
164 |
+
h1 = sec1 // 3600
|
165 |
+
m1 = (sec1 - h1*3600) // 60
|
166 |
+
s1 = sec1 % 60
|
167 |
+
|
168 |
+
h2 = sec2 // 3600
|
169 |
+
m2 = (sec2 - h2*3600) // 60
|
170 |
+
s2 = sec2 % 60
|
171 |
+
print(h1,m1,s1,'\t', h2,m2,s2)
|
172 |
+
|
173 |
+
def train_epoch(model, criterion, train_loader, optimizer, opt, config, epoch_i, tb_writer):
|
174 |
+
model.train()
|
175 |
+
criterion.train()
|
176 |
+
|
177 |
+
# init meters
|
178 |
+
time_meters = defaultdict(AverageMeter)
|
179 |
+
loss_meters = defaultdict(AverageMeter)
|
180 |
+
|
181 |
+
timer_dataloading = time.time()
|
182 |
+
loss_total = 0
|
183 |
+
|
184 |
+
for batch_idx, batch in enumerate(tqdm(train_loader)):
|
185 |
+
time_meters["dataloading_time"].update(time.time() - timer_dataloading)
|
186 |
+
timer_start = time.time()
|
187 |
+
model_input1, model_input2, model_input_oracle, \
|
188 |
+
model_gt1, model_gt2, model_gt_oracle, \
|
189 |
+
mask_GT = prepare_batch_inputs_qfvs(batch, config)
|
190 |
+
time_meters["prepare_inputs_time"].update(time.time() - timer_start)
|
191 |
+
|
192 |
+
timer_start = time.time()
|
193 |
+
output1 = model(**model_input1)
|
194 |
+
output2 = model(**model_input2)
|
195 |
+
output_oracle = model(**model_input_oracle)
|
196 |
+
|
197 |
+
loss_dict = {}
|
198 |
+
loss_dict1 = criterion(output1, model_gt1, mask_GT)
|
199 |
+
loss_dict2 = criterion(output2, model_gt2, mask_GT)
|
200 |
+
loss_dict3 = criterion(output_oracle, model_gt_oracle, mask_GT)
|
201 |
+
|
202 |
+
weight_dict = criterion.weight_dict
|
203 |
+
if config['qfvs_loss_gather'] > 0:
|
204 |
+
for k in loss_dict1.keys():
|
205 |
+
loss_dict[k] = loss_dict1[k] + loss_dict2[k] + loss_dict3[k]
|
206 |
+
else:
|
207 |
+
loss_dict = loss_dict3
|
208 |
+
|
209 |
+
losses = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict)
|
210 |
+
loss_total += losses.item()
|
211 |
+
|
212 |
+
time_meters["model_forward_time"].update(time.time() - timer_start)
|
213 |
+
timer_start = time.time()
|
214 |
+
optimizer.zero_grad()
|
215 |
+
losses.backward()
|
216 |
+
if opt.grad_clip > 0:
|
217 |
+
nn.utils.clip_grad_norm_(model.parameters(), opt.grad_clip)
|
218 |
+
optimizer.step()
|
219 |
+
time_meters["model_backward_time"].update(time.time() - timer_start)
|
220 |
+
|
221 |
+
timer_dataloading = time.time()
|
222 |
+
return round(loss_total / len(train_loader), 2)
|
223 |
+
|
224 |
+
# train in single domain.
|
225 |
+
def train(model, criterion, optimizer, lr_scheduler, train_loader, opt, config):
|
226 |
+
# if opt.device.type == "cuda":
|
227 |
+
# logger.info("CUDA enabled.")
|
228 |
+
# model.to(opt.device)
|
229 |
+
|
230 |
+
tb_writer = SummaryWriter(opt.tensorboard_log_dir)
|
231 |
+
tb_writer.add_text("hyperparameters", dict_to_markdown(vars(opt), max_str_len=None))
|
232 |
+
opt.train_log_txt_formatter = "{time_str} [Epoch] {epoch:03d} [Loss] {loss_str}\n"
|
233 |
+
opt.eval_log_txt_formatter = "{time_str} [Epoch] {epoch:03d} [Loss] {loss_str} [Metrics] {eval_metrics_str}\n"
|
234 |
+
|
235 |
+
prev_best_score = {'Fscore':0, 'Precision':0, 'Recall':0}
|
236 |
+
if opt.start_epoch is None:
|
237 |
+
start_epoch = -1 if opt.eval_init else 0
|
238 |
+
else:
|
239 |
+
start_epoch = opt.start_epoch
|
240 |
+
|
241 |
+
val_score = eval_epoch(model, config, opt)
|
242 |
+
tb_writer.add_scalar(f"Eval/QFVS-V{config['test_videos'][0]}-fscore", float(val_score['F']), 0)
|
243 |
+
logger.info(f"[Epoch {0}] [Fscore: {val_score['F']} / {prev_best_score['Fscore']}]"
|
244 |
+
f" [Precision: {val_score['P']} / {prev_best_score['Precision']}]"
|
245 |
+
f" [Recall: {val_score['R']} / {prev_best_score['Recall']}]")
|
246 |
+
for epoch_i in trange(start_epoch, opt.n_epoch, desc="Epoch"):
|
247 |
+
if epoch_i > -1:
|
248 |
+
loss_epoch = train_epoch(model, criterion, train_loader, optimizer, opt, config, epoch_i, tb_writer)
|
249 |
+
lr_scheduler.step()
|
250 |
+
eval_epoch_interval = opt.eval_epoch
|
251 |
+
if opt.eval_path is not None and (epoch_i + 1) % eval_epoch_interval == 0:
|
252 |
+
with torch.no_grad():
|
253 |
+
val_score = eval_epoch(model, config, opt)
|
254 |
+
tb_writer.add_scalar(f"Eval/QFVS-V{config['test_videos'][0]}-fscore", float(val_score['F']), epoch_i+1)
|
255 |
+
logger.info(f"[Epoch {epoch_i + 1}, Loss {loss_epoch}] [Fscore: {val_score['F']} / {prev_best_score['Fscore']}]"
|
256 |
+
f" [Precision: {val_score['P']} / {prev_best_score['Precision']}]"
|
257 |
+
f" [Recall: {val_score['R']} / {prev_best_score['Recall']}]")
|
258 |
+
|
259 |
+
if prev_best_score['Fscore'] < val_score['F']:
|
260 |
+
prev_best_score['Fscore'] = val_score['F']
|
261 |
+
prev_best_score['Precision'] = val_score['P']
|
262 |
+
prev_best_score['Recall'] = val_score['R']
|
263 |
+
|
264 |
+
checkpoint = {
|
265 |
+
"model": model.state_dict(),
|
266 |
+
"optimizer": optimizer.state_dict(),
|
267 |
+
"epoch": epoch_i,
|
268 |
+
"opt": opt
|
269 |
+
}
|
270 |
+
torch.save(checkpoint, opt.ckpt_filepath.replace(".ckpt", f"_V{config['test_videos'][0]}_best.ckpt"))
|
271 |
+
tb_writer.close()
|
272 |
+
return prev_best_score
|
273 |
+
|
274 |
+
def update_config(opt, config):
|
275 |
+
# for key in ["max_segment_num", "max_frame_num", "top_percent",
|
276 |
+
# "qfvs_vid_feature", "qfvs_txt_feature", "qfvs_dense_shot",
|
277 |
+
# "qfvs_score_ensemble", "qfvs_score_gather", "qfvs_loss_gather"]:
|
278 |
+
config["max_segment_num"] = opt.max_segment_num
|
279 |
+
config["max_frame_num"] = opt.max_frame_num
|
280 |
+
config["top_percent"] = opt.top_percent
|
281 |
+
config["vid_feature"] = opt.qfvs_vid_feature
|
282 |
+
config["txt_feature"] = opt.qfvs_txt_feature
|
283 |
+
config["qfvs_dense_shot"] = opt.qfvs_dense_shot
|
284 |
+
config["qfvs_score_ensemble"] = opt.qfvs_score_ensemble
|
285 |
+
config["qfvs_score_gather"] = opt.qfvs_score_gather
|
286 |
+
config["qfvs_loss_gather"] = opt.qfvs_loss_gather
|
287 |
+
return config
|
288 |
+
|
289 |
+
def start_training():
|
290 |
+
logger.info("Setup config, data and model...")
|
291 |
+
opt = BaseOptions().parse()
|
292 |
+
set_seed(opt.seed)
|
293 |
+
|
294 |
+
# config = load_json("./main/config_qfvs.json")
|
295 |
+
config = {}
|
296 |
+
config = update_config(opt, config)
|
297 |
+
|
298 |
+
tb_writer = SummaryWriter(opt.tensorboard_log_dir)
|
299 |
+
|
300 |
+
# key -> test video; value -> training videos.
|
301 |
+
qfvs_split = {
|
302 |
+
1: [2, 3, 4],
|
303 |
+
2: [1, 3, 4],
|
304 |
+
3: [1, 2, 4],
|
305 |
+
4: [1, 2, 3]
|
306 |
+
}
|
307 |
+
|
308 |
+
scores_videos = {}
|
309 |
+
for test_id, splits in qfvs_split.items():
|
310 |
+
if opt.qfvs_split != -1:
|
311 |
+
if test_id != opt.qfvs_split:
|
312 |
+
continue
|
313 |
+
logger.info(f"Start Training {opt.dset_name}: {test_id}")
|
314 |
+
config['train_videos'] = qfvs_split[test_id]
|
315 |
+
config['test_videos'] = [test_id]
|
316 |
+
train_dataset = DatasetQFVS(config)
|
317 |
+
train_loader = DataLoader(train_dataset, batch_size=opt.bsz, collate_fn=start_end_collate_qfvs, shuffle=True, num_workers=opt.num_workers)
|
318 |
+
|
319 |
+
model, criterion, optimizer, lr_scheduler = setup_model(opt)
|
320 |
+
count_parameters(model)
|
321 |
+
best_score = train(model, criterion, optimizer, lr_scheduler, train_loader, opt, config)
|
322 |
+
scores_videos['V'+str(test_id)] = best_score
|
323 |
+
|
324 |
+
# save the final results.
|
325 |
+
avg_fscore = sum([v['Fscore'] for k, v in scores_videos.items()]) / len(scores_videos)
|
326 |
+
avg_precision = sum([v['Precision'] for k, v in scores_videos.items()]) / len(scores_videos)
|
327 |
+
avg_recall = sum([v['Recall'] for k, v in scores_videos.items()]) / len(scores_videos)
|
328 |
+
scores_videos['avg'] = {'Fscore':avg_fscore, 'Precision':avg_precision, 'Recall':avg_recall}
|
329 |
+
|
330 |
+
save_metrics_path = os.path.join(opt.results_dir, f"best_{opt.dset_name}_{opt.eval_split_name}_preds_metrics.json")
|
331 |
+
save_json( scores_videos, save_metrics_path, save_pretty=True, sort_keys=False)
|
332 |
+
|
333 |
+
tb_writer.add_scalar(f"Eval/QFVS-avg-fscore", round(avg_fscore, 2), 1)
|
334 |
+
tb_writer.add_text(f"Eval/QFVS-{opt.dset_name}", dict_to_markdown(scores_videos, max_str_len=None))
|
335 |
+
tb_writer.close()
|
336 |
+
|
337 |
+
print(scores_videos)
|
338 |
+
return
|
339 |
+
|
340 |
+
if __name__ == '__main__':
|
341 |
+
start_training()
|
342 |
+
results = logger.info("\n\n\nFINISHED TRAINING!!!")
|
main/train_hl.py
ADDED
@@ -0,0 +1,229 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import pdb
|
3 |
+
import time
|
4 |
+
import json
|
5 |
+
import pprint
|
6 |
+
import random
|
7 |
+
import importlib
|
8 |
+
import numpy as np
|
9 |
+
from tqdm import tqdm, trange
|
10 |
+
from collections import defaultdict
|
11 |
+
|
12 |
+
import torch
|
13 |
+
import torch.nn as nn
|
14 |
+
import torch.backends.cudnn as cudnn
|
15 |
+
from torch.utils.data import DataLoader
|
16 |
+
from torch.utils.tensorboard import SummaryWriter
|
17 |
+
|
18 |
+
import sys
|
19 |
+
sys.path.append('/data/home/qinghonglin/univtg')
|
20 |
+
from main.config import BaseOptions, setup_model
|
21 |
+
from main.dataset import DatasetHL, prepare_batch_inputs_hl, start_end_collate_hl
|
22 |
+
from utils.basic_utils import set_seed, AverageMeter, dict_to_markdown, save_json, save_jsonl
|
23 |
+
from utils.model_utils import count_parameters
|
24 |
+
|
25 |
+
import logging
|
26 |
+
logger = logging.getLogger(__name__)
|
27 |
+
logging.basicConfig(format="%(asctime)s.%(msecs)03d:%(levelname)s:%(name)s - %(message)s",
|
28 |
+
datefmt="%Y-%m-%d %H:%M:%S",
|
29 |
+
level=logging.INFO)
|
30 |
+
|
31 |
+
def eval_epoch(model, train_val_dataset, opt): #, nms_thresh, device):
|
32 |
+
model.eval()
|
33 |
+
|
34 |
+
scores = []
|
35 |
+
train_val_dataset.set_state('val')
|
36 |
+
val_loader = DataLoader(
|
37 |
+
train_val_dataset,
|
38 |
+
collate_fn=start_end_collate_hl,
|
39 |
+
batch_size=opt.eval_bsz,
|
40 |
+
num_workers=opt.num_workers,
|
41 |
+
shuffle=False,
|
42 |
+
pin_memory=opt.pin_memory
|
43 |
+
)
|
44 |
+
|
45 |
+
with torch.no_grad():
|
46 |
+
for data in val_loader:
|
47 |
+
model_inputs, targets = prepare_batch_inputs_hl(data)
|
48 |
+
outputs = model(**model_inputs)
|
49 |
+
# pred_cls = outputs['pred_logits'].squeeze(-1)
|
50 |
+
# pred_cls = outputs['saliency_scores']
|
51 |
+
# pred_cls = outputs['saliency_scores'] + outputs['pred_logits'].squeeze(-1)
|
52 |
+
|
53 |
+
# pdb.set_trace()
|
54 |
+
if opt.f_loss_coef == 0:
|
55 |
+
pred_cls = outputs['saliency_scores']
|
56 |
+
elif opt.s_loss_intra_coef == 0:
|
57 |
+
pred_cls = outputs['pred_logits'].squeeze(-1)
|
58 |
+
else:
|
59 |
+
if opt.eval_mode == 'add':
|
60 |
+
pred_cls = outputs['saliency_scores'] + outputs['pred_logits'].squeeze(-1)
|
61 |
+
else:
|
62 |
+
pred_cls = outputs['pred_logits'].squeeze(-1)
|
63 |
+
|
64 |
+
pred_cls = pred_cls.detach().cpu()
|
65 |
+
scores.append(pred_cls)
|
66 |
+
map = round(train_val_dataset.evaluate(scores)['mAP'] * 100, 4)
|
67 |
+
return map
|
68 |
+
|
69 |
+
def train_epoch(model, criterion, train_val_dataset, optimizer, opt, epoch_i, tb_writer):
|
70 |
+
logger.info(f"[Epoch {epoch_i+1}]")
|
71 |
+
model.train()
|
72 |
+
criterion.train()
|
73 |
+
|
74 |
+
train_val_dataset.set_state('train')
|
75 |
+
train_loader = DataLoader(
|
76 |
+
train_val_dataset,
|
77 |
+
collate_fn=start_end_collate_hl,
|
78 |
+
batch_size=opt.bsz,
|
79 |
+
num_workers=opt.num_workers,
|
80 |
+
shuffle=True,
|
81 |
+
pin_memory=opt.pin_memory
|
82 |
+
)
|
83 |
+
|
84 |
+
# init meters
|
85 |
+
time_meters = defaultdict(AverageMeter)
|
86 |
+
loss_meters = defaultdict(AverageMeter)
|
87 |
+
|
88 |
+
num_training_examples = len(train_loader)
|
89 |
+
timer_dataloading = time.time()
|
90 |
+
for batch_idx, batch in enumerate(train_loader):
|
91 |
+
time_meters["dataloading_time"].update(time.time() - timer_dataloading)
|
92 |
+
timer_start = time.time()
|
93 |
+
model_inputs, targets = prepare_batch_inputs_hl(batch)
|
94 |
+
time_meters["prepare_inputs_time"].update(time.time() - timer_start)
|
95 |
+
|
96 |
+
timer_start = time.time()
|
97 |
+
outputs = model(**model_inputs)
|
98 |
+
loss_dict = criterion(outputs, targets)
|
99 |
+
weight_dict = criterion.weight_dict
|
100 |
+
losses = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict)
|
101 |
+
time_meters["model_forward_time"].update(time.time() - timer_start)
|
102 |
+
|
103 |
+
timer_start = time.time()
|
104 |
+
optimizer.zero_grad()
|
105 |
+
losses.backward()
|
106 |
+
if opt.grad_clip > 0:
|
107 |
+
nn.utils.clip_grad_norm_(model.parameters(), opt.grad_clip)
|
108 |
+
optimizer.step()
|
109 |
+
time_meters["model_backward_time"].update(time.time() - timer_start)
|
110 |
+
|
111 |
+
loss_dict["loss_overall"] = float(losses)
|
112 |
+
for k, v in loss_dict.items():
|
113 |
+
loss_meters[k].update(float(v) * weight_dict[k] if k in weight_dict else float(v))
|
114 |
+
|
115 |
+
timer_dataloading = time.time()
|
116 |
+
if opt.debug and batch_idx == 3:
|
117 |
+
break
|
118 |
+
|
119 |
+
# print/add logs
|
120 |
+
tb_writer.add_scalar("Train/lr", float(optimizer.param_groups[0]["lr"]), epoch_i+1)
|
121 |
+
for k, v in loss_meters.items():
|
122 |
+
tb_writer.add_scalar("Train/{}".format(k), v.avg, epoch_i+1)
|
123 |
+
|
124 |
+
to_write = opt.train_log_txt_formatter.format(
|
125 |
+
time_str=time.strftime("%Y_%m_%d_%H_%M_%S"),
|
126 |
+
epoch=epoch_i+1,
|
127 |
+
loss_str=" ".join(["{} {:.4f}".format(k, v.avg) for k, v in loss_meters.items()]))
|
128 |
+
with open(opt.train_log_filepath, "a") as f:
|
129 |
+
f.write(to_write)
|
130 |
+
|
131 |
+
logger.info("Epoch time stats:")
|
132 |
+
for name, meter in time_meters.items():
|
133 |
+
d = {k: f"{getattr(meter, k):.4f}" for k in ["max", "min", "avg"]}
|
134 |
+
logger.info(f"{name} ==> {d}")
|
135 |
+
|
136 |
+
# train in single domain.
|
137 |
+
def train(model, criterion, optimizer, lr_scheduler, train_val_dataset, opt):
|
138 |
+
# if opt.device.type == "cuda":
|
139 |
+
# logger.info("CUDA enabled.")
|
140 |
+
# model.to(opt.device)
|
141 |
+
|
142 |
+
tb_writer = SummaryWriter(opt.tensorboard_log_dir)
|
143 |
+
tb_writer.add_text("hyperparameters", dict_to_markdown(vars(opt), max_str_len=None))
|
144 |
+
opt.train_log_txt_formatter = "{time_str} [Epoch] {epoch:03d} [Loss] {loss_str}\n"
|
145 |
+
opt.eval_log_txt_formatter = "{time_str} [Epoch] {epoch:03d} [Loss] {loss_str} [Metrics] {eval_metrics_str}\n"
|
146 |
+
|
147 |
+
prev_best_score = 0.
|
148 |
+
if opt.start_epoch is None:
|
149 |
+
start_epoch = -1 if opt.eval_init else 0
|
150 |
+
else:
|
151 |
+
start_epoch = opt.start_epoch
|
152 |
+
|
153 |
+
for epoch_i in trange(start_epoch, opt.n_epoch, desc="Epoch"):
|
154 |
+
if epoch_i > -1:
|
155 |
+
train_epoch(model, criterion, train_val_dataset, optimizer, opt, epoch_i, tb_writer)
|
156 |
+
lr_scheduler.step()
|
157 |
+
eval_epoch_interval = opt.eval_epoch
|
158 |
+
if opt.eval_path is not None and (epoch_i + 1) % eval_epoch_interval == 0:
|
159 |
+
with torch.no_grad():
|
160 |
+
scores = eval_epoch(model, train_val_dataset, opt)
|
161 |
+
tb_writer.add_scalar(f"Eval/HL-{opt.dset_name}-{train_val_dataset.domain}-mAP", float(scores), epoch_i+1)
|
162 |
+
if prev_best_score < scores:
|
163 |
+
prev_best_score = scores
|
164 |
+
checkpoint = {
|
165 |
+
"model": model.state_dict(),
|
166 |
+
"optimizer": optimizer.state_dict(),
|
167 |
+
"epoch": epoch_i,
|
168 |
+
"opt": opt
|
169 |
+
}
|
170 |
+
torch.save(checkpoint, opt.ckpt_filepath.replace(".ckpt", f"_{train_val_dataset.domain}_best.ckpt"))
|
171 |
+
tb_writer.close()
|
172 |
+
return prev_best_score
|
173 |
+
|
174 |
+
def start_training():
|
175 |
+
logger.info("Setup config, data and model...")
|
176 |
+
opt = BaseOptions().parse()
|
177 |
+
set_seed(opt.seed)
|
178 |
+
|
179 |
+
from main.config_hl import TVSUM_SPLITS, YOUTUBE_SPLITS
|
180 |
+
if opt.dset_name == "tvsum":
|
181 |
+
domain_splits = TVSUM_SPLITS.keys()
|
182 |
+
if opt.dset_name == "youtube":
|
183 |
+
domain_splits = YOUTUBE_SPLITS.keys()
|
184 |
+
|
185 |
+
scores = {}
|
186 |
+
if opt.lr_warmup > 0:
|
187 |
+
# total_steps = opt.n_epoch * len(train_dataset) // opt.bsz
|
188 |
+
total_steps = opt.n_epoch
|
189 |
+
warmup_steps = opt.lr_warmup if opt.lr_warmup > 1 else int(opt.lr_warmup * total_steps)
|
190 |
+
opt.lr_warmup = [warmup_steps, total_steps]
|
191 |
+
|
192 |
+
domain_splits = domain_splits if not opt.domain_name else [opt.domain_name]
|
193 |
+
|
194 |
+
for domain in domain_splits:
|
195 |
+
dataset_config = dict(
|
196 |
+
dset_name=opt.dset_name,
|
197 |
+
domain=domain,
|
198 |
+
data_path=opt.train_path,
|
199 |
+
v_feat_types=opt.v_feat_types,
|
200 |
+
v_feat_dirs=opt.v_feat_dirs,
|
201 |
+
t_feat_dir=opt.t_feat_dir,
|
202 |
+
use_tef=True
|
203 |
+
)
|
204 |
+
dataloader = DatasetHL(**dataset_config)
|
205 |
+
|
206 |
+
model, criterion, optimizer, lr_scheduler = setup_model(opt)
|
207 |
+
count_parameters(model)
|
208 |
+
logger.info(f"Start Training {domain}")
|
209 |
+
best_score = train(model, criterion, optimizer, lr_scheduler, dataloader, opt)
|
210 |
+
scores[domain] = best_score
|
211 |
+
scores['AVG'] = sum(scores.values()) / len(scores)
|
212 |
+
|
213 |
+
# save the final results.
|
214 |
+
save_metrics_path = os.path.join(opt.results_dir, f"best_{opt.dset_name}_{opt.eval_split_name}_preds_metrics.json")
|
215 |
+
save_json(scores, save_metrics_path, save_pretty=True, sort_keys=False)
|
216 |
+
|
217 |
+
tb_writer = SummaryWriter(opt.tensorboard_log_dir)
|
218 |
+
tb_writer.add_text(f"HL-{opt.dset_name}", dict_to_markdown(scores, max_str_len=None))
|
219 |
+
tb_writer.add_scalar(f"Eval/HL-{opt.dset_name}-avg-mAP-key", float(scores['AVG']), 1)
|
220 |
+
tb_writer.close()
|
221 |
+
# return opt.ckpt_filepath.replace(".ckpt", "_best.ckpt"), opt.eval_split_name, opt.eval_path, opt.debug
|
222 |
+
|
223 |
+
print(opt.dset_name)
|
224 |
+
print(scores)
|
225 |
+
return
|
226 |
+
|
227 |
+
if __name__ == '__main__':
|
228 |
+
start_training()
|
229 |
+
results = logger.info("\n\n\nFINISHED TRAINING!!!")
|
main/train_mr.py
ADDED
@@ -0,0 +1,266 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import pdb
|
3 |
+
import sys
|
4 |
+
import time
|
5 |
+
import json
|
6 |
+
import pprint
|
7 |
+
import random
|
8 |
+
import numpy as np
|
9 |
+
from tqdm import tqdm, trange
|
10 |
+
from collections import defaultdict
|
11 |
+
|
12 |
+
import torch
|
13 |
+
import torch.nn as nn
|
14 |
+
import torch.backends.cudnn as cudnn
|
15 |
+
from torch.utils.data import DataLoader
|
16 |
+
from torch.utils.tensorboard import SummaryWriter
|
17 |
+
|
18 |
+
sys.path.append('/data/home/qinghonglin/univtg')
|
19 |
+
from main.config import BaseOptions, setup_model
|
20 |
+
from main.dataset import \
|
21 |
+
DatasetMR, start_end_collate_mr, prepare_batch_inputs_mr
|
22 |
+
from main.inference_mr import eval_epoch, start_inference
|
23 |
+
from utils.basic_utils import set_seed, AverageMeter, dict_to_markdown
|
24 |
+
from utils.model_utils import count_parameters
|
25 |
+
|
26 |
+
import logging
|
27 |
+
logger = logging.getLogger(__name__)
|
28 |
+
logging.basicConfig(format="%(asctime)s.%(msecs)03d:%(levelname)s:%(name)s - %(message)s",
|
29 |
+
datefmt="%Y-%m-%d %H:%M:%S",
|
30 |
+
level=logging.INFO)
|
31 |
+
|
32 |
+
def train_epoch(model, criterion, train_loader, optimizer, opt, epoch_i, tb_writer):
|
33 |
+
logger.info(f"[Epoch {epoch_i+1}]")
|
34 |
+
model.train()
|
35 |
+
criterion.train()
|
36 |
+
|
37 |
+
# init meters
|
38 |
+
time_meters = defaultdict(AverageMeter)
|
39 |
+
loss_meters = defaultdict(AverageMeter)
|
40 |
+
|
41 |
+
num_training_examples = len(train_loader)
|
42 |
+
timer_dataloading = time.time()
|
43 |
+
for batch_idx, batch in tqdm(enumerate(train_loader),
|
44 |
+
desc="Training Iteration",
|
45 |
+
total=num_training_examples):
|
46 |
+
time_meters["dataloading_time"].update(time.time() - timer_dataloading)
|
47 |
+
|
48 |
+
timer_start = time.time()
|
49 |
+
model_inputs, targets = prepare_batch_inputs_mr(batch[1], opt.device, non_blocking=opt.pin_memory)
|
50 |
+
time_meters["prepare_inputs_time"].update(time.time() - timer_start)
|
51 |
+
|
52 |
+
timer_start = time.time()
|
53 |
+
|
54 |
+
# try:
|
55 |
+
outputs = model(**model_inputs)
|
56 |
+
loss_dict = criterion(outputs, targets)
|
57 |
+
weight_dict = criterion.weight_dict
|
58 |
+
losses = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict)
|
59 |
+
time_meters["model_forward_time"].update(time.time() - timer_start)
|
60 |
+
|
61 |
+
timer_start = time.time()
|
62 |
+
optimizer.zero_grad()
|
63 |
+
losses.backward()
|
64 |
+
|
65 |
+
if opt.grad_clip > 0:
|
66 |
+
nn.utils.clip_grad_norm_(model.parameters(), opt.grad_clip)
|
67 |
+
optimizer.step()
|
68 |
+
time_meters["model_backward_time"].update(time.time() - timer_start)
|
69 |
+
|
70 |
+
loss_dict["loss_overall"] = float(losses) # for logging only
|
71 |
+
for k, v in loss_dict.items():
|
72 |
+
loss_meters[k].update(float(v) * weight_dict[k] if k in weight_dict else float(v))
|
73 |
+
|
74 |
+
timer_dataloading = time.time()
|
75 |
+
|
76 |
+
# print/add logs
|
77 |
+
tb_writer.add_scalar("Train/lr", float(optimizer.param_groups[0]["lr"]), epoch_i+1)
|
78 |
+
for k, v in loss_meters.items():
|
79 |
+
tb_writer.add_scalar("Train/{}".format(k), v.avg, epoch_i+1)
|
80 |
+
|
81 |
+
to_write = opt.train_log_txt_formatter.format(
|
82 |
+
time_str=time.strftime("%Y_%m_%d_%H_%M_%S"),
|
83 |
+
epoch=epoch_i+1,
|
84 |
+
loss_str=" ".join(["{} {:.4f}".format(k, v.avg) for k, v in loss_meters.items()]))
|
85 |
+
with open(opt.train_log_filepath, "a") as f:
|
86 |
+
f.write(to_write)
|
87 |
+
|
88 |
+
logger.info("Epoch time stats:")
|
89 |
+
for name, meter in time_meters.items():
|
90 |
+
d = {k: f"{getattr(meter, k):.4f}" for k in ["max", "min", "avg"]}
|
91 |
+
logger.info(f"{name} ==> {d}")
|
92 |
+
|
93 |
+
|
94 |
+
def train(model, criterion, optimizer, lr_scheduler, train_dataset, val_dataset, opt):
|
95 |
+
tb_writer = SummaryWriter(opt.tensorboard_log_dir)
|
96 |
+
tb_writer.add_text("hyperparameters", dict_to_markdown(vars(opt), max_str_len=None))
|
97 |
+
opt.train_log_txt_formatter = "{time_str} [Epoch] {epoch:03d} [Loss] {loss_str}\n"
|
98 |
+
opt.eval_log_txt_formatter = "{time_str} [Epoch] {epoch:03d} [Loss] {loss_str} [Metrics] {eval_metrics_str}\n"
|
99 |
+
|
100 |
+
train_loader = DataLoader(
|
101 |
+
train_dataset,
|
102 |
+
collate_fn=start_end_collate_mr,
|
103 |
+
batch_size=opt.bsz,
|
104 |
+
num_workers=opt.num_workers,
|
105 |
+
shuffle=True,
|
106 |
+
pin_memory=opt.pin_memory
|
107 |
+
)
|
108 |
+
|
109 |
+
prev_best_score = 0.
|
110 |
+
es_cnt = 0
|
111 |
+
if opt.start_epoch is None:
|
112 |
+
start_epoch = -1 if opt.eval_init else 0
|
113 |
+
else:
|
114 |
+
start_epoch = opt.start_epoch
|
115 |
+
save_submission_filename = "latest_{}_{}_preds.jsonl".format(opt.dset_name, opt.eval_split_name)
|
116 |
+
for epoch_i in trange(start_epoch, opt.n_epoch, desc="Epoch"):
|
117 |
+
if epoch_i > -1:
|
118 |
+
train_epoch(model, criterion, train_loader, optimizer, opt, epoch_i, tb_writer)
|
119 |
+
lr_scheduler.step()
|
120 |
+
eval_epoch_interval = opt.eval_epoch
|
121 |
+
if opt.eval_path is not None and (epoch_i + 1) % eval_epoch_interval == 0:
|
122 |
+
with torch.no_grad():
|
123 |
+
metrics_no_nms, metrics_nms, eval_loss_meters, latest_file_paths = \
|
124 |
+
eval_epoch(model, val_dataset, opt, save_submission_filename, epoch_i, criterion, tb_writer)
|
125 |
+
|
126 |
+
# log
|
127 |
+
to_write = opt.eval_log_txt_formatter.format(
|
128 |
+
time_str=time.strftime("%Y_%m_%d_%H_%M_%S"),
|
129 |
+
epoch=epoch_i,
|
130 |
+
loss_str=" ".join(["{} {:.4f}".format(k, v.avg) for k, v in eval_loss_meters.items()]),
|
131 |
+
eval_metrics_str=json.dumps(metrics_no_nms))
|
132 |
+
|
133 |
+
with open(opt.eval_log_filepath, "a") as f:
|
134 |
+
f.write(to_write)
|
135 |
+
logger.info("metrics_no_nms {}".format(pprint.pformat(metrics_no_nms["brief"], indent=4)))
|
136 |
+
if metrics_nms is not None:
|
137 |
+
logger.info("metrics_nms {}".format(pprint.pformat(metrics_nms["brief"], indent=4)))
|
138 |
+
|
139 |
+
metrics = metrics_nms if metrics_nms is not None else metrics_no_nms
|
140 |
+
for k, v in metrics["brief"].items():
|
141 |
+
tb_writer.add_scalar(f"Eval/{k}", float(v), epoch_i+1)
|
142 |
+
|
143 |
+
# stop_score = metrics["brief"]["MR-full-mAP"]
|
144 |
+
# pdb.set_trace()
|
145 |
+
stop_score = metrics["brief"][opt.main_metric]
|
146 |
+
if stop_score > prev_best_score:
|
147 |
+
es_cnt = 0
|
148 |
+
prev_best_score = stop_score
|
149 |
+
|
150 |
+
checkpoint = {
|
151 |
+
"model": model.state_dict(),
|
152 |
+
"optimizer": optimizer.state_dict(),
|
153 |
+
"lr_scheduler": lr_scheduler.state_dict(),
|
154 |
+
"epoch": epoch_i,
|
155 |
+
"opt": opt
|
156 |
+
}
|
157 |
+
torch.save(checkpoint, opt.ckpt_filepath.replace(".ckpt", "_best.ckpt"))
|
158 |
+
|
159 |
+
best_file_paths = [e.replace("latest", "best") for e in latest_file_paths]
|
160 |
+
for src, tgt in zip(latest_file_paths, best_file_paths):
|
161 |
+
os.renames(src, tgt)
|
162 |
+
logger.info("The checkpoint file has been updated.")
|
163 |
+
else:
|
164 |
+
es_cnt += 1
|
165 |
+
if opt.max_es_cnt != -1 and es_cnt > opt.max_es_cnt: # early stop
|
166 |
+
with open(opt.train_log_filepath, "a") as f:
|
167 |
+
f.write(f"Early Stop at epoch {epoch_i}")
|
168 |
+
logger.info(f"\n>>>>> Early stop at epoch {epoch_i} {prev_best_score}\n")
|
169 |
+
break
|
170 |
+
|
171 |
+
# save ckpt
|
172 |
+
checkpoint = {
|
173 |
+
"model": model.state_dict(),
|
174 |
+
"optimizer": optimizer.state_dict(),
|
175 |
+
"lr_scheduler": lr_scheduler.state_dict(),
|
176 |
+
"epoch": epoch_i,
|
177 |
+
"opt": opt
|
178 |
+
}
|
179 |
+
torch.save(checkpoint, opt.ckpt_filepath.replace(".ckpt", "_latest.ckpt"))
|
180 |
+
|
181 |
+
if (epoch_i + 1) % opt.save_interval == 0 or (epoch_i + 1) % opt.lr_drop == 0: # additional copies
|
182 |
+
checkpoint = {
|
183 |
+
"model": model.state_dict(),
|
184 |
+
"optimizer": optimizer.state_dict(),
|
185 |
+
"epoch": epoch_i,
|
186 |
+
"opt": opt
|
187 |
+
}
|
188 |
+
torch.save(checkpoint, opt.ckpt_filepath.replace(".ckpt", f"_e{epoch_i:04d}.ckpt"))
|
189 |
+
|
190 |
+
if opt.debug:
|
191 |
+
break
|
192 |
+
|
193 |
+
tb_writer.close()
|
194 |
+
|
195 |
+
|
196 |
+
def start_training():
|
197 |
+
logger.info("Setup config, data and model...")
|
198 |
+
opt = BaseOptions().parse()
|
199 |
+
set_seed(opt.seed)
|
200 |
+
if opt.debug: # keep the model run deterministically
|
201 |
+
# 'cudnn.benchmark = True' enabled auto finding the best algorithm for a specific input/net config.
|
202 |
+
# Enable this only when input size is fixed.
|
203 |
+
cudnn.benchmark = False
|
204 |
+
cudnn.deterministic = True
|
205 |
+
|
206 |
+
dataset_config = dict(
|
207 |
+
dset_name=opt.dset_name,
|
208 |
+
data_path=opt.train_path,
|
209 |
+
v_feat_dirs=opt.v_feat_dirs,
|
210 |
+
q_feat_dir=opt.t_feat_dir,
|
211 |
+
v_feat_dim=opt.v_feat_dim,
|
212 |
+
q_feat_dim=opt.t_feat_dim,
|
213 |
+
q_feat_type="last_hidden_state",
|
214 |
+
max_q_l=opt.max_q_l,
|
215 |
+
max_v_l=opt.max_v_l,
|
216 |
+
ctx_mode=opt.ctx_mode,
|
217 |
+
data_ratio=opt.data_ratio,
|
218 |
+
normalize_v=not opt.no_norm_vfeat,
|
219 |
+
normalize_t=not opt.no_norm_tfeat,
|
220 |
+
clip_len=opt.clip_length,
|
221 |
+
max_windows=opt.max_windows,
|
222 |
+
span_loss_type=opt.span_loss_type,
|
223 |
+
txt_drop_ratio=opt.txt_drop_ratio,
|
224 |
+
use_cache=opt.use_cache,
|
225 |
+
add_easy_negative=opt.add_easy_negative,
|
226 |
+
easy_negative_only=opt.easy_negative_only
|
227 |
+
)
|
228 |
+
|
229 |
+
dataset_config["data_path"] = opt.train_path
|
230 |
+
train_dataset = DatasetMR(**dataset_config)
|
231 |
+
|
232 |
+
if opt.eval_path is not None:
|
233 |
+
dataset_config["data_path"] = opt.eval_path
|
234 |
+
dataset_config["txt_drop_ratio"] = 0
|
235 |
+
dataset_config["q_feat_dir"] = opt.t_feat_dir.replace("txt_clip_asr", "txt_clip").replace("txt_clip_cap", "txt_clip") # for pretraining
|
236 |
+
# dataset_config["load_labels"] = False # uncomment to calculate eval loss
|
237 |
+
eval_dataset = DatasetMR(**dataset_config)
|
238 |
+
else:
|
239 |
+
eval_dataset = None
|
240 |
+
|
241 |
+
if opt.lr_warmup > 0:
|
242 |
+
# total_steps = opt.n_epoch * len(train_dataset) // opt.bsz
|
243 |
+
total_steps = opt.n_epoch
|
244 |
+
warmup_steps = opt.lr_warmup if opt.lr_warmup > 1 else int(opt.lr_warmup * total_steps)
|
245 |
+
opt.lr_warmup = [warmup_steps, total_steps]
|
246 |
+
model, criterion, optimizer, lr_scheduler = setup_model(opt)
|
247 |
+
logger.info(f"Model {model}")
|
248 |
+
count_parameters(model)
|
249 |
+
logger.info("Start Training...")
|
250 |
+
train(model, criterion, optimizer, lr_scheduler, train_dataset, eval_dataset, opt)
|
251 |
+
return opt.ckpt_filepath.replace(".ckpt", "_best.ckpt"), opt.eval_split_name, opt.eval_path, opt.debug
|
252 |
+
|
253 |
+
|
254 |
+
if __name__ == '__main__':
|
255 |
+
best_ckpt_path, eval_split_name, eval_path, debug = start_training()
|
256 |
+
if not debug:
|
257 |
+
input_args = ["--resume", best_ckpt_path,
|
258 |
+
"--eval_split_name", eval_split_name,
|
259 |
+
"--eval_path", eval_path]
|
260 |
+
|
261 |
+
import sys
|
262 |
+
sys.argv[1:] = input_args
|
263 |
+
logger.info("\n\n\nFINISHED TRAINING!!!")
|
264 |
+
logger.info("Evaluating model at {}".format(best_ckpt_path))
|
265 |
+
logger.info("Input args {}".format(sys.argv[1:]))
|
266 |
+
start_inference()
|
main/train_qfvs.py
ADDED
@@ -0,0 +1,325 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import pdb
|
3 |
+
import time
|
4 |
+
import json
|
5 |
+
import pprint
|
6 |
+
import random
|
7 |
+
import importlib
|
8 |
+
import numpy as np
|
9 |
+
from tqdm import tqdm, trange
|
10 |
+
from collections import defaultdict
|
11 |
+
|
12 |
+
import h5py
|
13 |
+
import torch
|
14 |
+
import torch.nn as nn
|
15 |
+
import torch.backends.cudnn as cudnn
|
16 |
+
from torch.utils.data import DataLoader
|
17 |
+
from torch.utils.tensorboard import SummaryWriter
|
18 |
+
|
19 |
+
import sys
|
20 |
+
sys.path.append('/Users/kevin/univtg')
|
21 |
+
from main.config import BaseOptions, setup_model
|
22 |
+
from main.dataset_qfvs import DatasetQFVS, prepare_batch_inputs_qfvs, start_end_collate_qfvs
|
23 |
+
from utils.basic_utils import set_seed, AverageMeter, dict_to_markdown, save_json, save_jsonl, load_json, load_pickle, l2_normalize_np_array
|
24 |
+
from utils.model_utils import count_parameters
|
25 |
+
from eval.qfvs import calculate_semantic_matching, load_videos_tag
|
26 |
+
|
27 |
+
import logging
|
28 |
+
logger = logging.getLogger(__name__)
|
29 |
+
logging.basicConfig(format="%(asctime)s.%(msecs)03d:%(levelname)s:%(name)s - %(message)s",
|
30 |
+
datefmt="%Y-%m-%d %H:%M:%S",
|
31 |
+
level=logging.INFO)
|
32 |
+
|
33 |
+
def eval_epoch(model, config, opt):
|
34 |
+
model.eval()
|
35 |
+
f1_sum = 0; p_sum = 0; r_sum = 0
|
36 |
+
|
37 |
+
assert len(config['test_videos']) == 1
|
38 |
+
video_id = config['test_videos'][0]
|
39 |
+
embedding = load_pickle(f"./data/qfvs/txt_clip/{config['txt_feature']}.pkl")
|
40 |
+
|
41 |
+
feat_type = config['vid_feature']
|
42 |
+
feat = h5py.File(f'./data/qfvs/processed/P0{video_id}_{feat_type}.h5', 'r')
|
43 |
+
features = torch.from_numpy(feat['features'][()])
|
44 |
+
seg_len = torch.from_numpy(feat['seg_len'][()])
|
45 |
+
# seg_len = torch.tensor(feat['seg_len'][()]).unsqueeze(0).cuda()
|
46 |
+
|
47 |
+
# dim = features.shape[-1]
|
48 |
+
# ctx_l = seg_len.sum().cpu()
|
49 |
+
|
50 |
+
# dim = features.shape[-1]
|
51 |
+
# ctx_l = features.shape[1]
|
52 |
+
# seg_len = torch.ones(ctx_l)
|
53 |
+
# features = features.reshape(-1, dim)[:ctx_l]
|
54 |
+
|
55 |
+
# tef_st = torch.arange(0, ctx_l, 1.0) / ctx_l
|
56 |
+
# tef_ed = tef_st + 1.0 / ctx_l
|
57 |
+
# tef = torch.stack([tef_st, tef_ed], dim=1).cuda() # (Lv, 2)
|
58 |
+
# features = torch.cat([features, tef], dim=1) # (Lv, Dv+2)
|
59 |
+
|
60 |
+
transfer = {"Cupglass": "Glass",
|
61 |
+
"Musicalinstrument": "Instrument",
|
62 |
+
"Petsanimal": "Animal"}
|
63 |
+
|
64 |
+
for _,_,files in os.walk("./data/qfvs/metadata/origin_data/Query-Focused_Summaries/Oracle_Summaries/P0"+str(video_id)):
|
65 |
+
evaluation_num=len(files)
|
66 |
+
|
67 |
+
mask_GT = torch.zeros(config["max_segment_num"], config["max_frame_num"], dtype=torch.bool).cuda()
|
68 |
+
for j in range(len(seg_len)):
|
69 |
+
for k in range(seg_len[j]):
|
70 |
+
mask_GT[j][k] = 1
|
71 |
+
|
72 |
+
for file in files:
|
73 |
+
summaries_GT=[]
|
74 |
+
with open("./data/qfvs/metadata/origin_data/Query-Focused_Summaries/Oracle_Summaries/P0"+str(video_id)+"/"+file,"r") as f:
|
75 |
+
for line in f.readlines():
|
76 |
+
summaries_GT.append(int(line.strip()))
|
77 |
+
|
78 |
+
concept1, concept2 = file.split('_')[0:2]
|
79 |
+
|
80 |
+
##############
|
81 |
+
if concept1 in transfer:
|
82 |
+
concept1 = transfer[concept1]
|
83 |
+
if concept2 in transfer:
|
84 |
+
concept2 = transfer[concept2]
|
85 |
+
concept1 = embedding[concept1]
|
86 |
+
concept2 = embedding[concept2]
|
87 |
+
|
88 |
+
concept1 = l2_normalize_np_array(concept1)
|
89 |
+
concept2 = l2_normalize_np_array(concept2)
|
90 |
+
|
91 |
+
data = {
|
92 |
+
'features':features,
|
93 |
+
'seg_len': seg_len,
|
94 |
+
'tokens_pad1':torch.from_numpy(concept1),
|
95 |
+
'tokens_pad2':torch.from_numpy(concept2),
|
96 |
+
'mask_GT': mask_GT
|
97 |
+
}
|
98 |
+
|
99 |
+
input1, input2, input_oracle, mask = prepare_batch_inputs_qfvs(start_end_collate_qfvs([data]), config, eval=True)
|
100 |
+
|
101 |
+
summaries_GT = [x - 1 for x in summaries_GT]
|
102 |
+
video_shots_tag = load_videos_tag(mat_path="./eval/Tags.mat")
|
103 |
+
|
104 |
+
if opt.f_loss_coef == 0:
|
105 |
+
output_type = 'saliency_scores'
|
106 |
+
elif opt.s_loss_intra_coef == 0:
|
107 |
+
output_type = 'pred_logits'
|
108 |
+
else:
|
109 |
+
if config['qfvs_score_ensemble'] > 0:
|
110 |
+
output_type = ['pred_logits', 'saliency_scores']
|
111 |
+
else:
|
112 |
+
output_type = 'pred_logits'
|
113 |
+
|
114 |
+
with torch.no_grad():
|
115 |
+
if not isinstance(output_type, list):
|
116 |
+
score1 = model(**input1)[output_type].squeeze()
|
117 |
+
score1 = score1.masked_select(mask_GT)
|
118 |
+
|
119 |
+
score2 = model(**input2)[output_type].squeeze()
|
120 |
+
score2 = score2.masked_select(mask_GT)
|
121 |
+
|
122 |
+
score = model(**input_oracle)[output_type].squeeze()
|
123 |
+
score = score.masked_select(mask_GT)
|
124 |
+
else:
|
125 |
+
score1, score2, score = torch.zeros((int(mask.sum().item()))).cuda(), torch.zeros((int(mask.sum().item()))).cuda(), torch.zeros((int(mask.sum().item()))).cuda()
|
126 |
+
for output_t in output_type:
|
127 |
+
score1 += model(**input1)[output_t].squeeze().masked_select(mask_GT)
|
128 |
+
score2 += model(**input2)[output_t].squeeze().masked_select(mask_GT)
|
129 |
+
score += model(**input_oracle)[output_t].squeeze().masked_select(mask_GT)
|
130 |
+
|
131 |
+
if config['qfvs_score_gather'] > 0:
|
132 |
+
score = score + score1 + score2
|
133 |
+
else:
|
134 |
+
score = score
|
135 |
+
|
136 |
+
# since video4 features dim is greater than video_shots_tag.
|
137 |
+
score = score[:min(score.shape[0], video_shots_tag[video_id-1].shape[0])]
|
138 |
+
_, top_index = score.topk(int(score.shape[0] * config["top_percent"]))
|
139 |
+
|
140 |
+
p, r, f1 = calculate_semantic_matching(list(top_index.cpu().numpy()), summaries_GT, video_shots_tag, video_id=video_id-1)
|
141 |
+
f1_sum+=f1; r_sum+=r; p_sum+=p
|
142 |
+
|
143 |
+
return {'F': round(100* f1_sum/evaluation_num,2) ,
|
144 |
+
'R': round(100* r_sum/evaluation_num,2) ,
|
145 |
+
'P': round(100* p_sum/evaluation_num,2) }
|
146 |
+
|
147 |
+
def idx2time(idx):
|
148 |
+
sec1, sec2 = idx*5, (idx+1)*5
|
149 |
+
|
150 |
+
h1 = sec1 // 3600
|
151 |
+
m1 = (sec1 - h1*3600) // 60
|
152 |
+
s1 = sec1 % 60
|
153 |
+
|
154 |
+
h2 = sec2 // 3600
|
155 |
+
m2 = (sec2 - h2*3600) // 60
|
156 |
+
s2 = sec2 % 60
|
157 |
+
print(h1,m1,s1,'\t', h2,m2,s2)
|
158 |
+
|
159 |
+
def train_epoch(model, criterion, train_loader, optimizer, opt, config, epoch_i, tb_writer):
|
160 |
+
model.train()
|
161 |
+
criterion.train()
|
162 |
+
|
163 |
+
# init meters
|
164 |
+
time_meters = defaultdict(AverageMeter)
|
165 |
+
loss_meters = defaultdict(AverageMeter)
|
166 |
+
|
167 |
+
timer_dataloading = time.time()
|
168 |
+
loss_total = 0
|
169 |
+
|
170 |
+
for batch_idx, batch in enumerate(tqdm(train_loader)):
|
171 |
+
time_meters["dataloading_time"].update(time.time() - timer_dataloading)
|
172 |
+
timer_start = time.time()
|
173 |
+
model_input1, model_input2, model_input_oracle, \
|
174 |
+
model_gt1, model_gt2, model_gt_oracle, \
|
175 |
+
mask_GT = prepare_batch_inputs_qfvs(batch, config)
|
176 |
+
time_meters["prepare_inputs_time"].update(time.time() - timer_start)
|
177 |
+
|
178 |
+
timer_start = time.time()
|
179 |
+
output1 = model(**model_input1)
|
180 |
+
output2 = model(**model_input2)
|
181 |
+
output_oracle = model(**model_input_oracle)
|
182 |
+
|
183 |
+
loss_dict = {}
|
184 |
+
loss_dict1 = criterion(output1, model_gt1, mask_GT)
|
185 |
+
loss_dict2 = criterion(output2, model_gt2, mask_GT)
|
186 |
+
loss_dict3 = criterion(output_oracle, model_gt_oracle, mask_GT)
|
187 |
+
|
188 |
+
weight_dict = criterion.weight_dict
|
189 |
+
if config['qfvs_loss_gather'] > 0:
|
190 |
+
for k in loss_dict1.keys():
|
191 |
+
loss_dict[k] = loss_dict1[k] + loss_dict2[k] + loss_dict3[k]
|
192 |
+
else:
|
193 |
+
loss_dict = loss_dict3
|
194 |
+
|
195 |
+
losses = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict)
|
196 |
+
loss_total += losses.item()
|
197 |
+
|
198 |
+
time_meters["model_forward_time"].update(time.time() - timer_start)
|
199 |
+
timer_start = time.time()
|
200 |
+
optimizer.zero_grad()
|
201 |
+
losses.backward()
|
202 |
+
if opt.grad_clip > 0:
|
203 |
+
nn.utils.clip_grad_norm_(model.parameters(), opt.grad_clip)
|
204 |
+
optimizer.step()
|
205 |
+
time_meters["model_backward_time"].update(time.time() - timer_start)
|
206 |
+
|
207 |
+
timer_dataloading = time.time()
|
208 |
+
return round(loss_total / len(train_loader), 2)
|
209 |
+
|
210 |
+
# train in single domain.
|
211 |
+
def train(model, criterion, optimizer, lr_scheduler, train_loader, opt, config):
|
212 |
+
# if opt.device.type == "cuda":
|
213 |
+
# logger.info("CUDA enabled.")
|
214 |
+
# model.to(opt.device)
|
215 |
+
|
216 |
+
tb_writer = SummaryWriter(opt.tensorboard_log_dir)
|
217 |
+
tb_writer.add_text("hyperparameters", dict_to_markdown(vars(opt), max_str_len=None))
|
218 |
+
opt.train_log_txt_formatter = "{time_str} [Epoch] {epoch:03d} [Loss] {loss_str}\n"
|
219 |
+
opt.eval_log_txt_formatter = "{time_str} [Epoch] {epoch:03d} [Loss] {loss_str} [Metrics] {eval_metrics_str}\n"
|
220 |
+
|
221 |
+
prev_best_score = {'Fscore':0, 'Precision':0, 'Recall':0}
|
222 |
+
if opt.start_epoch is None:
|
223 |
+
start_epoch = -1 if opt.eval_init else 0
|
224 |
+
else:
|
225 |
+
start_epoch = opt.start_epoch
|
226 |
+
|
227 |
+
val_score = eval_epoch(model, config, opt)
|
228 |
+
tb_writer.add_scalar(f"Eval/QFVS-V{config['test_videos'][0]}-fscore", float(val_score['F']), 0)
|
229 |
+
logger.info(f"[Epoch {0}] [Fscore: {val_score['F']} / {prev_best_score['Fscore']}]"
|
230 |
+
f" [Precision: {val_score['P']} / {prev_best_score['Precision']}]"
|
231 |
+
f" [Recall: {val_score['R']} / {prev_best_score['Recall']}]")
|
232 |
+
for epoch_i in trange(start_epoch, opt.n_epoch, desc="Epoch"):
|
233 |
+
if epoch_i > -1:
|
234 |
+
loss_epoch = train_epoch(model, criterion, train_loader, optimizer, opt, config, epoch_i, tb_writer)
|
235 |
+
lr_scheduler.step()
|
236 |
+
eval_epoch_interval = opt.eval_epoch
|
237 |
+
if opt.eval_path is not None and (epoch_i + 1) % eval_epoch_interval == 0:
|
238 |
+
with torch.no_grad():
|
239 |
+
val_score = eval_epoch(model, config, opt)
|
240 |
+
tb_writer.add_scalar(f"Eval/QFVS-V{config['test_videos'][0]}-fscore", float(val_score['F']), epoch_i+1)
|
241 |
+
logger.info(f"[Epoch {epoch_i + 1}, Loss {loss_epoch}] [Fscore: {val_score['F']} / {prev_best_score['Fscore']}]"
|
242 |
+
f" [Precision: {val_score['P']} / {prev_best_score['Precision']}]"
|
243 |
+
f" [Recall: {val_score['R']} / {prev_best_score['Recall']}]")
|
244 |
+
|
245 |
+
if prev_best_score['Fscore'] < val_score['F']:
|
246 |
+
prev_best_score['Fscore'] = val_score['F']
|
247 |
+
prev_best_score['Precision'] = val_score['P']
|
248 |
+
prev_best_score['Recall'] = val_score['R']
|
249 |
+
|
250 |
+
checkpoint = {
|
251 |
+
"model": model.state_dict(),
|
252 |
+
"optimizer": optimizer.state_dict(),
|
253 |
+
"epoch": epoch_i,
|
254 |
+
"opt": opt
|
255 |
+
}
|
256 |
+
torch.save(checkpoint, opt.ckpt_filepath.replace(".ckpt", f"_V{config['test_videos'][0]}_best.ckpt"))
|
257 |
+
tb_writer.close()
|
258 |
+
return prev_best_score
|
259 |
+
|
260 |
+
def update_config(opt, config):
|
261 |
+
# for key in ["max_segment_num", "max_frame_num", "top_percent",
|
262 |
+
# "qfvs_vid_feature", "qfvs_txt_feature", "qfvs_dense_shot",
|
263 |
+
# "qfvs_score_ensemble", "qfvs_score_gather", "qfvs_loss_gather"]:
|
264 |
+
config["max_segment_num"] = opt.max_segment_num
|
265 |
+
config["max_frame_num"] = opt.max_frame_num
|
266 |
+
config["top_percent"] = opt.top_percent
|
267 |
+
config["vid_feature"] = opt.qfvs_vid_feature
|
268 |
+
config["txt_feature"] = opt.qfvs_txt_feature
|
269 |
+
config["qfvs_dense_shot"] = opt.qfvs_dense_shot
|
270 |
+
config["qfvs_score_ensemble"] = opt.qfvs_score_ensemble
|
271 |
+
config["qfvs_score_gather"] = opt.qfvs_score_gather
|
272 |
+
config["qfvs_loss_gather"] = opt.qfvs_loss_gather
|
273 |
+
return config
|
274 |
+
|
275 |
+
def start_training():
|
276 |
+
logger.info("Setup config, data and model...")
|
277 |
+
opt = BaseOptions().parse()
|
278 |
+
set_seed(opt.seed)
|
279 |
+
|
280 |
+
# config = load_json("./main/config_qfvs.json")
|
281 |
+
config = {}
|
282 |
+
config = update_config(opt, config)
|
283 |
+
|
284 |
+
tb_writer = SummaryWriter(opt.tensorboard_log_dir)
|
285 |
+
|
286 |
+
# key -> test video; value -> training videos.
|
287 |
+
qfvs_split = {
|
288 |
+
1: [2, 3, 4],
|
289 |
+
2: [1, 3, 4],
|
290 |
+
3: [1, 2, 4],
|
291 |
+
4: [1, 2, 3]
|
292 |
+
}
|
293 |
+
|
294 |
+
scores_videos = {}
|
295 |
+
for test_id, splits in qfvs_split.items():
|
296 |
+
logger.info(f"Start Training {opt.dset_name}: {test_id}")
|
297 |
+
config['train_videos'] = qfvs_split[test_id]
|
298 |
+
config['test_videos'] = [test_id]
|
299 |
+
train_dataset = DatasetQFVS(config)
|
300 |
+
train_loader = DataLoader(train_dataset, batch_size=opt.bsz, collate_fn=start_end_collate_qfvs, shuffle=True, num_workers=opt.num_workers)
|
301 |
+
|
302 |
+
model, criterion, optimizer, lr_scheduler = setup_model(opt)
|
303 |
+
count_parameters(model)
|
304 |
+
best_score = train(model, criterion, optimizer, lr_scheduler, train_loader, opt, config)
|
305 |
+
scores_videos['V'+str(test_id)] = best_score
|
306 |
+
|
307 |
+
# save the final results.
|
308 |
+
avg_fscore = sum([v['Fscore'] for k, v in scores_videos.items()]) / len(scores_videos)
|
309 |
+
avg_precision = sum([v['Precision'] for k, v in scores_videos.items()]) / len(scores_videos)
|
310 |
+
avg_recall = sum([v['Recall'] for k, v in scores_videos.items()]) / len(scores_videos)
|
311 |
+
scores_videos['avg'] = {'Fscore':avg_fscore, 'Precision':avg_precision, 'Recall':avg_recall}
|
312 |
+
|
313 |
+
save_metrics_path = os.path.join(opt.results_dir, f"best_{opt.dset_name}_{opt.eval_split_name}_preds_metrics.json")
|
314 |
+
save_json( scores_videos, save_metrics_path, save_pretty=True, sort_keys=False)
|
315 |
+
|
316 |
+
tb_writer.add_scalar(f"Eval/QFVS-avg-fscore", round(avg_fscore, 2), 1)
|
317 |
+
tb_writer.add_text(f"Eval/QFVS-{opt.dset_name}", dict_to_markdown(scores_videos, max_str_len=None))
|
318 |
+
tb_writer.close()
|
319 |
+
|
320 |
+
print(scores_videos)
|
321 |
+
return
|
322 |
+
|
323 |
+
if __name__ == '__main__':
|
324 |
+
start_training()
|
325 |
+
results = logger.info("\n\n\nFINISHED TRAINING!!!")
|
main/train_vlp.py
ADDED
@@ -0,0 +1,278 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import pdb
|
3 |
+
import sys
|
4 |
+
import time
|
5 |
+
import json
|
6 |
+
import pprint
|
7 |
+
import random
|
8 |
+
import numpy as np
|
9 |
+
from tqdm import tqdm, trange
|
10 |
+
from collections import defaultdict
|
11 |
+
|
12 |
+
import torch
|
13 |
+
import torch.nn as nn
|
14 |
+
import torch.backends.cudnn as cudnn
|
15 |
+
from torch.utils.data import DataLoader
|
16 |
+
from torch.utils.tensorboard import SummaryWriter
|
17 |
+
|
18 |
+
sys.path.append('/data/home/qinghonglin/univtg')
|
19 |
+
from main.config import BaseOptions, setup_model
|
20 |
+
from main.dataset import \
|
21 |
+
DatasetVLP, start_end_collate_mr, prepare_batch_inputs_mr
|
22 |
+
from main.inference_mr import eval_epoch, start_inference
|
23 |
+
from utils.basic_utils import set_seed, AverageMeter, dict_to_markdown
|
24 |
+
from utils.model_utils import count_parameters
|
25 |
+
|
26 |
+
import logging
|
27 |
+
logger = logging.getLogger(__name__)
|
28 |
+
logging.basicConfig(format="%(asctime)s.%(msecs)03d:%(levelname)s:%(name)s - %(message)s",
|
29 |
+
datefmt="%Y-%m-%d %H:%M:%S",
|
30 |
+
level=logging.INFO)
|
31 |
+
|
32 |
+
def train_epoch(model, criterion, train_loader, optimizer, opt, epoch_i, tb_writer, cls=None):
|
33 |
+
logger.info(f"[Epoch {epoch_i+1}]")
|
34 |
+
model.train()
|
35 |
+
criterion.train()
|
36 |
+
|
37 |
+
# init meters
|
38 |
+
time_meters = defaultdict(AverageMeter)
|
39 |
+
loss_meters = defaultdict(AverageMeter)
|
40 |
+
|
41 |
+
num_training_examples = len(train_loader)
|
42 |
+
timer_dataloading = time.time()
|
43 |
+
for batch_idx, batch in tqdm(enumerate(train_loader),
|
44 |
+
desc="Training Iteration",
|
45 |
+
total=num_training_examples):
|
46 |
+
time_meters["dataloading_time"].update(time.time() - timer_dataloading)
|
47 |
+
|
48 |
+
timer_start = time.time()
|
49 |
+
model_inputs, targets = prepare_batch_inputs_mr(batch[1], opt.device, non_blocking=opt.pin_memory)
|
50 |
+
time_meters["prepare_inputs_time"].update(time.time() - timer_start)
|
51 |
+
|
52 |
+
timer_start = time.time()
|
53 |
+
|
54 |
+
if cls is not None:
|
55 |
+
model_inputs.update(cls)
|
56 |
+
|
57 |
+
# try:
|
58 |
+
outputs = model(**model_inputs)
|
59 |
+
loss_dict = criterion(outputs, targets)
|
60 |
+
weight_dict = criterion.weight_dict
|
61 |
+
losses = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict)
|
62 |
+
time_meters["model_forward_time"].update(time.time() - timer_start)
|
63 |
+
|
64 |
+
timer_start = time.time()
|
65 |
+
optimizer.zero_grad()
|
66 |
+
losses.backward()
|
67 |
+
# except:
|
68 |
+
# pdb.set_trace()
|
69 |
+
if opt.grad_clip > 0:
|
70 |
+
nn.utils.clip_grad_norm_(model.parameters(), opt.grad_clip)
|
71 |
+
optimizer.step()
|
72 |
+
time_meters["model_backward_time"].update(time.time() - timer_start)
|
73 |
+
|
74 |
+
loss_dict["loss_overall"] = float(losses) # for logging only
|
75 |
+
for k, v in loss_dict.items():
|
76 |
+
loss_meters[k].update(float(v) * weight_dict[k] if k in weight_dict else float(v))
|
77 |
+
|
78 |
+
timer_dataloading = time.time()
|
79 |
+
|
80 |
+
# print/add logs
|
81 |
+
tb_writer.add_scalar("Train/lr", float(optimizer.param_groups[0]["lr"]), epoch_i+1)
|
82 |
+
for k, v in loss_meters.items():
|
83 |
+
tb_writer.add_scalar("Train/{}".format(k), v.avg, epoch_i+1)
|
84 |
+
|
85 |
+
to_write = opt.train_log_txt_formatter.format(
|
86 |
+
time_str=time.strftime("%Y_%m_%d_%H_%M_%S"),
|
87 |
+
epoch=epoch_i+1,
|
88 |
+
loss_str=" ".join(["{} {:.4f}".format(k, v.avg) for k, v in loss_meters.items()]))
|
89 |
+
with open(opt.train_log_filepath, "a") as f:
|
90 |
+
f.write(to_write)
|
91 |
+
|
92 |
+
logger.info("Epoch time stats:")
|
93 |
+
for name, meter in time_meters.items():
|
94 |
+
d = {k: f"{getattr(meter, k):.4f}" for k in ["max", "min", "avg"]}
|
95 |
+
logger.info(f"{name} ==> {d}")
|
96 |
+
|
97 |
+
|
98 |
+
def train(model, criterion, optimizer, lr_scheduler, train_dataset, val_dataset, opt):
|
99 |
+
if opt.device.type == "cuda":
|
100 |
+
logger.info("CUDA enabled.")
|
101 |
+
model.to(opt.device)
|
102 |
+
|
103 |
+
tb_writer = SummaryWriter(opt.tensorboard_log_dir)
|
104 |
+
tb_writer.add_text("hyperparameters", dict_to_markdown(vars(opt), max_str_len=None))
|
105 |
+
opt.train_log_txt_formatter = "{time_str} [Epoch] {epoch:03d} [Loss] {loss_str}\n"
|
106 |
+
opt.eval_log_txt_formatter = "{time_str} [Epoch] {epoch:03d} [Loss] {loss_str} [Metrics] {eval_metrics_str}\n"
|
107 |
+
|
108 |
+
train_loader = DataLoader(
|
109 |
+
train_dataset,
|
110 |
+
collate_fn=start_end_collate_mr,
|
111 |
+
batch_size=opt.bsz,
|
112 |
+
num_workers=opt.num_workers,
|
113 |
+
shuffle=True,
|
114 |
+
pin_memory=opt.pin_memory
|
115 |
+
)
|
116 |
+
|
117 |
+
if ('tal' in opt.train_path) or ('mq' in opt.train_path):
|
118 |
+
cls = {
|
119 |
+
'src_cls': train_dataset.src_cls.cuda(),
|
120 |
+
'src_cls_mask': train_dataset.src_cls_mask.cuda(),}
|
121 |
+
else:
|
122 |
+
cls = None
|
123 |
+
|
124 |
+
prev_best_score = 0.
|
125 |
+
es_cnt = 0
|
126 |
+
if opt.start_epoch is None:
|
127 |
+
start_epoch = -1 if opt.eval_init else 0
|
128 |
+
else:
|
129 |
+
start_epoch = opt.start_epoch
|
130 |
+
save_submission_filename = "latest_{}_{}_preds.jsonl".format(opt.dset_name, opt.eval_split_name)
|
131 |
+
for epoch_i in trange(start_epoch, opt.n_epoch, desc="Epoch"):
|
132 |
+
if epoch_i > -1:
|
133 |
+
train_epoch(model, criterion, train_loader, optimizer, opt, epoch_i, tb_writer, cls)
|
134 |
+
lr_scheduler.step()
|
135 |
+
eval_epoch_interval = opt.eval_epoch
|
136 |
+
if opt.eval_path is not None and (epoch_i + 1) % eval_epoch_interval == 0:
|
137 |
+
with torch.no_grad():
|
138 |
+
metrics_no_nms, metrics_nms, eval_loss_meters, latest_file_paths = \
|
139 |
+
eval_epoch(model, val_dataset, opt, save_submission_filename, epoch_i, criterion, tb_writer)
|
140 |
+
|
141 |
+
# log
|
142 |
+
to_write = opt.eval_log_txt_formatter.format(
|
143 |
+
time_str=time.strftime("%Y_%m_%d_%H_%M_%S"),
|
144 |
+
epoch=epoch_i,
|
145 |
+
loss_str=" ".join(["{} {:.4f}".format(k, v.avg) for k, v in eval_loss_meters.items()]),
|
146 |
+
eval_metrics_str=json.dumps(metrics_no_nms))
|
147 |
+
|
148 |
+
with open(opt.eval_log_filepath, "a") as f:
|
149 |
+
f.write(to_write)
|
150 |
+
logger.info("metrics_no_nms {}".format(pprint.pformat(metrics_no_nms["brief"], indent=4)))
|
151 |
+
if metrics_nms is not None:
|
152 |
+
logger.info("metrics_nms {}".format(pprint.pformat(metrics_nms["brief"], indent=4)))
|
153 |
+
|
154 |
+
metrics = metrics_nms if metrics_nms is not None else metrics_no_nms
|
155 |
+
for k, v in metrics["brief"].items():
|
156 |
+
tb_writer.add_scalar(f"Eval/{k}", float(v), epoch_i+1)
|
157 |
+
|
158 |
+
# stop_score = metrics["brief"]["MR-full-mAP"]
|
159 |
+
# pdb.set_trace()
|
160 |
+
stop_score = metrics["brief"][opt.main_metric]
|
161 |
+
if stop_score > prev_best_score:
|
162 |
+
es_cnt = 0
|
163 |
+
prev_best_score = stop_score
|
164 |
+
|
165 |
+
checkpoint = {
|
166 |
+
"model": model.state_dict(),
|
167 |
+
"optimizer": optimizer.state_dict(),
|
168 |
+
"lr_scheduler": lr_scheduler.state_dict(),
|
169 |
+
"epoch": epoch_i,
|
170 |
+
"opt": opt
|
171 |
+
}
|
172 |
+
torch.save(checkpoint, opt.ckpt_filepath.replace(".ckpt", "_best.ckpt"))
|
173 |
+
|
174 |
+
best_file_paths = [e.replace("latest", "best") for e in latest_file_paths]
|
175 |
+
for src, tgt in zip(latest_file_paths, best_file_paths):
|
176 |
+
os.renames(src, tgt)
|
177 |
+
logger.info("The checkpoint file has been updated.")
|
178 |
+
else:
|
179 |
+
es_cnt += 1
|
180 |
+
if opt.max_es_cnt != -1 and es_cnt > opt.max_es_cnt: # early stop
|
181 |
+
with open(opt.train_log_filepath, "a") as f:
|
182 |
+
f.write(f"Early Stop at epoch {epoch_i}")
|
183 |
+
logger.info(f"\n>>>>> Early stop at epoch {epoch_i} {prev_best_score}\n")
|
184 |
+
break
|
185 |
+
|
186 |
+
# save ckpt
|
187 |
+
checkpoint = {
|
188 |
+
"model": model.state_dict(),
|
189 |
+
"optimizer": optimizer.state_dict(),
|
190 |
+
"lr_scheduler": lr_scheduler.state_dict(),
|
191 |
+
"epoch": epoch_i,
|
192 |
+
"opt": opt
|
193 |
+
}
|
194 |
+
torch.save(checkpoint, opt.ckpt_filepath.replace(".ckpt", "_latest.ckpt"))
|
195 |
+
|
196 |
+
if (epoch_i + 1) % opt.save_interval == 0 or (epoch_i + 1) % opt.lr_drop == 0: # additional copies
|
197 |
+
checkpoint = {
|
198 |
+
"model": model.state_dict(),
|
199 |
+
"optimizer": optimizer.state_dict(),
|
200 |
+
"epoch": epoch_i,
|
201 |
+
"opt": opt
|
202 |
+
}
|
203 |
+
torch.save(checkpoint, opt.ckpt_filepath.replace(".ckpt", f"_e{epoch_i:04d}.ckpt"))
|
204 |
+
|
205 |
+
if opt.debug:
|
206 |
+
break
|
207 |
+
|
208 |
+
tb_writer.close()
|
209 |
+
|
210 |
+
|
211 |
+
def start_training():
|
212 |
+
logger.info("Setup config, data and model...")
|
213 |
+
opt = BaseOptions().parse()
|
214 |
+
set_seed(opt.seed)
|
215 |
+
if opt.debug: # keep the model run deterministically
|
216 |
+
# 'cudnn.benchmark = True' enabled auto finding the best algorithm for a specific input/net config.
|
217 |
+
# Enable this only when input size is fixed.
|
218 |
+
cudnn.benchmark = False
|
219 |
+
cudnn.deterministic = True
|
220 |
+
|
221 |
+
dataset_config = dict(
|
222 |
+
dset_name=opt.dset_name,
|
223 |
+
data_path=opt.train_path,
|
224 |
+
v_feat_dirs=opt.v_feat_dirs,
|
225 |
+
q_feat_dir=opt.t_feat_dir,
|
226 |
+
v_feat_dim=opt.v_feat_dim,
|
227 |
+
q_feat_dim=opt.t_feat_dim,
|
228 |
+
q_feat_type="last_hidden_state",
|
229 |
+
max_q_l=opt.max_q_l,
|
230 |
+
max_v_l=opt.max_v_l,
|
231 |
+
ctx_mode=opt.ctx_mode,
|
232 |
+
data_ratio=opt.data_ratio,
|
233 |
+
normalize_v=not opt.no_norm_vfeat,
|
234 |
+
normalize_t=not opt.no_norm_tfeat,
|
235 |
+
clip_len=opt.clip_length,
|
236 |
+
max_windows=opt.max_windows,
|
237 |
+
span_loss_type=opt.span_loss_type,
|
238 |
+
txt_drop_ratio=opt.txt_drop_ratio,
|
239 |
+
use_cache=opt.use_cache,
|
240 |
+
add_easy_negative=opt.add_easy_negative,
|
241 |
+
easy_negative_only=opt.easy_negative_only
|
242 |
+
)
|
243 |
+
|
244 |
+
dataset_config["data_path"] = opt.train_path
|
245 |
+
train_dataset = DatasetVLP(**dataset_config)
|
246 |
+
|
247 |
+
if opt.eval_path is not None:
|
248 |
+
dataset_config["data_path"] = opt.eval_path
|
249 |
+
dataset_config["txt_drop_ratio"] = 0
|
250 |
+
dataset_config["q_feat_dir"] = opt.t_feat_dir.replace("txt_clip_asr", "txt_clip").replace("txt_clip_cap", "txt_clip") # for pretraining
|
251 |
+
# dataset_config["load_labels"] = False # uncomment to calculate eval loss
|
252 |
+
eval_dataset = DatasetVLP(**dataset_config)
|
253 |
+
else:
|
254 |
+
eval_dataset = None
|
255 |
+
|
256 |
+
if opt.lr_warmup > 0:
|
257 |
+
opt.lr_warmup = opt.n_epoch
|
258 |
+
model, criterion, optimizer, lr_scheduler = setup_model(opt)
|
259 |
+
logger.info(f"Model {model}")
|
260 |
+
count_parameters(model)
|
261 |
+
logger.info("Start Training...")
|
262 |
+
train(model, criterion, optimizer, lr_scheduler, train_dataset, eval_dataset, opt)
|
263 |
+
return opt.ckpt_filepath.replace(".ckpt", "_best.ckpt"), opt.eval_split_name, opt.eval_path, opt.debug
|
264 |
+
|
265 |
+
|
266 |
+
if __name__ == '__main__':
|
267 |
+
best_ckpt_path, eval_split_name, eval_path, debug = start_training()
|
268 |
+
if not debug:
|
269 |
+
input_args = ["--resume", best_ckpt_path,
|
270 |
+
"--eval_split_name", eval_split_name,
|
271 |
+
"--eval_path", eval_path]
|
272 |
+
|
273 |
+
import sys
|
274 |
+
sys.argv[1:] = input_args
|
275 |
+
logger.info("\n\n\nFINISHED TRAINING!!!")
|
276 |
+
logger.info("Evaluating model at {}".format(best_ckpt_path))
|
277 |
+
logger.info("Input args {}".format(sys.argv[1:]))
|
278 |
+
start_inference()
|
main/train_vlp_ddp.py
ADDED
@@ -0,0 +1,288 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import pdb
|
3 |
+
import sys
|
4 |
+
import time
|
5 |
+
import json
|
6 |
+
import pprint
|
7 |
+
import random
|
8 |
+
import numpy as np
|
9 |
+
from tqdm import tqdm, trange
|
10 |
+
from collections import defaultdict
|
11 |
+
|
12 |
+
import torch
|
13 |
+
import torch.nn as nn
|
14 |
+
import torch.distributed as dist
|
15 |
+
import torch.backends.cudnn as cudnn
|
16 |
+
from torch.utils.data import DataLoader
|
17 |
+
from torch.utils.tensorboard import SummaryWriter
|
18 |
+
from torch.utils.data.distributed import DistributedSampler
|
19 |
+
|
20 |
+
sys.path.append('/data/home/qinghonglin/univtg')
|
21 |
+
from main.config import BaseOptions, setup_model
|
22 |
+
from main.dataset import \
|
23 |
+
DatasetMR, DatasetVLP, start_end_collate_mr, prepare_batch_inputs_mr
|
24 |
+
from main.inference_mr import eval_epoch, start_inference
|
25 |
+
from utils.basic_utils import set_seed, AverageMeter, dict_to_markdown
|
26 |
+
from utils.model_utils import count_parameters
|
27 |
+
|
28 |
+
import logging
|
29 |
+
logger = logging.getLogger(__name__)
|
30 |
+
logging.basicConfig(format="%(asctime)s.%(msecs)03d:%(levelname)s:%(name)s - %(message)s",
|
31 |
+
datefmt="%Y-%m-%d %H:%M:%S",
|
32 |
+
level=logging.INFO)
|
33 |
+
|
34 |
+
def train_epoch(model, criterion, train_loader, optimizer, opt, epoch_i, tb_writer):
|
35 |
+
logger.info(f"[Epoch {epoch_i+1}]")
|
36 |
+
model.train()
|
37 |
+
criterion.train()
|
38 |
+
|
39 |
+
# init meters
|
40 |
+
time_meters = defaultdict(AverageMeter)
|
41 |
+
loss_meters = defaultdict(AverageMeter)
|
42 |
+
|
43 |
+
num_training_examples = len(train_loader)
|
44 |
+
timer_dataloading = time.time()
|
45 |
+
for batch_idx, batch in tqdm(enumerate(train_loader),
|
46 |
+
desc="Training Iteration",
|
47 |
+
total=num_training_examples):
|
48 |
+
time_meters["dataloading_time"].update(time.time() - timer_dataloading)
|
49 |
+
|
50 |
+
timer_start = time.time()
|
51 |
+
model_inputs, targets = prepare_batch_inputs_mr(batch[1], torch.device("cuda", int(opt.local_rank)), non_blocking=opt.pin_memory)
|
52 |
+
time_meters["prepare_inputs_time"].update(time.time() - timer_start)
|
53 |
+
|
54 |
+
timer_start = time.time()
|
55 |
+
|
56 |
+
# try:
|
57 |
+
outputs = model(**model_inputs)
|
58 |
+
loss_dict = criterion(outputs, targets)
|
59 |
+
weight_dict = criterion.weight_dict
|
60 |
+
losses = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict)
|
61 |
+
time_meters["model_forward_time"].update(time.time() - timer_start)
|
62 |
+
|
63 |
+
timer_start = time.time()
|
64 |
+
optimizer.zero_grad()
|
65 |
+
losses.backward()
|
66 |
+
|
67 |
+
if opt.grad_clip > 0:
|
68 |
+
nn.utils.clip_grad_norm_(model.parameters(), opt.grad_clip)
|
69 |
+
optimizer.step()
|
70 |
+
time_meters["model_backward_time"].update(time.time() - timer_start)
|
71 |
+
|
72 |
+
loss_dict["loss_overall"] = float(losses) # for logging only
|
73 |
+
for k, v in loss_dict.items():
|
74 |
+
loss_meters[k].update(float(v) * weight_dict[k] if k in weight_dict else float(v))
|
75 |
+
|
76 |
+
timer_dataloading = time.time()
|
77 |
+
|
78 |
+
# print/add logs
|
79 |
+
if int(opt.local_rank) in [0, -1]:
|
80 |
+
tb_writer.add_scalar("Train/lr", float(optimizer.param_groups[0]["lr"]), epoch_i+1)
|
81 |
+
for k, v in loss_meters.items():
|
82 |
+
tb_writer.add_scalar("Train/{}".format(k), v.avg, epoch_i+1)
|
83 |
+
|
84 |
+
to_write = opt.train_log_txt_formatter.format(
|
85 |
+
time_str=time.strftime("%Y_%m_%d_%H_%M_%S"),
|
86 |
+
epoch=epoch_i+1,
|
87 |
+
loss_str=" ".join(["{} {:.4f}".format(k, v.avg) for k, v in loss_meters.items()]))
|
88 |
+
with open(opt.train_log_filepath, "a") as f:
|
89 |
+
f.write(to_write)
|
90 |
+
|
91 |
+
logger.info("Epoch time stats:")
|
92 |
+
for name, meter in time_meters.items():
|
93 |
+
d = {k: f"{getattr(meter, k):.4f}" for k in ["max", "min", "avg"]}
|
94 |
+
logger.info(f"{name} ==> {d}")
|
95 |
+
|
96 |
+
|
97 |
+
def train(model, criterion, optimizer, lr_scheduler, train_dataset, val_dataset, opt):
|
98 |
+
if int(opt.local_rank) in [0, -1]:
|
99 |
+
tb_writer = SummaryWriter(opt.tensorboard_log_dir)
|
100 |
+
tb_writer.add_text("hyperparameters", dict_to_markdown(vars(opt), max_str_len=None))
|
101 |
+
opt.train_log_txt_formatter = "{time_str} [Epoch] {epoch:03d} [Loss] {loss_str}\n"
|
102 |
+
opt.eval_log_txt_formatter = "{time_str} [Epoch] {epoch:03d} [Loss] {loss_str} [Metrics] {eval_metrics_str}\n"
|
103 |
+
else:
|
104 |
+
tb_writer = None
|
105 |
+
|
106 |
+
train_loader = DataLoader(
|
107 |
+
train_dataset,
|
108 |
+
collate_fn=start_end_collate_mr,
|
109 |
+
batch_size=opt.bsz,
|
110 |
+
num_workers=opt.num_workers,
|
111 |
+
# shuffle=True,
|
112 |
+
pin_memory=opt.pin_memory,
|
113 |
+
sampler=DistributedSampler(train_dataset)
|
114 |
+
)
|
115 |
+
|
116 |
+
prev_best_score = 0.
|
117 |
+
es_cnt = 0
|
118 |
+
if opt.start_epoch is None:
|
119 |
+
start_epoch = -1 if opt.eval_init else 0
|
120 |
+
else:
|
121 |
+
start_epoch = opt.start_epoch
|
122 |
+
save_submission_filename = "latest_{}_{}_preds.jsonl".format(opt.dset_name, opt.eval_split_name)
|
123 |
+
for epoch_i in trange(start_epoch, opt.n_epoch, desc="Epoch"):
|
124 |
+
if epoch_i > -1:
|
125 |
+
train_epoch(model, criterion, train_loader, optimizer, opt, epoch_i, tb_writer)
|
126 |
+
lr_scheduler.step()
|
127 |
+
eval_epoch_interval = opt.eval_epoch
|
128 |
+
if int(opt.local_rank) in [0, -1] and opt.eval_path is not None and (epoch_i + 1) % eval_epoch_interval == 0:
|
129 |
+
with torch.no_grad():
|
130 |
+
metrics_no_nms, metrics_nms, eval_loss_meters, latest_file_paths = \
|
131 |
+
eval_epoch(model, val_dataset, opt, save_submission_filename, epoch_i, criterion, tb_writer)
|
132 |
+
|
133 |
+
# log
|
134 |
+
to_write = opt.eval_log_txt_formatter.format(
|
135 |
+
time_str=time.strftime("%Y_%m_%d_%H_%M_%S"),
|
136 |
+
epoch=epoch_i,
|
137 |
+
loss_str=" ".join(["{} {:.4f}".format(k, v.avg) for k, v in eval_loss_meters.items()]),
|
138 |
+
eval_metrics_str=json.dumps(metrics_no_nms))
|
139 |
+
|
140 |
+
if int(opt.local_rank) in [0, -1]:
|
141 |
+
with open(opt.eval_log_filepath, "a") as f:
|
142 |
+
f.write(to_write)
|
143 |
+
logger.info("metrics_no_nms {}".format(pprint.pformat(metrics_no_nms["brief"], indent=4)))
|
144 |
+
if metrics_nms is not None:
|
145 |
+
logger.info("metrics_nms {}".format(pprint.pformat(metrics_nms["brief"], indent=4)))
|
146 |
+
|
147 |
+
metrics = metrics_nms if metrics_nms is not None else metrics_no_nms
|
148 |
+
for k, v in metrics["brief"].items():
|
149 |
+
tb_writer.add_scalar(f"Eval/{k}", float(v), epoch_i+1)
|
150 |
+
|
151 |
+
# stop_score = metrics["brief"]["MR-full-mAP"]
|
152 |
+
# pdb.set_trace()
|
153 |
+
stop_score = metrics["brief"][opt.main_metric]
|
154 |
+
if stop_score > prev_best_score:
|
155 |
+
es_cnt = 0
|
156 |
+
prev_best_score = stop_score
|
157 |
+
|
158 |
+
checkpoint = {
|
159 |
+
"model": model.state_dict(),
|
160 |
+
"optimizer": optimizer.state_dict(),
|
161 |
+
"lr_scheduler": lr_scheduler.state_dict(),
|
162 |
+
"epoch": epoch_i,
|
163 |
+
"opt": opt
|
164 |
+
}
|
165 |
+
torch.save(checkpoint, opt.ckpt_filepath.replace(".ckpt", "_best.ckpt"))
|
166 |
+
|
167 |
+
best_file_paths = [e.replace("latest", "best") for e in latest_file_paths]
|
168 |
+
for src, tgt in zip(latest_file_paths, best_file_paths):
|
169 |
+
os.renames(src, tgt)
|
170 |
+
logger.info("The checkpoint file has been updated.")
|
171 |
+
else:
|
172 |
+
es_cnt += 1
|
173 |
+
if opt.max_es_cnt != -1 and es_cnt > opt.max_es_cnt: # early stop
|
174 |
+
with open(opt.train_log_filepath, "a") as f:
|
175 |
+
f.write(f"Early Stop at epoch {epoch_i}")
|
176 |
+
logger.info(f"\n>>>>> Early stop at epoch {epoch_i} {prev_best_score}\n")
|
177 |
+
break
|
178 |
+
|
179 |
+
# save ckpt
|
180 |
+
checkpoint = {
|
181 |
+
"model": model.state_dict(),
|
182 |
+
"optimizer": optimizer.state_dict(),
|
183 |
+
"lr_scheduler": lr_scheduler.state_dict(),
|
184 |
+
"epoch": epoch_i,
|
185 |
+
"opt": opt
|
186 |
+
}
|
187 |
+
torch.save(checkpoint, opt.ckpt_filepath.replace(".ckpt", "_latest.ckpt"))
|
188 |
+
|
189 |
+
if int(opt.local_rank) in [0, -1] and ((epoch_i + 1) % opt.save_interval == 0 or (epoch_i + 1) % opt.lr_drop == 0): # additional copies
|
190 |
+
checkpoint = {
|
191 |
+
"model": model.state_dict(),
|
192 |
+
"optimizer": optimizer.state_dict(),
|
193 |
+
"epoch": epoch_i,
|
194 |
+
"opt": opt
|
195 |
+
}
|
196 |
+
torch.save(checkpoint, opt.ckpt_filepath.replace(".ckpt", f"_e{epoch_i:04d}.ckpt"))
|
197 |
+
|
198 |
+
if opt.debug:
|
199 |
+
break
|
200 |
+
|
201 |
+
if int(opt.local_rank) in [0, -1]:
|
202 |
+
tb_writer.close()
|
203 |
+
|
204 |
+
|
205 |
+
def start_training():
|
206 |
+
# logger.info("Setup config, data and model...")
|
207 |
+
opt = BaseOptions().parse()
|
208 |
+
set_seed(opt.seed)
|
209 |
+
if opt.debug: # keep the model run deterministically
|
210 |
+
# 'cudnn.benchmark = True' enabled auto finding the best algorithm for a specific input/net config.
|
211 |
+
# Enable this only when input size is fixed.
|
212 |
+
cudnn.benchmark = False
|
213 |
+
cudnn.deterministic = True
|
214 |
+
|
215 |
+
local_rank = int(opt.local_rank)
|
216 |
+
dist.init_process_group(backend='nccl')
|
217 |
+
|
218 |
+
torch.cuda.set_device(local_rank)
|
219 |
+
device = torch.device("cuda", local_rank)
|
220 |
+
|
221 |
+
dataset_config = dict(
|
222 |
+
dset_name=opt.dset_name,
|
223 |
+
data_path=opt.train_path,
|
224 |
+
v_feat_dirs=opt.v_feat_dirs,
|
225 |
+
q_feat_dir=opt.t_feat_dir,
|
226 |
+
v_feat_dim=opt.v_feat_dim,
|
227 |
+
q_feat_dim=opt.t_feat_dim,
|
228 |
+
q_feat_type="last_hidden_state",
|
229 |
+
max_q_l=opt.max_q_l,
|
230 |
+
max_v_l=opt.max_v_l,
|
231 |
+
ctx_mode=opt.ctx_mode,
|
232 |
+
data_ratio=opt.data_ratio,
|
233 |
+
normalize_v=not opt.no_norm_vfeat,
|
234 |
+
normalize_t=not opt.no_norm_tfeat,
|
235 |
+
clip_len=opt.clip_length,
|
236 |
+
max_windows=opt.max_windows,
|
237 |
+
span_loss_type=opt.span_loss_type,
|
238 |
+
txt_drop_ratio=opt.txt_drop_ratio,
|
239 |
+
use_cache=opt.use_cache,
|
240 |
+
add_easy_negative=opt.add_easy_negative,
|
241 |
+
easy_negative_only=opt.easy_negative_only
|
242 |
+
)
|
243 |
+
|
244 |
+
dataset_config["data_path"] = opt.train_path
|
245 |
+
train_dataset = DatasetVLP(**dataset_config)
|
246 |
+
|
247 |
+
if opt.eval_path is not None:
|
248 |
+
# perform zero-shot on qvhl.
|
249 |
+
dataset_config["data_path"] = opt.eval_path
|
250 |
+
dataset_config["txt_drop_ratio"] = 0
|
251 |
+
if len(dataset_config["v_feat_dirs"]) == 1:
|
252 |
+
dataset_config["v_feat_dirs"] = ["data/qvhighlights/vid_clip"]
|
253 |
+
elif len(dataset_config["v_feat_dirs"]) == 2:
|
254 |
+
dataset_config["v_feat_dirs"] = ["data/qvhighlights/vid_slowfast", "data/qvhighlights/vid_clip"]
|
255 |
+
else:
|
256 |
+
raise NotImplementedError
|
257 |
+
dataset_config["q_feat_dir"] = "data/qvhighlights/txt_clip"
|
258 |
+
dataset_config["data_ratio"] = 1
|
259 |
+
# dataset_config["q_feat_dir"] = opt.t_feat_dir.replace("txt_clip_asr", "txt_clip").replace("txt_clip_cap", "txt_clip") # for pretraining
|
260 |
+
eval_dataset = DatasetMR(**dataset_config)
|
261 |
+
else:
|
262 |
+
eval_dataset = None
|
263 |
+
|
264 |
+
if opt.lr_warmup > 0:
|
265 |
+
# total_steps = opt.n_epoch * len(train_dataset) // opt.bsz
|
266 |
+
total_steps = opt.n_epoch
|
267 |
+
warmup_steps = opt.lr_warmup if opt.lr_warmup > 1 else int(opt.lr_warmup * total_steps)
|
268 |
+
opt.lr_warmup = [warmup_steps, total_steps]
|
269 |
+
model, criterion, optimizer, lr_scheduler = setup_model(opt)
|
270 |
+
|
271 |
+
model.to(device)
|
272 |
+
logger.info(f"Using {torch.cuda.device_count()} GPUs.")
|
273 |
+
model = torch.nn.parallel.DistributedDataParallel(model,
|
274 |
+
device_ids=[local_rank],
|
275 |
+
output_device=local_rank,
|
276 |
+
find_unused_parameters=True)
|
277 |
+
|
278 |
+
if int(opt.local_rank) in [0, -1]:
|
279 |
+
logger.info(f"Model {model}")
|
280 |
+
count_parameters(model)
|
281 |
+
logger.info("Start Training...")
|
282 |
+
train(model, criterion, optimizer, lr_scheduler, train_dataset, eval_dataset, opt)
|
283 |
+
# return opt.ckpt_filepath.replace(".ckpt", "_best.ckpt"), opt.eval_split_name, opt.eval_path, opt.debug
|
284 |
+
return
|
285 |
+
|
286 |
+
if __name__ == '__main__':
|
287 |
+
# best_ckpt_path, eval_split_name, eval_path, debug = start_training()
|
288 |
+
start_training()
|
model/base.py
ADDED
@@ -0,0 +1,449 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pdb
|
2 |
+
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from torch import nn
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
from model.transformer_encoder import build_transformer
|
8 |
+
from model.matcher import build_matcher
|
9 |
+
from model.position_encoding import build_position_encoding
|
10 |
+
from utils.span_utils import generalized_temporal_iou, span_cxw_to_xx
|
11 |
+
|
12 |
+
def init_weights(module):
|
13 |
+
if isinstance(module, (nn.Linear, nn.Embedding)):
|
14 |
+
module.weight.data.normal_(mean=0.0, std=0.02)
|
15 |
+
elif isinstance(module, nn.LayerNorm):
|
16 |
+
module.bias.data.zero_()
|
17 |
+
module.weight.data.fill_(1.0)
|
18 |
+
|
19 |
+
if isinstance(module, nn.Linear) and module.bias is not None:
|
20 |
+
module.bias.data.zero_()
|
21 |
+
|
22 |
+
def mask_logits(inputs, mask, mask_value=-1e30):
|
23 |
+
mask = mask.type(torch.float32)
|
24 |
+
return inputs + (1.0 - mask) * mask_value
|
25 |
+
|
26 |
+
def sim_matrix(a, b, eps=1e-8):
|
27 |
+
"""
|
28 |
+
added eps for numerical stability
|
29 |
+
"""
|
30 |
+
a_n, b_n = a.norm(dim=1)[:, None], b.norm(dim=1)[:, None]
|
31 |
+
a_norm = a / torch.max(a_n, eps * torch.ones_like(a_n))
|
32 |
+
b_norm = b / torch.max(b_n, eps * torch.ones_like(b_n))
|
33 |
+
sim_mt = torch.mm(a_norm, b_norm.transpose(0, 1))
|
34 |
+
return sim_mt
|
35 |
+
|
36 |
+
class WeightedPool(nn.Module):
|
37 |
+
def __init__(self, dim):
|
38 |
+
super(WeightedPool, self).__init__()
|
39 |
+
weight = torch.empty(dim, 1)
|
40 |
+
nn.init.xavier_uniform_(weight)
|
41 |
+
self.weight = nn.Parameter(weight, requires_grad=True)
|
42 |
+
|
43 |
+
def forward(self, x, mask):
|
44 |
+
alpha = torch.tensordot(x, self.weight, dims=1) # shape = (batch_size, seq_length, 1)
|
45 |
+
alpha = mask_logits(alpha, mask=mask.unsqueeze(2))
|
46 |
+
alphas = nn.Softmax(dim=1)(alpha)
|
47 |
+
pooled_x = torch.matmul(x.transpose(1, 2), alphas) # (batch_size, dim, 1)
|
48 |
+
pooled_x = pooled_x.squeeze(2)
|
49 |
+
return pooled_x
|
50 |
+
|
51 |
+
class Model(nn.Module):
|
52 |
+
""" This is the UniVTG module that performs moment localization. """
|
53 |
+
|
54 |
+
def __init__(self, transformer, position_embed, txt_position_embed, txt_dim, vid_dim,
|
55 |
+
input_dropout, aux_loss=False,
|
56 |
+
max_v_l=75, span_loss_type="l1", use_txt_pos=False, n_input_proj=2):
|
57 |
+
""" Initializes the model.
|
58 |
+
Parameters:
|
59 |
+
transformer: torch module of the transformer architecture. See transformer.py
|
60 |
+
position_embed: torch module of the position_embedding, See position_encoding.py
|
61 |
+
txt_position_embed: position_embedding for text
|
62 |
+
txt_dim: int, text query input dimension
|
63 |
+
vid_dim: int, video feature input dimension
|
64 |
+
max_v_l: int, maximum #clips in videos
|
65 |
+
span_loss_type: str, one of [l1, ce]
|
66 |
+
l1: (center-x, width) regression.
|
67 |
+
ce: (st_idx, ed_idx) classification.
|
68 |
+
# foreground_thd: float, intersection over prediction >= foreground_thd: labeled as foreground
|
69 |
+
# background_thd: float, intersection over prediction <= background_thd: labeled background
|
70 |
+
"""
|
71 |
+
super().__init__()
|
72 |
+
self.transformer = transformer
|
73 |
+
self.position_embed = position_embed
|
74 |
+
self.txt_position_embed = txt_position_embed
|
75 |
+
hidden_dim = transformer.d_model
|
76 |
+
self.span_loss_type = span_loss_type
|
77 |
+
self.max_v_l = max_v_l
|
78 |
+
span_pred_dim = 2 if span_loss_type == "l1" else max_v_l * 2
|
79 |
+
|
80 |
+
self.token_type_embeddings = nn.Embedding(2, hidden_dim)
|
81 |
+
self.token_type_embeddings.apply(init_weights)
|
82 |
+
|
83 |
+
# Conv projector
|
84 |
+
self.span_embed = Conv(hidden_dim, hidden_dim, span_pred_dim, 3, kernel_size=3)
|
85 |
+
self.class_embed = Conv(hidden_dim, hidden_dim, 1, 3, kernel_size=3) # 0: background, 1: foreground
|
86 |
+
|
87 |
+
self.use_txt_pos = use_txt_pos
|
88 |
+
self.n_input_proj = n_input_proj
|
89 |
+
relu_args = [True] * 3
|
90 |
+
relu_args[n_input_proj-1] = False
|
91 |
+
self.input_txt_proj = nn.Sequential(*[
|
92 |
+
LinearLayer(txt_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[0]),
|
93 |
+
LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[1]),
|
94 |
+
LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[2])
|
95 |
+
][:n_input_proj])
|
96 |
+
self.input_vid_proj = nn.Sequential(*[
|
97 |
+
LinearLayer(vid_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[0]),
|
98 |
+
LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[1]),
|
99 |
+
LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[2])
|
100 |
+
][:n_input_proj])
|
101 |
+
|
102 |
+
# MLP Projector
|
103 |
+
self.weightedpool = WeightedPool(hidden_dim)
|
104 |
+
|
105 |
+
def forward(self, src_txt, src_txt_mask, src_vid, src_vid_mask, src_cls=None, src_cls_mask=None):
|
106 |
+
bs = src_vid.shape[0]
|
107 |
+
src_vid = self.input_vid_proj(src_vid)
|
108 |
+
src_txt = self.input_txt_proj(src_txt)
|
109 |
+
if src_cls is not None:
|
110 |
+
src_cls = self.input_txt_proj(src_cls)
|
111 |
+
|
112 |
+
# type token.
|
113 |
+
src_vid = src_vid + self.token_type_embeddings(torch.full_like(src_vid_mask.long(), 1))
|
114 |
+
src_txt = src_txt + self.token_type_embeddings(torch.zeros_like(src_txt_mask.long()))
|
115 |
+
if src_cls is not None:
|
116 |
+
src_cls = src_cls + self.token_type_embeddings(torch.zeros_like(src_cls_mask.long()))
|
117 |
+
|
118 |
+
src = torch.cat([src_vid, src_txt], dim=1) # (bsz, L_vid+L_txt, d)
|
119 |
+
mask = torch.cat([src_vid_mask, src_txt_mask], dim=1).bool() # (bsz, L_vid+L_txt)
|
120 |
+
|
121 |
+
pos_vid = self.position_embed(src_vid, src_vid_mask) # (bsz, L_vid, d)
|
122 |
+
pos_txt = self.txt_position_embed(src_txt) if self.use_txt_pos else torch.zeros_like(src_txt) # (bsz, L_txt, d)
|
123 |
+
pos = torch.cat([pos_vid, pos_txt], dim=1)
|
124 |
+
|
125 |
+
memory = self.transformer(src, ~mask, pos)
|
126 |
+
vid_mem = memory[:, :src_vid.shape[1], :] # (bsz, L_vid, d)
|
127 |
+
|
128 |
+
outputs_class = self.class_embed(vid_mem).sigmoid() # (#layers, batch_size, #queries, #classes)
|
129 |
+
outputs_coord = self.span_embed(vid_mem) # (#layers, bsz, #queries, 2 or max_v_l * 2)
|
130 |
+
|
131 |
+
if self.span_loss_type == "l1":
|
132 |
+
outputs_coord = outputs_coord.sigmoid()
|
133 |
+
idx_mask = torch.tensor((-1, 1)).unsqueeze(0).unsqueeze(0).cuda()
|
134 |
+
idx_mask = idx_mask.repeat(outputs_coord.shape[0], outputs_coord.shape[1], 1)
|
135 |
+
outputs_coord = outputs_coord * idx_mask
|
136 |
+
else:
|
137 |
+
raise NotImplementedError
|
138 |
+
|
139 |
+
out = {'pred_logits': outputs_class, 'pred_spans': outputs_coord,
|
140 |
+
'src_vid_mask': src_vid_mask}
|
141 |
+
|
142 |
+
vid_mem_proj = src_vid
|
143 |
+
|
144 |
+
# word-level -> sentence-level
|
145 |
+
txt_mem_proj = self.weightedpool(src_txt, src_txt_mask).unsqueeze(1)
|
146 |
+
sim = F.cosine_similarity(vid_mem_proj, txt_mem_proj, dim=-1) + (src_vid_mask + 1e-45).log()
|
147 |
+
|
148 |
+
out["vid_mem_proj"] = vid_mem_proj
|
149 |
+
out["txt_mem_proj"] = txt_mem_proj
|
150 |
+
if src_cls is not None:
|
151 |
+
cls_mem_proj = self.weightedpool(src_cls, src_cls_mask)
|
152 |
+
out["cls_mem_proj"] = cls_mem_proj
|
153 |
+
out["saliency_scores"] = sim
|
154 |
+
return out
|
155 |
+
|
156 |
+
class SetCriterion(nn.Module):
|
157 |
+
""" This class computes the loss for DETR.
|
158 |
+
The process happens in two steps:
|
159 |
+
1) we compute hungarian assignment between ground truth boxes and the outputs of the model
|
160 |
+
2) we supervise each pair of matched ground-truth / prediction (supervise class and box)
|
161 |
+
"""
|
162 |
+
|
163 |
+
def __init__(self, matcher, weight_dict, eos_coef, losses, temperature, span_loss_type, max_v_l,
|
164 |
+
saliency_margin=1):
|
165 |
+
""" Create the criterion.
|
166 |
+
Parameters:
|
167 |
+
matcher: module able to compute a matching between targets and proposals
|
168 |
+
weight_dict: dict containing as key the names of the losses and as values their relative weight.
|
169 |
+
eos_coef: relative classification weight applied to the no-object category
|
170 |
+
losses: list of all the losses to be applied. See get_loss for list of available losses.
|
171 |
+
temperature: float, temperature for NCE loss
|
172 |
+
span_loss_type: str, [l1, ce]
|
173 |
+
max_v_l: int,
|
174 |
+
saliency_margin: float
|
175 |
+
"""
|
176 |
+
super().__init__()
|
177 |
+
self.matcher = matcher
|
178 |
+
self.weight_dict = weight_dict
|
179 |
+
self.losses = losses
|
180 |
+
self.temperature = temperature
|
181 |
+
self.span_loss_type = span_loss_type
|
182 |
+
self.max_v_l = max_v_l
|
183 |
+
self.saliency_margin = saliency_margin
|
184 |
+
self.temperature = 0.07
|
185 |
+
|
186 |
+
# foreground and background classification
|
187 |
+
self.foreground_label = 0
|
188 |
+
self.background_label = 1
|
189 |
+
self.eos_coef = eos_coef
|
190 |
+
empty_weight = torch.ones(2)
|
191 |
+
empty_weight[-1] = self.eos_coef # lower weight for background (index 1, foreground index 0)
|
192 |
+
self.register_buffer('empty_weight', empty_weight)
|
193 |
+
|
194 |
+
def loss_spans(self, outputs, targets, indices):
|
195 |
+
assert 'pred_spans' in outputs
|
196 |
+
|
197 |
+
start_spans = targets['timestamp']
|
198 |
+
pred_spans = outputs['pred_spans']
|
199 |
+
src_spans = start_spans + pred_spans
|
200 |
+
gt_spans = targets['span_labels_nn']
|
201 |
+
|
202 |
+
mask = targets['timestamp_mask'].bool()
|
203 |
+
mask_full = targets['timestamp_mask'].unsqueeze(2).repeat(1, 1, 2)
|
204 |
+
mask_valid = targets['timestamp_window'].bool()
|
205 |
+
mask_valid_full = targets['timestamp_window'].unsqueeze(2).repeat(1, 1, 2)
|
206 |
+
|
207 |
+
loss_span = F.smooth_l1_loss(src_spans, gt_spans, reduction='none') * mask_valid_full
|
208 |
+
loss_giou = 1 - torch.diag(generalized_temporal_iou(src_spans[mask_valid], gt_spans[mask_valid]))
|
209 |
+
|
210 |
+
losses = {}
|
211 |
+
losses['loss_b'] = loss_span.sum() / mask_valid.sum()
|
212 |
+
losses['loss_g'] = loss_giou.mean()
|
213 |
+
return losses
|
214 |
+
|
215 |
+
def loss_labels(self, outputs, targets, indices, log=True):
|
216 |
+
src_logits = outputs['pred_logits'].squeeze(-1) # (batch_size, #queries, #classes=2)
|
217 |
+
mask = targets['timestamp_mask'].bool()
|
218 |
+
mask_valid = targets['timestamp_window'].bool()
|
219 |
+
target_classes = torch.full(src_logits.shape[:2], 0, dtype=torch.int64, device=src_logits.device) # (batch_size, #queries)
|
220 |
+
target_classes[mask_valid] = 1
|
221 |
+
# target_classes = targets['timestamp_window'] # soft cls.
|
222 |
+
target_classes.float()
|
223 |
+
# pdb.set_trace()
|
224 |
+
|
225 |
+
weights = torch.zeros_like(target_classes).float()
|
226 |
+
weights[mask] = self.empty_weight[1]
|
227 |
+
weights[mask_valid] = self.empty_weight[0]
|
228 |
+
|
229 |
+
# pdb.set_trace()
|
230 |
+
loss_ce = F.binary_cross_entropy(src_logits, target_classes.float(), weight=weights, reduction="none") * mask
|
231 |
+
return {"loss_f": loss_ce.sum() / mask.sum()}
|
232 |
+
# return {"loss_f": loss_ce.sum() / (1 + mask_valid.sum())}
|
233 |
+
|
234 |
+
def loss_saliency(self, outputs, targets, indices, log=True):
|
235 |
+
"""higher scores for positive clips"""
|
236 |
+
if "saliency_pos_labels" not in targets:
|
237 |
+
return {"loss_s_inter": 0., "loss_s_intra": 0.}
|
238 |
+
saliency_scores = targets["saliency_scores"]
|
239 |
+
if saliency_scores.sum() == 0:
|
240 |
+
return {"loss_s_inter": 0., "loss_s_intra": 0.}
|
241 |
+
|
242 |
+
# * inter-vid mode
|
243 |
+
vid_mem_proj = outputs["vid_mem_proj"]
|
244 |
+
pos_indices = targets["saliency_pos_labels"][:,0].long() # (N, #pairs)
|
245 |
+
batch_indices = torch.arange(len(vid_mem_proj)).to(vid_mem_proj.device)
|
246 |
+
|
247 |
+
vid_feats = vid_mem_proj[batch_indices, pos_indices]
|
248 |
+
txt_feats = outputs["txt_mem_proj"].squeeze(1)
|
249 |
+
sim = sim_matrix(vid_feats, txt_feats)
|
250 |
+
|
251 |
+
i_logsm = F.log_softmax(sim / self.temperature, dim=1)
|
252 |
+
j_logsm = F.log_softmax(sim.t() /self.temperature, dim=1)
|
253 |
+
|
254 |
+
# sum over positives
|
255 |
+
idiag = torch.diag(i_logsm)
|
256 |
+
jdiag = torch.diag(j_logsm)
|
257 |
+
loss_i = idiag.sum() / len(idiag)
|
258 |
+
loss_j = jdiag.sum() / len(jdiag)
|
259 |
+
|
260 |
+
loss_saliency_inter = - loss_i - loss_j
|
261 |
+
|
262 |
+
# * intra-vid mode
|
263 |
+
mask = targets['timestamp_mask']
|
264 |
+
selected_scores = saliency_scores[batch_indices, pos_indices].unsqueeze(-1)
|
265 |
+
neg_indices_in = (saliency_scores < selected_scores)
|
266 |
+
neg_indices_in[batch_indices, pos_indices] = True
|
267 |
+
mask_invalid = neg_indices_in * mask.bool()
|
268 |
+
|
269 |
+
sim_in = F.cosine_similarity(vid_mem_proj, txt_feats.unsqueeze(1), dim=-1)
|
270 |
+
sim_in = sim_in + (mask_invalid + 1e-45).log()
|
271 |
+
logsm_in_i = F.log_softmax(sim_in / self.temperature, dim=1)
|
272 |
+
logsm_in_j = F.log_softmax(sim_in.t() / self.temperature, dim=1)
|
273 |
+
|
274 |
+
pos_logsm_in_i = logsm_in_i[batch_indices, pos_indices]
|
275 |
+
pos_logsm_in_j = logsm_in_j[pos_indices, batch_indices]
|
276 |
+
loss_in_i = pos_logsm_in_i.sum() / len(pos_logsm_in_i)
|
277 |
+
loss_in_j = pos_logsm_in_j.sum() / len(pos_logsm_in_j)
|
278 |
+
|
279 |
+
loss_saliency_intra = - loss_in_i - loss_in_j
|
280 |
+
|
281 |
+
return {"loss_s_inter": loss_saliency_inter, "loss_s_intra": loss_saliency_intra}
|
282 |
+
|
283 |
+
def loss_saliency_cls(self, outputs, targets, indices, log=True):
|
284 |
+
"""higher scores for positive clips"""
|
285 |
+
if "saliency_pos_labels" not in targets:
|
286 |
+
return {"loss_s_inter": 0., "loss_s_intra": 0.}
|
287 |
+
saliency_scores = targets["saliency_scores"]
|
288 |
+
if saliency_scores.sum() == 0:
|
289 |
+
return {"loss_s_inter": 0., "loss_s_intra": 0.}
|
290 |
+
|
291 |
+
# * inter-vid mode
|
292 |
+
vid_mem_proj = outputs["vid_mem_proj"]
|
293 |
+
pos_indices = targets["saliency_pos_labels"][:,0].long() # (N, #pairs)
|
294 |
+
batch_indices = torch.arange(len(vid_mem_proj)).to(vid_mem_proj.device)
|
295 |
+
|
296 |
+
vid_feats = vid_mem_proj[batch_indices, pos_indices]
|
297 |
+
txt_feats = outputs["txt_mem_proj"].squeeze(1)
|
298 |
+
sim = sim_matrix(vid_feats, txt_feats)
|
299 |
+
|
300 |
+
i_logsm = F.log_softmax(sim / self.temperature, dim=1)
|
301 |
+
j_logsm = F.log_softmax(sim.t() /self.temperature, dim=1)
|
302 |
+
|
303 |
+
# sum over positives
|
304 |
+
idiag = torch.diag(i_logsm)
|
305 |
+
jdiag = torch.diag(j_logsm)
|
306 |
+
loss_i = idiag.sum() / len(idiag)
|
307 |
+
loss_j = jdiag.sum() / len(jdiag)
|
308 |
+
|
309 |
+
loss_saliency_inter = - loss_i - loss_j
|
310 |
+
|
311 |
+
# * intra-vid mode
|
312 |
+
if 'cls_idx' not in targets.keys(): # eval
|
313 |
+
return {"loss_s_inter": loss_saliency_inter}
|
314 |
+
|
315 |
+
cls_indices = targets['cls_idx'].bool()
|
316 |
+
cls_feats = outputs["cls_mem_proj"].squeeze(1)
|
317 |
+
sim_cls = sim_matrix(vid_feats, cls_feats)
|
318 |
+
|
319 |
+
i_logsm_cls = F.log_softmax(sim_cls / self.temperature, dim=1)
|
320 |
+
idiag_cls = i_logsm_cls[cls_indices]
|
321 |
+
loss_cls_i = idiag_cls.sum() / len(idiag_cls)
|
322 |
+
|
323 |
+
loss_saliency_intra = - loss_cls_i
|
324 |
+
|
325 |
+
return {"loss_s_inter": loss_saliency_inter, "loss_s_intra": loss_saliency_intra}
|
326 |
+
|
327 |
+
def get_loss(self, loss, outputs, targets, indices, **kwargs):
|
328 |
+
loss_map = {
|
329 |
+
"spans": self.loss_spans,
|
330 |
+
"labels": self.loss_labels,
|
331 |
+
"saliency": self.loss_saliency,
|
332 |
+
"saliency_cls": self.loss_saliency_cls,
|
333 |
+
}
|
334 |
+
assert loss in loss_map, f'do you really want to compute {loss} loss?'
|
335 |
+
return loss_map[loss](outputs, targets, indices, **kwargs)
|
336 |
+
|
337 |
+
def forward(self, outputs, targets, hl_only=False):
|
338 |
+
""" This performs the loss computation.
|
339 |
+
Parameters:
|
340 |
+
outputs: dict of tensors, see the output specification of the model for the format
|
341 |
+
targets: list of dicts, such that len(targets) == batch_size.
|
342 |
+
The expected keys in each dict depends on the losses applied, see each loss' doc
|
343 |
+
"""
|
344 |
+
indices = None
|
345 |
+
# Compute all the requested losses
|
346 |
+
losses = {}
|
347 |
+
for loss in self.losses:
|
348 |
+
losses.update(self.get_loss(loss, outputs, targets, indices))
|
349 |
+
|
350 |
+
return losses
|
351 |
+
|
352 |
+
class MLP(nn.Module):
|
353 |
+
""" Very simple multi-layer perceptron (also called FFN)"""
|
354 |
+
|
355 |
+
def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
|
356 |
+
super().__init__()
|
357 |
+
self.num_layers = num_layers
|
358 |
+
h = [hidden_dim] * (num_layers - 1)
|
359 |
+
self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
|
360 |
+
|
361 |
+
def forward(self, x):
|
362 |
+
for i, layer in enumerate(self.layers):
|
363 |
+
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
|
364 |
+
return x
|
365 |
+
|
366 |
+
class Conv(nn.Module):
|
367 |
+
""" Very simple multi-layer perceptron (also called FFN)"""
|
368 |
+
|
369 |
+
def __init__(self, input_dim, hidden_dim, output_dim, num_layers, kernel_size):
|
370 |
+
super().__init__()
|
371 |
+
self.num_layers = num_layers
|
372 |
+
h = [hidden_dim] * (num_layers - 1)
|
373 |
+
# self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
|
374 |
+
self.layers = nn.ModuleList(
|
375 |
+
nn.Conv1d(n, k, kernel_size=kernel_size, stride=1, padding=kernel_size//2, dilation=1, groups=1, bias=True, padding_mode='zeros')
|
376 |
+
for n, k in zip([input_dim] + h, h + [output_dim]))
|
377 |
+
def forward(self, x):
|
378 |
+
x = x.permute(0,2,1)
|
379 |
+
for i, layer in enumerate(self.layers):
|
380 |
+
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
|
381 |
+
return x.permute(0, 2, 1)
|
382 |
+
|
383 |
+
class LinearLayer(nn.Module):
|
384 |
+
"""linear layer configurable with layer normalization, dropout, ReLU."""
|
385 |
+
|
386 |
+
def __init__(self, in_hsz, out_hsz, layer_norm=True, dropout=0.1, relu=True):
|
387 |
+
super(LinearLayer, self).__init__()
|
388 |
+
self.relu = relu
|
389 |
+
self.layer_norm = layer_norm
|
390 |
+
if layer_norm:
|
391 |
+
self.LayerNorm = nn.LayerNorm(in_hsz)
|
392 |
+
layers = [
|
393 |
+
nn.Dropout(dropout),
|
394 |
+
nn.Linear(in_hsz, out_hsz)
|
395 |
+
]
|
396 |
+
self.net = nn.Sequential(*layers)
|
397 |
+
|
398 |
+
def forward(self, x):
|
399 |
+
"""(N, L, D)"""
|
400 |
+
if self.layer_norm:
|
401 |
+
x = self.LayerNorm(x)
|
402 |
+
x = self.net(x)
|
403 |
+
if self.relu:
|
404 |
+
x = F.relu(x, inplace=True)
|
405 |
+
return x # (N, L, D)
|
406 |
+
|
407 |
+
|
408 |
+
def build_model(args):
|
409 |
+
device = torch.device(args.device)
|
410 |
+
|
411 |
+
transformer = build_transformer(args)
|
412 |
+
position_embedding, txt_position_embedding = build_position_encoding(args)
|
413 |
+
|
414 |
+
model = Model(
|
415 |
+
transformer,
|
416 |
+
position_embedding,
|
417 |
+
txt_position_embedding,
|
418 |
+
txt_dim=args.t_feat_dim,
|
419 |
+
vid_dim=args.v_feat_dim,
|
420 |
+
input_dropout=args.input_dropout,
|
421 |
+
span_loss_type=args.span_loss_type,
|
422 |
+
use_txt_pos=args.use_txt_pos,
|
423 |
+
n_input_proj=args.n_input_proj,
|
424 |
+
)
|
425 |
+
|
426 |
+
matcher = build_matcher(args)
|
427 |
+
weight_dict = {"loss_b": args.b_loss_coef,
|
428 |
+
"loss_g": args.g_loss_coef,
|
429 |
+
"loss_f": args.f_loss_coef,
|
430 |
+
"loss_s_intra": args.s_loss_intra_coef,
|
431 |
+
"loss_s_inter": args.s_loss_inter_coef}
|
432 |
+
|
433 |
+
if args.dset_type in ['mr', 'vlp']:
|
434 |
+
if 'tal' not in args.train_path:
|
435 |
+
losses = ['spans', 'labels', 'saliency']
|
436 |
+
else:
|
437 |
+
losses = ['spans', 'labels', 'saliency_cls']
|
438 |
+
elif args.dset_type in ['hl', 'vs']:
|
439 |
+
losses = ['labels', 'saliency']
|
440 |
+
|
441 |
+
criterion = SetCriterion(
|
442 |
+
matcher=matcher,
|
443 |
+
weight_dict=weight_dict, losses=losses,
|
444 |
+
eos_coef=args.eos_coef, temperature=args.temperature,
|
445 |
+
span_loss_type=args.span_loss_type, max_v_l=args.max_v_l,
|
446 |
+
saliency_margin=args.saliency_margin,
|
447 |
+
)
|
448 |
+
criterion.to(device)
|
449 |
+
return model, criterion
|
model/base_albef.py
ADDED
@@ -0,0 +1,478 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pdb
|
2 |
+
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from torch import nn
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
from model.transformer_encoder import build_transformer, Transformer
|
8 |
+
from model.matcher import build_matcher
|
9 |
+
from model.position_encoding import build_position_encoding
|
10 |
+
from utils.span_utils import generalized_temporal_iou, span_cxw_to_xx
|
11 |
+
|
12 |
+
def init_weights(module):
|
13 |
+
if isinstance(module, (nn.Linear, nn.Embedding)):
|
14 |
+
module.weight.data.normal_(mean=0.0, std=0.02)
|
15 |
+
elif isinstance(module, nn.LayerNorm):
|
16 |
+
module.bias.data.zero_()
|
17 |
+
module.weight.data.fill_(1.0)
|
18 |
+
|
19 |
+
if isinstance(module, nn.Linear) and module.bias is not None:
|
20 |
+
module.bias.data.zero_()
|
21 |
+
|
22 |
+
def mask_logits(inputs, mask, mask_value=-1e30):
|
23 |
+
mask = mask.type(torch.float32)
|
24 |
+
return inputs + (1.0 - mask) * mask_value
|
25 |
+
|
26 |
+
def sim_matrix(a, b, eps=1e-8):
|
27 |
+
"""
|
28 |
+
added eps for numerical stability
|
29 |
+
"""
|
30 |
+
a_n, b_n = a.norm(dim=1)[:, None], b.norm(dim=1)[:, None]
|
31 |
+
a_norm = a / torch.max(a_n, eps * torch.ones_like(a_n))
|
32 |
+
b_norm = b / torch.max(b_n, eps * torch.ones_like(b_n))
|
33 |
+
sim_mt = torch.mm(a_norm, b_norm.transpose(0, 1))
|
34 |
+
return sim_mt
|
35 |
+
|
36 |
+
class WeightedPool(nn.Module):
|
37 |
+
def __init__(self, dim):
|
38 |
+
super(WeightedPool, self).__init__()
|
39 |
+
weight = torch.empty(dim, 1)
|
40 |
+
nn.init.xavier_uniform_(weight)
|
41 |
+
self.weight = nn.Parameter(weight, requires_grad=True)
|
42 |
+
|
43 |
+
def forward(self, x, mask):
|
44 |
+
alpha = torch.tensordot(x, self.weight, dims=1) # shape = (batch_size, seq_length, 1)
|
45 |
+
alpha = mask_logits(alpha, mask=mask.unsqueeze(2))
|
46 |
+
alphas = nn.Softmax(dim=1)(alpha)
|
47 |
+
pooled_x = torch.matmul(x.transpose(1, 2), alphas) # (batch_size, dim, 1)
|
48 |
+
pooled_x = pooled_x.squeeze(2)
|
49 |
+
return pooled_x
|
50 |
+
|
51 |
+
class Model(nn.Module):
|
52 |
+
""" This is the UniVTG module that performs moment localization. """
|
53 |
+
|
54 |
+
def __init__(self, transformer_mm, transformer_v, transformer_t, position_embed, txt_position_embed, txt_dim, vid_dim,
|
55 |
+
input_dropout, aux_loss=False,
|
56 |
+
max_v_l=75, span_loss_type="l1", use_txt_pos=False, n_input_proj=2):
|
57 |
+
""" Initializes the model.
|
58 |
+
Parameters:
|
59 |
+
transformer: torch module of the transformer architecture. See transformer.py
|
60 |
+
position_embed: torch module of the position_embedding, See position_encoding.py
|
61 |
+
txt_position_embed: position_embedding for text
|
62 |
+
txt_dim: int, text query input dimension
|
63 |
+
vid_dim: int, video feature input dimension
|
64 |
+
max_v_l: int, maximum #clips in videos
|
65 |
+
span_loss_type: str, one of [l1, ce]
|
66 |
+
l1: (center-x, width) regression.
|
67 |
+
ce: (st_idx, ed_idx) classification.
|
68 |
+
# foreground_thd: float, intersection over prediction >= foreground_thd: labeled as foreground
|
69 |
+
# background_thd: float, intersection over prediction <= background_thd: labeled background
|
70 |
+
"""
|
71 |
+
super().__init__()
|
72 |
+
self.transformer = transformer_mm
|
73 |
+
self.position_embed = position_embed
|
74 |
+
self.txt_position_embed = txt_position_embed
|
75 |
+
hidden_dim = transformer_mm.d_model
|
76 |
+
self.span_loss_type = span_loss_type
|
77 |
+
self.max_v_l = max_v_l
|
78 |
+
span_pred_dim = 2 if span_loss_type == "l1" else max_v_l * 2
|
79 |
+
|
80 |
+
self.token_type_embeddings = nn.Embedding(2, hidden_dim)
|
81 |
+
self.token_type_embeddings.apply(init_weights)
|
82 |
+
|
83 |
+
# Conv projector
|
84 |
+
self.span_embed = Conv(hidden_dim, hidden_dim, span_pred_dim, 3, kernel_size=3)
|
85 |
+
self.class_embed = Conv(hidden_dim, hidden_dim, 1, 3, kernel_size=3) # 0: background, 1: foreground
|
86 |
+
|
87 |
+
self.use_txt_pos = use_txt_pos
|
88 |
+
self.n_input_proj = n_input_proj
|
89 |
+
relu_args = [True] * 3
|
90 |
+
relu_args[n_input_proj-1] = False
|
91 |
+
self.input_txt_proj = nn.Sequential(*[
|
92 |
+
LinearLayer(txt_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[0]),
|
93 |
+
LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[1]),
|
94 |
+
LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[2])
|
95 |
+
][:n_input_proj])
|
96 |
+
self.input_vid_proj = nn.Sequential(*[
|
97 |
+
LinearLayer(vid_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[0]),
|
98 |
+
LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[1]),
|
99 |
+
LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[2])
|
100 |
+
][:n_input_proj])
|
101 |
+
|
102 |
+
self.transformer_v = transformer_v
|
103 |
+
self.transformer_t = transformer_t
|
104 |
+
|
105 |
+
# MLP Projector
|
106 |
+
self.weightedpool = WeightedPool(hidden_dim)
|
107 |
+
|
108 |
+
def forward(self, src_txt, src_txt_mask, src_vid, src_vid_mask, src_cls=None, src_cls_mask=None):
|
109 |
+
bs = src_vid.shape[0]
|
110 |
+
src_vid = self.input_vid_proj(src_vid)
|
111 |
+
src_txt = self.input_txt_proj(src_txt)
|
112 |
+
if src_cls is not None:
|
113 |
+
src_cls = self.input_txt_proj(src_cls)
|
114 |
+
|
115 |
+
# pos embed.
|
116 |
+
pos_vid = self.position_embed(src_vid, src_vid_mask) # (bsz, L_vid, d)
|
117 |
+
pos_txt = self.txt_position_embed(src_txt) if self.use_txt_pos else torch.zeros_like(src_txt) # (bsz, L_txt, d)
|
118 |
+
|
119 |
+
src_vid = self.transformer_v(src_vid, ~src_vid_mask.bool(), pos_vid)
|
120 |
+
src_txt = self.transformer_t(src_txt, ~src_txt_mask.bool(), pos_txt)
|
121 |
+
|
122 |
+
# type token.
|
123 |
+
src_vid = src_vid + self.token_type_embeddings(torch.full_like(src_vid_mask.long(), 1))
|
124 |
+
src_txt = src_txt + self.token_type_embeddings(torch.zeros_like(src_txt_mask.long()))
|
125 |
+
if src_cls is not None:
|
126 |
+
src_cls = src_cls + self.token_type_embeddings(torch.zeros_like(src_cls_mask.long()))
|
127 |
+
|
128 |
+
src = torch.cat([src_vid, src_txt], dim=1) # (bsz, L_vid+L_txt, d)
|
129 |
+
mask = torch.cat([src_vid_mask, src_txt_mask], dim=1).bool() # (bsz, L_vid+L_txt)
|
130 |
+
pos = torch.cat([pos_vid, pos_txt], dim=1)
|
131 |
+
|
132 |
+
memory = self.transformer(src, ~mask, pos)
|
133 |
+
vid_mem = memory[:, :src_vid.shape[1], :] # (bsz, L_vid, d)
|
134 |
+
|
135 |
+
outputs_class = self.class_embed(vid_mem).sigmoid() # (#layers, batch_size, #queries, #classes)
|
136 |
+
outputs_coord = self.span_embed(vid_mem) # (#layers, bsz, #queries, 2 or max_v_l * 2)
|
137 |
+
|
138 |
+
if self.span_loss_type == "l1":
|
139 |
+
outputs_coord = outputs_coord.sigmoid()
|
140 |
+
idx_mask = torch.tensor((-1, 1)).unsqueeze(0).unsqueeze(0).cuda()
|
141 |
+
idx_mask = idx_mask.repeat(outputs_coord.shape[0], outputs_coord.shape[1], 1)
|
142 |
+
outputs_coord = outputs_coord * idx_mask
|
143 |
+
else:
|
144 |
+
raise NotImplementedError
|
145 |
+
|
146 |
+
out = {'pred_logits': outputs_class, 'pred_spans': outputs_coord,
|
147 |
+
'src_vid_mask': src_vid_mask}
|
148 |
+
|
149 |
+
vid_mem_proj = src_vid
|
150 |
+
|
151 |
+
# word-level -> sentence-level
|
152 |
+
txt_mem_proj = self.weightedpool(src_txt, src_txt_mask).unsqueeze(1)
|
153 |
+
sim = F.cosine_similarity(vid_mem_proj, txt_mem_proj, dim=-1) + (src_vid_mask + 1e-45).log()
|
154 |
+
|
155 |
+
out["vid_mem_proj"] = vid_mem_proj
|
156 |
+
out["txt_mem_proj"] = txt_mem_proj
|
157 |
+
if src_cls is not None:
|
158 |
+
cls_mem_proj = self.weightedpool(src_cls, src_cls_mask)
|
159 |
+
out["cls_mem_proj"] = cls_mem_proj
|
160 |
+
out["saliency_scores"] = sim
|
161 |
+
return out
|
162 |
+
|
163 |
+
class SetCriterion(nn.Module):
|
164 |
+
""" This class computes the loss for DETR.
|
165 |
+
The process happens in two steps:
|
166 |
+
1) we compute hungarian assignment between ground truth boxes and the outputs of the model
|
167 |
+
2) we supervise each pair of matched ground-truth / prediction (supervise class and box)
|
168 |
+
"""
|
169 |
+
|
170 |
+
def __init__(self, matcher, weight_dict, eos_coef, losses, temperature, span_loss_type, max_v_l,
|
171 |
+
saliency_margin=1):
|
172 |
+
""" Create the criterion.
|
173 |
+
Parameters:
|
174 |
+
matcher: module able to compute a matching between targets and proposals
|
175 |
+
weight_dict: dict containing as key the names of the losses and as values their relative weight.
|
176 |
+
eos_coef: relative classification weight applied to the no-object category
|
177 |
+
losses: list of all the losses to be applied. See get_loss for list of available losses.
|
178 |
+
temperature: float, temperature for NCE loss
|
179 |
+
span_loss_type: str, [l1, ce]
|
180 |
+
max_v_l: int,
|
181 |
+
saliency_margin: float
|
182 |
+
"""
|
183 |
+
super().__init__()
|
184 |
+
self.matcher = matcher
|
185 |
+
self.weight_dict = weight_dict
|
186 |
+
self.losses = losses
|
187 |
+
self.temperature = temperature
|
188 |
+
self.span_loss_type = span_loss_type
|
189 |
+
self.max_v_l = max_v_l
|
190 |
+
self.saliency_margin = saliency_margin
|
191 |
+
self.temperature = 0.07
|
192 |
+
|
193 |
+
# foreground and background classification
|
194 |
+
self.foreground_label = 0
|
195 |
+
self.background_label = 1
|
196 |
+
self.eos_coef = eos_coef
|
197 |
+
empty_weight = torch.ones(2)
|
198 |
+
empty_weight[-1] = self.eos_coef # lower weight for background (index 1, foreground index 0)
|
199 |
+
self.register_buffer('empty_weight', empty_weight)
|
200 |
+
|
201 |
+
def loss_spans(self, outputs, targets, indices):
|
202 |
+
assert 'pred_spans' in outputs
|
203 |
+
|
204 |
+
start_spans = targets['timestamp']
|
205 |
+
pred_spans = outputs['pred_spans']
|
206 |
+
src_spans = start_spans + pred_spans
|
207 |
+
gt_spans = targets['span_labels_nn']
|
208 |
+
|
209 |
+
mask = targets['timestamp_mask'].bool()
|
210 |
+
mask_full = targets['timestamp_mask'].unsqueeze(2).repeat(1, 1, 2)
|
211 |
+
mask_valid = targets['timestamp_window'].bool()
|
212 |
+
mask_valid_full = targets['timestamp_window'].unsqueeze(2).repeat(1, 1, 2)
|
213 |
+
|
214 |
+
loss_span = F.smooth_l1_loss(src_spans, gt_spans, reduction='none') * mask_valid_full
|
215 |
+
loss_giou = 1 - torch.diag(generalized_temporal_iou(src_spans[mask_valid], gt_spans[mask_valid]))
|
216 |
+
|
217 |
+
losses = {}
|
218 |
+
losses['loss_b'] = loss_span.sum() / mask_valid.sum()
|
219 |
+
losses['loss_g'] = loss_giou.mean()
|
220 |
+
return losses
|
221 |
+
|
222 |
+
def loss_labels(self, outputs, targets, indices, log=True):
|
223 |
+
src_logits = outputs['pred_logits'].squeeze(-1) # (batch_size, #queries, #classes=2)
|
224 |
+
mask = targets['timestamp_mask'].bool()
|
225 |
+
mask_valid = targets['timestamp_window'].bool()
|
226 |
+
target_classes = torch.full(src_logits.shape[:2], 0, dtype=torch.int64, device=src_logits.device) # (batch_size, #queries)
|
227 |
+
target_classes[mask_valid] = 1
|
228 |
+
# target_classes = targets['timestamp_window'] # soft cls.
|
229 |
+
target_classes.float()
|
230 |
+
# pdb.set_trace()
|
231 |
+
|
232 |
+
weights = torch.zeros_like(target_classes).float()
|
233 |
+
weights[mask] = self.empty_weight[1]
|
234 |
+
weights[mask_valid] = self.empty_weight[0]
|
235 |
+
|
236 |
+
loss_ce = F.binary_cross_entropy(src_logits, target_classes.float(), weight=weights, reduction="none") * mask
|
237 |
+
return {"loss_f": loss_ce.sum() / mask.sum()}
|
238 |
+
|
239 |
+
def loss_saliency(self, outputs, targets, indices, log=True):
|
240 |
+
"""higher scores for positive clips"""
|
241 |
+
if "saliency_pos_labels" not in targets:
|
242 |
+
return {"loss_s_inter": 0., "loss_s_intra": 0.}
|
243 |
+
saliency_scores = targets["saliency_scores"]
|
244 |
+
if saliency_scores.sum() == 0:
|
245 |
+
return {"loss_s_inter": 0., "loss_s_intra": 0.}
|
246 |
+
|
247 |
+
# * inter-vid mode
|
248 |
+
vid_mem_proj = outputs["vid_mem_proj"]
|
249 |
+
pos_indices = targets["saliency_pos_labels"][:,0].long() # (N, #pairs)
|
250 |
+
batch_indices = torch.arange(len(vid_mem_proj)).to(vid_mem_proj.device)
|
251 |
+
|
252 |
+
vid_feats = vid_mem_proj[batch_indices, pos_indices]
|
253 |
+
txt_feats = outputs["txt_mem_proj"].squeeze(1)
|
254 |
+
sim = sim_matrix(vid_feats, txt_feats)
|
255 |
+
|
256 |
+
i_logsm = F.log_softmax(sim / self.temperature, dim=1)
|
257 |
+
j_logsm = F.log_softmax(sim.t() /self.temperature, dim=1)
|
258 |
+
|
259 |
+
# sum over positives
|
260 |
+
idiag = torch.diag(i_logsm)
|
261 |
+
jdiag = torch.diag(j_logsm)
|
262 |
+
loss_i = idiag.sum() / len(idiag)
|
263 |
+
loss_j = jdiag.sum() / len(jdiag)
|
264 |
+
|
265 |
+
loss_saliency_inter = - loss_i - loss_j
|
266 |
+
|
267 |
+
# * intra-vid mode
|
268 |
+
mask = targets['timestamp_mask']
|
269 |
+
selected_scores = saliency_scores[batch_indices, pos_indices].unsqueeze(-1)
|
270 |
+
neg_indices_in = (saliency_scores < selected_scores)
|
271 |
+
neg_indices_in[batch_indices, pos_indices] = True
|
272 |
+
mask_invalid = neg_indices_in * mask.bool()
|
273 |
+
|
274 |
+
sim_in = F.cosine_similarity(vid_mem_proj, txt_feats.unsqueeze(1), dim=-1)
|
275 |
+
sim_in = sim_in + (mask_invalid + 1e-45).log()
|
276 |
+
logsm_in_i = F.log_softmax(sim_in / self.temperature, dim=1)
|
277 |
+
logsm_in_j = F.log_softmax(sim_in.t() / self.temperature, dim=1)
|
278 |
+
|
279 |
+
pos_logsm_in_i = logsm_in_i[batch_indices, pos_indices]
|
280 |
+
pos_logsm_in_j = logsm_in_j[pos_indices, batch_indices]
|
281 |
+
loss_in_i = pos_logsm_in_i.sum() / len(pos_logsm_in_i)
|
282 |
+
loss_in_j = pos_logsm_in_j.sum() / len(pos_logsm_in_j)
|
283 |
+
|
284 |
+
loss_saliency_intra = - loss_in_i - loss_in_j
|
285 |
+
|
286 |
+
return {"loss_s_inter": loss_saliency_inter, "loss_s_intra": loss_saliency_intra}
|
287 |
+
|
288 |
+
def loss_saliency_cls(self, outputs, targets, indices, log=True):
|
289 |
+
"""higher scores for positive clips"""
|
290 |
+
if "saliency_pos_labels" not in targets:
|
291 |
+
return {"loss_s_inter": 0., "loss_s_intra": 0.}
|
292 |
+
saliency_scores = targets["saliency_scores"]
|
293 |
+
if saliency_scores.sum() == 0:
|
294 |
+
return {"loss_s_inter": 0., "loss_s_intra": 0.}
|
295 |
+
|
296 |
+
# * inter-vid mode
|
297 |
+
vid_mem_proj = outputs["vid_mem_proj"]
|
298 |
+
pos_indices = targets["saliency_pos_labels"][:,0].long() # (N, #pairs)
|
299 |
+
batch_indices = torch.arange(len(vid_mem_proj)).to(vid_mem_proj.device)
|
300 |
+
|
301 |
+
vid_feats = vid_mem_proj[batch_indices, pos_indices]
|
302 |
+
txt_feats = outputs["txt_mem_proj"].squeeze(1)
|
303 |
+
sim = sim_matrix(vid_feats, txt_feats)
|
304 |
+
|
305 |
+
i_logsm = F.log_softmax(sim / self.temperature, dim=1)
|
306 |
+
j_logsm = F.log_softmax(sim.t() /self.temperature, dim=1)
|
307 |
+
|
308 |
+
# sum over positives
|
309 |
+
idiag = torch.diag(i_logsm)
|
310 |
+
jdiag = torch.diag(j_logsm)
|
311 |
+
loss_i = idiag.sum() / len(idiag)
|
312 |
+
loss_j = jdiag.sum() / len(jdiag)
|
313 |
+
|
314 |
+
loss_saliency_inter = - loss_i - loss_j
|
315 |
+
|
316 |
+
# * intra-vid mode
|
317 |
+
if 'cls_idx' not in targets.keys(): # eval
|
318 |
+
return {"loss_s_inter": loss_saliency_inter}
|
319 |
+
|
320 |
+
cls_indices = targets['cls_idx'].bool()
|
321 |
+
cls_feats = outputs["cls_mem_proj"].squeeze(1)
|
322 |
+
sim_cls = sim_matrix(vid_feats, cls_feats)
|
323 |
+
|
324 |
+
i_logsm_cls = F.log_softmax(sim_cls / self.temperature, dim=1)
|
325 |
+
idiag_cls = i_logsm_cls[cls_indices]
|
326 |
+
loss_cls_i = idiag_cls.sum() / len(idiag_cls)
|
327 |
+
|
328 |
+
loss_saliency_intra = - loss_cls_i
|
329 |
+
|
330 |
+
return {"loss_s_inter": loss_saliency_inter, "loss_s_intra": loss_saliency_intra}
|
331 |
+
|
332 |
+
def get_loss(self, loss, outputs, targets, indices, **kwargs):
|
333 |
+
loss_map = {
|
334 |
+
"spans": self.loss_spans,
|
335 |
+
"labels": self.loss_labels,
|
336 |
+
"saliency": self.loss_saliency,
|
337 |
+
"saliency_cls": self.loss_saliency_cls,
|
338 |
+
}
|
339 |
+
assert loss in loss_map, f'do you really want to compute {loss} loss?'
|
340 |
+
return loss_map[loss](outputs, targets, indices, **kwargs)
|
341 |
+
|
342 |
+
def forward(self, outputs, targets, hl_only=False):
|
343 |
+
""" This performs the loss computation.
|
344 |
+
Parameters:
|
345 |
+
outputs: dict of tensors, see the output specification of the model for the format
|
346 |
+
targets: list of dicts, such that len(targets) == batch_size.
|
347 |
+
The expected keys in each dict depends on the losses applied, see each loss' doc
|
348 |
+
"""
|
349 |
+
indices = None
|
350 |
+
# Compute all the requested losses
|
351 |
+
losses = {}
|
352 |
+
for loss in self.losses:
|
353 |
+
losses.update(self.get_loss(loss, outputs, targets, indices))
|
354 |
+
|
355 |
+
return losses
|
356 |
+
|
357 |
+
class MLP(nn.Module):
|
358 |
+
""" Very simple multi-layer perceptron (also called FFN)"""
|
359 |
+
|
360 |
+
def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
|
361 |
+
super().__init__()
|
362 |
+
self.num_layers = num_layers
|
363 |
+
h = [hidden_dim] * (num_layers - 1)
|
364 |
+
self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
|
365 |
+
|
366 |
+
def forward(self, x):
|
367 |
+
for i, layer in enumerate(self.layers):
|
368 |
+
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
|
369 |
+
return x
|
370 |
+
|
371 |
+
class Conv(nn.Module):
|
372 |
+
""" Very simple multi-layer perceptron (also called FFN)"""
|
373 |
+
|
374 |
+
def __init__(self, input_dim, hidden_dim, output_dim, num_layers, kernel_size):
|
375 |
+
super().__init__()
|
376 |
+
self.num_layers = num_layers
|
377 |
+
h = [hidden_dim] * (num_layers - 1)
|
378 |
+
# self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
|
379 |
+
self.layers = nn.ModuleList(
|
380 |
+
nn.Conv1d(n, k, kernel_size=kernel_size, stride=1, padding=kernel_size//2, dilation=1, groups=1, bias=True, padding_mode='zeros')
|
381 |
+
for n, k in zip([input_dim] + h, h + [output_dim]))
|
382 |
+
def forward(self, x):
|
383 |
+
x = x.permute(0,2,1)
|
384 |
+
for i, layer in enumerate(self.layers):
|
385 |
+
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
|
386 |
+
return x.permute(0, 2, 1)
|
387 |
+
|
388 |
+
class LinearLayer(nn.Module):
|
389 |
+
"""linear layer configurable with layer normalization, dropout, ReLU."""
|
390 |
+
|
391 |
+
def __init__(self, in_hsz, out_hsz, layer_norm=True, dropout=0.1, relu=True):
|
392 |
+
super(LinearLayer, self).__init__()
|
393 |
+
self.relu = relu
|
394 |
+
self.layer_norm = layer_norm
|
395 |
+
if layer_norm:
|
396 |
+
self.LayerNorm = nn.LayerNorm(in_hsz)
|
397 |
+
layers = [
|
398 |
+
nn.Dropout(dropout),
|
399 |
+
nn.Linear(in_hsz, out_hsz)
|
400 |
+
]
|
401 |
+
self.net = nn.Sequential(*layers)
|
402 |
+
|
403 |
+
def forward(self, x):
|
404 |
+
"""(N, L, D)"""
|
405 |
+
if self.layer_norm:
|
406 |
+
x = self.LayerNorm(x)
|
407 |
+
x = self.net(x)
|
408 |
+
if self.relu:
|
409 |
+
x = F.relu(x, inplace=True)
|
410 |
+
return x # (N, L, D)
|
411 |
+
|
412 |
+
|
413 |
+
def build_model(args):
|
414 |
+
device = torch.device(args.device)
|
415 |
+
|
416 |
+
transformer_mm = build_transformer(args)
|
417 |
+
transformer_v = Transformer(
|
418 |
+
d_model=args.hidden_dim,
|
419 |
+
dropout=args.dropout,
|
420 |
+
nhead=args.nheads,
|
421 |
+
dim_feedforward=args.dim_feedforward,
|
422 |
+
num_encoder_layers=args.sub_enc_layers,
|
423 |
+
num_decoder_layers=args.dec_layers,
|
424 |
+
normalize_before=args.pre_norm,
|
425 |
+
return_intermediate_dec=True,
|
426 |
+
)
|
427 |
+
transformer_t = Transformer(
|
428 |
+
d_model=args.hidden_dim,
|
429 |
+
dropout=args.dropout,
|
430 |
+
nhead=args.nheads,
|
431 |
+
dim_feedforward=args.dim_feedforward,
|
432 |
+
num_encoder_layers=args.sub_enc_layers,
|
433 |
+
num_decoder_layers=args.dec_layers,
|
434 |
+
normalize_before=args.pre_norm,
|
435 |
+
return_intermediate_dec=True,
|
436 |
+
)
|
437 |
+
# pdb.set_trace()
|
438 |
+
|
439 |
+
position_embedding, txt_position_embedding = build_position_encoding(args)
|
440 |
+
|
441 |
+
model = Model(
|
442 |
+
transformer_mm,
|
443 |
+
transformer_v,
|
444 |
+
transformer_t,
|
445 |
+
position_embedding,
|
446 |
+
txt_position_embedding,
|
447 |
+
txt_dim=args.t_feat_dim,
|
448 |
+
vid_dim=args.v_feat_dim,
|
449 |
+
input_dropout=args.input_dropout,
|
450 |
+
span_loss_type=args.span_loss_type,
|
451 |
+
use_txt_pos=args.use_txt_pos,
|
452 |
+
n_input_proj=args.n_input_proj,
|
453 |
+
)
|
454 |
+
|
455 |
+
matcher = build_matcher(args)
|
456 |
+
weight_dict = {"loss_b": args.b_loss_coef,
|
457 |
+
"loss_g": args.g_loss_coef,
|
458 |
+
"loss_f": args.f_loss_coef,
|
459 |
+
"loss_s_intra": args.s_loss_intra_coef,
|
460 |
+
"loss_s_inter": args.s_loss_inter_coef}
|
461 |
+
|
462 |
+
if args.dset_type in ['mr', 'vlp']:
|
463 |
+
if 'tal' not in args.train_path:
|
464 |
+
losses = ['spans', 'labels', 'saliency']
|
465 |
+
else:
|
466 |
+
losses = ['spans', 'labels', 'saliency_cls']
|
467 |
+
elif args.dset_type in ['hl', 'vs']:
|
468 |
+
losses = ['labels', 'saliency']
|
469 |
+
|
470 |
+
criterion = SetCriterion(
|
471 |
+
matcher=matcher,
|
472 |
+
weight_dict=weight_dict, losses=losses,
|
473 |
+
eos_coef=args.eos_coef, temperature=args.temperature,
|
474 |
+
span_loss_type=args.span_loss_type, max_v_l=args.max_v_l,
|
475 |
+
saliency_margin=args.saliency_margin,
|
476 |
+
)
|
477 |
+
criterion.to(device)
|
478 |
+
return model, criterion
|
model/base_droppath.py
ADDED
@@ -0,0 +1,449 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pdb
|
2 |
+
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from torch import nn
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
from model.transformer_encoder_droppath import build_transformer
|
8 |
+
from model.matcher import build_matcher
|
9 |
+
from model.position_encoding import build_position_encoding
|
10 |
+
from utils.span_utils import generalized_temporal_iou, span_cxw_to_xx
|
11 |
+
|
12 |
+
def init_weights(module):
|
13 |
+
if isinstance(module, (nn.Linear, nn.Embedding)):
|
14 |
+
module.weight.data.normal_(mean=0.0, std=0.02)
|
15 |
+
elif isinstance(module, nn.LayerNorm):
|
16 |
+
module.bias.data.zero_()
|
17 |
+
module.weight.data.fill_(1.0)
|
18 |
+
|
19 |
+
if isinstance(module, nn.Linear) and module.bias is not None:
|
20 |
+
module.bias.data.zero_()
|
21 |
+
|
22 |
+
def mask_logits(inputs, mask, mask_value=-1e30):
|
23 |
+
mask = mask.type(torch.float32)
|
24 |
+
return inputs + (1.0 - mask) * mask_value
|
25 |
+
|
26 |
+
def sim_matrix(a, b, eps=1e-8):
|
27 |
+
"""
|
28 |
+
added eps for numerical stability
|
29 |
+
"""
|
30 |
+
a_n, b_n = a.norm(dim=1)[:, None], b.norm(dim=1)[:, None]
|
31 |
+
a_norm = a / torch.max(a_n, eps * torch.ones_like(a_n))
|
32 |
+
b_norm = b / torch.max(b_n, eps * torch.ones_like(b_n))
|
33 |
+
sim_mt = torch.mm(a_norm, b_norm.transpose(0, 1))
|
34 |
+
return sim_mt
|
35 |
+
|
36 |
+
class WeightedPool(nn.Module):
|
37 |
+
def __init__(self, dim):
|
38 |
+
super(WeightedPool, self).__init__()
|
39 |
+
weight = torch.empty(dim, 1)
|
40 |
+
nn.init.xavier_uniform_(weight)
|
41 |
+
self.weight = nn.Parameter(weight, requires_grad=True)
|
42 |
+
|
43 |
+
def forward(self, x, mask):
|
44 |
+
alpha = torch.tensordot(x, self.weight, dims=1) # shape = (batch_size, seq_length, 1)
|
45 |
+
alpha = mask_logits(alpha, mask=mask.unsqueeze(2))
|
46 |
+
alphas = nn.Softmax(dim=1)(alpha)
|
47 |
+
pooled_x = torch.matmul(x.transpose(1, 2), alphas) # (batch_size, dim, 1)
|
48 |
+
pooled_x = pooled_x.squeeze(2)
|
49 |
+
return pooled_x
|
50 |
+
|
51 |
+
class Model(nn.Module):
|
52 |
+
""" This is the UniVTG module that performs moment localization. """
|
53 |
+
|
54 |
+
def __init__(self, transformer, position_embed, txt_position_embed, txt_dim, vid_dim,
|
55 |
+
input_dropout, aux_loss=False,
|
56 |
+
max_v_l=75, span_loss_type="l1", use_txt_pos=False, n_input_proj=2):
|
57 |
+
""" Initializes the model.
|
58 |
+
Parameters:
|
59 |
+
transformer: torch module of the transformer architecture. See transformer.py
|
60 |
+
position_embed: torch module of the position_embedding, See position_encoding.py
|
61 |
+
txt_position_embed: position_embedding for text
|
62 |
+
txt_dim: int, text query input dimension
|
63 |
+
vid_dim: int, video feature input dimension
|
64 |
+
max_v_l: int, maximum #clips in videos
|
65 |
+
span_loss_type: str, one of [l1, ce]
|
66 |
+
l1: (center-x, width) regression.
|
67 |
+
ce: (st_idx, ed_idx) classification.
|
68 |
+
# foreground_thd: float, intersection over prediction >= foreground_thd: labeled as foreground
|
69 |
+
# background_thd: float, intersection over prediction <= background_thd: labeled background
|
70 |
+
"""
|
71 |
+
super().__init__()
|
72 |
+
self.transformer = transformer
|
73 |
+
self.position_embed = position_embed
|
74 |
+
self.txt_position_embed = txt_position_embed
|
75 |
+
hidden_dim = transformer.d_model
|
76 |
+
self.span_loss_type = span_loss_type
|
77 |
+
self.max_v_l = max_v_l
|
78 |
+
span_pred_dim = 2 if span_loss_type == "l1" else max_v_l * 2
|
79 |
+
|
80 |
+
self.token_type_embeddings = nn.Embedding(2, hidden_dim)
|
81 |
+
self.token_type_embeddings.apply(init_weights)
|
82 |
+
|
83 |
+
# Conv projector
|
84 |
+
self.span_embed = Conv(hidden_dim, hidden_dim, span_pred_dim, 3, kernel_size=3)
|
85 |
+
self.class_embed = Conv(hidden_dim, hidden_dim, 1, 3, kernel_size=3) # 0: background, 1: foreground
|
86 |
+
|
87 |
+
self.use_txt_pos = use_txt_pos
|
88 |
+
self.n_input_proj = n_input_proj
|
89 |
+
relu_args = [True] * 3
|
90 |
+
relu_args[n_input_proj-1] = False
|
91 |
+
self.input_txt_proj = nn.Sequential(*[
|
92 |
+
LinearLayer(txt_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[0]),
|
93 |
+
LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[1]),
|
94 |
+
LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[2])
|
95 |
+
][:n_input_proj])
|
96 |
+
self.input_vid_proj = nn.Sequential(*[
|
97 |
+
LinearLayer(vid_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[0]),
|
98 |
+
LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[1]),
|
99 |
+
LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[2])
|
100 |
+
][:n_input_proj])
|
101 |
+
|
102 |
+
# MLP Projector
|
103 |
+
self.weightedpool = WeightedPool(hidden_dim)
|
104 |
+
|
105 |
+
def forward(self, src_txt, src_txt_mask, src_vid, src_vid_mask, src_cls=None, src_cls_mask=None):
|
106 |
+
bs = src_vid.shape[0]
|
107 |
+
src_vid = self.input_vid_proj(src_vid)
|
108 |
+
src_txt = self.input_txt_proj(src_txt)
|
109 |
+
if src_cls is not None:
|
110 |
+
src_cls = self.input_txt_proj(src_cls)
|
111 |
+
|
112 |
+
# type token.
|
113 |
+
src_vid = src_vid + self.token_type_embeddings(torch.full_like(src_vid_mask.long(), 1))
|
114 |
+
src_txt = src_txt + self.token_type_embeddings(torch.zeros_like(src_txt_mask.long()))
|
115 |
+
if src_cls is not None:
|
116 |
+
src_cls = src_cls + self.token_type_embeddings(torch.zeros_like(src_cls_mask.long()))
|
117 |
+
|
118 |
+
src = torch.cat([src_vid, src_txt], dim=1) # (bsz, L_vid+L_txt, d)
|
119 |
+
mask = torch.cat([src_vid_mask, src_txt_mask], dim=1).bool() # (bsz, L_vid+L_txt)
|
120 |
+
|
121 |
+
pos_vid = self.position_embed(src_vid, src_vid_mask) # (bsz, L_vid, d)
|
122 |
+
pos_txt = self.txt_position_embed(src_txt) if self.use_txt_pos else torch.zeros_like(src_txt) # (bsz, L_txt, d)
|
123 |
+
pos = torch.cat([pos_vid, pos_txt], dim=1)
|
124 |
+
|
125 |
+
memory = self.transformer(src, ~mask, pos)
|
126 |
+
vid_mem = memory[:, :src_vid.shape[1], :] # (bsz, L_vid, d)
|
127 |
+
|
128 |
+
outputs_class = self.class_embed(vid_mem).sigmoid() # (#layers, batch_size, #queries, #classes)
|
129 |
+
outputs_coord = self.span_embed(vid_mem) # (#layers, bsz, #queries, 2 or max_v_l * 2)
|
130 |
+
|
131 |
+
if self.span_loss_type == "l1":
|
132 |
+
outputs_coord = outputs_coord.sigmoid()
|
133 |
+
idx_mask = torch.tensor((-1, 1)).unsqueeze(0).unsqueeze(0).cuda()
|
134 |
+
idx_mask = idx_mask.repeat(outputs_coord.shape[0], outputs_coord.shape[1], 1)
|
135 |
+
outputs_coord = outputs_coord * idx_mask
|
136 |
+
else:
|
137 |
+
raise NotImplementedError
|
138 |
+
|
139 |
+
out = {'pred_logits': outputs_class, 'pred_spans': outputs_coord,
|
140 |
+
'src_vid_mask': src_vid_mask}
|
141 |
+
|
142 |
+
vid_mem_proj = src_vid
|
143 |
+
|
144 |
+
# word-level -> sentence-level
|
145 |
+
txt_mem_proj = self.weightedpool(src_txt, src_txt_mask).unsqueeze(1)
|
146 |
+
sim = F.cosine_similarity(vid_mem_proj, txt_mem_proj, dim=-1) + (src_vid_mask + 1e-45).log()
|
147 |
+
|
148 |
+
out["vid_mem_proj"] = vid_mem_proj
|
149 |
+
out["txt_mem_proj"] = txt_mem_proj
|
150 |
+
if src_cls is not None:
|
151 |
+
cls_mem_proj = self.weightedpool(src_cls, src_cls_mask)
|
152 |
+
out["cls_mem_proj"] = cls_mem_proj
|
153 |
+
out["saliency_scores"] = sim
|
154 |
+
return out
|
155 |
+
|
156 |
+
class SetCriterion(nn.Module):
|
157 |
+
""" This class computes the loss for DETR.
|
158 |
+
The process happens in two steps:
|
159 |
+
1) we compute hungarian assignment between ground truth boxes and the outputs of the model
|
160 |
+
2) we supervise each pair of matched ground-truth / prediction (supervise class and box)
|
161 |
+
"""
|
162 |
+
|
163 |
+
def __init__(self, matcher, weight_dict, eos_coef, losses, temperature, span_loss_type, max_v_l,
|
164 |
+
saliency_margin=1):
|
165 |
+
""" Create the criterion.
|
166 |
+
Parameters:
|
167 |
+
matcher: module able to compute a matching between targets and proposals
|
168 |
+
weight_dict: dict containing as key the names of the losses and as values their relative weight.
|
169 |
+
eos_coef: relative classification weight applied to the no-object category
|
170 |
+
losses: list of all the losses to be applied. See get_loss for list of available losses.
|
171 |
+
temperature: float, temperature for NCE loss
|
172 |
+
span_loss_type: str, [l1, ce]
|
173 |
+
max_v_l: int,
|
174 |
+
saliency_margin: float
|
175 |
+
"""
|
176 |
+
super().__init__()
|
177 |
+
self.matcher = matcher
|
178 |
+
self.weight_dict = weight_dict
|
179 |
+
self.losses = losses
|
180 |
+
self.temperature = temperature
|
181 |
+
self.span_loss_type = span_loss_type
|
182 |
+
self.max_v_l = max_v_l
|
183 |
+
self.saliency_margin = saliency_margin
|
184 |
+
self.temperature = 0.07
|
185 |
+
|
186 |
+
# foreground and background classification
|
187 |
+
self.foreground_label = 0
|
188 |
+
self.background_label = 1
|
189 |
+
self.eos_coef = eos_coef
|
190 |
+
empty_weight = torch.ones(2)
|
191 |
+
empty_weight[-1] = self.eos_coef # lower weight for background (index 1, foreground index 0)
|
192 |
+
self.register_buffer('empty_weight', empty_weight)
|
193 |
+
|
194 |
+
def loss_spans(self, outputs, targets, indices):
|
195 |
+
assert 'pred_spans' in outputs
|
196 |
+
|
197 |
+
start_spans = targets['timestamp']
|
198 |
+
pred_spans = outputs['pred_spans']
|
199 |
+
src_spans = start_spans + pred_spans
|
200 |
+
gt_spans = targets['span_labels_nn']
|
201 |
+
|
202 |
+
mask = targets['timestamp_mask'].bool()
|
203 |
+
mask_full = targets['timestamp_mask'].unsqueeze(2).repeat(1, 1, 2)
|
204 |
+
mask_valid = targets['timestamp_window'].bool()
|
205 |
+
mask_valid_full = targets['timestamp_window'].unsqueeze(2).repeat(1, 1, 2)
|
206 |
+
|
207 |
+
loss_span = F.smooth_l1_loss(src_spans, gt_spans, reduction='none') * mask_valid_full
|
208 |
+
loss_giou = 1 - torch.diag(generalized_temporal_iou(src_spans[mask_valid], gt_spans[mask_valid]))
|
209 |
+
|
210 |
+
losses = {}
|
211 |
+
losses['loss_b'] = loss_span.sum() / mask_valid.sum()
|
212 |
+
losses['loss_g'] = loss_giou.mean()
|
213 |
+
return losses
|
214 |
+
|
215 |
+
def loss_labels(self, outputs, targets, indices, log=True):
|
216 |
+
src_logits = outputs['pred_logits'].squeeze(-1) # (batch_size, #queries, #classes=2)
|
217 |
+
mask = targets['timestamp_mask'].bool()
|
218 |
+
mask_valid = targets['timestamp_window'].bool()
|
219 |
+
target_classes = torch.full(src_logits.shape[:2], 0, dtype=torch.int64, device=src_logits.device) # (batch_size, #queries)
|
220 |
+
target_classes[mask_valid] = 1
|
221 |
+
# target_classes = targets['timestamp_window'] # soft cls.
|
222 |
+
target_classes.float()
|
223 |
+
# pdb.set_trace()
|
224 |
+
|
225 |
+
weights = torch.zeros_like(target_classes).float()
|
226 |
+
weights[mask] = self.empty_weight[1]
|
227 |
+
weights[mask_valid] = self.empty_weight[0]
|
228 |
+
|
229 |
+
# pdb.set_trace()
|
230 |
+
loss_ce = F.binary_cross_entropy(src_logits, target_classes.float(), weight=weights, reduction="none") * mask
|
231 |
+
return {"loss_f": loss_ce.sum() / mask.sum()}
|
232 |
+
# return {"loss_f": loss_ce.sum() / (1 + mask_valid.sum())}
|
233 |
+
|
234 |
+
def loss_saliency(self, outputs, targets, indices, log=True):
|
235 |
+
"""higher scores for positive clips"""
|
236 |
+
if "saliency_pos_labels" not in targets:
|
237 |
+
return {"loss_s_inter": 0., "loss_s_intra": 0.}
|
238 |
+
saliency_scores = targets["saliency_scores"]
|
239 |
+
if saliency_scores.sum() == 0:
|
240 |
+
return {"loss_s_inter": 0., "loss_s_intra": 0.}
|
241 |
+
|
242 |
+
# * inter-vid mode
|
243 |
+
vid_mem_proj = outputs["vid_mem_proj"]
|
244 |
+
pos_indices = targets["saliency_pos_labels"][:,0].long() # (N, #pairs)
|
245 |
+
batch_indices = torch.arange(len(vid_mem_proj)).to(vid_mem_proj.device)
|
246 |
+
|
247 |
+
vid_feats = vid_mem_proj[batch_indices, pos_indices]
|
248 |
+
txt_feats = outputs["txt_mem_proj"].squeeze(1)
|
249 |
+
sim = sim_matrix(vid_feats, txt_feats)
|
250 |
+
|
251 |
+
i_logsm = F.log_softmax(sim / self.temperature, dim=1)
|
252 |
+
j_logsm = F.log_softmax(sim.t() /self.temperature, dim=1)
|
253 |
+
|
254 |
+
# sum over positives
|
255 |
+
idiag = torch.diag(i_logsm)
|
256 |
+
jdiag = torch.diag(j_logsm)
|
257 |
+
loss_i = idiag.sum() / len(idiag)
|
258 |
+
loss_j = jdiag.sum() / len(jdiag)
|
259 |
+
|
260 |
+
loss_saliency_inter = - loss_i - loss_j
|
261 |
+
|
262 |
+
# * intra-vid mode
|
263 |
+
mask = targets['timestamp_mask']
|
264 |
+
selected_scores = saliency_scores[batch_indices, pos_indices].unsqueeze(-1)
|
265 |
+
neg_indices_in = (saliency_scores < selected_scores)
|
266 |
+
neg_indices_in[batch_indices, pos_indices] = True
|
267 |
+
mask_invalid = neg_indices_in * mask.bool()
|
268 |
+
|
269 |
+
sim_in = F.cosine_similarity(vid_mem_proj, txt_feats.unsqueeze(1), dim=-1)
|
270 |
+
sim_in = sim_in + (mask_invalid + 1e-45).log()
|
271 |
+
logsm_in_i = F.log_softmax(sim_in / self.temperature, dim=1)
|
272 |
+
logsm_in_j = F.log_softmax(sim_in.t() / self.temperature, dim=1)
|
273 |
+
|
274 |
+
pos_logsm_in_i = logsm_in_i[batch_indices, pos_indices]
|
275 |
+
pos_logsm_in_j = logsm_in_j[pos_indices, batch_indices]
|
276 |
+
loss_in_i = pos_logsm_in_i.sum() / len(pos_logsm_in_i)
|
277 |
+
loss_in_j = pos_logsm_in_j.sum() / len(pos_logsm_in_j)
|
278 |
+
|
279 |
+
loss_saliency_intra = - loss_in_i - loss_in_j
|
280 |
+
|
281 |
+
return {"loss_s_inter": loss_saliency_inter, "loss_s_intra": loss_saliency_intra}
|
282 |
+
|
283 |
+
def loss_saliency_cls(self, outputs, targets, indices, log=True):
|
284 |
+
"""higher scores for positive clips"""
|
285 |
+
if "saliency_pos_labels" not in targets:
|
286 |
+
return {"loss_s_inter": 0., "loss_s_intra": 0.}
|
287 |
+
saliency_scores = targets["saliency_scores"]
|
288 |
+
if saliency_scores.sum() == 0:
|
289 |
+
return {"loss_s_inter": 0., "loss_s_intra": 0.}
|
290 |
+
|
291 |
+
# * inter-vid mode
|
292 |
+
vid_mem_proj = outputs["vid_mem_proj"]
|
293 |
+
pos_indices = targets["saliency_pos_labels"][:,0].long() # (N, #pairs)
|
294 |
+
batch_indices = torch.arange(len(vid_mem_proj)).to(vid_mem_proj.device)
|
295 |
+
|
296 |
+
vid_feats = vid_mem_proj[batch_indices, pos_indices]
|
297 |
+
txt_feats = outputs["txt_mem_proj"].squeeze(1)
|
298 |
+
sim = sim_matrix(vid_feats, txt_feats)
|
299 |
+
|
300 |
+
i_logsm = F.log_softmax(sim / self.temperature, dim=1)
|
301 |
+
j_logsm = F.log_softmax(sim.t() /self.temperature, dim=1)
|
302 |
+
|
303 |
+
# sum over positives
|
304 |
+
idiag = torch.diag(i_logsm)
|
305 |
+
jdiag = torch.diag(j_logsm)
|
306 |
+
loss_i = idiag.sum() / len(idiag)
|
307 |
+
loss_j = jdiag.sum() / len(jdiag)
|
308 |
+
|
309 |
+
loss_saliency_inter = - loss_i - loss_j
|
310 |
+
|
311 |
+
# * intra-vid mode
|
312 |
+
if 'cls_idx' not in targets.keys(): # eval
|
313 |
+
return {"loss_s_inter": loss_saliency_inter}
|
314 |
+
|
315 |
+
cls_indices = targets['cls_idx'].bool()
|
316 |
+
cls_feats = outputs["cls_mem_proj"].squeeze(1)
|
317 |
+
sim_cls = sim_matrix(vid_feats, cls_feats)
|
318 |
+
|
319 |
+
i_logsm_cls = F.log_softmax(sim_cls / self.temperature, dim=1)
|
320 |
+
idiag_cls = i_logsm_cls[cls_indices]
|
321 |
+
loss_cls_i = idiag_cls.sum() / len(idiag_cls)
|
322 |
+
|
323 |
+
loss_saliency_intra = - loss_cls_i
|
324 |
+
|
325 |
+
return {"loss_s_inter": loss_saliency_inter, "loss_s_intra": loss_saliency_intra}
|
326 |
+
|
327 |
+
def get_loss(self, loss, outputs, targets, indices, **kwargs):
|
328 |
+
loss_map = {
|
329 |
+
"spans": self.loss_spans,
|
330 |
+
"labels": self.loss_labels,
|
331 |
+
"saliency": self.loss_saliency,
|
332 |
+
"saliency_cls": self.loss_saliency_cls,
|
333 |
+
}
|
334 |
+
assert loss in loss_map, f'do you really want to compute {loss} loss?'
|
335 |
+
return loss_map[loss](outputs, targets, indices, **kwargs)
|
336 |
+
|
337 |
+
def forward(self, outputs, targets, hl_only=False):
|
338 |
+
""" This performs the loss computation.
|
339 |
+
Parameters:
|
340 |
+
outputs: dict of tensors, see the output specification of the model for the format
|
341 |
+
targets: list of dicts, such that len(targets) == batch_size.
|
342 |
+
The expected keys in each dict depends on the losses applied, see each loss' doc
|
343 |
+
"""
|
344 |
+
indices = None
|
345 |
+
# Compute all the requested losses
|
346 |
+
losses = {}
|
347 |
+
for loss in self.losses:
|
348 |
+
losses.update(self.get_loss(loss, outputs, targets, indices))
|
349 |
+
|
350 |
+
return losses
|
351 |
+
|
352 |
+
class MLP(nn.Module):
|
353 |
+
""" Very simple multi-layer perceptron (also called FFN)"""
|
354 |
+
|
355 |
+
def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
|
356 |
+
super().__init__()
|
357 |
+
self.num_layers = num_layers
|
358 |
+
h = [hidden_dim] * (num_layers - 1)
|
359 |
+
self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
|
360 |
+
|
361 |
+
def forward(self, x):
|
362 |
+
for i, layer in enumerate(self.layers):
|
363 |
+
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
|
364 |
+
return x
|
365 |
+
|
366 |
+
class Conv(nn.Module):
|
367 |
+
""" Very simple multi-layer perceptron (also called FFN)"""
|
368 |
+
|
369 |
+
def __init__(self, input_dim, hidden_dim, output_dim, num_layers, kernel_size):
|
370 |
+
super().__init__()
|
371 |
+
self.num_layers = num_layers
|
372 |
+
h = [hidden_dim] * (num_layers - 1)
|
373 |
+
# self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
|
374 |
+
self.layers = nn.ModuleList(
|
375 |
+
nn.Conv1d(n, k, kernel_size=kernel_size, stride=1, padding=kernel_size//2, dilation=1, groups=1, bias=True, padding_mode='zeros')
|
376 |
+
for n, k in zip([input_dim] + h, h + [output_dim]))
|
377 |
+
def forward(self, x):
|
378 |
+
x = x.permute(0,2,1)
|
379 |
+
for i, layer in enumerate(self.layers):
|
380 |
+
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
|
381 |
+
return x.permute(0, 2, 1)
|
382 |
+
|
383 |
+
class LinearLayer(nn.Module):
|
384 |
+
"""linear layer configurable with layer normalization, dropout, ReLU."""
|
385 |
+
|
386 |
+
def __init__(self, in_hsz, out_hsz, layer_norm=True, dropout=0.1, relu=True):
|
387 |
+
super(LinearLayer, self).__init__()
|
388 |
+
self.relu = relu
|
389 |
+
self.layer_norm = layer_norm
|
390 |
+
if layer_norm:
|
391 |
+
self.LayerNorm = nn.LayerNorm(in_hsz)
|
392 |
+
layers = [
|
393 |
+
nn.Dropout(dropout),
|
394 |
+
nn.Linear(in_hsz, out_hsz)
|
395 |
+
]
|
396 |
+
self.net = nn.Sequential(*layers)
|
397 |
+
|
398 |
+
def forward(self, x):
|
399 |
+
"""(N, L, D)"""
|
400 |
+
if self.layer_norm:
|
401 |
+
x = self.LayerNorm(x)
|
402 |
+
x = self.net(x)
|
403 |
+
if self.relu:
|
404 |
+
x = F.relu(x, inplace=True)
|
405 |
+
return x # (N, L, D)
|
406 |
+
|
407 |
+
|
408 |
+
def build_model(args):
|
409 |
+
device = torch.device(args.device)
|
410 |
+
|
411 |
+
transformer = build_transformer(args)
|
412 |
+
position_embedding, txt_position_embedding = build_position_encoding(args)
|
413 |
+
|
414 |
+
model = Model(
|
415 |
+
transformer,
|
416 |
+
position_embedding,
|
417 |
+
txt_position_embedding,
|
418 |
+
txt_dim=args.t_feat_dim,
|
419 |
+
vid_dim=args.v_feat_dim,
|
420 |
+
input_dropout=args.input_dropout,
|
421 |
+
span_loss_type=args.span_loss_type,
|
422 |
+
use_txt_pos=args.use_txt_pos,
|
423 |
+
n_input_proj=args.n_input_proj,
|
424 |
+
)
|
425 |
+
|
426 |
+
matcher = build_matcher(args)
|
427 |
+
weight_dict = {"loss_b": args.b_loss_coef,
|
428 |
+
"loss_g": args.g_loss_coef,
|
429 |
+
"loss_f": args.f_loss_coef,
|
430 |
+
"loss_s_intra": args.s_loss_intra_coef,
|
431 |
+
"loss_s_inter": args.s_loss_inter_coef}
|
432 |
+
|
433 |
+
if args.dset_type in ['mr', 'vlp']:
|
434 |
+
if 'tal' not in args.train_path:
|
435 |
+
losses = ['spans', 'labels', 'saliency']
|
436 |
+
else:
|
437 |
+
losses = ['spans', 'labels', 'saliency_cls']
|
438 |
+
elif args.dset_type in ['hl', 'vs']:
|
439 |
+
losses = ['labels', 'saliency']
|
440 |
+
|
441 |
+
criterion = SetCriterion(
|
442 |
+
matcher=matcher,
|
443 |
+
weight_dict=weight_dict, losses=losses,
|
444 |
+
eos_coef=args.eos_coef, temperature=args.temperature,
|
445 |
+
span_loss_type=args.span_loss_type, max_v_l=args.max_v_l,
|
446 |
+
saliency_margin=args.saliency_margin,
|
447 |
+
)
|
448 |
+
criterion.to(device)
|
449 |
+
return model, criterion
|
model/base_droppath_ablation.py
ADDED
@@ -0,0 +1,474 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pdb
|
2 |
+
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from torch import nn
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
from model.transformer_encoder_droppath import build_transformer
|
8 |
+
from model.matcher import build_matcher
|
9 |
+
from model.position_encoding import build_position_encoding
|
10 |
+
from utils.span_utils import generalized_temporal_iou, span_cxw_to_xx
|
11 |
+
|
12 |
+
def init_weights(module):
|
13 |
+
if isinstance(module, (nn.Linear, nn.Embedding)):
|
14 |
+
module.weight.data.normal_(mean=0.0, std=0.02)
|
15 |
+
elif isinstance(module, nn.LayerNorm):
|
16 |
+
module.bias.data.zero_()
|
17 |
+
module.weight.data.fill_(1.0)
|
18 |
+
|
19 |
+
if isinstance(module, nn.Linear) and module.bias is not None:
|
20 |
+
module.bias.data.zero_()
|
21 |
+
|
22 |
+
def mask_logits(inputs, mask, mask_value=-1e30):
|
23 |
+
mask = mask.type(torch.float32)
|
24 |
+
return inputs + (1.0 - mask) * mask_value
|
25 |
+
|
26 |
+
def sim_matrix(a, b, eps=1e-8):
|
27 |
+
"""
|
28 |
+
added eps for numerical stability
|
29 |
+
"""
|
30 |
+
a_n, b_n = a.norm(dim=1)[:, None], b.norm(dim=1)[:, None]
|
31 |
+
a_norm = a / torch.max(a_n, eps * torch.ones_like(a_n))
|
32 |
+
b_norm = b / torch.max(b_n, eps * torch.ones_like(b_n))
|
33 |
+
sim_mt = torch.mm(a_norm, b_norm.transpose(0, 1))
|
34 |
+
return sim_mt
|
35 |
+
|
36 |
+
class WeightedPool(nn.Module):
|
37 |
+
def __init__(self, dim):
|
38 |
+
super(WeightedPool, self).__init__()
|
39 |
+
weight = torch.empty(dim, 1)
|
40 |
+
nn.init.xavier_uniform_(weight)
|
41 |
+
self.weight = nn.Parameter(weight, requires_grad=True)
|
42 |
+
|
43 |
+
def forward(self, x, mask):
|
44 |
+
alpha = torch.tensordot(x, self.weight, dims=1) # shape = (batch_size, seq_length, 1)
|
45 |
+
alpha = mask_logits(alpha, mask=mask.unsqueeze(2))
|
46 |
+
alphas = nn.Softmax(dim=1)(alpha)
|
47 |
+
pooled_x = torch.matmul(x.transpose(1, 2), alphas) # (batch_size, dim, 1)
|
48 |
+
pooled_x = pooled_x.squeeze(2)
|
49 |
+
return pooled_x
|
50 |
+
|
51 |
+
class Model(nn.Module):
|
52 |
+
""" This is the UniVTG module that performs moment localization. """
|
53 |
+
|
54 |
+
def __init__(self, transformer, position_embed, txt_position_embed, txt_dim, vid_dim,
|
55 |
+
input_dropout, aux_loss=False,
|
56 |
+
max_v_l=75, span_loss_type="l1", use_txt_pos=False, n_input_proj=2):
|
57 |
+
""" Initializes the model.
|
58 |
+
Parameters:
|
59 |
+
transformer: torch module of the transformer architecture. See transformer.py
|
60 |
+
position_embed: torch module of the position_embedding, See position_encoding.py
|
61 |
+
txt_position_embed: position_embedding for text
|
62 |
+
txt_dim: int, text query input dimension
|
63 |
+
vid_dim: int, video feature input dimension
|
64 |
+
max_v_l: int, maximum #clips in videos
|
65 |
+
span_loss_type: str, one of [l1, ce]
|
66 |
+
l1: (center-x, width) regression.
|
67 |
+
ce: (st_idx, ed_idx) classification.
|
68 |
+
# foreground_thd: float, intersection over prediction >= foreground_thd: labeled as foreground
|
69 |
+
# background_thd: float, intersection over prediction <= background_thd: labeled background
|
70 |
+
"""
|
71 |
+
super().__init__()
|
72 |
+
self.transformer = transformer
|
73 |
+
self.position_embed = position_embed
|
74 |
+
self.txt_position_embed = txt_position_embed
|
75 |
+
hidden_dim = transformer.d_model
|
76 |
+
self.span_loss_type = span_loss_type
|
77 |
+
self.max_v_l = max_v_l
|
78 |
+
span_pred_dim = 2 if span_loss_type == "l1" else max_v_l * 2
|
79 |
+
|
80 |
+
self.token_type_embeddings = nn.Embedding(2, hidden_dim)
|
81 |
+
self.token_type_embeddings.apply(init_weights)
|
82 |
+
|
83 |
+
# Conv projector
|
84 |
+
self.span_embed = Conv(hidden_dim, hidden_dim, span_pred_dim, 3, kernel_size=3)
|
85 |
+
self.class_embed = Conv(hidden_dim, hidden_dim, 1, 3, kernel_size=3) # 0: background, 1: foreground
|
86 |
+
|
87 |
+
self.use_txt_pos = use_txt_pos
|
88 |
+
self.n_input_proj = n_input_proj
|
89 |
+
relu_args = [True] * 3
|
90 |
+
relu_args[n_input_proj-1] = False
|
91 |
+
self.input_txt_proj = nn.Sequential(*[
|
92 |
+
LinearLayer(txt_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[0]),
|
93 |
+
LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[1]),
|
94 |
+
LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[2])
|
95 |
+
][:n_input_proj])
|
96 |
+
self.input_vid_proj = nn.Sequential(*[
|
97 |
+
LinearLayer(vid_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[0]),
|
98 |
+
LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[1]),
|
99 |
+
LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[2])
|
100 |
+
][:n_input_proj])
|
101 |
+
|
102 |
+
# MLP Projector
|
103 |
+
self.weightedpool = WeightedPool(hidden_dim)
|
104 |
+
|
105 |
+
def forward(self, src_txt, src_txt_mask, src_vid, src_vid_mask, src_cls=None, src_cls_mask=None):
|
106 |
+
bs = src_vid.shape[0]
|
107 |
+
src_vid = self.input_vid_proj(src_vid)
|
108 |
+
src_txt = self.input_txt_proj(src_txt)
|
109 |
+
if src_cls is not None:
|
110 |
+
src_cls = self.input_txt_proj(src_cls)
|
111 |
+
|
112 |
+
# type token.
|
113 |
+
src_vid = src_vid + self.token_type_embeddings(torch.full_like(src_vid_mask.long(), 1))
|
114 |
+
src_txt = src_txt + self.token_type_embeddings(torch.zeros_like(src_txt_mask.long()))
|
115 |
+
if src_cls is not None:
|
116 |
+
src_cls = src_cls + self.token_type_embeddings(torch.zeros_like(src_cls_mask.long()))
|
117 |
+
|
118 |
+
src = torch.cat([src_vid, src_txt], dim=1) # (bsz, L_vid+L_txt, d)
|
119 |
+
mask = torch.cat([src_vid_mask, src_txt_mask], dim=1).bool() # (bsz, L_vid+L_txt)
|
120 |
+
|
121 |
+
pos_vid = self.position_embed(src_vid, src_vid_mask) # (bsz, L_vid, d)
|
122 |
+
pos_txt = self.txt_position_embed(src_txt) if self.use_txt_pos else torch.zeros_like(src_txt) # (bsz, L_txt, d)
|
123 |
+
pos = torch.cat([pos_vid, pos_txt], dim=1)
|
124 |
+
|
125 |
+
memory = self.transformer(src, ~mask, pos)
|
126 |
+
vid_mem = memory[:, :src_vid.shape[1], :] # (bsz, L_vid, d)
|
127 |
+
|
128 |
+
outputs_class = self.class_embed(vid_mem).sigmoid() # (#layers, batch_size, #queries, #classes)
|
129 |
+
outputs_coord = self.span_embed(vid_mem) # (#layers, bsz, #queries, 2 or max_v_l * 2)
|
130 |
+
|
131 |
+
if self.span_loss_type == "l1":
|
132 |
+
outputs_coord = outputs_coord.sigmoid()
|
133 |
+
idx_mask = torch.tensor((-1, 1)).unsqueeze(0).unsqueeze(0).cuda()
|
134 |
+
idx_mask = idx_mask.repeat(outputs_coord.shape[0], outputs_coord.shape[1], 1)
|
135 |
+
outputs_coord = outputs_coord * idx_mask
|
136 |
+
else:
|
137 |
+
raise NotImplementedError
|
138 |
+
|
139 |
+
out = {'pred_logits': outputs_class, 'pred_spans': outputs_coord,
|
140 |
+
'src_vid_mask': src_vid_mask}
|
141 |
+
|
142 |
+
vid_mem_proj = src_vid
|
143 |
+
|
144 |
+
# word-level -> sentence-level
|
145 |
+
txt_mem_proj = self.weightedpool(src_txt, src_txt_mask).unsqueeze(1)
|
146 |
+
sim = F.cosine_similarity(vid_mem_proj, txt_mem_proj, dim=-1) + (src_vid_mask + 1e-45).log()
|
147 |
+
|
148 |
+
out["vid_mem_proj"] = vid_mem_proj
|
149 |
+
out["txt_mem_proj"] = txt_mem_proj
|
150 |
+
if src_cls is not None:
|
151 |
+
cls_mem_proj = self.weightedpool(src_cls, src_cls_mask)
|
152 |
+
out["cls_mem_proj"] = cls_mem_proj
|
153 |
+
out["saliency_scores"] = sim
|
154 |
+
return out
|
155 |
+
|
156 |
+
class SetCriterion(nn.Module):
|
157 |
+
""" This class computes the loss for DETR.
|
158 |
+
The process happens in two steps:
|
159 |
+
1) we compute hungarian assignment between ground truth boxes and the outputs of the model
|
160 |
+
2) we supervise each pair of matched ground-truth / prediction (supervise class and box)
|
161 |
+
"""
|
162 |
+
|
163 |
+
def __init__(self, matcher, weight_dict, eos_coef, losses, temperature, span_loss_type, max_v_l,
|
164 |
+
saliency_margin=1):
|
165 |
+
""" Create the criterion.
|
166 |
+
Parameters:
|
167 |
+
matcher: module able to compute a matching between targets and proposals
|
168 |
+
weight_dict: dict containing as key the names of the losses and as values their relative weight.
|
169 |
+
eos_coef: relative classification weight applied to the no-object category
|
170 |
+
losses: list of all the losses to be applied. See get_loss for list of available losses.
|
171 |
+
temperature: float, temperature for NCE loss
|
172 |
+
span_loss_type: str, [l1, ce]
|
173 |
+
max_v_l: int,
|
174 |
+
saliency_margin: float
|
175 |
+
"""
|
176 |
+
super().__init__()
|
177 |
+
self.matcher = matcher
|
178 |
+
self.weight_dict = weight_dict
|
179 |
+
self.losses = losses
|
180 |
+
self.temperature = temperature
|
181 |
+
self.span_loss_type = span_loss_type
|
182 |
+
self.max_v_l = max_v_l
|
183 |
+
self.saliency_margin = saliency_margin
|
184 |
+
self.temperature = 0.07
|
185 |
+
|
186 |
+
# foreground and background classification
|
187 |
+
self.foreground_label = 0
|
188 |
+
self.background_label = 1
|
189 |
+
self.eos_coef = eos_coef
|
190 |
+
empty_weight = torch.ones(2)
|
191 |
+
empty_weight[-1] = self.eos_coef # lower weight for background (index 1, foreground index 0)
|
192 |
+
self.register_buffer('empty_weight', empty_weight)
|
193 |
+
|
194 |
+
def loss_spans(self, outputs, targets, indices):
|
195 |
+
assert 'pred_spans' in outputs
|
196 |
+
|
197 |
+
start_spans = targets['timestamp']
|
198 |
+
pred_spans = outputs['pred_spans']
|
199 |
+
src_spans = start_spans + pred_spans
|
200 |
+
gt_spans = targets['span_labels_nn']
|
201 |
+
|
202 |
+
mask = targets['timestamp_mask'].bool()
|
203 |
+
mask_full = targets['timestamp_mask'].unsqueeze(2).repeat(1, 1, 2)
|
204 |
+
mask_valid = targets['timestamp_window'].bool()
|
205 |
+
mask_valid_full = targets['timestamp_window'].unsqueeze(2).repeat(1, 1, 2)
|
206 |
+
|
207 |
+
weight_abalation_b = targets['weight_ablation'][:,0].unsqueeze(-1)
|
208 |
+
if weight_abalation_b.sum() == 0:
|
209 |
+
return {"loss_f": torch.tensor(0).cuda(), "loss_g": torch.tensor(0).cuda()}
|
210 |
+
|
211 |
+
mask_valid = (mask_valid * weight_abalation_b).bool()
|
212 |
+
mask_valid_full = (mask_valid_full * weight_abalation_b.unsqueeze(-1)).bool()
|
213 |
+
|
214 |
+
loss_span = F.smooth_l1_loss(src_spans, gt_spans, reduction='none') * mask_valid_full
|
215 |
+
loss_giou = 1 - torch.diag(generalized_temporal_iou(src_spans[mask_valid], gt_spans[mask_valid]))
|
216 |
+
|
217 |
+
losses = {}
|
218 |
+
losses['loss_b'] = loss_span.sum() / mask_valid.sum()
|
219 |
+
losses['loss_g'] = loss_giou.mean()
|
220 |
+
return losses
|
221 |
+
|
222 |
+
def loss_labels(self, outputs, targets, indices, log=True):
|
223 |
+
src_logits = outputs['pred_logits'].squeeze(-1) # (batch_size, #queries, #classes=2)
|
224 |
+
mask = targets['timestamp_mask'].bool()
|
225 |
+
mask_valid = targets['timestamp_window'].bool()
|
226 |
+
target_classes = torch.full(src_logits.shape[:2], 0, dtype=torch.int64, device=src_logits.device) # (batch_size, #queries)
|
227 |
+
target_classes[mask_valid] = 1
|
228 |
+
# target_classes = targets['timestamp_window'] # soft cls.
|
229 |
+
target_classes.float()
|
230 |
+
# pdb.set_trace()
|
231 |
+
|
232 |
+
weights = torch.zeros_like(target_classes).float()
|
233 |
+
weights[mask] = self.empty_weight[1]
|
234 |
+
weights[mask_valid] = self.empty_weight[0]
|
235 |
+
|
236 |
+
loss_ce = F.binary_cross_entropy(src_logits, target_classes.float(), weight=weights, reduction="none") * mask
|
237 |
+
|
238 |
+
weight_abalation_f = targets['weight_ablation'][:,2].unsqueeze(-1)
|
239 |
+
if weight_abalation_f.sum() == 0:
|
240 |
+
return {"loss_f": torch.tensor(0).cuda()}
|
241 |
+
|
242 |
+
mask = mask * weight_abalation_f
|
243 |
+
loss_ce = loss_ce * weight_abalation_f
|
244 |
+
return {"loss_f": loss_ce.sum() / mask.sum()}
|
245 |
+
# return {"loss_f": loss_ce.sum() / (1 + mask_valid.sum())}
|
246 |
+
|
247 |
+
def loss_saliency(self, outputs, targets, indices, log=True):
|
248 |
+
"""higher scores for positive clips"""
|
249 |
+
if "saliency_pos_labels" not in targets:
|
250 |
+
return {"loss_s_inter": 0., "loss_s_intra": 0.}
|
251 |
+
saliency_scores = targets["saliency_scores"]
|
252 |
+
if saliency_scores.sum() == 0:
|
253 |
+
return {"loss_s_inter": 0., "loss_s_intra": 0.}
|
254 |
+
|
255 |
+
# * inter-vid mode
|
256 |
+
vid_mem_proj = outputs["vid_mem_proj"]
|
257 |
+
pos_indices = targets["saliency_pos_labels"][:,0].long() # (N, #pairs)
|
258 |
+
batch_indices = torch.arange(len(vid_mem_proj)).to(vid_mem_proj.device)
|
259 |
+
|
260 |
+
vid_feats = vid_mem_proj[batch_indices, pos_indices]
|
261 |
+
txt_feats = outputs["txt_mem_proj"].squeeze(1)
|
262 |
+
sim = sim_matrix(vid_feats, txt_feats)
|
263 |
+
|
264 |
+
i_logsm = F.log_softmax(sim / self.temperature, dim=1)
|
265 |
+
j_logsm = F.log_softmax(sim.t() /self.temperature, dim=1)
|
266 |
+
|
267 |
+
# sum over positives
|
268 |
+
idiag = torch.diag(i_logsm)
|
269 |
+
jdiag = torch.diag(j_logsm)
|
270 |
+
|
271 |
+
weight_abalation_s = targets['weight_ablation'][:,3].bool()
|
272 |
+
if weight_abalation_s.sum() == 0:
|
273 |
+
return {"loss_s_inter": torch.tensor(0).cuda(),
|
274 |
+
"loss_s_intra": torch.tensor(0).cuda()}
|
275 |
+
|
276 |
+
_idiag = idiag[weight_abalation_s]
|
277 |
+
_jdiag = jdiag[weight_abalation_s]
|
278 |
+
|
279 |
+
loss_i = _idiag.sum() / len(_idiag)
|
280 |
+
loss_j = _jdiag.sum() / len(_jdiag)
|
281 |
+
|
282 |
+
loss_saliency_inter = - loss_i - loss_j
|
283 |
+
|
284 |
+
# * intra-vid mode
|
285 |
+
mask = targets['timestamp_mask']
|
286 |
+
selected_scores = saliency_scores[batch_indices, pos_indices].unsqueeze(-1)
|
287 |
+
neg_indices_in = (saliency_scores < selected_scores)
|
288 |
+
neg_indices_in[batch_indices, pos_indices] = True
|
289 |
+
mask_invalid = neg_indices_in * mask.bool()
|
290 |
+
|
291 |
+
sim_in = F.cosine_similarity(vid_mem_proj, txt_feats.unsqueeze(1), dim=-1)
|
292 |
+
sim_in = sim_in + (mask_invalid + 1e-45).log()
|
293 |
+
logsm_in_i = F.log_softmax(sim_in / self.temperature, dim=1)
|
294 |
+
logsm_in_j = F.log_softmax(sim_in.t() / self.temperature, dim=1)
|
295 |
+
|
296 |
+
pos_logsm_in_i = logsm_in_i[batch_indices, pos_indices]
|
297 |
+
pos_logsm_in_j = logsm_in_j[pos_indices, batch_indices]
|
298 |
+
_pos_logsm_in_i = pos_logsm_in_i[weight_abalation_s]
|
299 |
+
_pos_logsm_in_j = pos_logsm_in_j[weight_abalation_s]
|
300 |
+
|
301 |
+
loss_in_i = _pos_logsm_in_i.sum() / len(_pos_logsm_in_i)
|
302 |
+
loss_in_j = _pos_logsm_in_j.sum() / len(_pos_logsm_in_j)
|
303 |
+
|
304 |
+
loss_saliency_intra = - loss_in_i - loss_in_j
|
305 |
+
|
306 |
+
return {"loss_s_inter": loss_saliency_inter, "loss_s_intra": loss_saliency_intra}
|
307 |
+
|
308 |
+
def loss_saliency_cls(self, outputs, targets, indices, log=True):
|
309 |
+
"""higher scores for positive clips"""
|
310 |
+
if "saliency_pos_labels" not in targets:
|
311 |
+
return {"loss_s_inter": 0., "loss_s_intra": 0.}
|
312 |
+
saliency_scores = targets["saliency_scores"]
|
313 |
+
if saliency_scores.sum() == 0:
|
314 |
+
return {"loss_s_inter": 0., "loss_s_intra": 0.}
|
315 |
+
|
316 |
+
# * inter-vid mode
|
317 |
+
vid_mem_proj = outputs["vid_mem_proj"]
|
318 |
+
pos_indices = targets["saliency_pos_labels"][:,0].long() # (N, #pairs)
|
319 |
+
batch_indices = torch.arange(len(vid_mem_proj)).to(vid_mem_proj.device)
|
320 |
+
|
321 |
+
vid_feats = vid_mem_proj[batch_indices, pos_indices]
|
322 |
+
txt_feats = outputs["txt_mem_proj"].squeeze(1)
|
323 |
+
sim = sim_matrix(vid_feats, txt_feats)
|
324 |
+
|
325 |
+
i_logsm = F.log_softmax(sim / self.temperature, dim=1)
|
326 |
+
j_logsm = F.log_softmax(sim.t() /self.temperature, dim=1)
|
327 |
+
|
328 |
+
# sum over positives
|
329 |
+
idiag = torch.diag(i_logsm)
|
330 |
+
jdiag = torch.diag(j_logsm)
|
331 |
+
loss_i = idiag.sum() / len(idiag)
|
332 |
+
loss_j = jdiag.sum() / len(jdiag)
|
333 |
+
|
334 |
+
loss_saliency_inter = - loss_i - loss_j
|
335 |
+
|
336 |
+
# * intra-vid mode
|
337 |
+
if 'cls_idx' not in targets.keys(): # eval
|
338 |
+
return {"loss_s_inter": loss_saliency_inter}
|
339 |
+
|
340 |
+
cls_indices = targets['cls_idx'].bool()
|
341 |
+
cls_feats = outputs["cls_mem_proj"].squeeze(1)
|
342 |
+
sim_cls = sim_matrix(vid_feats, cls_feats)
|
343 |
+
|
344 |
+
i_logsm_cls = F.log_softmax(sim_cls / self.temperature, dim=1)
|
345 |
+
idiag_cls = i_logsm_cls[cls_indices]
|
346 |
+
loss_cls_i = idiag_cls.sum() / len(idiag_cls)
|
347 |
+
|
348 |
+
loss_saliency_intra = - loss_cls_i
|
349 |
+
|
350 |
+
return {"loss_s_inter": loss_saliency_inter, "loss_s_intra": loss_saliency_intra}
|
351 |
+
|
352 |
+
def get_loss(self, loss, outputs, targets, indices, **kwargs):
|
353 |
+
loss_map = {
|
354 |
+
"spans": self.loss_spans,
|
355 |
+
"labels": self.loss_labels,
|
356 |
+
"saliency": self.loss_saliency,
|
357 |
+
"saliency_cls": self.loss_saliency_cls,
|
358 |
+
}
|
359 |
+
assert loss in loss_map, f'do you really want to compute {loss} loss?'
|
360 |
+
return loss_map[loss](outputs, targets, indices, **kwargs)
|
361 |
+
|
362 |
+
def forward(self, outputs, targets, hl_only=False):
|
363 |
+
""" This performs the loss computation.
|
364 |
+
Parameters:
|
365 |
+
outputs: dict of tensors, see the output specification of the model for the format
|
366 |
+
targets: list of dicts, such that len(targets) == batch_size.
|
367 |
+
The expected keys in each dict depends on the losses applied, see each loss' doc
|
368 |
+
"""
|
369 |
+
indices = None
|
370 |
+
# Compute all the requested losses
|
371 |
+
losses = {}
|
372 |
+
for loss in self.losses:
|
373 |
+
losses.update(self.get_loss(loss, outputs, targets, indices))
|
374 |
+
|
375 |
+
return losses
|
376 |
+
|
377 |
+
class MLP(nn.Module):
|
378 |
+
""" Very simple multi-layer perceptron (also called FFN)"""
|
379 |
+
|
380 |
+
def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
|
381 |
+
super().__init__()
|
382 |
+
self.num_layers = num_layers
|
383 |
+
h = [hidden_dim] * (num_layers - 1)
|
384 |
+
self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
|
385 |
+
|
386 |
+
def forward(self, x):
|
387 |
+
for i, layer in enumerate(self.layers):
|
388 |
+
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
|
389 |
+
return x
|
390 |
+
|
391 |
+
class Conv(nn.Module):
|
392 |
+
""" Very simple multi-layer perceptron (also called FFN)"""
|
393 |
+
|
394 |
+
def __init__(self, input_dim, hidden_dim, output_dim, num_layers, kernel_size):
|
395 |
+
super().__init__()
|
396 |
+
self.num_layers = num_layers
|
397 |
+
h = [hidden_dim] * (num_layers - 1)
|
398 |
+
# self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
|
399 |
+
self.layers = nn.ModuleList(
|
400 |
+
nn.Conv1d(n, k, kernel_size=kernel_size, stride=1, padding=kernel_size//2, dilation=1, groups=1, bias=True, padding_mode='zeros')
|
401 |
+
for n, k in zip([input_dim] + h, h + [output_dim]))
|
402 |
+
def forward(self, x):
|
403 |
+
x = x.permute(0,2,1)
|
404 |
+
for i, layer in enumerate(self.layers):
|
405 |
+
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
|
406 |
+
return x.permute(0, 2, 1)
|
407 |
+
|
408 |
+
class LinearLayer(nn.Module):
|
409 |
+
"""linear layer configurable with layer normalization, dropout, ReLU."""
|
410 |
+
|
411 |
+
def __init__(self, in_hsz, out_hsz, layer_norm=True, dropout=0.1, relu=True):
|
412 |
+
super(LinearLayer, self).__init__()
|
413 |
+
self.relu = relu
|
414 |
+
self.layer_norm = layer_norm
|
415 |
+
if layer_norm:
|
416 |
+
self.LayerNorm = nn.LayerNorm(in_hsz)
|
417 |
+
layers = [
|
418 |
+
nn.Dropout(dropout),
|
419 |
+
nn.Linear(in_hsz, out_hsz)
|
420 |
+
]
|
421 |
+
self.net = nn.Sequential(*layers)
|
422 |
+
|
423 |
+
def forward(self, x):
|
424 |
+
"""(N, L, D)"""
|
425 |
+
if self.layer_norm:
|
426 |
+
x = self.LayerNorm(x)
|
427 |
+
x = self.net(x)
|
428 |
+
if self.relu:
|
429 |
+
x = F.relu(x, inplace=True)
|
430 |
+
return x # (N, L, D)
|
431 |
+
|
432 |
+
|
433 |
+
def build_model(args):
|
434 |
+
device = torch.device(args.device)
|
435 |
+
|
436 |
+
transformer = build_transformer(args)
|
437 |
+
position_embedding, txt_position_embedding = build_position_encoding(args)
|
438 |
+
|
439 |
+
model = Model(
|
440 |
+
transformer,
|
441 |
+
position_embedding,
|
442 |
+
txt_position_embedding,
|
443 |
+
txt_dim=args.t_feat_dim,
|
444 |
+
vid_dim=args.v_feat_dim,
|
445 |
+
input_dropout=args.input_dropout,
|
446 |
+
span_loss_type=args.span_loss_type,
|
447 |
+
use_txt_pos=args.use_txt_pos,
|
448 |
+
n_input_proj=args.n_input_proj,
|
449 |
+
)
|
450 |
+
|
451 |
+
matcher = build_matcher(args)
|
452 |
+
weight_dict = {"loss_b": args.b_loss_coef,
|
453 |
+
"loss_g": args.g_loss_coef,
|
454 |
+
"loss_f": args.f_loss_coef,
|
455 |
+
"loss_s_intra": args.s_loss_intra_coef,
|
456 |
+
"loss_s_inter": args.s_loss_inter_coef}
|
457 |
+
|
458 |
+
if args.dset_type in ['mr', 'vlp']:
|
459 |
+
if 'tal' not in args.train_path:
|
460 |
+
losses = ['spans', 'labels', 'saliency']
|
461 |
+
else:
|
462 |
+
losses = ['spans', 'labels', 'saliency_cls']
|
463 |
+
elif args.dset_type in ['hl', 'vs']:
|
464 |
+
losses = ['labels', 'saliency']
|
465 |
+
|
466 |
+
criterion = SetCriterion(
|
467 |
+
matcher=matcher,
|
468 |
+
weight_dict=weight_dict, losses=losses,
|
469 |
+
eos_coef=args.eos_coef, temperature=args.temperature,
|
470 |
+
span_loss_type=args.span_loss_type, max_v_l=args.max_v_l,
|
471 |
+
saliency_margin=args.saliency_margin,
|
472 |
+
)
|
473 |
+
criterion.to(device)
|
474 |
+
return model, criterion
|
model/base_droppath_qfvs.py
ADDED
@@ -0,0 +1,476 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pdb
|
2 |
+
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from torch import nn
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
from model.transformer_encoder_droppath import build_transformer
|
8 |
+
from model.matcher import build_matcher
|
9 |
+
from model.position_encoding import build_position_encoding
|
10 |
+
from utils.span_utils import generalized_temporal_iou, span_cxw_to_xx
|
11 |
+
|
12 |
+
def init_weights(module):
|
13 |
+
if isinstance(module, (nn.Linear, nn.Embedding)):
|
14 |
+
module.weight.data.normal_(mean=0.0, std=0.02)
|
15 |
+
elif isinstance(module, nn.LayerNorm):
|
16 |
+
module.bias.data.zero_()
|
17 |
+
module.weight.data.fill_(1.0)
|
18 |
+
|
19 |
+
if isinstance(module, nn.Linear) and module.bias is not None:
|
20 |
+
module.bias.data.zero_()
|
21 |
+
|
22 |
+
def mask_logits(inputs, mask, mask_value=-1e30):
|
23 |
+
mask = mask.type(torch.float32)
|
24 |
+
return inputs + (1.0 - mask) * mask_value
|
25 |
+
|
26 |
+
def sim_matrix(a, b, eps=1e-8):
|
27 |
+
"""
|
28 |
+
added eps for numerical stability
|
29 |
+
"""
|
30 |
+
a_n, b_n = a.norm(dim=1)[:, None], b.norm(dim=1)[:, None]
|
31 |
+
a_norm = a / torch.max(a_n, eps * torch.ones_like(a_n))
|
32 |
+
b_norm = b / torch.max(b_n, eps * torch.ones_like(b_n))
|
33 |
+
sim_mt = torch.mm(a_norm, b_norm.transpose(0, 1))
|
34 |
+
return sim_mt
|
35 |
+
|
36 |
+
class WeightedPool(nn.Module):
|
37 |
+
def __init__(self, dim):
|
38 |
+
super(WeightedPool, self).__init__()
|
39 |
+
weight = torch.empty(dim, 1)
|
40 |
+
nn.init.xavier_uniform_(weight)
|
41 |
+
self.weight = nn.Parameter(weight, requires_grad=True)
|
42 |
+
|
43 |
+
def forward(self, x, mask):
|
44 |
+
alpha = torch.tensordot(x, self.weight, dims=1) # shape = (batch_size, seq_length, 1)
|
45 |
+
alpha = mask_logits(alpha, mask=mask.unsqueeze(2))
|
46 |
+
alphas = nn.Softmax(dim=1)(alpha)
|
47 |
+
pooled_x = torch.matmul(x.transpose(1, 2), alphas) # (batch_size, dim, 1)
|
48 |
+
pooled_x = pooled_x.squeeze(2)
|
49 |
+
return pooled_x
|
50 |
+
|
51 |
+
class Model(nn.Module):
|
52 |
+
""" This is the UniVTG module that performs moment localization. """
|
53 |
+
|
54 |
+
def __init__(self, transformer, position_embed, txt_position_embed, txt_dim, vid_dim,
|
55 |
+
input_dropout, aux_loss=False,
|
56 |
+
max_v_l=75, span_loss_type="l1", use_txt_pos=False, n_input_proj=2):
|
57 |
+
""" Initializes the model.
|
58 |
+
Parameters:
|
59 |
+
transformer: torch module of the transformer architecture. See transformer.py
|
60 |
+
position_embed: torch module of the position_embedding, See position_encoding.py
|
61 |
+
txt_position_embed: position_embedding for text
|
62 |
+
txt_dim: int, text query input dimension
|
63 |
+
vid_dim: int, video feature input dimension
|
64 |
+
max_v_l: int, maximum #clips in videos
|
65 |
+
span_loss_type: str, one of [l1, ce]
|
66 |
+
l1: (center-x, width) regression.
|
67 |
+
ce: (st_idx, ed_idx) classification.
|
68 |
+
# foreground_thd: float, intersection over prediction >= foreground_thd: labeled as foreground
|
69 |
+
# background_thd: float, intersection over prediction <= background_thd: labeled background
|
70 |
+
"""
|
71 |
+
super().__init__()
|
72 |
+
self.transformer = transformer
|
73 |
+
self.position_embed = position_embed
|
74 |
+
self.txt_position_embed = txt_position_embed
|
75 |
+
hidden_dim = transformer.d_model
|
76 |
+
self.span_loss_type = span_loss_type
|
77 |
+
self.max_v_l = max_v_l
|
78 |
+
span_pred_dim = 2 if span_loss_type == "l1" else max_v_l * 2
|
79 |
+
|
80 |
+
self.token_type_embeddings = nn.Embedding(2, hidden_dim)
|
81 |
+
self.token_type_embeddings.apply(init_weights)
|
82 |
+
|
83 |
+
# Conv projector
|
84 |
+
self.span_embed = Conv(hidden_dim, hidden_dim, span_pred_dim, 3, kernel_size=3)
|
85 |
+
self.class_embed = Conv(hidden_dim, hidden_dim, 1, 3, kernel_size=3) # 0: background, 1: foreground
|
86 |
+
|
87 |
+
self.use_txt_pos = use_txt_pos
|
88 |
+
self.n_input_proj = n_input_proj
|
89 |
+
relu_args = [True] * 3
|
90 |
+
relu_args[n_input_proj-1] = False
|
91 |
+
self.input_txt_proj = nn.Sequential(*[
|
92 |
+
LinearLayer(txt_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[0]),
|
93 |
+
LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[1]),
|
94 |
+
LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[2])
|
95 |
+
][:n_input_proj])
|
96 |
+
self.input_vid_proj = nn.Sequential(*[
|
97 |
+
LinearLayer(vid_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[0]),
|
98 |
+
LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[1]),
|
99 |
+
LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[2])
|
100 |
+
][:n_input_proj])
|
101 |
+
|
102 |
+
# MLP Projector
|
103 |
+
self.weightedpool = WeightedPool(hidden_dim)
|
104 |
+
|
105 |
+
def forward(self, src_txt, src_txt_mask, src_vid, src_vid_mask, src_cls=None, src_cls_mask=None):
|
106 |
+
bs = src_vid.shape[0]
|
107 |
+
src_vid = self.input_vid_proj(src_vid)
|
108 |
+
src_txt = self.input_txt_proj(src_txt)
|
109 |
+
if src_cls is not None:
|
110 |
+
src_cls = self.input_txt_proj(src_cls)
|
111 |
+
|
112 |
+
# type token.
|
113 |
+
src_vid = src_vid + self.token_type_embeddings(torch.full_like(src_vid_mask.long(), 1))
|
114 |
+
src_txt = src_txt + self.token_type_embeddings(torch.zeros_like(src_txt_mask.long()))
|
115 |
+
if src_cls is not None:
|
116 |
+
src_cls = src_cls + self.token_type_embeddings(torch.zeros_like(src_cls_mask.long()))
|
117 |
+
|
118 |
+
src = torch.cat([src_vid, src_txt], dim=1) # (bsz, L_vid+L_txt, d)
|
119 |
+
mask = torch.cat([src_vid_mask, src_txt_mask], dim=1).bool() # (bsz, L_vid+L_txt)
|
120 |
+
|
121 |
+
pos_vid = self.position_embed(src_vid, src_vid_mask) # (bsz, L_vid, d)
|
122 |
+
pos_txt = self.txt_position_embed(src_txt) if self.use_txt_pos else torch.zeros_like(src_txt) # (bsz, L_txt, d)
|
123 |
+
pos = torch.cat([pos_vid, pos_txt], dim=1)
|
124 |
+
|
125 |
+
memory = self.transformer(src, ~mask, pos)
|
126 |
+
vid_mem = memory[:, :src_vid.shape[1], :] # (bsz, L_vid, d)
|
127 |
+
|
128 |
+
outputs_class = self.class_embed(vid_mem).sigmoid() # (#layers, batch_size, #queries, #classes)
|
129 |
+
outputs_coord = self.span_embed(vid_mem) # (#layers, bsz, #queries, 2 or max_v_l * 2)
|
130 |
+
|
131 |
+
if self.span_loss_type == "l1":
|
132 |
+
outputs_coord = outputs_coord.sigmoid()
|
133 |
+
idx_mask = torch.tensor((-1, 1)).unsqueeze(0).unsqueeze(0).cuda()
|
134 |
+
idx_mask = idx_mask.repeat(outputs_coord.shape[0], outputs_coord.shape[1], 1)
|
135 |
+
outputs_coord = outputs_coord * idx_mask
|
136 |
+
else:
|
137 |
+
raise NotImplementedError
|
138 |
+
|
139 |
+
out = {'pred_logits': outputs_class, 'pred_spans': outputs_coord,
|
140 |
+
'src_vid_mask': src_vid_mask}
|
141 |
+
|
142 |
+
vid_mem_proj = src_vid
|
143 |
+
|
144 |
+
# word-level -> sentence-level
|
145 |
+
txt_mem_proj = self.weightedpool(src_txt, src_txt_mask).unsqueeze(1)
|
146 |
+
sim = F.cosine_similarity(vid_mem_proj, txt_mem_proj, dim=-1) + (src_vid_mask + 1e-45).log()
|
147 |
+
|
148 |
+
out["vid_mem_proj"] = vid_mem_proj
|
149 |
+
out["txt_mem_proj"] = txt_mem_proj
|
150 |
+
if src_cls is not None:
|
151 |
+
cls_mem_proj = self.weightedpool(src_cls, src_cls_mask)
|
152 |
+
out["cls_mem_proj"] = cls_mem_proj
|
153 |
+
out["saliency_scores"] = sim
|
154 |
+
return out
|
155 |
+
|
156 |
+
class SetCriterion(nn.Module):
|
157 |
+
""" This class computes the loss for DETR.
|
158 |
+
The process happens in two steps:
|
159 |
+
1) we compute hungarian assignment between ground truth boxes and the outputs of the model
|
160 |
+
2) we supervise each pair of matched ground-truth / prediction (supervise class and box)
|
161 |
+
"""
|
162 |
+
|
163 |
+
def __init__(self, matcher, weight_dict, eos_coef, losses, temperature, span_loss_type, max_v_l,
|
164 |
+
saliency_margin=1):
|
165 |
+
""" Create the criterion.
|
166 |
+
Parameters:
|
167 |
+
matcher: module able to compute a matching between targets and proposals
|
168 |
+
weight_dict: dict containing as key the names of the losses and as values their relative weight.
|
169 |
+
eos_coef: relative classification weight applied to the no-object category
|
170 |
+
losses: list of all the losses to be applied. See get_loss for list of available losses.
|
171 |
+
temperature: float, temperature for NCE loss
|
172 |
+
span_loss_type: str, [l1, ce]
|
173 |
+
max_v_l: int,
|
174 |
+
saliency_margin: float
|
175 |
+
"""
|
176 |
+
super().__init__()
|
177 |
+
self.matcher = matcher
|
178 |
+
self.weight_dict = weight_dict
|
179 |
+
self.losses = losses
|
180 |
+
self.temperature = temperature
|
181 |
+
self.span_loss_type = span_loss_type
|
182 |
+
self.max_v_l = max_v_l
|
183 |
+
self.saliency_margin = saliency_margin
|
184 |
+
self.temperature = 0.07
|
185 |
+
|
186 |
+
# foreground and background classification
|
187 |
+
self.foreground_label = 0
|
188 |
+
self.background_label = 1
|
189 |
+
self.eos_coef = eos_coef
|
190 |
+
empty_weight = torch.ones(2)
|
191 |
+
empty_weight[-1] = self.eos_coef # lower weight for background (index 1, foreground index 0)
|
192 |
+
self.register_buffer('empty_weight', empty_weight)
|
193 |
+
|
194 |
+
def loss_spans(self, outputs, targets, indices):
|
195 |
+
assert 'pred_spans' in outputs
|
196 |
+
|
197 |
+
start_spans = targets['timestamp']
|
198 |
+
pred_spans = outputs['pred_spans']
|
199 |
+
src_spans = start_spans + pred_spans
|
200 |
+
gt_spans = targets['span_labels_nn']
|
201 |
+
|
202 |
+
mask = targets['timestamp_mask'].bool()
|
203 |
+
mask_full = targets['timestamp_mask'].unsqueeze(2).repeat(1, 1, 2)
|
204 |
+
mask_valid = targets['timestamp_window'].bool()
|
205 |
+
mask_valid_full = targets['timestamp_window'].unsqueeze(2).repeat(1, 1, 2)
|
206 |
+
|
207 |
+
loss_span = F.smooth_l1_loss(src_spans, gt_spans, reduction='none') * mask_valid_full
|
208 |
+
loss_giou = 1 - torch.diag(generalized_temporal_iou(src_spans[mask_valid], gt_spans[mask_valid]))
|
209 |
+
|
210 |
+
losses = {}
|
211 |
+
losses['loss_b'] = loss_span.sum() / mask_valid.sum()
|
212 |
+
losses['loss_g'] = loss_giou.mean()
|
213 |
+
return losses
|
214 |
+
|
215 |
+
def loss_labels(self, outputs, targets, indices, log=True):
|
216 |
+
saliency_scores = targets["saliency_scores"]
|
217 |
+
if saliency_scores.sum() == 0:
|
218 |
+
return {"loss_f": 0.}
|
219 |
+
|
220 |
+
src_logits = outputs['pred_logits'].squeeze(-1) # (batch_size, #queries, #classes=2)
|
221 |
+
target_classes = targets["saliency_scores"].squeeze()
|
222 |
+
|
223 |
+
weights = torch.ones_like(target_classes).float() * self.empty_weight[1]
|
224 |
+
weights[target_classes.bool()] = self.empty_weight[0]
|
225 |
+
|
226 |
+
loss_ce = F.binary_cross_entropy(src_logits, target_classes.float(), reduction="none")
|
227 |
+
return {"loss_f": loss_ce.sum() / target_classes.sum()}
|
228 |
+
# return {"loss_f": loss_ce.sum() / len(target_classes)}
|
229 |
+
|
230 |
+
# mask = targets['timestamp_mask'].bool()
|
231 |
+
# mask_valid = targets['timestamp_window'].bool()
|
232 |
+
# target_classes = torch.full(src_logits.shape[:2], 0, dtype=torch.int64, device=src_logits.device) # (batch_size, #queries)
|
233 |
+
# target_classes[mask_valid] = 1
|
234 |
+
# # target_classes = targets['timestamp_window'] # soft cls.
|
235 |
+
# target_classes.float()
|
236 |
+
# # pdb.set_trace()
|
237 |
+
|
238 |
+
# weights = torch.zeros_like(target_classes).float()
|
239 |
+
# weights[mask] = self.empty_weight[1]
|
240 |
+
# weights[mask_valid] = self.empty_weight[0]
|
241 |
+
|
242 |
+
# loss_ce = F.binary_cross_entropy(src_logits, target_classes.float(), weight=weights, reduction="none") * mask
|
243 |
+
# # return {"loss_f": loss_ce.sum() / mask.sum()}
|
244 |
+
# return {"loss_f": loss_ce.sum() / mask_valid.sum()}
|
245 |
+
|
246 |
+
def loss_saliency(self, outputs, targets, indices, log=True):
|
247 |
+
"""higher scores for positive clips"""
|
248 |
+
if "saliency_pos_labels" not in targets:
|
249 |
+
return {"loss_s_inter": 0., "loss_s_intra": 0.}
|
250 |
+
saliency_scores = targets["saliency_scores"]
|
251 |
+
if saliency_scores.sum() == 0:
|
252 |
+
return {"loss_s_inter": 0., "loss_s_intra": 0.}
|
253 |
+
|
254 |
+
# * qfvs mil-nce mode
|
255 |
+
pos_indices = saliency_scores.squeeze() > 0
|
256 |
+
|
257 |
+
sim = outputs['saliency_scores']
|
258 |
+
sim_soft = F.softmax(sim / self.temperature, dim=0)
|
259 |
+
sim_log = torch.log(sim_soft[pos_indices])
|
260 |
+
loss_saliency_intra = -sim_log.sum() / len(sim_log)
|
261 |
+
return {"loss_s_inter": 0., "loss_s_intra": loss_saliency_intra}
|
262 |
+
|
263 |
+
# * inter-vid mode
|
264 |
+
# vid_mem_proj = outputs["vid_mem_proj"]
|
265 |
+
# pos_indices = targets["saliency_pos_labels"][:,0].long() # (N, #pairs)
|
266 |
+
# batch_indices = torch.arange(len(vid_mem_proj)).to(vid_mem_proj.device)
|
267 |
+
|
268 |
+
# vid_feats = vid_mem_proj[batch_indices, pos_indices]
|
269 |
+
# txt_feats = outputs["txt_mem_proj"].squeeze(1)
|
270 |
+
# sim = sim_matrix(vid_feats, txt_feats)
|
271 |
+
|
272 |
+
# i_logsm = F.log_softmax(sim / self.temperature, dim=1)
|
273 |
+
# j_logsm = F.log_softmax(sim.t() /self.temperature, dim=1)
|
274 |
+
|
275 |
+
# # sum over positives
|
276 |
+
# idiag = torch.diag(i_logsm)
|
277 |
+
# jdiag = torch.diag(j_logsm)
|
278 |
+
# loss_i = idiag.sum() / len(idiag)
|
279 |
+
# loss_j = jdiag.sum() / len(jdiag)
|
280 |
+
|
281 |
+
# loss_saliency_inter = - loss_i - loss_j
|
282 |
+
|
283 |
+
# # * intra-vid mode
|
284 |
+
# mask = targets['timestamp_mask']
|
285 |
+
# selected_scores = saliency_scores[batch_indices, pos_indices].unsqueeze(-1)
|
286 |
+
# neg_indices_in = (saliency_scores < selected_scores)
|
287 |
+
# neg_indices_in[batch_indices, pos_indices] = True
|
288 |
+
# mask_invalid = neg_indices_in * mask.bool()
|
289 |
+
|
290 |
+
# sim_in = F.cosine_similarity(vid_mem_proj, txt_feats.unsqueeze(1), dim=-1)
|
291 |
+
# sim_in = sim_in + (mask_invalid + 1e-45).log()
|
292 |
+
# logsm_in_i = F.log_softmax(sim_in / self.temperature, dim=1)
|
293 |
+
# logsm_in_j = F.log_softmax(sim_in.t() / self.temperature, dim=1)
|
294 |
+
|
295 |
+
# pos_logsm_in_i = logsm_in_i[batch_indices, pos_indices]
|
296 |
+
# pos_logsm_in_j = logsm_in_j[pos_indices, batch_indices]
|
297 |
+
# loss_in_i = pos_logsm_in_i.sum() / len(pos_logsm_in_i)
|
298 |
+
# loss_in_j = pos_logsm_in_j.sum() / len(pos_logsm_in_j)
|
299 |
+
|
300 |
+
# loss_saliency_intra = - loss_in_i - loss_in_j
|
301 |
+
|
302 |
+
# return {"loss_s_inter": loss_saliency_inter, "loss_s_intra": loss_saliency_intra}
|
303 |
+
|
304 |
+
def loss_saliency_cls(self, outputs, targets, indices, log=True):
|
305 |
+
"""higher scores for positive clips"""
|
306 |
+
if "saliency_pos_labels" not in targets:
|
307 |
+
return {"loss_s_inter": 0., "loss_s_intra": 0.}
|
308 |
+
saliency_scores = targets["saliency_scores"]
|
309 |
+
if saliency_scores.sum() == 0:
|
310 |
+
return {"loss_s_inter": 0., "loss_s_intra": 0.}
|
311 |
+
|
312 |
+
# * inter-vid mode
|
313 |
+
vid_mem_proj = outputs["vid_mem_proj"]
|
314 |
+
pos_indices = targets["saliency_pos_labels"][:,0].long() # (N, #pairs)
|
315 |
+
batch_indices = torch.arange(len(vid_mem_proj)).to(vid_mem_proj.device)
|
316 |
+
|
317 |
+
vid_feats = vid_mem_proj[batch_indices, pos_indices]
|
318 |
+
txt_feats = outputs["txt_mem_proj"].squeeze(1)
|
319 |
+
sim = sim_matrix(vid_feats, txt_feats)
|
320 |
+
|
321 |
+
i_logsm = F.log_softmax(sim / self.temperature, dim=1)
|
322 |
+
j_logsm = F.log_softmax(sim.t() /self.temperature, dim=1)
|
323 |
+
|
324 |
+
# sum over positives
|
325 |
+
idiag = torch.diag(i_logsm)
|
326 |
+
jdiag = torch.diag(j_logsm)
|
327 |
+
loss_i = idiag.sum() / len(idiag)
|
328 |
+
loss_j = jdiag.sum() / len(jdiag)
|
329 |
+
|
330 |
+
loss_saliency_inter = - loss_i - loss_j
|
331 |
+
|
332 |
+
# * intra-vid mode
|
333 |
+
if 'cls_idx' not in targets.keys(): # eval
|
334 |
+
return {"loss_s_inter": loss_saliency_inter}
|
335 |
+
|
336 |
+
cls_indices = targets['cls_idx'].bool()
|
337 |
+
cls_feats = outputs["cls_mem_proj"].squeeze(1)
|
338 |
+
sim_cls = sim_matrix(vid_feats, cls_feats)
|
339 |
+
|
340 |
+
i_logsm_cls = F.log_softmax(sim_cls / self.temperature, dim=1)
|
341 |
+
idiag_cls = i_logsm_cls[cls_indices]
|
342 |
+
loss_cls_i = idiag_cls.sum() / len(idiag_cls)
|
343 |
+
|
344 |
+
loss_saliency_intra = - loss_cls_i
|
345 |
+
|
346 |
+
return {"loss_s_inter": loss_saliency_inter, "loss_s_intra": loss_saliency_intra}
|
347 |
+
|
348 |
+
def get_loss(self, loss, outputs, targets, indices, **kwargs):
|
349 |
+
loss_map = {
|
350 |
+
"spans": self.loss_spans,
|
351 |
+
"labels": self.loss_labels,
|
352 |
+
"saliency": self.loss_saliency,
|
353 |
+
"saliency_cls": self.loss_saliency_cls,
|
354 |
+
}
|
355 |
+
assert loss in loss_map, f'do you really want to compute {loss} loss?'
|
356 |
+
return loss_map[loss](outputs, targets, indices, **kwargs)
|
357 |
+
|
358 |
+
def forward(self, outputs, targets, mask_GT=None):
|
359 |
+
""" This performs the loss computation.
|
360 |
+
Parameters:
|
361 |
+
outputs: dict of tensors, see the output specification of the model for the format
|
362 |
+
targets: list of dicts, such that len(targets) == batch_size.
|
363 |
+
The expected keys in each dict depends on the losses applied, see each loss' doc
|
364 |
+
"""
|
365 |
+
indices = None
|
366 |
+
# Compute all the requested losses
|
367 |
+
losses = {}
|
368 |
+
outputs['pred_logits'] = outputs['pred_logits'].reshape(1, -1).masked_select(mask_GT[0])
|
369 |
+
count = mask_GT.sum()
|
370 |
+
outputs['saliency_scores'] = outputs['saliency_scores'].reshape(1, -1).masked_select(mask_GT[0])
|
371 |
+
# targets['saliency_scores'] = targets['saliency_scores'].masked_select(mask_GT[0])
|
372 |
+
targets['saliency_scores'] = targets['saliency_scores'][0,:count]
|
373 |
+
|
374 |
+
for loss in self.losses:
|
375 |
+
losses.update(self.get_loss(loss, outputs, targets, indices))
|
376 |
+
|
377 |
+
return losses
|
378 |
+
|
379 |
+
class MLP(nn.Module):
|
380 |
+
""" Very simple multi-layer perceptron (also called FFN)"""
|
381 |
+
|
382 |
+
def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
|
383 |
+
super().__init__()
|
384 |
+
self.num_layers = num_layers
|
385 |
+
h = [hidden_dim] * (num_layers - 1)
|
386 |
+
self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
|
387 |
+
|
388 |
+
def forward(self, x):
|
389 |
+
for i, layer in enumerate(self.layers):
|
390 |
+
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
|
391 |
+
return x
|
392 |
+
|
393 |
+
class Conv(nn.Module):
|
394 |
+
""" Very simple multi-layer perceptron (also called FFN)"""
|
395 |
+
|
396 |
+
def __init__(self, input_dim, hidden_dim, output_dim, num_layers, kernel_size):
|
397 |
+
super().__init__()
|
398 |
+
self.num_layers = num_layers
|
399 |
+
h = [hidden_dim] * (num_layers - 1)
|
400 |
+
# self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
|
401 |
+
self.layers = nn.ModuleList(
|
402 |
+
nn.Conv1d(n, k, kernel_size=kernel_size, stride=1, padding=kernel_size//2, dilation=1, groups=1, bias=True, padding_mode='zeros')
|
403 |
+
for n, k in zip([input_dim] + h, h + [output_dim]))
|
404 |
+
def forward(self, x):
|
405 |
+
x = x.permute(0,2,1)
|
406 |
+
for i, layer in enumerate(self.layers):
|
407 |
+
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
|
408 |
+
return x.permute(0, 2, 1)
|
409 |
+
|
410 |
+
class LinearLayer(nn.Module):
|
411 |
+
"""linear layer configurable with layer normalization, dropout, ReLU."""
|
412 |
+
|
413 |
+
def __init__(self, in_hsz, out_hsz, layer_norm=True, dropout=0.1, relu=True):
|
414 |
+
super(LinearLayer, self).__init__()
|
415 |
+
self.relu = relu
|
416 |
+
self.layer_norm = layer_norm
|
417 |
+
if layer_norm:
|
418 |
+
self.LayerNorm = nn.LayerNorm(in_hsz)
|
419 |
+
layers = [
|
420 |
+
nn.Dropout(dropout),
|
421 |
+
nn.Linear(in_hsz, out_hsz)
|
422 |
+
]
|
423 |
+
self.net = nn.Sequential(*layers)
|
424 |
+
|
425 |
+
def forward(self, x):
|
426 |
+
"""(N, L, D)"""
|
427 |
+
if self.layer_norm:
|
428 |
+
x = self.LayerNorm(x)
|
429 |
+
x = self.net(x)
|
430 |
+
if self.relu:
|
431 |
+
x = F.relu(x, inplace=True)
|
432 |
+
return x # (N, L, D)
|
433 |
+
|
434 |
+
|
435 |
+
def build_model(args):
|
436 |
+
device = torch.device(args.device)
|
437 |
+
|
438 |
+
transformer = build_transformer(args)
|
439 |
+
position_embedding, txt_position_embedding = build_position_encoding(args)
|
440 |
+
|
441 |
+
model = Model(
|
442 |
+
transformer,
|
443 |
+
position_embedding,
|
444 |
+
txt_position_embedding,
|
445 |
+
txt_dim=args.t_feat_dim,
|
446 |
+
vid_dim=args.v_feat_dim,
|
447 |
+
input_dropout=args.input_dropout,
|
448 |
+
span_loss_type=args.span_loss_type,
|
449 |
+
use_txt_pos=args.use_txt_pos,
|
450 |
+
n_input_proj=args.n_input_proj,
|
451 |
+
)
|
452 |
+
|
453 |
+
matcher = build_matcher(args)
|
454 |
+
weight_dict = {"loss_b": args.b_loss_coef,
|
455 |
+
"loss_g": args.g_loss_coef,
|
456 |
+
"loss_f": args.f_loss_coef,
|
457 |
+
"loss_s_intra": args.s_loss_intra_coef,
|
458 |
+
"loss_s_inter": args.s_loss_inter_coef}
|
459 |
+
|
460 |
+
if args.dset_type in ['mr', 'vlp']:
|
461 |
+
if 'tal' not in args.train_path:
|
462 |
+
losses = ['spans', 'labels', 'saliency']
|
463 |
+
else:
|
464 |
+
losses = ['spans', 'labels', 'saliency_cls']
|
465 |
+
elif args.dset_type in ['hl', 'vs']:
|
466 |
+
losses = ['labels', 'saliency']
|
467 |
+
|
468 |
+
criterion = SetCriterion(
|
469 |
+
matcher=matcher,
|
470 |
+
weight_dict=weight_dict, losses=losses,
|
471 |
+
eos_coef=args.eos_coef, temperature=args.temperature,
|
472 |
+
span_loss_type=args.span_loss_type, max_v_l=args.max_v_l,
|
473 |
+
saliency_margin=args.saliency_margin,
|
474 |
+
)
|
475 |
+
criterion.to(device)
|
476 |
+
return model, criterion
|
model/base_prompt.py
ADDED
@@ -0,0 +1,460 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pdb
|
2 |
+
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from torch import nn
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
from model.transformer_encoder import build_transformer
|
8 |
+
from model.matcher import build_matcher
|
9 |
+
from model.position_encoding import build_position_encoding
|
10 |
+
from utils.span_utils import generalized_temporal_iou, span_cxw_to_xx
|
11 |
+
|
12 |
+
def init_weights(module):
|
13 |
+
if isinstance(module, (nn.Linear, nn.Embedding)):
|
14 |
+
module.weight.data.normal_(mean=0.0, std=0.02)
|
15 |
+
elif isinstance(module, nn.LayerNorm):
|
16 |
+
module.bias.data.zero_()
|
17 |
+
module.weight.data.fill_(1.0)
|
18 |
+
|
19 |
+
if isinstance(module, nn.Linear) and module.bias is not None:
|
20 |
+
module.bias.data.zero_()
|
21 |
+
|
22 |
+
def mask_logits(inputs, mask, mask_value=-1e30):
|
23 |
+
mask = mask.type(torch.float32)
|
24 |
+
return inputs + (1.0 - mask) * mask_value
|
25 |
+
|
26 |
+
def sim_matrix(a, b, eps=1e-8):
|
27 |
+
"""
|
28 |
+
added eps for numerical stability
|
29 |
+
"""
|
30 |
+
a_n, b_n = a.norm(dim=1)[:, None], b.norm(dim=1)[:, None]
|
31 |
+
a_norm = a / torch.max(a_n, eps * torch.ones_like(a_n))
|
32 |
+
b_norm = b / torch.max(b_n, eps * torch.ones_like(b_n))
|
33 |
+
sim_mt = torch.mm(a_norm, b_norm.transpose(0, 1))
|
34 |
+
return sim_mt
|
35 |
+
|
36 |
+
class WeightedPool(nn.Module):
|
37 |
+
def __init__(self, dim):
|
38 |
+
super(WeightedPool, self).__init__()
|
39 |
+
weight = torch.empty(dim, 1)
|
40 |
+
nn.init.xavier_uniform_(weight)
|
41 |
+
self.weight = nn.Parameter(weight, requires_grad=True)
|
42 |
+
|
43 |
+
def forward(self, x, mask):
|
44 |
+
alpha = torch.tensordot(x, self.weight, dims=1) # shape = (batch_size, seq_length, 1)
|
45 |
+
alpha = mask_logits(alpha, mask=mask.unsqueeze(2))
|
46 |
+
alphas = nn.Softmax(dim=1)(alpha)
|
47 |
+
pooled_x = torch.matmul(x.transpose(1, 2), alphas) # (batch_size, dim, 1)
|
48 |
+
pooled_x = pooled_x.squeeze(2)
|
49 |
+
return pooled_x
|
50 |
+
|
51 |
+
class Model(nn.Module):
|
52 |
+
""" This is the UniVTG module that performs moment localization. """
|
53 |
+
|
54 |
+
def __init__(self, transformer, position_embed, txt_position_embed, txt_dim, vid_dim,
|
55 |
+
input_dropout, aux_loss=False,
|
56 |
+
max_v_l=75, span_loss_type="l1", use_txt_pos=False, n_input_proj=2):
|
57 |
+
""" Initializes the model.
|
58 |
+
Parameters:
|
59 |
+
transformer: torch module of the transformer architecture. See transformer.py
|
60 |
+
position_embed: torch module of the position_embedding, See position_encoding.py
|
61 |
+
txt_position_embed: position_embedding for text
|
62 |
+
txt_dim: int, text query input dimension
|
63 |
+
vid_dim: int, video feature input dimension
|
64 |
+
max_v_l: int, maximum #clips in videos
|
65 |
+
span_loss_type: str, one of [l1, ce]
|
66 |
+
l1: (center-x, width) regression.
|
67 |
+
ce: (st_idx, ed_idx) classification.
|
68 |
+
# foreground_thd: float, intersection over prediction >= foreground_thd: labeled as foreground
|
69 |
+
# background_thd: float, intersection over prediction <= background_thd: labeled background
|
70 |
+
"""
|
71 |
+
super().__init__()
|
72 |
+
self.transformer = transformer
|
73 |
+
self.position_embed = position_embed
|
74 |
+
self.txt_position_embed = txt_position_embed
|
75 |
+
hidden_dim = transformer.d_model
|
76 |
+
self.span_loss_type = span_loss_type
|
77 |
+
self.max_v_l = max_v_l
|
78 |
+
span_pred_dim = 2 if span_loss_type == "l1" else max_v_l * 2
|
79 |
+
|
80 |
+
self.prompt_learner = nn.Embedding(10, hidden_dim)
|
81 |
+
self.token_type_embeddings = nn.Embedding(2, hidden_dim)
|
82 |
+
self.token_type_embeddings.apply(init_weights)
|
83 |
+
|
84 |
+
# Conv projector
|
85 |
+
self.span_embed = Conv(hidden_dim, hidden_dim, span_pred_dim, 3, kernel_size=3)
|
86 |
+
self.class_embed = Conv(hidden_dim, hidden_dim, 1, 3, kernel_size=3) # 0: background, 1: foreground
|
87 |
+
|
88 |
+
self.use_txt_pos = use_txt_pos
|
89 |
+
self.n_input_proj = n_input_proj
|
90 |
+
relu_args = [True] * 3
|
91 |
+
relu_args[n_input_proj-1] = False
|
92 |
+
self.input_txt_proj = nn.Sequential(*[
|
93 |
+
LinearLayer(txt_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[0]),
|
94 |
+
LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[1]),
|
95 |
+
LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[2])
|
96 |
+
][:n_input_proj])
|
97 |
+
self.input_vid_proj = nn.Sequential(*[
|
98 |
+
LinearLayer(vid_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[0]),
|
99 |
+
LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[1]),
|
100 |
+
LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[2])
|
101 |
+
][:n_input_proj])
|
102 |
+
|
103 |
+
# MLP Projector
|
104 |
+
self.weightedpool = WeightedPool(hidden_dim)
|
105 |
+
|
106 |
+
def forward(self, src_txt, src_txt_mask, src_vid, src_vid_mask, src_cls=None, src_cls_mask=None):
|
107 |
+
bs = src_vid.shape[0]
|
108 |
+
src_vid = self.input_vid_proj(src_vid)
|
109 |
+
src_txt = self.input_txt_proj(src_txt)
|
110 |
+
if src_cls is not None:
|
111 |
+
src_cls = self.input_txt_proj(src_cls)
|
112 |
+
|
113 |
+
src_prompt = self.prompt_learner.weight.unsqueeze(0).repeat(bs, 1, 1)
|
114 |
+
src_prompt_mask = torch.ones((bs, src_prompt.shape[1])).cuda()
|
115 |
+
|
116 |
+
if self.training:
|
117 |
+
# src_txt = src_prompt
|
118 |
+
# src_txt_mask = torch.ones_like(src_prompt).cuda()
|
119 |
+
src_txt = torch.cat([src_prompt, src_txt], dim=1)
|
120 |
+
src_txt_mask = torch.cat([src_prompt_mask, src_txt_mask], dim=1)
|
121 |
+
else:
|
122 |
+
src_txt = torch.cat([src_prompt, src_txt], dim=1)
|
123 |
+
src_txt_mask = torch.cat([src_prompt_mask, src_txt_mask], dim=1)
|
124 |
+
|
125 |
+
# type token.
|
126 |
+
src_vid = src_vid + self.token_type_embeddings(torch.full_like(src_vid_mask.long(), 1))
|
127 |
+
src_txt = src_txt + self.token_type_embeddings(torch.zeros_like(src_txt_mask.long()))
|
128 |
+
if src_cls is not None:
|
129 |
+
src_cls = src_cls + self.token_type_embeddings(torch.zeros_like(src_cls_mask.long()))
|
130 |
+
|
131 |
+
src = torch.cat([src_vid, src_txt], dim=1) # (bsz, L_vid+L_txt, d)
|
132 |
+
mask = torch.cat([src_vid_mask, src_txt_mask], dim=1).bool() # (bsz, L_vid+L_txt)
|
133 |
+
|
134 |
+
pos_vid = self.position_embed(src_vid, src_vid_mask) # (bsz, L_vid, d)
|
135 |
+
pos_txt = self.txt_position_embed(src_txt) if self.use_txt_pos else torch.zeros_like(src_txt) # (bsz, L_txt, d)
|
136 |
+
pos = torch.cat([pos_vid, pos_txt], dim=1)
|
137 |
+
|
138 |
+
memory = self.transformer(src, ~mask, pos)
|
139 |
+
vid_mem = memory[:, :src_vid.shape[1], :] # (bsz, L_vid, d)
|
140 |
+
|
141 |
+
outputs_class = self.class_embed(vid_mem).sigmoid() # (#layers, batch_size, #queries, #classes)
|
142 |
+
outputs_coord = self.span_embed(vid_mem) # (#layers, bsz, #queries, 2 or max_v_l * 2)
|
143 |
+
|
144 |
+
if self.span_loss_type == "l1":
|
145 |
+
outputs_coord = outputs_coord.sigmoid()
|
146 |
+
idx_mask = torch.tensor((-1, 1)).unsqueeze(0).unsqueeze(0).cuda()
|
147 |
+
idx_mask = idx_mask.repeat(outputs_coord.shape[0], outputs_coord.shape[1], 1)
|
148 |
+
outputs_coord = outputs_coord * idx_mask
|
149 |
+
else:
|
150 |
+
raise NotImplementedError
|
151 |
+
|
152 |
+
out = {'pred_logits': outputs_class, 'pred_spans': outputs_coord,
|
153 |
+
'src_vid_mask': src_vid_mask}
|
154 |
+
|
155 |
+
vid_mem_proj = src_vid
|
156 |
+
|
157 |
+
# word-level -> sentence-level
|
158 |
+
txt_mem_proj = self.weightedpool(src_txt, src_txt_mask).unsqueeze(1)
|
159 |
+
sim = F.cosine_similarity(vid_mem_proj, txt_mem_proj, dim=-1) + (src_vid_mask + 1e-45).log()
|
160 |
+
|
161 |
+
out["vid_mem_proj"] = vid_mem_proj
|
162 |
+
out["txt_mem_proj"] = txt_mem_proj
|
163 |
+
if src_cls is not None:
|
164 |
+
cls_mem_proj = self.weightedpool(src_cls, src_cls_mask)
|
165 |
+
out["cls_mem_proj"] = cls_mem_proj
|
166 |
+
out["saliency_scores"] = sim
|
167 |
+
return out
|
168 |
+
|
169 |
+
class SetCriterion(nn.Module):
|
170 |
+
""" This class computes the loss for DETR.
|
171 |
+
The process happens in two steps:
|
172 |
+
1) we compute hungarian assignment between ground truth boxes and the outputs of the model
|
173 |
+
2) we supervise each pair of matched ground-truth / prediction (supervise class and box)
|
174 |
+
"""
|
175 |
+
|
176 |
+
def __init__(self, matcher, weight_dict, eos_coef, losses, temperature, span_loss_type, max_v_l,
|
177 |
+
saliency_margin=1):
|
178 |
+
""" Create the criterion.
|
179 |
+
Parameters:
|
180 |
+
matcher: module able to compute a matching between targets and proposals
|
181 |
+
weight_dict: dict containing as key the names of the losses and as values their relative weight.
|
182 |
+
eos_coef: relative classification weight applied to the no-object category
|
183 |
+
losses: list of all the losses to be applied. See get_loss for list of available losses.
|
184 |
+
temperature: float, temperature for NCE loss
|
185 |
+
span_loss_type: str, [l1, ce]
|
186 |
+
max_v_l: int,
|
187 |
+
saliency_margin: float
|
188 |
+
"""
|
189 |
+
super().__init__()
|
190 |
+
self.matcher = matcher
|
191 |
+
self.weight_dict = weight_dict
|
192 |
+
self.losses = losses
|
193 |
+
self.temperature = temperature
|
194 |
+
self.span_loss_type = span_loss_type
|
195 |
+
self.max_v_l = max_v_l
|
196 |
+
self.saliency_margin = saliency_margin
|
197 |
+
self.temperature = 0.07
|
198 |
+
|
199 |
+
# foreground and background classification
|
200 |
+
self.foreground_label = 0
|
201 |
+
self.background_label = 1
|
202 |
+
self.eos_coef = eos_coef
|
203 |
+
empty_weight = torch.ones(2)
|
204 |
+
empty_weight[-1] = self.eos_coef # lower weight for background (index 1, foreground index 0)
|
205 |
+
self.register_buffer('empty_weight', empty_weight)
|
206 |
+
|
207 |
+
def loss_spans(self, outputs, targets, indices):
|
208 |
+
assert 'pred_spans' in outputs
|
209 |
+
|
210 |
+
start_spans = targets['timestamp']
|
211 |
+
pred_spans = outputs['pred_spans']
|
212 |
+
src_spans = start_spans + pred_spans
|
213 |
+
gt_spans = targets['span_labels_nn']
|
214 |
+
|
215 |
+
mask = targets['timestamp_mask'].bool()
|
216 |
+
mask_full = targets['timestamp_mask'].unsqueeze(2).repeat(1, 1, 2)
|
217 |
+
mask_valid = targets['timestamp_window'].bool()
|
218 |
+
mask_valid_full = targets['timestamp_window'].unsqueeze(2).repeat(1, 1, 2)
|
219 |
+
|
220 |
+
loss_span = F.smooth_l1_loss(src_spans, gt_spans, reduction='none') * mask_valid_full
|
221 |
+
loss_giou = 1 - torch.diag(generalized_temporal_iou(src_spans[mask_valid], gt_spans[mask_valid]))
|
222 |
+
|
223 |
+
losses = {}
|
224 |
+
losses['loss_b'] = loss_span.sum() / mask_valid.sum()
|
225 |
+
losses['loss_g'] = loss_giou.mean()
|
226 |
+
return losses
|
227 |
+
|
228 |
+
def loss_labels(self, outputs, targets, indices, log=True):
|
229 |
+
src_logits = outputs['pred_logits'].squeeze(-1) # (batch_size, #queries, #classes=2)
|
230 |
+
mask = targets['timestamp_mask'].bool()
|
231 |
+
mask_valid = targets['timestamp_window'].bool()
|
232 |
+
target_classes = torch.full(src_logits.shape[:2], 0, dtype=torch.int64, device=src_logits.device) # (batch_size, #queries)
|
233 |
+
target_classes[mask_valid] = 1
|
234 |
+
# target_classes = targets['timestamp_window'] # soft cls.
|
235 |
+
target_classes.float()
|
236 |
+
# pdb.set_trace()
|
237 |
+
|
238 |
+
weights = torch.zeros_like(target_classes).float()
|
239 |
+
weights[mask] = self.empty_weight[1]
|
240 |
+
weights[mask_valid] = self.empty_weight[0]
|
241 |
+
|
242 |
+
loss_ce = F.binary_cross_entropy(src_logits, target_classes.float(), weight=weights, reduction="none") * mask
|
243 |
+
return {"loss_f": loss_ce.sum() / mask.sum()}
|
244 |
+
|
245 |
+
def loss_saliency(self, outputs, targets, indices, log=True):
|
246 |
+
"""higher scores for positive clips"""
|
247 |
+
if "saliency_pos_labels" not in targets:
|
248 |
+
return {"loss_s_inter": 0., "loss_s_intra": 0.}
|
249 |
+
saliency_scores = targets["saliency_scores"]
|
250 |
+
if saliency_scores.sum() == 0:
|
251 |
+
return {"loss_s_inter": 0., "loss_s_intra": 0.}
|
252 |
+
|
253 |
+
# * inter-vid mode
|
254 |
+
vid_mem_proj = outputs["vid_mem_proj"]
|
255 |
+
pos_indices = targets["saliency_pos_labels"][:,0].long() # (N, #pairs)
|
256 |
+
batch_indices = torch.arange(len(vid_mem_proj)).to(vid_mem_proj.device)
|
257 |
+
|
258 |
+
vid_feats = vid_mem_proj[batch_indices, pos_indices]
|
259 |
+
txt_feats = outputs["txt_mem_proj"].squeeze(1)
|
260 |
+
sim = sim_matrix(vid_feats, txt_feats)
|
261 |
+
|
262 |
+
i_logsm = F.log_softmax(sim / self.temperature, dim=1)
|
263 |
+
j_logsm = F.log_softmax(sim.t() /self.temperature, dim=1)
|
264 |
+
|
265 |
+
# sum over positives
|
266 |
+
idiag = torch.diag(i_logsm)
|
267 |
+
jdiag = torch.diag(j_logsm)
|
268 |
+
loss_i = idiag.sum() / len(idiag)
|
269 |
+
loss_j = jdiag.sum() / len(jdiag)
|
270 |
+
|
271 |
+
loss_saliency_inter = - loss_i - loss_j
|
272 |
+
|
273 |
+
# * intra-vid mode
|
274 |
+
mask = targets['timestamp_mask']
|
275 |
+
selected_scores = saliency_scores[batch_indices, pos_indices].unsqueeze(-1)
|
276 |
+
neg_indices_in = (saliency_scores < selected_scores)
|
277 |
+
neg_indices_in[batch_indices, pos_indices] = True
|
278 |
+
mask_invalid = neg_indices_in * mask.bool()
|
279 |
+
|
280 |
+
sim_in = F.cosine_similarity(vid_mem_proj, txt_feats.unsqueeze(1), dim=-1)
|
281 |
+
sim_in = sim_in + (mask_invalid + 1e-45).log()
|
282 |
+
logsm_in_i = F.log_softmax(sim_in / self.temperature, dim=1)
|
283 |
+
logsm_in_j = F.log_softmax(sim_in.t() / self.temperature, dim=1)
|
284 |
+
|
285 |
+
pos_logsm_in_i = logsm_in_i[batch_indices, pos_indices]
|
286 |
+
pos_logsm_in_j = logsm_in_j[pos_indices, batch_indices]
|
287 |
+
loss_in_i = pos_logsm_in_i.sum() / len(pos_logsm_in_i)
|
288 |
+
loss_in_j = pos_logsm_in_j.sum() / len(pos_logsm_in_j)
|
289 |
+
|
290 |
+
loss_saliency_intra = - loss_in_i - loss_in_j
|
291 |
+
|
292 |
+
return {"loss_s_inter": loss_saliency_inter, "loss_s_intra": loss_saliency_intra}
|
293 |
+
|
294 |
+
def loss_saliency_cls(self, outputs, targets, indices, log=True):
|
295 |
+
"""higher scores for positive clips"""
|
296 |
+
if "saliency_pos_labels" not in targets:
|
297 |
+
return {"loss_s_inter": 0., "loss_s_intra": 0.}
|
298 |
+
saliency_scores = targets["saliency_scores"]
|
299 |
+
if saliency_scores.sum() == 0:
|
300 |
+
return {"loss_s_inter": 0., "loss_s_intra": 0.}
|
301 |
+
|
302 |
+
# * inter-vid mode
|
303 |
+
vid_mem_proj = outputs["vid_mem_proj"]
|
304 |
+
pos_indices = targets["saliency_pos_labels"][:,0].long() # (N, #pairs)
|
305 |
+
batch_indices = torch.arange(len(vid_mem_proj)).to(vid_mem_proj.device)
|
306 |
+
|
307 |
+
vid_feats = vid_mem_proj[batch_indices, pos_indices]
|
308 |
+
txt_feats = outputs["txt_mem_proj"].squeeze(1)
|
309 |
+
sim = sim_matrix(vid_feats, txt_feats)
|
310 |
+
|
311 |
+
i_logsm = F.log_softmax(sim / self.temperature, dim=1)
|
312 |
+
j_logsm = F.log_softmax(sim.t() /self.temperature, dim=1)
|
313 |
+
|
314 |
+
# sum over positives
|
315 |
+
idiag = torch.diag(i_logsm)
|
316 |
+
jdiag = torch.diag(j_logsm)
|
317 |
+
loss_i = idiag.sum() / len(idiag)
|
318 |
+
loss_j = jdiag.sum() / len(jdiag)
|
319 |
+
|
320 |
+
loss_saliency_inter = - loss_i - loss_j
|
321 |
+
|
322 |
+
# * intra-vid mode
|
323 |
+
if 'cls_idx' not in targets.keys(): # eval
|
324 |
+
return {"loss_s_inter": loss_saliency_inter}
|
325 |
+
|
326 |
+
cls_indices = targets['cls_idx'].bool()
|
327 |
+
cls_feats = outputs["cls_mem_proj"].squeeze(1)
|
328 |
+
sim_cls = sim_matrix(vid_feats, cls_feats)
|
329 |
+
|
330 |
+
i_logsm_cls = F.log_softmax(sim_cls / self.temperature, dim=1)
|
331 |
+
idiag_cls = i_logsm_cls[cls_indices]
|
332 |
+
loss_cls_i = idiag_cls.sum() / len(idiag_cls)
|
333 |
+
|
334 |
+
loss_saliency_intra = - loss_cls_i
|
335 |
+
|
336 |
+
return {"loss_s_inter": loss_saliency_inter, "loss_s_intra": loss_saliency_intra}
|
337 |
+
|
338 |
+
def get_loss(self, loss, outputs, targets, indices, **kwargs):
|
339 |
+
loss_map = {
|
340 |
+
"spans": self.loss_spans,
|
341 |
+
"labels": self.loss_labels,
|
342 |
+
"saliency": self.loss_saliency,
|
343 |
+
"saliency_cls": self.loss_saliency_cls,
|
344 |
+
}
|
345 |
+
assert loss in loss_map, f'do you really want to compute {loss} loss?'
|
346 |
+
return loss_map[loss](outputs, targets, indices, **kwargs)
|
347 |
+
|
348 |
+
def forward(self, outputs, targets, hl_only=False):
|
349 |
+
""" This performs the loss computation.
|
350 |
+
Parameters:
|
351 |
+
outputs: dict of tensors, see the output specification of the model for the format
|
352 |
+
targets: list of dicts, such that len(targets) == batch_size.
|
353 |
+
The expected keys in each dict depends on the losses applied, see each loss' doc
|
354 |
+
"""
|
355 |
+
indices = None
|
356 |
+
# Compute all the requested losses
|
357 |
+
losses = {}
|
358 |
+
for loss in self.losses:
|
359 |
+
losses.update(self.get_loss(loss, outputs, targets, indices))
|
360 |
+
|
361 |
+
return losses
|
362 |
+
|
363 |
+
class MLP(nn.Module):
|
364 |
+
""" Very simple multi-layer perceptron (also called FFN)"""
|
365 |
+
|
366 |
+
def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
|
367 |
+
super().__init__()
|
368 |
+
self.num_layers = num_layers
|
369 |
+
h = [hidden_dim] * (num_layers - 1)
|
370 |
+
self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
|
371 |
+
|
372 |
+
def forward(self, x):
|
373 |
+
for i, layer in enumerate(self.layers):
|
374 |
+
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
|
375 |
+
return x
|
376 |
+
|
377 |
+
class Conv(nn.Module):
|
378 |
+
""" Very simple multi-layer perceptron (also called FFN)"""
|
379 |
+
|
380 |
+
def __init__(self, input_dim, hidden_dim, output_dim, num_layers, kernel_size):
|
381 |
+
super().__init__()
|
382 |
+
self.num_layers = num_layers
|
383 |
+
h = [hidden_dim] * (num_layers - 1)
|
384 |
+
# self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
|
385 |
+
self.layers = nn.ModuleList(
|
386 |
+
nn.Conv1d(n, k, kernel_size=kernel_size, stride=1, padding=kernel_size//2, dilation=1, groups=1, bias=True, padding_mode='zeros')
|
387 |
+
for n, k in zip([input_dim] + h, h + [output_dim]))
|
388 |
+
def forward(self, x):
|
389 |
+
x = x.permute(0,2,1)
|
390 |
+
for i, layer in enumerate(self.layers):
|
391 |
+
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
|
392 |
+
return x.permute(0, 2, 1)
|
393 |
+
|
394 |
+
class LinearLayer(nn.Module):
|
395 |
+
"""linear layer configurable with layer normalization, dropout, ReLU."""
|
396 |
+
|
397 |
+
def __init__(self, in_hsz, out_hsz, layer_norm=True, dropout=0.1, relu=True):
|
398 |
+
super(LinearLayer, self).__init__()
|
399 |
+
self.relu = relu
|
400 |
+
self.layer_norm = layer_norm
|
401 |
+
if layer_norm:
|
402 |
+
self.LayerNorm = nn.LayerNorm(in_hsz)
|
403 |
+
layers = [
|
404 |
+
nn.Dropout(dropout),
|
405 |
+
nn.Linear(in_hsz, out_hsz)
|
406 |
+
]
|
407 |
+
self.net = nn.Sequential(*layers)
|
408 |
+
|
409 |
+
def forward(self, x):
|
410 |
+
"""(N, L, D)"""
|
411 |
+
if self.layer_norm:
|
412 |
+
x = self.LayerNorm(x)
|
413 |
+
x = self.net(x)
|
414 |
+
if self.relu:
|
415 |
+
x = F.relu(x, inplace=True)
|
416 |
+
return x # (N, L, D)
|
417 |
+
|
418 |
+
|
419 |
+
def build_model(args):
|
420 |
+
device = torch.device(args.device)
|
421 |
+
|
422 |
+
transformer = build_transformer(args)
|
423 |
+
position_embedding, txt_position_embedding = build_position_encoding(args)
|
424 |
+
|
425 |
+
model = Model(
|
426 |
+
transformer,
|
427 |
+
position_embedding,
|
428 |
+
txt_position_embedding,
|
429 |
+
txt_dim=args.t_feat_dim,
|
430 |
+
vid_dim=args.v_feat_dim,
|
431 |
+
input_dropout=args.input_dropout,
|
432 |
+
span_loss_type=args.span_loss_type,
|
433 |
+
use_txt_pos=args.use_txt_pos,
|
434 |
+
n_input_proj=args.n_input_proj,
|
435 |
+
)
|
436 |
+
|
437 |
+
matcher = build_matcher(args)
|
438 |
+
weight_dict = {"loss_b": args.b_loss_coef,
|
439 |
+
"loss_g": args.g_loss_coef,
|
440 |
+
"loss_f": args.f_loss_coef,
|
441 |
+
"loss_s_intra": args.s_loss_intra_coef,
|
442 |
+
"loss_s_inter": args.s_loss_inter_coef}
|
443 |
+
|
444 |
+
if args.dset_type in ['mr']:
|
445 |
+
if 'tal' not in args.train_path:
|
446 |
+
losses = ['spans', 'labels', 'saliency']
|
447 |
+
else:
|
448 |
+
losses = ['spans', 'labels', 'saliency_cls']
|
449 |
+
elif args.dset_type in ['hl', 'vs']:
|
450 |
+
losses = ['labels', 'saliency']
|
451 |
+
|
452 |
+
criterion = SetCriterion(
|
453 |
+
matcher=matcher,
|
454 |
+
weight_dict=weight_dict, losses=losses,
|
455 |
+
eos_coef=args.eos_coef, temperature=args.temperature,
|
456 |
+
span_loss_type=args.span_loss_type, max_v_l=args.max_v_l,
|
457 |
+
saliency_margin=args.saliency_margin,
|
458 |
+
)
|
459 |
+
criterion.to(device)
|
460 |
+
return model, criterion
|
model/base_qfvs.py
ADDED
@@ -0,0 +1,476 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pdb
|
2 |
+
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from torch import nn
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
from model.transformer_encoder_droppath import build_transformer
|
8 |
+
from model.matcher import build_matcher
|
9 |
+
from model.position_encoding import build_position_encoding
|
10 |
+
from utils.span_utils import generalized_temporal_iou, span_cxw_to_xx
|
11 |
+
|
12 |
+
def init_weights(module):
|
13 |
+
if isinstance(module, (nn.Linear, nn.Embedding)):
|
14 |
+
module.weight.data.normal_(mean=0.0, std=0.02)
|
15 |
+
elif isinstance(module, nn.LayerNorm):
|
16 |
+
module.bias.data.zero_()
|
17 |
+
module.weight.data.fill_(1.0)
|
18 |
+
|
19 |
+
if isinstance(module, nn.Linear) and module.bias is not None:
|
20 |
+
module.bias.data.zero_()
|
21 |
+
|
22 |
+
def mask_logits(inputs, mask, mask_value=-1e30):
|
23 |
+
mask = mask.type(torch.float32)
|
24 |
+
return inputs + (1.0 - mask) * mask_value
|
25 |
+
|
26 |
+
def sim_matrix(a, b, eps=1e-8):
|
27 |
+
"""
|
28 |
+
added eps for numerical stability
|
29 |
+
"""
|
30 |
+
a_n, b_n = a.norm(dim=1)[:, None], b.norm(dim=1)[:, None]
|
31 |
+
a_norm = a / torch.max(a_n, eps * torch.ones_like(a_n))
|
32 |
+
b_norm = b / torch.max(b_n, eps * torch.ones_like(b_n))
|
33 |
+
sim_mt = torch.mm(a_norm, b_norm.transpose(0, 1))
|
34 |
+
return sim_mt
|
35 |
+
|
36 |
+
class WeightedPool(nn.Module):
|
37 |
+
def __init__(self, dim):
|
38 |
+
super(WeightedPool, self).__init__()
|
39 |
+
weight = torch.empty(dim, 1)
|
40 |
+
nn.init.xavier_uniform_(weight)
|
41 |
+
self.weight = nn.Parameter(weight, requires_grad=True)
|
42 |
+
|
43 |
+
def forward(self, x, mask):
|
44 |
+
alpha = torch.tensordot(x, self.weight, dims=1) # shape = (batch_size, seq_length, 1)
|
45 |
+
alpha = mask_logits(alpha, mask=mask.unsqueeze(2))
|
46 |
+
alphas = nn.Softmax(dim=1)(alpha)
|
47 |
+
pooled_x = torch.matmul(x.transpose(1, 2), alphas) # (batch_size, dim, 1)
|
48 |
+
pooled_x = pooled_x.squeeze(2)
|
49 |
+
return pooled_x
|
50 |
+
|
51 |
+
class Model(nn.Module):
|
52 |
+
""" This is the UniVTG module that performs moment localization. """
|
53 |
+
|
54 |
+
def __init__(self, transformer, position_embed, txt_position_embed, txt_dim, vid_dim,
|
55 |
+
input_dropout, aux_loss=False,
|
56 |
+
max_v_l=75, span_loss_type="l1", use_txt_pos=False, n_input_proj=2):
|
57 |
+
""" Initializes the model.
|
58 |
+
Parameters:
|
59 |
+
transformer: torch module of the transformer architecture. See transformer.py
|
60 |
+
position_embed: torch module of the position_embedding, See position_encoding.py
|
61 |
+
txt_position_embed: position_embedding for text
|
62 |
+
txt_dim: int, text query input dimension
|
63 |
+
vid_dim: int, video feature input dimension
|
64 |
+
max_v_l: int, maximum #clips in videos
|
65 |
+
span_loss_type: str, one of [l1, ce]
|
66 |
+
l1: (center-x, width) regression.
|
67 |
+
ce: (st_idx, ed_idx) classification.
|
68 |
+
# foreground_thd: float, intersection over prediction >= foreground_thd: labeled as foreground
|
69 |
+
# background_thd: float, intersection over prediction <= background_thd: labeled background
|
70 |
+
"""
|
71 |
+
super().__init__()
|
72 |
+
self.transformer = transformer
|
73 |
+
self.position_embed = position_embed
|
74 |
+
self.txt_position_embed = txt_position_embed
|
75 |
+
hidden_dim = transformer.d_model
|
76 |
+
self.span_loss_type = span_loss_type
|
77 |
+
self.max_v_l = max_v_l
|
78 |
+
span_pred_dim = 2 if span_loss_type == "l1" else max_v_l * 2
|
79 |
+
|
80 |
+
self.token_type_embeddings = nn.Embedding(2, hidden_dim)
|
81 |
+
self.token_type_embeddings.apply(init_weights)
|
82 |
+
|
83 |
+
# Conv projector
|
84 |
+
self.span_embed = Conv(hidden_dim, hidden_dim, span_pred_dim, 3, kernel_size=3)
|
85 |
+
self.class_embed = Conv(hidden_dim, hidden_dim, 1, 3, kernel_size=3) # 0: background, 1: foreground
|
86 |
+
|
87 |
+
self.use_txt_pos = use_txt_pos
|
88 |
+
self.n_input_proj = n_input_proj
|
89 |
+
relu_args = [True] * 3
|
90 |
+
relu_args[n_input_proj-1] = False
|
91 |
+
self.input_txt_proj = nn.Sequential(*[
|
92 |
+
LinearLayer(txt_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[0]),
|
93 |
+
LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[1]),
|
94 |
+
LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[2])
|
95 |
+
][:n_input_proj])
|
96 |
+
self.input_vid_proj = nn.Sequential(*[
|
97 |
+
LinearLayer(vid_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[0]),
|
98 |
+
LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[1]),
|
99 |
+
LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[2])
|
100 |
+
][:n_input_proj])
|
101 |
+
|
102 |
+
# MLP Projector
|
103 |
+
self.weightedpool = WeightedPool(hidden_dim)
|
104 |
+
|
105 |
+
def forward(self, src_txt, src_txt_mask, src_vid, src_vid_mask, src_cls=None, src_cls_mask=None):
|
106 |
+
bs = src_vid.shape[0]
|
107 |
+
src_vid = self.input_vid_proj(src_vid)
|
108 |
+
src_txt = self.input_txt_proj(src_txt)
|
109 |
+
if src_cls is not None:
|
110 |
+
src_cls = self.input_txt_proj(src_cls)
|
111 |
+
|
112 |
+
# type token.
|
113 |
+
src_vid = src_vid + self.token_type_embeddings(torch.full_like(src_vid_mask.long(), 1))
|
114 |
+
src_txt = src_txt + self.token_type_embeddings(torch.zeros_like(src_txt_mask.long()))
|
115 |
+
if src_cls is not None:
|
116 |
+
src_cls = src_cls + self.token_type_embeddings(torch.zeros_like(src_cls_mask.long()))
|
117 |
+
|
118 |
+
src = torch.cat([src_vid, src_txt], dim=1) # (bsz, L_vid+L_txt, d)
|
119 |
+
mask = torch.cat([src_vid_mask, src_txt_mask], dim=1).bool() # (bsz, L_vid+L_txt)
|
120 |
+
|
121 |
+
pos_vid = self.position_embed(src_vid, src_vid_mask) # (bsz, L_vid, d)
|
122 |
+
pos_txt = self.txt_position_embed(src_txt) if self.use_txt_pos else torch.zeros_like(src_txt) # (bsz, L_txt, d)
|
123 |
+
pos = torch.cat([pos_vid, pos_txt], dim=1)
|
124 |
+
|
125 |
+
memory = self.transformer(src, ~mask, pos)
|
126 |
+
vid_mem = memory[:, :src_vid.shape[1], :] # (bsz, L_vid, d)
|
127 |
+
|
128 |
+
outputs_class = self.class_embed(vid_mem).sigmoid() # (#layers, batch_size, #queries, #classes)
|
129 |
+
outputs_coord = self.span_embed(vid_mem) # (#layers, bsz, #queries, 2 or max_v_l * 2)
|
130 |
+
|
131 |
+
if self.span_loss_type == "l1":
|
132 |
+
outputs_coord = outputs_coord.sigmoid()
|
133 |
+
idx_mask = torch.tensor((-1, 1)).unsqueeze(0).unsqueeze(0).cuda()
|
134 |
+
idx_mask = idx_mask.repeat(outputs_coord.shape[0], outputs_coord.shape[1], 1)
|
135 |
+
outputs_coord = outputs_coord * idx_mask
|
136 |
+
else:
|
137 |
+
raise NotImplementedError
|
138 |
+
|
139 |
+
out = {'pred_logits': outputs_class, 'pred_spans': outputs_coord,
|
140 |
+
'src_vid_mask': src_vid_mask}
|
141 |
+
|
142 |
+
vid_mem_proj = src_vid
|
143 |
+
|
144 |
+
# word-level -> sentence-level
|
145 |
+
txt_mem_proj = self.weightedpool(src_txt, src_txt_mask).unsqueeze(1)
|
146 |
+
sim = F.cosine_similarity(vid_mem_proj, txt_mem_proj, dim=-1) + (src_vid_mask + 1e-45).log()
|
147 |
+
|
148 |
+
out["vid_mem_proj"] = vid_mem_proj
|
149 |
+
out["txt_mem_proj"] = txt_mem_proj
|
150 |
+
if src_cls is not None:
|
151 |
+
cls_mem_proj = self.weightedpool(src_cls, src_cls_mask)
|
152 |
+
out["cls_mem_proj"] = cls_mem_proj
|
153 |
+
out["saliency_scores"] = sim
|
154 |
+
return out
|
155 |
+
|
156 |
+
class SetCriterion(nn.Module):
|
157 |
+
""" This class computes the loss for DETR.
|
158 |
+
The process happens in two steps:
|
159 |
+
1) we compute hungarian assignment between ground truth boxes and the outputs of the model
|
160 |
+
2) we supervise each pair of matched ground-truth / prediction (supervise class and box)
|
161 |
+
"""
|
162 |
+
|
163 |
+
def __init__(self, matcher, weight_dict, eos_coef, losses, temperature, span_loss_type, max_v_l,
|
164 |
+
saliency_margin=1):
|
165 |
+
""" Create the criterion.
|
166 |
+
Parameters:
|
167 |
+
matcher: module able to compute a matching between targets and proposals
|
168 |
+
weight_dict: dict containing as key the names of the losses and as values their relative weight.
|
169 |
+
eos_coef: relative classification weight applied to the no-object category
|
170 |
+
losses: list of all the losses to be applied. See get_loss for list of available losses.
|
171 |
+
temperature: float, temperature for NCE loss
|
172 |
+
span_loss_type: str, [l1, ce]
|
173 |
+
max_v_l: int,
|
174 |
+
saliency_margin: float
|
175 |
+
"""
|
176 |
+
super().__init__()
|
177 |
+
self.matcher = matcher
|
178 |
+
self.weight_dict = weight_dict
|
179 |
+
self.losses = losses
|
180 |
+
self.temperature = temperature
|
181 |
+
self.span_loss_type = span_loss_type
|
182 |
+
self.max_v_l = max_v_l
|
183 |
+
self.saliency_margin = saliency_margin
|
184 |
+
self.temperature = 0.07
|
185 |
+
|
186 |
+
# foreground and background classification
|
187 |
+
self.foreground_label = 0
|
188 |
+
self.background_label = 1
|
189 |
+
self.eos_coef = eos_coef
|
190 |
+
empty_weight = torch.ones(2)
|
191 |
+
empty_weight[-1] = self.eos_coef # lower weight for background (index 1, foreground index 0)
|
192 |
+
self.register_buffer('empty_weight', empty_weight)
|
193 |
+
|
194 |
+
def loss_spans(self, outputs, targets, indices):
|
195 |
+
assert 'pred_spans' in outputs
|
196 |
+
|
197 |
+
start_spans = targets['timestamp']
|
198 |
+
pred_spans = outputs['pred_spans']
|
199 |
+
src_spans = start_spans + pred_spans
|
200 |
+
gt_spans = targets['span_labels_nn']
|
201 |
+
|
202 |
+
mask = targets['timestamp_mask'].bool()
|
203 |
+
mask_full = targets['timestamp_mask'].unsqueeze(2).repeat(1, 1, 2)
|
204 |
+
mask_valid = targets['timestamp_window'].bool()
|
205 |
+
mask_valid_full = targets['timestamp_window'].unsqueeze(2).repeat(1, 1, 2)
|
206 |
+
|
207 |
+
loss_span = F.smooth_l1_loss(src_spans, gt_spans, reduction='none') * mask_valid_full
|
208 |
+
loss_giou = 1 - torch.diag(generalized_temporal_iou(src_spans[mask_valid], gt_spans[mask_valid]))
|
209 |
+
|
210 |
+
losses = {}
|
211 |
+
losses['loss_b'] = loss_span.sum() / mask_valid.sum()
|
212 |
+
losses['loss_g'] = loss_giou.mean()
|
213 |
+
return losses
|
214 |
+
|
215 |
+
def loss_labels(self, outputs, targets, indices, log=True):
|
216 |
+
saliency_scores = targets["saliency_scores"]
|
217 |
+
if saliency_scores.sum() == 0:
|
218 |
+
return {"loss_f": 0.}
|
219 |
+
|
220 |
+
src_logits = outputs['pred_logits'].squeeze(-1) # (batch_size, #queries, #classes=2)
|
221 |
+
target_classes = targets["saliency_scores"].squeeze()
|
222 |
+
|
223 |
+
weights = torch.ones_like(target_classes).float() * self.empty_weight[1]
|
224 |
+
weights[target_classes.bool()] = self.empty_weight[0]
|
225 |
+
|
226 |
+
loss_ce = F.binary_cross_entropy(src_logits, target_classes.float(), reduction="none")
|
227 |
+
# pdb.set_trace()
|
228 |
+
return {"loss_f": loss_ce.sum() / target_classes.sum()}
|
229 |
+
# return {"loss_f": loss_ce.sum() / len(target_classes)}
|
230 |
+
|
231 |
+
# mask = targets['timestamp_mask'].bool()
|
232 |
+
# mask_valid = targets['timestamp_window'].bool()
|
233 |
+
# target_classes = torch.full(src_logits.shape[:2], 0, dtype=torch.int64, device=src_logits.device) # (batch_size, #queries)
|
234 |
+
# target_classes[mask_valid] = 1
|
235 |
+
# # target_classes = targets['timestamp_window'] # soft cls.
|
236 |
+
# target_classes.float()
|
237 |
+
# # pdb.set_trace()
|
238 |
+
|
239 |
+
# weights = torch.zeros_like(target_classes).float()
|
240 |
+
# weights[mask] = self.empty_weight[1]
|
241 |
+
# weights[mask_valid] = self.empty_weight[0]
|
242 |
+
|
243 |
+
# loss_ce = F.binary_cross_entropy(src_logits, target_classes.float(), weight=weights, reduction="none") * mask
|
244 |
+
# # return {"loss_f": loss_ce.sum() / mask.sum()}
|
245 |
+
# return {"loss_f": loss_ce.sum() / mask_valid.sum()}
|
246 |
+
|
247 |
+
def loss_saliency(self, outputs, targets, indices, log=True):
|
248 |
+
"""higher scores for positive clips"""
|
249 |
+
if "saliency_pos_labels" not in targets:
|
250 |
+
return {"loss_s_inter": 0., "loss_s_intra": 0.}
|
251 |
+
saliency_scores = targets["saliency_scores"]
|
252 |
+
if saliency_scores.sum() == 0:
|
253 |
+
return {"loss_s_inter": 0., "loss_s_intra": 0.}
|
254 |
+
|
255 |
+
# * qfvs mil-nce mode
|
256 |
+
pos_indices = saliency_scores.squeeze() > 0
|
257 |
+
|
258 |
+
sim = outputs['saliency_scores']
|
259 |
+
sim_soft = F.softmax(sim / self.temperature, dim=0)
|
260 |
+
sim_log = torch.log(sim_soft[pos_indices])
|
261 |
+
loss_saliency_intra = -sim_log.sum() / len(sim_log)
|
262 |
+
return {"loss_s_inter": 0., "loss_s_intra": loss_saliency_intra}
|
263 |
+
|
264 |
+
# * inter-vid mode
|
265 |
+
# vid_mem_proj = outputs["vid_mem_proj"]
|
266 |
+
# pos_indices = targets["saliency_pos_labels"][:,0].long() # (N, #pairs)
|
267 |
+
# batch_indices = torch.arange(len(vid_mem_proj)).to(vid_mem_proj.device)
|
268 |
+
|
269 |
+
# vid_feats = vid_mem_proj[batch_indices, pos_indices]
|
270 |
+
# txt_feats = outputs["txt_mem_proj"].squeeze(1)
|
271 |
+
# sim = sim_matrix(vid_feats, txt_feats)
|
272 |
+
|
273 |
+
# i_logsm = F.log_softmax(sim / self.temperature, dim=1)
|
274 |
+
# j_logsm = F.log_softmax(sim.t() /self.temperature, dim=1)
|
275 |
+
|
276 |
+
# # sum over positives
|
277 |
+
# idiag = torch.diag(i_logsm)
|
278 |
+
# jdiag = torch.diag(j_logsm)
|
279 |
+
# loss_i = idiag.sum() / len(idiag)
|
280 |
+
# loss_j = jdiag.sum() / len(jdiag)
|
281 |
+
|
282 |
+
# loss_saliency_inter = - loss_i - loss_j
|
283 |
+
|
284 |
+
# # * intra-vid mode
|
285 |
+
# mask = targets['timestamp_mask']
|
286 |
+
# selected_scores = saliency_scores[batch_indices, pos_indices].unsqueeze(-1)
|
287 |
+
# neg_indices_in = (saliency_scores < selected_scores)
|
288 |
+
# neg_indices_in[batch_indices, pos_indices] = True
|
289 |
+
# mask_invalid = neg_indices_in * mask.bool()
|
290 |
+
|
291 |
+
# sim_in = F.cosine_similarity(vid_mem_proj, txt_feats.unsqueeze(1), dim=-1)
|
292 |
+
# sim_in = sim_in + (mask_invalid + 1e-45).log()
|
293 |
+
# logsm_in_i = F.log_softmax(sim_in / self.temperature, dim=1)
|
294 |
+
# logsm_in_j = F.log_softmax(sim_in.t() / self.temperature, dim=1)
|
295 |
+
|
296 |
+
# pos_logsm_in_i = logsm_in_i[batch_indices, pos_indices]
|
297 |
+
# pos_logsm_in_j = logsm_in_j[pos_indices, batch_indices]
|
298 |
+
# loss_in_i = pos_logsm_in_i.sum() / len(pos_logsm_in_i)
|
299 |
+
# loss_in_j = pos_logsm_in_j.sum() / len(pos_logsm_in_j)
|
300 |
+
|
301 |
+
# loss_saliency_intra = - loss_in_i - loss_in_j
|
302 |
+
|
303 |
+
# return {"loss_s_inter": loss_saliency_inter, "loss_s_intra": loss_saliency_intra}
|
304 |
+
|
305 |
+
def loss_saliency_cls(self, outputs, targets, indices, log=True):
|
306 |
+
"""higher scores for positive clips"""
|
307 |
+
if "saliency_pos_labels" not in targets:
|
308 |
+
return {"loss_s_inter": 0., "loss_s_intra": 0.}
|
309 |
+
saliency_scores = targets["saliency_scores"]
|
310 |
+
if saliency_scores.sum() == 0:
|
311 |
+
return {"loss_s_inter": 0., "loss_s_intra": 0.}
|
312 |
+
|
313 |
+
# * inter-vid mode
|
314 |
+
vid_mem_proj = outputs["vid_mem_proj"]
|
315 |
+
pos_indices = targets["saliency_pos_labels"][:,0].long() # (N, #pairs)
|
316 |
+
batch_indices = torch.arange(len(vid_mem_proj)).to(vid_mem_proj.device)
|
317 |
+
|
318 |
+
vid_feats = vid_mem_proj[batch_indices, pos_indices]
|
319 |
+
txt_feats = outputs["txt_mem_proj"].squeeze(1)
|
320 |
+
sim = sim_matrix(vid_feats, txt_feats)
|
321 |
+
|
322 |
+
i_logsm = F.log_softmax(sim / self.temperature, dim=1)
|
323 |
+
j_logsm = F.log_softmax(sim.t() /self.temperature, dim=1)
|
324 |
+
|
325 |
+
# sum over positives
|
326 |
+
idiag = torch.diag(i_logsm)
|
327 |
+
jdiag = torch.diag(j_logsm)
|
328 |
+
loss_i = idiag.sum() / len(idiag)
|
329 |
+
loss_j = jdiag.sum() / len(jdiag)
|
330 |
+
|
331 |
+
loss_saliency_inter = - loss_i - loss_j
|
332 |
+
|
333 |
+
# * intra-vid mode
|
334 |
+
if 'cls_idx' not in targets.keys(): # eval
|
335 |
+
return {"loss_s_inter": loss_saliency_inter}
|
336 |
+
|
337 |
+
cls_indices = targets['cls_idx'].bool()
|
338 |
+
cls_feats = outputs["cls_mem_proj"].squeeze(1)
|
339 |
+
sim_cls = sim_matrix(vid_feats, cls_feats)
|
340 |
+
|
341 |
+
i_logsm_cls = F.log_softmax(sim_cls / self.temperature, dim=1)
|
342 |
+
idiag_cls = i_logsm_cls[cls_indices]
|
343 |
+
loss_cls_i = idiag_cls.sum() / len(idiag_cls)
|
344 |
+
|
345 |
+
loss_saliency_intra = - loss_cls_i
|
346 |
+
|
347 |
+
return {"loss_s_inter": loss_saliency_inter, "loss_s_intra": loss_saliency_intra}
|
348 |
+
|
349 |
+
def get_loss(self, loss, outputs, targets, indices, **kwargs):
|
350 |
+
loss_map = {
|
351 |
+
"spans": self.loss_spans,
|
352 |
+
"labels": self.loss_labels,
|
353 |
+
"saliency": self.loss_saliency,
|
354 |
+
"saliency_cls": self.loss_saliency_cls,
|
355 |
+
}
|
356 |
+
assert loss in loss_map, f'do you really want to compute {loss} loss?'
|
357 |
+
return loss_map[loss](outputs, targets, indices, **kwargs)
|
358 |
+
|
359 |
+
def forward(self, outputs, targets, mask_GT=None):
|
360 |
+
""" This performs the loss computation.
|
361 |
+
Parameters:
|
362 |
+
outputs: dict of tensors, see the output specification of the model for the format
|
363 |
+
targets: list of dicts, such that len(targets) == batch_size.
|
364 |
+
The expected keys in each dict depends on the losses applied, see each loss' doc
|
365 |
+
"""
|
366 |
+
indices = None
|
367 |
+
# Compute all the requested losses
|
368 |
+
losses = {}
|
369 |
+
# pdb.set_trace()
|
370 |
+
outputs['pred_logits'] = outputs['pred_logits'].reshape(1, -1).masked_select(mask_GT[0])
|
371 |
+
outputs['saliency_scores'] = outputs['saliency_scores'].reshape(1, -1).masked_select(mask_GT[0])
|
372 |
+
targets['saliency_scores'] = targets['saliency_scores'].masked_select(mask_GT[0])
|
373 |
+
|
374 |
+
for loss in self.losses:
|
375 |
+
losses.update(self.get_loss(loss, outputs, targets, indices))
|
376 |
+
|
377 |
+
return losses
|
378 |
+
|
379 |
+
class MLP(nn.Module):
|
380 |
+
""" Very simple multi-layer perceptron (also called FFN)"""
|
381 |
+
|
382 |
+
def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
|
383 |
+
super().__init__()
|
384 |
+
self.num_layers = num_layers
|
385 |
+
h = [hidden_dim] * (num_layers - 1)
|
386 |
+
self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
|
387 |
+
|
388 |
+
def forward(self, x):
|
389 |
+
for i, layer in enumerate(self.layers):
|
390 |
+
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
|
391 |
+
return x
|
392 |
+
|
393 |
+
class Conv(nn.Module):
|
394 |
+
""" Very simple multi-layer perceptron (also called FFN)"""
|
395 |
+
|
396 |
+
def __init__(self, input_dim, hidden_dim, output_dim, num_layers, kernel_size):
|
397 |
+
super().__init__()
|
398 |
+
self.num_layers = num_layers
|
399 |
+
h = [hidden_dim] * (num_layers - 1)
|
400 |
+
# self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
|
401 |
+
self.layers = nn.ModuleList(
|
402 |
+
nn.Conv1d(n, k, kernel_size=kernel_size, stride=1, padding=kernel_size//2, dilation=1, groups=1, bias=True, padding_mode='zeros')
|
403 |
+
for n, k in zip([input_dim] + h, h + [output_dim]))
|
404 |
+
def forward(self, x):
|
405 |
+
x = x.permute(0,2,1)
|
406 |
+
for i, layer in enumerate(self.layers):
|
407 |
+
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
|
408 |
+
return x.permute(0, 2, 1)
|
409 |
+
|
410 |
+
class LinearLayer(nn.Module):
|
411 |
+
"""linear layer configurable with layer normalization, dropout, ReLU."""
|
412 |
+
|
413 |
+
def __init__(self, in_hsz, out_hsz, layer_norm=True, dropout=0.1, relu=True):
|
414 |
+
super(LinearLayer, self).__init__()
|
415 |
+
self.relu = relu
|
416 |
+
self.layer_norm = layer_norm
|
417 |
+
if layer_norm:
|
418 |
+
self.LayerNorm = nn.LayerNorm(in_hsz)
|
419 |
+
layers = [
|
420 |
+
nn.Dropout(dropout),
|
421 |
+
nn.Linear(in_hsz, out_hsz)
|
422 |
+
]
|
423 |
+
self.net = nn.Sequential(*layers)
|
424 |
+
|
425 |
+
def forward(self, x):
|
426 |
+
"""(N, L, D)"""
|
427 |
+
if self.layer_norm:
|
428 |
+
x = self.LayerNorm(x)
|
429 |
+
x = self.net(x)
|
430 |
+
if self.relu:
|
431 |
+
x = F.relu(x, inplace=True)
|
432 |
+
return x # (N, L, D)
|
433 |
+
|
434 |
+
|
435 |
+
def build_model(args):
|
436 |
+
device = torch.device(args.device)
|
437 |
+
|
438 |
+
transformer = build_transformer(args)
|
439 |
+
position_embedding, txt_position_embedding = build_position_encoding(args)
|
440 |
+
|
441 |
+
model = Model(
|
442 |
+
transformer,
|
443 |
+
position_embedding,
|
444 |
+
txt_position_embedding,
|
445 |
+
txt_dim=args.t_feat_dim,
|
446 |
+
vid_dim=args.v_feat_dim,
|
447 |
+
input_dropout=args.input_dropout,
|
448 |
+
span_loss_type=args.span_loss_type,
|
449 |
+
use_txt_pos=args.use_txt_pos,
|
450 |
+
n_input_proj=args.n_input_proj,
|
451 |
+
)
|
452 |
+
|
453 |
+
matcher = build_matcher(args)
|
454 |
+
weight_dict = {"loss_b": args.b_loss_coef,
|
455 |
+
"loss_g": args.g_loss_coef,
|
456 |
+
"loss_f": args.f_loss_coef,
|
457 |
+
"loss_s_intra": args.s_loss_intra_coef,
|
458 |
+
"loss_s_inter": args.s_loss_inter_coef}
|
459 |
+
|
460 |
+
if args.dset_type in ['mr', 'vlp']:
|
461 |
+
if 'tal' not in args.train_path:
|
462 |
+
losses = ['spans', 'labels', 'saliency']
|
463 |
+
else:
|
464 |
+
losses = ['spans', 'labels', 'saliency_cls']
|
465 |
+
elif args.dset_type in ['hl', 'vs']:
|
466 |
+
losses = ['labels', 'saliency']
|
467 |
+
|
468 |
+
criterion = SetCriterion(
|
469 |
+
matcher=matcher,
|
470 |
+
weight_dict=weight_dict, losses=losses,
|
471 |
+
eos_coef=args.eos_coef, temperature=args.temperature,
|
472 |
+
span_loss_type=args.span_loss_type, max_v_l=args.max_v_l,
|
473 |
+
saliency_margin=args.saliency_margin,
|
474 |
+
)
|
475 |
+
criterion.to(device)
|
476 |
+
return model, criterion
|
model/matcher.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
2 |
+
"""
|
3 |
+
Modules to compute the matching cost and solve the corresponding LSAP.
|
4 |
+
"""
|
5 |
+
import torch
|
6 |
+
from scipy.optimize import linear_sum_assignment
|
7 |
+
from torch import nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
from utils.span_utils import generalized_temporal_iou, span_cxw_to_xx
|
10 |
+
|
11 |
+
|
12 |
+
class HungarianMatcher(nn.Module):
|
13 |
+
"""This class computes an assignment between the targets and the predictions of the network
|
14 |
+
|
15 |
+
For efficiency reasons, the targets don't include the no_object. Because of this, in general,
|
16 |
+
there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions,
|
17 |
+
while the others are un-matched (and thus treated as non-objects).
|
18 |
+
"""
|
19 |
+
def __init__(self, cost_class: float = 1, cost_span: float = 1, cost_giou: float = 1,
|
20 |
+
span_loss_type: str = "l1", max_v_l: int = 75):
|
21 |
+
"""Creates the matcher
|
22 |
+
|
23 |
+
Params:
|
24 |
+
cost_span: This is the relative weight of the L1 error of the span coordinates in the matching cost
|
25 |
+
cost_giou: This is the relative weight of the giou loss of the spans in the matching cost
|
26 |
+
"""
|
27 |
+
super().__init__()
|
28 |
+
self.cost_class = cost_class
|
29 |
+
self.cost_span = cost_span
|
30 |
+
self.cost_giou = cost_giou
|
31 |
+
self.span_loss_type = span_loss_type
|
32 |
+
self.max_v_l = max_v_l
|
33 |
+
self.foreground_label = 0
|
34 |
+
assert cost_class != 0 or cost_span != 0 or cost_giou != 0, "all costs cant be 0"
|
35 |
+
|
36 |
+
@torch.no_grad()
|
37 |
+
def forward(self, outputs, targets):
|
38 |
+
""" Performs the matching
|
39 |
+
|
40 |
+
Params:
|
41 |
+
outputs: This is a dict that contains at least these entries:
|
42 |
+
"pred_spans": Tensor of dim [batch_size, num_queries, 2] with the predicted span coordinates,
|
43 |
+
in normalized (cx, w) format
|
44 |
+
""pred_logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits
|
45 |
+
|
46 |
+
targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing:
|
47 |
+
"spans": Tensor of dim [num_target_spans, 2] containing the target span coordinates. The spans are
|
48 |
+
in normalized (cx, w) format
|
49 |
+
|
50 |
+
Returns:
|
51 |
+
A list of size batch_size, containing tuples of (index_i, index_j) where:
|
52 |
+
- index_i is the indices of the selected predictions (in order)
|
53 |
+
- index_j is the indices of the corresponding selected targets (in order)
|
54 |
+
For each batch element, it holds:
|
55 |
+
len(index_i) = len(index_j) = min(num_queries, num_target_spans)
|
56 |
+
"""
|
57 |
+
bs, num_queries = outputs["pred_spans"].shape[:2]
|
58 |
+
targets = targets["span_labels"]
|
59 |
+
|
60 |
+
# Also concat the target labels and spans
|
61 |
+
out_prob = outputs["pred_logits"].flatten(0, 1).softmax(-1) # [batch_size * num_queries, num_classes]
|
62 |
+
tgt_spans = torch.cat([v["spans"] for v in targets]) # [num_target_spans in batch, 2]
|
63 |
+
tgt_ids = torch.full([len(tgt_spans)], self.foreground_label) # [total #spans in the batch]
|
64 |
+
|
65 |
+
# Compute the classification cost. Contrary to the loss, we don't use the NLL,
|
66 |
+
# but approximate it in 1 - prob[target class].
|
67 |
+
# The 1 is a constant that doesn't change the matching, it can be omitted.
|
68 |
+
cost_class = -out_prob[:, tgt_ids] # [batch_size * num_queries, total #spans in the batch]
|
69 |
+
|
70 |
+
if self.span_loss_type == "l1":
|
71 |
+
# We flatten to compute the cost matrices in a batch
|
72 |
+
out_spans = outputs["pred_spans"].flatten(0, 1) # [batch_size * num_queries, 2]
|
73 |
+
|
74 |
+
# Compute the L1 cost between spans
|
75 |
+
cost_span = torch.cdist(out_spans, tgt_spans, p=1) # [batch_size * num_queries, total #spans in the batch]
|
76 |
+
|
77 |
+
# Compute the giou cost between spans
|
78 |
+
# [batch_size * num_queries, total #spans in the batch]
|
79 |
+
cost_giou = - generalized_temporal_iou(span_cxw_to_xx(out_spans), span_cxw_to_xx(tgt_spans))
|
80 |
+
else:
|
81 |
+
pred_spans = outputs["pred_spans"] # (bsz, #queries, max_v_l * 2)
|
82 |
+
pred_spans = pred_spans.view(bs * num_queries, 2, self.max_v_l).softmax(-1) # (bsz * #queries, 2, max_v_l)
|
83 |
+
cost_span = - pred_spans[:, 0][:, tgt_spans[:, 0]] - \
|
84 |
+
pred_spans[:, 1][:, tgt_spans[:, 1]] # (bsz * #queries, #spans)
|
85 |
+
# pred_spans = pred_spans.repeat(1, n_spans, 1, 1).flatten(0, 1) # (bsz * #queries * #spans, max_v_l, 2)
|
86 |
+
# tgt_spans = tgt_spans.view(1, n_spans, 2).repeat(bs * num_queries, 1, 1).flatten(0, 1) # (bsz * #queries * #spans, 2)
|
87 |
+
# cost_span = pred_spans[tgt_spans]
|
88 |
+
# cost_span = cost_span.view(bs * num_queries, n_spans)
|
89 |
+
|
90 |
+
# giou
|
91 |
+
cost_giou = 0
|
92 |
+
|
93 |
+
# Final cost matrix
|
94 |
+
# import ipdb; ipdb.set_trace()
|
95 |
+
C = self.cost_span * cost_span + self.cost_giou * cost_giou + self.cost_class * cost_class
|
96 |
+
C = C.view(bs, num_queries, -1).cpu()
|
97 |
+
|
98 |
+
sizes = [len(v["spans"]) for v in targets]
|
99 |
+
indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))]
|
100 |
+
return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices]
|
101 |
+
|
102 |
+
|
103 |
+
def build_matcher(args):
|
104 |
+
return HungarianMatcher(
|
105 |
+
cost_span=args.set_cost_span, cost_giou=args.set_cost_giou,
|
106 |
+
cost_class=args.set_cost_class, span_loss_type=args.span_loss_type, max_v_l=args.max_v_l
|
107 |
+
)
|
model/moment_detr.py
ADDED
@@ -0,0 +1,462 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
2 |
+
"""
|
3 |
+
DETR model and criterion classes.
|
4 |
+
"""
|
5 |
+
import torch
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from torch import nn
|
8 |
+
|
9 |
+
from utils.span_utils import generalized_temporal_iou, span_cxw_to_xx
|
10 |
+
|
11 |
+
from model.transformer import build_transformer
|
12 |
+
from model.matcher import build_matcher
|
13 |
+
from model.position_encoding import build_position_encoding
|
14 |
+
|
15 |
+
@torch.no_grad()
|
16 |
+
def accuracy(output, target, topk=(1,)):
|
17 |
+
"""Computes the precision@k for the specified values of k
|
18 |
+
output: (#items, #classes)
|
19 |
+
target: int,
|
20 |
+
"""
|
21 |
+
maxk = max(topk)
|
22 |
+
num_items = output.size(0)
|
23 |
+
|
24 |
+
_, pred = output.topk(maxk, 1, True, True)
|
25 |
+
pred = pred.t()
|
26 |
+
correct = pred.eq(target)
|
27 |
+
|
28 |
+
res = []
|
29 |
+
for k in topk:
|
30 |
+
correct_k = correct[:k].view(-1).float().sum(0)
|
31 |
+
res.append(correct_k.mul_(100.0 / num_items))
|
32 |
+
return res
|
33 |
+
|
34 |
+
class Model(nn.Module):
|
35 |
+
""" This is the Moment-DETR module that performs moment localization. """
|
36 |
+
|
37 |
+
def __init__(self, transformer, position_embed, txt_position_embed, txt_dim, vid_dim,
|
38 |
+
num_queries, input_dropout, aux_loss=False,
|
39 |
+
contrastive_align_loss=False, contrastive_hdim=64,
|
40 |
+
max_v_l=75, span_loss_type="l1", use_txt_pos=False, n_input_proj=2):
|
41 |
+
""" Initializes the model.
|
42 |
+
Parameters:
|
43 |
+
transformer: torch module of the transformer architecture. See transformer.py
|
44 |
+
position_embed: torch module of the position_embedding, See position_encoding.py
|
45 |
+
txt_position_embed: position_embedding for text
|
46 |
+
txt_dim: int, text query input dimension
|
47 |
+
vid_dim: int, video feature input dimension
|
48 |
+
num_queries: number of object queries, ie detection slot. This is the maximal number of objects
|
49 |
+
Moment-DETR can detect in a single video.
|
50 |
+
aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used.
|
51 |
+
contrastive_align_loss: If true, perform span - tokens contrastive learning
|
52 |
+
contrastive_hdim: dimension used for projecting the embeddings before computing contrastive loss
|
53 |
+
max_v_l: int, maximum #clips in videos
|
54 |
+
span_loss_type: str, one of [l1, ce]
|
55 |
+
l1: (center-x, width) regression.
|
56 |
+
ce: (st_idx, ed_idx) classification.
|
57 |
+
# foreground_thd: float, intersection over prediction >= foreground_thd: labeled as foreground
|
58 |
+
# background_thd: float, intersection over prediction <= background_thd: labeled background
|
59 |
+
"""
|
60 |
+
super().__init__()
|
61 |
+
self.num_queries = num_queries
|
62 |
+
self.transformer = transformer
|
63 |
+
self.position_embed = position_embed
|
64 |
+
self.txt_position_embed = txt_position_embed
|
65 |
+
hidden_dim = transformer.d_model
|
66 |
+
self.span_loss_type = span_loss_type
|
67 |
+
self.max_v_l = max_v_l
|
68 |
+
span_pred_dim = 2 if span_loss_type == "l1" else max_v_l * 2
|
69 |
+
self.span_embed = MLP(hidden_dim, hidden_dim, span_pred_dim, 3)
|
70 |
+
self.class_embed = nn.Linear(hidden_dim, 2) # 0: background, 1: foreground
|
71 |
+
self.use_txt_pos = use_txt_pos
|
72 |
+
self.n_input_proj = n_input_proj
|
73 |
+
# self.foreground_thd = foreground_thd
|
74 |
+
# self.background_thd = background_thd
|
75 |
+
self.query_embed = nn.Embedding(num_queries, hidden_dim)
|
76 |
+
relu_args = [True] * 3
|
77 |
+
relu_args[n_input_proj-1] = False
|
78 |
+
self.input_txt_proj = nn.Sequential(*[
|
79 |
+
LinearLayer(txt_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[0]),
|
80 |
+
LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[1]),
|
81 |
+
LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[2])
|
82 |
+
][:n_input_proj])
|
83 |
+
self.input_vid_proj = nn.Sequential(*[
|
84 |
+
LinearLayer(vid_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[0]),
|
85 |
+
LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[1]),
|
86 |
+
LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[2])
|
87 |
+
][:n_input_proj])
|
88 |
+
self.contrastive_align_loss = contrastive_align_loss
|
89 |
+
if contrastive_align_loss:
|
90 |
+
self.contrastive_align_projection_query = nn.Linear(hidden_dim, contrastive_hdim)
|
91 |
+
self.contrastive_align_projection_txt = nn.Linear(hidden_dim, contrastive_hdim)
|
92 |
+
self.contrastive_align_projection_vid = nn.Linear(hidden_dim, contrastive_hdim)
|
93 |
+
|
94 |
+
self.saliency_proj = nn.Linear(hidden_dim, 1)
|
95 |
+
self.aux_loss = aux_loss
|
96 |
+
|
97 |
+
def forward(self, src_txt, src_txt_mask, src_vid, src_vid_mask):
|
98 |
+
"""The forward expects two tensors:
|
99 |
+
- src_txt: [batch_size, L_txt, D_txt]
|
100 |
+
- src_txt_mask: [batch_size, L_txt], containing 0 on padded pixels,
|
101 |
+
will convert to 1 as padding later for transformer
|
102 |
+
- src_vid: [batch_size, L_vid, D_vid]
|
103 |
+
- src_vid_mask: [batch_size, L_vid], containing 0 on padded pixels,
|
104 |
+
will convert to 1 as padding later for transformer
|
105 |
+
|
106 |
+
It returns a dict with the following elements:
|
107 |
+
- "pred_spans": The normalized boxes coordinates for all queries, represented as
|
108 |
+
(center_x, width). These values are normalized in [0, 1],
|
109 |
+
relative to the size of each individual image (disregarding possible padding).
|
110 |
+
See PostProcess for information on how to retrieve the unnormalized bounding box.
|
111 |
+
- "aux_outputs": Optional, only returned when auxilary losses are activated. It is a list of
|
112 |
+
dictionnaries containing the two above keys for each decoder layer.
|
113 |
+
"""
|
114 |
+
src_vid = self.input_vid_proj(src_vid)
|
115 |
+
src_txt = self.input_txt_proj(src_txt)
|
116 |
+
src = torch.cat([src_vid, src_txt], dim=1) # (bsz, L_vid+L_txt, d)
|
117 |
+
mask = torch.cat([src_vid_mask, src_txt_mask], dim=1).bool() # (bsz, L_vid+L_txt)
|
118 |
+
# TODO should we remove or use different positional embeddings to the src_txt?
|
119 |
+
pos_vid = self.position_embed(src_vid, src_vid_mask) # (bsz, L_vid, d)
|
120 |
+
pos_txt = self.txt_position_embed(src_txt) if self.use_txt_pos else torch.zeros_like(src_txt) # (bsz, L_txt, d)
|
121 |
+
# pos_txt = torch.zeros_like(src_txt)
|
122 |
+
# pad zeros for txt positions
|
123 |
+
pos = torch.cat([pos_vid, pos_txt], dim=1)
|
124 |
+
# (#layers, bsz, #queries, d), (bsz, L_vid+L_txt, d)
|
125 |
+
hs, memory = self.transformer(src, ~mask, self.query_embed.weight, pos)
|
126 |
+
outputs_class = self.class_embed(hs) # (#layers, batch_size, #queries, #classes)
|
127 |
+
outputs_coord = self.span_embed(hs) # (#layers, bsz, #queries, 2 or max_v_l * 2)
|
128 |
+
if self.span_loss_type == "l1":
|
129 |
+
outputs_coord = outputs_coord.sigmoid()
|
130 |
+
out = {'pred_logits': outputs_class[-1], 'pred_spans': outputs_coord[-1]}
|
131 |
+
|
132 |
+
txt_mem = memory[:, src_vid.shape[1]:] # (bsz, L_txt, d)
|
133 |
+
vid_mem = memory[:, :src_vid.shape[1]] # (bsz, L_vid, d)
|
134 |
+
if self.contrastive_align_loss:
|
135 |
+
proj_queries = F.normalize(self.contrastive_align_projection_query(hs), p=2, dim=-1)
|
136 |
+
proj_txt_mem = F.normalize(self.contrastive_align_projection_txt(txt_mem), p=2, dim=-1)
|
137 |
+
proj_vid_mem = F.normalize(self.contrastive_align_projection_vid(vid_mem), p=2, dim=-1)
|
138 |
+
out.update(dict(
|
139 |
+
proj_queries=proj_queries[-1],
|
140 |
+
proj_txt_mem=proj_txt_mem,
|
141 |
+
proj_vid_mem=proj_vid_mem
|
142 |
+
))
|
143 |
+
|
144 |
+
out["saliency_scores"] = self.saliency_proj(vid_mem).squeeze(-1) # (bsz, L_vid)
|
145 |
+
|
146 |
+
if self.aux_loss:
|
147 |
+
# assert proj_queries and proj_txt_mem
|
148 |
+
out['aux_outputs'] = [
|
149 |
+
{'pred_logits': a, 'pred_spans': b} for a, b in zip(outputs_class[:-1], outputs_coord[:-1])]
|
150 |
+
if self.contrastive_align_loss:
|
151 |
+
assert proj_queries is not None
|
152 |
+
for idx, d in enumerate(proj_queries[:-1]):
|
153 |
+
out['aux_outputs'][idx].update(dict(proj_queries=d, proj_txt_mem=proj_txt_mem))
|
154 |
+
return out
|
155 |
+
|
156 |
+
# @torch.jit.unused
|
157 |
+
# def _set_aux_loss(self, outputs_class, outputs_coord):
|
158 |
+
# # this is a workaround to make torchscript happy, as torchscript
|
159 |
+
# # doesn't support dictionary with non-homogeneous values, such
|
160 |
+
# # as a dict having both a Tensor and a list.
|
161 |
+
# return [{'pred_logits': a, 'pred_spans': b}
|
162 |
+
# for a, b in zip(outputs_class[:-1], outputs_coord[:-1])]
|
163 |
+
|
164 |
+
|
165 |
+
class SetCriterion(nn.Module):
|
166 |
+
""" This class computes the loss for DETR.
|
167 |
+
The process happens in two steps:
|
168 |
+
1) we compute hungarian assignment between ground truth boxes and the outputs of the model
|
169 |
+
2) we supervise each pair of matched ground-truth / prediction (supervise class and box)
|
170 |
+
"""
|
171 |
+
|
172 |
+
def __init__(self, matcher, weight_dict, eos_coef, losses, temperature, span_loss_type, max_v_l,
|
173 |
+
saliency_margin=1):
|
174 |
+
""" Create the criterion.
|
175 |
+
Parameters:
|
176 |
+
matcher: module able to compute a matching between targets and proposals
|
177 |
+
weight_dict: dict containing as key the names of the losses and as values their relative weight.
|
178 |
+
eos_coef: relative classification weight applied to the no-object category
|
179 |
+
losses: list of all the losses to be applied. See get_loss for list of available losses.
|
180 |
+
temperature: float, temperature for NCE loss
|
181 |
+
span_loss_type: str, [l1, ce]
|
182 |
+
max_v_l: int,
|
183 |
+
saliency_margin: float
|
184 |
+
"""
|
185 |
+
super().__init__()
|
186 |
+
self.matcher = matcher
|
187 |
+
self.weight_dict = weight_dict
|
188 |
+
self.losses = losses
|
189 |
+
self.temperature = temperature
|
190 |
+
self.span_loss_type = span_loss_type
|
191 |
+
self.max_v_l = max_v_l
|
192 |
+
self.saliency_margin = saliency_margin
|
193 |
+
|
194 |
+
# foreground and background classification
|
195 |
+
self.foreground_label = 0
|
196 |
+
self.background_label = 1
|
197 |
+
self.eos_coef = eos_coef
|
198 |
+
empty_weight = torch.ones(2)
|
199 |
+
empty_weight[-1] = self.eos_coef # lower weight for background (index 1, foreground index 0)
|
200 |
+
self.register_buffer('empty_weight', empty_weight)
|
201 |
+
|
202 |
+
def loss_spans(self, outputs, targets, indices):
|
203 |
+
"""Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss
|
204 |
+
targets dicts must contain the key "spans" containing a tensor of dim [nb_tgt_spans, 2]
|
205 |
+
The target spans are expected in format (center_x, w), normalized by the image size.
|
206 |
+
"""
|
207 |
+
assert 'pred_spans' in outputs
|
208 |
+
targets = targets["span_labels"]
|
209 |
+
idx = self._get_src_permutation_idx(indices)
|
210 |
+
src_spans = outputs['pred_spans'][idx] # (#spans, max_v_l * 2)
|
211 |
+
tgt_spans = torch.cat([t['spans'][i] for t, (_, i) in zip(targets, indices)], dim=0) # (#spans, 2)
|
212 |
+
if self.span_loss_type == "l1":
|
213 |
+
loss_span = F.l1_loss(src_spans, tgt_spans, reduction='none')
|
214 |
+
loss_giou = 1 - torch.diag(generalized_temporal_iou(span_cxw_to_xx(src_spans), span_cxw_to_xx(tgt_spans)))
|
215 |
+
else: # ce
|
216 |
+
n_spans = src_spans.shape[0]
|
217 |
+
src_spans = src_spans.view(n_spans, 2, self.max_v_l).transpose(1, 2)
|
218 |
+
loss_span = F.cross_entropy(src_spans, tgt_spans, reduction='none')
|
219 |
+
|
220 |
+
# giou
|
221 |
+
# src_span_indices = src_spans.max(1)[1] # (#spans, 2)
|
222 |
+
# src_span_indices[:, 1] += 1 # ed non-inclusive [st, ed)
|
223 |
+
#
|
224 |
+
# tgt_span_indices = tgt_spans
|
225 |
+
# tgt_span_indices[:, 1] += 1
|
226 |
+
# loss_giou = 1 - torch.diag(generalized_temporal_iou(src_span_indices, tgt_span_indices))
|
227 |
+
loss_giou = loss_span.new_zeros([1])
|
228 |
+
|
229 |
+
losses = {}
|
230 |
+
losses['loss_b'] = loss_span.mean()
|
231 |
+
losses['loss_g'] = loss_giou.mean()
|
232 |
+
return losses
|
233 |
+
|
234 |
+
def loss_labels(self, outputs, targets, indices, log=True):
|
235 |
+
"""Classification loss (NLL)
|
236 |
+
targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes]
|
237 |
+
"""
|
238 |
+
# TODO add foreground and background classifier. use all non-matched as background.
|
239 |
+
assert 'pred_logits' in outputs
|
240 |
+
src_logits = outputs['pred_logits'] # (batch_size, #queries, #classes=2)
|
241 |
+
# idx is a tuple of two 1D tensors (batch_idx, src_idx), of the same length == #objects in batch
|
242 |
+
idx = self._get_src_permutation_idx(indices)
|
243 |
+
target_classes = torch.full(src_logits.shape[:2], self.background_label,
|
244 |
+
dtype=torch.int64, device=src_logits.device) # (batch_size, #queries)
|
245 |
+
target_classes[idx] = self.foreground_label
|
246 |
+
|
247 |
+
loss_ce = F.cross_entropy(src_logits.transpose(1, 2), target_classes, self.empty_weight, reduction="none")
|
248 |
+
losses = {'loss_f': loss_ce.mean()}
|
249 |
+
|
250 |
+
if log:
|
251 |
+
# TODO this should probably be a separate loss, not hacked in this one here
|
252 |
+
losses['class_error'] = 100 - accuracy(src_logits[idx], self.foreground_label)[0]
|
253 |
+
return losses
|
254 |
+
|
255 |
+
def loss_saliency(self, outputs, targets, indices, log=True):
|
256 |
+
"""higher scores for positive clips"""
|
257 |
+
if "saliency_pos_labels" not in targets:
|
258 |
+
return {"loss_s_intra": 0}
|
259 |
+
saliency_scores = outputs["saliency_scores"] # (N, L)
|
260 |
+
pos_indices = targets["saliency_pos_labels"] # (N, #pairs)
|
261 |
+
neg_indices = targets["saliency_neg_labels"] # (N, #pairs)
|
262 |
+
num_pairs = pos_indices.shape[1] # typically 2 or 4
|
263 |
+
batch_indices = torch.arange(len(saliency_scores)).to(saliency_scores.device)
|
264 |
+
pos_scores = torch.stack(
|
265 |
+
[saliency_scores[batch_indices, pos_indices[:, col_idx]] for col_idx in range(num_pairs)], dim=1)
|
266 |
+
neg_scores = torch.stack(
|
267 |
+
[saliency_scores[batch_indices, neg_indices[:, col_idx]] for col_idx in range(num_pairs)], dim=1)
|
268 |
+
loss_saliency = torch.clamp(self.saliency_margin + neg_scores - pos_scores, min=0).sum() \
|
269 |
+
/ (len(pos_scores) * num_pairs) * 2 # * 2 to keep the loss the same scale
|
270 |
+
return {"loss_s_intra": loss_saliency}
|
271 |
+
|
272 |
+
def loss_contrastive_align(self, outputs, targets, indices, log=True):
|
273 |
+
"""encourage higher scores between matched query span and input text"""
|
274 |
+
normalized_text_embed = outputs["proj_txt_mem"] # (bsz, #tokens, d) text tokens
|
275 |
+
normalized_img_embed = outputs["proj_queries"] # (bsz, #queries, d)
|
276 |
+
logits = torch.einsum(
|
277 |
+
"bmd,bnd->bmn", normalized_img_embed, normalized_text_embed) # (bsz, #queries, #tokens)
|
278 |
+
logits = logits.sum(2) / self.temperature # (bsz, #queries)
|
279 |
+
idx = self._get_src_permutation_idx(indices)
|
280 |
+
positive_map = torch.zeros_like(logits, dtype=torch.bool)
|
281 |
+
positive_map[idx] = True
|
282 |
+
positive_logits = logits.masked_fill(~positive_map, 0)
|
283 |
+
|
284 |
+
pos_term = positive_logits.sum(1) # (bsz, )
|
285 |
+
num_pos = positive_map.sum(1) # (bsz, )
|
286 |
+
neg_term = logits.logsumexp(1) # (bsz, )
|
287 |
+
loss_nce = - pos_term / num_pos + neg_term # (bsz, )
|
288 |
+
losses = {"loss_contrastive_align": loss_nce.mean()}
|
289 |
+
return losses
|
290 |
+
|
291 |
+
def loss_contrastive_align_vid_txt(self, outputs, targets, indices, log=True):
|
292 |
+
"""encourage higher scores between matched query span and input text"""
|
293 |
+
# TODO (1) align vid_mem and txt_mem;
|
294 |
+
# TODO (2) change L1 loss as CE loss on 75 labels, similar to soft token prediction in MDETR
|
295 |
+
normalized_text_embed = outputs["proj_txt_mem"] # (bsz, #tokens, d) text tokens
|
296 |
+
normalized_img_embed = outputs["proj_queries"] # (bsz, #queries, d)
|
297 |
+
logits = torch.einsum(
|
298 |
+
"bmd,bnd->bmn", normalized_img_embed, normalized_text_embed) # (bsz, #queries, #tokens)
|
299 |
+
logits = logits.sum(2) / self.temperature # (bsz, #queries)
|
300 |
+
idx = self._get_src_permutation_idx(indices)
|
301 |
+
positive_map = torch.zeros_like(logits, dtype=torch.bool)
|
302 |
+
positive_map[idx] = True
|
303 |
+
positive_logits = logits.masked_fill(~positive_map, 0)
|
304 |
+
|
305 |
+
pos_term = positive_logits.sum(1) # (bsz, )
|
306 |
+
num_pos = positive_map.sum(1) # (bsz, )
|
307 |
+
neg_term = logits.logsumexp(1) # (bsz, )
|
308 |
+
loss_nce = - pos_term / num_pos + neg_term # (bsz, )
|
309 |
+
losses = {"loss_contrastive_align": loss_nce.mean()}
|
310 |
+
return losses
|
311 |
+
|
312 |
+
def _get_src_permutation_idx(self, indices):
|
313 |
+
# permute predictions following indices
|
314 |
+
batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
|
315 |
+
src_idx = torch.cat([src for (src, _) in indices])
|
316 |
+
return batch_idx, src_idx # two 1D tensors of the same length
|
317 |
+
|
318 |
+
def _get_tgt_permutation_idx(self, indices):
|
319 |
+
# permute targets following indices
|
320 |
+
batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)])
|
321 |
+
tgt_idx = torch.cat([tgt for (_, tgt) in indices])
|
322 |
+
return batch_idx, tgt_idx
|
323 |
+
|
324 |
+
def get_loss(self, loss, outputs, targets, indices, **kwargs):
|
325 |
+
loss_map = {
|
326 |
+
"spans": self.loss_spans,
|
327 |
+
"labels": self.loss_labels,
|
328 |
+
"contrastive_align": self.loss_contrastive_align,
|
329 |
+
"saliency": self.loss_saliency,
|
330 |
+
}
|
331 |
+
assert loss in loss_map, f'do you really want to compute {loss} loss?'
|
332 |
+
return loss_map[loss](outputs, targets, indices, **kwargs)
|
333 |
+
|
334 |
+
def forward(self, outputs, targets):
|
335 |
+
""" This performs the loss computation.
|
336 |
+
Parameters:
|
337 |
+
outputs: dict of tensors, see the output specification of the model for the format
|
338 |
+
targets: list of dicts, such that len(targets) == batch_size.
|
339 |
+
The expected keys in each dict depends on the losses applied, see each loss' doc
|
340 |
+
"""
|
341 |
+
outputs_without_aux = {k: v for k, v in outputs.items() if k != 'aux_outputs'}
|
342 |
+
|
343 |
+
# Retrieve the matching between the outputs of the last layer and the targets
|
344 |
+
# list(tuples), each tuple is (pred_span_indices, tgt_span_indices)
|
345 |
+
indices = self.matcher(outputs_without_aux, targets)
|
346 |
+
|
347 |
+
# Compute all the requested losses
|
348 |
+
losses = {}
|
349 |
+
for loss in self.losses:
|
350 |
+
losses.update(self.get_loss(loss, outputs, targets, indices))
|
351 |
+
|
352 |
+
# In case of auxiliary losses, we repeat this process with the output of each intermediate layer.
|
353 |
+
if 'aux_outputs' in outputs:
|
354 |
+
for i, aux_outputs in enumerate(outputs['aux_outputs']):
|
355 |
+
indices = self.matcher(aux_outputs, targets)
|
356 |
+
for loss in self.losses:
|
357 |
+
if "saliency" == loss: # skip as it is only in the top layer
|
358 |
+
continue
|
359 |
+
kwargs = {}
|
360 |
+
l_dict = self.get_loss(loss, aux_outputs, targets, indices, **kwargs)
|
361 |
+
l_dict = {k + f'_{i}': v for k, v in l_dict.items()}
|
362 |
+
losses.update(l_dict)
|
363 |
+
|
364 |
+
return losses
|
365 |
+
|
366 |
+
|
367 |
+
class MLP(nn.Module):
|
368 |
+
""" Very simple multi-layer perceptron (also called FFN)"""
|
369 |
+
|
370 |
+
def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
|
371 |
+
super().__init__()
|
372 |
+
self.num_layers = num_layers
|
373 |
+
h = [hidden_dim] * (num_layers - 1)
|
374 |
+
self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
|
375 |
+
|
376 |
+
def forward(self, x):
|
377 |
+
for i, layer in enumerate(self.layers):
|
378 |
+
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
|
379 |
+
return x
|
380 |
+
|
381 |
+
|
382 |
+
class LinearLayer(nn.Module):
|
383 |
+
"""linear layer configurable with layer normalization, dropout, ReLU."""
|
384 |
+
|
385 |
+
def __init__(self, in_hsz, out_hsz, layer_norm=True, dropout=0.1, relu=True):
|
386 |
+
super(LinearLayer, self).__init__()
|
387 |
+
self.relu = relu
|
388 |
+
self.layer_norm = layer_norm
|
389 |
+
if layer_norm:
|
390 |
+
self.LayerNorm = nn.LayerNorm(in_hsz)
|
391 |
+
layers = [
|
392 |
+
nn.Dropout(dropout),
|
393 |
+
nn.Linear(in_hsz, out_hsz)
|
394 |
+
]
|
395 |
+
self.net = nn.Sequential(*layers)
|
396 |
+
|
397 |
+
def forward(self, x):
|
398 |
+
"""(N, L, D)"""
|
399 |
+
if self.layer_norm:
|
400 |
+
x = self.LayerNorm(x)
|
401 |
+
x = self.net(x)
|
402 |
+
if self.relu:
|
403 |
+
x = F.relu(x, inplace=True)
|
404 |
+
return x # (N, L, D)
|
405 |
+
|
406 |
+
|
407 |
+
def build_model(args):
|
408 |
+
# the `num_classes` naming here is somewhat misleading.
|
409 |
+
# it indeed corresponds to `max_obj_id + 1`, where max_obj_id
|
410 |
+
# is the maximum id for a class in your dataset. For example,
|
411 |
+
# COCO has a max_obj_id of 90, so we pass `num_classes` to be 91.
|
412 |
+
# As another example, for a dataset that has a single class with id 1,
|
413 |
+
# you should pass `num_classes` to be 2 (max_obj_id + 1).
|
414 |
+
# For more details on this, check the following discussion
|
415 |
+
# https://github.com/facebookresearch/moment_bert/issues/108#issuecomment-650269223
|
416 |
+
device = torch.device(args.device)
|
417 |
+
|
418 |
+
transformer = build_transformer(args)
|
419 |
+
position_embedding, txt_position_embedding = build_position_encoding(args)
|
420 |
+
|
421 |
+
model = Model(
|
422 |
+
transformer,
|
423 |
+
position_embedding,
|
424 |
+
txt_position_embedding,
|
425 |
+
txt_dim=args.t_feat_dim,
|
426 |
+
vid_dim=args.v_feat_dim,
|
427 |
+
num_queries=args.num_queries,
|
428 |
+
input_dropout=args.input_dropout,
|
429 |
+
aux_loss=args.aux_loss,
|
430 |
+
# contrastive_align_loss=args.contrastive_align_loss,
|
431 |
+
# contrastive_hdim=args.contrastive_hdim,
|
432 |
+
span_loss_type=args.span_loss_type,
|
433 |
+
use_txt_pos=args.use_txt_pos,
|
434 |
+
n_input_proj=args.n_input_proj,
|
435 |
+
)
|
436 |
+
|
437 |
+
matcher = build_matcher(args)
|
438 |
+
weight_dict = {"loss_b": args.b_loss_coef,
|
439 |
+
"loss_g": args.g_loss_coef,
|
440 |
+
"loss_f": args.f_loss_coef,
|
441 |
+
"loss_s_intra": args.s_loss_intra_coef,
|
442 |
+
"loss_s_inter": args.s_loss_inter_coef}
|
443 |
+
# if args.contrastive_align_loss:
|
444 |
+
# weight_dict["loss_contrastive_align"] = args.contrastive_align_loss_coef
|
445 |
+
# TODO this is a hack
|
446 |
+
if args.aux_loss:
|
447 |
+
aux_weight_dict = {}
|
448 |
+
for i in range(args.dec_layers - 1):
|
449 |
+
aux_weight_dict.update({k + f'_{i}': v for k, v in weight_dict.items() if k != "loss_saliency"})
|
450 |
+
weight_dict.update(aux_weight_dict)
|
451 |
+
|
452 |
+
losses = ['spans', 'labels', 'saliency']
|
453 |
+
# if args.contrastive_align_loss:
|
454 |
+
# losses += ["contrastive_align"]
|
455 |
+
criterion = SetCriterion(
|
456 |
+
matcher=matcher, weight_dict=weight_dict, losses=losses,
|
457 |
+
eos_coef=args.eos_coef, temperature=args.temperature,
|
458 |
+
span_loss_type=args.span_loss_type, max_v_l=args.max_v_l,
|
459 |
+
saliency_margin=args.saliency_margin
|
460 |
+
)
|
461 |
+
criterion.to(device)
|
462 |
+
return model, criterion
|
model/position_encoding.py
ADDED
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
2 |
+
"""
|
3 |
+
Various positional encodings for the transformer.
|
4 |
+
"""
|
5 |
+
import math
|
6 |
+
import torch
|
7 |
+
from torch import nn
|
8 |
+
import numpy as np
|
9 |
+
|
10 |
+
def PositionalEncoding(n_position, d_hid):
|
11 |
+
def get_position_angle_vec(position, d_hid):
|
12 |
+
return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]
|
13 |
+
|
14 |
+
sinusoid_table = np.array([get_position_angle_vec(pos_i, d_hid) for pos_i in range(n_position)])
|
15 |
+
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
|
16 |
+
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
|
17 |
+
return torch.FloatTensor(sinusoid_table) # shape:(1, maxLen(n_position), d_hid)
|
18 |
+
|
19 |
+
class TrainablePositionalEncoding(nn.Module):
|
20 |
+
"""Construct the embeddings from word, position and token_type embeddings.
|
21 |
+
"""
|
22 |
+
def __init__(self, max_position_embeddings, hidden_size, dropout=0.1):
|
23 |
+
super(TrainablePositionalEncoding, self).__init__()
|
24 |
+
self.position_embeddings = nn.Embedding(max_position_embeddings, hidden_size)
|
25 |
+
self.LayerNorm = nn.LayerNorm(hidden_size)
|
26 |
+
self.dropout = nn.Dropout(dropout)
|
27 |
+
|
28 |
+
def forward(self, input_feat):
|
29 |
+
"""
|
30 |
+
Args:
|
31 |
+
input_feat: (N, L, D)
|
32 |
+
"""
|
33 |
+
bsz, seq_length = input_feat.shape[:2]
|
34 |
+
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_feat.device)
|
35 |
+
position_ids = position_ids.unsqueeze(0).repeat(bsz, 1) # (N, L)
|
36 |
+
|
37 |
+
position_embeddings = self.position_embeddings(position_ids)
|
38 |
+
|
39 |
+
embeddings = self.LayerNorm(input_feat + position_embeddings)
|
40 |
+
embeddings = self.dropout(embeddings)
|
41 |
+
return embeddings
|
42 |
+
|
43 |
+
|
44 |
+
class PositionEmbeddingSine(nn.Module):
|
45 |
+
"""
|
46 |
+
This is a more standard version of the position embedding, very similar to the one
|
47 |
+
used by the Attention is all you need paper, generalized to work on images. (To 1D sequences)
|
48 |
+
"""
|
49 |
+
def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
|
50 |
+
super().__init__()
|
51 |
+
self.num_pos_feats = num_pos_feats
|
52 |
+
self.temperature = temperature
|
53 |
+
self.normalize = normalize
|
54 |
+
if scale is not None and normalize is False:
|
55 |
+
raise ValueError("normalize should be True if scale is passed")
|
56 |
+
if scale is None:
|
57 |
+
scale = 2 * math.pi
|
58 |
+
self.scale = scale
|
59 |
+
|
60 |
+
def forward(self, x, mask):
|
61 |
+
"""
|
62 |
+
Args:
|
63 |
+
x: torch.tensor, (batch_size, L, d)
|
64 |
+
mask: torch.tensor, (batch_size, L), with 1 as valid
|
65 |
+
|
66 |
+
Returns:
|
67 |
+
|
68 |
+
"""
|
69 |
+
assert mask is not None
|
70 |
+
x_embed = mask.cumsum(1, dtype=torch.float32) # (bsz, L)
|
71 |
+
if self.normalize:
|
72 |
+
eps = 1e-6
|
73 |
+
x_embed = x_embed / (x_embed[:, -1:] + eps) * self.scale
|
74 |
+
|
75 |
+
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
|
76 |
+
# import pdb; pdb.set_trace()
|
77 |
+
# dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
|
78 |
+
dim_t = self.temperature ** (2 * torch.div(dim_t, 2).int() / self.num_pos_feats)
|
79 |
+
|
80 |
+
pos_x = x_embed[:, :, None] / dim_t # (bsz, L, num_pos_feats)
|
81 |
+
pos_x = torch.stack((pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3).flatten(2) # (bsz, L, num_pos_feats*2)
|
82 |
+
# import ipdb; ipdb.set_trace()
|
83 |
+
return pos_x # .permute(0, 2, 1) # (bsz, num_pos_feats*2, L)
|
84 |
+
|
85 |
+
|
86 |
+
class PositionEmbeddingLearned(nn.Module):
|
87 |
+
"""
|
88 |
+
Absolute pos embedding, learned.
|
89 |
+
"""
|
90 |
+
def __init__(self, num_pos_feats=256):
|
91 |
+
super().__init__()
|
92 |
+
self.row_embed = nn.Embedding(50, num_pos_feats)
|
93 |
+
self.col_embed = nn.Embedding(50, num_pos_feats)
|
94 |
+
self.reset_parameters()
|
95 |
+
|
96 |
+
def reset_parameters(self):
|
97 |
+
nn.init.uniform_(self.row_embed.weight)
|
98 |
+
nn.init.uniform_(self.col_embed.weight)
|
99 |
+
|
100 |
+
def forward(self, x, mask):
|
101 |
+
h, w = x.shape[-2:]
|
102 |
+
i = torch.arange(w, device=x.device)
|
103 |
+
j = torch.arange(h, device=x.device)
|
104 |
+
x_emb = self.col_embed(i)
|
105 |
+
y_emb = self.row_embed(j)
|
106 |
+
pos = torch.cat([
|
107 |
+
x_emb.unsqueeze(0).repeat(h, 1, 1),
|
108 |
+
y_emb.unsqueeze(1).repeat(1, w, 1),
|
109 |
+
], dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(x.shape[0], 1, 1, 1)
|
110 |
+
return pos
|
111 |
+
|
112 |
+
|
113 |
+
def build_position_encoding(args):
|
114 |
+
N_steps = args.hidden_dim
|
115 |
+
if args.position_embedding in ('v2', 'sine'):
|
116 |
+
# TODO find a better way of exposing other arguments
|
117 |
+
position_embedding = PositionEmbeddingSine(N_steps, normalize=True)
|
118 |
+
# elif args.position_embedding in ('v3', 'learned'):
|
119 |
+
# position_embedding = PositionEmbeddingLearned(N_steps)
|
120 |
+
else:
|
121 |
+
raise ValueError(f"not supported {args.position_embedding}")
|
122 |
+
|
123 |
+
txt_pos_embed = TrainablePositionalEncoding(
|
124 |
+
max_position_embeddings=args.max_q_l,
|
125 |
+
hidden_size=args.hidden_dim, dropout=args.input_dropout)
|
126 |
+
return position_embedding, txt_pos_embed
|
model/transformer.py
ADDED
@@ -0,0 +1,471 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
2 |
+
"""
|
3 |
+
DETR Transformer class.
|
4 |
+
|
5 |
+
Copy-paste from torch.nn.Transformer with modifications:
|
6 |
+
* positional encodings are passed in MHattention
|
7 |
+
* extra LN at the end of encoder is removed
|
8 |
+
* decoder returns a stack of activations from all decoding layers
|
9 |
+
"""
|
10 |
+
import copy
|
11 |
+
from typing import Optional
|
12 |
+
|
13 |
+
import torch
|
14 |
+
import torch.nn.functional as F
|
15 |
+
from torch import nn, Tensor
|
16 |
+
|
17 |
+
|
18 |
+
class Transformer(nn.Module):
|
19 |
+
|
20 |
+
def __init__(self, d_model=512, nhead=8, num_encoder_layers=6,
|
21 |
+
num_decoder_layers=6, dim_feedforward=2048, dropout=0.1,
|
22 |
+
activation="relu", normalize_before=False,
|
23 |
+
return_intermediate_dec=False):
|
24 |
+
super().__init__()
|
25 |
+
|
26 |
+
# TransformerEncoderLayerThin
|
27 |
+
encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward,
|
28 |
+
dropout, activation, normalize_before)
|
29 |
+
encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
|
30 |
+
self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
|
31 |
+
|
32 |
+
# TransformerDecoderLayerThin
|
33 |
+
decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward,
|
34 |
+
dropout, activation, normalize_before)
|
35 |
+
decoder_norm = nn.LayerNorm(d_model)
|
36 |
+
self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm,
|
37 |
+
return_intermediate=return_intermediate_dec)
|
38 |
+
|
39 |
+
self._reset_parameters()
|
40 |
+
|
41 |
+
self.d_model = d_model
|
42 |
+
self.nhead = nhead
|
43 |
+
|
44 |
+
def _reset_parameters(self):
|
45 |
+
for p in self.parameters():
|
46 |
+
if p.dim() > 1:
|
47 |
+
nn.init.xavier_uniform_(p)
|
48 |
+
|
49 |
+
def forward(self, src, mask, query_embed, pos_embed):
|
50 |
+
"""
|
51 |
+
Args:
|
52 |
+
src: (batch_size, L, d)
|
53 |
+
mask: (batch_size, L)
|
54 |
+
query_embed: (#queries, d)
|
55 |
+
pos_embed: (batch_size, L, d) the same as src
|
56 |
+
|
57 |
+
Returns:
|
58 |
+
|
59 |
+
"""
|
60 |
+
# flatten NxCxHxW to HWxNxC
|
61 |
+
bs, l, d = src.shape
|
62 |
+
src = src.permute(1, 0, 2) # (L, batch_size, d)
|
63 |
+
pos_embed = pos_embed.permute(1, 0, 2) # (L, batch_size, d)
|
64 |
+
query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1) # (#queries, batch_size, d)
|
65 |
+
|
66 |
+
tgt = torch.zeros_like(query_embed)
|
67 |
+
memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed) # (L, batch_size, d)
|
68 |
+
hs = self.decoder(tgt, memory, memory_key_padding_mask=mask,
|
69 |
+
pos=pos_embed, query_pos=query_embed) # (#layers, #queries, batch_size, d)
|
70 |
+
hs = hs.transpose(1, 2) # (#layers, batch_size, #qeries, d)
|
71 |
+
# memory = memory.permute(1, 2, 0) # (batch_size, d, L)
|
72 |
+
memory = memory.transpose(0, 1) # (batch_size, L, d)
|
73 |
+
return hs, memory
|
74 |
+
|
75 |
+
|
76 |
+
class TransformerEncoder(nn.Module):
|
77 |
+
|
78 |
+
def __init__(self, encoder_layer, num_layers, norm=None, return_intermediate=False):
|
79 |
+
super().__init__()
|
80 |
+
self.layers = _get_clones(encoder_layer, num_layers)
|
81 |
+
self.num_layers = num_layers
|
82 |
+
self.norm = norm
|
83 |
+
self.return_intermediate = return_intermediate
|
84 |
+
|
85 |
+
def forward(self, src,
|
86 |
+
mask: Optional[Tensor] = None,
|
87 |
+
src_key_padding_mask: Optional[Tensor] = None,
|
88 |
+
pos: Optional[Tensor] = None):
|
89 |
+
output = src
|
90 |
+
|
91 |
+
intermediate = []
|
92 |
+
|
93 |
+
for layer in self.layers:
|
94 |
+
output = layer(output, src_mask=mask,
|
95 |
+
src_key_padding_mask=src_key_padding_mask, pos=pos)
|
96 |
+
if self.return_intermediate:
|
97 |
+
intermediate.append(output)
|
98 |
+
|
99 |
+
if self.norm is not None:
|
100 |
+
output = self.norm(output)
|
101 |
+
|
102 |
+
if self.return_intermediate:
|
103 |
+
return torch.stack(intermediate)
|
104 |
+
|
105 |
+
return output
|
106 |
+
|
107 |
+
|
108 |
+
class TransformerDecoder(nn.Module):
|
109 |
+
|
110 |
+
def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False):
|
111 |
+
super().__init__()
|
112 |
+
self.layers = _get_clones(decoder_layer, num_layers)
|
113 |
+
self.num_layers = num_layers
|
114 |
+
self.norm = norm
|
115 |
+
self.return_intermediate = return_intermediate
|
116 |
+
|
117 |
+
def forward(self, tgt, memory,
|
118 |
+
tgt_mask: Optional[Tensor] = None,
|
119 |
+
memory_mask: Optional[Tensor] = None,
|
120 |
+
tgt_key_padding_mask: Optional[Tensor] = None,
|
121 |
+
memory_key_padding_mask: Optional[Tensor] = None,
|
122 |
+
pos: Optional[Tensor] = None,
|
123 |
+
query_pos: Optional[Tensor] = None):
|
124 |
+
output = tgt
|
125 |
+
|
126 |
+
intermediate = []
|
127 |
+
|
128 |
+
for layer in self.layers:
|
129 |
+
output = layer(output, memory, tgt_mask=tgt_mask,
|
130 |
+
memory_mask=memory_mask,
|
131 |
+
tgt_key_padding_mask=tgt_key_padding_mask,
|
132 |
+
memory_key_padding_mask=memory_key_padding_mask,
|
133 |
+
pos=pos, query_pos=query_pos)
|
134 |
+
if self.return_intermediate:
|
135 |
+
intermediate.append(self.norm(output))
|
136 |
+
|
137 |
+
if self.norm is not None:
|
138 |
+
output = self.norm(output)
|
139 |
+
if self.return_intermediate:
|
140 |
+
intermediate.pop()
|
141 |
+
intermediate.append(output)
|
142 |
+
|
143 |
+
if self.return_intermediate:
|
144 |
+
return torch.stack(intermediate)
|
145 |
+
|
146 |
+
return output.unsqueeze(0)
|
147 |
+
|
148 |
+
|
149 |
+
class TransformerEncoderLayerThin(nn.Module):
|
150 |
+
|
151 |
+
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
|
152 |
+
activation="relu", normalize_before=False):
|
153 |
+
super().__init__()
|
154 |
+
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
155 |
+
# Implementation of Feedforward model
|
156 |
+
# self.linear1 = nn.Linear(d_model, dim_feedforward)
|
157 |
+
# self.dropout = nn.Dropout(dropout)
|
158 |
+
# self.linear2 = nn.Linear(dim_feedforward, d_model)
|
159 |
+
self.linear = nn.Linear(d_model, d_model)
|
160 |
+
self.norm = nn.LayerNorm(d_model)
|
161 |
+
self.dropout = nn.Dropout(dropout)
|
162 |
+
|
163 |
+
# self.activation = _get_activation_fn(activation)
|
164 |
+
self.normalize_before = normalize_before
|
165 |
+
|
166 |
+
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
|
167 |
+
return tensor if pos is None else tensor + pos
|
168 |
+
|
169 |
+
def forward_post(self,
|
170 |
+
src,
|
171 |
+
src_mask: Optional[Tensor] = None,
|
172 |
+
src_key_padding_mask: Optional[Tensor] = None,
|
173 |
+
pos: Optional[Tensor] = None):
|
174 |
+
q = k = self.with_pos_embed(src, pos)
|
175 |
+
src2 = self.self_attn(q, k, value=src, attn_mask=src_mask,
|
176 |
+
key_padding_mask=src_key_padding_mask)[0]
|
177 |
+
src2 = self.linear(src2)
|
178 |
+
src = src + self.dropout(src2)
|
179 |
+
src = self.norm(src)
|
180 |
+
# src = src + self.dropout1(src2)
|
181 |
+
# src = self.norm1(src)
|
182 |
+
# src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
|
183 |
+
# src = src + self.dropout2(src2)
|
184 |
+
# src = self.norm2(src)
|
185 |
+
return src
|
186 |
+
|
187 |
+
def forward_pre(self, src,
|
188 |
+
src_mask: Optional[Tensor] = None,
|
189 |
+
src_key_padding_mask: Optional[Tensor] = None,
|
190 |
+
pos: Optional[Tensor] = None):
|
191 |
+
"""not used"""
|
192 |
+
src2 = self.norm1(src)
|
193 |
+
q = k = self.with_pos_embed(src2, pos)
|
194 |
+
src2 = self.self_attn(q, k, value=src2, attn_mask=src_mask,
|
195 |
+
key_padding_mask=src_key_padding_mask)[0]
|
196 |
+
src = src + self.dropout1(src2)
|
197 |
+
src2 = self.norm2(src)
|
198 |
+
src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
|
199 |
+
src = src + self.dropout2(src2)
|
200 |
+
return src
|
201 |
+
|
202 |
+
def forward(self, src,
|
203 |
+
src_mask: Optional[Tensor] = None,
|
204 |
+
src_key_padding_mask: Optional[Tensor] = None,
|
205 |
+
pos: Optional[Tensor] = None):
|
206 |
+
if self.normalize_before:
|
207 |
+
return self.forward_pre(src, src_mask, src_key_padding_mask, pos)
|
208 |
+
return self.forward_post(src, src_mask, src_key_padding_mask, pos)
|
209 |
+
|
210 |
+
|
211 |
+
class TransformerEncoderLayer(nn.Module):
|
212 |
+
|
213 |
+
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
|
214 |
+
activation="relu", normalize_before=False):
|
215 |
+
super().__init__()
|
216 |
+
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
217 |
+
# Implementation of Feedforward model
|
218 |
+
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
219 |
+
self.dropout = nn.Dropout(dropout)
|
220 |
+
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
221 |
+
|
222 |
+
self.norm1 = nn.LayerNorm(d_model)
|
223 |
+
self.norm2 = nn.LayerNorm(d_model)
|
224 |
+
self.dropout1 = nn.Dropout(dropout)
|
225 |
+
self.dropout2 = nn.Dropout(dropout)
|
226 |
+
|
227 |
+
self.activation = _get_activation_fn(activation)
|
228 |
+
self.normalize_before = normalize_before
|
229 |
+
|
230 |
+
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
|
231 |
+
return tensor if pos is None else tensor + pos
|
232 |
+
|
233 |
+
def forward_post(self,
|
234 |
+
src,
|
235 |
+
src_mask: Optional[Tensor] = None,
|
236 |
+
src_key_padding_mask: Optional[Tensor] = None,
|
237 |
+
pos: Optional[Tensor] = None):
|
238 |
+
q = k = self.with_pos_embed(src, pos)
|
239 |
+
src2 = self.self_attn(q, k, value=src, attn_mask=src_mask,
|
240 |
+
key_padding_mask=src_key_padding_mask)[0]
|
241 |
+
src = src + self.dropout1(src2)
|
242 |
+
src = self.norm1(src)
|
243 |
+
src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
|
244 |
+
src = src + self.dropout2(src2)
|
245 |
+
src = self.norm2(src)
|
246 |
+
return src
|
247 |
+
|
248 |
+
def forward_pre(self, src,
|
249 |
+
src_mask: Optional[Tensor] = None,
|
250 |
+
src_key_padding_mask: Optional[Tensor] = None,
|
251 |
+
pos: Optional[Tensor] = None):
|
252 |
+
src2 = self.norm1(src)
|
253 |
+
q = k = self.with_pos_embed(src2, pos)
|
254 |
+
src2 = self.self_attn(q, k, value=src2, attn_mask=src_mask,
|
255 |
+
key_padding_mask=src_key_padding_mask)[0]
|
256 |
+
src = src + self.dropout1(src2)
|
257 |
+
src2 = self.norm2(src)
|
258 |
+
src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
|
259 |
+
src = src + self.dropout2(src2)
|
260 |
+
return src
|
261 |
+
|
262 |
+
def forward(self, src,
|
263 |
+
src_mask: Optional[Tensor] = None,
|
264 |
+
src_key_padding_mask: Optional[Tensor] = None,
|
265 |
+
pos: Optional[Tensor] = None):
|
266 |
+
if self.normalize_before:
|
267 |
+
return self.forward_pre(src, src_mask, src_key_padding_mask, pos)
|
268 |
+
return self.forward_post(src, src_mask, src_key_padding_mask, pos)
|
269 |
+
|
270 |
+
|
271 |
+
class TransformerDecoderLayer(nn.Module):
|
272 |
+
|
273 |
+
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
|
274 |
+
activation="relu", normalize_before=False):
|
275 |
+
super().__init__()
|
276 |
+
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
277 |
+
self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
278 |
+
# Implementation of Feedforward model
|
279 |
+
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
280 |
+
self.dropout = nn.Dropout(dropout)
|
281 |
+
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
282 |
+
|
283 |
+
self.norm1 = nn.LayerNorm(d_model)
|
284 |
+
self.norm2 = nn.LayerNorm(d_model)
|
285 |
+
self.norm3 = nn.LayerNorm(d_model)
|
286 |
+
self.dropout1 = nn.Dropout(dropout)
|
287 |
+
self.dropout2 = nn.Dropout(dropout)
|
288 |
+
self.dropout3 = nn.Dropout(dropout)
|
289 |
+
|
290 |
+
self.activation = _get_activation_fn(activation)
|
291 |
+
self.normalize_before = normalize_before
|
292 |
+
|
293 |
+
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
|
294 |
+
return tensor if pos is None else tensor + pos
|
295 |
+
|
296 |
+
def forward_post(self, tgt, memory,
|
297 |
+
tgt_mask: Optional[Tensor] = None,
|
298 |
+
memory_mask: Optional[Tensor] = None,
|
299 |
+
tgt_key_padding_mask: Optional[Tensor] = None,
|
300 |
+
memory_key_padding_mask: Optional[Tensor] = None,
|
301 |
+
pos: Optional[Tensor] = None,
|
302 |
+
query_pos: Optional[Tensor] = None):
|
303 |
+
q = k = self.with_pos_embed(tgt, query_pos)
|
304 |
+
tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask,
|
305 |
+
key_padding_mask=tgt_key_padding_mask)[0]
|
306 |
+
tgt = tgt + self.dropout1(tgt2)
|
307 |
+
tgt = self.norm1(tgt)
|
308 |
+
tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos),
|
309 |
+
key=self.with_pos_embed(memory, pos),
|
310 |
+
value=memory, attn_mask=memory_mask,
|
311 |
+
key_padding_mask=memory_key_padding_mask)[0]
|
312 |
+
tgt = tgt + self.dropout2(tgt2)
|
313 |
+
tgt = self.norm2(tgt)
|
314 |
+
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
|
315 |
+
tgt = tgt + self.dropout3(tgt2)
|
316 |
+
tgt = self.norm3(tgt)
|
317 |
+
return tgt
|
318 |
+
|
319 |
+
def forward_pre(self, tgt, memory,
|
320 |
+
tgt_mask: Optional[Tensor] = None,
|
321 |
+
memory_mask: Optional[Tensor] = None,
|
322 |
+
tgt_key_padding_mask: Optional[Tensor] = None,
|
323 |
+
memory_key_padding_mask: Optional[Tensor] = None,
|
324 |
+
pos: Optional[Tensor] = None,
|
325 |
+
query_pos: Optional[Tensor] = None):
|
326 |
+
tgt2 = self.norm1(tgt)
|
327 |
+
q = k = self.with_pos_embed(tgt2, query_pos)
|
328 |
+
tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask,
|
329 |
+
key_padding_mask=tgt_key_padding_mask)[0]
|
330 |
+
tgt = tgt + self.dropout1(tgt2)
|
331 |
+
tgt2 = self.norm2(tgt)
|
332 |
+
tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos),
|
333 |
+
key=self.with_pos_embed(memory, pos),
|
334 |
+
value=memory, attn_mask=memory_mask,
|
335 |
+
key_padding_mask=memory_key_padding_mask)[0]
|
336 |
+
tgt = tgt + self.dropout2(tgt2)
|
337 |
+
tgt2 = self.norm3(tgt)
|
338 |
+
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
|
339 |
+
tgt = tgt + self.dropout3(tgt2)
|
340 |
+
return tgt
|
341 |
+
|
342 |
+
def forward(self, tgt, memory,
|
343 |
+
tgt_mask: Optional[Tensor] = None,
|
344 |
+
memory_mask: Optional[Tensor] = None,
|
345 |
+
tgt_key_padding_mask: Optional[Tensor] = None,
|
346 |
+
memory_key_padding_mask: Optional[Tensor] = None,
|
347 |
+
pos: Optional[Tensor] = None,
|
348 |
+
query_pos: Optional[Tensor] = None):
|
349 |
+
if self.normalize_before:
|
350 |
+
return self.forward_pre(tgt, memory, tgt_mask, memory_mask,
|
351 |
+
tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos)
|
352 |
+
return self.forward_post(tgt, memory, tgt_mask, memory_mask,
|
353 |
+
tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos)
|
354 |
+
|
355 |
+
|
356 |
+
class TransformerDecoderLayerThin(nn.Module):
|
357 |
+
"""removed intermediate layer"""
|
358 |
+
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
|
359 |
+
activation="relu", normalize_before=False):
|
360 |
+
super().__init__()
|
361 |
+
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
362 |
+
self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
363 |
+
# Implementation of Feedforward model
|
364 |
+
self.linear1 = nn.Linear(d_model, d_model)
|
365 |
+
# self.linear1 = nn.Linear(d_model, dim_feedforward)
|
366 |
+
# self.dropout = nn.Dropout(dropout)
|
367 |
+
# self.linear2 = nn.Linear(dim_feedforward, d_model)
|
368 |
+
|
369 |
+
self.norm1 = nn.LayerNorm(d_model)
|
370 |
+
self.norm2 = nn.LayerNorm(d_model)
|
371 |
+
# self.norm3 = nn.LayerNorm(d_model)
|
372 |
+
self.dropout1 = nn.Dropout(dropout)
|
373 |
+
self.dropout2 = nn.Dropout(dropout)
|
374 |
+
# self.dropout3 = nn.Dropout(dropout)
|
375 |
+
|
376 |
+
# self.activation = _get_activation_fn(activation)
|
377 |
+
self.normalize_before = normalize_before
|
378 |
+
|
379 |
+
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
|
380 |
+
return tensor if pos is None else tensor + pos
|
381 |
+
|
382 |
+
def forward_post(self, tgt, memory,
|
383 |
+
tgt_mask: Optional[Tensor] = None,
|
384 |
+
memory_mask: Optional[Tensor] = None,
|
385 |
+
tgt_key_padding_mask: Optional[Tensor] = None,
|
386 |
+
memory_key_padding_mask: Optional[Tensor] = None,
|
387 |
+
pos: Optional[Tensor] = None,
|
388 |
+
query_pos: Optional[Tensor] = None):
|
389 |
+
q = k = self.with_pos_embed(tgt, query_pos)
|
390 |
+
tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask,
|
391 |
+
key_padding_mask=tgt_key_padding_mask)[0]
|
392 |
+
tgt = tgt + self.dropout1(tgt2)
|
393 |
+
tgt = self.norm1(tgt)
|
394 |
+
tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos),
|
395 |
+
key=self.with_pos_embed(memory, pos),
|
396 |
+
value=memory, attn_mask=memory_mask,
|
397 |
+
key_padding_mask=memory_key_padding_mask)[0]
|
398 |
+
tgt2 = self.linear1(tgt2)
|
399 |
+
tgt = tgt + self.dropout2(tgt2)
|
400 |
+
tgt = self.norm2(tgt)
|
401 |
+
# tgt = tgt + self.dropout2(tgt2)
|
402 |
+
# tgt = self.norm2(tgt)
|
403 |
+
# tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
|
404 |
+
# tgt = tgt + self.dropout3(tgt2)
|
405 |
+
# tgt = self.norm3(tgt)
|
406 |
+
return tgt
|
407 |
+
|
408 |
+
def forward_pre(self, tgt, memory,
|
409 |
+
tgt_mask: Optional[Tensor] = None,
|
410 |
+
memory_mask: Optional[Tensor] = None,
|
411 |
+
tgt_key_padding_mask: Optional[Tensor] = None,
|
412 |
+
memory_key_padding_mask: Optional[Tensor] = None,
|
413 |
+
pos: Optional[Tensor] = None,
|
414 |
+
query_pos: Optional[Tensor] = None):
|
415 |
+
tgt2 = self.norm1(tgt)
|
416 |
+
q = k = self.with_pos_embed(tgt2, query_pos)
|
417 |
+
tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask,
|
418 |
+
key_padding_mask=tgt_key_padding_mask)[0]
|
419 |
+
tgt = tgt + self.dropout1(tgt2)
|
420 |
+
tgt2 = self.norm2(tgt)
|
421 |
+
tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos),
|
422 |
+
key=self.with_pos_embed(memory, pos),
|
423 |
+
value=memory, attn_mask=memory_mask,
|
424 |
+
key_padding_mask=memory_key_padding_mask)[0]
|
425 |
+
tgt = tgt + self.dropout2(tgt2)
|
426 |
+
tgt2 = self.norm3(tgt)
|
427 |
+
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
|
428 |
+
tgt = tgt + self.dropout3(tgt2)
|
429 |
+
return tgt
|
430 |
+
|
431 |
+
def forward(self, tgt, memory,
|
432 |
+
tgt_mask: Optional[Tensor] = None,
|
433 |
+
memory_mask: Optional[Tensor] = None,
|
434 |
+
tgt_key_padding_mask: Optional[Tensor] = None,
|
435 |
+
memory_key_padding_mask: Optional[Tensor] = None,
|
436 |
+
pos: Optional[Tensor] = None,
|
437 |
+
query_pos: Optional[Tensor] = None):
|
438 |
+
if self.normalize_before:
|
439 |
+
return self.forward_pre(tgt, memory, tgt_mask, memory_mask,
|
440 |
+
tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos)
|
441 |
+
return self.forward_post(tgt, memory, tgt_mask, memory_mask,
|
442 |
+
tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos)
|
443 |
+
|
444 |
+
|
445 |
+
|
446 |
+
def _get_clones(module, N):
|
447 |
+
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
|
448 |
+
|
449 |
+
|
450 |
+
def build_transformer(args):
|
451 |
+
return Transformer(
|
452 |
+
d_model=args.hidden_dim,
|
453 |
+
dropout=args.dropout,
|
454 |
+
nhead=args.nheads,
|
455 |
+
dim_feedforward=args.dim_feedforward,
|
456 |
+
num_encoder_layers=args.enc_layers,
|
457 |
+
num_decoder_layers=args.dec_layers,
|
458 |
+
normalize_before=args.pre_norm,
|
459 |
+
return_intermediate_dec=True,
|
460 |
+
)
|
461 |
+
|
462 |
+
|
463 |
+
def _get_activation_fn(activation):
|
464 |
+
"""Return an activation function given a string"""
|
465 |
+
if activation == "relu":
|
466 |
+
return F.relu
|
467 |
+
if activation == "gelu":
|
468 |
+
return F.gelu
|
469 |
+
if activation == "glu":
|
470 |
+
return F.glu
|
471 |
+
raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
|
model/transformer_encoder.py
ADDED
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import pdb
|
3 |
+
from typing import Optional
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from torch import nn, Tensor
|
8 |
+
|
9 |
+
def mask_logits(inputs, mask, mask_value=-1e30):
|
10 |
+
mask = mask.type(torch.float32)
|
11 |
+
return inputs + (1.0 - mask) * mask_value
|
12 |
+
|
13 |
+
|
14 |
+
class Transformer(nn.Module):
|
15 |
+
|
16 |
+
def __init__(self, d_model=512, nhead=8, num_encoder_layers=4,
|
17 |
+
num_decoder_layers=6, dim_feedforward=2048, dropout=0.1,
|
18 |
+
activation="relu", normalize_before=False, # False as default
|
19 |
+
return_intermediate_dec=False):
|
20 |
+
super().__init__()
|
21 |
+
|
22 |
+
encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward,
|
23 |
+
dropout, activation, normalize_before)
|
24 |
+
encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
|
25 |
+
self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
|
26 |
+
|
27 |
+
self._reset_parameters()
|
28 |
+
|
29 |
+
self.d_model = d_model
|
30 |
+
self.nhead = nhead
|
31 |
+
|
32 |
+
def _reset_parameters(self):
|
33 |
+
for p in self.parameters():
|
34 |
+
if p.dim() > 1:
|
35 |
+
nn.init.xavier_uniform_(p)
|
36 |
+
|
37 |
+
def forward(self, src, mask, pos_embed):
|
38 |
+
"""
|
39 |
+
Args:
|
40 |
+
src: (batch_size, L, d)
|
41 |
+
mask: (batch_size, L)
|
42 |
+
query_embed: (#queries, d) -> my imple (batch_size, d) and #queries=1
|
43 |
+
pos_embed: (batch_size, L, d) the same as src
|
44 |
+
|
45 |
+
Returns:
|
46 |
+
|
47 |
+
"""
|
48 |
+
# flatten NxCxHxW to HWxNxC
|
49 |
+
src = src.permute(1, 0, 2) # (L, batch_size, d)
|
50 |
+
pos_embed = pos_embed.permute(1, 0, 2) # (L, batch_size, d)
|
51 |
+
|
52 |
+
memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed)
|
53 |
+
memory = memory.transpose(0, 1)
|
54 |
+
|
55 |
+
return memory
|
56 |
+
|
57 |
+
|
58 |
+
class TransformerEncoder(nn.Module):
|
59 |
+
def __init__(self, encoder_layer, num_layers, norm=None, return_intermediate=False):
|
60 |
+
super().__init__()
|
61 |
+
self.layers = _get_clones(encoder_layer, num_layers)
|
62 |
+
self.num_layers = num_layers
|
63 |
+
self.norm = norm
|
64 |
+
self.return_intermediate = return_intermediate
|
65 |
+
|
66 |
+
def forward(self, src,
|
67 |
+
mask: Optional[Tensor] = None,
|
68 |
+
src_key_padding_mask: Optional[Tensor] = None,
|
69 |
+
pos: Optional[Tensor] = None):
|
70 |
+
output = src
|
71 |
+
|
72 |
+
intermediate = []
|
73 |
+
|
74 |
+
for layer in self.layers:
|
75 |
+
output = layer(output, src_mask=mask,
|
76 |
+
src_key_padding_mask=src_key_padding_mask, pos=pos)
|
77 |
+
if self.return_intermediate:
|
78 |
+
intermediate.append(output)
|
79 |
+
|
80 |
+
if self.norm is not None:
|
81 |
+
output = self.norm(output)
|
82 |
+
|
83 |
+
if self.return_intermediate:
|
84 |
+
return torch.stack(intermediate)
|
85 |
+
|
86 |
+
return output
|
87 |
+
|
88 |
+
|
89 |
+
class TransformerEncoderLayer(nn.Module):
|
90 |
+
|
91 |
+
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
|
92 |
+
activation="relu", normalize_before=False):
|
93 |
+
super().__init__()
|
94 |
+
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
95 |
+
# Implementation of Feedforward model
|
96 |
+
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
97 |
+
self.dropout = nn.Dropout(dropout)
|
98 |
+
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
99 |
+
|
100 |
+
self.norm1 = nn.LayerNorm(d_model)
|
101 |
+
self.norm2 = nn.LayerNorm(d_model)
|
102 |
+
self.dropout1 = nn.Dropout(dropout)
|
103 |
+
self.dropout2 = nn.Dropout(dropout)
|
104 |
+
|
105 |
+
self.activation = _get_activation_fn(activation)
|
106 |
+
self.normalize_before = normalize_before
|
107 |
+
|
108 |
+
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
|
109 |
+
return tensor if pos is None else tensor + pos
|
110 |
+
|
111 |
+
def forward_post(self,
|
112 |
+
src,
|
113 |
+
src_mask: Optional[Tensor] = None,
|
114 |
+
src_key_padding_mask: Optional[Tensor] = None,
|
115 |
+
pos: Optional[Tensor] = None):
|
116 |
+
q = k = self.with_pos_embed(src, pos)
|
117 |
+
src2 = self.self_attn(q, k, value=src, attn_mask=src_mask,
|
118 |
+
key_padding_mask=src_key_padding_mask)[0]
|
119 |
+
src = src + self.dropout1(src2)
|
120 |
+
src = self.norm1(src)
|
121 |
+
src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
|
122 |
+
src = src + self.dropout2(src2)
|
123 |
+
src = self.norm2(src)
|
124 |
+
return src
|
125 |
+
|
126 |
+
def forward(self, src,
|
127 |
+
src_mask: Optional[Tensor] = None,
|
128 |
+
src_key_padding_mask: Optional[Tensor] = None,
|
129 |
+
pos: Optional[Tensor] = None):
|
130 |
+
if self.normalize_before:
|
131 |
+
return self.forward_pre(src, src_mask, src_key_padding_mask, pos)
|
132 |
+
return self.forward_post(src, src_mask, src_key_padding_mask, pos)
|
133 |
+
|
134 |
+
|
135 |
+
def _get_clones(module, N):
|
136 |
+
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
|
137 |
+
|
138 |
+
|
139 |
+
def build_transformer(args):
|
140 |
+
return Transformer(
|
141 |
+
d_model=args.hidden_dim,
|
142 |
+
dropout=args.dropout,
|
143 |
+
nhead=args.nheads,
|
144 |
+
dim_feedforward=args.dim_feedforward,
|
145 |
+
num_encoder_layers=args.enc_layers,
|
146 |
+
num_decoder_layers=args.dec_layers,
|
147 |
+
normalize_before=args.pre_norm,
|
148 |
+
return_intermediate_dec=True,
|
149 |
+
)
|
150 |
+
|
151 |
+
def _get_activation_fn(activation):
|
152 |
+
"""Return an activation function given a string"""
|
153 |
+
if activation == "relu":
|
154 |
+
return F.relu
|
155 |
+
if activation == "gelu":
|
156 |
+
return F.gelu
|
157 |
+
if activation == "glu":
|
158 |
+
return F.glu
|
159 |
+
raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
|
model/transformer_encoder_droppath.py
ADDED
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import pdb
|
3 |
+
from typing import Optional
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from torch import nn, Tensor
|
8 |
+
|
9 |
+
def mask_logits(inputs, mask, mask_value=-1e30):
|
10 |
+
mask = mask.type(torch.float32)
|
11 |
+
return inputs + (1.0 - mask) * mask_value
|
12 |
+
|
13 |
+
|
14 |
+
class Transformer(nn.Module):
|
15 |
+
|
16 |
+
def __init__(self, d_model=512, nhead=8, num_encoder_layers=4,
|
17 |
+
num_decoder_layers=6, dim_feedforward=2048, dropout=0.1, droppath=0.1,
|
18 |
+
activation="gelu", normalize_before=False, # False as default
|
19 |
+
return_intermediate_dec=False):
|
20 |
+
super().__init__()
|
21 |
+
|
22 |
+
encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward,
|
23 |
+
dropout, droppath, activation, normalize_before)
|
24 |
+
encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
|
25 |
+
self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
|
26 |
+
|
27 |
+
self._reset_parameters()
|
28 |
+
|
29 |
+
self.d_model = d_model
|
30 |
+
self.nhead = nhead
|
31 |
+
|
32 |
+
def _reset_parameters(self):
|
33 |
+
for p in self.parameters():
|
34 |
+
if p.dim() > 1:
|
35 |
+
nn.init.xavier_uniform_(p)
|
36 |
+
|
37 |
+
def forward(self, src, mask, pos_embed):
|
38 |
+
"""
|
39 |
+
Args:
|
40 |
+
src: (batch_size, L, d)
|
41 |
+
mask: (batch_size, L)
|
42 |
+
query_embed: (#queries, d) -> my imple (batch_size, d) and #queries=1
|
43 |
+
pos_embed: (batch_size, L, d) the same as src
|
44 |
+
|
45 |
+
Returns:
|
46 |
+
|
47 |
+
"""
|
48 |
+
# flatten NxCxHxW to HWxNxC
|
49 |
+
src = src.permute(1, 0, 2) # (L, batch_size, d)
|
50 |
+
pos_embed = pos_embed.permute(1, 0, 2) # (L, batch_size, d)
|
51 |
+
|
52 |
+
memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed)
|
53 |
+
memory = memory.transpose(0, 1)
|
54 |
+
|
55 |
+
return memory
|
56 |
+
|
57 |
+
|
58 |
+
class TransformerEncoder(nn.Module):
|
59 |
+
def __init__(self, encoder_layer, num_layers, norm=None, return_intermediate=False):
|
60 |
+
super().__init__()
|
61 |
+
self.layers = _get_clones(encoder_layer, num_layers)
|
62 |
+
self.num_layers = num_layers
|
63 |
+
self.norm = norm
|
64 |
+
self.return_intermediate = return_intermediate
|
65 |
+
|
66 |
+
def forward(self, src,
|
67 |
+
mask: Optional[Tensor] = None,
|
68 |
+
src_key_padding_mask: Optional[Tensor] = None,
|
69 |
+
pos: Optional[Tensor] = None):
|
70 |
+
output = src
|
71 |
+
|
72 |
+
intermediate = []
|
73 |
+
|
74 |
+
for layer in self.layers:
|
75 |
+
output = layer(output, src_mask=mask,
|
76 |
+
src_key_padding_mask=src_key_padding_mask, pos=pos)
|
77 |
+
if self.return_intermediate:
|
78 |
+
intermediate.append(output)
|
79 |
+
|
80 |
+
if self.norm is not None:
|
81 |
+
output = self.norm(output)
|
82 |
+
|
83 |
+
if self.return_intermediate:
|
84 |
+
return torch.stack(intermediate)
|
85 |
+
|
86 |
+
return output
|
87 |
+
|
88 |
+
class TransformerEncoderLayer(nn.Module):
|
89 |
+
|
90 |
+
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, droppath=0.1,
|
91 |
+
activation="relu", normalize_before=False):
|
92 |
+
super().__init__()
|
93 |
+
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
94 |
+
# Implementation of Feedforward model
|
95 |
+
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
96 |
+
self.dropout = nn.Dropout(dropout)
|
97 |
+
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
98 |
+
|
99 |
+
self.norm1 = nn.LayerNorm(d_model)
|
100 |
+
self.norm2 = nn.LayerNorm(d_model)
|
101 |
+
# self.dropout1 = nn.Dropout(dropout)
|
102 |
+
# self.dropout2 = nn.Dropout(dropout)
|
103 |
+
self.droppath1 = DropPath(droppath)
|
104 |
+
self.droppath2 = DropPath(droppath)
|
105 |
+
|
106 |
+
self.activation = _get_activation_fn(activation)
|
107 |
+
self.normalize_before = normalize_before
|
108 |
+
|
109 |
+
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
|
110 |
+
return tensor if pos is None else tensor + pos
|
111 |
+
|
112 |
+
def forward_post(self,
|
113 |
+
src,
|
114 |
+
src_mask: Optional[Tensor] = None,
|
115 |
+
src_key_padding_mask: Optional[Tensor] = None,
|
116 |
+
pos: Optional[Tensor] = None):
|
117 |
+
q = k = self.with_pos_embed(src, pos)
|
118 |
+
src2 = self.self_attn(q, k, value=src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0]
|
119 |
+
# src2 = self.self_attn_eff(q=q, k=k, v=src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0]
|
120 |
+
src = src + self.droppath1(src2)
|
121 |
+
src = self.norm1(src)
|
122 |
+
src2 = self.linear2(self.activation(self.linear1(src)))
|
123 |
+
# src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
|
124 |
+
src = src + self.droppath2(src2)
|
125 |
+
src = self.norm2(src)
|
126 |
+
return src
|
127 |
+
|
128 |
+
def forward(self, src,
|
129 |
+
src_mask: Optional[Tensor] = None,
|
130 |
+
src_key_padding_mask: Optional[Tensor] = None,
|
131 |
+
pos: Optional[Tensor] = None):
|
132 |
+
if self.normalize_before:
|
133 |
+
return self.forward_pre(src, src_mask, src_key_padding_mask, pos)
|
134 |
+
return self.forward_post(src, src_mask, src_key_padding_mask, pos)
|
135 |
+
|
136 |
+
|
137 |
+
def _get_clones(module, N):
|
138 |
+
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
|
139 |
+
|
140 |
+
|
141 |
+
def build_transformer(args):
|
142 |
+
return Transformer(
|
143 |
+
d_model=args.hidden_dim,
|
144 |
+
dropout=args.dropout,
|
145 |
+
droppath=args.droppath,
|
146 |
+
nhead=args.nheads,
|
147 |
+
dim_feedforward=args.dim_feedforward,
|
148 |
+
num_encoder_layers=args.enc_layers,
|
149 |
+
num_decoder_layers=args.dec_layers,
|
150 |
+
normalize_before=args.pre_norm,
|
151 |
+
return_intermediate_dec=True,
|
152 |
+
)
|
153 |
+
|
154 |
+
def drop_path(x, drop_prob=0.0, training=False):
|
155 |
+
"""
|
156 |
+
Stochastic Depth per sample.
|
157 |
+
"""
|
158 |
+
if drop_prob == 0.0 or not training:
|
159 |
+
return x
|
160 |
+
|
161 |
+
keep_prob = 1 - drop_prob
|
162 |
+
shape = (x.shape[0],) + (1,) * (x.ndim - 1)
|
163 |
+
mask = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
|
164 |
+
mask.floor_()
|
165 |
+
x = x.div(keep_prob) * mask
|
166 |
+
|
167 |
+
return x
|
168 |
+
|
169 |
+
|
170 |
+
class DropPath(nn.Module):
|
171 |
+
"""
|
172 |
+
Drop paths per sample (when applied in main path of residual blocks).
|
173 |
+
"""
|
174 |
+
|
175 |
+
def __init__(self, drop_prob=None):
|
176 |
+
super(DropPath, self).__init__()
|
177 |
+
|
178 |
+
self.drop_prob = drop_prob
|
179 |
+
|
180 |
+
def forward(self, x):
|
181 |
+
x = x.permute(1, 0, 2)
|
182 |
+
res = drop_path(x, self.drop_prob, self.training)
|
183 |
+
return res.permute(1, 0, 2)
|
184 |
+
# return drop_path(x, self.drop_prob, self.training)
|
185 |
+
|
186 |
+
def _get_activation_fn(activation):
|
187 |
+
"""Return an activation function given a string"""
|
188 |
+
if activation == "relu":
|
189 |
+
return F.relu
|
190 |
+
if activation == "gelu":
|
191 |
+
return F.gelu
|
192 |
+
if activation == "glu":
|
193 |
+
return F.glu
|
194 |
+
raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
|
model/univtg.py
ADDED
@@ -0,0 +1,450 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pdb
|
2 |
+
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from torch import nn
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
from model.transformer_encoder_droppath import build_transformer
|
8 |
+
from model.matcher import build_matcher
|
9 |
+
from model.position_encoding import build_position_encoding
|
10 |
+
from utils.span_utils import generalized_temporal_iou, span_cxw_to_xx
|
11 |
+
|
12 |
+
def init_weights(module):
|
13 |
+
if isinstance(module, (nn.Linear, nn.Embedding)):
|
14 |
+
module.weight.data.normal_(mean=0.0, std=0.02)
|
15 |
+
elif isinstance(module, nn.LayerNorm):
|
16 |
+
module.bias.data.zero_()
|
17 |
+
module.weight.data.fill_(1.0)
|
18 |
+
|
19 |
+
if isinstance(module, nn.Linear) and module.bias is not None:
|
20 |
+
module.bias.data.zero_()
|
21 |
+
|
22 |
+
def mask_logits(inputs, mask, mask_value=-1e30):
|
23 |
+
mask = mask.type(torch.float32)
|
24 |
+
return inputs + (1.0 - mask) * mask_value
|
25 |
+
|
26 |
+
def sim_matrix(a, b, eps=1e-8):
|
27 |
+
"""
|
28 |
+
added eps for numerical stability
|
29 |
+
"""
|
30 |
+
a_n, b_n = a.norm(dim=1)[:, None], b.norm(dim=1)[:, None]
|
31 |
+
a_norm = a / torch.max(a_n, eps * torch.ones_like(a_n))
|
32 |
+
b_norm = b / torch.max(b_n, eps * torch.ones_like(b_n))
|
33 |
+
sim_mt = torch.mm(a_norm, b_norm.transpose(0, 1))
|
34 |
+
return sim_mt
|
35 |
+
|
36 |
+
class WeightedPool(nn.Module):
|
37 |
+
def __init__(self, dim):
|
38 |
+
super(WeightedPool, self).__init__()
|
39 |
+
weight = torch.empty(dim, 1)
|
40 |
+
nn.init.xavier_uniform_(weight)
|
41 |
+
self.weight = nn.Parameter(weight, requires_grad=True)
|
42 |
+
|
43 |
+
def forward(self, x, mask):
|
44 |
+
alpha = torch.tensordot(x, self.weight, dims=1) # shape = (batch_size, seq_length, 1)
|
45 |
+
alpha = mask_logits(alpha, mask=mask.unsqueeze(2))
|
46 |
+
alphas = nn.Softmax(dim=1)(alpha)
|
47 |
+
pooled_x = torch.matmul(x.transpose(1, 2), alphas) # (batch_size, dim, 1)
|
48 |
+
pooled_x = pooled_x.squeeze(2)
|
49 |
+
return pooled_x
|
50 |
+
|
51 |
+
class Model(nn.Module):
|
52 |
+
""" This is the UniVTG module that performs moment localization. """
|
53 |
+
|
54 |
+
def __init__(self, transformer, position_embed, txt_position_embed, txt_dim, vid_dim,
|
55 |
+
input_dropout, aux_loss=False,
|
56 |
+
max_v_l=75, span_loss_type="l1", use_txt_pos=False, n_input_proj=2):
|
57 |
+
""" Initializes the model.
|
58 |
+
Parameters:
|
59 |
+
transformer: torch module of the transformer architecture. See transformer.py
|
60 |
+
position_embed: torch module of the position_embedding, See position_encoding.py
|
61 |
+
txt_position_embed: position_embedding for text
|
62 |
+
txt_dim: int, text query input dimension
|
63 |
+
vid_dim: int, video feature input dimension
|
64 |
+
max_v_l: int, maximum #clips in videos
|
65 |
+
span_loss_type: str, one of [l1, ce]
|
66 |
+
l1: (center-x, width) regression.
|
67 |
+
ce: (st_idx, ed_idx) classification.
|
68 |
+
# foreground_thd: float, intersection over prediction >= foreground_thd: labeled as foreground
|
69 |
+
# background_thd: float, intersection over prediction <= background_thd: labeled background
|
70 |
+
"""
|
71 |
+
super().__init__()
|
72 |
+
self.transformer = transformer
|
73 |
+
self.position_embed = position_embed
|
74 |
+
self.txt_position_embed = txt_position_embed
|
75 |
+
hidden_dim = transformer.d_model
|
76 |
+
self.span_loss_type = span_loss_type
|
77 |
+
self.max_v_l = max_v_l
|
78 |
+
span_pred_dim = 2 if span_loss_type == "l1" else max_v_l * 2
|
79 |
+
|
80 |
+
self.token_type_embeddings = nn.Embedding(2, hidden_dim)
|
81 |
+
self.token_type_embeddings.apply(init_weights)
|
82 |
+
|
83 |
+
# Conv projector
|
84 |
+
self.span_embed = Conv(hidden_dim, hidden_dim, span_pred_dim, 3, kernel_size=3)
|
85 |
+
self.class_embed = Conv(hidden_dim, hidden_dim, 1, 3, kernel_size=3) # 0: background, 1: foreground
|
86 |
+
|
87 |
+
self.use_txt_pos = use_txt_pos
|
88 |
+
self.n_input_proj = n_input_proj
|
89 |
+
relu_args = [True] * 3
|
90 |
+
relu_args[n_input_proj-1] = False
|
91 |
+
self.input_txt_proj = nn.Sequential(*[
|
92 |
+
LinearLayer(txt_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[0]),
|
93 |
+
LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[1]),
|
94 |
+
LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[2])
|
95 |
+
][:n_input_proj])
|
96 |
+
self.input_vid_proj = nn.Sequential(*[
|
97 |
+
LinearLayer(vid_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[0]),
|
98 |
+
LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[1]),
|
99 |
+
LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[2])
|
100 |
+
][:n_input_proj])
|
101 |
+
|
102 |
+
# MLP Projector
|
103 |
+
self.weightedpool = WeightedPool(hidden_dim)
|
104 |
+
|
105 |
+
def forward(self, src_txt, src_txt_mask, src_vid, src_vid_mask, src_cls=None, src_cls_mask=None):
|
106 |
+
bs = src_vid.shape[0]
|
107 |
+
src_vid = self.input_vid_proj(src_vid)
|
108 |
+
src_txt = self.input_txt_proj(src_txt)
|
109 |
+
if src_cls is not None:
|
110 |
+
src_cls = self.input_txt_proj(src_cls)
|
111 |
+
device_id = src_vid.device
|
112 |
+
|
113 |
+
# type token.
|
114 |
+
src_vid = src_vid + self.token_type_embeddings(torch.full_like(src_vid_mask.long(), 1))
|
115 |
+
src_txt = src_txt + self.token_type_embeddings(torch.zeros_like(src_txt_mask.long()))
|
116 |
+
if src_cls is not None:
|
117 |
+
src_cls = src_cls + self.token_type_embeddings(torch.zeros_like(src_cls_mask.long()))
|
118 |
+
|
119 |
+
src = torch.cat([src_vid, src_txt], dim=1) # (bsz, L_vid+L_txt, d)
|
120 |
+
mask = torch.cat([src_vid_mask, src_txt_mask], dim=1).bool() # (bsz, L_vid+L_txt)
|
121 |
+
|
122 |
+
pos_vid = self.position_embed(src_vid, src_vid_mask) # (bsz, L_vid, d)
|
123 |
+
pos_txt = self.txt_position_embed(src_txt) if self.use_txt_pos else torch.zeros_like(src_txt) # (bsz, L_txt, d)
|
124 |
+
pos = torch.cat([pos_vid, pos_txt], dim=1)
|
125 |
+
|
126 |
+
memory = self.transformer(src, ~mask, pos)
|
127 |
+
vid_mem = memory[:, :src_vid.shape[1], :] # (bsz, L_vid, d)
|
128 |
+
|
129 |
+
outputs_class = self.class_embed(vid_mem).sigmoid() # (#layers, batch_size, #queries, #classes)
|
130 |
+
outputs_coord = self.span_embed(vid_mem) # (#layers, bsz, #queries, 2 or max_v_l * 2)
|
131 |
+
|
132 |
+
if self.span_loss_type == "l1":
|
133 |
+
outputs_coord = outputs_coord.sigmoid()
|
134 |
+
idx_mask = torch.tensor((-1, 1)).unsqueeze(0).unsqueeze(0).to(device_id)
|
135 |
+
idx_mask = idx_mask.repeat(outputs_coord.shape[0], outputs_coord.shape[1], 1)
|
136 |
+
outputs_coord = outputs_coord * idx_mask
|
137 |
+
else:
|
138 |
+
raise NotImplementedError
|
139 |
+
|
140 |
+
out = {'pred_logits': outputs_class, 'pred_spans': outputs_coord,
|
141 |
+
'src_vid_mask': src_vid_mask}
|
142 |
+
|
143 |
+
vid_mem_proj = src_vid
|
144 |
+
|
145 |
+
# word-level -> sentence-level
|
146 |
+
txt_mem_proj = self.weightedpool(src_txt, src_txt_mask).unsqueeze(1)
|
147 |
+
sim = F.cosine_similarity(vid_mem_proj, txt_mem_proj, dim=-1) + (src_vid_mask + 1e-45).log()
|
148 |
+
|
149 |
+
out["vid_mem_proj"] = vid_mem_proj
|
150 |
+
out["txt_mem_proj"] = txt_mem_proj
|
151 |
+
if src_cls is not None:
|
152 |
+
cls_mem_proj = self.weightedpool(src_cls, src_cls_mask)
|
153 |
+
out["cls_mem_proj"] = cls_mem_proj
|
154 |
+
out["saliency_scores"] = sim
|
155 |
+
return out
|
156 |
+
|
157 |
+
class SetCriterion(nn.Module):
|
158 |
+
""" This class computes the loss for DETR.
|
159 |
+
The process happens in two steps:
|
160 |
+
1) we compute hungarian assignment between ground truth boxes and the outputs of the model
|
161 |
+
2) we supervise each pair of matched ground-truth / prediction (supervise class and box)
|
162 |
+
"""
|
163 |
+
|
164 |
+
def __init__(self, matcher, weight_dict, eos_coef, losses, temperature, span_loss_type, max_v_l,
|
165 |
+
saliency_margin=1):
|
166 |
+
""" Create the criterion.
|
167 |
+
Parameters:
|
168 |
+
matcher: module able to compute a matching between targets and proposals
|
169 |
+
weight_dict: dict containing as key the names of the losses and as values their relative weight.
|
170 |
+
eos_coef: relative classification weight applied to the no-object category
|
171 |
+
losses: list of all the losses to be applied. See get_loss for list of available losses.
|
172 |
+
temperature: float, temperature for NCE loss
|
173 |
+
span_loss_type: str, [l1, ce]
|
174 |
+
max_v_l: int,
|
175 |
+
saliency_margin: float
|
176 |
+
"""
|
177 |
+
super().__init__()
|
178 |
+
self.matcher = matcher
|
179 |
+
self.weight_dict = weight_dict
|
180 |
+
self.losses = losses
|
181 |
+
self.temperature = temperature
|
182 |
+
self.span_loss_type = span_loss_type
|
183 |
+
self.max_v_l = max_v_l
|
184 |
+
self.saliency_margin = saliency_margin
|
185 |
+
self.temperature = 0.07
|
186 |
+
|
187 |
+
# foreground and background classification
|
188 |
+
self.foreground_label = 0
|
189 |
+
self.background_label = 1
|
190 |
+
self.eos_coef = eos_coef
|
191 |
+
empty_weight = torch.ones(2)
|
192 |
+
empty_weight[-1] = self.eos_coef # lower weight for background (index 1, foreground index 0)
|
193 |
+
self.register_buffer('empty_weight', empty_weight)
|
194 |
+
|
195 |
+
def loss_spans(self, outputs, targets, indices):
|
196 |
+
assert 'pred_spans' in outputs
|
197 |
+
|
198 |
+
start_spans = targets['timestamp']
|
199 |
+
pred_spans = outputs['pred_spans']
|
200 |
+
src_spans = start_spans + pred_spans
|
201 |
+
gt_spans = targets['span_labels_nn']
|
202 |
+
|
203 |
+
mask = targets['timestamp_mask'].bool()
|
204 |
+
mask_full = targets['timestamp_mask'].unsqueeze(2).repeat(1, 1, 2)
|
205 |
+
mask_valid = targets['timestamp_window'].bool()
|
206 |
+
mask_valid_full = targets['timestamp_window'].unsqueeze(2).repeat(1, 1, 2)
|
207 |
+
|
208 |
+
loss_span = F.smooth_l1_loss(src_spans, gt_spans, reduction='none') * mask_valid_full
|
209 |
+
loss_giou = 1 - torch.diag(generalized_temporal_iou(src_spans[mask_valid], gt_spans[mask_valid]))
|
210 |
+
|
211 |
+
losses = {}
|
212 |
+
losses['loss_b'] = loss_span.sum() / mask_valid.sum()
|
213 |
+
losses['loss_g'] = loss_giou.mean()
|
214 |
+
return losses
|
215 |
+
|
216 |
+
def loss_labels(self, outputs, targets, indices, log=True):
|
217 |
+
src_logits = outputs['pred_logits'].squeeze(-1) # (batch_size, #queries, #classes=2)
|
218 |
+
mask = targets['timestamp_mask'].bool()
|
219 |
+
mask_valid = targets['timestamp_window'].bool()
|
220 |
+
target_classes = torch.full(src_logits.shape[:2], 0, dtype=torch.int64, device=src_logits.device) # (batch_size, #queries)
|
221 |
+
target_classes[mask_valid] = 1
|
222 |
+
# target_classes = targets['timestamp_window'] # soft cls.
|
223 |
+
target_classes.float()
|
224 |
+
# pdb.set_trace()
|
225 |
+
|
226 |
+
weights = torch.zeros_like(target_classes).float()
|
227 |
+
weights[mask] = self.empty_weight[1]
|
228 |
+
weights[mask_valid] = self.empty_weight[0]
|
229 |
+
|
230 |
+
# pdb.set_trace()
|
231 |
+
loss_ce = F.binary_cross_entropy(src_logits, target_classes.float(), weight=weights, reduction="none") * mask
|
232 |
+
return {"loss_f": loss_ce.sum() / mask.sum()}
|
233 |
+
# return {"loss_f": loss_ce.sum() / (1 + mask_valid.sum())}
|
234 |
+
|
235 |
+
def loss_saliency(self, outputs, targets, indices, log=True):
|
236 |
+
"""higher scores for positive clips"""
|
237 |
+
if "saliency_pos_labels" not in targets:
|
238 |
+
return {"loss_s_inter": 0., "loss_s_intra": 0.}
|
239 |
+
saliency_scores = targets["saliency_scores"]
|
240 |
+
if saliency_scores.sum() == 0:
|
241 |
+
return {"loss_s_inter": 0., "loss_s_intra": 0.}
|
242 |
+
|
243 |
+
# * inter-vid mode
|
244 |
+
vid_mem_proj = outputs["vid_mem_proj"]
|
245 |
+
pos_indices = targets["saliency_pos_labels"][:,0].long() # (N, #pairs)
|
246 |
+
batch_indices = torch.arange(len(vid_mem_proj)).to(vid_mem_proj.device)
|
247 |
+
|
248 |
+
vid_feats = vid_mem_proj[batch_indices, pos_indices]
|
249 |
+
txt_feats = outputs["txt_mem_proj"].squeeze(1)
|
250 |
+
sim = sim_matrix(vid_feats, txt_feats)
|
251 |
+
|
252 |
+
i_logsm = F.log_softmax(sim / self.temperature, dim=1)
|
253 |
+
j_logsm = F.log_softmax(sim.t() /self.temperature, dim=1)
|
254 |
+
|
255 |
+
# sum over positives
|
256 |
+
idiag = torch.diag(i_logsm)
|
257 |
+
jdiag = torch.diag(j_logsm)
|
258 |
+
loss_i = idiag.sum() / len(idiag)
|
259 |
+
loss_j = jdiag.sum() / len(jdiag)
|
260 |
+
|
261 |
+
loss_saliency_inter = - loss_i - loss_j
|
262 |
+
|
263 |
+
# * intra-vid mode
|
264 |
+
mask = targets['timestamp_mask']
|
265 |
+
selected_scores = saliency_scores[batch_indices, pos_indices].unsqueeze(-1)
|
266 |
+
neg_indices_in = (saliency_scores < selected_scores)
|
267 |
+
neg_indices_in[batch_indices, pos_indices] = True
|
268 |
+
mask_invalid = neg_indices_in * mask.bool()
|
269 |
+
|
270 |
+
sim_in = F.cosine_similarity(vid_mem_proj, txt_feats.unsqueeze(1), dim=-1)
|
271 |
+
sim_in = sim_in + (mask_invalid + 1e-45).log()
|
272 |
+
logsm_in_i = F.log_softmax(sim_in / self.temperature, dim=1)
|
273 |
+
logsm_in_j = F.log_softmax(sim_in.t() / self.temperature, dim=1)
|
274 |
+
|
275 |
+
pos_logsm_in_i = logsm_in_i[batch_indices, pos_indices]
|
276 |
+
pos_logsm_in_j = logsm_in_j[pos_indices, batch_indices]
|
277 |
+
loss_in_i = pos_logsm_in_i.sum() / len(pos_logsm_in_i)
|
278 |
+
loss_in_j = pos_logsm_in_j.sum() / len(pos_logsm_in_j)
|
279 |
+
|
280 |
+
loss_saliency_intra = - loss_in_i - loss_in_j
|
281 |
+
|
282 |
+
return {"loss_s_inter": loss_saliency_inter, "loss_s_intra": loss_saliency_intra}
|
283 |
+
|
284 |
+
def loss_saliency_cls(self, outputs, targets, indices, log=True):
|
285 |
+
"""higher scores for positive clips"""
|
286 |
+
if "saliency_pos_labels" not in targets:
|
287 |
+
return {"loss_s_inter": 0., "loss_s_intra": 0.}
|
288 |
+
saliency_scores = targets["saliency_scores"]
|
289 |
+
if saliency_scores.sum() == 0:
|
290 |
+
return {"loss_s_inter": 0., "loss_s_intra": 0.}
|
291 |
+
|
292 |
+
# * inter-vid mode
|
293 |
+
vid_mem_proj = outputs["vid_mem_proj"]
|
294 |
+
pos_indices = targets["saliency_pos_labels"][:,0].long() # (N, #pairs)
|
295 |
+
batch_indices = torch.arange(len(vid_mem_proj)).to(vid_mem_proj.device)
|
296 |
+
|
297 |
+
vid_feats = vid_mem_proj[batch_indices, pos_indices]
|
298 |
+
txt_feats = outputs["txt_mem_proj"].squeeze(1)
|
299 |
+
sim = sim_matrix(vid_feats, txt_feats)
|
300 |
+
|
301 |
+
i_logsm = F.log_softmax(sim / self.temperature, dim=1)
|
302 |
+
j_logsm = F.log_softmax(sim.t() /self.temperature, dim=1)
|
303 |
+
|
304 |
+
# sum over positives
|
305 |
+
idiag = torch.diag(i_logsm)
|
306 |
+
jdiag = torch.diag(j_logsm)
|
307 |
+
loss_i = idiag.sum() / len(idiag)
|
308 |
+
loss_j = jdiag.sum() / len(jdiag)
|
309 |
+
|
310 |
+
loss_saliency_inter = - loss_i - loss_j
|
311 |
+
|
312 |
+
# * intra-vid mode
|
313 |
+
if 'cls_idx' not in targets.keys(): # eval
|
314 |
+
return {"loss_s_inter": loss_saliency_inter}
|
315 |
+
|
316 |
+
cls_indices = targets['cls_idx'].bool()
|
317 |
+
cls_feats = outputs["cls_mem_proj"].squeeze(1)
|
318 |
+
sim_cls = sim_matrix(vid_feats, cls_feats)
|
319 |
+
|
320 |
+
i_logsm_cls = F.log_softmax(sim_cls / self.temperature, dim=1)
|
321 |
+
idiag_cls = i_logsm_cls[cls_indices]
|
322 |
+
loss_cls_i = idiag_cls.sum() / len(idiag_cls)
|
323 |
+
|
324 |
+
loss_saliency_intra = - loss_cls_i
|
325 |
+
|
326 |
+
return {"loss_s_inter": loss_saliency_inter, "loss_s_intra": loss_saliency_intra}
|
327 |
+
|
328 |
+
def get_loss(self, loss, outputs, targets, indices, **kwargs):
|
329 |
+
loss_map = {
|
330 |
+
"spans": self.loss_spans,
|
331 |
+
"labels": self.loss_labels,
|
332 |
+
"saliency": self.loss_saliency,
|
333 |
+
"saliency_cls": self.loss_saliency_cls,
|
334 |
+
}
|
335 |
+
assert loss in loss_map, f'do you really want to compute {loss} loss?'
|
336 |
+
return loss_map[loss](outputs, targets, indices, **kwargs)
|
337 |
+
|
338 |
+
def forward(self, outputs, targets, hl_only=False):
|
339 |
+
""" This performs the loss computation.
|
340 |
+
Parameters:
|
341 |
+
outputs: dict of tensors, see the output specification of the model for the format
|
342 |
+
targets: list of dicts, such that len(targets) == batch_size.
|
343 |
+
The expected keys in each dict depends on the losses applied, see each loss' doc
|
344 |
+
"""
|
345 |
+
indices = None
|
346 |
+
# Compute all the requested losses
|
347 |
+
losses = {}
|
348 |
+
for loss in self.losses:
|
349 |
+
losses.update(self.get_loss(loss, outputs, targets, indices))
|
350 |
+
|
351 |
+
return losses
|
352 |
+
|
353 |
+
class MLP(nn.Module):
|
354 |
+
""" Very simple multi-layer perceptron (also called FFN)"""
|
355 |
+
|
356 |
+
def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
|
357 |
+
super().__init__()
|
358 |
+
self.num_layers = num_layers
|
359 |
+
h = [hidden_dim] * (num_layers - 1)
|
360 |
+
self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
|
361 |
+
|
362 |
+
def forward(self, x):
|
363 |
+
for i, layer in enumerate(self.layers):
|
364 |
+
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
|
365 |
+
return x
|
366 |
+
|
367 |
+
class Conv(nn.Module):
|
368 |
+
""" Very simple multi-layer perceptron (also called FFN)"""
|
369 |
+
|
370 |
+
def __init__(self, input_dim, hidden_dim, output_dim, num_layers, kernel_size):
|
371 |
+
super().__init__()
|
372 |
+
self.num_layers = num_layers
|
373 |
+
h = [hidden_dim] * (num_layers - 1)
|
374 |
+
# self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
|
375 |
+
self.layers = nn.ModuleList(
|
376 |
+
nn.Conv1d(n, k, kernel_size=kernel_size, stride=1, padding=kernel_size//2, dilation=1, groups=1, bias=True, padding_mode='zeros')
|
377 |
+
for n, k in zip([input_dim] + h, h + [output_dim]))
|
378 |
+
def forward(self, x):
|
379 |
+
x = x.permute(0,2,1)
|
380 |
+
for i, layer in enumerate(self.layers):
|
381 |
+
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
|
382 |
+
return x.permute(0, 2, 1)
|
383 |
+
|
384 |
+
class LinearLayer(nn.Module):
|
385 |
+
"""linear layer configurable with layer normalization, dropout, ReLU."""
|
386 |
+
|
387 |
+
def __init__(self, in_hsz, out_hsz, layer_norm=True, dropout=0.1, relu=True):
|
388 |
+
super(LinearLayer, self).__init__()
|
389 |
+
self.relu = relu
|
390 |
+
self.layer_norm = layer_norm
|
391 |
+
if layer_norm:
|
392 |
+
self.LayerNorm = nn.LayerNorm(in_hsz)
|
393 |
+
layers = [
|
394 |
+
nn.Dropout(dropout),
|
395 |
+
nn.Linear(in_hsz, out_hsz)
|
396 |
+
]
|
397 |
+
self.net = nn.Sequential(*layers)
|
398 |
+
|
399 |
+
def forward(self, x):
|
400 |
+
"""(N, L, D)"""
|
401 |
+
if self.layer_norm:
|
402 |
+
x = self.LayerNorm(x)
|
403 |
+
x = self.net(x)
|
404 |
+
if self.relu:
|
405 |
+
x = F.relu(x, inplace=True)
|
406 |
+
return x # (N, L, D)
|
407 |
+
|
408 |
+
|
409 |
+
def build_model(args):
|
410 |
+
device = torch.device(args.device)
|
411 |
+
|
412 |
+
transformer = build_transformer(args)
|
413 |
+
position_embedding, txt_position_embedding = build_position_encoding(args)
|
414 |
+
|
415 |
+
model = Model(
|
416 |
+
transformer,
|
417 |
+
position_embedding,
|
418 |
+
txt_position_embedding,
|
419 |
+
txt_dim=args.t_feat_dim,
|
420 |
+
vid_dim=args.v_feat_dim,
|
421 |
+
input_dropout=args.input_dropout,
|
422 |
+
span_loss_type=args.span_loss_type,
|
423 |
+
use_txt_pos=args.use_txt_pos,
|
424 |
+
n_input_proj=args.n_input_proj,
|
425 |
+
)
|
426 |
+
|
427 |
+
matcher = build_matcher(args)
|
428 |
+
weight_dict = {"loss_b": args.b_loss_coef,
|
429 |
+
"loss_g": args.g_loss_coef,
|
430 |
+
"loss_f": args.f_loss_coef,
|
431 |
+
"loss_s_intra": args.s_loss_intra_coef,
|
432 |
+
"loss_s_inter": args.s_loss_inter_coef}
|
433 |
+
|
434 |
+
if args.dset_type in ['mr', 'vlp']:
|
435 |
+
if 'tal' not in args.train_path:
|
436 |
+
losses = ['spans', 'labels', 'saliency']
|
437 |
+
else:
|
438 |
+
losses = ['spans', 'labels', 'saliency_cls']
|
439 |
+
elif args.dset_type in ['hl', 'vs']:
|
440 |
+
losses = ['labels', 'saliency']
|
441 |
+
|
442 |
+
criterion = SetCriterion(
|
443 |
+
matcher=matcher,
|
444 |
+
weight_dict=weight_dict, losses=losses,
|
445 |
+
eos_coef=args.eos_coef, temperature=args.temperature,
|
446 |
+
span_loss_type=args.span_loss_type, max_v_l=args.max_v_l,
|
447 |
+
saliency_margin=args.saliency_margin,
|
448 |
+
)
|
449 |
+
criterion.to(device)
|
450 |
+
return model, criterion
|
model/univtg_ablation.py
ADDED
@@ -0,0 +1,474 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pdb
|
2 |
+
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from torch import nn
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
from model.transformer_encoder_droppath import build_transformer
|
8 |
+
from model.matcher import build_matcher
|
9 |
+
from model.position_encoding import build_position_encoding
|
10 |
+
from utils.span_utils import generalized_temporal_iou, span_cxw_to_xx
|
11 |
+
|
12 |
+
def init_weights(module):
|
13 |
+
if isinstance(module, (nn.Linear, nn.Embedding)):
|
14 |
+
module.weight.data.normal_(mean=0.0, std=0.02)
|
15 |
+
elif isinstance(module, nn.LayerNorm):
|
16 |
+
module.bias.data.zero_()
|
17 |
+
module.weight.data.fill_(1.0)
|
18 |
+
|
19 |
+
if isinstance(module, nn.Linear) and module.bias is not None:
|
20 |
+
module.bias.data.zero_()
|
21 |
+
|
22 |
+
def mask_logits(inputs, mask, mask_value=-1e30):
|
23 |
+
mask = mask.type(torch.float32)
|
24 |
+
return inputs + (1.0 - mask) * mask_value
|
25 |
+
|
26 |
+
def sim_matrix(a, b, eps=1e-8):
|
27 |
+
"""
|
28 |
+
added eps for numerical stability
|
29 |
+
"""
|
30 |
+
a_n, b_n = a.norm(dim=1)[:, None], b.norm(dim=1)[:, None]
|
31 |
+
a_norm = a / torch.max(a_n, eps * torch.ones_like(a_n))
|
32 |
+
b_norm = b / torch.max(b_n, eps * torch.ones_like(b_n))
|
33 |
+
sim_mt = torch.mm(a_norm, b_norm.transpose(0, 1))
|
34 |
+
return sim_mt
|
35 |
+
|
36 |
+
class WeightedPool(nn.Module):
|
37 |
+
def __init__(self, dim):
|
38 |
+
super(WeightedPool, self).__init__()
|
39 |
+
weight = torch.empty(dim, 1)
|
40 |
+
nn.init.xavier_uniform_(weight)
|
41 |
+
self.weight = nn.Parameter(weight, requires_grad=True)
|
42 |
+
|
43 |
+
def forward(self, x, mask):
|
44 |
+
alpha = torch.tensordot(x, self.weight, dims=1) # shape = (batch_size, seq_length, 1)
|
45 |
+
alpha = mask_logits(alpha, mask=mask.unsqueeze(2))
|
46 |
+
alphas = nn.Softmax(dim=1)(alpha)
|
47 |
+
pooled_x = torch.matmul(x.transpose(1, 2), alphas) # (batch_size, dim, 1)
|
48 |
+
pooled_x = pooled_x.squeeze(2)
|
49 |
+
return pooled_x
|
50 |
+
|
51 |
+
class Model(nn.Module):
|
52 |
+
""" This is the UniVTG module that performs moment localization. """
|
53 |
+
|
54 |
+
def __init__(self, transformer, position_embed, txt_position_embed, txt_dim, vid_dim,
|
55 |
+
input_dropout, aux_loss=False,
|
56 |
+
max_v_l=75, span_loss_type="l1", use_txt_pos=False, n_input_proj=2):
|
57 |
+
""" Initializes the model.
|
58 |
+
Parameters:
|
59 |
+
transformer: torch module of the transformer architecture. See transformer.py
|
60 |
+
position_embed: torch module of the position_embedding, See position_encoding.py
|
61 |
+
txt_position_embed: position_embedding for text
|
62 |
+
txt_dim: int, text query input dimension
|
63 |
+
vid_dim: int, video feature input dimension
|
64 |
+
max_v_l: int, maximum #clips in videos
|
65 |
+
span_loss_type: str, one of [l1, ce]
|
66 |
+
l1: (center-x, width) regression.
|
67 |
+
ce: (st_idx, ed_idx) classification.
|
68 |
+
# foreground_thd: float, intersection over prediction >= foreground_thd: labeled as foreground
|
69 |
+
# background_thd: float, intersection over prediction <= background_thd: labeled background
|
70 |
+
"""
|
71 |
+
super().__init__()
|
72 |
+
self.transformer = transformer
|
73 |
+
self.position_embed = position_embed
|
74 |
+
self.txt_position_embed = txt_position_embed
|
75 |
+
hidden_dim = transformer.d_model
|
76 |
+
self.span_loss_type = span_loss_type
|
77 |
+
self.max_v_l = max_v_l
|
78 |
+
span_pred_dim = 2 if span_loss_type == "l1" else max_v_l * 2
|
79 |
+
|
80 |
+
self.token_type_embeddings = nn.Embedding(2, hidden_dim)
|
81 |
+
self.token_type_embeddings.apply(init_weights)
|
82 |
+
|
83 |
+
# Conv projector
|
84 |
+
self.span_embed = Conv(hidden_dim, hidden_dim, span_pred_dim, 3, kernel_size=3)
|
85 |
+
self.class_embed = Conv(hidden_dim, hidden_dim, 1, 3, kernel_size=3) # 0: background, 1: foreground
|
86 |
+
|
87 |
+
self.use_txt_pos = use_txt_pos
|
88 |
+
self.n_input_proj = n_input_proj
|
89 |
+
relu_args = [True] * 3
|
90 |
+
relu_args[n_input_proj-1] = False
|
91 |
+
self.input_txt_proj = nn.Sequential(*[
|
92 |
+
LinearLayer(txt_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[0]),
|
93 |
+
LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[1]),
|
94 |
+
LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[2])
|
95 |
+
][:n_input_proj])
|
96 |
+
self.input_vid_proj = nn.Sequential(*[
|
97 |
+
LinearLayer(vid_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[0]),
|
98 |
+
LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[1]),
|
99 |
+
LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[2])
|
100 |
+
][:n_input_proj])
|
101 |
+
|
102 |
+
# MLP Projector
|
103 |
+
self.weightedpool = WeightedPool(hidden_dim)
|
104 |
+
|
105 |
+
def forward(self, src_txt, src_txt_mask, src_vid, src_vid_mask, src_cls=None, src_cls_mask=None):
|
106 |
+
bs = src_vid.shape[0]
|
107 |
+
src_vid = self.input_vid_proj(src_vid)
|
108 |
+
src_txt = self.input_txt_proj(src_txt)
|
109 |
+
if src_cls is not None:
|
110 |
+
src_cls = self.input_txt_proj(src_cls)
|
111 |
+
|
112 |
+
# type token.
|
113 |
+
src_vid = src_vid + self.token_type_embeddings(torch.full_like(src_vid_mask.long(), 1))
|
114 |
+
src_txt = src_txt + self.token_type_embeddings(torch.zeros_like(src_txt_mask.long()))
|
115 |
+
if src_cls is not None:
|
116 |
+
src_cls = src_cls + self.token_type_embeddings(torch.zeros_like(src_cls_mask.long()))
|
117 |
+
|
118 |
+
src = torch.cat([src_vid, src_txt], dim=1) # (bsz, L_vid+L_txt, d)
|
119 |
+
mask = torch.cat([src_vid_mask, src_txt_mask], dim=1).bool() # (bsz, L_vid+L_txt)
|
120 |
+
|
121 |
+
pos_vid = self.position_embed(src_vid, src_vid_mask) # (bsz, L_vid, d)
|
122 |
+
pos_txt = self.txt_position_embed(src_txt) if self.use_txt_pos else torch.zeros_like(src_txt) # (bsz, L_txt, d)
|
123 |
+
pos = torch.cat([pos_vid, pos_txt], dim=1)
|
124 |
+
|
125 |
+
memory = self.transformer(src, ~mask, pos)
|
126 |
+
vid_mem = memory[:, :src_vid.shape[1], :] # (bsz, L_vid, d)
|
127 |
+
|
128 |
+
outputs_class = self.class_embed(vid_mem).sigmoid() # (#layers, batch_size, #queries, #classes)
|
129 |
+
outputs_coord = self.span_embed(vid_mem) # (#layers, bsz, #queries, 2 or max_v_l * 2)
|
130 |
+
|
131 |
+
if self.span_loss_type == "l1":
|
132 |
+
outputs_coord = outputs_coord.sigmoid()
|
133 |
+
idx_mask = torch.tensor((-1, 1)).unsqueeze(0).unsqueeze(0).cuda()
|
134 |
+
idx_mask = idx_mask.repeat(outputs_coord.shape[0], outputs_coord.shape[1], 1)
|
135 |
+
outputs_coord = outputs_coord * idx_mask
|
136 |
+
else:
|
137 |
+
raise NotImplementedError
|
138 |
+
|
139 |
+
out = {'pred_logits': outputs_class, 'pred_spans': outputs_coord,
|
140 |
+
'src_vid_mask': src_vid_mask}
|
141 |
+
|
142 |
+
vid_mem_proj = src_vid
|
143 |
+
|
144 |
+
# word-level -> sentence-level
|
145 |
+
txt_mem_proj = self.weightedpool(src_txt, src_txt_mask).unsqueeze(1)
|
146 |
+
sim = F.cosine_similarity(vid_mem_proj, txt_mem_proj, dim=-1) + (src_vid_mask + 1e-45).log()
|
147 |
+
|
148 |
+
out["vid_mem_proj"] = vid_mem_proj
|
149 |
+
out["txt_mem_proj"] = txt_mem_proj
|
150 |
+
if src_cls is not None:
|
151 |
+
cls_mem_proj = self.weightedpool(src_cls, src_cls_mask)
|
152 |
+
out["cls_mem_proj"] = cls_mem_proj
|
153 |
+
out["saliency_scores"] = sim
|
154 |
+
return out
|
155 |
+
|
156 |
+
class SetCriterion(nn.Module):
|
157 |
+
""" This class computes the loss for DETR.
|
158 |
+
The process happens in two steps:
|
159 |
+
1) we compute hungarian assignment between ground truth boxes and the outputs of the model
|
160 |
+
2) we supervise each pair of matched ground-truth / prediction (supervise class and box)
|
161 |
+
"""
|
162 |
+
|
163 |
+
def __init__(self, matcher, weight_dict, eos_coef, losses, temperature, span_loss_type, max_v_l,
|
164 |
+
saliency_margin=1):
|
165 |
+
""" Create the criterion.
|
166 |
+
Parameters:
|
167 |
+
matcher: module able to compute a matching between targets and proposals
|
168 |
+
weight_dict: dict containing as key the names of the losses and as values their relative weight.
|
169 |
+
eos_coef: relative classification weight applied to the no-object category
|
170 |
+
losses: list of all the losses to be applied. See get_loss for list of available losses.
|
171 |
+
temperature: float, temperature for NCE loss
|
172 |
+
span_loss_type: str, [l1, ce]
|
173 |
+
max_v_l: int,
|
174 |
+
saliency_margin: float
|
175 |
+
"""
|
176 |
+
super().__init__()
|
177 |
+
self.matcher = matcher
|
178 |
+
self.weight_dict = weight_dict
|
179 |
+
self.losses = losses
|
180 |
+
self.temperature = temperature
|
181 |
+
self.span_loss_type = span_loss_type
|
182 |
+
self.max_v_l = max_v_l
|
183 |
+
self.saliency_margin = saliency_margin
|
184 |
+
self.temperature = 0.07
|
185 |
+
|
186 |
+
# foreground and background classification
|
187 |
+
self.foreground_label = 0
|
188 |
+
self.background_label = 1
|
189 |
+
self.eos_coef = eos_coef
|
190 |
+
empty_weight = torch.ones(2)
|
191 |
+
empty_weight[-1] = self.eos_coef # lower weight for background (index 1, foreground index 0)
|
192 |
+
self.register_buffer('empty_weight', empty_weight)
|
193 |
+
|
194 |
+
def loss_spans(self, outputs, targets, indices):
|
195 |
+
assert 'pred_spans' in outputs
|
196 |
+
|
197 |
+
start_spans = targets['timestamp']
|
198 |
+
pred_spans = outputs['pred_spans']
|
199 |
+
src_spans = start_spans + pred_spans
|
200 |
+
gt_spans = targets['span_labels_nn']
|
201 |
+
|
202 |
+
mask = targets['timestamp_mask'].bool()
|
203 |
+
mask_full = targets['timestamp_mask'].unsqueeze(2).repeat(1, 1, 2)
|
204 |
+
mask_valid = targets['timestamp_window'].bool()
|
205 |
+
mask_valid_full = targets['timestamp_window'].unsqueeze(2).repeat(1, 1, 2)
|
206 |
+
|
207 |
+
weight_abalation_b = targets['weight_ablation'][:,0].unsqueeze(-1)
|
208 |
+
if weight_abalation_b.sum() == 0:
|
209 |
+
return {"loss_f": torch.tensor(0).cuda(), "loss_g": torch.tensor(0).cuda()}
|
210 |
+
|
211 |
+
mask_valid = (mask_valid * weight_abalation_b).bool()
|
212 |
+
mask_valid_full = (mask_valid_full * weight_abalation_b.unsqueeze(-1)).bool()
|
213 |
+
|
214 |
+
loss_span = F.smooth_l1_loss(src_spans, gt_spans, reduction='none') * mask_valid_full
|
215 |
+
loss_giou = 1 - torch.diag(generalized_temporal_iou(src_spans[mask_valid], gt_spans[mask_valid]))
|
216 |
+
|
217 |
+
losses = {}
|
218 |
+
losses['loss_b'] = loss_span.sum() / mask_valid.sum()
|
219 |
+
losses['loss_g'] = loss_giou.mean()
|
220 |
+
return losses
|
221 |
+
|
222 |
+
def loss_labels(self, outputs, targets, indices, log=True):
|
223 |
+
src_logits = outputs['pred_logits'].squeeze(-1) # (batch_size, #queries, #classes=2)
|
224 |
+
mask = targets['timestamp_mask'].bool()
|
225 |
+
mask_valid = targets['timestamp_window'].bool()
|
226 |
+
target_classes = torch.full(src_logits.shape[:2], 0, dtype=torch.int64, device=src_logits.device) # (batch_size, #queries)
|
227 |
+
target_classes[mask_valid] = 1
|
228 |
+
# target_classes = targets['timestamp_window'] # soft cls.
|
229 |
+
target_classes.float()
|
230 |
+
# pdb.set_trace()
|
231 |
+
|
232 |
+
weights = torch.zeros_like(target_classes).float()
|
233 |
+
weights[mask] = self.empty_weight[1]
|
234 |
+
weights[mask_valid] = self.empty_weight[0]
|
235 |
+
|
236 |
+
loss_ce = F.binary_cross_entropy(src_logits, target_classes.float(), weight=weights, reduction="none") * mask
|
237 |
+
|
238 |
+
weight_abalation_f = targets['weight_ablation'][:,2].unsqueeze(-1)
|
239 |
+
if weight_abalation_f.sum() == 0:
|
240 |
+
return {"loss_f": torch.tensor(0).cuda()}
|
241 |
+
|
242 |
+
mask = mask * weight_abalation_f
|
243 |
+
loss_ce = loss_ce * weight_abalation_f
|
244 |
+
return {"loss_f": loss_ce.sum() / mask.sum()}
|
245 |
+
# return {"loss_f": loss_ce.sum() / (1 + mask_valid.sum())}
|
246 |
+
|
247 |
+
def loss_saliency(self, outputs, targets, indices, log=True):
|
248 |
+
"""higher scores for positive clips"""
|
249 |
+
if "saliency_pos_labels" not in targets:
|
250 |
+
return {"loss_s_inter": 0., "loss_s_intra": 0.}
|
251 |
+
saliency_scores = targets["saliency_scores"]
|
252 |
+
if saliency_scores.sum() == 0:
|
253 |
+
return {"loss_s_inter": 0., "loss_s_intra": 0.}
|
254 |
+
|
255 |
+
# * inter-vid mode
|
256 |
+
vid_mem_proj = outputs["vid_mem_proj"]
|
257 |
+
pos_indices = targets["saliency_pos_labels"][:,0].long() # (N, #pairs)
|
258 |
+
batch_indices = torch.arange(len(vid_mem_proj)).to(vid_mem_proj.device)
|
259 |
+
|
260 |
+
vid_feats = vid_mem_proj[batch_indices, pos_indices]
|
261 |
+
txt_feats = outputs["txt_mem_proj"].squeeze(1)
|
262 |
+
sim = sim_matrix(vid_feats, txt_feats)
|
263 |
+
|
264 |
+
i_logsm = F.log_softmax(sim / self.temperature, dim=1)
|
265 |
+
j_logsm = F.log_softmax(sim.t() /self.temperature, dim=1)
|
266 |
+
|
267 |
+
# sum over positives
|
268 |
+
idiag = torch.diag(i_logsm)
|
269 |
+
jdiag = torch.diag(j_logsm)
|
270 |
+
|
271 |
+
weight_abalation_s = targets['weight_ablation'][:,3].bool()
|
272 |
+
if weight_abalation_s.sum() == 0:
|
273 |
+
return {"loss_s_inter": torch.tensor(0).cuda(),
|
274 |
+
"loss_s_intra": torch.tensor(0).cuda()}
|
275 |
+
|
276 |
+
_idiag = idiag[weight_abalation_s]
|
277 |
+
_jdiag = jdiag[weight_abalation_s]
|
278 |
+
|
279 |
+
loss_i = _idiag.sum() / len(_idiag)
|
280 |
+
loss_j = _jdiag.sum() / len(_jdiag)
|
281 |
+
|
282 |
+
loss_saliency_inter = - loss_i - loss_j
|
283 |
+
|
284 |
+
# * intra-vid mode
|
285 |
+
mask = targets['timestamp_mask']
|
286 |
+
selected_scores = saliency_scores[batch_indices, pos_indices].unsqueeze(-1)
|
287 |
+
neg_indices_in = (saliency_scores < selected_scores)
|
288 |
+
neg_indices_in[batch_indices, pos_indices] = True
|
289 |
+
mask_invalid = neg_indices_in * mask.bool()
|
290 |
+
|
291 |
+
sim_in = F.cosine_similarity(vid_mem_proj, txt_feats.unsqueeze(1), dim=-1)
|
292 |
+
sim_in = sim_in + (mask_invalid + 1e-45).log()
|
293 |
+
logsm_in_i = F.log_softmax(sim_in / self.temperature, dim=1)
|
294 |
+
logsm_in_j = F.log_softmax(sim_in.t() / self.temperature, dim=1)
|
295 |
+
|
296 |
+
pos_logsm_in_i = logsm_in_i[batch_indices, pos_indices]
|
297 |
+
pos_logsm_in_j = logsm_in_j[pos_indices, batch_indices]
|
298 |
+
_pos_logsm_in_i = pos_logsm_in_i[weight_abalation_s]
|
299 |
+
_pos_logsm_in_j = pos_logsm_in_j[weight_abalation_s]
|
300 |
+
|
301 |
+
loss_in_i = _pos_logsm_in_i.sum() / len(_pos_logsm_in_i)
|
302 |
+
loss_in_j = _pos_logsm_in_j.sum() / len(_pos_logsm_in_j)
|
303 |
+
|
304 |
+
loss_saliency_intra = - loss_in_i - loss_in_j
|
305 |
+
|
306 |
+
return {"loss_s_inter": loss_saliency_inter, "loss_s_intra": loss_saliency_intra}
|
307 |
+
|
308 |
+
def loss_saliency_cls(self, outputs, targets, indices, log=True):
|
309 |
+
"""higher scores for positive clips"""
|
310 |
+
if "saliency_pos_labels" not in targets:
|
311 |
+
return {"loss_s_inter": 0., "loss_s_intra": 0.}
|
312 |
+
saliency_scores = targets["saliency_scores"]
|
313 |
+
if saliency_scores.sum() == 0:
|
314 |
+
return {"loss_s_inter": 0., "loss_s_intra": 0.}
|
315 |
+
|
316 |
+
# * inter-vid mode
|
317 |
+
vid_mem_proj = outputs["vid_mem_proj"]
|
318 |
+
pos_indices = targets["saliency_pos_labels"][:,0].long() # (N, #pairs)
|
319 |
+
batch_indices = torch.arange(len(vid_mem_proj)).to(vid_mem_proj.device)
|
320 |
+
|
321 |
+
vid_feats = vid_mem_proj[batch_indices, pos_indices]
|
322 |
+
txt_feats = outputs["txt_mem_proj"].squeeze(1)
|
323 |
+
sim = sim_matrix(vid_feats, txt_feats)
|
324 |
+
|
325 |
+
i_logsm = F.log_softmax(sim / self.temperature, dim=1)
|
326 |
+
j_logsm = F.log_softmax(sim.t() /self.temperature, dim=1)
|
327 |
+
|
328 |
+
# sum over positives
|
329 |
+
idiag = torch.diag(i_logsm)
|
330 |
+
jdiag = torch.diag(j_logsm)
|
331 |
+
loss_i = idiag.sum() / len(idiag)
|
332 |
+
loss_j = jdiag.sum() / len(jdiag)
|
333 |
+
|
334 |
+
loss_saliency_inter = - loss_i - loss_j
|
335 |
+
|
336 |
+
# * intra-vid mode
|
337 |
+
if 'cls_idx' not in targets.keys(): # eval
|
338 |
+
return {"loss_s_inter": loss_saliency_inter}
|
339 |
+
|
340 |
+
cls_indices = targets['cls_idx'].bool()
|
341 |
+
cls_feats = outputs["cls_mem_proj"].squeeze(1)
|
342 |
+
sim_cls = sim_matrix(vid_feats, cls_feats)
|
343 |
+
|
344 |
+
i_logsm_cls = F.log_softmax(sim_cls / self.temperature, dim=1)
|
345 |
+
idiag_cls = i_logsm_cls[cls_indices]
|
346 |
+
loss_cls_i = idiag_cls.sum() / len(idiag_cls)
|
347 |
+
|
348 |
+
loss_saliency_intra = - loss_cls_i
|
349 |
+
|
350 |
+
return {"loss_s_inter": loss_saliency_inter, "loss_s_intra": loss_saliency_intra}
|
351 |
+
|
352 |
+
def get_loss(self, loss, outputs, targets, indices, **kwargs):
|
353 |
+
loss_map = {
|
354 |
+
"spans": self.loss_spans,
|
355 |
+
"labels": self.loss_labels,
|
356 |
+
"saliency": self.loss_saliency,
|
357 |
+
"saliency_cls": self.loss_saliency_cls,
|
358 |
+
}
|
359 |
+
assert loss in loss_map, f'do you really want to compute {loss} loss?'
|
360 |
+
return loss_map[loss](outputs, targets, indices, **kwargs)
|
361 |
+
|
362 |
+
def forward(self, outputs, targets, hl_only=False):
|
363 |
+
""" This performs the loss computation.
|
364 |
+
Parameters:
|
365 |
+
outputs: dict of tensors, see the output specification of the model for the format
|
366 |
+
targets: list of dicts, such that len(targets) == batch_size.
|
367 |
+
The expected keys in each dict depends on the losses applied, see each loss' doc
|
368 |
+
"""
|
369 |
+
indices = None
|
370 |
+
# Compute all the requested losses
|
371 |
+
losses = {}
|
372 |
+
for loss in self.losses:
|
373 |
+
losses.update(self.get_loss(loss, outputs, targets, indices))
|
374 |
+
|
375 |
+
return losses
|
376 |
+
|
377 |
+
class MLP(nn.Module):
|
378 |
+
""" Very simple multi-layer perceptron (also called FFN)"""
|
379 |
+
|
380 |
+
def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
|
381 |
+
super().__init__()
|
382 |
+
self.num_layers = num_layers
|
383 |
+
h = [hidden_dim] * (num_layers - 1)
|
384 |
+
self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
|
385 |
+
|
386 |
+
def forward(self, x):
|
387 |
+
for i, layer in enumerate(self.layers):
|
388 |
+
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
|
389 |
+
return x
|
390 |
+
|
391 |
+
class Conv(nn.Module):
|
392 |
+
""" Very simple multi-layer perceptron (also called FFN)"""
|
393 |
+
|
394 |
+
def __init__(self, input_dim, hidden_dim, output_dim, num_layers, kernel_size):
|
395 |
+
super().__init__()
|
396 |
+
self.num_layers = num_layers
|
397 |
+
h = [hidden_dim] * (num_layers - 1)
|
398 |
+
# self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
|
399 |
+
self.layers = nn.ModuleList(
|
400 |
+
nn.Conv1d(n, k, kernel_size=kernel_size, stride=1, padding=kernel_size//2, dilation=1, groups=1, bias=True, padding_mode='zeros')
|
401 |
+
for n, k in zip([input_dim] + h, h + [output_dim]))
|
402 |
+
def forward(self, x):
|
403 |
+
x = x.permute(0,2,1)
|
404 |
+
for i, layer in enumerate(self.layers):
|
405 |
+
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
|
406 |
+
return x.permute(0, 2, 1)
|
407 |
+
|
408 |
+
class LinearLayer(nn.Module):
|
409 |
+
"""linear layer configurable with layer normalization, dropout, ReLU."""
|
410 |
+
|
411 |
+
def __init__(self, in_hsz, out_hsz, layer_norm=True, dropout=0.1, relu=True):
|
412 |
+
super(LinearLayer, self).__init__()
|
413 |
+
self.relu = relu
|
414 |
+
self.layer_norm = layer_norm
|
415 |
+
if layer_norm:
|
416 |
+
self.LayerNorm = nn.LayerNorm(in_hsz)
|
417 |
+
layers = [
|
418 |
+
nn.Dropout(dropout),
|
419 |
+
nn.Linear(in_hsz, out_hsz)
|
420 |
+
]
|
421 |
+
self.net = nn.Sequential(*layers)
|
422 |
+
|
423 |
+
def forward(self, x):
|
424 |
+
"""(N, L, D)"""
|
425 |
+
if self.layer_norm:
|
426 |
+
x = self.LayerNorm(x)
|
427 |
+
x = self.net(x)
|
428 |
+
if self.relu:
|
429 |
+
x = F.relu(x, inplace=True)
|
430 |
+
return x # (N, L, D)
|
431 |
+
|
432 |
+
|
433 |
+
def build_model(args):
|
434 |
+
device = torch.device(args.device)
|
435 |
+
|
436 |
+
transformer = build_transformer(args)
|
437 |
+
position_embedding, txt_position_embedding = build_position_encoding(args)
|
438 |
+
|
439 |
+
model = Model(
|
440 |
+
transformer,
|
441 |
+
position_embedding,
|
442 |
+
txt_position_embedding,
|
443 |
+
txt_dim=args.t_feat_dim,
|
444 |
+
vid_dim=args.v_feat_dim,
|
445 |
+
input_dropout=args.input_dropout,
|
446 |
+
span_loss_type=args.span_loss_type,
|
447 |
+
use_txt_pos=args.use_txt_pos,
|
448 |
+
n_input_proj=args.n_input_proj,
|
449 |
+
)
|
450 |
+
|
451 |
+
matcher = build_matcher(args)
|
452 |
+
weight_dict = {"loss_b": args.b_loss_coef,
|
453 |
+
"loss_g": args.g_loss_coef,
|
454 |
+
"loss_f": args.f_loss_coef,
|
455 |
+
"loss_s_intra": args.s_loss_intra_coef,
|
456 |
+
"loss_s_inter": args.s_loss_inter_coef}
|
457 |
+
|
458 |
+
if args.dset_type in ['mr', 'vlp']:
|
459 |
+
if 'tal' not in args.train_path:
|
460 |
+
losses = ['spans', 'labels', 'saliency']
|
461 |
+
else:
|
462 |
+
losses = ['spans', 'labels', 'saliency_cls']
|
463 |
+
elif args.dset_type in ['hl', 'vs']:
|
464 |
+
losses = ['labels', 'saliency']
|
465 |
+
|
466 |
+
criterion = SetCriterion(
|
467 |
+
matcher=matcher,
|
468 |
+
weight_dict=weight_dict, losses=losses,
|
469 |
+
eos_coef=args.eos_coef, temperature=args.temperature,
|
470 |
+
span_loss_type=args.span_loss_type, max_v_l=args.max_v_l,
|
471 |
+
saliency_margin=args.saliency_margin,
|
472 |
+
)
|
473 |
+
criterion.to(device)
|
474 |
+
return model, criterion
|
model/univtg_qfvs.py
ADDED
@@ -0,0 +1,476 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pdb
|
2 |
+
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from torch import nn
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
from model.transformer_encoder_droppath import build_transformer
|
8 |
+
from model.matcher import build_matcher
|
9 |
+
from model.position_encoding import build_position_encoding
|
10 |
+
from utils.span_utils import generalized_temporal_iou, span_cxw_to_xx
|
11 |
+
|
12 |
+
def init_weights(module):
|
13 |
+
if isinstance(module, (nn.Linear, nn.Embedding)):
|
14 |
+
module.weight.data.normal_(mean=0.0, std=0.02)
|
15 |
+
elif isinstance(module, nn.LayerNorm):
|
16 |
+
module.bias.data.zero_()
|
17 |
+
module.weight.data.fill_(1.0)
|
18 |
+
|
19 |
+
if isinstance(module, nn.Linear) and module.bias is not None:
|
20 |
+
module.bias.data.zero_()
|
21 |
+
|
22 |
+
def mask_logits(inputs, mask, mask_value=-1e30):
|
23 |
+
mask = mask.type(torch.float32)
|
24 |
+
return inputs + (1.0 - mask) * mask_value
|
25 |
+
|
26 |
+
def sim_matrix(a, b, eps=1e-8):
|
27 |
+
"""
|
28 |
+
added eps for numerical stability
|
29 |
+
"""
|
30 |
+
a_n, b_n = a.norm(dim=1)[:, None], b.norm(dim=1)[:, None]
|
31 |
+
a_norm = a / torch.max(a_n, eps * torch.ones_like(a_n))
|
32 |
+
b_norm = b / torch.max(b_n, eps * torch.ones_like(b_n))
|
33 |
+
sim_mt = torch.mm(a_norm, b_norm.transpose(0, 1))
|
34 |
+
return sim_mt
|
35 |
+
|
36 |
+
class WeightedPool(nn.Module):
|
37 |
+
def __init__(self, dim):
|
38 |
+
super(WeightedPool, self).__init__()
|
39 |
+
weight = torch.empty(dim, 1)
|
40 |
+
nn.init.xavier_uniform_(weight)
|
41 |
+
self.weight = nn.Parameter(weight, requires_grad=True)
|
42 |
+
|
43 |
+
def forward(self, x, mask):
|
44 |
+
alpha = torch.tensordot(x, self.weight, dims=1) # shape = (batch_size, seq_length, 1)
|
45 |
+
alpha = mask_logits(alpha, mask=mask.unsqueeze(2))
|
46 |
+
alphas = nn.Softmax(dim=1)(alpha)
|
47 |
+
pooled_x = torch.matmul(x.transpose(1, 2), alphas) # (batch_size, dim, 1)
|
48 |
+
pooled_x = pooled_x.squeeze(2)
|
49 |
+
return pooled_x
|
50 |
+
|
51 |
+
class Model(nn.Module):
|
52 |
+
""" This is the UniVTG module that performs moment localization. """
|
53 |
+
|
54 |
+
def __init__(self, transformer, position_embed, txt_position_embed, txt_dim, vid_dim,
|
55 |
+
input_dropout, aux_loss=False,
|
56 |
+
max_v_l=75, span_loss_type="l1", use_txt_pos=False, n_input_proj=2):
|
57 |
+
""" Initializes the model.
|
58 |
+
Parameters:
|
59 |
+
transformer: torch module of the transformer architecture. See transformer.py
|
60 |
+
position_embed: torch module of the position_embedding, See position_encoding.py
|
61 |
+
txt_position_embed: position_embedding for text
|
62 |
+
txt_dim: int, text query input dimension
|
63 |
+
vid_dim: int, video feature input dimension
|
64 |
+
max_v_l: int, maximum #clips in videos
|
65 |
+
span_loss_type: str, one of [l1, ce]
|
66 |
+
l1: (center-x, width) regression.
|
67 |
+
ce: (st_idx, ed_idx) classification.
|
68 |
+
# foreground_thd: float, intersection over prediction >= foreground_thd: labeled as foreground
|
69 |
+
# background_thd: float, intersection over prediction <= background_thd: labeled background
|
70 |
+
"""
|
71 |
+
super().__init__()
|
72 |
+
self.transformer = transformer
|
73 |
+
self.position_embed = position_embed
|
74 |
+
self.txt_position_embed = txt_position_embed
|
75 |
+
hidden_dim = transformer.d_model
|
76 |
+
self.span_loss_type = span_loss_type
|
77 |
+
self.max_v_l = max_v_l
|
78 |
+
span_pred_dim = 2 if span_loss_type == "l1" else max_v_l * 2
|
79 |
+
|
80 |
+
self.token_type_embeddings = nn.Embedding(2, hidden_dim)
|
81 |
+
self.token_type_embeddings.apply(init_weights)
|
82 |
+
|
83 |
+
# Conv projector
|
84 |
+
self.span_embed = Conv(hidden_dim, hidden_dim, span_pred_dim, 3, kernel_size=3)
|
85 |
+
self.class_embed = Conv(hidden_dim, hidden_dim, 1, 3, kernel_size=3) # 0: background, 1: foreground
|
86 |
+
|
87 |
+
self.use_txt_pos = use_txt_pos
|
88 |
+
self.n_input_proj = n_input_proj
|
89 |
+
relu_args = [True] * 3
|
90 |
+
relu_args[n_input_proj-1] = False
|
91 |
+
self.input_txt_proj = nn.Sequential(*[
|
92 |
+
LinearLayer(txt_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[0]),
|
93 |
+
LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[1]),
|
94 |
+
LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[2])
|
95 |
+
][:n_input_proj])
|
96 |
+
self.input_vid_proj = nn.Sequential(*[
|
97 |
+
LinearLayer(vid_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[0]),
|
98 |
+
LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[1]),
|
99 |
+
LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[2])
|
100 |
+
][:n_input_proj])
|
101 |
+
|
102 |
+
# MLP Projector
|
103 |
+
self.weightedpool = WeightedPool(hidden_dim)
|
104 |
+
|
105 |
+
def forward(self, src_txt, src_txt_mask, src_vid, src_vid_mask, src_cls=None, src_cls_mask=None):
|
106 |
+
bs = src_vid.shape[0]
|
107 |
+
src_vid = self.input_vid_proj(src_vid)
|
108 |
+
src_txt = self.input_txt_proj(src_txt)
|
109 |
+
if src_cls is not None:
|
110 |
+
src_cls = self.input_txt_proj(src_cls)
|
111 |
+
|
112 |
+
# type token.
|
113 |
+
src_vid = src_vid + self.token_type_embeddings(torch.full_like(src_vid_mask.long(), 1))
|
114 |
+
src_txt = src_txt + self.token_type_embeddings(torch.zeros_like(src_txt_mask.long()))
|
115 |
+
if src_cls is not None:
|
116 |
+
src_cls = src_cls + self.token_type_embeddings(torch.zeros_like(src_cls_mask.long()))
|
117 |
+
|
118 |
+
src = torch.cat([src_vid, src_txt], dim=1) # (bsz, L_vid+L_txt, d)
|
119 |
+
mask = torch.cat([src_vid_mask, src_txt_mask], dim=1).bool() # (bsz, L_vid+L_txt)
|
120 |
+
|
121 |
+
pos_vid = self.position_embed(src_vid, src_vid_mask) # (bsz, L_vid, d)
|
122 |
+
pos_txt = self.txt_position_embed(src_txt) if self.use_txt_pos else torch.zeros_like(src_txt) # (bsz, L_txt, d)
|
123 |
+
pos = torch.cat([pos_vid, pos_txt], dim=1)
|
124 |
+
|
125 |
+
memory = self.transformer(src, ~mask, pos)
|
126 |
+
vid_mem = memory[:, :src_vid.shape[1], :] # (bsz, L_vid, d)
|
127 |
+
|
128 |
+
outputs_class = self.class_embed(vid_mem).sigmoid() # (#layers, batch_size, #queries, #classes)
|
129 |
+
outputs_coord = self.span_embed(vid_mem) # (#layers, bsz, #queries, 2 or max_v_l * 2)
|
130 |
+
|
131 |
+
if self.span_loss_type == "l1":
|
132 |
+
outputs_coord = outputs_coord.sigmoid()
|
133 |
+
idx_mask = torch.tensor((-1, 1)).unsqueeze(0).unsqueeze(0).cuda()
|
134 |
+
idx_mask = idx_mask.repeat(outputs_coord.shape[0], outputs_coord.shape[1], 1)
|
135 |
+
outputs_coord = outputs_coord * idx_mask
|
136 |
+
else:
|
137 |
+
raise NotImplementedError
|
138 |
+
|
139 |
+
out = {'pred_logits': outputs_class, 'pred_spans': outputs_coord,
|
140 |
+
'src_vid_mask': src_vid_mask}
|
141 |
+
|
142 |
+
vid_mem_proj = src_vid
|
143 |
+
|
144 |
+
# word-level -> sentence-level
|
145 |
+
txt_mem_proj = self.weightedpool(src_txt, src_txt_mask).unsqueeze(1)
|
146 |
+
sim = F.cosine_similarity(vid_mem_proj, txt_mem_proj, dim=-1) + (src_vid_mask + 1e-45).log()
|
147 |
+
|
148 |
+
out["vid_mem_proj"] = vid_mem_proj
|
149 |
+
out["txt_mem_proj"] = txt_mem_proj
|
150 |
+
if src_cls is not None:
|
151 |
+
cls_mem_proj = self.weightedpool(src_cls, src_cls_mask)
|
152 |
+
out["cls_mem_proj"] = cls_mem_proj
|
153 |
+
out["saliency_scores"] = sim
|
154 |
+
return out
|
155 |
+
|
156 |
+
class SetCriterion(nn.Module):
|
157 |
+
""" This class computes the loss for DETR.
|
158 |
+
The process happens in two steps:
|
159 |
+
1) we compute hungarian assignment between ground truth boxes and the outputs of the model
|
160 |
+
2) we supervise each pair of matched ground-truth / prediction (supervise class and box)
|
161 |
+
"""
|
162 |
+
|
163 |
+
def __init__(self, matcher, weight_dict, eos_coef, losses, temperature, span_loss_type, max_v_l,
|
164 |
+
saliency_margin=1):
|
165 |
+
""" Create the criterion.
|
166 |
+
Parameters:
|
167 |
+
matcher: module able to compute a matching between targets and proposals
|
168 |
+
weight_dict: dict containing as key the names of the losses and as values their relative weight.
|
169 |
+
eos_coef: relative classification weight applied to the no-object category
|
170 |
+
losses: list of all the losses to be applied. See get_loss for list of available losses.
|
171 |
+
temperature: float, temperature for NCE loss
|
172 |
+
span_loss_type: str, [l1, ce]
|
173 |
+
max_v_l: int,
|
174 |
+
saliency_margin: float
|
175 |
+
"""
|
176 |
+
super().__init__()
|
177 |
+
self.matcher = matcher
|
178 |
+
self.weight_dict = weight_dict
|
179 |
+
self.losses = losses
|
180 |
+
self.temperature = temperature
|
181 |
+
self.span_loss_type = span_loss_type
|
182 |
+
self.max_v_l = max_v_l
|
183 |
+
self.saliency_margin = saliency_margin
|
184 |
+
self.temperature = 0.07
|
185 |
+
|
186 |
+
# foreground and background classification
|
187 |
+
self.foreground_label = 0
|
188 |
+
self.background_label = 1
|
189 |
+
self.eos_coef = eos_coef
|
190 |
+
empty_weight = torch.ones(2)
|
191 |
+
empty_weight[-1] = self.eos_coef # lower weight for background (index 1, foreground index 0)
|
192 |
+
self.register_buffer('empty_weight', empty_weight)
|
193 |
+
|
194 |
+
def loss_spans(self, outputs, targets, indices):
|
195 |
+
assert 'pred_spans' in outputs
|
196 |
+
|
197 |
+
start_spans = targets['timestamp']
|
198 |
+
pred_spans = outputs['pred_spans']
|
199 |
+
src_spans = start_spans + pred_spans
|
200 |
+
gt_spans = targets['span_labels_nn']
|
201 |
+
|
202 |
+
mask = targets['timestamp_mask'].bool()
|
203 |
+
mask_full = targets['timestamp_mask'].unsqueeze(2).repeat(1, 1, 2)
|
204 |
+
mask_valid = targets['timestamp_window'].bool()
|
205 |
+
mask_valid_full = targets['timestamp_window'].unsqueeze(2).repeat(1, 1, 2)
|
206 |
+
|
207 |
+
loss_span = F.smooth_l1_loss(src_spans, gt_spans, reduction='none') * mask_valid_full
|
208 |
+
loss_giou = 1 - torch.diag(generalized_temporal_iou(src_spans[mask_valid], gt_spans[mask_valid]))
|
209 |
+
|
210 |
+
losses = {}
|
211 |
+
losses['loss_b'] = loss_span.sum() / mask_valid.sum()
|
212 |
+
losses['loss_g'] = loss_giou.mean()
|
213 |
+
return losses
|
214 |
+
|
215 |
+
def loss_labels(self, outputs, targets, indices, log=True):
|
216 |
+
saliency_scores = targets["saliency_scores"]
|
217 |
+
if saliency_scores.sum() == 0:
|
218 |
+
return {"loss_f": 0.}
|
219 |
+
|
220 |
+
src_logits = outputs['pred_logits'].squeeze(-1) # (batch_size, #queries, #classes=2)
|
221 |
+
target_classes = targets["saliency_scores"].squeeze()
|
222 |
+
|
223 |
+
weights = torch.ones_like(target_classes).float() * self.empty_weight[1]
|
224 |
+
weights[target_classes.bool()] = self.empty_weight[0]
|
225 |
+
|
226 |
+
loss_ce = F.binary_cross_entropy(src_logits, target_classes.float(), reduction="none")
|
227 |
+
return {"loss_f": loss_ce.sum() / target_classes.sum()}
|
228 |
+
# return {"loss_f": loss_ce.sum() / len(target_classes)}
|
229 |
+
|
230 |
+
# mask = targets['timestamp_mask'].bool()
|
231 |
+
# mask_valid = targets['timestamp_window'].bool()
|
232 |
+
# target_classes = torch.full(src_logits.shape[:2], 0, dtype=torch.int64, device=src_logits.device) # (batch_size, #queries)
|
233 |
+
# target_classes[mask_valid] = 1
|
234 |
+
# # target_classes = targets['timestamp_window'] # soft cls.
|
235 |
+
# target_classes.float()
|
236 |
+
# # pdb.set_trace()
|
237 |
+
|
238 |
+
# weights = torch.zeros_like(target_classes).float()
|
239 |
+
# weights[mask] = self.empty_weight[1]
|
240 |
+
# weights[mask_valid] = self.empty_weight[0]
|
241 |
+
|
242 |
+
# loss_ce = F.binary_cross_entropy(src_logits, target_classes.float(), weight=weights, reduction="none") * mask
|
243 |
+
# # return {"loss_f": loss_ce.sum() / mask.sum()}
|
244 |
+
# return {"loss_f": loss_ce.sum() / mask_valid.sum()}
|
245 |
+
|
246 |
+
def loss_saliency(self, outputs, targets, indices, log=True):
|
247 |
+
"""higher scores for positive clips"""
|
248 |
+
if "saliency_pos_labels" not in targets:
|
249 |
+
return {"loss_s_inter": 0., "loss_s_intra": 0.}
|
250 |
+
saliency_scores = targets["saliency_scores"]
|
251 |
+
if saliency_scores.sum() == 0:
|
252 |
+
return {"loss_s_inter": 0., "loss_s_intra": 0.}
|
253 |
+
|
254 |
+
# * qfvs mil-nce mode
|
255 |
+
pos_indices = saliency_scores.squeeze() > 0
|
256 |
+
|
257 |
+
sim = outputs['saliency_scores']
|
258 |
+
sim_soft = F.softmax(sim / self.temperature, dim=0)
|
259 |
+
sim_log = torch.log(sim_soft[pos_indices])
|
260 |
+
loss_saliency_intra = -sim_log.sum() / len(sim_log)
|
261 |
+
return {"loss_s_inter": 0., "loss_s_intra": loss_saliency_intra}
|
262 |
+
|
263 |
+
# * inter-vid mode
|
264 |
+
# vid_mem_proj = outputs["vid_mem_proj"]
|
265 |
+
# pos_indices = targets["saliency_pos_labels"][:,0].long() # (N, #pairs)
|
266 |
+
# batch_indices = torch.arange(len(vid_mem_proj)).to(vid_mem_proj.device)
|
267 |
+
|
268 |
+
# vid_feats = vid_mem_proj[batch_indices, pos_indices]
|
269 |
+
# txt_feats = outputs["txt_mem_proj"].squeeze(1)
|
270 |
+
# sim = sim_matrix(vid_feats, txt_feats)
|
271 |
+
|
272 |
+
# i_logsm = F.log_softmax(sim / self.temperature, dim=1)
|
273 |
+
# j_logsm = F.log_softmax(sim.t() /self.temperature, dim=1)
|
274 |
+
|
275 |
+
# # sum over positives
|
276 |
+
# idiag = torch.diag(i_logsm)
|
277 |
+
# jdiag = torch.diag(j_logsm)
|
278 |
+
# loss_i = idiag.sum() / len(idiag)
|
279 |
+
# loss_j = jdiag.sum() / len(jdiag)
|
280 |
+
|
281 |
+
# loss_saliency_inter = - loss_i - loss_j
|
282 |
+
|
283 |
+
# # * intra-vid mode
|
284 |
+
# mask = targets['timestamp_mask']
|
285 |
+
# selected_scores = saliency_scores[batch_indices, pos_indices].unsqueeze(-1)
|
286 |
+
# neg_indices_in = (saliency_scores < selected_scores)
|
287 |
+
# neg_indices_in[batch_indices, pos_indices] = True
|
288 |
+
# mask_invalid = neg_indices_in * mask.bool()
|
289 |
+
|
290 |
+
# sim_in = F.cosine_similarity(vid_mem_proj, txt_feats.unsqueeze(1), dim=-1)
|
291 |
+
# sim_in = sim_in + (mask_invalid + 1e-45).log()
|
292 |
+
# logsm_in_i = F.log_softmax(sim_in / self.temperature, dim=1)
|
293 |
+
# logsm_in_j = F.log_softmax(sim_in.t() / self.temperature, dim=1)
|
294 |
+
|
295 |
+
# pos_logsm_in_i = logsm_in_i[batch_indices, pos_indices]
|
296 |
+
# pos_logsm_in_j = logsm_in_j[pos_indices, batch_indices]
|
297 |
+
# loss_in_i = pos_logsm_in_i.sum() / len(pos_logsm_in_i)
|
298 |
+
# loss_in_j = pos_logsm_in_j.sum() / len(pos_logsm_in_j)
|
299 |
+
|
300 |
+
# loss_saliency_intra = - loss_in_i - loss_in_j
|
301 |
+
|
302 |
+
# return {"loss_s_inter": loss_saliency_inter, "loss_s_intra": loss_saliency_intra}
|
303 |
+
|
304 |
+
def loss_saliency_cls(self, outputs, targets, indices, log=True):
|
305 |
+
"""higher scores for positive clips"""
|
306 |
+
if "saliency_pos_labels" not in targets:
|
307 |
+
return {"loss_s_inter": 0., "loss_s_intra": 0.}
|
308 |
+
saliency_scores = targets["saliency_scores"]
|
309 |
+
if saliency_scores.sum() == 0:
|
310 |
+
return {"loss_s_inter": 0., "loss_s_intra": 0.}
|
311 |
+
|
312 |
+
# * inter-vid mode
|
313 |
+
vid_mem_proj = outputs["vid_mem_proj"]
|
314 |
+
pos_indices = targets["saliency_pos_labels"][:,0].long() # (N, #pairs)
|
315 |
+
batch_indices = torch.arange(len(vid_mem_proj)).to(vid_mem_proj.device)
|
316 |
+
|
317 |
+
vid_feats = vid_mem_proj[batch_indices, pos_indices]
|
318 |
+
txt_feats = outputs["txt_mem_proj"].squeeze(1)
|
319 |
+
sim = sim_matrix(vid_feats, txt_feats)
|
320 |
+
|
321 |
+
i_logsm = F.log_softmax(sim / self.temperature, dim=1)
|
322 |
+
j_logsm = F.log_softmax(sim.t() /self.temperature, dim=1)
|
323 |
+
|
324 |
+
# sum over positives
|
325 |
+
idiag = torch.diag(i_logsm)
|
326 |
+
jdiag = torch.diag(j_logsm)
|
327 |
+
loss_i = idiag.sum() / len(idiag)
|
328 |
+
loss_j = jdiag.sum() / len(jdiag)
|
329 |
+
|
330 |
+
loss_saliency_inter = - loss_i - loss_j
|
331 |
+
|
332 |
+
# * intra-vid mode
|
333 |
+
if 'cls_idx' not in targets.keys(): # eval
|
334 |
+
return {"loss_s_inter": loss_saliency_inter}
|
335 |
+
|
336 |
+
cls_indices = targets['cls_idx'].bool()
|
337 |
+
cls_feats = outputs["cls_mem_proj"].squeeze(1)
|
338 |
+
sim_cls = sim_matrix(vid_feats, cls_feats)
|
339 |
+
|
340 |
+
i_logsm_cls = F.log_softmax(sim_cls / self.temperature, dim=1)
|
341 |
+
idiag_cls = i_logsm_cls[cls_indices]
|
342 |
+
loss_cls_i = idiag_cls.sum() / len(idiag_cls)
|
343 |
+
|
344 |
+
loss_saliency_intra = - loss_cls_i
|
345 |
+
|
346 |
+
return {"loss_s_inter": loss_saliency_inter, "loss_s_intra": loss_saliency_intra}
|
347 |
+
|
348 |
+
def get_loss(self, loss, outputs, targets, indices, **kwargs):
|
349 |
+
loss_map = {
|
350 |
+
"spans": self.loss_spans,
|
351 |
+
"labels": self.loss_labels,
|
352 |
+
"saliency": self.loss_saliency,
|
353 |
+
"saliency_cls": self.loss_saliency_cls,
|
354 |
+
}
|
355 |
+
assert loss in loss_map, f'do you really want to compute {loss} loss?'
|
356 |
+
return loss_map[loss](outputs, targets, indices, **kwargs)
|
357 |
+
|
358 |
+
def forward(self, outputs, targets, mask_GT=None):
|
359 |
+
""" This performs the loss computation.
|
360 |
+
Parameters:
|
361 |
+
outputs: dict of tensors, see the output specification of the model for the format
|
362 |
+
targets: list of dicts, such that len(targets) == batch_size.
|
363 |
+
The expected keys in each dict depends on the losses applied, see each loss' doc
|
364 |
+
"""
|
365 |
+
indices = None
|
366 |
+
# Compute all the requested losses
|
367 |
+
losses = {}
|
368 |
+
outputs['pred_logits'] = outputs['pred_logits'].reshape(1, -1).masked_select(mask_GT[0])
|
369 |
+
count = mask_GT.sum()
|
370 |
+
outputs['saliency_scores'] = outputs['saliency_scores'].reshape(1, -1).masked_select(mask_GT[0])
|
371 |
+
# targets['saliency_scores'] = targets['saliency_scores'].masked_select(mask_GT[0])
|
372 |
+
targets['saliency_scores'] = targets['saliency_scores'][0,:count]
|
373 |
+
|
374 |
+
for loss in self.losses:
|
375 |
+
losses.update(self.get_loss(loss, outputs, targets, indices))
|
376 |
+
|
377 |
+
return losses
|
378 |
+
|
379 |
+
class MLP(nn.Module):
|
380 |
+
""" Very simple multi-layer perceptron (also called FFN)"""
|
381 |
+
|
382 |
+
def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
|
383 |
+
super().__init__()
|
384 |
+
self.num_layers = num_layers
|
385 |
+
h = [hidden_dim] * (num_layers - 1)
|
386 |
+
self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
|
387 |
+
|
388 |
+
def forward(self, x):
|
389 |
+
for i, layer in enumerate(self.layers):
|
390 |
+
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
|
391 |
+
return x
|
392 |
+
|
393 |
+
class Conv(nn.Module):
|
394 |
+
""" Very simple multi-layer perceptron (also called FFN)"""
|
395 |
+
|
396 |
+
def __init__(self, input_dim, hidden_dim, output_dim, num_layers, kernel_size):
|
397 |
+
super().__init__()
|
398 |
+
self.num_layers = num_layers
|
399 |
+
h = [hidden_dim] * (num_layers - 1)
|
400 |
+
# self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
|
401 |
+
self.layers = nn.ModuleList(
|
402 |
+
nn.Conv1d(n, k, kernel_size=kernel_size, stride=1, padding=kernel_size//2, dilation=1, groups=1, bias=True, padding_mode='zeros')
|
403 |
+
for n, k in zip([input_dim] + h, h + [output_dim]))
|
404 |
+
def forward(self, x):
|
405 |
+
x = x.permute(0,2,1)
|
406 |
+
for i, layer in enumerate(self.layers):
|
407 |
+
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
|
408 |
+
return x.permute(0, 2, 1)
|
409 |
+
|
410 |
+
class LinearLayer(nn.Module):
|
411 |
+
"""linear layer configurable with layer normalization, dropout, ReLU."""
|
412 |
+
|
413 |
+
def __init__(self, in_hsz, out_hsz, layer_norm=True, dropout=0.1, relu=True):
|
414 |
+
super(LinearLayer, self).__init__()
|
415 |
+
self.relu = relu
|
416 |
+
self.layer_norm = layer_norm
|
417 |
+
if layer_norm:
|
418 |
+
self.LayerNorm = nn.LayerNorm(in_hsz)
|
419 |
+
layers = [
|
420 |
+
nn.Dropout(dropout),
|
421 |
+
nn.Linear(in_hsz, out_hsz)
|
422 |
+
]
|
423 |
+
self.net = nn.Sequential(*layers)
|
424 |
+
|
425 |
+
def forward(self, x):
|
426 |
+
"""(N, L, D)"""
|
427 |
+
if self.layer_norm:
|
428 |
+
x = self.LayerNorm(x)
|
429 |
+
x = self.net(x)
|
430 |
+
if self.relu:
|
431 |
+
x = F.relu(x, inplace=True)
|
432 |
+
return x # (N, L, D)
|
433 |
+
|
434 |
+
|
435 |
+
def build_model(args):
|
436 |
+
device = torch.device(args.device)
|
437 |
+
|
438 |
+
transformer = build_transformer(args)
|
439 |
+
position_embedding, txt_position_embedding = build_position_encoding(args)
|
440 |
+
|
441 |
+
model = Model(
|
442 |
+
transformer,
|
443 |
+
position_embedding,
|
444 |
+
txt_position_embedding,
|
445 |
+
txt_dim=args.t_feat_dim,
|
446 |
+
vid_dim=args.v_feat_dim,
|
447 |
+
input_dropout=args.input_dropout,
|
448 |
+
span_loss_type=args.span_loss_type,
|
449 |
+
use_txt_pos=args.use_txt_pos,
|
450 |
+
n_input_proj=args.n_input_proj,
|
451 |
+
)
|
452 |
+
|
453 |
+
matcher = build_matcher(args)
|
454 |
+
weight_dict = {"loss_b": args.b_loss_coef,
|
455 |
+
"loss_g": args.g_loss_coef,
|
456 |
+
"loss_f": args.f_loss_coef,
|
457 |
+
"loss_s_intra": args.s_loss_intra_coef,
|
458 |
+
"loss_s_inter": args.s_loss_inter_coef}
|
459 |
+
|
460 |
+
if args.dset_type in ['mr', 'vlp']:
|
461 |
+
if 'tal' not in args.train_path:
|
462 |
+
losses = ['spans', 'labels', 'saliency']
|
463 |
+
else:
|
464 |
+
losses = ['spans', 'labels', 'saliency_cls']
|
465 |
+
elif args.dset_type in ['hl', 'vs']:
|
466 |
+
losses = ['labels', 'saliency']
|
467 |
+
|
468 |
+
criterion = SetCriterion(
|
469 |
+
matcher=matcher,
|
470 |
+
weight_dict=weight_dict, losses=losses,
|
471 |
+
eos_coef=args.eos_coef, temperature=args.temperature,
|
472 |
+
span_loss_type=args.span_loss_type, max_v_l=args.max_v_l,
|
473 |
+
saliency_margin=args.saliency_margin,
|
474 |
+
)
|
475 |
+
criterion.to(device)
|
476 |
+
return model, criterion
|
requirements.txt
ADDED
@@ -0,0 +1,355 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
absl-py==1.2.0
|
2 |
+
accelerate==0.19.0
|
3 |
+
aiodns==3.0.0
|
4 |
+
aiofiles==23.1.0
|
5 |
+
aiohttp==3.8.3
|
6 |
+
aiohttp-socks==0.7.1
|
7 |
+
aiosignal==1.3.1
|
8 |
+
altair==5.0.1
|
9 |
+
antiorm==1.2.1
|
10 |
+
antlr4-python3-runtime==4.9.3
|
11 |
+
anyio==3.7.0
|
12 |
+
appdirs==1.4.4
|
13 |
+
argilla==1.8.0
|
14 |
+
argon2-cffi==21.3.0
|
15 |
+
argon2-cffi-bindings==21.2.0
|
16 |
+
asttokens==2.0.7
|
17 |
+
async-timeout==4.0.2
|
18 |
+
attrs==22.1.0
|
19 |
+
Babel==2.12.1
|
20 |
+
backcall==0.2.0
|
21 |
+
backoff==2.2.1
|
22 |
+
beautifulsoup4==4.11.1
|
23 |
+
bert-score==0.3.13
|
24 |
+
black==22.3.0
|
25 |
+
bleach==5.0.1
|
26 |
+
blis==0.7.9
|
27 |
+
boto3==1.24.84
|
28 |
+
botocore==1.27.84
|
29 |
+
Brotli==1.0.9
|
30 |
+
brotlipy==0.7.0
|
31 |
+
cachetools==5.2.0
|
32 |
+
catalogue==2.0.8
|
33 |
+
cchardet==2.1.7
|
34 |
+
certifi==2023.5.7
|
35 |
+
cffi==1.15.1
|
36 |
+
chardet==5.1.0
|
37 |
+
charset-normalizer==2.1.1
|
38 |
+
cinemagoer==2023.5.1
|
39 |
+
click==8.1.3
|
40 |
+
cloudpickle==2.2.0
|
41 |
+
cmake==3.26.3
|
42 |
+
coloredlogs==15.0.1
|
43 |
+
colorlog==6.7.0
|
44 |
+
commonmark==0.9.1
|
45 |
+
confection==0.0.4
|
46 |
+
contourpy==1.0.6
|
47 |
+
cryptography==37.0.1
|
48 |
+
cycler==0.11.0
|
49 |
+
cymem==2.0.7
|
50 |
+
dataclasses==0.6
|
51 |
+
dataclasses-json==0.5.7
|
52 |
+
dataflow==0.9.5
|
53 |
+
db==0.1.1
|
54 |
+
db-sqlite3==0.0.1
|
55 |
+
debugpy==1.6.3
|
56 |
+
decoder==0.5
|
57 |
+
decorator==4.4.2
|
58 |
+
decord==0.6.0
|
59 |
+
defusedxml==0.7.1
|
60 |
+
Deprecated==1.2.14
|
61 |
+
detectron2==0.6
|
62 |
+
docker==6.0.0
|
63 |
+
docker-pycreds==0.4.0
|
64 |
+
easydict==1.9
|
65 |
+
ego4d==1.2.5
|
66 |
+
einops==0.6.0
|
67 |
+
elastic-transport==8.4.0
|
68 |
+
elasticsearch==8.5.0
|
69 |
+
entrypoints==0.4
|
70 |
+
et-xmlfile==1.1.0
|
71 |
+
exceptiongroup==1.1.1
|
72 |
+
executing==0.10.0
|
73 |
+
fairscale==0.4.12
|
74 |
+
fake-useragent==0.1.14
|
75 |
+
fastapi==0.98.0
|
76 |
+
fastjsonschema==2.16.1
|
77 |
+
ffmpeg==1.4
|
78 |
+
ffmpeg-python==0.2.0
|
79 |
+
ffmpy==0.3.0
|
80 |
+
ffprobe==0.5
|
81 |
+
filelock==3.7.1
|
82 |
+
fonttools==4.38.0
|
83 |
+
frozenlist==1.3.3
|
84 |
+
fsspec==2023.5.0
|
85 |
+
ftfy==6.1.1
|
86 |
+
future==0.18.2
|
87 |
+
fvcore==0.1.5.post20220512
|
88 |
+
gdown==4.7.1
|
89 |
+
gensim==4.2.0
|
90 |
+
geographiclib==2.0
|
91 |
+
geopy==2.3.0
|
92 |
+
gitdb==4.0.10
|
93 |
+
GitPython==3.1.31
|
94 |
+
glide-text2im==0.0.0
|
95 |
+
google-api-core==2.11.1
|
96 |
+
google-api-python-client==2.95.0
|
97 |
+
google-auth==2.22.0
|
98 |
+
google-auth-httplib2==0.1.0
|
99 |
+
google-auth-oauthlib==0.4.6
|
100 |
+
google-cloud==0.34.0
|
101 |
+
google-cloud-vision==3.4.4
|
102 |
+
google-measurement-protocol==1.1.0
|
103 |
+
googleapis-common-protos==1.59.1
|
104 |
+
googletransx==2.4.2
|
105 |
+
gradio==3.23.0
|
106 |
+
greenlet==2.0.2
|
107 |
+
grpcio==1.56.2
|
108 |
+
grpcio-status==1.56.2
|
109 |
+
h11==0.14.0
|
110 |
+
h5py==3.7.0
|
111 |
+
httpcore==0.16.3
|
112 |
+
httplib2==0.22.0
|
113 |
+
httpx==0.23.3
|
114 |
+
huggingface-hub==0.15.1
|
115 |
+
humanfriendly==10.0
|
116 |
+
hydra-core==1.2.0
|
117 |
+
idna==3.3
|
118 |
+
imageio==2.31.0
|
119 |
+
imageio-ffmpeg==0.4.7
|
120 |
+
importlib-metadata==4.12.0
|
121 |
+
importlib-resources==5.9.0
|
122 |
+
iopath==0.1.9
|
123 |
+
ipdb==0.13.11
|
124 |
+
ipykernel==6.15.3
|
125 |
+
ipython==8.4.0
|
126 |
+
ipython-genutils==0.2.0
|
127 |
+
ipywidgets==8.0.2
|
128 |
+
jedi==0.18.1
|
129 |
+
Jinja2==3.1.2
|
130 |
+
jmespath==1.0.1
|
131 |
+
joblib==1.1.0
|
132 |
+
jsonlines==3.1.0
|
133 |
+
jsonschema==4.16.0
|
134 |
+
jupyter==1.0.0
|
135 |
+
jupyter_client==7.3.5
|
136 |
+
jupyter-console==6.4.4
|
137 |
+
jupyter-core==4.11.1
|
138 |
+
jupyterlab-pygments==0.2.2
|
139 |
+
jupyterlab-widgets==3.0.3
|
140 |
+
kiwisolver==1.4.4
|
141 |
+
langchain==0.0.191
|
142 |
+
langcodes==3.3.0
|
143 |
+
language-evaluation==0.1.0
|
144 |
+
lazy_loader==0.2
|
145 |
+
linkify-it-py==2.0.2
|
146 |
+
lit==16.0.5.post0
|
147 |
+
lxml==4.9.1
|
148 |
+
Markdown==3.4.1
|
149 |
+
markdown-it-py==2.2.0
|
150 |
+
markdown2==2.4.9
|
151 |
+
MarkupSafe==2.1.1
|
152 |
+
marshmallow==3.19.0
|
153 |
+
marshmallow-enum==1.5.1
|
154 |
+
matplotlib==3.6.2
|
155 |
+
matplotlib-inline==0.1.3
|
156 |
+
mdit-py-plugins==0.3.3
|
157 |
+
mdurl==0.1.2
|
158 |
+
mistune==2.0.4
|
159 |
+
mkl-fft==1.3.1
|
160 |
+
mkl-random==1.2.2
|
161 |
+
mkl-service==2.4.0
|
162 |
+
monotonic==1.6
|
163 |
+
more-itertools==9.1.0
|
164 |
+
moviepy==1.0.3
|
165 |
+
mpmath==1.3.0
|
166 |
+
msg-parser==1.2.0
|
167 |
+
msgpack==1.0.4
|
168 |
+
msgpack-numpy==0.4.8
|
169 |
+
multidict==6.0.4
|
170 |
+
murmurhash==1.0.9
|
171 |
+
mutagen==1.46.0
|
172 |
+
mypy-extensions==0.4.3
|
173 |
+
nbclient==0.6.8
|
174 |
+
nbconvert==7.0.0
|
175 |
+
nbformat==5.5.0
|
176 |
+
nest-asyncio==1.5.5
|
177 |
+
networkx==2.8.7
|
178 |
+
nh3==0.2.13
|
179 |
+
nltk==3.7
|
180 |
+
nms-1d-cpu==0.0.0
|
181 |
+
nncore==0.3.6
|
182 |
+
notebook==6.4.12
|
183 |
+
numexpr==2.8.4
|
184 |
+
numpy==1.23.1
|
185 |
+
nvidia-cublas-cu11==11.10.3.66
|
186 |
+
nvidia-cuda-cupti-cu11==11.7.101
|
187 |
+
nvidia-cuda-nvrtc-cu11==11.7.99
|
188 |
+
nvidia-cuda-runtime-cu11==11.7.99
|
189 |
+
nvidia-cudnn-cu11==8.5.0.96
|
190 |
+
nvidia-cufft-cu11==10.9.0.58
|
191 |
+
nvidia-curand-cu11==10.2.10.91
|
192 |
+
nvidia-cusolver-cu11==11.4.0.1
|
193 |
+
nvidia-cusparse-cu11==11.7.4.91
|
194 |
+
nvidia-nccl-cu11==2.14.3
|
195 |
+
nvidia-nvtx-cu11==11.7.91
|
196 |
+
oauthlib==3.2.0
|
197 |
+
olefile==0.46
|
198 |
+
omegaconf==2.2.3
|
199 |
+
openai==0.27.7
|
200 |
+
openapi-schema-pydantic==1.2.4
|
201 |
+
opencv-python==4.5.4.58
|
202 |
+
openpyxl==3.1.2
|
203 |
+
orjson==3.9.1
|
204 |
+
ortools==9.4.1874
|
205 |
+
packaging==21.3
|
206 |
+
pandas==1.5.2
|
207 |
+
pandocfilters==1.5.0
|
208 |
+
parso==0.8.3
|
209 |
+
pathspec==0.10.1
|
210 |
+
pathtools==0.1.2
|
211 |
+
pathy==0.10.1
|
212 |
+
pdfminer.six==20221105
|
213 |
+
peft==0.3.0
|
214 |
+
pexpect==4.8.0
|
215 |
+
pickleshare==0.7.5
|
216 |
+
Pillow==9.3.0
|
217 |
+
pip==22.2.2
|
218 |
+
pkgutil_resolve_name==1.3.10
|
219 |
+
platformdirs==2.5.2
|
220 |
+
portalocker==2.5.1
|
221 |
+
preshed==3.0.8
|
222 |
+
prices==1.1.1
|
223 |
+
proglog==0.1.10
|
224 |
+
prometheus-client==0.14.1
|
225 |
+
prompt-toolkit==3.0.30
|
226 |
+
proto-plus==1.22.3
|
227 |
+
protobuf==3.20.1
|
228 |
+
psutil==5.9.2
|
229 |
+
ptyprocess==0.7.0
|
230 |
+
pure-eval==0.2.2
|
231 |
+
pyasn1==0.4.8
|
232 |
+
pyasn1-modules==0.2.8
|
233 |
+
pycares==4.2.2
|
234 |
+
pycipher==0.5.2
|
235 |
+
pycocoevalcap==1.2
|
236 |
+
pycocotools==2.0.5
|
237 |
+
pycparser==2.21
|
238 |
+
pycryptodomex==3.18.0
|
239 |
+
pydantic==1.10.8
|
240 |
+
pydot==1.4.2
|
241 |
+
pydub==0.25.1
|
242 |
+
pyfiglet==0.8.post1
|
243 |
+
Pygments==2.12.0
|
244 |
+
pynvml==11.5.0
|
245 |
+
pyOpenSSL==22.0.0
|
246 |
+
pypandoc==1.11
|
247 |
+
pyparsing==3.0.9
|
248 |
+
pyrsistent==0.18.1
|
249 |
+
PySocks==1.7.1
|
250 |
+
python-dateutil==2.8.2
|
251 |
+
python-docx==0.8.11
|
252 |
+
python-hostlist==1.21
|
253 |
+
python-magic==0.4.27
|
254 |
+
python-multipart==0.0.6
|
255 |
+
python-pptx==0.6.21
|
256 |
+
python-socks==2.0.3
|
257 |
+
pytz==2022.7
|
258 |
+
PyWavelets==1.4.1
|
259 |
+
PyYAML==6.0
|
260 |
+
pyzmq==23.2.1
|
261 |
+
qtconsole==5.3.2
|
262 |
+
QtPy==2.2.0
|
263 |
+
regex==2022.7.25
|
264 |
+
requests==2.28.1
|
265 |
+
requests-oauthlib==1.3.1
|
266 |
+
rfc3986==1.5.0
|
267 |
+
rich==13.0.1
|
268 |
+
rouge-score==0.1.2
|
269 |
+
rsa==4.9
|
270 |
+
ruamel.yaml==0.17.21
|
271 |
+
ruamel.yaml.clib==0.2.7
|
272 |
+
s3transfer==0.6.0
|
273 |
+
sacremoses==0.0.53
|
274 |
+
safetensors==0.3.1
|
275 |
+
schedule==1.1.0
|
276 |
+
scikit-image==0.21.0
|
277 |
+
scikit-learn==1.1.2
|
278 |
+
scipy==1.9.3
|
279 |
+
seaborn==0.12.0
|
280 |
+
semantic-version==2.10.0
|
281 |
+
Send2Trash==1.8.0
|
282 |
+
sentencepiece==0.1.99
|
283 |
+
sentry-sdk==1.26.0
|
284 |
+
setproctitle==1.3.2
|
285 |
+
setuptools==59.5.0
|
286 |
+
shortuuid==1.0.11
|
287 |
+
simplejson==3.17.6
|
288 |
+
six==1.16.0
|
289 |
+
smart-open==6.2.0
|
290 |
+
smmap==5.0.0
|
291 |
+
sniffio==1.3.0
|
292 |
+
soupsieve==2.3.2.post1
|
293 |
+
spacy==3.5.3
|
294 |
+
spacy-legacy==3.0.12
|
295 |
+
spacy-loggers==1.0.4
|
296 |
+
SQLAlchemy==2.0.15
|
297 |
+
srsly==2.4.6
|
298 |
+
stack-data==0.4.0
|
299 |
+
starlette==0.27.0
|
300 |
+
svgwrite==1.4.3
|
301 |
+
sympy==1.12
|
302 |
+
tabulate==0.8.10
|
303 |
+
tenacity==8.2.2
|
304 |
+
tensorboard==2.9.1
|
305 |
+
tensorboard-data-server==0.6.1
|
306 |
+
tensorboard-plugin-wit==1.8.1
|
307 |
+
termcolor==1.1.0
|
308 |
+
terminado==0.15.0
|
309 |
+
terminaltables==3.1.10
|
310 |
+
thinc==8.1.10
|
311 |
+
threadpoolctl==3.1.0
|
312 |
+
tifffile==2023.4.12
|
313 |
+
timm==0.4.12
|
314 |
+
tinycss2==1.1.1
|
315 |
+
tokenizers==0.13.2
|
316 |
+
tomli==2.0.1
|
317 |
+
toolz==0.12.0
|
318 |
+
torch==2.0.1
|
319 |
+
torchaudio==0.9.0a0+33b2469
|
320 |
+
torchdata==0.6.1
|
321 |
+
torchtext==0.15.2
|
322 |
+
torchvision==0.10.0a0
|
323 |
+
tornado==6.2
|
324 |
+
tqdm==4.64.1
|
325 |
+
traitlets==5.3.0
|
326 |
+
transformers==4.28.1
|
327 |
+
triton==2.0.0
|
328 |
+
twint==2.1.21
|
329 |
+
typer==0.7.0
|
330 |
+
typing_extensions==4.3.0
|
331 |
+
typing-inspect==0.9.0
|
332 |
+
uc-micro-py==1.0.2
|
333 |
+
unstructured==0.7.1
|
334 |
+
uritemplate==4.1.1
|
335 |
+
urllib3==1.26.12
|
336 |
+
uvicorn==0.22.0
|
337 |
+
wandb==0.15.4
|
338 |
+
warmup-scheduler==0.3
|
339 |
+
wasabi==1.1.2
|
340 |
+
wavedrom==2.0.3.post3
|
341 |
+
wcwidth==0.2.5
|
342 |
+
webencodings==0.5.1
|
343 |
+
websocket-client==1.4.1
|
344 |
+
websockets==11.0.3
|
345 |
+
Werkzeug==2.2.1
|
346 |
+
wheel==0.37.1
|
347 |
+
widgetsnbextension==4.0.3
|
348 |
+
wrapt==1.14.1
|
349 |
+
xlrd==2.0.1
|
350 |
+
XlsxWriter==3.1.2
|
351 |
+
yacs==0.1.8
|
352 |
+
yarl==1.9.2
|
353 |
+
youtube-dl==2021.12.17
|
354 |
+
yt-dlp==2023.3.4
|
355 |
+
zipp==3.8.1
|
results/omni/opt.json
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"dset_type": "vlp",
|
3 |
+
"dset_name": "vlp",
|
4 |
+
"domain_name": null,
|
5 |
+
"model_id": "univtg",
|
6 |
+
"exp_id": "omni_mini_aio_unified__epo3_f10_b10g1_s0.1_0.1",
|
7 |
+
"device": 0,
|
8 |
+
"gpu_id": 0,
|
9 |
+
"debug": false,
|
10 |
+
"seed": 2018,
|
11 |
+
"local_rank": 0,
|
12 |
+
"eval_split_name": "val",
|
13 |
+
"data_ratio": 1.0,
|
14 |
+
"results_root": "results",
|
15 |
+
"num_workers": 8,
|
16 |
+
"no_pin_memory": false,
|
17 |
+
"bsz": 64,
|
18 |
+
"n_epoch": 100,
|
19 |
+
"max_es_cnt": 200,
|
20 |
+
"lr": 0.0001,
|
21 |
+
"lr_drop": 200,
|
22 |
+
"lr_gamma": 0.1,
|
23 |
+
"lr_warmup": 10.0,
|
24 |
+
"wd": 0.0001,
|
25 |
+
"grad_clip": 0.1,
|
26 |
+
"span_loss_type": "l1",
|
27 |
+
"b_loss_coef": 10.0,
|
28 |
+
"g_loss_coef": 1.0,
|
29 |
+
"eos_coef": 0.1,
|
30 |
+
"f_loss_coef": 10.0,
|
31 |
+
"s_loss_intra_coef": 0.1,
|
32 |
+
"s_loss_inter_coef": 0.1,
|
33 |
+
"main_metric": "MR-full-R1@0.3-key",
|
34 |
+
"eval_mode": null,
|
35 |
+
"eval_bsz": 32,
|
36 |
+
"eval_epoch": 5,
|
37 |
+
"eval_init": true,
|
38 |
+
"save_interval": 5,
|
39 |
+
"resume": "/data/home/qinghonglin/univtg/results/vlp-vlp/aio_unified_mini-clip-clip-2023_05_27_00/model_e0003.ckpt",
|
40 |
+
"resume_dir": null,
|
41 |
+
"resume_all": false,
|
42 |
+
"start_epoch": null,
|
43 |
+
"no_sort_results": false,
|
44 |
+
"max_before_nms": 1000,
|
45 |
+
"max_after_nms": 10,
|
46 |
+
"conf_thd": 0.0,
|
47 |
+
"nms_thd": 0.7,
|
48 |
+
"use_cache": -1,
|
49 |
+
"max_q_l": 75,
|
50 |
+
"max_v_l": 75,
|
51 |
+
"clip_length": 2.0,
|
52 |
+
"clip_len_list": null,
|
53 |
+
"max_windows": 5,
|
54 |
+
"add_easy_negative": 1,
|
55 |
+
"easy_negative_only": 1,
|
56 |
+
"round_multiple": 1,
|
57 |
+
"train_path": [
|
58 |
+
"data/qvhighlights/metadata/qvhighlights_train.jsonl",
|
59 |
+
"data/charades/metadata/charades_train.jsonl",
|
60 |
+
"data/ego4d/metadata/nlq_train.jsonl",
|
61 |
+
"data/tacos/metadata/train.jsonl",
|
62 |
+
"data/anet/metadata/train.jsonl",
|
63 |
+
"data/didemo/metadata/train.jsonl"
|
64 |
+
],
|
65 |
+
"eval_path": "data/qvhighlights/metadata/qvhighlights_val.jsonl",
|
66 |
+
"train_path_list": null,
|
67 |
+
"eval_path_list": null,
|
68 |
+
"feat_root_list": null,
|
69 |
+
"no_norm_vfeat": false,
|
70 |
+
"no_norm_tfeat": false,
|
71 |
+
"v_feat_dirs": [
|
72 |
+
"vid_clip"
|
73 |
+
],
|
74 |
+
"t_feat_dir": "txt_clip",
|
75 |
+
"v_feat_dim": 512,
|
76 |
+
"t_feat_dim": 512,
|
77 |
+
"ctx_mode": "video_tef",
|
78 |
+
"v_feat_types": "clip",
|
79 |
+
"t_feat_type": "clip",
|
80 |
+
"position_embedding": "sine",
|
81 |
+
"n_input_proj": 2,
|
82 |
+
"temperature": 0.07,
|
83 |
+
"enc_layers": 4,
|
84 |
+
"sub_enc_layers": 2,
|
85 |
+
"dec_layers": 2,
|
86 |
+
"dim_feedforward": 1024,
|
87 |
+
"hidden_dim": 512,
|
88 |
+
"input_dropout": 0.5,
|
89 |
+
"dropout": 0.0,
|
90 |
+
"droppath": 0.1,
|
91 |
+
"txt_drop_ratio": 0,
|
92 |
+
"use_txt_pos": false,
|
93 |
+
"nheads": 8,
|
94 |
+
"num_queries": 10,
|
95 |
+
"pre_norm": false,
|
96 |
+
"set_cost_span": 10,
|
97 |
+
"set_cost_giou": 1,
|
98 |
+
"set_cost_class": 4,
|
99 |
+
"saliency_margin": 0.2,
|
100 |
+
"aux_loss": false,
|
101 |
+
"max_segment_num": 20,
|
102 |
+
"max_frame_num": 200,
|
103 |
+
"top_percent": 0.02,
|
104 |
+
"qfvs_vid_feature": "fps1",
|
105 |
+
"qfvs_txt_feature": "query",
|
106 |
+
"qfvs_dense_shot": -1,
|
107 |
+
"qfvs_score_ensemble": -1,
|
108 |
+
"qfvs_score_gather": -1,
|
109 |
+
"qfvs_loss_gather": -1,
|
110 |
+
"results_dir": "results/vlp-vlp/omni_mini_aio_unified__epo3_f10_b10g1_s0.1_0.1-clip-clip-2023_05_31_06"
|
111 |
+
}
|
run_on_video/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from run_on_video.video_extractor import vid2clip, txt2clip
|
run_on_video/clip/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .clip import *
|
run_on_video/clip/bpe_simple_vocab_16e6.txt.gz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a
|
3 |
+
size 1356917
|
run_on_video/clip/clip.py
ADDED
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import hashlib
|
2 |
+
import os
|
3 |
+
import urllib
|
4 |
+
import warnings
|
5 |
+
from typing import Union, List
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from PIL import Image
|
9 |
+
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
|
10 |
+
from tqdm import tqdm
|
11 |
+
|
12 |
+
from .model import build_model
|
13 |
+
from .simple_tokenizer import SimpleTokenizer as _Tokenizer
|
14 |
+
|
15 |
+
__all__ = ["available_models", "load", "tokenize"]
|
16 |
+
_tokenizer = _Tokenizer()
|
17 |
+
|
18 |
+
_MODELS = {
|
19 |
+
"RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
|
20 |
+
"RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
|
21 |
+
"RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
|
22 |
+
"ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
|
23 |
+
}
|
24 |
+
|
25 |
+
|
26 |
+
def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")):
|
27 |
+
os.makedirs(root, exist_ok=True)
|
28 |
+
filename = os.path.basename(url)
|
29 |
+
|
30 |
+
expected_sha256 = url.split("/")[-2]
|
31 |
+
download_target = os.path.join(root, filename)
|
32 |
+
|
33 |
+
if os.path.exists(download_target) and not os.path.isfile(download_target):
|
34 |
+
raise RuntimeError(f"{download_target} exists and is not a regular file")
|
35 |
+
|
36 |
+
if os.path.isfile(download_target):
|
37 |
+
if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
|
38 |
+
return download_target
|
39 |
+
else:
|
40 |
+
warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
|
41 |
+
|
42 |
+
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
|
43 |
+
with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop:
|
44 |
+
while True:
|
45 |
+
buffer = source.read(8192)
|
46 |
+
if not buffer:
|
47 |
+
break
|
48 |
+
|
49 |
+
output.write(buffer)
|
50 |
+
loop.update(len(buffer))
|
51 |
+
|
52 |
+
if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
|
53 |
+
raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match")
|
54 |
+
|
55 |
+
return download_target
|
56 |
+
|
57 |
+
|
58 |
+
def _transform(n_px):
|
59 |
+
return Compose([
|
60 |
+
Resize(n_px, interpolation=Image.BICUBIC),
|
61 |
+
CenterCrop(n_px),
|
62 |
+
lambda image: image.convert("RGB"),
|
63 |
+
ToTensor(),
|
64 |
+
Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
|
65 |
+
])
|
66 |
+
|
67 |
+
|
68 |
+
def available_models() -> List[str]:
|
69 |
+
"""Returns the names of available CLIP models"""
|
70 |
+
return list(_MODELS.keys())
|
71 |
+
|
72 |
+
|
73 |
+
def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit=True):
|
74 |
+
"""Load a CLIP model
|
75 |
+
|
76 |
+
Parameters
|
77 |
+
----------
|
78 |
+
name : str
|
79 |
+
A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
|
80 |
+
|
81 |
+
device : Union[str, torch.device]
|
82 |
+
The device to put the loaded model
|
83 |
+
|
84 |
+
jit : bool
|
85 |
+
Whether to load the optimized JIT model (default) or more hackable non-JIT model.
|
86 |
+
|
87 |
+
Returns
|
88 |
+
-------
|
89 |
+
model : torch.nn.Module
|
90 |
+
The CLIP model
|
91 |
+
|
92 |
+
preprocess : Callable[[PIL.Image], torch.Tensor]
|
93 |
+
A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
|
94 |
+
"""
|
95 |
+
if name in _MODELS:
|
96 |
+
model_path = _download(_MODELS[name])
|
97 |
+
elif os.path.isfile(name):
|
98 |
+
model_path = name
|
99 |
+
else:
|
100 |
+
raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
|
101 |
+
|
102 |
+
try:
|
103 |
+
# loading JIT archive
|
104 |
+
model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval()
|
105 |
+
state_dict = None
|
106 |
+
except RuntimeError:
|
107 |
+
# loading saved state dict
|
108 |
+
if jit:
|
109 |
+
warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
|
110 |
+
jit = False
|
111 |
+
state_dict = torch.load(model_path, map_location="cpu")
|
112 |
+
|
113 |
+
if not jit:
|
114 |
+
model = build_model(state_dict or model.state_dict()).to(device)
|
115 |
+
if str(device) == "cpu":
|
116 |
+
model.float()
|
117 |
+
return model, _transform(model.visual.input_resolution)
|
118 |
+
|
119 |
+
# patch the device names
|
120 |
+
device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
|
121 |
+
device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
|
122 |
+
|
123 |
+
def patch_device(module):
|
124 |
+
graphs = [module.graph] if hasattr(module, "graph") else []
|
125 |
+
if hasattr(module, "forward1"):
|
126 |
+
graphs.append(module.forward1.graph)
|
127 |
+
|
128 |
+
for graph in graphs:
|
129 |
+
for node in graph.findAllNodes("prim::Constant"):
|
130 |
+
if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"):
|
131 |
+
node.copyAttributes(device_node)
|
132 |
+
|
133 |
+
model.apply(patch_device)
|
134 |
+
patch_device(model.encode_image)
|
135 |
+
patch_device(model.encode_text)
|
136 |
+
|
137 |
+
# patch dtype to float32 on CPU
|
138 |
+
if str(device) == "cpu":
|
139 |
+
float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
|
140 |
+
float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
|
141 |
+
float_node = float_input.node()
|
142 |
+
|
143 |
+
def patch_float(module):
|
144 |
+
graphs = [module.graph] if hasattr(module, "graph") else []
|
145 |
+
if hasattr(module, "forward1"):
|
146 |
+
graphs.append(module.forward1.graph)
|
147 |
+
|
148 |
+
for graph in graphs:
|
149 |
+
for node in graph.findAllNodes("aten::to"):
|
150 |
+
inputs = list(node.inputs())
|
151 |
+
for i in [1, 2]: # dtype can be the second or third argument to aten::to()
|
152 |
+
if inputs[i].node()["value"] == 5:
|
153 |
+
inputs[i].node().copyAttributes(float_node)
|
154 |
+
|
155 |
+
model.apply(patch_float)
|
156 |
+
patch_float(model.encode_image)
|
157 |
+
patch_float(model.encode_text)
|
158 |
+
|
159 |
+
model.float()
|
160 |
+
|
161 |
+
return model, _transform(model.input_resolution.item())
|
162 |
+
|
163 |
+
|
164 |
+
def tokenize(texts: Union[str, List[str]], context_length: int = 77, max_valid_length: int = 32) -> torch.LongTensor:
|
165 |
+
"""
|
166 |
+
Returns the tokenized representation of given input string(s)
|
167 |
+
|
168 |
+
Parameters
|
169 |
+
----------
|
170 |
+
texts : Union[str, List[str]]
|
171 |
+
An input string or a list of input strings to tokenize
|
172 |
+
|
173 |
+
context_length : int
|
174 |
+
The context length to use; all CLIP models use 77 as the context length
|
175 |
+
|
176 |
+
max_valid_length:
|
177 |
+
|
178 |
+
Returns
|
179 |
+
-------
|
180 |
+
A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
|
181 |
+
"""
|
182 |
+
if isinstance(texts, str):
|
183 |
+
texts = [texts]
|
184 |
+
|
185 |
+
sot_token = _tokenizer.encoder["<|startoftext|>"]
|
186 |
+
eot_token = _tokenizer.encoder["<|endoftext|>"]
|
187 |
+
all_tokens = [[sot_token] + _tokenizer.encode(text)[:max_valid_length-2] + [eot_token] for text in texts]
|
188 |
+
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
|
189 |
+
|
190 |
+
for i, tokens in enumerate(all_tokens):
|
191 |
+
if len(tokens) > context_length:
|
192 |
+
raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
|
193 |
+
result[i, :len(tokens)] = torch.tensor(tokens)
|
194 |
+
|
195 |
+
return result
|
run_on_video/clip/model.py
ADDED
@@ -0,0 +1,432 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import OrderedDict
|
2 |
+
from typing import Tuple, Union
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from torch import nn
|
8 |
+
|
9 |
+
|
10 |
+
class Bottleneck(nn.Module):
|
11 |
+
expansion = 4
|
12 |
+
|
13 |
+
def __init__(self, inplanes, planes, stride=1):
|
14 |
+
super().__init__()
|
15 |
+
|
16 |
+
# all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
|
17 |
+
self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
|
18 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
19 |
+
|
20 |
+
self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
|
21 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
22 |
+
|
23 |
+
self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
|
24 |
+
|
25 |
+
self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
|
26 |
+
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
|
27 |
+
|
28 |
+
self.relu = nn.ReLU(inplace=True)
|
29 |
+
self.downsample = None
|
30 |
+
self.stride = stride
|
31 |
+
|
32 |
+
if stride > 1 or inplanes != planes * Bottleneck.expansion:
|
33 |
+
# downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
|
34 |
+
self.downsample = nn.Sequential(OrderedDict([
|
35 |
+
("-1", nn.AvgPool2d(stride)),
|
36 |
+
("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
|
37 |
+
("1", nn.BatchNorm2d(planes * self.expansion))
|
38 |
+
]))
|
39 |
+
|
40 |
+
def forward(self, x: torch.Tensor):
|
41 |
+
identity = x
|
42 |
+
|
43 |
+
out = self.relu(self.bn1(self.conv1(x)))
|
44 |
+
out = self.relu(self.bn2(self.conv2(out)))
|
45 |
+
out = self.avgpool(out)
|
46 |
+
out = self.bn3(self.conv3(out))
|
47 |
+
|
48 |
+
if self.downsample is not None:
|
49 |
+
identity = self.downsample(x)
|
50 |
+
|
51 |
+
out += identity
|
52 |
+
out = self.relu(out)
|
53 |
+
return out
|
54 |
+
|
55 |
+
|
56 |
+
class AttentionPool2d(nn.Module):
|
57 |
+
def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
|
58 |
+
super().__init__()
|
59 |
+
self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
|
60 |
+
self.k_proj = nn.Linear(embed_dim, embed_dim)
|
61 |
+
self.q_proj = nn.Linear(embed_dim, embed_dim)
|
62 |
+
self.v_proj = nn.Linear(embed_dim, embed_dim)
|
63 |
+
self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
|
64 |
+
self.num_heads = num_heads
|
65 |
+
|
66 |
+
def forward(self, x):
|
67 |
+
x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC
|
68 |
+
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
|
69 |
+
x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
|
70 |
+
x, _ = F.multi_head_attention_forward(
|
71 |
+
query=x, key=x, value=x,
|
72 |
+
embed_dim_to_check=x.shape[-1],
|
73 |
+
num_heads=self.num_heads,
|
74 |
+
q_proj_weight=self.q_proj.weight,
|
75 |
+
k_proj_weight=self.k_proj.weight,
|
76 |
+
v_proj_weight=self.v_proj.weight,
|
77 |
+
in_proj_weight=None,
|
78 |
+
in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
|
79 |
+
bias_k=None,
|
80 |
+
bias_v=None,
|
81 |
+
add_zero_attn=False,
|
82 |
+
dropout_p=0,
|
83 |
+
out_proj_weight=self.c_proj.weight,
|
84 |
+
out_proj_bias=self.c_proj.bias,
|
85 |
+
use_separate_proj_weight=True,
|
86 |
+
training=self.training,
|
87 |
+
need_weights=False
|
88 |
+
)
|
89 |
+
|
90 |
+
return x[0]
|
91 |
+
|
92 |
+
|
93 |
+
class ModifiedResNet(nn.Module):
|
94 |
+
"""
|
95 |
+
A ResNet class that is similar to torchvision's but contains the following changes:
|
96 |
+
- There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
|
97 |
+
- Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
|
98 |
+
- The final pooling layer is a QKV attention instead of an average pool
|
99 |
+
"""
|
100 |
+
|
101 |
+
def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
|
102 |
+
super().__init__()
|
103 |
+
self.output_dim = output_dim
|
104 |
+
self.input_resolution = input_resolution
|
105 |
+
|
106 |
+
# the 3-layer stem
|
107 |
+
self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
|
108 |
+
self.bn1 = nn.BatchNorm2d(width // 2)
|
109 |
+
self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
|
110 |
+
self.bn2 = nn.BatchNorm2d(width // 2)
|
111 |
+
self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
|
112 |
+
self.bn3 = nn.BatchNorm2d(width)
|
113 |
+
self.avgpool = nn.AvgPool2d(2)
|
114 |
+
self.relu = nn.ReLU(inplace=True)
|
115 |
+
|
116 |
+
# residual layers
|
117 |
+
self._inplanes = width # this is a *mutable* variable used during construction
|
118 |
+
self.layer1 = self._make_layer(width, layers[0])
|
119 |
+
self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
|
120 |
+
self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
|
121 |
+
self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
|
122 |
+
|
123 |
+
embed_dim = width * 32 # the ResNet feature dimension
|
124 |
+
self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)
|
125 |
+
|
126 |
+
def _make_layer(self, planes, blocks, stride=1):
|
127 |
+
layers = [Bottleneck(self._inplanes, planes, stride)]
|
128 |
+
|
129 |
+
self._inplanes = planes * Bottleneck.expansion
|
130 |
+
for _ in range(1, blocks):
|
131 |
+
layers.append(Bottleneck(self._inplanes, planes))
|
132 |
+
|
133 |
+
return nn.Sequential(*layers)
|
134 |
+
|
135 |
+
def forward(self, x):
|
136 |
+
def stem(x):
|
137 |
+
for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), (self.conv3, self.bn3)]:
|
138 |
+
x = self.relu(bn(conv(x)))
|
139 |
+
x = self.avgpool(x)
|
140 |
+
return x
|
141 |
+
|
142 |
+
x = x.type(self.conv1.weight.dtype)
|
143 |
+
x = stem(x)
|
144 |
+
x = self.layer1(x)
|
145 |
+
x = self.layer2(x)
|
146 |
+
x = self.layer3(x)
|
147 |
+
x = self.layer4(x)
|
148 |
+
x = self.attnpool(x)
|
149 |
+
|
150 |
+
return x
|
151 |
+
|
152 |
+
|
153 |
+
class LayerNorm(nn.LayerNorm):
|
154 |
+
"""Subclass torch's LayerNorm to handle fp16."""
|
155 |
+
|
156 |
+
def forward(self, x: torch.Tensor):
|
157 |
+
orig_type = x.dtype
|
158 |
+
ret = super().forward(x.type(torch.float32))
|
159 |
+
return ret.type(orig_type)
|
160 |
+
|
161 |
+
|
162 |
+
class QuickGELU(nn.Module):
|
163 |
+
def forward(self, x: torch.Tensor):
|
164 |
+
return x * torch.sigmoid(1.702 * x)
|
165 |
+
|
166 |
+
|
167 |
+
class ResidualAttentionBlock(nn.Module):
|
168 |
+
def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
|
169 |
+
super().__init__()
|
170 |
+
|
171 |
+
self.attn = nn.MultiheadAttention(d_model, n_head)
|
172 |
+
self.ln_1 = LayerNorm(d_model)
|
173 |
+
self.mlp = nn.Sequential(OrderedDict([
|
174 |
+
("c_fc", nn.Linear(d_model, d_model * 4)),
|
175 |
+
("gelu", QuickGELU()),
|
176 |
+
("c_proj", nn.Linear(d_model * 4, d_model))
|
177 |
+
]))
|
178 |
+
self.ln_2 = LayerNorm(d_model)
|
179 |
+
self.attn_mask = attn_mask
|
180 |
+
|
181 |
+
def attention(self, x: torch.Tensor):
|
182 |
+
self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
|
183 |
+
return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
|
184 |
+
|
185 |
+
def forward(self, x: torch.Tensor):
|
186 |
+
x = x + self.attention(self.ln_1(x))
|
187 |
+
x = x + self.mlp(self.ln_2(x))
|
188 |
+
return x
|
189 |
+
|
190 |
+
|
191 |
+
class Transformer(nn.Module):
|
192 |
+
def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):
|
193 |
+
super().__init__()
|
194 |
+
self.width = width
|
195 |
+
self.layers = layers
|
196 |
+
self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
|
197 |
+
|
198 |
+
def forward(self, x: torch.Tensor):
|
199 |
+
return self.resblocks(x)
|
200 |
+
|
201 |
+
|
202 |
+
class VisualTransformer(nn.Module):
|
203 |
+
def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int):
|
204 |
+
super().__init__()
|
205 |
+
self.input_resolution = input_resolution
|
206 |
+
self.output_dim = output_dim
|
207 |
+
self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
|
208 |
+
|
209 |
+
scale = width ** -0.5
|
210 |
+
self.class_embedding = nn.Parameter(scale * torch.randn(width))
|
211 |
+
self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
|
212 |
+
self.ln_pre = LayerNorm(width)
|
213 |
+
|
214 |
+
self.transformer = Transformer(width, layers, heads)
|
215 |
+
|
216 |
+
self.ln_post = LayerNorm(width)
|
217 |
+
self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
|
218 |
+
|
219 |
+
def forward(self, x: torch.Tensor):
|
220 |
+
x = self.conv1(x) # shape = [*, width, grid, grid]
|
221 |
+
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
|
222 |
+
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
|
223 |
+
x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
|
224 |
+
x = x + self.positional_embedding.to(x.dtype)
|
225 |
+
x = self.ln_pre(x)
|
226 |
+
|
227 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
228 |
+
x = self.transformer(x)
|
229 |
+
x = x.permute(1, 0, 2) # LND -> NLD
|
230 |
+
|
231 |
+
x = self.ln_post(x[:, 0, :])
|
232 |
+
|
233 |
+
if self.proj is not None:
|
234 |
+
x = x @ self.proj
|
235 |
+
|
236 |
+
return x
|
237 |
+
|
238 |
+
|
239 |
+
class CLIP(nn.Module):
|
240 |
+
def __init__(self,
|
241 |
+
embed_dim: int,
|
242 |
+
# vision
|
243 |
+
image_resolution: int,
|
244 |
+
vision_layers: Union[Tuple[int, int, int, int], int],
|
245 |
+
vision_width: int,
|
246 |
+
vision_patch_size: int,
|
247 |
+
# text
|
248 |
+
context_length: int,
|
249 |
+
vocab_size: int,
|
250 |
+
transformer_width: int,
|
251 |
+
transformer_heads: int,
|
252 |
+
transformer_layers: int
|
253 |
+
):
|
254 |
+
super().__init__()
|
255 |
+
|
256 |
+
self.context_length = context_length
|
257 |
+
|
258 |
+
if isinstance(vision_layers, (tuple, list)):
|
259 |
+
vision_heads = vision_width * 32 // 64
|
260 |
+
self.visual = ModifiedResNet(
|
261 |
+
layers=vision_layers,
|
262 |
+
output_dim=embed_dim,
|
263 |
+
heads=vision_heads,
|
264 |
+
input_resolution=image_resolution,
|
265 |
+
width=vision_width
|
266 |
+
)
|
267 |
+
else:
|
268 |
+
vision_heads = vision_width // 64
|
269 |
+
self.visual = VisualTransformer(
|
270 |
+
input_resolution=image_resolution,
|
271 |
+
patch_size=vision_patch_size,
|
272 |
+
width=vision_width,
|
273 |
+
layers=vision_layers,
|
274 |
+
heads=vision_heads,
|
275 |
+
output_dim=embed_dim
|
276 |
+
)
|
277 |
+
|
278 |
+
self.transformer = Transformer(
|
279 |
+
width=transformer_width,
|
280 |
+
layers=transformer_layers,
|
281 |
+
heads=transformer_heads,
|
282 |
+
attn_mask=self.build_attention_mask()
|
283 |
+
)
|
284 |
+
|
285 |
+
self.vocab_size = vocab_size
|
286 |
+
self.token_embedding = nn.Embedding(vocab_size, transformer_width)
|
287 |
+
self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
|
288 |
+
self.ln_final = LayerNorm(transformer_width)
|
289 |
+
|
290 |
+
self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
|
291 |
+
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
292 |
+
|
293 |
+
self.initialize_parameters()
|
294 |
+
|
295 |
+
def initialize_parameters(self):
|
296 |
+
nn.init.normal_(self.token_embedding.weight, std=0.02)
|
297 |
+
nn.init.normal_(self.positional_embedding, std=0.01)
|
298 |
+
|
299 |
+
if isinstance(self.visual, ModifiedResNet):
|
300 |
+
if self.visual.attnpool is not None:
|
301 |
+
std = self.visual.attnpool.c_proj.in_features ** -0.5
|
302 |
+
nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
|
303 |
+
nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
|
304 |
+
nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
|
305 |
+
nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
|
306 |
+
|
307 |
+
for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]:
|
308 |
+
for name, param in resnet_block.named_parameters():
|
309 |
+
if name.endswith("bn3.weight"):
|
310 |
+
nn.init.zeros_(param)
|
311 |
+
|
312 |
+
proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
|
313 |
+
attn_std = self.transformer.width ** -0.5
|
314 |
+
fc_std = (2 * self.transformer.width) ** -0.5
|
315 |
+
for block in self.transformer.resblocks:
|
316 |
+
nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
|
317 |
+
nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
|
318 |
+
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
|
319 |
+
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
|
320 |
+
|
321 |
+
if self.text_projection is not None:
|
322 |
+
nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
|
323 |
+
|
324 |
+
def build_attention_mask(self):
|
325 |
+
# lazily create causal attention mask, with full attention between the vision tokens
|
326 |
+
# pytorch uses additive attention mask; fill with -inf
|
327 |
+
mask = torch.empty(self.context_length, self.context_length)
|
328 |
+
mask.fill_(float("-inf"))
|
329 |
+
mask.triu_(1) # zero out the lower diagonal
|
330 |
+
return mask
|
331 |
+
|
332 |
+
@property
|
333 |
+
def dtype(self):
|
334 |
+
return self.visual.conv1.weight.dtype
|
335 |
+
|
336 |
+
def encode_image(self, image):
|
337 |
+
return self.visual(image.type(self.dtype))
|
338 |
+
|
339 |
+
def encode_text(self, text):
|
340 |
+
x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
|
341 |
+
|
342 |
+
x = x + self.positional_embedding.type(self.dtype)
|
343 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
344 |
+
x = self.transformer(x)
|
345 |
+
x = x.permute(1, 0, 2) # LND -> NLD
|
346 |
+
x = self.ln_final(x).type(self.dtype)
|
347 |
+
|
348 |
+
# x.shape = [batch_size, n_ctx, transformer.width]
|
349 |
+
# take features from the eot embedding (eot_token is the highest number in each sequence)
|
350 |
+
eos_x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
|
351 |
+
|
352 |
+
return dict(last_hidden_state=x, pooler_output=eos_x)
|
353 |
+
|
354 |
+
def forward(self, image, text):
|
355 |
+
image_features = self.encode_image(image)
|
356 |
+
text_features = self.encode_text(text)
|
357 |
+
|
358 |
+
# normalized features
|
359 |
+
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
|
360 |
+
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
|
361 |
+
|
362 |
+
# cosine similarity as logits
|
363 |
+
logit_scale = self.logit_scale.exp()
|
364 |
+
logits_per_image = logit_scale * image_features @ text_features.t()
|
365 |
+
logits_per_text = logit_scale * text_features @ image_features.t()
|
366 |
+
|
367 |
+
# shape = [global_batch_size, global_batch_size]
|
368 |
+
return logits_per_image, logits_per_text
|
369 |
+
|
370 |
+
|
371 |
+
def convert_weights(model: nn.Module):
|
372 |
+
"""Convert applicable model parameters to fp16"""
|
373 |
+
|
374 |
+
def _convert_weights_to_fp16(l):
|
375 |
+
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
|
376 |
+
l.weight.data = l.weight.data.half()
|
377 |
+
if l.bias is not None:
|
378 |
+
l.bias.data = l.bias.data.half()
|
379 |
+
|
380 |
+
if isinstance(l, nn.MultiheadAttention):
|
381 |
+
for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
|
382 |
+
tensor = getattr(l, attr)
|
383 |
+
if tensor is not None:
|
384 |
+
tensor.data = tensor.data.half()
|
385 |
+
|
386 |
+
for name in ["text_projection", "proj"]:
|
387 |
+
if hasattr(l, name):
|
388 |
+
attr = getattr(l, name)
|
389 |
+
if attr is not None:
|
390 |
+
attr.data = attr.data.half()
|
391 |
+
|
392 |
+
model.apply(_convert_weights_to_fp16)
|
393 |
+
|
394 |
+
|
395 |
+
def build_model(state_dict: dict):
|
396 |
+
vit = "visual.proj" in state_dict
|
397 |
+
|
398 |
+
if vit:
|
399 |
+
vision_width = state_dict["visual.conv1.weight"].shape[0]
|
400 |
+
vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
|
401 |
+
vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
|
402 |
+
grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
|
403 |
+
image_resolution = vision_patch_size * grid_size
|
404 |
+
else:
|
405 |
+
counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
|
406 |
+
vision_layers = tuple(counts)
|
407 |
+
vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
|
408 |
+
output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
|
409 |
+
vision_patch_size = None
|
410 |
+
assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
|
411 |
+
image_resolution = output_width * 32
|
412 |
+
|
413 |
+
embed_dim = state_dict["text_projection"].shape[1]
|
414 |
+
context_length = state_dict["positional_embedding"].shape[0]
|
415 |
+
vocab_size = state_dict["token_embedding.weight"].shape[0]
|
416 |
+
transformer_width = state_dict["ln_final.weight"].shape[0]
|
417 |
+
transformer_heads = transformer_width // 64
|
418 |
+
transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks")))
|
419 |
+
|
420 |
+
model = CLIP(
|
421 |
+
embed_dim,
|
422 |
+
image_resolution, vision_layers, vision_width, vision_patch_size,
|
423 |
+
context_length, vocab_size, transformer_width, transformer_heads, transformer_layers
|
424 |
+
)
|
425 |
+
|
426 |
+
for key in ["input_resolution", "context_length", "vocab_size"]:
|
427 |
+
if key in state_dict:
|
428 |
+
del state_dict[key]
|
429 |
+
|
430 |
+
convert_weights(model)
|
431 |
+
model.load_state_dict(state_dict)
|
432 |
+
return model.eval()
|
run_on_video/clip/simple_tokenizer.py
ADDED
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gzip
|
2 |
+
import html
|
3 |
+
import os
|
4 |
+
from functools import lru_cache
|
5 |
+
|
6 |
+
import ftfy
|
7 |
+
import regex as re
|
8 |
+
|
9 |
+
|
10 |
+
@lru_cache()
|
11 |
+
def default_bpe():
|
12 |
+
return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
|
13 |
+
|
14 |
+
|
15 |
+
@lru_cache()
|
16 |
+
def bytes_to_unicode():
|
17 |
+
"""
|
18 |
+
Returns list of utf-8 byte and a corresponding list of unicode strings.
|
19 |
+
The reversible bpe codes work on unicode strings.
|
20 |
+
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
|
21 |
+
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
|
22 |
+
This is a signficant percentage of your normal, say, 32K bpe vocab.
|
23 |
+
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
|
24 |
+
And avoids mapping to whitespace/control characters the bpe code barfs on.
|
25 |
+
"""
|
26 |
+
bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
|
27 |
+
cs = bs[:]
|
28 |
+
n = 0
|
29 |
+
for b in range(2**8):
|
30 |
+
if b not in bs:
|
31 |
+
bs.append(b)
|
32 |
+
cs.append(2**8+n)
|
33 |
+
n += 1
|
34 |
+
cs = [chr(n) for n in cs]
|
35 |
+
return dict(zip(bs, cs))
|
36 |
+
|
37 |
+
|
38 |
+
def get_pairs(word):
|
39 |
+
"""Return set of symbol pairs in a word.
|
40 |
+
Word is represented as tuple of symbols (symbols being variable-length strings).
|
41 |
+
"""
|
42 |
+
pairs = set()
|
43 |
+
prev_char = word[0]
|
44 |
+
for char in word[1:]:
|
45 |
+
pairs.add((prev_char, char))
|
46 |
+
prev_char = char
|
47 |
+
return pairs
|
48 |
+
|
49 |
+
|
50 |
+
def basic_clean(text):
|
51 |
+
text = ftfy.fix_text(text)
|
52 |
+
text = html.unescape(html.unescape(text))
|
53 |
+
return text.strip()
|
54 |
+
|
55 |
+
|
56 |
+
def whitespace_clean(text):
|
57 |
+
text = re.sub(r'\s+', ' ', text)
|
58 |
+
text = text.strip()
|
59 |
+
return text
|
60 |
+
|
61 |
+
|
62 |
+
class SimpleTokenizer(object):
|
63 |
+
def __init__(self, bpe_path: str = default_bpe()):
|
64 |
+
self.byte_encoder = bytes_to_unicode()
|
65 |
+
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
|
66 |
+
merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
|
67 |
+
merges = merges[1:49152-256-2+1]
|
68 |
+
merges = [tuple(merge.split()) for merge in merges]
|
69 |
+
vocab = list(bytes_to_unicode().values())
|
70 |
+
vocab = vocab + [v+'</w>' for v in vocab]
|
71 |
+
for merge in merges:
|
72 |
+
vocab.append(''.join(merge))
|
73 |
+
vocab.extend(['<|startoftext|>', '<|endoftext|>'])
|
74 |
+
self.encoder = dict(zip(vocab, range(len(vocab))))
|
75 |
+
self.decoder = {v: k for k, v in self.encoder.items()}
|
76 |
+
self.bpe_ranks = dict(zip(merges, range(len(merges))))
|
77 |
+
self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
|
78 |
+
self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
|
79 |
+
|
80 |
+
def bpe(self, token):
|
81 |
+
if token in self.cache:
|
82 |
+
return self.cache[token]
|
83 |
+
word = tuple(token[:-1]) + ( token[-1] + '</w>',)
|
84 |
+
pairs = get_pairs(word)
|
85 |
+
|
86 |
+
if not pairs:
|
87 |
+
return token+'</w>'
|
88 |
+
|
89 |
+
while True:
|
90 |
+
bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
|
91 |
+
if bigram not in self.bpe_ranks:
|
92 |
+
break
|
93 |
+
first, second = bigram
|
94 |
+
new_word = []
|
95 |
+
i = 0
|
96 |
+
while i < len(word):
|
97 |
+
try:
|
98 |
+
j = word.index(first, i)
|
99 |
+
new_word.extend(word[i:j])
|
100 |
+
i = j
|
101 |
+
except:
|
102 |
+
new_word.extend(word[i:])
|
103 |
+
break
|
104 |
+
|
105 |
+
if word[i] == first and i < len(word)-1 and word[i+1] == second:
|
106 |
+
new_word.append(first+second)
|
107 |
+
i += 2
|
108 |
+
else:
|
109 |
+
new_word.append(word[i])
|
110 |
+
i += 1
|
111 |
+
new_word = tuple(new_word)
|
112 |
+
word = new_word
|
113 |
+
if len(word) == 1:
|
114 |
+
break
|
115 |
+
else:
|
116 |
+
pairs = get_pairs(word)
|
117 |
+
word = ' '.join(word)
|
118 |
+
self.cache[token] = word
|
119 |
+
return word
|
120 |
+
|
121 |
+
def encode(self, text):
|
122 |
+
bpe_tokens = []
|
123 |
+
text = whitespace_clean(basic_clean(text)).lower()
|
124 |
+
for token in re.findall(self.pat, text):
|
125 |
+
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
|
126 |
+
bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
|
127 |
+
return bpe_tokens
|
128 |
+
|
129 |
+
def decode(self, tokens):
|
130 |
+
text = ''.join([self.decoder[token] for token in tokens])
|
131 |
+
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
|
132 |
+
return text
|
run_on_video/clip_feature_extractor.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pdb
|
2 |
+
import torch as th
|
3 |
+
import math
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
from video_loader import VideoLoader
|
7 |
+
from torch.utils.data import DataLoader
|
8 |
+
import argparse
|
9 |
+
from preprocessing import Preprocessing
|
10 |
+
import torch.nn.functional as F
|
11 |
+
from tqdm import tqdm
|
12 |
+
import os
|
13 |
+
import sys
|
14 |
+
from feature_extractor import clip
|
15 |
+
import argparse
|
16 |
+
|
17 |
+
#################################
|
18 |
+
model_version = "ViT-B/32"
|
19 |
+
output_feat_size = 512
|
20 |
+
clip_len = 2
|
21 |
+
overwrite = True
|
22 |
+
num_decoding_thread = 4
|
23 |
+
half_precision = False
|
24 |
+
|
25 |
+
@torch.no_grad()
|
26 |
+
def extractor(vid_path, text, output_file):
|
27 |
+
dataset = VideoLoader(
|
28 |
+
vid_path,
|
29 |
+
framerate=1/clip_len,
|
30 |
+
size=224,
|
31 |
+
centercrop=True,
|
32 |
+
overwrite=overwrite,
|
33 |
+
model_version=model_version
|
34 |
+
)
|
35 |
+
n_dataset = len(dataset)
|
36 |
+
loader = DataLoader(
|
37 |
+
dataset,
|
38 |
+
batch_size=1,
|
39 |
+
shuffle=False,
|
40 |
+
num_workers=num_decoding_thread,
|
41 |
+
sampler=sampler if n_dataset > 10 else None,
|
42 |
+
)
|
43 |
+
preprocess = Preprocessing()
|
44 |
+
model, _ = clip.load(model_version, device="cuda", jit=False)
|
45 |
+
|
46 |
+
encoded_texts = clip.tokenize(text).to('cuda')
|
47 |
+
text_feature = model.encode_text(encoded_texts)['last_hidden_state']
|
48 |
+
valid_lengths = (encoded_texts != 0).sum(1).tolist()[0]
|
49 |
+
text_feature = text_feature[0, :valid_lengths].cpu().numpy()
|
50 |
+
np.savez(os.path.join(output_file, 'txt.npz'), features=text_feature)
|
51 |
+
|
52 |
+
totatl_num_frames = 0
|
53 |
+
with th.no_grad():
|
54 |
+
for k, data in enumerate(tqdm(loader)):
|
55 |
+
input_file = data['input'][0]
|
56 |
+
if os.path.isfile(output_file):
|
57 |
+
# print(f'Video {input_file} already processed.')
|
58 |
+
continue
|
59 |
+
elif not os.path.isfile(input_file):
|
60 |
+
print(f'{input_file}, does not exist.\n')
|
61 |
+
elif len(data['video'].shape) > 4:
|
62 |
+
video = data['video'].squeeze(0)
|
63 |
+
if len(video.shape) == 4:
|
64 |
+
video = preprocess(video)
|
65 |
+
n_chunk = len(video)
|
66 |
+
vid_features = th.cuda.FloatTensor(
|
67 |
+
n_chunk, output_feat_size).fill_(0)
|
68 |
+
n_iter = int(math.ceil(n_chunk))
|
69 |
+
for i in range(n_iter):
|
70 |
+
min_ind = i
|
71 |
+
max_ind = (i + 1)
|
72 |
+
video_batch = video[min_ind:max_ind].cuda()
|
73 |
+
batch_features = model.encode_image(video_batch)
|
74 |
+
vid_features[min_ind:max_ind] = batch_features
|
75 |
+
vid_features = vid_features.cpu().numpy()
|
76 |
+
if half_precision:
|
77 |
+
vid_features = vid_features.astype('float16')
|
78 |
+
totatl_num_frames += vid_features.shape[0]
|
79 |
+
# safeguard output path before saving
|
80 |
+
dirname = os.path.dirname(output_file)
|
81 |
+
if not os.path.exists(dirname):
|
82 |
+
print(f"Output directory {dirname} does not exists, creating...")
|
83 |
+
os.makedirs(dirname)
|
84 |
+
np.savez(os.path.join(output_file, 'vid.npz'), features=vid_features)
|
85 |
+
else:
|
86 |
+
print(f'{input_file}, failed at ffprobe.\n')
|
87 |
+
|
88 |
+
print(f"Total number of frames: {totatl_num_frames}")
|
89 |
+
|
90 |
+
if __name__ == "__main__":
|
91 |
+
parser = argparse.ArgumentParser(description='')
|
92 |
+
parser.add_argument('--vid_path', type=str, default='/data/home/qinghonglin/dataset/charades/videos/Charades_v1_480/0A8CF.mp4')
|
93 |
+
parser.add_argument('--text', nargs='+', type=str, default='a boy is drinking.')
|
94 |
+
parser.add_argument('--save_dir', type=str, default='./tmp')
|
95 |
+
args = parser.parse_args()
|
96 |
+
|
97 |
+
query = ' '.join(args.text)
|
98 |
+
|
99 |
+
print(args.vid_path)
|
100 |
+
print(query)
|
101 |
+
extractor(args.vid_path, [query], args.save_dir)
|
run_on_video/data_utils.py
ADDED
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import os
|
3 |
+
import numpy as np
|
4 |
+
import ffmpeg
|
5 |
+
import math
|
6 |
+
from run_on_video import clip
|
7 |
+
|
8 |
+
|
9 |
+
class ClipFeatureExtractor:
|
10 |
+
def __init__(self, framerate=1/2, size=224, centercrop=True, model_name_or_path="ViT-B/32", device="cuda"):
|
11 |
+
self.video_loader = VideoLoader(framerate=framerate, size=size, centercrop=centercrop)
|
12 |
+
print("Loading CLIP models")
|
13 |
+
self.clip_extractor, _ = clip.load(model_name_or_path, device=device, jit=False)
|
14 |
+
self.tokenizer = clip.tokenize
|
15 |
+
self.video_preprocessor = Preprocessing()
|
16 |
+
self.device = device
|
17 |
+
|
18 |
+
@torch.no_grad()
|
19 |
+
def encode_video(self, video_path: str, bsz=60):
|
20 |
+
video_frames = self.video_loader.read_video_from_file(video_path) # (T, H, W, 3)
|
21 |
+
video_frames = self.video_preprocessor(video_frames)
|
22 |
+
n_frames = len(video_frames)
|
23 |
+
n_batch = int(math.ceil(n_frames / bsz))
|
24 |
+
video_features = []
|
25 |
+
for i in range(n_batch):
|
26 |
+
st_idx = i * bsz
|
27 |
+
ed_idx = (i+1) * bsz
|
28 |
+
_video_frames = video_frames[st_idx:ed_idx].to(self.device)
|
29 |
+
_video_features = self.clip_extractor.encode_image(_video_frames)
|
30 |
+
video_features.append(_video_features)
|
31 |
+
video_features = torch.cat(video_features, dim=0)
|
32 |
+
return video_features # (T=#frames, d) torch tensor
|
33 |
+
|
34 |
+
@torch.no_grad()
|
35 |
+
def encode_text(self, text_list, bsz=60):
|
36 |
+
n_text = len(text_list)
|
37 |
+
n_batch = int(math.ceil(n_text / bsz))
|
38 |
+
text_features = []
|
39 |
+
for i in range(n_batch):
|
40 |
+
st_idx = i * bsz
|
41 |
+
ed_idx = (i+1) * bsz
|
42 |
+
encoded_texts = self.tokenizer(text_list[st_idx:ed_idx], context_length=77).to(self.device)
|
43 |
+
output = self.clip_extractor.encode_text(encoded_texts)
|
44 |
+
valid_lengths = (encoded_texts != 0).sum(1).tolist()
|
45 |
+
batch_last_hidden_states = output["last_hidden_state"]
|
46 |
+
for j, valid_len in enumerate(valid_lengths):
|
47 |
+
text_features.append(batch_last_hidden_states[j, :valid_len])
|
48 |
+
return text_features # List([L_j, d]) torch tensor
|
49 |
+
|
50 |
+
|
51 |
+
def convert_to_float(frac_str):
|
52 |
+
try:
|
53 |
+
return float(frac_str)
|
54 |
+
except ValueError:
|
55 |
+
try:
|
56 |
+
num, denom = frac_str.split('/')
|
57 |
+
except ValueError:
|
58 |
+
return None
|
59 |
+
try:
|
60 |
+
leading, num = num.split(' ')
|
61 |
+
except ValueError:
|
62 |
+
return float(num) / float(denom)
|
63 |
+
if float(leading) < 0:
|
64 |
+
sign_mult = -1
|
65 |
+
else:
|
66 |
+
sign_mult = 1
|
67 |
+
return float(leading) + sign_mult * (float(num) / float(denom))
|
68 |
+
|
69 |
+
|
70 |
+
class Normalize(object):
|
71 |
+
|
72 |
+
def __init__(self, mean, std):
|
73 |
+
self.mean = torch.FloatTensor(mean).view(1, 3, 1, 1)
|
74 |
+
self.std = torch.FloatTensor(std).view(1, 3, 1, 1)
|
75 |
+
|
76 |
+
def __call__(self, tensor):
|
77 |
+
tensor = (tensor - self.mean) / (self.std + 1e-8)
|
78 |
+
return tensor
|
79 |
+
|
80 |
+
|
81 |
+
class Preprocessing(object):
|
82 |
+
|
83 |
+
def __init__(self):
|
84 |
+
self.norm = Normalize(
|
85 |
+
mean=[0.48145466, 0.4578275, 0.40821073],
|
86 |
+
std=[0.26862954, 0.26130258, 0.27577711])
|
87 |
+
|
88 |
+
def __call__(self, tensor):
|
89 |
+
tensor = tensor / 255.0
|
90 |
+
tensor = self.norm(tensor)
|
91 |
+
return tensor
|
92 |
+
|
93 |
+
|
94 |
+
class VideoLoader:
|
95 |
+
"""Pytorch video loader.
|
96 |
+
Copied and modified from:
|
97 |
+
https://github.com/linjieli222/HERO_Video_Feature_Extractor/blob/main/clip/video_loader.py
|
98 |
+
"""
|
99 |
+
def __init__(
|
100 |
+
self,
|
101 |
+
framerate=1/2,
|
102 |
+
size=224,
|
103 |
+
centercrop=True,
|
104 |
+
):
|
105 |
+
self.centercrop = centercrop
|
106 |
+
self.size = size
|
107 |
+
self.framerate = framerate
|
108 |
+
|
109 |
+
def _get_video_info(self, video_path):
|
110 |
+
probe = ffmpeg.probe(video_path)
|
111 |
+
video_stream = next((stream for stream in probe['streams']
|
112 |
+
if stream['codec_type'] == 'video'), None)
|
113 |
+
width = int(video_stream['width'])
|
114 |
+
height = int(video_stream['height'])
|
115 |
+
fps = math.floor(convert_to_float(video_stream['avg_frame_rate']))
|
116 |
+
try:
|
117 |
+
frames_length = int(video_stream['nb_frames'])
|
118 |
+
duration = float(video_stream['duration'])
|
119 |
+
except Exception:
|
120 |
+
frames_length, duration = -1, -1
|
121 |
+
info = {"duration": duration, "frames_length": frames_length,
|
122 |
+
"fps": fps, "height": height, "width": width}
|
123 |
+
return info
|
124 |
+
|
125 |
+
def _get_output_dim(self, h, w):
|
126 |
+
if isinstance(self.size, tuple) and len(self.size) == 2:
|
127 |
+
return self.size
|
128 |
+
elif h >= w:
|
129 |
+
return int(h * self.size / w), self.size
|
130 |
+
else:
|
131 |
+
return self.size, int(w * self.size / h)
|
132 |
+
|
133 |
+
def read_video_from_file(self, video_path):
|
134 |
+
try:
|
135 |
+
info = self._get_video_info(video_path)
|
136 |
+
h, w = info["height"], info["width"]
|
137 |
+
except Exception:
|
138 |
+
print('ffprobe failed at: {}'.format(video_path))
|
139 |
+
return {'video': torch.zeros(1), 'input': video_path,
|
140 |
+
'info': {}}
|
141 |
+
height, width = self._get_output_dim(h, w)
|
142 |
+
try:
|
143 |
+
duration = info["duration"]
|
144 |
+
fps = self.framerate
|
145 |
+
if duration > 0 and duration < 1/fps+0.1:
|
146 |
+
fps = 2/max(int(duration), 1)
|
147 |
+
print(duration, fps)
|
148 |
+
except Exception:
|
149 |
+
fps = self.framerate
|
150 |
+
cmd = (
|
151 |
+
ffmpeg
|
152 |
+
.input(video_path)
|
153 |
+
.filter('fps', fps=fps)
|
154 |
+
.filter('scale', width, height)
|
155 |
+
)
|
156 |
+
if self.centercrop:
|
157 |
+
x = int((width - self.size) / 2.0)
|
158 |
+
y = int((height - self.size) / 2.0)
|
159 |
+
cmd = cmd.crop(x, y, self.size, self.size)
|
160 |
+
out, _ = (
|
161 |
+
cmd.output('pipe:', format='rawvideo', pix_fmt='rgb24')
|
162 |
+
.run(capture_stdout=True, quiet=True)
|
163 |
+
)
|
164 |
+
if self.centercrop and isinstance(self.size, int):
|
165 |
+
height, width = self.size, self.size
|
166 |
+
video = np.frombuffer(out, np.uint8).reshape(
|
167 |
+
[-1, height, width, 3])
|
168 |
+
video = torch.from_numpy(video.astype('float32'))
|
169 |
+
video = video.permute(0, 3, 1, 2)
|
170 |
+
return video
|
run_on_video/preprocessing.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch as th
|
2 |
+
|
3 |
+
|
4 |
+
class Normalize(object):
|
5 |
+
|
6 |
+
def __init__(self, mean, std):
|
7 |
+
self.mean = th.FloatTensor(mean).view(1, 3, 1, 1)
|
8 |
+
self.std = th.FloatTensor(std).view(1, 3, 1, 1)
|
9 |
+
|
10 |
+
def __call__(self, tensor):
|
11 |
+
tensor = (tensor - self.mean) / (self.std + 1e-8)
|
12 |
+
return tensor
|
13 |
+
|
14 |
+
|
15 |
+
class Preprocessing(object):
|
16 |
+
|
17 |
+
def __init__(self):
|
18 |
+
self.norm = Normalize(
|
19 |
+
mean=[0.48145466, 0.4578275, 0.40821073],
|
20 |
+
std=[0.26862954, 0.26130258, 0.27577711])
|
21 |
+
|
22 |
+
def __call__(self, tensor):
|
23 |
+
tensor = tensor / 255.0
|
24 |
+
tensor = self.norm(tensor)
|
25 |
+
return tensor
|
run_on_video/text_extractor.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pdb
|
2 |
+
import sys
|
3 |
+
import json
|
4 |
+
import torch
|
5 |
+
import numpy as np
|
6 |
+
from run_on_video.data_utils import ClipFeatureExtractor
|
7 |
+
import torch.nn.functional as F
|
8 |
+
import tqdm
|
9 |
+
import os
|
10 |
+
|
11 |
+
query_list = []
|
12 |
+
qid_list = []
|
13 |
+
dataset = 'charades'
|
14 |
+
split = 'test'
|
15 |
+
|
16 |
+
save_dir = f''
|
17 |
+
|
18 |
+
with open(f"data/{dataset}/metadata/{dataset}_{split}.jsonl", 'r') as f:
|
19 |
+
while True:
|
20 |
+
line = f.readline()
|
21 |
+
if not line:
|
22 |
+
break
|
23 |
+
js = json.loads(line)
|
24 |
+
query_list.append(js['query'])
|
25 |
+
qid_list.append(str(js['qid']))
|
26 |
+
|
27 |
+
# clip
|
28 |
+
feature_extractor = ClipFeatureExtractor(
|
29 |
+
framerate=1 / 2, size=224, centercrop=True,
|
30 |
+
model_name_or_path="ViT-B/32", device='cuda'
|
31 |
+
)
|
32 |
+
# pdb.set_trace()
|
33 |
+
query_feats = feature_extractor.encode_text(query_list)
|
34 |
+
|
35 |
+
for i in tqdm.tqdm(range(len(query_feats))):
|
36 |
+
np.savez(save_dir + '/' + qid_list[i], last_hidden_state=query_feats[i].cpu().numpy())
|
run_on_video/video_extractor.py
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pdb
|
2 |
+
import torch as th
|
3 |
+
import math
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
from run_on_video.video_loader import VideoLoader
|
7 |
+
from torch.utils.data import DataLoader
|
8 |
+
import argparse
|
9 |
+
from run_on_video.preprocessing import Preprocessing
|
10 |
+
import torch.nn.functional as F
|
11 |
+
from tqdm import tqdm
|
12 |
+
import os
|
13 |
+
import sys
|
14 |
+
from run_on_video import clip
|
15 |
+
import argparse
|
16 |
+
|
17 |
+
#################################
|
18 |
+
@torch.no_grad()
|
19 |
+
def vid2clip(model, vid_path, output_file,
|
20 |
+
model_version="ViT-B/32", output_feat_size=512,
|
21 |
+
clip_len=2, overwrite=True, num_decoding_thread=4, half_precision=False):
|
22 |
+
dataset = VideoLoader(
|
23 |
+
vid_path,
|
24 |
+
framerate=1/clip_len,
|
25 |
+
size=224,
|
26 |
+
centercrop=True,
|
27 |
+
overwrite=overwrite,
|
28 |
+
model_version=model_version
|
29 |
+
)
|
30 |
+
n_dataset = len(dataset)
|
31 |
+
loader = DataLoader(
|
32 |
+
dataset,
|
33 |
+
batch_size=1,
|
34 |
+
shuffle=False,
|
35 |
+
num_workers=num_decoding_thread,
|
36 |
+
sampler=None,
|
37 |
+
)
|
38 |
+
preprocess = Preprocessing()
|
39 |
+
device_id = next(model.parameters()).device
|
40 |
+
|
41 |
+
totatl_num_frames = 0
|
42 |
+
with th.no_grad():
|
43 |
+
for k, data in enumerate(tqdm(loader)):
|
44 |
+
input_file = data['input'][0]
|
45 |
+
if os.path.isfile(output_file):
|
46 |
+
# print(f'Video {input_file} already processed.')
|
47 |
+
continue
|
48 |
+
elif not os.path.isfile(input_file):
|
49 |
+
print(f'{input_file}, does not exist.\n')
|
50 |
+
elif len(data['video'].shape) > 4:
|
51 |
+
video = data['video'].squeeze(0)
|
52 |
+
if len(video.shape) == 4:
|
53 |
+
video = preprocess(video)
|
54 |
+
n_chunk = len(video)
|
55 |
+
vid_features = th.cuda.FloatTensor(
|
56 |
+
n_chunk, output_feat_size).fill_(0)
|
57 |
+
n_iter = int(math.ceil(n_chunk))
|
58 |
+
for i in range(n_iter):
|
59 |
+
min_ind = i
|
60 |
+
max_ind = (i + 1)
|
61 |
+
video_batch = video[min_ind:max_ind].to(device_id)
|
62 |
+
batch_features = model.encode_image(video_batch)
|
63 |
+
vid_features[min_ind:max_ind] = batch_features
|
64 |
+
vid_features = vid_features.cpu().numpy()
|
65 |
+
if half_precision:
|
66 |
+
vid_features = vid_features.astype('float16')
|
67 |
+
totatl_num_frames += vid_features.shape[0]
|
68 |
+
# safeguard output path before saving
|
69 |
+
dirname = os.path.dirname(output_file)
|
70 |
+
if not os.path.exists(dirname):
|
71 |
+
print(f"Output directory {dirname} does not exists, creating...")
|
72 |
+
os.makedirs(dirname)
|
73 |
+
np.savez(os.path.join(output_file, 'vid.npz'), features=vid_features)
|
74 |
+
else:
|
75 |
+
print(f'{input_file}, failed at ffprobe.\n')
|
76 |
+
print(f"Total number of frames: {totatl_num_frames}")
|
77 |
+
return vid_features
|
78 |
+
|
79 |
+
def txt2clip(model, text, output_file):
|
80 |
+
device_id = next(model.parameters()).device
|
81 |
+
encoded_texts = clip.tokenize(text).to(device_id)
|
82 |
+
text_feature = model.encode_text(encoded_texts)['last_hidden_state']
|
83 |
+
valid_lengths = (encoded_texts != 0).sum(1).tolist()[0]
|
84 |
+
text_feature = text_feature[0, :valid_lengths].detach().cpu().numpy()
|
85 |
+
|
86 |
+
np.savez(os.path.join(output_file, 'txt.npz'), features=text_feature)
|
87 |
+
return text_feature
|
88 |
+
|
89 |
+
if __name__ == "__main__":
|
90 |
+
parser = argparse.ArgumentParser(description='')
|
91 |
+
parser.add_argument('--vid_path', type=str, default='/data/home/qinghonglin/dataset/charades/videos/Charades_v1_480/0A8CF.mp4')
|
92 |
+
parser.add_argument('--text', nargs='+', type=str, default='a boy is drinking.')
|
93 |
+
parser.add_argument('--save_dir', type=str, default='./tmp')
|
94 |
+
args = parser.parse_args()
|