ttengwang commited on
Commit
10240e0
1 Parent(s): 36b57c3

update lastest

Browse files
Files changed (11) hide show
  1. .gitignore +135 -0
  2. README.md +45 -11
  3. app.py +28 -23
  4. app_huggingface.py +268 -0
  5. app_old.py +5 -5
  6. caption_anything.py +114 -0
  7. captioner/base_captioner.py +3 -2
  8. env.sh +1 -1
  9. image_editing_utils.py +2 -2
  10. requirements.txt +1 -0
  11. tools.py +7 -1
.gitignore ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ result/
2
+ model_cache/
3
+ *.pth
4
+ teng_grad_start.sh
5
+
6
+ # Byte-compiled / optimized / DLL files
7
+ __pycache__/
8
+ *.py[cod]
9
+ *$py.class
10
+ result/
11
+
12
+ # C extensions
13
+ *.so
14
+
15
+ # Distribution / packaging
16
+ .Python
17
+ build/
18
+ develop-eggs/
19
+ dist/
20
+ downloads/
21
+ eggs/
22
+ .eggs/
23
+ lib/
24
+ lib64/
25
+ parts/
26
+ sdist/
27
+ var/
28
+ wheels/
29
+ pip-wheel-metadata/
30
+ share/python-wheels/
31
+ *.egg-info/
32
+ .installed.cfg
33
+ *.egg
34
+ MANIFEST
35
+
36
+ # PyInstaller
37
+ # Usually these files are written by a python script from a template
38
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
39
+ *.manifest
40
+ *.spec
41
+
42
+ # Installer logs
43
+ pip-log.txt
44
+ pip-delete-this-directory.txt
45
+
46
+ # Unit test / coverage reports
47
+ htmlcov/
48
+ .tox/
49
+ .nox/
50
+ .coverage
51
+ .coverage.*
52
+ .cache
53
+ nosetests.xml
54
+ coverage.xml
55
+ *.cover
56
+ *.py,cover
57
+ .hypothesis/
58
+ .pytest_cache/
59
+
60
+ # Translations
61
+ *.mo
62
+ *.pot
63
+
64
+ # Django stuff:
65
+ *.log
66
+ local_settings.py
67
+ db.sqlite3
68
+ db.sqlite3-journal
69
+
70
+ # Flask stuff:
71
+ instance/
72
+ .webassets-cache
73
+
74
+ # Scrapy stuff:
75
+ .scrapy
76
+
77
+ # Sphinx documentation
78
+ docs/_build/
79
+
80
+ # PyBuilder
81
+ target/
82
+
83
+ # Jupyter Notebook
84
+ .ipynb_checkpoints
85
+
86
+ # IPython
87
+ profile_default/
88
+ ipython_config.py
89
+
90
+ # pyenv
91
+ .python-version
92
+
93
+ # pipenv
94
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
95
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
96
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
97
+ # install all needed dependencies.
98
+ #Pipfile.lock
99
+
100
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
101
+ __pypackages__/
102
+
103
+ # Celery stuff
104
+ celerybeat-schedule
105
+ celerybeat.pid
106
+
107
+ # SageMath parsed files
108
+ *.sage.py
109
+
110
+ # Environments
111
+ .env
112
+ .venv
113
+ env/
114
+ venv/
115
+ ENV/
116
+ env.bak/
117
+ venv.bak/
118
+
119
+ # Spyder project settings
120
+ .spyderproject
121
+ .spyproject
122
+
123
+ # Rope project settings
124
+ .ropeproject
125
+
126
+ # mkdocs documentation
127
+ /site
128
+
129
+ # mypy
130
+ .mypy_cache/
131
+ .dmypy.json
132
+ dmypy.json
133
+
134
+ # Pyre type checker
135
+ .pyre/
README.md CHANGED
@@ -1,13 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
- title: Caption Anything
3
- emoji: 📚
4
- colorFrom: green
5
- colorTo: green
6
- sdk: gradio
7
- sdk_version: 3.24.1
8
- app_file: app.py
9
- pinned: false
10
- license: apache-2.0
11
- ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Caption-Anything
2
+ <!-- ![](./Image/title.svg) -->
3
+ **Caption-Anything** is a versatile image processing tool that combines the capabilities of [Segment Anything](https://github.com/facebookresearch/segment-anything), Visual Captioning, and [ChatGPT](https://openai.com/blog/chatgpt). Our solution generates descriptive captions for any object within an image, offering a range of language styles to accommodate diverse user preferences. **Caption-Anything** supports visual controls (mouse click) and language controls (length, sentiment, factuality, and language).
4
+ * visual controls and language controls for text generation
5
+ * Chat about selected object for detailed understanding
6
+ * Interactive demo
7
+ ![](./Image/UI.png)
8
+
9
+ <!-- <a src="https://img.shields.io/badge/%F0%9F%A4%97-Open%20in%20Spaces-blue" href="https://huggingface.co/spaces/wybertwang/Caption-Anything">
10
+ <img src="https://img.shields.io/badge/%F0%9F%A4%97-Open%20in%20Spaces-blue" alt="Open in Spaces">
11
+ </a> -->
12
+
13
+ <!-- <a src="https://colab.research.google.com/assets/colab-badge.svg" href="">
14
+ <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open in Colab">
15
+ </a> -->
16
+
17
+ ### Demo
18
+ Explore the interactive demo of Caption-Anything, which showcases its powerful capabilities in generating captions for various objects within an image. The demo allows users to control visual aspects by clicking on objects, as well as to adjust textual properties such as length, sentiment, factuality, and language.
19
+ ![](./Image/demo1.png)
20
+
21
  ---
 
 
 
 
 
 
 
 
 
 
22
 
23
+ ![](./Image/demo2.png)
24
+
25
+ ### Getting Started
26
+
27
+
28
+ * Clone the repository:
29
+ ```bash
30
+ git clone https://github.com/ttengwang/caption-anything.git
31
+ ```
32
+ * Install dependencies:
33
+ ```bash
34
+ cd caption-anything
35
+ pip install -r requirements.txt
36
+ ```
37
+ * Download the [SAM checkpoints](https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth) and place it to `./segmenter/sam_vit_h_4b8939.pth.`
38
+
39
+ * Run the Caption-Anything gradio demo.
40
+ ```bash
41
+ # Configure the necessary ChatGPT APIs
42
+ export OPENAI_API_KEY={Your_Private_Openai_Key}
43
+ python app.py --regular_box --captioner blip2 --port 6086
44
+ ```
45
+
46
+ ## Acknowledgement
47
+ The project is based on [Segment Anything](https://github.com/facebookresearch/segment-anything), BLIP/BLIP-2, [ChatGPT](https://openai.com/blog/chatgpt). Thanks for the authors for their efforts.
app.py CHANGED
@@ -2,12 +2,12 @@ from io import BytesIO
2
  import string
3
  import gradio as gr
4
  import requests
5
- from caas import CaptionAnything
6
  import torch
7
  import json
8
  import sys
9
  import argparse
10
- from caas import parse_augment
11
  import numpy as np
12
  import PIL.ImageDraw as ImageDraw
13
  from image_editing_utils import create_bubble_frame
@@ -47,6 +47,9 @@ examples = [
47
  ]
48
 
49
  args = parse_augment()
 
 
 
50
  # args.device = 'cuda:5'
51
  # args.disable_gpt = False
52
  # args.enable_reduce_tokens = True
@@ -81,9 +84,9 @@ def chat_with_points(chat_input, click_state, state):
81
  return state, state
82
 
83
  points, labels, captions = click_state
84
- point_chat_prompt = "I want you act as a chat bot in terms of image. I will give you some points (w, h) in the image and tell you what happed on the point in natural language. Note that (0, 0) refers to the top-left corner of the image, w refers to the width and h refers the height. You should chat with me based on the fact in the image instead of imagination. Now I tell you the points with their visual description:\n{points_with_caps}\nNow begin chatting! Human: {chat_input}\nAI: "
85
- # "The image is of width {width} and height {height}."
86
-
87
  prev_visual_context = ""
88
  pos_points = [f"{points[i][0]}, {points[i][1]}" for i in range(len(points)) if labels[i] == 1]
89
  if len(captions):
@@ -114,9 +117,10 @@ def inference_seg_cap(image_input, point_prompt, language, sentiment, factuality
114
 
115
  out = model.inference(image_input, prompt, controls)
116
  state = state + [(None, "Image point: {}, Input label: {}".format(prompt["input_point"], prompt["input_label"]))]
117
- for k, v in out['generated_captions'].items():
118
- state = state + [(f'{k}: {v}', None)]
119
-
 
120
  click_state[2].append(out['generated_captions']['raw_caption'])
121
 
122
  text = out['generated_captions']['raw_caption']
@@ -127,12 +131,13 @@ def inference_seg_cap(image_input, point_prompt, language, sentiment, factuality
127
  origin_image_input = image_input
128
  image_input = create_bubble_frame(image_input, text, (evt.index[0], evt.index[1]))
129
 
130
- yield state, state, click_state, chat_input, image_input
131
  if not args.disable_gpt and hasattr(model, "text_refiner"):
132
  refined_caption = model.text_refiner.inference(query=text, controls=controls, context=out['context_captions'])
133
- new_cap = 'Original: ' + text + '. Refined: ' + refined_caption['caption']
 
134
  refined_image_input = create_bubble_frame(origin_image_input, new_cap, (evt.index[0], evt.index[1]))
135
- yield state, state, click_state, chat_input, refined_image_input
136
 
137
 
138
  def upload_callback(image_input, state):
@@ -195,28 +200,29 @@ with gr.Blocks(
195
  with gr.Column(scale=0.5):
196
  openai_api_key = gr.Textbox(
197
  placeholder="Input your openAI API key and press Enter",
198
- show_label=True,
199
  label = "OpenAI API Key",
200
  lines=1,
201
  type="password"
202
  )
203
  openai_api_key.submit(init_openai_api_key, inputs=[openai_api_key])
204
- chatbot = gr.Chatbot(label="Chat about Selected Object",).style(height=620,scale=0.5)
 
205
  chat_input = gr.Textbox(lines=1, label="Chat Input")
206
  with gr.Row():
207
  clear_button_text = gr.Button(value="Clear Text", interactive=True)
208
  submit_button_text = gr.Button(value="Submit", interactive=True, variant="primary")
209
  clear_button_clike.click(
210
- lambda x: ([[], [], []], x),
211
  [origin_image],
212
- [click_state, image_input],
213
  queue=False,
214
  show_progress=False
215
  )
216
  clear_button_image.click(
217
- lambda: (None, [], [], [[], [], []]),
218
  [],
219
- [image_input, chatbot, state, click_state],
220
  queue=False,
221
  show_progress=False
222
  )
@@ -228,9 +234,9 @@ with gr.Blocks(
228
  show_progress=False
229
  )
230
  image_input.clear(
231
- lambda: (None, [], [], [[], [], []]),
232
  [],
233
- [image_input, chatbot, state, click_state],
234
  queue=False,
235
  show_progress=False
236
  )
@@ -255,9 +261,8 @@ with gr.Blocks(
255
  state,
256
  click_state
257
  ],
258
-
259
- outputs=[chatbot, state, click_state, chat_input, image_input],
260
- show_progress=False, queue=True)
261
 
262
  iface.queue(concurrency_count=5, api_open=False, max_size=10)
263
- iface.launch(server_name="0.0.0.0", enable_queue=True)
 
2
  import string
3
  import gradio as gr
4
  import requests
5
+ from caption_anything import CaptionAnything
6
  import torch
7
  import json
8
  import sys
9
  import argparse
10
+ from caption_anything import parse_augment
11
  import numpy as np
12
  import PIL.ImageDraw as ImageDraw
13
  from image_editing_utils import create_bubble_frame
 
47
  ]
48
 
49
  args = parse_augment()
50
+ args.captioner = 'blip2'
51
+ args.seg_crop_mode = 'wo_bg'
52
+ args.regular_box = True
53
  # args.device = 'cuda:5'
54
  # args.disable_gpt = False
55
  # args.enable_reduce_tokens = True
 
84
  return state, state
85
 
86
  points, labels, captions = click_state
87
+ # point_chat_prompt = "I want you act as a chat bot in terms of image. I will give you some points (w, h) in the image and tell you what happed on the point in natural language. Note that (0, 0) refers to the top-left corner of the image, w refers to the width and h refers the height. You should chat with me based on the fact in the image instead of imagination. Now I tell you the points with their visual description:\n{points_with_caps}\nNow begin chatting! Human: {chat_input}\nAI: "
88
+ # # "The image is of width {width} and height {height}."
89
+ point_chat_prompt = "a) Revised prompt: I am an AI trained to chat with you about an image based on specific points (w, h) you provide, along with their visual descriptions. Please note that (0, 0) refers to the top-left corner of the image, w refers to the width, and h refers to the height. Here are the points and their descriptions you've given me: {points_with_caps}. Now, let's chat! Human: {chat_input} AI:"
90
  prev_visual_context = ""
91
  pos_points = [f"{points[i][0]}, {points[i][1]}" for i in range(len(points)) if labels[i] == 1]
92
  if len(captions):
 
117
 
118
  out = model.inference(image_input, prompt, controls)
119
  state = state + [(None, "Image point: {}, Input label: {}".format(prompt["input_point"], prompt["input_label"]))]
120
+ # for k, v in out['generated_captions'].items():
121
+ # state = state + [(f'{k}: {v}', None)]
122
+ state = state + [("caption: {}".format(out['generated_captions']['raw_caption']), None)]
123
+ wiki = out['generated_captions'].get('wiki', "")
124
  click_state[2].append(out['generated_captions']['raw_caption'])
125
 
126
  text = out['generated_captions']['raw_caption']
 
131
  origin_image_input = image_input
132
  image_input = create_bubble_frame(image_input, text, (evt.index[0], evt.index[1]))
133
 
134
+ yield state, state, click_state, chat_input, image_input, wiki
135
  if not args.disable_gpt and hasattr(model, "text_refiner"):
136
  refined_caption = model.text_refiner.inference(query=text, controls=controls, context=out['context_captions'])
137
+ # new_cap = 'Original: ' + text + '. Refined: ' + refined_caption['caption']
138
+ new_cap = refined_caption['caption']
139
  refined_image_input = create_bubble_frame(origin_image_input, new_cap, (evt.index[0], evt.index[1]))
140
+ yield state, state, click_state, chat_input, refined_image_input, wiki
141
 
142
 
143
  def upload_callback(image_input, state):
 
200
  with gr.Column(scale=0.5):
201
  openai_api_key = gr.Textbox(
202
  placeholder="Input your openAI API key and press Enter",
203
+ show_label=False,
204
  label = "OpenAI API Key",
205
  lines=1,
206
  type="password"
207
  )
208
  openai_api_key.submit(init_openai_api_key, inputs=[openai_api_key])
209
+ wiki_output = gr.Textbox(lines=6, label="Wiki")
210
+ chatbot = gr.Chatbot(label="Chat about Selected Object",).style(height=450,scale=0.5)
211
  chat_input = gr.Textbox(lines=1, label="Chat Input")
212
  with gr.Row():
213
  clear_button_text = gr.Button(value="Clear Text", interactive=True)
214
  submit_button_text = gr.Button(value="Submit", interactive=True, variant="primary")
215
  clear_button_clike.click(
216
+ lambda x: ([[], [], []], x, ""),
217
  [origin_image],
218
+ [click_state, image_input, wiki_output],
219
  queue=False,
220
  show_progress=False
221
  )
222
  clear_button_image.click(
223
+ lambda: (None, [], [], [[], [], []], ""),
224
  [],
225
+ [image_input, chatbot, state, click_state, wiki_output],
226
  queue=False,
227
  show_progress=False
228
  )
 
234
  show_progress=False
235
  )
236
  image_input.clear(
237
+ lambda: (None, [], [], [[], [], []], ""),
238
  [],
239
+ [image_input, chatbot, state, click_state, wiki_output],
240
  queue=False,
241
  show_progress=False
242
  )
 
261
  state,
262
  click_state
263
  ],
264
+ outputs=[chatbot, state, click_state, chat_input, image_input, wiki_output],
265
+ show_progress=False, queue=True)
 
