hysts HF staff commited on
Commit
2881ba6
1 Parent(s): 639c25d

Use diffusers implementation

Browse files
Files changed (9) hide show
  1. .gitmodules +0 -3
  2. .pre-commit-config.yaml +0 -1
  3. .vscode/settings.json +18 -0
  4. Dockerfile +2 -2
  5. README.md +2 -0
  6. app.py +46 -16
  7. model.py +74 -502
  8. requirements.txt +8 -12
  9. unidiffuser +0 -1
.gitmodules DELETED
@@ -1,3 +0,0 @@
1
- [submodule "unidiffuser"]
2
- path = unidiffuser
3
- url = https://github.com/thu-ml/unidiffuser
 
 
 
 
.pre-commit-config.yaml CHANGED
@@ -1,4 +1,3 @@
1
- exclude: patch
2
  repos:
3
  - repo: https://github.com/pre-commit/pre-commit-hooks
4
  rev: v4.2.0
 
 
1
  repos:
2
  - repo: https://github.com/pre-commit/pre-commit-hooks
3
  rev: v4.2.0
.vscode/settings.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "python.linting.enabled": true,
3
+ "python.linting.flake8Enabled": true,
4
+ "python.linting.pylintEnabled": false,
5
+ "python.linting.lintOnSave": true,
6
+ "python.formatting.provider": "yapf",
7
+ "python.formatting.yapfArgs": [
8
+ "--style={based_on_style: pep8, indent_width: 4, blank_line_before_nested_class_or_def: false, spaces_before_comment: 2, split_before_logical_operator: true}"
9
+ ],
10
+ "[python]": {
11
+ "editor.formatOnType": true,
12
+ "editor.codeActionsOnSave": {
13
+ "source.organizeImports": true
14
+ }
15
+ },
16
+ "editor.formatOnSave": true,
17
+ "files.insertFinalNewline": true
18
+ }
Dockerfile CHANGED
@@ -32,13 +32,13 @@ WORKDIR ${HOME}/app
32
 
33
  RUN curl https://pyenv.run | bash
34
  ENV PATH=${HOME}/.pyenv/shims:${HOME}/.pyenv/bin:${PATH}
35
- ARG PYTHON_VERSION=3.10.10
36
  RUN pyenv install ${PYTHON_VERSION} && \
37
  pyenv global ${PYTHON_VERSION} && \
38
  pyenv rehash && \
39
  pip install --no-cache-dir -U pip setuptools wheel
40
 
41
- RUN pip install --no-cache-dir -U torch==1.13.1 torchvision==0.14.1
42
  COPY --chown=1000 requirements.txt /tmp/requirements.txt
43
  RUN pip install --no-cache-dir -U -r /tmp/requirements.txt
44
 
 
32
 
33
  RUN curl https://pyenv.run | bash
34
  ENV PATH=${HOME}/.pyenv/shims:${HOME}/.pyenv/bin:${PATH}
35
+ ARG PYTHON_VERSION=3.10.11
36
  RUN pyenv install ${PYTHON_VERSION} && \
37
  pyenv global ${PYTHON_VERSION} && \
38
  pyenv rehash && \
39
  pip install --no-cache-dir -U pip setuptools wheel
40
 
41
+ RUN pip install --no-cache-dir -U torch==2.0.1 torchvision==0.15.2
42
  COPY --chown=1000 requirements.txt /tmp/requirements.txt
43
  RUN pip install --no-cache-dir -U -r /tmp/requirements.txt
44
 
README.md CHANGED
@@ -10,3 +10,5 @@ license: other
10
  ---
11
 
12
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
10
  ---
11
 
12
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
13
+
14
+ https://arxiv.org/abs/2303.06555
app.py CHANGED
@@ -3,8 +3,11 @@
3
  from __future__ import annotations
4
 
5
  import os
 
6
 
7
  import gradio as gr
 
 
8
 
9
  from model import Model
10
 
@@ -13,9 +16,19 @@ DESCRIPTION = '# [UniDiffuser](https://github.com/thu-ml/unidiffuser)'
13
  SPACE_ID = os.getenv('SPACE_ID')
14
  if SPACE_ID is not None:
15
  DESCRIPTION += f'\n<p>For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings. <a href="https://huggingface.co/spaces/{SPACE_ID}?duplicate=true"><img style="display: inline; margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space" /></a></p>'
 
 
16
 
17
  model = Model()
18
 
 
 
 
 
 
 
 
 
19
 
