dwb2023 commited on
Commit
5f6d3e9
1 Parent(s): 79ab92b

Update utils.py

Browse files

refine bnb config and architecture mappings

Files changed (1) hide show
  1. utils.py +53 -51
utils.py CHANGED
@@ -1,9 +1,8 @@
1
  import subprocess
2
- import os, requests
3
- import torch, torchvision
4
- import spaces
5
- from huggingface_hub import login
6
- from transformers import BitsAndBytesConfig, AutoModelForCausalLM, LlavaNextForConditionalGeneration, LlavaForConditionalGeneration, PaliGemmaForConditionalGeneration, Idefics2ForConditionalGeneration
7
 
8
  os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
9
 
@@ -15,63 +14,66 @@ def install_flash_attn():
15
  shell=True,
16
  )
17
 
18
- # Authenticate with Hugging Face
19
- def authenticate_hf(token):
20
- login(token=token, add_to_git_credential=True)
 
 
 
 
 
21
 
22
- # Function to get the model summary
23
- model_cache = {}
24
-
25
- @spaces.GPU
26
  def get_model_summary(model_name):
27
- if model_name in model_cache:
28
- return model_cache[model_name], ""
 
 
 
29
 
 
 
 
30
  try:
31
- # Fetch the config.json file
32
- config_url = f"https://huggingface.co/{model_name}/raw/main/config.json"
33
- headers = {"Authorization": f"Bearer {os.getenv('HF_TOKEN')}"}
34
- response = requests.get(config_url, headers=headers)
35
- response.raise_for_status()
36
- config = response.json()
37
- architecture = config["architectures"][0]
38
-
39
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
40
-
41
- # Check if the model is quantized
42
- is_quantized = "quantized" in model_name.lower()
43
 
44
  # Set up BitsAndBytesConfig if the model is quantized
45
- bnb_config = BitsAndBytesConfig(load_in_4bit=True) if is_quantized else None
46
-
47
- # Load the model based on its architecture and quantization status
48
- if architecture == "LlavaNextForConditionalGeneration":
49
- model = LlavaNextForConditionalGeneration.from_pretrained(
50
- model_name, config=bnb_config, trust_remote_code=True
51
- )
52
- elif architecture == "LlavaForConditionalGeneration":
53
- model = LlavaForConditionalGeneration.from_pretrained(
54
- model_name, config=bnb_config, trust_remote_code=True
55
- )
56
- elif architecture == "PaliGemmaForConditionalGeneration":
57
- model = PaliGemmaForConditionalGeneration.from_pretrained(
58
- model_name, config=bnb_config, trust_remote_code=True
59
- )
60
- elif architecture == "Idefics2ForConditionalGeneration":
61
- model = Idefics2ForConditionalGeneration.from_pretrained(
62
- model_name, config=bnb_config, trust_remote_code=True
63
  )
64
  else:
65
- model = AutoModelForCausalLM.from_pretrained(
66
- model_name, config=bnb_config, trust_remote_code=True
67
- )
 
 
 
 
 
 
68
 
69
  # Move to device only if the model is not quantized
70
- if not is_quantized:
71
- model = model.to(device)
72
 
73
- model_summary = str(model)
74
- model_cache[model_name] = model_summary
75
  return model_summary, ""
 
 
 
 
76
  except Exception as e:
77
  return "", str(e)
 
1
  import subprocess
2
+ import os
3
+ import torch
4
+ from transformers import BitsAndBytesConfig, AutoConfig, AutoModelForCausalLM, LlavaNextForConditionalGeneration, LlavaForConditionalGeneration, PaliGemmaForConditionalGeneration, Idefics2ForConditionalGeneration
5
+ from functools import lru_cache
 
6
 
7
  os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
8
 
 
14
  shell=True,
15
  )
16
 
17
+ # Architecture to model class mapping
18
+ ARCHITECTURE_MAP = {
19
+ "LlavaNextForConditionalGeneration": LlavaNextForConditionalGeneration,
20
+ "LlavaForConditionalGeneration": LlavaForConditionalGeneration,
21
+ "PaliGemmaForConditionalGeneration": PaliGemmaForConditionalGeneration,
22
+ "Idefics2ForConditionalGeneration": Idefics2ForConditionalGeneration,
23
+ "AutoModelForCausalLM": AutoModelForCausalLM
24
+ }
25
 
26
+ # Function to get the model summary with caching
27
+ @lru_cache(maxsize=10)
 
 
28
  def get_model_summary(model_name):
29
+ """
30
+ Retrieve the model summary for the given model name.
31
+
32
+ Args:
33
+ model_name (str): The name of the model to retrieve the summary for.
34
 
35
+ Returns:
36
+ tuple: A tuple containing the model summary (str) and an error message (str), if any.
37
+ """
38
  try:
39
+ # Fetch the model configuration
40
+ config = AutoConfig.from_pretrained(model_name)
41
+ architecture = config.architectures[0]
42
+ quantization_config = getattr(config, 'quantization_config', None)
 
 
 
 
 
 
 
 
43
 
44
  # Set up BitsAndBytesConfig if the model is quantized
45
+ if quantization_config:
46
+ bnb_config = BitsAndBytesConfig(
47
+ load_in_4bit=quantization_config.get('load_in_4bit', False),
48
+ load_in_8bit=quantization_config.get('load_in_8bit', False),
49
+ bnb_4bit_compute_dtype=quantization_config.get('bnb_4bit_compute_dtype', torch.float16),
50
+ bnb_4bit_quant_type=quantization_config.get('bnb_4bit_quant_type', 'nf4'),
51
+ bnb_4bit_use_double_quant=quantization_config.get('bnb_4bit_use_double_quant', False),
52
+ llm_int8_enable_fp32_cpu_offload=quantization_config.get('llm_int8_enable_fp32_cpu_offload', False),
53
+ llm_int8_has_fp16_weight=quantization_config.get('llm_int8_has_fp16_weight', False),
54
+ llm_int8_skip_modules=quantization_config.get('llm_int8_skip_modules', None),
55
+ llm_int8_threshold=quantization_config.get('llm_int8_threshold', 6.0),
 
 
 
 
 
 
 
56
  )
57
  else:
58
+ bnb_config = None
59
+
60
+ # Get the appropriate model class from the architecture map
61
+ model_class = ARCHITECTURE_MAP.get(architecture, AutoModelForCausalLM)
62
+
63
+ # Load the model
64
+ model = model_class.from_pretrained(
65
+ model_name, config=bnb_config, trust_remote_code=True
66
+ )
67
 
68
  # Move to device only if the model is not quantized
69
+ if model and not quantization_config:
70
+ model = model.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
71
 
72
+ model_summary = str(model) if model else "Model architecture not found."
 
73
  return model_summary, ""
74
+ except ValueError as ve:
75
+ return "", f"ValueError: {ve}"
76
+ except EnvironmentError as ee:
77
+ return "", f"EnvironmentError: {ee}"
78
  except Exception as e:
79
  return "", str(e)