fffiloni's picture
Update app.py
ba844d5
raw
history blame
4.92 kB
import gradio as gr
from huggingface_hub import login
from diffusers import DiffusionPipeline, StableDiffusionXLPipeline
import torch
import copy
import os
import spaces
import random
hf_token = os.environ.get("HF_TOKEN")
login(token = hf_token)
original_pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16)
@spaces.GPU
def infer(lora_1_id, lora_1_sfts, lora_2_id, lora_2_sfts, prompt, negative_prompt, lora_1_scale, lora_2_scale, seed):
unet = copy.deepcopy(original_pipe.unet)
text_encoder = copy.deepcopy(original_pipe.text_encoder)
text_encoder_2 = copy.deepcopy(original_pipe.text_encoder_2)
pipe = StableDiffusionXLPipeline(
vae = original_pipe.vae,
text_encoder = text_encoder,
text_encoder_2 = text_encoder_2,
scheduler = original_pipe.scheduler,
tokenizer = original_pipe.tokenizer,
tokenizer_2 = original_pipe.tokenizer_2,
unet = unet
)
pipe.to("cuda")
pipe.load_lora_weights(
lora_1_id,
weight_name = lora_1_sfts,
low_cpu_mem_usage = True,
use_auth_token = True
)
pipe.fuse_lora(lora_1_scale)
pipe.load_lora_weights(
lora_2_id,
weight_name = lora_2_sfts,
low_cpu_mem_usage = True,
use_auth_token = True
)
pipe.fuse_lora(lora_2_scale)
if negative_prompt == "" :
negative_prompt = None
if seed < 0 :
seed = random.randit(0, 423538377342)
generator = torch.Generator(device="cuda").manual_seed(seed)
image = pipe(
prompt = prompt,
negative_prompt = negative_prompt,
num_inference_steps = 25,
width = 1024,
height = 1024,
generator = generator
).images[0]
return image, seed
with gr.Blocks() as demo:
with gr.Column(elem_id="col-container"):
title = gr.HTML(
'''
<h1 style="text-align: center;">LoRA Fusion</h1>
<p style="text-align: center;">Fuse 2 custom LoRa models</p>
'''
)
# PART 1 • MODELS
with gr.Row():
with gr.Column():
lora_1_id = gr.Textbox(
label = "LoRa 1 ID",
placeholder = "username/model_id"
)
lora_1_sfts = gr.Textbox(
label = "Safetensors file",
placeholder = "specific_chosen.safetensors"
)
with gr.Column():
lora_2_id = gr.Textbox(
label = "LoRa 2 ID",
placeholder = "username/model_id"
)
lora_2_sfts = gr.Textbox(
label = "Safetensors file",
placeholder = "specific_chosen.safetensors"
)
# PART 2 • INFERENCE
with gr.Row():
prompt = gr.Textbox(
label = "Your prompt",
info = "Use your trigger words into a coherent prompt"
placeholde = "e.g: a triggerWordOne portrait in triggerWord2 style"
)
run_btn = gr.Button("Run")
output_image = gr.Image(
label = "Output"
)
# Advanced Settings
with gr.Accordion("Advanced Settings", open=False):
with gr.Row():
lora_1_scale = gr.Slider(
label = "LoRa 1 scale",
minimum = 0,
maximum = 1,
steps = 0.1,
value = 0.7
)
lora_2_scale = gr.Slider(
label = "LoRa 2 scale",
minimum = 0,
maximum = 1,
steps = 0.1,
value = 0.7
)
negative_prompt = gr.Textbox(
label = "Negative prompt"
)
seed = gr.Slider(
label = "Seed",
info = "-1 denotes a random seed",
minimum = -1,
maximum = 423538377342,
value = -1
)
last_used_seed = gr.Number(
label = Last used seed,
info = "the seed used in the last generation",
)
# ACTIONS
run_btn.click(
fn = infer,
inputs = [
lora_1_id,
lora_1_sfts,
lora_2_id,
lora_2_sfts,
prompt,
negative_prompt,
lora_1_scale,
lora_2_scale,
seed
],
outputs = [
output_image,
last_used_seed
]
)
demo.queue(concurrency_count=2).launch()