Spaces:
Runtime error
Runtime error
import spaces | |
import gradio as gr | |
import torch | |
from diffusers import AutoPipelineForInpainting | |
from PIL import Image | |
from transformers import ( | |
AutoModelForCausalLM, | |
AutoTokenizer, | |
BlipForConditionalGeneration, | |
BlipProcessor, | |
Owlv2ForObjectDetection, | |
Owlv2Processor, | |
SamModel, | |
SamProcessor, | |
) | |
def delete_model(model): | |
model.to("cpu") | |
del model | |
torch.cuda.empty_cache() | |
def run_language_model(edit_prompt, device): | |
language_model_id = "Qwen/Qwen1.5-0.5B-Chat" | |
language_model = AutoModelForCausalLM.from_pretrained( | |
language_model_id, device_map="auto" | |
) | |
tokenizer = AutoTokenizer.from_pretrained(language_model_id) | |
messages = [ | |
{"role": "system", "content": "Follow the examples and return the expected output"}, | |
{"role": "user", "content": "swap mountain and lion"}, # example 1 | |
{"role": "assistant", "content": "mountain, lion"}, # example 1 | |
{"role": "user", "content": "change the dog with cat"}, # example 2 | |
{"role": "assistant", "content": "dog, cat"}, # example 2 | |
{"role": "user", "content": "change the cat with a dog"}, # example 3 | |
{"role": "assistant", "content": "cat, dog"}, # example 3 | |
{"role": "user", "content": "replace the human with a boat"}, # example 4 | |
{"role": "assistant", "content": "human, boat"}, # example 4 | |
{"role": "user", "content": "in the above example change the background to the alps"}, # example 5 | |
{"role": "assistant", "content": "background, alps"}, # example 5 | |
{"role": "user", "content": "edit the house into a mansion"}, # example 6 | |
{"role": "assistant", "content": "house, a mansion"}, # example 6 | |
{"role": "user", "content": edit_prompt} | |
] | |
text = tokenizer.apply_chat_template( | |
messages, | |
tokenize=False, | |
add_generation_prompt=True | |
) | |
model_inputs = tokenizer([text], return_tensors="pt").to(device) | |
with torch.no_grad(): | |
generated_ids = language_model.generate( | |
model_inputs.input_ids, | |
max_new_tokens=512, | |
temperature=0.0, | |
do_sample=False | |
) | |
generated_ids = [ | |
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids) | |
] | |
response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] | |
to_replace, replace_with = response.split(", ") | |
delete_model(language_model) | |
return (to_replace, replace_with) | |
def run_image_captioner(image, device): | |
caption_model_id = "Salesforce/blip-image-captioning-base" | |
caption_model = BlipForConditionalGeneration.from_pretrained(caption_model_id).to( | |
device | |
) | |
caption_processor = BlipProcessor.from_pretrained(caption_model_id) | |
inputs = caption_processor(image, return_tensors="pt").to(device) | |
with torch.no_grad(): | |
outputs = caption_model.generate(**inputs, max_new_tokens=200) | |
caption = caption_processor.decode(outputs[0], skip_special_tokens=True) | |
delete_model(caption_model) | |
return caption | |
def run_segmentation(image, object_to_segment, device): | |
# OWL-V2 for object detection | |
owl_v2_model_id = "google/owlv2-base-patch16-ensemble" | |
processor = Owlv2Processor.from_pretrained(owl_v2_model_id) | |
od_model = Owlv2ForObjectDetection.from_pretrained(owl_v2_model_id).to(device) | |
text_queries = [object_to_segment] | |
inputs = processor(text=text_queries, images=image, return_tensors="pt").to(device) | |
with torch.no_grad(): | |
outputs = od_model(**inputs) | |
target_sizes = torch.tensor([image.size]).to(device) | |
results = processor.post_process_object_detection( | |
outputs, threshold=0.1, target_sizes=target_sizes | |
)[0] | |
boxes = results["boxes"].tolist() | |
delete_model(od_model) | |
# SAM for image segmentation | |
sam_model_id = "facebook/sam-vit-base" | |
seg_model = SamModel.from_pretrained(sam_model_id).to(device) | |
processor = SamProcessor.from_pretrained(sam_model_id) | |
input_boxes = [boxes] | |
inputs = processor(image, input_boxes=input_boxes, return_tensors="pt").to(device) | |
with torch.no_grad(): | |
outputs = seg_model(**inputs) | |
masks = processor.image_processor.post_process_masks( | |
outputs.pred_masks.cpu(), | |
inputs["original_sizes"].cpu(), | |
inputs["reshaped_input_sizes"].cpu(), | |
)[0] | |
# Merge the masks | |
masks = torch.max(masks[:, 0, ...], dim=0, keepdim=False).values | |
delete_model(seg_model) | |
return masks | |
def run_inpainting(image, replaced_caption, masks, device): | |
pipeline = AutoPipelineForInpainting.from_pretrained( | |
"diffusers/stable-diffusion-xl-1.0-inpainting-0.1", | |
torch_dtype=torch.float16, | |
variant="fp16", | |
).to(device) | |
prompt = replaced_caption | |
negative_prompt = """lowres, bad anatomy, bad hands, | |
text, error, missing fingers, extra digit, fewer digits, | |
cropped, worst quality, low quality""" | |
output = pipeline( | |
prompt=prompt, | |
image=image, | |
mask_image=Image.fromarray(masks.numpy()), | |
negative_prompt=negative_prompt, | |
guidance_scale=7.5, | |
strength=1.0, | |
).images[0] | |
delete_model(pipeline) | |
return output | |
def run_open_gen_fill(image, edit_prompt): | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
# Resize the image to (512, 512) | |
image = image.resize((512, 512)) | |
# Run the langauge model to extract the objects to be swapped from | |
# the edit prompt | |
to_replace, replace_with = run_language_model( | |
edit_prompt=edit_prompt, device=device | |
) | |
# Caption the input image | |
caption = run_image_captioner(image, device=device) | |
# Replace the object in the caption with the new object | |
replaced_caption = caption.replace(to_replace, replace_with) | |
# Segment the `to_replace` object from the input image | |
masks = run_segmentation(image, to_replace, device=device) | |
# Diffusion pipeline for inpainting | |
output = run_inpainting( | |
image=image, replaced_caption=replaced_caption, masks=masks, device=device | |
) | |
return ( | |
to_replace, | |
replace_with, | |
caption, | |
replaced_caption, | |
Image.fromarray(masks.numpy()), | |
output, | |
) | |
def setup_gradio_interface(): | |
block = gr.Blocks() | |
with block: | |
gr.Markdown("<h1><center>Open Generative Fill V1<h1><center>") | |
with gr.Row(): | |
with gr.Column(): | |
input_image_placeholder = gr.Image(type="pil", label="Input Image") | |
edit_prompt_placeholder = gr.Textbox(label="Enter the editing prompt") | |
run_button_placeholder = gr.Button(value="Run") | |
with gr.Column(): | |
with gr.Row(): | |
to_replace_placeholder = gr.Textbox(label="to_replace") | |
replace_with_placeholder = gr.Textbox(label="replace_with") | |
image_caption_placeholder = gr.Textbox(label="Image Caption") | |
replace_caption_placeholder = gr.Textbox(label="Replaced Caption") | |
# object_detection_placeholder = gr.Image(type="pil", label="Object Detection") | |
segmentation_placeholder = gr.Image(type="pil", label="Segmentation") | |
output_image_placeholder = gr.Image(type="pil", label="Output Image") | |
run_button_placeholder.click( | |
fn=lambda image, edit_prompt: run_open_gen_fill( | |
image=image, | |
edit_prompt=edit_prompt, | |
), | |
inputs=[input_image_placeholder, edit_prompt_placeholder], | |
outputs=[ | |
to_replace_placeholder, | |
replace_with_placeholder, | |
image_caption_placeholder, | |
replace_caption_placeholder, | |
# object_detection_placeholder, | |
segmentation_placeholder, | |
output_image_placeholder, | |
], | |
) | |
gr.Examples( | |
examples=[["dog.jpeg", "replace the dog with a tiger"]], | |
inputs=[input_image_placeholder, edit_prompt_placeholder], | |
outputs=[ | |
to_replace_placeholder, | |
replace_with_placeholder, | |
image_caption_placeholder, | |
replace_caption_placeholder, | |
# object_detection_placeholder, | |
segmentation_placeholder, | |
output_image_placeholder, | |
], | |
fn=lambda image, edit_prompt: run_open_gen_fill( | |
image=image, | |
edit_prompt=edit_prompt, | |
), | |
cache_examples=True, | |
label="Try this example input!", | |
) | |
return block | |
if __name__ == "__main__": | |
gradio_interface = setup_gradio_interface() | |
# gradio_interface.queue(max_size=10) | |
gradio_interface.launch(share=False, show_api=False, show_error=True) | |