File size: 4,722 Bytes
795c18e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
544d89c
795c18e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57e6552
 
795c18e
 
 
 
 
 
 
 
9d53ecc
544d89c
795c18e
 
 
544d89c
795c18e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f98a5ff
795c18e
 
 
 
 
 
08edc45
795c18e
57e6552
795c18e
168beef
6ddbf75
795c18e
 
 
 
 
 
 
ba007e8
795c18e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ba007e8
795c18e
 
 
 
 
 
 
 
 
3a657f3
795c18e
 
544d89c
 
795c18e
 
 
 
 
 
 
 
 
 
 
544d89c
795c18e
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import os
import io
import random
import requests
import gradio as gr
import numpy as np
from PIL import Image
import replicate


MAX_SEED = np.iinfo(np.int32).max


def predict(replicate_api, prompt, lora_id, lora_scale=0.95, aspect_ratio="1:1", seed=-1, randomize_seed=True, guidance_scale=3.5, num_inference_steps=28, progress=gr.Progress(track_tqdm=True)):

    # Validate API key and prompt
    if not replicate_api or not prompt:
        return "Error: Missing necessary inputs.", -1, None
    
    # Set the seed if randomize_seed is True
    if randomize_seed:
        seed = random.randint(0, MAX_SEED)

    # Set the Replicate API token in the environment variable
    os.environ["REPLICATE_API_TOKEN"] = replicate_api

    # Construct the input for the replicate model
    input_params = {
        "prompt": prompt,
        "output_format": "jpg",
        "aspect_ratio": aspect_ratio,
        "num_inference_steps": num_inference_steps,
        "guidance_scale": guidance_scale,
        "seed": seed,
        "disable_safety_checker": True
    }

    # If lora_id is provided, include it in the input
    if lora_id and lora_id.strip()!="":
        input_params["hf_lora"] = lora_id.strip()
        input_params["lora_scale"] = lora_scale

    try:
        # Run the model using the user's API token from the environment variable
        output = replicate.run(
            "lucataco/flux-dev-lora:a22c463f11808638ad5e2ebd582e07a469031f48dd567366fb4c6fdab91d614d",
            input=input_params
        )
        print(output,prompt)
        return output[0], seed, seed  # Return the generated image and seed

    except Exception as e:
        # Catch any exceptions, such as invalid API token or lack of credits
        return f"Error: {str(e)}", -1, None

    finally:
        # Always remove the API key from the environment
        if "REPLICATE_API_TOKEN" in os.environ:
            del os.environ["REPLICATE_API_TOKEN"]

    

demo = gr.Interface(fn=predict, inputs="text", outputs="image")

css="""
#col-container {
    margin: 0 auto;
    max-width: 520px;
}
"""

examples = [
    "a tiny astronaut hatching from an egg on the moon",
    "a cat holding a sign that says hello world",
    "an anime illustration of a wiener schnitzel",
]

with gr.Blocks(css=css) as demo:
    with gr.Column(elem_id="col-container"):
        gr.Markdown("# FLUX Dev with Replicate API")
        
        replicate_api = gr.Text(label="Replicate API Key", type='password', show_label=True, max_lines=1, placeholder="Enter your Replicate API token", container=True)
        prompt = gr.Text(label="Prompt", show_label=True, lines = 2, max_lines=4, show_copy_button = True, placeholder="Enter your prompt", container=True)
        with gr.Accordion("Advanced Settings", open=False):
            with gr.Row():
                custom_lora = gr.Textbox(label="Custom LoRA", info="LoRA Hugging Face path (optional)", placeholder="multimodalart/vintage-ads-flux")
                lora_scale = gr.Slider(
                    label="LoRA Scale",
                    minimum=0,
                    maximum=1,
                    step=0.01,
                    value=0.95,
                )
            aspect_ratio = gr.Radio(label="Aspect ratio", value="1:1", choices=["1:1", "4:5", "2:3", "3:4","9:16", "4:3", "16:9"])
            seed = gr.Slider(
                label="Seed",
                minimum=0,
                maximum=MAX_SEED,
                step=1,
                value=0,
            )

            randomize_seed = gr.Checkbox(label="Randomize seed", value=True)

            with gr.Row():
                guidance_scale = gr.Slider(
                    label="Guidance Scale",
                    minimum=1,
                    maximum=15,
                    step=0.1,
                    value=3.5,
                )
                num_inference_steps = gr.Slider(
                    label="Number of inference steps",
                    minimum=1,
                    maximum=50,
                    step=1,
                    value=28,
                )
        submit = gr.Button("Generate Image", variant="primary",scale=1)

        output = gr.Image(label="Output Image", show_label=True)

        seed_used = gr.Textbox(label="Seed Used", show_copy_button = True)
        

        gr.Examples(
            examples=examples,
            fn=predict,
            inputs=[prompt]
        )
        gr.on(
            triggers=[submit.click, prompt.submit],
            fn=predict,
            inputs=[replicate_api, prompt, custom_lora, lora_scale, aspect_ratio, seed, randomize_seed, guidance_scale, num_inference_steps],
            outputs = [output, seed, seed_used]
        )

demo.launch()