266
 
267
  iface.queue(concurrency_count=5, api_open=False, max_size=10)
268
+ iface.launch(server_name="0.0.0.0", enable_queue=True)
app_huggingface.py ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from io import BytesIO
2
+ import string
3
+ import gradio as gr
4
+ import requests
5
+ from caption_anything import CaptionAnything
6
+ import torch
7
+ import json
8
+ import sys
9
+ import argparse
10
+ from caption_anything import parse_augment
11
+ import numpy as np
12
+ import PIL.ImageDraw as ImageDraw
13
+ from image_editing_utils import create_bubble_frame
14
+ import copy
15
+ from tools import mask_painter
16
+ from PIL import Image
17
+ import os
18
+
19
+ def download_checkpoint(url, folder, filename):
20
+ os.makedirs(folder, exist_ok=True)
21
+ filepath = os.path.join(folder, filename)
22
+
23
+ if not os.path.exists(filepath):
24
+ response = requests.get(url, stream=True)
25
+ with open(filepath, "wb") as f:
26
+ for chunk in response.iter_content(chunk_size=8192):
27
+ if chunk:
28
+ f.write(chunk)
29
+
30
+ return filepath
31
+ checkpoint_url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
32
+ folder = "segmenter"
33
+ filename = "sam_vit_h_4b8939.pth"
34
+
35
+ download_checkpoint(checkpoint_url, folder, filename)
36
+
37
+
38
+ title = """<h1 align="center">Caption-Anything</h1>"""
39
+ description = """Gradio demo for Caption Anything, image to dense captioning generation with various language styles. To use it, simply upload your image, or click one of the examples to load them. Code: https://github.com/ttengwang/Caption-Anything
40
+ """
41
+
42
+ examples = [
43
+ ["test_img/img2.jpg"],
44
+ ["test_img/img5.jpg"],
45
+ ["test_img/img12.jpg"],
46
+ ["test_img/img14.jpg"],
47
+ ]
48
+
49
+ args = parse_augment()
50
+ args.captioner = 'blip2'
51
+ args.seg_crop_mode = 'wo_bg'
52
+ args.regular_box = True
53
+ # args.device = 'cuda:5'
54
+ # args.disable_gpt = False
55
+ # args.enable_reduce_tokens = True
56
+ # args.port=20322
57
+ model = CaptionAnything(args)
58
+
59
+ def init_openai_api_key(api_key):
60
+ os.environ['OPENAI_API_KEY'] = api_key
61
+ model.init_refiner()
62
+
63
+
64
+ def get_prompt(chat_input, click_state):
65
+ points = click_state[0]
66
+ labels = click_state[1]
67
+ inputs = json.loads(chat_input)
68
+ for input in inputs:
69
+ points.append(input[:2])
70
+ labels.append(input[2])
71
+
72
+ prompt = {
73
+ "prompt_type":["click"],
74
+ "input_point":points,
75
+ "input_label":labels,
76
+ "multimask_output":"True",
77
+ }
78
+ return prompt
79
+
80
+ def chat_with_points(chat_input, click_state, state):
81
+ if not hasattr(model, "text_refiner"):
82
+ response = "Text refiner is not initilzed, please input openai api key."
83
+ state = state + [(chat_input, response)]
84
+ return state, state
85
+
86
+ points, labels, captions = click_state
87
+ # point_chat_prompt = "I want you act as a chat bot in terms of image. I will give you some points (w, h) in the image and tell you what happed on the point in natural language. Note that (0, 0) refers to the top-left corner of the image, w refers to the width and h refers the height. You should chat with me based on the fact in the image instead of imagination. Now I tell you the points with their visual description:\n{points_with_caps}\nNow begin chatting! Human: {chat_input}\nAI: "
88
+ # # "The image is of width {width} and height {height}."
89
+ point_chat_prompt = "a) Revised prompt: I am an AI trained to chat with you about an image based on specific points (w, h) you provide, along with their visual descriptions. Please note that (0, 0) refers to the top-left corner of the image, w refers to the width, and h refers to the height. Here are the points and their descriptions you've given me: {points_with_caps}. Now, let's chat! Human: {chat_input} AI:"
90
+ prev_visual_context = ""
91
+ pos_points = [f"{points[i][0]}, {points[i][1]}" for i in range(len(points)) if labels[i] == 1]
92
+ if len(captions):
93
+ prev_visual_context = ', '.join(pos_points) + captions[-1] + '\n'
94
+ else:
95
+ prev_visual_context = 'no point exists.'
96
+ chat_prompt = point_chat_prompt.format(**{"points_with_caps": prev_visual_context, "chat_input": chat_input})
97
+ response = model.text_refiner.llm(chat_prompt)
98
+ state = state + [(chat_input, response)]
99
+ return state, state
100
+
101
+ def inference_seg_cap(image_input, point_prompt, language, sentiment, factuality, length, state, click_state, evt:gr.SelectData):
102
+
103
+ if point_prompt == 'Positive':
104
+ coordinate = "[[{}, {}, 1]]".format(str(evt.index[0]), str(evt.index[1]))
105
+ else:
106
+ coordinate = "[[{}, {}, 0]]".format(str(evt.index[0]), str(evt.index[1]))
107
+
108
+ controls = {'length': length,
109
+ 'sentiment': sentiment,
110
+ 'factuality': factuality,
111
+ 'language': language}
112
+
113
+ # click_coordinate = "[[{}, {}, 1]]".format(str(evt.index[0]), str(evt.index[1]))
114
+ # chat_input = click_coordinate
115
+ prompt = get_prompt(coordinate, click_state)
116
+ print('prompt: ', prompt, 'controls: ', controls)
117
+
118
+ out = model.inference(image_input, prompt, controls)
119
+ state = state + [(None, "Image point: {}, Input label: {}".format(prompt["input_point"], prompt["input_label"]))]
120
+ # for k, v in out['generated_captions'].items():
121
+ # state = state + [(f'{k}: {v}', None)]
122
+ state = state + [("caption: {}".format(out['generated_captions']['raw_caption']), None)]
123
+ wiki = out['generated_captions'].get('wiki', "")
124
+ click_state[2].append(out['generated_captions']['raw_caption'])
125
+
126
+ text = out['generated_captions']['raw_caption']
127
+ # draw = ImageDraw.Draw(image_input)
128
+ # draw.text((evt.index[0], evt.index[1]), text, textcolor=(0,0,255), text_size=120)
129
+ input_mask = np.array(Image.open(out['mask_save_path']).convert('P'))
130
+ image_input = mask_painter(np.array(image_input), input_mask)
131
+ origin_image_input = image_input
132
+ image_input = create_bubble_frame(image_input, text, (evt.index[0], evt.index[1]))
133
+
134
+ yield state, state, click_state, chat_input, image_input, wiki
135
+ if not args.disable_gpt and hasattr(model, "text_refiner"):
136
+ refined_caption = model.text_refiner.inference(query=text, controls=controls, context=out['context_captions'])
137
+ # new_cap = 'Original: ' + text + '. Refined: ' + refined_caption['caption']
138
+ new_cap = refined_caption['caption']
139
+ refined_image_input = create_bubble_frame(origin_image_input, new_cap, (evt.index[0], evt.index[1]))
140
+ yield state, state, click_state, chat_input, refined_image_input, wiki
141
+
142
+
143
+ def upload_callback(image_input, state):
144
+ state = [] + [('Image size: ' + str(image_input.size), None)]
145
+ click_state = [[], [], []]
146
+ model.segmenter.image = None
147
+ model.segmenter.image_embedding = None
148
+ model.segmenter.set_image(image_input)
149
+ return state, image_input, click_state
150
+
151
+ with gr.Blocks(
152
+ css='''
153
+ #image_upload{min-height:400px}
154
+ #image_upload [data-testid="image"], #image_upload [data-testid="image"] > div{min-height: 600px}
155
+ '''
156
+ ) as iface:
157
+ state = gr.State([])
158
+ click_state = gr.State([[],[],[]])
159
+ origin_image = gr.State(None)
160
+
161
+ gr.Markdown(title)
162
+ gr.Markdown(description)
163
+
164
+ with gr.Row():
165
+ with gr.Column(scale=1.0):
166
+ image_input = gr.Image(type="pil", interactive=True, elem_id="image_upload")
167
+ with gr.Row(scale=1.0):
168
+ point_prompt = gr.Radio(
169
+ choices=["Positive", "Negative"],
170
+ value="Positive",
171
+ label="Point Prompt",
172
+ interactive=True)
173
+ clear_button_clike = gr.Button(value="Clear Clicks", interactive=True)
174
+ clear_button_image = gr.Button(value="Clear Image", interactive=True)
175
+ with gr.Row(scale=1.0):
176
+ language = gr.Dropdown(['English', 'Chinese', 'French', "Spanish", "Arabic", "Portuguese", "Cantonese"], value="English", label="Language", interactive=True)
177
+
178
+ sentiment = gr.Radio(
179
+ choices=["Positive", "Natural", "Negative"],
180
+ value="Natural",
181
+ label="Sentiment",
182
+ interactive=True,
183
+ )
184
+ with gr.Row(scale=1.0):
185
+ factuality = gr.Radio(
186
+ choices=["Factual", "Imagination"],
187
+ value="Factual",
188
+ label="Factuality",
189
+ interactive=True,
190
+ )
191
+ length = gr.Slider(
192
+ minimum=10,
193
+ maximum=80,
194
+ value=10,
195
+ step=1,
196
+ interactive=True,
197
+ label="Length",
198
+ )
199
+
200
+ with gr.Column(scale=0.5):
201
+ openai_api_key = gr.Textbox(
202
+ placeholder="Input your openAI API key and press Enter",
203
+ show_label=False,
204
+ label = "OpenAI API Key",
205
+ lines=1,
206
+ type="password"
207
+ )
208
+ openai_api_key.submit(init_openai_api_key, inputs=[openai_api_key])
209
+ wiki_output = gr.Textbox(lines=6, label="Wiki")
210
+ chatbot = gr.Chatbot(label="Chat about Selected Object",).style(height=450,scale=0.5)
211
+ chat_input = gr.Textbox(lines=1, label="Chat Input")
212
+ with gr.Row():
213
+ clear_button_text = gr.Button(value="Clear Text", interactive=True)
214
+ submit_button_text = gr.Button(value="Submit", interactive=True, variant="primary")
215
+ clear_button_clike.click(
216
+ lambda x: ([[], [], []], x, ""),
217
+ [origin_image],
218
+ [click_state, image_input, wiki_output],
219
+ queue=False,
220
+ show_progress=False
221
+ )
222
+ clear_button_image.click(
223
+ lambda: (None, [], [], [[], [], []], ""),
224
+ [],
225
+ [image_input, chatbot, state, click_state, wiki_output],
226
+ queue=False,
227
+ show_progress=False
228
+ )
229
+ clear_button_text.click(
230
+ lambda: ([], [], [[], [], []]),
231
+ [],
232
+ [chatbot, state, click_state],
233
+ queue=False,
234
+ show_progress=False
235
+ )
236
+ image_input.clear(
237
+ lambda: (None, [], [], [[], [], []], ""),
238
+ [],
239
+ [image_input, chatbot, state, click_state, wiki_output],
240
+ queue=False,
241
+ show_progress=False
242
+ )
243
+
244
+ examples = gr.Examples(
245
+ examples=examples,
246
+ inputs=[image_input],
247
+ )
248
+
249
+ image_input.upload(upload_callback,[image_input, state], [state, origin_image, click_state])
250
+ chat_input.submit(chat_with_points, [chat_input, click_state, state], [chatbot, state])
251
+
252
+ # select coordinate
253
+ image_input.select(inference_seg_cap,
254
+ inputs=[
255
+ origin_image,
256
+ point_prompt,
257
+ language,
258
+ sentiment,
259
+ factuality,
260
+ length,
261
+ state,
262
+ click_state
263
+ ],
264
+ outputs=[chatbot, state, click_state, chat_input, image_input, wiki_output],
265
+ show_progress=False, queue=True)
266
+
267
+ iface.queue(concurrency_count=1, api_open=False, max_size=10)
268
+ iface.launch(server_name="0.0.0.0", enable_queue=True)
app_old.py CHANGED
@@ -2,12 +2,12 @@ from io import BytesIO
2
  import string
