md896 commited on
Commit
ac3911c
·
1 Parent(s): bc20ef9

changes in ultimate sota

Browse files
Files changed (1) hide show
  1. ultimate_sota_training.py +22 -13
ultimate_sota_training.py CHANGED
@@ -98,21 +98,30 @@ import torch
98
  from datasets import Dataset
99
 
100
  # --- CRITICAL FIXES FOR HF JOBS ---
101
- # 1. Mock vllm: TRL's GRPOTrainer (v0.18+) has a buggy import path that hard-fails if vllm is missing,
102
- # even if you don't intend to use it. We mock the entire vllm hierarchy.
103
  import sys
 
 
104
  from unittest.mock import MagicMock
105
- for m in [
106
- "vllm",
107
- "vllm.distributed",
108
- "vllm.distributed.device_communicators",
109
- "vllm.distributed.device_communicators.pynccl",
110
- "vllm.model_executor",
111
- "vllm.model_executor.parallel_utils",
112
- ]:
113
- sys.modules[m] = MagicMock()
114
-
115
- # 2. Mock llm_blender: It unconditionally tries to import TRANSFORMERS_CACHE which was removed in transformers 4.40+.
 
 
 
 
 
 
 
116
  import transformers.utils.hub
117
  if not hasattr(transformers.utils.hub, "TRANSFORMERS_CACHE"):
118
  transformers.utils.hub.TRANSFORMERS_CACHE = "/tmp"
 
98
  from datasets import Dataset
99
 
100
  # --- CRITICAL FIXES FOR HF JOBS ---
101
+ # 1. Mock vllm: TRL's GRPOTrainer (v0.18+) has a buggy import path that hard-fails if vllm is missing.
102
+ # We must provide a mock that satisfies both 'import' and 'importlib.util.find_spec'.
103
  import sys
104
+ import types
105
+ import importlib.machinery
106
  from unittest.mock import MagicMock
107
+
108
+ def mock_vllm_hierarchy():
109
+ for m_name in [
110
+ "vllm",
111
+ "vllm.distributed",
112
+ "vllm.distributed.device_communicators",
113
+ "vllm.distributed.device_communicators.pynccl",
114
+ "vllm.model_executor",
115
+ "vllm.model_executor.parallel_utils",
116
+ ]:
117
+ mock_m = MagicMock(spec=types.ModuleType)
118
+ mock_m.__name__ = m_name
119
+ mock_m.__spec__ = importlib.machinery.ModuleSpec(m_name, None)
120
+ sys.modules[m_name] = mock_m
121
+
122
+ mock_vllm_hierarchy()
123
+
124
+ # 2. Mock llm_blender: Fix for TRANSFORMERS_CACHE removal in transformers 4.40+.
125
  import transformers.utils.hub
126
  if not hasattr(transformers.utils.hub, "TRANSFORMERS_CACHE"):
127
  transformers.utils.hub.TRANSFORMERS_CACHE = "/tmp"