| import gradio as gr |
| import torch |
| import spaces |
| import os |
| from diffusers import FluxPipeline |
| from safetensors.torch import load_file |
| from huggingface_hub import hf_hub_download |
|
|
| |
| hf_token = os.environ.get("HF_TOKEN") |
|
|
| |
| pipe = FluxPipeline.from_pretrained( |
| 'black-forest-labs/FLUX.1-dev', |
| torch_dtype=torch.bfloat16, |
| use_safetensors=True, |
| token=hf_token |
| ).to('cuda') |
|
|
| |
| srpo_path = hf_hub_download( |
| repo_id="tencent/SRPO", |
| filename="diffusion_pytorch_model.safetensors" |
| ) |
| state_dict = load_file(srpo_path) |
| pipe.transformer.load_state_dict(state_dict) |
|
|
| @spaces.GPU(duration=120) |
| def generate_image( |
| prompt, |
| width=1024, |
| height=1024, |
| guidance_scale=3.5, |
| num_inference_steps=50, |
| seed=-1 |
| ): |
| if seed == -1: |
| seed = torch.randint(0, 2**32, (1,)).item() |
| |
| generator = torch.Generator(device='cuda').manual_seed(seed) |
| |
| image = pipe( |
| prompt=prompt, |
| guidance_scale=guidance_scale, |
| height=height, |
| width=width, |
| num_inference_steps=num_inference_steps, |
| max_sequence_length=512, |
| generator=generator |
| ).images[0] |
| |
| return image, seed |
|
|
| with gr.Blocks(title="FLUX SRPO Text-to-Image", theme=gr.themes.Soft(primary_hue="blue", secondary_hue="gray", neutral_hue="slate")) as demo: |
| gr.Markdown("# Flux SRPO") |
| gr.Markdown("Generate images using FLUX model enhanced with Tencent's [SRPO](https://github.com/Tencent-Hunyuan/SRPO) technique") |
| gr.Markdown("Built with [AnyCoder](https://huggingface.co/spaces/akhaliq/anycoder)") |
| |
| output_image = gr.Image(label="Generated Image", type="pil") |
| |
| prompt = gr.Textbox( |
| label="Prompt", |
| placeholder="Describe the image you want to generate...", |
| lines=3 |
| ) |
| |
| generate_btn = gr.Button("Generate Image", variant="primary", size="lg") |
| |
| with gr.Accordion("Advanced Settings", open=False): |
| with gr.Row(): |
| width = gr.Slider( |
| minimum=256, |
| maximum=2048, |
| value=1024, |
| step=64, |
| label="Width" |
| ) |
| height = gr.Slider( |
| minimum=256, |
| maximum=2048, |
| value=1024, |
| step=64, |
| label="Height" |
| ) |
| |
| with gr.Row(): |
| guidance_scale = gr.Slider( |
| minimum=1.0, |
| maximum=20.0, |
| value=3.5, |
| step=0.5, |
| label="Guidance Scale" |
| ) |
| num_inference_steps = gr.Slider( |
| minimum=10, |
| maximum=100, |
| value=50, |
| step=5, |
| label="Inference Steps" |
| ) |
| |
| seed = gr.Number( |
| label="Seed (-1 for random)", |
| value=-1, |
| precision=0 |
| ) |
| |
| used_seed = gr.Number(label="Seed Used", precision=0) |
| |
| gr.Examples( |
| examples=[ |
| ["The Death of Ophelia by John Everett Millais, Pre-Raphaelite painting, Ophelia floating in a river surrounded by flowers, detailed natural elements, melancholic and tragic atmosphere"], |
| ["A serene Japanese garden with cherry blossoms, koi pond, traditional wooden bridge, soft morning light, photorealistic"], |
| ["Cyberpunk cityscape at night, neon lights, flying cars, rain-slicked streets, blade runner aesthetic, highly detailed"], |
| ["Portrait of a majestic lion in golden hour light, detailed fur texture, intense gaze, African savanna background"], |
| ["Abstract colorful explosion of paint in water, high speed photography, vibrant colors mixing, dramatic lighting"], |
| ], |
| inputs=prompt, |
| label="Example Prompts" |
| ) |
| |
| generate_btn.click( |
| fn=generate_image, |
| inputs=[prompt, width, height, guidance_scale, num_inference_steps, seed], |
| outputs=[output_image, used_seed] |
| ) |
|
|
| demo.launch() |