jaketae commited on
Commit
83d94a8
β€’
1 Parent(s): 2e45025

feature: replace comma separated input w/ counter ui

Browse files
Files changed (1) hide show
  1. image2text.py +41 -26
image2text.py CHANGED
@@ -15,7 +15,7 @@ def app(model_name):
15
  st.title("Zero-shot Image Classification")
16
  st.markdown(
17
  """
18
- This demonstration explores capability of KoCLIP in the field of Zero-Shot Prediction. This demo takes a set of image and captions from, and predicts the most likely label among the different captions given.
19
 
20
  KoCLIP is a retraining of OpenAI's CLIP model using 82,783 images from [MSCOCO](https://cocodataset.org/#home) dataset and Korean caption annotations. Korean translation of caption annotations were obtained from [AI Hub](https://aihub.or.kr/keti_data_board/visual_intelligence). Base model `koclip` uses `klue/roberta` as text encoder and `openai/clip-vit-base-patch32` as image encoder. Larger model `koclip-large` uses `klue/roberta` as text encoder and bigger `google/vit-large-patch16-224` as image encoder.
21
  """
@@ -27,32 +27,47 @@ def app(model_name):
27
  )
28
  query2 = st.file_uploader("or upload an image...", type=["jpg", "jpeg", "png"])
29
 
30
- captions = st.text_input(
31
- "Enter candidate captions in comma-separated form.",
32
- value="κ·€μ—¬μš΄ 고양이,λ©‹μžˆλŠ” 강아지,ν¬λ™ν¬λ™ν•œ ν–„μŠ€ν„°",
33
- )
 
 
 
34
 
35
- if st.button("질문 (Query)"):
 
 
 
 
 
 
 
36
  if not any([query1, query2]):
37
  st.error("Please upload an image or paste an image URL.")
38
  else:
39
- image_data = (
40
- query2 if query2 is not None else requests.get(query1, stream=True).raw
41
- )
42
- image = Image.open(image_data)
43
- st.image(image)
44
- # captions = [caption.strip() for caption in captions.split(",")]
45
- captions = [f"이것은 {caption.strip()}이닀." for caption in captions.split(",")]
46
- inputs = processor(
47
- text=captions, images=image, return_tensors="jax", padding=True
48
- )
49
- inputs["pixel_values"] = jnp.transpose(
50
- inputs["pixel_values"], axes=[0, 2, 3, 1]
51
- )
52
- outputs = model(**inputs)
53
- probs = jax.nn.softmax(outputs.logits_per_image, axis=1)
54
- score_dict = {captions[idx]: prob for idx, prob in enumerate(*probs)}
55
- df = pd.DataFrame(score_dict.values(), index=score_dict.keys())
56
- st.bar_chart(df)
57
- # for idx, prob in sorted(enumerate(*probs), key=lambda x: x[1], reverse=True):
58
- # st.text(f"Score: `{prob}`, {captions[idx]}")
 
 
 
 
 
 
15
  st.title("Zero-shot Image Classification")
16
  st.markdown(
17
  """
18
+ This demonstration explores capability of KoCLIP in the field of Zero-Shot Prediction. This demo takes a set of image and captions from the user, and predicts the most likely label among the different captions given.
19
 
20
  KoCLIP is a retraining of OpenAI's CLIP model using 82,783 images from [MSCOCO](https://cocodataset.org/#home) dataset and Korean caption annotations. Korean translation of caption annotations were obtained from [AI Hub](https://aihub.or.kr/keti_data_board/visual_intelligence). Base model `koclip` uses `klue/roberta` as text encoder and `openai/clip-vit-base-patch32` as image encoder. Larger model `koclip-large` uses `klue/roberta` as text encoder and bigger `google/vit-large-patch16-224` as image encoder.
21
  """
 
27
  )
28
  query2 = st.file_uploader("or upload an image...", type=["jpg", "jpeg", "png"])
29
 
30
+ col1, col2 = st.beta_columns([3, 1])
31
+
32
+ with col2:
33
+ captions_count = st.selectbox(
34
+ "Number of labels", options=range(1, 6), index=2
35
+ )
36
+ compute = st.button("Classify")
37
 
38
+ with col1:
39
+ captions = []
40
+ defaults = ["κ·€μ—¬μš΄ 고양이", "λ©‹μžˆλŠ” 강아지", "ν¬λ™ν¬λ™ν•œ ν–„μŠ€ν„°"]
41
+ for idx in range(captions_count):
42
+ value = defaults[idx] if idx < len(defaults) else ""
43
+ captions.append(st.text_input(f"Insert label {idx+1}", value=value))
44
+
45
+ if compute:
46
  if not any([query1, query2]):
47
  st.error("Please upload an image or paste an image URL.")
48
  else:
49
+ st.markdown("""---""")
50
+ with st.spinner("Computing..."):
51
+ image_data = (
52
+ query2 if query2 is not None else requests.get(query1, stream=True).raw
53
+ )
54
+ image = Image.open(image_data)
55
+
56
+ # captions = [caption.strip() for caption in captions.split(",")]
57
+ captions = [f"이것은 {caption.strip()}이닀." for caption in captions]
58
+ inputs = processor(
59
+ text=captions, images=image, return_tensors="jax", padding=True
60
+ )
61
+ inputs["pixel_values"] = jnp.transpose(
62
+ inputs["pixel_values"], axes=[0, 2, 3, 1]
63
+ )
64
+ outputs = model(**inputs)
65
+ probs = jax.nn.softmax(outputs.logits_per_image, axis=1)
66
+ chart_data = pd.Series(probs[0], index=captions)
67
+
68
+ col1, col2 = st.beta_columns(2)
69
+ with col1:
70
+ st.image(image)
71
+ with col2:
72
+ st.bar_chart(chart_data)
73
+