import torch from tqdm import tqdm from diffusers.models import AutoencoderKL from diffusers import StableDiffusionXLPipeline import gradio as gr import requests import spaces models_list = [] loras_list = ["None"] models = {} def download_file(url, filename, progress=gr.Progress(track_tqdm=True)): response = requests.get(url, stream=True) total_size_in_bytes = int(response.headers.get('content-length', 0)) block_size = 1024 # 1 Kibibyte progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True) with open(filename, 'wb') as file: for data in response.iter_content(block_size): progress_bar.update(len(data)) file.write(data) progress_bar.close() if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes: print("ERROR, something went wrong") def get_civitai_model_info(model_id): url = f"https://civitai.com/api/v1/models/{model_id}" response = requests.get(url) if response.status_code != 200: return None return response.json() def find_download_url(data, file_extension): for file in data.get('modelVersions', [{}])[0].get('files', []): if file['name'].endswith(file_extension): return file['downloadUrl'] return None def download_and_load_civitai_model(model_id, lora_id="", progress=gr.Progress(track_tqdm=True)): try: model_data = get_civitai_model_info(model_id) if model_data is None: return f"Error: Model with ID {model_id} not found." model_name = model_data['name'] model_ckpt_url = find_download_url(model_data, '.ckpt') model_safetensors_url = find_download_url(model_data, '.safetensors') model_url = model_ckpt_url or model_safetensors_url if not model_url: return f"Error: No suitable file found for model {model_name}." file_extension = '.ckpt' if model_ckpt_url else '.safetensors' model_filename = f"{model_name}{file_extension}" download_file(model_url, model_filename) if lora_id: lora_data = get_civitai_model_info(lora_id) if lora_data is None: return f"Error: LoRA with ID {lora_id} not found." lora_name = lora_data['name'] lora_safetensors_url = find_download_url(lora_data, '.safetensors') if not lora_safetensors_url: return f"Error: No suitable file found for LoRA {lora_name}." download_file(lora_safetensors_url, f"{lora_name}.safetensors") if lora_name not in loras_list: loras_list.append(lora_name) else: lora_name = "None" if model_name not in models_list: models_list.append(model_name) # Load model after downloading load_result = load_model(model_filename, lora_name, use_lora=(lora_name != "None")) return f"Model/LoRA Downloaded and Loaded! {load_result}" except Exception as e: return f"Error downloading model or LoRA: {e}" def refresh_dropdowns(): return gr.update(choices=models_list), gr.update(choices=loras_list) def load_model(model, lora="", use_lora=False): try: print(f"\n\nLoading {model}...") vae = AutoencoderKL.from_pretrained( "madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16, ) pipeline = StableDiffusionXLPipeline.from_pretrained( model, vae=vae, torch_dtype=torch.float16, ) if use_lora and lora != "": pipeline.load_lora_weights(lora) pipeline.to("cuda") models[model] = pipeline return "Model/LoRA loaded successfully!" except Exception as e: return f"Error loading model {model}: {e}" @spaces.GPU def generate_images( model_name, lora_name, prompt, negative_prompt, num_inference_steps, guidance_scale, height, width, num_images=4, progress=gr.Progress(track_tqdm=True) ): if prompt is not None and prompt.strip() != "": pipe = models.get(model_name) if pipe is None: return [] outputs = [] for _ in range(num_images): output = pipe( prompt, negative_prompt=negative_prompt, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, height=height, width=width )["images"][0] outputs.append(output) return outputs else: return gr.Warning("Prompt empty!") # Create the Gradio blocks with gr.Blocks(theme='ParityError/Interstellar') as demo: with gr.Row(equal_height=False): with gr.Tab("Generate"): with gr.Column(elem_id="input_column"): with gr.Group(elem_id="input_group"): model_dropdown = gr.Dropdown(choices=models_list, value=models_list[0] if models_list else None, label="Model", elem_id="model_dropdown") lora_dropdown = gr.Dropdown(choices=loras_list, value=loras_list[0], label="LoRA") refresh_btn = gr.Button("Refresh Dropdowns") prompt = gr.Textbox(label="Prompt", elem_id="prompt_textbox") generate_btn = gr.Button("Generate Image", elem_id="generate_button") with gr.Accordion("Advanced", open=False, elem_id="advanced_accordion"): negative_prompt = gr.Textbox(label="Negative Prompt", value="lowres, (bad), text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality, watermark, unfinished, displeasing, oldest, early, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract]", elem_id="negative_prompt_textbox") num_inference_steps = gr.Slider(minimum=10, maximum=50, step=1, value=25, label="Number of Inference Steps", elem_id="num_inference_steps_slider") guidance_scale = gr.Slider(minimum=1, maximum=20, step=0.5, value=7.5, label="Guidance Scale", elem_id="guidance_scale_slider") height = gr.Slider(minimum=1024, maximum=2048, step=256, value=1024, label="Height", elem_id="height_slider") width = gr.Slider(minimum=1024, maximum=2048, step=256, value=1024, label="Width", elem_id="width_slider") num_images = gr.Slider(minimum=1, maximum=4, step=1, value=4, label="Number of Images", elem_id="num_images_slider") with gr.Column(elem_id="output_column"): output_gallery = gr.Gallery(label="Generated Images", height=480, scale=1, elem_id="output_gallery") refresh_btn.click(refresh_dropdowns, outputs=[model_dropdown, lora_dropdown]) generate_btn.click(generate_images, inputs=[model_dropdown, lora_dropdown, prompt, negative_prompt, num_inference_steps, guidance_scale, height, width, num_images], outputs=output_gallery) with gr.Tab("Download Custom Model"): with gr.Group(): model_id = gr.Textbox(label="CivitAI Model ID") lora_id = gr.Textbox(label="CivitAI LoRA ID (Optional)") download_button = gr.Button("Download Model") download_output = gr.Textbox(label="Download Output") download_button.click(download_and_load_civitai_model, inputs=[model_id, lora_id], outputs=download_output) demo.launch()