Spaces:
Running
Running
Mathis Petrovich
commited on
Commit
•
83f52e6
1
Parent(s):
efba2f3
First commit
Browse files- amass-annotations/amass_to_babel.json +0 -0
- amass-annotations/humanml3d.json +0 -0
- app.py +265 -0
- load.py +53 -0
- model.py +129 -0
amass-annotations/amass_to_babel.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
amass-annotations/humanml3d.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
app.py
ADDED
@@ -0,0 +1,265 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import partial
|
2 |
+
import os
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import numpy as np
|
6 |
+
import gradio as gr
|
7 |
+
import gdown
|
8 |
+
|
9 |
+
from load import load_model, load_json
|
10 |
+
from load import load_unit_motion_embs_splits, load_keyids_splits
|
11 |
+
|
12 |
+
|
13 |
+
EXAMPLES = [
|
14 |
+
"A person is walking in a circle",
|
15 |
+
"A person is jumping rope",
|
16 |
+
"Someone is doing a backflip",
|
17 |
+
"A person is doing a moonwalk",
|
18 |
+
"A person walks forward and then turns back",
|
19 |
+
"Picking up an object",
|
20 |
+
"A person is swimming in the sea",
|
21 |
+
"A human is squatting",
|
22 |
+
"Someone is jumping with one foot",
|
23 |
+
"A person is chopping vegetables",
|
24 |
+
"Someone walks backward",
|
25 |
+
"Somebody is ascending a staircase",
|
26 |
+
"A person is sitting down",
|
27 |
+
"A person is taking the stairs",
|
28 |
+
"Someone is doing jumping jacks",
|
29 |
+
"The person walked forward and is picking up his toolbox",
|
30 |
+
"The person angrily punching the air."
|
31 |
+
]
|
32 |
+
|
33 |
+
# Show closest text in the training
|
34 |
+
|
35 |
+
|
36 |
+
# css to make videos look nice
|
37 |
+
CSS = """
|
38 |
+
video {
|
39 |
+
position: relative;
|
40 |
+
margin: 0;
|
41 |
+
box-shadow: var(--block-shadow);
|
42 |
+
border-width: var(--block-border-width);
|
43 |
+
border-color: var(--block-border-color);
|
44 |
+
border-radius: var(--block-radius);
|
45 |
+
background: var(--block-background-fill);
|
46 |
+
width: 100%;
|
47 |
+
line-height: var(--line-sm);
|
48 |
+
}
|
49 |
+
"""
|
50 |
+
|
51 |
+
|
52 |
+
def humanml3d_keyid_to_babel_rendered_url(h3d_index, amass_to_babel, keyid):
|
53 |
+
# Don't show the mirrored version of HumanMl3D
|
54 |
+
if "M" in keyid:
|
55 |
+
return None
|
56 |
+
|
57 |
+
dico = h3d_index[keyid]
|
58 |
+
path = dico["path"]
|
59 |
+
|
60 |
+
# HumanAct12 motions are not rendered online
|
61 |
+
# so we skip them for now
|
62 |
+
if "humanact12" in path:
|
63 |
+
return None
|
64 |
+
|
65 |
+
# This motion is not rendered in BABEL
|
66 |
+
# so we skip them for now
|
67 |
+
if path not in amass_to_babel:
|
68 |
+
return None
|
69 |
+
|
70 |
+
babel_id = amass_to_babel[path].zfill(6)
|
71 |
+
url = f"https://babel-renders.s3.eu-central-1.amazonaws.com/{babel_id}.mp4"
|
72 |
+
|
73 |
+
# For the demo, we retrieve from the first annotation only
|
74 |
+
ann = dico["annotations"][0]
|
75 |
+
start = ann["start"]
|
76 |
+
end = ann["end"]
|
77 |
+
text = ann["text"]
|
78 |
+
|
79 |
+
data = {
|
80 |
+
"url": url,
|
81 |
+
"start": start,
|
82 |
+
"end": end,
|
83 |
+
"text": text,
|
84 |
+
"keyid": keyid,
|
85 |
+
"babel_id": babel_id
|
86 |
+
}
|
87 |
+
|
88 |
+
return data
|
89 |
+
|
90 |
+
|
91 |
+
def retrieve(model, keyid_to_url, all_unit_motion_embs, all_keyids, text, splits=["test"], nmax=8):
|
92 |
+
unit_motion_embs = torch.cat([all_unit_motion_embs[s] for s in splits])
|
93 |
+
keyids = np.concatenate([all_keyids[s] for s in splits])
|
94 |
+
|
95 |
+
scores = model.compute_scores(text, unit_embs=unit_motion_embs)
|
96 |
+
|
97 |
+
sorted_idxs = np.argsort(-scores)
|
98 |
+
best_keyids = keyids[sorted_idxs]
|
99 |
+
best_scores = scores[sorted_idxs]
|
100 |
+
|
101 |
+
datas = []
|
102 |
+
for keyid, score in zip(best_keyids, best_scores):
|
103 |
+
if len(datas) == nmax:
|
104 |
+
break
|
105 |
+
|
106 |
+
data = keyid_to_url(keyid)
|
107 |
+
if data is None:
|
108 |
+
continue
|
109 |
+
data["score"] = round(float(score), 2)
|
110 |
+
datas.append(data)
|
111 |
+
return datas
|
112 |
+
|
113 |
+
|
114 |
+
# HTML component
|
115 |
+
def get_video_html(url, video_id, start=None, end=None, score=None, width=350, height=350):
|
116 |
+
trim = ""
|
117 |
+
if start is not None:
|
118 |
+
if end is not None:
|
119 |
+
trim = f"#t={start},{end}"
|
120 |
+
else:
|
121 |
+
trim = f"#t={start}"
|
122 |
+
|
123 |
+
score_t = ""
|
124 |
+
if score is not None:
|
125 |
+
score_t = f'title="Score = {score}"'
|
126 |
+
|
127 |
+
video_html = f'''
|
128 |
+
<video preload="auto" muted playsinline onpause="this.load()"
|
129 |
+
autoplay loop disablepictureinpicture id="{video_id}" width="{width}" height="{height}" {score_t}>
|
130 |
+
<source src="{url}{trim}" type="video/mp4">
|
131 |
+
Your browser does not support the video tag.
|
132 |
+
</video>
|
133 |
+
'''
|
134 |
+
return video_html
|
135 |
+
|
136 |
+
|
137 |
+
def retrive_component(retrieve_function, text, splits, nvids, n_component=16):
|
138 |
+
# cannot produce more than n_compoenent
|
139 |
+
nvids = min(nvids, n_component)
|
140 |
+
if not splits:
|
141 |
+
return [None for _ in range(n_component)]
|
142 |
+
|
143 |
+
splits_l = [x.lower() for x in splits]
|
144 |
+
datas = retrieve_function(text, splits=splits_l, nmax=nvids)
|
145 |
+
htmls = [
|
146 |
+
get_video_html(
|
147 |
+
url["url"], idx, start=url["start"],
|
148 |
+
end=url["end"], score=url["score"]
|
149 |
+
)
|
150 |
+
for idx, url in enumerate(datas)
|
151 |
+
]
|
152 |
+
# get n_component exactly if asked less
|
153 |
+
# pad with dummy blocks
|
154 |
+
htmls = htmls + [None for _ in range(max(0, n_component-nvids))]
|
155 |
+
return htmls
|
156 |
+
|
157 |
+
|
158 |
+
def main():
|
159 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
160 |
+
|
161 |
+
# LOADING
|
162 |
+
model = load_model(device)
|
163 |
+
splits = ["train", "val", "test"]
|
164 |
+
all_unit_motion_embs = load_unit_motion_embs_splits(splits, device)
|
165 |
+
all_keyids = load_keyids_splits(splits)
|
166 |
+
|
167 |
+
h3d_index = load_json("amass-annotations/humanml3d.json")
|
168 |
+
amass_to_babel = load_json("amass-annotations/amass_to_babel.json")
|
169 |
+
|
170 |
+
keyid_to_url = partial(humanml3d_keyid_to_babel_rendered_url, h3d_index, amass_to_babel)
|
171 |
+
retrieve_function = partial(retrieve, model, keyid_to_url, all_unit_motion_embs, all_keyids)
|
172 |
+
|
173 |
+
# DEMO
|
174 |
+
theme = gr.themes.Default(primary_hue="blue", secondary_hue="gray")
|
175 |
+
retrive_and_show = partial(retrive_component, retrieve_function)
|
176 |
+
|
177 |
+
default_text = "A person is "
|
178 |
+
|
179 |
+
with gr.Blocks(css=CSS, theme=theme) as demo:
|
180 |
+
title = "<h1 style='text-align: center'>TMR: Text-to-Motion Retrieval Using Contrastive 3D Human Motion Synthesis </h1>"
|
181 |
+
gr.Markdown(title)
|
182 |
+
|
183 |
+
authors = """
|
184 |
+
<h2 style='text-align: center'>
|
185 |
+
<a href="https://mathis.petrovich.fr" target="_blank"><nobr>Mathis Petrovich</nobr></a>  
|
186 |
+
<a href="https://ps.is.mpg.de/~black" target="_blank"><nobr>Michael J. Black</nobr></a>  
|
187 |
+
<a href="https://imagine.enpc.fr/~varolg" target="_blank"><nobr>Gül Varol</nobr></a>
|
188 |
+
</h2>
|
189 |
+
"""
|
190 |
+
gr.Markdown(authors)
|
191 |
+
|
192 |
+
conf = """
|
193 |
+
<h2 style='text-align: center'>
|
194 |
+
<nobr>arXiv 2023</nobr>
|
195 |
+
</h2>
|
196 |
+
"""
|
197 |
+
gr.Markdown(conf)
|
198 |
+
|
199 |
+
videos = []
|
200 |
+
|
201 |
+
with gr.Row():
|
202 |
+
with gr.Column(scale=3):
|
203 |
+
with gr.Column(scale=2):
|
204 |
+
text = gr.Textbox(placeholder="Type in natural language, the motion to retrieve",
|
205 |
+
show_label=True, label="Text prompt", value=default_text)
|
206 |
+
with gr.Column(scale=1):
|
207 |
+
btn = gr.Button("Retrieve", variant='primary')
|
208 |
+
clear = gr.Button("Clear", variant='secondary')
|
209 |
+
|
210 |
+
with gr.Row():
|
211 |
+
with gr.Column(scale=1):
|
212 |
+
splits = gr.Dropdown(["Train", "Val", "Test"],
|
213 |
+
value=["Test"], multiselect=True, label="Splits",
|
214 |
+
info="HumanML3D data used for the motion database")
|
215 |
+
with gr.Column(scale=1):
|
216 |
+
nvideo_slider = gr.Slider(minimum=4, maximum=16, step=4, value=8, label="Number of videos")
|
217 |
+
with gr.Column(scale=2):
|
218 |
+
examples = gr.Examples(examples=EXAMPLES, inputs=text, examples_per_page=15)
|
219 |
+
|
220 |
+
i = -1
|
221 |
+
# should indent
|
222 |
+
for _ in range(4):
|
223 |
+
with gr.Row():
|
224 |
+
for _ in range(4):
|
225 |
+
i += 1
|
226 |
+
with gr.Column():
|
227 |
+
video = gr.HTML()
|
228 |
+
videos.append(video)
|
229 |
+
|
230 |
+
def check_error(splits):
|
231 |
+
if not splits:
|
232 |
+
raise gr.Error("At least one split should be selected!")
|
233 |
+
return splits
|
234 |
+
|
235 |
+
btn.click(fn=retrive_and_show, inputs=[text, splits, nvideo_slider], outputs=videos).then(
|
236 |
+
fn=check_error, inputs=splits
|
237 |
+
)
|
238 |
+
|
239 |
+
text.submit(fn=retrive_and_show, inputs=[text, splits, nvideo_slider], outputs=videos).then(
|
240 |
+
fn=check_error, inputs=splits
|
241 |
+
)
|
242 |
+
|
243 |
+
def keep_test(splits):
|
244 |
+
if len(splits) == 0:
|
245 |
+
return ["Test"]
|
246 |
+
return splits
|
247 |
+
|
248 |
+
def clear_videos():
|
249 |
+
return [None for x in range(16)] + [default_text]
|
250 |
+
|
251 |
+
clear.click(fn=clear_videos, outputs=videos + [text])
|
252 |
+
demo.launch()
|
253 |
+
|
254 |
+
|
255 |
+
def prepare():
|
256 |
+
if not os.path.exists("data"):
|
257 |
+
gdown.download_folder("https://drive.google.com/drive/folders/1MgPFgHZ28AMd01M1tJ7YW_1-ut3-4j08", use_cookies=False)
|
258 |
+
|
259 |
+
|
260 |
+
if __name__ == "__main__":
|
261 |
+
prepare()
|
262 |
+
main()
|
263 |
+
|
264 |
+
# new
|
265 |
+
# A person is walking slowly
|
load.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import orjson
|
3 |
+
import torch
|
4 |
+
import numpy as np
|
5 |
+
from model import TMR_textencoder
|
6 |
+
|
7 |
+
EMBS = "data/unit_motion_embs"
|
8 |
+
|
9 |
+
|
10 |
+
def load_json(path):
|
11 |
+
with open(path, "rb") as ff:
|
12 |
+
return orjson.loads(ff.read())
|
13 |
+
|
14 |
+
|
15 |
+
def load_keyids(split):
|
16 |
+
path = os.path.join(EMBS, f"{split}.keyids")
|
17 |
+
with open(path) as ff:
|
18 |
+
keyids = np.array([x.strip() for x in ff.readlines()])
|
19 |
+
return keyids
|
20 |
+
|
21 |
+
|
22 |
+
def load_keyids_splits(splits):
|
23 |
+
return {
|
24 |
+
split: load_keyids(split)
|
25 |
+
for split in splits
|
26 |
+
}
|
27 |
+
|
28 |
+
|
29 |
+
def load_unit_motion_embs(split, device):
|
30 |
+
path = os.path.join(EMBS, f"{split}_motion_embs_unit.npy")
|
31 |
+
tensor = torch.from_numpy(np.load(path)).to(device)
|
32 |
+
return tensor
|
33 |
+
|
34 |
+
|
35 |
+
def load_unit_motion_embs_splits(splits, device):
|
36 |
+
return {
|
37 |
+
split: load_unit_motion_embs(split, device)
|
38 |
+
for split in splits
|
39 |
+
}
|
40 |
+
|
41 |
+
|
42 |
+
def load_model(device):
|
43 |
+
text_params = {
|
44 |
+
'latent_dim': 256, 'ff_size': 1024, 'num_layers': 6, 'num_heads': 4,
|
45 |
+
'activation': 'gelu', 'modelpath': 'distilbert-base-uncased'
|
46 |
+
}
|
47 |
+
"unit_motion_embs"
|
48 |
+
model = TMR_textencoder(**text_params)
|
49 |
+
state_dict = torch.load("data/textencoder.pt", map_location=device)
|
50 |
+
# load values for the transformer only
|
51 |
+
model.load_state_dict(state_dict, strict=False)
|
52 |
+
model = model.eval()
|
53 |
+
return model
|
model.py
ADDED
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Union
|
2 |
+
import torch.nn as nn
|
3 |
+
import os
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import numpy as np
|
7 |
+
from torch import Tensor
|
8 |
+
from transformers import AutoTokenizer, AutoModel
|
9 |
+
from transformers import logging
|
10 |
+
from torch.nn.functional import normalize
|
11 |
+
|
12 |
+
|
13 |
+
class PositionalEncoding(nn.Module):
|
14 |
+
def __init__(self, d_model, max_len=5000):
|
15 |
+
super().__init__()
|
16 |
+
|
17 |
+
pe = torch.zeros(max_len, d_model)
|
18 |
+
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
|
19 |
+
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))
|
20 |
+
pe[:, 0::2] = torch.sin(position * div_term)
|
21 |
+
pe[:, 1::2] = torch.cos(position * div_term)
|
22 |
+
pe = pe.unsqueeze(0).transpose(0, 1)
|
23 |
+
|
24 |
+
self.register_buffer('pe', pe, persistent=False)
|
25 |
+
|
26 |
+
def forward(self, x):
|
27 |
+
return x + self.pe[:x.shape[0], :]
|
28 |
+
|
29 |
+
|
30 |
+
class TMR_textencoder(nn.Module):
|
31 |
+
def __init__(self, modelpath: str, latent_dim: int, ff_size: int,
|
32 |
+
num_layers: int, num_heads: int, activation: str, **kwargs) -> None:
|
33 |
+
super().__init__()
|
34 |
+
|
35 |
+
logging.set_verbosity_error()
|
36 |
+
|
37 |
+
# Tokenizer
|
38 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
39 |
+
self.tokenizer = AutoTokenizer.from_pretrained(modelpath)
|
40 |
+
|
41 |
+
# Text model
|
42 |
+
self.text_model = AutoModel.from_pretrained(modelpath)
|
43 |
+
# Then configure the model
|
44 |
+
self.text_encoded_dim = self.text_model.config.hidden_size
|
45 |
+
|
46 |
+
# Projection of the text-outputs into the latent space
|
47 |
+
self.projection = nn.Sequential(
|
48 |
+
nn.ReLU(),
|
49 |
+
nn.Linear(self.text_encoded_dim, latent_dim)
|
50 |
+
)
|
51 |
+
|
52 |
+
self.mu_token = nn.Parameter(torch.randn(latent_dim))
|
53 |
+
self.logvar_token = nn.Parameter(torch.randn(latent_dim))
|
54 |
+
self.sequence_pos_encoding = PositionalEncoding(latent_dim)
|
55 |
+
|
56 |
+
seq_trans_encoder_layer = nn.TransformerEncoderLayer(d_model=latent_dim,
|
57 |
+
nhead=num_heads,
|
58 |
+
dim_feedforward=ff_size,
|
59 |
+
dropout=0.0,
|
60 |
+
activation=activation)
|
61 |
+
self.seqTransEncoder = nn.TransformerEncoder(
|
62 |
+
seq_trans_encoder_layer,
|
63 |
+
num_layers=num_layers
|
64 |
+
)
|
65 |
+
|
66 |
+
def get_last_hidden_state(self, texts: List[str],
|
67 |
+
return_mask: bool = False
|
68 |
+
) -> Union[Tensor, tuple[Tensor, Tensor]]:
|
69 |
+
encoded_inputs = self.tokenizer(texts, return_tensors="pt", padding=True)
|
70 |
+
output = self.text_model(**encoded_inputs.to(self.text_model.device))
|
71 |
+
if not return_mask:
|
72 |
+
return output.last_hidden_state
|
73 |
+
return output.last_hidden_state, encoded_inputs.attention_mask.to(dtype=bool)
|
74 |
+
|
75 |
+
def forward(self, texts: List[str]) -> Tensor:
|
76 |
+
text_encoded, mask = self.get_last_hidden_state(texts, return_mask=True)
|
77 |
+
|
78 |
+
x = self.projection(text_encoded)
|
79 |
+
bs, nframes, _ = x.shape
|
80 |
+
# bs, nframes, totjoints, nfeats = x.shape
|
81 |
+
# Switch sequence and batch_size because the input of
|
82 |
+
# Pytorch Transformer is [Sequence, Batch size, ...]
|
83 |
+
x = x.permute(1, 0, 2) # now it is [nframes, bs, latent_dim]
|
84 |
+
|
85 |
+
mu_token = torch.tile(self.mu_token, (bs,)).reshape(bs, -1)
|
86 |
+
logvar_token = torch.tile(self.logvar_token, (bs,)).reshape(bs, -1)
|
87 |
+
|
88 |
+
# adding the distribution tokens for all sequences
|
89 |
+
xseq = torch.cat((mu_token[None], logvar_token[None], x), 0)
|
90 |
+
|
91 |
+
# create a bigger mask, to allow attend to mu and logvar
|
92 |
+
token_mask = torch.ones((bs, 2), dtype=bool, device=x.device)
|
93 |
+
aug_mask = torch.cat((token_mask, mask), 1)
|
94 |
+
|
95 |
+
# add positional encoding
|
96 |
+
xseq = self.sequence_pos_encoding(xseq)
|
97 |
+
final = self.seqTransEncoder(xseq, src_key_padding_mask=~aug_mask)
|
98 |
+
|
99 |
+
# only mu for inference
|
100 |
+
mu = final[0]
|
101 |
+
return mu
|
102 |
+
|
103 |
+
# compute score for retrieval
|
104 |
+
def compute_scores(self, texts, unit_embs=None, embs=None):
|
105 |
+
# not both empty
|
106 |
+
assert not (unit_embs is None and embs is None)
|
107 |
+
# not both filled
|
108 |
+
assert not (unit_embs is not None and embs is not None)
|
109 |
+
|
110 |
+
output_str = False
|
111 |
+
# if one input, squeeze the output
|
112 |
+
if isinstance(texts, str):
|
113 |
+
texts = [texts]
|
114 |
+
output_str = True
|
115 |
+
|
116 |
+
# compute unit_embs from embs if not given
|
117 |
+
if embs is not None:
|
118 |
+
unit_embs = normalize(embs)
|
119 |
+
|
120 |
+
with torch.no_grad():
|
121 |
+
latent_unit_texts = normalize(self(texts))
|
122 |
+
# compute cosine similarity between 0 and 1
|
123 |
+
scores = (unit_embs @ latent_unit_texts.T).T/2 + 0.5
|
124 |
+
scores = scores.cpu().numpy()
|
125 |
+
|
126 |
+
if output_str:
|
127 |
+
scores = scores[0]
|
128 |
+
|
129 |
+
return scores
|