|
from io import BytesIO |
|
import string |
|
import gradio as gr |
|
import requests |
|
from caption_anything import CaptionAnything |
|
import torch |
|
import json |
|
import sys |
|
import argparse |
|
from caption_anything import parse_augment |
|
import os |
|
|
|
|
|
def download_checkpoint(url, folder, filename): |
|
os.makedirs(folder, exist_ok=True) |
|
filepath = os.path.join(folder, filename) |
|
|
|
if not os.path.exists(filepath): |
|
response = requests.get(url, stream=True) |
|
with open(filepath, "wb") as f: |
|
for chunk in response.iter_content(chunk_size=8192): |
|
if chunk: |
|
f.write(chunk) |
|
|
|
return filepath |
|
checkpoint_url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth" |
|
folder = "segmenter" |
|
filename = "sam_vit_h_4b8939.pth" |
|
|
|
title = """<h1 align="center">Caption-Anything</h1>""" |
|
description = """Gradio demo for Caption Anything, image to dense captioning generation with various language styles. To use it, simply upload your image, or click one of the examples to load them. |
|
<br> <strong>Code</strong>: GitHub repo: <a href='https://github.com/ttengwang/Caption-Anything' target='_blank'></a> |
|
""" |
|
|
|
examples = [ |
|
["test_img/img2.jpg", "[[1000, 700, 1]]"] |
|
] |
|
|
|
args = parse_augment() |
|
|
|
def get_prompt(chat_input, click_state): |
|
points = click_state[0] |
|
labels = click_state[1] |
|
inputs = json.loads(chat_input) |
|
for input in inputs: |
|
points.append(input[:2]) |
|
labels.append(input[2]) |
|
|
|
prompt = { |
|
"prompt_type":["click"], |
|
"input_point":points, |
|
"input_label":labels, |
|
"multimask_output":"True", |
|
} |
|
return prompt |
|
|
|
def inference_seg_cap(image_input, chat_input, language, sentiment, factuality, length, state, click_state): |
|
controls = {'length': length, |
|
'sentiment': sentiment, |
|
'factuality': factuality, |
|
'language': language} |
|
prompt = get_prompt(chat_input, click_state) |
|
print('prompt: ', prompt, 'controls: ', controls) |
|
out = model.inference(image_input, prompt, controls) |
|
state = state + [(None, "Image point: {}, Input label: {}".format(prompt["input_point"], prompt["input_label"]))] |
|
for k, v in out['generated_captions'].items(): |
|
state = state + [(f'{k}: {v}', None)] |
|
click_state[2].append(out['generated_captions']['raw_caption']) |
|
image_output_mask = out['mask_save_path'] |
|
image_output_crop = out['crop_save_path'] |
|
return state, state, click_state, image_output_mask, image_output_crop |
|
|
|
|
|
def upload_callback(image_input, state): |
|
state = state + [('Image size: ' + str(image_input.size), None)] |
|
return state |
|
|
|
|
|
def get_select_coords(image_input, point_prompt, language, sentiment, factuality, length, state, click_state, evt: gr.SelectData): |
|
print("point_prompt: ", point_prompt) |
|
if point_prompt == 'Positive Point': |
|
coordinate = "[[{}, {}, 1]]".format(str(evt.index[0]), str(evt.index[1])) |
|
else: |
|
coordinate = "[[{}, {}, 0]]".format(str(evt.index[0]), str(evt.index[1])) |
|
return (coordinate,) + inference_seg_cap(image_input, coordinate, language, sentiment, factuality, length, state, click_state) |
|
|
|
def chat_with_points(chat_input, click_state, state): |
|
points, labels, captions = click_state |
|
|
|
|
|
point_chat_prompt = "a) Revised prompt: I am an AI trained to chat with you about an image based on specific points (w, h) you provide, along with their visual descriptions. Please note that (0, 0) refers to the top-left corner of the image, w refers to the width, and h refers to the height. Here are the points and their descriptions you've given me: {points_with_caps}. Now, let's chat! Human: {chat_input} AI:" |
|
prev_visual_context = "" |
|
pos_points = [f"{points[i][0]}, {points[i][1]}" for i in range(len(points)) if labels[i] == 1] |
|
prev_visual_context = ', '.join(pos_points) + captions[-1] + '\n' |
|
chat_prompt = point_chat_prompt.format(**{"points_with_caps": prev_visual_context, "chat_input": chat_input}) |
|
response = model.text_refiner.llm(chat_prompt) |
|
state = state + [(chat_input, response)] |
|
return state, state |
|
|
|
def init_openai_api_key(api_key): |
|
|
|
global model |
|
model = CaptionAnything(args, api_key) |
|
|
|
css=''' |
|
#image_upload{min-height:200px} |
|
#image_upload [data-testid="image"], #image_upload [data-testid="image"] > div{min-height: 200px} |
|
''' |
|
|
|
with gr.Blocks(css=css) as iface: |
|
state = gr.State([]) |
|
click_state = gr.State([[],[],[]]) |
|
caption_state = gr.State([[]]) |
|
gr.Markdown(title) |
|
gr.Markdown(description) |
|
|
|
with gr.Column(): |
|
openai_api_key = gr.Textbox( |
|
placeholder="Input your openAI API key and press Enter", |
|
show_label=False, |
|
lines=1, |
|
type="password", |
|
) |
|
openai_api_key.submit(init_openai_api_key, inputs=[openai_api_key]) |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=0.7): |
|
image_input = gr.Image(type="pil", interactive=True, label="Image", elem_id="image_upload").style(height=260,scale=1.0) |
|
|
|
with gr.Row(scale=0.7): |
|
point_prompt = gr.Radio( |
|
choices=["Positive Point", "Negative Point"], |
|
value="Positive Point", |
|
label="Points", |
|
interactive=True, |
|
) |
|
|
|
|
|
language = gr.Radio( |
|
choices=["English", "Chinese", "French", "Spanish", "Arabic", "Portuguese","Cantonese"], |
|
value="English", |
|
label="Language", |
|
interactive=True, |
|
) |
|
sentiment = gr.Radio( |
|
choices=["Positive", "Natural", "Negative"], |
|
value="Natural", |
|
label="Sentiment", |
|
interactive=True, |
|
) |
|
factuality = gr.Radio( |
|
choices=["Factual", "Imagination"], |
|
value="Factual", |
|
label="Factuality", |
|
interactive=True, |
|
) |
|
length = gr.Slider( |
|
minimum=5, |
|
maximum=100, |
|
value=10, |
|
step=1, |
|
interactive=True, |
|
label="Length", |
|
) |
|
|
|
with gr.Column(scale=1.5): |
|
with gr.Row(): |
|
image_output_mask= gr.Image(type="pil", interactive=False, label="Mask").style(height=260,scale=1.0) |
|
image_output_crop= gr.Image(type="pil", interactive=False, label="Cropped Image by Mask", show_progress=False).style(height=260,scale=1.0) |
|
chatbot = gr.Chatbot(label="Chat Output",).style(height=450,scale=0.5) |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=0.7): |
|
prompt_input = gr.Textbox(lines=1, label="Input Prompt (A list of points like : [[100, 200, 1], [200,300,0]])") |
|
prompt_input.submit( |
|
inference_seg_cap, |
|
[ |
|
image_input, |
|
prompt_input, |
|
language, |
|
sentiment, |
|
factuality, |
|
length, |
|
state, |
|
click_state |
|
], |
|
[chatbot, state, click_state, image_output_mask, image_output_crop], |
|
show_progress=False |
|
) |
|
|
|
image_input.upload( |
|
upload_callback, |
|
[image_input, state], |
|
[chatbot] |
|
) |
|
|
|
with gr.Row(): |
|
clear_button = gr.Button(value="Clear Click", interactive=True) |
|
clear_button.click( |
|
lambda: ("", [[], [], []], None, None), |
|
[], |
|
[prompt_input, click_state, image_output_mask, image_output_crop], |
|
queue=False, |
|
show_progress=False |
|
) |
|
|
|
clear_button = gr.Button(value="Clear", interactive=True) |
|
clear_button.click( |
|
lambda: ("", [], [], [[], [], []], None, None), |
|
[], |
|
[prompt_input, chatbot, state, click_state, image_output_mask, image_output_crop], |
|
queue=False, |
|
show_progress=False |
|
) |
|
|
|
submit_button = gr.Button( |
|
value="Submit", interactive=True, variant="primary" |
|
) |
|
submit_button.click( |
|
inference_seg_cap, |
|
[ |
|
image_input, |
|
prompt_input, |
|
language, |
|
sentiment, |
|
factuality, |
|
length, |
|
state, |
|
click_state |
|
], |
|
[chatbot, state, click_state, image_output_mask, image_output_crop], |
|
show_progress=False |
|
) |
|
|
|
|
|
image_input.select( |
|
get_select_coords, |
|
inputs=[image_input,point_prompt,language,sentiment,factuality,length,state,click_state], |
|
outputs=[prompt_input, chatbot, state, click_state, image_output_mask, image_output_crop], |
|
show_progress=False |
|
) |
|
|
|
image_input.change( |
|
lambda: ("", [], [[], [], []]), |
|
[], |
|
[chatbot, state, click_state], |
|
queue=False, |
|
) |
|
|
|
with gr.Column(scale=1.5): |
|
chat_input = gr.Textbox(lines=1, label="Chat Input") |
|
chat_input.submit(chat_with_points, [chat_input, click_state, state], [chatbot, state]) |
|
|
|
|
|
examples = gr.Examples( |
|
examples=examples, |
|
inputs=[image_input, prompt_input], |
|
) |
|
|
|
iface.queue(concurrency_count=1, api_open=False, max_size=10) |
|
iface.launch(server_name="0.0.0.0", enable_queue=True, server_port=args.port, share=args.gradio_share) |
|
|