Stoltz commited on
Commit
1aa42ec
1 Parent(s): cb8f4ad

Added a new custom model downloader

Browse files
Files changed (2) hide show
  1. app.py +133 -30
  2. requirements.txt +3 -1
app.py CHANGED
@@ -2,19 +2,98 @@ import torch
2
  from diffusers.models import AutoencoderKL
3
  from diffusers import StableDiffusionXLPipeline, StableDiffusionXLImg2ImgPipeline
4
  import gradio as gr
5
- import os
6
- import spaces
 
7
 
8
- # Load the models outside of the generate_images function
9
- model_list = [model.strip() for model in os.environ.get("MODELS").split(",")]
10
- lora_list = [model.strip() for model in os.environ.get("LORAS").split(",")]
11
 
12
- print(f"Detected {len(model_list)} on models and {len(lora_list)} LoRAs.")
 
13
 
14
- models = {}
15
- for model_name in model_list:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  try:
17
- print(f"\n\nLoading {model_name}...")
18
  vae = AutoencoderKL.from_pretrained(
19
  "madebyollin/sdxl-vae-fp16-fix",
20
  torch_dtype=torch.float16,
@@ -24,21 +103,28 @@ for model_name in model_list:
24
  StableDiffusionXLPipeline.from_pretrained
25
  )
26
 
27
- models[model_name] = pipeline(
28
- model_name,
29
  vae=vae,
30
  torch_dtype=torch.float16,
31
  custom_pipeline="lpw_stable_diffusion_xl",
32
  add_watermarker=False,
33
  )
 
 
 
34
 
35
- models[model_name].to("cuda")
 
 
36
  except Exception as e:
37
- print(f"Error loading model {model_name}: {e}")
 
38
 
