File size: 4,777 Bytes
375ee53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import os

import gradio as gr
import rembg
import spaces
import torch
from diffusers import StableDiffusionXLPipeline, EulerDiscreteScheduler
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file

from src.utils.infer_util import (remove_background, resize_foreground)


# Load StableDiffusionXL model
base = "stabilityai/stable-diffusion-xl-base-1.0"
repo = "ByteDance/SDXL-Lightning"

pipe = StableDiffusionXLPipeline.from_pretrained(
    base, torch_dtype=torch.float16, variant="fp16").to("cuda")


def generate_prompt(subject, style, color_scheme, angle, lighting_type, additional_details):
    return f"A 3D cartoon render of {subject}, featuring the entire body and shape, on a transparent background. The style should be {style}, with {color_scheme} colors, emphasizing the essential features and lines. The pose should clearly showcase the full form of the {subject} from a {angle} perspective. Lighting is {lighting_type}, highlighting the volume and depth of the subject. {additional_details}. Output as a high-resolution PNG with no background."


@spaces.GPU
def generate_image(subject, style, color_scheme, angle, lighting_type, additional_details):
    checkpoint = "sdxl_lightning_8step_unet.safetensors"
    num_inference_steps = 8

    pipe.scheduler = EulerDiscreteScheduler.from_config(
        pipe.scheduler.config, timestep_spacing="trailing")
    pipe.unet.load_state_dict(
        load_file(hf_hub_download(repo, checkpoint), device="cuda"))

    prompt = generate_prompt(subject, style, color_scheme,
                             angle, lighting_type, additional_details)
    results = pipe(
        prompt, num_inference_steps=num_inference_steps, guidance_scale=0)
    return results.images[0]


def check_input_image(input_image):
    if input_image is None:
        raise gr.Error("No image selected!")


def preprocess(input_image):
    rembg_session = rembg.new_session()

    input_image = remove_background(input_image, rembg_session)
    input_image = resize_foreground(input_image, 0.85)

    return input_image


def image_generation_ui():
    with gr.Row():
        subject = gr.Textbox(label='Subject', scale=2)
        style = gr.Dropdown(
            label='Style',
            choices=['Pixar-like', 'Disney-esque', 'Anime-inspired'],
            value='Pixar-like',
            multiselect=False,
            scale=2
        )
        color_scheme = gr.Dropdown(
            label='Color Scheme',
            choices=['Vibrant', 'Pastel', 'Monochromatic', 'Black and White'],
            value='Vibrant',
            multiselect=False,
            scale=2
        )
        angle = gr.Dropdown(
            label='Angle',
            choices=['Front', 'Side', 'Three-quarter'],
            value='Front',
            multiselect=False,
            scale=2
        )
        lighting_type = gr.Dropdown(
            label='Lighting Type',
            choices=['Bright and Even', 'Dramatic Shadows', 'Soft and Warm'],
            value='Bright and Even',
            multiselect=False,
            scale=2
        )
        additional_details = gr.Textbox(label='Additional Details', scale=2)
        submit_prompt = gr.Button('Generate Image', scale=1, variant='primary')

    with gr.Row(variant="panel"):
        with gr.Column():
            with gr.Row():
                input_image = gr.Image(
                    label="Input Image",
                    image_mode="RGBA",
                    sources="upload",
                    type="pil",
                    elem_id="content_image",
                )
                processed_image = gr.Image(
                    label="Processed Image",
                    image_mode="RGBA",
                    type="pil",
                    interactive=False
                )
            with gr.Row():
                submit_process = gr.Button(
                    "Remove Background", elem_id="process", variant="primary")
            with gr.Row(variant="panel"):
                gr.Examples(
                    examples=[os.path.join("examples", img_name)
                              for img_name in sorted(os.listdir("examples"))],
                    inputs=[input_image],
                    label="Examples",
                    cache_examples=False,
                    examples_per_page=16
                )

    submit_prompt.click(fn=generate_image, inputs=[subject, style, color_scheme, angle, lighting_type, additional_details], outputs=input_image).success(
        fn=preprocess, inputs=[input_image], outputs=[processed_image]
    )
    submit_process.click(fn=check_input_image, inputs=[input_image]).success(
        fn=preprocess, inputs=[input_image], outputs=[processed_image],
    )

    return input_image, processed_image