Spaces:
Running
Running
| """Model loading utilities.""" | |
| from __future__ import annotations | |
| from dataclasses import dataclass | |
| from typing import Any | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from src.config import settings | |
| class ModelBundle: | |
| """Holds tokenizer and model together.""" | |
| tokenizer: Any | |
| model: Any | |
| active_model_name: str | |
| is_mock: bool = False | |
| load_error: str = "" | |
| def load_model_bundle() -> ModelBundle: | |
| """Load Qwen2.5-Coder first, then fallback if needed.""" | |
| if settings.force_mock_mode: | |
| return ModelBundle( | |
| tokenizer=None, | |
| model=None, | |
| active_model_name="mock-rule-based", | |
| is_mock=True, | |
| load_error="FORCE_MOCK_MODE=true", | |
| ) | |
| candidate_models = [ | |
| settings.model_name, | |
| settings.fallback_model_name, | |
| settings.final_fallback_model_name, | |
| ] | |
| last_error = None | |
| for model_name in candidate_models: | |
| try: | |
| tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_name, | |
| trust_remote_code=True, | |
| device_map="auto", | |
| ) | |
| return ModelBundle( | |
| tokenizer=tokenizer, | |
| model=model, | |
| active_model_name=model_name, | |
| ) | |
| except Exception as exc: # pragma: no cover | |
| last_error = exc | |
| return ModelBundle( | |
| tokenizer=None, | |
| model=None, | |
| active_model_name="mock-rule-based", | |
| is_mock=True, | |
| load_error=str(last_error), | |
| ) | |