Peng Shiya commited on
Commit
c4af616
1 Parent(s): 38277a1

feature: local feedback

Browse files
Files changed (4) hide show
  1. .gitignore +2 -1
  2. app.py +59 -17
  3. app_configs.py +2 -3
  4. feedback.py +54 -0
.gitignore CHANGED
@@ -1,3 +1,4 @@
1
  __pycache__/
2
  model/
3
- flagged/
 
 
1
  __pycache__/
2
  model/
3
+ flagged/
4
+ data/
app.py CHANGED
@@ -1,5 +1,6 @@
1
  import os
2
  import app_configs as configs
 
3
  import service
4
  import gradio as gr
5
  import numpy as np
@@ -7,6 +8,7 @@ import cv2
7
  from PIL import Image
8
  import logging
9
  from huggingface_hub import hf_hub_download
 
10
 
11
  logging.basicConfig(level=logging.INFO)
12
  logger = logging.getLogger()
@@ -23,7 +25,10 @@ def load_sam_instance():
23
  chkpt_path = hf_hub_download("ybelkada/segment-anything", configs.model_ckpt_path)
24
  else:
25
  chkpt_path = configs.model_ckpt_path
26
- sam = service.get_sam(configs.model_type, chkpt_path, configs.device)
 
 
 
27
  return sam
28
 
29
  block = gr.Blocks()
@@ -38,21 +43,30 @@ with block:
38
  point_labels = gr.State(point_labels_empty)
39
  masks = gr.State()
40
  cutout_idx = gr.State(set())
 
41
 
42
  # UI
43
  with gr.Column():
44
  with gr.Row():
45
  input_image = gr.Image(label='Input', height=512, type='pil')
46
- masks_annotated_image = gr.AnnotatedImage(label='Segments')
 
47
  with gr.Row():
48
- point_label_radio = gr.Radio(label='Point Label', choices=[1,0], value=1)
49
- reset_btn = gr.Button('Reset')
50
- run_btn = gr.Button('Run', variant = 'primary')
51
- cutout_galary = gr.Gallery(label='Cutouts', object_fit='contain')
52
-
 
 
 
 
 
 
53
  # components
54
  components = {
55
  point_coords, point_labels, raw_image, masks, cutout_idx,
 
56
  input_image, point_label_radio, reset_btn, run_btn, masks_annotated_image}
57
 
58
  # event - init coords
@@ -82,26 +96,54 @@ with block:
82
  image = inputs[raw_image]
83
  if len(inputs[point_coords]) == 0:
84
  if configs.enable_segment_all:
85
- masks, _ = service.predict_all(sam, image)
86
  else:
87
  raise gr.Error('Segment-all disabled, set point label(s) before running')
88
  else:
89
- masks, _ = service.predict_conditioned(sam,
90
- image,
91
- point_coords=np.array(inputs[point_coords]),
92
- point_labels=np.array(inputs[point_labels]))
93
- annotated = (image, [(masks[i], f'Mask {i}') for i in range(len(masks))])
94
- return annotated, masks, set()
95
- run_btn.click(on_run_btn_click, components, [masks_annotated_image, masks, cutout_idx], queue=True)
 
 
 
 
 
 
 
 
 
 
 
96
 
97
  # event - get cutout
98
  def on_masks_annotated_image_select(inputs, evt:gr.SelectData):
99
  inputs[cutout_idx].add(evt.index)
100
  cutouts = [service.cutout(inputs[raw_image], inputs[masks][idx]) for idx in list(inputs[cutout_idx])]
101
  tight_cutouts = [service.crop_empty(cutout) for cutout in cutouts]
102
- return inputs[cutout_idx], tight_cutouts
103
- masks_annotated_image.select(on_masks_annotated_image_select, components, [cutout_idx, cutout_galary])
 
104
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
  if __name__ == '__main__':
106
  block.queue()
107
  block.launch()
 
1
  import os
2
  import app_configs as configs
