CodesbyVishal's picture
Update app.py
08d1c39 verified
import gradio as gr
import torch
from diffusers import StableDiffusionXLPipeline, EulerDiscreteScheduler
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
import spaces
from PIL import Image
import importlib.util
import os
import pages.color_pref as color_pref
from text_overlay.overlay_ui import get_overlay_ui
# Safety checker toggle
SAFETY_CHECKER = True
# Constants
base = "stabilityai/stable-diffusion-xl-base-1.0"
repo = "ByteDance/SDXL-Lightning"
checkpoints = {
"1-Step": ["sdxl_lightning_1step_unet_x0.safetensors", 1],
"2-Step": ["sdxl_lightning_2step_unet.safetensors", 2],
"4-Step": ["sdxl_lightning_4step_unet.safetensors", 4],
"8-Step": ["sdxl_lightning_8step_unet.safetensors", 8],
}
loaded = None
# Load dropdown options from pages
def load_dropdown_options(filename):
file_path = os.path.join("pages", filename)
spec = importlib.util.spec_from_file_location(filename[:-3], file_path)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
return module.get_options()
# Detect device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
# Load model pipeline
pipe = StableDiffusionXLPipeline.from_pretrained(
base,
torch_dtype=torch.float16 if device == "cuda" else torch.float32,
variant="fp16" if device == "cuda" else None
).to(device)
# Safety checker setup
if SAFETY_CHECKER:
from safety_checker import StableDiffusionSafetyChecker
from transformers import CLIPFeatureExtractor
safety_checker = StableDiffusionSafetyChecker.from_pretrained(
"CompVis/stable-diffusion-safety-checker"
).to(device)
feature_extractor = CLIPFeatureExtractor.from_pretrained(
"openai/clip-vit-base-patch32"
)
def check_nsfw_images(images):
safety_input = feature_extractor(images, return_tensors="pt").to(device)
has_nsfw_concepts = safety_checker(images=[images], clip_input=safety_input.pixel_values)
return images, has_nsfw_concepts
# Generate image function
@spaces.GPU(enable_queue=True)
def generate_image(prompt, ckpt):
global loaded
checkpoint, steps = checkpoints[ckpt]
if loaded != steps:
pipe.scheduler = EulerDiscreteScheduler.from_config(
pipe.scheduler.config,
timestep_spacing="trailing",
prediction_type="sample" if steps == 1 else "epsilon"
)
pipe.unet.load_state_dict(load_file(hf_hub_download(repo, checkpoint), device=device))
loaded = steps
result = pipe(prompt, num_inference_steps=steps, guidance_scale=0)
image = result.images[0]
if SAFETY_CHECKER:
images, flags = check_nsfw_images([image])
if any(flags):
gr.Warning("⚠️ NSFW content detected. Output has been blocked.")
return Image.new("RGB", (512, 512)), gr.update(visible=False)
return image, gr.update(visible=True)
return image, gr.update(visible=True)
# Gradio UI
with gr.Blocks() as demo:
gr.HTML("<h1><center>Text to Image model</center></h1>")
gr.HTML("<p><center>LText overlay on generated Image</center></p>")
# Load dropdown values for user preferences
audience_options = load_dropdown_options("audience.py")
color_palette_options = load_dropdown_options("color_pref.py")
image_size_options = load_dropdown_options("image_size.py")
orientation_options = load_dropdown_options("orientation.py")
tone_options = load_dropdown_options("tone.py")
# Left column for inputs (layout split)
with gr.Row():
with gr.Column(scale=1):
prompt = gr.Textbox(label="Enter your prompt (English)")
audience = gr.Dropdown(label="Select Audience", choices=audience_options, value=audience_options[0], interactive=True)
color_palette = gr.Dropdown(label="Select Color Palette", choices=color_palette_options, value=color_palette_options[0], interactive=True)
color_preview = gr.HTML(value=color_pref.generate_color_preview(color_palette_options[0]), label="Palette Preview")
color_palette.change(fn=lambda p: color_pref.generate_color_preview(p), inputs=color_palette, outputs=color_preview)
image_size = gr.Dropdown(label="Select Image Size", choices=image_size_options, value=image_size_options[0], interactive=True)
orientation = gr.Dropdown(label="Select Orientation", choices=orientation_options, value=orientation_options[0], interactive=True)
tone = gr.Dropdown(label="Select Tone", choices=tone_options, value=tone_options[0], interactive=True)
ckpt = gr.Dropdown(label="Select Inference Steps", choices=list(checkpoints.keys()), value="4-Step", interactive=True)
submit = gr.Button("Generate Image", variant="primary")
def create_prompt(aud, pal, size, ori, t, p):
return f"{p}, Audience: {aud}, Color Palette: {pal}, Image Size: {size}, Orientation: {ori}, Tone: {t}"
# Canvas and image components for output
generated_image = gr.Image(label="Generated Image")
copied_image = gr.Image(label="Canvas Image", visible=False)
add_logo = gr.File(label="Upload Logo")
add_text = gr.Textbox(label="Add Text", placeholder="Type your text here")
text_position = gr.Dropdown(label="Text Position", choices=["Top-Left", "Top-Right", "Bottom-Left", "Bottom-Right", "Center"], value="Center")
font_size = gr.Dropdown(label="Font Size", choices=[str(i) for i in range(1, 51)], value="20", interactive=True)
text_color = gr.ColorPicker(label="Text Color", value="#000000")
bg_color = gr.ColorPicker(label="Background Color", value="#FFFFFF")
border_color = gr.ColorPicker(label="Border Color", value="#000000")
border_width = gr.Dropdown(label="Border Width", choices=[str(i) for i in range(1, 51)], value="5", interactive=True)
# This button will copy the image to canvas
confirm_button = gr.Button("If the image is to your liking, copy it to the canvas for text input")
# Right column for generated image and canvas interaction
with gr.Column(scale=1):
with gr.Row():
gr.HTML("<b>Generated Image</b>")
generated_image
confirm_button # Copy button below the image
with gr.Row():
gr.HTML("<b>Canvas Image</b>")
copied_image
# Connect generation button
submit.click(
fn=lambda aud, pal, size, ori, t, p, c: generate_image(create_prompt(aud, pal, size, ori, t, p), c),
inputs=[audience, color_palette, image_size, orientation, tone, prompt, ckpt],
outputs=[generated_image, copied_image]
)
# Function to confirm copying image to canvas
def confirm_copy_to_canvas(generated_image):
# Simulate copying the generated image to the canvas
return generated_image
# Button to copy image to canvas after confirmation
confirm_button.click(
fn=confirm_copy_to_canvas,
inputs=[generated_image],
outputs=[copied_image]
)
# Launch app
demo.queue().launch()