model_explorer2 / utils.py
dwb2023's picture
Update utils.py
ef08154 verified
raw
history blame
No virus
2.92 kB
import subprocess
import os, requests
import torch, torchvision
import spaces
from huggingface_hub import login
from transformers import BitsAndBytesConfig, AutoModelForCausalLM, LlavaNextForConditionalGeneration, LlavaForConditionalGeneration, PaliGemmaForConditionalGeneration, Idefics2ForConditionalGeneration
# Install required package
def install_flash_attn():
subprocess.run(
"pip install flash-attn --no-build-isolation",
env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
shell=True,
)
# Authenticate with Hugging Face
def authenticate_hf(token):
login(token=token, add_to_git_credential=True)
# Function to get the model summary
model_cache = {}
@spaces.GPU
def get_model_summary(model_name):
if model_name in model_cache:
return model_cache[model_name], ""
try:
# Fetch the config.json file
config_url = f"https://huggingface.co/{model_name}/raw/main/config.json"
headers = {"Authorization": f"Bearer {os.getenv('HF_TOKEN')}"}
response = requests.get(config_url, headers=headers)
response.raise_for_status()
config = response.json()
architecture = config["architectures"][0]
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Check if the model is quantized
is_quantized = "quantized" in model_name.lower()
# Set up BitsAndBytesConfig if the model is quantized
bnb_config = BitsAndBytesConfig(load_in_4bit=True) if is_quantized else None
# Load the model based on its architecture and quantization status
if architecture == "LlavaNextForConditionalGeneration":
model = LlavaNextForConditionalGeneration.from_pretrained(
model_name, config=bnb_config, trust_remote_code=True
)
elif architecture == "LlavaForConditionalGeneration":
model = LlavaForConditionalGeneration.from_pretrained(
model_name, config=bnb_config, trust_remote_code=True
)
elif architecture == "PaliGemmaForConditionalGeneration":
model = PaliGemmaForConditionalGeneration.from_pretrained(
model_name, config=bnb_config, trust_remote_code=True
)
elif architecture == "Idefics2ForConditionalGeneration":
model = Idefics2ForConditionalGeneration.from_pretrained(
model_name, config=bnb_config, trust_remote_code=True
)
else:
model = AutoModelForCausalLM.from_pretrained(
model_name, config=bnb_config, trust_remote_code=True
)
# Move to device only if the model is not quantized
if not is_quantized:
model = model.to(device)
model_summary = str(model)
model_cache[model_name] = model_summary
return model_summary, ""
except Exception as e:
return "", str(e)