HighCWu commited on
Commit
01c9658
1 Parent(s): f77bf32

add gradio dir

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. gradio-modified/gradio/.dockerignore +2 -0
  2. gradio-modified/gradio/__init__.py +86 -0
  3. gradio-modified/gradio/blocks.py +1673 -0
  4. gradio-modified/gradio/components.py +0 -0
  5. gradio-modified/gradio/context.py +14 -0
  6. gradio-modified/gradio/data_classes.py +55 -0
  7. gradio-modified/gradio/deprecation.py +45 -0
  8. gradio-modified/gradio/documentation.py +193 -0
  9. gradio-modified/gradio/encryptor.py +31 -0
  10. gradio-modified/gradio/events.py +723 -0
  11. gradio-modified/gradio/examples.py +327 -0
  12. gradio-modified/gradio/exceptions.py +23 -0
  13. gradio-modified/gradio/external.py +462 -0
  14. gradio-modified/gradio/external_utils.py +186 -0
  15. gradio-modified/gradio/flagging.py +560 -0
  16. gradio-modified/gradio/helpers.py +792 -0
  17. gradio-modified/gradio/inputs.py +473 -0
  18. gradio-modified/gradio/interface.py +844 -0
  19. gradio-modified/gradio/interpretation.py +255 -0
  20. gradio-modified/gradio/ipython_ext.py +17 -0
  21. gradio-modified/gradio/launches.json +1 -0
  22. gradio-modified/gradio/layouts.py +377 -0
  23. gradio-modified/gradio/media_data.py +0 -0
  24. gradio-modified/gradio/mix.py +128 -0
  25. gradio-modified/gradio/networking.py +185 -0
  26. gradio-modified/gradio/outputs.py +334 -0
  27. gradio-modified/gradio/pipelines.py +191 -0
  28. gradio-modified/gradio/processing_utils.py +755 -0
  29. gradio-modified/gradio/queueing.py +446 -0
  30. gradio-modified/gradio/reload.py +59 -0
  31. gradio-modified/gradio/routes.py +622 -0
  32. gradio-modified/gradio/serializing.py +208 -0
  33. gradio-modified/gradio/strings.py +41 -0
  34. gradio-modified/gradio/templates.py +563 -0
  35. gradio-modified/{templates → gradio/templates}/frontend/assets/BlockLabel.37da86a3.js +0 -0
  36. gradio-modified/{templates → gradio/templates}/frontend/assets/CarouselItem.svelte_svelte_type_style_lang.cc0aed40.js +0 -0
  37. gradio-modified/{templates → gradio/templates}/frontend/assets/CarouselItem.svelte_svelte_type_style_lang.e110d966.css +0 -0
  38. gradio-modified/{templates → gradio/templates}/frontend/assets/Column.06c172ac.js +0 -0
  39. gradio-modified/{templates → gradio/templates}/frontend/assets/File.60a988f4.js +0 -0
  40. gradio-modified/{templates → gradio/templates}/frontend/assets/Image.4a41f1aa.js +0 -0
  41. gradio-modified/{templates → gradio/templates}/frontend/assets/Image.95fa511c.js +0 -0
  42. gradio-modified/{templates → gradio/templates}/frontend/assets/Model3D.b44fd6f2.js +0 -0
  43. gradio-modified/{templates → gradio/templates}/frontend/assets/ModifyUpload.2cfe71e4.js +0 -0
  44. gradio-modified/{templates → gradio/templates}/frontend/assets/Tabs.6b500f1a.js +0 -0
  45. gradio-modified/{templates → gradio/templates}/frontend/assets/Upload.5d0148e8.js +0 -0
  46. gradio-modified/{templates → gradio/templates}/frontend/assets/Webcam.8816836e.js +0 -0
  47. gradio-modified/{templates → gradio/templates}/frontend/assets/_commonjsHelpers.88e99c8f.js +0 -0
  48. gradio-modified/{templates → gradio/templates}/frontend/assets/color.509e5f03.js +0 -0
  49. gradio-modified/{templates → gradio/templates}/frontend/assets/csv.27f5436c.js +0 -0
  50. gradio-modified/{templates → gradio/templates}/frontend/assets/dsv.7fe76a93.js +0 -0
gradio-modified/gradio/.dockerignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ templates/frontend
2
+ templates/frontend/**/*
gradio-modified/gradio/__init__.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pkgutil
2
+
3
+ import gradio.components as components
4
+ import gradio.inputs as inputs
5
+ import gradio.outputs as outputs
6
+ import gradio.processing_utils
7
+ import gradio.templates
8
+ from gradio.blocks import Blocks
9
+ from gradio.components import (
10
+ HTML,
11
+ JSON,
12
+ Audio,
13
+ Button,
14
+ Carousel,
15
+ Chatbot,
16
+ Checkbox,
17
+ Checkboxgroup,
18
+ CheckboxGroup,
19
+ ColorPicker,
20
+ DataFrame,
21
+ Dataframe,
22
+ Dataset,
23
+ Dropdown,
24
+ File,
25
+ Gallery,
26
+ Highlight,
27
+ Highlightedtext,
28
+ HighlightedText,
29
+ Image,
30
+ Interpretation,
31
+ Json,
32
+ Label,
33
+ LinePlot,
34
+ Markdown,
35
+ Model3D,
36
+ Number,
37
+ Plot,
38
+ Radio,
39
+ ScatterPlot,
40
+ Slider,
41
+ State,
42
+ StatusTracker,
43
+ Text,
44
+ Textbox,
45
+ TimeSeries,
46
+ Timeseries,
47
+ UploadButton,
48
+ Variable,
49
+ Video,
50
+ component,
51
+ )
52
+ from gradio.exceptions import Error
53
+ from gradio.flagging import (
54
+ CSVLogger,
55
+ FlaggingCallback,
56
+ HuggingFaceDatasetJSONSaver,
57
+ HuggingFaceDatasetSaver,
58
+ SimpleCSVLogger,
59
+ )
60
+ from gradio.helpers import Progress
61
+ from gradio.helpers import create_examples as Examples
62
+ from gradio.helpers import make_waveform, skip, update
63
+ from gradio.interface import Interface, TabbedInterface, close_all
64
+ from gradio.ipython_ext import load_ipython_extension
65
+ from gradio.layouts import Accordion, Box, Column, Group, Row, Tab, TabItem, Tabs
66
+ from gradio.mix import Parallel, Series
67
+ from gradio.routes import Request, mount_gradio_app
68
+ from gradio.templates import (
69
+ Files,
70
+ ImageMask,
71
+ ImagePaint,
72
+ List,
73
+ Matrix,
74
+ Mic,
75
+ Microphone,
76
+ Numpy,
77
+ Paint,
78
+ Pil,
79
+ PlayableVideo,
80
+ Sketchpad,
81
+ TextArea,
82
+ Webcam,
83
+ )
84
+
85
+ current_pkg_version = pkgutil.get_data(__name__, "version.txt").decode("ascii").strip()
86
+ __version__ = current_pkg_version
gradio-modified/gradio/blocks.py ADDED
@@ -0,0 +1,1673 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import copy
4
+ import getpass
5
+ import inspect
6
+ import json
7
+ import os
8
+ import pkgutil
9
+ import random
10
+ import sys
11
+ import time
12
+ import warnings
13
+ import webbrowser
14
+ from abc import abstractmethod
15
+ from pathlib import Path
16
+ from types import ModuleType
17
+ from typing import TYPE_CHECKING, Any, Callable, Dict, Iterator, List, Set, Tuple, Type
18
+
19
+ import anyio
20
+ import requests
21
+ from anyio import CapacityLimiter
22
+ from typing_extensions import Literal
23
+
24
+ from gradio import (
25
+ components,
26
+ encryptor,
27
+ external,
28
+ networking,
29
+ queueing,
30
+ routes,
31
+ strings,
32
+ utils,
33
+ )
34
+ from gradio.context import Context
35
+ from gradio.deprecation import check_deprecated_parameters
36
+ from gradio.documentation import document, set_documentation_group
37
+ from gradio.exceptions import DuplicateBlockError, InvalidApiName
38
+ from gradio.helpers import create_tracker, skip, special_args
39
+ from gradio.tunneling import CURRENT_TUNNELS
40
+ from gradio.utils import (
41
+ TupleNoPrint,
42
+ check_function_inputs_match,
43
+ component_or_layout_class,
44
+ delete_none,
45
+ get_cancel_function,
46
+ get_continuous_fn,
47
+ )
48
+
49
+ set_documentation_group("blocks")
50
+
51
+
52
+ if TYPE_CHECKING: # Only import for type checking (is False at runtime).
53
+ import comet_ml
54
+ from fastapi.applications import FastAPI
55
+
56
+ from gradio.components import Component
57
+
58
+
59
+ class Block:
60
+ def __init__(
61
+ self,
62
+ *,
63
+ render: bool = True,
64
+ elem_id: str | None = None,
65
+ visible: bool = True,
66
+ root_url: str | None = None, # URL that is prepended to all file paths
67
+ _skip_init_processing: bool = False, # Used for loading from Spaces
68
+ **kwargs,
69
+ ):
70
+ self._id = Context.id
71
+ Context.id += 1
72
+ self.visible = visible
73
+ self.elem_id = elem_id
74
+ self.root_url = root_url
75
+ self._skip_init_processing = _skip_init_processing
76
+ self._style = {}
77
+ self.parent: BlockContext | None = None
78
+
79
+ if render:
80
+ self.render()
81
+ check_deprecated_parameters(self.__class__.__name__, **kwargs)
82
+
83
+ def render(self):
84
+ """
85
+ Adds self into appropriate BlockContext
86
+ """
87
+ if Context.root_block is not None and self._id in Context.root_block.blocks:
88
+ raise DuplicateBlockError(
89
+ f"A block with id: {self._id} has already been rendered in the current Blocks."
90
+ )
91
+ if Context.block is not None:
92
+ Context.block.add(self)
93
+ if Context.root_block is not None:
94
+ Context.root_block.blocks[self._id] = self
95
+ if isinstance(self, components.TempFileManager):
96
+ Context.root_block.temp_file_sets.append(self.temp_files)
97
+ return self
98
+
99
+ def unrender(self):
100
+ """
101
+ Removes self from BlockContext if it has been rendered (otherwise does nothing).
102
+ Removes self from the layout and collection of blocks, but does not delete any event triggers.
103
+ """
104
+ if Context.block is not None:
105
+ try:
106
+ Context.block.children.remove(self)
107
+ except ValueError:
108
+ pass
109
+ if Context.root_block is not None:
110
+ try:
111
+ del Context.root_block.blocks[self._id]
112
+ except KeyError:
113
+ pass
114
+ return self
115
+
116
+ def get_block_name(self) -> str:
117
+ """
118
+ Gets block's class name.
119
+
120
+ If it is template component it gets the parent's class name.
121
+
122
+ @return: class name
123
+ """
124
+ return (
125
+ self.__class__.__base__.__name__.lower()
126
+ if hasattr(self, "is_template")
127
+ else self.__class__.__name__.lower()
128
+ )
129
+
130
+ def get_expected_parent(self) -> Type[BlockContext] | None:
131
+ return None
132
+
133
+ def set_event_trigger(
134
+ self,
135
+ event_name: str,
136
+ fn: Callable | None,
137
+ inputs: Component | List[Component] | Set[Component] | None,
138
+ outputs: Component | List[Component] | None,
139
+ preprocess: bool = True,
140
+ postprocess: bool = True,
141
+ scroll_to_output: bool = False,
142
+ show_progress: bool = True,
143
+ api_name: str | None = None,
144
+ js: str | None = None,
145
+ no_target: bool = False,
146
+ queue: bool | None = None,
147
+ batch: bool = False,
148
+ max_batch_size: int = 4,
149
+ cancels: List[int] | None = None,
150
+ every: float | None = None,
151
+ ) -> Dict[str, Any]:
152
+ """
153
+ Adds an event to the component's dependencies.
154
+ Parameters:
155
+ event_name: event name
156
+ fn: Callable function
157
+ inputs: input list
158
+ outputs: output list
159
+ preprocess: whether to run the preprocess methods of components
160
+ postprocess: whether to run the postprocess methods of components
161
+ scroll_to_output: whether to scroll to output of dependency on trigger
162
+ show_progress: whether to show progress animation while running.
163
+ api_name: Defining this parameter exposes the endpoint in the api docs
164
+ js: Optional frontend js method to run before running 'fn'. Input arguments for js method are values of 'inputs' and 'outputs', return should be a list of values for output components
165
+ no_target: if True, sets "targets" to [], used for Blocks "load" event
166
+ batch: whether this function takes in a batch of inputs
167
+ max_batch_size: the maximum batch size to send to the function
168
+ cancels: a list of other events to cancel when this event is triggered. For example, setting cancels=[click_event] will cancel the click_event, where click_event is the return value of another components .click method.
169
+ Returns: None
170
+ """
171
+ # Support for singular parameter
172
+ if isinstance(inputs, set):
173
+ inputs_as_dict = True
174
+ inputs = sorted(inputs, key=lambda x: x._id)
175
+ else:
176
+ inputs_as_dict = False
177
+ if inputs is None:
178
+ inputs = []
179
+ elif not isinstance(inputs, list):
180
+ inputs = [inputs]
181
+
182
+ if isinstance(outputs, set):
183
+ outputs = sorted(outputs, key=lambda x: x._id)
184
+ else:
185
+ if outputs is None:
186
+ outputs = []
187
+ elif not isinstance(outputs, list):
188
+ outputs = [outputs]
189
+
190
+ if fn is not None and not cancels:
191
+ check_function_inputs_match(fn, inputs, inputs_as_dict)
192
+
193
+ if Context.root_block is None:
194
+ raise AttributeError(
195
+ f"{event_name}() and other events can only be called within a Blocks context."
196
+ )
197
+ if every is not None and every <= 0:
198
+ raise ValueError("Parameter every must be positive or None")
199
+ if every and batch:
200
+ raise ValueError(
201
+ f"Cannot run {event_name} event in a batch and every {every} seconds. "
202
+ "Either batch is True or every is non-zero but not both."
203
+ )
204
+
205
+ if every and fn:
206
+ fn = get_continuous_fn(fn, every)
207
+ elif every:
208
+ raise ValueError("Cannot set a value for `every` without a `fn`.")
209
+
210
+ Context.root_block.fns.append(
211
+ BlockFunction(fn, inputs, outputs, preprocess, postprocess, inputs_as_dict)
212
+ )
213
+ if api_name is not None:
214
+ api_name_ = utils.append_unique_suffix(
215
+ api_name, [dep["api_name"] for dep in Context.root_block.dependencies]
216
+ )
217
+ if not (api_name == api_name_):
218
+ warnings.warn(
219
+ "api_name {} already exists, using {}".format(api_name, api_name_)
220
+ )
221
+ api_name = api_name_
222
+
223
+ dependency = {
224
+ "targets": [self._id] if not no_target else [],
225
+ "trigger": event_name,
226
+ "inputs": [block._id for block in inputs],
227
+ "outputs": [block._id for block in outputs],
228
+ "backend_fn": fn is not None,
229
+ "js": js,
230
+ "queue": False if fn is None else queue,
231
+ "api_name": api_name,
232
+ "scroll_to_output": scroll_to_output,
233
+ "show_progress": show_progress,
234
+ "every": every,
235
+ "batch": batch,
236
+ "max_batch_size": max_batch_size,
237
+ "cancels": cancels or [],
238
+ }
239
+ Context.root_block.dependencies.append(dependency)
240
+ return dependency
241
+
242
+ def get_config(self):
243
+ return {
244
+ "visible": self.visible,
245
+ "elem_id": self.elem_id,
246
+ "style": self._style,
247
+ "root_url": self.root_url,
248
+ }
249
+
250
+ @staticmethod
251
+ @abstractmethod
252
+ def update(**kwargs) -> Dict:
253
+ return {}
254
+
255
+ @classmethod
256
+ def get_specific_update(cls, generic_update: Dict[str, Any]) -> Dict:
257
+ del generic_update["__type__"]
258
+ specific_update = cls.update(**generic_update)
259
+ return specific_update
260
+
261
+
262
+ class BlockContext(Block):
263
+ def __init__(
264
+ self,
265
+ visible: bool = True,
266
+ render: bool = True,
267
+ **kwargs,
268
+ ):
269
+ """
270
+ Parameters:
271
+ visible: If False, this will be hidden but included in the Blocks config file (its visibility can later be updated).
272
+ render: If False, this will not be included in the Blocks config file at all.
273
+ """
274
+ self.children: List[Block] = []
275
+ super().__init__(visible=visible, render=render, **kwargs)
276
+
277
+ def __enter__(self):
278
+ self.parent = Context.block
279
+ Context.block = self
280
+ return self
281
+
282
+ def add(self, child: Block):
283
+ child.parent = self
284
+ self.children.append(child)
285
+
286
+ def fill_expected_parents(self):
287
+ children = []
288
+ pseudo_parent = None
289
+ for child in self.children:
290
+ expected_parent = child.get_expected_parent()
291
+ if not expected_parent or isinstance(self, expected_parent):
292
+ pseudo_parent = None
293
+ children.append(child)
294
+ else:
295
+ if pseudo_parent is not None and isinstance(
296
+ pseudo_parent, expected_parent
297
+ ):
298
+ pseudo_parent.children.append(child)
299
+ else:
300
+ pseudo_parent = expected_parent(render=False)
301
+ children.append(pseudo_parent)
302
+ pseudo_parent.children = [child]
303
+ if Context.root_block:
304
+ Context.root_block.blocks[pseudo_parent._id] = pseudo_parent
305
+ child.parent = pseudo_parent
306
+ self.children = children
307
+
308
+ def __exit__(self, *args):
309
+ if getattr(self, "allow_expected_parents", True):
310
+ self.fill_expected_parents()
311
+ Context.block = self.parent
312
+
313
+ def postprocess(self, y):
314
+ """
315
+ Any postprocessing needed to be performed on a block context.
316
+ """
317
+ return y
318
+
319
+
320
+ class BlockFunction:
321
+ def __init__(
322
+ self,
323
+ fn: Callable | None,
324
+ inputs: List[Component],
325
+ outputs: List[Component],
326
+ preprocess: bool,
327
+ postprocess: bool,
328
+ inputs_as_dict: bool,
329
+ ):
330
+ self.fn = fn
331
+ self.inputs = inputs
332
+ self.outputs = outputs
333
+ self.preprocess = preprocess
334
+ self.postprocess = postprocess
335
+ self.total_runtime = 0
336
+ self.total_runs = 0
337
+ self.inputs_as_dict = inputs_as_dict
338
+
339
+ def __str__(self):
340
+ return str(
341
+ {
342
+ "fn": getattr(self.fn, "__name__", "fn")
343
+ if self.fn is not None
344
+ else None,
345
+ "preprocess": self.preprocess,
346
+ "postprocess": self.postprocess,
347
+ }
348
+ )
349
+
350
+ def __repr__(self):
351
+ return str(self)
352
+
353
+
354
+ class class_or_instancemethod(classmethod):
355
+ def __get__(self, instance, type_):
356
+ descr_get = super().__get__ if instance is None else self.__func__.__get__
357
+ return descr_get(instance, type_)
358
+
359
+
360
+ def postprocess_update_dict(block: Block, update_dict: Dict, postprocess: bool = True):
361
+ """
362
+ Converts a dictionary of updates into a format that can be sent to the frontend.
363
+ E.g. {"__type__": "generic_update", "value": "2", "interactive": False}
364
+ Into -> {"__type__": "update", "value": 2.0, "mode": "static"}
365
+
366
+ Parameters:
367
+ block: The Block that is being updated with this update dictionary.
368
+ update_dict: The original update dictionary
369
+ postprocess: Whether to postprocess the "value" key of the update dictionary.
370
+ """
371
+ if update_dict.get("__type__", "") == "generic_update":
372
+ update_dict = block.get_specific_update(update_dict)
373
+ if update_dict.get("value") is components._Keywords.NO_VALUE:
374
+ update_dict.pop("value")
375
+ prediction_value = delete_none(update_dict, skip_value=True)
376
+ if "value" in prediction_value and postprocess:
377
+ assert isinstance(
378
+ block, components.IOComponent
379
+ ), f"Component {block.__class__} does not support value"
380
+ prediction_value["value"] = block.postprocess(prediction_value["value"])
381
+ return prediction_value
382
+
383
+
384
+ def convert_component_dict_to_list(
385
+ outputs_ids: List[int], predictions: Dict
386
+ ) -> List | Dict:
387
+ """
388
+ Converts a dictionary of component updates into a list of updates in the order of
389
+ the outputs_ids and including every output component. Leaves other types of dictionaries unchanged.
390
+ E.g. {"textbox": "hello", "number": {"__type__": "generic_update", "value": "2"}}
391
+ Into -> ["hello", {"__type__": "generic_update"}, {"__type__": "generic_update", "value": "2"}]
392
+ """
393
+ keys_are_blocks = [isinstance(key, Block) for key in predictions.keys()]
394
+ if all(keys_are_blocks):
395
+ reordered_predictions = [skip() for _ in outputs_ids]
396
+ for component, value in predictions.items():
397
+ if component._id not in outputs_ids:
398
+ raise ValueError(
399
+ f"Returned component {component} not specified as output of function."
400
+ )
401
+ output_index = outputs_ids.index(component._id)
402
+ reordered_predictions[output_index] = value
403
+ predictions = utils.resolve_singleton(reordered_predictions)
404
+ elif any(keys_are_blocks):
405
+ raise ValueError(
406
+ "Returned dictionary included some keys as Components. Either all keys must be Components to assign Component values, or return a List of values to assign output values in order."
407
+ )
408
+ return predictions
409
+
410
+
411
+ @document("load")
412
+ class Blocks(BlockContext):
413
+ """
414
+ Blocks is Gradio's low-level API that allows you to create more custom web
415
+ applications and demos than Interfaces (yet still entirely in Python).
416
+
417
+
418
+ Compared to the Interface class, Blocks offers more flexibility and control over:
419
+ (1) the layout of components (2) the events that
420
+ trigger the execution of functions (3) data flows (e.g. inputs can trigger outputs,
421
+ which can trigger the next level of outputs). Blocks also offers ways to group
422
+ together related demos such as with tabs.
423
+
424
+
425
+ The basic usage of Blocks is as follows: create a Blocks object, then use it as a
426
+ context (with the "with" statement), and then define layouts, components, or events
427
+ within the Blocks context. Finally, call the launch() method to launch the demo.
428
+
429
+ Example:
430
+ import gradio as gr
431
+ def update(name):
432
+ return f"Welcome to Gradio, {name}!"
433
+
434
+ with gr.Blocks() as demo:
435
+ gr.Markdown("Start typing below and then click **Run** to see the output.")
436
+ with gr.Row():
437
+ inp = gr.Textbox(placeholder="What is your name?")
438
+ out = gr.Textbox()
439
+ btn = gr.Button("Run")
440
+ btn.click(fn=update, inputs=inp, outputs=out)
441
+
442
+ demo.launch()
443
+ Demos: blocks_hello, blocks_flipper, blocks_speech_text_sentiment, generate_english_german, sound_alert
444
+ Guides: blocks_and_event_listeners, controlling_layout, state_in_blocks, custom_CSS_and_JS, custom_interpretations_with_blocks, using_blocks_like_functions
445
+ """
446
+
447
+ def __init__(
448
+ self,
449
+ theme: str = "default",
450
+ analytics_enabled: bool | None = None,
451
+ mode: str = "blocks",
452
+ title: str = "Gradio",
453
+ css: str | None = None,
454
+ **kwargs,
455
+ ):
456
+ """
457
+ Parameters:
458
+ theme: which theme to use - right now, only "default" is supported.
459
+ analytics_enabled: whether to allow basic telemetry. If None, will use GRADIO_ANALYTICS_ENABLED environment variable or default to True.
460
+ mode: a human-friendly name for the kind of Blocks or Interface being created.
461
+ title: The tab title to display when this is opened in a browser window.
462
+ css: custom css or path to custom css file to apply to entire Blocks
463
+ """
464
+ # Cleanup shared parameters with Interface #TODO: is this part still necessary after Interface with Blocks?
465
+ self.limiter = None
466
+ self.save_to = None
467
+ self.theme = theme
468
+ self.encrypt = False
469
+ self.share = False
470
+ self.enable_queue = None
471
+ self.max_threads = 40
472
+ self.show_error = True
473
+ if css is not None and Path(css).exists():
474
+ with open(css) as css_file:
475
+ self.css = css_file.read()
476
+ else:
477
+ self.css = css
478
+
479
+ # For analytics_enabled and allow_flagging: (1) first check for
480
+ # parameter, (2) check for env variable, (3) default to True/"manual"
481
+ self.analytics_enabled = (
482
+ analytics_enabled
483
+ if analytics_enabled is not None
484
+ else os.getenv("GRADIO_ANALYTICS_ENABLED", "True") == "True"
485
+ )
486
+
487
+ super().__init__(render=False, **kwargs)
488
+ self.blocks: Dict[int, Block] = {}
489
+ self.fns: List[BlockFunction] = []
490
+ self.dependencies = []
491
+ self.mode = mode
492
+
493
+ self.is_running = False
494
+ self.local_url = None
495
+ self.share_url = None
496
+ self.width = None
497
+ self.height = None
498
+ self.api_open = True
499
+
500
+ self.ip_address = ""
501
+ self.is_space = True if os.getenv("SYSTEM") == "spaces" else False
502
+ self.favicon_path = None
503
+ self.auth = None
504
+ self.dev_mode = True
505
+ self.app_id = random.getrandbits(64)
506
+ self.temp_file_sets = []
507
+ self.title = title
508
+ self.show_api = True
509
+
510
+ # Only used when an Interface is loaded from a config
511
+ self.predict = None
512
+ self.input_components = None
513
+ self.output_components = None
514
+ self.__name__ = None
515
+ self.api_mode = None
516
+
517
+ if self.analytics_enabled:
518
+ self.ip_address = utils.get_local_ip_address()
519
+ data = {
520
+ "mode": self.mode,
521
+ "ip_address": self.ip_address,
522
+ "custom_css": self.css is not None,
523
+ "theme": self.theme,
524
+ "version": (pkgutil.get_data(__name__, "version.txt") or b"")
525
+ .decode("ascii")
526
+ .strip(),
527
+ }
528
+ utils.initiated_analytics(data)
529
+
530
+ @classmethod
531
+ def from_config(
532
+ cls, config: dict, fns: List[Callable], root_url: str | None = None
533
+ ) -> Blocks:
534
+ """
535
+ Factory method that creates a Blocks from a config and list of functions.
536
+
537
+ Parameters:
538
+ config: a dictionary containing the configuration of the Blocks.
539
+ fns: a list of functions that are used in the Blocks. Must be in the same order as the dependencies in the config.
540
+ root_url: an optional root url to use for the components in the Blocks. Allows serving files from an external URL.
541
+ """
542
+ config = copy.deepcopy(config)
543
+ components_config = config["components"]
544
+ original_mapping: Dict[int, Block] = {}
545
+
546
+ def get_block_instance(id: int) -> Block:
547
+ for block_config in components_config:
548
+ if block_config["id"] == id:
549
+ break
550
+ else:
551
+ raise ValueError("Cannot find block with id {}".format(id))
552
+ cls = component_or_layout_class(block_config["type"])
553
+ block_config["props"].pop("type", None)
554
+ block_config["props"].pop("name", None)
555
+ style = block_config["props"].pop("style", None)
556
+ if block_config["props"].get("root_url") is None and root_url:
557
+ block_config["props"]["root_url"] = root_url + "/"
558
+ # Any component has already processed its initial value, so we skip that step here
559
+ block = cls(**block_config["props"], _skip_init_processing=True)
560
+ if style and isinstance(block, components.IOComponent):
561
+ block.style(**style)
562
+ return block
563
+
564
+ def iterate_over_children(children_list):
565
+ for child_config in children_list:
566
+ id = child_config["id"]
567
+ block = get_block_instance(id)
568
+
569
+ original_mapping[id] = block
570
+
571
+ children = child_config.get("children")
572
+ if children is not None:
573
+ assert isinstance(
574
+ block, BlockContext
575
+ ), f"Invalid config, Block with id {id} has children but is not a BlockContext."
576
+ with block:
577
+ iterate_over_children(children)
578
+
579
+ with Blocks(theme=config["theme"], css=config["theme"]) as blocks:
580
+ # ID 0 should be the root Blocks component
581
+ original_mapping[0] = Context.root_block or blocks
582
+
583
+ iterate_over_children(config["layout"]["children"])
584
+
585
+ first_dependency = None
586
+
587
+ # add the event triggers
588
+ for dependency, fn in zip(config["dependencies"], fns):
589
+ # We used to add a "fake_event" to the config to cache examples
590
+ # without removing it. This was causing bugs in calling gr.Interface.load
591
+ # We fixed the issue by removing "fake_event" from the config in examples.py
592
+ # but we still need to skip these events when loading the config to support
593
+ # older demos
594
+ if dependency["trigger"] == "fake_event":
595
+ continue
596
+ targets = dependency.pop("targets")
597
+ trigger = dependency.pop("trigger")
598
+ dependency.pop("backend_fn")
599
+ dependency.pop("documentation", None)
600
+ dependency["inputs"] = [
601
+ original_mapping[i] for i in dependency["inputs"]
602
+ ]
603
+ dependency["outputs"] = [
604
+ original_mapping[o] for o in dependency["outputs"]
605
+ ]
606
+ dependency.pop("status_tracker", None)
607
+ dependency["preprocess"] = False
608
+ dependency["postprocess"] = False
609
+
610
+ for target in targets:
611
+ dependency = original_mapping[target].set_event_trigger(
612
+ event_name=trigger, fn=fn, **dependency
613
+ )
614
+ if first_dependency is None:
615
+ first_dependency = dependency
616
+
617
+ # Allows some use of Interface-specific methods with loaded Spaces
618
+ if first_dependency and Context.root_block:
619
+ blocks.predict = [fns[0]]
620
+ blocks.input_components = [
621
+ Context.root_block.blocks[i] for i in first_dependency["inputs"]
622
+ ]
623
+ blocks.output_components = [
624
+ Context.root_block.blocks[o] for o in first_dependency["outputs"]
625
+ ]
626
+ blocks.__name__ = "Interface"
627
+ blocks.api_mode = True
628
+
629
+ return blocks
630
+
631
+ def __str__(self):
632
+ return self.__repr__()
633
+
634
+ def __repr__(self):
635
+ num_backend_fns = len([d for d in self.dependencies if d["backend_fn"]])
636
+ repr = f"Gradio Blocks instance: {num_backend_fns} backend functions"
637
+ repr += "\n" + "-" * len(repr)
638
+ for d, dependency in enumerate(self.dependencies):
639
+ if dependency["backend_fn"]:
640
+ repr += f"\nfn_index={d}"
641
+ repr += "\n inputs:"
642
+ for input_id in dependency["inputs"]:
643
+ block = self.blocks[input_id]
644
+ repr += "\n |-{}".format(str(block))
645
+ repr += "\n outputs:"
646
+ for output_id in dependency["outputs"]:
647
+ block = self.blocks[output_id]
648
+ repr += "\n |-{}".format(str(block))
649
+ return repr
650
+
651
+ def render(self):
652
+ if Context.root_block is not None:
653
+ if self._id in Context.root_block.blocks:
654
+ raise DuplicateBlockError(
655
+ f"A block with id: {self._id} has already been rendered in the current Blocks."
656
+ )
657
+ if not set(Context.root_block.blocks).isdisjoint(self.blocks):
658
+ raise DuplicateBlockError(
659
+ "At least one block in this Blocks has already been rendered."
660
+ )
661
+
662
+ Context.root_block.blocks.update(self.blocks)
663
+ Context.root_block.fns.extend(self.fns)
664
+ dependency_offset = len(Context.root_block.dependencies)
665
+ for i, dependency in enumerate(self.dependencies):
666
+ api_name = dependency["api_name"]
667
+ if api_name is not None:
668
+ api_name_ = utils.append_unique_suffix(
669
+ api_name,
670
+ [dep["api_name"] for dep in Context.root_block.dependencies],
671
+ )
672
+ if not (api_name == api_name_):
673
+ warnings.warn(
674
+ "api_name {} already exists, using {}".format(
675
+ api_name, api_name_
676
+ )
677
+ )
678
+ dependency["api_name"] = api_name_
679
+ dependency["cancels"] = [
680
+ c + dependency_offset for c in dependency["cancels"]
681
+ ]
682
+ # Recreate the cancel function so that it has the latest
683
+ # dependency fn indices. This is necessary to properly cancel
684
+ # events in the backend
685
+ if dependency["cancels"]:
686
+ updated_cancels = [
687
+ Context.root_block.dependencies[i]
688
+ for i in dependency["cancels"]
689
+ ]
690
+ new_fn = BlockFunction(
691
+ get_cancel_function(updated_cancels)[0],
692
+ [],
693
+ [],
694
+ False,
695
+ True,
696
+ False,
697
+ )
698
+ Context.root_block.fns[dependency_offset + i] = new_fn
699
+ Context.root_block.dependencies.append(dependency)
700
+ Context.root_block.temp_file_sets.extend(self.temp_file_sets)
701
+
702
+ if Context.block is not None:
703
+ Context.block.children.extend(self.children)
704
+ return self
705
+
706
+ def is_callable(self, fn_index: int = 0) -> bool:
707
+ """Checks if a particular Blocks function is callable (i.e. not stateful or a generator)."""
708
+ block_fn = self.fns[fn_index]
709
+ dependency = self.dependencies[fn_index]
710
+
711
+ if inspect.isasyncgenfunction(block_fn.fn):
712
+ return False
713
+ if inspect.isgeneratorfunction(block_fn.fn):
714
+ return False
715
+ for input_id in dependency["inputs"]:
716
+ block = self.blocks[input_id]
717
+ if getattr(block, "stateful", False):
718
+ return False
719
+ for output_id in dependency["outputs"]:
720
+ block = self.blocks[output_id]
721
+ if getattr(block, "stateful", False):
722
+ return False
723
+
724
+ return True
725
+
726
+ def __call__(self, *inputs, fn_index: int = 0, api_name: str | None = None):
727
+ """
728
+ Allows Blocks objects to be called as functions. Supply the parameters to the
729
+ function as positional arguments. To choose which function to call, use the
730
+ fn_index parameter, which must be a keyword argument.
731
+
732
+ Parameters:
733
+ *inputs: the parameters to pass to the function
734
+ fn_index: the index of the function to call (defaults to 0, which for Interfaces, is the default prediction function)
735
+ api_name: The api_name of the dependency to call. Will take precedence over fn_index.
736
+ """
737
+ if api_name is not None:
738
+ inferred_fn_index = next(
739
+ (
740
+ i
741
+ for i, d in enumerate(self.dependencies)
742
+ if d.get("api_name") == api_name
743
+ ),
744
+ None,
745
+ )
746
+ if inferred_fn_index is None:
747
+ raise InvalidApiName(f"Cannot find a function with api_name {api_name}")
748
+ fn_index = inferred_fn_index
749
+ if not (self.is_callable(fn_index)):
750
+ raise ValueError(
751
+ "This function is not callable because it is either stateful or is a generator. Please use the .launch() method instead to create an interactive user interface."
752
+ )
753
+
754
+ inputs = list(inputs)
755
+ processed_inputs = self.serialize_data(fn_index, inputs)
756
+ batch = self.dependencies[fn_index]["batch"]
757
+ if batch:
758
+ processed_inputs = [[inp] for inp in processed_inputs]
759
+
760
+ outputs = utils.synchronize_async(
761
+ self.process_api,
762
+ fn_index=fn_index,
763
+ inputs=processed_inputs,
764
+ request=None,
765
+ state={},
766
+ )
767
+ outputs = outputs["data"]
768
+
769
+ if batch:
770
+ outputs = [out[0] for out in outputs]
771
+
772
+ processed_outputs = self.deserialize_data(fn_index, outputs)
773
+ processed_outputs = utils.resolve_singleton(processed_outputs)
774
+
775
+ return processed_outputs
776
+
777
+ async def call_function(
778
+ self,
779
+ fn_index: int,
780
+ processed_input: List[Any],
781
+ iterator: Iterator[Any] | None = None,
782
+ requests: routes.Request | List[routes.Request] | None = None,
783
+ event_id: str | None = None,
784
+ ):
785
+ """
786
+ Calls function with given index and preprocessed input, and measures process time.
787
+ Parameters:
788
+ fn_index: index of function to call
789
+ processed_input: preprocessed input to pass to function
790
+ iterator: iterator to use if function is a generator
791
+ requests: requests to pass to function
792
+ event_id: id of event in queue
793
+ """
794
+ block_fn = self.fns[fn_index]
795
+ assert block_fn.fn, f"function with index {fn_index} not defined."
796
+ is_generating = False
797
+
798
+ if block_fn.inputs_as_dict:
799
+ processed_input = [
800
+ {
801
+ input_component: data
802
+ for input_component, data in zip(block_fn.inputs, processed_input)
803
+ }
804
+ ]
805
+
806
+ if isinstance(requests, list):
807
+ request = requests[0]
808
+ else:
809
+ request = requests
810
+ processed_input, progress_index = special_args(
811
+ block_fn.fn,
812
+ processed_input,
813
+ request,
814
+ )
815
+ progress_tracker = (
816
+ processed_input[progress_index] if progress_index is not None else None
817
+ )
818
+
819
+ start = time.time()
820
+
821
+ if iterator is None: # If not a generator function that has already run
822
+ if progress_tracker is not None and progress_index is not None:
823
+ progress_tracker, fn = create_tracker(
824
+ self, event_id, block_fn.fn, progress_tracker.track_tqdm
825
+ )
826
+ processed_input[progress_index] = progress_tracker
827
+ else:
828
+ fn = block_fn.fn
829
+
830
+ if inspect.iscoroutinefunction(fn):
831
+ prediction = await fn(*processed_input)
832
+ else:
833
+ prediction = await anyio.to_thread.run_sync(
834
+ fn, *processed_input, limiter=self.limiter
835
+ )
836
+ else:
837
+ prediction = None
838
+
839
+ if inspect.isasyncgenfunction(block_fn.fn):
840
+ raise ValueError("Gradio does not support async generators.")
841
+ if inspect.isgeneratorfunction(block_fn.fn):
842
+ if not self.enable_queue:
843
+ raise ValueError("Need to enable queue to use generators.")
844
+ try:
845
+ if iterator is None:
846
+ iterator = prediction
847
+ prediction = await anyio.to_thread.run_sync(
848
+ utils.async_iteration, iterator, limiter=self.limiter
849
+ )
850
+ is_generating = True
851
+ except StopAsyncIteration:
852
+ n_outputs = len(self.dependencies[fn_index].get("outputs"))
853
+ prediction = (
854
+ components._Keywords.FINISHED_ITERATING
855
+ if n_outputs == 1
856
+ else (components._Keywords.FINISHED_ITERATING,) * n_outputs
857
+ )
858
+ iterator = None
859
+
860
+ duration = time.time() - start
861
+
862
+ return {
863
+ "prediction": prediction,
864
+ "duration": duration,
865
+ "is_generating": is_generating,
866
+ "iterator": iterator,
867
+ }
868
+
869
+ def serialize_data(self, fn_index: int, inputs: List[Any]) -> List[Any]:
870
+ dependency = self.dependencies[fn_index]
871
+ processed_input = []
872
+
873
+ for i, input_id in enumerate(dependency["inputs"]):
874
+ block = self.blocks[input_id]
875
+ assert isinstance(
876
+ block, components.IOComponent
877
+ ), f"{block.__class__} Component with id {input_id} not a valid input component."
878
+ serialized_input = block.serialize(inputs[i])
879
+ processed_input.append(serialized_input)
880
+
881
+ return processed_input
882
+
883
+ def deserialize_data(self, fn_index: int, outputs: List[Any]) -> List[Any]:
884
+ dependency = self.dependencies[fn_index]
885
+ predictions = []
886
+
887
+ for o, output_id in enumerate(dependency["outputs"]):
888
+ block = self.blocks[output_id]
889
+ assert isinstance(
890
+ block, components.IOComponent
891
+ ), f"{block.__class__} Component with id {output_id} not a valid output component."
892
+ deserialized = block.deserialize(outputs[o])
893
+ predictions.append(deserialized)
894
+
895
+ return predictions
896
+
897
+ def preprocess_data(self, fn_index: int, inputs: List[Any], state: Dict[int, Any]):
898
+ block_fn = self.fns[fn_index]
899
+ dependency = self.dependencies[fn_index]
900
+
901
+ if block_fn.preprocess:
902
+ processed_input = []
903
+ for i, input_id in enumerate(dependency["inputs"]):
904
+ block = self.blocks[input_id]
905
+ assert isinstance(
906
+ block, components.Component
907
+ ), f"{block.__class__} Component with id {input_id} not a valid input component."
908
+ if getattr(block, "stateful", False):
909
+ processed_input.append(state.get(input_id))
910
+ else:
911
+ processed_input.append(block.preprocess(inputs[i]))
912
+ else:
913
+ processed_input = inputs
914
+ return processed_input
915
+
916
+ def postprocess_data(
917
+ self, fn_index: int, predictions: List | Dict, state: Dict[int, Any]
918
+ ):
919
+ block_fn = self.fns[fn_index]
920
+ dependency = self.dependencies[fn_index]
921
+ batch = dependency["batch"]
922
+
923
+ if type(predictions) is dict and len(predictions) > 0:
924
+ predictions = convert_component_dict_to_list(
925
+ dependency["outputs"], predictions
926
+ )
927
+
928
+ if len(dependency["outputs"]) == 1 and not (batch):
929
+ predictions = [
930
+ predictions,
931
+ ]
932
+
933
+ output = []
934
+ for i, output_id in enumerate(dependency["outputs"]):
935
+ if predictions[i] is components._Keywords.FINISHED_ITERATING:
936
+ output.append(None)
937
+ continue
938
+ block = self.blocks[output_id]
939
+ if getattr(block, "stateful", False):
940
+ if not utils.is_update(predictions[i]):
941
+ state[output_id] = predictions[i]
942
+ output.append(None)
943
+ else:
944
+ prediction_value = predictions[i]
945
+ if utils.is_update(prediction_value):
946
+ assert isinstance(prediction_value, dict)
947
+ prediction_value = postprocess_update_dict(
948
+ block=block,
949
+ update_dict=prediction_value,
950
+ postprocess=block_fn.postprocess,
951
+ )
952
+ elif block_fn.postprocess:
953
+ assert isinstance(
954
+ block, components.Component
955
+ ), f"{block.__class__} Component with id {output_id} not a valid output component."
956
+ prediction_value = block.postprocess(prediction_value)
957
+ output.append(prediction_value)
958
+ return output
959
+
960
+ async def process_api(
961
+ self,
962
+ fn_index: int,
963
+ inputs: List[Any],
964
+ state: Dict[int, Any],
965
+ request: routes.Request | List[routes.Request] | None = None,
966
+ iterators: Dict[int, Any] | None = None,
967
+ event_id: str | None = None,
968
+ ) -> Dict[str, Any]:
969
+ """
970
+ Processes API calls from the frontend. First preprocesses the data,
971
+ then runs the relevant function, then postprocesses the output.
972
+ Parameters:
973
+ fn_index: Index of function to run.
974
+ inputs: input data received from the frontend
975
+ username: name of user if authentication is set up (not used)
976
+ state: data stored from stateful components for session (key is input block id)
977
+ iterators: the in-progress iterators for each generator function (key is function index)
978
+ Returns: None
979
+ """
980
+ block_fn = self.fns[fn_index]
981
+ batch = self.dependencies[fn_index]["batch"]
982
+
983
+ if batch:
984
+ max_batch_size = self.dependencies[fn_index]["max_batch_size"]
985
+ batch_sizes = [len(inp) for inp in inputs]
986
+ batch_size = batch_sizes[0]
987
+ if inspect.isasyncgenfunction(block_fn.fn) or inspect.isgeneratorfunction(
988
+ block_fn.fn
989
+ ):
990
+ raise ValueError("Gradio does not support generators in batch mode.")
991
+ if not all(x == batch_size for x in batch_sizes):
992
+ raise ValueError(
993
+ f"All inputs to a batch function must have the same length but instead have sizes: {batch_sizes}."
994
+ )
995
+ if batch_size > max_batch_size:
996
+ raise ValueError(
997
+ f"Batch size ({batch_size}) exceeds the max_batch_size for this function ({max_batch_size})"
998
+ )
999
+
1000
+ inputs = [
1001
+ self.preprocess_data(fn_index, list(i), state) for i in zip(*inputs)
1002
+ ]
1003
+ result = await self.call_function(
1004
+ fn_index, list(zip(*inputs)), None, request
1005
+ )
1006
+ preds = result["prediction"]
1007
+ data = [
1008
+ self.postprocess_data(fn_index, list(o), state) for o in zip(*preds)
1009
+ ]
1010
+ data = list(zip(*data))
1011
+ is_generating, iterator = None, None
1012
+ else:
1013
+ inputs = self.preprocess_data(fn_index, inputs, state)
1014
+ iterator = iterators.get(fn_index, None) if iterators else None
1015
+ result = await self.call_function(
1016
+ fn_index, inputs, iterator, request, event_id
1017
+ )
1018
+ data = self.postprocess_data(fn_index, result["prediction"], state)
1019
+ is_generating, iterator = result["is_generating"], result["iterator"]
1020
+
1021
+ block_fn.total_runtime += result["duration"]
1022
+ block_fn.total_runs += 1
1023
+
1024
+ return {
1025
+ "data": data,
1026
+ "is_generating": is_generating,
1027
+ "iterator": iterator,
1028
+ "duration": result["duration"],
1029
+ "average_duration": block_fn.total_runtime / block_fn.total_runs,
1030
+ }
1031
+
1032
+ async def create_limiter(self):
1033
+ self.limiter = (
1034
+ None
1035
+ if self.max_threads == 40
1036
+ else CapacityLimiter(total_tokens=self.max_threads)
1037
+ )
1038
+
1039
+ def get_config(self):
1040
+ return {"type": "column"}
1041
+
1042
+ def get_config_file(self):
1043
+ config = {
1044
+ "version": routes.VERSION,
1045
+ "mode": self.mode,
1046
+ "dev_mode": self.dev_mode,
1047
+ "components": [],
1048
+ "theme": self.theme,
1049
+ "css": self.css,
1050
+ "title": self.title or "Gradio",
1051
+ "is_space": self.is_space,
1052
+ "enable_queue": getattr(self, "enable_queue", False), # launch attributes
1053
+ "show_error": getattr(self, "show_error", False),
1054
+ "show_api": self.show_api,
1055
+ "is_colab": utils.colab_check(),
1056
+ }
1057
+
1058
+ def getLayout(block):
1059
+ if not isinstance(block, BlockContext):
1060
+ return {"id": block._id}
1061
+ children_layout = []
1062
+ for child in block.children:
1063
+ children_layout.append(getLayout(child))
1064
+ return {"id": block._id, "children": children_layout}
1065
+
1066
+ config["layout"] = getLayout(self)
1067
+
1068
+ for _id, block in self.blocks.items():
1069
+ config["components"].append(
1070
+ {
1071
+ "id": _id,
1072
+ "type": (block.get_block_name()),
1073
+ "props": utils.delete_none(block.get_config())
1074
+ if hasattr(block, "get_config")
1075
+ else {},
1076
+ }
1077
+ )
1078
+ config["dependencies"] = self.dependencies
1079
+ return config
1080
+
1081
+ def __enter__(self):
1082
+ if Context.block is None:
1083
+ Context.root_block = self
1084
+ self.parent = Context.block
1085
+ Context.block = self
1086
+ return self
1087
+
1088
+ def __exit__(self, *args):
1089
+ super().fill_expected_parents()
1090
+ Context.block = self.parent
1091
+ # Configure the load events before root_block is reset
1092
+ self.attach_load_events()
1093
+ if self.parent is None:
1094
+ Context.root_block = None
1095
+ else:
1096
+ self.parent.children.extend(self.children)
1097
+ self.config = self.get_config_file()
1098
+ self.app = routes.App.create_app(self)
1099
+
1100
+ @class_or_instancemethod
1101
+ def load(
1102
+ self_or_cls,
1103
+ fn: Callable | None = None,
1104
+ inputs: List[Component] | None = None,
1105
+ outputs: List[Component] | None = None,
1106
+ api_name: str | None = None,
1107
+ scroll_to_output: bool = False,
1108
+ show_progress: bool = True,
1109
+ queue=None,
1110
+ batch: bool = False,
1111
+ max_batch_size: int = 4,
1112
+ preprocess: bool = True,
1113
+ postprocess: bool = True,
1114
+ every: float | None = None,
1115
+ _js: str | None = None,
1116
+ *,
1117
+ name: str | None = None,
1118
+ src: str | None = None,
1119
+ api_key: str | None = None,
1120
+ alias: str | None = None,
1121
+ **kwargs,
1122
+ ) -> Blocks | Dict[str, Any] | None:
1123
+ """
1124
+ For reverse compatibility reasons, this is both a class method and an instance
1125
+ method, the two of which, confusingly, do two completely different things.
1126
+
1127
+
1128
+ Class method: loads a demo from a Hugging Face Spaces repo and creates it locally and returns a block instance. Equivalent to gradio.Interface.load()
1129
+
1130
+
1131
+ Instance method: adds event that runs as soon as the demo loads in the browser. Example usage below.
1132
+ Parameters:
1133
+ name: Class Method - the name of the model (e.g. "gpt2" or "facebook/bart-base") or space (e.g. "flax-community/spanish-gpt2"), can include the `src` as prefix (e.g. "models/facebook/bart-base")
1134
+ src: Class Method - the source of the model: `models` or `spaces` (or leave empty if source is provided as a prefix in `name`)
1135
+ api_key: Class Method - optional access token for loading private Hugging Face Hub models or spaces. Find your token here: https://huggingface.co/settings/tokens
1136
+ alias: Class Method - optional string used as the name of the loaded model instead of the default name (only applies if loading a Space running Gradio 2.x)
1137
+ fn: Instance Method - the function to wrap an interface around. Often a machine learning model's prediction function. Each parameter of the function corresponds to one input component, and the function should return a single value or a tuple of values, with each element in the tuple corresponding to one output component.
1138
+ inputs: Instance Method - List of gradio.components to use as inputs. If the function takes no inputs, this should be an empty list.
1139
+ outputs: Instance Method - List of gradio.components to use as inputs. If the function returns no outputs, this should be an empty list.
1140
+ api_name: Instance Method - Defining this parameter exposes the endpoint in the api docs
1141
+ scroll_to_output: Instance Method - If True, will scroll to output component on completion
1142
+ show_progress: Instance Method - If True, will show progress animation while pending
1143
+ queue: Instance Method - If True, will place the request on the queue, if the queue exists
1144
+ batch: Instance Method - If True, then the function should process a batch of inputs, meaning that it should accept a list of input values for each parameter. The lists should be of equal length (and be up to length `max_batch_size`). The function is then *required* to return a tuple of lists (even if there is only 1 output component), with each list in the tuple corresponding to one output component.
1145
+ max_batch_size: Instance Method - Maximum number of inputs to batch together if this is called from the queue (only relevant if batch=True)
1146
+ preprocess: Instance Method - If False, will not run preprocessing of component data before running 'fn' (e.g. leaving it as a base64 string if this method is called with the `Image` component).
1147
+ postprocess: Instance Method - If False, will not run postprocessing of component data before returning 'fn' output to the browser.
1148
+ every: Instance Method - Run this event 'every' number of seconds. Interpreted in seconds. Queue must be enabled.
1149
+ Example:
1150
+ import gradio as gr
1151
+ import datetime
1152
+ with gr.Blocks() as demo:
1153
+ def get_time():
1154
+ return datetime.datetime.now().time()
1155
+ dt = gr.Textbox(label="Current time")
1156
+ demo.load(get_time, inputs=None, outputs=dt)
1157
+ demo.launch()
1158
+ """
1159
+ # _js: Optional frontend js method to run before running 'fn'. Input arguments for js method are values of 'inputs' and 'outputs', return should be a list of values for output components.
1160
+ if isinstance(self_or_cls, type):
1161
+ if name is None:
1162
+ raise ValueError(
1163
+ "Blocks.load() requires passing parameters as keyword arguments"
1164
+ )
1165
+ return external.load_blocks_from_repo(name, src, api_key, alias, **kwargs)
1166
+ else:
1167
+ return self_or_cls.set_event_trigger(
1168
+ event_name="load",
1169
+ fn=fn,
1170
+ inputs=inputs,
1171
+ outputs=outputs,
1172
+ api_name=api_name,
1173
+ preprocess=preprocess,
1174
+ postprocess=postprocess,
1175
+ scroll_to_output=scroll_to_output,
1176
+ show_progress=show_progress,
1177
+ js=_js,
1178
+ queue=queue,
1179
+ batch=batch,
1180
+ max_batch_size=max_batch_size,
1181
+ every=every,
1182
+ no_target=True,
1183
+ )
1184
+
1185
+ def clear(self):
1186
+ """Resets the layout of the Blocks object."""
1187
+ self.blocks = {}
1188
+ self.fns = []
1189
+ self.dependencies = []
1190
+ self.children = []
1191
+ return self
1192
+
1193
+ @document()
1194
+ def queue(
1195
+ self,
1196
+ concurrency_count: int = 1,
1197
+ status_update_rate: float | Literal["auto"] = "auto",
1198
+ client_position_to_load_data: int | None = None,
1199
+ default_enabled: bool | None = None,
1200
+ api_open: bool = True,
1201
+ max_size: int | None = None,
1202
+ ):
1203
+ """
1204
+ You can control the rate of processed requests by creating a queue. This will allow you to set the number of requests to be processed at one time, and will let users know their position in the queue.
1205
+ Parameters:
1206
+ concurrency_count: Number of worker threads that will be processing requests from the queue concurrently. Increasing this number will increase the rate at which requests are processed, but will also increase the memory usage of the queue.
1207
+ status_update_rate: If "auto", Queue will send status estimations to all clients whenever a job is finished. Otherwise Queue will send status at regular intervals set by this parameter as the number of seconds.
1208
+ client_position_to_load_data: DEPRECATED. This parameter is deprecated and has no effect.
1209
+ default_enabled: Deprecated and has no effect.
1210
+ api_open: If True, the REST routes of the backend will be open, allowing requests made directly to those endpoints to skip the queue.
1211
+ max_size: The maximum number of events the queue will store at any given moment. If the queue is full, new events will not be added and a user will receive a message saying that the queue is full. If None, the queue size will be unlimited.
1212
+ Example:
1213
+ demo = gr.Interface(gr.Textbox(), gr.Image(), image_generator)
1214
+ demo.queue(concurrency_count=3)
1215
+ demo.launch()
1216
+ """
1217
+ if default_enabled is not None:
1218
+ warnings.warn(
1219
+ "The default_enabled parameter of queue has no effect and will be removed "
1220
+ "in a future version of gradio."
1221
+ )
1222
+ self.enable_queue = True
1223
+ self.api_open = api_open
1224
+ if client_position_to_load_data is not None:
1225
+ warnings.warn("The client_position_to_load_data parameter is deprecated.")
1226
+ self._queue = queueing.Queue(
1227
+ live_updates=status_update_rate == "auto",
1228
+ concurrency_count=concurrency_count,
1229
+ update_intervals=status_update_rate if status_update_rate != "auto" else 1,
1230
+ max_size=max_size,
1231
+ blocks_dependencies=self.dependencies,
1232
+ )
1233
+ self.config = self.get_config_file()
1234
+ return self
1235
+
1236
+ def launch(
1237
+ self,
1238
+ inline: bool | None = None,
1239
+ inbrowser: bool = False,
1240
+ share: bool | None = None,
1241
+ debug: bool = False,
1242
+ enable_queue: bool | None = None,
1243
+ max_threads: int = 40,
1244
+ auth: Callable | Tuple[str, str] | List[Tuple[str, str]] | None = None,
1245
+ auth_message: str | None = None,
1246
+ prevent_thread_lock: bool = False,
1247
+ show_error: bool = False,
1248
+ server_name: str | None = None,
1249
+ server_port: int | None = None,
1250
+ show_tips: bool = False,
1251
+ height: int = 500,
1252
+ width: int | str = "100%",
1253
+ encrypt: bool = False,
1254
+ favicon_path: str | None = None,
1255
+ ssl_keyfile: str | None = None,
1256
+ ssl_certfile: str | None = None,
1257
+ ssl_keyfile_password: str | None = None,
1258
+ quiet: bool = False,
1259
+ show_api: bool = True,
1260
+ _frontend: bool = True,
1261
+ ) -> Tuple[FastAPI, str, str]:
1262
+ """
1263
+ Launches a simple web server that serves the demo. Can also be used to create a
1264
+ public link used by anyone to access the demo from their browser by setting share=True.
1265
+
1266
+ Parameters:
1267
+ inline: whether to display in the interface inline in an iframe. Defaults to True in python notebooks; False otherwise.
1268
+ inbrowser: whether to automatically launch the interface in a new tab on the default browser.
1269
+ share: whether to create a publicly shareable link for the interface. Creates an SSH tunnel to make your UI accessible from anywhere. If not provided, it is set to False by default every time, except when running in Google Colab. When localhost is not accessible (e.g. Google Colab), setting share=False is not supported.
1270
+ debug: if True, blocks the main thread from running. If running in Google Colab, this is needed to print the errors in the cell output.
1271
+ auth: If provided, username and password (or list of username-password tuples) required to access interface. Can also provide function that takes username and password and returns True if valid login.
1272
+ auth_message: If provided, HTML message provided on login page.
1273
+ prevent_thread_lock: If True, the interface will block the main thread while the server is running.
1274
+ show_error: If True, any errors in the interface will be displayed in an alert modal and printed in the browser console log
1275
+ server_port: will start gradio app on this port (if available). Can be set by environment variable GRADIO_SERVER_PORT. If None, will search for an available port starting at 7860.
1276
+ server_name: to make app accessible on local network, set this to "0.0.0.0". Can be set by environment variable GRADIO_SERVER_NAME. If None, will use "127.0.0.1".
1277
+ show_tips: if True, will occasionally show tips about new Gradio features
1278
+ enable_queue: DEPRECATED (use .queue() method instead.) if True, inference requests will be served through a queue instead of with parallel threads. Required for longer inference times (> 1min) to prevent timeout. The default option in HuggingFace Spaces is True. The default option elsewhere is False.
1279
+ max_threads: the maximum number of total threads that the Gradio app can generate in parallel. The default is inherited from the starlette library (currently 40). Applies whether the queue is enabled or not. But if queuing is enabled, this parameter is increaseed to be at least the concurrency_count of the queue.
1280
+ width: The width in pixels of the iframe element containing the interface (used if inline=True)
1281
+ height: The height in pixels of the iframe element containing the interface (used if inline=True)
1282
+ encrypt: If True, flagged data will be encrypted by key provided by creator at launch
1283
+ favicon_path: If a path to a file (.png, .gif, or .ico) is provided, it will be used as the favicon for the web page.
1284
+ ssl_keyfile: If a path to a file is provided, will use this as the private key file to create a local server running on https.
1285
+ ssl_certfile: If a path to a file is provided, will use this as the signed certificate for https. Needs to be provided if ssl_keyfile is provided.
1286
+ ssl_keyfile_password: If a password is provided, will use this with the ssl certificate for https.
1287
+ quiet: If True, suppresses most print statements.
1288
+ show_api: If True, shows the api docs in the footer of the app. Default True. If the queue is enabled, then api_open parameter of .queue() will determine if the api docs are shown, independent of the value of show_api.
1289
+ Returns:
1290
+ app: FastAPI app object that is running the demo
1291
+ local_url: Locally accessible link to the demo
1292
+ share_url: Publicly accessible link to the demo (if share=True, otherwise None)
1293
+ Example:
1294
+ import gradio as gr
1295
+ def reverse(text):
1296
+ return text[::-1]
1297
+ demo = gr.Interface(reverse, "text", "text")
1298
+ demo.launch(share=True, auth=("username", "password"))
1299
+ """
1300
+ self.dev_mode = False
1301
+ if (
1302
+ auth
1303
+ and not callable(auth)
1304
+ and not isinstance(auth[0], tuple)
1305
+ and not isinstance(auth[0], list)
1306
+ ):
1307
+ self.auth = [auth]
1308
+ else:
1309
+ self.auth = auth
1310
+ self.auth_message = auth_message
1311
+ self.show_tips = show_tips
1312
+ self.show_error = show_error
1313
+ self.height = height
1314
+ self.width = width
1315
+ self.favicon_path = favicon_path
1316
+ self.progress_tracking = any(
1317
+ block_fn.fn is not None and special_args(block_fn.fn)[1] is not None
1318
+ for block_fn in self.fns
1319
+ )
1320
+
1321
+ if enable_queue is not None:
1322
+ self.enable_queue = enable_queue
1323
+ warnings.warn(
1324
+ "The `enable_queue` parameter has been deprecated. Please use the `.queue()` method instead.",
1325
+ DeprecationWarning,
1326
+ )
1327
+
1328
+ if self.is_space:
1329
+ self.enable_queue = self.enable_queue is not False
1330
+ else:
1331
+ self.enable_queue = self.enable_queue is True
1332
+ if self.enable_queue and not hasattr(self, "_queue"):
1333
+ self.queue()
1334
+ self.show_api = self.api_open if self.enable_queue else show_api
1335
+
1336
+ if not self.enable_queue and self.progress_tracking:
1337
+ raise ValueError("Progress tracking requires queuing to be enabled.")
1338
+
1339
+ for dep in self.dependencies:
1340
+ for i in dep["cancels"]:
1341
+ if not self.queue_enabled_for_fn(i):
1342
+ raise ValueError(
1343
+ "In order to cancel an event, the queue for that event must be enabled! "
1344
+ "You may get this error by either 1) passing a function that uses the yield keyword "
1345
+ "into an interface without enabling the queue or 2) defining an event that cancels "
1346
+ "another event without enabling the queue. Both can be solved by calling .queue() "
1347
+ "before .launch()"
1348
+ )
1349
+ if dep["batch"] and (
1350
+ dep["queue"] is False
1351
+ or (dep["queue"] is None and not self.enable_queue)
1352
+ ):
1353
+ raise ValueError("In order to use batching, the queue must be enabled.")
1354
+
1355
+ self.config = self.get_config_file()
1356
+ self.encrypt = encrypt
1357
+ self.max_threads = max(
1358
+ self._queue.max_thread_count if self.enable_queue else 0, max_threads
1359
+ )
1360
+ if self.encrypt:
1361
+ self.encryption_key = encryptor.get_key(
1362
+ getpass.getpass("Enter key for encryption: ")
1363
+ )
1364
+
1365
+ if self.is_running:
1366
+ assert isinstance(
1367
+ self.local_url, str
1368
+ ), f"Invalid local_url: {self.local_url}"
1369
+ if not (quiet):
1370
+ print(
1371
+ "Rerunning server... use `close()` to stop if you need to change `launch()` parameters.\n----"
1372
+ )
1373
+ else:
1374
+ server_name, server_port, local_url, app, server = networking.start_server(
1375
+ self,
1376
+ server_name,
1377
+ server_port,
1378
+ ssl_keyfile,
1379
+ ssl_certfile,
1380
+ ssl_keyfile_password,
1381
+ )
1382
+ self.server_name = server_name
1383
+ self.local_url = local_url
1384
+ self.server_port = server_port
1385
+ self.server_app = app
1386
+ self.server = server
1387
+ self.is_running = True
1388
+ self.is_colab = utils.colab_check()
1389
+ self.protocol = (
1390
+ "https"
1391
+ if self.local_url.startswith("https") or self.is_colab
1392
+ else "http"
1393
+ )
1394
+
1395
+ if self.enable_queue:
1396
+ self._queue.set_url(self.local_url)
1397
+
1398
+ # Cannot run async functions in background other than app's scope.
1399
+ # Workaround by triggering the app endpoint
1400
+ requests.get(f"{self.local_url}startup-events")
1401
+
1402
+ if self.enable_queue:
1403
+ if self.encrypt:
1404
+ raise ValueError("Cannot queue with encryption enabled.")
1405
+ utils.launch_counter()
1406
+
1407
+ self.share = (
1408
+ share
1409
+ if share is not None
1410
+ else True
1411
+ if self.is_colab and self.enable_queue
1412
+ else False
1413
+ )
1414
+
1415
+ # If running in a colab or not able to access localhost,
1416
+ # a shareable link must be created.
1417
+ if _frontend and (not networking.url_ok(self.local_url)) and (not self.share):
1418
+ raise ValueError(
1419
+ "When localhost is not accessible, a shareable link must be created. Please set share=True."
1420
+ )
1421
+
1422
+ if self.is_colab:
1423
+ if not quiet:
1424
+ if debug:
1425
+ print(strings.en["COLAB_DEBUG_TRUE"])
1426
+ else:
1427
+ print(strings.en["COLAB_DEBUG_FALSE"])
1428
+ if not self.share:
1429
+ print(strings.en["COLAB_WARNING"].format(self.server_port))
1430
+ if self.enable_queue and not self.share:
1431
+ raise ValueError(
1432
+ "When using queueing in Colab, a shareable link must be created. Please set share=True."
1433
+ )
1434
+ else:
1435
+ print(
1436
+ strings.en["RUNNING_LOCALLY_SEPARATED"].format(
1437
+ self.protocol, self.server_name, self.server_port
1438
+ )
1439
+ )
1440
+
1441
+ if self.share:
1442
+ if self.is_space:
1443
+ raise RuntimeError("Share is not supported when you are in Spaces")
1444
+ try:
1445
+ if self.share_url is None:
1446
+ self.share_url = networking.setup_tunnel(
1447
+ self.server_name, self.server_port
1448
+ )
1449
+ print(strings.en["SHARE_LINK_DISPLAY"].format(self.share_url))
1450
+ if not (quiet):
1451
+ print(strings.en["SHARE_LINK_MESSAGE"])
1452
+ except RuntimeError:
1453
+ if self.analytics_enabled:
1454
+ utils.error_analytics(self.ip_address, "Not able to set up tunnel")
1455
+ self.share_url = None
1456
+ self.share = False
1457
+ print(strings.en["COULD_NOT_GET_SHARE_LINK"])
1458
+ else:
1459
+ if not (quiet):
1460
+ print(strings.en["PUBLIC_SHARE_TRUE"])
1461
+ self.share_url = None
1462
+
1463
+ if inbrowser:
1464
+ link = self.share_url if self.share and self.share_url else self.local_url
1465
+ webbrowser.open(link)
1466
+
1467
+ # Check if running in a Python notebook in which case, display inline
1468
+ if inline is None:
1469
+ inline = utils.ipython_check() and (self.auth is None)
1470
+ if inline:
1471
+ if self.auth is not None:
1472
+ print(
1473
+ "Warning: authentication is not supported inline. Please"
1474
+ "click the link to access the interface in a new tab."
1475
+ )
1476
+ try:
1477
+ from IPython.display import HTML, Javascript, display # type: ignore
1478
+
1479
+ if self.share and self.share_url:
1480
+ while not networking.url_ok(self.share_url):
1481
+ time.sleep(0.25)
1482
+ display(
1483
+ HTML(
1484
+ f'<div><iframe src="{self.share_url}" width="{self.width}" height="{self.height}" allow="autoplay; camera; microphone; clipboard-read; clipboard-write;" frameborder="0" allowfullscreen></iframe></div>'
1485
+ )
1486
+ )
1487
+ elif self.is_colab:
1488
+ # modified from /usr/local/lib/python3.7/dist-packages/google/colab/output/_util.py within Colab environment
1489
+ code = """(async (port, path, width, height, cache, element) => {
1490
+ if (!google.colab.kernel.accessAllowed && !cache) {
1491
+ return;
1492
+ }
1493
+ element.appendChild(document.createTextNode(''));
1494
+ const url = await google.colab.kernel.proxyPort(port, {cache});
1495
+
1496
+ const external_link = document.createElement('div');
1497
+ external_link.innerHTML = `
1498
+ <div style="font-family: monospace; margin-bottom: 0.5rem">
1499
+ Running on <a href=${new URL(path, url).toString()} target="_blank">
1500
+ https://localhost:${port}${path}
1501
+ </a>
1502
+ </div>
1503
+ `;
1504
+ element.appendChild(external_link);
1505
+
1506
+ const iframe = document.createElement('iframe');
1507
+ iframe.src = new URL(path, url).toString();
1508
+ iframe.height = height;
1509
+ iframe.allow = "autoplay; camera; microphone; clipboard-read; clipboard-write;"
1510
+ iframe.width = width;
1511
+ iframe.style.border = 0;
1512
+ element.appendChild(iframe);
1513
+ })""" + "({port}, {path}, {width}, {height}, {cache}, window.element)".format(
1514
+ port=json.dumps(self.server_port),
1515
+ path=json.dumps("/"),
1516
+ width=json.dumps(self.width),
1517
+ height=json.dumps(self.height),
1518
+ cache=json.dumps(False),
1519
+ )
1520
+
1521
+ display(Javascript(code))
1522
+ else:
1523
+ display(
1524
+ HTML(
1525
+ f'<div><iframe src="{self.local_url}" width="{self.width}" height="{self.height}" allow="autoplay; camera; microphone; clipboard-read; clipboard-write;" frameborder="0" allowfullscreen></iframe></div>'
1526
+ )
1527
+ )
1528
+ except ImportError:
1529
+ pass
1530
+
1531
+ if getattr(self, "analytics_enabled", False):
1532
+ data = {
1533
+ "launch_method": "browser" if inbrowser else "inline",
1534
+ "is_google_colab": self.is_colab,
1535
+ "is_sharing_on": self.share,
1536
+ "share_url": self.share_url,
1537
+ "ip_address": self.ip_address,
1538
+ "enable_queue": self.enable_queue,
1539
+ "show_tips": self.show_tips,
1540
+ "server_name": server_name,
1541
+ "server_port": server_port,
1542
+ "is_spaces": self.is_space,
1543
+ "mode": self.mode,
1544
+ }
1545
+ utils.launch_analytics(data)
1546
+
1547
+ utils.show_tip(self)
1548
+
1549
+ # Block main thread if debug==True
1550
+ if debug or int(os.getenv("GRADIO_DEBUG", 0)) == 1:
1551
+ self.block_thread()
1552
+ # Block main thread if running in a script to stop script from exiting
1553
+ is_in_interactive_mode = bool(getattr(sys, "ps1", sys.flags.interactive))
1554
+
1555
+ if not prevent_thread_lock and not is_in_interactive_mode:
1556
+ self.block_thread()
1557
+
1558
+ return TupleNoPrint((self.server_app, self.local_url, self.share_url))
1559
+
1560
+ def integrate(
1561
+ self,
1562
+ comet_ml: comet_ml.Experiment | None = None,
1563
+ wandb: ModuleType | None = None,
1564
+ mlflow: ModuleType | None = None,
1565
+ ) -> None:
1566
+ """
1567
+ A catch-all method for integrating with other libraries. This method should be run after launch()
1568
+ Parameters:
1569
+ comet_ml: If a comet_ml Experiment object is provided, will integrate with the experiment and appear on Comet dashboard
1570
+ wandb: If the wandb module is provided, will integrate with it and appear on WandB dashboard
1571
+ mlflow: If the mlflow module is provided, will integrate with the experiment and appear on ML Flow dashboard
1572
+ """
1573
+ analytics_integration = ""
1574
+ if comet_ml is not None:
1575
+ analytics_integration = "CometML"
1576
+ comet_ml.log_other("Created from", "Gradio")
1577
+ if self.share_url is not None:
1578
+ comet_ml.log_text("gradio: " + self.share_url)
1579
+ comet_ml.end()
1580
+ elif self.local_url:
1581
+ comet_ml.log_text("gradio: " + self.local_url)
1582
+ comet_ml.end()
1583
+ else:
1584
+ raise ValueError("Please run `launch()` first.")
1585
+ if wandb is not None:
1586
+ analytics_integration = "WandB"
1587
+ if self.share_url is not None:
1588
+ wandb.log(
1589
+ {
1590
+ "Gradio panel": wandb.Html(
1591
+ '<iframe src="'
1592
+ + self.share_url
1593
+ + '" width="'
1594
+ + str(self.width)
1595
+ + '" height="'
1596
+ + str(self.height)
1597
+ + '" frameBorder="0"></iframe>'
1598
+ )
1599
+ }
1600
+ )
1601
+ else:
1602
+ print(
1603
+ "The WandB integration requires you to "
1604
+ "`launch(share=True)` first."
1605
+ )
1606
+ if mlflow is not None:
1607
+ analytics_integration = "MLFlow"
1608
+ if self.share_url is not None:
1609
+ mlflow.log_param("Gradio Interface Share Link", self.share_url)
1610
+ else:
1611
+ mlflow.log_param("Gradio Interface Local Link", self.local_url)
1612
+ if self.analytics_enabled and analytics_integration:
1613
+ data = {"integration": analytics_integration}
1614
+ utils.integration_analytics(data)
1615
+
1616
+ def close(self, verbose: bool = True) -> None:
1617
+ """
1618
+ Closes the Interface that was launched and frees the port.
1619
+ """
1620
+ try:
1621
+ if self.enable_queue:
1622
+ self._queue.close()
1623
+ self.server.close()
1624
+ self.is_running = False
1625
+ if verbose:
1626
+ print("Closing server running on port: {}".format(self.server_port))
1627
+ except (AttributeError, OSError): # can't close if not running
1628
+ pass
1629
+
1630
+ def block_thread(
1631
+ self,
1632
+ ) -> None:
1633
+ """Block main thread until interrupted by user."""
1634
+ try:
1635
+ while True:
1636
+ time.sleep(0.1)
1637
+ except (KeyboardInterrupt, OSError):
1638
+ print("Keyboard interruption in main thread... closing server.")
1639
+ self.server.close()
1640
+ for tunnel in CURRENT_TUNNELS:
1641
+ tunnel.kill()
1642
+
1643
+ def attach_load_events(self):
1644
+ """Add a load event for every component whose initial value should be randomized."""
1645
+ if Context.root_block:
1646
+ for component in Context.root_block.blocks.values():
1647
+ if (
1648
+ isinstance(component, components.IOComponent)
1649
+ and component.load_event_to_attach
1650
+ ):
1651
+ load_fn, every = component.load_event_to_attach
1652
+ # Use set_event_trigger to avoid ambiguity between load class/instance method
1653
+ self.set_event_trigger(
1654
+ "load",
1655
+ load_fn,
1656
+ None,
1657
+ component,
1658
+ no_target=True,
1659
+ queue=False,
1660
+ every=every,
1661
+ )
1662
+
1663
+ def startup_events(self):
1664
+ """Events that should be run when the app containing this block starts up."""
1665
+
1666
+ if self.enable_queue:
1667
+ utils.run_coro_in_background(self._queue.start, (self.progress_tracking,))
1668
+ utils.run_coro_in_background(self.create_limiter)
1669
+
1670
+ def queue_enabled_for_fn(self, fn_index: int):
1671
+ if self.dependencies[fn_index]["queue"] is None:
1672
+ return self.enable_queue
1673
+ return self.dependencies[fn_index]["queue"]
gradio-modified/gradio/components.py ADDED
The diff for this file is too large to render. See raw diff
 
gradio-modified/gradio/context.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Defines the Context class, which is used to store the state of all Blocks that are being rendered.
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import TYPE_CHECKING
6
+
7
+ if TYPE_CHECKING: # Only import for type checking (is False at runtime).
8
+ from gradio.blocks import BlockContext, Blocks
9
+
10
+
11
+ class Context:
12
+ root_block: Blocks | None = None # The current root block that holds all blocks.
13
+ block: BlockContext | None = None # The current block that children are added to.
14
+ id: int = 0 # Running id to uniquely refer to any block that gets defined
gradio-modified/gradio/data_classes.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Pydantic data models and other dataclasses. This is the only file that uses Optional[]
2
+ typing syntax instead of | None syntax to work with pydantic"""
3
+
4
+ from enum import Enum, auto
5
+ from typing import Any, Dict, List, Optional, Union
6
+
7
+ from pydantic import BaseModel
8
+
9
+
10
+ class PredictBody(BaseModel):
11
+ session_hash: Optional[str]
12
+ event_id: Optional[str]
13
+ data: List[Any]
14
+ fn_index: Optional[int]
15
+ batched: Optional[
16
+ bool
17
+ ] = False # Whether the data is a batch of samples (i.e. called from the queue if batch=True) or a single sample (i.e. called from the UI)
18
+ request: Optional[
19
+ Union[Dict, List[Dict]]
20
+ ] = None # dictionary of request headers, query parameters, url, etc. (used to to pass in request for queuing)
21
+
22
+
23
+ class ResetBody(BaseModel):
24
+ session_hash: str
25
+ fn_index: int
26
+
27
+
28
+ class InterfaceTypes(Enum):
29
+ STANDARD = auto()
30
+ INPUT_ONLY = auto()
31
+ OUTPUT_ONLY = auto()
32
+ UNIFIED = auto()
33
+
34
+
35
+ class Estimation(BaseModel):
36
+ msg: Optional[str] = "estimation"
37
+ rank: Optional[int] = None
38
+ queue_size: int
39
+ avg_event_process_time: Optional[float]
40
+ avg_event_concurrent_process_time: Optional[float]
41
+ rank_eta: Optional[float] = None
42
+ queue_eta: float
43
+
44
+
45
+ class ProgressUnit(BaseModel):
46
+ index: Optional[int]
47
+ length: Optional[int]
48
+ unit: Optional[str]
49
+ progress: Optional[float]
50
+ desc: Optional[str]
51
+
52
+
53
+ class Progress(BaseModel):
54
+ msg: str = "progress"
55
+ progress_data: List[ProgressUnit] = []
gradio-modified/gradio/deprecation.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+
3
+
4
+ def simple_deprecated_notice(term: str) -> str:
5
+ return f"`{term}` parameter is deprecated, and it has no effect"
6
+
7
+
8
+ def use_in_launch(term: str) -> str:
9
+ return f"`{term}` is deprecated in `Interface()`, please use it within `launch()` instead."
10
+
11
+
12
+ DEPRECATION_MESSAGE = {
13
+ "optional": simple_deprecated_notice("optional"),
14
+ "keep_filename": simple_deprecated_notice("keep_filename"),
15
+ "numeric": simple_deprecated_notice("numeric"),
16
+ "verbose": simple_deprecated_notice("verbose"),
17
+ "allow_screenshot": simple_deprecated_notice("allow_screenshot"),
18
+ "layout": simple_deprecated_notice("layout"),
19
+ "show_input": simple_deprecated_notice("show_input"),
20
+ "show_output": simple_deprecated_notice("show_output"),
21
+ "capture_session": simple_deprecated_notice("capture_session"),
22
+ "api_mode": simple_deprecated_notice("api_mode"),
23
+ "show_tips": use_in_launch("show_tips"),
24
+ "encrypt": use_in_launch("encrypt"),
25
+ "enable_queue": use_in_launch("enable_queue"),
26
+ "server_name": use_in_launch("server_name"),
27
+ "server_port": use_in_launch("server_port"),
28
+ "width": use_in_launch("width"),
29
+ "height": use_in_launch("height"),
30
+ "plot": "The 'plot' parameter has been deprecated. Use the new Plot component instead",
31
+ "type": "The 'type' parameter has been deprecated. Use the Number component instead.",
32
+ }
33
+
34
+
35
+ def check_deprecated_parameters(cls: str, **kwargs) -> None:
36
+ for key, value in DEPRECATION_MESSAGE.items():
37
+ if key in kwargs:
38
+ kwargs.pop(key)
39
+ # Interestingly, using DeprecationWarning causes warning to not appear.
40
+ warnings.warn(value)
41
+
42
+ if len(kwargs) != 0:
43
+ warnings.warn(
44
+ f"You have unused kwarg parameters in {cls}, please remove them: {kwargs}"
45
+ )
gradio-modified/gradio/documentation.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Contains methods that generate documentation for Gradio functions and classes."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import inspect
6
+ from typing import Callable, Dict, List, Tuple
7
+
8
+ classes_to_document = {}
9
+ documentation_group = None
10
+
11
+
12
+ def set_documentation_group(m):
13
+ global documentation_group
14
+ documentation_group = m
15
+ if m not in classes_to_document:
16
+ classes_to_document[m] = []
17
+
18
+
19
+ def document(*fns):
20
+ """
21
+ Defines the @document decorator which adds classes or functions to the Gradio
22
+ documentation at www.gradio.app/docs.
23
+
24
+ Usage examples:
25
+ - Put @document() above a class to document the class and its constructor.
26
+ - Put @document(fn1, fn2) above a class to also document the class methods fn1 and fn2.
27
+ """
28
+
29
+ def inner_doc(cls):
30
+ global documentation_group
31
+ classes_to_document[documentation_group].append((cls, fns))
32
+ return cls
33
+
34
+ return inner_doc
35
+
36
+
37
+ def document_fn(fn: Callable) -> Tuple[str, List[Dict], Dict, str | None]:
38
+ """
39
+ Generates documentation for any function.
40
+ Parameters:
41
+ fn: Function to document
42
+ Returns:
43
+ description: General description of fn
44
+ parameters: A list of dicts for each parameter, storing data for the parameter name, annotation and doc
45
+ return: A dict storing data for the returned annotation and doc
46
+ example: Code for an example use of the fn
47
+ """
48
+ doc_str = inspect.getdoc(fn) or ""
49
+ doc_lines = doc_str.split("\n")
50
+ signature = inspect.signature(fn)
51
+ description, parameters, returns, examples = [], {}, [], []
52
+ mode = "description"
53
+ for line in doc_lines:
54
+ line = line.rstrip()
55
+ if line == "Parameters:":
56
+ mode = "parameter"
57
+ elif line == "Example:":
58
+ mode = "example"
59
+ elif line == "Returns:":
60
+ mode = "return"
61
+ else:
62
+ if mode == "description":
63
+ description.append(line if line.strip() else "<br>")
64
+ continue
65
+ assert (
66
+ line.startswith(" ") or line.strip() == ""
67
+ ), f"Documentation format for {fn.__name__} has format error in line: {line}"
68
+ line = line[4:]
69
+ if mode == "parameter":
70
+ colon_index = line.index(": ")
71
+ assert (
72
+ colon_index > -1
73
+ ), f"Documentation format for {fn.__name__} has format error in line: {line}"
74
+ parameter = line[:colon_index]
75
+ parameter_doc = line[colon_index + 2 :]
76
+ parameters[parameter] = parameter_doc
77
+ elif mode == "return":
78
+ returns.append(line)
79
+ elif mode == "example":
80
+ examples.append(line)
81
+ description_doc = " ".join(description)
82
+ parameter_docs = []
83
+ for param_name, param in signature.parameters.items():
84
+ if param_name.startswith("_"):
85
+ continue
86
+ if param_name == "kwargs" and param_name not in parameters:
87
+ continue
88
+ parameter_doc = {
89
+ "name": param_name,
90
+ "annotation": param.annotation,
91
+ "doc": parameters.get(param_name),
92
+ }
93
+ if param_name in parameters:
94
+ del parameters[param_name]
95
+ if param.default != inspect.Parameter.empty:
96
+ default = param.default
97
+ if type(default) == str:
98
+ default = '"' + default + '"'
99
+ if default.__class__.__module__ != "builtins":
100
+ default = f"{default.__class__.__name__}()"
101
+ parameter_doc["default"] = default
102
+ elif parameter_doc["doc"] is not None and "kwargs" in parameter_doc["doc"]:
103
+ parameter_doc["kwargs"] = True
104
+ parameter_docs.append(parameter_doc)
105
+ assert (
106
+ len(parameters) == 0
107
+ ), f"Documentation format for {fn.__name__} documents nonexistent parameters: {''.join(parameters.keys())}"
108
+ if len(returns) == 0:
109
+ return_docs = {}
110
+ elif len(returns) == 1:
111
+ return_docs = {"annotation": signature.return_annotation, "doc": returns[0]}
112
+ else:
113
+ return_docs = {}
114
+ # raise ValueError("Does not support multiple returns yet.")
115
+ examples_doc = "\n".join(examples) if len(examples) > 0 else None
116
+ return description_doc, parameter_docs, return_docs, examples_doc
117
+
118
+
119
+ def document_cls(cls):
120
+ doc_str = inspect.getdoc(cls)
121
+ if doc_str is None:
122
+ return "", {}, ""
123
+ tags = {}
124
+ description_lines = []
125
+ mode = "description"
126
+ for line in doc_str.split("\n"):
127
+ line = line.rstrip()
128
+ if line.endswith(":") and " " not in line:
129
+ mode = line[:-1].lower()
130
+ tags[mode] = []
131
+ elif line.split(" ")[0].endswith(":") and not line.startswith(" "):
132
+ tag = line[: line.index(":")].lower()
133
+ value = line[line.index(":") + 2 :]
134
+ tags[tag] = value
135
+ else:
136
+ if mode == "description":
137
+ description_lines.append(line if line.strip() else "<br>")
138
+ else:
139
+ assert (
140
+ line.startswith(" ") or not line.strip()
141
+ ), f"Documentation format for {cls.__name__} has format error in line: {line}"
142
+ tags[mode].append(line[4:])
143
+ if "example" in tags:
144
+ example = "\n".join(tags["example"])
145
+ del tags["example"]
146
+ else:
147
+ example = None
148
+ for key, val in tags.items():
149
+ if isinstance(val, list):
150
+ tags[key] = "<br>".join(val)
151
+ description = " ".join(description_lines).replace("\n", "<br>")
152
+ return description, tags, example
153
+
154
+
155
+ def generate_documentation():
156
+ documentation = {}
157
+ for mode, class_list in classes_to_document.items():
158
+ documentation[mode] = []
159
+ for cls, fns in class_list:
160
+ fn_to_document = cls if inspect.isfunction(cls) else cls.__init__
161
+ _, parameter_doc, return_doc, _ = document_fn(fn_to_document)
162
+ cls_description, cls_tags, cls_example = document_cls(cls)
163
+ cls_documentation = {
164
+ "class": cls,
165
+ "name": cls.__name__,
166
+ "description": cls_description,
167
+ "tags": cls_tags,
168
+ "parameters": parameter_doc,
169
+ "returns": return_doc,
170
+ "example": cls_example,
171
+ "fns": [],
172
+ }
173
+ for fn_name in fns:
174
+ fn = getattr(cls, fn_name)
175
+ (
176
+ description_doc,
177
+ parameter_docs,
178
+ return_docs,
179
+ examples_doc,
180
+ ) = document_fn(fn)
181
+ cls_documentation["fns"].append(
182
+ {
183
+ "fn": fn,
184
+ "name": fn_name,
185
+ "description": description_doc,
186
+ "tags": {},
187
+ "parameters": parameter_docs,
188
+ "returns": return_docs,
189
+ "example": examples_doc,
190
+ }
191
+ )
192
+ documentation[mode].append(cls_documentation)
193
+ return documentation
gradio-modified/gradio/encryptor.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from Crypto import Random
2
+ from Crypto.Cipher import AES
3
+ from Crypto.Hash import SHA256
4
+
5
+
6
+ def get_key(password: str) -> bytes:
7
+ """Generates an encryption key based on the password provided."""
8
+ key = SHA256.new(password.encode()).digest()
9
+ return key
10
+
11
+
12
+ def encrypt(key: bytes, source: bytes) -> bytes:
13
+ """Encrypts source data using the provided encryption key"""
14
+ IV = Random.new().read(AES.block_size) # generate IV
15
+ encryptor = AES.new(key, AES.MODE_CBC, IV)
16
+ padding = AES.block_size - len(source) % AES.block_size # calculate needed padding
17
+ source += bytes([padding]) * padding # Python 2.x: source += chr(padding) * padding
18
+ data = IV + encryptor.encrypt(source) # store the IV at the beginning and encrypt
19
+ return data
20
+
21
+
22
+ def decrypt(key: bytes, source: bytes) -> bytes:
23
+ IV = source[: AES.block_size] # extract the IV from the beginning
24
+ decryptor = AES.new(key, AES.MODE_CBC, IV)
25
+ data = decryptor.decrypt(source[AES.block_size :]) # decrypt
26
+ padding = data[-1] # pick the padding value from the end; Python 2.x: ord(data[-1])
27
+ if (
28
+ data[-padding:] != bytes([padding]) * padding
29
+ ): # Python 2.x: chr(padding) * padding
30
+ raise ValueError("Invalid padding...")
31
+ return data[:-padding] # remove the padding
gradio-modified/gradio/events.py ADDED
@@ -0,0 +1,723 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Contains all of the events that can be triggered in a gr.Blocks() app, with the exception
2
+ of the on-page-load event, which is defined in gr.Blocks().load()."""
3
+
4
+ from __future__ import annotations
5
+
6
+ import warnings
7
+ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Set
8
+
9
+ from gradio.blocks import Block
10
+ from gradio.utils import get_cancel_function
11
+
12
+ if TYPE_CHECKING: # Only import for type checking (is False at runtime).
13
+ from gradio.components import Component, StatusTracker
14
+
15
+
16
+ def set_cancel_events(
17
+ block: Block, event_name: str, cancels: None | Dict[str, Any] | List[Dict[str, Any]]
18
+ ):
19
+ if cancels:
20
+ if not isinstance(cancels, list):
21
+ cancels = [cancels]
22
+ cancel_fn, fn_indices_to_cancel = get_cancel_function(cancels)
23
+ block.set_event_trigger(
24
+ event_name,
25
+ cancel_fn,
26
+ inputs=None,
27
+ outputs=None,
28
+ queue=False,
29
+ preprocess=False,
30
+ cancels=fn_indices_to_cancel,
31
+ )
32
+
33
+
34
+ class EventListener(Block):
35
+ pass
36
+
37
+
38
+ class Changeable(EventListener):
39
+ def change(
40
+ self,
41
+ fn: Callable | None,
42
+ inputs: Component | List[Component] | Set[Component] | None = None,
43
+ outputs: Component | List[Component] | None = None,
44
+ api_name: str | None = None,
45
+ status_tracker: StatusTracker | None = None,
46
+ scroll_to_output: bool = False,
47
+ show_progress: bool = True,
48
+ queue: bool | None = None,
49
+ batch: bool = False,
50
+ max_batch_size: int = 4,
51
+ preprocess: bool = True,
52
+ postprocess: bool = True,
53
+ cancels: Dict[str, Any] | List[Dict[str, Any]] | None = None,
54
+ every: float | None = None,
55
+ _js: str | None = None,
56
+ ):
57
+ """
58
+ This event is triggered when the component's input value changes (e.g. when the user types in a textbox
59
+ or uploads an image). This method can be used when this component is in a Gradio Blocks.
60
+
61
+ Parameters:
62
+ fn: the function to wrap an interface around. Often a machine learning model's prediction function. Each parameter of the function corresponds to one input component, and the function should return a single value or a tuple of values, with each element in the tuple corresponding to one output component.
63
+ inputs: List of gradio.components to use as inputs. If the function takes no inputs, this should be an empty list.
64
+ outputs: List of gradio.components to use as inputs. If the function returns no outputs, this should be an empty list.
65
+ api_name: Defining this parameter exposes the endpoint in the api docs
66
+ scroll_to_output: If True, will scroll to output component on completion
67
+ show_progress: If True, will show progress animation while pending
68
+ queue: If True, will place the request on the queue, if the queue exists
69
+ batch: If True, then the function should process a batch of inputs, meaning that it should accept a list of input values for each parameter. The lists should be of equal length (and be up to length `max_batch_size`). The function is then *required* to return a tuple of lists (even if there is only 1 output component), with each list in the tuple corresponding to one output component.
70
+ max_batch_size: Maximum number of inputs to batch together if this is called from the queue (only relevant if batch=True)
71
+ preprocess: If False, will not run preprocessing of component data before running 'fn' (e.g. leaving it as a base64 string if this method is called with the `Image` component).
72
+ postprocess: If False, will not run postprocessing of component data before returning 'fn' output to the browser.
73
+ cancels: A list of other events to cancel when this event is triggered. For example, setting cancels=[click_event] will cancel the click_event, where click_event is the return value of another components .click method.
74
+ every: Run this event 'every' number of seconds while the client connection is open. Interpreted in seconds. Queue must be enabled.
75
+ """
76
+ # _js: Optional frontend js method to run before running 'fn'. Input arguments for js method are values of 'inputs' and 'outputs', return should be a list of values for output components.
77
+ if status_tracker:
78
+ warnings.warn(
79
+ "The 'status_tracker' parameter has been deprecated and has no effect."
80
+ )
81
+ dep = self.set_event_trigger(
82
+ "change",
83
+ fn,
84
+ inputs,
85
+ outputs,
86
+ preprocess=preprocess,
87
+ postprocess=postprocess,
88
+ scroll_to_output=scroll_to_output,
89
+ show_progress=show_progress,
90
+ api_name=api_name,
91
+ js=_js,
92
+ queue=queue,
93
+ batch=batch,
94
+ max_batch_size=max_batch_size,
95
+ every=every,
96
+ )
97
+ set_cancel_events(self, "change", cancels)
98
+ return dep
99
+
100
+
101
+ class Clickable(EventListener):
102
+ def click(
103
+ self,
104
+ fn: Callable | None,
105
+ inputs: Component | List[Component] | Set[Component] | None = None,
106
+ outputs: Component | List[Component] | None = None,
107
+ api_name: str | None = None,
108
+ status_tracker: StatusTracker | None = None,
109
+ scroll_to_output: bool = False,
110
+ show_progress: bool = True,
111
+ queue=None,
112
+ batch: bool = False,
113
+ max_batch_size: int = 4,
114
+ preprocess: bool = True,
115
+ postprocess: bool = True,
116
+ cancels: Dict[str, Any] | List[Dict[str, Any]] | None = None,
117
+ every: float | None = None,
118
+ _js: str | None = None,
119
+ ):
120
+ """
121
+ This event is triggered when the component (e.g. a button) is clicked.
122
+ This method can be used when this component is in a Gradio Blocks.
123
+
124
+ Parameters:
125
+ fn: the function to wrap an interface around. Often a machine learning model's prediction function. Each parameter of the function corresponds to one input component, and the function should return a single value or a tuple of values, with each element in the tuple corresponding to one output component.
126
+ inputs: List of gradio.components to use as inputs. If the function takes no inputs, this should be an empty list.
127
+ outputs: List of gradio.components to use as inputs. If the function returns no outputs, this should be an empty list.
128
+ api_name: Defining this parameter exposes the endpoint in the api docs
129
+ scroll_to_output: If True, will scroll to output component on completion
130
+ show_progress: If True, will show progress animation while pending
131
+ queue: If True, will place the request on the queue, if the queue exists
132
+ batch: If True, then the function should process a batch of inputs, meaning that it should accept a list of input values for each parameter. The lists should be of equal length (and be up to length `max_batch_size`). The function is then *required* to return a tuple of lists (even if there is only 1 output component), with each list in the tuple corresponding to one output component.
133
+ max_batch_size: Maximum number of inputs to batch together if this is called from the queue (only relevant if batch=True)
134
+ preprocess: If False, will not run preprocessing of component data before running 'fn' (e.g. leaving it as a base64 string if this method is called with the `Image` component).
135
+ postprocess: If False, will not run postprocessing of component data before returning 'fn' output to the browser.
136
+ cancels: A list of other events to cancel when this event is triggered. For example, setting cancels=[click_event] will cancel the click_event, where click_event is the return value of another components .click method.
137
+ every: Run this event 'every' number of seconds while the client connection is open. Interpreted in seconds. Queue must be enabled.
138
+ """
139
+ # _js: Optional frontend js method to run before running 'fn'. Input arguments for js method are values of 'inputs' and 'outputs', return should be a list of values for output components.
140
+ if status_tracker:
141
+ warnings.warn(
142
+ "The 'status_tracker' parameter has been deprecated and has no effect."
143
+ )
144
+
145
+ dep = self.set_event_trigger(
146
+ "click",
147
+ fn,
148
+ inputs,
149
+ outputs,
150
+ preprocess=preprocess,
151
+ postprocess=postprocess,
152
+ scroll_to_output=scroll_to_output,
153
+ show_progress=show_progress,
154
+ api_name=api_name,
155
+ js=_js,
156
+ queue=queue,
157
+ batch=batch,
158
+ max_batch_size=max_batch_size,
159
+ every=every,
160
+ )
161
+ set_cancel_events(self, "click", cancels)
162
+ return dep
163
+
164
+
165
+ class Submittable(EventListener):
166
+ def submit(
167
+ self,
168
+ fn: Callable | None,
169
+ inputs: Component | List[Component] | Set[Component] | None = None,
170
+ outputs: Component | List[Component] | None = None,
171
+ api_name: str | None = None,
172
+ status_tracker: StatusTracker | None = None,
173
+ scroll_to_output: bool = False,
174
+ show_progress: bool = True,
175
+ queue: bool | None = None,
176
+ batch: bool = False,
177
+ max_batch_size: int = 4,
178
+ preprocess: bool = True,
179
+ postprocess: bool = True,
180
+ cancels: Dict[str, Any] | List[Dict[str, Any]] | None = None,
181
+ every: float | None = None,
182
+ _js: str | None = None,
183
+ ):
184
+ """
185
+ This event is triggered when the user presses the Enter key while the component (e.g. a textbox) is focused.
186
+ This method can be used when this component is in a Gradio Blocks.
187
+
188
+
189
+ Parameters:
190
+ fn: the function to wrap an interface around. Often a machine learning model's prediction function. Each parameter of the function corresponds to one input component, and the function should return a single value or a tuple of values, with each element in the tuple corresponding to one output component.
191
+ inputs: List of gradio.components to use as inputs. If the function takes no inputs, this should be an empty list.
192
+ outputs: List of gradio.components to use as inputs. If the function returns no outputs, this should be an empty list.
193
+ api_name: Defining this parameter exposes the endpoint in the api docs
194
+ scroll_to_output: If True, will scroll to output component on completion
195
+ show_progress: If True, will show progress animation while pending
196
+ queue: If True, will place the request on the queue, if the queue exists
197
+ batch: If True, then the function should process a batch of inputs, meaning that it should accept a list of input values for each parameter. The lists should be of equal length (and be up to length `max_batch_size`). The function is then *required* to return a tuple of lists (even if there is only 1 output component), with each list in the tuple corresponding to one output component.
198
+ max_batch_size: Maximum number of inputs to batch together if this is called from the queue (only relevant if batch=True)
199
+ preprocess: If False, will not run preprocessing of component data before running 'fn' (e.g. leaving it as a base64 string if this method is called with the `Image` component).
200
+ postprocess: If False, will not run postprocessing of component data before returning 'fn' output to the browser.
201
+ cancels: A list of other events to cancel when this event is triggered. For example, setting cancels=[click_event] will cancel the click_event, where click_event is the return value of another components .click method.
202
+ every: Run this event 'every' number of seconds while the client connection is open. Interpreted in seconds. Queue must be enabled.
203
+ """
204
+ # _js: Optional frontend js method to run before running 'fn'. Input arguments for js method are values of 'inputs' and 'outputs', return should be a list of values for output components.
205
+ if status_tracker:
206
+ warnings.warn(
207
+ "The 'status_tracker' parameter has been deprecated and has no effect."
208
+ )
209
+
210
+ dep = self.set_event_trigger(
211
+ "submit",
212
+ fn,
213
+ inputs,
214
+ outputs,
215
+ preprocess=preprocess,
216
+ postprocess=postprocess,
217
+ scroll_to_output=scroll_to_output,
218
+ show_progress=show_progress,
219
+ api_name=api_name,
220
+ js=_js,
221
+ queue=queue,
222
+ batch=batch,
223
+ max_batch_size=max_batch_size,
224
+ every=every,
225
+ )
226
+ set_cancel_events(self, "submit", cancels)
227
+ return dep
228
+
229
+
230
+ class Editable(EventListener):
231
+ def edit(
232
+ self,
233
+ fn: Callable | None,
234
+ inputs: Component | List[Component] | Set[Component] | None = None,
235
+ outputs: Component | List[Component] | None = None,
236
+ api_name: str | None = None,
237
+ status_tracker: StatusTracker | None = None,
238
+ scroll_to_output: bool = False,
239
+ show_progress: bool = True,
240
+ queue: bool | None = None,
241
+ batch: bool = False,
242
+ max_batch_size: int = 4,
243
+ preprocess: bool = True,
244
+ postprocess: bool = True,
245
+ cancels: Dict[str, Any] | List[Dict[str, Any]] | None = None,
246
+ every: float | None = None,
247
+ _js: str | None = None,
248
+ ):
249
+ """
250
+ This event is triggered when the user edits the component (e.g. image) using the
251
+ built-in editor. This method can be used when this component is in a Gradio Blocks.
252
+
253
+ Parameters:
254
+ fn: the function to wrap an interface around. Often a machine learning model's prediction function. Each parameter of the function corresponds to one input component, and the function should return a single value or a tuple of values, with each element in the tuple corresponding to one output component.
255
+ inputs: List of gradio.components to use as inputs. If the function takes no inputs, this should be an empty list.
256
+ outputs: List of gradio.components to use as inputs. If the function returns no outputs, this should be an empty list.
257
+ api_name: Defining this parameter exposes the endpoint in the api docs
258
+ scroll_to_output: If True, will scroll to output component on completion
259
+ show_progress: If True, will show progress animation while pending
260
+ queue: If True, will place the request on the queue, if the queue exists
261
+ batch: If True, then the function should process a batch of inputs, meaning that it should accept a list of input values for each parameter. The lists should be of equal length (and be up to length `max_batch_size`). The function is then *required* to return a tuple of lists (even if there is only 1 output component), with each list in the tuple corresponding to one output component.
262
+ max_batch_size: Maximum number of inputs to batch together if this is called from the queue (only relevant if batch=True)
263
+ preprocess: If False, will not run preprocessing of component data before running 'fn' (e.g. leaving it as a base64 string if this method is called with the `Image` component).
264
+ postprocess: If False, will not run postprocessing of component data before returning 'fn' output to the browser.
265
+ cancels: A list of other events to cancel when this event is triggered. For example, setting cancels=[click_event] will cancel the click_event, where click_event is the return value of another components .click method.
266
+ every: Run this event 'every' number of seconds while the client connection is open. Interpreted in seconds. Queue must be enabled.
267
+ """
268
+ # _js: Optional frontend js method to run before running 'fn'. Input arguments for js method are values of 'inputs' and 'outputs', return should be a list of values for output components.
269
+ if status_tracker:
270
+ warnings.warn(
271
+ "The 'status_tracker' parameter has been deprecated and has no effect."
272
+ )
273
+
274
+ dep = self.set_event_trigger(
275
+ "edit",
276
+ fn,
277
+ inputs,
278
+ outputs,
279
+ preprocess=preprocess,
280
+ postprocess=postprocess,
281
+ scroll_to_output=scroll_to_output,
282
+ show_progress=show_progress,
283
+ api_name=api_name,
284
+ js=_js,
285
+ queue=queue,
286
+ batch=batch,
287
+ max_batch_size=max_batch_size,
288
+ every=every,
289
+ )
290
+ set_cancel_events(self, "edit", cancels)
291
+ return dep
292
+
293
+
294
+ class Clearable(EventListener):
295
+ def clear(
296
+ self,
297
+ fn: Callable | None,
298
+ inputs: Component | List[Component] | Set[Component] | None = None,
299
+ outputs: Component | List[Component] | None = None,
300
+ api_name: str | None = None,
301
+ status_tracker: StatusTracker | None = None,
302
+ scroll_to_output: bool = False,
303
+ show_progress: bool = True,
304
+ queue: bool | None = None,
305
+ batch: bool = False,
306
+ max_batch_size: int = 4,
307
+ preprocess: bool = True,
308
+ postprocess: bool = True,
309
+ cancels: Dict[str, Any] | List[Dict[str, Any]] | None = None,
310
+ every: float | None = None,
311
+ _js: str | None = None,
312
+ ):
313
+ """
314
+ This event is triggered when the user clears the component (e.g. image or audio)
315
+ using the X button for the component. This method can be used when this component is in a Gradio Blocks.
316
+
317
+ Parameters:
318
+ fn: the function to wrap an interface around. Often a machine learning model's prediction function. Each parameter of the function corresponds to one input component, and the function should return a single value or a tuple of values, with each element in the tuple corresponding to one output component.
319
+ inputs: List of gradio.components to use as inputs. If the function takes no inputs, this should be an empty list.
320
+ outputs: List of gradio.components to use as inputs. If the function returns no outputs, this should be an empty list.
321
+ api_name: Defining this parameter exposes the endpoint in the api docs
322
+ scroll_to_output: If True, will scroll to output component on completion
323
+ show_progress: If True, will show progress animation while pending
324
+ queue: If True, will place the request on the queue, if the queue exists
325
+ batch: If True, then the function should process a batch of inputs, meaning that it should accept a list of input values for each parameter. The lists should be of equal length (and be up to length `max_batch_size`). The function is then *required* to return a tuple of lists (even if there is only 1 output component), with each list in the tuple corresponding to one output component.
326
+ max_batch_size: Maximum number of inputs to batch together if this is called from the queue (only relevant if batch=True)
327
+ preprocess: If False, will not run preprocessing of component data before running 'fn' (e.g. leaving it as a base64 string if this method is called with the `Image` component).
328
+ postprocess: If False, will not run postprocessing of component data before returning 'fn' output to the browser.
329
+ cancels: A list of other events to cancel when this event is triggered. For example, setting cancels=[click_event] will cancel the click_event, where click_event is the return value of another components .click method.
330
+ every: Run this event 'every' number of seconds while the client connection is open. Interpreted in seconds. Queue must be enabled.
331
+ """
332
+ # _js: Optional frontend js method to run before running 'fn'. Input arguments for js method are values of 'inputs' and 'outputs', return should be a list of values for output components.
333
+ if status_tracker:
334
+ warnings.warn(
335
+ "The 'status_tracker' parameter has been deprecated and has no effect."
336
+ )
337
+
338
+ dep = self.set_event_trigger(
339
+ "submit",
340
+ fn,
341
+ inputs,
342
+ outputs,
343
+ preprocess=preprocess,
344
+ postprocess=postprocess,
345
+ scroll_to_output=scroll_to_output,
346
+ show_progress=show_progress,
347
+ api_name=api_name,
348
+ js=_js,
349
+ queue=queue,
350
+ batch=batch,
351
+ max_batch_size=max_batch_size,
352
+ every=every,
353
+ )
354
+ set_cancel_events(self, "submit", cancels)
355
+ return dep
356
+
357
+
358
+ class Playable(EventListener):
359
+ def play(
360
+ self,
361
+ fn: Callable | None,
362
+ inputs: Component | List[Component] | Set[Component] | None = None,
363
+ outputs: Component | List[Component] | None = None,
364
+ api_name: str | None = None,
365
+ status_tracker: StatusTracker | None = None,
366
+ scroll_to_output: bool = False,
367
+ show_progress: bool = True,
368
+ queue: bool | None = None,
369
+ batch: bool = False,
370
+ max_batch_size: int = 4,
371
+ preprocess: bool = True,
372
+ postprocess: bool = True,
373
+ cancels: Dict[str, Any] | List[Dict[str, Any]] | None = None,
374
+ every: float | None = None,
375
+ _js: str | None = None,
376
+ ):
377
+ """
378
+ This event is triggered when the user plays the component (e.g. audio or video).
379
+ This method can be used when this component is in a Gradio Blocks.
380
+
381
+ Parameters:
382
+ fn: the function to wrap an interface around. Often a machine learning model's prediction function. Each parameter of the function corresponds to one input component, and the function should return a single value or a tuple of values, with each element in the tuple corresponding to one output component.
383
+ inputs: List of gradio.components to use as inputs. If the function takes no inputs, this should be an empty list.
384
+ outputs: List of gradio.components to use as inputs. If the function returns no outputs, this should be an empty list.
385
+ api_name: Defining this parameter exposes the endpoint in the api docs
386
+ scroll_to_output: If True, will scroll to output component on completion
387
+ show_progress: If True, will show progress animation while pending
388
+ queue: If True, will place the request on the queue, if the queue exists
389
+ batch: If True, then the function should process a batch of inputs, meaning that it should accept a list of input values for each parameter. The lists should be of equal length (and be up to length `max_batch_size`). The function is then *required* to return a tuple of lists (even if there is only 1 output component), with each list in the tuple corresponding to one output component.
390
+ max_batch_size: Maximum number of inputs to batch together if this is called from the queue (only relevant if batch=True)
391
+ preprocess: If False, will not run preprocessing of component data before running 'fn' (e.g. leaving it as a base64 string if this method is called with the `Image` component).
392
+ postprocess: If False, will not run postprocessing of component data before returning 'fn' output to the browser.
393
+ cancels: A list of other events to cancel when this event is triggered. For example, setting cancels=[click_event] will cancel the click_event, where click_event is the return value of another components .click method.
394
+ every: Run this event 'every' number of seconds while the client connection is open. Interpreted in seconds. Queue must be enabled.
395
+ """
396
+ # _js: Optional frontend js method to run before running 'fn'. Input arguments for js method are values of 'inputs' and 'outputs', return should be a list of values for output components.
397
+ if status_tracker:
398
+ warnings.warn(
399
+ "The 'status_tracker' parameter has been deprecated and has no effect."
400
+ )
401
+
402
+ dep = self.set_event_trigger(
403
+ "play",
404
+ fn,
405
+ inputs,
406
+ outputs,
407
+ preprocess=preprocess,
408
+ postprocess=postprocess,
409
+ scroll_to_output=scroll_to_output,
410
+ show_progress=show_progress,
411
+ api_name=api_name,
412
+ js=_js,
413
+ queue=queue,
414
+ batch=batch,
415
+ max_batch_size=max_batch_size,
416
+ every=every,
417
+ )
418
+ set_cancel_events(self, "play", cancels)
419
+ return dep
420
+
421
+ def pause(
422
+ self,
423
+ fn: Callable | None,
424
+ inputs: Component | List[Component] | Set[Component] | None = None,
425
+ outputs: Component | List[Component] | None = None,
426
+ api_name: str | None = None,
427
+ status_tracker: StatusTracker | None = None,
428
+ scroll_to_output: bool = False,
429
+ show_progress: bool = True,
430
+ queue: bool | None = None,
431
+ batch: bool = False,
432
+ max_batch_size: int = 4,
433
+ preprocess: bool = True,
434
+ postprocess: bool = True,
435
+ cancels: Dict[str, Any] | List[Dict[str, Any]] | None = None,
436
+ every: float | None = None,
437
+ _js: str | None = None,
438
+ ):
439
+ """
440
+ This event is triggered when the user pauses the component (e.g. audio or video).
441
+ This method can be used when this component is in a Gradio Blocks.
442
+
443
+ Parameters:
444
+ fn: the function to wrap an interface around. Often a machine learning model's prediction function. Each parameter of the function corresponds to one input component, and the function should return a single value or a tuple of values, with each element in the tuple corresponding to one output component.
445
+ inputs: List of gradio.components to use as inputs. If the function takes no inputs, this should be an empty list.
446
+ outputs: List of gradio.components to use as inputs. If the function returns no outputs, this should be an empty list.
447
+ api_name: Defining this parameter exposes the endpoint in the api docs
448
+ scroll_to_output: If True, will scroll to output component on completion
449
+ show_progress: If True, will show progress animation while pending
450
+ queue: If True, will place the request on the queue, if the queue exists
451
+ batch: If True, then the function should process a batch of inputs, meaning that it should accept a list of input values for each parameter. The lists should be of equal length (and be up to length `max_batch_size`). The function is then *required* to return a tuple of lists (even if there is only 1 output component), with each list in the tuple corresponding to one output component.
452
+ max_batch_size: Maximum number of inputs to batch together if this is called from the queue (only relevant if batch=True)
453
+ preprocess: If False, will not run preprocessing of component data before running 'fn' (e.g. leaving it as a base64 string if this method is called with the `Image` component).
454
+ postprocess: If False, will not run postprocessing of component data before returning 'fn' output to the browser.
455
+ cancels: A list of other events to cancel when this event is triggered. For example, setting cancels=[click_event] will cancel the click_event, where click_event is the return value of another components .click method.
456
+ every: Run this event 'every' number of seconds while the client connection is open. Interpreted in seconds. Queue must be enabled.
457
+ """
458
+ # _js: Optional frontend js method to run before running 'fn'. Input arguments for js method are values of 'inputs' and 'outputs', return should be a list of values for output components.
459
+ if status_tracker:
460
+ warnings.warn(
461
+ "The 'status_tracker' parameter has been deprecated and has no effect."
462
+ )
463
+
464
+ dep = self.set_event_trigger(
465
+ "pause",
466
+ fn,
467
+ inputs,
468
+ outputs,
469
+ preprocess=preprocess,
470
+ postprocess=postprocess,
471
+ scroll_to_output=scroll_to_output,
472
+ show_progress=show_progress,
473
+ api_name=api_name,
474
+ js=_js,
475
+ queue=queue,
476
+ batch=batch,
477
+ max_batch_size=max_batch_size,
478
+ every=every,
479
+ )
480
+ set_cancel_events(self, "pause", cancels)
481
+ return dep
482
+
483
+ def stop(
484
+ self,
485
+ fn: Callable | None,
486
+ inputs: Component | List[Component] | Set[Component] | None = None,
487
+ outputs: Component | List[Component] | None = None,
488
+ api_name: str | None = None,
489
+ status_tracker: StatusTracker | None = None,
490
+ scroll_to_output: bool = False,
491
+ show_progress: bool = True,
492
+ queue: bool | None = None,
493
+ batch: bool = False,
494
+ max_batch_size: int = 4,
495
+ preprocess: bool = True,
496
+ postprocess: bool = True,
497
+ cancels: Dict[str, Any] | List[Dict[str, Any]] | None = None,
498
+ every: float | None = None,
499
+ _js: str | None = None,
500
+ ):
501
+ """
502
+ This event is triggered when the user stops the component (e.g. audio or video).
503
+ This method can be used when this component is in a Gradio Blocks.
504
+
505
+ Parameters:
506
+ fn: the function to wrap an interface around. Often a machine learning model's prediction function. Each parameter of the function corresponds to one input component, and the function should return a single value or a tuple of values, with each element in the tuple corresponding to one output component.
507
+ inputs: List of gradio.components to use as inputs. If the function takes no inputs, this should be an empty list.
508
+ outputs: List of gradio.components to use as inputs. If the function returns no outputs, this should be an empty list.
509
+ api_name: Defining this parameter exposes the endpoint in the api docs
510
+ scroll_to_output: If True, will scroll to output component on completion
511
+ show_progress: If True, will show progress animation while pending
512
+ queue: If True, will place the request on the queue, if the queue exists
513
+ batch: If True, then the function should process a batch of inputs, meaning that it should accept a list of input values for each parameter. The lists should be of equal length (and be up to length `max_batch_size`). The function is then *required* to return a tuple of lists (even if there is only 1 output component), with each list in the tuple corresponding to one output component.
514
+ max_batch_size: Maximum number of inputs to batch together if this is called from the queue (only relevant if batch=True)
515
+ preprocess: If False, will not run preprocessing of component data before running 'fn' (e.g. leaving it as a base64 string if this method is called with the `Image` component).
516
+ postprocess: If False, will not run postprocessing of component data before returning 'fn' output to the browser.
517
+ cancels: A list of other events to cancel when this event is triggered. For example, setting cancels=[click_event] will cancel the click_event, where click_event is the return value of another components .click method.
518
+ every: Run this event 'every' number of seconds while the client connection is open. Interpreted in seconds. Queue must be enabled.
519
+ """
520
+ # _js: Optional frontend js method to run before running 'fn'. Input arguments for js method are values of 'inputs' and 'outputs', return should be a list of values for output components.
521
+ if status_tracker:
522
+ warnings.warn(
523
+ "The 'status_tracker' parameter has been deprecated and has no effect."
524
+ )
525
+
526
+ dep = self.set_event_trigger(
527
+ "stop",
528
+ fn,
529
+ inputs,
530
+ outputs,
531
+ preprocess=preprocess,
532
+ postprocess=postprocess,
533
+ scroll_to_output=scroll_to_output,
534
+ show_progress=show_progress,
535
+ api_name=api_name,
536
+ js=_js,
537
+ queue=queue,
538
+ batch=batch,
539
+ max_batch_size=max_batch_size,
540
+ every=every,
541
+ )
542
+ set_cancel_events(self, "stop", cancels)
543
+ return dep
544
+
545
+
546
+ class Streamable(EventListener):
547
+ def stream(
548
+ self,
549
+ fn: Callable | None,
550
+ inputs: Component | List[Component] | Set[Component] | None = None,
551
+ outputs: Component | List[Component] | None = None,
552
+ api_name: str | None = None,
553
+ status_tracker: StatusTracker | None = None,
554
+ scroll_to_output: bool = False,
555
+ show_progress: bool = False,
556
+ queue: bool | None = None,
557
+ batch: bool = False,
558
+ max_batch_size: int = 4,
559
+ preprocess: bool = True,
560
+ postprocess: bool = True,
561
+ cancels: Dict[str, Any] | List[Dict[str, Any]] | None = None,
562
+ every: float | None = None,
563
+ _js: str | None = None,
564
+ ):
565
+ """
566
+ This event is triggered when the user streams the component (e.g. a live webcam
567
+ component). This method can be used when this component is in a Gradio Blocks.
568
+
569
+ Parameters:
570
+ fn: the function to wrap an interface around. Often a machine learning model's prediction function. Each parameter of the function corresponds to one input component, and the function should return a single value or a tuple of values, with each element in the tuple corresponding to one output component.
571
+ inputs: List of gradio.components to use as inputs. If the function takes no inputs, this should be an empty list.
572
+ outputs: List of gradio.components to use as inputs. If the function returns no outputs, this should be an empty list.
573
+ api_name: Defining this parameter exposes the endpoint in the api docs
574
+ scroll_to_output: If True, will scroll to output component on completion
575
+ show_progress: If True, will show progress animation while pending
576
+ queue: If True, will place the request on the queue, if the queue exists
577
+ batch: If True, then the function should process a batch of inputs, meaning that it should accept a list of input values for each parameter. The lists should be of equal length (and be up to length `max_batch_size`). The function is then *required* to return a tuple of lists (even if there is only 1 output component), with each list in the tuple corresponding to one output component.
578
+ max_batch_size: Maximum number of inputs to batch together if this is called from the queue (only relevant if batch=True)
579
+ preprocess: If False, will not run preprocessing of component data before running 'fn' (e.g. leaving it as a base64 string if this method is called with the `Image` component).
580
+ postprocess: If False, will not run postprocessing of component data before returning 'fn' output to the browser.
581
+ cancels: A list of other events to cancel when this event is triggered. For example, setting cancels=[click_event] will cancel the click_event, where click_event is the return value of another components .click method.
582
+ every: Run this event 'every' number of seconds while the client connection is open. Interpreted in seconds. Queue must be enabled.
583
+ """
584
+ # _js: Optional frontend js method to run before running 'fn'. Input arguments for js method are values of 'inputs' and 'outputs', return should be a list of values for output components.
585
+ self.streaming = True
586
+
587
+ if status_tracker:
588
+ warnings.warn(
589
+ "The 'status_tracker' parameter has been deprecated and has no effect."
590
+ )
591
+
592
+ dep = self.set_event_trigger(
593
+ "stream",
594
+ fn,
595
+ inputs,
596
+ outputs,
597
+ preprocess=preprocess,
598
+ postprocess=postprocess,
599
+ scroll_to_output=scroll_to_output,
600
+ show_progress=show_progress,
601
+ api_name=api_name,
602
+ js=_js,
603
+ queue=queue,
604
+ batch=batch,
605
+ max_batch_size=max_batch_size,
606
+ every=every,
607
+ )
608
+ set_cancel_events(self, "stream", cancels)
609
+ return dep
610
+
611
+
612
+ class Blurrable(EventListener):
613
+ def blur(
614
+ self,
615
+ fn: Callable | None,
616
+ inputs: Component | List[Component] | Set[Component] | None = None,
617
+ outputs: Component | List[Component] | None = None,
618
+ api_name: str | None = None,
619
+ scroll_to_output: bool = False,
620
+ show_progress: bool = True,
621
+ queue: bool | None = None,
622
+ batch: bool = False,
623
+ max_batch_size: int = 4,
624
+ preprocess: bool = True,
625
+ postprocess: bool = True,
626
+ cancels: Dict[str, Any] | List[Dict[str, Any]] | None = None,
627
+ every: float | None = None,
628
+ _js: str | None = None,
629
+ ):
630
+ """
631
+ This event is triggered when the component's is unfocused/blurred (e.g. when the user clicks outside of a textbox). This method can be used when this component is in a Gradio Blocks.
632
+
633
+ Parameters:
634
+ fn: Callable function
635
+ inputs: List of gradio.components to use as inputs. If the function takes no inputs, this should be an empty list.
636
+ outputs: List of gradio.components to use as inputs. If the function returns no outputs, this should be an empty list.
637
+ api_name: Defining this parameter exposes the endpoint in the api docs
638
+ scroll_to_output: If True, will scroll to output component on completion
639
+ show_progress: If True, will show progress animation while pending
640
+ queue: If True, will place the request on the queue, if the queue exists
641
+ batch: If True, then the function should process a batch of inputs, meaning that it should accept a list of input values for each parameter. The lists should be of equal length (and be up to length `max_batch_size`). The function is then *required* to return a tuple of lists (even if there is only 1 output component), with each list in the tuple corresponding to one output component.
642
+ max_batch_size: Maximum number of inputs to batch together if this is called from the queue (only relevant if batch=True)
643
+ preprocess: If False, will not run preprocessing of component data before running 'fn' (e.g. leaving it as a base64 string if this method is called with the `Image` component).
644
+ postprocess: If False, will not run postprocessing of component data before returning 'fn' output to the browser.
645
+ cancels: A list of other events to cancel when this event is triggered. For example, setting cancels=[click_event] will cancel the click_event, where click_event is the return value of another components .click method.
646
+ every: Run this event 'every' number of seconds while the client connection is open. Interpreted in seconds. Queue must be enabled.
647
+ """
648
+ # _js: Optional frontend js method to run before running 'fn'. Input arguments for js method are values of 'inputs' and 'outputs', return should be a list of values for output components.
649
+
650
+ self.set_event_trigger(
651
+ "blur",
652
+ fn,
653
+ inputs,
654
+ outputs,
655
+ preprocess=preprocess,
656
+ postprocess=postprocess,
657
+ scroll_to_output=scroll_to_output,
658
+ show_progress=show_progress,
659
+ api_name=api_name,
660
+ js=_js,
661
+ queue=queue,
662
+ batch=batch,
663
+ max_batch_size=max_batch_size,
664
+ every=every,
665
+ )
666
+ set_cancel_events(self, "blur", cancels)
667
+
668
+
669
+ class Uploadable(EventListener):
670
+ def upload(
671
+ self,
672
+ fn: Callable | None,
673
+ inputs: List[Component],
674
+ outputs: Component | List[Component] | None = None,
675
+ api_name: str | None = None,
676
+ scroll_to_output: bool = False,
677
+ show_progress: bool = True,
678
+ queue: bool | None = None,
679
+ batch: bool = False,
680
+ max_batch_size: int = 4,
681
+ preprocess: bool = True,
682
+ postprocess: bool = True,
683
+ cancels: List[Dict[str, Any]] | None = None,
684
+ every: float | None = None,
685
+ _js: str | None = None,
686
+ ):
687
+ """
688
+ This event is triggered when the user uploads a file into the component (e.g. when the user uploads a video into a video component). This method can be used when this component is in a Gradio Blocks.
689
+
690
+ Parameters:
691
+ fn: Callable function
692
+ inputs: List of inputs
693
+ outputs: List of outputs
694
+ api_name: Defining this parameter exposes the endpoint in the api docs
695
+ scroll_to_output: If True, will scroll to output component on completion
696
+ show_progress: If True, will show progress animation while pending
697
+ queue: If True, will place the request on the queue, if the queue exists
698
+ batch: If True, then the function should process a batch of inputs, meaning that it should accept a list of input values for each parameter. The lists should be of equal length (and be up to length `max_batch_size`). The function is then *required* to return a tuple of lists (even if there is only 1 output component), with each list in the tuple corresponding to one output component.
699
+ max_batch_size: Maximum number of inputs to batch together if this is called from the queue (only relevant if batch=True)
700
+ preprocess: If False, will not run preprocessing of component data before running 'fn' (e.g. leaving it as a base64 string if this method is called with the `Image` component).
701
+ postprocess: If False, will not run postprocessing of component data before returning 'fn' output to the browser.
702
+ cancels: A list of other events to cancel when this event is triggered. For example, setting cancels=[click_event] will cancel the click_event, where click_event is the return value of another components .click method.
703
+ every: Run this event 'every' number of seconds while the client connection is open. Interpreted in seconds. Queue must be enabled.
704
+ """
705
+ # _js: Optional frontend js method to run before running 'fn'. Input arguments for js method are values of 'inputs' and 'outputs', return should be a list of values for output components.
706
+
707
+ self.set_event_trigger(
708
+ "upload",
709
+ fn,
710
+ inputs,
711
+ outputs,
712
+ preprocess=preprocess,
713
+ postprocess=postprocess,
714
+ scroll_to_output=scroll_to_output,
715
+ show_progress=show_progress,
716
+ api_name=api_name,
717
+ js=_js,
718
+ queue=queue,
719
+ batch=batch,
720
+ max_batch_size=max_batch_size,
721
+ every=every,
722
+ )
723
+ set_cancel_events(self, "upload", cancels)
gradio-modified/gradio/examples.py ADDED
@@ -0,0 +1,327 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Defines helper methods useful for loading and caching Interface examples.
3
+ """
4
+ from __future__ import annotations
5
+
6
+ import ast
7
+ import csv
8
+ import os
9
+ import warnings
10
+ from pathlib import Path
11
+ from typing import TYPE_CHECKING, Any, Callable, List
12
+
13
+ from gradio import utils
14
+ from gradio.components import Dataset
15
+ from gradio.context import Context
16
+ from gradio.documentation import document, set_documentation_group
17
+ from gradio.flagging import CSVLogger
18
+
19
+ if TYPE_CHECKING: # Only import for type checking (to avoid circular imports).
20
+ from gradio.components import IOComponent
21
+
22
+ CACHED_FOLDER = "gradio_cached_examples"
23
+ LOG_FILE = "log.csv"
24
+
25
+ set_documentation_group("component-helpers")
26
+
27
+
28
+ def create_examples(
29
+ examples: List[Any] | List[List[Any]] | str,
30
+ inputs: IOComponent | List[IOComponent],
31
+ outputs: IOComponent | List[IOComponent] | None = None,
32
+ fn: Callable | None = None,
33
+ cache_examples: bool = False,
34
+ examples_per_page: int = 10,
35
+ _api_mode: bool = False,
36
+ label: str | None = None,
37
+ elem_id: str | None = None,
38
+ run_on_click: bool = False,
39
+ preprocess: bool = True,
40
+ postprocess: bool = True,
41
+ batch: bool = False,
42
+ ):
43
+ """Top-level synchronous function that creates Examples. Provided for backwards compatibility, i.e. so that gr.Examples(...) can be used to create the Examples component."""
44
+ examples_obj = Examples(
45
+ examples=examples,
46
+ inputs=inputs,
47
+ outputs=outputs,
48
+ fn=fn,
49
+ cache_examples=cache_examples,
50
+ examples_per_page=examples_per_page,
51
+ _api_mode=_api_mode,
52
+ label=label,
53
+ elem_id=elem_id,
54
+ run_on_click=run_on_click,
55
+ preprocess=preprocess,
56
+ postprocess=postprocess,
57
+ batch=batch,
58
+ _initiated_directly=False,
59
+ )
60
+ utils.synchronize_async(examples_obj.create)
61
+ return examples_obj
62
+
63
+
64
+ @document()
65
+ class Examples:
66
+ """
67
+ This class is a wrapper over the Dataset component and can be used to create Examples
68
+ for Blocks / Interfaces. Populates the Dataset component with examples and
69
+ assigns event listener so that clicking on an example populates the input/output
70
+ components. Optionally handles example caching for fast inference.
71
+
72
+ Demos: blocks_inputs, fake_gan
73
+ Guides: more_on_examples_and_flagging, using_hugging_face_integrations, image_classification_in_pytorch, image_classification_in_tensorflow, image_classification_with_vision_transformers, create_your_own_friends_with_a_gan
74
+ """
75
+
76
+ def __init__(
77
+ self,
78
+ examples: List[Any] | List[List[Any]] | str,
79
+ inputs: IOComponent | List[IOComponent],
80
+ outputs: IOComponent | List[IOComponent] | None = None,
81
+ fn: Callable | None = None,
82
+ cache_examples: bool = False,
83
+ examples_per_page: int = 10,
84
+ _api_mode: bool = False,
85
+ label: str | None = "Examples",
86
+ elem_id: str | None = None,
87
+ run_on_click: bool = False,
88
+ preprocess: bool = True,
89
+ postprocess: bool = True,
90
+ batch: bool = False,
91
+ _initiated_directly: bool = True,
92
+ ):
93
+ """
94
+ Parameters:
95
+ examples: example inputs that can be clicked to populate specific components. Should be nested list, in which the outer list consists of samples and each inner list consists of an input corresponding to each input component. A string path to a directory of examples can also be provided but it should be within the directory with the python file running the gradio app. If there are multiple input components and a directory is provided, a log.csv file must be present in the directory to link corresponding inputs.
96
+ inputs: the component or list of components corresponding to the examples
97
+ outputs: optionally, provide the component or list of components corresponding to the output of the examples. Required if `cache` is True.
98
+ fn: optionally, provide the function to run to generate the outputs corresponding to the examples. Required if `cache` is True.
99
+ cache_examples: if True, caches examples for fast runtime. If True, then `fn` and `outputs` need to be provided
100
+ examples_per_page: how many examples to show per page.
101
+ label: the label to use for the examples component (by default, "Examples")
102
+ elem_id: an optional string that is assigned as the id of this component in the HTML DOM.
103
+ run_on_click: if cache_examples is False, clicking on an example does not run the function when an example is clicked. Set this to True to run the function when an example is clicked. Has no effect if cache_examples is True.
104
+ preprocess: if True, preprocesses the example input before running the prediction function and caching the output. Only applies if cache_examples is True.
105
+ postprocess: if True, postprocesses the example output after running the prediction function and before caching. Only applies if cache_examples is True.
106
+ batch: If True, then the function should process a batch of inputs, meaning that it should accept a list of input values for each parameter. Used only if cache_examples is True.
107
+ """
108
+ if _initiated_directly:
109
+ warnings.warn(
110
+ "Please use gr.Examples(...) instead of gr.examples.Examples(...) to create the Examples.",
111
+ )
112
+
113
+ if cache_examples and (fn is None or outputs is None):
114
+ raise ValueError("If caching examples, `fn` and `outputs` must be provided")
115
+
116
+ if not isinstance(inputs, list):
117
+ inputs = [inputs]
118
+ if outputs and not isinstance(outputs, list):
119
+ outputs = [outputs]
120
+
121
+ working_directory = Path().absolute()
122
+
123
+ if examples is None:
124
+ raise ValueError("The parameter `examples` cannot be None")
125
+ elif isinstance(examples, list) and (
126
+ len(examples) == 0 or isinstance(examples[0], list)
127
+ ):
128
+ pass
129
+ elif (
130
+ isinstance(examples, list) and len(inputs) == 1
131
+ ): # If there is only one input component, examples can be provided as a regular list instead of a list of lists
132
+ examples = [[e] for e in examples]
133
+ elif isinstance(examples, str):
134
+ if not Path(examples).exists():
135
+ raise FileNotFoundError(
136
+ "Could not find examples directory: " + examples
137
+ )
138
+ working_directory = examples
139
+ if not (Path(examples) / LOG_FILE).exists():
140
+ if len(inputs) == 1:
141
+ examples = [[e] for e in os.listdir(examples)]
142
+ else:
143
+ raise FileNotFoundError(
144
+ "Could not find log file (required for multiple inputs): "
145
+ + LOG_FILE
146
+ )
147
+ else:
148
+ with open(Path(examples) / LOG_FILE) as logs:
149
+ examples = list(csv.reader(logs))
150
+ examples = [
151
+ examples[i][: len(inputs)] for i in range(1, len(examples))
152
+ ] # remove header and unnecessary columns
153
+
154
+ else:
155
+ raise ValueError(
156
+ "The parameter `examples` must either be a string directory or a list"
157
+ "(if there is only 1 input component) or (more generally), a nested "
158
+ "list, where each sublist represents a set of inputs."
159
+ )
160
+
161
+ input_has_examples = [False] * len(inputs)
162
+ for example in examples:
163
+ for idx, example_for_input in enumerate(example):
164
+ if not (example_for_input is None):
165
+ try:
166
+ input_has_examples[idx] = True
167
+ except IndexError:
168
+ pass # If there are more example components than inputs, ignore. This can sometimes be intentional (e.g. loading from a log file where outputs and timestamps are also logged)
169
+
170
+ inputs_with_examples = [
171
+ inp for (inp, keep) in zip(inputs, input_has_examples) if keep
172
+ ]
173
+ non_none_examples = [
174
+ [ex for (ex, keep) in zip(example, input_has_examples) if keep]
175
+ for example in examples
176
+ ]
177
+
178
+ self.examples = examples
179
+ self.non_none_examples = non_none_examples
180
+ self.inputs = inputs
181
+ self.inputs_with_examples = inputs_with_examples
182
+ self.outputs = outputs
183
+ self.fn = fn
184
+ self.cache_examples = cache_examples
185
+ self._api_mode = _api_mode
186
+ self.preprocess = preprocess
187
+ self.postprocess = postprocess
188
+ self.batch = batch
189
+
190
+ with utils.set_directory(working_directory):
191
+ self.processed_examples = [
192
+ [
193
+ component.postprocess(sample)
194
+ for component, sample in zip(inputs, example)
195
+ ]
196
+ for example in examples
197
+ ]
198
+ self.non_none_processed_examples = [
199
+ [ex for (ex, keep) in zip(example, input_has_examples) if keep]
200
+ for example in self.processed_examples
201
+ ]
202
+ if cache_examples:
203
+ for example in self.examples:
204
+ if len([ex for ex in example if ex is not None]) != len(self.inputs):
205
+ warnings.warn(
206
+ "Examples are being cached but not all input components have "
207
+ "example values. This may result in an exception being thrown by "
208
+ "your function. If you do get an error while caching examples, make "
209
+ "sure all of your inputs have example values for all of your examples "
210
+ "or you provide default values for those particular parameters in your function."
211
+ )
212
+ break
213
+
214
+ with utils.set_directory(working_directory):
215
+ self.dataset = Dataset(
216
+ components=inputs_with_examples,
217
+ samples=non_none_examples,
218
+ type="index",
219
+ label=label,
220
+ samples_per_page=examples_per_page,
221
+ elem_id=elem_id,
222
+ )
223
+
224
+ self.cached_folder = Path(CACHED_FOLDER) / str(self.dataset._id)
225
+ self.cached_file = Path(self.cached_folder) / "log.csv"
226
+ self.cache_examples = cache_examples
227
+ self.run_on_click = run_on_click
228
+
229
+ async def create(self) -> None:
230
+ """Caches the examples if self.cache_examples is True and creates the Dataset
231
+ component to hold the examples"""
232
+
233
+ async def load_example(example_id):
234
+ if self.cache_examples:
235
+ processed_example = self.non_none_processed_examples[
236
+ example_id
237
+ ] + await self.load_from_cache(example_id)
238
+ else:
239
+ processed_example = self.non_none_processed_examples[example_id]
240
+ return utils.resolve_singleton(processed_example)
241
+
242
+ if Context.root_block:
243
+ if self.cache_examples and self.outputs:
244
+ targets = self.inputs_with_examples
245
+ else:
246
+ targets = self.inputs
247
+ self.dataset.click(
248
+ load_example,
249
+ inputs=[self.dataset],
250
+ outputs=targets, # type: ignore
251
+ postprocess=False,
252
+ queue=False,
253
+ )
254
+ if self.run_on_click and not self.cache_examples:
255
+ if self.fn is None:
256
+ raise ValueError("Cannot run_on_click if no function is provided")
257
+ self.dataset.click(
258
+ self.fn,
259
+ inputs=self.inputs, # type: ignore
260
+ outputs=self.outputs, # type: ignore
261
+ )
262
+
263
+ if self.cache_examples:
264
+ await self.cache()
265
+
266
+ async def cache(self) -> None:
267
+ """
268
+ Caches all of the examples so that their predictions can be shown immediately.
269
+ """
270
+ if Path(self.cached_file).exists():
271
+ print(
272
+ f"Using cache from '{Path(self.cached_folder).resolve()}' directory. If method or examples have changed since last caching, delete this folder to clear cache."
273
+ )
274
+ else:
275
+ if Context.root_block is None:
276
+ raise ValueError("Cannot cache examples if not in a Blocks context")
277
+
278
+ print(f"Caching examples at: '{Path(self.cached_file).resolve()}'")
279
+ cache_logger = CSVLogger()
280
+
281
+ # create a fake dependency to process the examples and get the predictions
282
+ dependency = Context.root_block.set_event_trigger(
283
+ event_name="fake_event",
284
+ fn=self.fn,
285
+ inputs=self.inputs_with_examples, # type: ignore
286
+ outputs=self.outputs, # type: ignore
287
+ preprocess=self.preprocess and not self._api_mode,
288
+ postprocess=self.postprocess and not self._api_mode,
289
+ batch=self.batch,
290
+ )
291
+
292
+ fn_index = Context.root_block.dependencies.index(dependency)
293
+ assert self.outputs is not None
294
+ cache_logger.setup(self.outputs, self.cached_folder)
295
+ for example_id, _ in enumerate(self.examples):
296
+ processed_input = self.processed_examples[example_id]
297
+ if self.batch:
298
+ processed_input = [[value] for value in processed_input]
299
+ prediction = await Context.root_block.process_api(
300
+ fn_index=fn_index, inputs=processed_input, request=None, state={}
301
+ )
302
+ output = prediction["data"]
303
+ if self.batch:
304
+ output = [value[0] for value in output]
305
+ cache_logger.flag(output)
306
+ # Remove the "fake_event" to prevent bugs in loading interfaces from spaces
307
+ Context.root_block.dependencies.remove(dependency)
308
+ Context.root_block.fns.pop(fn_index)
309
+
310
+ async def load_from_cache(self, example_id: int) -> List[Any]:
311
+ """Loads a particular cached example for the interface.
312
+ Parameters:
313
+ example_id: The id of the example to process (zero-indexed).
314
+ """
315
+ with open(self.cached_file) as cache:
316
+ examples = list(csv.reader(cache))
317
+ example = examples[example_id + 1] # +1 to adjust for header
318
+ output = []
319
+ assert self.outputs is not None
320
+ for component, value in zip(self.outputs, example):
321
+ try:
322
+ value_as_dict = ast.literal_eval(value)
323
+ assert utils.is_update(value_as_dict)
324
+ output.append(value_as_dict)
325
+ except (ValueError, TypeError, SyntaxError, AssertionError):
326
+ output.append(component.serialize(value, self.cached_folder))
327
+ return output
gradio-modified/gradio/exceptions.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class DuplicateBlockError(ValueError):
2
+ """Raised when a Blocks contains more than one Block with the same id"""
3
+
4
+ pass
5
+
6
+
7
+ class TooManyRequestsError(Exception):
8
+ """Raised when the Hugging Face API returns a 429 status code."""
9
+
10
+ pass
11
+
12
+
13
+ class InvalidApiName(ValueError):
14
+ pass
15
+
16
+
17
+ class Error(Exception):
18
+ def __init__(self, message: str):
19
+ self.message = message
20
+ super().__init__(self.message)
21
+
22
+ def __str__(self):
23
+ return repr(self.message)
gradio-modified/gradio/external.py ADDED
@@ -0,0 +1,462 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This module should not be used directly as its API is subject to change. Instead,
2
+ use the `gr.Blocks.load()` or `gr.Interface.load()` functions."""
3
+
4
+ from __future__ import annotations
5
+
6
+ import json
7
+ import re
8
+ import uuid
9
+ import warnings
10
+ from copy import deepcopy
11
+ from typing import TYPE_CHECKING, Callable, Dict
12
+
13
+ import requests
14
+
15
+ import gradio
16
+ from gradio import components, utils
17
+ from gradio.exceptions import TooManyRequestsError
18
+ from gradio.external_utils import (
19
+ cols_to_rows,
20
+ encode_to_base64,
21
+ get_tabular_examples,
22
+ get_ws_fn,
23
+ postprocess_label,
24
+ rows_to_cols,
25
+ streamline_spaces_interface,
26
+ use_websocket,
27
+ )
28
+ from gradio.processing_utils import to_binary
29
+
30
+ if TYPE_CHECKING:
31
+ from gradio.blocks import Blocks
32
+ from gradio.interface import Interface
33
+
34
+
35
+ def load_blocks_from_repo(
36
+ name: str,
37
+ src: str | None = None,
38
+ api_key: str | None = None,
39
+ alias: str | None = None,
40
+ **kwargs,
41
+ ) -> Blocks:
42
+ """Creates and returns a Blocks instance from a Hugging Face model or Space repo."""
43
+ if src is None:
44
+ # Separate the repo type (e.g. "model") from repo name (e.g. "google/vit-base-patch16-224")
45
+ tokens = name.split("/")
46
+ assert (
47
+ len(tokens) > 1
48
+ ), "Either `src` parameter must be provided, or `name` must be formatted as {src}/{repo name}"
49
+ src = tokens[0]
50
+ name = "/".join(tokens[1:])
51
+
52
+ factory_methods: Dict[str, Callable] = {
53
+ # for each repo type, we have a method that returns the Interface given the model name & optionally an api_key
54
+ "huggingface": from_model,
55
+ "models": from_model,
56
+ "spaces": from_spaces,
57
+ }
58
+ assert src.lower() in factory_methods, "parameter: src must be one of {}".format(
59
+ factory_methods.keys()
60
+ )
61
+
62
+ blocks: gradio.Blocks = factory_methods[src](name, api_key, alias, **kwargs)
63
+ return blocks
64
+
65
+
66
+ def from_model(model_name: str, api_key: str | None, alias: str | None, **kwargs):
67
+ model_url = "https://huggingface.co/{}".format(model_name)
68
+ api_url = "https://api-inference.huggingface.co/models/{}".format(model_name)
69
+ print("Fetching model from: {}".format(model_url))
70
+
71
+ headers = {"Authorization": f"Bearer {api_key}"} if api_key is not None else {}
72
+
73
+ # Checking if model exists, and if so, it gets the pipeline
74
+ response = requests.request("GET", api_url, headers=headers)
75
+ assert (
76
+ response.status_code == 200
77
+ ), f"Could not find model: {model_name}. If it is a private or gated model, please provide your Hugging Face access token (https://huggingface.co/settings/tokens) as the argument for the `api_key` parameter."
78
+ p = response.json().get("pipeline_tag")
79
+
80
+ pipelines = {
81
+ "audio-classification": {
82
+ # example model: ehcalabres/wav2vec2-lg-xlsr-en-speech-emotion-recognition
83
+ "inputs": components.Audio(source="upload", type="filepath", label="Input"),
84
+ "outputs": components.Label(label="Class"),
85
+ "preprocess": lambda i: to_binary,
86
+ "postprocess": lambda r: postprocess_label(
87
+ {i["label"].split(", ")[0]: i["score"] for i in r.json()}
88
+ ),
89
+ },
90
+ "audio-to-audio": {
91
+ # example model: facebook/xm_transformer_sm_all-en
92
+ "inputs": components.Audio(source="upload", type="filepath", label="Input"),
93
+ "outputs": components.Audio(label="Output"),
94
+ "preprocess": to_binary,
95
+ "postprocess": encode_to_base64,
96
+ },
97
+ "automatic-speech-recognition": {
98
+ # example model: facebook/wav2vec2-base-960h
99
+ "inputs": components.Audio(source="upload", type="filepath", label="Input"),
100
+ "outputs": components.Textbox(label="Output"),
101
+ "preprocess": to_binary,
102
+ "postprocess": lambda r: r.json()["text"],
103
+ },
104
+ "feature-extraction": {
105
+ # example model: julien-c/distilbert-feature-extraction
106
+ "inputs": components.Textbox(label="Input"),
107
+ "outputs": components.Dataframe(label="Output"),
108
+ "preprocess": lambda x: {"inputs": x},
109
+ "postprocess": lambda r: r.json()[0],
110
+ },
111
+ "fill-mask": {
112
+ "inputs": components.Textbox(label="Input"),
113
+ "outputs": components.Label(label="Classification"),
114
+ "preprocess": lambda x: {"inputs": x},
115
+ "postprocess": lambda r: postprocess_label(
116
+ {i["token_str"]: i["score"] for i in r.json()}
117
+ ),
118
+ },
119
+ "image-classification": {
120
+ # Example: google/vit-base-patch16-224
121
+ "inputs": components.Image(type="filepath", label="Input Image"),
122
+ "outputs": components.Label(label="Classification"),
123
+ "preprocess": to_binary,
124
+ "postprocess": lambda r: postprocess_label(
125
+ {i["label"].split(", ")[0]: i["score"] for i in r.json()}
126
+ ),
127
+ },
128
+ "question-answering": {
129
+ # Example: deepset/xlm-roberta-base-squad2
130
+ "inputs": [
131
+ components.Textbox(lines=7, label="Context"),
132
+ components.Textbox(label="Question"),
133
+ ],
134
+ "outputs": [
135
+ components.Textbox(label="Answer"),
136
+ components.Label(label="Score"),
137
+ ],
138
+ "preprocess": lambda c, q: {"inputs": {"context": c, "question": q}},
139
+ "postprocess": lambda r: (r.json()["answer"], {"label": r.json()["score"]}),
140
+ },
141
+ "summarization": {
142
+ # Example: facebook/bart-large-cnn
143
+ "inputs": components.Textbox(label="Input"),
144
+ "outputs": components.Textbox(label="Summary"),
145
+ "preprocess": lambda x: {"inputs": x},
146
+ "postprocess": lambda r: r.json()[0]["summary_text"],
147
+ },
148
+ "text-classification": {
149
+ # Example: distilbert-base-uncased-finetuned-sst-2-english
150
+ "inputs": components.Textbox(label="Input"),
151
+ "outputs": components.Label(label="Classification"),
152
+ "preprocess": lambda x: {"inputs": x},
153
+ "postprocess": lambda r: postprocess_label(
154
+ {i["label"].split(", ")[0]: i["score"] for i in r.json()[0]}
155
+ ),
156
+ },
157
+ "text-generation": {
158
+ # Example: gpt2
159
+ "inputs": components.Textbox(label="Input"),
160
+ "outputs": components.Textbox(label="Output"),
161
+ "preprocess": lambda x: {"inputs": x},
162
+ "postprocess": lambda r: r.json()[0]["generated_text"],
163
+ },
164
+ "text2text-generation": {
165
+ # Example: valhalla/t5-small-qa-qg-hl
166
+ "inputs": components.Textbox(label="Input"),
167
+ "outputs": components.Textbox(label="Generated Text"),
168
+ "preprocess": lambda x: {"inputs": x},
169
+ "postprocess": lambda r: r.json()[0]["generated_text"],
170
+ },
171
+ "translation": {
172
+ "inputs": components.Textbox(label="Input"),
173
+ "outputs": components.Textbox(label="Translation"),
174
+ "preprocess": lambda x: {"inputs": x},
175
+ "postprocess": lambda r: r.json()[0]["translation_text"],
176
+ },
177
+ "zero-shot-classification": {
178
+ # Example: facebook/bart-large-mnli
179
+ "inputs": [
180
+ components.Textbox(label="Input"),
181
+ components.Textbox(label="Possible class names (" "comma-separated)"),
182
+ components.Checkbox(label="Allow multiple true classes"),
183
+ ],
184
+ "outputs": components.Label(label="Classification"),
185
+ "preprocess": lambda i, c, m: {
186
+ "inputs": i,
187
+ "parameters": {"candidate_labels": c, "multi_class": m},
188
+ },
189
+ "postprocess": lambda r: postprocess_label(
190
+ {
191
+ r.json()["labels"][i]: r.json()["scores"][i]
192
+ for i in range(len(r.json()["labels"]))
193
+ }
194
+ ),
195
+ },
196
+ "sentence-similarity": {
197
+ # Example: sentence-transformers/distilbert-base-nli-stsb-mean-tokens
198
+ "inputs": [
199
+ components.Textbox(
200
+ value="That is a happy person", label="Source Sentence"
201
+ ),
202
+ components.Textbox(
203
+ lines=7,
204
+ placeholder="Separate each sentence by a newline",
205
+ label="Sentences to compare to",
206
+ ),
207
+ ],
208
+ "outputs": components.Label(label="Classification"),
209
+ "preprocess": lambda src, sentences: {
210
+ "inputs": {
211
+ "source_sentence": src,
212
+ "sentences": [s for s in sentences.splitlines() if s != ""],
213
+ }
214
+ },
215
+ "postprocess": lambda r: postprocess_label(
216
+ {f"sentence {i}": v for i, v in enumerate(r.json())}
217
+ ),
218
+ },
219
+ "text-to-speech": {
220
+ # Example: julien-c/ljspeech_tts_train_tacotron2_raw_phn_tacotron_g2p_en_no_space_train
221
+ "inputs": components.Textbox(label="Input"),
222
+ "outputs": components.Audio(label="Audio"),
223
+ "preprocess": lambda x: {"inputs": x},
224
+ "postprocess": encode_to_base64,
225
+ },
226
+ "text-to-image": {
227
+ # example model: osanseviero/BigGAN-deep-128
228
+ "inputs": components.Textbox(label="Input"),
229
+ "outputs": components.Image(label="Output"),
230
+ "preprocess": lambda x: {"inputs": x},
231
+ "postprocess": encode_to_base64,
232
+ },
233
+ "token-classification": {
234
+ # example model: huggingface-course/bert-finetuned-ner
235
+ "inputs": components.Textbox(label="Input"),
236
+ "outputs": components.HighlightedText(label="Output"),
237
+ "preprocess": lambda x: {"inputs": x},
238
+ "postprocess": lambda r: r, # Handled as a special case in query_huggingface_api()
239
+ },
240
+ }
241
+
242
+ if p in ["tabular-classification", "tabular-regression"]:
243
+ example_data = get_tabular_examples(model_name)
244
+ col_names, example_data = cols_to_rows(example_data)
245
+ example_data = [[example_data]] if example_data else None
246
+
247
+ pipelines[p] = {
248
+ "inputs": components.Dataframe(
249
+ label="Input Rows",
250
+ type="pandas",
251
+ headers=col_names,
252
+ col_count=(len(col_names), "fixed"),
253
+ ),
254
+ "outputs": components.Dataframe(
255
+ label="Predictions", type="array", headers=["prediction"]
256
+ ),
257
+ "preprocess": rows_to_cols,
258
+ "postprocess": lambda r: {
259
+ "headers": ["prediction"],
260
+ "data": [[pred] for pred in json.loads(r.text)],
261
+ },
262
+ "examples": example_data,
263
+ }
264
+
265
+ if p is None or not (p in pipelines):
266
+ raise ValueError("Unsupported pipeline type: {}".format(p))
267
+
268
+ pipeline = pipelines[p]
269
+
270
+ def query_huggingface_api(*params):
271
+ # Convert to a list of input components
272
+ data = pipeline["preprocess"](*params)
273
+ if isinstance(
274
+ data, dict
275
+ ): # HF doesn't allow additional parameters for binary files (e.g. images or audio files)
276
+ data.update({"options": {"wait_for_model": True}})
277
+ data = json.dumps(data)
278
+ response = requests.request("POST", api_url, headers=headers, data=data)
279
+ if not (response.status_code == 200):
280
+ errors_json = response.json()
281
+ errors, warns = "", ""
282
+ if errors_json.get("error"):
283
+ errors = f", Error: {errors_json.get('error')}"
284
+ if errors_json.get("warnings"):
285
+ warns = f", Warnings: {errors_json.get('warnings')}"
286
+ raise ValueError(
287
+ f"Could not complete request to HuggingFace API, Status Code: {response.status_code}"
288
+ + errors
289
+ + warns
290
+ )
291
+ if (
292
+ p == "token-classification"
293
+ ): # Handle as a special case since HF API only returns the named entities and we need the input as well
294
+ ner_groups = response.json()
295
+ input_string = params[0]
296
+ response = utils.format_ner_list(input_string, ner_groups)
297
+ output = pipeline["postprocess"](response)
298
+ return output
299
+
300
+ if alias is None:
301
+ query_huggingface_api.__name__ = model_name
302
+ else:
303
+ query_huggingface_api.__name__ = alias
304
+
305
+ interface_info = {
306
+ "fn": query_huggingface_api,
307
+ "inputs": pipeline["inputs"],
308
+ "outputs": pipeline["outputs"],
309
+ "title": model_name,
310
+ "examples": pipeline.get("examples"),
311
+ }
312
+
313
+ kwargs = dict(interface_info, **kwargs)
314
+ kwargs["_api_mode"] = True # So interface doesn't run pre/postprocess.
315
+ interface = gradio.Interface(**kwargs)
316
+ return interface
317
+
318
+
319
+ def from_spaces(
320
+ space_name: str, api_key: str | None, alias: str | None, **kwargs
321
+ ) -> Blocks:
322
+ space_url = "https://huggingface.co/spaces/{}".format(space_name)
323
+
324
+ print("Fetching Space from: {}".format(space_url))
325
+
326
+ headers = {}
327
+ if api_key is not None:
328
+ headers["Authorization"] = f"Bearer {api_key}"
329
+
330
+ iframe_url = (
331
+ requests.get(
332
+ f"https://huggingface.co/api/spaces/{space_name}/host", headers=headers
333
+ )
334
+ .json()
335
+ .get("host")
336
+ )
337
+
338
+ if iframe_url is None:
339
+ raise ValueError(
340
+ f"Could not find Space: {space_name}. If it is a private or gated Space, please provide your Hugging Face access token (https://huggingface.co/settings/tokens) as the argument for the `api_key` parameter."
341
+ )
342
+
343
+ r = requests.get(iframe_url, headers=headers)
344
+
345
+ result = re.search(
346
+ r"window.gradio_config = (.*?);[\s]*</script>", r.text
347
+ ) # some basic regex to extract the config
348
+ try:
349
+ config = json.loads(result.group(1)) # type: ignore
350
+ except AttributeError:
351
+ raise ValueError("Could not load the Space: {}".format(space_name))
352
+ if "allow_flagging" in config: # Create an Interface for Gradio 2.x Spaces
353
+ return from_spaces_interface(
354
+ space_name, config, alias, api_key, iframe_url, **kwargs
355
+ )
356
+ else: # Create a Blocks for Gradio 3.x Spaces
357
+ if kwargs:
358
+ warnings.warn(
359
+ "You cannot override parameters for this Space by passing in kwargs. "
360
+ "Instead, please load the Space as a function and use it to create a "
361
+ "Blocks or Interface locally. You may find this Guide helpful: "
362
+ "https://gradio.app/using_blocks_like_functions/"
363
+ )
364
+ return from_spaces_blocks(config, api_key, iframe_url)
365
+
366
+
367
+ def from_spaces_blocks(config: Dict, api_key: str | None, iframe_url: str) -> Blocks:
368
+ api_url = "{}/api/predict/".format(iframe_url)
369
+
370
+ headers = {"Content-Type": "application/json"}
371
+ if api_key is not None:
372
+ headers["Authorization"] = f"Bearer {api_key}"
373
+ ws_url = "{}/queue/join".format(iframe_url).replace("https", "wss")
374
+
375
+ ws_fn = get_ws_fn(ws_url, headers)
376
+
377
+ fns = []
378
+ for d, dependency in enumerate(config["dependencies"]):
379
+ if dependency["backend_fn"]:
380
+
381
+ def get_fn(outputs, fn_index, use_ws):
382
+ def fn(*data):
383
+ data = json.dumps({"data": data, "fn_index": fn_index})
384
+ hash_data = json.dumps(
385
+ {"fn_index": fn_index, "session_hash": str(uuid.uuid4())}
386
+ )
387
+ if use_ws:
388
+ result = utils.synchronize_async(ws_fn, data, hash_data)
389
+ output = result["data"]
390
+ else:
391
+ response = requests.post(api_url, headers=headers, data=data)
392
+ result = json.loads(response.content.decode("utf-8"))
393
+ try:
394
+ output = result["data"]
395
+ except KeyError:
396
+ if "error" in result and "429" in result["error"]:
397
+ raise TooManyRequestsError(
398
+ "Too many requests to the Hugging Face API"
399
+ )
400
+ raise KeyError(
401
+ f"Could not find 'data' key in response from external Space. Response received: {result}"
402
+ )
403
+ if len(outputs) == 1:
404
+ output = output[0]
405
+ return output
406
+
407
+ return fn
408
+
409
+ fn = get_fn(
410
+ deepcopy(dependency["outputs"]), d, use_websocket(config, dependency)
411
+ )
412
+ fns.append(fn)
413
+ else:
414
+ fns.append(None)
415
+ return gradio.Blocks.from_config(config, fns, iframe_url)
416
+
417
+
418
+ def from_spaces_interface(
419
+ model_name: str,
420
+ config: Dict,
421
+ alias: str | None,
422
+ api_key: str | None,
423
+ iframe_url: str,
424
+ **kwargs,
425
+ ) -> Interface:
426
+
427
+ config = streamline_spaces_interface(config)
428
+ api_url = "{}/api/predict/".format(iframe_url)
429
+ headers = {"Content-Type": "application/json"}
430
+ if api_key is not None:
431
+ headers["Authorization"] = f"Bearer {api_key}"
432
+
433
+ # The function should call the API with preprocessed data
434
+ def fn(*data):
435
+ data = json.dumps({"data": data})
436
+ response = requests.post(api_url, headers=headers, data=data)
437
+ result = json.loads(response.content.decode("utf-8"))
438
+ try:
439
+ output = result["data"]
440
+ except KeyError:
441
+ if "error" in result and "429" in result["error"]:
442
+ raise TooManyRequestsError("Too many requests to the Hugging Face API")
443
+ raise KeyError(
444
+ f"Could not find 'data' key in response from external Space. Response received: {result}"
445
+ )
446
+ if (
447
+ len(config["outputs"]) == 1
448
+ ): # if the fn is supposed to return a single value, pop it
449
+ output = output[0]
450
+ if len(config["outputs"]) == 1 and isinstance(
451
+ output, list
452
+ ): # Needed to support Output.Image() returning bounding boxes as well (TODO: handle different versions of gradio since they have slightly different APIs)
453
+ output = output[0]
454
+ return output
455
+
456
+ fn.__name__ = alias if (alias is not None) else model_name
457
+ config["fn"] = fn
458
+
459
+ kwargs = dict(config, **kwargs)
460
+ kwargs["_api_mode"] = True
461
+ interface = gradio.Interface(**kwargs)
462
+ return interface
gradio-modified/gradio/external_utils.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Utility function for gradio/external.py"""
2
+
3
+ import base64
4
+ import json
5
+ import math
6
+ import operator
7
+ import re
8
+ import warnings
9
+ from typing import Any, Dict, List, Tuple
10
+
11
+ import requests
12
+ import websockets
13
+ import yaml
14
+ from packaging import version
15
+ from websockets.legacy.protocol import WebSocketCommonProtocol
16
+
17
+ from gradio import components, exceptions
18
+
19
+ ##################
20
+ # Helper functions for processing tabular data
21
+ ##################
22
+
23
+
24
+ def get_tabular_examples(model_name: str) -> Dict[str, List[float]]:
25
+ readme = requests.get(f"https://huggingface.co/{model_name}/resolve/main/README.md")
26
+ if readme.status_code != 200:
27
+ warnings.warn(f"Cannot load examples from README for {model_name}", UserWarning)
28
+ example_data = {}
29
+ else:
30
+ yaml_regex = re.search(
31
+ "(?:^|[\r\n])---[\n\r]+([\\S\\s]*?)[\n\r]+---([\n\r]|$)", readme.text
32
+ )
33
+ if yaml_regex is None:
34
+ example_data = {}
35
+ else:
36
+ example_yaml = next(
37
+ yaml.safe_load_all(readme.text[: yaml_regex.span()[-1]])
38
+ )
39
+ example_data = example_yaml.get("widget", {}).get("structuredData", {})
40
+ if not example_data:
41
+ raise ValueError(
42
+ f"No example data found in README.md of {model_name} - Cannot build gradio demo. "
43
+ "See the README.md here: https://huggingface.co/scikit-learn/tabular-playground/blob/main/README.md "
44
+ "for a reference on how to provide example data to your model."
45
+ )
46
+ # replace nan with string NaN for inference API
47
+ for data in example_data.values():
48
+ for i, val in enumerate(data):
49
+ if isinstance(val, float) and math.isnan(val):
50
+ data[i] = "NaN"
51
+ return example_data
52
+
53
+
54
+ def cols_to_rows(
55
+ example_data: Dict[str, List[float]]
56
+ ) -> Tuple[List[str], List[List[float]]]:
57
+ headers = list(example_data.keys())
58
+ n_rows = max(len(example_data[header] or []) for header in headers)
59
+ data = []
60
+ for row_index in range(n_rows):
61
+ row_data = []
62
+ for header in headers:
63
+ col = example_data[header] or []
64
+ if row_index >= len(col):
65
+ row_data.append("NaN")
66
+ else:
67
+ row_data.append(col[row_index])
68
+ data.append(row_data)
69
+ return headers, data
70
+
71
+
72
+ def rows_to_cols(incoming_data: Dict) -> Dict[str, Dict[str, Dict[str, List[str]]]]:
73
+ data_column_wise = {}
74
+ for i, header in enumerate(incoming_data["headers"]):
75
+ data_column_wise[header] = [str(row[i]) for row in incoming_data["data"]]
76
+ return {"inputs": {"data": data_column_wise}}
77
+
78
+
79
+ ##################
80
+ # Helper functions for processing other kinds of data
81
+ ##################
82
+
83
+
84
+ def postprocess_label(scores: Dict) -> Dict:
85
+ sorted_pred = sorted(scores.items(), key=operator.itemgetter(1), reverse=True)
86
+ return {
87
+ "label": sorted_pred[0][0],
88
+ "confidences": [
89
+ {"label": pred[0], "confidence": pred[1]} for pred in sorted_pred
90
+ ],
91
+ }
92
+
93
+
94
+ def encode_to_base64(r: requests.Response) -> str:
95
+ # Handles the different ways HF API returns the prediction
96
+ base64_repr = base64.b64encode(r.content).decode("utf-8")
97
+ data_prefix = ";base64,"
98
+ # Case 1: base64 representation already includes data prefix
99
+ if data_prefix in base64_repr:
100
+ return base64_repr
101
+ else:
102
+ content_type = r.headers.get("content-type")
103
+ # Case 2: the data prefix is a key in the response
104
+ if content_type == "application/json":
105
+ try:
106
+ content_type = r.json()[0]["content-type"]
107
+ base64_repr = r.json()[0]["blob"]
108
+ except KeyError:
109
+ raise ValueError(
110
+ "Cannot determine content type returned" "by external API."
111
+ )
112
+ # Case 3: the data prefix is included in the response headers
113
+ else:
114
+ pass
115
+ new_base64 = "data:{};base64,".format(content_type) + base64_repr
116
+ return new_base64
117
+
118
+
119
+ ##################
120
+ # Helper functions for connecting to websockets
121
+ ##################
122
+
123
+
124
+ async def get_pred_from_ws(
125
+ websocket: WebSocketCommonProtocol, data: str, hash_data: str
126
+ ) -> Dict[str, Any]:
127
+ completed = False
128
+ resp = {}
129
+ while not completed:
130
+ msg = await websocket.recv()
131
+ resp = json.loads(msg)
132
+ if resp["msg"] == "queue_full":
133
+ raise exceptions.Error("Queue is full! Please try again.")
134
+ if resp["msg"] == "send_hash":
135
+ await websocket.send(hash_data)
136
+ elif resp["msg"] == "send_data":
137
+ await websocket.send(data)
138
+ completed = resp["msg"] == "process_completed"
139
+ return resp["output"]
140
+
141
+
142
+ def get_ws_fn(ws_url, headers):
143
+ async def ws_fn(data, hash_data):
144
+ async with websockets.connect( # type: ignore
145
+ ws_url, open_timeout=10, extra_headers=headers
146
+ ) as websocket:
147
+ return await get_pred_from_ws(websocket, data, hash_data)
148
+
149
+ return ws_fn
150
+
151
+
152
+ def use_websocket(config, dependency):
153
+ queue_enabled = config.get("enable_queue", False)
154
+ queue_uses_websocket = version.parse(
155
+ config.get("version", "2.0")
156
+ ) >= version.Version("3.2")
157
+ dependency_uses_queue = dependency.get("queue", False) is not False
158
+ return queue_enabled and queue_uses_websocket and dependency_uses_queue
159
+
160
+
161
+ ##################
162
+ # Helper function for cleaning up an Interface loaded from HF Spaces
163
+ ##################
164
+
165
+
166
+ def streamline_spaces_interface(config: Dict) -> Dict:
167
+ """Streamlines the interface config dictionary to remove unnecessary keys."""
168
+ config["inputs"] = [
169
+ components.get_component_instance(component)
170
+ for component in config["input_components"]
171
+ ]
172
+ config["outputs"] = [
173
+ components.get_component_instance(component)
174
+ for component in config["output_components"]
175
+ ]
176
+ parameters = {
177
+ "article",
178
+ "description",
179
+ "flagging_options",
180
+ "inputs",
181
+ "outputs",
182
+ "theme",
183
+ "title",
184
+ }
185
+ config = {k: config[k] for k in parameters}
186
+ return config
gradio-modified/gradio/flagging.py ADDED
@@ -0,0 +1,560 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import csv
4
+ import datetime
5
+ import io
6
+ import json
7
+ import os
8
+ import uuid
9
+ from abc import ABC, abstractmethod
10
+ from pathlib import Path
11
+ from typing import TYPE_CHECKING, Any, List
12
+
13
+ import gradio as gr
14
+ from gradio import encryptor, utils
15
+ from gradio.documentation import document, set_documentation_group
16
+
17
+ if TYPE_CHECKING:
18
+ from gradio.components import IOComponent
19
+
20
+ set_documentation_group("flagging")
21
+
22
+
23
+ def _get_dataset_features_info(is_new, components):
24
+ """
25
+ Takes in a list of components and returns a dataset features info
26
+
27
+ Parameters:
28
+ is_new: boolean, whether the dataset is new or not
29
+ components: list of components
30
+
31
+ Returns:
32
+ infos: a dictionary of the dataset features
33
+ file_preview_types: dictionary mapping of gradio components to appropriate string.
34
+ header: list of header strings
35
+
36
+ """
37
+ infos = {"flagged": {"features": {}}}
38
+ # File previews for certain input and output types
39
+ file_preview_types = {gr.Audio: "Audio", gr.Image: "Image"}
40
+ headers = []
41
+
42
+ # Generate the headers and dataset_infos
43
+ if is_new:
44
+
45
+ for component in components:
46
+ headers.append(component.label)
47
+ infos["flagged"]["features"][component.label] = {
48
+ "dtype": "string",
49
+ "_type": "Value",
50
+ }
51
+ if isinstance(component, tuple(file_preview_types)):
52
+ headers.append(component.label + " file")
53
+ for _component, _type in file_preview_types.items():
54
+ if isinstance(component, _component):
55
+ infos["flagged"]["features"][
56
+ (component.label or "") + " file"
57
+ ] = {"_type": _type}
58
+ break
59
+
60
+ headers.append("flag")
61
+ infos["flagged"]["features"]["flag"] = {
62
+ "dtype": "string",
63
+ "_type": "Value",
64
+ }
65
+
66
+ return infos, file_preview_types, headers
67
+
68
+
69
+ class FlaggingCallback(ABC):
70
+ """
71
+ An abstract class for defining the methods that any FlaggingCallback should have.
72
+ """
73
+
74
+ @abstractmethod
75
+ def setup(self, components: List[IOComponent], flagging_dir: str):
76
+ """
77
+ This method should be overridden and ensure that everything is set up correctly for flag().
78
+ This method gets called once at the beginning of the Interface.launch() method.
79
+ Parameters:
80
+ components: Set of components that will provide flagged data.
81
+ flagging_dir: A string, typically containing the path to the directory where the flagging file should be storied (provided as an argument to Interface.__init__()).
82
+ """
83
+ pass
84
+
85
+ @abstractmethod
86
+ def flag(
87
+ self,
88
+ flag_data: List[Any],
89
+ flag_option: str | None = None,
90
+ flag_index: int | None = None,
91
+ username: str | None = None,
92
+ ) -> int:
93
+ """
94
+ This method should be overridden by the FlaggingCallback subclass and may contain optional additional arguments.
95
+ This gets called every time the <flag> button is pressed.
96
+ Parameters:
97
+ interface: The Interface object that is being used to launch the flagging interface.
98
+ flag_data: The data to be flagged.
99
+ flag_option (optional): In the case that flagging_options are provided, the flag option that is being used.
100
+ flag_index (optional): The index of the sample that is being flagged.
101
+ username (optional): The username of the user that is flagging the data, if logged in.
102
+ Returns:
103
+ (int) The total number of samples that have been flagged.
104
+ """
105
+ pass
106
+
107
+
108
+ @document()
109
+ class SimpleCSVLogger(FlaggingCallback):
110
+ """
111
+ A simplified implementation of the FlaggingCallback abstract class
112
+ provided for illustrative purposes. Each flagged sample (both the input and output data)
113
+ is logged to a CSV file on the machine running the gradio app.
114
+ Example:
115
+ import gradio as gr
116
+ def image_classifier(inp):
117
+ return {'cat': 0.3, 'dog': 0.7}
118
+ demo = gr.Interface(fn=image_classifier, inputs="image", outputs="label",
119
+ flagging_callback=SimpleCSVLogger())
120
+ """
121
+
122
+ def __init__(self):
123
+ pass
124
+
125
+ def setup(self, components: List[IOComponent], flagging_dir: str | Path):
126
+ self.components = components
127
+ self.flagging_dir = flagging_dir
128
+ os.makedirs(flagging_dir, exist_ok=True)
129
+
130
+ def flag(
131
+ self,
132
+ flag_data: List[Any],
133
+ flag_option: str | None = None,
134
+ flag_index: int | None = None,
135
+ username: str | None = None,
136
+ ) -> int:
137
+ flagging_dir = self.flagging_dir
138
+ log_filepath = Path(flagging_dir) / "log.csv"
139
+
140
+ csv_data = []
141
+ for component, sample in zip(self.components, flag_data):
142
+ save_dir = Path(flagging_dir) / utils.strip_invalid_filename_characters(
143
+ component.label or ""
144
+ )
145
+ csv_data.append(
146
+ component.deserialize(
147
+ sample,
148
+ save_dir,
149
+ None,
150
+ )
151
+ )
152
+
153
+ with open(log_filepath, "a", newline="") as csvfile:
154
+ writer = csv.writer(csvfile)
155
+ writer.writerow(utils.sanitize_list_for_csv(csv_data))
156
+
157
+ with open(log_filepath, "r") as csvfile:
158
+ line_count = len([None for row in csv.reader(csvfile)]) - 1
159
+ return line_count
160
+
161
+
162
+ @document()
163
+ class CSVLogger(FlaggingCallback):
164
+ """
165
+ The default implementation of the FlaggingCallback abstract class. Each flagged
166
+ sample (both the input and output data) is logged to a CSV file with headers on the machine running the gradio app.
167
+ Example:
168
+ import gradio as gr
169
+ def image_classifier(inp):
170
+ return {'cat': 0.3, 'dog': 0.7}
171
+ demo = gr.Interface(fn=image_classifier, inputs="image", outputs="label",
172
+ flagging_callback=CSVLogger())
173
+ Guides: using_flagging
174
+ """
175
+
176
+ def __init__(self):
177
+ pass
178
+
179
+ def setup(
180
+ self,
181
+ components: List[IOComponent],
182
+ flagging_dir: str | Path,
183
+ encryption_key: bytes | None = None,
184
+ ):
185
+ self.components = components
186
+ self.flagging_dir = flagging_dir
187
+ self.encryption_key = encryption_key
188
+ os.makedirs(flagging_dir, exist_ok=True)
189
+
190
+ def flag(
191
+ self,
192
+ flag_data: List[Any],
193
+ flag_option: str | None = None,
194
+ flag_index: int | None = None,
195
+ username: str | None = None,
196
+ ) -> int:
197
+ flagging_dir = self.flagging_dir
198
+ log_filepath = Path(flagging_dir) / "log.csv"
199
+ is_new = not Path(log_filepath).exists()
200
+ headers = [
201
+ component.label or f"component {idx}"
202
+ for idx, component in enumerate(self.components)
203
+ ] + [
204
+ "flag",
205
+ "username",
206
+ "timestamp",
207
+ ]
208
+
209
+ csv_data = []
210
+ for idx, (component, sample) in enumerate(zip(self.components, flag_data)):
211
+ save_dir = Path(flagging_dir) / utils.strip_invalid_filename_characters(
212
+ component.label or f"component {idx}"
213
+ )
214
+ if utils.is_update(sample):
215
+ csv_data.append(str(sample))
216
+ else:
217
+ csv_data.append(
218
+ component.deserialize(
219
+ sample,
220
+ save_dir=save_dir,
221
+ encryption_key=self.encryption_key,
222
+ )
223
+ if sample is not None
224
+ else ""
225
+ )
226
+ csv_data.append(flag_option if flag_option is not None else "")
227
+ csv_data.append(username if username is not None else "")
228
+ csv_data.append(str(datetime.datetime.now()))
229
+
230
+ def replace_flag_at_index(file_content: str, flag_index: int):
231
+ file_content_ = io.StringIO(file_content)
232
+ content = list(csv.reader(file_content_))
233
+ header = content[0]
234
+ flag_col_index = header.index("flag")
235
+ content[flag_index][flag_col_index] = flag_option # type: ignore
236
+ output = io.StringIO()
237
+ writer = csv.writer(output)
238
+ writer.writerows(utils.sanitize_list_for_csv(content))
239
+ return output.getvalue()
240
+
241
+ if self.encryption_key:
242
+ output = io.StringIO()
243
+ if not is_new:
244
+ with open(log_filepath, "rb", encoding="utf-8") as csvfile:
245
+ encrypted_csv = csvfile.read()
246
+ decrypted_csv = encryptor.decrypt(
247
+ self.encryption_key, encrypted_csv
248
+ )
249
+ file_content = decrypted_csv.decode()
250
+ if flag_index is not None:
251
+ file_content = replace_flag_at_index(file_content, flag_index)
252
+ output.write(file_content)
253
+ writer = csv.writer(output)
254
+ if flag_index is None:
255
+ if is_new:
256
+ writer.writerow(utils.sanitize_list_for_csv(headers))
257
+ writer.writerow(utils.sanitize_list_for_csv(csv_data))
258
+ with open(log_filepath, "wb", encoding="utf-8") as csvfile:
259
+ csvfile.write(
260
+ encryptor.encrypt(self.encryption_key, output.getvalue().encode())
261
+ )
262
+ else:
263
+ if flag_index is None:
264
+ with open(log_filepath, "a", newline="", encoding="utf-8") as csvfile:
265
+ writer = csv.writer(csvfile)
266
+ if is_new:
267
+ writer.writerow(utils.sanitize_list_for_csv(headers))
268
+ writer.writerow(utils.sanitize_list_for_csv(csv_data))
269
+ else:
270
+ with open(log_filepath, encoding="utf-8") as csvfile:
271
+ file_content = csvfile.read()
272
+ file_content = replace_flag_at_index(file_content, flag_index)
273
+ with open(
274
+ log_filepath, "w", newline="", encoding="utf-8"
275
+ ) as csvfile: # newline parameter needed for Windows
276
+ csvfile.write(file_content)
277
+ with open(log_filepath, "r", encoding="utf-8") as csvfile:
278
+ line_count = len([None for row in csv.reader(csvfile)]) - 1
279
+ return line_count
280
+
281
+
282
+ @document()
283
+ class HuggingFaceDatasetSaver(FlaggingCallback):
284
+ """
285
+ A callback that saves each flagged sample (both the input and output data)
286
+ to a HuggingFace dataset.
287
+ Example:
288
+ import gradio as gr
289
+ hf_writer = gr.HuggingFaceDatasetSaver(HF_API_TOKEN, "image-classification-mistakes")
290
+ def image_classifier(inp):
291
+ return {'cat': 0.3, 'dog': 0.7}
292
+ demo = gr.Interface(fn=image_classifier, inputs="image", outputs="label",
293
+ allow_flagging="manual", flagging_callback=hf_writer)
294
+ Guides: using_flagging
295
+ """
296
+
297
+ def __init__(
298
+ self,
299
+ hf_token: str,
300
+ dataset_name: str,
301
+ organization: str | None = None,
302
+ private: bool = False,
303
+ ):
304
+ """
305
+ Parameters:
306
+ hf_token: The HuggingFace token to use to create (and write the flagged sample to) the HuggingFace dataset.
307
+ dataset_name: The name of the dataset to save the data to, e.g. "image-classifier-1"
308
+ organization: The organization to save the dataset under. The hf_token must provide write access to this organization. If not provided, saved under the name of the user corresponding to the hf_token.
309
+ private: Whether the dataset should be private (defaults to False).
310
+ """
311
+ self.hf_token = hf_token
312
+ self.dataset_name = dataset_name
313
+ self.organization_name = organization
314
+ self.dataset_private = private
315
+
316
+ def setup(self, components: List[IOComponent], flagging_dir: str):
317
+ """
318
+ Params:
319
+ flagging_dir (str): local directory where the dataset is cloned,
320
+ updated, and pushed from.
321
+ """
322
+ try:
323
+ import huggingface_hub
324
+ except (ImportError, ModuleNotFoundError):
325
+ raise ImportError(
326
+ "Package `huggingface_hub` not found is needed "
327
+ "for HuggingFaceDatasetSaver. Try 'pip install huggingface_hub'."
328
+ )
329
+ path_to_dataset_repo = huggingface_hub.create_repo(
330
+ name=self.dataset_name,
331
+ token=self.hf_token,
332
+ private=self.dataset_private,
333
+ repo_type="dataset",
334
+ exist_ok=True,
335
+ )
336
+ self.path_to_dataset_repo = path_to_dataset_repo # e.g. "https://huggingface.co/datasets/abidlabs/test-audio-10"
337
+ self.components = components
338
+ self.flagging_dir = flagging_dir
339
+ self.dataset_dir = Path(flagging_dir) / self.dataset_name
340
+ self.repo = huggingface_hub.Repository(
341
+ local_dir=str(self.dataset_dir),
342
+ clone_from=path_to_dataset_repo,
343
+ use_auth_token=self.hf_token,
344
+ )
345
+ self.repo.git_pull(lfs=True)
346
+
347
+ # Should filename be user-specified?
348
+ self.log_file = Path(self.dataset_dir) / "data.csv"
349
+ self.infos_file = Path(self.dataset_dir) / "dataset_infos.json"
350
+
351
+ def flag(
352
+ self,
353
+ flag_data: List[Any],
354
+ flag_option: str | None = None,
355
+ flag_index: int | None = None,
356
+ username: str | None = None,
357
+ ) -> int:
358
+ self.repo.git_pull(lfs=True)
359
+
360
+ is_new = not Path(self.log_file).exists()
361
+
362
+ with open(self.log_file, "a", newline="", encoding="utf-8") as csvfile:
363
+ writer = csv.writer(csvfile)
364
+
365
+ # File previews for certain input and output types
366
+ infos, file_preview_types, headers = _get_dataset_features_info(
367
+ is_new, self.components
368
+ )
369
+
370
+ # Generate the headers and dataset_infos
371
+ if is_new:
372
+ writer.writerow(utils.sanitize_list_for_csv(headers))
373
+
374
+ # Generate the row corresponding to the flagged sample
375
+ csv_data = []
376
+ for component, sample in zip(self.components, flag_data):
377
+ save_dir = Path(
378
+ self.dataset_dir
379
+ ) / utils.strip_invalid_filename_characters(component.label or "")
380
+ filepath = component.deserialize(sample, save_dir, None)
381
+ csv_data.append(filepath)
382
+ if isinstance(component, tuple(file_preview_types)):
383
+ csv_data.append(
384
+ "{}/resolve/main/{}".format(self.path_to_dataset_repo, filepath)
385
+ )
386
+ csv_data.append(flag_option if flag_option is not None else "")
387
+ writer.writerow(utils.sanitize_list_for_csv(csv_data))
388
+
389
+ if is_new:
390
+ json.dump(infos, open(self.infos_file, "w"))
391
+
392
+ with open(self.log_file, "r", encoding="utf-8") as csvfile:
393
+ line_count = len([None for row in csv.reader(csvfile)]) - 1
394
+
395
+ self.repo.push_to_hub(commit_message="Flagged sample #{}".format(line_count))
396
+
397
+ return line_count
398
+
399
+
400
+ class HuggingFaceDatasetJSONSaver(FlaggingCallback):
401
+ """
402
+ A FlaggingCallback that saves flagged data to a Hugging Face dataset in JSONL format.
403
+
404
+ Each data sample is saved in a different JSONL file,
405
+ allowing multiple users to use flagging simultaneously.
406
+ Saving to a single CSV would cause errors as only one user can edit at the same time.
407
+
408
+ """
409
+
410
+ def __init__(
411
+ self,
412
+ hf_foken: str,
413
+ dataset_name: str,
414
+ organization: str | None = None,
415
+ private: bool = False,
416
+ verbose: bool = True,
417
+ ):
418
+ """
419
+ Params:
420
+ hf_token (str): The token to use to access the huggingface API.
421
+ dataset_name (str): The name of the dataset to save the data to, e.g.
422
+ "image-classifier-1"
423
+ organization (str): The name of the organization to which to attach
424
+ the datasets. If None, the dataset attaches to the user only.
425
+ private (bool): If the dataset does not already exist, whether it
426
+ should be created as a private dataset or public. Private datasets
427
+ may require paid huggingface.co accounts
428
+ verbose (bool): Whether to print out the status of the dataset
429
+ creation.
430
+ """
431
+ self.hf_foken = hf_foken
432
+ self.dataset_name = dataset_name
433
+ self.organization_name = organization
434
+ self.dataset_private = private
435
+ self.verbose = verbose
436
+
437
+ def setup(self, components: List[IOComponent], flagging_dir: str):
438
+ """
439
+ Params:
440
+ components List[Component]: list of components for flagging
441
+ flagging_dir (str): local directory where the dataset is cloned,
442
+ updated, and pushed from.
443
+ """
444
+ try:
445
+ import huggingface_hub
446
+ except (ImportError, ModuleNotFoundError):
447
+ raise ImportError(
448
+ "Package `huggingface_hub` not found is needed "
449
+ "for HuggingFaceDatasetJSONSaver. Try 'pip install huggingface_hub'."
450
+ )
451
+ path_to_dataset_repo = huggingface_hub.create_repo(
452
+ name=self.dataset_name,
453
+ token=self.hf_foken,
454
+ private=self.dataset_private,
455
+ repo_type="dataset",
456
+ exist_ok=True,
457
+ )
458
+ self.path_to_dataset_repo = path_to_dataset_repo # e.g. "https://huggingface.co/datasets/abidlabs/test-audio-10"
459
+ self.components = components
460
+ self.flagging_dir = flagging_dir
461
+ self.dataset_dir = Path(flagging_dir) / self.dataset_name
462
+ self.repo = huggingface_hub.Repository(
463
+ local_dir=str(self.dataset_dir),
464
+ clone_from=path_to_dataset_repo,
465
+ use_auth_token=self.hf_foken,
466
+ )
467
+ self.repo.git_pull(lfs=True)
468
+
469
+ self.infos_file = Path(self.dataset_dir) / "dataset_infos.json"
470
+
471
+ def flag(
472
+ self,
473
+ flag_data: List[Any],
474
+ flag_option: str | None = None,
475
+ flag_index: int | None = None,
476
+ username: str | None = None,
477
+ ) -> str:
478
+ self.repo.git_pull(lfs=True)
479
+
480
+ # Generate unique folder for the flagged sample
481
+ unique_name = self.get_unique_name() # unique name for folder
482
+ folder_name = (
483
+ Path(self.dataset_dir) / unique_name
484
+ ) # unique folder for specific example
485
+ os.makedirs(folder_name)
486
+
487
+ # Now uses the existence of `dataset_infos.json` to determine if new
488
+ is_new = not Path(self.infos_file).exists()
489
+
490
+ # File previews for certain input and output types
491
+ infos, file_preview_types, _ = _get_dataset_features_info(
492
+ is_new, self.components
493
+ )
494
+
495
+ # Generate the row and header corresponding to the flagged sample
496
+ csv_data = []
497
+ headers = []
498
+
499
+ for component, sample in zip(self.components, flag_data):
500
+ headers.append(component.label)
501
+
502
+ try:
503
+ save_dir = Path(folder_name) / utils.strip_invalid_filename_characters(
504
+ component.label or ""
505
+ )
506
+ filepath = component.deserialize(sample, save_dir, None)
507
+ except Exception:
508
+ # Could not parse 'sample' (mostly) because it was None and `component.save_flagged`
509
+ # does not handle None cases.
510
+ # for example: Label (line 3109 of components.py raises an error if data is None)
511
+ filepath = None
512
+
513
+ if isinstance(component, tuple(file_preview_types)):
514
+ headers.append(component.label or "" + " file")
515
+
516
+ csv_data.append(
517
+ "{}/resolve/main/{}/{}".format(
518
+ self.path_to_dataset_repo, unique_name, filepath
519
+ )
520
+ if filepath is not None
521
+ else None
522
+ )
523
+
524
+ csv_data.append(filepath)
525
+ headers.append("flag")
526
+ csv_data.append(flag_option if flag_option is not None else "")
527
+
528
+ # Creates metadata dict from row data and dumps it
529
+ metadata_dict = {
530
+ header: _csv_data for header, _csv_data in zip(headers, csv_data)
531
+ }
532
+ self.dump_json(metadata_dict, Path(folder_name) / "metadata.jsonl")
533
+
534
+ if is_new:
535
+ json.dump(infos, open(self.infos_file, "w"))
536
+
537
+ self.repo.push_to_hub(commit_message="Flagged sample {}".format(unique_name))
538
+ return unique_name
539
+
540
+ def get_unique_name(self):
541
+ id = uuid.uuid4()
542
+ return str(id)
543
+
544
+ def dump_json(self, thing: dict, file_path: str | Path) -> None:
545
+ with open(file_path, "w+", encoding="utf8") as f:
546
+ json.dump(thing, f)
547
+
548
+
549
+ class FlagMethod:
550
+ """
551
+ Helper class that contains the flagging button option and callback
552
+ """
553
+
554
+ def __init__(self, flagging_callback: FlaggingCallback, flag_option=None):
555
+ self.flagging_callback = flagging_callback
556
+ self.flag_option = flag_option
557
+ self.__name__ = "Flag"
558
+
559
+ def __call__(self, *flag_data):
560
+ self.flagging_callback.flag(list(flag_data), flag_option=self.flag_option)
gradio-modified/gradio/helpers.py ADDED
@@ -0,0 +1,792 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Defines helper methods useful for loading and caching Interface examples.
3
+ """
4
+ from __future__ import annotations
5
+
6
+ import ast
7
+ import csv
8
+ import inspect
9
+ import os
10
+ import subprocess
11
+ import tempfile
12
+ import threading
13
+ import warnings
14
+ from pathlib import Path
15
+ from typing import TYPE_CHECKING, Any, Callable, Iterable, List, Optional, Tuple
16
+
17
+ import matplotlib
18
+ import matplotlib.pyplot as plt
19
+ import numpy as np
20
+ import PIL
21
+
22
+ from gradio import processing_utils, routes, utils
23
+ from gradio.context import Context
24
+ from gradio.documentation import document, set_documentation_group
25
+ from gradio.flagging import CSVLogger
26
+
27
+ if TYPE_CHECKING: # Only import for type checking (to avoid circular imports).
28
+ from gradio.components import IOComponent
29
+
30
+ CACHED_FOLDER = "gradio_cached_examples"
31
+ LOG_FILE = "log.csv"
32
+
33
+ set_documentation_group("helpers")
34
+
35
+
36
+ def create_examples(
37
+ examples: List[Any] | List[List[Any]] | str,
38
+ inputs: IOComponent | List[IOComponent],
39
+ outputs: IOComponent | List[IOComponent] | None = None,
40
+ fn: Callable | None = None,
41
+ cache_examples: bool = False,
42
+ examples_per_page: int = 10,
43
+ _api_mode: bool = False,
44
+ label: str | None = None,
45
+ elem_id: str | None = None,
46
+ run_on_click: bool = False,
47
+ preprocess: bool = True,
48
+ postprocess: bool = True,
49
+ batch: bool = False,
50
+ ):
51
+ """Top-level synchronous function that creates Examples. Provided for backwards compatibility, i.e. so that gr.Examples(...) can be used to create the Examples component."""
52
+ examples_obj = Examples(
53
+ examples=examples,
54
+ inputs=inputs,
55
+ outputs=outputs,
56
+ fn=fn,
57
+ cache_examples=cache_examples,
58
+ examples_per_page=examples_per_page,
59
+ _api_mode=_api_mode,
60
+ label=label,
61
+ elem_id=elem_id,
62
+ run_on_click=run_on_click,
63
+ preprocess=preprocess,
64
+ postprocess=postprocess,
65
+ batch=batch,
66
+ _initiated_directly=False,
67
+ )
68
+ utils.synchronize_async(examples_obj.create)
69
+ return examples_obj
70
+
71
+
72
+ @document()
73
+ class Examples:
74
+ """
75
+ This class is a wrapper over the Dataset component and can be used to create Examples
76
+ for Blocks / Interfaces. Populates the Dataset component with examples and
77
+ assigns event listener so that clicking on an example populates the input/output
78
+ components. Optionally handles example caching for fast inference.
79
+
80
+ Demos: blocks_inputs, fake_gan
81
+ Guides: more_on_examples_and_flagging, using_hugging_face_integrations, image_classification_in_pytorch, image_classification_in_tensorflow, image_classification_with_vision_transformers, create_your_own_friends_with_a_gan
82
+ """
83
+
84
+ def __init__(
85
+ self,
86
+ examples: List[Any] | List[List[Any]] | str,
87
+ inputs: IOComponent | List[IOComponent],
88
+ outputs: Optional[IOComponent | List[IOComponent]] = None,
89
+ fn: Optional[Callable] = None,
90
+ cache_examples: bool = False,
91
+ examples_per_page: int = 10,
92
+ _api_mode: bool = False,
93
+ label: str = "Examples",
94
+ elem_id: Optional[str] = None,
95
+ run_on_click: bool = False,
96
+ preprocess: bool = True,
97
+ postprocess: bool = True,
98
+ batch: bool = False,
99
+ _initiated_directly: bool = True,
100
+ ):
101
+ """
102
+ Parameters:
103
+ examples: example inputs that can be clicked to populate specific components. Should be nested list, in which the outer list consists of samples and each inner list consists of an input corresponding to each input component. A string path to a directory of examples can also be provided but it should be within the directory with the python file running the gradio app. If there are multiple input components and a directory is provided, a log.csv file must be present in the directory to link corresponding inputs.
104
+ inputs: the component or list of components corresponding to the examples
105
+ outputs: optionally, provide the component or list of components corresponding to the output of the examples. Required if `cache` is True.
106
+ fn: optionally, provide the function to run to generate the outputs corresponding to the examples. Required if `cache` is True.
107
+ cache_examples: if True, caches examples for fast runtime. If True, then `fn` and `outputs` need to be provided
108
+ examples_per_page: how many examples to show per page.
109
+ label: the label to use for the examples component (by default, "Examples")
110
+ elem_id: an optional string that is assigned as the id of this component in the HTML DOM.
111
+ run_on_click: if cache_examples is False, clicking on an example does not run the function when an example is clicked. Set this to True to run the function when an example is clicked. Has no effect if cache_examples is True.
112
+ preprocess: if True, preprocesses the example input before running the prediction function and caching the output. Only applies if cache_examples is True.
113
+ postprocess: if True, postprocesses the example output after running the prediction function and before caching. Only applies if cache_examples is True.
114
+ batch: If True, then the function should process a batch of inputs, meaning that it should accept a list of input values for each parameter. Used only if cache_examples is True.
115
+ """
116
+ if _initiated_directly:
117
+ warnings.warn(
118
+ "Please use gr.Examples(...) instead of gr.examples.Examples(...) to create the Examples.",
119
+ )
120
+
121
+ if cache_examples and (fn is None or outputs is None):
122
+ raise ValueError("If caching examples, `fn` and `outputs` must be provided")
123
+
124
+ if not isinstance(inputs, list):
125
+ inputs = [inputs]
126
+ if not isinstance(outputs, list):
127
+ outputs = [outputs]
128
+
129
+ working_directory = Path().absolute()
130
+
131
+ if examples is None:
132
+ raise ValueError("The parameter `examples` cannot be None")
133
+ elif isinstance(examples, list) and (
134
+ len(examples) == 0 or isinstance(examples[0], list)
135
+ ):
136
+ pass
137
+ elif (
138
+ isinstance(examples, list) and len(inputs) == 1
139
+ ): # If there is only one input component, examples can be provided as a regular list instead of a list of lists
140
+ examples = [[e] for e in examples]
141
+ elif isinstance(examples, str):
142
+ if not os.path.exists(examples):
143
+ raise FileNotFoundError(
144
+ "Could not find examples directory: " + examples
145
+ )
146
+ working_directory = examples
147
+ if not os.path.exists(os.path.join(examples, LOG_FILE)):
148
+ if len(inputs) == 1:
149
+ examples = [[e] for e in os.listdir(examples)]
150
+ else:
151
+ raise FileNotFoundError(
152
+ "Could not find log file (required for multiple inputs): "
153
+ + LOG_FILE
154
+ )
155
+ else:
156
+ with open(os.path.join(examples, LOG_FILE)) as logs:
157
+ examples = list(csv.reader(logs))
158
+ examples = [
159
+ examples[i][: len(inputs)] for i in range(1, len(examples))
160
+ ] # remove header and unnecessary columns
161
+
162
+ else:
163
+ raise ValueError(
164
+ "The parameter `examples` must either be a string directory or a list"
165
+ "(if there is only 1 input component) or (more generally), a nested "
166
+ "list, where each sublist represents a set of inputs."
167
+ )
168
+
169
+ input_has_examples = [False] * len(inputs)
170
+ for example in examples:
171
+ for idx, example_for_input in enumerate(example):
172
+ if not (example_for_input is None):
173
+ try:
174
+ input_has_examples[idx] = True
175
+ except IndexError:
176
+ pass # If there are more example components than inputs, ignore. This can sometimes be intentional (e.g. loading from a log file where outputs and timestamps are also logged)
177
+
178
+ inputs_with_examples = [
179
+ inp for (inp, keep) in zip(inputs, input_has_examples) if keep
180
+ ]
181
+ non_none_examples = [
182
+ [ex for (ex, keep) in zip(example, input_has_examples) if keep]
183
+ for example in examples
184
+ ]
185
+
186
+ self.examples = examples
187
+ self.non_none_examples = non_none_examples
188
+ self.inputs = inputs
189
+ self.inputs_with_examples = inputs_with_examples
190
+ self.outputs = outputs
191
+ self.fn = fn
192
+ self.cache_examples = cache_examples
193
+ self._api_mode = _api_mode
194
+ self.preprocess = preprocess
195
+ self.postprocess = postprocess
196
+ self.batch = batch
197
+
198
+ with utils.set_directory(working_directory):
199
+ self.processed_examples = [
200
+ [
201
+ component.postprocess(sample)
202
+ for component, sample in zip(inputs, example)
203
+ ]
204
+ for example in examples
205
+ ]
206
+ self.non_none_processed_examples = [
207
+ [ex for (ex, keep) in zip(example, input_has_examples) if keep]
208
+ for example in self.processed_examples
209
+ ]
210
+ if cache_examples:
211
+ for example in self.examples:
212
+ if len([ex for ex in example if ex is not None]) != len(self.inputs):
213
+ warnings.warn(
214
+ "Examples are being cached but not all input components have "
215
+ "example values. This may result in an exception being thrown by "
216
+ "your function. If you do get an error while caching examples, make "
217
+ "sure all of your inputs have example values for all of your examples "
218
+ "or you provide default values for those particular parameters in your function."
219
+ )
220
+ break
221
+
222
+ from gradio.components import Dataset
223
+
224
+ with utils.set_directory(working_directory):
225
+ self.dataset = Dataset(
226
+ components=inputs_with_examples,
227
+ samples=non_none_examples,
228
+ type="index",
229
+ label=label,
230
+ samples_per_page=examples_per_page,
231
+ elem_id=elem_id,
232
+ )
233
+
234
+ self.cached_folder = os.path.join(CACHED_FOLDER, str(self.dataset._id))
235
+ self.cached_file = os.path.join(self.cached_folder, "log.csv")
236
+ self.cache_examples = cache_examples
237
+ self.run_on_click = run_on_click
238
+
239
+ async def create(self) -> None:
240
+ """Caches the examples if self.cache_examples is True and creates the Dataset
241
+ component to hold the examples"""
242
+
243
+ async def load_example(example_id):
244
+ if self.cache_examples:
245
+ processed_example = self.non_none_processed_examples[
246
+ example_id
247
+ ] + await self.load_from_cache(example_id)
248
+ else:
249
+ processed_example = self.non_none_processed_examples[example_id]
250
+ return utils.resolve_singleton(processed_example)
251
+
252
+ if Context.root_block:
253
+ self.dataset.click(
254
+ load_example,
255
+ inputs=[self.dataset],
256
+ outputs=self.inputs_with_examples
257
+ + (self.outputs if self.cache_examples else []),
258
+ postprocess=False,
259
+ queue=False,
260
+ )
261
+ if self.run_on_click and not self.cache_examples:
262
+ self.dataset.click(
263
+ self.fn,
264
+ inputs=self.inputs,
265
+ outputs=self.outputs,
266
+ )
267
+
268
+ if self.cache_examples:
269
+ await self.cache()
270
+
271
+ async def cache(self) -> None:
272
+ """
273
+ Caches all of the examples so that their predictions can be shown immediately.
274
+ """
275
+ if os.path.exists(self.cached_file):
276
+ print(
277
+ f"Using cache from '{os.path.abspath(self.cached_folder)}' directory. If method or examples have changed since last caching, delete this folder to clear cache."
278
+ )
279
+ else:
280
+ if Context.root_block is None:
281
+ raise ValueError("Cannot cache examples if not in a Blocks context")
282
+
283
+ print(f"Caching examples at: '{os.path.abspath(self.cached_file)}'")
284
+ cache_logger = CSVLogger()
285
+
286
+ # create a fake dependency to process the examples and get the predictions
287
+ dependency = Context.root_block.set_event_trigger(
288
+ event_name="fake_event",
289
+ fn=self.fn,
290
+ inputs=self.inputs_with_examples,
291
+ outputs=self.outputs,
292
+ preprocess=self.preprocess and not self._api_mode,
293
+ postprocess=self.postprocess and not self._api_mode,
294
+ batch=self.batch,
295
+ )
296
+
297
+ fn_index = Context.root_block.dependencies.index(dependency)
298
+ cache_logger.setup(self.outputs, self.cached_folder)
299
+ for example_id, _ in enumerate(self.examples):
300
+ processed_input = self.processed_examples[example_id]
301
+ if self.batch:
302
+ processed_input = [[value] for value in processed_input]
303
+ prediction = await Context.root_block.process_api(
304
+ fn_index=fn_index, inputs=processed_input, request=None, state={}
305
+ )
306
+ output = prediction["data"]
307
+ if self.batch:
308
+ output = [value[0] for value in output]
309
+ cache_logger.flag(output)
310
+ # Remove the "fake_event" to prevent bugs in loading interfaces from spaces
311
+ Context.root_block.dependencies.remove(dependency)
312
+ Context.root_block.fns.pop(fn_index)
313
+
314
+ async def load_from_cache(self, example_id: int) -> List[Any]:
315
+ """Loads a particular cached example for the interface.
316
+ Parameters:
317
+ example_id: The id of the example to process (zero-indexed).
318
+ """
319
+ with open(self.cached_file) as cache:
320
+ examples = list(csv.reader(cache))
321
+ example = examples[example_id + 1] # +1 to adjust for header
322
+ output = []
323
+ for component, value in zip(self.outputs, example):
324
+ try:
325
+ value_as_dict = ast.literal_eval(value)
326
+ assert utils.is_update(value_as_dict)
327
+ output.append(value_as_dict)
328
+ except (ValueError, TypeError, SyntaxError, AssertionError):
329
+ output.append(component.serialize(value, self.cached_folder))
330
+ return output
331
+
332
+
333
+ class TrackedIterable:
334
+ def __init__(
335
+ self,
336
+ iterable: Iterable,
337
+ index: int | None,
338
+ length: int | None,
339
+ desc: str | None,
340
+ unit: str | None,
341
+ _tqdm=None,
342
+ progress: float = None,
343
+ ) -> None:
344
+ self.iterable = iterable
345
+ self.index = index
346
+ self.length = length
347
+ self.desc = desc
348
+ self.unit = unit
349
+ self._tqdm = _tqdm
350
+ self.progress = progress
351
+
352
+
353
+ @document("__call__", "tqdm")
354
+ class Progress(Iterable):
355
+ """
356
+ The Progress class provides a custom progress tracker that is used in a function signature.
357
+ To attach a Progress tracker to a function, simply add a parameter right after the input parameters that has a default value set to a `gradio.Progress()` instance.
358
+ The Progress tracker can then be updated in the function by calling the Progress object or using the `tqdm` method on an Iterable.
359
+ The Progress tracker is currently only available with `queue()`.
360
+ Example:
361
+ import gradio as gr
362
+ import time
363
+ def my_function(x, progress=gr.Progress()):
364
+ progress(0, desc="Starting...")
365
+ time.sleep(1)
366
+ for i in progress.tqdm(range(100)):
367
+ time.sleep(0.1)
368
+ return x
369
+ gr.Interface(my_function, gr.Textbox(), gr.Textbox()).queue().launch()
370
+ Demos: progress
371
+ """
372
+
373
+ def __init__(
374
+ self,
375
+ track_tqdm: bool = False,
376
+ _active: bool = False,
377
+ _callback: Callable = None,
378
+ _event_id: str = None,
379
+ ):
380
+ """
381
+ Parameters:
382
+ track_tqdm: If True, the Progress object will track any tqdm.tqdm iterations with the tqdm library in the function.
383
+ """
384
+ self.track_tqdm = track_tqdm
385
+ self._active = _active
386
+ self._callback = _callback
387
+ self._event_id = _event_id
388
+ self.iterables: List[TrackedIterable] = []
389
+
390
+ def __len__(self):
391
+ return self.iterables[-1].length
392
+
393
+ def __iter__(self):
394
+ return self
395
+
396
+ def __next__(self):
397
+ """
398
+ Updates progress tracker with next item in iterable.
399
+ """
400
+ if self._active:
401
+ current_iterable = self.iterables[-1]
402
+ while (
403
+ not hasattr(current_iterable.iterable, "__next__")
404
+ and len(self.iterables) > 0
405
+ ):
406
+ current_iterable = self.iterables.pop()
407
+ self._callback(
408
+ event_id=self._event_id,
409
+ iterables=self.iterables,
410
+ )
411
+ current_iterable.index += 1
412
+ try:
413
+ return next(current_iterable.iterable)
414
+ except StopIteration:
415
+ self.iterables.pop()
416
+ raise StopIteration
417
+ else:
418
+ return self
419
+
420
+ def __call__(
421
+ self,
422
+ progress: float | Tuple[int, int | None] | None,
423
+ desc: str | None = None,
424
+ total: float | None = None,
425
+ unit: str = "steps",
426
+ _tqdm=None,
427
+ ):
428
+ """
429
+ Updates progress tracker with progress and message text.
430
+ Parameters:
431
+ progress: If float, should be between 0 and 1 representing completion. If Tuple, first number represents steps completed, and second value represents total steps or None if unknown. If None, hides progress bar.
432
+ desc: description to display.
433
+ total: estimated total number of steps.
434
+ unit: unit of iterations.
435
+ """
436
+ if self._active:
437
+ if isinstance(progress, tuple):
438
+ index, total = progress
439
+ progress = None
440
+ else:
441
+ index = None
442
+ self._callback(
443
+ event_id=self._event_id,
444
+ iterables=self.iterables
445
+ + [TrackedIterable(None, index, total, desc, unit, _tqdm, progress)],
446
+ )
447
+ else:
448
+ return progress
449
+
450
+ def tqdm(
451
+ self,
452
+ iterable: Iterable | None,
453
+ desc: str = None,
454
+ total: float = None,
455
+ unit: str = "steps",
456
+ _tqdm=None,
457
+ *args,
458
+ **kwargs,
459
+ ):
460
+ """
461
+ Attaches progress tracker to iterable, like tqdm.
462
+ Parameters:
463
+ iterable: iterable to attach progress tracker to.
464
+ desc: description to display.
465
+ total: estimated total number of steps.
466
+ unit: unit of iterations.
467
+ """
468
+ if iterable is None:
469
+ new_iterable = TrackedIterable(None, 0, total, desc, unit, _tqdm)
470
+ self.iterables.append(new_iterable)
471
+ self._callback(event_id=self._event_id, iterables=self.iterables)
472
+ return
473
+ length = len(iterable) if hasattr(iterable, "__len__") else None
474
+ self.iterables.append(
475
+ TrackedIterable(iter(iterable), 0, length, desc, unit, _tqdm)
476
+ )
477
+ return self
478
+
479
+ def update(self, n=1):
480
+ """
481
+ Increases latest iterable with specified number of steps.
482
+ Parameters:
483
+ n: number of steps completed.
484
+ """
485
+ if self._active and len(self.iterables) > 0:
486
+ current_iterable = self.iterables[-1]
487
+ current_iterable.index += n
488
+ self._callback(
489
+ event_id=self._event_id,
490
+ iterables=self.iterables,
491
+ )
492
+ else:
493
+ return
494
+
495
+ def close(self, _tqdm):
496
+ """
497
+ Removes iterable with given _tqdm.
498
+ """
499
+ if self._active:
500
+ for i in range(len(self.iterables)):
501
+ if id(self.iterables[i]._tqdm) == id(_tqdm):
502
+ self.iterables.pop(i)
503
+ break
504
+ self._callback(
505
+ event_id=self._event_id,
506
+ iterables=self.iterables,
507
+ )
508
+ else:
509
+ return
510
+
511
+
512
+ def create_tracker(root_blocks, event_id, fn, track_tqdm):
513
+
514
+ progress = Progress(
515
+ _active=True, _callback=root_blocks._queue.set_progress, _event_id=event_id
516
+ )
517
+ if not track_tqdm:
518
+ return progress, fn
519
+
520
+ try:
521
+ _tqdm = __import__("tqdm")
522
+ except ModuleNotFoundError:
523
+ return progress, fn
524
+ if not hasattr(root_blocks, "_progress_tracker_per_thread"):
525
+ root_blocks._progress_tracker_per_thread = {}
526
+
527
+ def init_tqdm(self, iterable=None, desc=None, *args, **kwargs):
528
+ self._progress = root_blocks._progress_tracker_per_thread.get(
529
+ threading.get_ident()
530
+ )
531
+ if self._progress is not None:
532
+ self._progress.event_id = event_id
533
+ self._progress.tqdm(iterable, desc, _tqdm=self, *args, **kwargs)
534
+ kwargs["file"] = open(os.devnull, "w")
535
+ self.__init__orig__(iterable, desc, *args, **kwargs)
536
+
537
+ def iter_tqdm(self):
538
+ if self._progress is not None:
539
+ return self._progress
540
+ else:
541
+ return self.__iter__orig__()
542
+
543
+ def update_tqdm(self, n=1):
544
+ if self._progress is not None:
545
+ self._progress.update(n)
546
+ return self.__update__orig__(n)
547
+
548
+ def close_tqdm(self):
549
+ if self._progress is not None:
550
+ self._progress.close(self)
551
+ return self.__close__orig__()
552
+
553
+ def exit_tqdm(self, exc_type, exc_value, traceback):
554
+ if self._progress is not None:
555
+ self._progress.close(self)
556
+ return self.__exit__orig__(exc_type, exc_value, traceback)
557
+
558
+ if not hasattr(_tqdm.tqdm, "__init__orig__"):
559
+ _tqdm.tqdm.__init__orig__ = _tqdm.tqdm.__init__
560
+ _tqdm.tqdm.__init__ = init_tqdm
561
+ if not hasattr(_tqdm.tqdm, "__update__orig__"):
562
+ _tqdm.tqdm.__update__orig__ = _tqdm.tqdm.update
563
+ _tqdm.tqdm.update = update_tqdm
564
+ if not hasattr(_tqdm.tqdm, "__close__orig__"):
565
+ _tqdm.tqdm.__close__orig__ = _tqdm.tqdm.close
566
+ _tqdm.tqdm.close = close_tqdm
567
+ if not hasattr(_tqdm.tqdm, "__exit__orig__"):
568
+ _tqdm.tqdm.__exit__orig__ = _tqdm.tqdm.__exit__
569
+ _tqdm.tqdm.__exit__ = exit_tqdm
570
+ if not hasattr(_tqdm.tqdm, "__iter__orig__"):
571
+ _tqdm.tqdm.__iter__orig__ = _tqdm.tqdm.__iter__
572
+ _tqdm.tqdm.__iter__ = iter_tqdm
573
+ if hasattr(_tqdm, "auto") and hasattr(_tqdm.auto, "tqdm"):
574
+ _tqdm.auto.tqdm = _tqdm.tqdm
575
+
576
+ def tracked_fn(*args):
577
+ thread_id = threading.get_ident()
578
+ root_blocks._progress_tracker_per_thread[thread_id] = progress
579
+ response = fn(*args)
580
+ del root_blocks._progress_tracker_per_thread[thread_id]
581
+ return response
582
+
583
+ return progress, tracked_fn
584
+
585
+
586
+ def special_args(
587
+ fn: Callable,
588
+ inputs: List[Any] | None = None,
589
+ request: routes.Request | None = None,
590
+ ):
591
+ """
592
+ Checks if function has special arguments Request (via annotation) or Progress (via default value).
593
+ If inputs is provided, these values will be loaded into the inputs array.
594
+ Parameters:
595
+ block_fn: function to check.
596
+ inputs: array to load special arguments into.
597
+ request: request to load into inputs.
598
+ Returns:
599
+ updated inputs, request index, progress index
600
+ """
601
+ signature = inspect.signature(fn)
602
+ positional_args = []
603
+ for i, param in enumerate(signature.parameters.values()):
604
+ if param.kind not in (param.POSITIONAL_ONLY, param.POSITIONAL_OR_KEYWORD):
605
+ break
606
+ positional_args.append(param)
607
+ progress_index = None
608
+ for i, param in enumerate(positional_args):
609
+ if isinstance(param.default, Progress):
610
+ progress_index = i
611
+ if inputs is not None:
612
+ inputs.insert(i, param.default)
613
+ elif param.annotation == routes.Request:
614
+ if inputs is not None:
615
+ inputs.insert(i, request)
616
+ if inputs is not None:
617
+ while len(inputs) < len(positional_args):
618
+ i = len(inputs)
619
+ param = positional_args[i]
620
+ if param.default == param.empty:
621
+ warnings.warn("Unexpected argument. Filling with None.")
622
+ inputs.append(None)
623
+ else:
624
+ inputs.append(param.default)
625
+ return inputs or [], progress_index
626
+
627
+
628
+ @document()
629
+ def update(**kwargs) -> dict:
630
+ """
631
+ Updates component properties. When a function passed into a Gradio Interface or a Blocks events returns a typical value, it updates the value of the output component. But it is also possible to update the properties of an output component (such as the number of lines of a `Textbox` or the visibility of an `Image`) by returning the component's `update()` function, which takes as parameters any of the constructor parameters for that component.
632
+ This is a shorthand for using the update method on a component.
633
+ For example, rather than using gr.Number.update(...) you can just use gr.update(...).
634
+ Note that your editor's autocompletion will suggest proper parameters
635
+ if you use the update method on the component.
636
+ Demos: blocks_essay, blocks_update, blocks_essay_update
637
+
638
+ Parameters:
639
+ kwargs: Key-word arguments used to update the component's properties.
640
+ Example:
641
+ # Blocks Example
642
+ import gradio as gr
643
+ with gr.Blocks() as demo:
644
+ radio = gr.Radio([1, 2, 4], label="Set the value of the number")
645
+ number = gr.Number(value=2, interactive=True)
646
+ radio.change(fn=lambda value: gr.update(value=value), inputs=radio, outputs=number)
647
+ demo.launch()
648
+
649
+ # Interface example
650
+ import gradio as gr
651
+ def change_textbox(choice):
652
+ if choice == "short":
653
+ return gr.Textbox.update(lines=2, visible=True)
654
+ elif choice == "long":
655
+ return gr.Textbox.update(lines=8, visible=True)
656
+ else:
657
+ return gr.Textbox.update(visible=False)
658
+ gr.Interface(
659
+ change_textbox,
660
+ gr.Radio(
661
+ ["short", "long", "none"], label="What kind of essay would you like to write?"
662
+ ),
663
+ gr.Textbox(lines=2),
664
+ live=True,
665
+ ).launch()
666
+ """
667
+ kwargs["__type__"] = "generic_update"
668
+ return kwargs
669
+
670
+
671
+ def skip() -> dict:
672
+ return update()
673
+
674
+
675
+ @document()
676
+ def make_waveform(
677
+ audio: str | Tuple[int, np.ndarray],
678
+ *,
679
+ bg_color: str = "#f3f4f6",
680
+ bg_image: str = None,
681
+ fg_alpha: float = 0.75,
682
+ bars_color: str | Tuple[str, str] = ("#fbbf24", "#ea580c"),
683
+ bar_count: int = 50,
684
+ bar_width: float = 0.6,
685
+ ):
686
+ """
687
+ Generates a waveform video from an audio file. Useful for creating an easy to share audio visualization. The output should be passed into a `gr.Video` component.
688
+ Parameters:
689
+ audio: Audio file path or tuple of (sample_rate, audio_data)
690
+ bg_color: Background color of waveform (ignored if bg_image is provided)
691
+ bg_image: Background image of waveform
692
+ fg_alpha: Opacity of foreground waveform
693
+ bars_color: Color of waveform bars. Can be a single color or a tuple of (start_color, end_color) of gradient
694
+ bar_count: Number of bars in waveform
695
+ bar_width: Width of bars in waveform. 1 represents full width, 0.5 represents half width, etc.
696
+ Returns:
697
+ A filepath to the output video.
698
+ """
699
+ if isinstance(audio, str):
700
+ audio_file = audio
701
+ audio = processing_utils.audio_from_file(audio)
702
+ else:
703
+ tmp_wav = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
704
+ processing_utils.audio_to_file(audio[0], audio[1], tmp_wav.name)
705
+ audio_file = tmp_wav.name
706
+ duration = round(len(audio[1]) / audio[0], 4)
707
+
708
+ # Helper methods to create waveform
709
+ def hex_to_RGB(hex_str):
710
+ return [int(hex_str[i : i + 2], 16) for i in range(1, 6, 2)]
711
+
712
+ def get_color_gradient(c1, c2, n):
713
+ assert n > 1
714
+ c1_rgb = np.array(hex_to_RGB(c1)) / 255
715
+ c2_rgb = np.array(hex_to_RGB(c2)) / 255
716
+ mix_pcts = [x / (n - 1) for x in range(n)]
717
+ rgb_colors = [((1 - mix) * c1_rgb + (mix * c2_rgb)) for mix in mix_pcts]
718
+ return [
719
+ "#" + "".join([format(int(round(val * 255)), "02x") for val in item])
720
+ for item in rgb_colors
721
+ ]
722
+
723
+ # Reshape audio to have a fixed number of bars
724
+ samples = audio[1]
725
+ if len(samples.shape) > 1:
726
+ samples = np.mean(samples, 1)
727
+ bins_to_pad = bar_count - (len(samples) % bar_count)
728
+ samples = np.pad(samples, [(0, bins_to_pad)])
729
+ samples = np.reshape(samples, (bar_count, -1))
730
+ samples = np.abs(samples)
731
+ samples = np.max(samples, 1)
732
+
733
+ matplotlib.use("Agg")
734
+ plt.clf()
735
+ # Plot waveform
736
+ color = (
737
+ bars_color
738
+ if isinstance(bars_color, str)
739
+ else get_color_gradient(bars_color[0], bars_color[1], bar_count)
740
+ )
741
+ plt.bar(
742
+ np.arange(0, bar_count),
743
+ samples * 2,
744
+ bottom=(-1 * samples),
745
+ width=bar_width,
746
+ color=color,
747
+ )
748
+ plt.axis("off")
749
+ plt.margins(x=0)
750
+ tmp_img = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
751
+ savefig_kwargs = {"bbox_inches": "tight"}
752
+ if bg_image is not None:
753
+ savefig_kwargs["transparent"] = True
754
+ else:
755
+ savefig_kwargs["facecolor"] = bg_color
756
+ plt.savefig(tmp_img.name, **savefig_kwargs)
757
+ waveform_img = PIL.Image.open(tmp_img.name)
758
+ waveform_img = waveform_img.resize((1000, 200))
759
+
760
+ # Composite waveform with background image
761
+ if bg_image is not None:
762
+ waveform_array = np.array(waveform_img)
763
+ waveform_array[:, :, 3] = waveform_array[:, :, 3] * fg_alpha
764
+ waveform_img = PIL.Image.fromarray(waveform_array)
765
+
766
+ bg_img = PIL.Image.open(bg_image)
767
+ waveform_width, waveform_height = waveform_img.size
768
+ bg_width, bg_height = bg_img.size
769
+ if waveform_width != bg_width:
770
+ bg_img = bg_img.resize(
771
+ (waveform_width, 2 * int(bg_height * waveform_width / bg_width / 2))
772
+ )
773
+ bg_width, bg_height = bg_img.size
774
+ composite_height = max(bg_height, waveform_height)
775
+ composite = PIL.Image.new("RGBA", (waveform_width, composite_height), "#FFFFFF")
776
+ composite.paste(bg_img, (0, composite_height - bg_height))
777
+ composite.paste(
778
+ waveform_img, (0, composite_height - waveform_height), waveform_img
779
+ )
780
+ composite.save(tmp_img.name)
781
+ img_width, img_height = composite.size
782
+ else:
783
+ img_width, img_height = waveform_img.size
784
+ waveform_img.save(tmp_img.name)
785
+
786
+ # Convert waveform to video with ffmpeg
787
+ output_mp4 = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)
788
+
789
+ ffmpeg_cmd = f"""ffmpeg -loop 1 -i {tmp_img.name} -i {audio_file} -vf "color=c=#FFFFFF77:s={img_width}x{img_height}[bar];[0][bar]overlay=-w+(w/{duration})*t:H-h:shortest=1" -t {duration} -y {output_mp4.name}"""
790
+
791
+ subprocess.call(ffmpeg_cmd, shell=True)
792
+ return output_mp4.name
gradio-modified/gradio/inputs.py ADDED
@@ -0,0 +1,473 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # type: ignore
2
+ """
3
+ This module defines various classes that can serve as the `input` to an interface. Each class must inherit from
4
+ `InputComponent`, and each class must define a path to its template. All of the subclasses of `InputComponent` are
5
+ automatically added to a registry, which allows them to be easily referenced in other parts of the code.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import warnings
11
+ from typing import Any, List, Optional, Tuple
12
+
13
+ from gradio import components
14
+
15
+
16
+ class Textbox(components.Textbox):
17
+ def __init__(
18
+ self,
19
+ lines: int = 1,
20
+ placeholder: Optional[str] = None,
21
+ default: str = "",
22
+ numeric: Optional[bool] = False,
23
+ type: Optional[str] = "text",
24
+ label: Optional[str] = None,
25
+ optional: bool = False,
26
+ ):
27
+ warnings.warn(
28
+ "Usage of gradio.inputs is deprecated, and will not be supported in the future, please import your component from gradio.components",
29
+ )
30
+ super().__init__(
31
+ value=default,
32
+ lines=lines,
33
+ placeholder=placeholder,
34
+ label=label,
35
+ numeric=numeric,
36
+ type=type,
37
+ optional=optional,
38
+ )
39
+
40
+
41
+ class Number(components.Number):
42
+ """
43
+ Component creates a field for user to enter numeric input. Provides a number as an argument to the wrapped function.
44
+ Input type: float
45
+ """
46
+
47
+ def __init__(
48
+ self,
49
+ default: Optional[float] = None,
50
+ label: Optional[str] = None,
51
+ optional: bool = False,
52
+ ):
53
+ """
54
+ Parameters:
55
+ default (float): default value.
56
+ label (str): component name in interface.
57
+ optional (bool): If True, the interface can be submitted with no value for this component.
58
+ """
59
+ warnings.warn(
60
+ "Usage of gradio.inputs is deprecated, and will not be supported in the future, please import your component from gradio.components",
61
+ )
62
+ super().__init__(value=default, label=label, optional=optional)
63
+
64
+
65
+ class Slider(components.Slider):
66
+ """
67
+ Component creates a slider that ranges from `minimum` to `maximum`. Provides number as an argument to the wrapped function.
68
+ Input type: float
69
+ """
70
+
71
+ def __init__(
72
+ self,
73
+ minimum: float = 0,
74
+ maximum: float = 100,
75
+ step: Optional[float] = None,
76
+ default: Optional[float] = None,
77
+ label: Optional[str] = None,
78
+ optional: bool = False,
79
+ ):
80
+ """
81
+ Parameters:
82
+ minimum (float): minimum value for slider.
83
+ maximum (float): maximum value for slider.
84
+ step (float): increment between slider values.
85
+ default (float): default value.
86
+ label (str): component name in interface.
87
+ optional (bool): this parameter is ignored.
88
+ """
89
+ warnings.warn(
90
+ "Usage of gradio.inputs is deprecated, and will not be supported in the future, please import your component from gradio.components",
91
+ )
92
+
93
+ super().__init__(
94
+ value=default,
95
+ minimum=minimum,
96
+ maximum=maximum,
97
+ step=step,
98
+ label=label,
99
+ optional=optional,
100
+ )
101
+
102
+
103
+ class Checkbox(components.Checkbox):
104
+ """
105
+ Component creates a checkbox that can be set to `True` or `False`. Provides a boolean as an argument to the wrapped function.
106
+ Input type: bool
107
+ """
108
+
109
+ def __init__(
110
+ self,
111
+ default: bool = False,
112
+ label: Optional[str] = None,
113
+ optional: bool = False,
114
+ ):
115
+ """
116
+ Parameters:
117
+ label (str): component name in interface.
118
+ default (bool): if True, checked by default.
119
+ optional (bool): this parameter is ignored.
120
+ """
121
+ warnings.warn(
122
+ "Usage of gradio.inputs is deprecated, and will not be supported in the future, please import your component from gradio.components",
123
+ )
124
+ super().__init__(value=default, label=label, optional=optional)
125
+
126
+
127
+ class CheckboxGroup(components.CheckboxGroup):
128
+ """
129
+ Component creates a set of checkboxes of which a subset can be selected. Provides a list of strings representing the selected choices as an argument to the wrapped function.
130
+ Input type: Union[List[str], List[int]]
131
+ """
132
+
133
+ def __init__(
134
+ self,
135
+ choices: List[str],
136
+ default: List[str] = [],
137
+ type: str = "value",
138
+ label: Optional[str] = None,
139
+ optional: bool = False,
140
+ ):
141
+ """
142
+ Parameters:
143
+ choices (List[str]): list of options to select from.
144
+ default (List[str]): default selected list of options.
145
+ type (str): Type of value to be returned by component. "value" returns the list of strings of the choices selected, "index" returns the list of indicies of the choices selected.
146
+ label (str): component name in interface.
147
+ optional (bool): this parameter is ignored.
148
+ """
149
+ warnings.warn(
150
+ "Usage of gradio.inputs is deprecated, and will not be supported in the future, please import your component from gradio.components",
151
+ )
152
+ super().__init__(
153
+ value=default,
154
+ choices=choices,
155
+ type=type,
156
+ label=label,
157
+ optional=optional,
158
+ )
159
+
160
+
161
+ class Radio(components.Radio):
162
+ """
163
+ Component creates a set of radio buttons of which only one can be selected. Provides string representing selected choice as an argument to the wrapped function.
164
+ Input type: Union[str, int]
165
+ """
166
+
167
+ def __init__(
168
+ self,
169
+ choices: List[str],
170
+ type: str = "value",
171
+ default: Optional[str] = None,
172
+ label: Optional[str] = None,
173
+ optional: bool = False,
174
+ ):
175
+ """
176
+ Parameters:
177
+ choices (List[str]): list of options to select from.
178
+ type (str): Type of value to be returned by component. "value" returns the string of the choice selected, "index" returns the index of the choice selected.
179
+ default (str): the button selected by default. If None, no button is selected by default.
180
+ label (str): component name in interface.
181
+ optional (bool): this parameter is ignored.
182
+ """
183
+ warnings.warn(
184
+ "Usage of gradio.inputs is deprecated, and will not be supported in the future, please import your component from gradio.components",
185
+ )
186
+ super().__init__(
187
+ choices=choices,
188
+ type=type,
189
+ value=default,
190
+ label=label,
191
+ optional=optional,
192
+ )
193
+
194
+
195
+ class Dropdown(components.Dropdown):
196
+ """
197
+ Component creates a dropdown of which only one can be selected. Provides string representing selected choice as an argument to the wrapped function.
198
+ Input type: Union[str, int]
199
+ """
200
+
201
+ def __init__(
202
+ self,
203
+ choices: List[str],
204
+ type: str = "value",
205
+ default: Optional[str] = None,
206
+ label: Optional[str] = None,
207
+ optional: bool = False,
208
+ ):
209
+ """
210
+ Parameters:
211
+ choices (List[str]): list of options to select from.
212
+ type (str): Type of value to be returned by component. "value" returns the string of the choice selected, "index" returns the index of the choice selected.
213
+ default (str): default value selected in dropdown. If None, no value is selected by default.
214
+ label (str): component name in interface.
215
+ optional (bool): this parameter is ignored.
216
+ """
217
+ warnings.warn(
218
+ "Usage of gradio.inputs is deprecated, and will not be supported in the future, please import your component from gradio.components",
219
+ )
220
+ super().__init__(
221
+ choices=choices,
222
+ type=type,
223
+ value=default,
224
+ label=label,
225
+ optional=optional,
226
+ )
227
+
228
+
229
+ class Image(components.Image):
230
+ """
231
+ Component creates an image upload box with editing capabilities.
232
+ Input type: Union[numpy.array, PIL.Image, file-object]
233
+ """
234
+
235
+ def __init__(
236
+ self,
237
+ shape: Tuple[int, int] = None,
238
+ image_mode: str = "RGB",
239
+ invert_colors: bool = False,
240
+ source: str = "upload",
241
+ tool: str = "editor",
242
+ type: str = "numpy",
243
+ label: str = None,
244
+ optional: bool = False,
245
+ ):
246
+ """
247
+ Parameters:
248
+ shape (Tuple[int, int]): (width, height) shape to crop and resize image to; if None, matches input image size.
249
+ image_mode (str): How to process the uploaded image. Accepts any of the PIL image modes, e.g. "RGB" for color images, "RGBA" to include the transparency mask, "L" for black-and-white images.
250
+ invert_colors (bool): whether to invert the image as a preprocessing step.
251
+ source (str): Source of image. "upload" creates a box where user can drop an image file, "webcam" allows user to take snapshot from their webcam, "canvas" defaults to a white image that can be edited and drawn upon with tools.
252
+ tool (str): Tools used for editing. "editor" allows a full screen editor, "select" provides a cropping and zoom tool.
253
+ type (str): Type of value to be returned by component. "numpy" returns a numpy array with shape (width, height, 3) and values from 0 to 255, "pil" returns a PIL image object, "file" returns a temporary file object whose path can be retrieved by file_obj.name, "filepath" returns the path directly.
254
+ label (str): component name in interface.
255
+ optional (bool): If True, the interface can be submitted with no uploaded image, in which case the input value is None.
256
+ """
257
+ warnings.warn(
258
+ "Usage of gradio.inputs is deprecated, and will not be supported in the future, please import your component from gradio.components",
259
+ )
260
+ super().__init__(
261
+ shape=shape,
262
+ image_mode=image_mode,
263
+ invert_colors=invert_colors,
264
+ source=source,
265
+ tool=tool,
266
+ type=type,
267
+ label=label,
268
+ optional=optional,
269
+ )
270
+
271
+
272
+ class Video(components.Video):
273
+ """
274
+ Component creates a video file upload that is converted to a file path.
275
+
276
+ Input type: filepath
277
+ """
278
+
279
+ def __init__(
280
+ self,
281
+ type: Optional[str] = None,
282
+ source: str = "upload",
283
+ label: Optional[str] = None,
284
+ optional: bool = False,
285
+ ):
286
+ """
287
+ Parameters:
288
+ type (str): Type of video format to be returned by component, such as 'avi' or 'mp4'. If set to None, video will keep uploaded format.
289
+ source (str): Source of video. "upload" creates a box where user can drop an video file, "webcam" allows user to record a video from their webcam.
290
+ label (str): component name in interface.
291
+ optional (bool): If True, the interface can be submitted with no uploaded video, in which case the input value is None.
292
+ """
293
+ warnings.warn(
294
+ "Usage of gradio.inputs is deprecated, and will not be supported in the future, please import your components from gradio.components",
295
+ )
296
+ super().__init__(format=type, source=source, label=label, optional=optional)
297
+
298
+
299
+ class Audio(components.Audio):
300
+ """
301
+ Component accepts audio input files.
302
+ Input type: Union[Tuple[int, numpy.array], file-object, numpy.array]
303
+ """
304
+
305
+ def __init__(
306
+ self,
307
+ source: str = "upload",
308
+ type: str = "numpy",
309
+ label: str = None,
310
+ optional: bool = False,
311
+ ):
312
+ """
313
+ Parameters:
314
+ source (str): Source of audio. "upload" creates a box where user can drop an audio file, "microphone" creates a microphone input.
315
+ type (str): Type of value to be returned by component. "numpy" returns a 2-set tuple with an integer sample_rate and the data numpy.array of shape (samples, 2), "file" returns a temporary file object whose path can be retrieved by file_obj.name, "filepath" returns the path directly.
316
+ label (str): component name in interface.
317
+ optional (bool): If True, the interface can be submitted with no uploaded audio, in which case the input value is None.
318
+ """
319
+ warnings.warn(
320
+ "Usage of gradio.inputs is deprecated, and will not be supported in the future, please import your components from gradio.components",
321
+ )
322
+ super().__init__(source=source, type=type, label=label, optional=optional)
323
+
324
+
325
+ class File(components.File):
326
+ """
327
+ Component accepts generic file uploads.
328
+ Input type: Union[file-object, bytes, List[Union[file-object, bytes]]]
329
+ """
330
+
331
+ def __init__(
332
+ self,
333
+ file_count: str = "single",
334
+ type: str = "file",
335
+ label: Optional[str] = None,
336
+ keep_filename: bool = True,
337
+ optional: bool = False,
338
+ ):
339
+ """
340
+ Parameters:
341
+ file_count (str): if single, allows user to upload one file. If "multiple", user uploads multiple files. If "directory", user uploads all files in selected directory. Return type will be list for each file in case of "multiple" or "directory".
342
+ type (str): Type of value to be returned by component. "file" returns a temporary file object whose path can be retrieved by file_obj.name, "binary" returns an bytes object.
343
+ label (str): component name in interface.
344
+ keep_filename (bool): DEPRECATED. Original filename always kept.
345
+ optional (bool): If True, the interface can be submitted with no uploaded image, in which case the input value is None.
346
+ """
347
+ warnings.warn(
348
+ "Usage of gradio.inputs is deprecated, and will not be supported in the future, please import your components from gradio.components",
349
+ )
350
+ super().__init__(
351
+ file_count=file_count,
352
+ type=type,
353
+ label=label,
354
+ keep_filename=keep_filename,
355
+ optional=optional,
356
+ )
357
+
358
+
359
+ class Dataframe(components.Dataframe):
360
+ """
361
+ Component accepts 2D input through a spreadsheet interface.
362
+ Input type: Union[pandas.DataFrame, numpy.array, List[Union[str, float]], List[List[Union[str, float]]]]
363
+ """
364
+
365
+ def __init__(
366
+ self,
367
+ headers: Optional[List[str]] = None,
368
+ row_count: int = 3,
369
+ col_count: Optional[int] = 3,
370
+ datatype: str | List[str] = "str",
371
+ col_width: int | List[int] = None,
372
+ default: Optional[List[List[Any]]] = None,
373
+ type: str = "pandas",
374
+ label: Optional[str] = None,
375
+ optional: bool = False,
376
+ ):
377
+ """
378
+ Parameters:
379
+ headers (List[str]): Header names to dataframe. If None, no headers are shown.
380
+ row_count (int): Limit number of rows for input.
381
+ col_count (int): Limit number of columns for input. If equal to 1, return data will be one-dimensional. Ignored if `headers` is provided.
382
+ datatype (Union[str, List[str]]): Datatype of values in sheet. Can be provided per column as a list of strings, or for the entire sheet as a single string. Valid datatypes are "str", "number", "bool", and "date".
383
+ col_width (Union[int, List[int]]): Width of columns in pixels. Can be provided as single value or list of values per column.
384
+ default (List[List[Any]]): Default value
385
+ type (str): Type of value to be returned by component. "pandas" for pandas dataframe, "numpy" for numpy array, or "array" for a Python array.
386
+ label (str): component name in interface.
387
+ optional (bool): this parameter is ignored.
388
+ """
389
+ warnings.warn(
390
+ "Usage of gradio.inputs is deprecated, and will not be supported in the future, please import your components from gradio.components",
391
+ )
392
+ super().__init__(
393
+ value=default,
394
+ headers=headers,
395
+ row_count=row_count,
396
+ col_count=col_count,
397
+ datatype=datatype,
398
+ col_width=col_width,
399
+ type=type,
400
+ label=label,
401
+ optional=optional,
402
+ )
403
+
404
+
405
+ class Timeseries(components.Timeseries):
406
+ """
407
+ Component accepts pandas.DataFrame uploaded as a timeseries csv file.
408
+ Input type: pandas.DataFrame
409
+ """
410
+
411
+ def __init__(
412
+ self,
413
+ x: Optional[str] = None,
414
+ y: str | List[str] = None,
415
+ label: Optional[str] = None,
416
+ optional: bool = False,
417
+ ):
418
+ """
419
+ Parameters:
420
+ x (str): Column name of x (time) series. None if csv has no headers, in which case first column is x series.
421
+ y (Union[str, List[str]]): Column name of y series, or list of column names if multiple series. None if csv has no headers, in which case every column after first is a y series.
422
+ label (str): component name in interface.
423
+ optional (bool): If True, the interface can be submitted with no uploaded csv file, in which case the input value is None.
424
+ """
425
+ warnings.warn(
426
+ "Usage of gradio.inputs is deprecated, and will not be supported in the future, please import your components from gradio.components",
427
+ )
428
+ super().__init__(x=x, y=y, label=label, optional=optional)
429
+
430
+
431
+ class State(components.State):
432
+ """
433
+ Special hidden component that stores state across runs of the interface.
434
+ Input type: Any
435
+ """
436
+
437
+ def __init__(
438
+ self,
439
+ label: str = None,
440
+ default: Any = None,
441
+ ):
442
+ """
443
+ Parameters:
444
+ label (str): component name in interface (not used).
445
+ default (Any): the initial value of the state.
446
+ optional (bool): this parameter is ignored.
447
+ """
448
+ warnings.warn(
449
+ "Usage of gradio.inputs is deprecated, and will not be supported in the future, please import this component as gr.State() from gradio.components",
450
+ )
451
+ super().__init__(value=default, label=label)
452
+
453
+
454
+ class Image3D(components.Model3D):
455
+ """
456
+ Used for 3D image model output.
457
+ Input type: File object of type (.obj, glb, or .gltf)
458
+ """
459
+
460
+ def __init__(
461
+ self,
462
+ label: Optional[str] = None,
463
+ optional: bool = False,
464
+ ):
465
+ """
466
+ Parameters:
467
+ label (str): component name in interface.
468
+ optional (bool): If True, the interface can be submitted with no uploaded image, in which case the input value is None.
469
+ """
470
+ warnings.warn(
471
+ "Usage of gradio.outputs is deprecated, and will not be supported in the future, please import your components from gradio.components",
472
+ )
473
+ super().__init__(label=label, optional=optional)
gradio-modified/gradio/interface.py ADDED
@@ -0,0 +1,844 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This is the core file in the `gradio` package, and defines the Interface class,
3
+ including various methods for constructing an interface and then launching it.
4
+ """
5
+
6
+ from __future__ import annotations
7
+
8
+ import inspect
9
+ import json
10
+ import os
11
+ import pkgutil
12
+ import re
13
+ import warnings
14
+ import weakref
15
+ from typing import TYPE_CHECKING, Any, Callable, List, Tuple
16
+
17
+ from markdown_it import MarkdownIt
18
+ from mdit_py_plugins.dollarmath.index import dollarmath_plugin
19
+ from mdit_py_plugins.footnote.index import footnote_plugin
20
+
21
+ from gradio import Examples, interpretation, utils
22
+ from gradio.blocks import Blocks
23
+ from gradio.components import (
24
+ Button,
25
+ Interpretation,
26
+ IOComponent,
27
+ Markdown,
28
+ State,
29
+ get_component_instance,
30
+ )
31
+ from gradio.data_classes import InterfaceTypes
32
+ from gradio.documentation import document, set_documentation_group
33
+ from gradio.events import Changeable, Streamable
34
+ from gradio.flagging import CSVLogger, FlaggingCallback, FlagMethod
35
+ from gradio.layouts import Column, Row, Tab, Tabs
36
+ from gradio.pipelines import load_from_pipeline
37
+
38
+ set_documentation_group("interface")
39
+
40
+ if TYPE_CHECKING: # Only import for type checking (is False at runtime).
41
+ from transformers.pipelines.base import Pipeline
42
+
43
+
44
+ @document("launch", "load", "from_pipeline", "integrate", "queue")
45
+ class Interface(Blocks):
46
+ """
47
+ Interface is Gradio's main high-level class, and allows you to create a web-based GUI / demo
48
+ around a machine learning model (or any Python function) in a few lines of code.
49
+ You must specify three parameters: (1) the function to create a GUI for (2) the desired input components and
50
+ (3) the desired output components. Additional parameters can be used to control the appearance
51
+ and behavior of the demo.
52
+
53
+ Example:
54
+ import gradio as gr
55
+
56
+ def image_classifier(inp):
57
+ return {'cat': 0.3, 'dog': 0.7}
58
+
59
+ demo = gr.Interface(fn=image_classifier, inputs="image", outputs="label")
60
+ demo.launch()
61
+ Demos: hello_world, hello_world_3, gpt_j
62
+ Guides: quickstart, key_features, sharing_your_app, interface_state, reactive_interfaces, advanced_interface_features, setting_up_a_gradio_demo_for_maximum_performance
63
+ """
64
+
65
+ # stores references to all currently existing Interface instances
66
+ instances: weakref.WeakSet = weakref.WeakSet()
67
+
68
+ @classmethod
69
+ def get_instances(cls) -> List[Interface]:
70
+ """
71
+ :return: list of all current instances.
72
+ """
73
+ return list(Interface.instances)
74
+
75
+ @classmethod
76
+ def load(
77
+ cls,
78
+ name: str,
79
+ src: str | None = None,
80
+ api_key: str | None = None,
81
+ alias: str | None = None,
82
+ **kwargs,
83
+ ) -> Interface:
84
+ """
85
+ Class method that constructs an Interface from a Hugging Face repo. Can accept
86
+ model repos (if src is "models") or Space repos (if src is "spaces"). The input
87
+ and output components are automatically loaded from the repo.
88
+ Parameters:
89
+ name: the name of the model (e.g. "gpt2" or "facebook/bart-base") or space (e.g. "flax-community/spanish-gpt2"), can include the `src` as prefix (e.g. "models/facebook/bart-base")
90
+ src: the source of the model: `models` or `spaces` (or leave empty if source is provided as a prefix in `name`)
91
+ api_key: optional access token for loading private Hugging Face Hub models or spaces. Find your token here: https://huggingface.co/settings/tokens
92
+ alias: optional string used as the name of the loaded model instead of the default name (only applies if loading a Space running Gradio 2.x)
93
+ Returns:
94
+ a Gradio Interface object for the given model
95
+ Example:
96
+ import gradio as gr
97
+ description = "Story generation with GPT"
98
+ examples = [["An adventurer is approached by a mysterious stranger in the tavern for a new quest."]]
99
+ demo = gr.Interface.load("models/EleutherAI/gpt-neo-1.3B", description=description, examples=examples)
100
+ demo.launch()
101
+ """
102
+ return super().load(name=name, src=src, api_key=api_key, alias=alias, **kwargs)
103
+
104
+ @classmethod
105
+ def from_pipeline(cls, pipeline: Pipeline, **kwargs) -> Interface:
106
+ """
107
+ Class method that constructs an Interface from a Hugging Face transformers.Pipeline object.
108
+ The input and output components are automatically determined from the pipeline.
109
+ Parameters:
110
+ pipeline: the pipeline object to use.
111
+ Returns:
112
+ a Gradio Interface object from the given Pipeline
113
+ Example:
114
+ import gradio as gr
115
+ from transformers import pipeline
116
+ pipe = pipeline("image-classification")
117
+ gr.Interface.from_pipeline(pipe).launch()
118
+ """
119
+ interface_info = load_from_pipeline(pipeline)
120
+ kwargs = dict(interface_info, **kwargs)
121
+ interface = cls(**kwargs)
122
+ return interface
123
+
124
+ def __init__(
125
+ self,
126
+ fn: Callable,
127
+ inputs: str | IOComponent | List[str | IOComponent] | None,
128
+ outputs: str | IOComponent | List[str | IOComponent] | None,
129
+ examples: List[Any] | List[List[Any]] | str | None = None,
130
+ cache_examples: bool | None = None,
131
+ examples_per_page: int = 10,
132
+ live: bool = False,
133
+ interpretation: Callable | str | None = None,
134
+ num_shap: float = 2.0,
135
+ title: str | None = None,
136
+ description: str | None = None,
137
+ article: str | None = None,
138
+ thumbnail: str | None = None,
139
+ theme: str = "default",
140
+ css: str | None = None,
141
+ allow_flagging: str | None = None,
142
+ flagging_options: List[str] | None = None,
143
+ flagging_dir: str = "flagged",
144
+ flagging_callback: FlaggingCallback = CSVLogger(),
145
+ analytics_enabled: bool | None = None,
146
+ batch: bool = False,
147
+ max_batch_size: int = 4,
148
+ _api_mode: bool = False,
149
+ **kwargs,
150
+ ):
151
+ """
152
+ Parameters:
153
+ fn: the function to wrap an interface around. Often a machine learning model's prediction function. Each parameter of the function corresponds to one input component, and the function should return a single value or a tuple of values, with each element in the tuple corresponding to one output component.
154
+ inputs: a single Gradio component, or list of Gradio components. Components can either be passed as instantiated objects, or referred to by their string shortcuts. The number of input components should match the number of parameters in fn. If set to None, then only the output components will be displayed.
155
+ outputs: a single Gradio component, or list of Gradio components. Components can either be passed as instantiated objects, or referred to by their string shortcuts. The number of output components should match the number of values returned by fn. If set to None, then only the input components will be displayed.
156
+ examples: sample inputs for the function; if provided, appear below the UI components and can be clicked to populate the interface. Should be nested list, in which the outer list consists of samples and each inner list consists of an input corresponding to each input component. A string path to a directory of examples can also be provided, but it should be within the directory with the python file running the gradio app. If there are multiple input components and a directory is provided, a log.csv file must be present in the directory to link corresponding inputs.
157
+ cache_examples: If True, caches examples in the server for fast runtime in examples. The default option in HuggingFace Spaces is True. The default option elsewhere is False.
158
+ examples_per_page: If examples are provided, how many to display per page.
159
+ live: whether the interface should automatically rerun if any of the inputs change.
160
+ interpretation: function that provides interpretation explaining prediction output. Pass "default" to use simple built-in interpreter, "shap" to use a built-in shapley-based interpreter, or your own custom interpretation function. For more information on the different interpretation methods, see the Advanced Interface Features guide.
161
+ num_shap: a multiplier that determines how many examples are computed for shap-based interpretation. Increasing this value will increase shap runtime, but improve results. Only applies if interpretation is "shap".
162
+ title: a title for the interface; if provided, appears above the input and output components in large font. Also used as the tab title when opened in a browser window.
163
+ description: a description for the interface; if provided, appears above the input and output components and beneath the title in regular font. Accepts Markdown and HTML content.
164
+ article: an expanded article explaining the interface; if provided, appears below the input and output components in regular font. Accepts Markdown and HTML content.
165
+ thumbnail: path or url to image to use as display image when the web demo is shared on social media.
166
+ theme: Theme to use - right now, only "default" is supported. Can be set with the GRADIO_THEME environment variable.
167
+ css: custom css or path to custom css file to use with interface.
168
+ allow_flagging: one of "never", "auto", or "manual". If "never" or "auto", users will not see a button to flag an input and output. If "manual", users will see a button to flag. If "auto", every input the user submits will be automatically flagged (outputs are not flagged). If "manual", both the input and outputs are flagged when the user clicks flag button. This parameter can be set with environmental variable GRADIO_ALLOW_FLAGGING; otherwise defaults to "manual".
169
+ flagging_options: if provided, allows user to select from the list of options when flagging. Only applies if allow_flagging is "manual".
170
+ flagging_dir: what to name the directory where flagged data is stored.
171
+ flagging_callback: An instance of a subclass of FlaggingCallback which will be called when a sample is flagged. By default logs to a local CSV file.
172
+ analytics_enabled: Whether to allow basic telemetry. If None, will use GRADIO_ANALYTICS_ENABLED environment variable if defined, or default to True.
173
+ batch: If True, then the function should process a batch of inputs, meaning that it should accept a list of input values for each parameter. The lists should be of equal length (and be up to length `max_batch_size`). The function is then *required* to return a tuple of lists (even if there is only 1 output component), with each list in the tuple corresponding to one output component.
174
+ max_batch_size: Maximum number of inputs to batch together if this is called from the queue (only relevant if batch=True)
175
+ """
176
+ super().__init__(
177
+ analytics_enabled=analytics_enabled,
178
+ mode="interface",
179
+ css=css,
180
+ title=title or "Gradio",
181
+ theme=theme,
182
+ **kwargs,
183
+ )
184
+
185
+ if isinstance(fn, list):
186
+ raise DeprecationWarning(
187
+ "The `fn` parameter only accepts a single function, support for a list "
188
+ "of functions has been deprecated. Please use gradio.mix.Parallel "
189
+ "instead."
190
+ )
191
+
192
+ self.interface_type = InterfaceTypes.STANDARD
193
+ if (inputs is None or inputs == []) and (outputs is None or outputs == []):
194
+ raise ValueError("Must provide at least one of `inputs` or `outputs`")
195
+ elif outputs is None or outputs == []:
196
+ outputs = []
197
+ self.interface_type = InterfaceTypes.INPUT_ONLY
198
+ elif inputs is None or inputs == []:
199
+ inputs = []
200
+ self.interface_type = InterfaceTypes.OUTPUT_ONLY
201
+
202
+ assert isinstance(inputs, (str, list, IOComponent))
203
+ assert isinstance(outputs, (str, list, IOComponent))
204
+
205
+ if not isinstance(inputs, list):
206
+ inputs = [inputs]
207
+ if not isinstance(outputs, list):
208
+ outputs = [outputs]
209
+
210
+ if self.is_space and cache_examples is None:
211
+ self.cache_examples = True
212
+ else:
213
+ self.cache_examples = cache_examples or False
214
+
215
+ state_input_indexes = [
216
+ idx for idx, i in enumerate(inputs) if i == "state" or isinstance(i, State)
217
+ ]
218
+ state_output_indexes = [
219
+ idx for idx, o in enumerate(outputs) if o == "state" or isinstance(o, State)
220
+ ]
221
+
222
+ if len(state_input_indexes) == 0 and len(state_output_indexes) == 0:
223
+ pass
224
+ elif len(state_input_indexes) != 1 or len(state_output_indexes) != 1:
225
+ raise ValueError(
226
+ "If using 'state', there must be exactly one state input and one state output."
227
+ )
228
+ else:
229
+ state_input_index = state_input_indexes[0]
230
+ state_output_index = state_output_indexes[0]
231
+ if inputs[state_input_index] == "state":
232
+ default = utils.get_default_args(fn)[state_input_index]
233
+ state_variable = State(value=default) # type: ignore
234
+ else:
235
+ state_variable = inputs[state_input_index]
236
+
237
+ inputs[state_input_index] = state_variable
238
+ outputs[state_output_index] = state_variable
239
+
240
+ if cache_examples:
241
+ warnings.warn(
242
+ "Cache examples cannot be used with state inputs and outputs."
243
+ "Setting cache_examples to False."
244
+ )
245
+ self.cache_examples = False
246
+
247
+ self.input_components = [
248
+ get_component_instance(i, render=False) for i in inputs
249
+ ]
250
+ self.output_components = [
251
+ get_component_instance(o, render=False) for o in outputs
252
+ ]
253
+
254
+ for component in self.input_components + self.output_components:
255
+ if not (isinstance(component, IOComponent)):
256
+ raise ValueError(
257
+ f"{component} is not a valid input/output component for Interface."
258
+ )
259
+
260
+ if len(self.input_components) == len(self.output_components):
261
+ same_components = [
262
+ i is o for i, o in zip(self.input_components, self.output_components)
263
+ ]
264
+ if all(same_components):
265
+ self.interface_type = InterfaceTypes.UNIFIED
266
+
267
+ if self.interface_type in [
268
+ InterfaceTypes.STANDARD,
269
+ InterfaceTypes.OUTPUT_ONLY,
270
+ ]:
271
+ for o in self.output_components:
272
+ assert isinstance(o, IOComponent)
273
+ o.interactive = False # Force output components to be non-interactive
274
+
275
+ if (
276
+ interpretation is None
277
+ or isinstance(interpretation, list)
278
+ or callable(interpretation)
279
+ ):
280
+ self.interpretation = interpretation
281
+ elif isinstance(interpretation, str):
282
+ self.interpretation = [
283
+ interpretation.lower() for _ in self.input_components
284
+ ]
285
+ else:
286
+ raise ValueError("Invalid value for parameter: interpretation")
287
+
288
+ self.api_mode = _api_mode
289
+ self.fn = fn
290
+ self.fn_durations = [0, 0]
291
+ self.__name__ = getattr(fn, "__name__", "fn")
292
+ self.live = live
293
+ self.title = title
294
+
295
+ CLEANER = re.compile("<.*?>")
296
+
297
+ def clean_html(raw_html):
298
+ cleantext = re.sub(CLEANER, "", raw_html)
299
+ return cleantext
300
+
301
+ md = (
302
+ MarkdownIt(
303
+ "js-default",
304
+ {
305
+ "linkify": True,
306
+ "typographer": True,
307
+ "html": True,
308
+ },
309
+ )
310
+ .use(dollarmath_plugin)
311
+ .use(footnote_plugin)
312
+ .enable("table")
313
+ )
314
+
315
+ simple_description = None
316
+ if description is not None:
317
+ description = md.render(description)
318
+ simple_description = clean_html(description)
319
+ self.simple_description = simple_description
320
+ self.description = description
321
+ if article is not None:
322
+ article = utils.readme_to_html(article)
323
+ article = md.render(article)
324
+ self.article = article
325
+
326
+ self.thumbnail = thumbnail
327
+ self.theme = theme or os.getenv("GRADIO_THEME", "default")
328
+ if not (self.theme == "default"):
329
+ warnings.warn("Currently, only the 'default' theme is supported.")
330
+
331
+ self.examples = examples
332
+ self.num_shap = num_shap
333
+ self.examples_per_page = examples_per_page
334
+
335
+ self.simple_server = None
336
+
337
+ # For analytics_enabled and allow_flagging: (1) first check for
338
+ # parameter, (2) check for env variable, (3) default to True/"manual"
339
+ self.analytics_enabled = (
340
+ analytics_enabled
341
+ if analytics_enabled is not None
342
+ else os.getenv("GRADIO_ANALYTICS_ENABLED", "True") == "True"
343
+ )
344
+ if allow_flagging is None:
345
+ allow_flagging = os.getenv("GRADIO_ALLOW_FLAGGING", "manual")
346
+ if allow_flagging is True:
347
+ warnings.warn(
348
+ "The `allow_flagging` parameter in `Interface` now"
349
+ "takes a string value ('auto', 'manual', or 'never')"
350
+ ", not a boolean. Setting parameter to: 'manual'."
351
+ )
352
+ self.allow_flagging = "manual"
353
+ elif allow_flagging == "manual":
354
+ self.allow_flagging = "manual"
355
+ elif allow_flagging is False:
356
+ warnings.warn(
357
+ "The `allow_flagging` parameter in `Interface` now"
358
+ "takes a string value ('auto', 'manual', or 'never')"
359
+ ", not a boolean. Setting parameter to: 'never'."
360
+ )
361
+ self.allow_flagging = "never"
362
+ elif allow_flagging == "never":
363
+ self.allow_flagging = "never"
364
+ elif allow_flagging == "auto":
365
+ self.allow_flagging = "auto"
366
+ else:
367
+ raise ValueError(
368
+ "Invalid value for `allow_flagging` parameter."
369
+ "Must be: 'auto', 'manual', or 'never'."
370
+ )
371
+
372
+ self.flagging_options = flagging_options
373
+ self.flagging_callback = flagging_callback
374
+ self.flagging_dir = flagging_dir
375
+ self.batch = batch
376
+ self.max_batch_size = max_batch_size
377
+
378
+ self.save_to = None # Used for selenium tests
379
+ self.share = None
380
+ self.share_url = None
381
+ self.local_url = None
382
+
383
+ self.favicon_path = None
384
+
385
+ if self.analytics_enabled:
386
+ data = {
387
+ "mode": self.mode,
388
+ "fn": fn,
389
+ "inputs": inputs,
390
+ "outputs": outputs,
391
+ "live": live,
392
+ "ip_address": self.ip_address,
393
+ "interpretation": interpretation,
394
+ "allow_flagging": allow_flagging,
395
+ "custom_css": self.css is not None,
396
+ "theme": self.theme,
397
+ "version": (pkgutil.get_data(__name__, "version.txt") or b"")
398
+ .decode("ascii")
399
+ .strip(),
400
+ }
401
+ utils.initiated_analytics(data)
402
+
403
+ utils.version_check()
404
+ Interface.instances.add(self)
405
+
406
+ param_names = inspect.getfullargspec(self.fn)[0]
407
+ for component, param_name in zip(self.input_components, param_names):
408
+ assert isinstance(component, IOComponent)
409
+ if component.label is None:
410
+ component.label = param_name
411
+ for i, component in enumerate(self.output_components):
412
+ assert isinstance(component, IOComponent)
413
+ if component.label is None:
414
+ if len(self.output_components) == 1:
415
+ component.label = "output"
416
+ else:
417
+ component.label = "output " + str(i)
418
+
419
+ if self.allow_flagging != "never":
420
+ if (
421
+ self.interface_type == InterfaceTypes.UNIFIED
422
+ or self.allow_flagging == "auto"
423
+ ):
424
+ self.flagging_callback.setup(self.input_components, self.flagging_dir) # type: ignore
425
+ elif self.interface_type == InterfaceTypes.INPUT_ONLY:
426
+ pass
427
+ else:
428
+ self.flagging_callback.setup(
429
+ self.input_components + self.output_components, self.flagging_dir # type: ignore
430
+ )
431
+
432
+ # Render the Gradio UI
433
+ with self:
434
+ self.render_title_description()
435
+
436
+ submit_btn, clear_btn, stop_btn, flag_btns = None, None, None, None
437
+ interpretation_btn, interpretation_set = None, None
438
+ input_component_column, interpret_component_column = None, None
439
+
440
+ with Row().style(equal_height=False):
441
+ if self.interface_type in [
442
+ InterfaceTypes.STANDARD,
443
+ InterfaceTypes.INPUT_ONLY,
444
+ InterfaceTypes.UNIFIED,
445
+ ]:
446
+ (
447
+ submit_btn,
448
+ clear_btn,
449
+ stop_btn,
450
+ flag_btns,
451
+ input_component_column,
452
+ interpret_component_column,
453
+ interpretation_set,
454
+ ) = self.render_input_column()
455
+ if self.interface_type in [
456
+ InterfaceTypes.STANDARD,
457
+ InterfaceTypes.OUTPUT_ONLY,
458
+ ]:
459
+ (
460
+ submit_btn_out,
461
+ clear_btn_2_out,
462
+ stop_btn_2_out,
463
+ flag_btns_out,
464
+ interpretation_btn,
465
+ ) = self.render_output_column(submit_btn)
466
+ submit_btn = submit_btn or submit_btn_out
467
+ clear_btn = clear_btn or clear_btn_2_out
468
+ stop_btn = stop_btn or stop_btn_2_out
469
+ flag_btns = flag_btns or flag_btns_out
470
+
471
+ assert clear_btn is not None, "Clear button not rendered"
472
+
473
+ self.attach_submit_events(submit_btn, stop_btn)
474
+ self.attach_clear_events(
475
+ clear_btn, input_component_column, interpret_component_column
476
+ )
477
+ self.attach_interpretation_events(
478
+ interpretation_btn,
479
+ interpretation_set,
480
+ input_component_column,
481
+ interpret_component_column,
482
+ )
483
+
484
+ self.render_flagging_buttons(flag_btns)
485
+ self.render_examples()
486
+ self.render_article()
487
+
488
+ self.config = self.get_config_file()
489
+
490
+ def render_title_description(self) -> None:
491
+ if self.title:
492
+ Markdown(
493
+ "<h1 style='text-align: center; margin-bottom: 1rem'>"
494
+ + self.title
495
+ + "</h1>"
496
+ )
497
+ if self.description:
498
+ Markdown(self.description)
499
+
500
+ def render_flag_btns(self) -> List[Tuple[Button, str | None]]:
501
+ if self.flagging_options is None:
502
+ return [(Button("Flag"), None)]
503
+ else:
504
+ return [
505
+ (
506
+ Button("Flag as " + flag_option),
507
+ flag_option,
508
+ )
509
+ for flag_option in self.flagging_options
510
+ ]
511
+
512
+ def render_input_column(
513
+ self,
514
+ ) -> Tuple[
515
+ Button | None,
516
+ Button | None,
517
+ Button | None,
518
+ List | None,
519
+ Column,
520
+ Column | None,
521
+ List[Interpretation] | None,
522
+ ]:
523
+ submit_btn, clear_btn, stop_btn, flag_btns = None, None, None, None
524
+ interpret_component_column, interpretation_set = None, None
525
+
526
+ with Column(variant="panel"):
527
+ input_component_column = Column()
528
+ with input_component_column:
529
+ for component in self.input_components:
530
+ component.render()
531
+ if self.interpretation:
532
+ interpret_component_column = Column(visible=False)
533
+ interpretation_set = []
534
+ with interpret_component_column:
535
+ for component in self.input_components:
536
+ interpretation_set.append(Interpretation(component))
537
+ with Row():
538
+ if self.interface_type in [
539
+ InterfaceTypes.STANDARD,
540
+ InterfaceTypes.INPUT_ONLY,
541
+ ]:
542
+ clear_btn = Button("Clear")
543
+ if not self.live:
544
+ submit_btn = Button("Submit", variant="primary")
545
+ # Stopping jobs only works if the queue is enabled
546
+ # We don't know if the queue is enabled when the interface
547
+ # is created. We use whether a generator function is provided
548
+ # as a proxy of whether the queue will be enabled.
549
+ # Using a generator function without the queue will raise an error.
550
+ if inspect.isgeneratorfunction(self.fn):
551
+ stop_btn = Button("Stop", variant="stop")
552
+ elif self.interface_type == InterfaceTypes.UNIFIED:
553
+ clear_btn = Button("Clear")
554
+ submit_btn = Button("Submit", variant="primary")
555
+ if inspect.isgeneratorfunction(self.fn) and not self.live:
556
+ stop_btn = Button("Stop", variant="stop")
557
+ if self.allow_flagging == "manual":
558
+ flag_btns = self.render_flag_btns()
559
+ elif self.allow_flagging == "auto":
560
+ flag_btns = [(submit_btn, None)]
561
+ return (
562
+ submit_btn,
563
+ clear_btn,
564
+ stop_btn,
565
+ flag_btns,
566
+ input_component_column,
567
+ interpret_component_column,
568
+ interpretation_set,
569
+ )
570
+
571
+ def render_output_column(
572
+ self,
573
+ submit_btn_in: Button | None,
574
+ ) -> Tuple[Button | None, Button | None, Button | None, List | None, Button | None]:
575
+ submit_btn = submit_btn_in
576
+ interpretation_btn, clear_btn, flag_btns, stop_btn = None, None, None, None
577
+
578
+ with Column(variant="panel"):
579
+ for component in self.output_components:
580
+ if not (isinstance(component, State)):
581
+ component.render()
582
+ with Row():
583
+ if self.interface_type == InterfaceTypes.OUTPUT_ONLY:
584
+ clear_btn = Button("Clear")
585
+ submit_btn = Button("Generate", variant="primary")
586
+ if inspect.isgeneratorfunction(self.fn) and not self.live:
587
+ # Stopping jobs only works if the queue is enabled
588
+ # We don't know if the queue is enabled when the interface
589
+ # is created. We use whether a generator function is provided
590
+ # as a proxy of whether the queue will be enabled.
591
+ # Using a generator function without the queue will raise an error.
592
+ stop_btn = Button("Stop", variant="stop")
593
+ if self.allow_flagging == "manual":
594
+ flag_btns = self.render_flag_btns()
595
+ elif self.allow_flagging == "auto":
596
+ assert submit_btn is not None, "Submit button not rendered"
597
+ flag_btns = [(submit_btn, None)]
598
+ if self.interpretation:
599
+ interpretation_btn = Button("Interpret")
600
+
601
+ return submit_btn, clear_btn, stop_btn, flag_btns, interpretation_btn
602
+
603
+ def render_article(self):
604
+ if self.article:
605
+ Markdown(self.article)
606
+
607
+ def attach_submit_events(self, submit_btn: Button | None, stop_btn: Button | None):
608
+ if self.live:
609
+ if self.interface_type == InterfaceTypes.OUTPUT_ONLY:
610
+ assert submit_btn is not None, "Submit button not rendered"
611
+ super().load(self.fn, None, self.output_components)
612
+ # For output-only interfaces, the user probably still want a "generate"
613
+ # button even if the Interface is live
614
+ submit_btn.click(
615
+ self.fn,
616
+ None,
617
+ self.output_components,
618
+ api_name="predict",
619
+ preprocess=not (self.api_mode),
620
+ postprocess=not (self.api_mode),
621
+ batch=self.batch,
622
+ max_batch_size=self.max_batch_size,
623
+ )
624
+ else:
625
+ for component in self.input_components:
626
+ if isinstance(component, Streamable) and component.streaming:
627
+ component.stream(
628
+ self.fn,
629
+ self.input_components,
630
+ self.output_components,
631
+ api_name="predict",
632
+ preprocess=not (self.api_mode),
633
+ postprocess=not (self.api_mode),
634
+ )
635
+ continue
636
+ if isinstance(component, Changeable):
637
+ component.change(
638
+ self.fn,
639
+ self.input_components,
640
+ self.output_components,
641
+ api_name="predict",
642
+ preprocess=not (self.api_mode),
643
+ postprocess=not (self.api_mode),
644
+ )
645
+ else:
646
+ assert submit_btn is not None, "Submit button not rendered"
647
+ pred = submit_btn.click(
648
+ self.fn,
649
+ self.input_components,
650
+ self.output_components,
651
+ api_name="predict",
652
+ scroll_to_output=True,
653
+ preprocess=not (self.api_mode),
654
+ postprocess=not (self.api_mode),
655
+ batch=self.batch,
656
+ max_batch_size=self.max_batch_size,
657
+ )
658
+ if stop_btn:
659
+ stop_btn.click(
660
+ None,
661
+ inputs=None,
662
+ outputs=None,
663
+ cancels=[pred],
664
+ )
665
+
666
+ def attach_clear_events(
667
+ self,
668
+ clear_btn: Button,
669
+ input_component_column: Column | None,
670
+ interpret_component_column: Column | None,
671
+ ):
672
+ clear_btn.click(
673
+ None,
674
+ [],
675
+ (
676
+ self.input_components
677
+ + self.output_components
678
+ + ([input_component_column] if input_component_column else [])
679
+ + ([interpret_component_column] if self.interpretation else [])
680
+ ), # type: ignore
681
+ _js=f"""() => {json.dumps(
682
+ [getattr(component, "cleared_value", None)
683
+ for component in self.input_components + self.output_components] + (
684
+ [Column.update(visible=True)]
685
+ if self.interface_type
686
+ in [
687
+ InterfaceTypes.STANDARD,
688
+ InterfaceTypes.INPUT_ONLY,
689
+ InterfaceTypes.UNIFIED,
690
+ ]
691
+ else []
692
+ )
693
+ + ([Column.update(visible=False)] if self.interpretation else [])
694
+ )}
695
+ """,
696
+ )
697
+
698
+ def attach_interpretation_events(
699
+ self,
700
+ interpretation_btn: Button | None,
701
+ interpretation_set: List[Interpretation] | None,
702
+ input_component_column: Column | None,
703
+ interpret_component_column: Column | None,
704
+ ):
705
+ if interpretation_btn:
706
+ interpretation_btn.click(
707
+ self.interpret_func,
708
+ inputs=self.input_components + self.output_components,
709
+ outputs=interpretation_set
710
+ or [] + [input_component_column, interpret_component_column], # type: ignore
711
+ preprocess=False,
712
+ )
713
+
714
+ def render_flagging_buttons(self, flag_btns: List | None):
715
+ if flag_btns:
716
+ if self.interface_type in [
717
+ InterfaceTypes.STANDARD,
718
+ InterfaceTypes.OUTPUT_ONLY,
719
+ InterfaceTypes.UNIFIED,
720
+ ]:
721
+ if (
722
+ self.interface_type == InterfaceTypes.UNIFIED
723
+ or self.allow_flagging == "auto"
724
+ ):
725
+ flag_components = self.input_components
726
+ else:
727
+ flag_components = self.input_components + self.output_components
728
+ for flag_btn, flag_option in flag_btns:
729
+ flag_method = FlagMethod(self.flagging_callback, flag_option)
730
+ flag_btn.click(
731
+ flag_method,
732
+ inputs=flag_components,
733
+ outputs=[],
734
+ preprocess=False,
735
+ queue=False,
736
+ )
737
+
738
+ def render_examples(self):
739
+ if self.examples:
740
+ non_state_inputs = [
741
+ c for c in self.input_components if not isinstance(c, State)
742
+ ]
743
+ non_state_outputs = [
744
+ c for c in self.output_components if not isinstance(c, State)
745
+ ]
746
+ self.examples_handler = Examples(
747
+ examples=self.examples,
748
+ inputs=non_state_inputs, # type: ignore
749
+ outputs=non_state_outputs, # type: ignore
750
+ fn=self.fn,
751
+ cache_examples=self.cache_examples,
752
+ examples_per_page=self.examples_per_page,
753
+ _api_mode=self.api_mode,
754
+ batch=self.batch,
755
+ )
756
+
757
+ def __str__(self):
758
+ return self.__repr__()
759
+
760
+ def __repr__(self):
761
+ repr = f"Gradio Interface for: {self.__name__}"
762
+ repr += "\n" + "-" * len(repr)
763
+ repr += "\ninputs:"
764
+ for component in self.input_components:
765
+ repr += "\n|-{}".format(str(component))
766
+ repr += "\noutputs:"
767
+ for component in self.output_components:
768
+ repr += "\n|-{}".format(str(component))
769
+ return repr
770
+
771
+ async def interpret_func(self, *args):
772
+ return await self.interpret(list(args)) + [
773
+ Column.update(visible=False),
774
+ Column.update(visible=True),
775
+ ]
776
+
777
+ async def interpret(self, raw_input: List[Any]) -> List[Any]:
778
+ return [
779
+ {"original": raw_value, "interpretation": interpretation}
780
+ for interpretation, raw_value in zip(
781
+ (await interpretation.run_interpret(self, raw_input))[0], raw_input
782
+ )
783
+ ]
784
+
785
+ def test_launch(self) -> None:
786
+ """
787
+ Deprecated.
788
+ """
789
+ warnings.warn("The Interface.test_launch() function is deprecated.")
790
+
791
+
792
+ @document()
793
+ class TabbedInterface(Blocks):
794
+ """
795
+ A TabbedInterface is created by providing a list of Interfaces, each of which gets
796
+ rendered in a separate tab.
797
+ Demos: stt_or_tts
798
+ """
799
+
800
+ def __init__(
801
+ self,
802
+ interface_list: List[Interface],
803
+ tab_names: List[str] | None = None,
804
+ title: str | None = None,
805
+ theme: str = "default",
806
+ analytics_enabled: bool | None = None,
807
+ css: str | None = None,
808
+ ):
809
+ """
810
+ Parameters:
811
+ interface_list: a list of interfaces to be rendered in tabs.
812
+ tab_names: a list of tab names. If None, the tab names will be "Tab 1", "Tab 2", etc.
813
+ title: a title for the interface; if provided, appears above the input and output components in large font. Also used as the tab title when opened in a browser window.
814
+ theme: which theme to use - right now, only "default" is supported.
815
+ analytics_enabled: whether to allow basic telemetry. If None, will use GRADIO_ANALYTICS_ENABLED environment variable or default to True.
816
+ css: custom css or path to custom css file to apply to entire Blocks
817
+ Returns:
818
+ a Gradio Tabbed Interface for the given interfaces
819
+ """
820
+ super().__init__(
821
+ title=title or "Gradio",
822
+ theme=theme,
823
+ analytics_enabled=analytics_enabled,
824
+ mode="tabbed_interface",
825
+ css=css,
826
+ )
827
+ if tab_names is None:
828
+ tab_names = ["Tab {}".format(i) for i in range(len(interface_list))]
829
+ with self:
830
+ if title:
831
+ Markdown(
832
+ "<h1 style='text-align: center; margin-bottom: 1rem'>"
833
+ + title
834
+ + "</h1>"
835
+ )
836
+ with Tabs():
837
+ for (interface, tab_name) in zip(interface_list, tab_names):
838
+ with Tab(label=tab_name):
839
+ interface.render()
840
+
841
+
842
+ def close_all(verbose: bool = True) -> None:
843
+ for io in Interface.get_instances():
844
+ io.close(verbose)
gradio-modified/gradio/interpretation.py ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import math
3
+
4
+ import numpy as np
5
+
6
+ from gradio import utils
7
+ from gradio.components import Label, Number
8
+
9
+
10
+ async def run_interpret(interface, raw_input):
11
+ """
12
+ Runs the interpretation command for the machine learning model. Handles both the "default" out-of-the-box
13
+ interpretation for a certain set of UI component types, as well as the custom interpretation case.
14
+ Parameters:
15
+ raw_input: a list of raw inputs to apply the interpretation(s) on.
16
+ """
17
+ if isinstance(interface.interpretation, list): # Either "default" or "shap"
18
+ processed_input = [
19
+ input_component.preprocess(raw_input[i])
20
+ for i, input_component in enumerate(interface.input_components)
21
+ ]
22
+ original_output = await interface.call_function(0, processed_input)
23
+ original_output = original_output["prediction"]
24
+
25
+ if len(interface.output_components) == 1:
26
+ original_output = [original_output]
27
+
28
+ scores, alternative_outputs = [], []
29
+
30
+ for i, (x, interp) in enumerate(zip(raw_input, interface.interpretation)):
31
+ if interp == "default":
32
+ input_component = interface.input_components[i]
33
+ neighbor_raw_input = list(raw_input)
34
+ if input_component.interpret_by_tokens:
35
+ tokens, neighbor_values, masks = input_component.tokenize(x)
36
+ interface_scores = []
37
+ alternative_output = []
38
+ for neighbor_input in neighbor_values:
39
+ neighbor_raw_input[i] = neighbor_input
40
+ processed_neighbor_input = [
41
+ input_component.preprocess(neighbor_raw_input[i])
42
+ for i, input_component in enumerate(
43
+ interface.input_components
44
+ )
45
+ ]
46
+
47
+ neighbor_output = await interface.call_function(
48
+ 0, processed_neighbor_input
49
+ )
50
+ neighbor_output = neighbor_output["prediction"]
51
+ if len(interface.output_components) == 1:
52
+ neighbor_output = [neighbor_output]
53
+ processed_neighbor_output = [
54
+ output_component.postprocess(neighbor_output[i])
55
+ for i, output_component in enumerate(
56
+ interface.output_components
57
+ )
58
+ ]
59
+
60
+ alternative_output.append(processed_neighbor_output)
61
+ interface_scores.append(
62
+ quantify_difference_in_label(
63
+ interface, original_output, neighbor_output
64
+ )
65
+ )
66
+ alternative_outputs.append(alternative_output)
67
+ scores.append(
68
+ input_component.get_interpretation_scores(
69
+ raw_input[i],
70
+ neighbor_values,
71
+ interface_scores,
72
+ masks=masks,
73
+ tokens=tokens,
74
+ )
75
+ )
76
+ else:
77
+ (
78
+ neighbor_values,
79
+ interpret_kwargs,
80
+ ) = input_component.get_interpretation_neighbors(x)
81
+ interface_scores = []
82
+ alternative_output = []
83
+ for neighbor_input in neighbor_values:
84
+ neighbor_raw_input[i] = neighbor_input
85
+ processed_neighbor_input = [
86
+ input_component.preprocess(neighbor_raw_input[i])
87
+ for i, input_component in enumerate(
88
+ interface.input_components
89
+ )
90
+ ]
91
+ neighbor_output = await interface.call_function(
92
+ 0, processed_neighbor_input
93
+ )
94
+ neighbor_output = neighbor_output["prediction"]
95
+ if len(interface.output_components) == 1:
96
+ neighbor_output = [neighbor_output]
97
+ processed_neighbor_output = [
98
+ output_component.postprocess(neighbor_output[i])
99
+ for i, output_component in enumerate(
100
+ interface.output_components
101
+ )
102
+ ]
103
+
104
+ alternative_output.append(processed_neighbor_output)
105
+ interface_scores.append(
106
+ quantify_difference_in_label(
107
+ interface, original_output, neighbor_output
108
+ )
109
+ )
110
+ alternative_outputs.append(alternative_output)
111
+ interface_scores = [-score for score in interface_scores]
112
+ scores.append(
113
+ input_component.get_interpretation_scores(
114
+ raw_input[i],
115
+ neighbor_values,
116
+ interface_scores,
117
+ **interpret_kwargs
118
+ )
119
+ )
120
+ elif interp == "shap" or interp == "shapley":
121
+ try:
122
+ import shap # type: ignore
123
+ except (ImportError, ModuleNotFoundError):
124
+ raise ValueError(
125
+ "The package `shap` is required for this interpretation method. Try: `pip install shap`"
126
+ )
127
+ input_component = interface.input_components[i]
128
+ if not (input_component.interpret_by_tokens):
129
+ raise ValueError(
130
+ "Input component {} does not support `shap` interpretation".format(
131
+ input_component
132
+ )
133
+ )
134
+
135
+ tokens, _, masks = input_component.tokenize(x)
136
+
137
+ # construct a masked version of the input
138
+ def get_masked_prediction(binary_mask):
139
+ masked_xs = input_component.get_masked_inputs(tokens, binary_mask)
140
+ preds = []
141
+ for masked_x in masked_xs:
142
+ processed_masked_input = copy.deepcopy(processed_input)
143
+ processed_masked_input[i] = input_component.preprocess(masked_x)
144
+ new_output = utils.synchronize_async(
145
+ interface.call_function, 0, processed_masked_input
146
+ )
147
+ new_output = new_output["prediction"]
148
+ if len(interface.output_components) == 1:
149
+ new_output = [new_output]
150
+ pred = get_regression_or_classification_value(
151
+ interface, original_output, new_output
152
+ )
153
+ preds.append(pred)
154
+ return np.array(preds)
155
+
156
+ num_total_segments = len(tokens)
157
+ explainer = shap.KernelExplainer(
158
+ get_masked_prediction, np.zeros((1, num_total_segments))
159
+ )
160
+ shap_values = explainer.shap_values(
161
+ np.ones((1, num_total_segments)),
162
+ nsamples=int(interface.num_shap * num_total_segments),
163
+ silent=True,
164
+ )
165
+ scores.append(
166
+ input_component.get_interpretation_scores(
167
+ raw_input[i], None, shap_values[0], masks=masks, tokens=tokens
168
+ )
169
+ )
170
+ alternative_outputs.append([])
171
+ elif interp is None:
172
+ scores.append(None)
173
+ alternative_outputs.append([])
174
+ else:
175
+ raise ValueError("Unknown intepretation method: {}".format(interp))
176
+ return scores, alternative_outputs
177
+ else: # custom interpretation function
178
+ processed_input = [
179
+ input_component.preprocess(raw_input[i])
180
+ for i, input_component in enumerate(interface.input_components)
181
+ ]
182
+ interpreter = interface.interpretation
183
+ interpretation = interpreter(*processed_input)
184
+ if len(raw_input) == 1:
185
+ interpretation = [interpretation]
186
+ return interpretation, []
187
+
188
+
189
+ def diff(original, perturbed):
190
+ try: # try computing numerical difference
191
+ score = float(original) - float(perturbed)
192
+ except ValueError: # otherwise, look at strict difference in label
193
+ score = int(not (original == perturbed))
194
+ return score
195
+
196
+
197
+ def quantify_difference_in_label(interface, original_output, perturbed_output):
198
+ output_component = interface.output_components[0]
199
+ post_original_output = output_component.postprocess(original_output[0])
200
+ post_perturbed_output = output_component.postprocess(perturbed_output[0])
201
+
202
+ if isinstance(output_component, Label):
203
+ original_label = post_original_output["label"]
204
+ perturbed_label = post_perturbed_output["label"]
205
+
206
+ # Handle different return types of Label interface
207
+ if "confidences" in post_original_output:
208
+ original_confidence = original_output[0][original_label]
209
+ perturbed_confidence = perturbed_output[0][original_label]
210
+ score = original_confidence - perturbed_confidence
211
+ else:
212
+ score = diff(original_label, perturbed_label)
213
+ return score
214
+
215
+ elif isinstance(output_component, Number):
216
+ score = diff(post_original_output, post_perturbed_output)
217
+ return score
218
+
219
+ else:
220
+ raise ValueError(
221
+ "This interpretation method doesn't support the Output component: {}".format(
222
+ output_component
223
+ )
224
+ )
225
+
226
+
227
+ def get_regression_or_classification_value(
228
+ interface, original_output, perturbed_output
229
+ ):
230
+ """Used to combine regression/classification for Shap interpretation method."""
231
+ output_component = interface.output_components[0]
232
+ post_original_output = output_component.postprocess(original_output[0])
233
+ post_perturbed_output = output_component.postprocess(perturbed_output[0])
234
+
235
+ if type(output_component) == Label:
236
+ original_label = post_original_output["label"]
237
+ perturbed_label = post_perturbed_output["label"]
238
+
239
+ # Handle different return types of Label interface
240
+ if "confidences" in post_original_output:
241
+ if math.isnan(perturbed_output[0][original_label]):
242
+ return 0
243
+ return perturbed_output[0][original_label]
244
+ else:
245
+ score = diff(
246
+ perturbed_label, original_label
247
+ ) # Intentionally inverted order of arguments.
248
+ return score
249
+
250
+ else:
251
+ raise ValueError(
252
+ "This interpretation method doesn't support the Output component: {}".format(
253
+ output_component
254
+ )
255
+ )
gradio-modified/gradio/ipython_ext.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ try:
2
+ from IPython.core.magic import needs_local_scope, register_cell_magic
3
+ except ImportError:
4
+ pass
5
+
6
+ import gradio
7
+
8
+
9
+ def load_ipython_extension(ipython):
10
+ __demo = gradio.Blocks()
11
+
12
+ @register_cell_magic
13
+ @needs_local_scope
14
+ def blocks(line, cell, local_ns=None):
15
+ with __demo.clear():
16
+ exec(cell, None, local_ns)
17
+ __demo.launch(quiet=True)
gradio-modified/gradio/launches.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"launches": 145}
gradio-modified/gradio/layouts.py ADDED
@@ -0,0 +1,377 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import warnings
4
+ from typing import TYPE_CHECKING, Callable, List, Type
5
+
6
+ from gradio.blocks import BlockContext
7
+ from gradio.documentation import document, set_documentation_group
8
+
9
+ set_documentation_group("layout")
10
+
11
+ if TYPE_CHECKING: # Only import for type checking (is False at runtime).
12
+ from gradio.components import Component
13
+
14
+
15
+ @document()
16
+ class Row(BlockContext):
17
+ """
18
+ Row is a layout element within Blocks that renders all children horizontally.
19
+ Example:
20
+ with gradio.Blocks() as demo:
21
+ with gradio.Row():
22
+ gr.Image("lion.jpg")
23
+ gr.Image("tiger.jpg")
24
+ demo.launch()
25
+ Guides: controlling_layout
26
+ """
27
+
28
+ def __init__(
29
+ self,
30
+ *,
31
+ variant: str = "default",
32
+ visible: bool = True,
33
+ elem_id: str | None = None,
34
+ **kwargs,
35
+ ):
36
+ """
37
+ Parameters:
38
+ variant: row type, 'default' (no background), 'panel' (gray background color and rounded corners), or 'compact' (rounded corners and no internal gap).
39
+ visible: If False, row will be hidden.
40
+ elem_id: An optional string that is assigned as the id of this component in the HTML DOM. Can be used for targeting CSS styles.
41
+ """
42
+ self.variant = variant
43
+ if variant == "compact":
44
+ self.allow_expected_parents = False
45
+ super().__init__(visible=visible, elem_id=elem_id, **kwargs)
46
+
47
+ def get_config(self):
48
+ return {"type": "row", "variant": self.variant, **super().get_config()}
49
+
50
+ @staticmethod
51
+ def update(
52
+ visible: bool | None = None,
53
+ ):
54
+ return {
55
+ "visible": visible,
56
+ "__type__": "update",
57
+ }
58
+
59
+ def style(
60
+ self,
61
+ *,
62
+ equal_height: bool | None = None,
63
+ mobile_collapse: bool | None = None,
64
+ **kwargs,
65
+ ):
66
+ """
67
+ Styles the Row.
68
+ Parameters:
69
+ equal_height: If True, makes every child element have equal height
70
+ mobile_collapse: DEPRECATED.
71
+ """
72
+ if equal_height is not None:
73
+ self._style["equal_height"] = equal_height
74
+ if mobile_collapse is not None:
75
+ warnings.warn("mobile_collapse is no longer supported.")
76
+ return self
77
+
78
+
79
+ @document()
80
+ class Column(BlockContext):
81
+ """
82
+ Column is a layout element within Blocks that renders all children vertically. The widths of columns can be set through the `scale` and `min_width` parameters.
83
+ If a certain scale results in a column narrower than min_width, the min_width parameter will win.
84
+ Example:
85
+ with gradio.Blocks() as demo:
86
+ with gradio.Row():
87
+ with gradio.Column(scale=1):
88
+ text1 = gr.Textbox()
89
+ text2 = gr.Textbox()
90
+ with gradio.Column(scale=4):
91
+ btn1 = gr.Button("Button 1")
92
+ btn2 = gr.Button("Button 2")
93
+ Guides: controlling_layout
94
+ """
95
+
96
+ def __init__(
97
+ self,
98
+ *,
99
+ scale: int = 1,
100
+ min_width: int = 320,
101
+ variant: str = "default",
102
+ visible: bool = True,
103
+ elem_id: str | None = None,
104
+ **kwargs,
105
+ ):
106
+ """
107
+ Parameters:
108
+ scale: relative width compared to adjacent Columns. For example, if Column A has scale=2, and Column B has scale=1, A will be twice as wide as B.
109
+ min_width: minimum pixel width of Column, will wrap if not sufficient screen space to satisfy this value. If a certain scale value results in a column narrower than min_width, the min_width parameter will be respected first.
110
+ variant: column type, 'default' (no background), 'panel' (gray background color and rounded corners), or 'compact' (rounded corners and no internal gap).
111
+ visible: If False, column will be hidden.
112
+ elem_id: An optional string that is assigned as the id of this component in the HTML DOM. Can be used for targeting CSS styles.
113
+ """
114
+ self.scale = scale
115
+ self.min_width = min_width
116
+ self.variant = variant
117
+ if variant == "compact":
118
+ self.allow_expected_parents = False
119
+ super().__init__(visible=visible, elem_id=elem_id, **kwargs)
120
+
121
+ def get_config(self):
122
+ return {
123
+ "type": "column",
124
+ "variant": self.variant,
125
+ "scale": self.scale,
126
+ "min_width": self.min_width,
127
+ **super().get_config(),
128
+ }
129
+
130
+ @staticmethod
131
+ def update(
132
+ variant: str | None = None,
133
+ visible: bool | None = None,
134
+ ):
135
+ return {
136
+ "variant": variant,
137
+ "visible": visible,
138
+ "__type__": "update",
139
+ }
140
+
141
+
142
+ class Tabs(BlockContext):
143
+ """
144
+ Tabs is a layout element within Blocks that can contain multiple "Tab" Components.
145
+ """
146
+
147
+ def __init__(
148
+ self,
149
+ *,
150
+ selected: int | str | None = None,
151
+ visible: bool = True,
152
+ elem_id: str | None = None,
153
+ **kwargs,
154
+ ):
155
+ """
156
+ Parameters:
157
+ selected: The currently selected tab. Must correspond to an id passed to the one of the child TabItems. Defaults to the first TabItem.
158
+ visible: If False, Tabs will be hidden.
159
+ elem_id: An optional string that is assigned as the id of this component in the HTML DOM. Can be used for targeting CSS styles.
160
+ """
161
+ super().__init__(visible=visible, elem_id=elem_id, **kwargs)
162
+ self.selected = selected
163
+
164
+ def get_config(self):
165
+ return {"selected": self.selected, **super().get_config()}
166
+
167
+ @staticmethod
168
+ def update(
169
+ selected: int | str | None = None,
170
+ ):
171
+ return {
172
+ "selected": selected,
173
+ "__type__": "update",
174
+ }
175
+
176
+ def change(self, fn: Callable, inputs: List[Component], outputs: List[Component]):
177
+ """
178
+ Parameters:
179
+ fn: Callable function
180
+ inputs: List of inputs
181
+ outputs: List of outputs
182
+ Returns: None
183
+ """
184
+ self.set_event_trigger("change", fn, inputs, outputs)
185
+
186
+
187
+ @document()
188
+ class Tab(BlockContext):
189
+ """
190
+ Tab (or its alias TabItem) is a layout element. Components defined within the Tab will be visible when this tab is selected tab.
191
+ Example:
192
+ with gradio.Blocks() as demo:
193
+ with gradio.Tab("Lion"):
194
+ gr.Image("lion.jpg")
195
+ gr.Button("New Lion")
196
+ with gradio.Tab("Tiger"):
197
+ gr.Image("tiger.jpg")
198
+ gr.Button("New Tiger")
199
+ Guides: controlling_layout
200
+ """
201
+
202
+ def __init__(
203
+ self,
204
+ label: str,
205
+ *,
206
+ id: int | str | None = None,
207
+ elem_id: str | None = None,
208
+ **kwargs,
209
+ ):
210
+ """
211
+ Parameters:
212
+ label: The visual label for the tab
213
+ id: An optional identifier for the tab, required if you wish to control the selected tab from a predict function.
214
+ elem_id: An optional string that is assigned as the id of this component in the HTML DOM. Can be used for targeting CSS styles.
215
+ """
216
+ super().__init__(elem_id=elem_id, **kwargs)
217
+ self.label = label
218
+ self.id = id
219
+
220
+ def get_config(self):
221
+ return {
222
+ "label": self.label,
223
+ "id": self.id,
224
+ **super().get_config(),
225
+ }
226
+
227
+ def select(self, fn: Callable, inputs: List[Component], outputs: List[Component]):
228
+ """
229
+ Parameters:
230
+ fn: Callable function
231
+ inputs: List of inputs
232
+ outputs: List of outputs
233
+ Returns: None
234
+ """
235
+ self.set_event_trigger("select", fn, inputs, outputs)
236
+
237
+ def get_expected_parent(self) -> Type[Tabs]:
238
+ return Tabs
239
+
240
+ def get_block_name(self):
241
+ return "tabitem"
242
+
243
+
244
+ TabItem = Tab
245
+
246
+
247
+ class Group(BlockContext):
248
+ """
249
+ Group is a layout element within Blocks which groups together children so that
250
+ they do not have any padding or margin between them.
251
+ Example:
252
+ with gradio.Group():
253
+ gr.Textbox(label="First")
254
+ gr.Textbox(label="Last")
255
+ """
256
+
257
+ def __init__(
258
+ self,
259
+ *,
260
+ visible: bool = True,
261
+ elem_id: str | None = None,
262
+ **kwargs,
263
+ ):
264
+ """
265
+ Parameters:
266
+ visible: If False, group will be hidden.
267
+ elem_id: An optional string that is assigned as the id of this component in the HTML DOM. Can be used for targeting CSS styles.
268
+ """
269
+ super().__init__(visible=visible, elem_id=elem_id, **kwargs)
270
+
271
+ def get_config(self):
272
+ return {"type": "group", **super().get_config()}
273
+
274
+ @staticmethod
275
+ def update(
276
+ visible: bool | None = None,
277
+ ):
278
+ return {
279
+ "visible": visible,
280
+ "__type__": "update",
281
+ }
282
+
283
+
284
+ @document()
285
+ class Box(BlockContext):
286
+ """
287
+ Box is a a layout element which places children in a box with rounded corners and
288
+ some padding around them.
289
+ Example:
290
+ with gradio.Box():
291
+ gr.Textbox(label="First")
292
+ gr.Textbox(label="Last")
293
+ """
294
+
295
+ def __init__(
296
+ self,
297
+ *,
298
+ visible: bool = True,
299
+ elem_id: str | None = None,
300
+ **kwargs,
301
+ ):
302
+ """
303
+ Parameters:
304
+ visible: If False, box will be hidden.
305
+ elem_id: An optional string that is assigned as the id of this component in the HTML DOM. Can be used for targeting CSS styles.
306
+ """
307
+ super().__init__(visible=visible, elem_id=elem_id, **kwargs)
308
+
309
+ def get_config(self):
310
+ return {"type": "box", **super().get_config()}
311
+
312
+ @staticmethod
313
+ def update(
314
+ visible: bool | None = None,
315
+ ):
316
+ return {
317
+ "visible": visible,
318
+ "__type__": "update",
319
+ }
320
+
321
+ def style(self, **kwargs):
322
+ return self
323
+
324
+
325
+ class Form(BlockContext):
326
+ def get_config(self):
327
+ return {"type": "form", **super().get_config()}
328
+
329
+
330
+ @document()
331
+ class Accordion(BlockContext):
332
+ """
333
+ Accordion is a layout element which can be toggled to show/hide the contained content.
334
+ Example:
335
+ with gradio.Accordion("See Details"):
336
+ gr.Markdown("lorem ipsum")
337
+ """
338
+
339
+ def __init__(
340
+ self,
341
+ label,
342
+ *,
343
+ open: bool = True,
344
+ visible: bool = True,
345
+ elem_id: str | None = None,
346
+ **kwargs,
347
+ ):
348
+ """
349
+ Parameters:
350
+ label: name of accordion section.
351
+ open: if True, accordion is open by default.
352
+ elem_id: An optional string that is assigned as the id of this component in the HTML DOM. Can be used for targeting CSS styles.
353
+ """
354
+ self.label = label
355
+ self.open = open
356
+ super().__init__(visible=visible, elem_id=elem_id, **kwargs)
357
+
358
+ def get_config(self):
359
+ return {
360
+ "type": "accordion",
361
+ "open": self.open,
362
+ "label": self.label,
363
+ **super().get_config(),
364
+ }
365
+
366
+ @staticmethod
367
+ def update(
368
+ open: bool | None = None,
369
+ label: str | None = None,
370
+ visible: bool | None = None,
371
+ ):
372
+ return {
373
+ "visible": visible,
374
+ "label": label,
375
+ "open": open,
376
+ "__type__": "update",
377
+ }
gradio-modified/gradio/media_data.py ADDED
The diff for this file is too large to render. See raw diff
 
gradio-modified/gradio/mix.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Ways to transform interfaces to produce new interfaces
3
+ """
4
+ import asyncio
5
+ import warnings
6
+
7
+ import gradio
8
+ from gradio.documentation import document, set_documentation_group
9
+
10
+ set_documentation_group("mix_interface")
11
+
12
+
13
+ @document()
14
+ class Parallel(gradio.Interface):
15
+ """
16
+ Creates a new Interface consisting of multiple Interfaces in parallel (comparing their outputs).
17
+ The Interfaces to put in Parallel must share the same input components (but can have different output components).
18
+
19
+ Demos: interface_parallel, interface_parallel_load
20
+ Guides: advanced_interface_features
21
+ """
22
+
23
+ def __init__(self, *interfaces: gradio.Interface, **options):
24
+ """
25
+ Parameters:
26
+ interfaces: any number of Interface objects that are to be compared in parallel
27
+ options: additional kwargs that are passed into the new Interface object to customize it
28
+ Returns:
29
+ an Interface object comparing the given models
30
+ """
31
+ outputs = []
32
+
33
+ for interface in interfaces:
34
+ if not (isinstance(interface, gradio.Interface)):
35
+ warnings.warn(
36
+ "Parallel requires all inputs to be of type Interface. "
37
+ "May not work as expected."
38
+ )
39
+ outputs.extend(interface.output_components)
40
+
41
+ async def parallel_fn(*args):
42
+ return_values_with_durations = await asyncio.gather(
43
+ *[interface.call_function(0, list(args)) for interface in interfaces]
44
+ )
45
+ return_values = [rv["prediction"] for rv in return_values_with_durations]
46
+ combined_list = []
47
+ for interface, return_value in zip(interfaces, return_values):
48
+ if len(interface.output_components) == 1:
49
+ combined_list.append(return_value)
50
+ else:
51
+ combined_list.extend(return_value)
52
+ if len(outputs) == 1:
53
+ return combined_list[0]
54
+ return combined_list
55
+
56
+ parallel_fn.__name__ = " | ".join([io.__name__ for io in interfaces])
57
+
58
+ kwargs = {
59
+ "fn": parallel_fn,
60
+ "inputs": interfaces[0].input_components,
61
+ "outputs": outputs,
62
+ }
63
+ kwargs.update(options)
64
+ super().__init__(**kwargs)
65
+
66
+
67
+ @document()
68
+ class Series(gradio.Interface):
69
+ """
70
+ Creates a new Interface from multiple Interfaces in series (the output of one is fed as the input to the next,
71
+ and so the input and output components must agree between the interfaces).
72
+
73
+ Demos: interface_series, interface_series_load
74
+ Guides: advanced_interface_features
75
+ """
76
+
77
+ def __init__(self, *interfaces: gradio.Interface, **options):
78
+ """
79
+ Parameters:
80
+ interfaces: any number of Interface objects that are to be connected in series
81
+ options: additional kwargs that are passed into the new Interface object to customize it
82
+ Returns:
83
+ an Interface object connecting the given models
84
+ """
85
+
86
+ async def connected_fn(*data):
87
+ for idx, interface in enumerate(interfaces):
88
+ # skip preprocessing for first interface since the Series interface will include it
89
+ if idx > 0 and not (interface.api_mode):
90
+ data = [
91
+ input_component.preprocess(data[i])
92
+ for i, input_component in enumerate(interface.input_components)
93
+ ]
94
+
95
+ # run all of predictions sequentially
96
+ data = (await interface.call_function(0, list(data)))["prediction"]
97
+ if len(interface.output_components) == 1:
98
+ data = [data]
99
+
100
+ # skip postprocessing for final interface since the Series interface will include it
101
+ if idx < len(interfaces) - 1 and not (interface.api_mode):
102
+ data = [
103
+ output_component.postprocess(data[i])
104
+ for i, output_component in enumerate(
105
+ interface.output_components
106
+ )
107
+ ]
108
+
109
+ if len(interface.output_components) == 1: # type: ignore
110
+ return data[0]
111
+ return data
112
+
113
+ for interface in interfaces:
114
+ if not (isinstance(interface, gradio.Interface)):
115
+ warnings.warn(
116
+ "Series requires all inputs to be of type Interface. May "
117
+ "not work as expected."
118
+ )
119
+ connected_fn.__name__ = " => ".join([io.__name__ for io in interfaces])
120
+
121
+ kwargs = {
122
+ "fn": connected_fn,
123
+ "inputs": interfaces[0].input_components,
124
+ "outputs": interfaces[-1].output_components,
125
+ "_api_mode": interfaces[0].api_mode, # TODO: set api_mode per-interface
126
+ }
127
+ kwargs.update(options)
128
+ super().__init__(**kwargs)
gradio-modified/gradio/networking.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Defines helper methods useful for setting up ports, launching servers, and
3
+ creating tunnels.
4
+ """
5
+ from __future__ import annotations
6
+
7
+ import os
8
+ import socket
9
+ import threading
10
+ import time
11
+ import warnings
12
+ from typing import TYPE_CHECKING, Tuple
13
+
14
+ import requests
15
+ import uvicorn
16
+
17
+ from gradio.routes import App
18
+ from gradio.tunneling import Tunnel
19
+
20
+ if TYPE_CHECKING: # Only import for type checking (to avoid circular imports).
21
+ from gradio.blocks import Blocks
22
+
23
+ # By default, the local server will try to open on localhost, port 7860.
24
+ # If that is not available, then it will try 7861, 7862, ... 7959.
25
+ INITIAL_PORT_VALUE = int(os.getenv("GRADIO_SERVER_PORT", "7860"))
26
+ TRY_NUM_PORTS = int(os.getenv("GRADIO_NUM_PORTS", "100"))
27
+ LOCALHOST_NAME = os.getenv("GRADIO_SERVER_NAME", "127.0.0.1")
28
+ GRADIO_API_SERVER = "https://api.gradio.app/v2/tunnel-request"
29
+
30
+
31
+ class Server(uvicorn.Server):
32
+ def install_signal_handlers(self):
33
+ pass
34
+
35
+ def run_in_thread(self):
36
+ self.thread = threading.Thread(target=self.run, daemon=True)
37
+ self.thread.start()
38
+ while not self.started:
39
+ time.sleep(1e-3)
40
+
41
+ def close(self):
42
+ self.should_exit = True
43
+ self.thread.join()
44
+
45
+
46
+ def get_first_available_port(initial: int, final: int) -> int:
47
+ """
48
+ Gets the first open port in a specified range of port numbers
49
+ Parameters:
50
+ initial: the initial value in the range of port numbers
51
+ final: final (exclusive) value in the range of port numbers, should be greater than `initial`
52
+ Returns:
53
+ port: the first open port in the range
54
+ """
55
+ for port in range(initial, final):
56
+ try:
57
+ s = socket.socket() # create a socket object
58
+ s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
59
+ s.bind((LOCALHOST_NAME, port)) # Bind to the port
60
+ s.close()
61
+ return port
62
+ except OSError:
63
+ pass
64
+ raise OSError(
65
+ "All ports from {} to {} are in use. Please close a port.".format(
66
+ initial, final - 1
67
+ )
68
+ )
69
+
70
+
71
+ def configure_app(app: App, blocks: Blocks) -> App:
72
+ auth = blocks.auth
73
+ if auth is not None:
74
+ if not callable(auth):
75
+ app.auth = {account[0]: account[1] for account in auth}
76
+ else:
77
+ app.auth = auth
78
+ else:
79
+ app.auth = None
80
+ app.blocks = blocks
81
+ app.cwd = os.getcwd()
82
+ app.favicon_path = blocks.favicon_path
83
+ app.tokens = {}
84
+ return app
85
+
86
+
87
+ def start_server(
88
+ blocks: Blocks,
89
+ server_name: str | None = None,
90
+ server_port: int | None = None,
91
+ ssl_keyfile: str | None = None,
92
+ ssl_certfile: str | None = None,
93
+ ssl_keyfile_password: str | None = None,
94
+ ) -> Tuple[str, int, str, App, Server]:
95
+ """Launches a local server running the provided Interface
96
+ Parameters:
97
+ blocks: The Blocks object to run on the server
98
+ server_name: to make app accessible on local network, set this to "0.0.0.0". Can be set by environment variable GRADIO_SERVER_NAME.
99
+ server_port: will start gradio app on this port (if available). Can be set by environment variable GRADIO_SERVER_PORT.
100
+ auth: If provided, username and password (or list of username-password tuples) required to access the Blocks. Can also provide function that takes username and password and returns True if valid login.
101
+ ssl_keyfile: If a path to a file is provided, will use this as the private key file to create a local server running on https.
102
+ ssl_certfile: If a path to a file is provided, will use this as the signed certificate for https. Needs to be provided if ssl_keyfile is provided.
103
+ ssl_keyfile_password: If a password is provided, will use this with the ssl certificate for https.
104
+ Returns:
105
+ port: the port number the server is running on
106
+ path_to_local_server: the complete address that the local server can be accessed at
107
+ app: the FastAPI app object
108
+ server: the server object that is a subclass of uvicorn.Server (used to close the server)
109
+ """
110
+ server_name = server_name or LOCALHOST_NAME
111
+ # if port is not specified, search for first available port
112
+ if server_port is None:
113
+ port = get_first_available_port(
114
+ INITIAL_PORT_VALUE, INITIAL_PORT_VALUE + TRY_NUM_PORTS
115
+ )
116
+ else:
117
+ try:
118
+ s = socket.socket()
119
+ s.bind((LOCALHOST_NAME, server_port))
120
+ s.close()
121
+ except OSError:
122
+ raise OSError(
123
+ "Port {} is in use. If a gradio.Blocks is running on the port, you can close() it or gradio.close_all().".format(
124
+ server_port
125
+ )
126
+ )
127
+ port = server_port
128
+
129
+ url_host_name = "localhost" if server_name == "0.0.0.0" else server_name
130
+
131
+ if ssl_keyfile is not None:
132
+ if ssl_certfile is None:
133
+ raise ValueError(
134
+ "ssl_certfile must be provided if ssl_keyfile is provided."
135
+ )
136
+ path_to_local_server = "https://{}:{}/".format(url_host_name, port)
137
+ else:
138
+ path_to_local_server = "http://{}:{}/".format(url_host_name, port)
139
+
140
+ app = App.create_app(blocks)
141
+
142
+ if blocks.save_to is not None: # Used for selenium tests
143
+ blocks.save_to["port"] = port
144
+ config = uvicorn.Config(
145
+ app=app,
146
+ port=port,
147
+ host=server_name,
148
+ log_level="warning",
149
+ ssl_keyfile=ssl_keyfile,
150
+ ssl_certfile=ssl_certfile,
151
+ ssl_keyfile_password=ssl_keyfile_password,
152
+ ws_max_size=1024 * 1024 * 1024, # Setting max websocket size to be 1 GB
153
+ )
154
+ server = Server(config=config)
155
+ server.run_in_thread()
156
+ return server_name, port, path_to_local_server, app, server
157
+
158
+
159
+ def setup_tunnel(local_host: str, local_port: int) -> str:
160
+ response = requests.get(GRADIO_API_SERVER)
161
+ if response and response.status_code == 200:
162
+ try:
163
+ payload = response.json()[0]
164
+ remote_host, remote_port = payload["host"], int(payload["port"])
165
+ tunnel = Tunnel(remote_host, remote_port, local_host, local_port)
166
+ address = tunnel.start_tunnel()
167
+ return address
168
+ except Exception as e:
169
+ raise RuntimeError(str(e))
170
+ else:
171
+ raise RuntimeError("Could not get share link from Gradio API Server.")
172
+
173
+
174
+ def url_ok(url: str) -> bool:
175
+ try:
176
+ for _ in range(5):
177
+ with warnings.catch_warnings():
178
+ warnings.filterwarnings("ignore")
179
+ r = requests.head(url, timeout=3, verify=False)
180
+ if r.status_code in (200, 401, 302): # 401 or 302 if auth is set
181
+ return True
182
+ time.sleep(0.500)
183
+ except (ConnectionError, requests.exceptions.ConnectionError):
184
+ return False
185
+ return False
gradio-modified/gradio/outputs.py ADDED
@@ -0,0 +1,334 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # type: ignore
2
+ """
3
+ This module defines various classes that can serve as the `output` to an interface. Each class must inherit from
4
+ `OutputComponent`, and each class must define a path to its template. All of the subclasses of `OutputComponent` are
5
+ automatically added to a registry, which allows them to be easily referenced in other parts of the code.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import warnings
11
+ from typing import Dict, List, Optional
12
+
13
+ from gradio import components
14
+
15
+
16
+ class Textbox(components.Textbox):
17
+ def __init__(
18
+ self,
19
+ type: str = "text",
20
+ label: Optional[str] = None,
21
+ ):
22
+ warnings.warn(
23
+ "Usage of gradio.outputs is deprecated, and will not be supported in the future, please import your components from gradio.components",
24
+ )
25
+ super().__init__(label=label, type=type)
26
+
27
+
28
+ class Image(components.Image):
29
+ """
30
+ Component displays an output image.
31
+ Output type: Union[numpy.array, PIL.Image, str, matplotlib.pyplot, Tuple[Union[numpy.array, PIL.Image, str], List[Tuple[str, float, float, float, float]]]]
32
+ """
33
+
34
+ def __init__(
35
+ self, type: str = "auto", plot: bool = False, label: Optional[str] = None
36
+ ):
37
+ """
38
+ Parameters:
39
+ type (str): Type of value to be passed to component. "numpy" expects a numpy array with shape (width, height, 3), "pil" expects a PIL image object, "file" expects a file path to the saved image or a remote URL, "plot" expects a matplotlib.pyplot object, "auto" detects return type.
40
+ plot (bool): DEPRECATED. Whether to expect a plot to be returned by the function.
41
+ label (str): component name in interface.
42
+ """
43
+ warnings.warn(
44
+ "Usage of gradio.outputs is deprecated, and will not be supported in the future, please import your components from gradio.components",
45
+ )
46
+ if plot:
47
+ type = "plot"
48
+ super().__init__(type=type, label=label)
49
+
50
+
51
+ class Video(components.Video):
52
+ """
53
+ Used for video output.
54
+ Output type: filepath
55
+ """
56
+
57
+ def __init__(self, type: Optional[str] = None, label: Optional[str] = None):
58
+ """
59
+ Parameters:
60
+ type (str): Type of video format to be passed to component, such as 'avi' or 'mp4'. Use 'mp4' to ensure browser playability. If set to None, video will keep returned format.
61
+ label (str): component name in interface.
62
+ """
63
+ warnings.warn(
64
+ "Usage of gradio.outputs is deprecated, and will not be supported in the future, please import your components from gradio.components",
65
+ )
66
+ super().__init__(format=type, label=label)
67
+
68
+
69
+ class Audio(components.Audio):
70
+ """
71
+ Creates an audio player that plays the output audio.
72
+ Output type: Union[Tuple[int, numpy.array], str]
73
+ """
74
+
75
+ def __init__(self, type: str = "auto", label: Optional[str] = None):
76
+ """
77
+ Parameters:
78
+ type (str): Type of value to be passed to component. "numpy" returns a 2-set tuple with an integer sample_rate and the data as 16-bit int numpy.array of shape (samples, 2), "file" returns a temporary file path to the saved wav audio file, "auto" detects return type.
79
+ label (str): component name in interface.
80
+ """
81
+ warnings.warn(
82
+ "Usage of gradio.outputs is deprecated, and will not be supported in the future, please import your components from gradio.components",
83
+ )
84
+ super().__init__(type=type, label=label)
85
+
86
+
87
+ class File(components.File):
88
+ """
89
+ Used for file output.
90
+ Output type: Union[file-like, str]
91
+ """
92
+
93
+ def __init__(self, label: Optional[str] = None):
94
+ """
95
+ Parameters:
96
+ label (str): component name in interface.
97
+ """
98
+ warnings.warn(
99
+ "Usage of gradio.outputs is deprecated, and will not be supported in the future, please import your components from gradio.components",
100
+ )
101
+ super().__init__(label=label)
102
+
103
+
104
+ class Dataframe(components.Dataframe):
105
+ """
106
+ Component displays 2D output through a spreadsheet interface.
107
+ Output type: Union[pandas.DataFrame, numpy.array, List[Union[str, float]], List[List[Union[str, float]]]]
108
+ """
109
+
110
+ def __init__(
111
+ self,
112
+ headers: Optional[List[str]] = None,
113
+ max_rows: Optional[int] = 20,
114
+ max_cols: Optional[int] = None,
115
+ overflow_row_behaviour: str = "paginate",
116
+ type: str = "auto",
117
+ label: Optional[str] = None,
118
+ ):
119
+ """
120
+ Parameters:
121
+ headers (List[str]): Header names to dataframe. Only applicable if type is "numpy" or "array".
122
+ max_rows (int): Maximum number of rows to display at once. Set to None for infinite.
123
+ max_cols (int): Maximum number of columns to display at once. Set to None for infinite.
124
+ overflow_row_behaviour (str): If set to "paginate", will create pages for overflow rows. If set to "show_ends", will show initial and final rows and truncate middle rows.
125
+ type (str): Type of value to be passed to component. "pandas" for pandas dataframe, "numpy" for numpy array, or "array" for Python array, "auto" detects return type.
126
+ label (str): component name in interface.
127
+ """
128
+ warnings.warn(
129
+ "Usage of gradio.outputs is deprecated, and will not be supported in the future, please import your components from gradio.components",
130
+ )
131
+ super().__init__(
132
+ headers=headers,
133
+ type=type,
134
+ label=label,
135
+ max_rows=max_rows,
136
+ max_cols=max_cols,
137
+ overflow_row_behaviour=overflow_row_behaviour,
138
+ )
139
+
140
+
141
+ class Timeseries(components.Timeseries):
142
+ """
143
+ Component accepts pandas.DataFrame.
144
+ Output type: pandas.DataFrame
145
+ """
146
+
147
+ def __init__(
148
+ self, x: str = None, y: str | List[str] = None, label: Optional[str] = None
149
+ ):
150
+ """
151
+ Parameters:
152
+ x (str): Column name of x (time) series. None if csv has no headers, in which case first column is x series.
153
+ y (Union[str, List[str]]): Column name of y series, or list of column names if multiple series. None if csv has no headers, in which case every column after first is a y series.
154
+ label (str): component name in interface.
155
+ """
156
+ warnings.warn(
157
+ "Usage of gradio.outputs is deprecated, and will not be supported in the future, please import your components from gradio.components",
158
+ )
159
+ super().__init__(x=x, y=y, label=label)
160
+
161
+
162
+ class State(components.State):
163
+ """
164
+ Special hidden component that stores state across runs of the interface.
165
+ Output type: Any
166
+ """
167
+
168
+ def __init__(self, label: Optional[str] = None):
169
+ """
170
+ Parameters:
171
+ label (str): component name in interface (not used).
172
+ """
173
+ warnings.warn(
174
+ "Usage of gradio.outputs is deprecated, and will not be supported in the future, please import this component as gr.State() from gradio.components",
175
+ )
176
+ super().__init__(label=label)
177
+
178
+
179
+ class Label(components.Label):
180
+ """
181
+ Component outputs a classification label, along with confidence scores of top categories if provided. Confidence scores are represented as a dictionary mapping labels to scores between 0 and 1.
182
+ Output type: Union[Dict[str, float], str, int, float]
183
+ """
184
+
185
+ def __init__(
186
+ self,
187
+ num_top_classes: Optional[int] = None,
188
+ type: str = "auto",
189
+ label: Optional[str] = None,
190
+ ):
191
+ """
192
+ Parameters:
193
+ num_top_classes (int): number of most confident classes to show.
194
+ type (str): Type of value to be passed to component. "value" expects a single out label, "confidences" expects a dictionary mapping labels to confidence scores, "auto" detects return type.
195
+ label (str): component name in interface.
196
+ """
197
+ warnings.warn(
198
+ "Usage of gradio.outputs is deprecated, and will not be supported in the future, please import your components from gradio.components",
199
+ )
200
+ super().__init__(num_top_classes=num_top_classes, type=type, label=label)
201
+
202
+
203
+ class KeyValues:
204
+ """
205
+ Component displays a table representing values for multiple fields.
206
+ Output type: Union[Dict, List[Tuple[str, Union[str, int, float]]]]
207
+ """
208
+
209
+ def __init__(self, value: str = " ", *, label: Optional[str] = None, **kwargs):
210
+ """
211
+ Parameters:
212
+ value (str): IGNORED
213
+ label (str): component name in interface.
214
+ """
215
+ raise DeprecationWarning(
216
+ "The KeyValues component is deprecated. Please use the DataFrame or JSON "
217
+ "components instead."
218
+ )
219
+
220
+
221
+ class HighlightedText(components.HighlightedText):
222
+ """
223
+ Component creates text that contains spans that are highlighted by category or numerical value.
224
+ Output is represent as a list of Tuple pairs, where the first element represents the span of text represented by the tuple, and the second element represents the category or value of the text.
225
+ Output type: List[Tuple[str, Union[float, str]]]
226
+ """
227
+
228
+ def __init__(
229
+ self,
230
+ color_map: Dict[str, str] = None,
231
+ label: Optional[str] = None,
232
+ show_legend: bool = False,
233
+ ):
234
+ """
235
+ Parameters:
236
+ color_map (Dict[str, str]): Map between category and respective colors
237
+ label (str): component name in interface.
238
+ show_legend (bool): whether to show span categories in a separate legend or inline.
239
+ """
240
+ warnings.warn(
241
+ "Usage of gradio.outputs is deprecated, and will not be supported in the future, please import your components from gradio.components",
242
+ )
243
+ super().__init__(color_map=color_map, label=label, show_legend=show_legend)
244
+
245
+
246
+ class JSON(components.JSON):
247
+ """
248
+ Used for JSON output. Expects a JSON string or a Python object that is JSON serializable.
249
+ Output type: Union[str, Any]
250
+ """
251
+
252
+ def __init__(self, label: Optional[str] = None):
253
+ """
254
+ Parameters:
255
+ label (str): component name in interface.
256
+ """
257
+ warnings.warn(
258
+ "Usage of gradio.outputs is deprecated, and will not be supported in the future, please import your components from gradio.components",
259
+ )
260
+ super().__init__(label=label)
261
+
262
+
263
+ class HTML(components.HTML):
264
+ """
265
+ Used for HTML output. Expects an HTML valid string.
266
+ Output type: str
267
+ """
268
+
269
+ def __init__(self, label: Optional[str] = None):
270
+ """
271
+ Parameters:
272
+ label (str): component name in interface.
273
+ """
274
+ super().__init__(label=label)
275
+
276
+
277
+ class Carousel(components.Carousel):
278
+ """
279
+ Component displays a set of output components that can be scrolled through.
280
+ """
281
+
282
+ def __init__(
283
+ self,
284
+ components: components.Component | List[components.Component],
285
+ label: Optional[str] = None,
286
+ ):
287
+ """
288
+ Parameters:
289
+ components (Union[List[Component], Component]): Classes of component(s) that will be scrolled through.
290
+ label (str): component name in interface.
291
+ """
292
+ warnings.warn(
293
+ "Usage of gradio.outputs is deprecated, and will not be supported in the future, please import your components from gradio.components",
294
+ )
295
+ super().__init__(components=components, label=label)
296
+
297
+
298
+ class Chatbot(components.Chatbot):
299
+ """
300
+ Component displays a chatbot output showing both user submitted messages and responses
301
+ Output type: List[Tuple[str, str]]
302
+ """
303
+
304
+ def __init__(self, label: Optional[str] = None):
305
+ """
306
+ Parameters:
307
+ label (str): component name in interface (not used).
308
+ """
309
+ warnings.warn(
310
+ "Usage of gradio.outputs is deprecated, and will not be supported in the future, please import your components from gradio.components",
311
+ )
312
+ super().__init__(label=label)
313
+
314
+
315
+ class Image3D(components.Model3D):
316
+ """
317
+ Used for 3D image model output.
318
+ Input type: File object of type (.obj, glb, or .gltf)
319
+ """
320
+
321
+ def __init__(
322
+ self,
323
+ clear_color=None,
324
+ label: Optional[str] = None,
325
+ ):
326
+ """
327
+ Parameters:
328
+ label (str): component name in interface.
329
+ optional (bool): If True, the interface can be submitted with no uploaded image, in which case the input value is None.
330
+ """
331
+ warnings.warn(
332
+ "Usage of gradio.outputs is deprecated, and will not be supported in the future, please import your components from gradio.components",
333
+ )
334
+ super().__init__(clear_color=clear_color, label=label)
gradio-modified/gradio/pipelines.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This module should not be used directly as its API is subject to change. Instead,
2
+ please use the `gr.Interface.from_pipeline()` function."""
3
+
4
+ from __future__ import annotations
5
+
6
+ from typing import TYPE_CHECKING, Dict
7
+
8
+ from gradio import components
9
+
10
+ if TYPE_CHECKING: # Only import for type checking (is False at runtime).
11
+ from transformers import pipelines
12
+
13
+
14
+ def load_from_pipeline(pipeline: pipelines.base.Pipeline) -> Dict:
15
+ """
16
+ Gets the appropriate Interface kwargs for a given Hugging Face transformers.Pipeline.
17
+ pipeline (transformers.Pipeline): the transformers.Pipeline from which to create an interface
18
+ Returns:
19
+ (dict): a dictionary of kwargs that can be used to construct an Interface object
20
+ """
21
+ try:
22
+ import transformers
23
+ from transformers import pipelines
24
+ except ImportError:
25
+ raise ImportError(
26
+ "transformers not installed. Please try `pip install transformers`"
27
+ )
28
+ if not isinstance(pipeline, pipelines.base.Pipeline):
29
+ raise ValueError("pipeline must be a transformers.Pipeline")
30
+
31
+ # Handle the different pipelines. The has_attr() checks to make sure the pipeline exists in the
32
+ # version of the transformers library that the user has installed.
33
+ if hasattr(transformers, "AudioClassificationPipeline") and isinstance(
34
+ pipeline, pipelines.audio_classification.AudioClassificationPipeline
35
+ ):
36
+ pipeline_info = {
37
+ "inputs": components.Audio(
38
+ source="microphone", type="filepath", label="Input"
39
+ ),
40
+ "outputs": components.Label(label="Class"),
41
+ "preprocess": lambda i: {"inputs": i},
42
+ "postprocess": lambda r: {i["label"].split(", ")[0]: i["score"] for i in r},
43
+ }
44
+ elif hasattr(transformers, "AutomaticSpeechRecognitionPipeline") and isinstance(
45
+ pipeline,
46
+ pipelines.automatic_speech_recognition.AutomaticSpeechRecognitionPipeline,
47
+ ):
48
+ pipeline_info = {
49
+ "inputs": components.Audio(
50
+ source="microphone", type="filepath", label="Input"
51
+ ),
52
+ "outputs": components.Textbox(label="Output"),
53
+ "preprocess": lambda i: {"inputs": i},
54
+ "postprocess": lambda r: r["text"],
55
+ }
56
+ elif hasattr(transformers, "FeatureExtractionPipeline") and isinstance(
57
+ pipeline, pipelines.feature_extraction.FeatureExtractionPipeline
58
+ ):
59
+ pipeline_info = {
60
+ "inputs": components.Textbox(label="Input"),
61
+ "outputs": components.Dataframe(label="Output"),
62
+ "preprocess": lambda x: {"inputs": x},
63
+ "postprocess": lambda r: r[0],
64
+ }
65
+ elif hasattr(transformers, "FillMaskPipeline") and isinstance(
66
+ pipeline, pipelines.fill_mask.FillMaskPipeline
67
+ ):
68
+ pipeline_info = {
69
+ "inputs": components.Textbox(label="Input"),
70
+ "outputs": components.Label(label="Classification"),
71
+ "preprocess": lambda x: {"inputs": x},
72
+ "postprocess": lambda r: {i["token_str"]: i["score"] for i in r},
73
+ }
74
+ elif hasattr(transformers, "ImageClassificationPipeline") and isinstance(
75
+ pipeline, pipelines.image_classification.ImageClassificationPipeline
76
+ ):
77
+ pipeline_info = {
78
+ "inputs": components.Image(type="filepath", label="Input Image"),
79
+ "outputs": components.Label(type="confidences", label="Classification"),
80
+ "preprocess": lambda i: {"images": i},
81
+ "postprocess": lambda r: {i["label"].split(", ")[0]: i["score"] for i in r},
82
+ }
83
+ elif hasattr(transformers, "QuestionAnsweringPipeline") and isinstance(
84
+ pipeline, pipelines.question_answering.QuestionAnsweringPipeline
85
+ ):
86
+ pipeline_info = {
87
+ "inputs": [
88
+ components.Textbox(lines=7, label="Context"),
89
+ components.Textbox(label="Question"),
90
+ ],
91
+ "outputs": [
92
+ components.Textbox(label="Answer"),
93
+ components.Label(label="Score"),
94
+ ],
95
+ "preprocess": lambda c, q: {"context": c, "question": q},
96
+ "postprocess": lambda r: (r["answer"], r["score"]),
97
+ }
98
+ elif hasattr(transformers, "SummarizationPipeline") and isinstance(
99
+ pipeline, pipelines.text2text_generation.SummarizationPipeline
100
+ ):
101
+ pipeline_info = {
102
+ "inputs": components.Textbox(lines=7, label="Input"),
103
+ "outputs": components.Textbox(label="Summary"),
104
+ "preprocess": lambda x: {"inputs": x},
105
+ "postprocess": lambda r: r[0]["summary_text"],
106
+ }
107
+ elif hasattr(transformers, "TextClassificationPipeline") and isinstance(
108
+ pipeline, pipelines.text_classification.TextClassificationPipeline
109
+ ):
110
+ pipeline_info = {
111
+ "inputs": components.Textbox(label="Input"),
112
+ "outputs": components.Label(label="Classification"),
113
+ "preprocess": lambda x: [x],
114
+ "postprocess": lambda r: {i["label"].split(", ")[0]: i["score"] for i in r},
115
+ }
116
+ elif hasattr(transformers, "TextGenerationPipeline") and isinstance(
117
+ pipeline, pipelines.text_generation.TextGenerationPipeline
118
+ ):
119
+ pipeline_info = {
120
+ "inputs": components.Textbox(label="Input"),
121
+ "outputs": components.Textbox(label="Output"),
122
+ "preprocess": lambda x: {"text_inputs": x},
123
+ "postprocess": lambda r: r[0]["generated_text"],
124
+ }
125
+ elif hasattr(transformers, "TranslationPipeline") and isinstance(
126
+ pipeline, pipelines.text2text_generation.TranslationPipeline
127
+ ):
128
+ pipeline_info = {
129
+ "inputs": components.Textbox(label="Input"),
130
+ "outputs": components.Textbox(label="Translation"),
131
+ "preprocess": lambda x: [x],
132
+ "postprocess": lambda r: r[0]["translation_text"],
133
+ }
134
+ elif hasattr(transformers, "Text2TextGenerationPipeline") and isinstance(
135
+ pipeline, pipelines.text2text_generation.Text2TextGenerationPipeline
136
+ ):
137
+ pipeline_info = {
138
+ "inputs": components.Textbox(label="Input"),
139
+ "outputs": components.Textbox(label="Generated Text"),
140
+ "preprocess": lambda x: [x],
141
+ "postprocess": lambda r: r[0]["generated_text"],
142
+ }
143
+ elif hasattr(transformers, "ZeroShotClassificationPipeline") and isinstance(
144
+ pipeline, pipelines.zero_shot_classification.ZeroShotClassificationPipeline
145
+ ):
146
+ pipeline_info = {
147
+ "inputs": [
148
+ components.Textbox(label="Input"),
149
+ components.Textbox(label="Possible class names (" "comma-separated)"),
150
+ components.Checkbox(label="Allow multiple true classes"),
151
+ ],
152
+ "outputs": components.Label(label="Classification"),
153
+ "preprocess": lambda i, c, m: {
154
+ "sequences": i,
155
+ "candidate_labels": c,
156
+ "multi_label": m,
157
+ },
158
+ "postprocess": lambda r: {
159
+ r["labels"][i]: r["scores"][i] for i in range(len(r["labels"]))
160
+ },
161
+ }
162
+ else:
163
+ raise ValueError("Unsupported pipeline type: {}".format(type(pipeline)))
164
+
165
+ # define the function that will be called by the Interface
166
+ def fn(*params):
167
+ data = pipeline_info["preprocess"](*params)
168
+ # special cases that needs to be handled differently
169
+ if isinstance(
170
+ pipeline,
171
+ (
172
+ pipelines.text_classification.TextClassificationPipeline,
173
+ pipelines.text2text_generation.Text2TextGenerationPipeline,
174
+ pipelines.text2text_generation.TranslationPipeline,
175
+ ),
176
+ ):
177
+ data = pipeline(*data)
178
+ else:
179
+ data = pipeline(**data)
180
+ output = pipeline_info["postprocess"](data)
181
+ return output
182
+
183
+ interface_info = pipeline_info.copy()
184
+ interface_info["fn"] = fn
185
+ del interface_info["preprocess"]
186
+ del interface_info["postprocess"]
187
+
188
+ # define the title/description of the Interface
189
+ interface_info["title"] = pipeline.model.__class__.__name__
190
+
191
+ return interface_info
gradio-modified/gradio/processing_utils.py ADDED
@@ -0,0 +1,755 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import base64
4
+ import hashlib
5
+ import json
6
+ import mimetypes
7
+ import os
8
+ import pathlib
9
+ import shutil
10
+ import subprocess
11
+ import tempfile
12
+ import urllib.request
13
+ import warnings
14
+ from io import BytesIO
15
+ from pathlib import Path
16
+ from typing import Dict, Tuple
17
+
18
+ import numpy as np
19
+ import requests
20
+ from ffmpy import FFmpeg, FFprobe, FFRuntimeError
21
+ from PIL import Image, ImageOps, PngImagePlugin
22
+
23
+ from gradio import encryptor, utils
24
+
25
+ with warnings.catch_warnings():
26
+ warnings.simplefilter("ignore") # Ignore pydub warning if ffmpeg is not installed
27
+ from pydub import AudioSegment
28
+
29
+
30
+ #########################
31
+ # GENERAL
32
+ #########################
33
+
34
+
35
+ def to_binary(x: str | Dict) -> bytes:
36
+ """Converts a base64 string or dictionary to a binary string that can be sent in a POST."""
37
+ if isinstance(x, dict):
38
+ if x.get("data"):
39
+ base64str = x["data"]
40
+ else:
41
+ base64str = encode_url_or_file_to_base64(x["name"])
42
+ else:
43
+ base64str = x
44
+ return base64.b64decode(base64str.split(",")[1])
45
+
46
+
47
+ #########################
48
+ # IMAGE PRE-PROCESSING
49
+ #########################
50
+
51
+
52
+ def decode_base64_to_image(encoding: str) -> Image.Image:
53
+ content = encoding.split(";")[1]
54
+ image_encoded = content.split(",")[1]
55
+ return Image.open(BytesIO(base64.b64decode(image_encoded)))
56
+
57
+
58
+ def encode_url_or_file_to_base64(path: str | Path, encryption_key: bytes | None = None):
59
+ if utils.validate_url(str(path)):
60
+ return encode_url_to_base64(str(path), encryption_key=encryption_key)
61
+ else:
62
+ return encode_file_to_base64(str(path), encryption_key=encryption_key)
63
+
64
+
65
+ def get_mimetype(filename: str) -> str | None:
66
+ mimetype = mimetypes.guess_type(filename)[0]
67
+ if mimetype is not None:
68
+ mimetype = mimetype.replace("x-wav", "wav").replace("x-flac", "flac")
69
+ return mimetype
70
+
71
+
72
+ def get_extension(encoding: str) -> str | None:
73
+ encoding = encoding.replace("audio/wav", "audio/x-wav")
74
+ type = mimetypes.guess_type(encoding)[0]
75
+ if type == "audio/flac": # flac is not supported by mimetypes
76
+ return "flac"
77
+ elif type is None:
78
+ return None
79
+ extension = mimetypes.guess_extension(type)
80
+ if extension is not None and extension.startswith("."):
81
+ extension = extension[1:]
82
+ return extension
83
+
84
+
85
+ def encode_file_to_base64(f, encryption_key=None):
86
+ with open(f, "rb") as file:
87
+ encoded_string = base64.b64encode(file.read())
88
+ if encryption_key:
89
+ encoded_string = encryptor.decrypt(encryption_key, encoded_string)
90
+ base64_str = str(encoded_string, "utf-8")
91
+ mimetype = get_mimetype(f)
92
+ return (
93
+ "data:"
94
+ + (mimetype if mimetype is not None else "")
95
+ + ";base64,"
96
+ + base64_str
97
+ )
98
+
99
+
100
+ def encode_url_to_base64(url, encryption_key=None):
101
+ encoded_string = base64.b64encode(requests.get(url).content)
102
+ if encryption_key:
103
+ encoded_string = encryptor.decrypt(encryption_key, encoded_string)
104
+ base64_str = str(encoded_string, "utf-8")
105
+ mimetype = get_mimetype(url)
106
+ return (
107
+ "data:" + (mimetype if mimetype is not None else "") + ";base64," + base64_str
108
+ )
109
+
110
+
111
+ def encode_plot_to_base64(plt):
112
+ with BytesIO() as output_bytes:
113
+ plt.savefig(output_bytes, format="png")
114
+ bytes_data = output_bytes.getvalue()
115
+ base64_str = str(base64.b64encode(bytes_data), "utf-8")
116
+ return "data:image/png;base64," + base64_str
117
+
118
+
119
+ def save_array_to_file(image_array, dir=None):
120
+ pil_image = Image.fromarray(_convert(image_array, np.uint8, force_copy=False))
121
+ file_obj = tempfile.NamedTemporaryFile(delete=False, suffix=".png", dir=dir)
122
+ pil_image.save(file_obj)
123
+ return file_obj
124
+
125
+
126
+ def save_pil_to_file(pil_image, dir=None):
127
+ file_obj = tempfile.NamedTemporaryFile(delete=False, suffix=".png", dir=dir)
128
+ pil_image.save(file_obj)
129
+ return file_obj
130
+
131
+
132
+ def encode_pil_to_base64(pil_image):
133
+ with BytesIO() as output_bytes:
134
+
135
+ # Copy any text-only metadata
136
+ use_metadata = False
137
+ metadata = PngImagePlugin.PngInfo()
138
+ for key, value in pil_image.info.items():
139
+ if isinstance(key, str) and isinstance(value, str):
140
+ metadata.add_text(key, value)
141
+ use_metadata = True
142
+
143
+ pil_image.save(
144
+ output_bytes, "PNG", pnginfo=(metadata if use_metadata else None)
145
+ )
146
+ bytes_data = output_bytes.getvalue()
147
+ base64_str = str(base64.b64encode(bytes_data), "utf-8")
148
+ return "data:image/png;base64," + base64_str
149
+
150
+
151
+ def encode_array_to_base64(image_array):
152
+ with BytesIO() as output_bytes:
153
+ pil_image = Image.fromarray(_convert(image_array, np.uint8, force_copy=False))
154
+ pil_image.save(output_bytes, "PNG")
155
+ bytes_data = output_bytes.getvalue()
156
+ base64_str = str(base64.b64encode(bytes_data), "utf-8")
157
+ return "data:image/png;base64," + base64_str
158
+
159
+
160
+ def resize_and_crop(img, size, crop_type="center"):
161
+ """
162
+ Resize and crop an image to fit the specified size.
163
+ args:
164
+ size: `(width, height)` tuple. Pass `None` for either width or height
165
+ to only crop and resize the other.
166
+ crop_type: can be 'top', 'middle' or 'bottom', depending on this
167
+ value, the image will cropped getting the 'top/left', 'middle' or
168
+ 'bottom/right' of the image to fit the size.
169
+ raises:
170
+ ValueError: if an invalid `crop_type` is provided.
171
+ """
172
+ if crop_type == "top":
173
+ center = (0, 0)
174
+ elif crop_type == "center":
175
+ center = (0.5, 0.5)
176
+ else:
177
+ raise ValueError
178
+
179
+ resize = list(size)
180
+ if size[0] is None:
181
+ resize[0] = img.size[0]
182
+ if size[1] is None:
183
+ resize[1] = img.size[1]
184
+ return ImageOps.fit(img, resize, centering=center) # type: ignore
185
+
186
+
187
+ ##################
188
+ # Audio
189
+ ##################
190
+
191
+
192
+ def audio_from_file(filename, crop_min=0, crop_max=100):
193
+ try:
194
+ audio = AudioSegment.from_file(filename)
195
+ except FileNotFoundError as e:
196
+ isfile = Path(filename).is_file()
197
+ msg = (
198
+ f"Cannot load audio from file: `{'ffprobe' if isfile else filename}` not found."
199
+ + " Please install `ffmpeg` in your system to use non-WAV audio file formats"
200
+ " and make sure `ffprobe` is in your PATH."
201
+ if isfile
202
+ else ""
203
+ )
204
+ raise RuntimeError(msg) from e
205
+ if crop_min != 0 or crop_max != 100:
206
+ audio_start = len(audio) * crop_min / 100
207
+ audio_end = len(audio) * crop_max / 100
208
+ audio = audio[audio_start:audio_end]
209
+ data = np.array(audio.get_array_of_samples())
210
+ if audio.channels > 1:
211
+ data = data.reshape(-1, audio.channels)
212
+ return audio.frame_rate, data
213
+
214
+
215
+ def audio_to_file(sample_rate, data, filename):
216
+ data = convert_to_16_bit_wav(data)
217
+ audio = AudioSegment(
218
+ data.tobytes(),
219
+ frame_rate=sample_rate,
220
+ sample_width=data.dtype.itemsize,
221
+ channels=(1 if len(data.shape) == 1 else data.shape[1]),
222
+ )
223
+ file = audio.export(filename, format="wav")
224
+ file.close() # type: ignore
225
+
226
+
227
+ def convert_to_16_bit_wav(data):
228
+ # Based on: https://docs.scipy.org/doc/scipy/reference/generated/scipy.io.wavfile.write.html
229
+ warning = "Trying to convert audio automatically from {} to 16-bit int format."
230
+ if data.dtype in [np.float64, np.float32, np.float16]:
231
+ warnings.warn(warning.format(data.dtype))
232
+ data = data / np.abs(data).max()
233
+ data = data * 32767
234
+ data = data.astype(np.int16)
235
+ elif data.dtype == np.int32:
236
+ warnings.warn(warning.format(data.dtype))
237
+ data = data / 65538
238
+ data = data.astype(np.int16)
239
+ elif data.dtype == np.int16:
240
+ pass
241
+ elif data.dtype == np.uint16:
242
+ warnings.warn(warning.format(data.dtype))
243
+ data = data - 32768
244
+ data = data.astype(np.int16)
245
+ elif data.dtype == np.uint8:
246
+ warnings.warn(warning.format(data.dtype))
247
+ data = data * 257 - 32768
248
+ data = data.astype(np.int16)
249
+ else:
250
+ raise ValueError(
251
+ "Audio data cannot be converted automatically from "
252
+ f"{data.dtype} to 16-bit int format."
253
+ )
254
+ return data
255
+
256
+
257
+ ##################
258
+ # OUTPUT
259
+ ##################
260
+
261
+
262
+ def decode_base64_to_binary(encoding) -> Tuple[bytes, str | None]:
263
+ extension = get_extension(encoding)
264
+ data = encoding.split(",")[1]
265
+ return base64.b64decode(data), extension
266
+
267
+
268
+ def decode_base64_to_file(
269
+ encoding, encryption_key=None, file_path=None, dir=None, prefix=None
270
+ ):
271
+ if dir is not None:
272
+ os.makedirs(dir, exist_ok=True)
273
+ data, extension = decode_base64_to_binary(encoding)
274
+ if file_path is not None and prefix is None:
275
+ filename = Path(file_path).name
276
+ prefix = filename
277
+ if "." in filename:
278
+ prefix = filename[0 : filename.index(".")]
279
+ extension = filename[filename.index(".") + 1 :]
280
+
281
+ if prefix is not None:
282
+ prefix = utils.strip_invalid_filename_characters(prefix)
283
+
284
+ if extension is None:
285
+ file_obj = tempfile.NamedTemporaryFile(delete=False, prefix=prefix, dir=dir)
286
+ else:
287
+ file_obj = tempfile.NamedTemporaryFile(
288
+ delete=False,
289
+ prefix=prefix,
290
+ suffix="." + extension,
291
+ dir=dir,
292
+ )
293
+ if encryption_key is not None:
294
+ data = encryptor.encrypt(encryption_key, data)
295
+ file_obj.write(data)
296
+ file_obj.flush()
297
+ return file_obj
298
+
299
+
300
+ def dict_or_str_to_json_file(jsn, dir=None):
301
+ if dir is not None:
302
+ os.makedirs(dir, exist_ok=True)
303
+
304
+ file_obj = tempfile.NamedTemporaryFile(
305
+ delete=False, suffix=".json", dir=dir, mode="w+"
306
+ )
307
+ if isinstance(jsn, str):
308
+ jsn = json.loads(jsn)
309
+ json.dump(jsn, file_obj)
310
+ file_obj.flush()
311
+ return file_obj
312
+
313
+
314
+ def file_to_json(file_path: str | Path) -> Dict:
315
+ with open(file_path) as f:
316
+ return json.load(f)
317
+
318
+
319
+ class TempFileManager:
320
+ """
321
+ A class that should be inherited by any Component that needs to manage temporary files.
322
+ It should be instantiated in the __init__ method of the component.
323
+ """
324
+
325
+ def __init__(self) -> None:
326
+ # Set stores all the temporary files created by this component.
327
+ self.temp_files = set()
328
+
329
+ def hash_file(self, file_path: str, chunk_num_blocks: int = 128) -> str:
330
+ sha1 = hashlib.sha1()
331
+ with open(file_path, "rb") as f:
332
+ for chunk in iter(lambda: f.read(chunk_num_blocks * sha1.block_size), b""):
333
+ sha1.update(chunk)
334
+ return sha1.hexdigest()
335
+
336
+ def hash_url(self, url: str, chunk_num_blocks: int = 128) -> str:
337
+ sha1 = hashlib.sha1()
338
+ remote = urllib.request.urlopen(url)
339
+ max_file_size = 100 * 1024 * 1024 # 100MB
340
+ total_read = 0
341
+ while True:
342
+ data = remote.read(chunk_num_blocks * sha1.block_size)
343
+ total_read += chunk_num_blocks * sha1.block_size
344
+ if not data or total_read > max_file_size:
345
+ break
346
+ sha1.update(data)
347
+ return sha1.hexdigest()
348
+
349
+ def get_prefix_and_extension(self, file_path_or_url: str) -> Tuple[str, str]:
350
+ file_name = Path(file_path_or_url).name
351
+ prefix, extension = file_name, None
352
+ if "." in file_name:
353
+ prefix = file_name[0 : file_name.index(".")]
354
+ extension = "." + file_name[file_name.index(".") + 1 :]
355
+ else:
356
+ extension = ""
357
+ prefix = utils.strip_invalid_filename_characters(prefix)
358
+ return prefix, extension
359
+
360
+ def get_temp_file_path(self, file_path: str) -> str:
361
+ prefix, extension = self.get_prefix_and_extension(file_path)
362
+ file_hash = self.hash_file(file_path)
363
+ return prefix + file_hash + extension
364
+
365
+ def get_temp_url_path(self, url: str) -> str:
366
+ prefix, extension = self.get_prefix_and_extension(url)
367
+ file_hash = self.hash_url(url)
368
+ return prefix + file_hash + extension
369
+
370
+ def make_temp_copy_if_needed(self, file_path: str) -> str:
371
+ """Returns a temporary file path for a copy of the given file path if it does
372
+ not already exist. Otherwise returns the path to the existing temp file."""
373
+ f = tempfile.NamedTemporaryFile()
374
+ temp_dir = Path(f.name).parent
375
+
376
+ temp_file_path = self.get_temp_file_path(file_path)
377
+ f.name = str(temp_dir / temp_file_path)
378
+ full_temp_file_path = str(Path(f.name).resolve())
379
+
380
+ if not Path(full_temp_file_path).exists():
381
+ shutil.copy2(file_path, full_temp_file_path)
382
+
383
+ self.temp_files.add(full_temp_file_path)
384
+ return full_temp_file_path
385
+
386
+ def download_temp_copy_if_needed(self, url: str) -> str:
387
+ """Downloads a file and makes a temporary file path for a copy if does not already
388
+ exist. Otherwise returns the path to the existing temp file."""
389
+ f = tempfile.NamedTemporaryFile()
390
+ temp_dir = Path(f.name).parent
391
+
392
+ temp_file_path = self.get_temp_url_path(url)
393
+ f.name = str(temp_dir / temp_file_path)
394
+ full_temp_file_path = str(Path(f.name).resolve())
395
+
396
+ if not Path(full_temp_file_path).exists():
397
+ with requests.get(url, stream=True) as r:
398
+ with open(full_temp_file_path, "wb") as f:
399
+ shutil.copyfileobj(r.raw, f)
400
+
401
+ self.temp_files.add(full_temp_file_path)
402
+ return full_temp_file_path
403
+
404
+
405
+ def create_tmp_copy_of_file(file_path, dir=None):
406
+ if dir is not None:
407
+ os.makedirs(dir, exist_ok=True)
408
+ file_name = Path(file_path).name
409
+ prefix, extension = file_name, None
410
+ if "." in file_name:
411
+ prefix = file_name[0 : file_name.index(".")]
412
+ extension = file_name[file_name.index(".") + 1 :]
413
+ prefix = utils.strip_invalid_filename_characters(prefix)
414
+ if extension is None:
415
+ file_obj = tempfile.NamedTemporaryFile(delete=False, prefix=prefix, dir=dir)
416
+ else:
417
+ file_obj = tempfile.NamedTemporaryFile(
418
+ delete=False,
419
+ prefix=prefix,
420
+ suffix="." + extension,
421
+ dir=dir,
422
+ )
423
+ shutil.copy2(file_path, file_obj.name)
424
+ return file_obj
425
+
426
+
427
+ def _convert(image, dtype, force_copy=False, uniform=False):
428
+ """
429
+ Adapted from: https://github.com/scikit-image/scikit-image/blob/main/skimage/util/dtype.py#L510-L531
430
+
431
+ Convert an image to the requested data-type.
432
+ Warnings are issued in case of precision loss, or when negative values
433
+ are clipped during conversion to unsigned integer types (sign loss).
434
+ Floating point values are expected to be normalized and will be clipped
435
+ to the range [0.0, 1.0] or [-1.0, 1.0] when converting to unsigned or
436
+ signed integers respectively.
437
+ Numbers are not shifted to the negative side when converting from
438
+ unsigned to signed integer types. Negative values will be clipped when
439
+ converting to unsigned integers.
440
+ Parameters
441
+ ----------
442
+ image : ndarray
443
+ Input image.
444
+ dtype : dtype
445
+ Target data-type.
446
+ force_copy : bool, optional
447
+ Force a copy of the data, irrespective of its current dtype.
448
+ uniform : bool, optional
449
+ Uniformly quantize the floating point range to the integer range.
450
+ By default (uniform=False) floating point values are scaled and
451
+ rounded to the nearest integers, which minimizes back and forth
452
+ conversion errors.
453
+ .. versionchanged :: 0.15
454
+ ``_convert`` no longer warns about possible precision or sign
455
+ information loss. See discussions on these warnings at:
456
+ https://github.com/scikit-image/scikit-image/issues/2602
457
+ https://github.com/scikit-image/scikit-image/issues/543#issuecomment-208202228
458
+ https://github.com/scikit-image/scikit-image/pull/3575
459
+ References
460
+ ----------
461
+ .. [1] DirectX data conversion rules.
462
+ https://msdn.microsoft.com/en-us/library/windows/desktop/dd607323%28v=vs.85%29.aspx
463
+ .. [2] Data Conversions. In "OpenGL ES 2.0 Specification v2.0.25",
464
+ pp 7-8. Khronos Group, 2010.
465
+ .. [3] Proper treatment of pixels as integers. A.W. Paeth.
466
+ In "Graphics Gems I", pp 249-256. Morgan Kaufmann, 1990.
467
+ .. [4] Dirty Pixels. J. Blinn. In "Jim Blinn's corner: Dirty Pixels",
468
+ pp 47-57. Morgan Kaufmann, 1998.
469
+ """
470
+ dtype_range = {
471
+ bool: (False, True),
472
+ np.bool_: (False, True),
473
+ np.bool8: (False, True),
474
+ float: (-1, 1),
475
+ np.float_: (-1, 1),
476
+ np.float16: (-1, 1),
477
+ np.float32: (-1, 1),
478
+ np.float64: (-1, 1),
479
+ }
480
+
481
+ def _dtype_itemsize(itemsize, *dtypes):
482
+ """Return first of `dtypes` with itemsize greater than `itemsize`
483
+ Parameters
484
+ ----------
485
+ itemsize: int
486
+ The data type object element size.
487
+ Other Parameters
488
+ ----------------
489
+ *dtypes:
490
+ Any Object accepted by `np.dtype` to be converted to a data
491
+ type object
492
+ Returns
493
+ -------
494
+ dtype: data type object
495
+ First of `dtypes` with itemsize greater than `itemsize`.
496
+ """
497
+ return next(dt for dt in dtypes if np.dtype(dt).itemsize >= itemsize)
498
+
499
+ def _dtype_bits(kind, bits, itemsize=1):
500
+ """Return dtype of `kind` that can store a `bits` wide unsigned int
501
+ Parameters:
502
+ kind: str
503
+ Data type kind.
504
+ bits: int
505
+ Desired number of bits.
506
+ itemsize: int
507
+ The data type object element size.
508
+ Returns
509
+ -------
510
+ dtype: data type object
511
+ Data type of `kind` that can store a `bits` wide unsigned int
512
+ """
513
+
514
+ s = next(
515
+ i
516
+ for i in (itemsize,) + (2, 4, 8)
517
+ if bits < (i * 8) or (bits == (i * 8) and kind == "u")
518
+ )
519
+
520
+ return np.dtype(kind + str(s))
521
+
522
+ def _scale(a, n, m, copy=True):
523
+ """Scale an array of unsigned/positive integers from `n` to `m` bits.
524
+ Numbers can be represented exactly only if `m` is a multiple of `n`.
525
+ Parameters
526
+ ----------
527
+ a : ndarray
528
+ Input image array.
529
+ n : int
530
+ Number of bits currently used to encode the values in `a`.
531
+ m : int
532
+ Desired number of bits to encode the values in `out`.
533
+ copy : bool, optional
534
+ If True, allocates and returns new array. Otherwise, modifies
535
+ `a` in place.
536
+ Returns
537
+ -------
538
+ out : array
539
+ Output image array. Has the same kind as `a`.
540
+ """
541
+ kind = a.dtype.kind
542
+ if n > m and a.max() < 2**m:
543
+ return a.astype(_dtype_bits(kind, m))
544
+ elif n == m:
545
+ return a.copy() if copy else a
546
+ elif n > m:
547
+ # downscale with precision loss
548
+ if copy:
549
+ b = np.empty(a.shape, _dtype_bits(kind, m))
550
+ np.floor_divide(a, 2 ** (n - m), out=b, dtype=a.dtype, casting="unsafe")
551
+ return b
552
+ else:
553
+ a //= 2 ** (n - m)
554
+ return a
555
+ elif m % n == 0:
556
+ # exact upscale to a multiple of `n` bits
557
+ if copy:
558
+ b = np.empty(a.shape, _dtype_bits(kind, m))
559
+ np.multiply(a, (2**m - 1) // (2**n - 1), out=b, dtype=b.dtype)
560
+ return b
561
+ else:
562
+ a = a.astype(_dtype_bits(kind, m, a.dtype.itemsize), copy=False)
563
+ a *= (2**m - 1) // (2**n - 1)
564
+ return a
565
+ else:
566
+ # upscale to a multiple of `n` bits,
567
+ # then downscale with precision loss
568
+ o = (m // n + 1) * n
569
+ if copy:
570
+ b = np.empty(a.shape, _dtype_bits(kind, o))
571
+ np.multiply(a, (2**o - 1) // (2**n - 1), out=b, dtype=b.dtype)
572
+ b //= 2 ** (o - m)
573
+ return b
574
+ else:
575
+ a = a.astype(_dtype_bits(kind, o, a.dtype.itemsize), copy=False)
576
+ a *= (2**o - 1) // (2**n - 1)
577
+ a //= 2 ** (o - m)
578
+ return a
579
+
580
+ image = np.asarray(image)
581
+ dtypeobj_in = image.dtype
582
+ if dtype is np.floating:
583
+ dtypeobj_out = np.dtype("float64")
584
+ else:
585
+ dtypeobj_out = np.dtype(dtype)
586
+ dtype_in = dtypeobj_in.type
587
+ dtype_out = dtypeobj_out.type
588
+ kind_in = dtypeobj_in.kind
589
+ kind_out = dtypeobj_out.kind
590
+ itemsize_in = dtypeobj_in.itemsize
591
+ itemsize_out = dtypeobj_out.itemsize
592
+
593
+ # Below, we do an `issubdtype` check. Its purpose is to find out
594
+ # whether we can get away without doing any image conversion. This happens
595
+ # when:
596
+ #
597
+ # - the output and input dtypes are the same or
598
+ # - when the output is specified as a type, and the input dtype
599
+ # is a subclass of that type (e.g. `np.floating` will allow
600
+ # `float32` and `float64` arrays through)
601
+
602
+ if np.issubdtype(dtype_in, np.obj2sctype(dtype)):
603
+ if force_copy:
604
+ image = image.copy()
605
+ return image
606
+
607
+ if kind_in in "ui":
608
+ imin_in = np.iinfo(dtype_in).min
609
+ imax_in = np.iinfo(dtype_in).max
610
+ if kind_out in "ui":
611
+ imin_out = np.iinfo(dtype_out).min # type: ignore
612
+ imax_out = np.iinfo(dtype_out).max # type: ignore
613
+
614
+ # any -> binary
615
+ if kind_out == "b":
616
+ return image > dtype_in(dtype_range[dtype_in][1] / 2)
617
+
618
+ # binary -> any
619
+ if kind_in == "b":
620
+ result = image.astype(dtype_out)
621
+ if kind_out != "f":
622
+ result *= dtype_out(dtype_range[dtype_out][1])
623
+ return result
624
+
625
+ # float -> any
626
+ if kind_in == "f":
627
+ if kind_out == "f":
628
+ # float -> float
629
+ return image.astype(dtype_out)
630
+
631
+ if np.min(image) < -1.0 or np.max(image) > 1.0:
632
+ raise ValueError("Images of type float must be between -1 and 1.")
633
+ # floating point -> integer
634
+ # use float type that can represent output integer type
635
+ computation_type = _dtype_itemsize(
636
+ itemsize_out, dtype_in, np.float32, np.float64
637
+ )
638
+
639
+ if not uniform:
640
+ if kind_out == "u":
641
+ image_out = np.multiply(image, imax_out, dtype=computation_type) # type: ignore
642
+ else:
643
+ image_out = np.multiply(
644
+ image, (imax_out - imin_out) / 2, dtype=computation_type # type: ignore
645
+ )
646
+ image_out -= 1.0 / 2.0
647
+ np.rint(image_out, out=image_out)
648
+ np.clip(image_out, imin_out, imax_out, out=image_out) # type: ignore
649
+ elif kind_out == "u":
650
+ image_out = np.multiply(image, imax_out + 1, dtype=computation_type) # type: ignore
651
+ np.clip(image_out, 0, imax_out, out=image_out) # type: ignore
652
+ else:
653
+ image_out = np.multiply(
654
+ image, (imax_out - imin_out + 1.0) / 2.0, dtype=computation_type # type: ignore
655
+ )
656
+ np.floor(image_out, out=image_out)
657
+ np.clip(image_out, imin_out, imax_out, out=image_out) # type: ignore
658
+ return image_out.astype(dtype_out)
659
+
660
+ # signed/unsigned int -> float
661
+ if kind_out == "f":
662
+ # use float type that can exactly represent input integers
663
+ computation_type = _dtype_itemsize(
664
+ itemsize_in, dtype_out, np.float32, np.float64
665
+ )
666
+
667
+ if kind_in == "u":
668
+ # using np.divide or np.multiply doesn't copy the data
669
+ # until the computation time
670
+ image = np.multiply(image, 1.0 / imax_in, dtype=computation_type) # type: ignore
671
+ # DirectX uses this conversion also for signed ints
672
+ # if imin_in:
673
+ # np.maximum(image, -1.0, out=image)
674
+ else:
675
+ image = np.add(image, 0.5, dtype=computation_type)
676
+ image *= 2 / (imax_in - imin_in) # type: ignore
677
+
678
+ return np.asarray(image, dtype_out)
679
+
680
+ # unsigned int -> signed/unsigned int
681
+ if kind_in == "u":
682
+ if kind_out == "i":
683
+ # unsigned int -> signed int
684
+ image = _scale(image, 8 * itemsize_in, 8 * itemsize_out - 1)
685
+ return image.view(dtype_out)
686
+ else:
687
+ # unsigned int -> unsigned int
688
+ return _scale(image, 8 * itemsize_in, 8 * itemsize_out)
689
+
690
+ # signed int -> unsigned int
691
+ if kind_out == "u":
692
+ image = _scale(image, 8 * itemsize_in - 1, 8 * itemsize_out)
693
+ result = np.empty(image.shape, dtype_out)
694
+ np.maximum(image, 0, out=result, dtype=image.dtype, casting="unsafe")
695
+ return result
696
+
697
+ # signed int -> signed int
698
+ if itemsize_in > itemsize_out:
699
+ return _scale(image, 8 * itemsize_in - 1, 8 * itemsize_out - 1)
700
+
701
+ image = image.astype(_dtype_bits("i", itemsize_out * 8))
702
+ image -= imin_in # type: ignore
703
+ image = _scale(image, 8 * itemsize_in, 8 * itemsize_out, copy=False)
704
+ image += imin_out # type: ignore
705
+ return image.astype(dtype_out)
706
+
707
+
708
+ def ffmpeg_installed() -> bool:
709
+ return shutil.which("ffmpeg") is not None
710
+
711
+
712
+ def video_is_playable(video_filepath: str) -> bool:
713
+ """Determines if a video is playable in the browser.
714
+
715
+ A video is playable if it has a playable container and codec.
716
+ .mp4 -> h264
717
+ .webm -> vp9
718
+ .ogg -> theora
719
+ """
720
+ try:
721
+ container = pathlib.Path(video_filepath).suffix.lower()
722
+ probe = FFprobe(
723
+ global_options="-show_format -show_streams -select_streams v -print_format json",
724
+ inputs={video_filepath: None},
725
+ )
726
+ output = probe.run(stderr=subprocess.PIPE, stdout=subprocess.PIPE)
727
+ output = json.loads(output[0])
728
+ video_codec = output["streams"][0]["codec_name"]
729
+ return (container, video_codec) in [
730
+ (".mp4", "h264"),
731
+ (".ogg", "theora"),
732
+ (".webm", "vp9"),
733
+ ]
734
+ # If anything goes wrong, assume the video can be played to not convert downstream
735
+ except (FFRuntimeError, IndexError, KeyError):
736
+ return True
737
+
738
+
739
+ def convert_video_to_playable_mp4(video_path: str) -> str:
740
+ """Convert the video to mp4. If something goes wrong return the original video."""
741
+ try:
742
+ output_path = pathlib.Path(video_path).with_suffix(".mp4")
743
+ with tempfile.NamedTemporaryFile(delete=False) as tmp_file:
744
+ shutil.copy2(video_path, tmp_file.name)
745
+ # ffmpeg will automatically use h264 codec (playable in browser) when converting to mp4
746
+ ff = FFmpeg(
747
+ inputs={str(tmp_file.name): None},
748
+ outputs={str(output_path): None},
749
+ global_options="-y -loglevel quiet",
750
+ )
751
+ ff.run()
752
+ except FFRuntimeError as e:
753
+ print(f"Error converting video to browser-playable format {str(e)}")
754
+ output_path = video_path
755
+ return str(output_path)
gradio-modified/gradio/queueing.py ADDED
@@ -0,0 +1,446 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ import copy
5
+ import sys
6
+ import time
7
+ from collections import deque
8
+ from typing import Any, Deque, Dict, List, Tuple
9
+
10
+ import fastapi
11
+
12
+ from gradio.data_classes import Estimation, PredictBody, Progress, ProgressUnit
13
+ from gradio.helpers import TrackedIterable
14
+ from gradio.utils import AsyncRequest, run_coro_in_background, set_task_name
15
+
16
+
17
+ class Event:
18
+ def __init__(
19
+ self,
20
+ websocket: fastapi.WebSocket,
21
+ session_hash: str,
22
+ fn_index: int,
23
+ ):
24
+ self.websocket = websocket
25
+ self.session_hash: str = session_hash
26
+ self.fn_index: int = fn_index
27
+ self._id = f"{self.session_hash}_{self.fn_index}"
28
+ self.data: PredictBody | None = None
29
+ self.lost_connection_time: float | None = None
30
+ self.token: str | None = None
31
+ self.progress: Progress | None = None
32
+ self.progress_pending: bool = False
33
+
34
+ async def disconnect(self, code: int = 1000):
35
+ await self.websocket.close(code=code)
36
+
37
+
38
+ class Queue:
39
+ def __init__(
40
+ self,
41
+ live_updates: bool,
42
+ concurrency_count: int,
43
+ update_intervals: float,
44
+ max_size: int | None,
45
+ blocks_dependencies: List,
46
+ ):
47
+ self.event_queue: Deque[Event] = deque()
48
+ self.events_pending_reconnection = []
49
+ self.stopped = False
50
+ self.max_thread_count = concurrency_count
51
+ self.update_intervals = update_intervals
52
+ self.active_jobs: List[None | List[Event]] = [None] * concurrency_count
53
+ self.delete_lock = asyncio.Lock()
54
+ self.server_path = None
55
+ self.duration_history_total = 0
56
+ self.duration_history_count = 0
57
+ self.avg_process_time = 0
58
+ self.avg_concurrent_process_time = None
59
+ self.queue_duration = 1
60
+ self.live_updates = live_updates
61
+ self.sleep_when_free = 0.05
62
+ self.progress_update_sleep_when_free = 0.1
63
+ self.max_size = max_size
64
+ self.blocks_dependencies = blocks_dependencies
65
+ self.access_token = ""
66
+
67
+ async def start(self, progress_tracking=False):
68
+ run_coro_in_background(self.start_processing)
69
+ if progress_tracking:
70
+ run_coro_in_background(self.start_progress_tracking)
71
+ if not self.live_updates:
72
+ run_coro_in_background(self.notify_clients)
73
+
74
+ def close(self):
75
+ self.stopped = True
76
+
77
+ def resume(self):
78
+ self.stopped = False
79
+
80
+ def set_url(self, url: str):
81
+ self.server_path = url
82
+
83
+ def set_access_token(self, token: str):
84
+ self.access_token = token
85
+
86
+ def get_active_worker_count(self) -> int:
87
+ count = 0
88
+ for worker in self.active_jobs:
89
+ if worker is not None:
90
+ count += 1
91
+ return count
92
+
93
+ def get_events_in_batch(self) -> Tuple[List[Event] | None, bool]:
94
+ if not (self.event_queue):
95
+ return None, False
96
+
97
+ first_event = self.event_queue.popleft()
98
+ events = [first_event]
99
+
100
+ event_fn_index = first_event.fn_index
101
+ batch = self.blocks_dependencies[event_fn_index]["batch"]
102
+
103
+ if batch:
104
+ batch_size = self.blocks_dependencies[event_fn_index]["max_batch_size"]
105
+ rest_of_batch = [
106
+ event for event in self.event_queue if event.fn_index == event_fn_index
107
+ ][: batch_size - 1]
108
+ events.extend(rest_of_batch)
109
+ [self.event_queue.remove(event) for event in rest_of_batch]
110
+
111
+ return events, batch
112
+
113
+ async def start_processing(self) -> None:
114
+ while not self.stopped:
115
+ if not self.event_queue:
116
+ await asyncio.sleep(self.sleep_when_free)
117
+ continue
118
+
119
+ if not (None in self.active_jobs):
120
+ await asyncio.sleep(self.sleep_when_free)
121
+ continue
122
+ # Using mutex to avoid editing a list in use
123
+ async with self.delete_lock:
124
+ events, batch = self.get_events_in_batch()
125
+
126
+ if events:
127
+ self.active_jobs[self.active_jobs.index(None)] = events
128
+ task = run_coro_in_background(self.process_events, events, batch)
129
+ run_coro_in_background(self.broadcast_live_estimations)
130
+ set_task_name(task, events[0].session_hash, events[0].fn_index, batch)
131
+
132
+ async def start_progress_tracking(self) -> None:
133
+ while not self.stopped:
134
+ if not any(self.active_jobs):
135
+ await asyncio.sleep(self.progress_update_sleep_when_free)
136
+ continue
137
+
138
+ for job in self.active_jobs:
139
+ if job is None:
140
+ continue
141
+ for event in job:
142
+ if event.progress_pending and event.progress:
143
+ event.progress_pending = False
144
+ client_awake = await self.send_message(
145
+ event, event.progress.dict()
146
+ )
147
+ if not client_awake:
148
+ await self.clean_event(event)
149
+
150
+ await asyncio.sleep(self.progress_update_sleep_when_free)
151
+
152
+ def set_progress(
153
+ self,
154
+ event_id: str,
155
+ iterables: List[TrackedIterable] | None,
156
+ ):
157
+ if iterables is None:
158
+ return
159
+ for job in self.active_jobs:
160
+ if job is None:
161
+ continue
162
+ for evt in job:
163
+ if evt._id == event_id:
164
+ progress_data: List[ProgressUnit] = []
165
+ for iterable in iterables:
166
+ progress_unit = ProgressUnit(
167
+ index=iterable.index,
168
+ length=iterable.length,
169
+ unit=iterable.unit,
170
+ progress=iterable.progress,
171
+ desc=iterable.desc,
172
+ )
173
+ progress_data.append(progress_unit)
174
+ evt.progress = Progress(progress_data=progress_data)
175
+ evt.progress_pending = True
176
+
177
+ def push(self, event: Event) -> int | None:
178
+ """
179
+ Add event to queue, or return None if Queue is full
180
+ Parameters:
181
+ event: Event to add to Queue
182
+ Returns:
183
+ rank of submitted Event
184
+ """
185
+ queue_len = len(self.event_queue)
186
+ if self.max_size is not None and queue_len >= self.max_size:
187
+ return None
188
+ self.event_queue.append(event)
189
+ return queue_len
190
+
191
+ async def clean_event(self, event: Event) -> None:
192
+ if event in self.event_queue:
193
+ async with self.delete_lock:
194
+ self.event_queue.remove(event)
195
+
196
+ async def broadcast_live_estimations(self) -> None:
197
+ """
198
+ Runs 2 functions sequentially instead of concurrently. Otherwise dced clients are tried to get deleted twice.
199
+ """
200
+ if self.live_updates:
201
+ await self.broadcast_estimations()
202
+
203
+ async def gather_event_data(self, event: Event) -> bool:
204
+ """
205
+ Gather data for the event
206
+
207
+ Parameters:
208
+ event:
209
+ """
210
+ if not event.data:
211
+ client_awake = await self.send_message(event, {"msg": "send_data"})
212
+ if not client_awake:
213
+ return False
214
+ event.data = await self.get_message(event)
215
+ return True
216
+
217
+ async def notify_clients(self) -> None:
218
+ """
219
+ Notify clients about events statuses in the queue periodically.
220
+ """
221
+ while not self.stopped:
222
+ await asyncio.sleep(self.update_intervals)
223
+ if self.event_queue:
224
+ await self.broadcast_estimations()
225
+
226
+ async def broadcast_estimations(self) -> None:
227
+ estimation = self.get_estimation()
228
+ # Send all messages concurrently
229
+ await asyncio.gather(
230
+ *[
231
+ self.send_estimation(event, estimation, rank)
232
+ for rank, event in enumerate(self.event_queue)
233
+ ]
234
+ )
235
+
236
+ async def send_estimation(
237
+ self, event: Event, estimation: Estimation, rank: int
238
+ ) -> Estimation:
239
+ """
240
+ Send estimation about ETA to the client.
241
+
242
+ Parameters:
243
+ event:
244
+ estimation:
245
+ rank:
246
+ """
247
+ estimation.rank = rank
248
+
249
+ if self.avg_concurrent_process_time is not None:
250
+ estimation.rank_eta = (
251
+ estimation.rank * self.avg_concurrent_process_time
252
+ + self.avg_process_time
253
+ )
254
+ if None not in self.active_jobs:
255
+ # Add estimated amount of time for a thread to get empty
256
+ estimation.rank_eta += self.avg_concurrent_process_time
257
+ client_awake = await self.send_message(event, estimation.dict())
258
+ if not client_awake:
259
+ await self.clean_event(event)
260
+ return estimation
261
+
262
+ def update_estimation(self, duration: float) -> None:
263
+ """
264
+ Update estimation by last x element's average duration.
265
+
266
+ Parameters:
267
+ duration:
268
+ """
269
+ self.duration_history_total += duration
270
+ self.duration_history_count += 1
271
+ self.avg_process_time = (
272
+ self.duration_history_total / self.duration_history_count
273
+ )
274
+ self.avg_concurrent_process_time = self.avg_process_time / min(
275
+ self.max_thread_count, self.duration_history_count
276
+ )
277
+ self.queue_duration = self.avg_concurrent_process_time * len(self.event_queue)
278
+
279
+ def get_estimation(self) -> Estimation:
280
+ return Estimation(
281
+ queue_size=len(self.event_queue),
282
+ avg_event_process_time=self.avg_process_time,
283
+ avg_event_concurrent_process_time=self.avg_concurrent_process_time,
284
+ queue_eta=self.queue_duration,
285
+ )
286
+
287
+ def get_request_params(self, websocket: fastapi.WebSocket) -> Dict[str, Any]:
288
+ return {
289
+ "url": str(websocket.url),
290
+ "headers": dict(websocket.headers),
291
+ "query_params": dict(websocket.query_params),
292
+ "path_params": dict(websocket.path_params),
293
+ "client": dict(host=websocket.client.host, port=websocket.client.port), # type: ignore
294
+ }
295
+
296
+ async def call_prediction(self, events: List[Event], batch: bool):
297
+ data = events[0].data
298
+ assert data is not None, "No event data"
299
+ token = events[0].token
300
+ data.event_id = events[0]._id if not batch else None
301
+ try:
302
+ data.request = self.get_request_params(events[0].websocket)
303
+ except ValueError:
304
+ pass
305
+
306
+ if batch:
307
+ data.data = list(zip(*[event.data.data for event in events if event.data]))
308
+ data.request = [
309
+ self.get_request_params(event.websocket)
310
+ for event in events
311
+ if event.data
312
+ ]
313
+ data.batched = True
314
+
315
+ response = await AsyncRequest(
316
+ method=AsyncRequest.Method.POST,
317
+ url=f"{self.server_path}api/predict",
318
+ json=dict(data),
319
+ headers={"Authorization": f"Bearer {self.access_token}"},
320
+ cookies={"access-token": token} if token is not None else None,
321
+ )
322
+ return response
323
+
324
+ async def process_events(self, events: List[Event], batch: bool) -> None:
325
+ awake_events: List[Event] = []
326
+ try:
327
+ for event in events:
328
+ client_awake = await self.gather_event_data(event)
329
+ if client_awake:
330
+ client_awake = await self.send_message(
331
+ event, {"msg": "process_starts"}
332
+ )
333
+ if client_awake:
334
+ awake_events.append(event)
335
+ if not awake_events:
336
+ return
337
+ begin_time = time.time()
338
+ response = await self.call_prediction(awake_events, batch)
339
+ if response.has_exception:
340
+ for event in awake_events:
341
+ await self.send_message(
342
+ event,
343
+ {
344
+ "msg": "process_completed",
345
+ "output": {"error": str(response.exception)},
346
+ "success": False,
347
+ },
348
+ )
349
+ elif response.json.get("is_generating", False):
350
+ old_response = response
351
+ while response.json.get("is_generating", False):
352
+ # Python 3.7 doesn't have named tasks.
353
+ # In order to determine if a task was cancelled, we
354
+ # ping the websocket to see if it was closed mid-iteration.
355
+ if sys.version_info < (3, 8):
356
+ is_alive = await self.send_message(event, {"msg": "alive?"})
357
+ if not is_alive:
358
+ return
359
+ old_response = response
360
+ open_ws = []
361
+ for event in awake_events:
362
+ open = await self.send_message(
363
+ event,
364
+ {
365
+ "msg": "process_generating",
366
+ "output": old_response.json,
367
+ "success": old_response.status == 200,
368
+ },
369
+ )
370
+ open_ws.append(open)
371
+ awake_events = [
372
+ e for e, is_open in zip(awake_events, open_ws) if is_open
373
+ ]
374
+ if not awake_events:
375
+ return
376
+ response = await self.call_prediction(awake_events, batch)
377
+ for event in awake_events:
378
+ if response.status != 200:
379
+ relevant_response = response
380
+ else:
381
+ relevant_response = old_response
382
+
383
+ await self.send_message(
384
+ event,
385
+ {
386
+ "msg": "process_completed",
387
+ "output": relevant_response.json,
388
+ "success": relevant_response.status == 200,
389
+ },
390
+ )
391
+ else:
392
+ output = copy.deepcopy(response.json)
393
+ for e, event in enumerate(awake_events):
394
+ if batch and "data" in output:
395
+ output["data"] = list(zip(*response.json.get("data")))[e]
396
+ await self.send_message(
397
+ event,
398
+ {
399
+ "msg": "process_completed",
400
+ "output": output,
401
+ "success": response.status == 200,
402
+ },
403
+ )
404
+ end_time = time.time()
405
+ if response.status == 200:
406
+ self.update_estimation(end_time - begin_time)
407
+ finally:
408
+ for event in awake_events:
409
+ try:
410
+ await event.disconnect()
411
+ except Exception:
412
+ pass
413
+ self.active_jobs[self.active_jobs.index(events)] = None
414
+ for event in awake_events:
415
+ await self.clean_event(event)
416
+ # Always reset the state of the iterator
417
+ # If the job finished successfully, this has no effect
418
+ # If the job is cancelled, this will enable future runs
419
+ # to start "from scratch"
420
+ await self.reset_iterators(event.session_hash, event.fn_index)
421
+
422
+ async def send_message(self, event, data: Dict) -> bool:
423
+ try:
424
+ await event.websocket.send_json(data=data)
425
+ return True
426
+ except:
427
+ await self.clean_event(event)
428
+ return False
429
+
430
+ async def get_message(self, event) -> PredictBody | None:
431
+ try:
432
+ data = await event.websocket.receive_json()
433
+ return PredictBody(**data)
434
+ except:
435
+ await self.clean_event(event)
436
+ return None
437
+
438
+ async def reset_iterators(self, session_hash: str, fn_index: int):
439
+ await AsyncRequest(
440
+ method=AsyncRequest.Method.POST,
441
+ url=f"{self.server_path}reset",
442
+ json={
443
+ "session_hash": session_hash,
444
+ "fn_index": fn_index,
445
+ },
446
+ )
gradio-modified/gradio/reload.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+
3
+ Contains the functions that run when `gradio` is called from the command line. Specifically, allows
4
+
5
+ $ gradio app.py, to run app.py in reload mode where any changes in the app.py file or Gradio library reloads the demo.
6
+ $ gradio app.py my_demo, to use variable names other than "demo"
7
+ """
8
+ import inspect
9
+ import os
10
+ import sys
11
+ from pathlib import Path
12
+
13
+ import gradio
14
+ from gradio import networking
15
+
16
+
17
+ def run_in_reload_mode():
18
+ args = sys.argv[1:]
19
+ if len(args) == 0:
20
+ raise ValueError("No file specified.")
21
+ if len(args) == 1:
22
+ demo_name = "demo"
23
+ else:
24
+ demo_name = args[1]
25
+
26
+ original_path = args[0]
27
+ abs_original_path = Path(original_path).name
28
+ path = str(Path(original_path).resolve())
29
+ path = path.replace("/", ".")
30
+ path = path.replace("\\", ".")
31
+ filename = Path(path).stem
32
+
33
+ gradio_folder = Path(inspect.getfile(gradio)).parent
34
+
35
+ port = networking.get_first_available_port(
36
+ networking.INITIAL_PORT_VALUE,
37
+ networking.INITIAL_PORT_VALUE + networking.TRY_NUM_PORTS,
38
+ )
39
+ print(
40
+ f"\nLaunching in *reload mode* on: http://{networking.LOCALHOST_NAME}:{port} (Press CTRL+C to quit)\n"
41
+ )
42
+ command = f"uvicorn {filename}:{demo_name}.app --reload --port {port} --log-level warning "
43
+ message = "Watching:"
44
+
45
+ message_change_count = 0
46
+ if str(gradio_folder).strip():
47
+ command += f'--reload-dir "{gradio_folder}" '
48
+ message += f" '{gradio_folder}'"
49
+ message_change_count += 1
50
+
51
+ abs_parent = Path(abs_original_path).parent
52
+ if str(abs_parent).strip():
53
+ command += f'--reload-dir "{abs_parent}"'
54
+ if message_change_count == 1:
55
+ message += ","
56
+ message += f" '{abs_parent}'"
57
+
58
+ print(message + "\n")
59
+ os.system(command)
gradio-modified/gradio/routes.py ADDED
@@ -0,0 +1,622 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Implements a FastAPI server to run the gradio interface. Note that some types in this
2
+ module use the Optional/Union notation so that they work correctly with pydantic."""
3
+
4
+ from __future__ import annotations
5
+
6
+ import asyncio
7
+ import inspect
8
+ import json
9
+ import mimetypes
10
+ import os
11
+ import posixpath
12
+ import secrets
13
+ import traceback
14
+ from collections import defaultdict
15
+ from copy import deepcopy
16
+ from pathlib import Path
17
+ from typing import Any, Dict, List, Optional, Type
18
+ from urllib.parse import urlparse
19
+
20
+ import fastapi
21
+ import markupsafe
22
+ import orjson
23
+ import pkg_resources
24
+ from fastapi import Depends, FastAPI, HTTPException, WebSocket, status
25
+ from fastapi.middleware.cors import CORSMiddleware
26
+ from fastapi.responses import (
27
+ FileResponse,
28
+ HTMLResponse,
29
+ JSONResponse,
30
+ PlainTextResponse,
31
+ )
32
+ from fastapi.security import OAuth2PasswordRequestForm
33
+ from fastapi.templating import Jinja2Templates
34
+ from jinja2.exceptions import TemplateNotFound
35
+ from starlette.responses import RedirectResponse
36
+ from starlette.websockets import WebSocketState
37
+
38
+ import gradio
39
+ from gradio import utils
40
+ from gradio.data_classes import PredictBody, ResetBody
41
+ from gradio.documentation import document, set_documentation_group
42
+ from gradio.exceptions import Error
43
+ from gradio.queueing import Estimation, Event
44
+ from gradio.utils import cancel_tasks, run_coro_in_background, set_task_name
45
+
46
+ mimetypes.init()
47
+
48
+ STATIC_TEMPLATE_LIB = pkg_resources.resource_filename("gradio", "templates/")
49
+ STATIC_PATH_LIB = pkg_resources.resource_filename("gradio", "templates/frontend/static")
50
+ BUILD_PATH_LIB = pkg_resources.resource_filename("gradio", "templates/frontend/assets")
51
+ VERSION_FILE = pkg_resources.resource_filename("gradio", "version.txt")
52
+ with open(VERSION_FILE) as version_file:
53
+ VERSION = version_file.read()
54
+
55
+
56
+ class ORJSONResponse(JSONResponse):
57
+ media_type = "application/json"
58
+
59
+ @staticmethod
60
+ def _render(content: Any) -> bytes:
61
+ return orjson.dumps(
62
+ content,
63
+ option=orjson.OPT_SERIALIZE_NUMPY | orjson.OPT_PASSTHROUGH_DATETIME,
64
+ default=str,
65
+ )
66
+
67
+ def render(self, content: Any) -> bytes:
68
+ return ORJSONResponse._render(content)
69
+
70
+ @staticmethod
71
+ def _render_str(content: Any) -> str:
72
+ return ORJSONResponse._render(content).decode("utf-8")
73
+
74
+
75
+ def toorjson(value):
76
+ return markupsafe.Markup(
77
+ ORJSONResponse._render_str(value)
78
+ .replace("<", "\\u003c")
79
+ .replace(">", "\\u003e")
80
+ .replace("&", "\\u0026")
81
+ .replace("'", "\\u0027")
82
+ )
83
+
84
+
85
+ templates = Jinja2Templates(directory=STATIC_TEMPLATE_LIB)
86
+ templates.env.filters["toorjson"] = toorjson
87
+
88
+
89
+ ###########
90
+ # Auth
91
+ ###########
92
+
93
+
94
+ class App(FastAPI):
95
+ """
96
+ FastAPI App Wrapper
97
+ """
98
+
99
+ def __init__(self, **kwargs):
100
+ self.tokens = {}
101
+ self.auth = None
102
+ self.blocks: gradio.Blocks | None = None
103
+ self.state_holder = {}
104
+ self.iterators = defaultdict(dict)
105
+ self.lock = asyncio.Lock()
106
+ self.queue_token = secrets.token_urlsafe(32)
107
+ self.startup_events_triggered = False
108
+ super().__init__(**kwargs)
109
+
110
+ def configure_app(self, blocks: gradio.Blocks) -> None:
111
+ auth = blocks.auth
112
+ if auth is not None:
113
+ if not callable(auth):
114
+ self.auth = {account[0]: account[1] for account in auth}
115
+ else:
116
+ self.auth = auth
117
+ else:
118
+ self.auth = None
119
+
120
+ self.blocks = blocks
121
+ if hasattr(self.blocks, "_queue"):
122
+ self.blocks._queue.set_access_token(self.queue_token)
123
+ self.cwd = os.getcwd()
124
+ self.favicon_path = blocks.favicon_path
125
+ self.tokens = {}
126
+
127
+ def get_blocks(self) -> gradio.Blocks:
128
+ if self.blocks is None:
129
+ raise ValueError("No Blocks has been configured for this app.")
130
+ return self.blocks
131
+
132
+ @staticmethod
133
+ def create_app(blocks: gradio.Blocks) -> App:
134
+ app = App(default_response_class=ORJSONResponse)
135
+ app.configure_app(blocks)
136
+
137
+ app.add_middleware(
138
+ CORSMiddleware,
139
+ allow_origins=["*"],
140
+ allow_methods=["*"],
141
+ allow_headers=["*"],
142
+ )
143
+
144
+ @app.get("/user")
145
+ @app.get("/user/")
146
+ def get_current_user(request: fastapi.Request) -> Optional[str]:
147
+ token = request.cookies.get("access-token")
148
+ return app.tokens.get(token)
149
+
150
+ @app.get("/login_check")
151
+ @app.get("/login_check/")
152
+ def login_check(user: str = Depends(get_current_user)):
153
+ if app.auth is None or not (user is None):
154
+ return
155
+ raise HTTPException(
156
+ status_code=status.HTTP_401_UNAUTHORIZED, detail="Not authenticated"
157
+ )
158
+
159
+ async def ws_login_check(websocket: WebSocket) -> Optional[str]:
160
+ token = websocket.cookies.get("access-token")
161
+ return token # token is returned to allow request in queue
162
+
163
+ @app.get("/token")
164
+ @app.get("/token/")
165
+ def get_token(request: fastapi.Request) -> dict:
166
+ token = request.cookies.get("access-token")
167
+ return {"token": token, "user": app.tokens.get(token)}
168
+
169
+ @app.get("/app_id")
170
+ @app.get("/app_id/")
171
+ def app_id(request: fastapi.Request) -> dict:
172
+ return {"app_id": app.get_blocks().app_id}
173
+
174
+ @app.post("/login")
175
+ @app.post("/login/")
176
+ def login(form_data: OAuth2PasswordRequestForm = Depends()):
177
+ username, password = form_data.username, form_data.password
178
+ if app.auth is None:
179
+ return RedirectResponse(url="/", status_code=status.HTTP_302_FOUND)
180
+ if (
181
+ not callable(app.auth)
182
+ and username in app.auth
183
+ and app.auth[username] == password
184
+ ) or (callable(app.auth) and app.auth.__call__(username, password)):
185
+ token = secrets.token_urlsafe(16)
186
+ app.tokens[token] = username
187
+ response = RedirectResponse(url="/", status_code=status.HTTP_302_FOUND)
188
+ response.set_cookie(key="access-token", value=token, httponly=True)
189
+ return response
190
+ else:
191
+ raise HTTPException(status_code=400, detail="Incorrect credentials.")
192
+
193
+ ###############
194
+ # Main Routes
195
+ ###############
196
+
197
+ @app.head("/", response_class=HTMLResponse)
198
+ @app.get("/", response_class=HTMLResponse)
199
+ def main(request: fastapi.Request, user: str = Depends(get_current_user)):
200
+ mimetypes.add_type("application/javascript", ".js")
201
+ blocks = app.get_blocks()
202
+
203
+ if app.auth is None or not (user is None):
204
+ config = app.get_blocks().config
205
+ else:
206
+ config = {
207
+ "auth_required": True,
208
+ "auth_message": blocks.auth_message,
209
+ }
210
+
211
+ try:
212
+ template = (
213
+ "frontend/share.html" if blocks.share else "frontend/index.html"
214
+ )
215
+ return templates.TemplateResponse(
216
+ template, {"request": request, "config": config}
217
+ )
218
+ except TemplateNotFound:
219
+ if blocks.share:
220
+ raise ValueError(
221
+ "Did you install Gradio from source files? Share mode only "
222
+ "works when Gradio is installed through the pip package."
223
+ )
224
+ else:
225
+ raise ValueError(
226
+ "Did you install Gradio from source files? You need to build "
227
+ "the frontend by running /scripts/build_frontend.sh"
228
+ )
229
+
230
+ @app.get("/config/", dependencies=[Depends(login_check)])
231
+ @app.get("/config", dependencies=[Depends(login_check)])
232
+ def get_config():
233
+ return app.get_blocks().config
234
+
235
+ @app.get("/static/{path:path}")
236
+ def static_resource(path: str):
237
+ static_file = safe_join(STATIC_PATH_LIB, path)
238
+ if static_file is not None:
239
+ return FileResponse(static_file)
240
+ raise HTTPException(status_code=404, detail="Static file not found")
241
+
242
+ @app.get("/assets/{path:path}")
243
+ def build_resource(path: str):
244
+ build_file = safe_join(BUILD_PATH_LIB, path)
245
+ if build_file is not None:
246
+ return FileResponse(build_file)
247
+ raise HTTPException(status_code=404, detail="Build file not found")
248
+
249
+ @app.get("/favicon.ico")
250
+ async def favicon():
251
+ blocks = app.get_blocks()
252
+ if blocks.favicon_path is None:
253
+ return static_resource("img/logo.svg")
254
+ else:
255
+ return FileResponse(blocks.favicon_path)
256
+
257
+ @app.get("/file={path:path}", dependencies=[Depends(login_check)])
258
+ def file(path: str):
259
+ blocks = app.get_blocks()
260
+ if utils.validate_url(path):
261
+ return RedirectResponse(url=path, status_code=status.HTTP_302_FOUND)
262
+ if Path(app.cwd).resolve() in Path(path).resolve().parents or Path(
263
+ path
264
+ ).resolve() in set().union(*blocks.temp_file_sets):
265
+ return FileResponse(
266
+ Path(path).resolve(), headers={"Accept-Ranges": "bytes"}
267
+ )
268
+ else:
269
+ raise ValueError(
270
+ f"File cannot be fetched: {path}. All files must contained within the Gradio python app working directory, or be a temp file created by the Gradio python app."
271
+ )
272
+
273
+ @app.get("/file/{path:path}", dependencies=[Depends(login_check)])
274
+ def file_deprecated(path: str):
275
+ return file(path)
276
+
277
+ @app.post("/reset/")
278
+ @app.post("/reset")
279
+ async def reset_iterator(body: ResetBody):
280
+ if body.session_hash not in app.iterators:
281
+ return {"success": False}
282
+ async with app.lock:
283
+ app.iterators[body.session_hash][body.fn_index] = None
284
+ app.iterators[body.session_hash]["should_reset"].add(body.fn_index)
285
+ return {"success": True}
286
+
287
+ async def run_predict(
288
+ body: PredictBody,
289
+ request: Request | List[Request],
290
+ fn_index_inferred: int,
291
+ username: str = Depends(get_current_user),
292
+ ):
293
+ if hasattr(body, "session_hash"):
294
+ if body.session_hash not in app.state_holder:
295
+ app.state_holder[body.session_hash] = {
296
+ _id: deepcopy(getattr(block, "value", None))
297
+ for _id, block in app.get_blocks().blocks.items()
298
+ if getattr(block, "stateful", False)
299
+ }
300
+ session_state = app.state_holder[body.session_hash]
301
+ iterators = app.iterators[body.session_hash]
302
+ # The should_reset set keeps track of the fn_indices
303
+ # that have been cancelled. When a job is cancelled,
304
+ # the /reset route will mark the jobs as having been reset.
305
+ # That way if the cancel job finishes BEFORE the job being cancelled
306
+ # the job being cancelled will not overwrite the state of the iterator.
307
+ # In all cases, should_reset will be the empty set the next time
308
+ # the fn_index is run.
309
+ app.iterators[body.session_hash]["should_reset"] = set([])
310
+ else:
311
+ session_state = {}
312
+ iterators = {}
313
+ event_id = getattr(body, "event_id", None)
314
+ raw_input = body.data
315
+ fn_index = body.fn_index
316
+ batch = app.get_blocks().dependencies[fn_index_inferred]["batch"]
317
+ if not (body.batched) and batch:
318
+ raw_input = [raw_input]
319
+ try:
320
+ output = await app.get_blocks().process_api(
321
+ fn_index=fn_index_inferred,
322
+ inputs=raw_input,
323
+ request=request,
324
+ state=session_state,
325
+ iterators=iterators,
326
+ event_id=event_id,
327
+ )
328
+ iterator = output.pop("iterator", None)
329
+ if hasattr(body, "session_hash"):
330
+ if fn_index in app.iterators[body.session_hash]["should_reset"]:
331
+ app.iterators[body.session_hash][fn_index] = None
332
+ else:
333
+ app.iterators[body.session_hash][fn_index] = iterator
334
+ if isinstance(output, Error):
335
+ raise output
336
+ except BaseException as error:
337
+ show_error = app.get_blocks().show_error or isinstance(error, Error)
338
+ traceback.print_exc()
339
+ return JSONResponse(
340
+ content={"error": str(error) if show_error else None},
341
+ status_code=500,
342
+ )
343
+
344
+ if not (body.batched) and batch:
345
+ output["data"] = output["data"][0]
346
+ return output
347
+
348
+ # had to use '/run' endpoint for Colab compatibility, '/api' supported for backwards compatibility
349
+ @app.post("/run/{api_name}", dependencies=[Depends(login_check)])
350
+ @app.post("/run/{api_name}/", dependencies=[Depends(login_check)])
351
+ @app.post("/api/{api_name}", dependencies=[Depends(login_check)])
352
+ @app.post("/api/{api_name}/", dependencies=[Depends(login_check)])
353
+ async def predict(
354
+ api_name: str,
355
+ body: PredictBody,
356
+ request: fastapi.Request,
357
+ username: str = Depends(get_current_user),
358
+ ):
359
+ fn_index_inferred = None
360
+ if body.fn_index is None:
361
+ for i, fn in enumerate(app.get_blocks().dependencies):
362
+ if fn["api_name"] == api_name:
363
+ fn_index_inferred = i
364
+ break
365
+ if fn_index_inferred is None:
366
+ return JSONResponse(
367
+ content={
368
+ "error": f"This app has no endpoint /api/{api_name}/."
369
+ },
370
+ status_code=500,
371
+ )
372
+ else:
373
+ fn_index_inferred = body.fn_index
374
+ if not app.get_blocks().api_open and app.get_blocks().queue_enabled_for_fn(
375
+ fn_index_inferred
376
+ ):
377
+ if f"Bearer {app.queue_token}" != request.headers.get("Authorization"):
378
+ raise HTTPException(
379
+ status_code=status.HTTP_401_UNAUTHORIZED,
380
+ detail="Not authorized to skip the queue",
381
+ )
382
+
383
+ # If this fn_index cancels jobs, then the only input we need is the
384
+ # current session hash
385
+ if app.get_blocks().dependencies[fn_index_inferred]["cancels"]:
386
+ body.data = [body.session_hash]
387
+ if body.request:
388
+ if body.batched:
389
+ gr_request = [Request(**req) for req in body.request]
390
+ else:
391
+ assert isinstance(body.request, dict)
392
+ gr_request = Request(**body.request)
393
+ else:
394
+ gr_request = Request(request)
395
+ result = await run_predict(
396
+ body=body,
397
+ fn_index_inferred=fn_index_inferred,
398
+ username=username,
399
+ request=gr_request,
400
+ )
401
+ return result
402
+
403
+ @app.websocket("/queue/join")
404
+ async def join_queue(
405
+ websocket: WebSocket,
406
+ token: Optional[str] = Depends(ws_login_check),
407
+ ):
408
+ blocks = app.get_blocks()
409
+ if app.auth is not None and token is None:
410
+ await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
411
+ return
412
+ if blocks._queue.server_path is None:
413
+ app_url = get_server_url_from_ws_url(str(websocket.url))
414
+ blocks._queue.set_url(app_url)
415
+ await websocket.accept()
416
+ # In order to cancel jobs, we need the session_hash and fn_index
417
+ # to create a unique id for each job
418
+ await websocket.send_json({"msg": "send_hash"})
419
+ session_info = await websocket.receive_json()
420
+ event = Event(
421
+ websocket, session_info["session_hash"], session_info["fn_index"]
422
+ )
423
+ # set the token into Event to allow using the same token for call_prediction
424
+ event.token = token
425
+ event.session_hash = session_info["session_hash"]
426
+
427
+ # Continuous events are not put in the queue so that they do not
428
+ # occupy the queue's resource as they are expected to run forever
429
+ if blocks.dependencies[event.fn_index].get("every", 0):
430
+ await cancel_tasks(set([f"{event.session_hash}_{event.fn_index}"]))
431
+ await blocks._queue.reset_iterators(event.session_hash, event.fn_index)
432
+ task = run_coro_in_background(
433
+ blocks._queue.process_events, [event], False
434
+ )
435
+ set_task_name(task, event.session_hash, event.fn_index, batch=False)
436
+ else:
437
+ rank = blocks._queue.push(event)
438
+
439
+ if rank is None:
440
+ await blocks._queue.send_message(event, {"msg": "queue_full"})
441
+ await event.disconnect()
442
+ return
443
+ estimation = blocks._queue.get_estimation()
444
+ await blocks._queue.send_estimation(event, estimation, rank)
445
+ while True:
446
+ await asyncio.sleep(60)
447
+ if websocket.application_state == WebSocketState.DISCONNECTED:
448
+ return
449
+
450
+ @app.get(
451
+ "/queue/status",
452
+ dependencies=[Depends(login_check)],
453
+ response_model=Estimation,
454
+ )
455
+ async def get_queue_status():
456
+ return app.get_blocks()._queue.get_estimation()
457
+
458
+ @app.get("/startup-events")
459
+ async def startup_events():
460
+ if not app.startup_events_triggered:
461
+ app.get_blocks().startup_events()
462
+ app.startup_events_triggered = True
463
+ return True
464
+ return False
465
+
466
+ @app.get("/robots.txt", response_class=PlainTextResponse)
467
+ def robots_txt():
468
+ if app.get_blocks().share:
469
+ return "User-agent: *\nDisallow: /"
470
+ else:
471
+ return "User-agent: *\nDisallow: "
472
+
473
+ return app
474
+
475
+
476
+ ########
477
+ # Helper functions
478
+ ########
479
+
480
+
481
+ def safe_join(directory: str, path: str) -> str | None:
482
+ """Safely path to a base directory to avoid escaping the base directory.
483
+ Borrowed from: werkzeug.security.safe_join"""
484
+ _os_alt_seps: List[str] = list(
485
+ sep for sep in [os.path.sep, os.path.altsep] if sep is not None and sep != "/"
486
+ )
487
+
488
+ if path != "":
489
+ filename = posixpath.normpath(path)
490
+ else:
491
+ return directory
492
+
493
+ if (
494
+ any(sep in filename for sep in _os_alt_seps)
495
+ or os.path.isabs(filename)
496
+ or filename == ".."
497
+ or filename.startswith("../")
498
+ ):
499
+ return None
500
+ return posixpath.join(directory, filename)
501
+
502
+
503
+ def get_types(cls_set: List[Type]):
504
+ docset = []
505
+ types = []
506
+ for cls in cls_set:
507
+ doc = inspect.getdoc(cls) or ""
508
+ doc_lines = doc.split("\n")
509
+ for line in doc_lines:
510
+ if "value (" in line:
511
+ types.append(line.split("value (")[1].split(")")[0])
512
+ docset.append(doc_lines[1].split(":")[-1])
513
+ return docset, types
514
+
515
+
516
+ def get_server_url_from_ws_url(ws_url: str):
517
+ ws_url_parsed = urlparse(ws_url)
518
+ scheme = "http" if ws_url_parsed.scheme == "ws" else "https"
519
+ port = f":{ws_url_parsed.port}" if ws_url_parsed.port else ""
520
+ return f"{scheme}://{ws_url_parsed.hostname}{port}{ws_url_parsed.path.replace('queue/join', '')}"
521
+
522
+
523
+ set_documentation_group("routes")
524
+
525
+
526
+ class Obj:
527
+ """
528
+ Using a class to convert dictionaries into objects. Used by the `Request` class.
529
+ Credit: https://www.geeksforgeeks.org/convert-nested-python-dictionary-to-object/
530
+ """
531
+
532
+ def __init__(self, dict1):
533
+ self.__dict__.update(dict1)
534
+
535
+ def __str__(self) -> str:
536
+ return str(self.__dict__)
537
+
538
+ def __repr__(self) -> str:
539
+ return str(self.__dict__)
540
+
541
+
542
+ @document()
543
+ class Request:
544
+ """
545
+ A Gradio request object that can be used to access the request headers, cookies,
546
+ query parameters and other information about the request from within the prediction
547
+ function. The class is a thin wrapper around the fastapi.Request class. Attributes
548
+ of this class include: `headers`, `client`, `query_params`, and `path_params`,
549
+ Example:
550
+ import gradio as gr
551
+ def echo(name, request: gr.Request):
552
+ print("Request headers dictionary:", request.headers)
553
+ print("IP address:", request.client.host)
554
+ return name
555
+ io = gr.Interface(echo, "textbox", "textbox").launch()
556
+ """
557
+
558
+ def __init__(self, request: fastapi.Request | None = None, **kwargs):
559
+ """
560
+ Can be instantiated with either a fastapi.Request or by manually passing in
561
+ attributes (needed for websocket-based queueing).
562
+ Parameters:
563
+ request: A fastapi.Request
564
+ """
565
+ self.request = request
566
+ self.kwargs: Dict = kwargs
567
+
568
+ def dict_to_obj(self, d):
569
+ if isinstance(d, dict):
570
+ return json.loads(json.dumps(d), object_hook=Obj)
571
+ else:
572
+ return d
573
+
574
+ def __getattr__(self, name):
575
+ if self.request:
576
+ return self.dict_to_obj(getattr(self.request, name))
577
+ else:
578
+ try:
579
+ obj = self.kwargs[name]
580
+ except KeyError:
581
+ raise AttributeError(f"'Request' object has no attribute '{name}'")
582
+ return self.dict_to_obj(obj)
583
+
584
+
585
+ @document()
586
+ def mount_gradio_app(
587
+ app: fastapi.FastAPI,
588
+ blocks: gradio.Blocks,
589
+ path: str,
590
+ gradio_api_url: str | None = None,
591
+ ) -> fastapi.FastAPI:
592
+ """Mount a gradio.Blocks to an existing FastAPI application.
593
+
594
+ Parameters:
595
+ app: The parent FastAPI application.
596
+ blocks: The blocks object we want to mount to the parent app.
597
+ path: The path at which the gradio application will be mounted.
598
+ gradio_api_url: The full url at which the gradio app will run. This is only needed if deploying to Huggingface spaces of if the websocket endpoints of your deployed app are on a different network location than the gradio app. If deploying to spaces, set gradio_api_url to 'http://localhost:7860/'
599
+ Example:
600
+ from fastapi import FastAPI
601
+ import gradio as gr
602
+ app = FastAPI()
603
+ @app.get("/")
604
+ def read_main():
605
+ return {"message": "This is your main app"}
606
+ io = gr.Interface(lambda x: "Hello, " + x + "!", "textbox", "textbox")
607
+ app = gr.mount_gradio_app(app, io, path="/gradio")
608
+ # Then run `uvicorn run:app` from the terminal and navigate to http://localhost:8000/gradio.
609
+ """
610
+ blocks.dev_mode = False
611
+ blocks.config = blocks.get_config_file()
612
+ gradio_app = App.create_app(blocks)
613
+
614
+ @app.on_event("startup")
615
+ async def start_queue():
616
+ if gradio_app.get_blocks().enable_queue:
617
+ if gradio_api_url:
618
+ gradio_app.get_blocks()._queue.set_url(gradio_api_url)
619
+ gradio_app.get_blocks().startup_events()
620
+
621
+ app.mount(path, gradio_app)
622
+ return app
gradio-modified/gradio/serializing.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from abc import ABC, abstractmethod
4
+ from pathlib import Path
5
+ from typing import Any, Dict
6
+
7
+ from gradio import processing_utils, utils
8
+
9
+
10
+ class Serializable(ABC):
11
+ @abstractmethod
12
+ def serialize(
13
+ self, x: Any, load_dir: str | Path = "", encryption_key: bytes | None = None
14
+ ):
15
+ """
16
+ Convert data from human-readable format to serialized format for a browser.
17
+ """
18
+ pass
19
+
20
+ @abstractmethod
21
+ def deserialize(
22
+ self,
23
+ x: Any,
24
+ save_dir: str | Path | None = None,
25
+ encryption_key: bytes | None = None,
26
+ ):
27
+ """
28
+ Convert data from serialized format for a browser to human-readable format.
29
+ """
30
+ pass
31
+
32
+
33
+ class SimpleSerializable(Serializable):
34
+ def serialize(
35
+ self, x: Any, load_dir: str | Path = "", encryption_key: bytes | None = None
36
+ ) -> Any:
37
+ """
38
+ Convert data from human-readable format to serialized format. For SimpleSerializable components, this is a no-op.
39
+ Parameters:
40
+ x: Input data to serialize
41
+ load_dir: Ignored
42
+ encryption_key: Ignored
43
+ """
44
+ return x
45
+
46
+ def deserialize(
47
+ self,
48
+ x: Any,
49
+ save_dir: str | Path | None = None,
50
+ encryption_key: bytes | None = None,
51
+ ):
52
+ """
53
+ Convert data from serialized format to human-readable format. For SimpleSerializable components, this is a no-op.
54
+ Parameters:
55
+ x: Input data to deserialize
56
+ save_dir: Ignored
57
+ encryption_key: Ignored
58
+ """
59
+ return x
60
+
61
+
62
+ class ImgSerializable(Serializable):
63
+ def serialize(
64
+ self,
65
+ x: str | None,
66
+ load_dir: str | Path = "",
67
+ encryption_key: bytes | None = None,
68
+ ) -> str | None:
69
+ """
70
+ Convert from human-friendly version of a file (string filepath) to a seralized
71
+ representation (base64).
72
+ Parameters:
73
+ x: String path to file to serialize
74
+ load_dir: Path to directory containing x
75
+ encryption_key: Used to encrypt the file
76
+ """
77
+ if x is None or x == "":
78
+ return None
79
+ return processing_utils.encode_url_or_file_to_base64(
80
+ Path(load_dir) / x, encryption_key=encryption_key
81
+ )
82
+
83
+ def deserialize(
84
+ self,
85
+ x: str | None,
86
+ save_dir: str | Path | None = None,
87
+ encryption_key: bytes | None = None,
88
+ ) -> str | None:
89
+ """
90
+ Convert from serialized representation of a file (base64) to a human-friendly
91
+ version (string filepath). Optionally, save the file to the directory specified by save_dir
92
+ Parameters:
93
+ x: Base64 representation of image to deserialize into a string filepath
94
+ save_dir: Path to directory to save the deserialized image to
95
+ encryption_key: Used to decrypt the file
96
+ """
97
+ if x is None or x == "":
98
+ return None
99
+ file = processing_utils.decode_base64_to_file(
100
+ x, dir=save_dir, encryption_key=encryption_key
101
+ )
102
+ return file.name
103
+
104
+
105
+ class FileSerializable(Serializable):
106
+ def serialize(
107
+ self,
108
+ x: str | None,
109
+ load_dir: str | Path = "",
110
+ encryption_key: bytes | None = None,
111
+ ) -> Dict | None:
112
+ """
113
+ Convert from human-friendly version of a file (string filepath) to a
114
+ seralized representation (base64)
115
+ Parameters:
116
+ x: String path to file to serialize
117
+ load_dir: Path to directory containing x
118
+ encryption_key: Used to encrypt the file
119
+ """
120
+ if x is None or x == "":
121
+ return None
122
+ filename = Path(load_dir) / x
123
+ return {
124
+ "name": filename,
125
+ "data": processing_utils.encode_url_or_file_to_base64(
126
+ filename, encryption_key=encryption_key
127
+ ),
128
+ "orig_name": Path(filename).name,
129
+ "is_file": False,
130
+ }
131
+
132
+ def deserialize(
133
+ self,
134
+ x: str | Dict | None,
135
+ save_dir: Path | str | None = None,
136
+ encryption_key: bytes | None = None,
137
+ ) -> str | None:
138
+ """
139
+ Convert from serialized representation of a file (base64) to a human-friendly
140
+ version (string filepath). Optionally, save the file to the directory specified by `save_dir`
141
+ Parameters:
142
+ x: Base64 representation of file to deserialize into a string filepath
143
+ save_dir: Path to directory to save the deserialized file to
144
+ encryption_key: Used to decrypt the file
145
+ """
146
+ if x is None:
147
+ return None
148
+ if isinstance(save_dir, Path):
149
+ save_dir = str(save_dir)
150
+ if isinstance(x, str):
151
+ file_name = processing_utils.decode_base64_to_file(
152
+ x, dir=save_dir, encryption_key=encryption_key
153
+ ).name
154
+ elif isinstance(x, dict):
155
+ if x.get("is_file", False):
156
+ if utils.validate_url(x["name"]):
157
+ file_name = x["name"]
158
+ else:
159
+ file_name = processing_utils.create_tmp_copy_of_file(
160
+ x["name"], dir=save_dir
161
+ ).name
162
+ else:
163
+ file_name = processing_utils.decode_base64_to_file(
164
+ x["data"], dir=save_dir, encryption_key=encryption_key
165
+ ).name
166
+ else:
167
+ raise ValueError(
168
+ f"A FileSerializable component cannot only deserialize a string or a dict, not a: {type(x)}"
169
+ )
170
+ return file_name
171
+
172
+
173
+ class JSONSerializable(Serializable):
174
+ def serialize(
175
+ self,
176
+ x: str | None,
177
+ load_dir: str | Path = "",
178
+ encryption_key: bytes | None = None,
179
+ ) -> Dict | None:
180
+ """
181
+ Convert from a a human-friendly version (string path to json file) to a
182
+ serialized representation (json string)
183
+ Parameters:
184
+ x: String path to json file to read to get json string
185
+ load_dir: Path to directory containing x
186
+ encryption_key: Ignored
187
+ """
188
+ if x is None or x == "":
189
+ return None
190
+ return processing_utils.file_to_json(Path(load_dir) / x)
191
+
192
+ def deserialize(
193
+ self,
194
+ x: str | Dict,
195
+ save_dir: str | Path | None = None,
196
+ encryption_key: bytes | None = None,
197
+ ) -> str | None:
198
+ """
199
+ Convert from serialized representation (json string) to a human-friendly
200
+ version (string path to json file). Optionally, save the file to the directory specified by `save_dir`
201
+ Parameters:
202
+ x: Json string
203
+ save_dir: Path to save the deserialized json file to
204
+ encryption_key: Ignored
205
+ """
206
+ if x is None:
207
+ return None
208
+ return processing_utils.dict_or_str_to_json_file(x, dir=save_dir).name
gradio-modified/gradio/strings.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+
3
+ import requests
4
+
5
+ MESSAGING_API_ENDPOINT = "https://api.gradio.app/gradio-messaging/en"
6
+
7
+ en = {
8
+ "RUNNING_LOCALLY": "Running on local URL: {}",
9
+ "RUNNING_LOCALLY_SEPARATED": "Running on local URL: {}://{}:{}",
10
+ "SHARE_LINK_DISPLAY": "Running on public URL: {}",
11
+ "COULD_NOT_GET_SHARE_LINK": "\nCould not create share link, please check your internet connection.",
12
+ "COLAB_NO_LOCAL": "Cannot display local interface on google colab, public link created.",
13
+ "PUBLIC_SHARE_TRUE": "\nTo create a public link, set `share=True` in `launch()`.",
14
+ "MODEL_PUBLICLY_AVAILABLE_URL": "Model available publicly at: {} (may take up to a minute for link to be usable)",
15
+ "GENERATING_PUBLIC_LINK": "Generating public link (may take a few seconds...):",
16
+ "BETA_INVITE": "\nThanks for being a Gradio user! If you have questions or feedback, please join our Discord server and chat with us: https://discord.gg/feTf9x3ZSB",
17
+ "COLAB_DEBUG_TRUE": "Colab notebook detected. This cell will run indefinitely so that you can see errors and logs. "
18
+ "To turn off, set debug=False in launch().",
19
+ "COLAB_DEBUG_FALSE": "Colab notebook detected. To show errors in colab notebook, set debug=True in launch()",
20
+ "COLAB_WARNING": "Note: opening Chrome Inspector may crash demo inside Colab notebooks.",
21
+ "SHARE_LINK_MESSAGE": "\nThis share link expires in 72 hours. For free permanent hosting and GPU upgrades (NEW!), check out Spaces: https://huggingface.co/spaces",
22
+ "INLINE_DISPLAY_BELOW": "Interface loading below...",
23
+ "TIPS": [
24
+ "You can add authentication to your app with the `auth=` kwarg in the `launch()` command; for example: `gr.Interface(...).launch(auth=('username', 'password'))`",
25
+ "Let users specify why they flagged input with the `flagging_options=` kwarg; for example: `gr.Interface(..., flagging_options=['too slow', 'incorrect output', 'other'])`",
26
+ "You can show or hide the button for flagging with the `allow_flagging=` kwarg; for example: gr.Interface(..., allow_flagging=False)",
27
+ "The inputs and outputs flagged by the users are stored in the flagging directory, specified by the flagging_dir= kwarg. You can view this data through the interface by setting the examples= kwarg to the flagging directory; for example gr.Interface(..., examples='flagged')",
28
+ "You can add a title and description to your interface using the `title=` and `description=` kwargs. The `article=` kwarg can be used to add a description under the interface; for example gr.Interface(..., title='My app', description='Lorem ipsum'). Try using Markdown!",
29
+ "For a classification or regression model, set `interpretation='default'` to see why the model made a prediction.",
30
+ ],
31
+ }
32
+
33
+ try:
34
+ updated_messaging = requests.get(MESSAGING_API_ENDPOINT, timeout=3).json()
35
+ en.update(updated_messaging)
36
+ except (
37
+ requests.ConnectionError,
38
+ requests.exceptions.ReadTimeout,
39
+ json.decoder.JSONDecodeError,
40
+ ): # Use default messaging
41
+ pass
gradio-modified/gradio/templates.py ADDED
@@ -0,0 +1,563 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import typing
4
+ from typing import Any, Callable, Tuple
5
+
6
+ import numpy as np
7
+ from PIL.Image import Image
8
+
9
+ from gradio import components
10
+
11
+
12
+ class TextArea(components.Textbox):
13
+ """
14
+ Sets: lines=7
15
+ """
16
+
17
+ is_template = True
18
+
19
+ def __init__(
20
+ self,
21
+ value: str | Callable | None = "",
22
+ *,
23
+ lines: int = 7,
24
+ max_lines: int = 20,
25
+ placeholder: str | None = None,
26
+ label: str | None = None,
27
+ show_label: bool = True,
28
+ interactive: bool | None = None,
29
+ visible: bool = True,
30
+ elem_id: str | None = None,
31
+ **kwargs,
32
+ ):
33
+ super().__init__(
34
+ value=value,
35
+ lines=lines,
36
+ max_lines=max_lines,
37
+ placeholder=placeholder,
38
+ label=label,
39
+ show_label=show_label,
40
+ interactive=interactive,
41
+ visible=visible,
42
+ elem_id=elem_id,
43
+ **kwargs,
44
+ )
45
+
46
+
47
+ class Webcam(components.Image):
48
+ """
49
+ Sets: source="webcam", interactive=True
50
+ """
51
+
52
+ is_template = True
53
+
54
+ def __init__(
55
+ self,
56
+ value: str | Image | np.ndarray | None = None,
57
+ *,
58
+ shape: Tuple[int, int] | None = None,
59
+ image_mode: str = "RGB",
60
+ invert_colors: bool = False,
61
+ source: str = "webcam",
62
+ tool: str | None = None,
63
+ type: str = "numpy",
64
+ label: str | None = None,
65
+ show_label: bool = True,
66
+ interactive: bool | None = True,
67
+ visible: bool = True,
68
+ streaming: bool = False,
69
+ elem_id: str | None = None,
70
+ mirror_webcam: bool = True,
71
+ **kwargs,
72
+ ):
73
+ super().__init__(
74
+ value=value,
75
+ shape=shape,
76
+ image_mode=image_mode,
77
+ invert_colors=invert_colors,
78
+ source=source,
79
+ tool=tool,
80
+ type=type,
81
+ label=label,
82
+ show_label=show_label,
83
+ interactive=interactive,
84
+ visible=visible,
85
+ streaming=streaming,
86
+ elem_id=elem_id,
87
+ mirror_webcam=mirror_webcam,
88
+ **kwargs,
89
+ )
90
+
91
+
92
+ class Sketchpad(components.Image):
93
+ """
94
+ Sets: image_mode="L", source="canvas", shape=(28, 28), invert_colors=True, interactive=True
95
+ """
96
+
97
+ is_template = True
98
+
99
+ def __init__(
100
+ self,
101
+ value: str | Image | np.ndarray | None = None,
102
+ *,
103
+ shape: Tuple[int, int] = (28, 28),
104
+ image_mode: str = "L",
105
+ invert_colors: bool = True,
106
+ source: str = "canvas",
107
+ tool: str | None = None,
108
+ type: str = "numpy",
109
+ label: str | None = None,
110
+ show_label: bool = True,
111
+ interactive: bool | None = True,
112
+ visible: bool = True,
113
+ streaming: bool = False,
114
+ elem_id: str | None = None,
115
+ mirror_webcam: bool = True,
116
+ **kwargs,
117
+ ):
118
+ super().__init__(
119
+ value=value,
120
+ shape=shape,
121
+ image_mode=image_mode,
122
+ invert_colors=invert_colors,
123
+ source=source,
124
+ tool=tool,
125
+ type=type,
126
+ label=label,
127
+ show_label=show_label,
128
+ interactive=interactive,
129
+ visible=visible,
130
+ streaming=streaming,
131
+ elem_id=elem_id,
132
+ mirror_webcam=mirror_webcam,
133
+ **kwargs,
134
+ )
135
+
136
+
137
+ class Paint(components.Image):
138
+ """
139
+ Sets: source="canvas", tool="color-sketch", interactive=True
140
+ """
141
+
142
+ is_template = True
143
+
144
+ def __init__(
145
+ self,
146
+ value: str | Image | np.ndarray | None = None,
147
+ *,
148
+ shape: Tuple[int, int] | None = None,
149
+ image_mode: str = "RGB",
150
+ invert_colors: bool = False,
151
+ source: str = "canvas",
152
+ tool: str = "color-sketch",
153
+ type: str = "numpy",
154
+ label: str | None = None,
155
+ show_label: bool = True,
156
+ interactive: bool | None = True,
157
+ visible: bool = True,
158
+ streaming: bool = False,
159
+ elem_id: str | None = None,
160
+ mirror_webcam: bool = True,
161
+ **kwargs,
162
+ ):
163
+ super().__init__(
164
+ value=value,
165
+ shape=shape,
166
+ image_mode=image_mode,
167
+ invert_colors=invert_colors,
168
+ source=source,
169
+ tool=tool,
170
+ type=type,
171
+ label=label,
172
+ show_label=show_label,
173
+ interactive=interactive,
174
+ visible=visible,
175
+ streaming=streaming,
176
+ elem_id=elem_id,
177
+ mirror_webcam=mirror_webcam,
178
+ **kwargs,
179
+ )
180
+
181
+
182
+ class ImageMask(components.Image):
183
+ """
184
+ Sets: source="upload", tool="sketch", interactive=True
185
+ """
186
+
187
+ is_template = True
188
+
189
+ def __init__(
190
+ self,
191
+ value: str | Image | np.ndarray | None = None,
192
+ *,
193
+ shape: Tuple[int, int] | None = None,
194
+ image_mode: str = "RGB",
195
+ invert_colors: bool = False,
196
+ source: str = "upload",
197
+ tool: str = "sketch",
198
+ type: str = "numpy",
199
+ label: str | None = None,
200
+ show_label: bool = True,
201
+ interactive: bool | None = True,
202
+ visible: bool = True,
203
+ streaming: bool = False,
204
+ elem_id: str | None = None,
205
+ mirror_webcam: bool = True,
206
+ **kwargs,
207
+ ):
208
+ super().__init__(
209
+ value=value,
210
+ shape=shape,
211
+ image_mode=image_mode,
212
+ invert_colors=invert_colors,
213
+ source=source,
214
+ tool=tool,
215
+ type=type,
216
+ label=label,
217
+ show_label=show_label,
218
+ interactive=interactive,
219
+ visible=visible,
220
+ streaming=streaming,
221
+ elem_id=elem_id,
222
+ mirror_webcam=mirror_webcam,
223
+ **kwargs,
224
+ )
225
+
226
+
227
+ class ImagePaint(components.Image):
228
+ """
229
+ Sets: source="upload", tool="color-sketch", interactive=True
230
+ """
231
+
232
+ is_template = True
233
+
234
+ def __init__(
235
+ self,
236
+ value: str | Image | np.ndarray | None = None,
237
+ *,
238
+ shape: Tuple[int, int] | None = None,
239
+ image_mode: str = "RGB",
240
+ invert_colors: bool = False,
241
+ source: str = "upload",
242
+ tool: str = "color-sketch",
243
+ type: str = "numpy",
244
+ label: str | None = None,
245
+ show_label: bool = True,
246
+ interactive: bool | None = True,
247
+ visible: bool = True,
248
+ streaming: bool = False,
249
+ elem_id: str | None = None,
250
+ mirror_webcam: bool = True,
251
+ **kwargs,
252
+ ):
253
+ super().__init__(
254
+ value=value,
255
+ shape=shape,
256
+ image_mode=image_mode,
257
+ invert_colors=invert_colors,
258
+ source=source,
259
+ tool=tool,
260
+ type=type,
261
+ label=label,
262
+ show_label=show_label,
263
+ interactive=interactive,
264
+ visible=visible,
265
+ streaming=streaming,
266
+ elem_id=elem_id,
267
+ mirror_webcam=mirror_webcam,
268
+ **kwargs,
269
+ )
270
+
271
+
272
+ class Pil(components.Image):
273
+ """
274
+ Sets: type="pil"
275
+ """
276
+
277
+ is_template = True
278
+
279
+ def __init__(
280
+ self,
281
+ value: str | Image | np.ndarray | None = None,
282
+ *,
283
+ shape: Tuple[int, int] | None = None,
284
+ image_mode: str = "RGB",
285
+ invert_colors: bool = False,
286
+ source: str = "upload",
287
+ tool: str | None = None,
288
+ type: str = "pil",
289
+ label: str | None = None,
290
+ show_label: bool = True,
291
+ interactive: bool | None = None,
292
+ visible: bool = True,
293
+ streaming: bool = False,
294
+ elem_id: str | None = None,
295
+ mirror_webcam: bool = True,
296
+ **kwargs,
297
+ ):
298
+ super().__init__(
299
+ value=value,
300
+ shape=shape,
301
+ image_mode=image_mode,
302
+ invert_colors=invert_colors,
303
+ source=source,
304
+ tool=tool,
305
+ type=type,
306
+ label=label,
307
+ show_label=show_label,
308
+ interactive=interactive,
309
+ visible=visible,
310
+ streaming=streaming,
311
+ elem_id=elem_id,
312
+ mirror_webcam=mirror_webcam,
313
+ **kwargs,
314
+ )
315
+
316
+
317
+ class PlayableVideo(components.Video):
318
+ """
319
+ Sets: format="mp4"
320
+ """
321
+
322
+ is_template = True
323
+
324
+ def __init__(
325
+ self,
326
+ value: str | Callable | None = None,
327
+ *,
328
+ format: str | None = "mp4",
329
+ source: str = "upload",
330
+ label: str | None = None,
331
+ show_label: bool = True,
332
+ interactive: bool | None = None,
333
+ visible: bool = True,
334
+ elem_id: str | None = None,
335
+ mirror_webcam: bool = True,
336
+ include_audio: bool | None = None,
337
+ **kwargs,
338
+ ):
339
+ super().__init__(
340
+ value=value,
341
+ format=format,
342
+ source=source,
343
+ label=label,
344
+ show_label=show_label,
345
+ interactive=interactive,
346
+ visible=visible,
347
+ elem_id=elem_id,
348
+ mirror_webcam=mirror_webcam,
349
+ include_audio=include_audio,
350
+ **kwargs,
351
+ )
352
+
353
+
354
+ class Microphone(components.Audio):
355
+ """
356
+ Sets: source="microphone"
357
+ """
358
+
359
+ is_template = True
360
+
361
+ def __init__(
362
+ self,
363
+ value: str | Tuple[int, np.ndarray] | Callable | None = None,
364
+ *,
365
+ source: str = "microphone",
366
+ type: str = "numpy",
367
+ label: str | None = None,
368
+ show_label: bool = True,
369
+ interactive: bool | None = None,
370
+ visible: bool = True,
371
+ streaming: bool = False,
372
+ elem_id: str | None = None,
373
+ **kwargs,
374
+ ):
375
+ super().__init__(
376
+ value=value,
377
+ source=source,
378
+ type=type,
379
+ label=label,
380
+ show_label=show_label,
381
+ interactive=interactive,
382
+ visible=visible,
383
+ streaming=streaming,
384
+ elem_id=elem_id,
385
+ **kwargs,
386
+ )
387
+
388
+
389
+ class Files(components.File):
390
+ """
391
+ Sets: file_count="multiple"
392
+ """
393
+
394
+ is_template = True
395
+
396
+ def __init__(
397
+ self,
398
+ value: str | typing.List[str] | Callable | None = None,
399
+ *,
400
+ file_count: str = "multiple",
401
+ type: str = "file",
402
+ label: str | None = None,
403
+ show_label: bool = True,
404
+ interactive: bool | None = None,
405
+ visible: bool = True,
406
+ elem_id: str | None = None,
407
+ **kwargs,
408
+ ):
409
+ super().__init__(
410
+ value=value,
411
+ file_count=file_count,
412
+ type=type,
413
+ label=label,
414
+ show_label=show_label,
415
+ interactive=interactive,
416
+ visible=visible,
417
+ elem_id=elem_id,
418
+ **kwargs,
419
+ )
420
+
421
+
422
+ class Numpy(components.Dataframe):
423
+ """
424
+ Sets: type="numpy"
425
+ """
426
+
427
+ is_template = True
428
+
429
+ def __init__(
430
+ self,
431
+ value: typing.List[typing.List[Any]] | Callable | None = None,
432
+ *,
433
+ headers: typing.List[str] | None = None,
434
+ row_count: int | Tuple[int, str] = (1, "dynamic"),
435
+ col_count: int | Tuple[int, str] | None = None,
436
+ datatype: str | typing.List[str] = "str",
437
+ type: str = "numpy",
438
+ max_rows: int | None = 20,
439
+ max_cols: int | None = None,
440
+ overflow_row_behaviour: str = "paginate",
441
+ label: str | None = None,
442
+ show_label: bool = True,
443
+ interactive: bool | None = None,
444
+ visible: bool = True,
445
+ elem_id: str | None = None,
446
+ wrap: bool = False,
447
+ **kwargs,
448
+ ):
449
+ super().__init__(
450
+ value=value,
451
+ headers=headers,
452
+ row_count=row_count,
453
+ col_count=col_count,
454
+ datatype=datatype,
455
+ type=type,
456
+ max_rows=max_rows,
457
+ max_cols=max_cols,
458
+ overflow_row_behaviour=overflow_row_behaviour,
459
+ label=label,
460
+ show_label=show_label,
461
+ interactive=interactive,
462
+ visible=visible,
463
+ elem_id=elem_id,
464
+ wrap=wrap,
465
+ **kwargs,
466
+ )
467
+
468
+
469
+ class Matrix(components.Dataframe):
470
+ """
471
+ Sets: type="array"
472
+ """
473
+
474
+ is_template = True
475
+
476
+ def __init__(
477
+ self,
478
+ value: typing.List[typing.List[Any]] | Callable | None = None,
479
+ *,
480
+ headers: typing.List[str] | None = None,
481
+ row_count: int | Tuple[int, str] = (1, "dynamic"),
482
+ col_count: int | Tuple[int, str] | None = None,
483
+ datatype: str | typing.List[str] = "str",
484
+ type: str = "array",
485
+ max_rows: int | None = 20,
486
+ max_cols: int | None = None,
487
+ overflow_row_behaviour: str = "paginate",
488
+ label: str | None = None,
489
+ show_label: bool = True,
490
+ interactive: bool | None = None,
491
+ visible: bool = True,
492
+ elem_id: str | None = None,
493
+ wrap: bool = False,
494
+ **kwargs,
495
+ ):
496
+ super().__init__(
497
+ value=value,
498
+ headers=headers,
499
+ row_count=row_count,
500
+ col_count=col_count,
501
+ datatype=datatype,
502
+ type=type,
503
+ max_rows=max_rows,
504
+ max_cols=max_cols,
505
+ overflow_row_behaviour=overflow_row_behaviour,
506
+ label=label,
507
+ show_label=show_label,
508
+ interactive=interactive,
509
+ visible=visible,
510
+ elem_id=elem_id,
511
+ wrap=wrap,
512
+ **kwargs,
513
+ )
514
+
515
+
516
+ class List(components.Dataframe):
517
+ """
518
+ Sets: type="array", col_count=1
519
+ """
520
+
521
+ is_template = True
522
+
523
+ def __init__(
524
+ self,
525
+ value: typing.List[typing.List[Any]] | Callable | None = None,
526
+ *,
527
+ headers: typing.List[str] | None = None,
528
+ row_count: int | Tuple[int, str] = (1, "dynamic"),
529
+ col_count: int | Tuple[int, str] = 1,
530
+ datatype: str | typing.List[str] = "str",
531
+ type: str = "array",
532
+ max_rows: int | None = 20,
533
+ max_cols: int | None = None,
534
+ overflow_row_behaviour: str = "paginate",
535
+ label: str | None = None,
536
+ show_label: bool = True,
537
+ interactive: bool | None = None,
538
+ visible: bool = True,
539
+ elem_id: str | None = None,
540
+ wrap: bool = False,
541
+ **kwargs,
542
+ ):
543
+ super().__init__(
544
+ value=value,
545
+ headers=headers,
546
+ row_count=row_count,
547
+ col_count=col_count,
548
+ datatype=datatype,
549
+ type=type,
550
+ max_rows=max_rows,
551
+ max_cols=max_cols,
552
+ overflow_row_behaviour=overflow_row_behaviour,
553
+ label=label,
554
+ show_label=show_label,
555
+ interactive=interactive,
556
+ visible=visible,
557
+ elem_id=elem_id,
558
+ wrap=wrap,
559
+ **kwargs,
560
+ )
561
+
562
+
563
+ Mic = Microphone
gradio-modified/{templates → gradio/templates}/frontend/assets/BlockLabel.37da86a3.js RENAMED
File without changes
gradio-modified/{templates → gradio/templates}/frontend/assets/CarouselItem.svelte_svelte_type_style_lang.cc0aed40.js RENAMED
File without changes
gradio-modified/{templates → gradio/templates}/frontend/assets/CarouselItem.svelte_svelte_type_style_lang.e110d966.css RENAMED
File without changes
gradio-modified/{templates → gradio/templates}/frontend/assets/Column.06c172ac.js RENAMED
File without changes
gradio-modified/{templates → gradio/templates}/frontend/assets/File.60a988f4.js RENAMED
File without changes
gradio-modified/{templates → gradio/templates}/frontend/assets/Image.4a41f1aa.js RENAMED
File without changes
gradio-modified/{templates → gradio/templates}/frontend/assets/Image.95fa511c.js RENAMED
File without changes
gradio-modified/{templates → gradio/templates}/frontend/assets/Model3D.b44fd6f2.js RENAMED
File without changes
gradio-modified/{templates → gradio/templates}/frontend/assets/ModifyUpload.2cfe71e4.js RENAMED
File without changes
gradio-modified/{templates → gradio/templates}/frontend/assets/Tabs.6b500f1a.js RENAMED
File without changes
gradio-modified/{templates → gradio/templates}/frontend/assets/Upload.5d0148e8.js RENAMED
File without changes
gradio-modified/{templates → gradio/templates}/frontend/assets/Webcam.8816836e.js RENAMED
File without changes
gradio-modified/{templates → gradio/templates}/frontend/assets/_commonjsHelpers.88e99c8f.js RENAMED
File without changes
gradio-modified/{templates → gradio/templates}/frontend/assets/color.509e5f03.js RENAMED
File without changes
gradio-modified/{templates → gradio/templates}/frontend/assets/csv.27f5436c.js RENAMED
File without changes
gradio-modified/{templates → gradio/templates}/frontend/assets/dsv.7fe76a93.js RENAMED
File without changes