SkalskiP commited on
Commit
03b9405
1 Parent(s): 3e01228

Update Dockerfile and app.py for improved functionality

Browse files
Files changed (3) hide show
  1. .gitignore +2 -1
  2. Dockerfile +5 -1
  3. app.py +31 -82
.gitignore CHANGED
@@ -1,2 +1,3 @@
1
  .idea/
2
- venv/
 
 
1
  .idea/
2
+ venv/
3
+ weights/
Dockerfile CHANGED
@@ -31,12 +31,16 @@ WORKDIR $HOME/app
31
  RUN pip install torch==2.0.1+cu117 torchvision==0.15.2+cu117 -f https://download.pytorch.org/whl/torch_stable.html
32
 
33
  # Install dependencies
34
- RUN pip install --no-cache-dir gradio==3.50.2 opencv-python supervision
35
 
36
  # Install SAM and Detectron2
37
  RUN pip install 'git+https://github.com/facebookresearch/segment-anything.git'
38
  RUN pip install 'git+https://github.com/facebookresearch/detectron2.git'
39
 
 
 
 
 
40
  COPY app.py .
41
 
42
  RUN find $HOME/app
 
31
  RUN pip install torch==2.0.1+cu117 torchvision==0.15.2+cu117 -f https://download.pytorch.org/whl/torch_stable.html
32
 
33
  # Install dependencies
34
+ RUN pip install --no-cache-dir gradio==3.50.2 opencv-python supervision pillow
35
 
36
  # Install SAM and Detectron2
37
  RUN pip install 'git+https://github.com/facebookresearch/segment-anything.git'
38
  RUN pip install 'git+https://github.com/facebookresearch/detectron2.git'
39
 
40
+ # Download weights
41
+ RUN mkdir -p $HOME/app/weigths
42
+ RUN wget -c -O $HOME/app/weigths/sam_vit_h_4b8939.pth https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth
43
+
44
  COPY app.py .
45
 
46
  RUN find $HOME/app
app.py CHANGED
@@ -1,98 +1,47 @@
1
- import gradio as gr
2
-
3
- from detectron2.data import MetadataCatalog
4
- from segment_anything import SamAutomaticMaskGenerator
5
 
 
 
 
6
 
7
- metadata = MetadataCatalog.get('coco_2017_train_panoptic')
8
- print(metadata)
9
 
 
10
 
11
- class ImageMask(gr.components.Image):
12
- """
13
- Sets: source="canvas", tool="sketch"
14
- """
15
 
16
- is_template = True
 
 
 
 
 
 
 
 
17
 
18
- def __init__(self, **kwargs):
19
- super().__init__(source="upload", tool="sketch", interactive=True, **kwargs)
20
 
21
- def preprocess(self, x):
22
- return super().preprocess(x)
23
 
 
 
24
 
25
- demo = gr.Blocks()
26
- image = ImageMask(
27
- label="Input",
28
- type="pil",
29
- brush_radius=20.0,
30
- brush_color="#FFFFFF")
31
- slider = gr.Slider(
32
- minimum=1,
33
- maximum=3,
34
- value=2,
35
- label="Granularity",
36
- info="Choose in [1, 1.5), [1.5, 2.5), [2.5, 3] for [seem, semantic-sam (multi-level), sam]")
37
- mode = gr.Radio(
38
- choices=['Automatic', 'Interactive', ],
39
- value='Automatic',
40
- label="Segmentation Mode")
41
- image_out = gr.Image(label="Auto generation", type="pil")
42
- slider_alpha = gr.Slider(
43
- minimum=0,
44
- maximum=1,
45
- value=0.1,
46
- label="Mask Alpha",
47
- info="Choose in [0, 1]")
48
- label_mode = gr.Radio(
49
- choices=['Number', 'Alphabet'],
50
- value='Number',
51
- label="Mark Mode")
52
- anno_mode = gr.CheckboxGroup(
53
- choices=["Mask", "Box", "Mark"],
54
- value=['Mask', 'Mark'],
55
- label="Annotation Mode")
56
- runBtn = gr.Button("Run")
57
 
