omniscience / model_utils.py
donb-hf's picture
add examples
4038683
raw
history blame contribute delete
No virus
3.15 kB
import subprocess
import os
import torch
from transformers import BitsAndBytesConfig, AutoConfig, AutoModelForCausalLM, LlavaNextForConditionalGeneration, LlavaForConditionalGeneration, PaliGemmaForConditionalGeneration, Idefics2ForConditionalGeneration, Owlv2ForObjectDetection, GroundingDinoForObjectDetection, SamModel
import spaces
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
def install_flash_attn():
subprocess.run(
"pip install flash-attn --no-build-isolation",
env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
shell=True,
)
ARCHITECTURE_MAP = {
"LlavaNextForConditionalGeneration": LlavaNextForConditionalGeneration,
"LlavaForConditionalGeneration": LlavaForConditionalGeneration,
"PaliGemmaForConditionalGeneration": PaliGemmaForConditionalGeneration,
"Idefics2ForConditionalGeneration": Idefics2ForConditionalGeneration,
"Owlv2ForObjectDetection": Owlv2ForObjectDetection,
"GroundingDinoForObjectDetection": GroundingDinoForObjectDetection,
"SamModel": SamModel,
"AutoModelForCausalLM": AutoModelForCausalLM
}
@spaces.GPU
def get_model_summary(model_name):
try:
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
architecture = config.architectures[0]
quantization_config = getattr(config, 'quantization_config', None)
if quantization_config:
bnb_config = BitsAndBytesConfig(
load_in_4bit=quantization_config.get('load_in_4bit', False),
load_in_8bit=quantization_config.get('load_in_8bit', False),
bnb_4bit_compute_dtype=quantization_config.get('bnb_4bit_compute_dtype', torch.float16),
bnb_4bit_quant_type=quantization_config.get('bnb_4bit_quant_type', 'nf4'),
bnb_4bit_use_double_quant=quantization_config.get('bnb_4bit_use_double_quant', False),
llm_int8_enable_fp32_cpu_offload=quantization_config.get('llm_int8_enable_fp32_cpu_offload', False),
llm_int8_has_fp16_weight=quantization_config.get('llm_int8_has_fp16_weight', False),
llm_int8_skip_modules=quantization_config.get('llm_int8_skip_modules', None),
llm_int8_threshold=quantization_config.get('llm_int8_threshold', 6.0),
)
else:
bnb_config = None
model_class = ARCHITECTURE_MAP.get(architecture, AutoModelForCausalLM)
model = model_class.from_pretrained(
model_name, config=bnb_config, trust_remote_code=True
)
if model and not quantization_config:
model = model.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
model_summary = str(model) if model else "Model architecture not found."
config_content = config.to_json_string() if config else "Configuration not found."
return f"## Model Architecture\n\n{model_summary}\n\n## Configuration\n\n{config_content}", ""
except ValueError as ve:
return "", f"ValueError: {ve}"
except EnvironmentError as ee:
return "", f"EnvironmentError: {ee}"
except Exception as e:
return "", str(e)