|
from io import BytesIO |
|
import string |
|
import gradio as gr |
|
import requests |
|
from caption_anything import CaptionAnything |
|
import torch |
|
import json |
|
from diffusers import StableDiffusionInpaintPipeline |
|
import sys |
|
import argparse |
|
from caption_anything import parse_augment |
|
import numpy as np |
|
import PIL.ImageDraw as ImageDraw |
|
from image_editing_utils import create_bubble_frame |
|
import copy |
|
from tools import mask_painter |
|
from PIL import Image |
|
import os |
|
import cv2 |
|
|
|
def download_checkpoint(url, folder, filename): |
|
os.makedirs(folder, exist_ok=True) |
|
filepath = os.path.join(folder, filename) |
|
|
|
if not os.path.exists(filepath): |
|
response = requests.get(url, stream=True) |
|
with open(filepath, "wb") as f: |
|
for chunk in response.iter_content(chunk_size=8192): |
|
if chunk: |
|
f.write(chunk) |
|
|
|
return filepath |
|
checkpoint_url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth" |
|
folder = "segmenter" |
|
filename = "sam_vit_h_4b8939.pth" |
|
|
|
download_checkpoint(checkpoint_url, folder, filename) |
|
|
|
|
|
title = """<h1 align="center">Edit Anything</h1>""" |
|
description = """Gradio demo for Segment Anything, image to dense Segment generation with various language styles. To use it, simply upload your image, or click one of the examples to load them. |
|
""" |
|
|
|
examples = [ |
|
["test_img/img35.webp"], |
|
["test_img/img2.jpg"], |
|
["test_img/img5.jpg"], |
|
["test_img/img12.jpg"], |
|
["test_img/img14.jpg"], |
|
["test_img/img0.png"], |
|
["test_img/img1.jpg"], |
|
] |
|
|
|
args = parse_augment() |
|
|
|
|
|
|
|
|
|
model = CaptionAnything(args) |
|
|
|
def init_openai_api_key(api_key): |
|
|
|
model.init_refiner(api_key) |
|
openai_available = model.text_refiner is not None |
|
return gr.update(visible = openai_available), gr.update(visible = openai_available), gr.update(visible = openai_available), gr.update(visible = True), gr.update(visible = True) |
|
|
|
def get_prompt(chat_input, click_state): |
|
points = click_state[0] |
|
labels = click_state[1] |
|
inputs = json.loads(chat_input) |
|
for input in inputs: |
|
points.append(input[:2]) |
|
labels.append(input[2]) |
|
|
|
prompt = { |
|
"prompt_type":["click"], |
|
"input_point":points, |
|
"input_label":labels, |
|
"multimask_output":"True", |
|
} |
|
return prompt |
|
|
|
def chat_with_points(chat_input, click_state, state, mask,image_input): |
|
|
|
points, labels, captions = click_state |
|
|
|
|
|
|
|
pipe = StableDiffusionInpaintPipeline.from_pretrained( |
|
"stabilityai/stable-diffusion-2-inpainting", |
|
torch_dtype=torch.float32, |
|
) |
|
|
|
|
|
pipe = pipe |
|
|
|
|
|
image_input = np.array(image_input) |
|
h,w = image_input.shape[:2] |
|
|
|
image = cv2.resize(image_input,(512,512)) |
|
mask = cv2.resize(mask,(512,512)).astype(np.uint8) |
|
print(image.shape,mask.shape) |
|
print("chat_input:",chat_input) |
|
image = pipe(prompt=chat_input, image=image, mask_image=mask).images[0] |
|
image = image.resize((w,h)) |
|
|
|
|
|
return state, state, image |
|
|
|
def inference_seg_cap(image_input, point_prompt, language, sentiment, factuality, length, state, click_state, evt:gr.SelectData): |
|
|
|
if point_prompt == 'Positive': |
|
coordinate = "[[{}, {}, 1]]".format(str(evt.index[0]), str(evt.index[1])) |
|
else: |
|
coordinate = "[[{}, {}, 0]]".format(str(evt.index[0]), str(evt.index[1])) |
|
|
|
controls = {'length': length, |
|
'sentiment': sentiment, |
|
'factuality': factuality, |
|
'language': language} |
|
|
|
|
|
|
|
prompt = get_prompt(coordinate, click_state) |
|
print('prompt: ', prompt, 'controls: ', controls) |
|
|
|
out = model.inference(image_input, prompt, controls) |
|
state = state + [(None, "Image point: {}, Input label: {}".format(prompt["input_point"], prompt["input_label"]))] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
input_mask = np.array(out['mask'].convert('P')) |
|
image_input = mask_painter(np.array(image_input), input_mask) |
|
origin_image_input = image_input |
|
text = "edit" |
|
image_input = create_bubble_frame(image_input, text, (evt.index[0], evt.index[1])) |
|
|
|
yield state, state, click_state, image_input, input_mask |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def upload_callback(image_input, state): |
|
state = [] + [('Image size: ' + str(image_input.size), None)] |
|
click_state = [[], [], []] |
|
res = 1024 |
|
width, height = image_input.size |
|
ratio = min(1.0 * res / max(width, height), 1.0) |
|
if ratio < 1.0: |
|
image_input = image_input.resize((int(width * ratio), int(height * ratio))) |
|
print('Scaling input image to {}'.format(image_input.size)) |
|
model.segmenter.image = None |
|
model.segmenter.image_embedding = None |
|
model.segmenter.set_image(image_input) |
|
return state, image_input, click_state, image_input |
|
|
|
with gr.Blocks( |
|
css=''' |
|
#image_upload{min-height:400px} |
|
#image_upload [data-testid="image"], #image_upload [data-testid="image"] > div{min-height: 600px} |
|
''' |
|
) as iface: |
|
state = gr.State([]) |
|
click_state = gr.State([[],[],[]]) |
|
origin_image = gr.State(None) |
|
mask_save_path = gr.State(None) |
|
|
|
gr.Markdown(title) |
|
gr.Markdown(description) |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=1.0): |
|
with gr.Column(visible=True) as modules_not_need_gpt: |
|
image_input = gr.Image(type="pil", interactive=True, elem_id="image_upload") |
|
example_image = gr.Image(type="pil", interactive=False, visible=False) |
|
with gr.Row(scale=1.0): |
|
point_prompt = gr.Radio( |
|
choices=["Positive", "Negative"], |
|
value="Positive", |
|
label="Point Prompt", |
|
interactive=True) |
|
clear_button_clike = gr.Button(value="Clear Clicks", interactive=True) |
|
clear_button_image = gr.Button(value="Clear Image", interactive=True) |
|
with gr.Column(visible=True) as modules_need_gpt: |
|
with gr.Row(scale=1.0): |
|
language = gr.Dropdown(['English', 'Chinese', 'French', "Spanish", "Arabic", "Portuguese", "Cantonese"], value="English", label="Language", interactive=True) |
|
|
|
sentiment = gr.Radio( |
|
choices=["Positive", "Natural", "Negative"], |
|
value="Natural", |
|
label="Sentiment", |
|
interactive=True, |
|
) |
|
with gr.Row(scale=1.0): |
|
factuality = gr.Radio( |
|
choices=["Factual", "Imagination"], |
|
value="Factual", |
|
label="Factuality", |
|
interactive=True, |
|
) |
|
length = gr.Slider( |
|
minimum=10, |
|
maximum=80, |
|
value=10, |
|
step=1, |
|
interactive=True, |
|
label="Length", |
|
) |
|
|
|
with gr.Column(scale=0.5): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with gr.Column(visible=True) as modules_not_need_gpt2: |
|
chatbot = gr.Chatbot(label="History",).style(height=450,scale=0.5) |
|
with gr.Column(visible=True) as modules_need_gpt3: |
|
chat_input = gr.Textbox(lines=1, label="Edit Prompt") |
|
with gr.Row(): |
|
clear_button_text = gr.Button(value="Clear Text", interactive=True) |
|
submit_button_text = gr.Button(value="Submit", interactive=True, variant="primary") |
|
|
|
|
|
|
|
clear_button_clike.click( |
|
lambda x: ([[], [], []], x, ""), |
|
[origin_image], |
|
[click_state, image_input], |
|
queue=False, |
|
show_progress=False |
|
) |
|
|
|
clear_button_image.click( |
|
lambda: (None, [], [], [[], [], []], "", ""), |
|
[], |
|
[image_input, chatbot, state, click_state, origin_image], |
|
queue=False, |
|
show_progress=False |
|
) |
|
clear_button_text.click( |
|
lambda: ([], [], [[], [], []]), |
|
[], |
|
[chatbot, state, click_state], |
|
queue=False, |
|
show_progress=False |
|
) |
|
|
|
|
|
image_input.clear( |
|
lambda: (None, [], [], [[], [], []], "", ""), |
|
[], |
|
[image_input, chatbot, state, click_state, origin_image], |
|
queue=False, |
|
show_progress=False |
|
) |
|
|
|
def example_callback(x): |
|
model.image_embedding = None |
|
return x |
|
|
|
gr.Examples( |
|
examples=examples, |
|
inputs=[example_image], |
|
) |
|
|
|
submit_button_text.click( |
|
chat_with_points, |
|
[chat_input, click_state, state, mask_save_path,image_input], |
|
[chatbot, state, image_input] |
|
) |
|
|
|
|
|
image_input.upload(upload_callback,[image_input, state], [state, origin_image, click_state, image_input]) |
|
chat_input.submit(chat_with_points, [chat_input, click_state, state, mask_save_path,image_input], [chatbot, state, image_input]) |
|
example_image.change(upload_callback,[example_image, state], [state, origin_image, click_state, image_input]) |
|
|
|
|
|
image_input.select(inference_seg_cap, |
|
inputs=[ |
|
origin_image, |
|
point_prompt, |
|
language, |
|
sentiment, |
|
factuality, |
|
length, |
|
state, |
|
click_state |
|
], |
|
outputs=[chatbot, state, click_state, image_input, mask_save_path], |
|
show_progress=False, queue=True) |
|
|
|
iface.queue(concurrency_count=3, api_open=False, max_size=10) |
|
iface.launch(server_name="0.0.0.0", enable_queue=True) |