File size: 4,757 Bytes
1efd233
dfc92d5
89ace7e
 
c1fc3a9
cc04f60
1e1efc2
27bcfa0
c1fc3a9
27bcfa0
 
 
 
 
1efd233
dce3abc
 
2ccc88d
4b29566
 
 
c1fc3a9
028d122
6bf2756
4b29566
 
 
 
40ff259
 
a6d3ba4
76f6945
40ff259
 
 
 
4b29566
 
40ff259
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4b29566
 
 
 
 
 
1efd233
c1fc3a9
 
 
 
4b29566
 
 
 
 
 
a6d3ba4
4b29566
 
 
 
a6d3ba4
 
4b29566
 
 
 
 
 
c1fc3a9
 
 
 
88a0be3
c1fc3a9
 
 
 
 
4b29566
 
 
 
 
 
c1fc3a9
4b29566
0d18b6e
c1fc3a9
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import gradio as gr
import os, requests
import torch, torchvision, einops
import spaces
import subprocess
from transformers import AutoModelForCausalLM, AutoModel, AutoModelForVision2Seq, PaliGemmaForConditionalGeneration, LlavaForConditionalGeneration, LlavaNextForConditionalGeneration
from huggingface_hub import login

# Install required package
subprocess.run(
    "pip install flash-attn --no-build-isolation",
    env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
    shell=True,
)

HF_TOKEN = os.getenv("HF_TOKEN")
login(token=HF_TOKEN, add_to_git_credential=True)

# Cache for storing loaded models and their summaries
model_cache = {}

# Function to get the model summary
@spaces.GPU
def get_model_summary(model_name):
    if model_name in model_cache:
        return model_cache[model_name], ""

    try:
        # Fetch the config.json file
        config_url = f"https://huggingface.co/{model_name}/raw/main/config.json"
        headers = {"Authorization": f"Bearer {HF_TOKEN}"}
        response = requests.get(config_url, headers=headers)
        response.raise_for_status()
        config = response.json()
        architecture = config["architectures"][0]

        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        # Select the correct model class based on the architecture
        if architecture == "LlavaNextForConditionalGeneration":
            from transformers import LlavaNextForConditionalGeneration
            model = LlavaNextForConditionalGeneration.from_pretrained(model_name, trust_remote_code=True).to(device)
        elif architecture == "LlavaForConditionalGeneration":
            from transformers import LlavaForConditionalGeneration
            model = LlavaForConditionalGeneration.from_pretrained(model_name, trust_remote_code=True).to(device)
        elif architecture == "PaliGemmaForConditionalGeneration":
            from transformers import PaliGemmaForConditionalGeneration
            model = PaliGemmaForConditionalGeneration.from_pretrained(model_name, trust_remote_code=True).to(device)
        elif architecture == "Idefics2ForConditionalGeneration":
            from transformers import Idefics2ForConditionalGeneration
            model = Idefics2ForConditionalGeneration.from_pretrained(model_name, trust_remote_code=True).to(device)
        elif architecture == "MiniCPMV":
            from transformers import MiniCPMV
            model = MiniCPMV.from_pretrained(model_name, trust_remote_code=True).to(device)
        else:
            model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True).to(device)
        
        model_summary = str(model)
        model_cache[model_name] = model_summary
        return model_summary, ""
    except Exception as e:
        return "", str(e)

# Create the Gradio Blocks interface
with gr.Blocks() as demo:
    with gr.Row():
        with gr.Column():
            textbox = gr.Textbox(label="Model Name", placeholder="Enter the model name here OR select example below...", lines=1)
            gr.Markdown("### Vision Models")
            vision_examples = gr.Examples(
                examples=[
                    ["llava-hf/llava-v1.6-mistral-7b-hf"],
                    ["xtuner/llava-phi-3-mini-hf"],
                    ["xtuner/llava-llama-3-8b-v1_1-transformers"],
                    ["vikhyatk/moondream2"],
                    ["openbmb/MiniCPM-Llama3-V-2_5"],
                    ["microsoft/Phi-3-vision-128k-instruct"],
                    ["google/paligemma-3b-mix-224"],
                    ["HuggingFaceM4/idefics2-8b-chatty"],
                    ["microsoft/llava-med-v1.5-mistral-7b"]
                ],
                inputs=textbox
            )
            
            gr.Markdown("### Other Models")
            other_examples = gr.Examples(
                examples=[
                    ["google/gemma-7b"],
                    ["microsoft/Phi-3-mini-4k-instruct"],
                    ["meta-llama/Meta-Llama-3-8B"],
                    ["mistralai/Mistral-7B-Instruct-v0.3"]
                ],
                inputs=textbox
            )
            submit_button = gr.Button("Submit")
        with gr.Column():
            output = gr.Textbox(label="Model Architecture", lines=20, placeholder="Model architecture will appear here...", show_copy_button=True)
            error_output = gr.Textbox(label="Error", lines=10, placeholder="Exceptions will appear here...", show_copy_button=True)

    def handle_click(model_name):
        model_summary, error_message = get_model_summary(model_name)
        return model_summary, error_message

    submit_button.click(fn=handle_click, inputs=textbox, outputs=[output, error_output])

# Launch the interface
demo.launch()