|
import replicate |
|
import os |
|
import json |
|
import gradio as gr |
|
import requests |
|
|
|
fish_api_key = os.getenv("FISH_API_KEY") |
|
|
|
def load_config(config_path="config.json"): |
|
"""Load configuration from JSON file""" |
|
with open(config_path, 'r') as f: |
|
return json.load(f) |
|
|
|
config = load_config() |
|
api_token = os.getenv("REPLICATE_API_TOKEN") |
|
|
|
|
|
|
|
|
|
def save_image(output): |
|
file_path = "generated_image.png" |
|
image_url = output[0] |
|
response = requests.get(image_url) |
|
|
|
with open(file_path, "wb") as file: |
|
file.write(response.content) |
|
|
|
return file_path |
|
|
|
def generate(lora_model, prompt, aspect_ratio, num_inference_steps, guidance_scale, seed, lora_scale): |
|
selected_lora_model = config["image_model"][lora_model]["persona"] |
|
trigger_word = config["image_model"][lora_model]["trigger_word"] |
|
|
|
input = { |
|
"model": "dev", |
|
"prompt": f'{trigger_word} + " " + {prompt} + " " + in the style of + " " + {trigger_word}', |
|
"go_fast": False, |
|
"lora_scale": lora_scale, |
|
"num_outputs": 1, |
|
"aspect_ratio": aspect_ratio, |
|
"output_format": "png", |
|
"guidance_scale": guidance_scale, |
|
"extra_lora_scale": lora_scale, |
|
"num_inference_steps": num_inference_steps, |
|
"seed": int(seed) if int(seed) != -1 else -1, |
|
|
|
} |
|
|
|
output = replicate.run( |
|
selected_lora_model, |
|
input=input |
|
) |
|
|
|
image_path = save_image(output) |
|
return image_path |
|
|
|
|
|
def create_interface(): |
|
|
|
with gr.Blocks(title="Image Generator") as interface: |
|
gr.Markdown("# LoRA Image Generator") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
lora_model = gr.Dropdown( |
|
choices= ["group1", "group2", "group3", "group4", "group5"], |
|
label="Choose your Image Model" |
|
) |
|
prompt = gr.Textbox( |
|
label="Prompt", |
|
|
|
) |
|
aspect_ratio = gr.Dropdown( |
|
choices=["1:1", "16:9", "9:16"], |
|
label="Aspect Ratio", |
|
value="1:1" |
|
) |
|
num_inference_steps = gr.Slider( |
|
minimum=1, |
|
maximum=50, |
|
value=28, |
|
step=1, |
|
label="Inference Steps" |
|
) |
|
guidance_scale = gr.Slider( |
|
minimum=0, |
|
maximum=10, |
|
value=3, |
|
step=1, |
|
label="Guidance Scale" |
|
) |
|
seed = gr.Number( |
|
label="Seed (-1 for random)", |
|
value=-1.0 |
|
) |
|
lora_scale = gr.Slider( |
|
minimum=-1.0, |
|
maximum=3.0, |
|
value=1.0, |
|
step=0.05, |
|
label="LoRA Scale" |
|
) |
|
generate_btn = gr.Button("Generate Image") |
|
|
|
with gr.Column(): |
|
output_image = gr.Image(label="Generated Image") |
|
|
|
generate_btn.click( |
|
fn=generate, |
|
inputs=[lora_model, prompt, aspect_ratio, num_inference_steps, |
|
guidance_scale, seed, lora_scale], |
|
outputs=output_image |
|
) |
|
|
|
return interface |
|
|
|
if __name__ == "__main__": |
|
interface = create_interface() |
|
interface.launch() |