3
  import gradio as gr
4
  import requests
5
- from caas import CaptionAnything
6
  import torch
7
  import json
8
  import sys
9
  import argparse
10
- from caas import parse_augment
11
  import os
12
 
13
  # download sam checkpoint if not downloaded
@@ -83,12 +83,12 @@ def get_select_coords(image_input, point_prompt, language, sentiment, factuality
83
  else:
84
  coordinate = "[[{}, {}, 0]]".format(str(evt.index[0]), str(evt.index[1]))
85
  return (coordinate,) + inference_seg_cap(image_input, coordinate, language, sentiment, factuality, length, state, click_state)
86
-
87
  def chat_with_points(chat_input, click_state, state):
88
  points, labels, captions = click_state
89
- point_chat_prompt = "I want you act as a chat bot in terms of image. I will give you some points (w, h) in the image and tell you what happed on the point in natural language. Note that (0, 0) refers to the top-left corner of the image, w refers to the width and h refers the height. You should chat with me based on the fact in the image instead of imagination. Now I tell you the points with their visual description:\n{points_with_caps}\n. Now begin chatting! Human: {chat_input}\nAI: "
90
  # "The image is of width {width} and height {height}."
91
-
92
  prev_visual_context = ""
93
  pos_points = [f"{points[i][0]}, {points[i][1]}" for i in range(len(points)) if labels[i] == 1]
94
  prev_visual_context = ', '.join(pos_points) + captions[-1] + '\n'
 
2
  import string
3
  import gradio as gr
4
  import requests
5
+ from caption_anything import CaptionAnything
6
  import torch
7
  import json
8
  import sys
9
  import argparse
10
+ from caption_anything import parse_augment
11
  import os
12
 
13
  # download sam checkpoint if not downloaded
 
83
  else:
84
  coordinate = "[[{}, {}, 0]]".format(str(evt.index[0]), str(evt.index[1]))
85
  return (coordinate,) + inference_seg_cap(image_input, coordinate, language, sentiment, factuality, length, state, click_state)
86
+
87
  def chat_with_points(chat_input, click_state, state):
88
  points, labels, captions = click_state
89
+ # point_chat_prompt = "I want you act as a chat bot in terms of image. I will give you some points (w, h) in the image and tell you what happed on the point in natural language. Note that (0, 0) refers to the top-left corner of the image, w refers to the width and h refers the height. You should chat with me based on the fact in the image instead of imagination. Now I tell you the points with their visual description:\n{points_with_caps}\n. Now begin chatting! Human: {chat_input}\nAI: "
90
  # "The image is of width {width} and height {height}."
91
+ point_chat_prompt = "a) Revised prompt: I am an AI trained to chat with you about an image based on specific points (w, h) you provide, along with their visual descriptions. Please note that (0, 0) refers to the top-left corner of the image, w refers to the width, and h refers to the height. Here are the points and their descriptions you've given me: {points_with_caps}. Now, let's chat! Human: {chat_input} AI:"
92
  prev_visual_context = ""
93
  pos_points = [f"{points[i][0]}, {points[i][1]}" for i in range(len(points)) if labels[i] == 1]
94
  prev_visual_context = ', '.join(pos_points) + captions[-1] + '\n'
caption_anything.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from captioner import build_captioner, BaseCaptioner
2
+ from segmenter import build_segmenter
3
+ from text_refiner import build_text_refiner
4
+ import os
5
+ import argparse
6
+ import pdb
7
+ import time
8
+ from PIL import Image
9
+
10
+ class CaptionAnything():
11
+ def __init__(self, args):
12
+ self.args = args
13
+ self.captioner = build_captioner(args.captioner, args.device, args)
14
+ self.segmenter = build_segmenter(args.segmenter, args.device, args)
15
+ if not args.disable_gpt:
16
+ self.init_refiner()
17
+
18
+
19
+ def init_refiner(self):
20
+ if os.environ.get('OPENAI_API_KEY', None):
21
+ self.text_refiner = build_text_refiner(self.args.text_refiner, self.args.device, self.args)
22
+
23
+ def inference(self, image, prompt, controls, disable_gpt=False):
24
+ # segment with prompt
25
+ print("CA prompt: ", prompt, "CA controls",controls)
26
+ seg_mask = self.segmenter.inference(image, prompt)[0, ...]
27
+ mask_save_path = f'result/mask_{time.time()}.png'
28
+ if not os.path.exists(os.path.dirname(mask_save_path)):
29
+ os.makedirs(os.path.dirname(mask_save_path))
30
+ new_p = Image.fromarray(seg_mask.astype('int') * 255.)
31
+ if new_p.mode != 'RGB':
32
+ new_p = new_p.convert('RGB')
33
+ new_p.save(mask_save_path)
34
+ print('seg_mask path: ', mask_save_path)
35
+ print("seg_mask.shape: ", seg_mask.shape)
36
+ # captioning with mask
37
+ if self.args.enable_reduce_tokens:
38
+ caption, crop_save_path = self.captioner.inference_with_reduced_tokens(image, seg_mask, crop_mode=self.args.seg_crop_mode, filter=self.args.clip_filter, regular_box = self.args.regular_box)
39
+ else:
40
+ caption, crop_save_path = self.captioner.inference_seg(image, seg_mask, crop_mode=self.args.seg_crop_mode, filter=self.args.clip_filter, regular_box = self.args.regular_box)
41
+ # refining with TextRefiner
42
+ context_captions = []
43
+ if self.args.context_captions:
44
+ context_captions.append(self.captioner.inference(image))
45
+ if not disable_gpt and hasattr(self, "text_refiner"):
46
+ refined_caption = self.text_refiner.inference(query=caption, controls=controls, context=context_captions)
47
+ else:
48
+ refined_caption = {'raw_caption': caption}
49
+ out = {'generated_captions': refined_caption,
50
+ 'crop_save_path': crop_save_path,
51
+ 'mask_save_path': mask_save_path,
52
+ 'context_captions': context_captions}
53
+ return out
54
+
55
+ def parse_augment():
56
+ parser = argparse.ArgumentParser()
57
+ parser.add_argument('--captioner', type=str, default="blip")
58
+ parser.add_argument('--segmenter', type=str, default="base")
59
+ parser.add_argument('--text_refiner', type=str, default="base")
60
+ parser.add_argument('--segmenter_checkpoint', type=str, default="segmenter/sam_vit_h_4b8939.pth")
61
+ parser.add_argument('--seg_crop_mode', type=str, default="w_bg", choices=['wo_bg', 'w_bg'], help="whether to add or remove background of the image when captioning")
62
+ parser.add_argument('--clip_filter', action="store_true", help="use clip to filter bad captions")
63
+ parser.add_argument('--context_captions', action="store_true", help="use surrounding captions to enhance current caption (TODO)")
64
+ parser.add_argument('--regular_box', action="store_true", default = False, help="crop image with a regular box")
65
+ parser.add_argument('--device', type=str, default="cuda:0")
66
+ parser.add_argument('--port', type=int, default=6086, help="only useful when running gradio applications")
67
+ parser.add_argument('--debug', action="store_true")
68
+ parser.add_argument('--gradio_share', action="store_true")
69
+ parser.add_argument('--disable_gpt', action="store_true")
70
+ parser.add_argument('--enable_reduce_tokens', action="store_true", default=False)
71
+ parser.add_argument('--disable_reuse_features', action="store_true", default=False)
72
+ args = parser.parse_args()
73
+
74
+ if args.debug:
75
+ print(args)
76
+ return args
77
+
78
+ if __name__ == "__main__":
79
+ args = parse_augment()
80
+ # image_path = 'test_img/img3.jpg'
81
+ image_path = 'test_img/img13.jpg'
82
+ prompts = [
83
+ {
84
+ "prompt_type":["click"],
85
+ "input_point":[[500, 300], [1000, 500]],
86
+ "input_label":[1, 0],
87
+ "multimask_output":"True",
88
+ },
89
+ {
90
+ "prompt_type":["click"],
91
+ "input_point":[[900, 800]],
92
+ "input_label":[1],
93
+ "multimask_output":"True",
94
+ }
95
+ ]
96
+ controls = {
97
+ "length": "30",
98
+ "sentiment": "positive",
99
+ # "imagination": "True",
100
+ "imagination": "False",
101
+ "language": "English",
102
+ }
103
+
104
+ model = CaptionAnything(args)
105
+ for prompt in prompts:
106
+ print('*'*30)
107
+ print('Image path: ', image_path)
108
+ image = Image.open(image_path)
109
+ print(image)
110
+ print('Visual controls (SAM prompt):\n', prompt)
111
+ print('Language controls:\n', controls)
112
+ out = model.inference(image_path, prompt, controls)
113
+
114
+
captioner/base_captioner.py CHANGED
@@ -146,7 +146,8 @@ class BaseCaptioner:
146
  seg_mask = np.array(seg_mask) > 0
147
 
148
  if crop_mode=="wo_bg":
149
- image = np.array(image) * seg_mask[:,:,np.newaxis]
 
150
  else:
151
  image = np.array(image)
152
 
@@ -168,7 +169,7 @@ class BaseCaptioner:
168
  seg_mask = np.array(seg_mask) > 0
169
 
170
  if crop_mode=="wo_bg":
171
- image = np.array(image) * seg_mask[:,:,np.newaxis]
172
  else:
173
  image = np.array(image)
174
 
 
146
  seg_mask = np.array(seg_mask) > 0
147
 
148
  if crop_mode=="wo_bg":
149
+ image = np.array(image) * seg_mask[:,:,np.newaxis] + (1 - seg_mask[:,:,np.newaxis]) * 255
150
+ image = np.uint8(image)
151
  else:
152
  image = np.array(image)
153
 
 
169
  seg_mask = np.array(seg_mask) > 0
170
 
171
  if crop_mode=="wo_bg":
172
+ image = np.array(image) * seg_mask[:,:,np.newaxis] + (1- seg_mask[:,:,np.newaxis]) * 255
173
  else:
174
  image = np.array(image)
175
 
env.sh CHANGED
@@ -1,6 +1,6 @@
1
  conda create -n caption_anything python=3.8 -y
2
  source activate caption_anything
3
- pip install -r requirement.txt
4
  cd segmenter
5
  wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth
6
 
 
1
  conda create -n caption_anything python=3.8 -y
2
  source activate caption_anything
3
+ pip install -r requirements.txt
4
  cd segmenter
5
  wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth
6
 
image_editing_utils.py CHANGED
@@ -17,7 +17,7 @@ def wrap_text(text, font, max_width):
17
  lines.append(current_line)
18
  return lines
19
 
20
- def create_bubble_frame(image, text, point, font_path='DejaVuSansCondensed-Bold.ttf', font_size_ratio=0.033):
21
  # Load the image
22
  if type(image) == np.ndarray:
23
  image = Image.fromarray(image)
@@ -27,7 +27,7 @@ def create_bubble_frame(image, text, point, font_path='DejaVuSansCondensed-Bold.
27
 
28
  # Calculate max_text_width and font_size based on image dimensions and total number of characters
29
  total_chars = len(text)
30
- max_text_width = int(0.33 * width)
31
  font_size = int(height * font_size_ratio)
32
 
33
  # Load the font
 
17
  lines.append(current_line)
18
  return lines
19
 
20
+ def create_bubble_frame(image, text, point, font_path='DejaVuSansCondensed-Bold.ttf', font_size_ratio=0.025):
21
  # Load the image
22
  if type(image) == np.ndarray:
23
  image = Image.fromarray(image)
 
27
 
28
  # Calculate max_text_width and font_size based on image dimensions and total number of characters
29
  total_chars = len(text)
30
+ max_text_width = int(0.4 * width)
31
  font_size = int(height * font_size_ratio)
32
 
33
  # Load the font
requirements.txt CHANGED
@@ -16,3 +16,4 @@ matplotlib
16
  onnxruntime
17
  onnx
18
  https://gradio-builds.s3.amazonaws.com/3e68e5e882a6790ac5b457bd33f4edf9b695af90/gradio-3.24.1-py3-none-any.whl
 
 
16
  onnxruntime
17
  onnx
18
  https://gradio-builds.s3.amazonaws.com/3e68e5e882a6790ac5b457bd33f4edf9b695af90/gradio-3.24.1-py3-none-any.whl
19
+ accelerate
tools.py CHANGED
@@ -1,6 +1,7 @@
1
  import cv2
2
  import numpy as np
3
  from PIL import Image
 
4
 
5
 
6
  def colormap(rgb=True):
@@ -145,6 +146,11 @@ def mask_painter(input_image, input_mask, background_alpha=0.7, background_blur_
145
  assert input_image.shape[:2] == input_mask.shape, 'different shape'
146
  assert background_blur_radius % 2 * contour_width % 2 > 0, 'background_blur_radius and contour_width must be ODD'
147
 
 
 
 
 
 
148
  # 0: background, 1: foreground
149
  input_mask[input_mask>0] = 255
150
 
@@ -157,7 +163,7 @@ def mask_painter(input_image, input_mask, background_alpha=0.7, background_blur_
157
  kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (contour_width, contour_width))
158
  contour_mask = cv2.dilate(contour_mask, kernel)
159
  painted_image = vis_add_mask(painted_image, 255-contour_mask, color_list[contour_color], contour_alpha, contour_width)
160
-
161
  return painted_image
162
 
163
 
 
1
  import cv2
2
  import numpy as np
3
  from PIL import Image
4
+ import copy
5
 
6
 
7
  def colormap(rgb=True):
 
146
  assert input_image.shape[:2] == input_mask.shape, 'different shape'
147
  assert background_blur_radius % 2 * contour_width % 2 > 0, 'background_blur_radius and contour_width must be ODD'
148
 
149
+ width, height = input_image.shape[0], input_image.shape[1]
150
+ res = 1024
151
+ ratio = min(1.0 * res / max(width, height), 1.0)
152
+ input_image = cv2.resize(input_image, (int(height*ratio), int(width*ratio)))
153
+ input_mask = cv2.resize(input_mask, (int(height*ratio), int(width*ratio)))
154
  # 0: background, 1: foreground
155
  input_mask[input_mask>0] = 255
156
 
 
163
  kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (contour_width, contour_width))
164
  contour_mask = cv2.dilate(contour_mask, kernel)
165
  painted_image = vis_add_mask(painted_image, 255-contour_mask, color_list[contour_color], contour_alpha, contour_width)
166
+ painted_image = cv2.resize(painted_image, (height, width))
167
  return painted_image
168
 
169