Stoltz
Added a new custom model downloader
1aa42ec
raw
history blame
No virus
7.6 kB
import torch
from diffusers.models import AutoencoderKL
from diffusers import StableDiffusionXLPipeline, StableDiffusionXLImg2ImgPipeline
import gradio as gr
import subprocess
import requests
#import spaces
models_list = []
loras_list = [ "None" ]
models = {}
def download_file(url, filename, progress=gr.Progress(track_tqdm=True)):
response = requests.get(url, stream=True)
total_size_in_bytes= int(response.headers.get('content-length', 0))
block_size = 1024 #1 Kibibyte
progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True)
with open(filename, 'wb') as file:
for data in response.iter_content(block_size):
progress_bar.update(len(data))
file.write(data)
progress_bar.close()
if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes:
print("ERROR, something went wrong")
def download_civitai_model(model_id, lora_id = ""):
if model_id.startswith("http"):
headers = {
"Content-Type": "application/json"
}
response = requests.get(model_id, headers=headers)
# The response is a requests.Response object, and you can get the JSON content like this:
data = response.json()
# The model name should be accessible as:
model_name = data['name']
download_file(model_id, model_name)
else:
model_url = "https://civitai.com/api/download/models/{model_id}"
headers = {
"Content-Type": "application/json"
}
response = requests.get(model_url, headers=headers)
# The response is a requests.Response object, and you can get the JSON content like this:
data = response.json()
# The model name should be accessible as:
model_name = data['name']
download_file(model_url, model_name)
if lora_id.startswith("http"):
headers = {
"Content-Type": "application/json"
}
response = requests.get(model_id, headers=headers)
# The response is a requests.Response object, and you can get the JSON content like this:
data = response.json()
# The model name should be accessible as:
model_name = data['name']
download_file(lora_id, lora_name)
elif lora_id != None or "":
lora_url = "https://civitai.com/api/download/models/{lora_id}"
headers = {
"Content-Type": "application/json"
}
response = requests.get(lora_url, headers=headers)
# The response is a requests.Response object, and you can get the JSON content like this:
data = response.json()
# The model name should be accessible as:
lora_name = data['name']
download_file(lora_id, lora_name)
models_list.append(model_name)
loras_list.append(lora_name)
return "Model/LoRA Downloaded!"
def load_model(model, lora = "", use_lora = False):
try:
print(f"\n\nLoading {model}...")
vae = AutoencoderKL.from_pretrained(
"madebyollin/sdxl-vae-fp16-fix",
torch_dtype=torch.float16,
)
pipeline = (
StableDiffusionXLPipeline.from_pretrained
)
models[model] = pipeline(
model,
vae=vae,
torch_dtype=torch.float16,
custom_pipeline="lpw_stable_diffusion_xl",
add_watermarker=False,
)
if use_lora and lora != "":
models[model].load_lora_weights(lora)
models[model].to("cuda")
return "Model/LoRA downloaded successfully!"
except Exception as e:
gr.Error(f"Error loading model {model}: {e}")
print(f"Error loading model {model}: {e}")
#@spaces.GPU
def generate_images(
model_name,
lora_name,
prompt,
negative_prompt,
num_inference_steps,
guidance_scale,
height,
width,
num_images=4,
progress=gr.Progress(track_tqdm=True)
):
if prompt is not None and prompt.strip() != "":
if lora_name == "None":
load_model(model_name, "", False)
elif lora_name in loras_list and lora_name != "None":
load_model(model_name, lora_name, True)
pipe = models.get(model_name)
if pipe is None:
return []
outputs = []
for _ in range(num_images):
output = pipe(
prompt,
negative_prompt=negative_prompt,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
height=height,
width=width
)["images"][0]
outputs.append(output)
return outputs
else:
gr.Warning("Prompt empty!")
# Create the Gradio blocks
with gr.Blocks(theme='ParityError/Interstellar') as demo:
with gr.Row(equal_height=False):
with gr.Tab("Generate"):
with gr.Column(elem_id="input_column"):
with gr.Group(elem_id="input_group"):
model_dropdown = gr.Dropdown(choices=models_list, value=models_list[0] if models_list else None, label="Model", elem_id="model_dropdown")
lora_dropdown = gr.Dropdown(choices=loras_list, value=loras_list[0], label="LoRA")
prompt = gr.Textbox(label="Prompt", elem_id="prompt_textbox")
generate_btn = gr.Button("Generate Image", elem_id="generate_button")
with gr.Accordion("Advanced", open=False, elem_id="advanced_accordion"):
negative_prompt = gr.Textbox(label="Negative Prompt", value="lowres, (bad), text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality, watermark, unfinished, displeasing, oldest, early, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract]", elem_id="negative_prompt_textbox")
num_inference_steps = gr.Slider(minimum=10, maximum=50, step=1, value=25, label="Number of Inference Steps", elem_id="num_inference_steps_slider")
guidance_scale = gr.Slider(minimum=1, maximum=20, step=0.5, value=7.5, label="Guidance Scale", elem_id="guidance_scale_slider")
height = gr.Slider(minimum=1024, maximum=2048, step=256, value=1024, label="Height", elem_id="height_slider")
width = gr.Slider(minimum=1024, maximum=2048, step=256, value=1024, label="Width", elem_id="width_slider")
num_images = gr.Slider(minimum=1, maximum=4, step=1, value=4, label="Number of Images", elem_id="num_images_slider")
with gr.Column(elem_id="output_column"):
output_gallery = gr.Gallery(label="Generated Images", height=480, scale=1, elem_id="output_gallery")
generate_btn.click(generate_images, inputs=[model_dropdown, lora_dropdown, prompt, negative_prompt, num_inference_steps, guidance_scale, height, width, num_images], outputs=output_gallery)
with gr.Tab("Download Custom Model"):
with gr.Group():
modelId = gr.Textbox(label="CivitAI Model ID")
loraId = gr.Textbox(label="CivitAI LoRA ID (Optional)")
download_button = gr.Button("Download Model")
download_output = gr.Textbox(label="Download Output")
download_button.click(download_civitai_model, inputs=[modelId, loraId], outputs=download_output)
demo.launch()