Spaces:
4RiZ4
/
Build error

4RiZ4 Mathux commited on
Commit
9e2bdc0
·
0 Parent(s):

Duplicate from Mathux/TMR

Browse files

Co-authored-by: Mathis Petrovich <Mathux@users.noreply.huggingface.co>

.gitattributes ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tflite filter=lfs diff=lfs merge=lfs -text
29
+ *.tgz filter=lfs diff=lfs merge=lfs -text
30
+ *.wasm filter=lfs diff=lfs merge=lfs -text
31
+ *.xz filter=lfs diff=lfs merge=lfs -text
32
+ *.zip filter=lfs diff=lfs merge=lfs -text
33
+ *.zst filter=lfs diff=lfs merge=lfs -text
34
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: TMR
3
+ emoji: 🐨
4
+ colorFrom: yellow
5
+ colorTo: purple
6
+ sdk: gradio
7
+ sdk_version: 3.24.1
8
+ app_file: app.py
9
+ pinned: false
10
+ duplicated_from: Mathux/TMR
11
+ ---
12
+
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
amass-annotations/amass_to_babel.json ADDED
The diff for this file is too large to render. See raw diff
 
amass-annotations/humanml3d.json ADDED
The diff for this file is too large to render. See raw diff
 
app.py ADDED
@@ -0,0 +1,313 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+ import os
3
+
4
+ import torch
5
+ import numpy as np
6
+ import gradio as gr
7
+ import gdown
8
+
9
+ from load import load_model, load_json
10
+ from load import load_unit_motion_embs_splits, load_keyids_splits
11
+
12
+
13
+ WEBSITE = """
14
+ <div class="embed_hidden">
15
+ <h1 style='text-align: center'>TMR: Text-to-Motion Retrieval Using Contrastive 3D Human Motion Synthesis </h1>
16
+
17
+ <h2 style='text-align: center'>
18
+ <a href="https://mathis.petrovich.fr" target="_blank"><nobr>Mathis Petrovich</nobr></a> &emsp;
19
+ <a href="https://ps.is.mpg.de/~black" target="_blank"><nobr>Michael J. Black</nobr></a> &emsp;
20
+ <a href="https://imagine.enpc.fr/~varolg" target="_blank"><nobr>G&uumll Varol</nobr></a>
21
+ </h2>
22
+
23
+ <h2 style='text-align: center'>
24
+ <nobr>arXiv 2023</nobr>
25
+ </h2>
26
+
27
+ <h3 style="text-align:center;">
28
+ <a target="_blank" href="https://arxiv.org/abs/2305.00976"> <button type="button" class="btn btn-primary btn-lg"> Paper </button></a>
29
+ <a target="_blank" href="https://github.com/Mathux/TMR"> <button type="button" class="btn btn-primary btn-lg"> Code </button></a>
30
+ <a target="_blank" href="https://mathis.petrovich.fr/tmr"> <button type="button" class="btn btn-primary btn-lg"> Webpage </button></a>
31
+ <a target="_blank" href="https://mathis.petrovich.fr/tmr/tmr.bib"> <button type="button" class="btn btn-primary btn-lg"> BibTex </button></a>
32
+ </h3>
33
+
34
+ <h3> Description </h3>
35
+ <p>
36
+ This space illustrates <a href='https://mathis.petrovich.fr/tmr/' target='_blank'><b>TMR</b></a>, a method for text-to-motion retrieval. Given a gallery of 3D human motions (which can be unseen during training) and a text query, the goal is to search for motions which are close to the text query.
37
+ </p>
38
+ </div>
39
+ """
40
+
41
+ EXAMPLES = [
42
+ "A person is walking slowly",
43
+ "A person is walking in a circle",
44
+ "A person is jumping rope",
45
+ "Someone is doing a backflip",
46
+ "A person is doing a moonwalk",
47
+ "A person walks forward and then turns back",
48
+ "Picking up an object",
49
+ "A person is swimming in the sea",
50
+ "A human is squatting",
51
+ "Someone is jumping with one foot",
52
+ "A person is chopping vegetables",
53
+ "Someone walks backward",
54
+ "Somebody is ascending a staircase",
55
+ "A person is sitting down",
56
+ "A person is taking the stairs",
57
+ "Someone is doing jumping jacks",
58
+ "The person walked forward and is picking up his toolbox",
59
+ "The person angrily punching the air"
60
+ ]
61
+
62
+ # Show closest text in the training
63
+
64
+
65
+ # css to make videos look nice
66
+ # var(--block-border-color);
67
+ CSS = """
68
+ .retrieved_video {
69
+ position: relative;
70
+ margin: 0;
71
+ box-shadow: var(--block-shadow);
72
+ border-width: var(--block-border-width);
73
+ border-color: #000000;
74
+ border-radius: var(--block-radius);
75
+ background: var(--block-background-fill);
76
+ width: 100%;
77
+ line-height: var(--line-sm);
78
+ }
79
+
80
+ .contour_video {
81
+ display: flex;
82
+ flex-direction: column;
83
+ justify-content: center;
84
+ align-items: center;
85
+ z-index: var(--layer-5);
86
+ border-radius: var(--block-radius);
87
+ background: var(--background-fill-primary);
88
+ padding: 0 var(--size-6);
89
+ max-height: var(--size-screen-h);
90
+ overflow: hidden;
91
+ }
92
+ """
93
+
94
+
95
+ DEFAULT_TEXT = "A person is "
96
+
97
+ def humanml3d_keyid_to_babel_rendered_url(h3d_index, amass_to_babel, keyid):
98
+ # Don't show the mirrored version of HumanMl3D
99
+ if "M" in keyid:
100
+ return None
101
+
102
+ dico = h3d_index[keyid]
103
+ path = dico["path"]
104
+
105
+ # HumanAct12 motions are not rendered online
106
+ # so we skip them for now
107
+ if "humanact12" in path:
108
+ return None
109
+
110
+ # This motion is not rendered in BABEL
111
+ # so we skip them for now
112
+ if path not in amass_to_babel:
113
+ return None
114
+
115
+ babel_id = amass_to_babel[path].zfill(6)
116
+ url = f"https://babel-renders.s3.eu-central-1.amazonaws.com/{babel_id}.mp4"
117
+
118
+ # For the demo, we retrieve from the first annotation only
119
+ ann = dico["annotations"][0]
120
+ start = ann["start"]
121
+ end = ann["end"]
122
+ text = ann["text"]
123
+
124
+ data = {
125
+ "url": url,
126
+ "start": start,
127
+ "end": end,
128
+ "text": text,
129
+ "keyid": keyid,
130
+ "babel_id": babel_id,
131
+ "path": path
132
+ }
133
+
134
+ return data
135
+
136
+
137
+ def retrieve(model, keyid_to_url, all_unit_motion_embs, all_keyids, text, splits=["test"], nmax=8):
138
+ unit_motion_embs = torch.cat([all_unit_motion_embs[s] for s in splits])
139
+ keyids = np.concatenate([all_keyids[s] for s in splits])
140
+
141
+ scores = model.compute_scores(text, unit_embs=unit_motion_embs)
142
+
143
+ sorted_idxs = np.argsort(-scores)
144
+ best_keyids = keyids[sorted_idxs]
145
+ best_scores = scores[sorted_idxs]
146
+
147
+ datas = []
148
+ for keyid, score in zip(best_keyids, best_scores):
149
+ if len(datas) == nmax:
150
+ break
151
+
152
+ data = keyid_to_url(keyid)
153
+ if data is None:
154
+ continue
155
+ data["score"] = round(float(score), 2)
156
+ datas.append(data)
157
+ return datas
158
+
159
+
160
+ # HTML component
161
+ def get_video_html(data, video_id, width=700, height=700):
162
+ url = data["url"]
163
+ start = data["start"]
164
+ end = data["end"]
165
+ score = data["score"]
166
+ text = data["text"]
167
+ keyid = data["keyid"]
168
+ babel_id = data["babel_id"]
169
+ path = data["path"]
170
+
171
+ trim = f"#t={start},{end}"
172
+ title = f'''Score = {score}
173
+
174
+ Corresponding text: {text}
175
+
176
+ HumanML3D keyid: {keyid}
177
+
178
+ BABEL keyid: {babel_id}
179
+
180
+ AMASS path: {path}'''
181
+
182
+ # class="wrap default svelte-gjihhp hide"
183
+ # <div class="contour_video" style="position: absolute; padding: 10px;">
184
+ # width="{width}" height="{height}"
185
+ video_html = f'''
186
+ <video class="retrieved_video" width="{width}" height="{height}" preload="auto" muted playsinline onpause="this.load()"
187
+ autoplay loop disablepictureinpicture id="{video_id}" title="{title}">
188
+ <source src="{url}{trim}" type="video/mp4">
189
+ Your browser does not support the video tag.
190
+ </video>
191
+ '''
192
+ return video_html
193
+
194
+
195
+ def retrieve_component(retrieve_function, text, splits_choice, nvids, n_component=24):
196
+ if text == DEFAULT_TEXT or text == "" or text is None:
197
+ return [None for _ in range(n_component)]
198
+
199
+ # cannot produce more than n_compoenent
200
+ nvids = min(nvids, n_component)
201
+
202
+ if "Unseen" in splits_choice:
203
+ splits = ["test"]
204
+ else:
205
+ splits = ["train", "val", "test"]
206
+
207
+ datas = retrieve_function(text, splits=splits, nmax=nvids)
208
+ htmls = [get_video_html(data, idx) for idx, data in enumerate(datas)]
209
+ # get n_component exactly if asked less
210
+ # pad with dummy blocks
211
+ htmls = htmls + [None for _ in range(max(0, n_component-nvids))]
212
+ return htmls
213
+
214
+
215
+ if not os.path.exists("data"):
216
+ gdown.download_folder("https://drive.google.com/drive/folders/1MgPFgHZ28AMd01M1tJ7YW_1-ut3-4j08",
217
+ use_cookies=False)
218
+
219
+
220
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
221
+
222
+ # LOADING
223
+ model = load_model(device)
224
+ splits = ["train", "val", "test"]
225
+ all_unit_motion_embs = load_unit_motion_embs_splits(splits, device)
226
+ all_keyids = load_keyids_splits(splits)
227
+
228
+ h3d_index = load_json("amass-annotations/humanml3d.json")
229
+ amass_to_babel = load_json("amass-annotations/amass_to_babel.json")
230
+
231
+ keyid_to_url = partial(humanml3d_keyid_to_babel_rendered_url, h3d_index, amass_to_babel)
232
+ retrieve_function = partial(retrieve, model, keyid_to_url, all_unit_motion_embs, all_keyids)
233
+
234
+ # DEMO
235
+ theme = gr.themes.Default(primary_hue="blue", secondary_hue="gray")
236
+ retrieve_and_show = partial(retrieve_component, retrieve_function)
237
+
238
+ with gr.Blocks(css=CSS, theme=theme) as demo:
239
+ gr.Markdown(WEBSITE)
240
+ videos = []
241
+
242
+ with gr.Row():
243
+ with gr.Column(scale=3):
244
+ with gr.Column(scale=2):
245
+ text = gr.Textbox(placeholder="Type the motion you want to search with a sentence",
246
+ show_label=True, label="Text prompt", value=DEFAULT_TEXT)
247
+ with gr.Column(scale=1):
248
+ btn = gr.Button("Retrieve", variant='primary')
249
+ clear = gr.Button("Clear", variant='secondary')
250
+
251
+ with gr.Row():
252
+ with gr.Column(scale=1):
253
+ splits_choice = gr.Radio(["All motions", "Unseen motions"], label="Gallery of motion",
254
+ value="All motions",
255
+ info="The motion gallery is coming from HumanML3D")
256
+
257
+ with gr.Column(scale=1):
258
+ # nvideo_slider = gr.Slider(minimum=4, maximum=24, step=4, value=8, label="Number of videos")
259
+ nvideo_slider = gr.Radio([4, 8, 12, 16, 24], label="Videos",
260
+ value=8,
261
+ info="Number of videos to display")
262
+
263
+ with gr.Column(scale=2):
264
+ def retrieve_example(text, splits_choice, nvideo_slider):
265
+ return retrieve_and_show(text, splits_choice, nvideo_slider)
266
+
267
+ examples = gr.Examples(examples=[[x, None, None] for x in EXAMPLES],
268
+ inputs=[text, splits_choice, nvideo_slider],
269
+ examples_per_page=20,
270
+ run_on_click=False, cache_examples=False,
271
+ fn=retrieve_example, outputs=[])
272
+
273
+ i = -1
274
+ # should indent
275
+ for _ in range(6):
276
+ with gr.Row():
277
+ for _ in range(4):
278
+ i += 1
279
+ video = gr.HTML()
280
+ videos.append(video)
281
+
282
+ # connect the examples to the output
283
+ # a bit hacky
284
+ examples.outputs = videos
285
+
286
+ def load_example(example_id):
287
+ processed_example = examples.non_none_processed_examples[example_id]
288
+ return gr.utils.resolve_singleton(processed_example)
289
+
290
+ examples.dataset.click(
291
+ load_example,
292
+ inputs=[examples.dataset],
293
+ outputs=examples.inputs_with_examples, # type: ignore
294
+ show_progress=False,
295
+ postprocess=False,
296
+ queue=False,
297
+ ).then(
298
+ fn=retrieve_example,
299
+ inputs=examples.inputs,
300
+ outputs=videos
301
+ )
302
+
303
+ btn.click(fn=retrieve_and_show, inputs=[text, splits_choice, nvideo_slider], outputs=videos)
304
+ text.submit(fn=retrieve_and_show, inputs=[text, splits_choice, nvideo_slider], outputs=videos)
305
+ splits_choice.change(fn=retrieve_and_show, inputs=[text, splits_choice, nvideo_slider], outputs=videos)
306
+ nvideo_slider.change(fn=retrieve_and_show, inputs=[text, splits_choice, nvideo_slider], outputs=videos)
307
+
308
+ def clear_videos():
309
+ return [None for x in range(24)] + [DEFAULT_TEXT]
310
+
311
+ clear.click(fn=clear_videos, outputs=videos + [text])
312
+
313
+ demo.launch()
load.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import orjson
3
+ import torch
4
+ import numpy as np
5
+ from model import TMR_textencoder
6
+
7
+ EMBS = "data/unit_motion_embs"
8
+
9
+
10
+ def load_json(path):
11
+ with open(path, "rb") as ff:
12
+ return orjson.loads(ff.read())
13
+
14
+
15
+ def load_keyids(split):
16
+ path = os.path.join(EMBS, f"{split}.keyids")
17
+ with open(path) as ff:
18
+ keyids = np.array([x.strip() for x in ff.readlines()])
19
+ return keyids
20
+
21
+
22
+ def load_keyids_splits(splits):
23
+ return {
24
+ split: load_keyids(split)
25
+ for split in splits
26
+ }
27
+
28
+
29
+ def load_unit_motion_embs(split, device):
30
+ path = os.path.join(EMBS, f"{split}_motion_embs_unit.npy")
31
+ tensor = torch.from_numpy(np.load(path)).to(device)
32
+ return tensor
33
+
34
+
35
+ def load_unit_motion_embs_splits(splits, device):
36
+ return {
37
+ split: load_unit_motion_embs(split, device)
38
+ for split in splits
39
+ }
40
+
41
+
42
+ def load_model(device):
43
+ text_params = {
44
+ 'latent_dim': 256, 'ff_size': 1024, 'num_layers': 6, 'num_heads': 4,
45
+ 'activation': 'gelu', 'modelpath': 'distilbert-base-uncased'
46
+ }
47
+ "unit_motion_embs"
48
+ model = TMR_textencoder(**text_params)
49
+ state_dict = torch.load("data/textencoder.pt", map_location=device)
50
+ # load values for the transformer only
51
+ model.load_state_dict(state_dict, strict=False)
52
+ model = model.eval()
53
+ return model
model.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+ import torch.nn as nn
3
+ import os
4
+
5
+ import torch
6
+ import numpy as np
7
+ from torch import Tensor
8
+ from transformers import AutoTokenizer, AutoModel
9
+ from transformers import logging
10
+ from torch.nn.functional import normalize
11
+
12
+
13
+ class PositionalEncoding(nn.Module):
14
+ def __init__(self, d_model, max_len=5000):
15
+ super().__init__()
16
+
17
+ pe = torch.zeros(max_len, d_model)
18
+ position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
19
+ div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))
20
+ pe[:, 0::2] = torch.sin(position * div_term)
21
+ pe[:, 1::2] = torch.cos(position * div_term)
22
+ pe = pe.unsqueeze(0).transpose(0, 1)
23
+
24
+ self.register_buffer('pe', pe, persistent=False)
25
+
26
+ def forward(self, x):
27
+ return x + self.pe[:x.shape[0], :]
28
+
29
+
30
+ class TMR_textencoder(nn.Module):
31
+ def __init__(self, modelpath: str, latent_dim: int, ff_size: int,
32
+ num_layers: int, num_heads: int, activation: str, **kwargs) -> None:
33
+ super().__init__()
34
+
35
+ logging.set_verbosity_error()
36
+
37
+ # Tokenizer
38
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
39
+ self.tokenizer = AutoTokenizer.from_pretrained(modelpath)
40
+
41
+ # Text model
42
+ self.text_model = AutoModel.from_pretrained(modelpath)
43
+ # Then configure the model
44
+ self.text_encoded_dim = self.text_model.config.hidden_size
45
+
46
+ # Projection of the text-outputs into the latent space
47
+ self.projection = nn.Sequential(
48
+ nn.ReLU(),
49
+ nn.Linear(self.text_encoded_dim, latent_dim)
50
+ )
51
+
52
+ self.mu_token = nn.Parameter(torch.randn(latent_dim))
53
+ self.logvar_token = nn.Parameter(torch.randn(latent_dim))
54
+ self.sequence_pos_encoding = PositionalEncoding(latent_dim)
55
+
56
+ seq_trans_encoder_layer = nn.TransformerEncoderLayer(d_model=latent_dim,
57
+ nhead=num_heads,
58
+ dim_feedforward=ff_size,
59
+ dropout=0.0,
60
+ activation=activation)
61
+ self.seqTransEncoder = nn.TransformerEncoder(
62
+ seq_trans_encoder_layer,
63
+ num_layers=num_layers
64
+ )
65
+
66
+ def get_last_hidden_state(self, texts: List[str],
67
+ return_mask: bool = False):
68
+ encoded_inputs = self.tokenizer(texts, return_tensors="pt", padding=True)
69
+ output = self.text_model(**encoded_inputs.to(self.text_model.device))
70
+ if not return_mask:
71
+ return output.last_hidden_state
72
+ return output.last_hidden_state, encoded_inputs.attention_mask.to(dtype=bool)
73
+
74
+ def forward(self, texts: List[str]) -> Tensor:
75
+ text_encoded, mask = self.get_last_hidden_state(texts, return_mask=True)
76
+
77
+ x = self.projection(text_encoded)
78
+ bs, nframes, _ = x.shape
79
+ # bs, nframes, totjoints, nfeats = x.shape
80
+ # Switch sequence and batch_size because the input of
81
+ # Pytorch Transformer is [Sequence, Batch size, ...]
82
+ x = x.permute(1, 0, 2) # now it is [nframes, bs, latent_dim]
83
+
84
+ mu_token = torch.tile(self.mu_token, (bs,)).reshape(bs, -1)
85
+ logvar_token = torch.tile(self.logvar_token, (bs,)).reshape(bs, -1)
86
+
87
+ # adding the distribution tokens for all sequences
88
+ xseq = torch.cat((mu_token[None], logvar_token[None], x), 0)
89
+
90
+ # create a bigger mask, to allow attend to mu and logvar
91
+ token_mask = torch.ones((bs, 2), dtype=bool, device=x.device)
92
+ aug_mask = torch.cat((token_mask, mask), 1)
93
+
94
+ # add positional encoding
95
+ xseq = self.sequence_pos_encoding(xseq)
96
+ final = self.seqTransEncoder(xseq, src_key_padding_mask=~aug_mask)
97
+
98
+ # only mu for inference
99
+ mu = final[0]
100
+ return mu
101
+
102
+ # compute score for retrieval
103
+ def compute_scores(self, texts, unit_embs=None, embs=None):
104
+ # not both empty
105
+ assert not (unit_embs is None and embs is None)
106
+ # not both filled
107
+ assert not (unit_embs is not None and embs is not None)
108
+
109
+ output_str = False
110
+ # if one input, squeeze the output
111
+ if isinstance(texts, str):
112
+ texts = [texts]
113
+ output_str = True
114
+
115
+ # compute unit_embs from embs if not given
116
+ if embs is not None:
117
+ unit_embs = normalize(embs)
118
+
119
+ with torch.no_grad():
120
+ latent_unit_texts = normalize(self(texts))
121
+ # compute cosine similarity between 0 and 1
122
+ scores = (unit_embs @ latent_unit_texts.T).T/2 + 0.5
123
+ scores = scores.cpu().numpy()
124
+
125
+ if output_str:
126
+ scores = scores[0]
127
+
128
+ return scores
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ torch
2
+ orjson
3
+ numpy
4
+ gdown
5
+ transformers