import random import torch import gradio as gr from diffusers import ( FluxPipeline, DPMSolverMultistepScheduler, DPMSolverSDEScheduler, EulerDiscreteScheduler, EulerAncestralDiscreteScheduler, HeunDiscreteScheduler, DDIMScheduler, ) # ----------------------------------------------------------------------------- # Pipeline loading helpers # ----------------------------------------------------------------------------- def _load_pipe(hf_token: str | None = None) -> FluxPipeline: """Load the FLUX pipeline once and keep it in memory. Args: hf_token: Optional Hugging Face token if the model is gated/private. Returns: A fully‑initialised FluxPipeline with LoRA fused and memory‑saving features enabled. """ pipe = FluxPipeline.from_pretrained( "black-forest-labs/FLUX.1-dev", torch_dtype=torch.float16, use_auth_token=hf_token or None, ) # Memory optimisations ---------------------------------------------------- pipe.enable_sequential_cpu_offload() pipe.enable_attention_slicing() # LoRA -------------------------------------------------------------------- pipe.load_lora_weights( "kudzueye/boreal-flux-dev-v2", weight_name="boreal-v2.safetensors" ) pipe.fuse_lora(lora_scale=0.8) return pipe # Keep a single global instance to avoid re‑loading on every request _pipe: FluxPipeline | None = None def _get_pipe(hf_token: str | None = None) -> FluxPipeline: global _pipe if _pipe is None: _pipe = _load_pipe(hf_token) return _pipe # ----------------------------------------------------------------------------- # Scheduler mapping # ----------------------------------------------------------------------------- SCHED_MAP = { "DPM++ 2M Karras": DPMSolverMultistepScheduler, "DPM++ SDE Karras": DPMSolverSDEScheduler, "Euler": EulerDiscreteScheduler, "Euler a": EulerAncestralDiscreteScheduler, "Heun": HeunDiscreteScheduler, "DDIM": DDIMScheduler, } # ----------------------------------------------------------------------------- # Inference function # ----------------------------------------------------------------------------- def query( prompt: str, negative_prompt: str, steps: int, cfg_scale: float, sampler: str, seed: int, strength: float, # kept for future img2img support hf_token: str, ): """Run the generation and return a PIL image + the seed actually used.""" pipe = _get_pipe(hf_token or None) # Replace scheduler if the user selected a different sampler SchedulerCls = SCHED_MAP.get(sampler, DPMSolverMultistepScheduler) if not isinstance(pipe.scheduler, SchedulerCls): pipe.scheduler = SchedulerCls.from_config(pipe.scheduler.config) # Handle seed if seed == -1: seed = random.randint(0, 1_000_000_000) generator = torch.Generator(device=pipe.device).manual_seed(seed) # Run inference with torch.no_grad(): result = pipe( prompt=prompt, negative_prompt=negative_prompt, num_inference_steps=steps, guidance_scale=cfg_scale, generator=generator, height=512, width=512, ) return result.images[0], str(seed) # ----------------------------------------------------------------------------- # Gradio UI # ----------------------------------------------------------------------------- CSS = """ #app-container { max-width: 600px; margin-left: auto; margin-right: auto; } #title-container { display: flex; align-items: center; justify-content: center; } #title-icon { width: 32px; height: auto; margin-right: 10px; } #title-text { font-size: 24px; font-weight: bold; } """ with gr.Blocks(theme="Nymbo/Nymbo_Theme", css=CSS) as app: gr.HTML( """

Text-to-Image Generator App

""" ) with gr.Column(elem_id="app-container"): with gr.Row(): with gr.Column(elem_id="prompt-container"): with gr.Row(): txt_prompt = gr.Textbox( label="Prompt", placeholder="Enter a prompt here", lines=2, elem_id="prompt-text-input", ) with gr.Row(): with gr.Accordion("Advanced Settings", open=False): neg_prompt = gr.Textbox( label="Negative Prompt", placeholder="What should not be in the image", value="(deformed, distorted, disfigured), poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, (mutated hands and fingers), disconnected limbs, mutation, mutated, ugly, disgusting, blurry, amputation, misspellings, typos", lines=3, elem_id="negative-prompt-text-input", ) steps_in = gr.Slider( label="Sampling steps", value=35, minimum=1, maximum=100, step=1 ) cfg_in = gr.Slider( label="CFG Scale", value=7, minimum=1, maximum=20, step=1 ) sampler_in = gr.Radio( label="Sampling method", value="DPM++ 2M Karras", choices=list(SCHED_MAP.keys()), ) strength_in = gr.Slider( label="Strength", value=0.7, minimum=0, maximum=1, step=0.001 ) seed_in = gr.Slider( label="Seed", value=-1, minimum=-1, maximum=1_000_000_000, step=1 ) api_key_in = gr.Textbox( label="Hugging Face API Key (required for private models)", placeholder="Enter your Hugging Face API Key here", type="password", elem_id="api-key", ) with gr.Row(): run_button = gr.Button("Run", variant="primary", elem_id="gen-button") with gr.Row(): img_out = gr.Image(type="pil", label="Image Output", elem_id="gallery") seed_out = gr.Textbox(label="Seed Used", elem_id="seed-output") run_button.click( fn=query, inputs=[ txt_prompt, neg_prompt, steps_in, cfg_in, sampler_in, seed_in, strength_in, api_key_in, ], outputs=[img_out, seed_out], ) app.launch(show_api=True, share=False)