| | import gradio as gr |
| | import torch |
| | import spaces |
| | from PIL import Image, ImageDraw, ImageFont |
| | from src.condition import Condition |
| | from diffusers.pipelines import FluxPipeline |
| | import numpy as np |
| |
|
| | from src.generate import seed_everything, generate |
| |
|
| | pipe = None |
| | pipe = FluxPipeline.from_pretrained( |
| | "black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16 |
| | ) |
| | pipe = pipe.to("cuda") |
| | pipe.load_lora_weights( |
| | "Yuanshi/OminiControl", |
| | weight_name=f"omini/subject_512.safetensors", |
| | adapter_name="subject_512", |
| | ) |
| | pipe.load_lora_weights( |
| | "Yuanshi/OminiControl", |
| | weight_name=f"omini/subject_1024_beta.safetensors", |
| | adapter_name="subject_1024", |
| | ) |
| |
|
| |
|
| | @spaces.GPU |
| | def process_image_and_text(image, resolution, text): |
| | w, h, min_size = image.size[0], image.size[1], min(image.size) |
| | image = image.crop( |
| | ( |
| | (w - min_size) // 2, |
| | (h - min_size) // 2, |
| | (w + min_size) // 2, |
| | (h + min_size) // 2, |
| | ) |
| | ) |
| | image = image.resize((512, 512)) |
| |
|
| | condition = Condition("subject", image) |
| |
|
| | result_img = generate( |
| | pipe, |
| | prompt=text.strip(), |
| | conditions=[condition], |
| | num_inference_steps=8, |
| | height=resolution, |
| | width=resolution, |
| | ).images[0] |
| |
|
| | return result_img |
| |
|
| |
|
| | def get_samples(): |
| | sample_list = [ |
| | { |
| | "image": "assets/oranges.jpg", |
| | "resolution": 512, |
| | "text": "A very close up view of this item. It is placed on a wooden table. The background is a dark room, the TV is on, and the screen is showing a cooking show. With text on the screen that reads 'Omini Control!'", |
| | }, |
| | { |
| | "image": "assets/penguin.jpg", |
| | "resolution": 512, |
| | "text": "On Christmas evening, on a crowded sidewalk, this item sits on the road, covered in snow and wearing a Christmas hat, holding a sign that reads 'Omini Control!'", |
| | }, |
| | { |
| | "image": "assets/rc_car.jpg", |
| | "resolution": 1024, |
| | "text": "A film style shot. On the moon, this item drives across the moon surface. The background is that Earth looms large in the foreground.", |
| | }, |
| | { |
| | "image": "assets/clock.jpg", |
| | "resolution": 1024, |
| | "text": "In a Bauhaus style room, this item is placed on a shiny glass table, with a vase of flowers next to it. In the afternoon sun, the shadows of the blinds are cast on the wall.", |
| | }, |
| | ] |
| | return [ |
| | [ |
| | Image.open(sample["image"]).resize((512, 512)), |
| | sample["resolution"], |
| | sample["text"], |
| | ] |
| | for sample in sample_list |
| | ] |
| |
|
| |
|
| | header = """ |
| | # 🌍 OminiControl / FLUX |
| | |
| | <div style="text-align: center; display: flex; justify-content: left; gap: 5px;"> |
| | <a href="https://arxiv.org/abs/2411.15098"><img src="https://img.shields.io/badge/ariXv-Paper-A42C25.svg" alt="arXiv"></a> |
| | <a href="https://huggingface.co/Yuanshi/OminiControl"><img src="https://img.shields.io/badge/🤗-Model-ffbd45.svg" alt="HuggingFace"></a> |
| | <a href="https://github.com/Yuanshi9815/OminiControl"><img src="https://img.shields.io/badge/GitHub-Code-blue.svg?logo=github&" alt="GitHub"></a> |
| | </div> |
| | """ |
| |
|
| |
|
| | def create_app(): |
| | with gr.Blocks() as app: |
| | gr.Markdown(header) |
| | with gr.Tabs(): |
| | with gr.Tab("Subject-driven"): |
| | gr.Interface( |
| | fn=process_image_and_text, |
| | inputs=[ |
| | gr.Image(type="pil", label="Condition Image", width=300), |
| | gr.Radio( |
| | [("512", 512), ("1024(beta)", 1024)], |
| | label="Resolution", |
| | value=512, |
| | ), |
| | |
| | gr.Textbox(lines=2, label="Text Prompt"), |
| | ], |
| | outputs=gr.Image(type="pil"), |
| | examples=get_samples(), |
| | ) |
| | with gr.Tab("Fill"): |
| | gr.Markdown("Coming soon") |
| | with gr.Tab("Canny"): |
| | gr.Markdown("Coming soon") |
| | with gr.Tab("Depth"): |
| | gr.Markdown("Coming soon") |
| | return app |
| |
|
| |
|
| | if __name__ == "__main__": |
| | create_app().launch(debug=True, ssr_mode=False) |
| |
|