Aloukik21 commited on
Commit
b31917b
·
verified ·
1 Parent(s): 1f14b6a

Update handler to use Aloukik21/trainer cache

Browse files
Files changed (1) hide show
  1. rp_handler.py +98 -45
rp_handler.py CHANGED
@@ -72,24 +72,36 @@ MODEL_PRESETS = {
72
  "flux_schnell": "train_lora_flux_schnell_24gb.yaml",
73
  }
74
 
75
- # HuggingFace repos used by each model (for pre-warming)
76
- MODEL_HF_REPOS = {
77
- "wan21_1b": ["Wan-AI/Wan2.1-T2V-1.3B-Diffusers"],
78
- "wan21_14b": ["Wan-AI/Wan2.1-T2V-14B-Diffusers"],
79
- "wan22_14b": ["ai-toolkit/Wan2.2-T2V-A14B-Diffusers-bf16"],
80
- "qwen_image": ["Qwen/Qwen-Image"],
81
- "qwen_image_edit": ["Qwen/Qwen-Image-Edit"],
82
- "qwen_image_edit_2509": ["Qwen/Qwen-Image-Edit"],
83
- "flux_dev": ["black-forest-labs/FLUX.1-dev"],
84
- "flux_schnell": ["black-forest-labs/FLUX.1-schnell"],
 
 
 
85
  }
86
 
87
- # Accuracy Recovery Adapters (smaller files, can be pre-downloaded)
88
- ARA_FILES = {
89
- "wan22_14b": "ostris/accuracy_recovery_adapters/wan22_14b_t2i_torchao_uint4.safetensors",
90
- "qwen_image": "ostris/accuracy_recovery_adapters/qwen_image_torchao_uint3.safetensors",
 
 
 
 
 
 
91
  }
92
 
 
 
 
93
 
94
  # =============================================================================
95
  # Cleanup Functions
@@ -196,52 +208,95 @@ def get_environment_info():
196
  }
197
 
198
 
199
- def find_cached_model(hf_repo: str) -> str:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
200
  """
201
- Find cached model path on RunPod.
202
 
203
  Args:
204
- hf_repo: HuggingFace repo ID (e.g., 'black-forest-labs/FLUX.1-dev')
205
 
206
  Returns:
207
- Path to cached model, or original repo ID if not cached
208
  """
209
  if not IS_RUNPOD_CACHE:
210
- return hf_repo
211
 
212
- # Convert "Org/Repo" -> "models--Org--Repo"
213
- cache_name = hf_repo.replace("/", "--")
214
  snapshots_dir = Path(RUNPOD_HF_CACHE) / f"models--{cache_name}" / "snapshots"
215
 
216
  if snapshots_dir.exists():
217
  snapshots = list(snapshots_dir.iterdir())
218
  if snapshots:
219
- cached_path = str(snapshots[0])
220
- logger.info(f"Using cached model: {hf_repo} -> {cached_path}")
221
- return cached_path
 
222
 
223
- logger.info(f"Model not cached, will download: {hf_repo}")
224
- return hf_repo
225
 
226
 
227
  def check_model_cache_status(model_key: str) -> dict:
228
- """Check if model files are cached."""
229
- if model_key not in MODEL_HF_REPOS:
230
  return {"cached": False, "reason": "unknown model"}
231
 
232
- repos = MODEL_HF_REPOS[model_key]
233
- status = {"repos": {}}
 
 
 
234
 
235
- for repo in repos:
236
- cache_name = repo.replace("/", "--")
237
- snapshots_dir = Path(RUNPOD_HF_CACHE) / f"models--{cache_name}" / "snapshots"
238
 
239
- if snapshots_dir.exists() and list(snapshots_dir.iterdir()):
240
- status["repos"][repo] = "cached"
 
 
 
 
 
241
  else:
242
- status["repos"][repo] = "not cached"
 
 
 
243
 
244
- status["all_cached"] = all(s == "cached" for s in status["repos"].values())
245
  return status
246
 
247
 
@@ -306,14 +361,12 @@ def run_training(params):
306
  if "trigger_word" in params:
307
  process["trigger_word"] = params["trigger_word"]
308
 
309
- # Check if we should use cached model path
310
- if IS_RUNPOD_CACHE and "model" in process:
311
- original_path = process["model"].get("name_or_path", "")
312
- if original_path:
313
- cached_path = find_cached_model(original_path)
314
- if cached_path != original_path:
315
- process["model"]["name_or_path"] = cached_path
316
- logger.info(f"Using cached model path: {cached_path}")
317
 
318
  # Save config
319
  config_dir = os.path.join(AI_TOOLKIT_DIR, "config")
 
72
  "flux_schnell": "train_lora_flux_schnell_24gb.yaml",
73
  }
74
 
75
+ # All models cached in single HuggingFace repo for RunPod caching
76
+ CACHE_REPO = "Aloukik21/trainer"
77
+
78
+ # Map model keys to subfolder in cache repo
79
+ MODEL_CACHE_PATHS = {
80
+ "wan21_1b": "wan21-14b", # Uses same base, different config
81
+ "wan21_14b": "wan21-14b",
82
+ "wan22_14b": "wan22-14b",
83
+ "qwen_image": "qwen-image",
84
+ "qwen_image_edit": "qwen-image", # Same base model
85
+ "qwen_image_edit_2509": "qwen-image",
86
+ "flux_dev": "flux-dev",
87
+ "flux_schnell": "flux-schnell",
88
  }
89
 
