Sqxww's picture
fix unmatched arguments
dec59eb
import spaces
import os
import time
import torch
import gradio as gr
from PIL import Image
from huggingface_hub import hf_hub_download, list_repo_files, login
from src_inference.pipeline import FluxPipeline
from src_inference.lora_helper import set_single_lora
HF_TOKEN = os.environ.get("HF_TOKEN")
if HF_TOKEN:
login(token=HF_TOKEN)
BASE_PATH = "black-forest-labs/FLUX.1-dev"
LOCAL_LORA_DIR = "./LoRAs"
CUSTOM_LORA_DIR = "./Custom_LoRAs"
os.makedirs(LOCAL_LORA_DIR, exist_ok=True)
os.makedirs(CUSTOM_LORA_DIR, exist_ok=True)
print("downloading OmniConsistency base LoRA …")
omni_consistency_path = hf_hub_download(
repo_id="showlab/OmniConsistency",
filename="OmniConsistency.safetensors",
local_dir="./Model"
)
print("loading base pipeline …")
pipe = FluxPipeline.from_pretrained(
BASE_PATH, torch_dtype=torch.bfloat16
).to("cuda")
set_single_lora(pipe.transformer, omni_consistency_path,
lora_weights=[1], cond_size=512)
lora_names = [
"3D_Chibi", "American_Cartoon", "Macaron",
"Pixel", "Poly", "Van_Gogh"
]
def download_all_loras():
for name in lora_names:
hf_hub_download(
repo_id="showlab/OmniConsistency",
filename=f"LoRAs/{name}_rank128_bf16.safetensors",
local_dir=LOCAL_LORA_DIR,
)
download_all_loras()
def reload_all_loras():
pipe.unload_lora_weights()
for name in lora_names:
pipe.load_lora_weights(
f"{LOCAL_LORA_DIR}/LoRAs",
weight_name=f"{name}_rank128_bf16.safetensors",
adapter_name=name,
)
reload_all_loras()
def clear_cache(transformer):
for _, attn_processor in transformer.attn_processors.items():
attn_processor.bank_kv.clear()
@spaces.GPU(duration=30)
def generate_image(
lora_name,
prompt,
uploaded_image,
guidance_scale,
num_inference_steps,
seed
):
width, height = uploaded_image.size
maxSize = 1024
factor = maxSize / max(width, height)
width = int(width * factor)
height = int(height * factor)
generator = torch.Generator("cpu").manual_seed(seed)
pipe.set_adapters(lora_name)
spatial_image = [uploaded_image.convert("RGB")]
subject_images = []
start = time.time()
out_img = pipe(
prompt,
height=(height // 8) * 8,
width=(width // 8) * 8,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
max_sequence_length=512,
generator=generator,
spatial_images=spatial_image,
subject_images=subject_images,
cond_size=512,
).images[0]
print(f"inference time: {time.time()-start:.2f}s")
clear_cache(pipe.transformer)
return uploaded_image, out_img
# =============== Gradio UI ===============
def create_interface():
def update_trigger_word(lora_name, prompt):
for name in lora_names:
trigger = " ".join(name.split("_")) + " style,"
prompt = prompt.replace(trigger, "")
new_trigger = " ".join(lora_name.split("_"))+ " style,"
return new_trigger + prompt
header = """
<div style="text-align: center; display: flex; justify-content: left; gap: 5px;">
<a href="https://arxiv.org/abs/2505.18445"><img src="https://img.shields.io/badge/ariXv-2505.18445-A42C25.svg" alt="arXiv"></a>
<a href="https://huggingface.co/showlab/OmniConsistency"><img src="https://img.shields.io/badge/🤗_HuggingFace-Model-ffbd45.svg" alt="HuggingFace"></a>
<a href="https://github.com/showlab/OmniConsistency"><img src="https://img.shields.io/badge/GitHub-Code-blue.svg?logo=github&" alt="GitHub"></a>
</div>
"""
with gr.Blocks() as demo:
gr.Markdown("# OmniConsistency LoRA Image Generation")
gr.Markdown("Select a LoRA, enter a prompt, and upload an image to generate a new image with OmniConsistency.")
gr.HTML(header)
with gr.Row():
with gr.Column(scale=1):
image_input = gr.Image(type="pil", label="Upload Image")
prompt_box = gr.Textbox(label="Prompt",
value="3D Chibi style,",
info="Remember to include the necessary trigger words if you're using a custom LoRA."
)
lora_dropdown = gr.Dropdown(
lora_names, label="Select built-in LoRA")
gen_btn = gr.Button("Generate")
with gr.Column(scale=1):
output_image = gr.ImageSlider(label="Generated Image")
with gr.Accordion("Advanced Options", open=False):
height_box = gr.Textbox(value="1024", label="Height")
width_box = gr.Textbox(value="1024", label="Width")
guidance_slider = gr.Slider(
0.1, 20, value=3.5, step=0.1, label="Guidance Scale")
steps_slider = gr.Slider(
1, 50, value=25, step=1, label="Inference Steps")
seed_slider = gr.Slider(
1, 2_147_483_647, value=42, step=1, label="Seed")
lora_dropdown.select(fn=update_trigger_word, inputs=[lora_dropdown,prompt_box],
outputs=prompt_box)
gen_btn.click(
fn=generate_image,
inputs=[lora_dropdown, prompt_box, image_input, guidance_slider, steps_slider, seed_slider],
outputs=output_image
)
return demo
if __name__ == "__main__":
demo = create_interface()
demo.launch(ssr_mode=False)