jaketae commited on
Commit
2cf3514
โ€ข
1 Parent(s): 9928549

style: run linter

Browse files
Files changed (6) hide show
  1. app.py +0 -1
  2. embed_captions.py +4 -2
  3. embed_images.py +14 -9
  4. image2text.py +4 -2
  5. text2image.py +10 -8
  6. utils.py +12 -7
app.py CHANGED
@@ -3,7 +3,6 @@ import streamlit as st
3
  import image2text
4
  import text2image
5
 
6
-
7
  PAGES = {"Text to Image": text2image, "Image to Text": image2text}
8
 
9
  st.sidebar.title("Navigation")
3
  import image2text
4
  import text2image
5
 
 
6
  PAGES = {"Text to Image": text2image, "Image to Text": image2text}
7
 
8
  st.sidebar.title("Navigation")
embed_captions.py CHANGED
@@ -1,13 +1,15 @@
1
- import csv
2
  import argparse
 
 
3
  from utils import load_model
4
 
 
5
  def main(args):
6
  caption_txt_path = args.text_path
7
  f = open(caption_txt_path)
8
  captions = [sent.strip() for sent in f.readlines()
9
 
10
- for model_name in ["koclip-base", "koclip-large"]:
11
  model, processor = load_model(f"koclip/{model_name}")
12
  captions_processed = [processor(sent,images=None,return_tensors='jax') for sent in captions]
13
  vec = [np.asarray(model.get_text_features(**c)) for c in captions_processed]
 
1
  import argparse
2
+ import csv
3
+
4
  from utils import load_model
5
 
6
+
7
  def main(args):
8
  caption_txt_path = args.text_path
9
  f = open(caption_txt_path)
10
  captions = [sent.strip() for sent in f.readlines()
11
 
12
+ for model_name in ["koclip-base", "koclip-large"]:
13
  model, processor = load_model(f"koclip/{model_name}")
14
  captions_processed = [processor(sent,images=None,return_tensors='jax') for sent in captions]
15
  vec = [np.asarray(model.get_text_features(**c)) for c in captions_processed]
embed_images.py CHANGED
@@ -2,20 +2,19 @@ import argparse
2
  import csv
3
  import os
4
 
5
- from PIL import Image
6
-
7
- from utils import load_model
8
  import jax.numpy as jnp
9
  from jax import jit
10
-
11
  from tqdm import tqdm
12
 
 
 
13
 
14
  def main(args):
15
  root = args.image_path
16
  files = list(os.listdir(root))
17
  for f in files:
18
- assert(f[-4:] == ".jpg")
19
  for model_name in ["koclip-base", "koclip-large"]:
20
  model, processor = load_model(f"koclip/{model_name}")
21
  with tqdm(total=len(files)) as pbar:
@@ -24,22 +23,28 @@ def main(args):
24
  image_ids = []
25
  for idx in range(counter, min(len(files), counter + args.batch_size)):
26
  file_ = files[idx]
27
- image = Image.open(os.path.join(root, file_)).convert('RGB')
28
  images.append(image)
29
  image_ids.append(file_)
30
 
31
  pbar.update(args.batch_size)
32
  try:
33
- inputs = processor(text=[""], images=images, return_tensors="jax", padding=True)
 
 
34
  except:
35
  print(image_ids)
36
  break
37
- inputs['pixel_values'] = jnp.transpose(inputs['pixel_values'], axes=[0, 2, 3, 1])
 
 
38
  features = model(**inputs).image_embeds
39
  with open(os.path.join(args.out_path, f"{model_name}.tsv"), "a+") as f:
40
  writer = csv.writer(f, delimiter="\t")
41
  for image_id, feature in zip(image_ids, features):
42
- writer.writerow([image_id, ",".join(map(lambda x: str(x), feature))])
 
 
43
 
44
 
45
  if __name__ == "__main__":
2
  import csv
3
  import os
4
 
 
 
 
5
  import jax.numpy as jnp
6
  from jax import jit
7
+ from PIL import Image
8
  from tqdm import tqdm
9
 
10
+ from utils import load_model
11
+
12
 
13
  def main(args):
14
  root = args.image_path
15
  files = list(os.listdir(root))
16
  for f in files:
17
+ assert f[-4:] == ".jpg"
18
  for model_name in ["koclip-base", "koclip-large"]:
19
  model, processor = load_model(f"koclip/{model_name}")
20
  with tqdm(total=len(files)) as pbar:
23
  image_ids = []
24
  for idx in range(counter, min(len(files), counter + args.batch_size)):
25
  file_ = files[idx]
26
+ image = Image.open(os.path.join(root, file_)).convert("RGB")
27
  images.append(image)
28
  image_ids.append(file_)
29
 
30
  pbar.update(args.batch_size)
31
  try:
32
+ inputs = processor(
33
+ text=[""], images=images, return_tensors="jax", padding=True
34
+ )
35
  except:
36
  print(image_ids)
37
  break
38
+ inputs["pixel_values"] = jnp.transpose(
39
+ inputs["pixel_values"], axes=[0, 2, 3, 1]
40
+ )
41
  features = model(**inputs).image_embeds
42
  with open(os.path.join(args.out_path, f"{model_name}.tsv"), "a+") as f:
43
  writer = csv.writer(f, delimiter="\t")
44
  for image_id, feature in zip(image_ids, features):
45
+ writer.writerow(
46
+ [image_id, ",".join(map(lambda x: str(x), feature))]
47
+ )
48
 
49
 
50
  if __name__ == "__main__":
image2text.py CHANGED
@@ -7,6 +7,8 @@ def app(model_name):
7
  model, processor = load_model(model_name)
8
 
9
  st.title("Image to Text")
10
- st.markdown("""
 
11
  Some text goes in here.
12
- """)
 
7
  model, processor = load_model(model_name)
8
 
9
  st.title("Image to Text")
10
+ st.markdown(
11
+ """
12
  Some text goes in here.
13
+ """
14
+ )
text2image.py CHANGED
@@ -1,21 +1,22 @@
1
  import os
2
 
 
 
3
  import streamlit as st
4
 
5
- from utils import load_model, load_index
6
- import numpy as np
7
- import matplotlib.pyplot as plt
8
 
9
 
10
  def app(model_name):
11
- images_directory = 'images/val2017'
12
- features_directory = f'features/val2017/{model_name}.tsv'
13
 
14
  files, index = load_index(features_directory)
15
- model, processor = load_model(f'koclip/{model_name}')
16
 
17
  st.title("Text to Image Search Engine")
18
- st.markdown("""
 
19
  This demonstration explores capability of KoCLIP as a Korean-language Image search engine. Embeddings for each of
20
  5000 images from [MSCOCO](https://cocodataset.org/#home) 2017 validation set was generated using trained KoCLIP
21
  vision model. They are ranked based on cosine similarity distance from input Text query embeddings and top 10 images
@@ -27,7 +28,8 @@ def app(model_name):
27
  Larger model `koclip-large` uses `klue/roberta` as text encoder and bigger `google/vit-large-patch16-224` as image encoder.
28
 
29
  Example Queries : ์ปดํ“จํ„ฐํ•˜๋Š” ๊ณ ์–‘์ด(Cat playing on a computer), ๊ธธ ์œ„์—์„œ ๋‹ฌ๋ฆฌ๋Š” ์ž๋™์ฐจ(Car running on the road),
30
- """)
 
31
 
32
  query = st.text_input("ํ•œ๊ธ€ ์งˆ๋ฌธ์„ ์ ์–ด์ฃผ์„ธ์š” (Korean Text Query) :", value="์•„ํŒŒํŠธ")
33
  if st.button("์งˆ๋ฌธ (Query)"):
1
  import os
2
 
3
+ import matplotlib.pyplot as plt
4
+ import numpy as np
5
  import streamlit as st
6
 
7
+ from utils import load_index, load_model
 
 
8
 
9
 
10
  def app(model_name):
11
+ images_directory = "images/val2017"
12
+ features_directory = f"features/val2017/{model_name}.tsv"
13
 
14
  files, index = load_index(features_directory)
15
+ model, processor = load_model(f"koclip/{model_name}")
16
 
17
  st.title("Text to Image Search Engine")
18
+ st.markdown(
19
+ """
20
  This demonstration explores capability of KoCLIP as a Korean-language Image search engine. Embeddings for each of
21
  5000 images from [MSCOCO](https://cocodataset.org/#home) 2017 validation set was generated using trained KoCLIP
22
  vision model. They are ranked based on cosine similarity distance from input Text query embeddings and top 10 images
28
  Larger model `koclip-large` uses `klue/roberta` as text encoder and bigger `google/vit-large-patch16-224` as image encoder.
29
 
30
  Example Queries : ์ปดํ“จํ„ฐํ•˜๋Š” ๊ณ ์–‘์ด(Cat playing on a computer), ๊ธธ ์œ„์—์„œ ๋‹ฌ๋ฆฌ๋Š” ์ž๋™์ฐจ(Car running on the road),
31
+ """
32
+ )
33
 
34
  query = st.text_input("ํ•œ๊ธ€ ์งˆ๋ฌธ์„ ์ ์–ด์ฃผ์„ธ์š” (Korean Text Query) :", value="์•„ํŒŒํŠธ")
35
  if st.button("์งˆ๋ฌธ (Query)"):
utils.py CHANGED
@@ -1,26 +1,28 @@
1
  import nmslib
2
- import streamlit as st
3
- from transformers import CLIPProcessor, AutoTokenizer, ViTFeatureExtractor
4
  import numpy as np
 
 
5
 
6
  from koclip import FlaxHybridCLIP
7
 
 
8
  @st.cache(allow_output_mutation=True)
9
  def load_index(img_file):
10
  filenames, embeddings = [], []
11
  lines = open(img_file, "r")
12
  for line in lines:
13
- cols = line.strip().split('\t')
14
  filename = cols[0]
15
- embedding = np.array([float(x) for x in cols[1].split(',')])
16
  filenames.append(filename)
17
  embeddings.append(embedding)
18
  embeddings = np.array(embeddings)
19
- index = nmslib.init(method='hnsw', space='cosinesimil')
20
  index.addDataPointBatch(embeddings)
21
- index.createIndex({'post': 2}, print_progress=True)
22
  return filenames, index
23
 
 
24
  @st.cache(allow_output_mutation=True)
25
  def load_model(model_name="koclip/koclip-base"):
26
  assert model_name in {"koclip/koclip-base", "koclip/koclip-large"}
@@ -28,9 +30,12 @@ def load_model(model_name="koclip/koclip-base"):
28
  processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
29
  processor.tokenizer = AutoTokenizer.from_pretrained("klue/roberta-large")
30
  if model_name == "koclip/koclip-large":
31
- processor.feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-large-patch16-224")
 
 
32
  return model, processor
33
 
 
34
  @st.cache(allow_output_mutation=True)
35
  def load_model_v2(model_name="koclip/koclip"):
36
  model = FlaxHybridCLIP.from_pretrained(model_name)
1
  import nmslib
 
 
2
  import numpy as np
3
+ import streamlit as st
4
+ from transformers import AutoTokenizer, CLIPProcessor, ViTFeatureExtractor
5
 
6
  from koclip import FlaxHybridCLIP
7
 
8
+
9
  @st.cache(allow_output_mutation=True)
10
  def load_index(img_file):
11
  filenames, embeddings = [], []
12
  lines = open(img_file, "r")
13
  for line in lines:
14
+ cols = line.strip().split("\t")
15
  filename = cols[0]
16
+ embedding = [float(x) for x in cols[1].split(",")]
17
  filenames.append(filename)
18
  embeddings.append(embedding)
19
  embeddings = np.array(embeddings)
20
+ index = nmslib.init(method="hnsw", space="cosinesimil")
21
  index.addDataPointBatch(embeddings)
22
+ index.createIndex({"post": 2}, print_progress=True)
23
  return filenames, index
24
 
25
+
26
  @st.cache(allow_output_mutation=True)
27
  def load_model(model_name="koclip/koclip-base"):
28
  assert model_name in {"koclip/koclip-base", "koclip/koclip-large"}
30
  processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
31
  processor.tokenizer = AutoTokenizer.from_pretrained("klue/roberta-large")
32
  if model_name == "koclip/koclip-large":
33
+ processor.feature_extractor = ViTFeatureExtractor.from_pretrained(
34
+ "google/vit-large-patch16-224"
35
+ )
36
  return model, processor
37
 
38
+
39
  @st.cache(allow_output_mutation=True)
40
  def load_model_v2(model_name="koclip/koclip"):
41
  model = FlaxHybridCLIP.from_pretrained(model_name)