devtrent commited on
Commit
8e37dd1
1 Parent(s): c6d1483

Embed.py implementation

Browse files
Files changed (2) hide show
  1. embed.py +26 -25
  2. requirements.txt +1 -0
embed.py CHANGED
@@ -5,33 +5,41 @@ import os
5
  from PIL import Image
6
 
7
  from utils import load_model
 
 
 
 
8
 
9
 
10
  def main(args):
11
  root = args.image_path
12
  files = list(os.listdir(root))
13
- for model_name in ["koclip", "koclip/koclip-large"]:
14
- counter = 0
15
- images = []
16
- image_ids = []
17
  model, processor = load_model(f"koclip/{model_name}")
18
- while counter < len(files):
19
- if counter != 0 and counter % args.batch_size == 0:
20
- inputs = processor(text=[""], images=images, return_tensors="jax", padding=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  features = model(**inputs).image_embeds
22
- with open(os.path.join(args.out_path, f"{model_name}.tsv", "w+")) as f:
23
  writer = csv.writer(f, delimiter="\t")
24
  for image_id, feature in zip(image_ids, features):
25
- writer.writerow([image_id, ",".join(feature)])
26
- images = []
27
- image_ids = []
28
- else:
29
- file_ = files[counter]
30
- image = Image.open(os.path.join(root, file_))
31
- images.append(image)
32
- image_ids.append(file_)
33
- counter += 1
34
-
35
 
36
 
37
  if __name__ == "__main__":
@@ -41,10 +49,3 @@ if __name__ == "__main__":
41
  parser.add_argument("--out_path", default="features")
42
  args = parser.parse_args()
43
  main(args)
44
-
45
-
46
-
47
-
48
-
49
-
50
-
5
  from PIL import Image
6
 
7
  from utils import load_model
8
+ import jax.numpy as jnp
9
+ from jax import jit
10
+
11
+ from tqdm import tqdm
12
 
13
 
14
  def main(args):
15
  root = args.image_path
16
  files = list(os.listdir(root))
17
+ for f in files:
18
+ assert(f[-4:] == ".jpg")
19
+ for model_name in ["koclip", "koclip-large"]:
 
20
  model, processor = load_model(f"koclip/{model_name}")
21
+ with tqdm(total=len(files)) as pbar:
22
+ for counter in range(0, len(files), args.batch_size):
23
+ images = []
24
+ image_ids = []
25
+ for idx in range(counter, min(len(files), counter + args.batch_size)):
26
+ file_ = files[idx]
27
+ image = Image.open(os.path.join(root, file_)).convert('RGB')
28
+ images.append(image)
29
+ image_ids.append(file_)
30
+
31
+ pbar.update(args.batch_size)
32
+ try:
33
+ inputs = processor(text=[""], images=images, return_tensors="jax", padding=True)
34
+ except:
35
+ print(image_ids)
36
+ break
37
+ inputs['pixel_values'] = jnp.transpose(inputs['pixel_values'], axes=[0, 2, 3, 1])
38
  features = model(**inputs).image_embeds
39
+ with open(os.path.join(args.out_path, f"{model_name}.tsv"), "a+") as f:
40
  writer = csv.writer(f, delimiter="\t")
41
  for image_id, feature in zip(image_ids, features):
42
+ writer.writerow([image_id, ",".join(map(lambda x: str(x), feature))])
 
 
 
 
 
 
 
 
 
43
 
44
 
45
  if __name__ == "__main__":
49
  parser.add_argument("--out_path", default="features")
50
  args = parser.parse_args()
51
  main(args)
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -3,3 +3,4 @@ jaxlib
3
  flax
4
  transformers
5
  streamlit
 
3
  flax
4
  transformers
5
  streamlit
6
+ tqdm