Spaces:
Runtime error
Runtime error
Mehdi Cherti
commited on
Commit
·
7060b15
1
Parent(s):
a45817e
update models
Browse files- clip_encoder.py +21 -0
- run.py +14 -1
clip_encoder.py
CHANGED
@@ -62,3 +62,24 @@ class CLIPImageEncoder(nn.Module):
|
|
62 |
|
63 |
|
64 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
62 |
|
63 |
|
64 |
|
65 |
+
class OpenCLIPImageEncoder(nn.Module):
|
66 |
+
|
67 |
+
def __init__(self, model="ViT-B/32", pretrained="openai"):
|
68 |
+
super().__init__()
|
69 |
+
model, _, preprocess = open_clip.create_model_and_transforms(model, pretrained=pretrained)
|
70 |
+
self.tokenizer = open_clip.get_tokenizer(model)
|
71 |
+
CLIP_MEAN = [0.48145466, 0.4578275, 0.40821073]
|
72 |
+
CLIP_STD = [0.26862954, 0.26130258, 0.27577711]
|
73 |
+
mean = torch.tensor(CLIP_MEAN).view(1, 3, 1, 1)
|
74 |
+
std = torch.tensor(CLIP_STD).view(1, 3, 1, 1)
|
75 |
+
self.register_buffer("mean", mean)
|
76 |
+
self.register_buffer("std", std)
|
77 |
+
|
78 |
+
def forward_image(self, x):
|
79 |
+
x = torch.nn.functional.interpolate(x, mode='bicubic', size=(224, 224))
|
80 |
+
x = (x-self.mean)/self.std
|
81 |
+
return self.model.encode_image(x)
|
82 |
+
|
83 |
+
def forward_text(self, texts):
|
84 |
+
toks = self.tokenizer.tokenize(texts, truncate=True).to(self.mean.device)
|
85 |
+
return self.model.encode_text(toks)
|
run.py
CHANGED
@@ -237,7 +237,7 @@ def ddgan_laion2b_v2():
|
|
237 |
return cfg
|
238 |
|
239 |
def ddgan_ddb_v1():
|
240 |
-
cfg =
|
241 |
return cfg
|
242 |
|
243 |
def ddgan_sd_v11():
|
@@ -245,6 +245,17 @@ def ddgan_sd_v11():
|
|
245 |
cfg['model']['image_size'] = 512
|
246 |
return cfg
|
247 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
248 |
models = [
|
249 |
ddgan_cifar10_cond17, # cifar10, cross attn for discr
|
250 |
ddgan_cifar10_cond18, # cifar10, xl encoder
|
@@ -286,6 +297,8 @@ models = [
|
|
286 |
ddgan_sd_v11,
|
287 |
ddgan_laion2b_v2,
|
288 |
ddgan_ddb_v1,
|
|
|
|
|
289 |
]
|
290 |
|
291 |
def get_model(model_name):
|
|
|
237 |
return cfg
|
238 |
|
239 |
def ddgan_ddb_v1():
|
240 |
+
cfg = ddgan_sd_v10()
|
241 |
return cfg
|
242 |
|
243 |
def ddgan_sd_v11():
|
|
|
245 |
cfg['model']['image_size'] = 512
|
246 |
return cfg
|
247 |
|
248 |
+
def ddgan_ddb_v2():
|
249 |
+
cfg = ddgan_ddb_v1()
|
250 |
+
cfg['model']['num_timesteps'] = 1
|
251 |
+
return cfg
|
252 |
+
|
253 |
+
def ddgan_ddb_v3():
|
254 |
+
cfg = ddgan_ddb_v1()
|
255 |
+
cfg['model']['num_channels_dae'] = 192
|
256 |
+
cfg['model']['num_timesteps'] = 2
|
257 |
+
return cfg
|
258 |
+
|
259 |
models = [
|
260 |
ddgan_cifar10_cond17, # cifar10, cross attn for discr
|
261 |
ddgan_cifar10_cond18, # cifar10, xl encoder
|
|
|
297 |
ddgan_sd_v11,
|
298 |
ddgan_laion2b_v2,
|
299 |
ddgan_ddb_v1,
|
300 |
+
ddgan_ddb_v2,
|
301 |
+
ddgan_ddb_v3
|
302 |
]
|
303 |
|
304 |
def get_model(model_name):
|