File size: 3,394 Bytes
a660631
 
 
 
 
 
77d5c02
f92529f
a660631
 
 
 
 
 
 
f521e88
cce8954
 
a660631
6df5354
a660631
53b39de
f521e88
a660631
f521e88
a660631
554ac76
f521e88
a660631
f521e88
a6d82aa
 
 
f521e88
208f8fb
a660631
f521e88
84448a9
8fad46e
 
 
 
 
 
 
 
a660631
f521e88
a660631
d6252d0
f521e88
d6252d0
f521e88
a660631
d6252d0
 
f521e88
d6252d0
f521e88
 
 
 
d6252d0
f521e88
a660631
 
f521e88
a660631
 
681c919
 
 
 
f521e88
681c919
3f0b2d3
 
681c919
 
 
 
 
a660631
074f54b
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
#!/usr/bin/env python

from __future__ import annotations

import gradio as gr
import torch
torch.jit.script = lambda f: f
import spaces

from app_canny import create_demo as create_demo_canny
from app_depth import create_demo as create_demo_depth
from app_lineart import create_demo as create_demo_lineart
from app_segmentation import create_demo as create_demo_segmentation
from app_softedge import create_demo as create_demo_softedge
from model import Model
from settings import ALLOW_CHANGING_BASE_MODEL, DEFAULT_MODEL_ID, SHOW_DUPLICATE_BUTTON
from transformers.utils.hub import move_cache
move_cache()

DESCRIPTION = "# [ControlNet++: Improving Conditional Controls with Efficient Consistency Feedback](https://arxiv.org/abs/2404.07987) \n ### The first row in outputs is the input conditions. The second row is the images generated by ControlNet++. The third row is the conditions extracted from our generated images."

if not torch.cuda.is_available():
    DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"

model = Model(base_model_id=DEFAULT_MODEL_ID, task_name="Canny")


with gr.Blocks(css="style.css") as demo:
    gr.Markdown(DESCRIPTION)
    gr.DuplicateButton(
        value="Duplicate Space for private use",
        elem_id="duplicate-button",
        visible=SHOW_DUPLICATE_BUTTON,
    )

    with gr.Tabs():
        with gr.TabItem("Lineart"):
            create_demo_lineart(model.process_lineart)
        with gr.TabItem("Depth"):
            create_demo_depth(model.process_depth)
        with gr.TabItem("Segmentation"):
            create_demo_segmentation(model.process_segmentation)
        with gr.TabItem("SoftEdge"):
            create_demo_softedge(model.process_softedge)
        with gr.TabItem("Canny"):
            create_demo_canny(model.process_canny)

    with gr.Accordion(label="Base model", open=False):
        with gr.Row():
            with gr.Column(scale=5):
                current_base_model = gr.Text(label="Current base model")
            with gr.Column(scale=1):
                check_base_model_button = gr.Button("Check current base model")
        with gr.Row():
            with gr.Column(scale=5):
                new_base_model_id = gr.Text(
                    label="New base model",
                    max_lines=1,
                    placeholder="runwayml/stable-diffusion-v1-5",
                    info="The base model must be compatible with Stable Diffusion v1.5.",
                    interactive=ALLOW_CHANGING_BASE_MODEL,
                )
            with gr.Column(scale=1):
                change_base_model_button = gr.Button("Change base model", interactive=ALLOW_CHANGING_BASE_MODEL)
        if not ALLOW_CHANGING_BASE_MODEL:
            gr.Markdown(
                """The base model is not allowed to be changed in this Space so as not to slow down the demo, but it can be changed if you duplicate the Space."""
            )

    check_base_model_button.click(
        fn=lambda: model.base_model_id,
        outputs=current_base_model,
        queue=False,
        api_name="check_base_model",
    )
    gr.on(
        triggers=[new_base_model_id.submit, change_base_model_button.click],
        fn=model.set_base_model,
        inputs=new_base_model_id,
        outputs=current_base_model,
        api_name=False,
    )

if __name__ == "__main__":
    demo.queue(max_size=20).launch()