|
|
|
|
|
""" |
|
|
Quick validation for the specific errors from the previous log |
|
|
""" |
|
|
|
|
|
def test_device_mesh_issue(): |
|
|
"""Test the exact error: No module named 'torch.distributed.device_mesh'""" |
|
|
print("π Testing device_mesh issue...") |
|
|
try: |
|
|
|
|
|
from accelerate.parallelism_config import ParallelismConfig |
|
|
print("β
accelerate.parallelism_config: OK (device_mesh not required)") |
|
|
return True |
|
|
except ImportError as e: |
|
|
if "device_mesh" in str(e): |
|
|
print(f"β device_mesh still required: {e}") |
|
|
return False |
|
|
else: |
|
|
print(f"β οΈ Other import issue: {e}") |
|
|
return True |
|
|
|
|
|
def test_transformers_generation(): |
|
|
"""Test transformers.generation.utils import""" |
|
|
print("π Testing transformers generation utils...") |
|
|
try: |
|
|
from transformers.generation import GenerationConfig, GenerationMixin |
|
|
print("β
transformers.generation: OK") |
|
|
return True |
|
|
except ImportError as e: |
|
|
print(f"β transformers.generation failed: {e}") |
|
|
return False |
|
|
|
|
|
def test_mistral_model_import(): |
|
|
"""Test the specific mistral model import that failed""" |
|
|
print("π Testing mistral model import...") |
|
|
try: |
|
|
from transformers.models.mistral.modeling_mistral import MistralForCausalLM |
|
|
print("β
MistralForCausalLM: OK") |
|
|
return True |
|
|
except ImportError as e: |
|
|
if "device_mesh" in str(e): |
|
|
print(f"β Mistral still needs device_mesh: {e}") |
|
|
return False |
|
|
else: |
|
|
print(f"β οΈ Mistral other issue: {e}") |
|
|
return True |
|
|
|
|
|
def test_tokenizer_enum_issue(): |
|
|
"""Test the tokenizer enum issue""" |
|
|
print("π Testing tokenizer enum compatibility...") |
|
|
try: |
|
|
from transformers import AutoTokenizer |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium") |
|
|
print("β
DialoGPT tokenizer: No enum issues") |
|
|
return True |
|
|
except Exception as e: |
|
|
if "enum" in str(e).lower() or "variant" in str(e).lower(): |
|
|
print(f"β Tokenizer enum issue persists: {e}") |
|
|
return False |
|
|
else: |
|
|
print(f"β οΈ Tokenizer other issue: {e}") |
|
|
return True |
|
|
|
|
|
def main(): |
|
|
print("π¨ Validation: Previous Error Conditions") |
|
|
print("=" * 50) |
|
|
|
|
|
tests = [ |
|
|
("Device Mesh Issue", test_device_mesh_issue), |
|
|
("Transformers Generation", test_transformers_generation), |
|
|
("Mistral Model Import", test_mistral_model_import), |
|
|
("Tokenizer Enum Issue", test_tokenizer_enum_issue) |
|
|
] |
|
|
|
|
|
results = [] |
|
|
for name, test_func in tests: |
|
|
print(f"\nπ§ͺ {name}:") |
|
|
try: |
|
|
result = test_func() |
|
|
results.append(result) |
|
|
except Exception as e: |
|
|
print(f"β Test crashed: {e}") |
|
|
results.append(False) |
|
|
|
|
|
print("\n" + "=" * 50) |
|
|
passed = sum(results) |
|
|
total = len(results) |
|
|
|
|
|
if passed == total: |
|
|
print("β
ALL TESTS PASSED - Previous errors should be resolved!") |
|
|
else: |
|
|
print(f"β οΈ {passed}/{total} tests passed - Some issues may persist") |
|
|
|
|
|
print(f"Success rate: {passed}/{total} ({100*passed/total:.1f}%)") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|