File size: 3,599 Bytes
8ccf632
819389a
 
8ccf632
819389a
81b26b5
06f0278
819389a
 
8ccf632
 
819389a
8f879f7
 
819389a
 
 
 
8f879f7
819389a
 
 
 
 
 
 
 
6177b55
01787f6
 
6177b55
80a2167
 
6177b55
819389a
80a2167
 
01787f6
819389a
 
 
 
 
 
 
 
 
8f879f7
 
 
819389a
 
 
 
 
 
 
 
6177b55
 
819389a
 
 
 
 
 
 
 
 
 
 
 
6177b55
819389a
 
 
 
 
 
 
 
 
 
6177b55
 
819389a
8f879f7
819389a
ca256e8
6177b55
 
8ccf632
819389a
 
8f879f7
819389a
 
 
 
 
 
 
 
ca256e8
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import gradio as gr
import numpy as np
import random
import torch
import time
from diffusers import DiffusionPipeline

# Set the device and dtype
dtype = torch.float16
device = "cuda" if torch.cuda.is_available() else "cpu"

# Load the diffusion pipeline
pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=dtype)
pipe = pipe.to(device)

MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 2048

def infer(prompt, seed=42, randomize_seed=False, width=1024, height=1024, num_inference_steps=4, guidance_scale=7.5, progress=gr.Progress(track_tqdm=True)):
    start_time = time.time()

    if width > MAX_IMAGE_SIZE or height > MAX_IMAGE_SIZE:
        raise ValueError("Image size exceeds the maximum allowed dimensions.")

    if randomize_seed:
        seed = random.randint(0, MAX_SEED)
    generator = torch.Generator(device=device).manual_seed(seed)

    try:
        image = pipe(
            prompt=prompt,
            width=width,
            height=height,
            num_inference_steps=num_inference_steps,
            generator=generator,
            guidance_scale=guidance_scale
        ).images[0]
    except Exception as e:
        print(f"Error generating image: {e}")
        return None, seed, f"Error: {str(e)}"

    if time.time() - start_time > 60:
        return None, seed, "Image generation took too long and was cancelled."

    return image, seed, None

examples = [
    ["a tiny astronaut hatching from an egg on the moon"],
    ["a cat holding a sign that says hello world"],
    ["an anime illustration of a wiener schnitzel"],
]

with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown("""
    # Custom Image Creator
    12B param rectified flow transformer distilled from [FLUX.1 [pro]](https://blackforestlabs.ai/) for 4 step generation
    [[blog](https://blackforestlabs.ai/announcing-black-forest-labs/)] [[model](https://huggingface.co/black-forest-labs/FLUX.1)]
    """)

    with gr.Row():
        with gr.Column(scale=2):
            prompt = gr.Textbox(
                label="Prompt",
                placeholder="Enter your prompt",
                lines=3
            )
            run_button = gr.Button("Generate Image", variant="primary")

        with gr.Column(scale=2):
            result = gr.Image(label="Generated Image")
            seed_output = gr.Number(label="Seed Used")

    with gr.Accordion("Advanced Settings", open=False):
        seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=42)
        randomize_seed = gr.Checkbox(label="Randomize seed", value=True)

        with gr.Row():
            width = gr.Slider(label="Width", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024)
            height = gr.Slider(label="Height", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024)

        with gr.Row():
            num_inference_steps = gr.Slider(label="Number of inference steps", minimum=1, maximum=50, step=1, value=4)
            guidance_scale = gr.Slider(label="Guidance Scale", minimum=0.0, maximum=20.0, step=0.5, value=7.5)

    gr.Examples(
        examples=examples,
        inputs=[prompt],
        outputs=[result, seed_output],
        fn=infer,
        cache_examples=True
    )

    run_button.click(
        fn=infer,
        inputs=[prompt, seed, randomize_seed, width, height, num_inference_steps, guidance_scale],
        outputs=[result, seed_output]
    )

    gr.Markdown("""
    ## Save Your Image
    Right-click on the generated image and select 'Save image as' to download it.
    """)

if __name__ == "__main__":
    demo.launch()