ttengwang commited on
Commit
3cb3d90
1 Parent(s): ecb53e5
Files changed (3) hide show
  1. app.py +263 -0
  2. env.sh +6 -0
  3. requirements.txt +18 -0
app.py ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from io import BytesIO
2
+ import string
3
+ import gradio as gr
4
+ import requests
5
+ from caas import CaptionAnything
6
+ import torch
7
+ import json
8
+ import sys
9
+ import argparse
10
+ from caas import parse_augment
11
+ import numpy as np
12
+ import PIL.ImageDraw as ImageDraw
13
+ from image_editing_utils import create_bubble_frame
14
+ import copy
15
+ from tools import mask_painter
16
+ from PIL import Image
17
+ import os
18
+
19
+ def download_checkpoint(url, folder, filename):
20
+ os.makedirs(folder, exist_ok=True)
21
+ filepath = os.path.join(folder, filename)
22
+
23
+ if not os.path.exists(filepath):
24
+ response = requests.get(url, stream=True)
25
+ with open(filepath, "wb") as f:
26
+ for chunk in response.iter_content(chunk_size=8192):
27
+ if chunk:
28
+ f.write(chunk)
29
+
30
+ return filepath
31
+ checkpoint_url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
32
+ folder = "segmenter"
33
+ filename = "sam_vit_h_4b8939.pth"
34
+
35
+ download_checkpoint(checkpoint_url, folder, filename)
36
+
37
+
38
+ title = """<h1 align="center">Caption-Anything</h1>"""
39
+ description = """Gradio demo for Caption Anything, image to dense captioning generation with various language styles. To use it, simply upload your image, or click one of the examples to load them. Code: https://github.com/ttengwang/Caption-Anything
40
+ """
41
+
42
+ examples = [
43
+ ["test_img/img2.jpg"],
44
+ ["test_img/img5.jpg"],
45
+ ["test_img/img12.jpg"],
46
+ ["test_img/img14.jpg"],
47
+ ]
48
+
49
+ args = parse_augment()
50
+ # args.device = 'cuda:5'
51
+ # args.disable_gpt = False
52
+ # args.enable_reduce_tokens = True
53
+ # args.port=20322
54
+ model = CaptionAnything(args)
55
+
56
+ def init_openai_api_key(api_key):
57
+ os.environ['OPENAI_API_KEY'] = api_key
58
+ model.init_refiner()
59
+
60
+
61
+ def get_prompt(chat_input, click_state):
62
+ points = click_state[0]
63
+ labels = click_state[1]
64
+ inputs = json.loads(chat_input)
65
+ for input in inputs:
66
+ points.append(input[:2])
67
+ labels.append(input[2])
68
+
69
+ prompt = {
70
+ "prompt_type":["click"],
71
+ "input_point":points,
72
+ "input_label":labels,
73
+ "multimask_output":"True",
74
+ }
75
+ return prompt
76
+
77
+ def chat_with_points(chat_input, click_state, state):
78
+ if not hasattr(model, "text_refiner"):
79
+ response = "Text refiner is not initilzed, please input openai api key."
80
+ state = state + [(chat_input, response)]
81
+ return state, state
82
+
83
+ points, labels, captions = click_state
84
+ point_chat_prompt = "I want you act as a chat bot in terms of image. I will give you some points (w, h) in the image and tell you what happed on the point in natural language. Note that (0, 0) refers to the top-left corner of the image, w refers to the width and h refers the height. You should chat with me based on the fact in the image instead of imagination. Now I tell you the points with their visual description:\n{points_with_caps}\nNow begin chatting! Human: {chat_input}\nAI: "
85
+ # "The image is of width {width} and height {height}."
86
+
87
+ prev_visual_context = ""
88
+ pos_points = [f"{points[i][0]}, {points[i][1]}" for i in range(len(points)) if labels[i] == 1]
89
+ if len(captions):
90
+ prev_visual_context = ', '.join(pos_points) + captions[-1] + '\n'
91
+ else:
92
+ prev_visual_context = 'no point exists.'
93
+ chat_prompt = point_chat_prompt.format(**{"points_with_caps": prev_visual_context, "chat_input": chat_input})
94
+ response = model.text_refiner.llm(chat_prompt)
95
+ state = state + [(chat_input, response)]
96
+ return state, state
97
+
98
+ def inference_seg_cap(image_input, point_prompt, language, sentiment, factuality, length, state, click_state, evt:gr.SelectData):
99
+
100
+ if point_prompt == 'Positive':
101
+ coordinate = "[[{}, {}, 1]]".format(str(evt.index[0]), str(evt.index[1]))
102
+ else:
103
+ coordinate = "[[{}, {}, 0]]".format(str(evt.index[0]), str(evt.index[1]))
104
+
105
+ controls = {'length': length,
106
+ 'sentiment': sentiment,
107
+ 'factuality': factuality,
108
+ 'language': language}
109
+
110
+ # click_coordinate = "[[{}, {}, 1]]".format(str(evt.index[0]), str(evt.index[1]))
111
+ # chat_input = click_coordinate
112
+ prompt = get_prompt(coordinate, click_state)
113
+ print('prompt: ', prompt, 'controls: ', controls)
114
+
115
+ out = model.inference(image_input, prompt, controls)
116
+ state = state + [(None, "Image point: {}, Input label: {}".format(prompt["input_point"], prompt["input_label"]))]
117
+ for k, v in out['generated_captions'].items():
118
+ state = state + [(f'{k}: {v}', None)]
119
+
120
+ click_state[2].append(out['generated_captions']['raw_caption'])
121
+
122
+ text = out['generated_captions']['raw_caption']
123
+ # draw = ImageDraw.Draw(image_input)
124
+ # draw.text((evt.index[0], evt.index[1]), text, textcolor=(0,0,255), text_size=120)
125
+ input_mask = np.array(Image.open(out['mask_save_path']).convert('P'))
126
+ image_input = mask_painter(np.array(image_input), input_mask)
127
+ origin_image_input = image_input
128
+ image_input = create_bubble_frame(image_input, text, (evt.index[0], evt.index[1]))
129
+
130
+ yield state, state, click_state, chat_input, image_input
131
+ if not args.disable_gpt and hasattr(model, "text_refiner"):
132
+ refined_caption = model.text_refiner.inference(query=text, controls=controls, context=out['context_captions'])
133
+ new_cap = 'Original: ' + text + '. Refined: ' + refined_caption['caption']
134
+ refined_image_input = create_bubble_frame(origin_image_input, new_cap, (evt.index[0], evt.index[1]))
135
+ yield state, state, click_state, chat_input, refined_image_input
136
+
137
+
138
+ def upload_callback(image_input, state):
139
+ state = [] + [('Image size: ' + str(image_input.size), None)]
140
+ click_state = [[], [], []]
141
+ model.segmenter.image = None
142
+ model.segmenter.image_embedding = None
143
+ model.segmenter.set_image(image_input)
144
+ return state, image_input, click_state
145
+
146
+ with gr.Blocks(
147
+ css='''
148
+ #image_upload{min-height:400px}
149
+ #image_upload [data-testid="image"], #image_upload [data-testid="image"] > div{min-height: 600px}
150
+ '''
151
+ ) as iface:
152
+ state = gr.State([])
153
+ click_state = gr.State([[],[],[]])
154
+ origin_image = gr.State(None)
155
+
156
+ gr.Markdown(title)
157
+ gr.Markdown(description)
158
+
159
+ with gr.Row():
160
+ with gr.Column(scale=1.0):
161
+ image_input = gr.Image(type="pil", interactive=True, elem_id="image_upload")
162
+ with gr.Row(scale=1.0):
163
+ point_prompt = gr.Radio(
164
+ choices=["Positive", "Negative"],
165
+ value="Positive",
166
+ label="Point Prompt",
167
+ interactive=True)
168
+ clear_button_clike = gr.Button(value="Clear Clicks", interactive=True)
169
+ clear_button_image = gr.Button(value="Clear Image", interactive=True)
170
+ with gr.Row(scale=1.0):
171
+ language = gr.Dropdown(['English', 'Chinese', 'French', "Spanish", "Arabic", "Portuguese", "Cantonese"], value="English", label="Language", interactive=True)
172
+
173
+ sentiment = gr.Radio(
174
+ choices=["Positive", "Natural", "Negative"],
175
+ value="Natural",
176
+ label="Sentiment",
177
+ interactive=True,
178
+ )
179
+ with gr.Row(scale=1.0):
180
+ factuality = gr.Radio(
181
+ choices=["Factual", "Imagination"],
182
+ value="Factual",
183
+ label="Factuality",
184
+ interactive=True,
185
+ )
186
+ length = gr.Slider(
187
+ minimum=10,
188
+ maximum=80,
189
+ value=10,
190
+ step=1,
191
+ interactive=True,
192
+ label="Length",
193
+ )
194
+
195
+ with gr.Column(scale=0.5):
196
+ openai_api_key = gr.Textbox(
197
+ placeholder="Input your openAI API key and press Enter",
198
+ show_label=True,
199
+ label = "OpenAI API Key",
200
+ lines=1,
201
+ type="password"
202
+ )
203
+ openai_api_key.submit(init_openai_api_key, inputs=[openai_api_key])
204
+ chatbot = gr.Chatbot(label="Chat about Selected Object",).style(height=620,scale=0.5)
205
+ chat_input = gr.Textbox(lines=1, label="Chat Input")
206
+ with gr.Row():
207
+ clear_button_text = gr.Button(value="Clear Text", interactive=True)
208
+ submit_button_text = gr.Button(value="Submit", interactive=True, variant="primary")
209
+ clear_button_clike.click(
210
+ lambda x: ([[], [], []], x),
211
+ [origin_image],
212
+ [click_state, image_input],
213
+ queue=False,
214
+ show_progress=False
215
+ )
216
+ clear_button_image.click(
217
+ lambda: (None, [], [], [[], [], []]),
218
+ [],
219
+ [image_input, chatbot, state, click_state],
220
+ queue=False,
221
+ show_progress=False
222
+ )
223
+ clear_button_text.click(
224
+ lambda: ([], [], [[], [], []]),
225
+ [],
226
+ [chatbot, state, click_state],
227
+ queue=False,
228
+ show_progress=False
229
+ )
230
+ image_input.clear(
231
+ lambda: (None, [], [], [[], [], []]),
232
+ [],
233
+ [image_input, chatbot, state, click_state],
234
+ queue=False,
235
+ show_progress=False
236
+ )
237
+
238
+ examples = gr.Examples(
239
+ examples=examples,
240
+ inputs=[image_input],
241
+ )
242
+
243
+ image_input.upload(upload_callback,[image_input, state], [state, origin_image, click_state])
244
+ chat_input.submit(chat_with_points, [chat_input, click_state, state], [chatbot, state])
245
+
246
+ # select coordinate
247
+ image_input.select(inference_seg_cap,
248
+ inputs=[
249
+ origin_image,
250
+ point_prompt,
251
+ language,
252
+ sentiment,
253
+ factuality,
254
+ length,
255
+ state,
256
+ click_state
257
+ ],
258
+
259
+ outputs=[chatbot, state, click_state, chat_input, image_input],
260
+ show_progress=False, queue=True)
261
+
262
+ iface.queue(concurrency_count=1, api_open=False, max_size=10)
263
+ iface.launch(server_name="0.0.0.0", enable_queue=True, server_port=args.port, share=args.gradio_share)
env.sh ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ conda create -n caption_anything python=3.8 -y
2
+ source activate caption_anything
3
+ pip install -r requirement.txt
4
+ cd segmenter
5
+ wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth
6
+
requirements.txt ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==1.10.1
2
+ torchvision==0.11.2
3
+ torchaudio==0.10.1
4
+ openai
5
+ pillow
6
+ langchain==0.0.101
7
+ git+https://github.com/huggingface/transformers.git
8
+ ftfy
9
+ regex
10
+ tqdm
11
+ git+https://github.com/openai/CLIP.git
12
+ git+https://github.com/facebookresearch/segment-anything.git
13
+ opencv-python
14
+ pycocotools
15
+ matplotlib
16
+ onnxruntime
17
+ onnx
18
+ https://gradio-builds.s3.amazonaws.com/3e68e5e882a6790ac5b457bd33f4edf9b695af90/gradio-3.24.1-py3-none-any.whl