kratadata's picture
init
ee136a1
raw
history blame
3.62 kB
import replicate
import os
import json
import gradio as gr
import requests
fish_api_key = os.getenv("FISH_API_KEY")
def load_config(config_path="config.json"):
"""Load configuration from JSON file"""
with open(config_path, 'r') as f:
return json.load(f)
config = load_config()
api_token = os.getenv("REPLICATE_API_TOKEN")
#api_token =config["replicate_api_token"]
#os.environ["REPLICATE_API_TOKEN"] = api_token
def save_image(output):
file_path = "generated_image.png"
image_url = output[0]
response = requests.get(image_url)
# Save the image
with open(file_path, "wb") as file:
file.write(response.content)
return file_path
def generate(lora_model, prompt, aspect_ratio, num_inference_steps, guidance_scale, seed, lora_scale):
selected_lora_model = config["image_model"][lora_model]["persona"]
trigger_word = config["image_model"][lora_model]["trigger_word"]
input = {
"model": "dev",
"prompt": f'{trigger_word} + " " + {prompt} + " " + in the style of + " " + {trigger_word}',
"go_fast": False,
"lora_scale": lora_scale,
"num_outputs": 1,
"aspect_ratio": aspect_ratio,
"output_format": "png",
"guidance_scale": guidance_scale,
"extra_lora_scale": lora_scale,
"num_inference_steps": num_inference_steps,
"seed": int(seed) if int(seed) != -1 else -1,
}
output = replicate.run(
selected_lora_model,
input=input
)
image_path = save_image(output)
return image_path
# Create Gradio interface
def create_interface():
with gr.Blocks(title="Image Generator") as interface:
gr.Markdown("# LoRA Image Generator")
with gr.Row():
with gr.Column():
lora_model = gr.Dropdown(
choices= ["group1", "group2", "group3", "group4", "group5"],
label="Choose your Image Model"
)
prompt = gr.Textbox(
label="Prompt",
)
aspect_ratio = gr.Dropdown(
choices=["1:1", "16:9", "9:16"],
label="Aspect Ratio",
value="1:1"
)
num_inference_steps = gr.Slider(
minimum=1,
maximum=50,
value=28,
step=1,
label="Inference Steps"
)
guidance_scale = gr.Slider(
minimum=0,
maximum=10,
value=3,
step=1,
label="Guidance Scale"
)
seed = gr.Number(
label="Seed (-1 for random)",
value=-1.0
)
lora_scale = gr.Slider(
minimum=-1.0,
maximum=3.0,
value=1.0,
step=0.05,
label="LoRA Scale"
)
generate_btn = gr.Button("Generate Image")
with gr.Column():
output_image = gr.Image(label="Generated Image")
generate_btn.click(
fn=generate,
inputs=[lora_model, prompt, aspect_ratio, num_inference_steps,
guidance_scale, seed, lora_scale],
outputs=output_image
)
return interface
if __name__ == "__main__":
interface = create_interface()
interface.launch()