39
- @spaces.GPU
40
  def generate_images(
41
  model_name,
 
42
  prompt,
43
  negative_prompt,
44
  num_inference_steps,
@@ -49,11 +135,15 @@ def generate_images(
49
  progress=gr.Progress(track_tqdm=True)
50
  ):
51
  if prompt is not None and prompt.strip() != "":
 
 
 
 
 
52
  pipe = models.get(model_name)
53
  if pipe is None:
54
  return []
55
 
56
- print(f"Prompt is: [ {prompt} ]")
57
  outputs = []
58
 
59
  for _ in range(num_images):
@@ -76,21 +166,34 @@ def generate_images(
76
  # Create the Gradio blocks
77
  with gr.Blocks(theme='ParityError/Interstellar') as demo:
78
  with gr.Row(equal_height=False):
79
- with gr.Column(elem_id="input_column"):
80
- with gr.Group(elem_id="input_group"):
81
- model_dropdown = gr.Dropdown(choices=list(models.keys()), value=model_list[0] if model_list else None, label="Model", elem_id="model_dropdown")
82
- prompt = gr.Textbox(label="Prompt", elem_id="prompt_textbox")
83
- generate_btn = gr.Button("Generate Image", elem_id="generate_button")
84
- with gr.Accordion("Advanced", open=False, elem_id="advanced_accordion"):
85
- negative_prompt = gr.Textbox(label="Negative Prompt", value="lowres, (bad), text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality, watermark, unfinished, displeasing, oldest, early, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract]", elem_id="negative_prompt_textbox")
86
- num_inference_steps = gr.Slider(minimum=10, maximum=50, step=1, value=25, label="Number of Inference Steps", elem_id="num_inference_steps_slider")
87
- guidance_scale = gr.Slider(minimum=1, maximum=20, step=0.5, value=7.5, label="Guidance Scale", elem_id="guidance_scale_slider")
88
- height = gr.Slider(minimum=1024, maximum=2048, step=256, value=1024, label="Height", elem_id="height_slider")
89
- width = gr.Slider(minimum=1024, maximum=2048, step=256, value=1024, label="Width", elem_id="width_slider")
90
- num_images = gr.Slider(minimum=1, maximum=4, step=1, value=4, label="Number of Images", elem_id="num_images_slider")
91
- with gr.Column(elem_id="output_column"):
92
- output_gallery = gr.Gallery(label="Generated Images", height=480, scale=1, elem_id="output_gallery")
 
 
 
 
 
93
 
94
- generate_btn.click(generate_images, inputs=[model_dropdown, prompt, negative_prompt, num_inference_steps, guidance_scale, height, width, num_images], outputs=output_gallery)
 
 
 
 
 
 
 
 
95
 
96
  demo.launch()
 
2
  from diffusers.models import AutoencoderKL
3
  from diffusers import StableDiffusionXLPipeline, StableDiffusionXLImg2ImgPipeline
4
  import gradio as gr
5
+ import subprocess
6
+ import requests
7
+ #import spaces
8
 
9
+ models_list = []
10
+ loras_list = [ "None" ]
11
+ models = {}
12
 
13
+ def download_file(url, filename, progress=gr.Progress(track_tqdm=True)):
14
+ response = requests.get(url, stream=True)
15
 
16
+ total_size_in_bytes= int(response.headers.get('content-length', 0))
17
+ block_size = 1024 #1 Kibibyte
18
+ progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True)
19
+
20
+ with open(filename, 'wb') as file:
21
+ for data in response.iter_content(block_size):
22
+ progress_bar.update(len(data))
23
+ file.write(data)
24
+ progress_bar.close()
25
+ if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes:
26
+ print("ERROR, something went wrong")
27
+
28
+ def download_civitai_model(model_id, lora_id = ""):
29
+ if model_id.startswith("http"):
30
+ headers = {
31
+ "Content-Type": "application/json"
32
+ }
33
+
34
+ response = requests.get(model_id, headers=headers)
35
+
36
+ # The response is a requests.Response object, and you can get the JSON content like this:
37
+ data = response.json()
38
+
39
+ # The model name should be accessible as:
40
+ model_name = data['name']
41
+ download_file(model_id, model_name)
42
+
43
+ else:
44
+ model_url = "https://civitai.com/api/download/models/{model_id}"
45
+
46
+ headers = {
47
+ "Content-Type": "application/json"
48
+ }
49
+
50
+ response = requests.get(model_url, headers=headers)
51
+
52
+ # The response is a requests.Response object, and you can get the JSON content like this:
53
+ data = response.json()
54
+
55
+ # The model name should be accessible as:
56
+ model_name = data['name']
57
+ download_file(model_url, model_name)
58
+
59
+ if lora_id.startswith("http"):
60
+ headers = {
61
+ "Content-Type": "application/json"
62
+ }
63
+
64
+ response = requests.get(model_id, headers=headers)
65
+
66
+ # The response is a requests.Response object, and you can get the JSON content like this:
67
+ data = response.json()
68
+
69
+ # The model name should be accessible as:
70
+ model_name = data['name']
71
+ download_file(lora_id, lora_name)
72
+
73
+ elif lora_id != None or "":
74
+ lora_url = "https://civitai.com/api/download/models/{lora_id}"
75
+
76
+ headers = {
77
+ "Content-Type": "application/json"
78
+ }
79
+
80
+ response = requests.get(lora_url, headers=headers)
81
+
82
+ # The response is a requests.Response object, and you can get the JSON content like this:
83
+ data = response.json()
84
+
85
+ # The model name should be accessible as:
86
+ lora_name = data['name']
87
+ download_file(lora_id, lora_name)
88
+
89
+ models_list.append(model_name)
90
+ loras_list.append(lora_name)
91
+
92
+ return "Model/LoRA Downloaded!"
93
+
94
+ def load_model(model, lora = "", use_lora = False):
95
  try:
96
+ print(f"\n\nLoading {model}...")
97
  vae = AutoencoderKL.from_pretrained(
98
  "madebyollin/sdxl-vae-fp16-fix",
99
  torch_dtype=torch.float16,
 
103
  StableDiffusionXLPipeline.from_pretrained
104
  )
105
 
106
+ models[model] = pipeline(
107
+ model,
108
  vae=vae,
109
  torch_dtype=torch.float16,
110
  custom_pipeline="lpw_stable_diffusion_xl",
111
  add_watermarker=False,
112
  )
113
+
114
+ if use_lora and lora != "":
115
+ models[model].load_lora_weights(lora)
116
 
117
+ models[model].to("cuda")
118
+
119
+ return "Model/LoRA downloaded successfully!"
120
  except Exception as e:
121
+ gr.Error(f"Error loading model {model}: {e}")
122
+ print(f"Error loading model {model}: {e}")
123
 
124
+ #@spaces.GPU
125
  def generate_images(
126
  model_name,
127
+ lora_name,
128
  prompt,
129
  negative_prompt,
130
  num_inference_steps,
 
135
  progress=gr.Progress(track_tqdm=True)
136
  ):
137
  if prompt is not None and prompt.strip() != "":
138
+ if lora_name == "None":
139
+ load_model(model_name, "", False)
140
+ elif lora_name in loras_list and lora_name != "None":
141
+ load_model(model_name, lora_name, True)
142
+
143
  pipe = models.get(model_name)
144
  if pipe is None:
145
  return []
146
 
 
147
  outputs = []
148
 
149
  for _ in range(num_images):
 
166
  # Create the Gradio blocks
167
  with gr.Blocks(theme='ParityError/Interstellar') as demo:
168
  with gr.Row(equal_height=False):
169
+ with gr.Tab("Generate"):
170
+ with gr.Column(elem_id="input_column"):
171
+ with gr.Group(elem_id="input_group"):
172
+ model_dropdown = gr.Dropdown(choices=models_list, value=models_list[0] if models_list else None, label="Model", elem_id="model_dropdown")
173
+ lora_dropdown = gr.Dropdown(choices=loras_list, value=loras_list[0], label="LoRA")
174
+ prompt = gr.Textbox(label="Prompt", elem_id="prompt_textbox")
175
+ generate_btn = gr.Button("Generate Image", elem_id="generate_button")
176
+ with gr.Accordion("Advanced", open=False, elem_id="advanced_accordion"):
177
+ negative_prompt = gr.Textbox(label="Negative Prompt", value="lowres, (bad), text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality, watermark, unfinished, displeasing, oldest, early, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract]", elem_id="negative_prompt_textbox")
178
+ num_inference_steps = gr.Slider(minimum=10, maximum=50, step=1, value=25, label="Number of Inference Steps", elem_id="num_inference_steps_slider")
179
+ guidance_scale = gr.Slider(minimum=1, maximum=20, step=0.5, value=7.5, label="Guidance Scale", elem_id="guidance_scale_slider")
180
+ height = gr.Slider(minimum=1024, maximum=2048, step=256, value=1024, label="Height", elem_id="height_slider")
181
+ width = gr.Slider(minimum=1024, maximum=2048, step=256, value=1024, label="Width", elem_id="width_slider")
182
+ num_images = gr.Slider(minimum=1, maximum=4, step=1, value=4, label="Number of Images", elem_id="num_images_slider")
183
+ with gr.Column(elem_id="output_column"):
184
+ output_gallery = gr.Gallery(label="Generated Images", height=480, scale=1, elem_id="output_gallery")
185
+
186
+ generate_btn.click(generate_images, inputs=[model_dropdown, lora_dropdown, prompt, negative_prompt, num_inference_steps, guidance_scale, height, width, num_images], outputs=output_gallery)
187
+
188
 
189
+ with gr.Tab("Download Custom Model"):
190
+ with gr.Group():
191
+ modelId = gr.Textbox(label="CivitAI Model ID")
192
+ loraId = gr.Textbox(label="CivitAI LoRA ID (Optional)")
193
+ download_button = gr.Button("Download Model")
194
+
195
+ download_output = gr.Textbox(label="Download Output")
196
+
197
+ download_button.click(download_civitai_model, inputs=[modelId, loraId], outputs=download_output)
198
 
199
  demo.launch()
requirements.txt CHANGED
@@ -2,4 +2,6 @@ diffusers
2
  transformers
3
  accelerate
4
  torch
5
- bs4
 
 
 
2
  transformers
3
  accelerate
4
  torch
5
+ bs4
6
+ gradio
7
+ wget