File size: 7,602 Bytes
0925cf1
a58c598
 
c63d488
1aa42ec
 
 
0925cf1
1aa42ec
 
 
fd34bb6
1aa42ec
 
c2b6d1e
1aa42ec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c92c1fc
1aa42ec
05c550f
 
 
 
 
 
bdf96ee
05c550f
 
1aa42ec
 
05c550f
 
 
 
 
1aa42ec
 
 
05c550f
1aa42ec
 
 
c92c1fc
1aa42ec
 
0925cf1
1aa42ec
c6747cf
e66a721
1aa42ec
e66a721
 
 
 
c6747cf
 
874cb7c
65dc494
c6747cf
e66a721
1aa42ec
 
 
 
 
8f724dc
 
 
c8f91a3
8f724dc
db07984
dfe65d8
 
 
 
 
 
 
 
 
 
e66a721
1cdbaa3
b282552
c28f29b
 
 
edf126d
92ec9db
563066a
a031477
1aa42ec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8dae870
1aa42ec
 
 
 
 
 
 
 
 
82d2444
761d42b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
import torch
from diffusers.models import AutoencoderKL
from diffusers import StableDiffusionXLPipeline, StableDiffusionXLImg2ImgPipeline
import gradio as gr
import subprocess
import requests
#import spaces

models_list = []
loras_list = [ "None" ]
models = {}

def download_file(url, filename, progress=gr.Progress(track_tqdm=True)):
    response = requests.get(url, stream=True)

    total_size_in_bytes= int(response.headers.get('content-length', 0))
    block_size = 1024 #1 Kibibyte
    progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True)

    with open(filename, 'wb') as file:
        for data in response.iter_content(block_size):
            progress_bar.update(len(data))
            file.write(data)
    progress_bar.close()
    if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes:
        print("ERROR, something went wrong")

def download_civitai_model(model_id, lora_id = ""):      
    if model_id.startswith("http"):    
        headers = {
            "Content-Type": "application/json"
        }

        response = requests.get(model_id, headers=headers)

        # The response is a requests.Response object, and you can get the JSON content like this:
        data = response.json()

        # The model name should be accessible as:
        model_name = data['name']
        download_file(model_id, model_name)
        
    else:
        model_url = "https://civitai.com/api/download/models/{model_id}"
        
        headers = {
            "Content-Type": "application/json"
        }

        response = requests.get(model_url, headers=headers)

        # The response is a requests.Response object, and you can get the JSON content like this:
        data = response.json()

        # The model name should be accessible as:
        model_name = data['name']
        download_file(model_url, model_name)
    
    if lora_id.startswith("http"):            
        headers = {
            "Content-Type": "application/json"
        }

        response = requests.get(model_id, headers=headers)

        # The response is a requests.Response object, and you can get the JSON content like this:
        data = response.json()

        # The model name should be accessible as:
        model_name = data['name']
        download_file(lora_id, lora_name)
        
    elif lora_id != None or "":
        lora_url = "https://civitai.com/api/download/models/{lora_id}"
        
        headers = {
            "Content-Type": "application/json"
        }

        response = requests.get(lora_url, headers=headers)

        # The response is a requests.Response object, and you can get the JSON content like this:
        data = response.json()

        # The model name should be accessible as:
        lora_name = data['name']
        download_file(lora_id, lora_name)
        
    models_list.append(model_name)
    loras_list.append(lora_name)
        
    return "Model/LoRA Downloaded!"  

def load_model(model, lora = "", use_lora = False):    
    try:
        print(f"\n\nLoading {model}...")
        vae = AutoencoderKL.from_pretrained(
            "madebyollin/sdxl-vae-fp16-fix",
            torch_dtype=torch.float16,
        )

        pipeline = (
            StableDiffusionXLPipeline.from_pretrained
        )
        
        models[model] = pipeline(
            model,
            vae=vae,
            torch_dtype=torch.float16,
            custom_pipeline="lpw_stable_diffusion_xl",
            add_watermarker=False,
        )
        
        if use_lora and lora != "":
            models[model].load_lora_weights(lora)

        models[model].to("cuda")
        
        return "Model/LoRA downloaded successfully!"
    except Exception as e:
        gr.Error(f"Error loading model {model}: {e}")
        print(f"Error loading model {model}: {e}")

#@spaces.GPU
def generate_images(
    model_name,
    lora_name,
    prompt,
    negative_prompt,
    num_inference_steps,
    guidance_scale,
    height,
    width,
    num_images=4,
    progress=gr.Progress(track_tqdm=True)
):
    if prompt is not None and prompt.strip() != "":
        if lora_name == "None":
            load_model(model_name, "", False)
        elif lora_name in loras_list and lora_name != "None":
            load_model(model_name, lora_name, True)
        
        pipe = models.get(model_name)
        if pipe is None:
            return []

        outputs = []

        for _ in range(num_images):
            output = pipe(
                prompt,
                negative_prompt=negative_prompt,
                num_inference_steps=num_inference_steps,
                guidance_scale=guidance_scale,
                height=height,
                width=width
            )["images"][0]
            outputs.append(output)

        

        return outputs
    else:
        gr.Warning("Prompt empty!")

# Create the Gradio blocks
with gr.Blocks(theme='ParityError/Interstellar') as demo:
    with gr.Row(equal_height=False):
        with gr.Tab("Generate"):
            with gr.Column(elem_id="input_column"):
                with gr.Group(elem_id="input_group"):
                    model_dropdown = gr.Dropdown(choices=models_list, value=models_list[0] if models_list else None, label="Model", elem_id="model_dropdown")
                    lora_dropdown = gr.Dropdown(choices=loras_list, value=loras_list[0], label="LoRA")
                    prompt = gr.Textbox(label="Prompt", elem_id="prompt_textbox")
                    generate_btn = gr.Button("Generate Image", elem_id="generate_button")
                with gr.Accordion("Advanced", open=False, elem_id="advanced_accordion"):
                    negative_prompt = gr.Textbox(label="Negative Prompt", value="lowres, (bad), text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality, watermark, unfinished, displeasing, oldest, early, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract]", elem_id="negative_prompt_textbox")
                    num_inference_steps = gr.Slider(minimum=10, maximum=50, step=1, value=25, label="Number of Inference Steps", elem_id="num_inference_steps_slider")
                    guidance_scale = gr.Slider(minimum=1, maximum=20, step=0.5, value=7.5, label="Guidance Scale", elem_id="guidance_scale_slider")
                    height = gr.Slider(minimum=1024, maximum=2048, step=256, value=1024, label="Height", elem_id="height_slider")
                    width = gr.Slider(minimum=1024, maximum=2048, step=256, value=1024, label="Width", elem_id="width_slider")
                    num_images = gr.Slider(minimum=1, maximum=4, step=1, value=4, label="Number of Images", elem_id="num_images_slider")
            with gr.Column(elem_id="output_column"):
                output_gallery = gr.Gallery(label="Generated Images", height=480, scale=1, elem_id="output_gallery")
            
            generate_btn.click(generate_images, inputs=[model_dropdown, lora_dropdown, prompt, negative_prompt, num_inference_steps, guidance_scale, height, width, num_images], outputs=output_gallery)
        
        
        with gr.Tab("Download Custom Model"):
            with gr.Group():
                modelId = gr.Textbox(label="CivitAI Model ID")
                loraId = gr.Textbox(label="CivitAI LoRA ID (Optional)")
                download_button = gr.Button("Download Model")
            
            download_output = gr.Textbox(label="Download Output")
            
            download_button.click(download_civitai_model, inputs=[modelId, loraId], outputs=download_output)

demo.launch()