Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import torch | |
from tqdm import tqdm | |
from diffusers.models import AutoencoderKL | |
from diffusers import StableDiffusionXLPipeline | |
import gradio as gr | |
import requests | |
import spaces | |
# Ensure directories exist | |
os.makedirs('models', exist_ok=True) | |
os.makedirs('loras', exist_ok=True) | |
models_list = [] | |
loras_list = ["None"] | |
models = {} | |
def download_file(url, filename, progress=gr.Progress(track_tqdm=True)): | |
response = requests.get(url, stream=True) | |
total_size_in_bytes = int(response.headers.get('content-length', 0)) | |
block_size = 1024 # 1 Kibibyte | |
progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True) | |
with open(filename, 'wb') as file: | |
for data in response.iter_content(block_size): | |
progress_bar.update(len(data)) | |
file.write(data) | |
progress_bar.close() | |
if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes: | |
print("ERROR, something went wrong") | |
def get_civitai_model_info(model_id): | |
url = f"https://civitai.com/api/v1/models/{model_id}" | |
response = requests.get(url) | |
if response.status_code != 200: | |
return None | |
return response.json() | |
def find_download_url(data, file_extension): | |
for file in data.get('modelVersions', [{}])[0].get('files', []): | |
if file['name'].endswith(file_extension): | |
return file['downloadUrl'] | |
else: | |
return None | |
def download_and_load_civitai_model(model_id, lora_id="", progress=gr.Progress(track_tqdm=True)): | |
try: | |
model_data = get_civitai_model_info(model_id) | |
if model_data is None: | |
return f"Error: Model with ID {model_id} not found." | |
model_name = model_data['name'] | |
model_ckpt_url = find_download_url(model_data, '.ckpt') | |
model_safetensors_url = find_download_url(model_data, '.safetensors') | |
model_url = model_ckpt_url or model_safetensors_url | |
if not model_url: | |
return f"Error: No suitable file found for model {model_name}." | |
file_extension = '.ckpt' if model_ckpt_url else '.safetensors' | |
model_filename = os.path.join('models', f"{model_name}{file_extension}") | |
download_file(model_url, model_filename) | |
if lora_id: | |
lora_data = get_civitai_model_info(lora_id) | |
if lora_data is None: | |
return f"Error: LoRA with ID {lora_id} not found." | |
lora_name = lora_data['name'] | |
lora_safetensors_url = find_download_url(lora_data, '.safetensors') | |
if not lora_safetensors_url: | |
return f"Error: No suitable file found for LoRA {lora_name}." | |
lora_filename = os.path.join('loras', f"{lora_name}.safetensors") | |
download_file(lora_safetensors_url, lora_filename) | |
if lora_name not in loras_list: | |
loras_list.append(lora_name) | |
else: | |
lora_name = "None" | |
if model_name not in models_list: | |
models_list.append(model_name) | |
# Load model after downloading | |
load_result = load_model(model_filename, lora_name, use_lora=(lora_name != "None")) | |
return f"Model/LoRA Downloaded and Loaded! {load_result}" | |
except Exception as e: | |
return f"Error downloading model or LoRA: {e}" | |
def refresh_dropdowns(): | |
return gr.update(choices=models_list), gr.update(choices=loras_list) | |
def load_model(model, lora="", use_lora=False, progress=gr.Progress(track_tqdm=True)): | |
try: | |
print(f"\n\nLoading {model}...") | |
gr.Info(f"Loading {model}, it may take a while.") | |
vae = AutoencoderKL.from_pretrained( | |
"madebyollin/sdxl-vae-fp16-fix", | |
torch_dtype=torch.float16, | |
) | |
pipeline = StableDiffusionXLPipeline.from_single_file( | |
model, | |
vae=vae, | |
torch_dtype=torch.float16, | |
) | |
if use_lora and lora != "": | |
lora_path = os.path.join('loras', lora + '.safetensors') | |
pipeline.load_lora_weights(lora_path) | |
pipeline.to("cuda") | |
models[model] = pipeline | |
return "Model/LoRA loaded successfully!" | |
except Exception as e: | |
return f"Error loading model {model}: {e}" | |
def generate_images( | |
model_name, | |
lora_name, | |
prompt, | |
negative_prompt, | |
num_inference_steps, | |
guidance_scale, | |
height, | |
width, | |
num_images=4, | |
progress=gr.Progress(track_tqdm=True) | |
): | |
if prompt is not None and prompt.strip() != "": | |
pipe = models.get(model_name) | |
if pipe is None: | |
return [] | |
outputs = [] | |
for _ in range(num_images): | |
output = pipe( | |
prompt, | |
negative_prompt=negative_prompt, | |
num_inference_steps=num_inference_steps, | |
guidance_scale=guidance_scale, | |
height=height, | |
width=width | |
)["images"][0] | |
outputs.append(output) | |
return outputs | |
else: | |
return gr.Warning("Prompt empty!") | |
# Create the Gradio blocks | |
with gr.Blocks(theme='ParityError/Interstellar') as demo: | |
with gr.Row(equal_height=False): | |
with gr.Tab("Generate"): | |
with gr.Column(elem_id="input_column"): | |
with gr.Group(elem_id="input_group"): | |
model_dropdown = gr.Dropdown(choices=models_list, value=models_list[0] if models_list else None, label="Model", elem_id="model_dropdown") | |
lora_dropdown = gr.Dropdown(choices=loras_list, value=loras_list[0], label="LoRA") | |
refresh_btn = gr.Button("Refresh Dropdowns") | |
prompt = gr.Textbox(label="Prompt", elem_id="prompt_textbox") | |
generate_btn = gr.Button("Generate Image", elem_id="generate_button") | |
with gr.Accordion("Advanced", open=False, elem_id="advanced_accordion"): | |
negative_prompt = gr.Textbox(label="Negative Prompt", value="lowres, (bad), text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality, watermark, unfinished, displeasing, oldest, early, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract]", elem_id="negative_prompt_textbox") | |
num_inference_steps = gr.Slider(minimum=10, maximum=50, step=1, value=25, label="Number of Inference Steps", elem_id="num_inference_steps_slider") | |
guidance_scale = gr.Slider(minimum=1, maximum=20, step=0.5, value=7.5, label="Guidance Scale", elem_id="guidance_scale_slider") | |
height = gr.Slider(minimum=1024, maximum=2048, step=256, value=1024, label="Height", elem_id="height_slider") | |
width = gr.Slider(minimum=1024, maximum=2048, step=256, value=1024, label="Width", elem_id="width_slider") | |
num_images = gr.Slider(minimum=1, maximum=4, step=1, value=4, label="Number of Images", elem_id="num_images_slider") | |
with gr.Column(elem_id="output_column"): | |
output_gallery = gr.Gallery(label="Generated Images", height=480, scale=1, elem_id="output_gallery") | |
refresh_btn.click(refresh_dropdowns, outputs=[model_dropdown, lora_dropdown]) | |
generate_btn.click(generate_images, inputs=[model_dropdown, lora_dropdown, prompt, negative_prompt, num_inference_steps, guidance_scale, height, width, num_images], outputs=output_gallery) | |
with gr.Tab("Download Custom Model"): | |
with gr.Group(): | |
model_id = gr.Textbox(label="CivitAI Model ID") | |
lora_id = gr.Textbox(label="CivitAI LoRA ID (Optional)") | |
download_button = gr.Button("Download Model") | |
download_output = gr.Textbox(label="Download Output") | |
download_button.click(download_and_load_civitai_model, inputs=[model_id, lora_id], outputs=download_output) | |
demo.launch() |