ariG23498's picture
ariG23498 HF staff
get better masks
3d3d52b verified
raw
history blame
9.02 kB
import spaces
import gradio as gr
import torch
from diffusers import AutoPipelineForInpainting
from PIL import Image
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
BlipForConditionalGeneration,
BlipProcessor,
Owlv2ForObjectDetection,
Owlv2Processor,
SamModel,
SamProcessor,
)
def delete_model(model):
model.to("cpu")
del model
torch.cuda.empty_cache()
@spaces.GPU()
def run_language_model(edit_prompt, device):
language_model_id = "Qwen/Qwen1.5-0.5B-Chat"
language_model = AutoModelForCausalLM.from_pretrained(
language_model_id, device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(language_model_id)
messages = [
{"role": "system", "content": "Follow the examples and return the expected output"},
{"role": "user", "content": "swap mountain and lion"}, # example 1
{"role": "assistant", "content": "mountain, lion"}, # example 1
{"role": "user", "content": "change the dog with cat"}, # example 2
{"role": "assistant", "content": "dog, cat"}, # example 2
{"role": "user", "content": "change the cat with a dog"}, # example 3
{"role": "assistant", "content": "cat, dog"}, # example 3
{"role": "user", "content": "replace the human with a boat"}, # example 4
{"role": "assistant", "content": "human, boat"}, # example 4
{"role": "user", "content": "in the above example change the background to the alps"}, # example 5
{"role": "assistant", "content": "background, alps"}, # example 5
{"role": "user", "content": "edit the house into a mansion"}, # example 6
{"role": "assistant", "content": "house, a mansion"}, # example 6
{"role": "user", "content": edit_prompt}
]
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
model_inputs = tokenizer([text], return_tensors="pt").to(device)
with torch.no_grad():
generated_ids = language_model.generate(
model_inputs.input_ids,
max_new_tokens=512,
temperature=0.0,
do_sample=False
)
generated_ids = [
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
]
response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
to_replace, replace_with = response.split(", ")
delete_model(language_model)
return (to_replace, replace_with)
@spaces.GPU()
def run_image_captioner(image, device):
caption_model_id = "Salesforce/blip-image-captioning-base"
caption_model = BlipForConditionalGeneration.from_pretrained(caption_model_id).to(
device
)
caption_processor = BlipProcessor.from_pretrained(caption_model_id)
inputs = caption_processor(image, return_tensors="pt").to(device)
with torch.no_grad():
outputs = caption_model.generate(**inputs, max_new_tokens=200)
caption = caption_processor.decode(outputs[0], skip_special_tokens=True)
delete_model(caption_model)
return caption
@spaces.GPU()
def run_segmentation(image, object_to_segment, device):
# OWL-V2 for object detection
owl_v2_model_id = "google/owlv2-base-patch16-ensemble"
processor = Owlv2Processor.from_pretrained(owl_v2_model_id)
od_model = Owlv2ForObjectDetection.from_pretrained(owl_v2_model_id).to(device)
text_queries = [object_to_segment]
inputs = processor(text=text_queries, images=image, return_tensors="pt").to(device)
with torch.no_grad():
outputs = od_model(**inputs)
target_sizes = torch.tensor([image.size]).to(device)
results = processor.post_process_object_detection(
outputs, threshold=0.1, target_sizes=target_sizes
)[0]
boxes = results["boxes"].tolist()
delete_model(od_model)
# SAM for image segmentation
sam_model_id = "facebook/sam-vit-base"
seg_model = SamModel.from_pretrained(sam_model_id).to(device)
processor = SamProcessor.from_pretrained(sam_model_id)
input_boxes = [boxes]
inputs = processor(image, input_boxes=input_boxes, return_tensors="pt").to(device)
with torch.no_grad():
outputs = seg_model(**inputs)
masks = processor.image_processor.post_process_masks(
outputs.pred_masks.cpu(),
inputs["original_sizes"].cpu(),
inputs["reshaped_input_sizes"].cpu(),
)[0]
# Merge the masks
masks = torch.max(masks[:, 0, ...], dim=0, keepdim=False).values
delete_model(seg_model)
return masks
@spaces.GPU()
def run_inpainting(image, replaced_caption, masks, device):
pipeline = AutoPipelineForInpainting.from_pretrained(
"diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
torch_dtype=torch.float16,
variant="fp16",
).to(device)
prompt = replaced_caption
negative_prompt = """lowres, bad anatomy, bad hands,
text, error, missing fingers, extra digit, fewer digits,
cropped, worst quality, low quality"""
output = pipeline(
prompt=prompt,
image=image,
mask_image=Image.fromarray(masks.numpy()),
negative_prompt=negative_prompt,
guidance_scale=7.5,
strength=1.0,
).images[0]
delete_model(pipeline)
return output
def run_open_gen_fill(image, edit_prompt):
device = "cuda" if torch.cuda.is_available() else "cpu"
# Resize the image to (512, 512)
image = image.resize((512, 512))
# Run the langauge model to extract the objects to be swapped from
# the edit prompt
to_replace, replace_with = run_language_model(
edit_prompt=edit_prompt, device=device
)
# Caption the input image
caption = run_image_captioner(image, device=device)
# Replace the object in the caption with the new object
replaced_caption = caption.replace(to_replace, replace_with)
# Segment the `to_replace` object from the input image
masks = run_segmentation(image, to_replace, device=device)
# Diffusion pipeline for inpainting
output = run_inpainting(
image=image, replaced_caption=replaced_caption, masks=masks, device=device
)
return (
to_replace,
replace_with,
caption,
replaced_caption,
Image.fromarray(masks.numpy()),
output,
)
def setup_gradio_interface():
block = gr.Blocks()
with block:
gr.Markdown("<h1><center>Open Generative Fill V1<h1><center>")
with gr.Row():
with gr.Column():
input_image_placeholder = gr.Image(type="pil", label="Input Image")
edit_prompt_placeholder = gr.Textbox(label="Enter the editing prompt")
run_button_placeholder = gr.Button(value="Run")
with gr.Column():
with gr.Row():
to_replace_placeholder = gr.Textbox(label="to_replace")
replace_with_placeholder = gr.Textbox(label="replace_with")
image_caption_placeholder = gr.Textbox(label="Image Caption")
replace_caption_placeholder = gr.Textbox(label="Replaced Caption")
# object_detection_placeholder = gr.Image(type="pil", label="Object Detection")
segmentation_placeholder = gr.Image(type="pil", label="Segmentation")
output_image_placeholder = gr.Image(type="pil", label="Output Image")
run_button_placeholder.click(
fn=lambda image, edit_prompt: run_open_gen_fill(
image=image,
edit_prompt=edit_prompt,
),
inputs=[input_image_placeholder, edit_prompt_placeholder],
outputs=[
to_replace_placeholder,
replace_with_placeholder,
image_caption_placeholder,
replace_caption_placeholder,
# object_detection_placeholder,
segmentation_placeholder,
output_image_placeholder,
],
)
gr.Examples(
examples=[["dog.jpeg", "replace the dog with a tiger"]],
inputs=[input_image_placeholder, edit_prompt_placeholder],
outputs=[
to_replace_placeholder,
replace_with_placeholder,
image_caption_placeholder,
replace_caption_placeholder,
# object_detection_placeholder,
segmentation_placeholder,
output_image_placeholder,
],
fn=lambda image, edit_prompt: run_open_gen_fill(
image=image,
edit_prompt=edit_prompt,
),
cache_examples=True,
label="Try this example input!",
)
return block
if __name__ == "__main__":
gradio_interface = setup_gradio_interface()
# gradio_interface.queue(max_size=10)
gradio_interface.launch(share=False, show_api=False, show_error=True)