Xsong123's picture
Update app.py
4d0c4ea verified
raw
history blame
10.4 kB
import gradio as gr
import numpy as np
import spaces
import torch
import random
import json
import os
from PIL import Image
from diffusers import FluxKontextPipeline
from diffusers.utils import load_image
from huggingface_hub import hf_hub_download, HfFileSystem, ModelCard
from safetensors.torch import load_file
import requests
import re
# Load Kontext model
MAX_SEED = np.iinfo(np.int32).max
pipe = FluxKontextPipeline.from_pretrained("black-forest-labs/FLUX.1-Kontext-dev", torch_dtype=torch.bfloat16).to("cuda")
# Load LoRA data from our custom JSON file
with open("kontext_loras.json", "r") as file:
data = json.load(file)
# Add default values for keys that might be missing, to prevent errors
flux_loras_raw = [
{
"image": item["image"],
"title": item["title"],
"repo": item["repo"],
"weights": item.get("weights", "pytorch_lora_weights.safetensors"),
# The following keys are kept for compatibility with the original demo structure,
# but our simplified logic doesn't heavily rely on them.
"trigger_word": item.get("trigger_word", ""),
"lora_type": item.get("lora_type", "flux"),
"lora_scale_config": item.get("lora_scale", 1.0), # Default scale set to 1.0
"prompt_placeholder": item.get("prompt_placeholder", "Describe the subject..."),
}
for item in data
]
print(f"Loaded {len(flux_loras_raw)} LoRAs from kontext_loras.json")
def update_selection(selected_state: gr.SelectData, flux_loras):
"""Update UI when a LoRA is selected"""
if selected_state.index >= len(flux_loras):
return "### No LoRA selected", gr.update(), None, gr.update()
lora_repo = flux_loras[selected_state.index]["repo"]
updated_text = f"### Selected: [{lora_repo}](https://huggingface.co/{lora_repo})"
config_placeholder = flux_loras[selected_state.index]["prompt_placeholder"]
optimal_scale = flux_loras[selected_state.index].get("lora_scale_config", 1.0)
print("Selected Style: ", flux_loras[selected_state.index]['title'])
print("Optimal Scale: ", optimal_scale)
return updated_text, gr.update(placeholder=config_placeholder), selected_state.index, optimal_scale
# This wrapper is kept for compatibility with the Gradio event triggers
def infer_with_lora_wrapper(input_image, prompt, selected_index, lora_state, custom_lora, seed=42, randomize_seed=False, guidance_scale=2.5, lora_scale=1.75,portrait_mode=False, flux_loras=None, progress=gr.Progress(track_tqdm=True)):
"""Wrapper function to handle state serialization"""
# The 'custom_lora' and 'lora_state' arguments are no longer used but kept in the signature
return infer_with_lora(input_image, prompt, selected_index, seed, randomize_seed, guidance_scale, lora_scale, portrait_mode, flux_loras, progress)
@spaces.GPU
def infer_with_lora(input_image, prompt, selected_index, seed=42, randomize_seed=False, guidance_scale=2.5, lora_scale=1.0, portrait_mode=False, flux_loras=None, progress=gr.Progress(track_tqdm=True)):
"""Generate image with selected LoRA"""
global pipe
if randomize_seed:
seed = random.randint(0, MAX_SEED)
# Unload any previous LoRA to ensure a clean state
if "selected_lora" in pipe.get_active_adapters():
pipe.unload_lora_weights()
# Determine which LoRA to use from our gallery
lora_to_use = None
if selected_index is not None and flux_loras and selected_index < len(flux_loras):
lora_to_use = flux_loras[selected_index]
if lora_to_use:
print(f"Applying LoRA: {lora_to_use['title']}")
try:
# Load LoRA directly from the Hugging Face Hub
pipe.load_lora_weights(
lora_to_use["repo"],
weight_name=lora_to_use["weights"],
adapter_name="selected_lora"
)
pipe.set_adapters(["selected_lora"], adapter_weights=[lora_scale])
print(f"Loaded {lora_to_use['repo']} with scale {lora_scale}")
# Simplified and direct prompt construction
style_name = lora_to_use['title']
if prompt:
final_prompt = f"Turn this image of {prompt} into {style_name} style."
else:
final_prompt = f"Turn this image into {style_name} style."
print(f"Using prompt: {final_prompt}")
except Exception as e:
print(f"Error loading LoRA: {e}")
final_prompt = prompt # Fallback to user prompt if LoRA fails
else:
# No LoRA selected, just use the original prompt
final_prompt = prompt
input_image = input_image.convert("RGB")
try:
image = pipe(
image=input_image,
width=input_image.size[0],
height=input_image.size[1],
prompt=final_prompt,
guidance_scale=guidance_scale,
generator=torch.Generator().manual_seed(seed)
).images[0]
return image, seed, gr.update(visible=True), lora_scale
except Exception as e:
print(f"Error during inference: {e}")
return None, seed, gr.update(visible=False), lora_scale
# CSS styling
css = """
#main_app {
display: flex;
gap: 20px;
}
#box_column {
min-width: 400px;
}
#title{text-align: center}
#title h1{font-size: 3em; display:inline-flex; align-items:center}
#title img{width: 100px; margin-right: 0.5em}
#selected_lora {
color: #2563eb;
font-weight: bold;
}
#prompt {
flex-grow: 1;
}
#run_button {
background: linear-gradient(45deg, #2563eb, #3b82f6);
color: white;
border: none;
padding: 8px 16px;
border-radius: 6px;
font-weight: bold;
}
.custom_lora_card {
background: #f8fafc;
border: 1px solid #e2e8f0;
border-radius: 8px;
padding: 12px;
margin: 8px 0;
}
#gallery{
overflow: scroll !important
}
"""
# Create Gradio interface
with gr.Blocks(css=css, theme=gr.themes.Ocean(font=[gr.themes.GoogleFont("Lexend Deca"), "sans-serif"])) as demo:
gr_flux_loras = gr.State(value=flux_loras_raw)
title = gr.HTML(
"""<h1><img src="https://huggingface.co/spaces/kontext-community/FLUX.1-Kontext-portrait/resolve/main/dora_kontext.png" alt="LoRA"> Kontext-Style LoRA Explorer</h1>""",
elem_id="title",
)
gr.Markdown("A demo for the style LoRAs from the [Kontext-Style Collection](https://huggingface.co/Kontext-Style) 🤗")
selected_state = gr.State(value=None)
# The following states are no longer used by the simplified logic but kept for component structure
custom_loaded_lora = gr.State(value=None)
lora_state = gr.State(value=1.0)
with gr.Row(elem_id="main_app"):
with gr.Column(scale=4, elem_id="box_column"):
with gr.Group(elem_id="gallery_box"):
input_image = gr.Image(label="Upload a picture of yourself", type="pil", height=300)
portrait_mode = gr.Checkbox(label="portrait mode", value=True)
gallery = gr.Gallery(
label="Pick a LoRA",
allow_preview=False,
columns=3,
elem_id="gallery",
show_share_button=False,
height=400
)
custom_model = gr.Textbox(
label="Or enter a custom HuggingFace FLUX LoRA",
placeholder="e.g., username/lora-name",
visible=False
)
custom_model_card = gr.HTML(visible=False)
custom_model_button = gr.Button("Remove custom LoRA", visible=False)
with gr.Column(scale=5):
with gr.Row():
prompt = gr.Textbox(
label="Editing Prompt",
show_label=False,
lines=1,
max_lines=1,
placeholder="opt - describe the person/subject, e.g. 'a man with glasses and a beard'",
elem_id="prompt"
)
run_button = gr.Button("Generate", elem_id="run_button")
result = gr.Image(label="Generated Image", interactive=False)
reuse_button = gr.Button("Reuse this image", visible=False)
with gr.Accordion("Advanced Settings", open=False):
lora_scale = gr.Slider(
label="LoRA Scale",
minimum=0,
maximum=2,
step=0.1,
value=1.5,
info="Controls the strength of the LoRA effect"
)
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=0,
)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
guidance_scale = gr.Slider(
label="Guidance Scale",
minimum=1,
maximum=10,
step=0.1,
value=2.5,
)
prompt_title = gr.Markdown(
value="### Click on a LoRA in the gallery to select it",
visible=True,
elem_id="selected_lora",
)
# Event handlers
# The custom model inputs are no longer needed as we've hidden them.
gallery.select(
fn=update_selection,
inputs=[gr_flux_loras],
outputs=[prompt_title, prompt, selected_state, lora_scale],
show_progress=False
)
gr.on(
triggers=[run_button.click, prompt.submit],
fn=infer_with_lora_wrapper,
inputs=[input_image, prompt, selected_state, lora_state, custom_loaded_lora, seed, randomize_seed, guidance_scale, lora_scale, portrait_mode, gr_flux_loras],
outputs=[result, seed, reuse_button, lora_state]
)
reuse_button.click(
fn=lambda image: image,
inputs=[result],
outputs=[input_image]
)
# Initialize gallery
demo.load(
fn=lambda: (flux_loras_raw, flux_loras_raw),
outputs=[gallery, gr_flux_loras]
)
demo.queue(default_concurrency_limit=None)
demo.launch()