hpoghos commited on
Commit
f3a7d15
1 Parent(s): ef10431
app.py CHANGED
@@ -5,7 +5,7 @@ import argparse
5
  import datetime
6
  from pathlib import Path
7
  import torch
8
- import spaces
9
  import gradio as gr
10
  import tempfile
11
  import yaml
@@ -54,7 +54,7 @@ inference_generator = torch.Generator(device="cuda")
54
  # -------------------------
55
  # ----- Functionality -----
56
  # -------------------------
57
- @spaces.GPU
58
  def generate(prompt, num_frames, image, model_name_stage1, model_name_stage2, seed, t, image_guidance, where_to_log=result_fol):
59
  now = datetime.datetime.now()
60
  name = prompt[:100].replace(" ", "_") + "_" + str(now.time()).replace(":", "_").replace(".", "_")
 
5
  import datetime
6
  from pathlib import Path
7
  import torch
8
+ # import spaces
9
  import gradio as gr
10
  import tempfile
11
  import yaml
 
54
  # -------------------------
55
  # ----- Functionality -----
56
  # -------------------------
57
+ # @spaces.GPU
58
  def generate(prompt, num_frames, image, model_name_stage1, model_name_stage2, seed, t, image_guidance, where_to_log=result_fol):
59
  now = datetime.datetime.now()
60
  name = prompt[:100].replace(" ", "_") + "_" + str(now.time()).replace(":", "_").replace(".", "_")
t2v_enhanced/model/diffusers_conditional/models/controlnet/image_embedder.py CHANGED
@@ -3,7 +3,7 @@ from typing import Any, Mapping
3
  import torch
4
  import torch.nn as nn
5
  import kornia
6
- # import open_clip
7
  from transformers import AutoImageProcessor, AutoModel
8
  from transformers.models.bit.image_processing_bit import BitImageProcessor
9
  from einops import rearrange, repeat
@@ -52,160 +52,160 @@ class AbstractEncoder(nn.Module):
52
 
53
 
54
 
