File size: 3,315 Bytes
ef187eb
 
e7915f0
0cffd40
 
ef187eb
11fa80e
63b6eaf
2b0f02c
11fa80e
0cffd40
8b1e96d
0cffd40
8b1e96d
 
0cccf69
 
8b1e96d
 
 
ec35e66
 
 
 
4efab5c
 
 
ec35e66
 
4efab5c
 
 
 
 
 
 
 
 
8b1e96d
 
e7915f0
 
8b1e96d
 
 
f286ae5
4b64a91
8b1e96d
9b38787
3a2b9b2
8b1e96d
9b38787
11fa80e
8b1e96d
 
 
 
3494613
6380dba
8b1e96d
3819ced
67399b5
1462211
fee8445
1462211
ef187eb
1462211
8b1e96d
0cffd40
 
8b3ca8d
 
 
 
 
 
 
 
 
 
0cffd40
8b1e96d
0cffd40
4efab5c
3639a4a
9b38787
8b1e96d
0cffd40
9b38787
fe16630
8b1e96d
8b3ca8d
 
 
 
 
 
fe16630
8b3ca8d
8b1e96d
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import gradio as gr
import torch
from diffusers import DiffusionPipeline, UNet2DConditionModel, LCMScheduler
from huggingface_hub import hf_hub_download
import spaces
from PIL import Image
import requests
from translatepy import Translator

translator = Translator()

# Constants
base = "stabilityai/stable-diffusion-xl-base-1.0"
repo = "tianweiy/DMD2"
checkpoints = {
    "1-Step" : ["dmd2_sdxl_1step_unet_fp16.bin", 1],
    "4-Step" : ["dmd2_sdxl_4step_unet_fp16.bin", 4],
}
loaded = None

CSS = """
.gradio-container {
  max-width: 690px !important;
}
footer {
    visibility: hidden;
}
"""

JS = """function () {
  gradioURL = window.location.href
  if (!gradioURL.endsWith('?__theme=dark')) {
    window.location.replace(gradioURL + '?__theme=dark');
  }
}"""



# Ensure model and scheduler are initialized in GPU-enabled function
if torch.cuda.is_available():
    unet = UNet2DConditionModel.from_config(base, subfolder="unet").to("cuda", torch.float16)
    pipe = DiffusionPipeline.from_pretrained(base, torch_dtype=torch.float16, variant="fp16").to("cuda")


# Function 
@spaces.GPU()
def generate_image(prompt, ckpt="4-Step"):
    global loaded
    
    prompt = str(translator.translate(prompt, 'English'))

    print(prompt)
    
    checkpoint = checkpoints[ckpt][0]
    num_inference_steps = checkpoints[ckpt][1]

    if loaded != num_inference_steps:
        pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
        pipe.unet.load_state_dict(torch.load(hf_hub_download(repo, checkpoint), map_location="cuda"))
        loaded = num_inference_steps

    if loaded == 1:
        timesteps=[399]
    else:
        timesteps=[999, 749, 499, 249]

    results = pipe(prompt, num_inference_steps=num_inference_steps, guidance_scale=0, timesteps=timesteps)
    return results.images[0]


examples = [
    "a cat eating a piece of cheese",
    "a ROBOT riding a BLUE horse on Mars, photorealistic",
    "Ironman VS Hulk, ultrarealistic",
    "a CUTE robot artist painting on an easel",
    "Astronaut in a jungle, cold color palette, oil pastel, detailed, 8k",
    "An alien holding sign board contain word 'Flash', futuristic, neonpunk",
    "Kids going to school, Anime style"
]


# Gradio Interface

with gr.Blocks(css=CSS, js=JS, theme="soft") as demo:
    gr.HTML("<h1><center>DMD2🦖</center></h1>")
    gr.HTML("<p><center><a href='https://huggingface.co/tianweiy/DMD2'>DMD2</a> text-to-image generation</center><br><center>Multi-Languages, 4-step is higher quality & 2X slower</center></p>")
    with gr.Group():
        with gr.Row():
            prompt = gr.Textbox(label='Enter Your Prompt', scale=8)
            ckpt = gr.Dropdown(label='Steps',choices=['1-Step', '4-Step'], value='4-Step', interactive=True)
            submit = gr.Button(scale=1, variant='primary')
    img = gr.Image(label='DMD2 Generated Image')    
    gr.Examples(
        examples=examples,
        inputs=prompt,
        outputs=img,
        fn=generate_image,
        cache_examples="lazy",
    )

    prompt.submit(fn=generate_image,
                 inputs=[prompt, ckpt],
                 outputs=img,
                 )
    submit.click(fn=generate_image,
                 inputs=[prompt, ckpt],
                 outputs=img,
                 )
    
demo.queue().launch()