ritwikraha
fixes
2257008
raw
history blame
3.25 kB
import gradio as gr
import torch
from diffusers import DiffusionPipeline, AutoencoderKL
from PIL import Image
import spaces
# Initialize the VAE model and Diffusion Pipeline outside the GPU-enabled function for efficiency
vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
pipe = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
vae=vae,
torch_dtype=torch.float16,
variant="fp16",
use_safetensors=True
)
pipe.load_lora_weights('ritwikraha/khabib_sketch_LoRA')
if torch.cuda.is_available():
_ = pipe.to("cuda")
# Define the image generation function
@spaces.GPU(enable_queue=True)
def generate_sketch(prompt, negative_prompt="ugly face, multiple bodies, bad anatomy, disfigured, extra fingers", guidance_scale=3, num_inference_steps=50):
"""Generate a sketch image based on a prompt using Stable Diffusion XL with LoRA weights.
Args:
prompt (str): Description of the image to generate.
negative_prompt (str, optional): Negative prompt to avoid certain features. Defaults to common undesirables.
guidance_scale (int, optional): The strength of the guidance. Defaults to 3.
num_inference_steps (int, optional): The number of steps for the diffusion process. Defaults to 50.
Returns:
PIL.Image: The generated sketch image.
"""
result = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
)
return result.images[0].convert("RGB") # Ensure the image is in RGB format
# Gradio Interface
description = """
This demo utilizes the SDXL model LoRA adaption weights for stabilityai/stable-diffusion-xl-base-1.0. The weights were trained on sketches of Khabib by ritwikraha using DreamBooth and can be found here: https://huggingface.co/ritwikraha/khabib_sketch_LoRA
"""
# Setup Gradio interface
with gr.Blocks() as demo:
gr.HTML("<h1><center>Khabib Sketch Maker πŸ₯‹</center></h1>")
gr.Markdown(description)
gr.HTML("<ul><li>Remember to prompt with the identifier 'TOK', e.g., A sketch of TOK Khabib.</li><li>Sketches work best.</li><li>Lower guidance scale is preferred for simpler prompts.</li><li>Negative prompt needs specifying for good image generation.</li></ul>")
with gr.Row():
with gr.Column():
prompt_input = gr.Textbox(label="Enter your image prompt", value="a sketch of TOK khabib dancing, monchrome, pen sketch", scale=8)
negative_prompt_input = gr.Textbox(label="Enter negative prompt", value="ugly face, multiple bodies, bad anatomy, disfigured, extra fingers", lines=2)
guidance_scale_slider = gr.Slider(label="Guidance Scale", minimum=1, maximum=5, value=3)
steps_slider = gr.Slider(label="Number of Inference Steps", minimum=20, maximum=100, value=50)
submit_button = gr.Button("Submit")
with gr.Column():
output_image = gr.Image(label="Generated Sketch")
submit_button.click(
fn=generate_sketch,
inputs=[prompt_input, negative_prompt_input, guidance_scale_slider, steps_slider],
outputs=output_image
)
demo.launch()