File size: 14,001 Bytes
5436b58
 
10c79ab
5436b58
 
 
 
 
 
7824fd4
8ec7aea
5436b58
 
 
 
 
 
 
655ce6f
5436b58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10c79ab
 
5436b58
 
 
10c79ab
 
 
 
 
5436b58
 
10c79ab
5436b58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7824fd4
45c5cd3
 
5436b58
 
10c79ab
5436b58
 
 
 
b79dfdb
10c79ab
8ec7aea
 
843fd20
5436b58
 
 
 
 
 
10c79ab
7824fd4
 
 
 
 
 
 
 
 
5436b58
 
 
 
 
 
7824fd4
 
10c79ab
7824fd4
5436b58
10c79ab
5436b58
 
 
 
 
 
 
 
 
10c79ab
 
 
 
 
 
 
 
 
 
 
 
 
 
5436b58
 
 
 
 
 
 
18949ce
5436b58
 
 
 
 
 
10c79ab
5436b58
 
 
 
 
 
781d7f4
5436b58
 
 
10c79ab
 
5436b58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4466817
10c79ab
4466817
 
 
1b8a444
4466817
 
 
 
 
 
 
 
 
 
10c79ab
1b8a444
10c79ab
1b8a444
4466817
 
 
18949ce
4466817
 
1b8a444
4466817
10c79ab
 
 
 
 
 
 
 
 
 
 
 
 
5436b58
 
 
 
 
18949ce
5436b58
 
 
 
 
 
 
4466817
5436b58
10c79ab
 
 
 
5436b58
 
 
 
 
 
 
 
10c79ab
 
5436b58
 
 
 
 
18949ce
5857783
 
5436b58
18949ce
1b8a444
 
5436b58
 
 
 
 
 
10c79ab
5436b58
 
1b8a444
5436b58
10c79ab
5436b58
 
 
 
 
 
 
 
5a6d7ea
5436b58
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
import os

import argparse
import gradio as gr
import numpy as np
import torch
import torchvision.transforms as T

from clip_interrogator import Config, Interrogator
from diffusers import StableDiffusionPipeline
from transformers import file_utils

from ditail import DitailDemo, seed_everything

BASE_MODEL = {
    'sd1.5': 'runwayml/stable-diffusion-v1-5',
    'realistic vision': 'stablediffusionapi/realistic-vision-v51',
    'pastel mix (anime)': 'stablediffusionapi/pastel-mix-stylized-anime',
    # 'chaos (abstract)': 'MAPS-research/Chaos3.0',
}

# LoRA trigger words
LORA_TRIGGER_WORD = {
    'none': [],
    'film': ['film overlay', 'film grain'],
    'snow': ['snow'],
    'flat': ['sdh', 'flat illustration'],
    'minecraft': ['minecraft square style', 'cg, computer graphics'],
    'animeoutline': ['lineart', 'monochrome'],
    'impressionism': ['impressionist', 'in the style of Monet'],
    'pop': ['POP ART'],
    'shinkai_makoto': ['shinkai makoto', 'kimi no na wa.', 'tenki no ko', 'kotonoha no niwa'],
}

METADATA_TO_SHOW = ['inv_model', 'spl_model', 'lora', 'lora_scale', 'inv_steps', 'spl_steps', 'pos_prompt', 'alpha', 'neg_prompt', 'beta', 'omega']