3
+ from feedback import Feedback
4
  import service
5
  import gradio as gr
6
  import numpy as np
 
8
  from PIL import Image
9
  import logging
10
  from huggingface_hub import hf_hub_download
11
+ import torch
12
 
13
  logging.basicConfig(level=logging.INFO)
14
  logger = logging.getLogger()
 
25
  chkpt_path = hf_hub_download("ybelkada/segment-anything", configs.model_ckpt_path)
26
  else:
27
  chkpt_path = configs.model_ckpt_path
28
+ device = configs.device
29
+ if device is None:
30
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
31
+ sam = service.get_sam(configs.model_type, chkpt_path, device)
32
  return sam
33
 
34
  block = gr.Blocks()
 
43
  point_labels = gr.State(point_labels_empty)
44
  masks = gr.State()
45
  cutout_idx = gr.State(set())
46
+ feedback = gr.State(lambda : Feedback())
47
 
48
  # UI
49
  with gr.Column():
50
  with gr.Row():
51
  input_image = gr.Image(label='Input', height=512, type='pil')
52
+ masks_annotated_image = gr.AnnotatedImage(label='Segments', height=512)
53
+ cutout_galary = gr.Gallery(label='Cutouts', object_fit='contain', height=512)
54
  with gr.Row():
55
+ with gr.Column(scale=1):
56
+ point_label_radio = gr.Radio(label='Point Label', choices=[1,0], value=1)
57
+ reset_btn = gr.Button('Reset')
58
+ run_btn = gr.Button('Run', variant = 'primary')
59
+ with gr.Column(scale=2):
60
+ with gr.Accordion('Provide Feedback'):
61
+ with gr.Row():
62
+ upvote_button = gr.Button('Upvote')
63
+ downvote_button = gr.Button('Downvote')
64
+ feedback_textbox = gr.Textbox(lines=3, show_label=False)
65
+ feedback_submit_button = gr.Button('Submit')
66
  # components
67
  components = {
68
  point_coords, point_labels, raw_image, masks, cutout_idx,
69
+ feedback, upvote_button, downvote_button, feedback_textbox, feedback_submit_button,
70
  input_image, point_label_radio, reset_btn, run_btn, masks_annotated_image}
71
 
72
  # event - init coords
 
96
  image = inputs[raw_image]
97
  if len(inputs[point_coords]) == 0:
98
  if configs.enable_segment_all:
99
+ generated_masks, _ = service.predict_all(sam, image)
100
  else:
101
  raise gr.Error('Segment-all disabled, set point label(s) before running')
102
  else:
103
+ generated_masks, _ = service.predict_conditioned(sam,
104
+ image,
105
+ point_coords=np.array(inputs[point_coords]),
106
+ point_labels=np.array(inputs[point_labels]))
107
+ annotated = (image, [(generated_masks[i], f'Mask {i}') for i in range(len(generated_masks))])
108
+ inputs[feedback].save_inference(
109
+ pt_coords=inputs[point_coords],
110
+ pt_labels=inputs[point_labels],
111
+ image=inputs[raw_image],
112
+ mask=generated_masks,
113
+ )
114
+ return {
115
+ masks_annotated_image:annotated,
116
+ masks: generated_masks,
117
+ cutout_idx: set(),
118
+ feedback: inputs[feedback],
119
+ }
120
+ run_btn.click(on_run_btn_click, components, [masks_annotated_image, masks, cutout_idx, feedback], queue=True)
121
 
122
  # event - get cutout
123
  def on_masks_annotated_image_select(inputs, evt:gr.SelectData):
124
  inputs[cutout_idx].add(evt.index)
125
  cutouts = [service.cutout(inputs[raw_image], inputs[masks][idx]) for idx in list(inputs[cutout_idx])]
126
  tight_cutouts = [service.crop_empty(cutout) for cutout in cutouts]
