model_explorer4 / utils.py
dwb2023's picture
Update utils.py
461b9e0 verified
raw
history blame
No virus
3.4 kB
import subprocess
import os
import torch
from transformers import BitsAndBytesConfig, AutoConfig, AutoModelForCausalLM, LlavaNextForConditionalGeneration, LlavaForConditionalGeneration, PaliGemmaForConditionalGeneration, Idefics2ForConditionalGeneration
import spaces
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
# Install required package
def install_flash_attn():
subprocess.run(
"pip install flash-attn --no-build-isolation",
env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
shell=True,
)
# Architecture to model class mapping
ARCHITECTURE_MAP = {
"LlavaNextForConditionalGeneration": LlavaNextForConditionalGeneration,
"LlavaForConditionalGeneration": LlavaForConditionalGeneration,
"PaliGemmaForConditionalGeneration": PaliGemmaForConditionalGeneration,
"Idefics2ForConditionalGeneration": Idefics2ForConditionalGeneration,
"AutoModelForCausalLM": AutoModelForCausalLM
}
# Function to get the model summary with caching and GPU support
@spaces.GPU
def get_model_summary(model_name):
"""
Retrieve the model summary for the given model name.
Args:
model_name (str): The name of the model to retrieve the summary for.
Returns:
tuple: A tuple containing the model summary (str) and an error message (str), if any.
"""
try:
# Fetch the model configuration
config = AutoConfig.from_pretrained(model_name)
architecture = config.architectures[0]
quantization_config = getattr(config, 'quantization_config', None)
# Set up BitsAndBytesConfig if the model is quantized
if quantization_config:
bnb_config = BitsAndBytesConfig(
load_in_4bit=quantization_config.get('load_in_4bit', False),
load_in_8bit=quantization_config.get('load_in_8bit', False),
bnb_4bit_compute_dtype=quantization_config.get('bnb_4bit_compute_dtype', torch.float16),
bnb_4bit_quant_type=quantization_config.get('bnb_4bit_quant_type', 'nf4'),
bnb_4bit_use_double_quant=quantization_config.get('bnb_4bit_use_double_quant', False),
llm_int8_enable_fp32_cpu_offload=quantization_config.get('llm_int8_enable_fp32_cpu_offload', False),
llm_int8_has_fp16_weight=quantization_config.get('llm_int8_has_fp16_weight', False),
llm_int8_skip_modules=quantization_config.get('llm_int8_skip_modules', None),
llm_int8_threshold=quantization_config.get('llm_int8_threshold', 6.0),
)
else:
bnb_config = None
# Get the appropriate model class from the architecture map
model_class = ARCHITECTURE_MAP.get(architecture, AutoModelForCausalLM)
# Load the model
model = model_class.from_pretrained(
model_name, config=bnb_config, trust_remote_code=True
)
# Move to device only if the model is not quantized
if model and not quantization_config:
model = model.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
model_summary = str(model) if model else "Model architecture not found."
return model_summary, ""
except ValueError as ve:
return "", f"ValueError: {ve}"
except EnvironmentError as ee:
return "", f"EnvironmentError: {ee}"
except Exception as e:
return "", str(e)