File size: 6,370 Bytes
6da2189
49c69e8
 
 
 
90d82c3
49c69e8
 
 
2b7f523
e602685
7138b0a
 
 
e602685
 
 
6da2189
e602685
 
49c69e8
 
 
 
e602685
 
49c69e8
e602685
49c69e8
 
e602685
 
06c55e8
 
 
 
 
 
 
 
 
e602685
06c55e8
49c69e8
6da2189
49c69e8
 
6da2189
 
49c69e8
 
6da2189
c9928a7
49c69e8
 
 
 
c9928a7
49c69e8
6da2189
49c69e8
 
06c55e8
af0ed3b
49c69e8
c9928a7
49c69e8
 
 
 
 
 
 
e602685
49c69e8
 
4527b8b
49c69e8
4527b8b
e602685
 
 
49c69e8
e602685
 
 
57320b0
 
 
 
49c69e8
e602685
 
 
49c69e8
e602685
4347f54
e602685
49c69e8
e602685
 
 
49c69e8
4527b8b
 
06c55e8
57320b0
49c69e8
e602685
 
49c69e8
 
e602685
 
af0ed3b
 
 
 
e602685
af0ed3b
e602685
49c69e8
 
06c55e8
8278e6d
5a0d186
49c69e8
 
 
6da2189
49c69e8
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
from diffusers import CycleDiffusionPipeline, DDIMScheduler
import gradio as gr
import torch
from PIL import Image
import utils
import streamlit as st

is_colab = utils.is_google_colab()

if True:
    model_id_or_path = "CompVis/stable-diffusion-v1-4"
    scheduler = DDIMScheduler.from_config(model_id_or_path,
                                          use_auth_token=st.secrets["USER_TOKEN"],
                                          subfolder="scheduler")
    pipe = CycleDiffusionPipeline.from_pretrained(model_id_or_path,
                                                  use_auth_token=st.secrets["USER_TOKEN"],
                                                  scheduler=scheduler)

    if torch.cuda.is_available():
        pipe = pipe.to("cuda")

device = "GPU πŸ”₯" if torch.cuda.is_available() else "CPU πŸ₯Ά"


def inference(source_prompt, target_prompt, source_guidance_scale=1, guidance_scale=5, num_inference_steps=100,
              width=512, height=512, seed=0, img=None, strength=0.7):

    torch.manual_seed(seed)

    ratio = min(height / img.height, width / img.width)
    img = img.resize((int(img.width * ratio), int(img.height * ratio)))

    results = pipe(prompt=target_prompt,
                   source_prompt=source_prompt,
                   init_image=img,
                   num_inference_steps=num_inference_steps,
                   eta=0.1,
                   strength=strength,
                   guidance_scale=guidance_scale,
                   source_guidance_scale=source_guidance_scale,
                   )

    return replace_nsfw_images(results)


def replace_nsfw_images(results):
    for i in range(len(results.images)):
        if results.nsfw_content_detected[i]:
            results.images[i] = Image.open("nsfw.png")
    return results.images[0]


css = """.cycle-diffusion-div div{display:inline-flex;align-items:center;gap:.8rem;font-size:1.75rem}.cycle-diffusion-div div h1{font-weight:900;margin-bottom:7px}.cycle-diffusion-div p{margin-bottom:10px;font-size:94%}.cycle-diffusion-div p a{text-decoration:underline}.tabs{margin-top:0;margin-bottom:0}#gallery{min-height:20rem}
"""
with gr.Blocks(css=css) as demo:
    gr.HTML(
        f"""
            <div class="cycle-diffusion-div">
              <div>
                <h1>CycleDiffusion with Stable Diffusion</h1>
              </div>
              <p>
                Demo for CycleDiffusion with Stable Diffusion. <br>
                <a href="https://huggingface.co/docs/diffusers/main/en/api/pipelines/cycle_diffusion">🧨 Pipeline doc</a> | <a href="https://arxiv.org/abs/2210.05559">πŸ“„ Paper link</a>
              </p>
              <p>You can skip the queue in the colab: <a href="https://colab.research.google.com/gist/ChenWu98/0aa4fe7be80f6b45d3d055df9f14353a/copy-of-fine-tuned-diffusion-gradio.ipynb"><img data-canonical-src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab" src="https://colab.research.google.com/assets/colab-badge.svg"></a></p>
               Running on <b>{device}</b>{(" in a <b>Google Colab</b>." if is_colab else "")}
              </p>
            </div>
        """
    )
    with gr.Row():

        with gr.Column(scale=55):
            with gr.Group():

                img = gr.Image(label="Input image", height=512, tool="editor", type="pil")

                image_out = gr.Image(label="Output image", height=512)
                # gallery = gr.Gallery(
                #     label="Generated images", show_label=False, elem_id="gallery"
                # ).style(grid=[1], height="auto")

        with gr.Column(scale=45):
            with gr.Tab("Options"):
                with gr.Group():
                    with gr.Row():
                        source_prompt = gr.Textbox(label="Source prompt", placeholder="Source prompt describes the input image")
                    with gr.Row():
                        target_prompt = gr.Textbox(label="Target prompt", placeholder="Target prompt describes the output image")

                    with gr.Row():
                        source_guidance_scale = gr.Slider(label="Source guidance scale", value=1, minimum=1, maximum=10)
                        guidance_scale = gr.Slider(label="Target guidance scale", value=5, minimum=1, maximum=10)

                    with gr.Row():
                        num_inference_steps = gr.Slider(label="Number of inference steps", value=100, minimum=25, maximum=500, step=1)
                        strength = gr.Slider(label="Strength", value=0.7, minimum=0.5, maximum=1, step=0.01)

                    with gr.Row():
                        width = gr.Slider(label="Width", value=512, minimum=64, maximum=1024, step=8)
                        height = gr.Slider(label="Height", value=512, minimum=64, maximum=1024, step=8)

                    with gr.Row():
                        seed = gr.Slider(0, 2147483647, label='Seed', value=0, step=1)
                    with gr.Row():
                        generate = gr.Button(value="Edit")

    inputs = [source_prompt, target_prompt, source_guidance_scale, guidance_scale, num_inference_steps,
              width, height, seed, img, strength]
    generate.click(inference, inputs=inputs, outputs=image_out)

    ex = gr.Examples(
        [
            ["An astronaut riding a horse", "An astronaut riding an elephant", 1, 2, 100, 0, "images/astronaut_horse.png", 0.8],
            ["A black colored car.", "A blue colored car.", 1, 2, 100, 0, "images/black_car.png", 0.85],
            ["An aerial view of autumn scene.", "An aerial view of winter scene.", 1, 5, 100, 0, "images/mausoleum.png", 0.9],
            ["A green apple and a black backpack on the floor.", "A red apple and a black backpack on the floor.", 1, 7, 100, 0, "images/apple_bag.png", 0.9],
        ],
        [source_prompt, target_prompt, source_guidance_scale, guidance_scale, num_inference_steps, seed, img, strength],
        image_out, inference, cache_examples=False)

    gr.Markdown('''
      Space built with Diffusers 🧨 by HuggingFace πŸ€—.
      [![Twitter Follow](https://img.shields.io/twitter/follow/ChenHenryWu?style=social)](https://twitter.com/ChenHenryWu) 
      ![visitors](https://visitor-badge.glitch.me/badge?page_id=ChenWu98.CycleDiffusion)
    ''')

if not is_colab:
    demo.queue(concurrency_count=1)
demo.launch(debug=is_colab, share=is_colab)