dwb2023 commited on
Commit
40ff259
1 Parent(s): cc04f60

Update app.py

Browse files

adjust architecture for certain models

Files changed (1) hide show
  1. app.py +25 -1
app.py CHANGED
@@ -26,9 +26,33 @@ def get_model_summary(model_name):
26
  return model_cache[model_name], ""
27
 
28
  try:
 
 
 
 
 
 
 
29
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
30
 
31
- model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True).to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
  model_summary = str(model)
34
  model_cache[model_name] = model_summary
 
26
  return model_cache[model_name], ""
27
 
28
  try:
29
+ # Fetch the config.json file
30
+ config_url = f"https://huggingface.co/{model_name}/raw/main/config.json"
31
+ response = requests.get(config_url)
32
+ response.raise_for_status()
33
+ config = response.json()
34
+ architecture = config["architectures"][0]
35
+
36
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
37
 
38
+ # Select the correct model class based on the architecture
39
+ if architecture == "LlavaNextForConditionalGeneration":
40
+ from transformers import LlavaNextForConditionalGeneration
41
+ model = LlavaNextForConditionalGeneration.from_pretrained(model_name, trust_remote_code=True).to(device)
42
+ elif architecture == "LlavaForConditionalGeneration":
43
+ from transformers import LlavaForConditionalGeneration
44
+ model = LlavaForConditionalGeneration.from_pretrained(model_name, trust_remote_code=True).to(device)
45
+ elif architecture == "PaliGemmaForConditionalGeneration":
46
+ from transformers import PaliGemmaForConditionalGeneration
47
+ model = PaliGemmaForConditionalGeneration.from_pretrained(model_name, trust_remote_code=True).to(device)
48
+ elif architecture == "Idefics2ForConditionalGeneration":
49
+ from transformers import Idefics2ForConditionalGeneration
50
+ model = Idefics2ForConditionalGeneration.from_pretrained(model_name, trust_remote_code=True).to(device)
51
+ elif architecture == "MiniCPMV":
52
+ from transformers import MiniCPMV
53
+ model = MiniCPMV.from_pretrained(model_name, trust_remote_code=True).to(device)
54
+ else:
55
+ model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True).to(device)
56
 
57
  model_summary = str(model)
58
  model_cache[model_name] = model_summary