ariG23498's picture
Update app.py
f153f53 verified
raw history blame
No virus
10.7 kB
import spaces
import gradio as gr
import torch
from diffusers import AutoPipelineForInpainting
from PIL import Image, ImageFilter
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
BlipForConditionalGeneration,
BlipProcessor,
Owlv2ForObjectDetection,
Owlv2Processor,
SamModel,
SamProcessor,
)
def delete_model(model):
model.to("cpu")
del model
torch.cuda.empty_cache()
@spaces.GPU()
def run_language_model(edit_prompt, caption, device):
language_model_id = "Qwen/Qwen1.5-0.5B-Chat"
language_model = AutoModelForCausalLM.from_pretrained(
language_model_id, device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(language_model_id)
messages = [
{"role": "system", "content": "Follow the examples and return the expected output"},
{"role": "user", "content": "Caption: a blue sky with fluffy clouds\nQuery: Make the sky stormy"},
{"role": "assistant", "content": "A: sky\nB: a stormy sky with heavy gray clouds, torrential rain, gloomy, overcast"},
{"role": "user", "content": "Caption: a cat sleeping on a sofa\nQuery: Change the cat to a dog"},
{"role": "assistant", "content": "A: cat\nB: a dog sleeping on a sofa, cozy and comfortable, snuggled up in a warm blanket, peaceful"},
{"role": "user", "content": "Caption: a snowy mountain peak\nQuery: Replace the snow with greenery"},
{"role": "assistant", "content": "A: snow\nB: a lush green mountain peak in summer, clear blue skies, birds flying overhead, serene and majestic"},
{"role": "user", "content": "Caption: a vintage car parked by the roadside\nQuery: Change the car to a modern electric vehicle"},
{"role": "assistant", "content": "A: car\nB: a sleek modern electric vehicle parked by the roadside, cutting-edge design, environmentally friendly, silent and powerful"},
{"role": "user", "content": "Caption: a wooden bridge over a river\nQuery: Make the bridge stone"},
{"role": "assistant", "content": "A: bridge\nB: an ancient stone bridge over a river, moss-covered, sturdy and timeless, with clear waters flowing beneath"},
{"role": "user", "content": "Caption: a bowl of salad on the table\nQuery: Replace salad with soup"},
{"role": "assistant", "content": "A: bowl\nB: a bowl of steaming hot soup on the table, scrumptious, with garnishing"},
{"role": "user", "content": "Caption: a book on a desk surrounded by stationery\nQuery: Remove all stationery, add a laptop"},
{"role": "assistant", "content": "A: stationery\nB: a book on a desk with a laptop next to it, modern study setup, focused and productive, technology and education combined"},
{"role": "user", "content": "Caption: a cup of coffee on a wooden table\nQuery: Change coffee to tea"},
{"role": "assistant", "content": "A: cup\nB: a steaming cup of tea on a wooden table, calming and aromatic, with a slice of lemon on the side, inviting"},
{"role": "user", "content": "Caption: a small pen on a white table\nQuery: Change the pen to an elaborate fountain pen"},
{"role": "assistant", "content": "A: pen\nB: an elaborate fountain pen on a white table, sleek and elegant, with intricate designs, ready for writing"},
{"role": "user", "content": "Caption: a plain notebook on a desk\nQuery: Replace the notebook with a journal"},
{"role": "assistant", "content": "A: notebook\nB: an artistically decorated journal on a desk, vibrant cover, filled with creativity, inspiring and personalized"},
{"role": "user", "content": f"Caption: {caption}\nQuery: {edit_prompt}"},
]
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
model_inputs = tokenizer([text], return_tensors="pt").to(device)
with torch.no_grad():
generated_ids = language_model.generate(
model_inputs.input_ids,
max_new_tokens=512,
temperature=0.0,
do_sample=False
)
generated_ids = [
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
]
response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
output_generation_a, output_generation_b = response.split("\n")
to_replace = output_generation_a[2:].strip()
replaced_caption = output_generation_b[2:].strip()
delete_model(language_model)
return (to_replace, replaced_caption)
@spaces.GPU()
def run_image_captioner(image, device):
caption_model_id = "Salesforce/blip-image-captioning-base"
caption_model = BlipForConditionalGeneration.from_pretrained(caption_model_id).to(
device
)
caption_processor = BlipProcessor.from_pretrained(caption_model_id)
inputs = caption_processor(image, return_tensors="pt").to(device)
with torch.no_grad():
outputs = caption_model.generate(**inputs, max_new_tokens=200)
caption = caption_processor.decode(outputs[0], skip_special_tokens=True)
delete_model(caption_model)
return caption
@spaces.GPU()
def run_segmentation(image, object_to_segment, device):
# OWL-V2 for object detection
owl_v2_model_id = "google/owlv2-base-patch16-ensemble"
processor = Owlv2Processor.from_pretrained(owl_v2_model_id)
od_model = Owlv2ForObjectDetection.from_pretrained(owl_v2_model_id).to(device)
text_queries = [object_to_segment]
inputs = processor(text=text_queries, images=image, return_tensors="pt").to(device)
with torch.no_grad():
outputs = od_model(**inputs)
target_sizes = torch.tensor([image.size]).to(device)
results = processor.post_process_object_detection(
outputs, threshold=0.1, target_sizes=target_sizes
)[0]
boxes = results["boxes"].tolist()
delete_model(od_model)
# SAM for image segmentation
sam_model_id = "facebook/sam-vit-base"
seg_model = SamModel.from_pretrained(sam_model_id).to(device)
processor = SamProcessor.from_pretrained(sam_model_id)
input_boxes = [boxes]
inputs = processor(image, input_boxes=input_boxes, return_tensors="pt").to(device)
with torch.no_grad():
outputs = seg_model(**inputs)
masks = processor.image_processor.post_process_masks(
outputs.pred_masks.cpu(),
inputs["original_sizes"].cpu(),
inputs["reshaped_input_sizes"].cpu(),
)[0]
# Merge the masks
masks = torch.max(masks[:, 0, ...], dim=0, keepdim=False).values
delete_model(seg_model)
return masks
@spaces.GPU()
def run_inpainting(image, replaced_caption, masks, generator, device):
pipeline = AutoPipelineForInpainting.from_pretrained(
"diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
torch_dtype=torch.float16,
variant="fp16",
).to(device)
masks = Image.fromarray(masks.numpy())
# dilation_image = masks.filter(ImageFilter.MaxFilter(3))
prompt = replaced_caption
negative_prompt = """lowres, bad anatomy, bad hands,
text, error, missing fingers, extra digit, fewer digits,
cropped, worst quality, low quality"""
output = pipeline(
prompt=prompt,
image=image,
# mask_image=dilation_image,
mask_image=masks,
negative_prompt=negative_prompt,
guidance_scale=7.5,
strength=1.0,
generator=generator,
).images[0]
delete_model(pipeline)
return output
def run_open_gen_fill(image, edit_prompt):
device = "cuda" if torch.cuda.is_available() else "cpu"
# Resize the image to (512, 512)
image = image.resize((512, 512))
# Caption the input image
caption = run_image_captioner(image, device=device)
# Run the langauge model to extract the object for segmentation
# and get the replaced caption
to_replace, replaced_caption = run_language_model(
edit_prompt=edit_prompt, caption=caption, device=device
)
# Segment the `to_replace` object from the input image
masks = run_segmentation(image, to_replace, device=device)
# Diffusion pipeline for inpainting
generator = torch.Generator(device).manual_seed(17)
output = run_inpainting(
image=image, replaced_caption=replaced_caption, masks=masks, generator=generator, device=device
)
return (
to_replace,
caption,
replaced_caption,
Image.fromarray(masks.numpy()),
output,
)
def setup_gradio_interface():
block = gr.Blocks()
with block:
gr.Markdown("<h1><center>Open Generative Fill V1<h1><center>")
with gr.Row():
with gr.Column():
input_image_placeholder = gr.Image(type="pil", label="Input Image")
edit_prompt_placeholder = gr.Textbox(label="Enter the editing prompt")
run_button_placeholder = gr.Button(value="Run")
with gr.Column():
to_replace_placeholder = gr.Textbox(label="to_replace")
image_caption_placeholder = gr.Textbox(label="Image Caption")
replace_caption_placeholder = gr.Textbox(label="Replaced Caption")
segmentation_placeholder = gr.Image(type="pil", label="Segmentation")
output_image_placeholder = gr.Image(type="pil", label="Output Image")
run_button_placeholder.click(
fn=lambda image, edit_prompt: run_open_gen_fill(
image=image,
edit_prompt=edit_prompt,
),
inputs=[input_image_placeholder, edit_prompt_placeholder],
outputs=[
to_replace_placeholder,
image_caption_placeholder,
replace_caption_placeholder,
segmentation_placeholder,
output_image_placeholder,
],
)
gr.Examples(
examples=[["dog.jpeg", "replace the dog with a tiger"]],
inputs=[input_image_placeholder, edit_prompt_placeholder],
outputs=[
to_replace_placeholder,
image_caption_placeholder,
replace_caption_placeholder,
segmentation_placeholder,
output_image_placeholder,
],
fn=lambda image, edit_prompt: run_open_gen_fill(
image=image,
edit_prompt=edit_prompt,
),
cache_examples=True,
label="Try this example input!",
)
return block
if __name__ == "__main__":
gradio_interface = setup_gradio_interface()
# gradio_interface.queue(max_size=10)
gradio_interface.launch(share=False, show_api=False, show_error=True)