4rtemi5 commited on
Commit
6004a14
1 Parent(s): 20c970c

update localization

Browse files
Files changed (1) hide show
  1. localization.py +8 -8
localization.py CHANGED
@@ -127,22 +127,22 @@ def get_heatmap(image_url, text, pixel_size=10, iterations=3):
127
  # if full_sim is None:
128
  # full_sim = sim
129
  # sim = jax.nn.relu(sim - full_sim)
130
- emb = jnp.expand_dims(e, axis=0) * n
131
 
132
  if v:
133
  vm = jnp.any(m, axis=0)
134
- vertical_scores = vertical_scores + (emb * vm) #/ jnp.mean(vm)
135
- vertical_masks = vertical_masks + vm #/ jnp.mean(vm)
136
  if h:
137
  hm = jnp.any(m, axis=1)
138
- horizontal_scores = horizontal_scores + (emb * hm) #/ jnp.mean(hm)
139
- horizontal_masks = horizontal_masks + hm #/ jnp.mean(hm)
140
 
141
 
142
  embs_1 = jnp.expand_dims((vertical_scores), axis=0) * jnp.expand_dims(jnp.abs(horizontal_scores), axis=1)
143
  embs_2 = jnp.expand_dims(jnp.abs(vertical_scores), axis=0) * jnp.expand_dims((horizontal_scores), axis=1)
144
  full_embs = jnp.minimum(embs_1, embs_2)
145
- mask_sum = jnp.expand_dims(vertical_masks, axis=0) * jnp.expand_dims(horizontal_masks, axis=1)
146
  full_embs = (full_embs / mask_sum)
147
 
148
  orig_shape = full_embs.shape
@@ -196,9 +196,9 @@ def app():
196
  col1, col2 = st.columns([0.75, 0.25])
197
 
198
  with col2:
199
- pixel_size = st.selectbox("Pixel Size", options=range(10, 26, 5), index=2)
200
 
201
- iterations = st.selectbox("Refinement Steps", options=range(1, 6, 1), index=0)
202
 
203
  compute = st.button("LOCATE")
204
 
127
  # if full_sim is None:
128
  # full_sim = sim
129
  # sim = jax.nn.relu(sim - full_sim)
130
+ emb = jnp.expand_dims(e, axis=0) #* n
131
 
132
  if v:
133
  vm = jnp.any(m, axis=0)
134
+ vertical_scores = vertical_scores + (emb * vm) / jnp.mean(vm)
135
+ vertical_masks = vertical_masks + vm / jnp.mean(vm)
136
  if h:
137
  hm = jnp.any(m, axis=1)
138
+ horizontal_scores = horizontal_scores + (emb * hm) / jnp.mean(hm)
139
+ horizontal_masks = horizontal_masks + hm / jnp.mean(hm)
140
 
141
 
142
  embs_1 = jnp.expand_dims((vertical_scores), axis=0) * jnp.expand_dims(jnp.abs(horizontal_scores), axis=1)
143
  embs_2 = jnp.expand_dims(jnp.abs(vertical_scores), axis=0) * jnp.expand_dims((horizontal_scores), axis=1)
144
  full_embs = jnp.minimum(embs_1, embs_2)
145
+ mask_sum = jnp.expand_dims(vertical_masks + 1, axis=0) * jnp.expand_dims(horizontal_masks + 1, axis=1)
146
  full_embs = (full_embs / mask_sum)
147
 
148
  orig_shape = full_embs.shape
196
  col1, col2 = st.columns([0.75, 0.25])
197
 
198
  with col2:
199
+ pixel_size = st.selectbox("Pixel Size", options=range(5, 26, 5), index=3)
200
 
201
+ iterations = st.selectbox("Refinement Steps", options=range(0, 6, 1), index=0)
202
 
203
  compute = st.button("LOCATE")
204