20
  def create_demo(mode_name: str) -> gr.Blocks:
21
  with gr.Blocks() as demo:
@@ -28,7 +41,7 @@ def create_demo(mode_name: str) -> gr.Blocks:
28
  'joint',
29
  'i',
30
  't',
31
- 'i2ti2',
32
  't2i2t',
33
  ],
34
  value=mode_name,
@@ -37,28 +50,26 @@ def create_demo(mode_name: str) -> gr.Blocks:
37
  max_lines=1,
38
  visible=mode_name in ['t2i', 't2i2t'])
39
  image = gr.Image(label='Input image',
40
- type='filepath',
41
  visible=mode_name in ['i2t', 'i2t2i'])
42
  run_button = gr.Button('Run')
43
  with gr.Accordion('Advanced options', open=False):
44
- seed = gr.Slider(
45
- label='Seed',
46
- minimum=-1,
47
- maximum=1000000,
48
- step=1,
49
- value=-1,
50
- info=
51
- 'If set to -1, a different seed will be used each time.'
52
- )
53
  num_steps = gr.Slider(label='Steps',
54
  minimum=1,
55
  maximum=100,
56
- value=50,
57
  step=1)
58
  guidance_scale = gr.Slider(label='Guidance Scale',
59
  minimum=0.1,
60
  maximum=30.0,
61
- value=7.0,
62
  step=0.1)
63
  with gr.Column():
64
  result_image = gr.Image(label='Generated image',
@@ -80,8 +91,27 @@ def create_demo(mode_name: str) -> gr.Blocks:
80
  result_text,
81
  ]
82
 
83
- prompt.submit(fn=model.run, inputs=inputs, outputs=outputs)
84
- run_button.click(fn=model.run, inputs=inputs, outputs=outputs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  return demo
86
 
87
 
@@ -102,4 +132,4 @@ with gr.Blocks(css='style.css') as demo:
102
  create_demo('t')
103
  with gr.TabItem('text variation'):
104
  create_demo('t2i2t')
105
- demo.queue(api_open=False).launch()
 
3
  from __future__ import annotations
4
 
5
  import os
6
+ import random
7
 
8
  import gradio as gr
9
+ import numpy as np
10
+ import torch
11
 
12
  from model import Model
13
 
 
16
  SPACE_ID = os.getenv('SPACE_ID')
17
  if SPACE_ID is not None:
18
  DESCRIPTION += f'\n<p>For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings. <a href="https://huggingface.co/spaces/{SPACE_ID}?duplicate=true"><img style="display: inline; margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space" /></a></p>'
19
+ if not torch.cuda.is_available():
20
+ DESCRIPTION += '\n<p>Running on CPU 🥶</p>'
21
 
22
  model = Model()
23
 
24
+ MAX_SEED = np.iinfo(np.int32).max
25
+
26
+
27
+ def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
28
+ if randomize_seed:
29
+ seed = random.randint(0, MAX_SEED)
30
+ return seed
31
+
32
 
33
  def create_demo(mode_name: str) -> gr.Blocks:
34
  with gr.Blocks() as demo:
 
41
  'joint',
42
  'i',
43
  't',
44
+ 'i2t2i',
45
  't2i2t',
46
  ],
47
  value=mode_name,
 
50
  max_lines=1,
51
  visible=mode_name in ['t2i', 't2i2t'])
52
  image = gr.Image(label='Input image',
53
+ type='pil',
54
  visible=mode_name in ['i2t', 'i2t2i'])
55
  run_button = gr.Button('Run')
56
  with gr.Accordion('Advanced options', open=False):
57
+ seed = gr.Slider(label='Seed',
58
+ minimum=0,
59
+ maximum=MAX_SEED,
60
+ step=1,
61
+ value=0)
62
+ randomize_seed = gr.Checkbox(label='Randomize seed',
63
+ value=True)
 
 
64
  num_steps = gr.Slider(label='Steps',
65
  minimum=1,
66
  maximum=100,
67
+ value=20,
68
  step=1)
69
  guidance_scale = gr.Slider(label='Guidance Scale',
70
  minimum=0.1,
71
  maximum=30.0,
72
+ value=8.0,
73
  step=0.1)
74
  with gr.Column():
75
  result_image = gr.Image(label='Generated image',
 
91
  result_text,
92
  ]
93
 
