Spaces:
Running
Running
import os | |
import io | |
import random | |
import requests | |
import gradio as gr | |
import numpy as np | |
from PIL import Image | |
import replicate | |
MAX_SEED = np.iinfo(np.int32).max | |
def predict(replicate_api, prompt, lora_id, lora_scale=0.95, aspect_ratio="1:1", seed=-1, randomize_seed=True, guidance_scale=3.5, num_inference_steps=28, progress=gr.Progress(track_tqdm=True)): | |
# Validate API key and prompt | |
if not replicate_api or not prompt: | |
return "Error: Missing necessary inputs.", -1, None | |
# Set the seed if randomize_seed is True | |
if randomize_seed: | |
seed = random.randint(0, MAX_SEED) | |
# Set the Replicate API token in the environment variable | |
os.environ["REPLICATE_API_TOKEN"] = replicate_api | |
# Construct the input for the replicate model | |
input_params = { | |
"prompt": prompt, | |
"output_format": "jpg", | |
"aspect_ratio": aspect_ratio, | |
"num_inference_steps": num_inference_steps, | |
"guidance_scale": guidance_scale, | |
"seed": seed, | |
"disable_safety_checker": True | |
} | |
# If lora_id is provided, include it in the input | |
if lora_id and lora_id.strip()!="": | |
input_params["hf_lora"] = lora_id.strip() | |
input_params["lora_scale"] = lora_scale | |
try: | |
# Run the model using the user's API token from the environment variable | |
output = replicate.run( | |
"lucataco/flux-dev-lora:a22c463f11808638ad5e2ebd582e07a469031f48dd567366fb4c6fdab91d614d", | |
input=input_params | |
) | |
print(output,prompt) | |
return output[0], seed, seed # Return the generated image and seed | |
except Exception as e: | |
# Catch any exceptions, such as invalid API token or lack of credits | |
return f"Error: {str(e)}", -1, None | |
finally: | |
# Always remove the API key from the environment | |
if "REPLICATE_API_TOKEN" in os.environ: | |
del os.environ["REPLICATE_API_TOKEN"] | |
demo = gr.Interface(fn=predict, inputs="text", outputs="image") | |
css=""" | |
#col-container { | |
margin: 0 auto; | |
max-width: 520px; | |
} | |
""" | |
examples = [ | |
"a tiny astronaut hatching from an egg on the moon", | |
"a cat holding a sign that says hello world", | |
"an anime illustration of a wiener schnitzel", | |
] | |
with gr.Blocks(css=css) as demo: | |
with gr.Column(elem_id="col-container"): | |
gr.Markdown("# FLUX Dev with Replicate API") | |
replicate_api = gr.Text(label="Replicate API Key", type='password', show_label=True, max_lines=1, placeholder="Enter your Replicate API token", container=True) | |
prompt = gr.Text(label="Prompt", show_label=True, lines = 2, max_lines=4, show_copy_button = True, placeholder="Enter your prompt", container=True) | |
with gr.Accordion("Advanced Settings", open=False): | |
with gr.Row(): | |
custom_lora = gr.Textbox(label="Custom LoRA", info="LoRA Hugging Face path (optional)", placeholder="multimodalart/vintage-ads-flux") | |
lora_scale = gr.Slider( | |
label="LoRA Scale", | |
minimum=0, | |
maximum=1, | |
step=0.01, | |
value=0.95, | |
) | |
aspect_ratio = gr.Radio(label="Aspect ratio", value="1:1", choices=["1:1", "4:5", "2:3", "3:4","9:16", "4:3", "16:9"]) | |
seed = gr.Slider( | |
label="Seed", | |
minimum=0, | |
maximum=MAX_SEED, | |
step=1, | |
value=0, | |
) | |
randomize_seed = gr.Checkbox(label="Randomize seed", value=True) | |
with gr.Row(): | |
guidance_scale = gr.Slider( | |
label="Guidance Scale", | |
minimum=1, | |
maximum=15, | |
step=0.1, | |
value=3.5, | |
) | |
num_inference_steps = gr.Slider( | |
label="Number of inference steps", | |
minimum=1, | |
maximum=50, | |
step=1, | |
value=28, | |
) | |
submit = gr.Button("Generate Image", variant="primary",scale=1) | |
output = gr.Image(label="Output Image", show_label=True) | |
seed_used = gr.Textbox(label="Seed Used", show_copy_button = True) | |
gr.Examples( | |
examples=examples, | |
fn=predict, | |
inputs=[prompt] | |
) | |
gr.on( | |
triggers=[submit.click, prompt.submit], | |
fn=predict, | |
inputs=[replicate_api, prompt, custom_lora, lora_scale, aspect_ratio, seed, randomize_seed, guidance_scale, num_inference_steps], | |
outputs = [output, seed, seed_used] | |
) | |
demo.launch() |