adymaharana commited on
Commit
6b98e7a
1 Parent(s): 96db007

Removed clip dependency

Browse files
Files changed (2) hide show
  1. app.py +0 -1
  2. dalle/utils/utils.py +1 -24
app.py CHANGED
@@ -5,7 +5,6 @@ import torchvision.transforms as transforms
5
  from dalle.models import StoryDalle
6
  import argparse
7
  from PIL import Image
8
- import numpy as np
9
  from torchvision.utils import save_image
10
  import tensorflow_hub as hub
11
  import gdown
 
5
  from dalle.models import StoryDalle
6
  import argparse
7
  from PIL import Image
 
8
  from torchvision.utils import save_image
9
  import tensorflow_hub as hub
10
  import gdown
dalle/utils/utils.py CHANGED
@@ -10,12 +10,8 @@ import urllib
10
  import hashlib
11
  import tarfile
12
  import torch
13
- import clip
14
  import numpy as np
15
- from PIL import Image
16
- from torch.nn import functional as F
17
  from tqdm import tqdm
18
- import torchvision.utils as vutils
19
  import matplotlib.pyplot as plt
20
 
21
 
@@ -26,25 +22,6 @@ def set_seed(seed: int):
26
  torch.cuda.manual_seed_all(seed)
27
 
28
 
29
- @torch.no_grad()
30
- def clip_score(prompt: str,
31
- images: np.ndarray,
32
- model_clip: torch.nn.Module,
33
- preprocess_clip,
34
- device: str) -> np.ndarray:
35
- images = [preprocess_clip(Image.fromarray((image*255).astype(np.uint8))) for image in images]
36
- images = torch.stack(images, dim=0).to(device=device)
37
- texts = clip.tokenize(prompt).to(device=device)
38
- texts = torch.repeat_interleave(texts, images.shape[0], dim=0)
39
-
40
- image_features = model_clip.encode_image(images)
41
- text_features = model_clip.encode_text(texts)
42
-
43
- scores = F.cosine_similarity(image_features, text_features).squeeze()
44
- rank = torch.argsort(scores, descending=True).cpu().numpy()
45
- return rank
46
-
47
-
48
  def download(url: str, root: str) -> str:
49
  os.makedirs(root, exist_ok=True)
50
  filename = os.path.basename(url)
@@ -128,4 +105,4 @@ def save_image(ground_truth, images, out_dir, batch_idx):
128
  # fid.write(texts[i][idx] + '\n')
129
  # fid.write('\n\n')
130
  # fid.close()
131
- return
 
10
  import hashlib
11
  import tarfile
12
  import torch
 
13
  import numpy as np
 
 
14
  from tqdm import tqdm
 
15
  import matplotlib.pyplot as plt
16
 
17
 
 
22
  torch.cuda.manual_seed_all(seed)
23
 
24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  def download(url: str, root: str) -> str:
26
  os.makedirs(root, exist_ok=True)
27
  filename = os.path.basename(url)
 
105
  # fid.write(texts[i][idx] + '\n')
106
  # fid.write('\n\n')
107
  # fid.close()
108
+ return