import gradio as gr
import torch
from diffusers import StableDiffusionPipeline

# Load the pipeline
pipeline = StableDiffusionPipeline.from_pretrained(
    './Traditional_Korean_Painting_Model_2.safetensors',
).to("cuda" if torch.cuda.is_available() else "cpu")


# prompt = "a photo of an astronaut riding a horse on mars"
# image = pipe(prompt).images[0]

def generate_image(prompt):
    try:
        image = pipe(prompt).images[0]
        return image
    except Exception as e:
        print(f"Error generating image: {e}")
        return "Error generating image"

# Gradio 인터페이스 설정
with gr.Blocks() as demo:
    gr.Markdown("# Traditional Korean Painting Generator")
    gr.Markdown("Enter a prompt to generate a traditional Korean painting.")
    
    with gr.Row():
        with gr.Column():
            prompt = gr.Textbox(label="Prompt", placeholder="Describe the scene...")
            generate_btn = gr.Button("Generate")
        with gr.Column():
            output_image = gr.Image(label="Generated Image", type="pil")

    generate_btn.click(fn=generate_image, inputs=prompt, outputs=output_image)

if __name__ == "__main__":
    demo.launch()