90
+ # Original HuggingFace repos (fallback if cache not available)
91
+ MODEL_HF_REPOS = {
92
+ "wan21_1b": "Wan-AI/Wan2.1-T2V-1.3B-Diffusers",
93
+ "wan21_14b": "Wan-AI/Wan2.1-T2V-14B-Diffusers",
94
+ "wan22_14b": "ai-toolkit/Wan2.2-T2V-A14B-Diffusers-bf16",
95
+ "qwen_image": "Qwen/Qwen-Image",
96
+ "qwen_image_edit": "Qwen/Qwen-Image-Edit",
97
+ "qwen_image_edit_2509": "Qwen/Qwen-Image-Edit",
98
+ "flux_dev": "black-forest-labs/FLUX.1-dev",
99
+ "flux_schnell": "black-forest-labs/FLUX.1-schnell",
100
  }
101
 
102
+ # Accuracy Recovery Adapters path in cache repo
103
+ ARA_CACHE_PATH = "accuracy_recovery_adapters"
104
+
105
 
106
  # =============================================================================
107
  # Cleanup Functions
 
208
  }
209
 
210
 
211
+ def find_cached_model(model_key: str) -> str:
212
+ """
213
+ Find cached model path on RunPod from Aloukik21/trainer repo.
214
+
215
+ Args:
216
+ model_key: Model key (e.g., 'flux_dev', 'wan22_14b')
217
+
218
+ Returns:
219
+ Path to cached model subfolder, or original HF repo if not cached
220
+ """
221
+ if not IS_RUNPOD_CACHE:
222
+ return MODEL_HF_REPOS.get(model_key, "")
223
+
224
+ # Check for Aloukik21/trainer cache
225
+ cache_name = CACHE_REPO.replace("/", "--")
226
+ snapshots_dir = Path(RUNPOD_HF_CACHE) / f"models--{cache_name}" / "snapshots"
227
+
228
+ if snapshots_dir.exists():
229
+ snapshots = list(snapshots_dir.iterdir())
230
+ if snapshots:
231
+ # Get the subfolder for this model
232
+ subfolder = MODEL_CACHE_PATHS.get(model_key)
233
+ if subfolder:
234
+ cached_path = snapshots[0] / subfolder
235
+ if cached_path.exists():
236
+ logger.info(f"Using cached model: {model_key} -> {cached_path}")
237
+ return str(cached_path)
238
+
239
+ # Fallback to original repo
240
+ original_repo = MODEL_HF_REPOS.get(model_key, "")
241
+ logger.info(f"Model not in cache, using original: {original_repo}")
242
+ return original_repo
243
+
244
+
245
+ def find_cached_ara(adapter_name: str) -> str:
246
  """
247
+ Find cached accuracy recovery adapter.
248
 
249
  Args:
250
+ adapter_name: Adapter filename (e.g., 'wan22_14b_t2i_torchao_uint4.safetensors')
251
 
252
  Returns:
253
+ Path to cached adapter, or original HF path
254
  """
255
  if not IS_RUNPOD_CACHE:
256
+ return f"ostris/accuracy_recovery_adapters/{adapter_name}"
257
 
258
+ cache_name = CACHE_REPO.replace("/", "--")
 
259
  snapshots_dir = Path(RUNPOD_HF_CACHE) / f"models--{cache_name}" / "snapshots"
260
 
261
  if snapshots_dir.exists():
262
  snapshots = list(snapshots_dir.iterdir())
263
  if snapshots:
264
+ cached_path = snapshots[0] / ARA_CACHE_PATH / adapter_name
265
+ if cached_path.exists():
266
+ logger.info(f"Using cached ARA: {adapter_name} -> {cached_path}")
267
+ return str(cached_path)
268
 
269
+ return f"ostris/accuracy_recovery_adapters/{adapter_name}"
 
270
 
271
 
272
  def check_model_cache_status(model_key: str) -> dict:
273
+ """Check if model files are cached in Aloukik21/trainer."""
274
+ if model_key not in MODEL_CACHE_PATHS:
275
  return {"cached": False, "reason": "unknown model"}
276
 
277
+ status = {
278
+ "model": model_key,
279
+ "cache_repo": CACHE_REPO,
280
+ "subfolder": MODEL_CACHE_PATHS.get(model_key),
281
+ }
282
 
283
+ # Check if main cache repo exists
284
+ cache_name = CACHE_REPO.replace("/", "--")
285
+ snapshots_dir = Path(RUNPOD_HF_CACHE) / f"models--{cache_name}" / "snapshots"
286
 
287
+ if snapshots_dir.exists():
288
+ snapshots = list(snapshots_dir.iterdir())
289
+ if snapshots:
290
+ subfolder = MODEL_CACHE_PATHS.get(model_key)
291
+ model_path = snapshots[0] / subfolder
292
+ status["cached"] = model_path.exists()
293
+ status["path"] = str(model_path) if model_path.exists() else None
294
  else:
295
+ status["cached"] = False
296
+ else:
297
+ status["cached"] = False
298
+ status["reason"] = "cache repo not found"
299
 
 
300
  return status
301
 
302
 
 
361
  if "trigger_word" in params:
362
  process["trigger_word"] = params["trigger_word"]
363
 
364
+ # Check if we should use cached model path from Aloukik21/trainer
365
+ if "model" in process:
366
+ cached_path = find_cached_model(model_key)
367
+ if cached_path:
368
+ process["model"]["name_or_path"] = cached_path
369
+ logger.info(f"Model path set to: {cached_path}")
 
 
370
 
371
  # Save config
372
  config_dir = os.path.join(AI_TOOLKIT_DIR, "config")