jykoh commited on
Commit
b6f5818
β€’
1 Parent(s): ec9b1ca

Initial commit.

Browse files
Files changed (11) hide show
  1. .gitattributes +1 -0
  2. .gitignore +4 -0
  3. Dockerfile +18 -0
  4. README.md +10 -8
  5. app.py +218 -0
  6. cc3m_embeddings_urls.npy +3 -0
  7. gill/layers.py +54 -0
  8. gill/models.py +909 -0
  9. gill/utils.py +249 -0
  10. requirements.txt +36 -0
  11. share_btn.py +107 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ cc3m_embeddings_urls.npy filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ .DS_Store
2
+ venv/
3
+ __pycache__
4
+ *.pyc
Dockerfile ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM pytorch/pytorch:1.11.0-cuda11.3-cudnn8-runtime as base
2
+
3
+ RUN apt-get update && apt-get -y install git
4
+
5
+
6
+ ENV HOME=/exp/fromage
7
+
8
+
9
+
10
+ WORKDIR /exp/fromage
11
+ COPY ./requirements.txt ./requirements.txt
12
+ RUN python -m pip install -r ./requirements.txt
13
+ RUN python -m pip install gradio
14
+
15
+ COPY . .
16
+ RUN chmod -R a+rwX .
17
+
18
+ CMD ["uvicorn", "app:main", "--host", "0.0.0.0", "--port", "7860"]
README.md CHANGED
@@ -1,11 +1,13 @@
1
  ---
2
- title: Gill
3
- emoji: 🐨
4
- colorFrom: red
5
- colorTo: purple
6
  sdk: docker
7
- pinned: false
8
- license: apache-2.0
 
 
 
 
 
 
9
  ---
10
-
11
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: GILL
3
+ emoji: 🐟
 
 
4
  sdk: docker
5
+ app_file: app.py
6
+ colorFrom: blue
7
+ colorTo: red
8
+ pinned: true
9
+ tags:
10
+ - multimodal
11
+ - computer-vision
12
+ - nlp
13
  ---
 
 
