Spaces:
Runtime error
Runtime error
import os | |
import random | |
import re | |
import requests | |
import torch | |
import numpy as np | |
import gradio as gr | |
import spaces | |
from diffusers import FluxPipeline | |
from translatepy import Translator | |
# ----------------------------------------------------------------------------- | |
# CONFIGURATION | |
# ----------------------------------------------------------------------------- | |
config = { | |
"model_id": "black-forest-labs/FLUX.1-dev", | |
"default_lora": "nftnik/BR_ohwx_V1", | |
"default_weight_name": "BR_ohwx.safetensors", | |
"max_seed": int(np.iinfo(np.int32).max), | |
"css": "footer { visibility: hidden; }", | |
"default_width": 896, | |
"default_height": 1152, | |
"default_guidance_scale": 3.5, | |
"default_steps": 35, | |
"default_loRa_scale": 1.0, | |
} | |
# ----------------------------------------------------------------------------- | |
# Environment and device setup | |
# ----------------------------------------------------------------------------- | |
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" | |
translator = Translator() | |
HF_TOKEN = os.environ.get("HF_TOKEN", None) | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
print(f"Using {device.upper()}") | |
# ----------------------------------------------------------------------------- | |
# Initialize the Flux pipeline and load default LoRA | |
# ----------------------------------------------------------------------------- | |
pipe = FluxPipeline.from_pretrained( | |
config["model_id"], torch_dtype=torch.bfloat16 | |
).to(device) | |
pipe.load_lora_weights(config["default_lora"], weight_name=config["default_weight_name"]) | |
# ----------------------------------------------------------------------------- | |
# Function to load a new LoRA model | |
# ----------------------------------------------------------------------------- | |
def enable_lora(lora_add: str): | |
pipe.unload_lora_weights() | |
if not lora_add: | |
return gr.update(value="") | |
url = f"https://huggingface.co/{lora_add}/tree/main" | |
try: | |
pipe.load_lora_weights(lora_add) | |
return gr.update(label="LoRA Loaded Now") | |
except Exception as e: | |
raise gr.Error(f"Failed to load {lora_add}: {e}") | |
# ----------------------------------------------------------------------------- | |
# Function to generate an image from a prompt | |
# ----------------------------------------------------------------------------- | |
def generate_image( | |
prompt: str, lora_word: str, lora_scale: float = config["default_loRa_scale"], | |
width: int = config["default_width"], height: int = config["default_height"], | |
guidance_scale: float = config["default_guidance_scale"], steps: int = config["default_steps"], | |
seed: int = -1, nums: int = 1 | |
): | |
pipe.to(device) | |
seed = random.randint(0, config["max_seed"]) if seed == -1 else int(seed) | |
prompt_english = str(translator.translate(prompt, "English")) | |
full_prompt = f"{prompt_english} {lora_word}" | |
generator = torch.Generator().manual_seed(seed) | |
result = pipe( | |
prompt=full_prompt, height=height, width=width, guidance_scale=guidance_scale, | |
output_type="pil", num_inference_steps=steps, num_images_per_prompt=nums, | |
generator=generator, joint_attention_kwargs={"scale": lora_scale}, | |
) | |
return result.images, seed | |
# ----------------------------------------------------------------------------- | |
# Gradio UI | |
# ----------------------------------------------------------------------------- | |
example_prompts = [ | |
["Medium-shot portrait, ohwx blue alien, wearing black techwear with a high collar, standing inside a futuristic VR showroom.", "ohwx", 0.9], | |
["ohwx blue alien, wearing black techwear with a high collar, immersed in a digital cybernetic landscape.", "ohwx", 0.9], | |
["full-body shot, ohwx blue alien, wearing black techwear with a high collar, black cyber sneakers, running through a neon-lit cyberpunk alley at night.", "ohwx", 0.9], | |
["ohwx blue alien, wearing black techwear with a high collar, sitting inside a sleek, high-tech VR capsule, immersed in an augmented reality experience.", "ohwx", 0.9] | |
] | |
with gr.Blocks(css=config["css"]) as demo: | |
gr.HTML("<h1><center>BR METAVERSO - Avatar Generator</center></h1>") | |
processing_status = gr.Markdown("**π’ Ready**", visible=True) # Status indicator | |
with gr.Row(): | |
with gr.Column(scale=4): | |
gallery = gr.Gallery(label="Flux Generated Image", columns=1, preview=True, height=600) | |
prompt_input = gr.Textbox(label="Enter Your Prompt", lines=2, placeholder="Enter prompt...") | |
generate_btn = gr.Button(variant="primary") | |
with gr.Accordion("Advanced Options", open=True): | |
width_slider = gr.Slider(label="Width", minimum=512, maximum=1920, step=8, value=config["default_width"]) | |
height_slider = gr.Slider(label="Height", minimum=512, maximum=1920, step=8, value=config["default_height"]) | |
guidance_slider = gr.Slider(label="Guidance Scale", minimum=3.5, maximum=7, step=0.1, value=config["default_guidance_scale"]) | |
steps_slider = gr.Slider(label="Steps", minimum=1, maximum=100, step=1, value=config["default_steps"]) | |
seed_slider = gr.Slider(label="Seed", minimum=-1, maximum=config["max_seed"], step=1, value=-1) | |
nums_slider = gr.Slider(label="Image Count", minimum=1, maximum=2, step=1, value=1) | |
lora_scale_slider = gr.Slider(label="LoRA Scale", minimum=0.1, maximum=2.0, step=0.1, value=config["default_loRa_scale"]) | |
lora_add_text = gr.Textbox(label="Flux LoRA", lines=1, value=config["default_lora"]) | |
lora_word_text = gr.Textbox(label="Flux LoRA Trigger Word", lines=1, value="ohwx") | |
load_lora_btn = gr.Button(value="Load LoRA", variant="secondary") | |
gr.Examples(examples=example_prompts, inputs=[prompt_input, lora_word_text, lora_scale_slider], cache_examples=False, examples_per_page=4) | |
# Ensuring processing status updates correctly | |
def update_status(): | |
return "**β³ Processing...**" | |
generate_btn.click(fn=update_status, inputs=[], outputs=[processing_status]).then( | |
fn=generate_image, | |
inputs=[prompt_input, lora_word_text, lora_scale_slider, width_slider, height_slider, guidance_slider, steps_slider, seed_slider, nums_slider], | |
outputs=[gallery, seed_slider] | |
).then( | |
fn=lambda: "**β Done!**", | |
inputs=[], | |
outputs=[processing_status] | |
) | |
load_lora_btn.click(fn=enable_lora, inputs=[lora_add_text], outputs=lora_add_text) | |
demo.queue().launch() |