Spaces:
Runtime error
Runtime error
"""Model manager for handling multiple LLMs in Hugging Face Spaces.""" | |
import os | |
from typing import Dict, Any, Optional, List | |
import logging | |
from dataclasses import dataclass | |
from enum import Enum | |
import huggingface_hub | |
from llama_cpp import Llama | |
class ModelType(Enum): | |
"""Types of models and their specific tasks.""" | |
REASONING = "reasoning" | |
CODE = "code" | |
CHAT = "chat" | |
PLANNING = "planning" | |
ANALYSIS = "analysis" | |
class ModelConfig: | |
"""Configuration for a specific model.""" | |
repo_id: str | |
filename: str | |
model_type: ModelType | |
context_size: int = 4096 | |
gpu_layers: int = 35 | |
batch_size: int = 512 | |
threads: int = 8 | |
class ModelManager: | |
"""Manages multiple LLM models for different tasks in Spaces.""" | |
def __init__(self): | |
# In Spaces, models are stored in the cache directory | |
self.model_dir = os.getenv('SPACE_CACHE_DIR', '/tmp/models') | |
self.models: Dict[str, Llama] = {} | |
self.logger = logging.getLogger(__name__) | |
# Define model configurations | |
self.model_configs = { | |
"reasoning": ModelConfig( | |
repo_id="rrbale/pruned-qwen-moe", | |
filename="model-Q6_K.gguf", | |
model_type=ModelType.REASONING | |
), | |
"code": ModelConfig( | |
repo_id="YorkieOH10/deepseek-coder-6.7B-kexer-Q8_0-GGUF", | |
filename="model.gguf", | |
model_type=ModelType.CODE | |
), | |
"chat": ModelConfig( | |
repo_id="Nidum-Llama-3.2-3B-Uncensored-GGUF", | |
filename="model-Q6_K.gguf", | |
model_type=ModelType.CHAT | |
), | |
"planning": ModelConfig( | |
repo_id="deepseek-ai/JanusFlow-1.3B", | |
filename="model.gguf", | |
model_type=ModelType.PLANNING | |
), | |
"analysis": ModelConfig( | |
repo_id="prithivMLmods/QwQ-4B-Instruct", | |
filename="model.gguf", | |
model_type=ModelType.ANALYSIS, | |
context_size=8192, | |
gpu_layers=40 | |
), | |
"general": ModelConfig( | |
repo_id="gpt-omni/mini-omni2", | |
filename="mini-omni2.gguf", | |
model_type=ModelType.CHAT | |
) | |
} | |
os.makedirs(self.model_dir, exist_ok=True) | |
async def initialize_model(self, model_key: str) -> Optional[Llama]: | |
"""Initialize a specific model in Spaces.""" | |
try: | |
config = self.model_configs[model_key] | |
cache_dir = os.path.join(self.model_dir, model_key) | |
os.makedirs(cache_dir, exist_ok=True) | |
# Download model using HF Hub | |
self.logger.info(f"Downloading {model_key} model...") | |
model_path = huggingface_hub.hf_hub_download( | |
repo_id=config.repo_id, | |
filename=config.filename, | |
repo_type="model", | |
cache_dir=cache_dir, | |
local_dir_use_symlinks=False | |
) | |
# Configure for Spaces GPU environment | |
try: | |
model = Llama( | |
model_path=model_path, | |
n_ctx=config.context_size, | |
n_batch=config.batch_size, | |
n_threads=config.threads, | |
n_gpu_layers=config.gpu_layers, | |
main_gpu=0, | |
tensor_split=None # Let it use all available GPU memory | |
) | |
self.logger.info(f"{model_key} model loaded with GPU acceleration!") | |
except Exception as e: | |
self.logger.warning(f"GPU loading failed for {model_key}: {e}, falling back to CPU...") | |
model = Llama( | |
model_path=model_path, | |
n_ctx=2048, | |
n_batch=256, | |
n_threads=4, | |
n_gpu_layers=0 | |
) | |
self.logger.info(f"{model_key} model loaded in CPU-only mode") | |
self.models[model_key] = model | |
return model | |
except Exception as e: | |
self.logger.error(f"Error initializing {model_key} model: {e}") | |
return None | |
async def get_model(self, model_key: str) -> Optional[Llama]: | |
"""Get a model, initializing it if necessary.""" | |
if model_key not in self.models: | |
return await self.initialize_model(model_key) | |
return self.models[model_key] | |
async def initialize_all_models(self): | |
"""Initialize all configured models.""" | |
for model_key in self.model_configs.keys(): | |
await self.initialize_model(model_key) | |
def get_best_model_for_task(self, task_type: str) -> str: | |
"""Get the best model key for a specific task type.""" | |
task_model_mapping = { | |
"reasoning": "reasoning", | |
"code": "code", | |
"chat": "chat", | |
"planning": "planning", | |
"analysis": "analysis", | |
"general": "general" | |
} | |
return task_model_mapping.get(task_type, "general") | |