|
|
import gradio as gr |
|
|
import torch |
|
|
from diffusers import StableDiffusionXLPipeline, EulerDiscreteScheduler |
|
|
from huggingface_hub import hf_hub_download |
|
|
from safetensors.torch import load_file |
|
|
import spaces |
|
|
from PIL import Image |
|
|
import importlib.util |
|
|
import os |
|
|
import pages.color_pref as color_pref |
|
|
from text_overlay.overlay_ui import get_overlay_ui |
|
|
|
|
|
|
|
|
SAFETY_CHECKER = True |
|
|
|
|
|
|
|
|
base = "stabilityai/stable-diffusion-xl-base-1.0" |
|
|
repo = "ByteDance/SDXL-Lightning" |
|
|
checkpoints = { |
|
|
"1-Step": ["sdxl_lightning_1step_unet_x0.safetensors", 1], |
|
|
"2-Step": ["sdxl_lightning_2step_unet.safetensors", 2], |
|
|
"4-Step": ["sdxl_lightning_4step_unet.safetensors", 4], |
|
|
"8-Step": ["sdxl_lightning_8step_unet.safetensors", 8], |
|
|
} |
|
|
loaded = None |
|
|
|
|
|
|
|
|
def load_dropdown_options(filename): |
|
|
file_path = os.path.join("pages", filename) |
|
|
spec = importlib.util.spec_from_file_location(filename[:-3], file_path) |
|
|
module = importlib.util.module_from_spec(spec) |
|
|
spec.loader.exec_module(module) |
|
|
return module.get_options() |
|
|
|
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
print(f"Using device: {device}") |
|
|
|
|
|
|
|
|
pipe = StableDiffusionXLPipeline.from_pretrained( |
|
|
base, |
|
|
torch_dtype=torch.float16 if device == "cuda" else torch.float32, |
|
|
variant="fp16" if device == "cuda" else None |
|
|
).to(device) |
|
|
|
|
|
|
|
|
if SAFETY_CHECKER: |
|
|
from safety_checker import StableDiffusionSafetyChecker |
|
|
from transformers import CLIPFeatureExtractor |
|
|
|
|
|
safety_checker = StableDiffusionSafetyChecker.from_pretrained( |
|
|
"CompVis/stable-diffusion-safety-checker" |
|
|
).to(device) |
|
|
|
|
|
feature_extractor = CLIPFeatureExtractor.from_pretrained( |
|
|
"openai/clip-vit-base-patch32" |
|
|
) |
|
|
|
|
|
def check_nsfw_images(images): |
|
|
safety_input = feature_extractor(images, return_tensors="pt").to(device) |
|
|
has_nsfw_concepts = safety_checker(images=[images], clip_input=safety_input.pixel_values) |
|
|
return images, has_nsfw_concepts |
|
|
|
|
|
|
|
|
@spaces.GPU(enable_queue=True) |
|
|
def generate_image(prompt, ckpt): |
|
|
global loaded |
|
|
checkpoint, steps = checkpoints[ckpt] |
|
|
|
|
|
if loaded != steps: |
|
|
pipe.scheduler = EulerDiscreteScheduler.from_config( |
|
|
pipe.scheduler.config, |
|
|
timestep_spacing="trailing", |
|
|
prediction_type="sample" if steps == 1 else "epsilon" |
|
|
) |
|
|
pipe.unet.load_state_dict(load_file(hf_hub_download(repo, checkpoint), device=device)) |
|
|
loaded = steps |
|
|
|
|
|
result = pipe(prompt, num_inference_steps=steps, guidance_scale=0) |
|
|
image = result.images[0] |
|
|
|
|
|
if SAFETY_CHECKER: |
|
|
images, flags = check_nsfw_images([image]) |
|
|
if any(flags): |
|
|
gr.Warning("⚠️ NSFW content detected. Output has been blocked.") |
|
|
return Image.new("RGB", (512, 512)), gr.update(visible=False) |
|
|
return image, gr.update(visible=True) |
|
|
|
|
|
return image, gr.update(visible=True) |
|
|
|
|
|
|
|
|
with gr.Blocks() as demo: |
|
|
gr.HTML("<h1><center>Text to Image model</center></h1>") |
|
|
gr.HTML("<p><center>LText overlay on generated Image</center></p>") |
|
|
|
|
|
|
|
|
audience_options = load_dropdown_options("audience.py") |
|
|
color_palette_options = load_dropdown_options("color_pref.py") |
|
|
image_size_options = load_dropdown_options("image_size.py") |
|
|
orientation_options = load_dropdown_options("orientation.py") |
|
|
tone_options = load_dropdown_options("tone.py") |
|
|
|
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
prompt = gr.Textbox(label="Enter your prompt (English)") |
|
|
audience = gr.Dropdown(label="Select Audience", choices=audience_options, value=audience_options[0], interactive=True) |
|
|
|
|
|
color_palette = gr.Dropdown(label="Select Color Palette", choices=color_palette_options, value=color_palette_options[0], interactive=True) |
|
|
color_preview = gr.HTML(value=color_pref.generate_color_preview(color_palette_options[0]), label="Palette Preview") |
|
|
color_palette.change(fn=lambda p: color_pref.generate_color_preview(p), inputs=color_palette, outputs=color_preview) |
|
|
|
|
|
image_size = gr.Dropdown(label="Select Image Size", choices=image_size_options, value=image_size_options[0], interactive=True) |
|
|
orientation = gr.Dropdown(label="Select Orientation", choices=orientation_options, value=orientation_options[0], interactive=True) |
|
|
tone = gr.Dropdown(label="Select Tone", choices=tone_options, value=tone_options[0], interactive=True) |
|
|
ckpt = gr.Dropdown(label="Select Inference Steps", choices=list(checkpoints.keys()), value="4-Step", interactive=True) |
|
|
submit = gr.Button("Generate Image", variant="primary") |
|
|
|
|
|
def create_prompt(aud, pal, size, ori, t, p): |
|
|
return f"{p}, Audience: {aud}, Color Palette: {pal}, Image Size: {size}, Orientation: {ori}, Tone: {t}" |
|
|
|
|
|
|
|
|
generated_image = gr.Image(label="Generated Image") |
|
|
copied_image = gr.Image(label="Canvas Image", visible=False) |
|
|
|
|
|
add_logo = gr.File(label="Upload Logo") |
|
|
add_text = gr.Textbox(label="Add Text", placeholder="Type your text here") |
|
|
text_position = gr.Dropdown(label="Text Position", choices=["Top-Left", "Top-Right", "Bottom-Left", "Bottom-Right", "Center"], value="Center") |
|
|
font_size = gr.Dropdown(label="Font Size", choices=[str(i) for i in range(1, 51)], value="20", interactive=True) |
|
|
text_color = gr.ColorPicker(label="Text Color", value="#000000") |
|
|
bg_color = gr.ColorPicker(label="Background Color", value="#FFFFFF") |
|
|
border_color = gr.ColorPicker(label="Border Color", value="#000000") |
|
|
border_width = gr.Dropdown(label="Border Width", choices=[str(i) for i in range(1, 51)], value="5", interactive=True) |
|
|
|
|
|
|
|
|
confirm_button = gr.Button("If the image is to your liking, copy it to the canvas for text input") |
|
|
|
|
|
|
|
|
with gr.Column(scale=1): |
|
|
with gr.Row(): |
|
|
gr.HTML("<b>Generated Image</b>") |
|
|
generated_image |
|
|
|
|
|
confirm_button |
|
|
|
|
|
with gr.Row(): |
|
|
gr.HTML("<b>Canvas Image</b>") |
|
|
copied_image |
|
|
|
|
|
|
|
|
submit.click( |
|
|
fn=lambda aud, pal, size, ori, t, p, c: generate_image(create_prompt(aud, pal, size, ori, t, p), c), |
|
|
inputs=[audience, color_palette, image_size, orientation, tone, prompt, ckpt], |
|
|
outputs=[generated_image, copied_image] |
|
|
) |
|
|
|
|
|
|
|
|
def confirm_copy_to_canvas(generated_image): |
|
|
|
|
|
return generated_image |
|
|
|
|
|
|
|
|
confirm_button.click( |
|
|
fn=confirm_copy_to_canvas, |
|
|
inputs=[generated_image], |
|
|
outputs=[copied_image] |
|
|
) |
|
|
|
|
|
|
|
|
demo.queue().launch() |
|
|
|