app.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tempfile
2
+ from share_btn import community_icon_html, loading_icon_html, share_js, save_js
3
+ import huggingface_hub
4
+ import gradio as gr
5
+ from gill import utils
6
+ from gill import models
7
+ import matplotlib.pyplot as plt
8
+ from PIL import Image
9
+ import torch
10
+ import numpy as np
11
+ import os
12
+ os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "False"
13
+
14
+
15
+ css = """
16
+ #chatbot { min-height: 300px; }
17
+ #save-btn {
18
+ background-image: linear-gradient(to right bottom, rgba(130,217,244, 0.9), rgba(158,231,214, 1.0));
19
+ }
20
+ #save-btn:hover {
21
+ background-image: linear-gradient(to right bottom, rgba(110,197,224, 0.9), rgba(138,211,194, 1.0));
22
+ }
23
+ #share-btn {
24
+ background-image: linear-gradient(to right bottom, rgba(130,217,244, 0.9), rgba(158,231,214, 1.0));
25
+ }
26
+ #share-btn:hover {
27
+ background-image: linear-gradient(to right bottom, rgba(110,197,224, 0.9), rgba(138,211,194, 1.0));
28
+ }
29
+ #gallery { z-index: 999999; }
30
+ #gallery img:hover {transform: scale(2.3); z-index: 999999; position: relative; padding-right: 30%; padding-bottom: 30%;}
31
+ #gallery button img:hover {transform: none; z-index: 999999; position: relative; padding-right: 0; padding-bottom: 0;}
32
+ @media (hover: none) {
33
+ #gallery img:hover {transform: none; z-index: 999999; position: relative; padding-right: 0; 0;}
34
+ }
35
+ """
36
+
37
+ examples = [
38
+ 'examples/sparrow.png',
39
+ 'examples/beaver.png',
40
+ 'examples/couch.png',
41
+ 'examples/guac.png',
42
+ 'examples/scraped_knee.png'
43
+ ]
44
+
45
+ # Download model from HF Hub.
46
+ ckpt_path = huggingface_hub.hf_hub_download(
47
+ repo_id='jykoh/gill', filename='pretrained_ckpt.pth.tar')
48
+ decision_model_path = huggingface_hub.hf_hub_download(
49
+ repo_id='jykoh/gill', filename='decision_model.pth.tar')
50
+ args_path = huggingface_hub.hf_hub_download(
51
+ repo_id='jykoh/gill', filename='model_args.json')
52
+ model = models.load_gill('./', args_path, ckpt_path, decision_model_path)
53
+
54
+
55
+ def upload_image(state, image_input):
56
+ conversation = state[0]
57
+ chat_history = state[1]
58
+ input_image = Image.open(image_input.name).resize(
59
+ (224, 224)).convert('RGB')
60
+ input_image.save(image_input.name) # Overwrite with smaller image.
61
+ conversation += [(f'<img src="/file={image_input.name}" style="display: inline-block;">', "")]
62
+ return [conversation, chat_history + [input_image, ""]], conversation
63
+
64
+
65
+ def reset():
66
+ return [[], []], []
67
+
68
+
69
+ def reset_last(state):
70
+ conversation = state[0][:-1]
71
+ chat_history = state[1][:-2]
72
+ return [conversation, chat_history], conversation
73
+
74
+
75
+ def save_image_to_local(image: Image.Image):
76
+ # TODO(jykoh): Update so the url path is used, to prevent repeat saving.
77
+ filename = next(tempfile._get_candidate_names()) + '.png'
78
+ image.save(filename)
79
+ return filename
80
+
81
+
82
+ def generate_for_prompt(input_text, state, ret_scale_factor, num_words, temperature):
83
+ # Ignore empty inputs.
84
+ if len(input_text) == 0:
85
+ return state, state[0], gr.update(visible=True)
86
+
87
+ input_prompt = 'Q: ' + input_text + '\nA:'
88
+ conversation = state[0]
89
+ chat_history = state[1]
90
+ print('Generating for', chat_history, flush=True)
91
+
92
+ # If an image was uploaded, prepend it to the model.
93
+ model_inputs = chat_history
94
+ model_inputs.append(input_prompt)
95
+
96
+ top_p = 1.0
97
+ if temperature != 0.0:
98
+ top_p = 0.95
99
+
100
+ print('Running model.generate_for_images_and_texts with',
101
+ model_inputs, flush=True)
102
+ model_outputs = model.generate_for_images_and_texts(model_inputs,
103
+ num_words=max(num_words, 1), ret_scale_factor=ret_scale_factor, top_p=top_p,
104
+ temperature=temperature, max_num_rets=1,
105
+ num_inference_steps=1)
106
+ print('model_outputs', model_outputs, ret_scale_factor, flush=True)
107
+
108
+ im_names = []
109
+ response = ''
110
+ text_outputs = []
111
+ for output_i, p in enumerate(model_outputs):
112
+ if type(p) == str:
113
+ if output_i > 0:
114
+ response += '<br/>'
115
+ # Remove the image tokens for output.
116
+ text_outputs.append(p.replace('[IMG0] [IMG1] [IMG2] [IMG3] [IMG4] [IMG5] [IMG6] [IMG7]', ''))
117
+ response += p
118
+ if len(model_outputs) > 1:
119
+ response += '<br/>'
120
+ elif type(p) == dict:
121
+ # Decide whether to generate or retrieve.
122
+ if p['decision'] is not None and p['decision'][0] == 'gen':
123
+ image = p['gen'][0][0].resize((512, 512))
124
+ filename = save_image_to_local(image)
125
+ response += f'<img src="/file={filename}" style="display: inline-block;"><p style="font-size: 12px; color: #555;">(Generated)</p>'
126
+ else:
127
+ image = p['ret'][0][0].resize((512, 512))
128
+ filename = save_image_to_local(image)
129
+ response += f'<img src="/file={filename}" style="display: inline-block;"><p style="font-size: 12px; color: #555;">(Retrieved)</p>'
130
+
131
+
132
+ chat_history = model_inputs + \
133
+ [' '.join([s for s in model_outputs if type(s) == str]) + '\n']
134
+ # Remove [RET] from outputs.
135
+ conversation.append((input_text, response.replace('[IMG0] [IMG1] [IMG2] [IMG3] [IMG4] [IMG5] [IMG6] [IMG7]', '')))
136
+
137
+ # Set input image to None.
138
+ print('state', state, flush=True)
139
+ print('updated state', [conversation, chat_history], flush=True)
140
+ return [conversation, chat_history], conversation, gr.update(visible=True), gr.update(visible=True)
141
+
142
+
143
+ with gr.Blocks(css=css) as demo:
144
+ gr.HTML("""
145
+ <h1>πŸ§€ FROMAGe</h1>
146
+ <p>This is the official Gradio demo for the FROMAGe model, a model that can process arbitrarily interleaved image and text inputs, and produce image and text outputs.</p>
147
+
148
+ <strong>Paper:</strong> <a href="https://arxiv.org/abs/2301.13823" target="_blank">Grounding Language Models to Images for Multimodal Generation</a>
149
+ <br/>
150
+ <strong>Project Website:</strong> <a href="https://jykoh.com/fromage" target="_blank">FROMAGe Website</a>
151
+ <br/>
152
+ <strong>Code and Models:</strong> <a href="https://github.com/kohjingyu/fromage" target="_blank">GitHub</a>
153
+ <br/>
154
+ <br/>
155
+
156
+ <strong>Tips:</strong>
157
+ <ul>
158
+ <li>Start by inputting either image or text prompts (or both) and chat with FROMAGe to get image-and-text replies.</li>
159
+ <li>Tweak the level of sensitivity to images and text using the parameters on the right.</li>
160
+ <li>FROMAGe <i>retrieves</i> images from a database, and doesn't generate novel images, and will not be able to return images outside those in Conceptual Captions.</li>
161
+ <li>Check out cool conversations in the examples or community tab for inspiration and share your own!</li>
162
+ <li>For faster inference without waiting in queue, you may duplicate the space and use your own GPU: <a href="https://huggingface.co/spaces/jykoh/fromage?duplicate=true"><img style="display: inline-block; margin-top: 0em; margin-bottom: 0em" src="https://img.shields.io/badge/-Duplicate%20Space-blue?labelColor=white&style=flat&logo=data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAABAAAAAQCAYAAAAf8/9hAAAAAXNSR0IArs4c6QAAAP5JREFUOE+lk7FqAkEURY+ltunEgFXS2sZGIbXfEPdLlnxJyDdYB62sbbUKpLbVNhyYFzbrrA74YJlh9r079973psed0cvUD4A+4HoCjsA85X0Dfn/RBLBgBDxnQPfAEJgBY+A9gALA4tcbamSzS4xq4FOQAJgCDwV2CPKV8tZAJcAjMMkUe1vX+U+SMhfAJEHasQIWmXNN3abzDwHUrgcRGmYcgKe0bxrblHEB4E/pndMazNpSZGcsZdBlYJcEL9Afo75molJyM2FxmPgmgPqlWNLGfwZGG6UiyEvLzHYDmoPkDDiNm9JR9uboiONcBXrpY1qmgs21x1QwyZcpvxt9NS09PlsPAAAAAElFTkSuQmCC&logoWidth=14" alt="Duplicate Space"></a></li>
163
+ </ul>
164
+ """)
165
+
166
+ gr_state = gr.State([[], []]) # conversation, chat_history
167
+
168
+ with gr.Row():
169
+ with gr.Column(scale=0.7, min_width=500):
170
+ with gr.Row():
171
+ chatbot = gr.Chatbot(elem_id="chatbot", label="πŸ§€ FROMAGe Chatbot")
172
+ with gr.Row():
173
+ image_btn = gr.UploadButton("πŸ–ΌοΈ Upload Image", file_types=["image"])
174
+
175
+ text_input = gr.Textbox(label="Message", placeholder="Type a message")
176
+
177
+ with gr.Column():
178
+ submit_btn = gr.Button(
179
+ "Submit", interactive=True, variant="primary")
180
+ clear_last_btn = gr.Button("Undo")
181
+ clear_btn = gr.Button("Reset All")
182
+ with gr.Row(visible=False) as save_group:
183
+ save_button = gr.Button("πŸ’Ύ Save Conversation as .png", elem_id="save-btn")
184
+
185
+ with gr.Row(visible=False) as share_group:
186
+ share_button = gr.Button("πŸ€— Share to Community (opens new window)", elem_id="share-btn")
187
+
188
+ with gr.Column(scale=0.3, min_width=400):
189
+ ret_scale_factor = gr.Slider(minimum=0.0, maximum=3.0, value=1.0, step=0.1, interactive=True,
190
+ label="Frequency multiplier for returning images (higher means more frequent)")
191
+ # max_ret_images = gr.Number(
192
+ # minimum=0, maximum=3, value=2, precision=1, interactive=True, label="Max images to return")
193
+ gr_max_len = gr.Slider(minimum=1, maximum=64, value=32,
194
+ step=1, interactive=True, label="Max # of words")
195
+ gr_temperature = gr.Slider(
196
+ minimum=0.0, maximum=1.0, value=0.0, interactive=True, label="Temperature (0 for deterministic, higher for more randomness)")
197
+
198
+ gallery = gr.Gallery(
199
+ value=[Image.open(e) for e in examples], label="Example Conversations", show_label=True, elem_id="gallery",
200
+ ).style(grid=[2], height="auto")
201
+
202
+ text_input.submit(generate_for_prompt, [text_input, gr_state, ret_scale_factor,
203
+ gr_max_len, gr_temperature], [gr_state, chatbot, share_group, save_group])
204
+ text_input.submit(lambda: "", None, text_input) # Reset chatbox.
205
+ submit_btn.click(generate_for_prompt, [text_input, gr_state, ret_scale_factor,
206
+ gr_max_len, gr_temperature], [gr_state, chatbot, share_group, save_group])
207
+ submit_btn.click(lambda: "", None, text_input) # Reset chatbox.
208
+
209
+ image_btn.upload(upload_image, [gr_state, image_btn], [gr_state, chatbot])
210
+ clear_last_btn.click(reset_last, [gr_state], [gr_state, chatbot])
211
+ clear_btn.click(reset, [], [gr_state, chatbot])
212
+ share_button.click(None, [], [], _js=share_js)
213
+ save_button.click(None, [], [], _js=save_js)
214
+
215
+
216
+ demo.queue(concurrency_count=1, api_open=False, max_size=16)
217
+ # demo.launch(debug=True, server_name="0.0.0.0")
218
+ demo.launch(debug=True, server_name="127.0.0.1")
cc3m_embeddings_urls.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:797e2ab9d46f103106bbf111352c762c5969630e9a13ccdc1f56a51c63fc39a3
3
+ size 2887526287
gill/layers.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+
5
+ class TextFcLayer(nn.Module):
6
+ """Layers used in mapping text embeddings to visual outputs."""
7
+
8
+ def __init__(self, in_dim: int, out_dim: int, num_input_tokens: int = 1, num_output_tokens: int = 1, mode: str = 'linear'):
9
+ super().__init__()
10
+
11
+ self.num_input_tokens = num_input_tokens
12
+ self.num_output_tokens = num_output_tokens
13
+ self.mode = mode
14
+
15
+ if mode == 'linear':
16
+ self.model = nn.Linear(in_dim, out_dim)
17
+ elif mode == 'gill_mapper': # TODO(jykoh): Rename to GILLMapper
18
+ hidden_dim = 512
19
+ self.fc = nn.Linear(in_dim, hidden_dim)
20
+ self.tfm = nn.Transformer(batch_first=True, norm_first=True,
21
+ d_model=hidden_dim, num_encoder_layers=4, num_decoder_layers=4,
22
+ dim_feedforward=hidden_dim * 4, dropout=0.0, nhead=4)
23
+ self.model = nn.Linear(hidden_dim, out_dim)
24
+ self.query_embs = nn.Parameter(torch.randn(1, num_output_tokens, hidden_dim))
25
+ else:
26
+ raise NotImplementedError(mode)
27
+
28
+ def forward(self, x: torch.Tensor, input_embs: torch.Tensor) -> torch.Tensor:
29
+ outputs = None
30
+
31
+ if self.mode == 'gill_mapper':
32
+ x = x + input_embs
33
+
34
+ if isinstance(self.model, nn.ModuleList):
35
+ assert len(self.model) == x.shape[1] == self.num_input_tokens, (len(self.model), x.shape, self.num_input_tokens)
36
+ outputs = []
37
+ for i in range(self.num_input_tokens):
38
+ outputs.append(self.model[i](x[:, i, :])) # (N, D)
39
+ outputs = torch.stack(outputs, dim=1) # (N, T, D)
40
+ else:
41
+ if self.mode == 'gill_mapper':
42
+ x = self.fc(x)
43
+ x = self.tfm(x, self.query_embs.repeat(x.shape[0], 1, 1))
44
+ outputs = self.model(x)
45
+
46
+ if outputs.shape[1] != self.num_output_tokens and self.mode == 'linear':
47
+ if self.mode == 'linear':
48
+ outputs = outputs[:, :self.num_output_tokens, :]
49
+ else:
50
+ raise NotImplementedError
51
+
52
+ assert outputs.shape[1] == 1 or (outputs.shape[1] * outputs.shape[2] == self.num_output_tokens * 768), (outputs.shape, self.num_output_tokens)
53
+ return outputs # (N, T, D)
54
+
gill/models.py ADDED
@@ -0,0 +1,909 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional
2
+ from collections import namedtuple
3
+ from diffusers import StableDiffusionPipeline
4
+ import json
5
+ import numpy as np
6
+ import os
7
+ import glob
8
+ import torch
9
+ from torch import Tensor
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ import pickle as pkl
13
+ from PIL import Image, UnidentifiedImageError
14
+ from requests.exceptions import ConnectionError
15
+
16
+ from transformers import AutoTokenizer, AutoModel, CLIPVisionModel, OPTForCausalLM
17
+ from gill import utils
18
+ from gill import layers
19
+
20
+
21
+ class GILLArgs:
22
+ freeze_lm: bool = True
23
+ freeze_vm: bool = True
24
+ opt_version: str = 'facebook/opt-6.7b'
25
+ visual_encoder: str = 'openai/clip-vit-large-patch14'
26
+ n_visual_tokens: int = 1
27
+ task: str = 'captioning'
28
+ ret_emb_dim: Optional[int] = 256
29
+ gen_emb_dim: Optional[int] = 256
30
+ text_emb_layers: List[int] = [-1]
31
+ gen_token_idx: List[int] = [0]
32
+ retrieval_token_idx: List[int] = [0]
33
+ text_fc_mode: str = 'gill_mapper'
34
+ ret_text_fc_mode: str = 'linear'
35
+ num_tokens: int = 8
36
+ num_clip_tokens: int = 77
37
+
38
+
39
+ class GILLModel(nn.Module):
40
+ def __init__(self, tokenizer, args: GILLArgs = GILLArgs()):
41
+ super().__init__()
42
+ self.tokenizer = tokenizer
43
+ self.feature_extractor = utils.get_feature_extractor_for_model(args.visual_encoder, train=False)
44
+ self.image_token = self.tokenizer.cls_token_id
45
+ assert args.text_emb_layers != set(args.text_emb_layers), 'text_emb_layers not unique'
46
+ self.args = args
47
+ self.num_tokens = args.num_tokens
48
+ self.num_clip_tokens = args.num_clip_tokens
49
+
50
+ opt_version = args.opt_version
51
+ visual_encoder = args.visual_encoder
52
+ n_visual_tokens = args.n_visual_tokens
53
+ print(f"Using {opt_version} for the language model.")
54
+ print(f"Using {visual_encoder} for the visual model with {n_visual_tokens} visual tokens.")
55
+
56
+ if 'facebook/opt' in opt_version:
57
+ self.lm = OPTForCausalLM.from_pretrained(opt_version)
58
+ else:
59
+ raise NotImplementedError
60
+
61
+ self.opt_version = opt_version
62
+
63
+ if self.args.freeze_lm:
64
+ self.lm.eval()
65
+ print("Freezing the LM.")
66
+ for param in self.lm.parameters():
67
+ param.requires_grad = False
68
+ else:
69
+ self.lm.train()
70
+
71
+ self.retrieval_token_idx = args.retrieval_token_idx
72
+ self.gen_token_idx = args.gen_token_idx
73
+ self.lm.resize_token_embeddings(len(tokenizer))
74
+
75
+ self.input_embeddings = self.lm.get_input_embeddings()
76
+
77
+ print("Restoring pretrained weights for the visual model.")
78
+ if 'clip' in visual_encoder:
79
+ self.visual_model = CLIPVisionModel.from_pretrained(visual_encoder)
80
+ else:
81
+ self.visual_model = AutoModel.from_pretrained(visual_encoder)
82
+
83
+ if 'clip' in visual_encoder:
84
+ hidden_size = self.visual_model.config.hidden_size
85
+ else:
86
+ raise NotImplementedError
87
+
88
+ if self.args.freeze_vm:
89
+ print("Freezing the VM.")
90
+ self.visual_model.eval()
91
+ for param in self.visual_model.parameters():
92
+ param.requires_grad = False
93
+ else:
94
+ self.visual_model.train()
95
+
96
+ self.visual_model_name = visual_encoder
97
+
98
+ embedding_dim = self.input_embeddings.embedding_dim * self.args.n_visual_tokens
99
+ self.ret_text_hidden_fcs = nn.ModuleList([])
100
+ self.gen_text_hidden_fcs = nn.ModuleList([])
101
+
102
+ for layer_idx in self.args.text_emb_layers:
103
+ if (layer_idx == -1 or layer_idx == self.lm.config.num_hidden_layers) and ('bert' not in opt_version):
104
+ if 'opt' in opt_version: # OPT models
105
+ in_dim = self.lm.config.word_embed_proj_dim
106
+ else:
107
+ raise NotImplementedError
108
+
109
+ self.ret_text_hidden_fcs.append(
110
+ layers.TextFcLayer(in_dim, self.args.ret_emb_dim, num_input_tokens=self.args.num_tokens,
111
+ num_output_tokens=1, mode=self.args.ret_text_fc_mode))
112
+ self.gen_text_hidden_fcs.append(
113
+ layers.TextFcLayer(in_dim, self.args.gen_emb_dim, num_input_tokens=self.args.num_tokens,
114
+ num_output_tokens=self.args.num_clip_tokens, mode=self.args.text_fc_mode))
115
+
116
+ elif layer_idx < self.lm.config.num_hidden_layers:
117
+ self.ret_text_hidden_fcs.append(layers.TextFcLayer(self.lm.config.hidden_size, self.args.ret_emb_dim, num_input_tokens=self.args.num_tokens, num_output_tokens=1, mode=self.args.ret_text_fc_mode))
118
+ self.gen_text_hidden_fcs.append(layers.TextFcLayer(self.lm.config.hidden_size, self.args.gen_emb_dim, num_input_tokens=self.args.num_tokens, num_output_tokens=self.args.num_clip_tokens, mode=self.args.text_fc_mode))
119
+ else:
120
+ raise ValueError(f'Embedding of layer {layer_idx} was requested but model only has {self.lm.config.num_hidden_layers} layers.')
121
+
122
+ self.visual_embeddings = nn.Linear(hidden_size, embedding_dim)
123
+
124
+ # Retrieval image FC layer.
125
+ self.visual_fc = nn.Linear(hidden_size, self.args.ret_emb_dim)
126
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
127
+
128
+
129
+ def get_visual_embs(self, pixel_values: torch.FloatTensor, mode: str = 'captioning'):
130
+ if mode not in ['captioning', 'retrieval', 'generation']:
131
+ raise ValueError(f"mode should be one of ['captioning', 'retrieval', 'generation'], got {mode} instead.")
132
+
133
+ # Extract visual embeddings from the vision encoder.
134
+ if 'clip' in self.visual_model_name:
135
+ outputs = self.visual_model(pixel_values)
136
+ encoder_outputs = outputs.pooler_output
137
+ else:
138
+ raise NotImplementedError
139
+
140
+ # Use the correct fc based on function argument.
141
+ if mode == 'captioning':
142
+ visual_embs = self.visual_embeddings(encoder_outputs) # (2, D * n_visual_tokens)
143
+ visual_embs = torch.reshape(visual_embs, (visual_embs.shape[0], self.args.n_visual_tokens, -1))
144
+ elif mode == 'retrieval':
145
+ visual_embs = self.visual_fc(encoder_outputs) # (2, D * n_visual_tokens)
146
+ visual_embs = torch.reshape(visual_embs, (visual_embs.shape[0], 1, -1))
147
+ elif mode == 'generation':
148
+ visual_embs = torch.zeros((pixel_values.shape[0], 1, 768), device=pixel_values.device)
149
+ else:
150
+ raise NotImplementedError
151
+
152
+ return visual_embs
153
+
154
+
155
+ def train(self, mode=True):
156
+ super(GILLModel, self).train(mode=mode)
157
+ # Overwrite train() to ensure frozen models remain frozen.
158
+ if self.args.freeze_lm:
159
+ self.lm.eval()
160
+ if self.args.freeze_vm:
161
+ self.visual_model.eval()
162
+
163
+
164
+ def forward(
165
+ self,
166
+ pixel_values: torch.FloatTensor,
167
+ labels: Optional[torch.LongTensor] = None,
168
+ caption_len: Optional[torch.LongTensor] = None,
169
+ mode: str = 'captioning',
170
+ concat_captions: bool = False,
171
+ input_prefix: Optional[str] = None,
172
+ ):
173
+ visual_embs = self.get_visual_embs(pixel_values, mode)
174
+
175
+ batch_size, vis_seq_len, _ = visual_embs.shape # vis_seq_len = n_visual_tokens
176
+ if labels is not None:
177
+ assert labels.shape[0] == batch_size, (visual_embs.shape, labels.shape)
178
+ visual_embs_norm = ((visual_embs ** 2).sum(dim=-1) ** 0.5).mean()
179
+
180
+ input_embs = self.input_embeddings(labels) # (N, T, D)
181
+ input_embs_norm = ((input_embs ** 2).sum(dim=-1) ** 0.5).mean()
182
+
183
+ last_embedding_idx = caption_len - 1 # -1 to retrieve the token before the eos token
184
+
185
+ if input_prefix is not None:
186
+ prompt_ids = self.tokenizer(input_prefix, add_special_tokens=False, return_tensors="pt").input_ids
187
+ prompt_ids = prompt_ids.to(visual_embs.device)
188
+ prompt_embs = self.input_embeddings(prompt_ids)
189
+ prompt_embs = prompt_embs.repeat(batch_size, 1, 1)
190
+ assert prompt_embs.shape[0] == batch_size, prompt_embs.shape
191
+ assert prompt_embs.shape[2] == input_embs.shape[2], prompt_embs.shape
192
+ assert len(prompt_embs.shape) == 3, prompt_embs.shape
193
+
194
+ if mode == 'captioning':
195
+ # Concat to text embeddings.
196
+ condition_seq_len = 0
197
+ if input_prefix is None:
198
+ # Just add visual embeddings.
199
+ input_embs = torch.cat([visual_embs, input_embs], axis=1)
200
+ last_embedding_idx += vis_seq_len
201
+ condition_seq_len += vis_seq_len
202
+ full_labels = torch.zeros(visual_embs.shape[:2], dtype=torch.int64).to(visual_embs.device) - 100
203
+ else:
204
+ print(f'Adding prefix "{input_prefix}" to captioning.')
205
+ # Add visual and prompt embeddings.
206
+ prefix_embs = torch.cat([visual_embs, prompt_embs], axis=1)
207
+ input_embs = torch.cat([prefix_embs, input_embs], axis=1)
208
+
209
+ last_embedding_idx += prefix_embs.shape[1]
210
+ condition_seq_len += prefix_embs.shape[1]
211
+ full_labels = torch.zeros(prefix_embs.shape[:2], dtype=torch.int64).to(visual_embs.device) - 100
212
+
213
+ # Mask out embedding tokens in the labels.
214
+ full_labels = torch.cat([full_labels, labels], axis=1)
215
+
216
+ pad_idx = []
217
+
218
+ for label in full_labels:
219
+ for k, token in enumerate(label):
220
+ # Mask out retrieval/gen tokens if they exist.
221
+ if token in [self.tokenizer.pad_token_id] + self.retrieval_token_idx + self.gen_token_idx:
222
+ label[k:] = -100
223
+ pad_idx.append(k)
224
+ break
225
+ if k == len(label) - 1: # No padding found.
226
+ pad_idx.append(k + 1)
227
+ assert len(pad_idx) == batch_size, (len(pad_idx), batch_size)
228
+
229
+ bs, seq_len, embs_dim = input_embs.shape
230
+ if concat_captions:
231
+ print('Concatenating examples for captioning!')
232
+ assert len(input_embs.shape) == 3, input_embs
233
+ assert len(full_labels.shape) == 2, full_labels
234
+ assert batch_size % 2 == 0
235
+ all_concat_input_embs = []
236
+ all_concat_labels = []
237
+
238
+ # Rearrange embeddings and labels (and their padding) to concatenate captions.
239
+ for i in range(batch_size // 2):
240
+ first_idx = i * 2
241
+ second_idx = first_idx + 1
242
+ first_emb = input_embs[first_idx, :pad_idx[first_idx], :]
243
+ first_labels = full_labels[first_idx, :pad_idx[first_idx]]
244
+ first_padding = input_embs[first_idx, pad_idx[first_idx]:, :]
245
+ first_labels_padding = full_labels[first_idx, pad_idx[first_idx]:]
246
+
247
+ second_emb = input_embs[second_idx, :pad_idx[second_idx], :]
248
+ second_labels = full_labels[second_idx, :pad_idx[second_idx]]
249
+ second_padding = input_embs[second_idx, pad_idx[second_idx]:, :]
250
+ second_labels_padding = full_labels[second_idx, pad_idx[second_idx]:]
251
+ bos_idx = visual_embs.shape[1]
252
+
253
+ assert torch.all(first_labels_padding == -100), first_labels_padding
254
+ assert torch.all(second_labels_padding == -100), second_labels_padding
255
+ assert torch.all(second_labels[bos_idx] == self.tokenizer.bos_token_id), (second_labels, bos_idx, self.tokenizer.bos_token_id)
256
+
257
+ # Remove BOS token of the second caption.
258
+ second_labels = torch.cat([second_labels[:bos_idx], second_labels[bos_idx + 1:]], axis=0)
259
+ second_emb = torch.cat([second_emb[:bos_idx, :], second_emb[bos_idx + 1:, :]], axis=0)
260
+
261
+ concat_input_embs = torch.cat([first_emb, second_emb, first_padding, second_padding], axis=0) # (T*2, 768)
262
+ concat_labels = torch.cat([first_labels, second_labels, first_labels_padding, second_labels_padding], axis=0) # (T*2, 768)
263
+ all_concat_input_embs.append(concat_input_embs)
264
+ all_concat_labels.append(concat_labels)
265
+
266
+ # Pad to max length.
267
+ input_embs = torch.stack(all_concat_input_embs, axis=0) # (N/2, T*2, 768)
268
+ full_labels = torch.stack(all_concat_labels, axis=0) # (N/2, T*2, 768)
269
+ print("Concatenated full_labels:", full_labels[0, ...])
270
+ assert input_embs.shape == (bs // 2, seq_len * 2 - 1, embs_dim), input_embs.shape
271
+ assert full_labels.shape == (bs // 2, seq_len * 2 - 1), full_labels.shape
272
+
273
+ output = self.lm(inputs_embeds=input_embs,
274
+ labels=full_labels,
275
+ output_hidden_states=True)
276
+ elif mode in ['retrieval', 'generation']:
277
+ full_labels = torch.clone(labels)
278
+ if input_prefix is not None:
279
+ print(f'Adding prefix "{input_prefix}" to retrieval.')
280
+ # Add prompt embeddings.
281
+ prefix_embs = prompt_embs
282
+ input_embs = torch.cat([prefix_embs, input_embs], axis=1)
283
+ last_embedding_idx += prefix_embs.shape[1]
284
+ full_labels = torch.cat([
285
+ torch.zeros(prefix_embs.shape[:2], dtype=torch.int64).to(labels.device) - 100,
286
+ full_labels
287
+ ], axis=1)
288
+
289
+ pad_idx = []
290
+ for label in full_labels:
291
+ for k, token in enumerate(label):
292
+ if (token == self.tokenizer.pad_token_id):
293
+ label[k:] = -100
294
+ pad_idx.append(k)
295
+ break
296
+ if k == len(label) - 1: # No padding found.
297
+ pad_idx.append(k + 1)
298
+ assert len(pad_idx) == batch_size, (len(pad_idx), batch_size)
299
+
300
+ bs, seq_len, embs_dim = input_embs.shape
301
+ # Concatenate examples for captioning, if specified.
302
+ if concat_captions:
303
+ print(f'Concatenating examples for {mode}!')
304
+ assert len(input_embs.shape) == 3, input_embs
305
+ assert len(full_labels.shape) == 2, full_labels
306
+ assert batch_size % 2 == 0
307
+ all_concat_input_embs = []
308
+ all_concat_labels = []
309
+ all_last_embedding_idx = []
310
+
311
+ # Rearrange embeddings and labels (and their padding) to concatenate captions.
312
+ for i in range(batch_size // 2):
313
+ first_idx = i * 2
314
+ second_idx = first_idx + 1
315
+ first_emb = input_embs[first_idx, :pad_idx[first_idx], :]
316
+ first_labels = full_labels[first_idx, :pad_idx[first_idx]]
317
+ first_padding = input_embs[first_idx, pad_idx[first_idx]:, :]
318
+ first_labels_padding = full_labels[first_idx, pad_idx[first_idx]:]
319
+
320
+ second_emb = input_embs[second_idx, :pad_idx[second_idx], :]
321
+ second_labels = full_labels[second_idx, :pad_idx[second_idx]]
322
+ second_padding = input_embs[second_idx, pad_idx[second_idx]:, :]
323
+ second_labels_padding = full_labels[second_idx, pad_idx[second_idx]:]
324
+
325
+ bos_idx = 0
326
+ assert torch.all(first_labels_padding == -100), first_labels_padding
327
+ assert torch.all(second_labels_padding == -100), second_labels_padding
328
+ assert torch.all(second_labels[bos_idx] == self.tokenizer.bos_token_id), (second_labels, bos_idx, self.tokenizer.bos_token_id)
329
+
330
+ # Remove BOS token of second caption.
331
+ second_labels = second_labels[bos_idx + 1:]
332
+ second_emb = second_emb[bos_idx + 1:, :]
333
+ last_embedding_idx[second_idx] = last_embedding_idx[second_idx] - 1
334
+
335
+ concat_input_embs = torch.cat([first_emb, second_emb, first_padding, second_padding], axis=0) # (T*2, 768)
336
+ concat_labels = torch.cat([first_labels, second_labels, first_labels_padding, second_labels_padding], axis=0) # (T*2, 768)
337
+ all_concat_input_embs.append(concat_input_embs)
338
+ all_concat_labels.append(concat_labels)
339
+
340
+ all_last_embedding_idx.append((last_embedding_idx[first_idx], first_emb.shape[0] + last_embedding_idx[second_idx]))
341
+
342
+ if mode == 'retrieval':
343
+ assert concat_labels[all_last_embedding_idx[-1][0]] in self.retrieval_token_idx, (concat_labels, all_last_embedding_idx[-1][0])
344
+ assert concat_labels[all_last_embedding_idx[-1][1]] in self.retrieval_token_idx, (concat_labels, all_last_embedding_idx[-1][1])
345
+ elif mode == 'generation':
346
+ # Check that the last n tokens are GEN tokens.
347
+ for gen_i in range(len(self.gen_token_idx)):
348
+ assert concat_labels[all_last_embedding_idx[-1][0]-gen_i] == self.gen_token_idx[-gen_i-1], (concat_labels, all_last_embedding_idx[-1][0]-gen_i, self.gen_token_idx[-gen_i-1])
349
+ assert concat_labels[all_last_embedding_idx[-1][1]-gen_i] == self.gen_token_idx[-gen_i-1], (concat_labels, all_last_embedding_idx[-1][1]-gen_i, self.gen_token_idx[-gen_i-1])
350
+
351
+ # Pad to max length.
352
+ input_embs = torch.stack(all_concat_input_embs, axis=0) # (N/2, T*2, 768)
353
+ full_labels = torch.stack(all_concat_labels, axis=0) # (N/2, T*2, 768)
354
+ assert input_embs.shape == (bs // 2, seq_len * 2 - 1, embs_dim), input_embs.shape
355
+ assert full_labels.shape == (bs // 2, seq_len * 2 - 1), full_labels.shape
356
+
357
+ # Update labels to pad non-first tokens.
358
+ for label in full_labels:
359
+ for k, token in enumerate(label):
360
+ if (token == self.tokenizer.pad_token_id) or (token in (self.retrieval_token_idx[1:] + self.gen_token_idx[1:])):
361
+ label[k:] = -100
362
+ break
363
+ output = self.lm(inputs_embeds=input_embs,
364
+ labels=full_labels,
365
+ output_hidden_states=True)
366
+ else:
367
+ raise NotImplementedError
368
+
369
+ last_embedding = None
370
+ last_output_logit = None
371
+ hidden_states = []
372
+ llm_hidden_states = []
373
+
374
+ if mode in ['retrieval', 'generation']:
375
+ num_tokens = self.num_tokens
376
+ if mode == 'retrieval':
377
+ text_hidden_fcs = self.ret_text_hidden_fcs
378
+ else:
379
+ text_hidden_fcs = self.gen_text_hidden_fcs
380
+
381
+ # Concatenate captions for retrieval / generation, if specified.
382
+ if not concat_captions:
383
+ for idx, fc_layer in zip(self.args.text_emb_layers, text_hidden_fcs):
384
+ input_hidden_state = torch.stack([output.hidden_states[idx][i, last_embedding_idx[i]-num_tokens+1:last_embedding_idx[i]+1, :] for i in range(batch_size)], axis=0)
385
+ input_embedding = torch.stack([input_embs[i, last_embedding_idx[i]-num_tokens+1:last_embedding_idx[i]+1, :] for i in range(batch_size)], axis=0)
386
+ llm_hidden_states.append(input_hidden_state)
387
+ hidden_states.append(fc_layer(input_hidden_state, input_embedding)) # (N, seq_len, 2048)
388
+ else:
389
+ for idx, fc_layer in zip(self.args.text_emb_layers, text_hidden_fcs):
390
+ all_last_embedding = []
391
+ all_input_embedding = []
392
+ all_last_output_logit = []
393
+ for i in range(batch_size // 2):
394
+ first_last_embedding_idx, second_last_embedding_idx = all_last_embedding_idx[i]
395
+ first_last_embedding = output.hidden_states[idx][i, first_last_embedding_idx-num_tokens+1:first_last_embedding_idx+1, :] # (N, D)
396
+ second_last_embedding = output.hidden_states[idx][i, second_last_embedding_idx-num_tokens+1:second_last_embedding_idx+1, :] # (N, D)
397
+ all_last_embedding.append(first_last_embedding)
398
+ all_last_embedding.append(second_last_embedding)
399
+
400
+ first_input_embs = input_embs[i, first_last_embedding_idx-num_tokens+1:first_last_embedding_idx+1, :] # (N, D)
401
+ second_input_embs = input_embs[i, second_last_embedding_idx-num_tokens+1:second_last_embedding_idx+1, :] # (N, D)
402
+ all_input_embedding.append(first_input_embs)
403
+ all_input_embedding.append(second_input_embs)
404
+
405
+ first_last_output_logit = output.logits[i, first_last_embedding_idx - 1, :] # (N, D)
406
+ second_last_output_logit = output.logits[i, second_last_embedding_idx - 1, :] # (N, D)
407
+ all_last_output_logit.append(first_last_output_logit)
408
+ all_last_output_logit.append(second_last_output_logit)
409
+
410
+ last_embedding = torch.stack(all_last_embedding, axis=0)
411
+ input_embedding = torch.stack(all_input_embedding, axis=0)
412
+ last_output_logit = torch.stack(all_last_output_logit, axis=0)
413
+ llm_hidden_states.append(last_embedding)
414
+ hidden_states.append(fc_layer(last_embedding, input_embedding)) # (N, seq_len, 2048)
415
+
416
+ if not concat_captions:
417
+ # Add hidden states together.
418
+ last_embedding = torch.stack(hidden_states, dim=-1).sum(dim=-1) #torch.stack([last_hidden_state[i, :, :] for i in range(batch_size)], axis=0) # (N, T, D)
419
+ last_output_logit = torch.stack([output.logits[i, last_embedding_idx[i] - 1, :] for i in range(batch_size)], axis=0) # (N, D)
420
+ else:
421
+ # Add hidden states together.
422
+ last_embedding = torch.stack(hidden_states, dim=-1).sum(dim=-1)
423
+
424
+ # Compute retrieval loss.
425
+ if mode == 'retrieval':
426
+ assert visual_embs.shape[1] == 1, visual_embs.shape
427
+ assert last_embedding.shape[1] == 1, last_embedding.shape
428
+ visual_embs = visual_embs[:, 0, :]
429
+ visual_embs = visual_embs / visual_embs.norm(dim=1, keepdim=True)
430
+ last_embedding = last_embedding[:, 0, :]
431
+ last_embedding = last_embedding / last_embedding.norm(dim=1, keepdim=True)
432
+
433
+ # cosine similarity as logits
434
+ logit_scale = self.logit_scale.exp()
435
+ visual_embs = logit_scale * visual_embs
436
+ elif mode == 'captioning':
437
+ pass
438
+ else:
439
+ raise NotImplementedError
440
+
441
+ return output, full_labels, last_embedding, last_output_logit, visual_embs, visual_embs_norm, input_embs_norm, llm_hidden_states
442
+
443
+ def generate(self, embeddings = torch.FloatTensor, max_len: int = 32,
444
+ temperature: float = 0.0, top_p: float = 1.0, min_word_tokens: int = 0,
445
+ ret_scale_factor: float = 1.0, gen_scale_factor: float = 1.0,
446
+ filter_value: float = -float('Inf')):
447
+ """Runs greedy decoding and returns generated captions.
448
+
449
+ Args:
450
+ min_word_tokens: Minimum number of words to generate before allowing a [IMG] output.
451
+ filter_value: Value to assign to tokens that should never be generated.
452
+ Outputs:
453
+ out: (N, T) int32 sequence of output tokens.
454
+ output_embeddings: (N, T, 256) sequence of text output embeddings.
455
+ """
456
+ self.lm.eval()
457
+
458
+ with torch.no_grad(): # no tracking history
459
+ # init output with image tokens
460
+ out = None
461
+ output_embeddings = []
462
+ output_logits = []
463
+
464
+ for i in range(max_len):
465
+ output = self.lm(inputs_embeds=embeddings, use_cache=False, output_hidden_states=True)
466
+
467
+ for idx in self.args.text_emb_layers:
468
+ output_embeddings.append(output.hidden_states[idx])
469
+
470
+ logits = output.logits[:, -1, :] # (N, vocab_size)
471
+ if top_p == 1.0:
472
+ logits = logits.cpu()
473
+ output_logits.append(logits)
474
+
475
+ # Prevent the model from generating the [IMG1..n] tokens.
476
+ logits[:, self.retrieval_token_idx[1:]] = filter_value
477
+ logits[:, self.gen_token_idx[1:]] = filter_value
478
+
479
+ if (self.retrieval_token_idx or self.gen_token_idx) and self.retrieval_token_idx[0] != -1 and self.gen_token_idx[0] != -1:
480
+ if i < min_word_tokens:
481
+ # Eliminate probability of generating [IMG] if this is earlier than min_word_tokens.
482
+ logits[:, self.retrieval_token_idx] = filter_value
483
+ logits[:, self.gen_token_idx] = filter_value
484
+ else:
485
+ # Multiply by scaling factor.
486
+ if ret_scale_factor > 1:
487
+ logits[:, self.retrieval_token_idx[0]] = logits[:, self.retrieval_token_idx[0]].abs() * ret_scale_factor
488
+ if gen_scale_factor > 1:
489
+ logits[:, self.gen_token_idx[0]] = logits[:, self.gen_token_idx[0]].abs() * gen_scale_factor
490
+
491
+ if temperature == 0.0:
492
+ if top_p != 1.0:
493
+ raise ValueError('top_p cannot be set if temperature is 0 (greedy decoding).')
494
+ next_token = torch.argmax(logits, keepdim=True, dim=-1) # (N, 1)
495
+ else:
496
+ logits = logits / temperature
497
+
498
+ # Apply top-p filtering.
499
+ if top_p < 1.0:
500
+ assert top_p > 0, f'top_p should be above 0, got {top_p} instead.'
501
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True) # (N, D) and (N, D)
502
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) # (N, D)
503
+
504
+ # Remove tokens with cumulative probability above the threshold
505
+ sorted_indices_to_remove = cumulative_probs > top_p
506
+ # Shift the indices to the right to keep also the first token above the threshold
507
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
508
+ sorted_indices_to_remove[..., 0] = 0
509
+
510
+ for j in range(sorted_indices.shape[0]):
511
+ indices_to_remove = sorted_indices[j, sorted_indices_to_remove[j, :]]
512
+ logits[j, indices_to_remove] = filter_value
513
+
514
+ token_weights = logits.exp() # (N, vocab_size)
515
+ next_token = torch.multinomial(token_weights, 1) # (N, 1)
516
+
517
+ # Force generation of the remaining [IMG] tokens if [IMG0] is generated.
518
+ if next_token.shape[0] == 1 and next_token.item() == self.retrieval_token_idx[0]:
519
+ assert self.retrieval_token_idx == self.gen_token_idx, (self.retrieval_token_idx, self.gen_token_idx)
520
+ next_token = torch.tensor(self.retrieval_token_idx)[None, :].long().to(embeddings.device) # (1, num_tokens)
521
+ else:
522
+ next_token = next_token.long().to(embeddings.device)
523
+
524
+ if out is not None:
525
+ out = torch.cat([out, next_token], dim=-1)
526
+ else:
527
+ out = next_token
528
+
529
+ next_embedding = self.input_embeddings(next_token)
530
+ embeddings = torch.cat([embeddings, next_embedding], dim=1)
531
+
532
+ return out, output_embeddings, output_logits
533
+
534
+
535
+ class GILL(nn.Module):
536
+ def __init__(self, tokenizer, model_args: Optional[GILLArgs] = None,
537
+ path_array: Optional[List[str]] = None, emb_matrix: Optional[torch.tensor] = None,
538
+ load_sd: bool = False, num_gen_images: int = 1, decision_model_path: Optional[str] = None):
539
+ super().__init__()
540
+ self.model = GILLModel(tokenizer, model_args)
541
+ self.path_array = path_array
542
+ self.emb_matrix = emb_matrix
543
+ self.load_sd = load_sd
544
+ self.num_gen_images = num_gen_images
545
+ self.idx2dec = {0: 'gen', 1: 'ret', 2: 'same'}
546
+ self.decision_model = None
547
+
548
+ # Load the Stable Diffusion model.
549
+ if load_sd:
550
+ model_id = "runwayml/stable-diffusion-v1-5"
551
+ if torch.cuda.is_available():
552
+ self.sd_pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda")
553
+ else:
554
+ self.sd_pipe = StableDiffusionPipeline.from_pretrained(model_id)
555
+
556
+ if decision_model_path is not None:
557
+ print('Loading decision model...')
558
+ self.decision_model = nn.Sequential(*[
559
+ nn.Dropout(0.5),
560
+ nn.Linear(4097, 2),
561
+ ])
562
+
563
+ if torch.cuda.is_available():
564
+ mlp_checkpoint = torch.load(decision_model_path)
565
+ else:
566
+ mlp_checkpoint = torch.load(decision_model_path, map_location=torch.device('cpu'))
567
+
568
+ self.decision_model.load_state_dict(mlp_checkpoint['state_dict'], strict=True)
569
+ self.decision_model.eval()
570
+
571
+ def __call__(self, images: Tensor, tgt_tokens: Optional[Tensor] = None, caption_len: Optional[Tensor] = None,
572
+ generate: bool = False, num_words: int = 32, temperature: float = 1.0, top_p: float = 1.0,
573
+ ret_scale_factor: float = 1.0, gen_scale_factor: float = 1.0,
574
+ min_word_tokens: int = 0, mode: str = 'captioning', concat_captions: bool = False,
575
+ input_prefix: Optional[str] = None) -> Tensor:
576
+ if generate:
577
+ return self.model.generate(images, num_words, temperature=temperature, top_p=top_p,
578
+ min_word_tokens=min_word_tokens, ret_scale_factor=ret_scale_factor,
579
+ gen_scale_factor=gen_scale_factor)
580
+ else:
581
+ output = self.model(
582
+ pixel_values = images,
583
+ labels = tgt_tokens,
584
+ caption_len = caption_len,
585
+ mode = mode,
586
+ concat_captions = concat_captions,
587
+ input_prefix = input_prefix)
588
+ return output
589
+
590
+ def generate_for_images_and_texts(
591
+ self, prompts: List, num_words: int = 0, min_word_tokens: int = 0, ret_scale_factor: float = 1.0, gen_scale_factor: float = 1.0,
592
+ top_p: float = 1.0, temperature: float = 0.0, max_num_rets: int = 1, generator=None,
593
+ always_add_bos : bool = False, guidance_scale: float = 7.5, num_inference_steps: int = 50):
594
+ """
595
+ Encode prompts into embeddings, and generates text and image outputs accordingly.
596
+
597
+ Args:
598
+ prompts: List of interleaved PIL.Image.Image and strings representing input to the model.
599
+ num_words: Maximum number of words to generate for. If num_words = 0, the model will run its forward pass and return the outputs.
600
+ min_word_tokens: Minimum number of actual words before generating an image.
601
+ ret_scale_factor: Proportion to scale [IMG] token logits by. A higher value may increase the probability of the model generating [IMG] outputs.
602
+ top_p: If set to < 1, the smallest set of tokens with highest probabilities that add up to top_p or higher are kept for generation.
603
+ temperature: Used to modulate logit distribution.
604
+ max_num_rets: Maximum number of images to return in one generation pass.
605
+ Returns:
606
+ return_outputs: List consisting of either str or List[PIL.Image.Image] objects, representing image-text interleaved model outputs.
607
+ """
608
+ input_embs = []
609
+ input_ids = []
610
+ add_bos = True
611
+
612
+ with torch.no_grad():
613
+ for p in prompts:
614
+ if type(p) == Image.Image:
615
+ # Encode as image.
616
+ pixel_values = utils.get_pixel_values_for_model(self.model.feature_extractor, p)
617
+ pixel_values = pixel_values.to(device=self.model.logit_scale.device, dtype=self.model.logit_scale.dtype)
618
+ pixel_values = pixel_values[None, ...]
619
+
620
+ visual_embs = self.model.get_visual_embs(pixel_values, mode='captioning') # (1, n_visual_tokens, D)
621
+ input_embs.append(visual_embs)
622
+ elif type(p) == str:
623
+ text_ids = self.model.tokenizer(p, add_special_tokens=add_bos, return_tensors="pt").input_ids.to(self.model.logit_scale.device)
624
+ # Only add <bos> once unless the flag is set.
625
+ if not always_add_bos:
626
+ add_bos = False
627
+
628
+ text_embs = self.model.input_embeddings(text_ids) # (1, T, D)
629
+ input_embs.append(text_embs)
630
+ input_ids.append(text_ids)
631
+ else:
632
+ raise ValueError(f'Input prompts should be either PIL.Image.Image or str types, got {type(p)} instead.')
633
+ input_embs = torch.cat(input_embs, dim=1)
634
+ input_ids = torch.cat(input_ids, dim=1)
635
+
636
+ if num_words == 0:
637
+ raise NotImplementedError('Generation not implemented for num_words=0.')
638
+ elif num_words > 0:
639
+ generated_ids, generated_embeddings, _ = self.model.generate(input_embs, num_words, min_word_tokens=min_word_tokens,
640
+ temperature=temperature, top_p=top_p, ret_scale_factor=ret_scale_factor, gen_scale_factor=gen_scale_factor)
641
+ embeddings = generated_embeddings[-1][:, input_embs.shape[1]:]
642
+
643
+ # Truncate to newline.
644
+ newline_token_id = self.model.tokenizer('\n', add_special_tokens=False).input_ids[0]
645
+ trunc_idx = 0
646
+ for j in range(generated_ids.shape[1]):
647
+ if generated_ids[0, j] == newline_token_id:
648
+ trunc_idx = j
649
+ break
650
+ if trunc_idx > 0:
651
+ generated_ids = generated_ids[:, :trunc_idx]
652
+ embeddings = embeddings[:, :trunc_idx]
653
+ else:
654
+ raise ValueError
655
+
656
+ # Save outputs as an interleaved list.
657
+ return_outputs = []
658
+ # Find up to max_num_rets [IMG] tokens, and their corresponding scores.
659
+ all_ret_idx = [i for i, x in enumerate(generated_ids[0, :] == self.model.retrieval_token_idx[0]) if x][:max_num_rets]
660
+ seen_image_idx = [] # Avoid showing the same image multiple times.
661
+
662
+ last_ret_idx = 0
663
+ if len(all_ret_idx) == 0:
664
+ # No [IMG] tokens.
665
+ caption = self.model.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
666
+ return_outputs.append(utils.truncate_caption(caption))
667
+ else:
668
+ for ret_idx in all_ret_idx:
669
+ assert generated_ids[0, ret_idx:ret_idx+self.model.num_tokens].cpu().detach().numpy().tolist() == self.model.retrieval_token_idx, (generated_ids[0, ret_idx:ret_idx+self.model.num_tokens], self.model.retrieval_token_idx)
670
+ raw_emb = embeddings[:, ret_idx:ret_idx+self.model.num_tokens, :] # (1, 8, 4096)
671
+ assert len(self.model.args.text_emb_layers) == 1
672
+
673
+ image_outputs = {
674
+ 'gen': [],
675
+ 'ret': [],
676
+ 'decision': None,
677
+ }
678
+
679
+ if self.emb_matrix is not None:
680
+ # Produce retrieval embedding.
681
+ ret_emb = self.model.ret_text_hidden_fcs[0](raw_emb, None)[:, 0, :] # (1, 256)
682
+ ret_emb = ret_emb / ret_emb.norm(dim=-1, keepdim=True)
683
+ ret_emb = ret_emb.type(self.emb_matrix.dtype) # (1, 256)
684
+ scores = self.emb_matrix @ ret_emb.T
685
+
686
+ # Downweight seen images.
687
+ for seen_idx in seen_image_idx:
688
+ scores[seen_idx, :] -= 1000
689
+
690
+ # Get the top 3 images for each image.
691
+ _, top_image_idx = scores.squeeze().topk(3)
692
+ for img_idx in top_image_idx:
693
+ # Find the first image that does not error out.
694
+ try:
695
+ seen_image_idx.append(img_idx)
696
+ img = utils.get_image_from_url(self.path_array[img_idx])
697
+ image_outputs['ret'].append((img, 'ret', scores[img_idx].item()))
698
+ if len(image_outputs) == max_num_rets:
699
+ break
700
+ except (UnidentifiedImageError, ConnectionError):
701
+ pass
702
+
703
+ # Make decision with MLP.
704
+ if self.decision_model is not None:
705
+ decision_emb = raw_emb[:, 0, :] # (1, 4096)
706
+ assert decision_emb.shape[1] == 4096, decision_emb.shape
707
+ max_ret_score = scores.max().reshape((1, 1)).clone().detach().to(device=decision_emb.device, dtype=decision_emb.dtype)
708
+ decision_logits = self.decision_model(torch.cat([decision_emb, max_ret_score], dim=-1))
709
+ probs = decision_logits.softmax(dim=-1).cpu().float().numpy().tolist()
710
+ image_outputs['decision'] = [self.idx2dec[decision_logits.argmax().item()]] + probs
711
+ else:
712
+ # If no embedding matrix is provided, generate instead.
713
+ image_outputs['decision'] = ['gen', [0, 1]]
714
+
715
+ # Produce generation embedding.
716
+ gen_prefix = ' '.join([f'[IMG{i}]' for i in range(self.model.args.num_tokens)])
717
+ gen_prefx_ids = self.model.tokenizer(gen_prefix, add_special_tokens=False, return_tensors="pt").input_ids.to(self.model.logit_scale.device)
718
+ gen_prefix_embs = self.model.input_embeddings(gen_prefx_ids) # (1, T, D)
719
+ gen_emb = self.model.gen_text_hidden_fcs[0](raw_emb, gen_prefix_embs) # (1, 77, 768)
720
+
721
+ if gen_emb.shape[1] != 77:
722
+ print(f"Padding {gen_emb.shape} with zeros")
723
+ bs = gen_emb.shape[0]
724
+ clip_emb = 768
725
+ gen_emb = gen_emb.reshape(bs, -1, clip_emb) # (bs, T, 768)
726
+ seq_len = gen_emb.shape[1]
727
+ gen_emb = torch.cat([gen_emb, torch.zeros((bs, 77 - seq_len, clip_emb), device=gen_emb.device, dtype=gen_emb.dtype)], dim=1)
728
+ print('Padded to', gen_emb.shape)
729
+
730
+ gen_emb = gen_emb.repeat(self.num_gen_images, 1, 1) # (self.num_gen_images, 77, 768)
731
+
732
+ # OPTIM(jykoh): Only generate if scores are low.
733
+ if self.load_sd:
734
+ # If num_gen_images > 8, split into multiple batches (for GPU memory reasons).
735
+ gen_max_bs = 8
736
+ gen_images = []
737
+ for i in range(0, self.num_gen_images, gen_max_bs):
738
+ gen_images.extend(
739
+ self.sd_pipe(prompt_embeds=gen_emb[i:i+gen_max_bs], generator=generator,
740
+ guidance_scale=guidance_scale, num_inference_steps=num_inference_steps).images)
741
+
742
+ all_gen_pixels = []
743
+ for img in gen_images:
744
+ pixel_values = utils.get_pixel_values_for_model(self.model.feature_extractor, img.resize((224, 224)).convert('RGB'))
745
+ pixel_values = pixel_values.to(device=self.model.logit_scale.device, dtype=self.model.logit_scale.dtype)
746
+ all_gen_pixels.append(pixel_values)
747
+
748
+ if self.emb_matrix is not None:
749
+ all_gen_pixels = torch.stack(all_gen_pixels, dim=0)
750
+ gen_visual_embs = self.model.get_visual_embs(all_gen_pixels, mode='retrieval') # (1, D)
751
+ gen_visual_embs = gen_visual_embs / gen_visual_embs.norm(dim=-1, keepdim=True)
752
+ gen_visual_embs = gen_visual_embs.type(self.emb_matrix.dtype)
753
+ gen_rank_scores = (gen_visual_embs @ ret_emb.T).squeeze()
754
+ sorted_score_idx = torch.argsort(-gen_rank_scores)
755
+
756
+ # Rank images by retrieval score.
757
+ if self.num_gen_images > 1:
758
+ image_outputs['gen'] = [(gen_images[idx], gen_rank_scores[idx].item()) for idx in sorted_score_idx]
759
+ else:
760
+ image_outputs['gen'] = [(gen_images[0], gen_rank_scores.item())]
761
+ else:
762
+ image_outputs['gen'] = [(gen_images[0], 0)]
763
+ else:
764
+ image_outputs['gen'] = [gen_emb]
765
+
766
+ caption = self.model.tokenizer.batch_decode(generated_ids[:, last_ret_idx:ret_idx], skip_special_tokens=True)[0]
767
+ last_ret_idx = ret_idx + 1
768
+ return_outputs.append(utils.truncate_caption(caption) + f' {gen_prefix}')
769
+ return_outputs.append(image_outputs)
770
+
771
+ return return_outputs
772
+
773
+ def get_log_likelihood_scores(
774
+ self, prompts: List):
775
+ """
776
+ Output the log likelihood of the given interleaved prompts.
777
+
778
+ Args:
779
+ prompts: List of interleaved PIL.Image.Image and strings representing input to the model.
780
+ Returns:
781
+ Log likelihood score of the prompt sequence.
782
+ """
783
+ input_embs = []
784
+ input_ids = []
785
+ add_bos = True
786
+
787
+ for p in prompts:
788
+ if type(p) == Image.Image:
789
+ # Encode as image.
790
+ pixel_values = utils.get_pixel_values_for_model(self.model.feature_extractor, p)
791
+ pixel_values = pixel_values.to(device=self.model.logit_scale.device, dtype=self.model.logit_scale.dtype)
792
+ pixel_values = pixel_values[None, ...]
793
+
794
+ visual_embs = self.model.get_visual_embs(pixel_values, mode='captioning') # (1, n_visual_tokens, D)
795
+ input_embs.append(visual_embs)
796
+ id_ = torch.zeros(visual_embs.shape[:2], dtype=torch.int64).to(visual_embs.device) - 100
797
+ input_ids.append(id_)
798
+ elif type(p) == str:
799
+ text_ids = self.model.tokenizer(p, add_special_tokens=True, return_tensors="pt").input_ids.to(self.model.logit_scale.device)
800
+ if not add_bos:
801
+ # Remove <bos> tag.
802
+ text_ids = text_ids[:, 1:]
803
+ else:
804
+ # Only add <bos> once.
805
+ add_bos = False
806
+
807
+ text_embs = self.model.input_embeddings(text_ids) # (1, T, D)
808
+ input_embs.append(text_embs)
809
+ input_ids.append(text_ids)
810
+ else:
811
+ raise ValueError(f'Input prompts should be either PIL.Image.Image or str types, got {type(p)} instead.')
812
+ input_embs = torch.cat(input_embs, dim=1)
813
+ input_ids = torch.cat(input_ids, dim=1)
814
+
815
+ outputs = self.model.lm(inputs_embeds=input_embs, labels=input_ids, use_cache=False, output_hidden_states=True)
816
+ return -outputs.loss.item()
817
+
818
+
819
+ def load_gill(embeddings_dir: str, model_args_path: str, model_ckpt_path: str, decision_model_path: str) -> GILL:
820
+ embs_paths = [s for s in glob.glob(os.path.join(embeddings_dir, 'cc3m*.npy'))]
821
+
822
+ if not os.path.exists(model_args_path):
823
+ raise ValueError(f'model_args.json does not exist at {model_args_path}.')
824
+ if not os.path.exists(model_ckpt_path):
825
+ raise ValueError(f'pretrained_ckpt.pth.tar does not exist at {model_ckpt_path}.')
826
+ if len(embs_paths) == 0:
827
+ print(f'cc3m*.npy files do not exist in {embeddings_dir}. Running the model without retrieval.')
828
+ path_array, emb_matrix = None, None
829
+ else:
830
+ # Load embeddings.
831
+ # Construct embedding matrix for nearest neighbor lookup.
832
+ path_array = []
833
+ emb_matrix = []
834
+
835
+ # These were precomputed for all CC3M images with `model.get_visual_embs(image, mode='retrieval')`.
836
+ for p in embs_paths:
837
+ with open(p, 'rb') as wf:
838
+ train_embs_data = pkl.load(wf)
839
+ path_array.extend(train_embs_data['paths'])
840
+ emb_matrix.extend(train_embs_data['embeddings'])
841
+ emb_matrix = np.stack(emb_matrix, axis=0)
842
+
843
+ # Number of paths should be equal to number of embeddings.
844
+ assert len(path_array) == emb_matrix.shape[0], (len(path_array), emb_matrix.shape)
845
+
846
+ with open(model_args_path, 'r') as f:
847
+ model_kwargs = json.load(f)
848
+
849
+ # Initialize tokenizer.
850
+ tokenizer = AutoTokenizer.from_pretrained(model_kwargs['opt_version'], use_fast=False)
851
+ if tokenizer.pad_token is None:
852
+ tokenizer.pad_token_id = tokenizer.eos_token_id
853
+ # Add an image token for loss masking (and visualization) purposes.
854
+ tokenizer.add_special_tokens({"cls_token": "<|image|>"}) # add special image token to tokenizer
855
+
856
+ # Add [IMG] tokens to the vocabulary.
857
+ model_kwargs['retrieval_token_idx'] = []
858
+ for i in range(model_kwargs['num_tokens']):
859
+ print(f'Adding [IMG{i}] token to vocabulary.')
860
+ print(f'Before adding new token, tokenizer("[IMG{i}]") =', tokenizer(f'[IMG{i}]', add_special_tokens=False))
861
+ num_added_tokens = tokenizer.add_tokens(f'[IMG{i}]')
862
+ print(f'After adding {num_added_tokens} new tokens, tokenizer("[IMG{i}]") =', tokenizer(f'[IMG{i}]', add_special_tokens=False))
863
+ ret_token_idx = tokenizer(f'[IMG{i}]', add_special_tokens=False).input_ids
864
+ assert len(ret_token_idx) == 1, ret_token_idx
865
+ model_kwargs['retrieval_token_idx'].append(ret_token_idx[0])
866
+ # Use the same RET tokens for generation.
867
+ model_kwargs['gen_token_idx'] = model_kwargs['retrieval_token_idx']
868
+
869
+ debug = False
870
+ if debug:
871
+ model_kwargs['opt_version'] = 'facebook/opt-125m'
872
+ model_kwargs['visual_encoder'] = 'openai/clip-vit-base-patch32'
873
+ decision_model_path = None
874
+
875
+ args = namedtuple('args', model_kwargs)(**model_kwargs)
876
+
877
+ # Initialize model for inference.
878
+ model = GILL(tokenizer, args, path_array=path_array, emb_matrix=emb_matrix,
879
+ load_sd=not debug, num_gen_images=1, decision_model_path=decision_model_path)
880
+ model = model.eval()
881
+ if not debug:
882
+ model = model.bfloat16()
883
+ model = model.cuda()
884
+
885
+ # Load pretrained linear mappings and [IMG] embeddings.
886
+ checkpoint = torch.load(model_ckpt_path)
887
+ state_dict = {}
888
+ # This is needed if we train with DDP.
889
+ for k, v in checkpoint['state_dict'].items():
890
+ state_dict[k.replace('module.', '')] = v
891
+ img_token_embeddings = state_dict['model.input_embeddings.weight'].cpu().detach()
892
+ del state_dict['model.input_embeddings.weight']
893
+
894
+ model.load_state_dict(state_dict, strict=False)
895
+ # Copy over the embeddings of the [IMG] tokens (while loading the others from the pretrained LLM).
896
+ with torch.no_grad():
897
+ if 'share_ret_gen' in model_kwargs:
898
+ assert model_kwargs['share_ret_gen'], 'Model loading only supports share_ret_gen=True for now.'
899
+ model.model.input_embeddings.weight[-model_kwargs['num_tokens']:, :].copy_(img_token_embeddings)
900
+
901
+ if len(embs_paths) > 0:
902
+ logit_scale = model.model.logit_scale.exp()
903
+ emb_matrix = torch.tensor(emb_matrix, dtype=logit_scale.dtype).to(logit_scale.device)
904
+ emb_matrix = emb_matrix / emb_matrix.norm(dim=1, keepdim=True)
905
+ emb_matrix = logit_scale * emb_matrix
906
+ model.emb_matrix = emb_matrix
907
+
908
+ return model
909
+
gill/utils.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from enum import Enum
2
+ import subprocess
3
+ import sys
4
+ import shutil
5
+ import torch
6
+ import torch.distributed as dist
7
+ from torchvision.transforms import functional as F
8
+ from torchvision import transforms as T
9
+ from transformers import AutoFeatureExtractor
10
+ from PIL import Image, ImageDraw, ImageFont, ImageOps
11
+ import random
12
+ import requests
13
+ from io import BytesIO
14
+
15
+
16
+ def dump_git_status(out_file=sys.stdout, exclude_file_patterns=['*.ipynb', '*.th', '*.sh', '*.txt', '*.json']):
17
+ """Logs git status to stdout."""
18
+ subprocess.call('git rev-parse HEAD', shell=True, stdout=out_file)
19
+ subprocess.call('echo', shell=True, stdout=out_file)
20
+ exclude_string = ''
21
+ subprocess.call('git --no-pager diff -- . {}'.format(exclude_string), shell=True, stdout=out_file)
22
+
23
+
24
+ def get_image_from_url(url: str):
25
+ response = requests.get(url)
26
+ img = Image.open(BytesIO(response.content))
27
+ img = img.resize((224, 224))
28
+ img = img.convert('RGB')
29
+ return img
30
+
31
+
32
+ def truncate_caption(caption: str) -> str:
33
+ """Truncate captions at periods and newlines."""
34
+ caption = caption.strip('\n')
35
+ trunc_index = caption.find('\n') + 1
36
+ if trunc_index <= 0:
37
+ trunc_index = caption.find('.') + 1
38
+ if trunc_index > 0:
39
+ caption = caption[:trunc_index]
40
+ return caption
41
+
42
+
43
+ def pad_to_size(x, size=256):
44
+ delta_w = size - x.size[0]
45
+ delta_h = size - x.size[1]
46
+ padding = (
47
+ delta_w // 2,
48
+ delta_h // 2,
49
+ delta_w - (delta_w // 2),
50
+ delta_h - (delta_h // 2),
51
+ )
52
+ new_im = ImageOps.expand(x, padding)
53
+ return new_im
54
+
55
+
56
+ class RandCropResize(object):
57
+
58
+ """
59
+ Randomly crops, then randomly resizes, then randomly crops again, an image. Mirroring the augmentations from https://arxiv.org/abs/2102.12092
60
+ """
61
+
62
+ def __init__(self, target_size):
63
+ self.target_size = target_size
64
+
65
+ def __call__(self, img):
66
+ img = pad_to_size(img, self.target_size)
67
+ d_min = min(img.size)
68
+ img = T.RandomCrop(size=d_min)(img)
69
+ t_min = min(d_min, round(9 / 8 * self.target_size))
70
+ t_max = min(d_min, round(12 / 8 * self.target_size))
71
+ t = random.randint(t_min, t_max + 1)
72
+ img = T.Resize(t)(img)
73
+ if min(img.size) < 256:
74
+ img = T.Resize(256)(img)
75
+ return T.RandomCrop(size=self.target_size)(img)
76
+
77
+
78
+ class SquarePad(object):
79
+ """Pads image to square.
80
+ From https://discuss.pytorch.org/t/how-to-resize-and-pad-in-a-torchvision-transforms-compose/71850/9
81
+ """
82
+ def __call__(self, image):
83
+ max_wh = max(image.size)
84
+ p_left, p_top = [(max_wh - s) // 2 for s in image.size]
85
+ p_right, p_bottom = [max_wh - (s+pad) for s, pad in zip(image.size, [p_left, p_top])]
86
+ padding = (p_left, p_top, p_right, p_bottom)
87
+ return F.pad(image, padding, 0, 'constant')
88
+
89
+
90
+ def create_image_of_text(text: str, width: int = 224, nrows: int = 2, color=(255, 255, 255), font=None) -> torch.Tensor:
91
+ """Creates a (3, nrows * 14, width) image of text.
92
+
93
+ Returns:
94
+ cap_img: (3, 14 * nrows, width) image of wrapped text.
95
+ """
96
+ height = 12
97
+ padding = 5
98
+ effective_width = width - 2 * padding
99
+ # Create a black image to draw text on.
100
+ cap_img = Image.new('RGB', (effective_width * nrows, height), color = (0, 0, 0))
101
+ draw = ImageDraw.Draw(cap_img)
102
+ draw.text((0, 0), text, color, font=font or ImageFont.load_default())
103
+ cap_img = F.convert_image_dtype(F.pil_to_tensor(cap_img), torch.float32) # (3, height, W * nrows)
104
+ cap_img = torch.split(cap_img, effective_width, dim=-1) # List of nrow elements of shape (3, height, W)
105
+ cap_img = torch.cat(cap_img, dim=1) # (3, height * nrows, W)
106
+ # Add zero padding.
107
+ cap_img = torch.nn.functional.pad(cap_img, [padding, padding, 0, padding])
108
+ return cap_img
109
+
110
+
111
+ def get_feature_extractor_for_model(model_name: str, image_size: int = 224, train: bool = True):
112
+ print(f'Using HuggingFace AutoFeatureExtractor for {model_name}.')
113
+ feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
114
+ return feature_extractor
115
+
116
+
117
+ def get_pixel_values_for_model(feature_extractor, img: Image.Image):
118
+ pixel_values = feature_extractor(img.convert('RGB'), return_tensors="pt").pixel_values[0, ...] # (3, H, W)
119
+ return pixel_values
120
+
121
+
122
+ def save_checkpoint(state, is_best, filename='checkpoint'):
123
+ torch.save(state, filename + '.pth.tar')
124
+ if is_best:
125
+ shutil.copyfile(filename + '.pth.tar', filename + '_best.pth.tar')
126
+
127
+
128
+ def accuracy(output, target, padding, topk=(1,)):
129
+ """Computes the accuracy over the k top predictions for the specified values of k"""
130
+ with torch.no_grad():
131
+ maxk = max(topk)
132
+ if output.shape[-1] < maxk:
133
+ print(f"[WARNING] Less than {maxk} predictions available. Using {output.shape[-1]} for topk.")
134
+
135
+ maxk = min(maxk, output.shape[-1])
136
+ batch_size = target.size(0)
137
+
138
+ # Take topk along the last dimension.
139
+ _, pred = output.topk(maxk, -1, True, True) # (N, T, topk)
140
+
141
+ mask = (target != padding).type(target.dtype)
142
+ target_expand = target[..., None].expand_as(pred)
143
+ correct = pred.eq(target_expand)
144
+ correct = correct * mask[..., None].expand_as(correct)
145
+
146
+ res = []
147
+ for k in topk:
148
+ correct_k = correct[..., :k].reshape(-1).float().sum(0, keepdim=True)
149
+ res.append(correct_k.mul_(100.0 / mask.sum()))
150
+ return res
151
+
152
+
153
+ def get_params_count(model, max_name_len: int = 60):
154
+ params = [(name[:max_name_len], p.numel(), str(tuple(p.shape)), p.requires_grad) for name, p in model.named_parameters()]
155
+ total_trainable_params = sum([x[1] for x in params if x[-1]])
156
+ total_nontrainable_params = sum([x[1] for x in params if not x[-1]])
157
+ return params, total_trainable_params, total_nontrainable_params
158
+
159
+
160
+ def get_params_count_str(model, max_name_len: int = 60):
161
+ padding = 70 # Hardcoded depending on desired amount of padding and separators.
162
+ params, total_trainable_params, total_nontrainable_params = get_params_count(model, max_name_len)
163
+ param_counts_text = ''
164
+ param_counts_text += '=' * (max_name_len + padding) + '\n'
165
+ param_counts_text += f'| {"Module":<{max_name_len}} | {"Trainable":<10} | {"Shape":>15} | {"Param Count":>12} |\n'
166
+ param_counts_text += '-' * (max_name_len + padding) + '\n'
167
+ for name, param_count, shape, trainable in params:
168
+ param_counts_text += f'| {name:<{max_name_len}} | {"True" if trainable else "False":<10} | {shape:>15} | {param_count:>12,} |\n'
169
+ param_counts_text += '-' * (max_name_len + padding) + '\n'
170
+ param_counts_text += f'| {"Total trainable params":<{max_name_len}} | {"":<10} | {"":<15} | {total_trainable_params:>12,} |\n'
171
+ param_counts_text += f'| {"Total non-trainable params":<{max_name_len}} | {"":<10} | {"":<15} | {total_nontrainable_params:>12,} |\n'
172
+ param_counts_text += '=' * (max_name_len + padding) + '\n'
173
+ return param_counts_text
174
+
175
+
176
+ class Summary(Enum):
177
+ NONE = 0
178
+ AVERAGE = 1
179
+ SUM = 2
180
+ COUNT = 3
181
+
182
+
183
+ class ProgressMeter(object):
184
+ def __init__(self, num_batches, meters, prefix=""):
185
+ self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
186
+ self.meters = meters
187
+ self.prefix = prefix
188
+
189
+ def display(self, batch):
190
+ entries = [self.prefix + self.batch_fmtstr.format(batch)]
191
+ entries += [str(meter) for meter in self.meters]
192
+ print('\t'.join(entries))
193
+
194
+ def display_summary(self):
195
+ entries = [" *"]
196
+ entries += [meter.summary() for meter in self.meters]
197
+ print(' '.join(entries))
198
+
199
+ def _get_batch_fmtstr(self, num_batches):
200
+ num_digits = len(str(num_batches // 1))
201
+ fmt = '{:' + str(num_digits) + 'd}'
202
+ return '[' + fmt + '/' + fmt.format(num_batches) + ']'
203
+
204
+
205
+ class AverageMeter(object):
206
+ """Computes and stores the average and current value"""
207
+ def __init__(self, name, fmt=':f', summary_type=Summary.AVERAGE):
208
+ self.name = name
209
+ self.fmt = fmt
210
+ self.summary_type = summary_type
211
+ self.reset()
212
+
213
+ def reset(self):
214
+ self.val = 0
215
+ self.avg = 0
216
+ self.sum = 0
217
+ self.count = 0
218
+
219
+ def update(self, val, n=1):
220
+ self.val = val
221
+ self.sum += val * n
222
+ self.count += n
223
+ self.avg = self.sum / self.count
224
+
225
+ def all_reduce(self):
226
+ device = "cuda" if torch.cuda.is_available() else "cpu"
227
+ total = torch.tensor([self.sum, self.count], dtype=torch.float32, device=device)
228
+ dist.all_reduce(total, dist.ReduceOp.SUM, async_op=False)
229
+ self.sum, self.count = total.tolist()
230
+ self.avg = self.sum / self.count
231
+
232
+ def __str__(self):
233
+ fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
234
+ return fmtstr.format(**self.__dict__)
235
+
236
+ def summary(self):
237
+ fmtstr = ''
238
+ if self.summary_type is Summary.NONE:
239
+ fmtstr = ''
240
+ elif self.summary_type is Summary.AVERAGE:
241
+ fmtstr = '{name} {avg:.3f}'
242
+ elif self.summary_type is Summary.SUM:
243
+ fmtstr = '{name} {sum:.3f}'
244
+ elif self.summary_type is Summary.COUNT:
245
+ fmtstr = '{name} {count:.3f}'
246
+ else:
247
+ raise ValueError('invalid summary type %r' % self.summary_type)
248
+
249
+ return fmtstr.format(**self.__dict__)
requirements.txt ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ attrs==22.2.0
2
+ certifi==2022.12.7
3
+ charset-normalizer==3.0.1
4
+ contourpy==1.0.7
5
+ cycler==0.11.0
6
+ einops==0.4.1
7
+ exceptiongroup==1.1.0
8
+ filelock==3.9.0
9
+ fonttools==4.38.0
10
+ huggingface-hub==0.12.0
11
+ idna==3.4
12
+ iniconfig==2.0.0
13
+ kiwisolver==1.4.4
14
+ matplotlib==3.6.3
15
+ numpy==1.24.2
16
+ packaging==23.0
17
+ Pillow==9.4.0
18
+ pluggy==1.0.0
19
+ pyparsing==3.0.9
20
+ pytest==7.2.1
21
+ python-dateutil==2.8.2
22
+ PyYAML==6.0
23
+ regex==2022.10.31
24
+ requests==2.28.2
25
+ six==1.16.0
26
+ tokenizers==0.12.1
27
+ tomli==2.0.1
28
+ torch==1.11.0
29
+ torchaudio==0.11.0
30
+ torchmetrics==0.9.3
31
+ torchvision==0.12.0
32
+ tqdm==4.64.1
33
+ transformers==4.21.3
34
+ typing_extensions==4.4.0
35
+ urllib3==1.26.14
36
+ warmup-scheduler==0.3.0
share_btn.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from https://huggingface.co/spaces/haoheliu/audioldm-text-to-audio-generation/blob/79681cd8cb235160a27cdd100673346eb1784e53/share_btn.py
2
+
3
+ community_icon_html = """<svg id="share-btn-share-icon" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32">
4
+ <path d="M20.6081 3C21.7684 3 22.8053 3.49196 23.5284 4.38415C23.9756 4.93678 24.4428 5.82749 24.4808 7.16133C24.9674 7.01707 25.4353 6.93643 25.8725 6.93643C26.9833 6.93643 27.9865 7.37587 28.696 8.17411C29.6075 9.19872 30.0124 10.4579 29.8361 11.7177C29.7523 12.3177 29.5581 12.8555 29.2678 13.3534C29.8798 13.8646 30.3306 14.5763 30.5485 15.4322C30.719 16.1032 30.8939 17.5006 29.9808 18.9403C30.0389 19.0342 30.0934 19.1319 30.1442 19.2318C30.6932 20.3074 30.7283 21.5229 30.2439 22.6548C29.5093 24.3704 27.6841 25.7219 24.1397 27.1727C21.9347 28.0753 19.9174 28.6523 19.8994 28.6575C16.9842 29.4379 14.3477 29.8345 12.0653 29.8345C7.87017 29.8345 4.8668 28.508 3.13831 25.8921C0.356375 21.6797 0.754104 17.8269 4.35369 14.1131C6.34591 12.058 7.67023 9.02782 7.94613 8.36275C8.50224 6.39343 9.97271 4.20438 12.4172 4.20438H12.4179C12.6236 4.20438 12.8314 4.2214 13.0364 4.25468C14.107 4.42854 15.0428 5.06476 15.7115 6.02205C16.4331 5.09583 17.134 4.359 17.7682 3.94323C18.7242 3.31737 19.6794 3 20.6081 3ZM20.6081 5.95917C20.2427 5.95917 19.7963 6.1197 19.3039 6.44225C17.7754 7.44319 14.8258 12.6772 13.7458 14.7131C13.3839 15.3952 12.7655 15.6837 12.2086 15.6837C11.1036 15.6837 10.2408 14.5497 12.1076 13.1085C14.9146 10.9402 13.9299 7.39584 12.5898 7.1776C12.5311 7.16799 12.4731 7.16355 12.4172 7.16355C11.1989 7.16355 10.6615 9.33114 10.6615 9.33114C10.6615 9.33114 9.0863 13.4148 6.38031 16.206C3.67434 18.998 3.5346 21.2388 5.50675 24.2246C6.85185 26.2606 9.42666 26.8753 12.0653 26.8753C14.8021 26.8753 17.6077 26.2139 19.1799 25.793C19.2574 25.7723 28.8193 22.984 27.6081 20.6107C27.4046 20.212 27.0693 20.0522 26.6471 20.0522C24.9416 20.0522 21.8393 22.6726 20.5057 22.6726C20.2076 22.6726 19.9976 22.5416 19.9116 22.222C19.3433 20.1173 28.552 19.2325 27.7758 16.1839C27.639 15.6445 27.2677 15.4256 26.746 15.4263C24.4923 15.4263 19.4358 19.5181 18.3759 19.5181C18.2949 19.5181 18.2368 19.4937 18.2053 19.4419C17.6743 18.557 17.9653 17.9394 21.7082 15.6009C25.4511 13.2617 28.0783 11.8545 26.5841 10.1752C26.4121 9.98141 26.1684 9.8956 25.8725 9.8956C23.6001 9.89634 18.2311 14.9403 18.2311 14.9403C18.2311 14.9403 16.7821 16.496 15.9057 16.496C15.7043 16.496 15.533 16.4139 15.4169 16.2112C14.7956 15.1296 21.1879 10.1286 21.5484 8.06535C21.7928 6.66715 21.3771 5.95917 20.6081 5.95917Z" fill="#FF9D00"></path>
5
+ <path d="M5.50686 24.2246C3.53472 21.2387 3.67446 18.9979 6.38043 16.206C9.08641 13.4147 10.6615 9.33111 10.6615 9.33111C10.6615 9.33111 11.2499 6.95933 12.59 7.17757C13.93 7.39581 14.9139 10.9401 12.1069 13.1084C9.29997 15.276 12.6659 16.7489 13.7459 14.713C14.8258 12.6772 17.7747 7.44316 19.304 6.44221C20.8326 5.44128 21.9089 6.00204 21.5484 8.06532C21.188 10.1286 14.795 15.1295 15.4171 16.2118C16.0391 17.2934 18.2312 14.9402 18.2312 14.9402C18.2312 14.9402 25.0907 8.49588 26.5842 10.1752C28.0776 11.8545 25.4512 13.2616 21.7082 15.6008C17.9646 17.9393 17.6744 18.557 18.2054 19.4418C18.7372 20.3266 26.9998 13.1351 27.7759 16.1838C28.5513 19.2324 19.3434 20.1173 19.9117 22.2219C20.48 24.3274 26.3979 18.2382 27.6082 20.6107C28.8193 22.9839 19.2574 25.7722 19.18 25.7929C16.0914 26.62 8.24723 28.3726 5.50686 24.2246Z" fill="#FFD21E"></path>
6
+ </svg>"""
7
+
8
+ loading_icon_html = """<svg id="share-btn-loading-icon" style="display:none;" class="animate-spin"
9
+ style="color: #ffffff;
10
+ "
11
+ xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" fill="none" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 24 24"><circle style="opacity: 0.25;" cx="12" cy="12" r="10" stroke="white" stroke-width="4"></circle><path style="opacity: 0.75;" fill="white" d="M4 12a8 8 0 018-8V0C5.373 0 0 5.373 0 12h4zm2 5.291A7.962 7.962 0 014 12H0c0 3.042 1.135 5.824 3 7.938l3-2.647z"></path></svg>"""
12
+
13
+ share_js = """
14
+ async () => {
15
+ const html2canvas = (await import('https://cdnjs.cloudflare.com/ajax/libs/html2canvas/1.4.1/html2canvas.esm.js')).default;
16
+ async function uploadFile(file) {
17
+ console.log(file.type)
18
+ const UPLOAD_URL = 'https://huggingface.co/uploads';
19
+ const response = await fetch(UPLOAD_URL, {
20
+ method: 'POST',
21
+ headers: {
22
+ 'Content-Type': file.type,
23
+ 'X-Requested-With': 'XMLHttpRequest',
24
+ },
25
+ body: file, /// <- File inherits from Blob
26
+ });
27
+ const url = await response.text();
28
+ return url;
29
+ }
30
+ async function getImageFile(div) {
31
+ return new Promise((resolve, reject) =>
32
+ html2canvas(div)
33
+ .then((canvas) => {
34
+ const imageBlob = canvas.toBlob((blob) => {
35
+ const imageId = Date.now();
36
+ const fileName = "FROMAGe-" + imageId + ".jpg";
37
+ resolve(new File([blob], fileName, { type: 'image/jpeg' }));
38
+ }, 'image/jpeg', 0.95);
39
+ })
40
+
41
+ )
42
+ }
43
+
44
+ const gradioEl = document.querySelector("gradio-app").shadowRoot || document.querySelector('body > gradio-app');
45
+ const chatbotEl = gradioEl.querySelector('#chatbot')
46
+ const imageFile = await getImageFile(chatbotEl);
47
+ console.log(imageFile);
48
+ const urlChatbotImage = await uploadFile(imageFile);
49
+ console.log(urlChatbotImage);
50
+ let titleTxt = `FROMAGe Example`;
51
+
52
+ //const shareBtnEl = gradioEl.querySelector('#share-btn');
53
+ //shareBtnEl.style.pointerEvents = 'none';
54
+ const descriptionMd = `
55
+
56
+ <img src='${urlChatbotImage}'>
57
+ `;
58
+ const params = new URLSearchParams({
59
+ title: titleTxt,
60
+ description: descriptionMd,
61
+ });
62
+ const paramsStr = params.toString();
63
+ window.open(`https://huggingface.co/spaces/jykoh/fromage/discussions/new?${paramsStr}`, '_blank');
64
+ //shareBtnEl.style.removeProperty('pointer-events');
65
+ }
66
+ """
67
+
68
+ save_js = """
69
+ async () => {
70
+ const html2canvas = (await import('https://cdnjs.cloudflare.com/ajax/libs/html2canvas/1.4.1/html2canvas.esm.js')).default;
71
+
72
+ function saveAs(uri, filename) {
73
+ var link = document.createElement('a');
74
+ if (typeof link.download === 'string') {
75
+ link.href = uri;
76
+ link.download = filename;
77
+
78
+ //Firefox requires the link to be in the body
79
+ document.body.appendChild(link);
80
+
81
+ //simulate click
82
+ link.click();
83
+
84
+ //remove the link when done
85
+ document.body.removeChild(link);
86
+ } else {
87
+ window.open(uri);
88
+ }
89
+ }
90
+
91
+ async function getImageFile(div) {
92
+ return new Promise((resolve, reject) =>
93
+ html2canvas(div)
94
+ .then((canvas) => {
95
+ const imageId = Date.now();
96
+ const fileName = "FROMAGe-" + imageId + ".png";
97
+ saveAs(canvas.toDataURL(), fileName);
98
+ })
99
+
100
+ )
101
+ }
102
+ const gradioEl = document.querySelector("gradio-app").shadowRoot || document.querySelector('body > gradio-app');
103
+ const chatbotEl = gradioEl.querySelector('#chatbot')
104
+ const imageFile = await getImageFile(chatbotEl);
105
+ console.log(imageFile);
106
+ }
107
+ """