DiffModels / app.py
SemaSci's picture
Update app.py
1fd63c5 verified
raw
history blame
10.9 kB
# это файл только с LoRA, без ControlNet и IpAdapter
import gradio as gr
import numpy as np
import random
# import spaces #[uncomment to use ZeroGPU]
from diffusers import DiffusionPipeline
import torch
from peft import PeftModel, LoraConfig
import os
# Добавляем глобальный кэш для пайплайнов
pipe_cache = {}
def get_lora_sd_pipeline(
ckpt_dir='./lora_logos',
base_model_name_or_path=None,
dtype=torch.float16,
adapter_name="default"
):
unet_sub_dir = os.path.join(ckpt_dir, "unet")
text_encoder_sub_dir = os.path.join(ckpt_dir, "text_encoder")
if os.path.exists(text_encoder_sub_dir) and base_model_name_or_path is None:
config = LoraConfig.from_pretrained(text_encoder_sub_dir)
base_model_name_or_path = config.base_model_name_or_path
if base_model_name_or_path is None:
raise ValueError("Please specify the base model name or path")
pipe = DiffusionPipeline.from_pretrained(base_model_name_or_path, torch_dtype=dtype)
before_params = pipe.unet.parameters()
pipe.unet = PeftModel.from_pretrained(pipe.unet, unet_sub_dir, adapter_name=adapter_name)
pipe.unet.set_adapter(adapter_name)
after_params = pipe.unet.parameters()
print("UNet Parameters changed:", any(torch.any(b != a) for b, a in zip(before_params, after_params)))
if os.path.exists(text_encoder_sub_dir):
pipe.text_encoder = PeftModel.from_pretrained(pipe.text_encoder, text_encoder_sub_dir, adapter_name=adapter_name)
if dtype in (torch.float16, torch.bfloat16):
pipe.unet.half()
if pipe.text_encoder is not None:
pipe.text_encoder.half()
return pipe
def process_prompt(prompt, tokenizer, text_encoder, max_length=77):
tokens = tokenizer(prompt, truncation=False, return_tensors="pt")["input_ids"]
chunks = [tokens[:, i:i + max_length] for i in range(0, tokens.shape[1], max_length)]
with torch.no_grad():
embeds = [text_encoder(chunk.to(text_encoder.device))[0] for chunk in chunks]
return torch.cat(embeds, dim=1)
def align_embeddings(prompt_embeds, negative_prompt_embeds):
max_length = max(prompt_embeds.shape[1], negative_prompt_embeds.shape[1])
return torch.nn.functional.pad(prompt_embeds, (0, 0, 0, max_length - prompt_embeds.shape[1])), \
torch.nn.functional.pad(negative_prompt_embeds, (0, 0, 0, max_length - negative_prompt_embeds.shape[1]))
device = "cuda" if torch.cuda.is_available() else "cpu"
#model_repo_id = "stabilityai/sdxl-turbo" # Replace to the model you would like to use
model_id_default = "sd-legacy/stable-diffusion-v1-5"
model_dropdown = ['stabilityai/sdxl-turbo', 'CompVis/stable-diffusion-v1-4', 'sd-legacy/stable-diffusion-v1-5' ]
model_lora_default = "lora_pussinboots_logos"
model_lora_dropdown = ['lora_lady_and_cats_logos', 'lora_pussinboots_logos' ]
if torch.cuda.is_available():
torch_dtype = torch.float16
else:
torch_dtype = torch.float32
# pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
# pipe = pipe.to(device)
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 1024
# @spaces.GPU #[uncomment to use ZeroGPU]
def infer(
prompt,
negative_prompt,
randomize_seed,
width=512,
height=512,
model_repo_id=model_id_default,
seed=42,
guidance_scale=7,
num_inference_steps=20,
model_lora_id=model_lora_default,
lora_scale=0.5,
progress=gr.Progress(track_tqdm=True),
):
global pipe_cache
if randomize_seed:
seed = random.randint(0, MAX_SEED)
generator = torch.Generator().manual_seed(seed)
# Кэширование пайплайнов
cache_key = f"{model_repo_id}_{model_lora_id}"
if cache_key not in pipe_cache:
if model_repo_id != model_id_default:
pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype).to(device)
prompt_embeds = process_prompt(prompt, pipe.tokenizer, pipe.text_encoder)
negative_prompt_embeds = process_prompt(negative_prompt, pipe.tokenizer, pipe.text_encoder)
prompt_embeds, negative_prompt_embeds = align_embeddings(prompt_embeds, negative_prompt_embeds)
else:
pipe = get_lora_sd_pipeline(
ckpt_dir='./'+model_lora_id,
base_model_name_or_path=model_id_default,
dtype=torch_dtype
).to(device)
pipe_cache[cache_key] = pipe
else:
pipe = pipe_cache[cache_key]
# Динамическое применение масштаба LoRA
if model_repo_id == model_id_default:
# Убираем fuse_lora()
# pipe.fuse_lora(lora_scale=lora_scale) # Закомментировали проблемную строку
# Вместо этого устанавливаем адаптеры динамически
pipe.unet.set_adapters(
[model_lora_id],
adapter_weights=[lora_scale]
)
if hasattr(pipe, 'text_encoder') and pipe.text_encoder is not None:
pipe.text_encoder.set_adapters(
[model_lora_id],
adapter_weights=[lora_scale]
)
print(f"Active adapters - UNet: {pipe.unet.active_adapters}, Text Encoder: {pipe.text_encoder.active_adapters if hasattr(pipe, 'text_encoder') else None}")
print(f"LoRA scale applied: {lora_scale}")
# на вызов pipe с эмбеддингами
params = {
'prompt_embeds': prompt_embeds,
'negative_prompt_embeds': negative_prompt_embeds,
'guidance_scale': guidance_scale,
'num_inference_steps': num_inference_steps,
'width': width,
'height': height,
'generator': generator,
}
return pipe(**params).images[0], seed
# return image, seed
examples = [
"Puss in Boots wearing a sombrero crosses the Grand Canyon on a tightrope with a guitar.",
"A cat is playing a song called ""About the Cat"" on an accordion by the sea at sunset. The sun is quickly setting behind the horizon, and the light is fading.",
"A cat walks through the grass on the streets of an abandoned city. The camera view is always focused on the cat's face.",
"A young lady in a Russian embroidered kaftan is sitting on a beautiful carved veranda, holding a cup to her mouth and drinking tea from the cup. With her other hand, the girl holds a saucer. The cup and saucer are painted with gzhel. Next to the girl on the table stands a samovar, and steam can be seen above it.",
"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
"An astronaut riding a green horse",
"A delicious ceviche cheesecake slice",
]
css = """
#col-container {
margin: 0 auto;
max-width: 640px;
}
"""
with gr.Blocks(css=css) as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown(" # Text-to-Image SemaSci Template")
with gr.Row():
prompt = gr.Text(
label="Prompt",
show_label=False,
max_lines=1,
placeholder="Enter your prompt",
container=False,
)
run_button = gr.Button("Run", scale=0, variant="primary")
result = gr.Image(label="Result", show_label=False)
with gr.Accordion("Advanced Settings", open=False):
# model_repo_id = gr.Text(
# label="Model Id",
# max_lines=1,
# placeholder="Choose model",
# visible=True,
# value=model_repo_id,
# )
model_repo_id = gr.Dropdown(
label="Model Id",
choices=model_dropdown,
info="Choose model",
visible=True,
allow_custom_value=True,
# value=model_repo_id,
value=model_id_default,
)
negative_prompt = gr.Text(
label="Negative prompt",
max_lines=1,
placeholder="Enter a negative prompt",
visible=True,
)
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=42,
)
randomize_seed = gr.Checkbox(label="Randomize seed", value=False)
with gr.Row():
width = gr.Slider(
label="Width",
minimum=256,
maximum=MAX_IMAGE_SIZE,
step=32,
value=256, # Replace with defaults that work for your model
)
height = gr.Slider(
label="Height",
minimum=256,
maximum=MAX_IMAGE_SIZE,
step=32,
value=256, # Replace with defaults that work for your model
)
with gr.Row():
guidance_scale = gr.Slider(
label="Guidance scale",
minimum=0.0,
maximum=10.0,
step=0.1,
value=7.0, # Replace with defaults that work for your model
)
num_inference_steps = gr.Slider(
label="Number of inference steps",
minimum=1,
maximum=50,
step=1,
value=20, # Replace with defaults that work for your model
)
with gr.Row():
model_lora_id = gr.Dropdown(
label="Lora Id",
choices=model_lora_dropdown,
info="Choose LoRA model",
visible=True,
allow_custom_value=True,
value=model_lora_default,
)
lora_scale = gr.Slider(
label="LoRA scale",
minimum=0.0,
maximum=1.0,
step=0.1,
value=0.5,
)
gr.Examples(examples=examples, inputs=[prompt])
gr.on(
triggers=[run_button.click, prompt.submit],
fn=infer,
inputs=[
prompt,
negative_prompt,
randomize_seed,
width,
height,
model_repo_id,
seed,
guidance_scale,
num_inference_steps,
model_lora_id,
lora_scale,
],
outputs=[result, seed],
)
if __name__ == "__main__":
demo.launch()