Spaces:
Sleeping
Sleeping
import torch | |
from diffusers.models import AutoencoderKL | |
from diffusers import StableDiffusionXLPipeline, StableDiffusionXLImg2ImgPipeline | |
import gradio as gr | |
import subprocess | |
import requests | |
#import spaces | |
models_list = [] | |
loras_list = [ "None" ] | |
models = {} | |
def download_file(url, filename, progress=gr.Progress(track_tqdm=True)): | |
response = requests.get(url, stream=True) | |
total_size_in_bytes= int(response.headers.get('content-length', 0)) | |
block_size = 1024 #1 Kibibyte | |
progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True) | |
with open(filename, 'wb') as file: | |
for data in response.iter_content(block_size): | |
progress_bar.update(len(data)) | |
file.write(data) | |
progress_bar.close() | |
if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes: | |
print("ERROR, something went wrong") | |
def download_civitai_model(model_id, lora_id = ""): | |
if model_id.startswith("http"): | |
headers = { | |
"Content-Type": "application/json" | |
} | |
response = requests.get(model_id, headers=headers) | |
# The response is a requests.Response object, and you can get the JSON content like this: | |
data = response.json() | |
# The model name should be accessible as: | |
model_name = data['name'] | |
download_file(model_id, model_name) | |
else: | |
model_url = "https://civitai.com/api/download/models/{model_id}" | |
headers = { | |
"Content-Type": "application/json" | |
} | |
response = requests.get(model_url, headers=headers) | |
# The response is a requests.Response object, and you can get the JSON content like this: | |
data = response.json() | |
# The model name should be accessible as: | |
model_name = data['name'] | |
download_file(model_url, model_name) | |
if lora_id.startswith("http"): | |
headers = { | |
"Content-Type": "application/json" | |
} | |
response = requests.get(model_id, headers=headers) | |
# The response is a requests.Response object, and you can get the JSON content like this: | |
data = response.json() | |
# The model name should be accessible as: | |
model_name = data['name'] | |
download_file(lora_id, lora_name) | |
elif lora_id != None or "": | |
lora_url = "https://civitai.com/api/download/models/{lora_id}" | |
headers = { | |
"Content-Type": "application/json" | |
} | |
response = requests.get(lora_url, headers=headers) | |
# The response is a requests.Response object, and you can get the JSON content like this: | |
data = response.json() | |
# The model name should be accessible as: | |
lora_name = data['name'] | |
download_file(lora_id, lora_name) | |
models_list.append(model_name) | |
loras_list.append(lora_name) | |
return "Model/LoRA Downloaded!" | |
def load_model(model, lora = "", use_lora = False): | |
try: | |
print(f"\n\nLoading {model}...") | |
vae = AutoencoderKL.from_pretrained( | |
"madebyollin/sdxl-vae-fp16-fix", | |
torch_dtype=torch.float16, | |
) | |
pipeline = ( | |
StableDiffusionXLPipeline.from_pretrained | |
) | |
models[model] = pipeline( | |
model, | |
vae=vae, | |
torch_dtype=torch.float16, | |
custom_pipeline="lpw_stable_diffusion_xl", | |
add_watermarker=False, | |
) | |
if use_lora and lora != "": | |
models[model].load_lora_weights(lora) | |
models[model].to("cuda") | |
return "Model/LoRA downloaded successfully!" | |
except Exception as e: | |
gr.Error(f"Error loading model {model}: {e}") | |
print(f"Error loading model {model}: {e}") | |
#@spaces.GPU | |
def generate_images( | |
model_name, | |
lora_name, | |
prompt, | |
negative_prompt, | |
num_inference_steps, | |
guidance_scale, | |
height, | |
width, | |
num_images=4, | |
progress=gr.Progress(track_tqdm=True) | |
): | |
if prompt is not None and prompt.strip() != "": | |
if lora_name == "None": | |
load_model(model_name, "", False) | |
elif lora_name in loras_list and lora_name != "None": | |
load_model(model_name, lora_name, True) | |
pipe = models.get(model_name) | |
if pipe is None: | |
return [] | |
outputs = [] | |
for _ in range(num_images): | |
output = pipe( | |
prompt, | |
negative_prompt=negative_prompt, | |
num_inference_steps=num_inference_steps, | |
guidance_scale=guidance_scale, | |
height=height, | |
width=width | |
)["images"][0] | |
outputs.append(output) | |
return outputs | |
else: | |
gr.Warning("Prompt empty!") | |
# Create the Gradio blocks | |
with gr.Blocks(theme='ParityError/Interstellar') as demo: | |
with gr.Row(equal_height=False): | |
with gr.Tab("Generate"): | |
with gr.Column(elem_id="input_column"): | |
with gr.Group(elem_id="input_group"): | |
model_dropdown = gr.Dropdown(choices=models_list, value=models_list[0] if models_list else None, label="Model", elem_id="model_dropdown") | |
lora_dropdown = gr.Dropdown(choices=loras_list, value=loras_list[0], label="LoRA") | |
prompt = gr.Textbox(label="Prompt", elem_id="prompt_textbox") | |
generate_btn = gr.Button("Generate Image", elem_id="generate_button") | |
with gr.Accordion("Advanced", open=False, elem_id="advanced_accordion"): | |
negative_prompt = gr.Textbox(label="Negative Prompt", value="lowres, (bad), text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality, watermark, unfinished, displeasing, oldest, early, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract]", elem_id="negative_prompt_textbox") | |
num_inference_steps = gr.Slider(minimum=10, maximum=50, step=1, value=25, label="Number of Inference Steps", elem_id="num_inference_steps_slider") | |
guidance_scale = gr.Slider(minimum=1, maximum=20, step=0.5, value=7.5, label="Guidance Scale", elem_id="guidance_scale_slider") | |
height = gr.Slider(minimum=1024, maximum=2048, step=256, value=1024, label="Height", elem_id="height_slider") | |
width = gr.Slider(minimum=1024, maximum=2048, step=256, value=1024, label="Width", elem_id="width_slider") | |
num_images = gr.Slider(minimum=1, maximum=4, step=1, value=4, label="Number of Images", elem_id="num_images_slider") | |
with gr.Column(elem_id="output_column"): | |
output_gallery = gr.Gallery(label="Generated Images", height=480, scale=1, elem_id="output_gallery") | |
generate_btn.click(generate_images, inputs=[model_dropdown, lora_dropdown, prompt, negative_prompt, num_inference_steps, guidance_scale, height, width, num_images], outputs=output_gallery) | |
with gr.Tab("Download Custom Model"): | |
with gr.Group(): | |
modelId = gr.Textbox(label="CivitAI Model ID") | |
loraId = gr.Textbox(label="CivitAI LoRA ID (Optional)") | |
download_button = gr.Button("Download Model") | |
download_output = gr.Textbox(label="Download Output") | |
download_button.click(download_civitai_model, inputs=[modelId, loraId], outputs=download_output) | |
demo.launch() |