94
+ prompt.submit(
95
+ fn=randomize_seed_fn,
96
+ inputs=[seed, randomize_seed],
97
+ outputs=seed,
98
+ queue=False,
99
+ ).then(
100
+ fn=model.run,
101
+ inputs=inputs,
102
+ outputs=outputs,
103
+ )
104
+ run_button.click(
105
+ fn=randomize_seed_fn,
106
+ inputs=[seed, randomize_seed],
107
+ outputs=seed,
108
+ queue=False,
109
+ ).then(
110
+ fn=model.run,
111
+ inputs=inputs,
112
+ outputs=outputs,
113
+ api_name=f'run_{mode_name}',
114
+ )
115
  return demo
116
 
117
 
 
132
  create_demo('t')
133
  with gr.TabItem('text variation'):
134
  create_demo('t2i2t')
135
+ demo.queue(max_size=15).launch()
model.py CHANGED
@@ -1,515 +1,87 @@
1
  from __future__ import annotations
2
 
3
- import pathlib
4
- import random
5
- import sys
6
- from typing import Callable
7
-
8
- import clip
9
- import einops
10
- import numpy as np
11
  import PIL.Image
12
  import torch
13
- from huggingface_hub import snapshot_download
14
-
15
- repo_dir = pathlib.Path(__file__).parent
16
- submodule_dir = repo_dir / 'unidiffuser'
17
- sys.path.append(submodule_dir.as_posix())
18
-
19
- import utils
20
- from configs.sample_unidiffuser_v1 import get_config
21
- from dpm_solver_pp import DPM_Solver, NoiseScheduleVP
22
- from libs.autoencoder import FrozenAutoencoderKL
23
- from libs.autoencoder import get_model as get_autoencoder
24
- from libs.caption_decoder import CaptionDecoder
25
- from libs.clip import FrozenCLIPEmbedder
26
-
27
- model_dir = repo_dir / 'models'
28
- if not model_dir.exists():
29
- snapshot_download('thu-ml/unidiffuser-v1',
30
- repo_type='model',
31
- local_dir=model_dir)
32
-
33
-
34
- def stable_diffusion_beta_schedule(linear_start=0.00085,
35
- linear_end=0.0120,
36
- n_timestep=1000):
37
- _betas = (torch.linspace(linear_start**0.5,
38
- linear_end**0.5,
39
- n_timestep,
40
- dtype=torch.float64)**2)
41
- return _betas.numpy()
42
 
43
 
44
  class Model:
45
  def __init__(self):
46
  self.device = torch.device(
47
  'cuda:0' if torch.cuda.is_available() else 'cpu')