class WebApp():
    def __init__(self, debug_mode=False):
        if torch.cuda.is_available():
            self.device = "cuda"
        else:
            self.device = "cpu"

        self.args_base = {
            "seed": 42,
            "device": self.device,
            "output_dir": "output_demo",
            "caption_model_name": "blip-large",
            "clip_model_name": "ViT-L-14/openai",
            "inv_model": "stablediffusionapi/realistic-vision-v51",
            "spl_model": "runwayml/stable-diffusion-v1-5",
            "inv_steps": 50,
            "spl_steps": 50,
            "img": None,
            "pos_prompt": '',
            "neg_prompt": 'worst quality, blurry, NSFW',
            "alpha": 3.0,
            "beta": 0.5,
            "omega": 15,
            "mask": None,
            "lora": "none",
            "lora_dir": "./ditail/lora",
            "lora_scale": 0.7,
            "no_injection": False,
        }

        self.args_input = {} # for gr.components only
        self.gr_loras = list(LORA_TRIGGER_WORD.keys())

        self.gtag = os.environ.get('GTag')

        self.ga_script = f"""
            <script async src="https://www.googletagmanager.com/gtag/js?id={self.gtag}"></script>
            """
        self.ga_load = f"""
            function() {{
                window.dataLayer = window.dataLayer || [];
                function gtag(){{dataLayer.push(arguments);}}
                gtag('js', new Date());

                gtag('config', '{self.gtag}');
            }}
            """
        
        # # pre-download base model for better user experience
        # self._preload_pipeline()

        self.debug_mode = debug_mode # turn off clip interrogator when debugging for faster building speed
        if not self.debug_mode and self.device=="cuda":
            self.init_interrogator()


    def init_interrogator(self):
        cache_path = os.environ.get('HF_HOME')
        # print(f"Intended cache dir: {cache_path}")
        config = Config()
        config.cache_path = cache_path
        config.clip_model_path = cache_path
        config.clip_model_name = self.args_base['clip_model_name']
        config.caption_model_name = self.args_base['caption_model_name']
        self.ci = Interrogator(config)
        self.ci.config.chunk_size = 2048 if self.ci.config.clip_model_name == "ViT-L-14/openai" else 1024
        self.ci.config.flavor_intermediate_count = 2048 if self.ci.config.clip_model_name == "ViT-L-14/openai" else 1024

        # print(f"HF cache dir: {file_utils.default_cache_path}")

    def _preload_pipeline(self):
        for model in BASE_MODEL.values():
            pipe = StableDiffusionPipeline.from_pretrained(
                model, torch_dtype=torch.float16
            ).to(self.args_base['device'])
        pipe = None


    def title(self):
        gr.HTML(
                """
                <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
                <div>
                    <h1 >Diffusion Cocktail 🍸: Fused Generation from Diffusion Models</h1>
                    <div style="display: flex; justify-content: center; align-items: center; text-align: center; margin: 20px; gap: 10px;">
                        <a class="flex-item" href="https://arxiv.org/abs/2312.08873" target="_blank">
                            <img src="https://img.shields.io/badge/arXiv-Paper-darkred.svg" alt="arXiv Paper">
                        </a>                      
                        <a class="flex-item" href="https://MAPS-research.github.io/Ditail" target="_blank">
                            <img src="https://img.shields.io/badge/Website-Ditail-yellow.svg" alt="Project Page">
                        </a>
                        <a class="flex-item" href="https://github.com/MAPS-research/Ditail" target="_blank">
                            <img src="https://img.shields.io/badge/Github-Code-green.svg" alt="GitHub Code">
                        </a>
                    </div>
                </div>
                </div>
                """
                )
    

    def device_requirements(self):
        gr.Markdown(
            """
            <center>
            <h2>
            Attention: The demo doesn't work in this space running on CPU only. \
            Please duplicate and upgrade to a private "T4 medium" GPU.
            </h2>
            </center>
            """
        )
        gr.DuplicateButton(size='lg', scale=1, variant='primary')

    def get_image(self):
        self.args_input['img'] = gr.Image(label='content image', type='pil', show_share_button=False, elem_classes="input_image")
    
    def get_prompts(self):
        generate_prompt = gr.Checkbox(label='generate prompt with clip', value=True)
        self.args_input['pos_prompt'] = gr.Textbox(label='prompt')
             
        # event listeners
        self.args_input['img'].upload(self._interrogate_image, inputs=[self.args_input['img'], generate_prompt], outputs=[self.args_input['pos_prompt']])
        generate_prompt.change(self._interrogate_image, inputs=[self.args_input['img'], generate_prompt], outputs=[self.args_input['pos_prompt']])


    def _interrogate_image(self, image, generate_prompt):
        if hasattr(self, 'ci') and image is not None and generate_prompt:
            return self.ci.interrogate_fast(image).split(',')[0].replace('arafed', '')
        else:
            return ''
        

    def get_base_model(self):
        self.args_input['spl_model'] = gr.Radio(choices=list(BASE_MODEL.keys()), value=list(BASE_MODEL.keys())[2], label='target base model')

    def get_lora(self, num_cols=3):
        self.args_input['lora'] = gr.State('none')
        self.lora_gallery = gr.Gallery(label='target LoRA (optional)', columns=num_cols, value=[(os.path.join(self.args_base['lora_dir'], f"{lora}.jpeg"), lora) for lora in self.gr_loras], allow_preview=False, show_share_button=False)
        self.lora_gallery.select(self._update_lora_selection, inputs=[], outputs=[self.args_input['lora']])
    
    def _update_lora_selection(self, selected_state: gr.SelectData):
        return self.gr_loras[selected_state.index]

    def get_params(self):
        with gr.Row():
            with gr.Column():
                self.args_input['inv_model'] = gr.Radio(choices=list(BASE_MODEL.keys()), value=list(BASE_MODEL.keys())[1], label='inversion base model')
                self.args_input['neg_prompt'] = gr.Textbox(label='negative prompt', value=self.args_base['neg_prompt'])
                self.args_input['alpha'] = gr.Number(label='positive prompt scaling weight (alpha)', value=self.args_base['alpha'], interactive=True)
                self.args_input['beta'] = gr.Number(label='negative prompt scaling weight (beta)', value=self.args_base['beta'], interactive=True)

            with gr.Column():
                self.args_input['omega'] = gr.Slider(label='cfg', value=self.args_base['omega'], maximum=25, interactive=True)
                
                self.args_input['inv_steps'] = gr.Slider(minimum=1, maximum=100, label='edit steps', interactive=True, value=self.args_base['inv_steps'], step=1)
                self.args_input['spl_steps'] = gr.Slider(minimum=1, maximum=100, label='sample steps', interactive=False, value=self.args_base['spl_steps'], step=1, visible=False)
                # sync inv_steps with spl_steps
                self.args_input['inv_steps'].change(lambda x: x, inputs=self.args_input['inv_steps'], outputs=self.args_input['spl_steps'])

                self.args_input['lora_scale'] = gr.Slider(minimum=0, maximum=1, label='LoRA scale', value=0.7)
                self.args_input['seed'] = gr.Number(label='seed', value=self.args_base['seed'], interactive=True, precision=0, step=1)

    def run_ditail(self, *values):
        gr_args = self.args_base.copy()
        # print(self.args_input.keys())
        for k, v in zip(list(self.args_input.keys()), values):
            gr_args[k] = v
        # quick fix for example
        gr_args['lora'] = 'none' if not isinstance(gr_args['lora'], str) else gr_args['lora']
        print('selected lora: ', gr_args['lora'])
        # map inversion model to url
        gr_args['pos_prompt'] = ', '.join(LORA_TRIGGER_WORD.get(gr_args['lora'], [])+[gr_args['pos_prompt']])
        gr_args['inv_model'] = BASE_MODEL[gr_args['inv_model']]
        gr_args['spl_model'] = BASE_MODEL[gr_args['spl_model']]
        print('selected model: ', gr_args['inv_model'], gr_args['spl_model'])

        seed_everything(gr_args['seed'])
        ditail = DitailDemo(gr_args)
        

        args_to_show = {}
        for key in METADATA_TO_SHOW:
            args_to_show[key] = gr_args[key]

        img = ditail.run_ditail()

        # reset ditail to free memory usage
        ditail = None

        return img, args_to_show

    # def run_example(self, img, prompt, inv_model, spl_model, lora):
        # return self.run_ditail(img, prompt, spl_model, gr.State(lora), inv_model)
    def run_example(self, *values):
        gr_args = self.args_base.copy()
        for k, v in zip(['img', 'pos_prompt', 'inv_model', 'spl_model', 'lora'], values):
            gr_args[k] = v
        args_to_show = {}
        for key in METADATA_TO_SHOW:
            args_to_show[key] = gr_args[key]
        img = os.path.join(os.path.dirname(__file__), "example", "Cocktail_impression.jpg")
        # self.lora_gallery.selected_index = self.gr_loras.index(gr_args['lora'])
        return img, args_to_show
        

    def show_credits(self):
        gr.Markdown(
            """
            ### Model Credits
            * Diffusion Models are downloaded from [huggingface](https://huggingface.co): [stable diffusion 1.5](https://huggingface.co/runwayml/stable-diffusion-v1-5), [realistic vision](https://huggingface.co/stablediffusionapi/realistic-vision-v51), [pastel mix](https://huggingface.co/stablediffusionapi/pastel-mix-stylized-anime)
            * LoRA Models are downloaded from [civitai](https://civitai.com) and [liblib](https://www.liblib.art): [film](https://civitai.com/models/90393/japan-vibes-film-color), [snow](https://www.liblib.art/modelinfo/f732b23b02f041bdb7f8f3f8a256ca8b), [flat](https://www.liblib.art/modelinfo/76dcb8b59d814960b0244849f2747a15), [minecraft](https://civitai.com/models/113741/minecraft-square-style), [animeoutline](https://civitai.com/models/16014/anime-lineart-manga-like-style), [impressionism](https://civitai.com/models/113383/y5-impressionism-style), [pop](https://civitai.com/models/161450?modelVersionId=188417), [shinkai_makoto](https://civitai.com/models/10626?modelVersionId=12610) 
            """
        )


    def ui(self):
        with gr.Blocks(css='.input_image img {object-fit: contain;}', head=self.ga_script) as demo:

            self.title()

            if self.device == "cpu":
                self.device_requirements()

            with gr.Row():
                self.get_image()

                with gr.Column():
                    self.get_prompts()
                    self.get_base_model()
                    self.get_lora(num_cols=3)
                    submit_btn = gr.Button("Generate", variant='primary')
                    if self.device == 'cpu':
                        submit_btn.variant='secondary'

            with gr.Accordion("advanced options", open=False):
                self.get_params()   
            
            with gr.Row():
                with gr.Column():
                    output_image = gr.Image(label="output image")
                metadata = gr.JSON(label='metadata')

                submit_btn.click(self.run_ditail,
                                inputs=list(self.args_input.values()),
                                outputs=[output_image, metadata],
                                scroll_to_output=True,
                                )

            with gr.Row():
                cache_examples = not self.debug_mode
                gr.Examples(
                    examples=[[os.path.join(os.path.dirname(__file__), "example", "Cocktail.jpg"), 'a glass of a cocktail with a lime wedge on it', list(BASE_MODEL.keys())[1], list(BASE_MODEL.keys())[1], 'impressionism']],
                    inputs=[self.args_input['img'], self.args_input['pos_prompt'], self.args_input['inv_model'], self.args_input['spl_model'], gr.Textbox(label='LoRA', visible=False), ],
                    fn = self.run_example,
                    outputs=[output_image, metadata],
                    run_on_click=True,
                    # cache_examples=cache_examples,
                )

            self.show_credits()
        
            demo.load(None, js=self.ga_load)
        return demo


app = WebApp(debug_mode=False)
demo = app.ui()


if __name__ == "__main__":
    demo.launch(share=True)