jaketae commited on
Commit
b4968e4
1 Parent(s): b1f513f

feature: add image feature extraction script

Browse files
Files changed (1) hide show
  1. embed.py +50 -0
embed.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import csv
3
+ import os
4
+
5
+ from PIL import Image
6
+
7
+ from utils import load_model
8
+
9
+
10
+ def main(args):
11
+ root = args.image_path
12
+ files = list(os.listdir(root))
13
+ for model_name in ["koclip", "koclip/koclip-large"]:
14
+ counter = 0
15
+ images = []
16
+ image_ids = []
17
+ model, processor = load_model(f"koclip/{model_name}")
18
+ while counter < len(files):
19
+ if counter != 0 and counter % args.batch_size == 0:
20
+ inputs = processor(text=[""], images=images, return_tensors="jax", padding=True)
21
+ features = model(**inputs).image_embeds
22
+ with open(os.path.join(args.out_path, f"{model_name}.tsv", "w+")) as f:
23
+ writer = csv.writer(f, delimiter="\t")
24
+ for image_id, feature in zip(image_ids, features):
25
+ writer.writerow([image_id, ",".join(feature)])
26
+ images = []
27
+ image_ids = []
28
+ else:
29
+ file_ = files[counter]
30
+ image = Image.open(os.path.join(root, file_))
31
+ images.append(image)
32
+ image_ids.append(file_)
33
+ counter += 1
34
+
35
+
36
+
37
+ if __name__ == "__main__":
38
+ parser = argparse.ArgumentParser()
39
+ parser.add_argument("--batch_size", default=16)
40
+ parser.add_argument("--image_path", default="images")
41
+ parser.add_argument("--out_path", default="features")
42
+ args = parser.parse_args()
43
+ main(args)
44
+
45
+
46
+
47
+
48
+
49
+
50
+