rafaaa2105's picture
Update app.py
303d4ef verified
raw
history blame
7.9 kB
import os
import torch
from tqdm import tqdm
from diffusers.models import AutoencoderKL
from diffusers import StableDiffusionXLPipeline
import gradio as gr
import requests
import spaces
# Ensure directories exist
os.makedirs('models', exist_ok=True)
os.makedirs('loras', exist_ok=True)
models_list = []
loras_list = ["None"]
models = {}
def download_file(url, filename, progress=gr.Progress(track_tqdm=True)):
response = requests.get(url, stream=True)
total_size_in_bytes = int(response.headers.get('content-length', 0))
block_size = 1024 # 1 Kibibyte
progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True)
with open(filename, 'wb') as file:
for data in response.iter_content(block_size):
progress_bar.update(len(data))
file.write(data)
progress_bar.close()
if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes:
print("ERROR, something went wrong")
def get_civitai_model_info(model_id):
url = f"https://civitai.com/api/v1/models/{model_id}"
response = requests.get(url)
if response.status_code != 200:
return None
return response.json()
def find_download_url(data, file_extension):
for file in data.get('modelVersions', [{}])[0].get('files', []):
if file['name'].endswith(file_extension):
return file['downloadUrl']
else:
return None
def download_and_load_civitai_model(model_id, lora_id="", progress=gr.Progress(track_tqdm=True)):
try:
model_data = get_civitai_model_info(model_id)
if model_data is None:
return f"Error: Model with ID {model_id} not found."
model_name = model_data['name']
model_ckpt_url = find_download_url(model_data, '.ckpt')
model_safetensors_url = find_download_url(model_data, '.safetensors')
model_url = model_ckpt_url or model_safetensors_url
if not model_url:
return f"Error: No suitable file found for model {model_name}."
file_extension = '.ckpt' if model_ckpt_url else '.safetensors'
model_filename = os.path.join('models', f"{model_name}{file_extension}")
download_file(model_url, model_filename)
if lora_id:
lora_data = get_civitai_model_info(lora_id)
if lora_data is None:
return f"Error: LoRA with ID {lora_id} not found."
lora_name = lora_data['name']
lora_safetensors_url = find_download_url(lora_data, '.safetensors')
if not lora_safetensors_url:
return f"Error: No suitable file found for LoRA {lora_name}."
lora_filename = os.path.join('loras', f"{lora_name}.safetensors")
download_file(lora_safetensors_url, lora_filename)
if lora_name not in loras_list:
loras_list.append(lora_name)
else:
lora_name = "None"
if model_name not in models_list:
models_list.append(model_name)
# Load model after downloading
load_result = load_model(model_filename, lora_name, use_lora=(lora_name != "None"))
return f"Model/LoRA Downloaded and Loaded! {load_result}"
except Exception as e:
return f"Error downloading model or LoRA: {e}"
def refresh_dropdowns():
return gr.update(choices=models_list), gr.update(choices=loras_list)
def load_model(model, lora="", use_lora=False, progress=gr.Progress(track_tqdm=True)):
try:
print(f"\n\nLoading {model}...")
gr.Info(f"Loading {model}, it may take a while.")
vae = AutoencoderKL.from_pretrained(
"madebyollin/sdxl-vae-fp16-fix",
torch_dtype=torch.float16,
)
pipeline = StableDiffusionXLPipeline.from_single_file(
model,
vae=vae,
torch_dtype=torch.float16,
)
if use_lora and lora != "":
lora_path = os.path.join('loras', lora + '.safetensors')
pipeline.load_lora_weights(lora_path)
pipeline.to("cuda")
models[model] = pipeline
return "Model/LoRA loaded successfully!"
except Exception as e:
return f"Error loading model {model}: {e}"
@spaces.GPU
def generate_images(
model_name,
lora_name,
prompt,
negative_prompt,
num_inference_steps,
guidance_scale,
height,
width,
num_images=4,
progress=gr.Progress(track_tqdm=True)
):
if prompt is not None and prompt.strip() != "":
pipe = models.get(model_name)
if pipe is None:
return []
outputs = []
for _ in range(num_images):
output = pipe(
prompt,
negative_prompt=negative_prompt,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
height=height,
width=width
)["images"][0]
outputs.append(output)
return outputs
else:
return gr.Warning("Prompt empty!")
# Create the Gradio blocks
with gr.Blocks(theme='ParityError/Interstellar') as demo:
with gr.Row(equal_height=False):
with gr.Tab("Generate"):
with gr.Column(elem_id="input_column"):
with gr.Group(elem_id="input_group"):
model_dropdown = gr.Dropdown(choices=models_list, value=models_list[0] if models_list else None, label="Model", elem_id="model_dropdown")
lora_dropdown = gr.Dropdown(choices=loras_list, value=loras_list[0], label="LoRA")
refresh_btn = gr.Button("Refresh Dropdowns")
prompt = gr.Textbox(label="Prompt", elem_id="prompt_textbox")
generate_btn = gr.Button("Generate Image", elem_id="generate_button")
with gr.Accordion("Advanced", open=False, elem_id="advanced_accordion"):
negative_prompt = gr.Textbox(label="Negative Prompt", value="lowres, (bad), text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality, watermark, unfinished, displeasing, oldest, early, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract]", elem_id="negative_prompt_textbox")
num_inference_steps = gr.Slider(minimum=10, maximum=50, step=1, value=25, label="Number of Inference Steps", elem_id="num_inference_steps_slider")
guidance_scale = gr.Slider(minimum=1, maximum=20, step=0.5, value=7.5, label="Guidance Scale", elem_id="guidance_scale_slider")
height = gr.Slider(minimum=1024, maximum=2048, step=256, value=1024, label="Height", elem_id="height_slider")
width = gr.Slider(minimum=1024, maximum=2048, step=256, value=1024, label="Width", elem_id="width_slider")
num_images = gr.Slider(minimum=1, maximum=4, step=1, value=4, label="Number of Images", elem_id="num_images_slider")
with gr.Column(elem_id="output_column"):
output_gallery = gr.Gallery(label="Generated Images", height=480, scale=1, elem_id="output_gallery")
refresh_btn.click(refresh_dropdowns, outputs=[model_dropdown, lora_dropdown])
generate_btn.click(generate_images, inputs=[model_dropdown, lora_dropdown, prompt, negative_prompt, num_inference_steps, guidance_scale, height, width, num_images], outputs=output_gallery)
with gr.Tab("Download Custom Model"):
with gr.Group():
model_id = gr.Textbox(label="CivitAI Model ID")
lora_id = gr.Textbox(label="CivitAI LoRA ID (Optional)")
download_button = gr.Button("Download Model")
download_output = gr.Textbox(label="Download Output")
download_button.click(download_and_load_civitai_model, inputs=[model_id, lora_id], outputs=download_output)
demo.launch()