ariG23498's picture
Update app.py
f153f53 verified
import spaces
import gradio as gr
import torch
from diffusers import AutoPipelineForInpainting
from PIL import Image, ImageFilter
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
BlipForConditionalGeneration,
BlipProcessor,
Owlv2ForObjectDetection,
Owlv2Processor,
SamModel,
SamProcessor,
)
def delete_model(model):
model.to("cpu")
del model
torch.cuda.empty_cache()
@spaces.GPU()
def run_language_model(edit_prompt, caption, device):
language_model_id = "Qwen/Qwen1.5-0.5B-Chat"
language_model = AutoModelForCausalLM.from_pretrained(
language_model_id, device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(language_model_id)
messages = [
{"role": "system", "content": "Follow the examples and return the expected output"},
{"role": "user", "content": "Caption: a blue sky with fluffy clouds\nQuery: Make the sky stormy"},
{"role": "assistant", "content": "A: sky\nB: a stormy sky with heavy gray clouds, torrential rain, gloomy, overcast"},
{"role": "user", "content": "Caption: a cat sleeping on a sofa\nQuery: Change the cat to a dog"},
{"role": "assistant", "content": "A: cat\nB: a dog sleeping on a sofa, cozy and comfortable, snuggled up in a warm blanket, peaceful"},
{"role": "user", "content": "Caption: a snowy mountain peak\nQuery: Replace the snow with greenery"},
{"role": "assistant", "content": "A: snow\nB: a lush green mountain peak in summer, clear blue skies, birds flying overhead, serene and majestic"},
{"role": "user", "content": "Caption: a vintage car parked by the roadside\nQuery: Change the car to a modern electric vehicle"},
{"role": "assistant", "content": "A: car\nB: a sleek modern electric vehicle parked by the roadside, cutting-edge design, environmentally friendly, silent and powerful"},
{"role": "user", "content": "Caption: a wooden bridge over a river\nQuery: Make the bridge stone"},
{"role": "assistant", "content": "A: bridge\nB: an ancient stone bridge over a river, moss-covered, sturdy and timeless, with clear waters flowing beneath"},
{"role": "user", "content": "Caption: a bowl of salad on the table\nQuery: Replace salad with soup"},
{"role": "assistant", "content": "A: bowl\nB: a bowl of steaming hot soup on the table, scrumptious, with garnishing"},
{"role": "user", "content": "Caption: a book on a desk surrounded by stationery\nQuery: Remove all stationery, add a laptop"},
{"role": "assistant", "content": "A: stationery\nB: a book on a desk with a laptop next to it, modern study setup, focused and productive, technology and education combined"},
{"role": "user", "content": "Caption: a cup of coffee on a wooden table\nQuery: Change coffee to tea"},
{"role": "assistant", "content": "A: cup\nB: a steaming cup of tea on a wooden table, calming and aromatic, with a slice of lemon on the side, inviting"},
{"role": "user", "content": "Caption: a small pen on a white table\nQuery: Change the pen to an elaborate fountain pen"},
{"role": "assistant", "content": "A: pen\nB: an elaborate fountain pen on a white table, sleek and elegant, with intricate designs, ready for writing"},
{"role": "user", "content": "Caption: a plain notebook on a desk\nQuery: Replace the notebook with a journal"},
{"role": "assistant", "content": "A: notebook\nB: an artistically decorated journal on a desk, vibrant cover, filled with creativity, inspiring and personalized"},
{"role": "user", "content": f"Caption: {caption}\nQuery: {edit_prompt}"},
]
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
model_inputs = tokenizer([text], return_tensors="pt").to(device)
with torch.no_grad():
generated_ids = language_model.generate(
model_inputs.input_ids,
max_new_tokens=512,
temperature=0.0,
do_sample=False
)
generated_ids = [
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
]
response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
output_generation_a, output_generation_b = response.split("\n")
to_replace = output_generation_a[2:].strip()
replaced_caption = output_generation_b[2:].strip()
delete_model(language_model)
return (to_replace, replaced_caption)
@spaces.GPU()
def run_image_captioner(image, device):
caption_model_id = "Salesforce/blip-image-captioning-base"
caption_model = BlipForConditionalGeneration.from_pretrained(caption_model_id).to(
device
)
caption_processor = BlipProcessor.from_pretrained(caption_model_id)
inputs = caption_processor(image, return_tensors="pt").to(device)
with torch.no_grad():
outputs = caption_model.generate(**inputs, max_new_tokens=200)
caption = caption_processor.decode(outputs[0], skip_special_tokens=True)
delete_model(caption_model)
return caption
@spaces.GPU()
def run_segmentation(image, object_to_segment, device):
# OWL-V2 for object detection
owl_v2_model_id = "google/owlv2-base-patch16-ensemble"
processor = Owlv2Processor.from_pretrained(owl_v2_model_id)
od_model = Owlv2ForObjectDetection.from_pretrained(owl_v2_model_id).to(device)
text_queries = [object_to_segment]
inputs = processor(text=text_queries, images=image, return_tensors="pt").to(device)
with torch.no_grad():
outputs = od_model(**inputs)
target_sizes = torch.tensor([image.size]).to(device)
results = processor.post_process_object_detection(
outputs, threshold=0.1, target_sizes=target_sizes
)[0]
boxes = results["boxes"].tolist()
delete_model(od_model)
# SAM for image segmentation
sam_model_id = "facebook/sam-vit-base"
seg_model = SamModel.from_pretrained(sam_model_id).to(device)
processor = SamProcessor.from_pretrained(sam_model_id)
input_boxes = [boxes]
inputs = processor(image, input_boxes=input_boxes, return_tensors="pt").to(device)
with torch.no_grad():
outputs = seg_model(**inputs)
masks = processor.image_processor.post_process_masks(
outputs.pred_masks.cpu(),
inputs["original_sizes"].cpu(),
inputs["reshaped_input_sizes"].cpu(),
)[0]
# Merge the masks
masks = torch.max(masks[:, 0, ...], dim=0, keepdim=False).values
delete_model(seg_model)
return masks
@spaces.GPU()
def run_inpainting(image, replaced_caption, masks, generator, device):
pipeline = AutoPipelineForInpainting.from_pretrained(
"diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
torch_dtype=torch.float16,
variant="fp16",
).to(device)
masks = Image.fromarray(masks.numpy())
# dilation_image = masks.filter(ImageFilter.MaxFilter(3))
prompt = replaced_caption
negative_prompt = """lowres, bad anatomy, bad hands,
text, error, missing fingers, extra digit, fewer digits,
cropped, worst quality, low quality"""
output = pipeline(
prompt=prompt,
image=image,
# mask_image=dilation_image,
mask_image=masks,
negative_prompt=negative_prompt,
guidance_scale=7.5,
strength=1.0,
generator=generator,
).images[0]
delete_model(pipeline)
return output
def run_open_gen_fill(image, edit_prompt):
device = "cuda" if torch.cuda.is_available() else "cpu"
# Resize the image to (512, 512)
image = image.resize((512, 512))
# Caption the input image
caption = run_image_captioner(image, device=device)
# Run the langauge model to extract the object for segmentation
# and get the replaced caption
to_replace, replaced_caption = run_language_model(
edit_prompt=edit_prompt, caption=caption, device=device
)
# Segment the `to_replace` object from the input image
masks = run_segmentation(image, to_replace, device=device)
# Diffusion pipeline for inpainting
generator = torch.Generator(device).manual_seed(17)
output = run_inpainting(
image=image, replaced_caption=replaced_caption, masks=masks, generator=generator, device=device
)
return (
to_replace,
caption,
replaced_caption,
Image.fromarray(masks.numpy()),
output,
)
def setup_gradio_interface():
block = gr.Blocks()
with block:
gr.Markdown("<h1><center>Open Generative Fill V1<h1><center>")
with gr.Row():
with gr.Column():
input_image_placeholder = gr.Image(type="pil", label="Input Image")
edit_prompt_placeholder = gr.Textbox(label="Enter the editing prompt")
run_button_placeholder = gr.Button(value="Run")
with gr.Column():
to_replace_placeholder = gr.Textbox(label="to_replace")
image_caption_placeholder = gr.Textbox(label="Image Caption")
replace_caption_placeholder = gr.Textbox(label="Replaced Caption")
segmentation_placeholder = gr.Image(type="pil", label="Segmentation")
output_image_placeholder = gr.Image(type="pil", label="Output Image")
run_button_placeholder.click(
fn=lambda image, edit_prompt: run_open_gen_fill(
image=image,
edit_prompt=edit_prompt,
),
inputs=[input_image_placeholder, edit_prompt_placeholder],
outputs=[
to_replace_placeholder,
image_caption_placeholder,
replace_caption_placeholder,
segmentation_placeholder,
output_image_placeholder,
],
)
gr.Examples(
examples=[["dog.jpeg", "replace the dog with a tiger"]],
inputs=[input_image_placeholder, edit_prompt_placeholder],
outputs=[
to_replace_placeholder,
image_caption_placeholder,
replace_caption_placeholder,
segmentation_placeholder,
output_image_placeholder,
],
fn=lambda image, edit_prompt: run_open_gen_fill(
image=image,
edit_prompt=edit_prompt,
),
cache_examples=True,
label="Try this example input!",
)
return block
if __name__ == "__main__":
gradio_interface = setup_gradio_interface()
# gradio_interface.queue(max_size=10)
gradio_interface.launch(share=False, show_api=False, show_error=True)