ampehta commited on
Commit
503acf7
โ€ข
1 Parent(s): 98b26e2

Revert "Merge branch 'main' of https://huggingface.co/spaces/flax-community/koclip into main"

Browse files

This reverts commit 98b26e2c6b300d6deec2c1f5a119ad6089b11224, reversing
changes made to 699df87c9ee261cf6dfc69f6a6276d3e99bfbc3e.

Files changed (5) hide show
  1. app.py +1 -0
  2. embed.py +11 -16
  3. image2text.py +2 -4
  4. text2image.py +8 -11
  5. utils.py +13 -11
app.py CHANGED
@@ -3,6 +3,7 @@ import streamlit as st
3
  import image2text
4
  import text2image
5
 
 
6
  PAGES = {"Text to Image": text2image, "Image to Text": image2text}
7
 
8
  st.sidebar.title("Navigation")
3
  import image2text
4
  import text2image
5
 
6
+
7
  PAGES = {"Text to Image": text2image, "Image to Text": image2text}
8
 
9
  st.sidebar.title("Navigation")
embed.py CHANGED
@@ -2,20 +2,21 @@ import argparse
2
  import csv
3
  import os
4
 
5
- import jax.numpy as jnp
6
  from PIL import Image
7
- from tqdm import tqdm
8
 
9
  from utils import load_model
 
 
 
 
10
 
11
 
12
  def main(args):
13
  root = args.image_path
14
  files = list(os.listdir(root))
15
  for f in files:
16
- assert f[-4:] == ".jpg"
17
  for model_name in ["koclip-base", "koclip-large"]:
18
- # for model_name in ["koclip-large"]:
19
  model, processor = load_model(f"koclip/{model_name}")
20
  with tqdm(total=len(files)) as pbar:
21
  for counter in range(0, len(files), args.batch_size):
@@ -23,34 +24,28 @@ def main(args):
23
  image_ids = []
24
  for idx in range(counter, min(len(files), counter + args.batch_size)):
25
  file_ = files[idx]
26
- image = Image.open(os.path.join(root, file_)).convert("RGB")
27
  images.append(image)
28
  image_ids.append(file_)
29
 
30
  pbar.update(args.batch_size)
31
  try:
32
- inputs = processor(
33
- text=[""], images=images, return_tensors="jax", padding=True
34
- )
35
  except:
36
  print(image_ids)
37
  break
38
- inputs["pixel_values"] = jnp.transpose(
39
- inputs["pixel_values"], axes=[0, 2, 3, 1]
40
- )
41
  features = model(**inputs).image_embeds
42
  with open(os.path.join(args.out_path, f"{model_name}.tsv"), "a+") as f:
43
  writer = csv.writer(f, delimiter="\t")
44
  for image_id, feature in zip(image_ids, features):
45
- writer.writerow(
46
- [image_id, ",".join(map(lambda x: str(x), feature))]
47
- )
48
 
49
 
50
  if __name__ == "__main__":
51
  parser = argparse.ArgumentParser()
52
  parser.add_argument("--batch_size", default=16)
53
- parser.add_argument("--image_path", default="images/val2017")
54
- parser.add_argument("--out_path", default="features/val2017")
55
  args = parser.parse_args()
56
  main(args)
2
  import csv
3
  import os
4
 
 
5
  from PIL import Image
 
6
 
7
  from utils import load_model
8
+ import jax.numpy as jnp
9
+ from jax import jit
10
+
11
+ from tqdm import tqdm
12
 
13
 
14
  def main(args):
15
  root = args.image_path
16
  files = list(os.listdir(root))
17
  for f in files:
18
+ assert(f[-4:] == ".jpg")
19
  for model_name in ["koclip-base", "koclip-large"]:
 
20
  model, processor = load_model(f"koclip/{model_name}")
21
  with tqdm(total=len(files)) as pbar:
22
  for counter in range(0, len(files), args.batch_size):
24
  image_ids = []
25
  for idx in range(counter, min(len(files), counter + args.batch_size)):
26
  file_ = files[idx]
27
+ image = Image.open(os.path.join(root, file_)).convert('RGB')
28
  images.append(image)
29
  image_ids.append(file_)