127
+ inputs[feedback].save_feedback(cutout_idx=evt.index)
128
+ return inputs[cutout_idx], tight_cutouts, inputs[feedback]
129
+ masks_annotated_image.select(on_masks_annotated_image_select, components, [cutout_idx, cutout_galary, feedback], queue=False)
130
 
131
+ # event - feedback
132
+ def on_feedback_submit_button_click(inputs):
133
+ inputs[feedback].save_feedback(feedback_str=inputs[feedback_textbox])
134
+ gr.Info('Thanks for your feedback')
135
+ return inputs[feedback], None
136
+ feedback_submit_button.click(on_feedback_submit_button_click, {feedback, feedback_textbox}, [feedback, feedback_textbox], queue=False)
137
+ def on_upvote_button_click(inputs):
138
+ inputs[feedback].save_feedback(like=1)
139
+ gr.Info('Thanks for your feedback')
140
+ return {feedback:inputs[feedback]}
141
+ upvote_button.click(on_upvote_button_click,components,[feedback, downvote_button], queue=False)
142
+ def on_downvote_button_click(inputs):
143
+ inputs[feedback].save_feedback(like=-1)
144
+ gr.Info('Thanks for your feedback')
145
+ return {feedback:inputs[feedback]}
146
+ downvote_button.click(on_downvote_button_click,components,[feedback, upvote_button], queue=False)
147
  if __name__ == '__main__':
148
  block.queue()
149
  block.launch()
app_configs.py CHANGED
@@ -1,6 +1,5 @@
1
  model_type = r'vit_b'
2
  # model_ckpt_path = None
3
  model_ckpt_path = "checkpoints/sam_vit_b_01ec64.pth"
4
- device = 'cpu'
5
- enable_segment_all = False
6
- flagging_dir = r'.\flagged'
 
1
  model_type = r'vit_b'
2
  # model_ckpt_path = None
3
  model_ckpt_path = "checkpoints/sam_vit_b_01ec64.pth"
4
+ device = None
5
+ enable_segment_all = True
 
feedback.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Dict, List
3
+ import uuid
4
+ import csv
5
+
6
+ import numpy as np
7
+ from PIL import Image
8
+
9
+ def write_row(filepath:str, row: Dict):
10
+ new_file = not os.path.isfile(filepath)
11
+ with open(filepath, mode="a", newline="") as file:
12
+ fieldnames = row.keys()
13
+ writer = csv.DictWriter(file, fieldnames=fieldnames)
14
+ if new_file:
15
+ writer.writeheader() # Write header if new file
16
+ writer.writerow(row) # Write the row
17
+
18
+ class Feedback():
19
+ def __init__(self,
20
+ image_dir = './data/input',
21
+ mask_dir = './data/mask',
22
+ inference_csv = './data/inference.csv',
23
+ feedback_csv = './data/feedback.csv',
24
+ ):
25
+ self.image_dir = image_dir
26
+ self.mask_dir = mask_dir
27
+ self.inference_csv = inference_csv
28
+ self.feedback_csv = feedback_csv
29
+
30
+ def save_inference(self, pt_coords:List, pt_labels:List, image: Image.Image, mask: np.ndarray):
31
+ self.inference_id = uuid.uuid4()
32
+ write_row(
33
+ filepath=self.inference_csv,
34
+ row = {
35
+ "inference_id": self.inference_id,
36
+ "image": image.tobytes(),
37
+ "mask": mask.tobytes(),
38
+ "pt_coords": str(pt_coords),
39
+ "pt_labels": str(pt_labels),
40
+ }
41
+ )
42
+
43
+ def save_feedback(self, cutout_idx:int=None, feedback_str:str=None, like:int=None):
44
+ write_row(
45
+ filepath=self.feedback_csv,
46
+ row = {
47
+ "inference_id": self.inference_id,
48
+ "cutout_idx": cutout_idx,
49
+ "feedback_str": feedback_str,
50
+ "like": like,
51
+ }
52
+ )
53
+
54
+