Spaces:
Sleeping
Sleeping
import torch | |
from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler | |
from tqdm.auto import tqdm | |
from huggingface_hub import hf_hub_url, login, HfApi, create_repo | |
import os | |
import traceback | |
from peft import PeftModel | |
import gradio as gr | |
def display_image(image): | |
"""Display the generated image.""" | |
return image | |
def load_and_merge_lora(base_model_id, lora_id, lora_adapter_name): | |
try: | |
pipe = DiffusionPipeline.from_pretrained( | |
base_model_id, | |
torch_dtype=torch.float16, | |
variant="fp16", | |
use_safetensors=True, | |
).to("cpu") | |
pipe.scheduler = DPMSolverMultistepScheduler.from_config( | |
pipe.scheduler.config | |
) | |
# Get the UNet model from the pipeline | |
unet = pipe.unet | |
# Apply PEFT to the UNet model | |
unet = PeftModel.from_pretrained( | |
unet, | |
lora_id, | |
torch_dtype=torch.float16, | |
adapter_name=lora_adapter_name | |
) | |
# Replace the original UNet in the pipeline with the PEFT-loaded one | |
pipe.unet = unet | |
print("LoRA merged successfully!") | |
return pipe | |
except Exception as e: | |
error_msg = traceback.format_exc() | |
print(f"Error merging LoRA: {e}\n\nFull traceback saved to errors.txt") | |
with open("errors.txt", "w") as f: | |
f.write(error_msg) | |
return None | |
def save_merged_model(pipe, save_path, push_to_hub=False, hf_token=None): | |
"""Saves and optionally pushes the merged model to Hugging Face Hub.""" | |
try: | |
pipe.save_pretrained(save_path) | |
print(f"Merged model saved successfully to: {save_path}") | |
if push_to_hub: | |
if hf_token is None: | |
hf_token = input("Enter your Hugging Face write token: ") | |
login(token=hf_token) | |
repo_name = input("Enter the Hugging Face repository name " | |
"(e.g., your_username/your_model_name): ") | |
# Create the repository if it doesn't exist | |
create_repo(repo_name, token=hf_token, exist_ok=True) | |
api = HfApi() | |
api.upload_folder( | |
folder_path=save_path, | |
repo_id=repo_name, | |
token=hf_token, | |
repo_type="model", | |
) | |
print(f"Model pushed successfully to Hugging Face Hub: {repo_name}") | |
except Exception as e: | |
print(f"Error saving/pushing the merged model: {e}") | |
def generate_and_save(base_model_id, lora_id, lora_adapter_name, prompt, lora_scale, save_path, push_to_hub, hf_token): | |
pipe = load_and_merge_lora(base_model_id, lora_id, lora_adapter_name) | |
if pipe: | |
lora_scale = float(lora_scale) | |
image = pipe( | |
prompt, | |
num_inference_steps=30, | |
cross_attention_kwargs={"scale": lora_scale}, | |
generator=torch.manual_seed(0) | |
).images[0] | |
image.save("generated_image.png") | |
print(f"Image saved to: generated_image.png") | |
save_merged_model(pipe, save_path, push_to_hub, hf_token) | |
return image, "Image generated and model saved/pushed (if selected)." | |
iface = gr.Interface( | |
fn=generate_and_save, | |
inputs=[ | |
gr.Textbox(label="Base Model ID (e.g., stabilityai/stable-diffusion-xl-base-1.0)"), | |
gr.Textbox(label="LoRA ID (e.g., your_username/your_lora)"), | |
gr.Textbox(label="LoRA Adapter Name"), | |
gr.Textbox(label="Prompt"), | |
gr.Slider(label="LoRA Scale", minimum=0.0, maximum=1.0, value=0.7, step=0.1), | |
gr.Textbox(label="Save Path"), | |
gr.Checkbox(label="Push to Hugging Face Hub"), | |
gr.Textbox(label="Hugging Face Write Token", type="password") | |
], | |
outputs=[ | |
gr.Image(label="Generated Image"), | |
gr.Textbox(label="Status") | |
], | |
title="LoRA Merger and Image Generator", | |
description="Merge a LoRA with a base Stable Diffusion model and generate images." | |
) | |
iface.launch() |