Spaces:
Runtime error
Runtime error
File size: 3,566 Bytes
a5723a0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 |
import gradio as gr
import torch
from transformers import pipeline
import gc
import json
# Define available models/tasks
MODEL_CONFIGS = [
{
"name": "Text Generation (GPT-2)",
"task": "text-generation",
"model": "gpt2",
"input_type": "text",
"output_type": "text"
},
{
"name": "Image Classification (ViT)",
"task": "image-classification",
"model": "google/vit-base-patch16-224",
"input_type": "image",
"output_type": "label"
},
# Add more models/tasks as needed
]
# Shared state for demo
shared_state = gr.State({"active_model": None, "last_result": None})
# Model cache for lazy loading
model_cache = {}
def load_model(task, model_name):
# Use device_map="auto" or device=0 for GPU if available
return pipeline(task, model=model_name, device=-1)
def unload_model(model_key):
if model_key in model_cache:
del model_cache[model_key]
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
with gr.Blocks() as demo:
gr.Markdown("# Multi-Model, Multi-Task Gradio Demo\n_Switch between models and tasks in one Space!_")
tab_names = [m["name"] for m in MODEL_CONFIGS]
with gr.Tabs() as tabs:
tab_blocks = []
for i, config in enumerate(MODEL_CONFIGS):
with gr.Tab(config["name"]):
status = gr.Markdown(f"**Model:** {config['model']}<br>**Task:** {config['task']}")
load_btn = gr.Button("Load Model")
unload_btn = gr.Button("Unload Model")
if config["input_type"] == "text":
input_comp = gr.Textbox(label="Input Text")
elif config["input_type"] == "image":
input_comp = gr.Image(label="Input Image")
else:
input_comp = gr.Textbox(label="Input")
run_btn = gr.Button("Run Model")
output_comp = gr.Textbox(label="Output", lines=4)
model_key = f"{config['task']}|{config['model']}"
def do_load(state):
if model_key not in model_cache:
model_cache[model_key] = load_model(config["task"], config["model"])
state = dict(state)
state["active_model"] = model_key
return f"Loaded: {model_key}", state
def do_unload(state):
unload_model(model_key)
state = dict(state)
state["active_model"] = None
return f"Unloaded: {model_key}", state
def do_run(inp, state):
if model_key not in model_cache:
return "Model not loaded!", state
pipe = model_cache[model_key]
result = pipe(inp)
state = dict(state)
state["last_result"] = result
return str(result), state
load_btn.click(do_load, shared_state, [status, shared_state])
unload_btn.click(do_unload, shared_state, [status, shared_state])
run_btn.click(do_run, [input_comp, shared_state], [output_comp, shared_state])
# Shared state display
def pretty_json(state):
return json.dumps(state, indent=2, ensure_ascii=False)
shared_state_box = gr.Textbox(label="Shared State", lines=8, interactive=False)
shared_state.change(pretty_json, shared_state, shared_state_box)
demo.launch() |