Spaces:
Sleeping
Sleeping
File size: 7,896 Bytes
08203ce 0925cf1 3871549 a58c598 809d6c5 c63d488 1aa42ec 5bde01d 0925cf1 08203ce 1aa42ec e14fae9 1aa42ec fd34bb6 1aa42ec e14fae9 d2dc338 1aa42ec 09073c1 c3967c0 303d4ef c3967c0 f52f10c 3871549 c3967c0 d2dc338 c3967c0 3871549 c3967c0 08203ce d2dc338 3871549 c3967c0 d2dc338 c3967c0 3871549 08203ce 09073c1 d2dc338 3871549 09073c1 d2dc338 09073c1 3871549 1aa42ec e47b9ec d2dc338 c92c1fc d2dc338 05c550f b54b151 d2dc338 05c550f 3871549 08203ce e14fae9 d2dc338 e14fae9 c92c1fc d2dc338 0925cf1 5bde01d c6747cf e66a721 1aa42ec e66a721 c6747cf 874cb7c 65dc494 c6747cf 3871549 8f724dc c8f91a3 8f724dc db07984 dfe65d8 e66a721 c28f29b 3871549 edf126d 92ec9db 563066a a031477 1aa42ec e47b9ec 1aa42ec e14fae9 e47b9ec 1aa42ec e14fae9 1aa42ec 3871549 1aa42ec e14fae9 1aa42ec e14fae9 09073c1 82d2444 d2dc338 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 |
import os
import torch
from tqdm import tqdm
from diffusers.models import AutoencoderKL
from diffusers import StableDiffusionXLPipeline
import gradio as gr
import requests
import spaces
# Ensure directories exist
os.makedirs('models', exist_ok=True)
os.makedirs('loras', exist_ok=True)
models_list = []
loras_list = ["None"]
models = {}
def download_file(url, filename, progress=gr.Progress(track_tqdm=True)):
response = requests.get(url, stream=True)
total_size_in_bytes = int(response.headers.get('content-length', 0))
block_size = 1024 # 1 Kibibyte
progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True)
with open(filename, 'wb') as file:
for data in response.iter_content(block_size):
progress_bar.update(len(data))
file.write(data)
progress_bar.close()
if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes:
print("ERROR, something went wrong")
def get_civitai_model_info(model_id):
url = f"https://civitai.com/api/v1/models/{model_id}"
response = requests.get(url)
if response.status_code != 200:
return None
return response.json()
def find_download_url(data, file_extension):
for file in data.get('modelVersions', [{}])[0].get('files', []):
if file['name'].endswith(file_extension):
return file['downloadUrl']
else:
return None
def download_and_load_civitai_model(model_id, lora_id="", progress=gr.Progress(track_tqdm=True)):
try:
model_data = get_civitai_model_info(model_id)
if model_data is None:
return f"Error: Model with ID {model_id} not found."
model_name = model_data['name']
model_ckpt_url = find_download_url(model_data, '.ckpt')
model_safetensors_url = find_download_url(model_data, '.safetensors')
model_url = model_ckpt_url or model_safetensors_url
if not model_url:
return f"Error: No suitable file found for model {model_name}."
file_extension = '.ckpt' if model_ckpt_url else '.safetensors'
model_filename = os.path.join('models', f"{model_name}{file_extension}")
download_file(model_url, model_filename)
if lora_id:
lora_data = get_civitai_model_info(lora_id)
if lora_data is None:
return f"Error: LoRA with ID {lora_id} not found."
lora_name = lora_data['name']
lora_safetensors_url = find_download_url(lora_data, '.safetensors')
if not lora_safetensors_url:
return f"Error: No suitable file found for LoRA {lora_name}."
lora_filename = os.path.join('loras', f"{lora_name}.safetensors")
download_file(lora_safetensors_url, lora_filename)
if lora_name not in loras_list:
loras_list.append(lora_name)
else:
lora_name = "None"
if model_name not in models_list:
models_list.append(model_name)
# Load model after downloading
load_result = load_model(model_filename, lora_name, use_lora=(lora_name != "None"))
return f"Model/LoRA Downloaded and Loaded! {load_result}"
except Exception as e:
return f"Error downloading model or LoRA: {e}"
def refresh_dropdowns():
return gr.update(choices=models_list), gr.update(choices=loras_list)
def load_model(model, lora="", use_lora=False, progress=gr.Progress(track_tqdm=True)):
try:
print(f"\n\nLoading {model}...")
gr.Info(f"Loading {model}, it may take a while.")
vae = AutoencoderKL.from_pretrained(
"madebyollin/sdxl-vae-fp16-fix",
torch_dtype=torch.float16,
)
pipeline = StableDiffusionXLPipeline.from_single_file(
model,
vae=vae,
torch_dtype=torch.float16,
)
if use_lora and lora != "":
lora_path = os.path.join('loras', lora + '.safetensors')
pipeline.load_lora_weights(lora_path)
pipeline.to("cuda")
models[model] = pipeline
return "Model/LoRA loaded successfully!"
except Exception as e:
return f"Error loading model {model}: {e}"
@spaces.GPU
def generate_images(
model_name,
lora_name,
prompt,
negative_prompt,
num_inference_steps,
guidance_scale,
height,
width,
num_images=4,
progress=gr.Progress(track_tqdm=True)
):
if prompt is not None and prompt.strip() != "":
pipe = models.get(model_name)
if pipe is None:
return []
outputs = []
for _ in range(num_images):
output = pipe(
prompt,
negative_prompt=negative_prompt,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
height=height,
width=width
)["images"][0]
outputs.append(output)
return outputs
else:
return gr.Warning("Prompt empty!")
# Create the Gradio blocks
with gr.Blocks(theme='ParityError/Interstellar') as demo:
with gr.Row(equal_height=False):
with gr.Tab("Generate"):
with gr.Column(elem_id="input_column"):
with gr.Group(elem_id="input_group"):
model_dropdown = gr.Dropdown(choices=models_list, value=models_list[0] if models_list else None, label="Model", elem_id="model_dropdown")
lora_dropdown = gr.Dropdown(choices=loras_list, value=loras_list[0], label="LoRA")
refresh_btn = gr.Button("Refresh Dropdowns")
prompt = gr.Textbox(label="Prompt", elem_id="prompt_textbox")
generate_btn = gr.Button("Generate Image", elem_id="generate_button")
with gr.Accordion("Advanced", open=False, elem_id="advanced_accordion"):
negative_prompt = gr.Textbox(label="Negative Prompt", value="lowres, (bad), text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality, watermark, unfinished, displeasing, oldest, early, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract]", elem_id="negative_prompt_textbox")
num_inference_steps = gr.Slider(minimum=10, maximum=50, step=1, value=25, label="Number of Inference Steps", elem_id="num_inference_steps_slider")
guidance_scale = gr.Slider(minimum=1, maximum=20, step=0.5, value=7.5, label="Guidance Scale", elem_id="guidance_scale_slider")
height = gr.Slider(minimum=1024, maximum=2048, step=256, value=1024, label="Height", elem_id="height_slider")
width = gr.Slider(minimum=1024, maximum=2048, step=256, value=1024, label="Width", elem_id="width_slider")
num_images = gr.Slider(minimum=1, maximum=4, step=1, value=4, label="Number of Images", elem_id="num_images_slider")
with gr.Column(elem_id="output_column"):
output_gallery = gr.Gallery(label="Generated Images", height=480, scale=1, elem_id="output_gallery")
refresh_btn.click(refresh_dropdowns, outputs=[model_dropdown, lora_dropdown])
generate_btn.click(generate_images, inputs=[model_dropdown, lora_dropdown, prompt, negative_prompt, num_inference_steps, guidance_scale, height, width, num_images], outputs=output_gallery)
with gr.Tab("Download Custom Model"):
with gr.Group():
model_id = gr.Textbox(label="CivitAI Model ID")
lora_id = gr.Textbox(label="CivitAI LoRA ID (Optional)")
download_button = gr.Button("Download Model")
download_output = gr.Textbox(label="Download Output")
download_button.click(download_and_load_civitai_model, inputs=[model_id, lora_id], outputs=download_output)
demo.launch() |