30
 
31
  pbar.update(args.batch_size)
32
  try:
33
+ inputs = processor(text=[""], images=images, return_tensors="jax", padding=True)
 
 
34
  except:
35
  print(image_ids)
36
  break
37
+ inputs['pixel_values'] = jnp.transpose(inputs['pixel_values'], axes=[0, 2, 3, 1])
 
 
38
  features = model(**inputs).image_embeds
39
  with open(os.path.join(args.out_path, f"{model_name}.tsv"), "a+") as f:
40
  writer = csv.writer(f, delimiter="\t")
41
  for image_id, feature in zip(image_ids, features):
42
+ writer.writerow([image_id, ",".join(map(lambda x: str(x), feature))])
 
 
43
 
44
 
45
  if __name__ == "__main__":
46
  parser = argparse.ArgumentParser()
47
  parser.add_argument("--batch_size", default=16)
48
+ parser.add_argument("--image_path", default="images")
49
+ parser.add_argument("--out_path", default="features")
50
  args = parser.parse_args()
51
  main(args)
image2text.py CHANGED
@@ -7,8 +7,6 @@ def app(model_name):
7
  model, processor = load_model(model_name)
8
 
9
  st.title("Image to Text")
10
- st.markdown(
11
- """
12
  Some text goes in here.
13
- """
14
- )
7
  model, processor = load_model(model_name)
8
 
9
  st.title("Image to Text")
10
+ st.markdown("""
 
11
  Some text goes in here.
12
+ """)
 
text2image.py CHANGED
@@ -1,22 +1,21 @@
1
  import os
2
 
3
- import matplotlib.pyplot as plt
4
- import numpy as np
5
  import streamlit as st
6
 
7
- from utils import load_index, load_model
 
 
8
 
9
 
10
  def app(model_name):
11
- images_directory = "images/val2017"
12
- features_directory = f"features/val2017/{model_name}.tsv"
13
 
14
  files, index = load_index(features_directory)
15
- model, processor = load_model(f"koclip/{model_name}")
16
 
17
  st.title("Text to Image Search Engine")
18
- st.markdown(
19
- """
20
  This demonstration explores capability of KoCLIP as a Korean-language Image search engine. Embeddings for each of
21
  5000 images from [MSCOCO](https://cocodataset.org/#home) 2017 validation set was generated using trained KoCLIP
22
  vision model. They are ranked based on cosine similarity distance from input Text query embeddings and top 10 images
@@ -28,11 +27,9 @@ def app(model_name):
28
  Larger model `koclip-large` uses `klue/roberta` as text encoder and bigger `google/vit-large-patch16-224` as image encoder.
29
 
30
  Example Queries : ์•„ํŒŒํŠธ(Apartment), ์ž๋™์ฐจ(Car), ์ปดํ“จํ„ฐ(Computer)
31
- """
32
- )
33
 
34
  query = st.text_input("ํ•œ๊ธ€ ์งˆ๋ฌธ์„ ์ ์–ด์ฃผ์„ธ์š” (Korean Text Query) :", value="์•„ํŒŒํŠธ")
35
-
36
  if st.button("์งˆ๋ฌธ (Query)"):
37
  proc = processor(text=[query], images=None, return_tensors="jax", padding=True)
38
  vec = np.asarray(model.get_text_features(**proc))
1
  import os
2
 
 
 
3
  import streamlit as st
4
 
5
+ from utils import load_model, load_index
6
+ import numpy as np
7
+ import matplotlib.pyplot as plt
8
 
9
 
10
  def app(model_name):
11
+ images_directory = 'images/val2017'
12
+ features_directory = f'features/val2017/{model_name}.tsv'
13
 
14
  files, index = load_index(features_directory)
15
+ model, processor = load_model(f'koclip/{model_name}')
16
 
17
  st.title("Text to Image Search Engine")
