nftnik's picture
Update app.py
91d99a8 verified
import os
import random
import re
import requests
import torch
import numpy as np
import gradio as gr
import spaces
from diffusers import FluxPipeline
from translatepy import Translator
# -----------------------------------------------------------------------------
# CONFIGURATION
# -----------------------------------------------------------------------------
config = {
"model_id": "black-forest-labs/FLUX.1-dev",
"default_lora": "nftnik/BR_ohwx_V1",
"default_weight_name": "BR_ohwx.safetensors",
"max_seed": int(np.iinfo(np.int32).max),
"css": "footer { visibility: hidden; }",
"default_width": 896,
"default_height": 1152,
"default_guidance_scale": 3.5,
"default_steps": 35,
"default_loRa_scale": 1.0,
}
# -----------------------------------------------------------------------------
# Environment and device setup
# -----------------------------------------------------------------------------
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
translator = Translator()
HF_TOKEN = os.environ.get("HF_TOKEN", None)
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device.upper()}")
# -----------------------------------------------------------------------------
# Initialize the Flux pipeline and load default LoRA
# -----------------------------------------------------------------------------
pipe = FluxPipeline.from_pretrained(
config["model_id"], torch_dtype=torch.bfloat16
).to(device)
pipe.load_lora_weights(config["default_lora"], weight_name=config["default_weight_name"])
# -----------------------------------------------------------------------------
# Function to load a new LoRA model
# -----------------------------------------------------------------------------
def enable_lora(lora_add: str):
pipe.unload_lora_weights()
if not lora_add:
return gr.update(value="")
url = f"https://huggingface.co/{lora_add}/tree/main"
try:
pipe.load_lora_weights(lora_add)
return gr.update(label="LoRA Loaded Now")
except Exception as e:
raise gr.Error(f"Failed to load {lora_add}: {e}")
# -----------------------------------------------------------------------------
# Function to generate an image from a prompt
# -----------------------------------------------------------------------------
@spaces.GPU()
def generate_image(
prompt: str, lora_word: str, lora_scale: float = config["default_loRa_scale"],
width: int = config["default_width"], height: int = config["default_height"],
guidance_scale: float = config["default_guidance_scale"], steps: int = config["default_steps"],
seed: int = -1, nums: int = 1
):
pipe.to(device)
seed = random.randint(0, config["max_seed"]) if seed == -1 else int(seed)
prompt_english = str(translator.translate(prompt, "English"))
full_prompt = f"{prompt_english} {lora_word}"
generator = torch.Generator().manual_seed(seed)
result = pipe(
prompt=full_prompt, height=height, width=width, guidance_scale=guidance_scale,
output_type="pil", num_inference_steps=steps, num_images_per_prompt=nums,
generator=generator, joint_attention_kwargs={"scale": lora_scale},
)
return result.images, seed
# -----------------------------------------------------------------------------
# Gradio UI
# -----------------------------------------------------------------------------
example_prompts = [
["Medium-shot portrait, ohwx blue alien, wearing black techwear with a high collar, standing inside a futuristic VR showroom.", "ohwx", 0.9],
["ohwx blue alien, wearing black techwear with a high collar, immersed in a digital cybernetic landscape.", "ohwx", 0.9],
["full-body shot, ohwx blue alien, wearing black techwear with a high collar, black cyber sneakers, running through a neon-lit cyberpunk alley at night.", "ohwx", 0.9],
["ohwx blue alien, wearing black techwear with a high collar, sitting inside a sleek, high-tech VR capsule, immersed in an augmented reality experience.", "ohwx", 0.9]
]
with gr.Blocks(css=config["css"]) as demo:
gr.HTML("<h1><center>BR METAVERSO - Avatar Generator</center></h1>")
processing_status = gr.Markdown("**🟒 Ready**", visible=True) # Status indicator
with gr.Row():
with gr.Column(scale=4):
gallery = gr.Gallery(label="Flux Generated Image", columns=1, preview=True, height=600)
prompt_input = gr.Textbox(label="Enter Your Prompt", lines=2, placeholder="Enter prompt...")
generate_btn = gr.Button(variant="primary")
with gr.Accordion("Advanced Options", open=True):
width_slider = gr.Slider(label="Width", minimum=512, maximum=1920, step=8, value=config["default_width"])
height_slider = gr.Slider(label="Height", minimum=512, maximum=1920, step=8, value=config["default_height"])
guidance_slider = gr.Slider(label="Guidance Scale", minimum=3.5, maximum=7, step=0.1, value=config["default_guidance_scale"])
steps_slider = gr.Slider(label="Steps", minimum=1, maximum=100, step=1, value=config["default_steps"])
seed_slider = gr.Slider(label="Seed", minimum=-1, maximum=config["max_seed"], step=1, value=-1)
nums_slider = gr.Slider(label="Image Count", minimum=1, maximum=2, step=1, value=1)
lora_scale_slider = gr.Slider(label="LoRA Scale", minimum=0.1, maximum=2.0, step=0.1, value=config["default_loRa_scale"])
lora_add_text = gr.Textbox(label="Flux LoRA", lines=1, value=config["default_lora"])
lora_word_text = gr.Textbox(label="Flux LoRA Trigger Word", lines=1, value="ohwx")
load_lora_btn = gr.Button(value="Load LoRA", variant="secondary")
gr.Examples(examples=example_prompts, inputs=[prompt_input, lora_word_text, lora_scale_slider], cache_examples=False, examples_per_page=4)
# Ensuring processing status updates correctly
def update_status():
return "**⏳ Processing...**"
generate_btn.click(fn=update_status, inputs=[], outputs=[processing_status]).then(
fn=generate_image,
inputs=[prompt_input, lora_word_text, lora_scale_slider, width_slider, height_slider, guidance_slider, steps_slider, seed_slider, nums_slider],
outputs=[gallery, seed_slider]
).then(
fn=lambda: "**βœ… Done!**",
inputs=[],
outputs=[processing_status]
)
load_lora_btn.click(fn=enable_lora, inputs=[lora_add_text], outputs=lora_add_text)
demo.queue().launch()