Spaces:
Sleeping
Sleeping
import os | |
import subprocess | |
import gradio as gr | |
from PIL import Image | |
import torch | |
from transformers import Blip2Processor, Blip2ForConditionalGeneration | |
# ===== 1. Initialize BLIP-2 for Auto-Captioning ===== | |
def load_blip_model(): | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b") | |
model = Blip2ForConditionalGeneration.from_pretrained( | |
"Salesforce/blip2-opt-2.7b", | |
torch_dtype=torch.float16 | |
).to(device) | |
return processor, model, device | |
processor, model, device = load_blip_model() | |
def generate_caption(image_path, trigger_word): | |
image = Image.open(image_path) | |
inputs = processor(image, return_tensors="pt").to(device, torch.float16) | |
generated_ids = model.generate(**inputs, max_new_tokens=50) | |
caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip() | |
return f"a photo of [{trigger_word}], {caption}" | |
# ===== 2. Install Kohya_SS Manually ===== | |
if not os.path.exists("kohya_ss"): | |
print("β¬οΈ Installing Kohya_SS...") | |
os.system("git clone https://github.com/bmaltais/kohya_ss") | |
os.system("cd kohya_ss && pip install -r requirements.txt") | |
os.system("cd kohya_ss && pip install .") | |
# ===== 3. Training Function ===== | |
def train_lora(images, trigger_word, progress=gr.Progress()): | |
progress(0.1, desc="Preparing data...") | |
# Save images + auto-caption | |
os.makedirs("train", exist_ok=True) | |
for i, img in enumerate(progress.tqdm(images, desc="Processing images")): | |
img_path = f"train/img_{i}.jpg" | |
img.save(img_path) | |
caption = generate_caption(img_path, trigger_word) | |
with open(f"train/img_{i}.txt", "w") as f: | |
f.write(caption) | |
# Train LoRA (optimized for HF Spaces T4 GPU) | |
cmd = """ | |
python kohya_ss/train_network.py \ | |
--pretrained_model_name_or_path="stabilityai/stable-diffusion-xl-base-1.0" \ | |
--train_data_dir="train" \ | |
--output_dir="output" \ | |
--resolution=512 \ | |
--network_dim=32 \ | |
--lr=1e-4 \ | |
--max_train_steps=800 \ | |
--mixed_precision="fp16" \ | |
--save_precision="fp16" \ | |
--optimizer_type="AdamW8bit" \ | |
--xformers | |
""" | |
progress(0.8, desc="Training LoRA...") | |
subprocess.run(cmd, shell=True, check=True) | |
return "output/lora.safetensors" | |
# ===== 4. Gradio UI ===== | |
with gr.Blocks(title="1-Click LoRA Trainer") as demo: | |
gr.Markdown(""" | |
## π¨ Weights.gg-Style LoRA Trainer | |
Upload 30 images + set a trigger word to train a custom LoRA. | |
""") | |
with gr.Row(): | |
with gr.Column(): | |
images = gr.Files( | |
label="Upload Character Images (30 max)", | |
file_types=["image"], | |
interactive=True | |
) | |
trigger = gr.Textbox( | |
label="Trigger Word", | |
placeholder="E.g., 'my_char' (used as [my_char] in prompts)" | |
) | |
train_btn = gr.Button("π Train LoRA", variant="primary") | |
with gr.Column(): | |
output = gr.File(label="Download LoRA") | |
gallery = gr.Gallery(label="Training Preview") | |
train_btn.click( | |
train_lora, | |
inputs=[images, trigger], | |
outputs=output, | |
api_name="train" | |
) | |
if __name__ == "__main__": | |
demo.launch(server_name="0.0.0.0", server_port=7860) |