File size: 5,939 Bytes
262b155
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9659375
262b155
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
# Authors: Hui Ren (rhfeiyang.github.io)
import os

import gradio as gr
from diffusers import DiffusionPipeline
import matplotlib.pyplot as plt
import torch
from PIL import Image



device = "cuda" if torch.cuda.is_available() else "cpu"
pipe = DiffusionPipeline.from_pretrained("rhfeiyang/art-free-diffusion-v1",).to(device)

from inference import get_lora_network, inference, get_validation_dataloader
lora_map = {
    "None": "None",
    "Andre Derain": "andre-derain_subset1",
    "Vincent van Gogh": "van_gogh_subset1",
    "Andy Warhol": "andy_subset1",
    "Walter Battiss": "walter-battiss_subset2",
    "Camille Corot": "camille-corot_subset1",
    "Claude Monet": "monet_subset2",
    "Pablo Picasso": "picasso_subset1",
    "Jackson Pollock": "jackson-pollock_subset1",
    "Gerhard Richter": "gerhard-richter_subset1",
    "M.C. Escher": "m.c.-escher_subset1",
    "Albert Gleizes": "albert-gleizes_subset1",
    "Hokusai": "katsushika-hokusai_subset1",
    "Wassily Kandinsky": "kandinsky_subset1",
    "Gustav Klimt": "klimt_subset3",
    "Roy Lichtenstein": "roy-lichtenstein_subset1",
    "Henri Matisse": "henri-matisse_subset1",
    "Joan Miro": "joan-miro_subset2",
}

def demo_inference_gen(adapter_choice:str, prompt:str, samples:int=1,seed:int=0, steps=50, guidance_scale=7.5):
    adapter_path = lora_map[adapter_choice]
    if adapter_path not in [None, "None"]:
        adapter_path = f"data/Art_adapters/{adapter_path}/adapter_alpha1.0_rank1_all_up_1000steps.pt"

    prompts = [prompt]*samples
    infer_loader = get_validation_dataloader(prompts)
    network = get_lora_network(pipe.unet, adapter_path)["network"]
    pred_images = inference(network, pipe.tokenizer, pipe.text_encoder, pipe.vae, pipe.unet, pipe.scheduler, infer_loader,
                            height=512, width=512, scales=[1.0],
                            save_dir=None, seed=seed,steps=steps, guidance_scale=guidance_scale,
                            start_noise=-1, show=False, style_prompt="sks art", no_load=True,
                            from_scratch=True)[0][1.0]
    return pred_images

def demo_inference_stylization(adapter_path:str, prompts:list, image:list, start_noise=800,seed:int=0):
    infer_loader = get_validation_dataloader(prompts, image)
    network = get_lora_network(pipe.unet, adapter_path,"all_up")["network"]
    pred_images = inference(network, pipe.tokenizer, pipe.text_encoder, pipe.vae, pipe.unet, pipe.scheduler, infer_loader,
                            height=512, width=512, scales=[0.,1.],
                            save_dir=None, seed=seed,steps=20, guidance_scale=7.5,
                            start_noise=start_noise, show=True, style_prompt="sks art", no_load=True,
                            from_scratch=False)
    return pred_images

# def infer(prompt, samples, steps, scale, seed):
#     generator = torch.Generator(device=device).manual_seed(seed)
#     images_list = pipe(  # type: ignore
#         [prompt] * samples,
#         num_inference_steps=steps,
#         guidance_scale=scale,
#         generator=generator,
#     )
#     images = []
#     safe_image = Image.open(r"data/unsafe.png")
#     print(images_list)
#     for i, image in enumerate(images_list["images"]):  # type: ignore
#         if images_list["nsfw_content_detected"][i]:  # type: ignore
#             images.append(safe_image)
#         else:
#             images.append(image)
#     return images




block = gr.Blocks()
# Direct infer
with block:
    with gr.Group():
        gr.Markdown(" # Art-Free Diffusion Demo")
        with gr.Row():
            text = gr.Textbox(
                label="Enter your prompt",
                max_lines=2,
                placeholder="Enter your prompt",
                container=False,
                value="Park with cherry blossom trees, picnicker’s and a clear blue pond.",
            )



            btn = gr.Button("Run", scale=0)
        gallery = gr.Gallery(
            label="Generated images",
            show_label=False,
            elem_id="gallery",
            columns=[2],
        )

        advanced_button = gr.Button("Advanced options", elem_id="advanced-btn")

        with gr.Row(elem_id="advanced-options"):
            adapter_choice = gr.Dropdown(
                label="Choose adapter",
                choices=["None", "Andre Derain","Vincent van Gogh","Andy Warhol", "Walter Battiss",
                         "Camille Corot", "Claude Monet", "Pablo Picasso",
                         "Jackson Pollock", "Gerhard Richter", "M.C. Escher",
                         "Albert Gleizes", "Hokusai", "Wassily Kandinsky", "Gustav Klimt", "Roy Lichtenstein",
                         "Henri Matisse", "Joan Miro"
                         ],
                value="None"
            )
            # print(adapter_choice[0])
            # lora_path = lora_map[adapter_choice.value]
            # if lora_path is not None:
            #     lora_path = f"data/Art_adapters/{lora_path}/adapter_alpha1.0_rank1_all_up_1000steps.pt"

            samples = gr.Slider(label="Images", minimum=1, maximum=4, value=1, step=1)
            steps = gr.Slider(label="Steps", minimum=1, maximum=50, value=20, step=1)
            scale = gr.Slider(
                label="Guidance Scale", minimum=0, maximum=50, value=7.5, step=0.1
            )
            print(scale)
            seed = gr.Slider(
                label="Seed",
                minimum=0,
                maximum=2147483647,
                step=1,
                randomize=True,
            )

        gr.on([text.submit, btn.click], demo_inference_gen, inputs=[adapter_choice, text, samples, seed, steps, scale], outputs=gallery)
        advanced_button.click(
            None,
            [],
            text,
        )



block.launch()