Spaces:
Paused
Paused
import subprocess | |
import os | |
import torch | |
from transformers import BitsAndBytesConfig, AutoConfig, AutoModelForCausalLM, LlavaNextForConditionalGeneration, LlavaForConditionalGeneration, PaliGemmaForConditionalGeneration, Idefics2ForConditionalGeneration | |
import spaces | |
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" | |
# Install required package | |
def install_flash_attn(): | |
subprocess.run( | |
"pip install flash-attn --no-build-isolation", | |
env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"}, | |
shell=True, | |
) | |
# Architecture to model class mapping | |
ARCHITECTURE_MAP = { | |
"LlavaNextForConditionalGeneration": LlavaNextForConditionalGeneration, | |
"LlavaForConditionalGeneration": LlavaForConditionalGeneration, | |
"PaliGemmaForConditionalGeneration": PaliGemmaForConditionalGeneration, | |
"Idefics2ForConditionalGeneration": Idefics2ForConditionalGeneration, | |
"AutoModelForCausalLM": AutoModelForCausalLM | |
} | |
# Function to get the model summary with caching and GPU support | |
def get_model_summary(model_name): | |
""" | |
Retrieve the model summary for the given model name. | |
Args: | |
model_name (str): The name of the model to retrieve the summary for. | |
Returns: | |
tuple: A tuple containing the model summary (str) and an error message (str), if any. | |
""" | |
try: | |
# Fetch the model configuration | |
config = AutoConfig.from_pretrained(model_name) | |
architecture = config.architectures[0] | |
quantization_config = getattr(config, 'quantization_config', None) | |
# Set up BitsAndBytesConfig if the model is quantized | |
if quantization_config: | |
bnb_config = BitsAndBytesConfig( | |
load_in_4bit=quantization_config.get('load_in_4bit', False), | |
load_in_8bit=quantization_config.get('load_in_8bit', False), | |
bnb_4bit_compute_dtype=quantization_config.get('bnb_4bit_compute_dtype', torch.float16), | |
bnb_4bit_quant_type=quantization_config.get('bnb_4bit_quant_type', 'nf4'), | |
bnb_4bit_use_double_quant=quantization_config.get('bnb_4bit_use_double_quant', False), | |
llm_int8_enable_fp32_cpu_offload=quantization_config.get('llm_int8_enable_fp32_cpu_offload', False), | |
llm_int8_has_fp16_weight=quantization_config.get('llm_int8_has_fp16_weight', False), | |
llm_int8_skip_modules=quantization_config.get('llm_int8_skip_modules', None), | |
llm_int8_threshold=quantization_config.get('llm_int8_threshold', 6.0), | |
) | |
else: | |
bnb_config = None | |
# Get the appropriate model class from the architecture map | |
model_class = ARCHITECTURE_MAP.get(architecture, AutoModelForCausalLM) | |
# Load the model | |
model = model_class.from_pretrained( | |
model_name, config=bnb_config, trust_remote_code=True | |
) | |
# Move to device only if the model is not quantized | |
if model and not quantization_config: | |
model = model.to(torch.device("cuda" if torch.cuda.is_available() else "cpu")) | |
model_summary = str(model) if model else "Model architecture not found." | |
return model_summary, "" | |
except ValueError as ve: | |
return "", f"ValueError: {ve}" | |
except EnvironmentError as ee: | |
return "", f"EnvironmentError: {ee}" | |
except Exception as e: | |
return "", str(e) | |