|
|
""" |
|
|
RunPod Serverless Handler - Wrapper for AI-Toolkit |
|
|
Does NOT modify ai-toolkit code, only wraps it |
|
|
|
|
|
Supports RunPod model caching via HuggingFace integration. |
|
|
""" |
|
|
|
|
|
import os |
|
|
import sys |
|
|
import subprocess |
|
|
import traceback |
|
|
import logging |
|
|
import uuid |
|
|
from pathlib import Path |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
RUNPOD_CACHE_BASE = "/runpod-volume/huggingface-cache" |
|
|
RUNPOD_HF_CACHE = "/runpod-volume/huggingface-cache/hub" |
|
|
|
|
|
|
|
|
IS_RUNPOD_CACHE = os.path.exists("/runpod-volume") |
|
|
|
|
|
if IS_RUNPOD_CACHE: |
|
|
|
|
|
os.environ["HF_HOME"] = RUNPOD_CACHE_BASE |
|
|
os.environ["HUGGINGFACE_HUB_CACHE"] = RUNPOD_HF_CACHE |
|
|
os.environ["TRANSFORMERS_CACHE"] = RUNPOD_HF_CACHE |
|
|
os.environ["HF_DATASETS_CACHE"] = f"{RUNPOD_CACHE_BASE}/datasets" |
|
|
|
|
|
|
|
|
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" |
|
|
os.environ["NO_ALBUMENTATIONS_UPDATE"] = "1" |
|
|
os.environ["DISABLE_TELEMETRY"] = "YES" |
|
|
|
|
|
|
|
|
HF_TOKEN = os.environ.get("HF_TOKEN", "") |
|
|
if HF_TOKEN: |
|
|
os.environ["HUGGING_FACE_HUB_TOKEN"] = HF_TOKEN |
|
|
|
|
|
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) |
|
|
AI_TOOLKIT_DIR = os.path.join(SCRIPT_DIR, "ai-toolkit") |
|
|
|
|
|
import runpod |
|
|
import torch |
|
|
import yaml |
|
|
import gc |
|
|
import shutil |
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
CURRENT_MODEL = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MODEL_PRESETS = { |
|
|
"wan21_1b": "train_lora_wan21_1b_24gb.yaml", |
|
|
"wan21_14b": "train_lora_wan21_14b_24gb.yaml", |
|
|
"wan22_14b": "train_lora_wan22_14b_24gb.yaml", |
|
|
"qwen_image": "train_lora_qwen_image_24gb.yaml", |
|
|
"qwen_image_edit": "train_lora_qwen_image_edit_32gb.yaml", |
|
|
"qwen_image_edit_2509": "train_lora_qwen_image_edit_2509_32gb.yaml", |
|
|
"flux_dev": "train_lora_flux_24gb.yaml", |
|
|
"flux_schnell": "train_lora_flux_schnell_24gb.yaml", |
|
|
} |
|
|
|
|
|
|
|
|
CACHE_REPO = "Aloukik21/trainer" |
|
|
|
|
|
|
|
|
MODEL_CACHE_PATHS = { |
|
|
"wan21_1b": "wan21-14b", |
|
|
"wan21_14b": "wan21-14b", |
|
|
"wan22_14b": "wan22-14b", |
|
|
"qwen_image": "qwen-image", |
|
|
"qwen_image_edit": "qwen-image", |
|
|
"qwen_image_edit_2509": "qwen-image", |
|
|
"flux_dev": "flux-dev", |
|
|
"flux_schnell": "flux-schnell", |
|
|
} |
|
|
|
|
|
|
|
|
MODEL_HF_REPOS = { |
|
|
"wan21_1b": "Wan-AI/Wan2.1-T2V-1.3B-Diffusers", |
|
|
"wan21_14b": "Wan-AI/Wan2.1-T2V-14B-Diffusers", |
|
|
"wan22_14b": "ai-toolkit/Wan2.2-T2V-A14B-Diffusers-bf16", |
|
|
"qwen_image": "Qwen/Qwen-Image", |
|
|
"qwen_image_edit": "Qwen/Qwen-Image-Edit", |
|
|
"qwen_image_edit_2509": "Qwen/Qwen-Image-Edit", |
|
|
"flux_dev": "black-forest-labs/FLUX.1-dev", |
|
|
"flux_schnell": "black-forest-labs/FLUX.1-schnell", |
|
|
} |
|
|
|
|
|
|
|
|
ARA_CACHE_PATH = "accuracy_recovery_adapters" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def cleanup_gpu_memory(): |
|
|
"""Aggressively clean up GPU memory.""" |
|
|
logger.info("Cleaning up GPU memory...") |
|
|
|
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.empty_cache() |
|
|
torch.cuda.synchronize() |
|
|
|
|
|
|
|
|
gc.collect() |
|
|
|
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
logger.info(f"GPU memory after cleanup: {get_gpu_info()}") |
|
|
|
|
|
|
|
|
def cleanup_temp_files(): |
|
|
"""Clean up temporary training files.""" |
|
|
logger.info("Cleaning up temporary files...") |
|
|
|
|
|
|
|
|
config_dir = os.path.join(AI_TOOLKIT_DIR, "config") |
|
|
for f in os.listdir(config_dir): |
|
|
if f.endswith('.yaml') and f.startswith(('lora_', 'test_', 'my_')): |
|
|
try: |
|
|
os.remove(os.path.join(config_dir, f)) |
|
|
logger.info(f"Removed temp config: {f}") |
|
|
except Exception as e: |
|
|
logger.warning(f"Failed to remove {f}: {e}") |
|
|
|
|
|
|
|
|
workspace_dirs = ["/workspace/dataset", "/workspace/output"] |
|
|
for ws_dir in workspace_dirs: |
|
|
if os.path.exists(ws_dir): |
|
|
for item in os.listdir(ws_dir): |
|
|
item_path = os.path.join(ws_dir, item) |
|
|
if item.startswith(('_latent_cache', '_t_e_cache', '.aitk')): |
|
|
try: |
|
|
if os.path.isdir(item_path): |
|
|
shutil.rmtree(item_path) |
|
|
else: |
|
|
os.remove(item_path) |
|
|
logger.info(f"Removed cache: {item_path}") |
|
|
except Exception as e: |
|
|
logger.warning(f"Failed to remove {item_path}: {e}") |
|
|
|
|
|
|
|
|
def cleanup_before_training(new_model: str): |
|
|
"""Full cleanup before starting new model training.""" |
|
|
global CURRENT_MODEL |
|
|
|
|
|
if CURRENT_MODEL and CURRENT_MODEL != new_model: |
|
|
logger.info(f"Switching from {CURRENT_MODEL} to {new_model} - performing full cleanup") |
|
|
cleanup_gpu_memory() |
|
|
cleanup_temp_files() |
|
|
elif CURRENT_MODEL == new_model: |
|
|
logger.info(f"Same model {new_model} - light cleanup only") |
|
|
cleanup_gpu_memory() |
|
|
else: |
|
|
logger.info(f"First training run with {new_model}") |
|
|
|
|
|
CURRENT_MODEL = new_model |
|
|
|
|
|
|
|
|
gpu_info = get_gpu_info() |
|
|
logger.info(f"Ready for training. GPU: {gpu_info['name']}, Free: {gpu_info['free_gb']}GB") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_gpu_info(): |
|
|
"""Get GPU information.""" |
|
|
if not torch.cuda.is_available(): |
|
|
return {"available": False} |
|
|
props = torch.cuda.get_device_properties(0) |
|
|
free_mem, total_mem = torch.cuda.mem_get_info(0) |
|
|
return { |
|
|
"available": True, |
|
|
"name": props.name, |
|
|
"total_gb": round(total_mem / (1024**3), 2), |
|
|
"free_gb": round(free_mem / (1024**3), 2), |
|
|
} |
|
|
|
|
|
|
|
|
def get_environment_info(): |
|
|
"""Get environment information for debugging.""" |
|
|
return { |
|
|
"is_runpod_cache": IS_RUNPOD_CACHE, |
|
|
"hf_home": os.environ.get("HF_HOME", "not set"), |
|
|
"hf_token_set": bool(HF_TOKEN), |
|
|
"gpu": get_gpu_info(), |
|
|
"ai_toolkit_dir": AI_TOOLKIT_DIR, |
|
|
"cache_exists": os.path.exists(RUNPOD_HF_CACHE) if IS_RUNPOD_CACHE else False, |
|
|
} |
|
|
|
|
|
|
|
|
def find_cached_model(model_key: str) -> str: |
|
|
""" |
|
|
Find cached model path on RunPod from Aloukik21/trainer repo. |
|
|
|
|
|
Args: |
|
|
model_key: Model key (e.g., 'flux_dev', 'wan22_14b') |
|
|
|
|
|
Returns: |
|
|
Path to cached model subfolder, or original HF repo if not cached |
|
|
""" |
|
|
if not IS_RUNPOD_CACHE: |
|
|
return MODEL_HF_REPOS.get(model_key, "") |
|
|
|
|
|
|
|
|
cache_name = CACHE_REPO.replace("/", "--") |
|
|
snapshots_dir = Path(RUNPOD_HF_CACHE) / f"models--{cache_name}" / "snapshots" |
|
|
|
|
|
if snapshots_dir.exists(): |
|
|
snapshots = list(snapshots_dir.iterdir()) |
|
|
if snapshots: |
|
|
|
|
|
subfolder = MODEL_CACHE_PATHS.get(model_key) |
|
|
if subfolder: |
|
|
cached_path = snapshots[0] / subfolder |
|
|
if cached_path.exists(): |
|
|
logger.info(f"Using cached model: {model_key} -> {cached_path}") |
|
|
return str(cached_path) |
|
|
|
|
|
|
|
|
original_repo = MODEL_HF_REPOS.get(model_key, "") |
|
|
logger.info(f"Model not in cache, using original: {original_repo}") |
|
|
return original_repo |
|
|
|
|
|
|
|
|
def find_cached_ara(adapter_name: str) -> str: |
|
|
""" |
|
|
Find cached accuracy recovery adapter. |
|
|
|
|
|
Args: |
|
|
adapter_name: Adapter filename (e.g., 'wan22_14b_t2i_torchao_uint4.safetensors') |
|
|
|
|
|
Returns: |
|
|
Path to cached adapter, or original HF path |
|
|
""" |
|
|
if not IS_RUNPOD_CACHE: |
|
|
return f"ostris/accuracy_recovery_adapters/{adapter_name}" |
|
|
|
|
|
cache_name = CACHE_REPO.replace("/", "--") |
|
|
snapshots_dir = Path(RUNPOD_HF_CACHE) / f"models--{cache_name}" / "snapshots" |
|
|
|
|
|
if snapshots_dir.exists(): |
|
|
snapshots = list(snapshots_dir.iterdir()) |
|
|
if snapshots: |
|
|
cached_path = snapshots[0] / ARA_CACHE_PATH / adapter_name |
|
|
if cached_path.exists(): |
|
|
logger.info(f"Using cached ARA: {adapter_name} -> {cached_path}") |
|
|
return str(cached_path) |
|
|
|
|
|
return f"ostris/accuracy_recovery_adapters/{adapter_name}" |
|
|
|
|
|
|
|
|
def check_model_cache_status(model_key: str) -> dict: |
|
|
"""Check if model files are cached in Aloukik21/trainer.""" |
|
|
if model_key not in MODEL_CACHE_PATHS: |
|
|
return {"cached": False, "reason": "unknown model"} |
|
|
|
|
|
status = { |
|
|
"model": model_key, |
|
|
"cache_repo": CACHE_REPO, |
|
|
"subfolder": MODEL_CACHE_PATHS.get(model_key), |
|
|
} |
|
|
|
|
|
|
|
|
cache_name = CACHE_REPO.replace("/", "--") |
|
|
snapshots_dir = Path(RUNPOD_HF_CACHE) / f"models--{cache_name}" / "snapshots" |
|
|
|
|
|
if snapshots_dir.exists(): |
|
|
snapshots = list(snapshots_dir.iterdir()) |
|
|
if snapshots: |
|
|
subfolder = MODEL_CACHE_PATHS.get(model_key) |
|
|
model_path = snapshots[0] / subfolder |
|
|
status["cached"] = model_path.exists() |
|
|
status["path"] = str(model_path) if model_path.exists() else None |
|
|
else: |
|
|
status["cached"] = False |
|
|
else: |
|
|
status["cached"] = False |
|
|
status["reason"] = "cache repo not found" |
|
|
|
|
|
return status |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_example_config(model_key): |
|
|
"""Load example config from ai-toolkit.""" |
|
|
if model_key not in MODEL_PRESETS: |
|
|
raise ValueError(f"Unknown model: {model_key}. Available: {list(MODEL_PRESETS.keys())}") |
|
|
|
|
|
config_file = MODEL_PRESETS[model_key] |
|
|
config_path = os.path.join(AI_TOOLKIT_DIR, "config", "examples", config_file) |
|
|
|
|
|
with open(config_path, 'r') as f: |
|
|
return yaml.safe_load(f) |
|
|
|
|
|
|
|
|
def run_training(params): |
|
|
"""Run training using ai-toolkit.""" |
|
|
model_key = params.get("model", "wan22_14b") |
|
|
|
|
|
|
|
|
cleanup_before_training(model_key) |
|
|
|
|
|
|
|
|
config = load_example_config(model_key) |
|
|
|
|
|
|
|
|
job_name = params.get("name", f"lora_{model_key}_{uuid.uuid4().hex[:6]}") |
|
|
config["config"]["name"] = job_name |
|
|
|
|
|
process = config["config"]["process"][0] |
|
|
|
|
|
|
|
|
process["datasets"][0]["folder_path"] = params.get("dataset_path", "/workspace/dataset") |
|
|
|
|
|
|
|
|
process["training_folder"] = params.get("output_path", "/workspace/output") |
|
|
|
|
|
|
|
|
if "steps" in params: |
|
|
process["train"]["steps"] = params["steps"] |
|
|
if "batch_size" in params: |
|
|
process["train"]["batch_size"] = params["batch_size"] |
|
|
if "learning_rate" in params: |
|
|
process["train"]["lr"] = params["learning_rate"] |
|
|
if "lora_rank" in params: |
|
|
process["network"]["linear"] = params["lora_rank"] |
|
|
process["network"]["linear_alpha"] = params.get("lora_alpha", params["lora_rank"]) |
|
|
if "save_every" in params: |
|
|
process["save"]["save_every"] = params["save_every"] |
|
|
if "sample_every" in params: |
|
|
process["sample"]["sample_every"] = params["sample_every"] |
|
|
if "resolution" in params: |
|
|
process["datasets"][0]["resolution"] = params["resolution"] |
|
|
if "num_frames" in params: |
|
|
process["datasets"][0]["num_frames"] = params["num_frames"] |
|
|
if "sample_prompts" in params: |
|
|
process["sample"]["prompts"] = params["sample_prompts"] |
|
|
if "trigger_word" in params: |
|
|
process["trigger_word"] = params["trigger_word"] |
|
|
|
|
|
|
|
|
if "model" in process: |
|
|
cached_path = find_cached_model(model_key) |
|
|
if cached_path: |
|
|
process["model"]["name_or_path"] = cached_path |
|
|
logger.info(f"Model path set to: {cached_path}") |
|
|
|
|
|
|
|
|
config_dir = os.path.join(AI_TOOLKIT_DIR, "config") |
|
|
config_path = os.path.join(config_dir, f"{job_name}.yaml") |
|
|
|
|
|
with open(config_path, 'w') as f: |
|
|
yaml.dump(config, f, default_flow_style=False) |
|
|
|
|
|
logger.info(f"Config saved: {config_path}") |
|
|
logger.info(f"Starting: {job_name}") |
|
|
|
|
|
|
|
|
cmd = [sys.executable, os.path.join(AI_TOOLKIT_DIR, "run.py"), config_path] |
|
|
logger.info(f"Command: {' '.join(cmd)}") |
|
|
|
|
|
proc = subprocess.Popen( |
|
|
cmd, |
|
|
cwd=AI_TOOLKIT_DIR, |
|
|
stdout=subprocess.PIPE, |
|
|
stderr=subprocess.STDOUT, |
|
|
text=True, |
|
|
bufsize=1, |
|
|
) |
|
|
|
|
|
for line in proc.stdout: |
|
|
logger.info(line.rstrip()) |
|
|
|
|
|
proc.wait() |
|
|
|
|
|
|
|
|
cleanup_gpu_memory() |
|
|
|
|
|
if proc.returncode != 0: |
|
|
raise RuntimeError(f"Training failed with code {proc.returncode}") |
|
|
|
|
|
return { |
|
|
"status": "success", |
|
|
"job_name": job_name, |
|
|
"output_path": process["training_folder"], |
|
|
"model": model_key, |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def handler(job): |
|
|
"""RunPod handler.""" |
|
|
job_input = job.get("input", {}) |
|
|
action = job_input.get("action", "train") |
|
|
|
|
|
logger.info(f"Action: {action}, GPU: {get_gpu_info()}") |
|
|
|
|
|
try: |
|
|
if action == "list_models": |
|
|
return {"status": "success", "models": list(MODEL_PRESETS.keys())} |
|
|
|
|
|
elif action == "status": |
|
|
return { |
|
|
"status": "success", |
|
|
"environment": get_environment_info(), |
|
|
} |
|
|
|
|
|
elif action == "check_cache": |
|
|
model_key = job_input.get("model") |
|
|
if model_key: |
|
|
cache_status = check_model_cache_status(model_key) |
|
|
else: |
|
|
cache_status = {m: check_model_cache_status(m) for m in MODEL_PRESETS.keys()} |
|
|
return {"status": "success", "cache": cache_status} |
|
|
|
|
|
elif action == "cleanup": |
|
|
|
|
|
cleanup_gpu_memory() |
|
|
cleanup_temp_files() |
|
|
global CURRENT_MODEL |
|
|
CURRENT_MODEL = None |
|
|
return { |
|
|
"status": "success", |
|
|
"message": "Cleanup complete", |
|
|
"gpu": get_gpu_info(), |
|
|
} |
|
|
|
|
|
elif action == "train": |
|
|
params = job_input.get("params", {}) |
|
|
params["model"] = job_input.get("model", params.get("model", "wan22_14b")) |
|
|
return run_training(params) |
|
|
|
|
|
else: |
|
|
return {"status": "error", "error": f"Unknown action: {action}"} |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(traceback.format_exc()) |
|
|
return {"status": "error", "error": str(e)} |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
logger.info("Starting AI-Toolkit RunPod Handler") |
|
|
logger.info(f"Environment: {get_environment_info()}") |
|
|
runpod.serverless.start({"handler": handler}) |
|
|
|