Update handler to use Aloukik21/trainer cache
Browse files- rp_handler.py +98 -45
rp_handler.py
CHANGED
|
@@ -72,24 +72,36 @@ MODEL_PRESETS = {
|
|
| 72 |
"flux_schnell": "train_lora_flux_schnell_24gb.yaml",
|
| 73 |
}
|
| 74 |
|
| 75 |
-
#
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
"
|
| 81 |
-
"
|
| 82 |
-
"
|
| 83 |
-
"
|
| 84 |
-
"
|
|
|
|
|
|
|
|
|
|
| 85 |
}
|
| 86 |
|
| 87 |
-
#
|
| 88 |
-
|
| 89 |
-
"
|
| 90 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 91 |
}
|
| 92 |
|
|
|
|
|
|
|
|
|
|
| 93 |
|
| 94 |
# =============================================================================
|
| 95 |
# Cleanup Functions
|
|
@@ -196,52 +208,95 @@ def get_environment_info():
|
|
| 196 |
}
|
| 197 |
|
| 198 |
|
| 199 |
-
def find_cached_model(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 200 |
"""
|
| 201 |
-
Find cached
|
| 202 |
|
| 203 |
Args:
|
| 204 |
-
|
| 205 |
|
| 206 |
Returns:
|
| 207 |
-
Path to cached
|
| 208 |
"""
|
| 209 |
if not IS_RUNPOD_CACHE:
|
| 210 |
-
return
|
| 211 |
|
| 212 |
-
|
| 213 |
-
cache_name = hf_repo.replace("/", "--")
|
| 214 |
snapshots_dir = Path(RUNPOD_HF_CACHE) / f"models--{cache_name}" / "snapshots"
|
| 215 |
|
| 216 |
if snapshots_dir.exists():
|
| 217 |
snapshots = list(snapshots_dir.iterdir())
|
| 218 |
if snapshots:
|
| 219 |
-
cached_path =
|
| 220 |
-
|
| 221 |
-
|
|
|
|
| 222 |
|
| 223 |
-
|
| 224 |
-
return hf_repo
|
| 225 |
|
| 226 |
|
| 227 |
def check_model_cache_status(model_key: str) -> dict:
|
| 228 |
-
"""Check if model files are cached."""
|
| 229 |
-
if model_key not in
|
| 230 |
return {"cached": False, "reason": "unknown model"}
|
| 231 |
|
| 232 |
-
|
| 233 |
-
|
|
|
|
|
|
|
|
|
|
| 234 |
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
|
| 239 |
-
|
| 240 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 241 |
else:
|
| 242 |
-
status["
|
|
|
|
|
|
|
|
|
|
| 243 |
|
| 244 |
-
status["all_cached"] = all(s == "cached" for s in status["repos"].values())
|
| 245 |
return status
|
| 246 |
|
| 247 |
|
|
@@ -306,14 +361,12 @@ def run_training(params):
|
|
| 306 |
if "trigger_word" in params:
|
| 307 |
process["trigger_word"] = params["trigger_word"]
|
| 308 |
|
| 309 |
-
# Check if we should use cached model path
|
| 310 |
-
if
|
| 311 |
-
|
| 312 |
-
if
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
process["model"]["name_or_path"] = cached_path
|
| 316 |
-
logger.info(f"Using cached model path: {cached_path}")
|
| 317 |
|
| 318 |
# Save config
|
| 319 |
config_dir = os.path.join(AI_TOOLKIT_DIR, "config")
|
|
|
|
| 72 |
"flux_schnell": "train_lora_flux_schnell_24gb.yaml",
|
| 73 |
}
|
| 74 |
|
| 75 |
+
# All models cached in single HuggingFace repo for RunPod caching
|
| 76 |
+
CACHE_REPO = "Aloukik21/trainer"
|
| 77 |
+
|
| 78 |
+
# Map model keys to subfolder in cache repo
|
| 79 |
+
MODEL_CACHE_PATHS = {
|
| 80 |
+
"wan21_1b": "wan21-14b", # Uses same base, different config
|
| 81 |
+
"wan21_14b": "wan21-14b",
|
| 82 |
+
"wan22_14b": "wan22-14b",
|
| 83 |
+
"qwen_image": "qwen-image",
|
| 84 |
+
"qwen_image_edit": "qwen-image", # Same base model
|
| 85 |
+
"qwen_image_edit_2509": "qwen-image",
|
| 86 |
+
"flux_dev": "flux-dev",
|
| 87 |
+
"flux_schnell": "flux-schnell",
|
| 88 |
}
|
| 89 |
|
| 90 |
+
# Original HuggingFace repos (fallback if cache not available)
|
| 91 |
+
MODEL_HF_REPOS = {
|
| 92 |
+
"wan21_1b": "Wan-AI/Wan2.1-T2V-1.3B-Diffusers",
|
| 93 |
+
"wan21_14b": "Wan-AI/Wan2.1-T2V-14B-Diffusers",
|
| 94 |
+
"wan22_14b": "ai-toolkit/Wan2.2-T2V-A14B-Diffusers-bf16",
|
| 95 |
+
"qwen_image": "Qwen/Qwen-Image",
|
| 96 |
+
"qwen_image_edit": "Qwen/Qwen-Image-Edit",
|
| 97 |
+
"qwen_image_edit_2509": "Qwen/Qwen-Image-Edit",
|
| 98 |
+
"flux_dev": "black-forest-labs/FLUX.1-dev",
|
| 99 |
+
"flux_schnell": "black-forest-labs/FLUX.1-schnell",
|
| 100 |
}
|
| 101 |
|
| 102 |
+
# Accuracy Recovery Adapters path in cache repo
|
| 103 |
+
ARA_CACHE_PATH = "accuracy_recovery_adapters"
|
| 104 |
+
|
| 105 |
|
| 106 |
# =============================================================================
|
| 107 |
# Cleanup Functions
|
|
|
|
| 208 |
}
|
| 209 |
|
| 210 |
|
| 211 |
+
def find_cached_model(model_key: str) -> str:
|
| 212 |
+
"""
|
| 213 |
+
Find cached model path on RunPod from Aloukik21/trainer repo.
|
| 214 |
+
|
| 215 |
+
Args:
|
| 216 |
+
model_key: Model key (e.g., 'flux_dev', 'wan22_14b')
|
| 217 |
+
|
| 218 |
+
Returns:
|
| 219 |
+
Path to cached model subfolder, or original HF repo if not cached
|
| 220 |
+
"""
|
| 221 |
+
if not IS_RUNPOD_CACHE:
|
| 222 |
+
return MODEL_HF_REPOS.get(model_key, "")
|
| 223 |
+
|
| 224 |
+
# Check for Aloukik21/trainer cache
|
| 225 |
+
cache_name = CACHE_REPO.replace("/", "--")
|
| 226 |
+
snapshots_dir = Path(RUNPOD_HF_CACHE) / f"models--{cache_name}" / "snapshots"
|
| 227 |
+
|
| 228 |
+
if snapshots_dir.exists():
|
| 229 |
+
snapshots = list(snapshots_dir.iterdir())
|
| 230 |
+
if snapshots:
|
| 231 |
+
# Get the subfolder for this model
|
| 232 |
+
subfolder = MODEL_CACHE_PATHS.get(model_key)
|
| 233 |
+
if subfolder:
|
| 234 |
+
cached_path = snapshots[0] / subfolder
|
| 235 |
+
if cached_path.exists():
|
| 236 |
+
logger.info(f"Using cached model: {model_key} -> {cached_path}")
|
| 237 |
+
return str(cached_path)
|
| 238 |
+
|
| 239 |
+
# Fallback to original repo
|
| 240 |
+
original_repo = MODEL_HF_REPOS.get(model_key, "")
|
| 241 |
+
logger.info(f"Model not in cache, using original: {original_repo}")
|
| 242 |
+
return original_repo
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
def find_cached_ara(adapter_name: str) -> str:
|
| 246 |
"""
|
| 247 |
+
Find cached accuracy recovery adapter.
|
| 248 |
|
| 249 |
Args:
|
| 250 |
+
adapter_name: Adapter filename (e.g., 'wan22_14b_t2i_torchao_uint4.safetensors')
|
| 251 |
|
| 252 |
Returns:
|
| 253 |
+
Path to cached adapter, or original HF path
|
| 254 |
"""
|
| 255 |
if not IS_RUNPOD_CACHE:
|
| 256 |
+
return f"ostris/accuracy_recovery_adapters/{adapter_name}"
|
| 257 |
|
| 258 |
+
cache_name = CACHE_REPO.replace("/", "--")
|
|
|
|
| 259 |
snapshots_dir = Path(RUNPOD_HF_CACHE) / f"models--{cache_name}" / "snapshots"
|
| 260 |
|
| 261 |
if snapshots_dir.exists():
|
| 262 |
snapshots = list(snapshots_dir.iterdir())
|
| 263 |
if snapshots:
|
| 264 |
+
cached_path = snapshots[0] / ARA_CACHE_PATH / adapter_name
|
| 265 |
+
if cached_path.exists():
|
| 266 |
+
logger.info(f"Using cached ARA: {adapter_name} -> {cached_path}")
|
| 267 |
+
return str(cached_path)
|
| 268 |
|
| 269 |
+
return f"ostris/accuracy_recovery_adapters/{adapter_name}"
|
|
|
|
| 270 |
|
| 271 |
|
| 272 |
def check_model_cache_status(model_key: str) -> dict:
|
| 273 |
+
"""Check if model files are cached in Aloukik21/trainer."""
|
| 274 |
+
if model_key not in MODEL_CACHE_PATHS:
|
| 275 |
return {"cached": False, "reason": "unknown model"}
|
| 276 |
|
| 277 |
+
status = {
|
| 278 |
+
"model": model_key,
|
| 279 |
+
"cache_repo": CACHE_REPO,
|
| 280 |
+
"subfolder": MODEL_CACHE_PATHS.get(model_key),
|
| 281 |
+
}
|
| 282 |
|
| 283 |
+
# Check if main cache repo exists
|
| 284 |
+
cache_name = CACHE_REPO.replace("/", "--")
|
| 285 |
+
snapshots_dir = Path(RUNPOD_HF_CACHE) / f"models--{cache_name}" / "snapshots"
|
| 286 |
|
| 287 |
+
if snapshots_dir.exists():
|
| 288 |
+
snapshots = list(snapshots_dir.iterdir())
|
| 289 |
+
if snapshots:
|
| 290 |
+
subfolder = MODEL_CACHE_PATHS.get(model_key)
|
| 291 |
+
model_path = snapshots[0] / subfolder
|
| 292 |
+
status["cached"] = model_path.exists()
|
| 293 |
+
status["path"] = str(model_path) if model_path.exists() else None
|
| 294 |
else:
|
| 295 |
+
status["cached"] = False
|
| 296 |
+
else:
|
| 297 |
+
status["cached"] = False
|
| 298 |
+
status["reason"] = "cache repo not found"
|
| 299 |
|
|
|
|
| 300 |
return status
|
| 301 |
|
| 302 |
|
|
|
|
| 361 |
if "trigger_word" in params:
|
| 362 |
process["trigger_word"] = params["trigger_word"]
|
| 363 |
|
| 364 |
+
# Check if we should use cached model path from Aloukik21/trainer
|
| 365 |
+
if "model" in process:
|
| 366 |
+
cached_path = find_cached_model(model_key)
|
| 367 |
+
if cached_path:
|
| 368 |
+
process["model"]["name_or_path"] = cached_path
|
| 369 |
+
logger.info(f"Model path set to: {cached_path}")
|
|
|
|
|
|
|
| 370 |
|
| 371 |
# Save config
|
| 372 |
config_dir = os.path.join(AI_TOOLKIT_DIR, "config")
|