Tech-Meld's picture
Update app.py
f25f274 verified
import torch
from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
from tqdm.auto import tqdm
from huggingface_hub import hf_hub_url, login, HfApi, create_repo
import os
import traceback
from peft import PeftModel
import gradio as gr
def display_image(image):
"""Display the generated image."""
return image
def load_and_merge_lora(base_model_id, lora_id, lora_adapter_name):
try:
pipe = DiffusionPipeline.from_pretrained(
base_model_id,
torch_dtype=torch.float16,
variant="fp16",
use_safetensors=True,
).to("cpu")
pipe.scheduler = DPMSolverMultistepScheduler.from_config(
pipe.scheduler.config
)
# Get the UNet model from the pipeline
unet = pipe.unet
# Apply PEFT to the UNet model
unet = PeftModel.from_pretrained(
unet,
lora_id,
torch_dtype=torch.float16,
adapter_name=lora_adapter_name
)
# Replace the original UNet in the pipeline with the PEFT-loaded one
pipe.unet = unet
print("LoRA merged successfully!")
return pipe
except Exception as e:
error_msg = traceback.format_exc()
print(f"Error merging LoRA: {e}\n\nFull traceback saved to errors.txt")
with open("errors.txt", "w") as f:
f.write(error_msg)
return None
def save_merged_model(pipe, save_path, push_to_hub=False, hf_token=None):
"""Saves and optionally pushes the merged model to Hugging Face Hub."""
try:
pipe.save_pretrained(save_path)
print(f"Merged model saved successfully to: {save_path}")
if push_to_hub:
if hf_token is None:
hf_token = input("Enter your Hugging Face write token: ")
login(token=hf_token)
repo_name = input("Enter the Hugging Face repository name "
"(e.g., your_username/your_model_name): ")
# Create the repository if it doesn't exist
create_repo(repo_name, token=hf_token, exist_ok=True)
api = HfApi()
api.upload_folder(
folder_path=save_path,
repo_id=repo_name,
token=hf_token,
repo_type="model",
)
print(f"Model pushed successfully to Hugging Face Hub: {repo_name}")
except Exception as e:
print(f"Error saving/pushing the merged model: {e}")
def generate_and_save(base_model_id, lora_id, lora_adapter_name, prompt, lora_scale, save_path, push_to_hub, hf_token):
pipe = load_and_merge_lora(base_model_id, lora_id, lora_adapter_name)
if pipe:
lora_scale = float(lora_scale)
image = pipe(
prompt,
num_inference_steps=30,
cross_attention_kwargs={"scale": lora_scale},
generator=torch.manual_seed(0)
).images[0]
image.save("generated_image.png")
print(f"Image saved to: generated_image.png")
save_merged_model(pipe, save_path, push_to_hub, hf_token)
return image, "Image generated and model saved/pushed (if selected)."
iface = gr.Interface(
fn=generate_and_save,
inputs=[
gr.Textbox(label="Base Model ID (e.g., stabilityai/stable-diffusion-xl-base-1.0)"),
gr.Textbox(label="LoRA ID (e.g., your_username/your_lora)"),
gr.Textbox(label="LoRA Adapter Name"),
gr.Textbox(label="Prompt"),
gr.Slider(label="LoRA Scale", minimum=0.0, maximum=1.0, value=0.7, step=0.1),
gr.Textbox(label="Save Path"),
gr.Checkbox(label="Push to Hugging Face Hub"),
gr.Textbox(label="Hugging Face Write Token", type="password")
],
outputs=[
gr.Image(label="Generated Image"),
gr.Textbox(label="Status")
],
title="LoRA Merger and Image Generator",
description="Merge a LoRA with a base Stable Diffusion model and generate images."
)
iface.launch()