55
- # class FrozenOpenCLIPImageEmbedder(AbstractEncoder):
56
- # """
57
- # Uses the OpenCLIP vision transformer encoder for images
58
- # """
59
-
60
- # def __init__(
61
- # self,
62
- # arch="ViT-H-14",
63
- # version="laion2b_s32b_b79k",
64
- # device="cuda",
65
- # max_length=77,
66
- # freeze=True,
67
- # antialias=True,
68
- # ucg_rate=0.0,
69
- # unsqueeze_dim=False,
70
- # repeat_to_max_len=False,
71
- # num_image_crops=0,
72
- # output_tokens=False,
73
- # ):
74
- # super().__init__()
75
- # model, _, _ = open_clip.create_model_and_transforms(
76
- # arch,
77
- # device=torch.device("cpu"),
78
- # pretrained=version,
79
- # )
80
- # del model.transformer
81
- # self.model = model
82
- # self.max_crops = num_image_crops
83
- # self.pad_to_max_len = self.max_crops > 0
84
- # self.repeat_to_max_len = repeat_to_max_len and (not self.pad_to_max_len)
85
- # self.device = device
86
- # self.max_length = max_length
87
- # if freeze:
88
- # self.freeze()
89
-
90
- # self.antialias = antialias
91
-
92
- # self.register_buffer(
93
- # "mean", torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False
94
- # )
95
- # self.register_buffer(
96
- # "std", torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False
97
- # )
98
- # self.ucg_rate = ucg_rate
99
- # self.unsqueeze_dim = unsqueeze_dim
100
- # self.stored_batch = None
101
- # self.model.visual.output_tokens = output_tokens
102
- # self.output_tokens = output_tokens
103
-
104
- # def preprocess(self, x):
105
- # # normalize to [0,1]
106
- # x = kornia.geometry.resize(
107
- # x,
108
- # (224, 224),
109
- # interpolation="bicubic",
110
- # align_corners=True,
111
- # antialias=self.antialias,
112
- # )
113
- # x = (x + 1.0) / 2.0
114
- # # renormalize according to clip
115
- # x = kornia.enhance.normalize(x, self.mean, self.std)
116
- # return x
117
-
118
- # def freeze(self):
119
- # self.model = self.model.eval()
120
- # for param in self.parameters():
121
- # param.requires_grad = False
122
-
123
- # def forward(self, image, no_dropout=False):
124
- # z = self.encode_with_vision_transformer(image)
125
- # tokens = None
126
- # if self.output_tokens:
127
- # z, tokens = z[0], z[1]
128
- # z = z.to(image.dtype)
129
- # if self.ucg_rate > 0.0 and not no_dropout and not (self.max_crops > 0):
130
- # z = (
131
- # torch.bernoulli(
132
- # (1.0 - self.ucg_rate) * torch.ones(z.shape[0], device=z.device)
133
- # )[:, None]
134
- # * z
135
- # )
136
- # if tokens is not None:
137
- # tokens = (
138
- # expand_dims_like(
139
- # torch.bernoulli(
140
- # (1.0 - self.ucg_rate)
141
- # * torch.ones(tokens.shape[0], device=tokens.device)
142
- # ),
143
- # tokens,
144
- # )
145
- # * tokens
146
- # )
147
- # if self.unsqueeze_dim:
148
- # z = z[:, None, :]
149
- # if self.output_tokens:
150
- # assert not self.repeat_to_max_len
151
- # assert not self.pad_to_max_len
152
- # return tokens, z
153
- # if self.repeat_to_max_len:
154
- # if z.dim() == 2:
155
- # z_ = z[:, None, :]
156
- # else:
157
- # z_ = z
158
- # return repeat(z_, "b 1 d -> b n d", n=self.max_length), z
159
- # elif self.pad_to_max_len:
160
- # assert z.dim() == 3
161
- # z_pad = torch.cat(
162
- # (
163
- # z,
164
- # torch.zeros(
165
- # z.shape[0],
166
- # self.max_length - z.shape[1],
167
- # z.shape[2],
168
- # device=z.device,
169
- # ),
170
- # ),
171
- # 1,
172
- # )
173
- # return z_pad, z_pad[:, 0, ...]
174
- # return z
175
-
176
- # def encode_with_vision_transformer(self, img):
177
- # # if self.max_crops > 0:
178
- # # img = self.preprocess_by_cropping(img)
179
- # if img.dim() == 5:
180
- # assert self.max_crops == img.shape[1]
181
- # img = rearrange(img, "b n c h w -> (b n) c h w")
182
- # img = self.preprocess(img)
183
- # if not self.output_tokens:
184
- # assert not self.model.visual.output_tokens
185
- # x = self.model.visual(img)
186
- # tokens = None
187
- # else:
188
- # assert self.model.visual.output_tokens
189
- # x, tokens = self.model.visual(img)
190
- # if self.max_crops > 0:
191
- # x = rearrange(x, "(b n) d -> b n d", n=self.max_crops)
192
- # # drop out between 0 and all along the sequence axis
193
- # x = (
194
- # torch.bernoulli(
195
- # (1.0 - self.ucg_rate)
196
- # * torch.ones(x.shape[0], x.shape[1], 1, device=x.device)
197
- # )
198
- # * x
199
- # )
200
- # if tokens is not None:
201
- # tokens = rearrange(tokens, "(b n) t d -> b t (n d)", n=self.max_crops)
202
- # print(
203
- # f"You are running very experimental token-concat in {self.__class__.__name__}. "
204
- # f"Check what you are doing, and then remove this message."
205
- # )
206
- # if self.output_tokens:
207
- # return x, tokens
208
- # return x
209
-
210
- # def encode(self, text):
211
- # return self(text)
 
3
  import torch
4
  import torch.nn as nn
5
  import kornia
6
+ from open_clip import create_model_and_transforms
7
  from transformers import AutoImageProcessor, AutoModel
8
  from transformers.models.bit.image_processing_bit import BitImageProcessor
9
  from einops import rearrange, repeat
 
52
 
53
 
54
 
