jaketae commited on
Commit
7326e2c
1 Parent(s): 8ff0261

feature: show prob scores as bar chart

Browse files
Files changed (2) hide show
  1. image2text.py +18 -13
  2. requirements.txt +3 -1
image2text.py CHANGED
@@ -1,8 +1,10 @@
1
  import streamlit as st
 
2
  import numpy as np
3
  import jax
4
  import jax.numpy as jnp
5
  from PIL import Image
 
6
 
7
  from utils import load_model
8
 
@@ -17,20 +19,21 @@ def app(model_name):
17
  """
18
  )
19
 
20
- query = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
21
- captions = st.text_input("사용하실 캡션을 쉼표 단위로 구분해서 적어주세요", value="고양이,강아지,느티나무...")
 
 
 
 
 
22
 
23
  if st.button("질문 (Query)"):
24
- if query is None:
25
- st.error("Please upload an image query.")
26
  else:
27
- image = Image.open(query)
 
28
  st.image(image)
29
- # pixel_values = processor(
30
- # text=[""], images=image, return_tensors="jax", padding=True
31
- # ).pixel_values
32
- # pixel_values = jnp.transpose(pixel_values, axes=[0, 2, 3, 1])
33
- # vec = np.asarray(model.get_image_features(pixel_values))
34
  captions = captions.split(",")
35
  inputs = processor(text=captions, images=image, return_tensors="jax", padding=True)
36
  inputs["pixel_values"] = jnp.transpose(
@@ -38,8 +41,10 @@ def app(model_name):
38
  )
39
  outputs = model(**inputs)
40
  probs = jax.nn.softmax(outputs.logits_per_image, axis=1)
41
-
42
- for idx, prob in sorted(enumerate(*probs), key=lambda x: x[1], reverse=True):
43
- st.text(f"Score: `{prob}`, {captions[idx]}")
 
 
44
 
45
 
1
  import streamlit as st
2
+ import requests
3
  import numpy as np
4
  import jax
5
  import jax.numpy as jnp
6
  from PIL import Image
7
+ import pandas as pd
8
 
9
  from utils import load_model
10
 
19
  """
20
  )
21
 
22
+ query1 = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
23
+ query2 = st.text_input("or a URL to an image...")
24
+
25
+ captions = st.text_input(
26
+ "Enter candidate captions in comma-separated form.",
27
+ value="귀여운 고양이,멋있는 강아지,트랜스포머"
28
+ )
29
 
30
  if st.button("질문 (Query)"):
31
+ if not any([query1, query2]):
32
+ st.error("Please upload an image or paste an image URL.")
33
  else:
34
+ image_data = query1 if query1 is not None else requests.get(query2, stream=True).raw
35
+ image = Image.open(image_data)
36
  st.image(image)
 
 
 
 
 
37
  captions = captions.split(",")
38
  inputs = processor(text=captions, images=image, return_tensors="jax", padding=True)
39
  inputs["pixel_values"] = jnp.transpose(
41
  )
42
  outputs = model(**inputs)
43
  probs = jax.nn.softmax(outputs.logits_per_image, axis=1)
44
+ score_dict = {captions[idx]: prob for idx, prob in enumerate(*probs)}
45
+ df = pd.DataFrame(score_dict.values(), index=score_dict.keys())
46
+ st.bar_chart(df)
47
+ # for idx, prob in sorted(enumerate(*probs), key=lambda x: x[1], reverse=True):
48
+ # st.text(f"Score: `{prob}`, {captions[idx]}")
49
 
50
 
requirements.txt CHANGED
@@ -5,4 +5,6 @@ transformers
5
  streamlit
6
  tqdm
7
  nmslib
8
- matplotlib
 
 
5
  streamlit
6
  tqdm
7
  nmslib
8
+ matplotlib
9
+ pandas
10
+ requests