|
import gradio as gr |
|
import torch |
|
import spaces |
|
import os |
|
import numpy as np |
|
from PIL import Image |
|
from huggingface_hub import hf_hub_download |
|
from safetensors.torch import load_file |
|
from omegaconf import OmegaConf |
|
|
|
from image_datasets.dataset import image_resize |
|
def tensor_to_pil_image(in_image): |
|
tensor = in_image.squeeze(0) |
|
tensor = (tensor + 1) / 2 |
|
tensor = tensor * 255 |
|
numpy_array = tensor.permute(1, 2, 0).byte().numpy() |
|
pil_image = Image.fromarray(numpy_array) |
|
return pil_image |
|
|
|
args = OmegaConf.load("inference_configs/inference.yaml") |
|
|
|
|
|
device = torch.device("cuda") |
|
dtype = torch.bfloat16 |
|
|
|
|
|
|
|
|
|
|
|
@spaces.GPU |
|
def generate(image: Image.Image, edit_prompt: str): |
|
from src.flux.xflux_pipeline import XFluxSampler |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sampler = XFluxSampler( |
|
device = device, |
|
ip_loaded=False, |
|
spatial_condition=False, |
|
clip_image_processor=None, |
|
image_encoder=None, |
|
improj=None |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
img = image_resize(image, 512) |
|
w, h = img.size |
|
img = img.resize(((w // 32) * 32, (h // 32) * 32)) |
|
img = torch.from_numpy((np.array(img) / 127.5) - 1) |
|
img = img.permute(2, 0, 1).unsqueeze(0).to(device, dtype=dtype) |
|
|
|
result = sampler( |
|
prompt=edit_prompt, |
|
width=args.sample_width, |
|
height=args.sample_height, |
|
num_steps=args.sample_steps, |
|
image_prompt=None, |
|
true_gs=args.cfg_scale, |
|
seed=args.seed, |
|
ip_scale=args.ip_scale if args.use_ip else 1.0, |
|
source_image=img if args.use_spatial_condition else None, |
|
) |
|
return tensor_to_pil_image(result) |
|
|
|
def get_samples(): |
|
sample_list = [ |
|
{ |
|
"image": "assets/0_camera_zoom/20486354.png", |
|
"edit_prompt": "Zoom in on the coral and add a small blue fish in the background.", |
|
}, |
|
] |
|
return [ |
|
[ |
|
Image.open(sample["image"]).resize((512, 512)), |
|
sample["edit_prompt"], |
|
] |
|
for sample in sample_list |
|
] |
|
|
|
header = """ |
|
# ByteMorph |
|
|
|
<div style="text-align: center; display: flex; justify-content: left; gap: 5px;"> |
|
<a href=""><img src="https://img.shields.io/badge/ariXv-Paper-A42C25.svg" alt="arXiv"></a> |
|
<a href="https://huggingface.co/datasets/Boese0601/ByteMorph-Bench"><img src="https://img.shields.io/badge/🤗-Model-ffbd45.svg" alt="HuggingFace"></a> |
|
<a href="https://github.com/Boese0601/ByteMorph"><img src="https://img.shields.io/badge/GitHub-Code-blue.svg?logo=github&" alt="GitHub"></a> |
|
</div> |
|
""" |
|
|
|
def create_app(): |
|
with gr.Blocks() as app: |
|
gr.Markdown(header, elem_id="header") |
|
with gr.Row(equal_height=False): |
|
with gr.Column(variant="panel", elem_classes="inputPanel"): |
|
original_image = gr.Image( |
|
type="pil", label="Condition Image", width=300, elem_id="input" |
|
) |
|
edit_prompt = gr.Textbox(lines=2, label="Edit Prompt", elem_id="edit_prompt") |
|
submit_btn = gr.Button("Run", elem_id="submit_btn") |
|
|
|
with gr.Column(variant="panel", elem_classes="outputPanel"): |
|
output_image = gr.Image(type="pil", elem_id="output") |
|
|
|
with gr.Row(): |
|
examples = gr.Examples( |
|
examples=get_samples(), |
|
inputs=[original_image, edit_prompt], |
|
label="Examples", |
|
) |
|
|
|
submit_btn.click( |
|
fn=generate, |
|
inputs=[original_image, edit_prompt], |
|
outputs=output_image, |
|
) |
|
gr.HTML( |
|
""" |
|
<div style="text-align: center;"> |
|
* This demo's template was modified from <a href="https://arxiv.org/abs/2411.15098" target="_blank">OminiControl</a>. |
|
</div> |
|
""" |
|
) |
|
return app |
|
|
|
if __name__ == "__main__": |
|
create_app().launch(debug=False, share=False, ssr_mode=False) |
|
|