55
+ class FrozenOpenCLIPImageEmbedder(AbstractEncoder):
56
+ """
57
+ Uses the OpenCLIP vision transformer encoder for images
58
+ """
59
+
60
+ def __init__(
61
+ self,
62
+ arch="ViT-H-14",
63
+ version="laion2b_s32b_b79k",
64
+ device="cuda",
65
+ max_length=77,
66
+ freeze=True,
67
+ antialias=True,
68
+ ucg_rate=0.0,
69
+ unsqueeze_dim=False,
70
+ repeat_to_max_len=False,
71
+ num_image_crops=0,
72
+ output_tokens=False,
73
+ ):
74
+ super().__init__()
75
+ model, _, _ = open_clip.create_model_and_transforms(
76
+ arch,
77
+ device=torch.device("cpu"),
78
+ pretrained=version,
79
+ )
80
+ del model.transformer
81
+ self.model = model
82
+ self.max_crops = num_image_crops
83
+ self.pad_to_max_len = self.max_crops > 0
84
+ self.repeat_to_max_len = repeat_to_max_len and (not self.pad_to_max_len)
85
+ self.device = device
86
+ self.max_length = max_length
87
+ if freeze:
88
+ self.freeze()
89
+
90
+ self.antialias = antialias
91
+
92
+ self.register_buffer(
93
+ "mean", torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False
94
+ )
95
+ self.register_buffer(
96
+ "std", torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False
97
+ )
98
+ self.ucg_rate = ucg_rate
99
+ self.unsqueeze_dim = unsqueeze_dim
100
+ self.stored_batch = None
101
+ self.model.visual.output_tokens = output_tokens
102
+ self.output_tokens = output_tokens
103
+
104
+ def preprocess(self, x):
105
+ # normalize to [0,1]
106
+ x = kornia.geometry.resize(
107
+ x,
108
+ (224, 224),
109
+ interpolation="bicubic",
110
+ align_corners=True,
111
+ antialias=self.antialias,
112
+ )
113
+ x = (x + 1.0) / 2.0
114
+ # renormalize according to clip
115
+ x = kornia.enhance.normalize(x, self.mean, self.std)
116
+ return x
117
+
118
+ def freeze(self):
119
+ self.model = self.model.eval()
120
+ for param in self.parameters():
121
+ param.requires_grad = False
122
+
123
+ def forward(self, image, no_dropout=False):
124
+ z = self.encode_with_vision_transformer(image)
125
+ tokens = None
126
+ if self.output_tokens:
127
+ z, tokens = z[0], z[1]
128
+ z = z.to(image.dtype)
129
+ if self.ucg_rate > 0.0 and not no_dropout and not (self.max_crops > 0):
130
+ z = (
131
+ torch.bernoulli(
132
+ (1.0 - self.ucg_rate) * torch.ones(z.shape[0], device=z.device)
133
+ )[:, None]
134
+ * z
135
+ )
136
+ if tokens is not None:
137
+ tokens = (
138
+ expand_dims_like(
139
+ torch.bernoulli(
140
+ (1.0 - self.ucg_rate)
141
+ * torch.ones(tokens.shape[0], device=tokens.device)
142
+ ),
143
+ tokens,
144
+ )
145
+ * tokens
146
+ )
147
+ if self.unsqueeze_dim:
148
+ z = z[:, None, :]
149
+ if self.output_tokens:
150
+ assert not self.repeat_to_max_len
151
+ assert not self.pad_to_max_len
152
+ return tokens, z
153
+ if self.repeat_to_max_len:
154
+ if z.dim() == 2:
155
+ z_ = z[:, None, :]
156
+ else:
157
+ z_ = z
158
+ return repeat(z_, "b 1 d -> b n d", n=self.max_length), z
159
+ elif self.pad_to_max_len:
160
+ assert z.dim() == 3
161
+ z_pad = torch.cat(
162
+ (
163
+ z,
164
+ torch.zeros(
165
+ z.shape[0],
166
+ self.max_length - z.shape[1],
167
+ z.shape[2],
168
+ device=z.device,
169
+ ),
170
+ ),
171
+ 1,
172
+ )
173
+ return z_pad, z_pad[:, 0, ...]
174
+ return z
175
+
176
+ def encode_with_vision_transformer(self, img):
177
+ # if self.max_crops > 0:
178
+ # img = self.preprocess_by_cropping(img)
179
+ if img.dim() == 5:
180
+ assert self.max_crops == img.shape[1]
181
+ img = rearrange(img, "b n c h w -> (b n) c h w")
182
+ img = self.preprocess(img)
183
+ if not self.output_tokens:
184
+ assert not self.model.visual.output_tokens
185
+ x = self.model.visual(img)
186
+ tokens = None
187
+ else:
188
+ assert self.model.visual.output_tokens
189
+ x, tokens = self.model.visual(img)
190
+ if self.max_crops > 0:
191
+ x = rearrange(x, "(b n) d -> b n d", n=self.max_crops)
192
+ # drop out between 0 and all along the sequence axis
193
+ x = (
194
+ torch.bernoulli(
195
+ (1.0 - self.ucg_rate)
196
+ * torch.ones(x.shape[0], x.shape[1], 1, device=x.device)
197
+ )
198
+ * x
199
+ )
200
+ if tokens is not None:
201
+ tokens = rearrange(tokens, "(b n) t d -> b t (n d)", n=self.max_crops)
202
+ print(
203
+ f"You are running very experimental token-concat in {self.__class__.__name__}. "
204
+ f"Check what you are doing, and then remove this message."
205
+ )
206
+ if self.output_tokens:
207
+ return x, tokens
208
+ return x
209
+
210
+ def encode(self, text):
211
+ return self(text)