File size: 3,584 Bytes
8889bbb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import os
from gradio.themes import ThemeClass as Theme
import numpy as np
import argparse
import gradio as gr
from typing import Any, Iterator
from typing import Iterator, List, Optional, Tuple
import filelock
import glob
import json
import time
from gradio.routes import Request
from gradio.utils import SyncToAsyncIterator, async_iteration
from gradio.helpers import special_args
import anyio
from typing import AsyncGenerator, Callable, Literal, Union, cast

from gradio_client.documentation import document, set_documentation_group
from gradio.components import Button, Component
from gradio.events import Dependency, EventListenerMethod
from typing import List, Optional, Union, Dict, Tuple
from tqdm.auto import tqdm
from huggingface_hub import snapshot_download


def create_class_func_registry():
    registry = {}
    def register_registry(cls, exist_ok=False):
        assert exist_ok or cls.__name__ not in registry, f'{cls} already in registry: {registry}'
        registry[cls.__name__] = cls
        return cls
        
    def get_registry(name):
        assert name in registry, f'{name} not in registry: {registry}'
        return registry[name]

    return registry, register_registry, get_registry

DEMOS, register_demo, get_demo_class = create_class_func_registry()


class BaseDemo(object):
    """
    All demo should be created from BaseDemo and registered with @register_demo
    """
    def __init__(self) -> None:
        pass

    @property
    def tab_name(self):
        return "Demo"

    def create_demo(
            self, 
            title: Optional[str] = None, 
            description: Optional[str] = None,
            **kwargs,
    ) -> gr.Blocks:
        pass


@document()
class CustomTabbedInterface(gr.Blocks):
    def __init__(
        self,
        interface_list: list[gr.Interface],
        tab_names: Optional[list[str]] = None,
        title: Optional[str] = None,
        description: Optional[str] = None,
        theme: Optional[gr.Theme] = None,
        analytics_enabled: Optional[bool] = None,
        css: Optional[str] = None,
    ):
        """
        Parameters:
            interface_list: a list of interfaces to be rendered in tabs.
            tab_names: a list of tab names. If None, the tab names will be "Tab 1", "Tab 2", etc.
            title: a title for the interface; if provided, appears above the input and output components in large font. Also used as the tab title when opened in a browser window.
            analytics_enabled: whether to allow basic telemetry. If None, will use GRADIO_ANALYTICS_ENABLED environment variable or default to True.
            css: custom css or path to custom css file to apply to entire Blocks
        Returns:
            a Gradio Tabbed Interface for the given interfaces
        """
        super().__init__(
            title=title or "Gradio",
            theme=theme,
            analytics_enabled=analytics_enabled,
            mode="tabbed_interface",
            css=css,
        )
        self.description = description
        if tab_names is None:
            tab_names = [f"Tab {i}" for i in range(len(interface_list))]
        with self:
            if title:
                gr.Markdown(
                    f"<h1 style='text-align: center; margin-bottom: 1rem'>{title}</h1>"
                )
            if description:
                gr.Markdown(description)
            with gr.Tabs():
                for interface, tab_name in zip(interface_list, tab_names):
                    with gr.Tab(label=tab_name):
                        interface.render()