trainloraf / app.py
LHRuig's picture
Upload 3 files
d41cf19 verified
import os
import subprocess
import gradio as gr
from PIL import Image
import torch
from transformers import Blip2Processor, Blip2ForConditionalGeneration
# ===== 1. Initialize BLIP-2 for Auto-Captioning =====
def load_blip_model():
device = "cuda" if torch.cuda.is_available() else "cpu"
processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
model = Blip2ForConditionalGeneration.from_pretrained(
"Salesforce/blip2-opt-2.7b",
torch_dtype=torch.float16
).to(device)
return processor, model, device
processor, model, device = load_blip_model()
def generate_caption(image_path, trigger_word):
image = Image.open(image_path)
inputs = processor(image, return_tensors="pt").to(device, torch.float16)
generated_ids = model.generate(**inputs, max_new_tokens=50)
caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
return f"a photo of [{trigger_word}], {caption}"
# ===== 2. Install Kohya_SS Manually =====
if not os.path.exists("kohya_ss"):
print("⬇️ Installing Kohya_SS...")
os.system("git clone https://github.com/bmaltais/kohya_ss")
os.system("cd kohya_ss && pip install -r requirements.txt")
os.system("cd kohya_ss && pip install .")
# ===== 3. Training Function =====
def train_lora(images, trigger_word, progress=gr.Progress()):
progress(0.1, desc="Preparing data...")
# Save images + auto-caption
os.makedirs("train", exist_ok=True)
for i, img in enumerate(progress.tqdm(images, desc="Processing images")):
img_path = f"train/img_{i}.jpg"
img.save(img_path)
caption = generate_caption(img_path, trigger_word)
with open(f"train/img_{i}.txt", "w") as f:
f.write(caption)
# Train LoRA (optimized for HF Spaces T4 GPU)
cmd = """
python kohya_ss/train_network.py \
--pretrained_model_name_or_path="stabilityai/stable-diffusion-xl-base-1.0" \
--train_data_dir="train" \
--output_dir="output" \
--resolution=512 \
--network_dim=32 \
--lr=1e-4 \
--max_train_steps=800 \
--mixed_precision="fp16" \
--save_precision="fp16" \
--optimizer_type="AdamW8bit" \
--xformers
"""
progress(0.8, desc="Training LoRA...")
subprocess.run(cmd, shell=True, check=True)
return "output/lora.safetensors"
# ===== 4. Gradio UI =====
with gr.Blocks(title="1-Click LoRA Trainer") as demo:
gr.Markdown("""
## 🎨 Weights.gg-Style LoRA Trainer
Upload 30 images + set a trigger word to train a custom LoRA.
""")
with gr.Row():
with gr.Column():
images = gr.Files(
label="Upload Character Images (30 max)",
file_types=["image"],
interactive=True
)
trigger = gr.Textbox(
label="Trigger Word",
placeholder="E.g., 'my_char' (used as [my_char] in prompts)"
)
train_btn = gr.Button("πŸš€ Train LoRA", variant="primary")
with gr.Column():
output = gr.File(label="Download LoRA")
gallery = gr.Gallery(label="Training Preview")
train_btn.click(
train_lora,
inputs=[images, trigger],
outputs=output,
api_name="train"
)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)