jaketae commited on
Commit
6525b03
1 Parent(s): 84c806e

fix: discard remaining pixels

Browse files
Files changed (1) hide show
  1. most_relevant_part.py +27 -25
most_relevant_part.py CHANGED
@@ -1,22 +1,23 @@
1
  import os
2
- import requests
3
- import streamlit as st
4
- from PIL import Image
5
 
6
  import jax
7
  import jax.numpy as jnp
8
  import numpy as np
 
 
 
9
 
10
  from utils import load_model
11
 
12
- def split_image(im):
 
13
  im = np.array(im)
14
- M = im.shape[0] // 3
15
- N = im.shape[1] // 3
16
  tiles = [
17
- im[x:x + M, y:y + N]
18
- for x in range(0, im.shape[0], M)
19
- for y in range(0, im.shape[1], N)
20
  ]
21
  return tiles
22
 
@@ -36,19 +37,21 @@ def app(model_name):
36
  model, processor = load_model(f"koclip/{model_name}")
37
 
38
  st.title("Most Relevant Part of Image")
39
- st.markdown("""
 
40
  Given a piece of text, the CLIP model finds the part of an image that best explains the text.
41
  To try it out, you can
42
  1) Upload an image
43
  2) Explain a part of the image in text
44
  Which will yield the most relevant image tile from a 3x3 grid of the image
45
- """)
 
46
 
47
  query1 = st.text_input(
48
  "Enter a URL to an image...",
49
- value="https://img.sbs.co.kr/newimg/news/20200823/201463830_1280.jpg")
50
- query2 = st.file_uploader("or upload an image...",
51
- type=["jpg", "jpeg", "png"])
52
  captions = st.text_input(
53
  "Enter query to find most relevant part of image ",
54
  value="이건 서울의 경복궁 사진이다.",
@@ -58,23 +61,22 @@ def app(model_name):
58
  if not any([query1, query2]):
59
  st.error("Please upload an image or paste an image URL.")
60
  else:
61
- image_data = (query2 if query2 is not None else requests.get(
62
- query1, stream=True).raw)
 
63
  image = Image.open(image_data)
64
  st.image(image)
65
 
66
  images = split_image(image)
67
 
68
- inputs = processor(text=captions,
69
- images=images,
70
- return_tensors="jax",
71
- padding=True)
72
- inputs["pixel_values"] = jnp.transpose(inputs["pixel_values"],
73
- axes=[0, 2, 3, 1])
74
  outputs = model(**inputs)
75
  probs = jax.nn.softmax(outputs.logits_per_image, axis=0)
76
- for idx, prob in sorted(enumerate(probs),
77
- key=lambda x: x[1],
78
- reverse=True):
79
  st.text(f"Score: {prob[0]:.3f}")
80
  st.image(images[idx])
1
  import os
 
 
 
2
 
3
  import jax
4
  import jax.numpy as jnp
5
  import numpy as np
6
+ import requests
7
+ import streamlit as st
8
+ from PIL import Image
9
 
10
  from utils import load_model
11
 
12
+
13
+ def split_image(im, num_rows=3, num_cols=3):
14
  im = np.array(im)
15
+ row_size = im.shape[0] // num_rows
16
+ col_size = im.shape[1] // num_cols
17
  tiles = [
18
+ im[x : x + M, y : y + N]
19
+ for x in range(0, num_rows * row_size, row_size)
20
+ for y in range(0, num_cols * col_size, col_size)
21
  ]
22
  return tiles
23
 
37
  model, processor = load_model(f"koclip/{model_name}")
38
 
39
  st.title("Most Relevant Part of Image")
40
+ st.markdown(
41
+ """
42
  Given a piece of text, the CLIP model finds the part of an image that best explains the text.
43
  To try it out, you can
44
  1) Upload an image
45
  2) Explain a part of the image in text
46
  Which will yield the most relevant image tile from a 3x3 grid of the image
47
+ """
48
+ )
49
 
50
  query1 = st.text_input(
51
  "Enter a URL to an image...",
52
+ value="https://img.sbs.co.kr/newimg/news/20200823/201463830_1280.jpg",
53
+ )
54
+ query2 = st.file_uploader("or upload an image...", type=["jpg", "jpeg", "png"])
55
  captions = st.text_input(
56
  "Enter query to find most relevant part of image ",
57
  value="이건 서울의 경복궁 사진이다.",
61
  if not any([query1, query2]):
62
  st.error("Please upload an image or paste an image URL.")
63
  else:
64
+ image_data = (
65
+ query2 if query2 is not None else requests.get(query1, stream=True).raw
66
+ )
67
  image = Image.open(image_data)
68
  st.image(image)
69
 
70
  images = split_image(image)
71
 
72
+ inputs = processor(
73
+ text=captions, images=images, return_tensors="jax", padding=True
74
+ )
75
+ inputs["pixel_values"] = jnp.transpose(
76
+ inputs["pixel_values"], axes=[0, 2, 3, 1]
77
+ )
78
  outputs = model(**inputs)
79
  probs = jax.nn.softmax(outputs.logits_per_image, axis=0)
80
+ for idx, prob in sorted(enumerate(probs), key=lambda x: x[1], reverse=True):
 
 
81
  st.text(f"Score: {prob[0]:.3f}")
82
  st.image(images[idx])