import os import subprocess import gradio as gr from PIL import Image import torch from transformers import Blip2Processor, Blip2ForConditionalGeneration # ===== 1. Initialize BLIP-2 for Auto-Captioning ===== def load_blip_model(): device = "cuda" if torch.cuda.is_available() else "cpu" processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b") model = Blip2ForConditionalGeneration.from_pretrained( "Salesforce/blip2-opt-2.7b", torch_dtype=torch.float16 ).to(device) return processor, model, device processor, model, device = load_blip_model() def generate_caption(image_path, trigger_word): image = Image.open(image_path) inputs = processor(image, return_tensors="pt").to(device, torch.float16) generated_ids = model.generate(**inputs, max_new_tokens=50) caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip() return f"a photo of [{trigger_word}], {caption}" # ===== 2. Install Kohya_SS Manually ===== if not os.path.exists("kohya_ss"): print("⬇️ Installing Kohya_SS...") os.system("git clone https://github.com/bmaltais/kohya_ss") os.system("cd kohya_ss && pip install -r requirements.txt") os.system("cd kohya_ss && pip install .") # ===== 3. Training Function ===== def train_lora(images, trigger_word, progress=gr.Progress()): progress(0.1, desc="Preparing data...") # Save images + auto-caption os.makedirs("train", exist_ok=True) for i, img in enumerate(progress.tqdm(images, desc="Processing images")): img_path = f"train/img_{i}.jpg" img.save(img_path) caption = generate_caption(img_path, trigger_word) with open(f"train/img_{i}.txt", "w") as f: f.write(caption) # Train LoRA (optimized for HF Spaces T4 GPU) cmd = """ python kohya_ss/train_network.py \ --pretrained_model_name_or_path="stabilityai/stable-diffusion-xl-base-1.0" \ --train_data_dir="train" \ --output_dir="output" \ --resolution=512 \ --network_dim=32 \ --lr=1e-4 \ --max_train_steps=800 \ --mixed_precision="fp16" \ --save_precision="fp16" \ --optimizer_type="AdamW8bit" \ --xformers """ progress(0.8, desc="Training LoRA...") subprocess.run(cmd, shell=True, check=True) return "output/lora.safetensors" # ===== 4. Gradio UI ===== with gr.Blocks(title="1-Click LoRA Trainer") as demo: gr.Markdown(""" ## 🎨 Weights.gg-Style LoRA Trainer Upload 30 images + set a trigger word to train a custom LoRA. """) with gr.Row(): with gr.Column(): images = gr.Files( label="Upload Character Images (30 max)", file_types=["image"], interactive=True ) trigger = gr.Textbox( label="Trigger Word", placeholder="E.g., 'my_char' (used as [my_char] in prompts)" ) train_btn = gr.Button("🚀 Train LoRA", variant="primary") with gr.Column(): output = gr.File(label="Download LoRA") gallery = gr.Gallery(label="Training Preview") train_btn.click( train_lora, inputs=[images, trigger], outputs=output, api_name="train" ) if __name__ == "__main__": demo.launch(server_name="0.0.0.0", server_port=7860)