Spaces:
Runtime error
Runtime error
File size: 4,777 Bytes
375ee53 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 |
import os
import gradio as gr
import rembg
import spaces
import torch
from diffusers import StableDiffusionXLPipeline, EulerDiscreteScheduler
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
from src.utils.infer_util import (remove_background, resize_foreground)
# Load StableDiffusionXL model
base = "stabilityai/stable-diffusion-xl-base-1.0"
repo = "ByteDance/SDXL-Lightning"
pipe = StableDiffusionXLPipeline.from_pretrained(
base, torch_dtype=torch.float16, variant="fp16").to("cuda")
def generate_prompt(subject, style, color_scheme, angle, lighting_type, additional_details):
return f"A 3D cartoon render of {subject}, featuring the entire body and shape, on a transparent background. The style should be {style}, with {color_scheme} colors, emphasizing the essential features and lines. The pose should clearly showcase the full form of the {subject} from a {angle} perspective. Lighting is {lighting_type}, highlighting the volume and depth of the subject. {additional_details}. Output as a high-resolution PNG with no background."
@spaces.GPU
def generate_image(subject, style, color_scheme, angle, lighting_type, additional_details):
checkpoint = "sdxl_lightning_8step_unet.safetensors"
num_inference_steps = 8
pipe.scheduler = EulerDiscreteScheduler.from_config(
pipe.scheduler.config, timestep_spacing="trailing")
pipe.unet.load_state_dict(
load_file(hf_hub_download(repo, checkpoint), device="cuda"))
prompt = generate_prompt(subject, style, color_scheme,
angle, lighting_type, additional_details)
results = pipe(
prompt, num_inference_steps=num_inference_steps, guidance_scale=0)
return results.images[0]
def check_input_image(input_image):
if input_image is None:
raise gr.Error("No image selected!")
def preprocess(input_image):
rembg_session = rembg.new_session()
input_image = remove_background(input_image, rembg_session)
input_image = resize_foreground(input_image, 0.85)
return input_image
def image_generation_ui():
with gr.Row():
subject = gr.Textbox(label='Subject', scale=2)
style = gr.Dropdown(
label='Style',
choices=['Pixar-like', 'Disney-esque', 'Anime-inspired'],
value='Pixar-like',
multiselect=False,
scale=2
)
color_scheme = gr.Dropdown(
label='Color Scheme',
choices=['Vibrant', 'Pastel', 'Monochromatic', 'Black and White'],
value='Vibrant',
multiselect=False,
scale=2
)
angle = gr.Dropdown(
label='Angle',
choices=['Front', 'Side', 'Three-quarter'],
value='Front',
multiselect=False,
scale=2
)
lighting_type = gr.Dropdown(
label='Lighting Type',
choices=['Bright and Even', 'Dramatic Shadows', 'Soft and Warm'],
value='Bright and Even',
multiselect=False,
scale=2
)
additional_details = gr.Textbox(label='Additional Details', scale=2)
submit_prompt = gr.Button('Generate Image', scale=1, variant='primary')
with gr.Row(variant="panel"):
with gr.Column():
with gr.Row():
input_image = gr.Image(
label="Input Image",
image_mode="RGBA",
sources="upload",
type="pil",
elem_id="content_image",
)
processed_image = gr.Image(
label="Processed Image",
image_mode="RGBA",
type="pil",
interactive=False
)
with gr.Row():
submit_process = gr.Button(
"Remove Background", elem_id="process", variant="primary")
with gr.Row(variant="panel"):
gr.Examples(
examples=[os.path.join("examples", img_name)
for img_name in sorted(os.listdir("examples"))],
inputs=[input_image],
label="Examples",
cache_examples=False,
examples_per_page=16
)
submit_prompt.click(fn=generate_image, inputs=[subject, style, color_scheme, angle, lighting_type, additional_details], outputs=input_image).success(
fn=preprocess, inputs=[input_image], outputs=[processed_image]
)
submit_process.click(fn=check_input_image, inputs=[input_image]).success(
fn=preprocess, inputs=[input_image], outputs=[processed_image],
)
return input_image, processed_image
|