model_explorer4 / app.py
dwb2023's picture
Update app.py
a6d3ba4 verified
raw
history blame
No virus
4.76 kB
import gradio as gr
import os, requests
import torch, torchvision, einops
import spaces
import subprocess
from transformers import AutoModelForCausalLM, AutoModel, AutoModelForVision2Seq, PaliGemmaForConditionalGeneration, LlavaForConditionalGeneration, LlavaNextForConditionalGeneration
from huggingface_hub import login
# Install required package
subprocess.run(
"pip install flash-attn --no-build-isolation",
env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
shell=True,
)
HF_TOKEN = os.getenv("HF_TOKEN")
login(token=HF_TOKEN, add_to_git_credential=True)
# Cache for storing loaded models and their summaries
model_cache = {}
# Function to get the model summary
@spaces.GPU
def get_model_summary(model_name):
if model_name in model_cache:
return model_cache[model_name], ""
try:
# Fetch the config.json file
config_url = f"https://huggingface.co/{model_name}/raw/main/config.json"
headers = {"Authorization": f"Bearer {HF_TOKEN}"}
response = requests.get(config_url, headers=headers)
response.raise_for_status()
config = response.json()
architecture = config["architectures"][0]
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Select the correct model class based on the architecture
if architecture == "LlavaNextForConditionalGeneration":
from transformers import LlavaNextForConditionalGeneration
model = LlavaNextForConditionalGeneration.from_pretrained(model_name, trust_remote_code=True).to(device)
elif architecture == "LlavaForConditionalGeneration":
from transformers import LlavaForConditionalGeneration
model = LlavaForConditionalGeneration.from_pretrained(model_name, trust_remote_code=True).to(device)
elif architecture == "PaliGemmaForConditionalGeneration":
from transformers import PaliGemmaForConditionalGeneration
model = PaliGemmaForConditionalGeneration.from_pretrained(model_name, trust_remote_code=True).to(device)
elif architecture == "Idefics2ForConditionalGeneration":
from transformers import Idefics2ForConditionalGeneration
model = Idefics2ForConditionalGeneration.from_pretrained(model_name, trust_remote_code=True).to(device)
elif architecture == "MiniCPMV":
from transformers import MiniCPMV
model = MiniCPMV.from_pretrained(model_name, trust_remote_code=True).to(device)
else:
model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True).to(device)
model_summary = str(model)
model_cache[model_name] = model_summary
return model_summary, ""
except Exception as e:
return "", str(e)
# Create the Gradio Blocks interface
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
textbox = gr.Textbox(label="Model Name", placeholder="Enter the model name here OR select example below...", lines=1)
gr.Markdown("### Vision Models")
vision_examples = gr.Examples(
examples=[
["llava-hf/llava-v1.6-mistral-7b-hf"],
["xtuner/llava-phi-3-mini-hf"],
["xtuner/llava-llama-3-8b-v1_1-transformers"],
["vikhyatk/moondream2"],
["openbmb/MiniCPM-Llama3-V-2_5"],
["microsoft/Phi-3-vision-128k-instruct"],
["google/paligemma-3b-mix-224"],
["HuggingFaceM4/idefics2-8b-chatty"],
["microsoft/llava-med-v1.5-mistral-7b"]
],
inputs=textbox
)
gr.Markdown("### Other Models")
other_examples = gr.Examples(
examples=[
["google/gemma-7b"],
["microsoft/Phi-3-mini-4k-instruct"],
["meta-llama/Meta-Llama-3-8B"],
["mistralai/Mistral-7B-Instruct-v0.3"]
],
inputs=textbox
)
submit_button = gr.Button("Submit")
with gr.Column():
output = gr.Textbox(label="Model Architecture", lines=20, placeholder="Model architecture will appear here...", show_copy_button=True)
error_output = gr.Textbox(label="Error", lines=10, placeholder="Exceptions will appear here...", show_copy_button=True)
def handle_click(model_name):
model_summary, error_message = get_model_summary(model_name)
return model_summary, error_message
submit_button.click(fn=handle_click, inputs=textbox, outputs=[output, error_output])
# Launch the interface
demo.launch()