huy-ha commited on
Commit
19d656f
·
1 Parent(s): 9f4163a

return results in gallery

Browse files
Files changed (1) hide show
  1. app.py +10 -13
app.py CHANGED
@@ -21,8 +21,8 @@ def generate_relevancy(
21
  labels = labels.split(",")
22
  prompts = [prompt]
23
  assert img.dtype == np.uint8
 
24
  h, w, c = img.shape
25
- start = time()
26
  grads = ClipWrapper.get_clip_saliency(
27
  img=img,
28
  text_labels=np.array(labels),
@@ -32,16 +32,13 @@ def generate_relevancy(
32
  if subtract_mean:
33
  grads -= grads.mean(axis=0)
34
  grads = grads.cpu().numpy()
35
- num_axes = int(np.ceil(np.sqrt(len(labels))))
36
- fig, axes = plt.subplots(num_axes, num_axes)
37
- if num_axes == 1:
38
- axes = [axes]
39
- else:
40
- axes = axes.flatten()
41
  vmin = 0.002
42
  cmap = plt.get_cmap("jet")
43
  vmax = 0.008
44
- for ax, label_grad, label in zip(axes, grads, labels):
 
 
 
45
  ax.axis("off")
46
  ax.imshow(img)
47
  ax.set_title(label, fontsize=12)
@@ -50,10 +47,10 @@ def generate_relevancy(
50
  grad = 1 - grad
51
  colored_grad[..., -1] = grad * 0.7
52
  ax.imshow(colored_grad)
53
- plt.tight_layout(pad=0)
54
- img = plot_to_png(fig)
55
- plt.close(fig)
56
- return img
57
 
58
 
59
  iface = gr.Interface(
@@ -69,7 +66,7 @@ iface = gr.Interface(
69
  ),
70
  gr.Checkbox(value=True, label="subtract mean"),
71
  ],
72
- outputs=gr.Image(type="numpy"),
73
  examples=[
74
  [
75
  "https://semantic-abstraction.cs.columbia.edu/downloads/matterport.png",
 
21
  labels = labels.split(",")
22
  prompts = [prompt]
23
  assert img.dtype == np.uint8
24
+ img = Image.fromarray(img).resize((244 * 2, 244 * 2))
25
  h, w, c = img.shape
 
26
  grads = ClipWrapper.get_clip_saliency(
27
  img=img,
28
  text_labels=np.array(labels),
 
32
  if subtract_mean:
33
  grads -= grads.mean(axis=0)
34
  grads = grads.cpu().numpy()
 
 
 
 
 
 
35
  vmin = 0.002
36
  cmap = plt.get_cmap("jet")
37
  vmax = 0.008
38
+
39
+ returns = []
40
+ for label_grad, label in zip(grads, labels):
41
+ fig, ax = plt.subplots(1, 1)
42
  ax.axis("off")
43
  ax.imshow(img)
44
  ax.set_title(label, fontsize=12)
 
47
  grad = 1 - grad
48
  colored_grad[..., -1] = grad * 0.7
49
  ax.imshow(colored_grad)
50
+ plt.tight_layout(pad=0)
51
+ returns.append(plot_to_png(fig))
52
+ plt.close(fig)
53
+ return returns
54
 
55
 
56
  iface = gr.Interface(
 
66
  ),
67
  gr.Checkbox(value=True, label="subtract mean"),
68
  ],
69
+ outputs=gr.Gallery(label="Relevancy Maps", type="numpy"),
70
  examples=[
71
  [
72
  "https://semantic-abstraction.cs.columbia.edu/downloads/matterport.png",