versae commited on
Commit
0649211
1 Parent(s): e001acb

Upload image_vectorizer_clip.py

Browse files
Files changed (1) hide show
  1. image_vectorizer_clip.py +76 -0
image_vectorizer_clip.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import csv
2
+ import json
3
+ import os
4
+ import time
5
+ import urllib.request
6
+ import pandas as pd
7
+ import numpy as np
8
+ from pathlib import Path
9
+ from multiprocessing.dummy import Pool
10
+ from tqdm import tqdm
11
+ from transformers import CLIPProcessor, CLIPModel
12
+ from PIL import Image, ImageFile
13
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
14
+
15
+ processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
16
+ model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
17
+
18
+
19
+ def compute_image_embeddings(list_of_images):
20
+ return model.get_image_features(
21
+ **processor(images=list_of_images, return_tensors="pt", padding=True)
22
+ )
23
+
24
+
25
+ def load_image(path, same_height=False):
26
+ im = Image.open(path)
27
+ if im.mode != 'RGB':
28
+ im = im.convert('RGB')
29
+ if same_height:
30
+ ratio = 224 / im.size[1]
31
+ else:
32
+ ratio = 224 / min(im.size)
33
+ return im.resize((int(im.size[0] * ratio), int(im.size[1] * ratio)))
34
+
35
+
36
+ def main():
37
+ embeddings = None
38
+ rows = [["id", "label", "thumbnail"]]
39
+ total = sum(1 for _ in Path("./vectors_20211011").rglob("**/*.jpg"))
40
+ images_path = Path("./vectors_20211011").rglob("**/*.jpg")
41
+ for i, image_path in enumerate(tqdm(images_path, total=total)):
42
+ embedding = compute_image_embeddings(
43
+ [load_image(image_path)]
44
+ ).detach().numpy()[0]
45
+ if embeddings is None:
46
+ embeddings = embedding
47
+ else:
48
+ embeddings = np.vstack([embeddings, embedding])
49
+ filename = image_path.as_posix()
50
+ record_path = (filename
51
+ .replace("vectors", "records")
52
+ .replace(".jpg", ".json")
53
+ )
54
+ with open(record_path) as record_file:
55
+ record = json.load(record_file)
56
+ rows.append([
57
+ image_path.stem,
58
+ record["metadata"]["title"],
59
+ record["_links"]["thumbnail_large"]["href"]
60
+ ])
61
+ if i and i % 1000 == 0:
62
+ with open("clip.csv", "w") as clip_file:
63
+ writer = csv.writer(clip_file)
64
+ for row in tqdm(rows, desc="Writing rows and embeddings"):
65
+ writer.writerow(row)
66
+ np.save("clip.npy", embeddings)
67
+
68
+ with open("clip.csv", "w") as clip_file:
69
+ writer = csv.writer(clip_file)
70
+ for row in tqdm(rows, desc="Writing rows and embeddings"):
71
+ writer.writerow(row)
72
+ np.save("clip.npy", embeddings)
73
+
74
+
75
+ if __name__ == "__main__":
76
+ main()