4rtemi5 commited on
Commit
e8aa0cd
1 Parent(s): 8515d3b

Push localization with load management to master

Browse files
Files changed (1) hide show
  1. localization.py +20 -24
localization.py CHANGED
@@ -13,14 +13,10 @@ import jax
13
  import gc
14
 
15
 
16
- preprocess = transforms.Compose(
17
- [
18
- transforms.ToTensor(),
19
- transforms.Normalize(
20
- (0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)
21
- ),
22
- ]
23
- )
24
 
25
 
26
  def pad_to_square(image, size=224):
@@ -54,19 +50,19 @@ def gen_image_batch(image_url, image_size=224, pixel_size=10):
54
  masks.append(mask)
55
 
56
  for i in range(0, n_pixels):
57
- for j in range(i + 1, n_pixels):
58
  m = mask.copy()
59
- m[: min(i * pixel_size, image_size) + 1, :] = 0
60
- m[min(j * pixel_size, image_size) + 1 :, :] = 0
61
  neg_m = 1 - m
62
  image_batch.append(image * m + gray * neg_m)
63
  masks.append(m)
64
 
65
- for i in range(0, n_pixels + 1):
66
- for j in range(i + 1, n_pixels + 1):
67
  m = mask.copy()
68
- m[:, : min(i * pixel_size + 1, image_size)] = 0
69
- m[:, min(j * pixel_size + 1, image_size) :] = 0
70
  neg_m = 1 - m
71
  image_batch.append(image * m + gray * neg_m)
72
  masks.append(m)
@@ -79,9 +75,7 @@ def get_heatmap(image_url, text, pixel_size=10, iterations=3):
79
  model = get_model()
80
  image_size = model.config.vision_config.image_size
81
  text_embedding = text_encoder(text, model, tokenizer)
82
- images, masks = gen_image_batch(
83
- image_url, image_size=image_size, pixel_size=pixel_size
84
- )
85
 
86
  input_image = images[0].copy()
87
  images = np.stack([preprocess(image) for image in images], axis=0)
@@ -118,8 +112,6 @@ def app():
118
 
119
  For example, try typing "gatto" (cat) or "cane" (dog) in the space for label and click "locate"!
120
 
121
- *Depending on the server load, the computation time may vary. With normal load and pixel size 10, it can take up to two minutes.
122
- *
123
  """
124
  )
125
 
@@ -133,9 +125,13 @@ def app():
133
  col1, col2 = st.beta_columns([3, 1])
134
 
135
  with col2:
136
- pixel_size = st.selectbox("Pixel Size", options=range(10, 21, 5), index=1)
 
 
137
 
138
- iterations = st.selectbox("Refinement Steps", options=range(3, 30, 3), index=0)
 
 
139
 
140
  compute = st.button("LOCATE")
141
 
@@ -152,7 +148,7 @@ def app():
152
 
153
 
154
  if not caption or not image_url:
155
- st.error("Please specify an image URL and a label")
156
  else:
157
  with st.spinner("Computing..."):
158
  heatmap, image = get_heatmap(image_url, caption, pixel_size, iterations)
@@ -164,7 +160,7 @@ def app():
164
  gc.collect()
165
 
166
  elif image_url:
167
- image_raw = requests.get(image_url, stream=True,).raw
168
  image = Image.open(image_raw).convert("RGB")
169
  with col1:
170
  st.image(image)
 
13
  import gc
14
 
15
 
16
+ preprocess = transforms.Compose([
17
+ transforms.ToTensor(),
18
+ transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
19
+ ])
 
 
 
 
20
 
21
 
22
  def pad_to_square(image, size=224):
 
50
  masks.append(mask)
51
 
52
  for i in range(0, n_pixels):
53
+ for j in range(i+1, n_pixels):
54
  m = mask.copy()
55
+ m[:min(i*pixel_size, image_size) + 1, :] = 0
56
+ m[min(j*pixel_size, image_size) + 1:, :] = 0
57
  neg_m = 1 - m
58
  image_batch.append(image * m + gray * neg_m)
59
  masks.append(m)
60
 
61
+ for i in range(0, n_pixels+1):
62
+ for j in range(i+1, n_pixels+1):
63
  m = mask.copy()
64
+ m[:, :min(i*pixel_size + 1, image_size)] = 0
65
+ m[:, min(j*pixel_size + 1, image_size):] = 0
66
  neg_m = 1 - m
67
  image_batch.append(image * m + gray * neg_m)
68
  masks.append(m)
 
75
  model = get_model()
76
  image_size = model.config.vision_config.image_size
77
  text_embedding = text_encoder(text, model, tokenizer)
78
+ images, masks = gen_image_batch(image_url, image_size=image_size, pixel_size=pixel_size)
 
 
79
 
80
  input_image = images[0].copy()
81
  images = np.stack([preprocess(image) for image in images], axis=0)
 
112
 
113
  For example, try typing "gatto" (cat) or "cane" (dog) in the space for label and click "locate"!
114
 
 
 
115
  """
116
  )
117
 
 
125
  col1, col2 = st.beta_columns([3, 1])
126
 
127
  with col2:
128
+ pixel_size = st.selectbox(
129
+ "Pixel Size", options=range(10, 21, 5), index=0
130
+ )
131
 
132
+ iterations = st.selectbox(
133
+ "Refinement Steps", options=range(3, 30, 3), index=0
134
+ )
135
 
136
  compute = st.button("LOCATE")
137
 
 
148
 
149
 
150
  if not caption or not image_url:
151
+ st.error("Please choose one image and at least one label")
152
  else:
153
  with st.spinner("Computing..."):
154
  heatmap, image = get_heatmap(image_url, caption, pixel_size, iterations)
 
160
  gc.collect()
161
 
162
  elif image_url:
163
+ image_raw = requests.get(image_url, stream=True, ).raw
164
  image = Image.open(image_raw).convert("RGB")
165
  with col1:
166
  st.image(image)