vinid commited on
Commit
d05a223
β€’
1 Parent(s): 997927f

fixing image2text

Browse files
Files changed (2) hide show
  1. image2text.py +18 -8
  2. requirements.txt +2 -1
image2text.py CHANGED
@@ -4,7 +4,8 @@ from utils import text_encoder, image_encoder
4
  from PIL import Image
5
  from jax import numpy as jnp
6
  import pandas as pd
7
-
 
8
 
9
  def app():
10
  st.title("From Image to Text")
@@ -17,23 +18,31 @@ def app():
17
  image classification task!
18
 
19
  🀌 Italian mode on! 🀌
 
 
 
20
 
21
  """
22
  )
23
 
24
- filename = st.file_uploader(
25
- "Choose an image from your computer", type=["jpg", "jpeg", "png"]
 
 
 
 
26
  )
27
 
 
28
  MAX_CAP = 4
29
 
30
  col1, col2 = st.beta_columns([3, 1])
31
 
32
  with col2:
33
  captions_count = st.selectbox(
34
- "Number of labels", options=range(1, MAX_CAP + 1)
35
  )
36
- compute = st.button("Compute")
37
 
38
  with col1:
39
  captions = list()
@@ -43,7 +52,7 @@ def app():
43
  if compute:
44
  captions = [c for c in captions if c != ""]
45
 
46
- if not captions or not filename:
47
  st.error("Please choose one image and at least one label")
48
  else:
49
  with st.spinner("Computing..."):
@@ -55,13 +64,14 @@ def app():
55
  text_embeds.extend(text_encoder(c, model, tokenizer))
56
 
57
  text_embeds = jnp.array(text_embeds)
 
58
 
59
- image = Image.open(filename).convert("RGB")
60
  transform = get_image_transform(model.config.vision_config.image_size)
61
  image_embed = image_encoder(transform(image), model)
62
 
63
  # we could have a softmax here
64
- cos_similarities = jnp.matmul(image_embed, text_embeds.T)
65
 
66
  chart_data = pd.Series(cos_similarities[0], index=captions)
67
 
 
4
  from PIL import Image
5
  from jax import numpy as jnp
6
  import pandas as pd
7
+ import requests
8
+ import jax
9
 
10
  def app():
11
  st.title("From Image to Text")
 
18
  image classification task!
19
 
20
  🀌 Italian mode on! 🀌
21
+
22
+ For example, try to write "cat" in the space for label1 and "dog" in the space for label2 and the run
23
+ "classify"!
24
 
25
  """
26
  )
27
 
28
+ image_url = st.text_input(
29
+
30
+ "You can input the URL of an image",
31
+
32
+ value="https://upload.wikimedia.org/wikipedia/commons/thumb/5/5e/Domestic_Cat_Face_Shot.jpg/1280px-Domestic_Cat_Face_Shot.jpg",
33
+
34
  )
35
 
36
+
37
  MAX_CAP = 4
38
 
39
  col1, col2 = st.beta_columns([3, 1])
40
 
41
  with col2:
42
  captions_count = st.selectbox(
43
+ "Number of labels", options=range(1, MAX_CAP + 1), index=1
44
  )
45
+ compute = st.button("Classify")
46
 
47
  with col1:
48
  captions = list()
 
52
  if compute:
53
  captions = [c for c in captions if c != ""]
54
 
55
+ if not captions or not image_url:
56
  st.error("Please choose one image and at least one label")
57
  else:
58
  with st.spinner("Computing..."):
 
64
  text_embeds.extend(text_encoder(c, model, tokenizer))
65
 
66
  text_embeds = jnp.array(text_embeds)
67
+ image_raw = requests.get(image_url, stream=True).raw
68
 
69
+ image = Image.open(image_raw).convert("RGB")
70
  transform = get_image_transform(model.config.vision_config.image_size)
71
  image_embed = image_encoder(transform(image), model)
72
 
73
  # we could have a softmax here
74
+ cos_similarities = jax.nn.softmax(jnp.matmul(image_embed, text_embeds.T))
75
 
76
  chart_data = pd.Series(cos_similarities[0], index=captions)
77
 
requirements.txt CHANGED
@@ -5,4 +5,5 @@ torch
5
  torchvision
6
  natsort
7
  stqdm
8
- pandas
 
 
5
  torchvision
6
  natsort
7
  stqdm
8
+ pandas
9
+ requests