Spaces:
Build error
Build error
adymaharana
commited on
Commit
•
6b98e7a
1
Parent(s):
96db007
Removed clip dependency
Browse files- app.py +0 -1
- dalle/utils/utils.py +1 -24
app.py
CHANGED
@@ -5,7 +5,6 @@ import torchvision.transforms as transforms
|
|
5 |
from dalle.models import StoryDalle
|
6 |
import argparse
|
7 |
from PIL import Image
|
8 |
-
import numpy as np
|
9 |
from torchvision.utils import save_image
|
10 |
import tensorflow_hub as hub
|
11 |
import gdown
|
|
|
5 |
from dalle.models import StoryDalle
|
6 |
import argparse
|
7 |
from PIL import Image
|
|
|
8 |
from torchvision.utils import save_image
|
9 |
import tensorflow_hub as hub
|
10 |
import gdown
|
dalle/utils/utils.py
CHANGED
@@ -10,12 +10,8 @@ import urllib
|
|
10 |
import hashlib
|
11 |
import tarfile
|
12 |
import torch
|
13 |
-
import clip
|
14 |
import numpy as np
|
15 |
-
from PIL import Image
|
16 |
-
from torch.nn import functional as F
|
17 |
from tqdm import tqdm
|
18 |
-
import torchvision.utils as vutils
|
19 |
import matplotlib.pyplot as plt
|
20 |
|
21 |
|
@@ -26,25 +22,6 @@ def set_seed(seed: int):
|
|
26 |
torch.cuda.manual_seed_all(seed)
|
27 |
|
28 |
|
29 |
-
@torch.no_grad()
|
30 |
-
def clip_score(prompt: str,
|
31 |
-
images: np.ndarray,
|
32 |
-
model_clip: torch.nn.Module,
|
33 |
-
preprocess_clip,
|
34 |
-
device: str) -> np.ndarray:
|
35 |
-
images = [preprocess_clip(Image.fromarray((image*255).astype(np.uint8))) for image in images]
|
36 |
-
images = torch.stack(images, dim=0).to(device=device)
|
37 |
-
texts = clip.tokenize(prompt).to(device=device)
|
38 |
-
texts = torch.repeat_interleave(texts, images.shape[0], dim=0)
|
39 |
-
|
40 |
-
image_features = model_clip.encode_image(images)
|
41 |
-
text_features = model_clip.encode_text(texts)
|
42 |
-
|
43 |
-
scores = F.cosine_similarity(image_features, text_features).squeeze()
|
44 |
-
rank = torch.argsort(scores, descending=True).cpu().numpy()
|
45 |
-
return rank
|
46 |
-
|
47 |
-
|
48 |
def download(url: str, root: str) -> str:
|
49 |
os.makedirs(root, exist_ok=True)
|
50 |
filename = os.path.basename(url)
|
@@ -128,4 +105,4 @@ def save_image(ground_truth, images, out_dir, batch_idx):
|
|
128 |
# fid.write(texts[i][idx] + '\n')
|
129 |
# fid.write('\n\n')
|
130 |
# fid.close()
|
131 |
-
return
|
|
|
10 |
import hashlib
|
11 |
import tarfile
|
12 |
import torch
|
|
|
13 |
import numpy as np
|
|
|
|
|
14 |
from tqdm import tqdm
|
|
|
15 |
import matplotlib.pyplot as plt
|
16 |
|
17 |
|
|
|
22 |
torch.cuda.manual_seed_all(seed)
|
23 |
|
24 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
def download(url: str, root: str) -> str:
|
26 |
os.makedirs(root, exist_ok=True)
|
27 |
filename = os.path.basename(url)
|
|
|
105 |
# fid.write(texts[i][idx] + '\n')
|
106 |
# fid.write('\n\n')
|
107 |
# fid.close()
|
108 |
+
return
|