demo / backend /tasks /get_available_model_provider.py
tfrere's picture
add get available model provider to benchmark generation
0e34dc4
raw
history blame
7.28 kB
import os
import logging
import json
from huggingface_hub import model_info, InferenceClient
from dotenv import load_dotenv
# Define preferred providers
PREFERRED_PROVIDERS = ["sambanova", "novita"]
def filter_providers(providers):
"""Filter providers to only include preferred ones."""
return [provider for provider in providers if provider in PREFERRED_PROVIDERS]
def prioritize_providers(providers):
"""Prioritize preferred providers, keeping all others."""
preferred = [provider for provider in providers if provider in PREFERRED_PROVIDERS]
non_preferred = [provider for provider in providers if provider not in PREFERRED_PROVIDERS]
return preferred + non_preferred
# Configure logging
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)
def is_vision_model(model_name: str) -> bool:
"""
Check if the model is a vision model based on its name
Args:
model_name: Name of the model
Returns:
True if it's a vision model, False otherwise
"""
vision_indicators = ["-VL-", "vision", "clip", "image"]
return any(indicator in model_name.lower() for indicator in vision_indicators)
def get_test_payload(model_name: str) -> dict:
"""
Get the appropriate test payload based on model type
Args:
model_name: Name of the model
Returns:
Dictionary containing the test payload
"""
# We're only testing text models now
return {
"inputs": "Hello",
"parameters": {
"max_new_tokens": 5
}
}
def test_provider(model_name: str, provider: str, verbose: bool = False) -> bool:
"""
Test if a specific provider is available for a model using InferenceClient
Args:
model_name: Name of the model
provider: Provider to test
verbose: Whether to log detailed information
Returns:
True if the provider is available, False otherwise
"""
try:
# Load environment variables
load_dotenv()
# Get HF token from environment
hf_token = os.environ.get("HF_TOKEN")
if not hf_token:
raise ValueError("HF_TOKEN not defined in environment")
if verbose:
logger.info(f"Testing provider {provider} for model {model_name}")
# Initialize the InferenceClient with the specific provider
client = InferenceClient(
model=model_name,
token=hf_token,
provider=provider,
timeout=10 # Increased timeout to allow model loading
)
try:
# Use the chat completions method for testing
response = client.chat_completion(
messages=[{"role": "user", "content": "Hello"}],
max_tokens=5
)
if verbose:
logger.info(f"Provider {provider} is available for {model_name}")
return True
except Exception as e:
if verbose:
error_message = str(e)
logger.error(f"Error with provider {provider}: {error_message}")
# Log specific error types if we can identify them
if "status_code=429" in error_message:
logger.warning(f"Provider {provider} rate limited. You may need to wait or upgrade your plan.")
elif "status_code=401" in error_message:
logger.warning(f"Authentication failed for provider {provider}. Check your token.")
elif "status_code=503" in error_message:
logger.warning(f"Provider {provider} service unavailable. Model may be loading or provider is down.")
elif "timed out" in error_message.lower():
logger.error(f"Timeout error with provider {provider} - request timed out after 10 seconds")
return False
except Exception as e:
if verbose:
logger.error(f"Error in test_provider: {str(e)}")
return False
def get_available_model_provider(model_name, verbose=False):
"""
Get the first available provider for a given model.
Args:
model_name: Name of the model on the Hub
verbose: Whether to log detailed information
Returns:
First available provider or None if none are available
"""
try:
# Load environment variables
load_dotenv()
# Get HF token from environment
hf_token = os.environ.get("HF_TOKEN")
if not hf_token:
raise ValueError("HF_TOKEN not defined in environment")
# Get providers for the model and prioritize them
info = model_info(model_name, expand="inferenceProviderMapping")
if not hasattr(info, "inference_provider_mapping"):
if verbose:
logger.info(f"No inference providers found for {model_name}")
return None
providers = list(info.inference_provider_mapping.keys())
if not providers:
if verbose:
logger.info(f"Empty list of providers for {model_name}")
return None
# Prioritize providers
providers = prioritize_providers(providers)
if verbose:
logger.info(f"Available providers for {model_name}: {', '.join(providers)}")
# Test each provider
for provider in providers:
if test_provider(model_name, provider, verbose):
return provider
return None
except Exception as e:
if verbose:
logger.error(f"Error in get_available_model_provider: {str(e)}")
return None
if __name__ == "__main__":
# # Example usage with verbose mode enabled
# model = "Qwen/Qwen2.5-72B-Instruct"
# # Test sambanova provider
# print("\nTesting sambanova provider:")
# sambanova_available = test_provider(model, "sambanova", verbose=True)
# print(f"sambanova available: {sambanova_available}")
# # Test novita provider
# print("\nTesting novita provider:")
# novita_available = test_provider(model, "novita", verbose=True)
# print(f"novita available: {novita_available}")
# # Test automatic provider selection
# print("\nTesting automatic provider selection:")
# provider = get_available_model_provider(model, verbose=True)
# print(f"Selected provider: {provider}")
models = [
"Qwen/QwQ-32B",
"Qwen/Qwen2.5-72B-Instruct",
"meta-llama/Llama-3.3-70B-Instruct",
"deepseek-ai/DeepSeek-R1-Distill-Llama-70B",
"mistralai/Mistral-Small-24B-Instruct-2501",
]
providers = []
for model in models:
provider = get_available_model_provider(model, verbose=True)
providers.append(provider)
print(f"Providers {len(providers)}: {providers}")
# print("\nTesting novita provider:")
# novita_available = test_provider("deepseek-ai/DeepSeek-V3-0324", "novita", verbose=True)
# print(f"novita available: {novita_available}")