File size: 3,566 Bytes
a5723a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import gradio as gr
import torch
from transformers import pipeline
import gc
import json

# Define available models/tasks
MODEL_CONFIGS = [
    {
        "name": "Text Generation (GPT-2)",
        "task": "text-generation",
        "model": "gpt2",
        "input_type": "text",
        "output_type": "text"
    },
    {
        "name": "Image Classification (ViT)",
        "task": "image-classification",
        "model": "google/vit-base-patch16-224",
        "input_type": "image",
        "output_type": "label"
    },
    # Add more models/tasks as needed
]

# Shared state for demo
shared_state = gr.State({"active_model": None, "last_result": None})

# Model cache for lazy loading
model_cache = {}

def load_model(task, model_name):
    # Use device_map="auto" or device=0 for GPU if available
    return pipeline(task, model=model_name, device=-1)

def unload_model(model_key):
    if model_key in model_cache:
        del model_cache[model_key]
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

with gr.Blocks() as demo:
    gr.Markdown("# Multi-Model, Multi-Task Gradio Demo\n_Switch between models and tasks in one Space!_")
    tab_names = [m["name"] for m in MODEL_CONFIGS]
    with gr.Tabs() as tabs:
        tab_blocks = []
        for i, config in enumerate(MODEL_CONFIGS):
            with gr.Tab(config["name"]):
                status = gr.Markdown(f"**Model:** {config['model']}<br>**Task:** {config['task']}")
                load_btn = gr.Button("Load Model")
                unload_btn = gr.Button("Unload Model")
                if config["input_type"] == "text":
                    input_comp = gr.Textbox(label="Input Text")
                elif config["input_type"] == "image":
                    input_comp = gr.Image(label="Input Image")
                else:
                    input_comp = gr.Textbox(label="Input")
                run_btn = gr.Button("Run Model")
                output_comp = gr.Textbox(label="Output", lines=4)
                model_key = f"{config['task']}|{config['model']}"

                def do_load(state):
                    if model_key not in model_cache:
                        model_cache[model_key] = load_model(config["task"], config["model"])
                    state = dict(state)
                    state["active_model"] = model_key
                    return f"Loaded: {model_key}", state

                def do_unload(state):
                    unload_model(model_key)
                    state = dict(state)
                    state["active_model"] = None
                    return f"Unloaded: {model_key}", state

                def do_run(inp, state):
                    if model_key not in model_cache:
                        return "Model not loaded!", state
                    pipe = model_cache[model_key]
                    result = pipe(inp)
                    state = dict(state)
                    state["last_result"] = result
                    return str(result), state

                load_btn.click(do_load, shared_state, [status, shared_state])
                unload_btn.click(do_unload, shared_state, [status, shared_state])
                run_btn.click(do_run, [input_comp, shared_state], [output_comp, shared_state])

    # Shared state display
    def pretty_json(state):
        return json.dumps(state, indent=2, ensure_ascii=False)
    shared_state_box = gr.Textbox(label="Shared State", lines=8, interactive=False)
    shared_state.change(pretty_json, shared_state, shared_state_box)

demo.launch()