18
+ st.markdown("""
 
19
  This demonstration explores capability of KoCLIP as a Korean-language Image search engine. Embeddings for each of
20
  5000 images from [MSCOCO](https://cocodataset.org/#home) 2017 validation set was generated using trained KoCLIP
21
  vision model. They are ranked based on cosine similarity distance from input Text query embeddings and top 10 images
27
  Larger model `koclip-large` uses `klue/roberta` as text encoder and bigger `google/vit-large-patch16-224` as image encoder.
28
 
29
  Example Queries : ์•„ํŒŒํŠธ(Apartment), ์ž๋™์ฐจ(Car), ์ปดํ“จํ„ฐ(Computer)
30
+ """)
 
31
 
32
  query = st.text_input("ํ•œ๊ธ€ ์งˆ๋ฌธ์„ ์ ์–ด์ฃผ์„ธ์š” (Korean Text Query) :", value="์•„ํŒŒํŠธ")
 
33
  if st.button("์งˆ๋ฌธ (Query)"):
34
  proc = processor(text=[query], images=None, return_tensors="jax", padding=True)
35
  vec = np.asarray(model.get_text_features(**proc))
utils.py CHANGED
@@ -1,28 +1,26 @@
1
  import nmslib
2
- import numpy as np
3
  import streamlit as st
4
- from transformers import AutoTokenizer, CLIPProcessor, ViTFeatureExtractor
 
5
 
6
  from koclip import FlaxHybridCLIP
7
 
8
-
9
  @st.cache(allow_output_mutation=True)
10
  def load_index(img_file):
11
  filenames, embeddings = [], []
12
  lines = open(img_file, "r")
13
  for line in lines:
14
- cols = line.strip().split("\t")
15
  filename = cols[0]
16
- embedding = [float(x) for x in cols[1].split(",")]
17
  filenames.append(filename)
18
  embeddings.append(embedding)
19
  embeddings = np.array(embeddings)
20
- index = nmslib.init(method="hnsw", space="cosinesimil")
21
  index.addDataPointBatch(embeddings)
22
- index.createIndex({"post": 2}, print_progress=True)
23
  return filenames, index
24
 
25
-
26
  @st.cache(allow_output_mutation=True)
27
  def load_model(model_name="koclip/koclip-base"):
28
  assert model_name in {"koclip/koclip-base", "koclip/koclip-large"}
@@ -30,7 +28,11 @@ def load_model(model_name="koclip/koclip-base"):
30
  processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
31
  processor.tokenizer = AutoTokenizer.from_pretrained("klue/roberta-large")
32
  if model_name == "koclip/koclip-large":
33
- processor.feature_extractor = ViTFeatureExtractor.from_pretrained(
34
- "google/vit-large-patch16-224"
35
- )
 
 
 
 
36
  return model, processor
1
  import nmslib
 
2
  import streamlit as st
3
+ from transformers import CLIPProcessor, AutoTokenizer, ViTFeatureExtractor
4
+ import numpy as np
5
 
6
  from koclip import FlaxHybridCLIP
7
 
 
8
  @st.cache(allow_output_mutation=True)
9
  def load_index(img_file):
10
  filenames, embeddings = [], []
11
  lines = open(img_file, "r")
12
  for line in lines:
13
+ cols = line.strip().split('\t')
14
  filename = cols[0]
15
+ embedding = np.array([float(x) for x in cols[1].split(',')])
16
  filenames.append(filename)
17
  embeddings.append(embedding)
18
  embeddings = np.array(embeddings)
19
+ index = nmslib.init(method='hnsw', space='cosinesimil')
20
  index.addDataPointBatch(embeddings)
21
+ index.createIndex({'post': 2}, print_progress=True)
22
  return filenames, index
23
 
 
24
  @st.cache(allow_output_mutation=True)
25
  def load_model(model_name="koclip/koclip-base"):
26
  assert model_name in {"koclip/koclip-base", "koclip/koclip-large"}
28
  processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
29
  processor.tokenizer = AutoTokenizer.from_pretrained("klue/roberta-large")
30
  if model_name == "koclip/koclip-large":
31
+ processor.feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-large-patch16-224")
32
+ return model, processor
33
+
34
+ @st.cache(allow_output_mutation=True)
35
+ def load_model_v2(model_name="koclip/koclip"):
36
+ model = FlaxHybridCLIP.from_pretrained(model_name)
37
+ processor = CLIPProcessor.from_pretrained(model_name)
38
  return model, processor