48
- self.config = get_config()
49
-
50
- self.nnet = self.load_model()
51
- self.caption_decoder = CaptionDecoder(device=self.device,
52
- **self.config.caption_decoder)
53
- self.clip_text_model = self.load_clip_text_model()
54
- self.autoencoder = self.load_autoencoder()
55
-
56
- self.clip_img_model, self.clip_img_model_preprocess = clip.load(
57
- 'ViT-B/32', device=self.device, jit=False)
58
- self.empty_context = self.clip_text_model.encode([''])[0]
59
-
60
- self.betas = stable_diffusion_beta_schedule()
61
- self.N = len(self.betas)
62
-
63
- @property
64
- def use_caption_decoder(self) -> bool:
65
- return (self.config.text_dim < self.config.clip_text_dim
66
- or self.config.mode != 't2i')
67
-
68
- def load_model(self,
69
- model_path: str = 'models/uvit_v1.pth') -> torch.nn.Module:
70
- model = utils.get_nnet(**self.config.nnet)
71
- model.load_state_dict(torch.load(model_path, map_location='cpu'))
72
- model.to(self.device)
73
- model.eval()
74
- return model
75
-
76
- def load_clip_text_model(self) -> FrozenCLIPEmbedder:
77
- clip_text_model = FrozenCLIPEmbedder(device=self.device)
78
- clip_text_model.to(self.device)
79
- clip_text_model.eval()
80
- return clip_text_model
81
-
82
- def load_autoencoder(self) -> FrozenAutoencoderKL:
83
- autoencoder = get_autoencoder(**self.config.autoencoder)
84
- autoencoder.to(self.device)
85
- return autoencoder
86
-
87
- def split(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
88
- C, H, W = self.config.z_shape
89
- z_dim = C * H * W
90
- z, clip_img = x.split([z_dim, self.config.clip_img_dim], dim=1)
91
- z = einops.rearrange(z, 'B (C H W) -> B C H W', C=C, H=H, W=W)
92
- clip_img = einops.rearrange(clip_img,
93
- 'B (L D) -> B L D',
94
- L=1,
95
- D=self.config.clip_img_dim)
96
- return z, clip_img
97
-
98
- @staticmethod
99
- def combine(z, clip_img):
100
- z = einops.rearrange(z, 'B C H W -> B (C H W)')
101
- clip_img = einops.rearrange(clip_img, 'B L D -> B (L D)')
102
- return torch.concat([z, clip_img], dim=-1)
103
-
104
- def t2i_nnet(
105
- self, x, timesteps, text
106
- ): # text is the low dimension version of the text clip embedding
107
- """
108
- 1. calculate the conditional model output
109
- 2. calculate unconditional model output
110
- config.sample.t2i_cfg_mode == 'empty_token': using the original cfg with the empty string
111
- config.sample.t2i_cfg_mode == 'true_uncond: using the unconditional model learned by our method
112
- 3. return linear combination of conditional output and unconditional output
113
- """
114
- z, clip_img = self.split(x)
115
-
116
- t_text = torch.zeros(timesteps.size(0),
117
- dtype=torch.int,
118
- device=self.device)
119
-
120
- z_out, clip_img_out, text_out = self.nnet(
121
- z,
122
- clip_img,
123
- text=text,
124
- t_img=timesteps,
125
- t_text=t_text,
126
- data_type=torch.zeros_like(
127
- t_text, device=self.device, dtype=torch.int) +
128
- self.config.data_type)
129
- x_out = self.combine(z_out, clip_img_out)
130
-
131
- if self.config.sample.scale == 0.:
132
- return x_out
133
-
134
- if self.config.sample.t2i_cfg_mode == 'empty_token':
135
- _empty_context = einops.repeat(self.empty_context,
136
- 'L D -> B L D',
137
- B=x.size(0))
138
- if self.use_caption_decoder:
139
- _empty_context = self.caption_decoder.encode_prefix(
140
- _empty_context)
141
- z_out_uncond, clip_img_out_uncond, text_out_uncond = self.nnet(
142
- z,
143
- clip_img,
144
- text=_empty_context,
145
- t_img=timesteps,
146
- t_text=t_text,
147
- data_type=torch.zeros_like(
148
- t_text, device=self.device, dtype=torch.int) +
149
- self.config.data_type)
150
- x_out_uncond = self.combine(z_out_uncond, clip_img_out_uncond)
151
- elif self.config.sample.t2i_cfg_mode == 'true_uncond':
152
- text_N = torch.randn_like(text) # 3 other possible choices
153
- z_out_uncond, clip_img_out_uncond, text_out_uncond = self.nnet(
154
- z,
155
- clip_img,
156
- text=text_N,
157
- t_img=timesteps,
158
- t_text=torch.ones_like(timesteps) * self.N,
159
- data_type=torch.zeros_like(
160
- t_text, device=self.device, dtype=torch.int) +
161
- self.config.data_type)
162
- x_out_uncond = self.combine(z_out_uncond, clip_img_out_uncond)
163
  else:
164
- raise NotImplementedError
165
-
166
- return x_out + self.config.sample.scale * (x_out - x_out_uncond)
167
-
168
- def i_nnet(self, x, timesteps):
169
- z, clip_img = self.split(x)
170
- text = torch.randn(x.size(0),
171
- 77,
172
- self.config.text_dim,
173
- device=self.device)
174
- t_text = torch.ones_like(timesteps) * self.N
175
- z_out, clip_img_out, text_out = self.nnet(
176
- z,
177
- clip_img,
178
- text=text,
179
- t_img=timesteps,
180
- t_text=t_text,
181
- data_type=torch.zeros_like(
182
- t_text, device=self.device, dtype=torch.int) +
183
- self.config.data_type)
184
- x_out = self.combine(z_out, clip_img_out)
185
- return x_out
186
-
187
- def t_nnet(self, x, timesteps):
188
- z = torch.randn(x.size(0), *self.config.z_shape, device=self.device)
189
- clip_img = torch.randn(x.size(0),
190
- 1,
191
- self.config.clip_img_dim,
192
- device=self.device)
193
- z_out, clip_img_out, text_out = self.nnet(
194
- z,
195
- clip_img,
196
- text=x,
197
- t_img=torch.ones_like(timesteps) * self.N,
198
- t_text=timesteps,
199
- data_type=torch.zeros_like(
200
- timesteps, device=self.device, dtype=torch.int) +
201
- self.config.data_type)
202
- return text_out
203
-
204
- def i2t_nnet(self, x, timesteps, z, clip_img):
205
- """
206
- 1. calculate the conditional model output
207
- 2. calculate unconditional model output
208
- 3. return linear combination of conditional output and unconditional output
209
- """
210
- t_img = torch.zeros(timesteps.size(0),
211
- dtype=torch.int,
212
- device=self.device)
213
-
214
- z_out, clip_img_out, text_out = self.nnet(
215
- z,
216
- clip_img,
217
- text=x,
218
- t_img=t_img,
219
- t_text=timesteps,
220
- data_type=torch.zeros_like(
221
- t_img, device=self.device, dtype=torch.int) +
222
- self.config.data_type)
223
-
224
- if self.config.sample.scale == 0.:
225
- return text_out
226
-
227
- z_N = torch.randn_like(z) # 3 other possible choices
228
- clip_img_N = torch.randn_like(clip_img)
229
- z_out_uncond, clip_img_out_uncond, text_out_uncond = self.nnet(
230
- z_N,
231
- clip_img_N,
232
- text=x,
233
- t_img=torch.ones_like(timesteps) * self.N,
234
- t_text=timesteps,
235
- data_type=torch.zeros_like(
236
- timesteps, device=self.device, dtype=torch.int) +
237
- self.config.data_type)
238
-
239
- return text_out + self.config.sample.scale * (text_out -
240
- text_out_uncond)
241
-
242
- def split_joint(self, x):
243
- C, H, W = self.config.z_shape
244
- z_dim = C * H * W
245
- z, clip_img, text = x.split(
246
- [z_dim, self.config.clip_img_dim, 77 * self.config.text_dim],
247
- dim=1)
248
- z = einops.rearrange(z, 'B (C H W) -> B C H W', C=C, H=H, W=W)
249
- clip_img = einops.rearrange(clip_img,
250
- 'B (L D) -> B L D',
251
- L=1,
252
- D=self.config.clip_img_dim)
253
- text = einops.rearrange(text,
254
- 'B (L D) -> B L D',
255
- L=77,
256
- D=self.config.text_dim)
257
- return z, clip_img, text
258
-
259
- @staticmethod
260
- def combine_joint(z: torch.Tensor, clip_img: torch.Tensor,
261
- text: torch.Tensor) -> torch.Tensor:
262
- z = einops.rearrange(z, 'B C H W -> B (C H W)')
263
- clip_img = einops.rearrange(clip_img, 'B L D -> B (L D)')
264
- text = einops.rearrange(text, 'B L D -> B (L D)')
265
- return torch.concat([z, clip_img, text], dim=-1)
266
-
267
- def joint_nnet(self, x, timesteps):
268
- z, clip_img, text = self.split_joint(x)
269
- z_out, clip_img_out, text_out = self.nnet(
270
- z,
271
- clip_img,
272
- text=text,
273
- t_img=timesteps,
274
- t_text=timesteps,
275
- data_type=torch.zeros_like(
276
- timesteps, device=self.device, dtype=torch.int) +
277
- self.config.data_type)
278
- x_out = self.combine_joint(z_out, clip_img_out, text_out)
279
-
280
- if self.config.sample.scale == 0.:
281
- return x_out
282
-
283
- z_noise = torch.randn(x.size(0),
284
- *self.config.z_shape,
285
- device=self.device)
286
- clip_img_noise = torch.randn(x.size(0),
287
- 1,
288
- self.config.clip_img_dim,
289
- device=self.device)
290
- text_noise = torch.randn(x.size(0),
291
- 77,
292
- self.config.text_dim,
293
- device=self.device)
294
-
295
- _, _, text_out_uncond = self.nnet(
296
- z_noise,
297
- clip_img_noise,
298
- text=text,
299
- t_img=torch.ones_like(timesteps) * self.N,
300
- t_text=timesteps,
301
- data_type=torch.zeros_like(
302
- timesteps, device=self.device, dtype=torch.int) +
303
- self.config.data_type)
304
- z_out_uncond, clip_img_out_uncond, _ = self.nnet(
305
- z,
306
- clip_img,
307
- text=text_noise,
308
- t_img=timesteps,
309
- t_text=torch.ones_like(timesteps) * self.N,
310
- data_type=torch.zeros_like(
311
- timesteps, device=self.device, dtype=torch.int) +
312
- self.config.data_type)
313
-
314
- x_out_uncond = self.combine_joint(z_out_uncond, clip_img_out_uncond,
315
- text_out_uncond)
316
-
317
- return x_out + self.config.sample.scale * (x_out - x_out_uncond)
318
-
319
- @torch.cuda.amp.autocast()
320
- def encode(self, _batch):
321
- return self.autoencoder.encode(_batch)
322
-
323
- @torch.cuda.amp.autocast()
324
- def decode(self, _batch):
325
- return self.autoencoder.decode(_batch)
326
-
327
- def prepare_contexts(
328
- self) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
329
- resolution = self.config.z_shape[-1] * 8
330
-
331
- contexts = torch.randn(self.config.n_samples, 77,
332
- self.config.clip_text_dim).to(self.device)
333
- img_contexts = torch.randn(self.config.n_samples,
334
- 2 * self.config.z_shape[0],
335
- self.config.z_shape[1],
336
- self.config.z_shape[2])
337
- clip_imgs = torch.randn(self.config.n_samples, 1,
338
- self.config.clip_img_dim)
339
-
340
- if self.config.mode in ['t2i', 't2i2t']:
341
- prompts = [self.config.prompt] * self.config.n_samples
342
- contexts = self.clip_text_model.encode(prompts)
343
-
344
- elif self.config.mode in ['i2t', 'i2t2i']:
345
- img_contexts = []
346
- clip_imgs = []
347
-
348
- def get_img_feature(image):
349
- image = np.array(image).astype(np.uint8)
350
- image = utils.center_crop(resolution, resolution, image)
351
- clip_img_feature = self.clip_img_model.encode_image(
352
- self.clip_img_model_preprocess(
353
- PIL.Image.fromarray(image)).unsqueeze(0).to(
354
- self.device))
355
-
356
- image = (image / 127.5 - 1.0).astype(np.float32)
357
- image = einops.rearrange(image, 'h w c -> 1 c h w')
358
- image = torch.tensor(image, device=self.device)
359
- moments = self.autoencoder.encode_moments(image)
360
-
361
- return clip_img_feature, moments
362
-
363
- image = PIL.Image.open(self.config.img).convert('RGB')
364
- clip_img, img_context = get_img_feature(image)
365
-
366
- img_contexts.append(img_context)
367
- clip_imgs.append(clip_img)
368
- img_contexts = img_contexts * self.config.n_samples
369
- clip_imgs = clip_imgs * self.config.n_samples
370
-
371
- img_contexts = torch.concat(img_contexts, dim=0)
372
- clip_imgs = torch.stack(clip_imgs, dim=0)
373
-
374
- return contexts, img_contexts, clip_imgs
375
-
376
- @staticmethod
377
- def unpreprocess(v: torch.Tensor) -> torch.Tensor: # to B C H W and [0, 1]
378
- v = 0.5 * (v + 1.)
379
- v.clamp_(0., 1.)
380
- return v
381
-
382
- def get_sample_fn(self, _n_samples: int) -> Callable:
383
- def sample_fn(mode: str, **kwargs):
384
- _z_init = torch.randn(_n_samples,
385
- *self.config.z_shape,
386
- device=self.device)
387
- _clip_img_init = torch.randn(_n_samples,
388
- 1,
389
- self.config.clip_img_dim,
390
- device=self.device)
391
- _text_init = torch.randn(_n_samples,
392
- 77,
393
- self.config.text_dim,
394
- device=self.device)
395
- if mode == 'joint':
396
- _x_init = self.combine_joint(_z_init, _clip_img_init,
397
- _text_init)
398
- elif mode in ['t2i', 'i']:
399
- _x_init = self.combine(_z_init, _clip_img_init)
400
- elif mode in ['i2t', 't']:
401
- _x_init = _text_init
402
- noise_schedule = NoiseScheduleVP(schedule='discrete',
403
- betas=torch.tensor(
404
- self.betas,
405
- device=self.device).float())
406
-
407
- def model_fn(x, t_continuous):
408
- t = t_continuous * self.N
409
- if mode == 'joint':
410
- return self.joint_nnet(x, t)
411
- elif mode == 't2i':
412
- return self.t2i_nnet(x, t, **kwargs)
413
- elif mode == 'i2t':
414
- return self.i2t_nnet(x, t, **kwargs)
415
- elif mode == 'i':
416
- return self.i_nnet(x, t)
417
- elif mode == 't':
418
- return self.t_nnet(x, t)
419
-
420
- dpm_solver = DPM_Solver(model_fn,
421
- noise_schedule,
422
- predict_x0=True,
423
- thresholding=False)
424
- with torch.inference_mode(), torch.autocast(
425
- device_type=self.device.type):
426
- x = dpm_solver.sample(_x_init,
427
- steps=self.config.sample.sample_steps,
428
- eps=1. / self.N,
429
- T=1.)
430
-
431
- if mode == 'joint':
432
- _z, _clip_img, _text = self.split_joint(x)
433
- return _z, _clip_img, _text
434
- elif mode in ['t2i', 'i']:
435
- _z, _clip_img = self.split(x)
436
- return _z, _clip_img
437
- elif mode in ['i2t', 't']:
438
- return x
439
-
440
- return sample_fn
441
-
442
- @staticmethod
443
- def to_pil(tensor: torch.Tensor) -> PIL.Image.Image:
444
- return PIL.Image.fromarray(
445
- tensor.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to(
446
- 'cpu', torch.uint8).numpy())
447
-
448
- def run(self, mode: str, prompt: str, image_path: str, seed: int,
449
- num_steps: int,
450
- guidance_scale: float) -> tuple[PIL.Image.Image | None, str]:
451
- self.config.mode = mode
452
- self.config.prompt = prompt
453
- self.config.img = image_path
454
- self.config.seed = seed
455
- self.config.sample.sample_steps = num_steps
456
- self.config.sample.scale = guidance_scale
457
- self.config.n_samples = 1
458
-
459
- #set_seed(self.config.seed)
460
- if seed == -1:
461
- seed = random.randint(0, 1000000)
462
- torch.manual_seed(seed)
463
-
464
- contexts, img_contexts, clip_imgs = self.prepare_contexts()
465
- if self.use_caption_decoder:
466
- contexts_low_dim = self.caption_decoder.encode_prefix(contexts)
467
- else:
468
- contexts_low_dim = contexts
469
- z_img = self.autoencoder.sample(img_contexts)
470
-
471
- if self.config.mode in ['t2i', 't2i2t']:
472
- _n_samples = contexts_low_dim.size(0)
473
- elif self.config.mode in ['i2t', 'i2t2i']:
474
- _n_samples = img_contexts.size(0)
475
- else:
476
- _n_samples = self.config.n_samples
477
- sample_fn = self.get_sample_fn(_n_samples)
478
-
479
- if self.config.mode == 'joint':
480
- _z, _clip_img, _text = sample_fn(self.config.mode)
481
- samples = self.unpreprocess(self.decode(_z))
482
- samples = [self.to_pil(tensor) for tensor in samples]
483
- prompts = self.caption_decoder.generate_captions(_text)
484
- return samples[0], prompts[0]
485
-
486
- elif self.config.mode in ['t2i', 'i', 'i2t2i']:
487
- if self.config.mode == 't2i':
488
- _z, _clip_img = sample_fn(
489
- self.config.mode,
490
- text=contexts_low_dim) # conditioned on the text embedding
491
- elif self.config.mode == 'i':
492
- _z, _clip_img = sample_fn(self.config.mode)
493
- elif self.config.mode == 'i2t2i':
494
- _text = sample_fn(
495
- 'i2t', z=z_img,
496
- clip_img=clip_imgs) # conditioned on the image embedding
497
- _z, _clip_img = sample_fn('t2i', text=_text)
498
- samples = self.unpreprocess(self.decode(_z))
499
- samples = [self.to_pil(tensor) for tensor in samples]
500
- return samples[0], ''
501
-
502
- elif self.config.mode in ['i2t', 't', 't2i2t']:
503
- if self.config.mode == 'i2t':
504
- _text = sample_fn(
505
- self.config.mode, z=z_img,
506
- clip_img=clip_imgs) # conditioned on the image embedding
507
- elif self.config.mode == 't':
508
- _text = sample_fn(self.config.mode)
509
- elif self.config.mode == 't2i2t':
510
- _z, _clip_img = sample_fn('t2i', text=contexts_low_dim)
511
- _text = sample_fn('i2t', z=_z, clip_img=_clip_img)
512
- prompts = self.caption_decoder.generate_captions(_text)
513
- return None, prompts[0]
514
  else:
515
  raise ValueError
 
1
  from __future__ import annotations
2
 
 
 
 
 
 
 
 
 
3
  import PIL.Image
4
  import torch
5
+ from diffusers import UniDiffuserPipeline
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
 
8
  class Model:
9
  def __init__(self):
10
  self.device = torch.device(
11
  'cuda:0' if torch.cuda.is_available() else 'cpu')
12
+ if self.device.type == 'cuda':
13
+ self.pipe = UniDiffuserPipeline.from_pretrained(
14
+ 'thu-ml/unidiffuser-v1', torch_dtype=torch.float16)
15
+ self.pipe.to(self.device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  else:
17
+ self.pipe = UniDiffuserPipeline.from_pretrained(
18
+ 'thu-ml/unidiffuser-v1')
19
+
20
+ def run(
21
+ self,
22
+ mode: str,
23
+ prompt: str,
24
+ image: PIL.Image.Image | None,
25
+ seed: int = 0,
26
+ num_steps: int = 20,
27
+ guidance_scale: float = 8.0,
28
+ ) -> tuple[PIL.Image.Image | None, str]:
29
+ generator = torch.Generator(device=self.device).manual_seed(seed)
30
+ if mode == 't2i':
31
+ self.pipe.set_text_to_image_mode()
32
+ sample = self.pipe(prompt=prompt,
33
+ num_inference_steps=num_steps,
34
+ guidance_scale=guidance_scale,
35
+ generator=generator)
36
+ return sample.images[0], ''
37
+ elif mode == 'i2t':
38
+ self.pipe.set_image_to_text_mode()
39
+ sample = self.pipe(image=image,
40
+ num_inference_steps=num_steps,
41
+ guidance_scale=guidance_scale,
42
+ generator=generator)
43
+ return None, sample.text[0]
44
+ elif mode == 'joint':
45
+ self.pipe.set_joint_mode()
46
+ sample = self.pipe(num_inference_steps=num_steps,
47
+ guidance_scale=guidance_scale,
48
+ generator=generator)
49
+ return sample.images[0], sample.text[0]
50
+ elif mode == 'i':
51
+ self.pipe.set_image_mode()
52
+ sample = self.pipe(num_inference_steps=num_steps,
53
+ guidance_scale=guidance_scale,
54
+ generator=generator)
55
+ return sample.images[0], ''
56
+ elif mode == 't':
57
+ self.pipe.set_text_mode()
58
+ sample = self.pipe(num_inference_steps=num_steps,
59
+ guidance_scale=guidance_scale,
60
+ generator=generator)
61
+ return None, sample.text[0]
62
+ elif mode == 'i2t2i':
63
+ self.pipe.set_image_to_text_mode()
64
+ sample = self.pipe(image=image,
65
+ num_inference_steps=num_steps,
66
+ guidance_scale=guidance_scale,
67
+ generator=generator)
68
+ self.pipe.set_text_to_image_mode()
69
+ sample = self.pipe(prompt=sample.text[0],
70
+ num_inference_steps=num_steps,
71
+ guidance_scale=guidance_scale,
72
+ generator=generator)
73
+ return sample.images[0], ''
74
+ elif mode == 't2i2t':
75
+ self.pipe.set_text_to_image_mode()
76
+ sample = self.pipe(prompt=prompt,
77
+ num_inference_steps=num_steps,
78
+ guidance_scale=guidance_scale,
79
+ generator=generator)
80
+ self.pipe.set_image_to_text_mode()
81
+ sample = self.pipe(image=sample.images[0],
82
+ num_inference_steps=num_steps,
83
+ guidance_scale=guidance_scale,
84
+ generator=generator)
85
+ return None, sample.text[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
  else:
87
  raise ValueError
requirements.txt CHANGED
@@ -1,13 +1,9 @@
1
- absl-py==1.4.0
2
- accelerate==0.12.0
3
- einops==0.6.0
4
- ftfy==6.1.1
5
- git+https://github.com/openai/CLIP.git@a9b1bf5
6
- gradio==3.21.0
7
- huggingface-hub==0.13.2
8
- ml-collections==0.1.1
9
- torch==1.13.1
10
- torchvision==0.14.1
11
- transformers==4.23.1
12
  triton==2.0.0
13
- xformers==0.0.16
 
1
+ accelerate==0.20.3
2
+ diffusers==0.17.0
3
+ gradio==3.34.0
4
+ huggingface-hub==0.15.1
5
+ torch==2.0.1
6
+ torchvision==0.15.2
7
+ transformers==4.30.0
 
 
 
 
8
  triton==2.0.0
9
+ xformers==0.0.20
unidiffuser DELETED
@@ -1 +0,0 @@
1
- Subproject commit 390368777ce0a6102f50361ab6dae8e0991447a8