58
- title = "Set-of-Mark (SoM) Prompting for Visual Grounding in GPT-4V"
59
- description = "This is a demo for SoM Prompting to unleash extraordinary visual grounding in GPT-4V. Please upload an image and them click the 'Run' button to get the image with marks. Then try it on <a href='https://chat.openai.com/'>GPT-4V<a>!"
 
60
 
61
- with demo:
62
- gr.Markdown(f"<h1 style='text-align: center;'>{title}</h1>")
63
- gr.Markdown("<h3 style='text-align: center; margin-bottom: 1rem'>project: <a href='https://som-gpt4v.github.io/'>link</a>, arXiv: <a href='https://arxiv.org/abs/2310.11441'>link</a>, code: <a href='https://github.com/microsoft/SoM'>link</a></h3>")
64
- gr.Markdown(f"<h3 style='margin-bottom: 1rem'>{description}</h3>")
65
  with gr.Row():
66
  with gr.Column():
67
- image.render()
68
- slider.render()
69
- with gr.Row():
70
- mode.render()
71
- anno_mode.render()
72
- with gr.Row():
73
- slider_alpha.render()
74
- label_mode.render()
75
  with gr.Column():
76
- image_out.render()
77
- runBtn.render()
78
- # with gr.Row():
79
- # example = gr.Examples(
80
- # examples=[
81
- # ["examples/ironing_man.jpg"],
82
- # ],
83
- # inputs=image,
84
- # cache_examples=False,
85
- # )
86
- # example = gr.Examples(
87
- # examples=[
88
- # ["examples/ironing_man_som.png"],
89
- # ],
90
- # inputs=image,
91
- # cache_examples=False,
92
- # label='Marked Examples',
93
- # )
94
 
95
- # runBtn.click(inference, inputs=[image, slider, mode, slider_alpha, label_mode, anno_mode],
96
- # outputs = image_out)
97
 
98
- demo.queue().launch()
 
1
+ import torch
 
 
 
2
 
3
+ import gradio as gr
4
+ import numpy as np
5
+ import supervision as sv
6
 
7
+ from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
 
8
 
9
+ DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
10
 
11
+ SAM_CHECKPOINT = "weights/sam_vit_h_4b8939.pth"
12
+ SAM_MODEL_TYPE = "vit_h"
 
 
13
 
14
+ MARKDOWN = """
15
+ <h1 style='text-align: center'>
16
+ <img
17
+ src='https://som-gpt4v.github.io/website/img/som_logo.png'
18
+ style='height:50px; display:inline-block'
19
+ />
20
+ Set-of-Mark (SoM) Prompting Unleashes Extraordinary Visual Grounding in GPT-4V
21
+ </h1>
22
+ """
23
 
24
+ sam = sam_model_registry[SAM_MODEL_TYPE](checkpoint=SAM_CHECKPOINT).to(device=DEVICE)
25
+ mask_generator = SamAutomaticMaskGenerator(sam)
26
 
 
 
27
 
28
+ def inference(image: np.ndarray) -> np.ndarray:
29
+ return image
30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
+ image_input = gr.Image(label="Input", type="numpy")
33
+ image_output = gr.Image(label="SoM Visual Prompt", type="numpy", height=512)
34
+ run_button = gr.Button("Run")
35
 
36
+ with gr.Blocks() as demo:
37
+ gr.Markdown(MARKDOWN)
 
 
38
  with gr.Row():
39
  with gr.Column():
40
+ image_input.render()
 
 
 
 
 
 
 
41
  with gr.Column():
42
+ image_output.render()
43
+ run_button.render()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
+ run_button.click(inference, inputs=[image_input], outputs=image_output)
 
46
 
47
+ demo.queue().launch(debug=False, show_error=True)