model_explorer4 / utils.py
dwb2023's picture
Update utils.py
5f6d3e9 verified
raw
history blame
No virus
3.42 kB
import subprocess
import os
import torch
from transformers import BitsAndBytesConfig, AutoConfig, AutoModelForCausalLM, LlavaNextForConditionalGeneration, LlavaForConditionalGeneration, PaliGemmaForConditionalGeneration, Idefics2ForConditionalGeneration
from functools import lru_cache
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
# Install required package
def install_flash_attn():
subprocess.run(
"pip install flash-attn --no-build-isolation",
env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
shell=True,
)
# Architecture to model class mapping
ARCHITECTURE_MAP = {
"LlavaNextForConditionalGeneration": LlavaNextForConditionalGeneration,
"LlavaForConditionalGeneration": LlavaForConditionalGeneration,
"PaliGemmaForConditionalGeneration": PaliGemmaForConditionalGeneration,
"Idefics2ForConditionalGeneration": Idefics2ForConditionalGeneration,
"AutoModelForCausalLM": AutoModelForCausalLM
}
# Function to get the model summary with caching
@lru_cache(maxsize=10)
def get_model_summary(model_name):
"""
Retrieve the model summary for the given model name.
Args:
model_name (str): The name of the model to retrieve the summary for.
Returns:
tuple: A tuple containing the model summary (str) and an error message (str), if any.
"""
try:
# Fetch the model configuration
config = AutoConfig.from_pretrained(model_name)
architecture = config.architectures[0]
quantization_config = getattr(config, 'quantization_config', None)
# Set up BitsAndBytesConfig if the model is quantized
if quantization_config:
bnb_config = BitsAndBytesConfig(
load_in_4bit=quantization_config.get('load_in_4bit', False),
load_in_8bit=quantization_config.get('load_in_8bit', False),
bnb_4bit_compute_dtype=quantization_config.get('bnb_4bit_compute_dtype', torch.float16),
bnb_4bit_quant_type=quantization_config.get('bnb_4bit_quant_type', 'nf4'),
bnb_4bit_use_double_quant=quantization_config.get('bnb_4bit_use_double_quant', False),
llm_int8_enable_fp32_cpu_offload=quantization_config.get('llm_int8_enable_fp32_cpu_offload', False),
llm_int8_has_fp16_weight=quantization_config.get('llm_int8_has_fp16_weight', False),
llm_int8_skip_modules=quantization_config.get('llm_int8_skip_modules', None),
llm_int8_threshold=quantization_config.get('llm_int8_threshold', 6.0),
)
else:
bnb_config = None
# Get the appropriate model class from the architecture map
model_class = ARCHITECTURE_MAP.get(architecture, AutoModelForCausalLM)
# Load the model
model = model_class.from_pretrained(
model_name, config=bnb_config, trust_remote_code=True
)
# Move to device only if the model is not quantized
if model and not quantization_config:
model = model.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
model_summary = str(model) if model else "Model architecture not found."
return model_summary, ""
except ValueError as ve:
return "", f"ValueError: {ve}"
except EnvironmentError as ee:
return "", f"EnvironmentError: {ee}"
except Exception as e:
return "", str(e)