microhan commited on
Commit
d4576ce
1 Parent(s): 1aa5ab2

update module gradio

Browse files

update module gradio files

This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +2 -0
  2. gradio/.dockerignore +2 -0
  3. gradio/__init__.py +93 -0
  4. gradio/__pycache__/__init__.cpython-38.pyc +0 -0
  5. gradio/__pycache__/blocks.cpython-38.pyc +0 -0
  6. gradio/__pycache__/components.cpython-38.pyc +0 -0
  7. gradio/__pycache__/context.cpython-38.pyc +0 -0
  8. gradio/__pycache__/data_classes.cpython-38.pyc +0 -0
  9. gradio/__pycache__/deprecation.cpython-38.pyc +0 -0
  10. gradio/__pycache__/documentation.cpython-38.pyc +0 -0
  11. gradio/__pycache__/events.cpython-38.pyc +0 -0
  12. gradio/__pycache__/exceptions.cpython-38.pyc +0 -0
  13. gradio/__pycache__/external.cpython-38.pyc +0 -0
  14. gradio/__pycache__/external_utils.cpython-38.pyc +0 -0
  15. gradio/__pycache__/flagging.cpython-38.pyc +0 -0
  16. gradio/__pycache__/helpers.cpython-38.pyc +0 -0
  17. gradio/__pycache__/inputs.cpython-38.pyc +0 -0
  18. gradio/__pycache__/interface.cpython-38.pyc +0 -0
  19. gradio/__pycache__/interpretation.cpython-38.pyc +0 -0
  20. gradio/__pycache__/ipython_ext.cpython-38.pyc +0 -0
  21. gradio/__pycache__/layouts.cpython-38.pyc +0 -0
  22. gradio/__pycache__/media_data.cpython-38.pyc +0 -0
  23. gradio/__pycache__/mix.cpython-38.pyc +0 -0
  24. gradio/__pycache__/networking.cpython-38.pyc +0 -0
  25. gradio/__pycache__/outputs.cpython-38.pyc +0 -0
  26. gradio/__pycache__/pipelines.cpython-38.pyc +0 -0
  27. gradio/__pycache__/processing_utils.cpython-38.pyc +0 -0
  28. gradio/__pycache__/queueing.cpython-38.pyc +0 -0
  29. gradio/__pycache__/ranged_response.cpython-38.pyc +0 -0
  30. gradio/__pycache__/reload.cpython-38.pyc +0 -0
  31. gradio/__pycache__/routes.cpython-38.pyc +0 -0
  32. gradio/__pycache__/serializing.cpython-38.pyc +0 -0
  33. gradio/__pycache__/strings.cpython-38.pyc +0 -0
  34. gradio/__pycache__/templates.cpython-38.pyc +0 -0
  35. gradio/__pycache__/tunneling.cpython-38.pyc +0 -0
  36. gradio/__pycache__/utils.cpython-38.pyc +0 -0
  37. gradio/blocks.py +1779 -0
  38. gradio/components.py +0 -0
  39. gradio/context.py +18 -0
  40. gradio/data_classes.py +55 -0
  41. gradio/deprecation.py +45 -0
  42. gradio/documentation.py +261 -0
  43. gradio/events.py +298 -0
  44. gradio/exceptions.py +39 -0
  45. gradio/external.py +512 -0
  46. gradio/external_utils.py +185 -0
  47. gradio/flagging.py +555 -0
  48. gradio/helpers.py +839 -0
  49. gradio/inputs.py +473 -0
  50. gradio/interface.py +888 -0
.gitattributes CHANGED
@@ -32,3 +32,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
35
+ gradio/templates/cdn/assets/index.3c2bfbb6.js.map filter=lfs diff=lfs merge=lfs -text
36
+ gradio/templates/frontend/assets/index.756cf7e0.js.map filter=lfs diff=lfs merge=lfs -text
gradio/.dockerignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ templates/frontend
2
+ templates/frontend/**/*
gradio/__init__.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pkgutil
2
+
3
+ import gradio.components as components
4
+ import gradio.inputs as inputs
5
+ import gradio.outputs as outputs
6
+ import gradio.processing_utils
7
+ import gradio.templates
8
+ import gradio.themes as themes
9
+ from gradio.blocks import Blocks
10
+ from gradio.components import (
11
+ HTML,
12
+ JSON,
13
+ Audio,
14
+ BarPlot,
15
+ Button,
16
+ Carousel,
17
+ Chatbot,
18
+ Checkbox,
19
+ Checkboxgroup,
20
+ CheckboxGroup,
21
+ Code,
22
+ ColorPicker,
23
+ DataFrame,
24
+ Dataframe,
25
+ Dataset,
26
+ Dropdown,
27
+ File,
28
+ Gallery,
29
+ Highlight,
30
+ Highlightedtext,
31
+ HighlightedText,
32
+ Image,
33
+ Interpretation,
34
+ Json,
35
+ Label,
36
+ LinePlot,
37
+ Markdown,
38
+ Model3D,
39
+ Number,
40
+ Plot,
41
+ Radio,
42
+ ScatterPlot,
43
+ Slider,
44
+ State,
45
+ StatusTracker,
46
+ Text,
47
+ Textbox,
48
+ TimeSeries,
49
+ Timeseries,
50
+ UploadButton,
51
+ Variable,
52
+ Video,
53
+ component,
54
+ )
55
+ from gradio.events import SelectData
56
+ from gradio.exceptions import Error
57
+ from gradio.flagging import (
58
+ CSVLogger,
59
+ FlaggingCallback,
60
+ HuggingFaceDatasetJSONSaver,
61
+ HuggingFaceDatasetSaver,
62
+ SimpleCSVLogger,
63
+ )
64
+ from gradio.helpers import EventData, Progress
65
+ from gradio.helpers import create_examples as Examples
66
+ from gradio.helpers import make_waveform, skip, update
67
+ from gradio.interface import Interface, TabbedInterface, close_all
68
+ from gradio.ipython_ext import load_ipython_extension
69
+ from gradio.layouts import Accordion, Box, Column, Group, Row, Tab, TabItem, Tabs
70
+ from gradio.mix import Parallel, Series
71
+ from gradio.routes import Request, mount_gradio_app
72
+ from gradio.templates import (
73
+ Files,
74
+ ImageMask,
75
+ ImagePaint,
76
+ List,
77
+ Matrix,
78
+ Mic,
79
+ Microphone,
80
+ Numpy,
81
+ Paint,
82
+ Pil,
83
+ PlayableVideo,
84
+ Sketchpad,
85
+ TextArea,
86
+ Webcam,
87
+ )
88
+ from gradio.themes import Base as Theme
89
+
90
+ current_pkg_version = (
91
+ (pkgutil.get_data(__name__, "version.txt") or b"").decode("ascii").strip()
92
+ )
93
+ __version__ = current_pkg_version
gradio/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (2.62 kB). View file
 
gradio/__pycache__/blocks.cpython-38.pyc ADDED
Binary file (57.4 kB). View file
 
gradio/__pycache__/components.cpython-38.pyc ADDED
Binary file (181 kB). View file
 
gradio/__pycache__/context.cpython-38.pyc ADDED
Binary file (707 Bytes). View file
 
gradio/__pycache__/data_classes.cpython-38.pyc ADDED
Binary file (2.25 kB). View file
 
gradio/__pycache__/deprecation.cpython-38.pyc ADDED
Binary file (1.59 kB). View file
 
gradio/__pycache__/documentation.cpython-38.pyc ADDED
Binary file (6.57 kB). View file
 
gradio/__pycache__/events.cpython-38.pyc ADDED
Binary file (9.84 kB). View file
 
gradio/__pycache__/exceptions.cpython-38.pyc ADDED
Binary file (1.75 kB). View file
 
gradio/__pycache__/external.cpython-38.pyc ADDED
Binary file (14.6 kB). View file
 
gradio/__pycache__/external_utils.cpython-38.pyc ADDED
Binary file (5.83 kB). View file
 
gradio/__pycache__/flagging.cpython-38.pyc ADDED
Binary file (16.4 kB). View file
 
gradio/__pycache__/helpers.cpython-38.pyc ADDED
Binary file (28.4 kB). View file
 
gradio/__pycache__/inputs.cpython-38.pyc ADDED
Binary file (17.1 kB). View file
 
gradio/__pycache__/interface.cpython-38.pyc ADDED
Binary file (27.9 kB). View file
 
gradio/__pycache__/interpretation.cpython-38.pyc ADDED
Binary file (9.13 kB). View file
 
gradio/__pycache__/ipython_ext.cpython-38.pyc ADDED
Binary file (860 Bytes). View file
 
gradio/__pycache__/layouts.cpython-38.pyc ADDED
Binary file (11.6 kB). View file
 
gradio/__pycache__/media_data.cpython-38.pyc ADDED
Binary file (473 kB). View file
 
gradio/__pycache__/mix.cpython-38.pyc ADDED
Binary file (4.78 kB). View file
 
gradio/__pycache__/networking.cpython-38.pyc ADDED
Binary file (6.28 kB). View file
 
gradio/__pycache__/outputs.cpython-38.pyc ADDED
Binary file (12.5 kB). View file
 
gradio/__pycache__/pipelines.cpython-38.pyc ADDED
Binary file (6.86 kB). View file
 
gradio/__pycache__/processing_utils.cpython-38.pyc ADDED
Binary file (22.8 kB). View file
 
gradio/__pycache__/queueing.cpython-38.pyc ADDED
Binary file (12.8 kB). View file
 
gradio/__pycache__/ranged_response.cpython-38.pyc ADDED
Binary file (5.55 kB). View file
 
gradio/__pycache__/reload.cpython-38.pyc ADDED
Binary file (1.75 kB). View file
 
gradio/__pycache__/routes.cpython-38.pyc ADDED
Binary file (24.3 kB). View file
 
gradio/__pycache__/serializing.cpython-38.pyc ADDED
Binary file (6.67 kB). View file
 
gradio/__pycache__/strings.cpython-38.pyc ADDED
Binary file (3.33 kB). View file
 
gradio/__pycache__/templates.cpython-38.pyc ADDED
Binary file (8.88 kB). View file
 
gradio/__pycache__/tunneling.cpython-38.pyc ADDED
Binary file (2.99 kB). View file
 
gradio/__pycache__/utils.cpython-38.pyc ADDED
Binary file (34 kB). View file
 
gradio/blocks.py ADDED
@@ -0,0 +1,1779 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import copy
4
+ import inspect
5
+ import json
6
+ import os
7
+ import random
8
+ import secrets
9
+ import sys
10
+ import time
11
+ import warnings
12
+ import webbrowser
13
+ from abc import abstractmethod
14
+ from types import ModuleType
15
+ from typing import TYPE_CHECKING, Any, Callable, Dict, Iterator, List, Set, Tuple, Type
16
+
17
+ import anyio
18
+ import requests
19
+ from anyio import CapacityLimiter
20
+ from typing_extensions import Literal
21
+
22
+ from gradio import components, external, networking, queueing, routes, strings, utils
23
+ from gradio.context import Context
24
+ from gradio.deprecation import check_deprecated_parameters
25
+ from gradio.documentation import document, set_documentation_group
26
+ from gradio.exceptions import DuplicateBlockError, InvalidApiName
27
+ from gradio.helpers import EventData, create_tracker, skip, special_args
28
+ from gradio.themes import Default as DefaultTheme
29
+ from gradio.themes import ThemeClass as Theme
30
+ from gradio.tunneling import CURRENT_TUNNELS
31
+ from gradio.utils import (
32
+ GRADIO_VERSION,
33
+ TupleNoPrint,
34
+ check_function_inputs_match,
35
+ component_or_layout_class,
36
+ delete_none,
37
+ get_cancel_function,
38
+ get_continuous_fn,
39
+ )
40
+
41
+ set_documentation_group("blocks")
42
+
43
+ if TYPE_CHECKING: # Only import for type checking (is False at runtime).
44
+ import comet_ml
45
+ from fastapi.applications import FastAPI
46
+
47
+ from gradio.components import Component
48
+
49
+
50
+ class Block:
51
+ def __init__(
52
+ self,
53
+ *,
54
+ render: bool = True,
55
+ elem_id: str | None = None,
56
+ elem_classes: List[str] | str | None = None,
57
+ visible: bool = True,
58
+ root_url: str | None = None, # URL that is prepended to all file paths
59
+ _skip_init_processing: bool = False, # Used for loading from Spaces
60
+ **kwargs,
61
+ ):
62
+ self._id = Context.id
63
+ Context.id += 1
64
+ self.visible = visible
65
+ self.elem_id = elem_id
66
+ self.elem_classes = (
67
+ [elem_classes] if isinstance(elem_classes, str) else elem_classes
68
+ )
69
+ self.root_url = root_url
70
+ self.share_token = secrets.token_urlsafe(32)
71
+ self._skip_init_processing = _skip_init_processing
72
+ self._style = {}
73
+ self.parent: BlockContext | None = None
74
+ self.root = ""
75
+
76
+ if render:
77
+ self.render()
78
+ check_deprecated_parameters(self.__class__.__name__, **kwargs)
79
+
80
+ def render(self):
81
+ """
82
+ Adds self into appropriate BlockContext
83
+ """
84
+ if Context.root_block is not None and self._id in Context.root_block.blocks:
85
+ raise DuplicateBlockError(
86
+ f"A block with id: {self._id} has already been rendered in the current Blocks."
87
+ )
88
+ if Context.block is not None:
89
+ Context.block.add(self)
90
+ if Context.root_block is not None:
91
+ Context.root_block.blocks[self._id] = self
92
+ if isinstance(self, components.TempFileManager):
93
+ Context.root_block.temp_file_sets.append(self.temp_files)
94
+ return self
95
+
96
+ def unrender(self):
97
+ """
98
+ Removes self from BlockContext if it has been rendered (otherwise does nothing).
99
+ Removes self from the layout and collection of blocks, but does not delete any event triggers.
100
+ """
101
+ if Context.block is not None:
102
+ try:
103
+ Context.block.children.remove(self)
104
+ except ValueError:
105
+ pass
106
+ if Context.root_block is not None:
107
+ try:
108
+ del Context.root_block.blocks[self._id]
109
+ except KeyError:
110
+ pass
111
+ return self
112
+
113
+ def get_block_name(self) -> str:
114
+ """
115
+ Gets block's class name.
116
+
117
+ If it is template component it gets the parent's class name.
118
+
119
+ @return: class name
120
+ """
121
+ return (
122
+ self.__class__.__base__.__name__.lower()
123
+ if hasattr(self, "is_template")
124
+ else self.__class__.__name__.lower()
125
+ )
126
+
127
+ def get_expected_parent(self) -> Type[BlockContext] | None:
128
+ return None
129
+
130
+ def set_event_trigger(
131
+ self,
132
+ event_name: str,
133
+ fn: Callable | None,
134
+ inputs: Component | List[Component] | Set[Component] | None,
135
+ outputs: Component | List[Component] | None,
136
+ preprocess: bool = True,
137
+ postprocess: bool = True,
138
+ scroll_to_output: bool = False,
139
+ show_progress: bool = True,
140
+ api_name: str | None = None,
141
+ js: str | None = None,
142
+ no_target: bool = False,
143
+ queue: bool | None = None,
144
+ batch: bool = False,
145
+ max_batch_size: int = 4,
146
+ cancels: List[int] | None = None,
147
+ every: float | None = None,
148
+ collects_event_data: bool | None = None,
149
+ trigger_after: int | None = None,
150
+ trigger_only_on_success: bool = False,
151
+ ) -> Tuple[Dict[str, Any], int]:
152
+ """
153
+ Adds an event to the component's dependencies.
154
+ Parameters:
155
+ event_name: event name
156
+ fn: Callable function
157
+ inputs: input list
158
+ outputs: output list
159
+ preprocess: whether to run the preprocess methods of components
160
+ postprocess: whether to run the postprocess methods of components
161
+ scroll_to_output: whether to scroll to output of dependency on trigger
162
+ show_progress: whether to show progress animation while running.
163
+ api_name: Defining this parameter exposes the endpoint in the api docs
164
+ js: Optional frontend js method to run before running 'fn'. Input arguments for js method are values of 'inputs' and 'outputs', return should be a list of values for output components
165
+ no_target: if True, sets "targets" to [], used for Blocks "load" event
166
+ batch: whether this function takes in a batch of inputs
167
+ max_batch_size: the maximum batch size to send to the function
168
+ cancels: a list of other events to cancel when this event is triggered. For example, setting cancels=[click_event] will cancel the click_event, where click_event is the return value of another components .click method.
169
+ every: Run this event 'every' number of seconds while the client connection is open. Interpreted in seconds. Queue must be enabled.
170
+ collects_event_data: whether to collect event data for this event
171
+ trigger_after: if set, this event will be triggered after 'trigger_after' function index
172
+ trigger_only_on_success: if True, this event will only be triggered if the previous event was successful (only applies if `trigger_after` is set)
173
+ Returns: dependency information, dependency index
174
+ """
175
+ # Support for singular parameter
176
+ if isinstance(inputs, set):
177
+ inputs_as_dict = True
178
+ inputs = sorted(inputs, key=lambda x: x._id)
179
+ else:
180
+ inputs_as_dict = False
181
+ if inputs is None:
182
+ inputs = []
183
+ elif not isinstance(inputs, list):
184
+ inputs = [inputs]
185
+
186
+ if isinstance(outputs, set):
187
+ outputs = sorted(outputs, key=lambda x: x._id)
188
+ else:
189
+ if outputs is None:
190
+ outputs = []
191
+ elif not isinstance(outputs, list):
192
+ outputs = [outputs]
193
+
194
+ if fn is not None and not cancels:
195
+ check_function_inputs_match(fn, inputs, inputs_as_dict)
196
+
197
+ if Context.root_block is None:
198
+ raise AttributeError(
199
+ f"{event_name}() and other events can only be called within a Blocks context."
200
+ )
201
+ if every is not None and every <= 0:
202
+ raise ValueError("Parameter every must be positive or None")
203
+ if every and batch:
204
+ raise ValueError(
205
+ f"Cannot run {event_name} event in a batch and every {every} seconds. "
206
+ "Either batch is True or every is non-zero but not both."
207
+ )
208
+
209
+ if every and fn:
210
+ fn = get_continuous_fn(fn, every)
211
+ elif every:
212
+ raise ValueError("Cannot set a value for `every` without a `fn`.")
213
+
214
+ _, progress_index, event_data_index = (
215
+ special_args(fn) if fn else (None, None, None)
216
+ )
217
+ Context.root_block.fns.append(
218
+ BlockFunction(
219
+ fn,
220
+ inputs,
221
+ outputs,
222
+ preprocess,
223
+ postprocess,
224
+ inputs_as_dict,
225
+ progress_index is not None,
226
+ )
227
+ )
228
+ if api_name is not None:
229
+ api_name_ = utils.append_unique_suffix(
230
+ api_name, [dep["api_name"] for dep in Context.root_block.dependencies]
231
+ )
232
+ if not (api_name == api_name_):
233
+ warnings.warn(
234
+ "api_name {} already exists, using {}".format(api_name, api_name_)
235
+ )
236
+ api_name = api_name_
237
+
238
+ if collects_event_data is None:
239
+ collects_event_data = event_data_index is not None
240
+
241
+ dependency = {
242
+ "targets": [self._id] if not no_target else [],
243
+ "trigger": event_name,
244
+ "inputs": [block._id for block in inputs],
245
+ "outputs": [block._id for block in outputs],
246
+ "backend_fn": fn is not None,
247
+ "js": js,
248
+ "queue": False if fn is None else queue,
249
+ "api_name": api_name,
250
+ "scroll_to_output": scroll_to_output,
251
+ "show_progress": show_progress,
252
+ "every": every,
253
+ "batch": batch,
254
+ "max_batch_size": max_batch_size,
255
+ "cancels": cancels or [],
256
+ "types": {
257
+ "continuous": bool(every),
258
+ "generator": inspect.isgeneratorfunction(fn) or bool(every),
259
+ },
260
+ "collects_event_data": collects_event_data,
261
+ "trigger_after": trigger_after,
262
+ "trigger_only_on_success": trigger_only_on_success,
263
+ }
264
+ Context.root_block.dependencies.append(dependency)
265
+ return dependency, len(Context.root_block.dependencies) - 1
266
+
267
+ def get_config(self):
268
+ return {
269
+ "visible": self.visible,
270
+ "elem_id": self.elem_id,
271
+ "elem_classes": self.elem_classes,
272
+ "style": self._style,
273
+ "root_url": self.root_url,
274
+ }
275
+
276
+ @staticmethod
277
+ @abstractmethod
278
+ def update(**kwargs) -> Dict:
279
+ return {}
280
+
281
+ @classmethod
282
+ def get_specific_update(cls, generic_update: Dict[str, Any]) -> Dict:
283
+ generic_update = generic_update.copy()
284
+ del generic_update["__type__"]
285
+ specific_update = cls.update(**generic_update)
286
+ return specific_update
287
+
288
+
289
+ class BlockContext(Block):
290
+ def __init__(
291
+ self,
292
+ visible: bool = True,
293
+ render: bool = True,
294
+ **kwargs,
295
+ ):
296
+ """
297
+ Parameters:
298
+ visible: If False, this will be hidden but included in the Blocks config file (its visibility can later be updated).
299
+ render: If False, this will not be included in the Blocks config file at all.
300
+ """
301
+ self.children: List[Block] = []
302
+ Block.__init__(self, visible=visible, render=render, **kwargs)
303
+
304
+ def __enter__(self):
305
+ self.parent = Context.block
306
+ Context.block = self
307
+ return self
308
+
309
+ def add(self, child: Block):
310
+ child.parent = self
311
+ self.children.append(child)
312
+
313
+ def fill_expected_parents(self):
314
+ children = []
315
+ pseudo_parent = None
316
+ for child in self.children:
317
+ expected_parent = child.get_expected_parent()
318
+ if not expected_parent or isinstance(self, expected_parent):
319
+ pseudo_parent = None
320
+ children.append(child)
321
+ else:
322
+ if pseudo_parent is not None and isinstance(
323
+ pseudo_parent, expected_parent
324
+ ):
325
+ pseudo_parent.children.append(child)
326
+ else:
327
+ pseudo_parent = expected_parent(render=False)
328
+ children.append(pseudo_parent)
329
+ pseudo_parent.children = [child]
330
+ if Context.root_block:
331
+ Context.root_block.blocks[pseudo_parent._id] = pseudo_parent
332
+ child.parent = pseudo_parent
333
+ self.children = children
334
+
335
+ def __exit__(self, *args):
336
+ if getattr(self, "allow_expected_parents", True):
337
+ self.fill_expected_parents()
338
+ Context.block = self.parent
339
+
340
+ def postprocess(self, y):
341
+ """
342
+ Any postprocessing needed to be performed on a block context.
343
+ """
344
+ return y
345
+
346
+
347
+ class BlockFunction:
348
+ def __init__(
349
+ self,
350
+ fn: Callable | None,
351
+ inputs: List[Component],
352
+ outputs: List[Component],
353
+ preprocess: bool,
354
+ postprocess: bool,
355
+ inputs_as_dict: bool,
356
+ tracks_progress: bool = False,
357
+ ):
358
+ self.fn = fn
359
+ self.inputs = inputs
360
+ self.outputs = outputs
361
+ self.preprocess = preprocess
362
+ self.postprocess = postprocess
363
+ self.tracks_progress = tracks_progress
364
+ self.total_runtime = 0
365
+ self.total_runs = 0
366
+ self.inputs_as_dict = inputs_as_dict
367
+ self.name = getattr(fn, "__name__", "fn") if fn is not None else None
368
+
369
+ def __str__(self):
370
+ return str(
371
+ {
372
+ "fn": self.name,
373
+ "preprocess": self.preprocess,
374
+ "postprocess": self.postprocess,
375
+ }
376
+ )
377
+
378
+ def __repr__(self):
379
+ return str(self)
380
+
381
+
382
+ class class_or_instancemethod(classmethod):
383
+ def __get__(self, instance, type_):
384
+ descr_get = super().__get__ if instance is None else self.__func__.__get__
385
+ return descr_get(instance, type_)
386
+
387
+
388
+ def postprocess_update_dict(block: Block, update_dict: Dict, postprocess: bool = True):
389
+ """
390
+ Converts a dictionary of updates into a format that can be sent to the frontend.
391
+ E.g. {"__type__": "generic_update", "value": "2", "interactive": False}
392
+ Into -> {"__type__": "update", "value": 2.0, "mode": "static"}
393
+
394
+ Parameters:
395
+ block: The Block that is being updated with this update dictionary.
396
+ update_dict: The original update dictionary
397
+ postprocess: Whether to postprocess the "value" key of the update dictionary.
398
+ """
399
+ if update_dict.get("__type__", "") == "generic_update":
400
+ update_dict = block.get_specific_update(update_dict)
401
+ if update_dict.get("value") is components._Keywords.NO_VALUE:
402
+ update_dict.pop("value")
403
+ interactive = update_dict.pop("interactive", None)
404
+ if interactive is not None:
405
+ update_dict["mode"] = "dynamic" if interactive else "static"
406
+ prediction_value = delete_none(update_dict, skip_value=True)
407
+ if "value" in prediction_value and postprocess:
408
+ assert isinstance(
409
+ block, components.IOComponent
410
+ ), f"Component {block.__class__} does not support value"
411
+ prediction_value["value"] = block.postprocess(prediction_value["value"])
412
+ return prediction_value
413
+
414
+
415
+ def convert_component_dict_to_list(
416
+ outputs_ids: List[int], predictions: Dict
417
+ ) -> List | Dict:
418
+ """
419
+ Converts a dictionary of component updates into a list of updates in the order of
420
+ the outputs_ids and including every output component. Leaves other types of dictionaries unchanged.
421
+ E.g. {"textbox": "hello", "number": {"__type__": "generic_update", "value": "2"}}
422
+ Into -> ["hello", {"__type__": "generic_update"}, {"__type__": "generic_update", "value": "2"}]
423
+ """
424
+ keys_are_blocks = [isinstance(key, Block) for key in predictions.keys()]
425
+ if all(keys_are_blocks):
426
+ reordered_predictions = [skip() for _ in outputs_ids]
427
+ for component, value in predictions.items():
428
+ if component._id not in outputs_ids:
429
+ raise ValueError(
430
+ f"Returned component {component} not specified as output of function."
431
+ )
432
+ output_index = outputs_ids.index(component._id)
433
+ reordered_predictions[output_index] = value
434
+ predictions = utils.resolve_singleton(reordered_predictions)
435
+ elif any(keys_are_blocks):
436
+ raise ValueError(
437
+ "Returned dictionary included some keys as Components. Either all keys must be Components to assign Component values, or return a List of values to assign output values in order."
438
+ )
439
+ return predictions
440
+
441
+
442
+ @document("launch", "queue", "integrate", "load")
443
+ class Blocks(BlockContext):
444
+ """
445
+ Blocks is Gradio's low-level API that allows you to create more custom web
446
+ applications and demos than Interfaces (yet still entirely in Python).
447
+
448
+
449
+ Compared to the Interface class, Blocks offers more flexibility and control over:
450
+ (1) the layout of components (2) the events that
451
+ trigger the execution of functions (3) data flows (e.g. inputs can trigger outputs,
452
+ which can trigger the next level of outputs). Blocks also offers ways to group
453
+ together related demos such as with tabs.
454
+
455
+
456
+ The basic usage of Blocks is as follows: create a Blocks object, then use it as a
457
+ context (with the "with" statement), and then define layouts, components, or events
458
+ within the Blocks context. Finally, call the launch() method to launch the demo.
459
+
460
+ Example:
461
+ import gradio as gr
462
+ def update(name):
463
+ return f"Welcome to Gradio, {name}!"
464
+
465
+ with gr.Blocks() as demo:
466
+ gr.Markdown("Start typing below and then click **Run** to see the output.")
467
+ with gr.Row():
468
+ inp = gr.Textbox(placeholder="What is your name?")
469
+ out = gr.Textbox()
470
+ btn = gr.Button("Run")
471
+ btn.click(fn=update, inputs=inp, outputs=out)
472
+
473
+ demo.launch()
474
+ Demos: blocks_hello, blocks_flipper, blocks_speech_text_sentiment, generate_english_german, sound_alert
475
+ Guides: blocks_and_event_listeners, controlling_layout, state_in_blocks, custom_CSS_and_JS, custom_interpretations_with_blocks, using_blocks_like_functions
476
+ """
477
+
478
+ def __init__(
479
+ self,
480
+ theme: Theme | str | None = None,
481
+ analytics_enabled: bool | None = None,
482
+ mode: str = "blocks",
483
+ title: str = "Gradio",
484
+ css: str | None = None,
485
+ **kwargs,
486
+ ):
487
+ """
488
+ Parameters:
489
+ analytics_enabled: whether to allow basic telemetry. If None, will use GRADIO_ANALYTICS_ENABLED environment variable or default to True.
490
+ mode: a human-friendly name for the kind of Blocks or Interface being created.
491
+ title: The tab title to display when this is opened in a browser window.
492
+ css: custom css or path to custom css file to apply to entire Blocks
493
+ """
494
+ # Cleanup shared parameters with Interface #TODO: is this part still necessary after Interface with Blocks?
495
+ self.limiter = None
496
+ self.save_to = None
497
+ if theme is None:
498
+ theme = DefaultTheme()
499
+ elif isinstance(theme, str):
500
+ try:
501
+ theme = Theme.from_hub(theme)
502
+ except Exception as e:
503
+ warnings.warn(f"Cannot load {theme}. Caught Exception: {str(e)}")
504
+ theme = DefaultTheme()
505
+ if not isinstance(theme, Theme):
506
+ warnings.warn("Theme should be a class loaded from gradio.themes")
507
+ theme = DefaultTheme()
508
+ self.theme = theme
509
+ self.theme_css = theme._get_theme_css()
510
+ self.stylesheets = theme._stylesheets
511
+ self.encrypt = False
512
+ self.share = False
513
+ self.enable_queue = None
514
+ self.max_threads = 40
515
+ self.show_error = True
516
+ if css is not None and os.path.exists(css):
517
+ with open(css) as css_file:
518
+ self.css = css_file.read()
519
+ else:
520
+ self.css = css
521
+
522
+ # For analytics_enabled and allow_flagging: (1) first check for
523
+ # parameter, (2) check for env variable, (3) default to True/"manual"
524
+ self.analytics_enabled = (
525
+ analytics_enabled
526
+ if analytics_enabled is not None
527
+ else os.getenv("GRADIO_ANALYTICS_ENABLED", "True") == "True"
528
+ )
529
+ if not self.analytics_enabled:
530
+ os.environ["HF_HUB_DISABLE_TELEMETRY"] = "True"
531
+ super().__init__(render=False, **kwargs)
532
+ self.blocks: Dict[int, Block] = {}
533
+ self.fns: List[BlockFunction] = []
534
+ self.dependencies = []
535
+ self.mode = mode
536
+
537
+ self.is_running = False
538
+ self.local_url = None
539
+ self.share_url = None
540
+ self.width = None
541
+ self.height = None
542
+ self.api_open = True
543
+
544
+ self.is_space = True if os.getenv("SYSTEM") == "spaces" else False
545
+ self.favicon_path = None
546
+ self.auth = None
547
+ self.dev_mode = True
548
+ self.app_id = random.getrandbits(64)
549
+ self.temp_file_sets = []
550
+ self.title = title
551
+ self.show_api = True
552
+
553
+ # Only used when an Interface is loaded from a config
554
+ self.predict = None
555
+ self.input_components = None
556
+ self.output_components = None
557
+ self.__name__ = None
558
+ self.api_mode = None
559
+ self.progress_tracking = None
560
+
561
+ self.file_directories = []
562
+
563
+ if self.analytics_enabled:
564
+ data = {
565
+ "mode": self.mode,
566
+ "custom_css": self.css is not None,
567
+ "theme": self.theme,
568
+ "version": GRADIO_VERSION,
569
+ }
570
+ utils.initiated_analytics(data)
571
+
572
+ @classmethod
573
+ def from_config(
574
+ cls,
575
+ config: dict,
576
+ fns: List[Callable],
577
+ root_url: str | None = None,
578
+ ) -> Blocks:
579
+ """
580
+ Factory method that creates a Blocks from a config and list of functions.
581
+
582
+ Parameters:
583
+ config: a dictionary containing the configuration of the Blocks.
584
+ fns: a list of functions that are used in the Blocks. Must be in the same order as the dependencies in the config.
585
+ root_url: an optional root url to use for the components in the Blocks. Allows serving files from an external URL.
586
+ """
587
+ config = copy.deepcopy(config)
588
+ components_config = config["components"]
589
+ original_mapping: Dict[int, Block] = {}
590
+
591
+ def get_block_instance(id: int) -> Block:
592
+ for block_config in components_config:
593
+ if block_config["id"] == id:
594
+ break
595
+ else:
596
+ raise ValueError("Cannot find block with id {}".format(id))
597
+ cls = component_or_layout_class(block_config["type"])
598
+ block_config["props"].pop("type", None)
599
+ block_config["props"].pop("name", None)
600
+ style = block_config["props"].pop("style", None)
601
+ if block_config["props"].get("root_url") is None and root_url:
602
+ block_config["props"]["root_url"] = root_url + "/"
603
+ # Any component has already processed its initial value, so we skip that step here
604
+ block = cls(**block_config["props"], _skip_init_processing=True)
605
+ if style and isinstance(block, components.IOComponent):
606
+ block.style(**style)
607
+ return block
608
+
609
+ def iterate_over_children(children_list):
610
+ for child_config in children_list:
611
+ id = child_config["id"]
612
+ block = get_block_instance(id)
613
+
614
+ original_mapping[id] = block
615
+
616
+ children = child_config.get("children")
617
+ if children is not None:
618
+ assert isinstance(
619
+ block, BlockContext
620
+ ), f"Invalid config, Block with id {id} has children but is not a BlockContext."
621
+ with block:
622
+ iterate_over_children(children)
623
+
624
+ derived_fields = ["types"]
625
+
626
+ with Blocks() as blocks:
627
+ # ID 0 should be the root Blocks component
628
+ original_mapping[0] = Context.root_block or blocks
629
+
630
+ iterate_over_children(config["layout"]["children"])
631
+
632
+ first_dependency = None
633
+
634
+ # add the event triggers
635
+ for dependency, fn in zip(config["dependencies"], fns):
636
+ # We used to add a "fake_event" to the config to cache examples
637
+ # without removing it. This was causing bugs in calling gr.Interface.load
638
+ # We fixed the issue by removing "fake_event" from the config in examples.py
639
+ # but we still need to skip these events when loading the config to support
640
+ # older demos
641
+ if dependency["trigger"] == "fake_event":
642
+ continue
643
+ for field in derived_fields:
644
+ dependency.pop(field, None)
645
+ targets = dependency.pop("targets")
646
+ trigger = dependency.pop("trigger")
647
+ dependency.pop("backend_fn")
648
+ dependency.pop("documentation", None)
649
+ dependency["inputs"] = [
650
+ original_mapping[i] for i in dependency["inputs"]
651
+ ]
652
+ dependency["outputs"] = [
653
+ original_mapping[o] for o in dependency["outputs"]
654
+ ]
655
+ dependency.pop("status_tracker", None)
656
+ dependency["preprocess"] = False
657
+ dependency["postprocess"] = False
658
+
659
+ for target in targets:
660
+ dependency = original_mapping[target].set_event_trigger(
661
+ event_name=trigger, fn=fn, **dependency
662
+ )[0]
663
+ if first_dependency is None:
664
+ first_dependency = dependency
665
+
666
+ # Allows some use of Interface-specific methods with loaded Spaces
667
+ if first_dependency and Context.root_block:
668
+ blocks.predict = [fns[0]]
669
+ blocks.input_components = [
670
+ Context.root_block.blocks[i] for i in first_dependency["inputs"]
671
+ ]
672
+ blocks.output_components = [
673
+ Context.root_block.blocks[o] for o in first_dependency["outputs"]
674
+ ]
675
+ blocks.__name__ = "Interface"
676
+ blocks.api_mode = True
677
+
678
+ return blocks
679
+
680
+ def __str__(self):
681
+ return self.__repr__()
682
+
683
+ def __repr__(self):
684
+ num_backend_fns = len([d for d in self.dependencies if d["backend_fn"]])
685
+ repr = f"Gradio Blocks instance: {num_backend_fns} backend functions"
686
+ repr += "\n" + "-" * len(repr)
687
+ for d, dependency in enumerate(self.dependencies):
688
+ if dependency["backend_fn"]:
689
+ repr += f"\nfn_index={d}"
690
+ repr += "\n inputs:"
691
+ for input_id in dependency["inputs"]:
692
+ block = self.blocks[input_id]
693
+ repr += "\n |-{}".format(str(block))
694
+ repr += "\n outputs:"
695
+ for output_id in dependency["outputs"]:
696
+ block = self.blocks[output_id]
697
+ repr += "\n |-{}".format(str(block))
698
+ return repr
699
+
700
+ def render(self):
701
+ if Context.root_block is not None:
702
+ if self._id in Context.root_block.blocks:
703
+ raise DuplicateBlockError(
704
+ f"A block with id: {self._id} has already been rendered in the current Blocks."
705
+ )
706
+ if not set(Context.root_block.blocks).isdisjoint(self.blocks):
707
+ raise DuplicateBlockError(
708
+ "At least one block in this Blocks has already been rendered."
709
+ )
710
+
711
+ Context.root_block.blocks.update(self.blocks)
712
+ Context.root_block.fns.extend(self.fns)
713
+ dependency_offset = len(Context.root_block.dependencies)
714
+ for i, dependency in enumerate(self.dependencies):
715
+ api_name = dependency["api_name"]
716
+ if api_name is not None:
717
+ api_name_ = utils.append_unique_suffix(
718
+ api_name,
719
+ [dep["api_name"] for dep in Context.root_block.dependencies],
720
+ )
721
+ if not (api_name == api_name_):
722
+ warnings.warn(
723
+ "api_name {} already exists, using {}".format(
724
+ api_name, api_name_
725
+ )
726
+ )
727
+ dependency["api_name"] = api_name_
728
+ dependency["cancels"] = [
729
+ c + dependency_offset for c in dependency["cancels"]
730
+ ]
731
+ if dependency.get("trigger_after") is not None:
732
+ dependency["trigger_after"] += dependency_offset
733
+ # Recreate the cancel function so that it has the latest
734
+ # dependency fn indices. This is necessary to properly cancel
735
+ # events in the backend
736
+ if dependency["cancels"]:
737
+ updated_cancels = [
738
+ Context.root_block.dependencies[i]
739
+ for i in dependency["cancels"]
740
+ ]
741
+ new_fn = BlockFunction(
742
+ get_cancel_function(updated_cancels)[0],
743
+ [],
744
+ [],
745
+ False,
746
+ True,
747
+ False,
748
+ )
749
+ Context.root_block.fns[dependency_offset + i] = new_fn
750
+ Context.root_block.dependencies.append(dependency)
751
+ Context.root_block.temp_file_sets.extend(self.temp_file_sets)
752
+
753
+ if Context.block is not None:
754
+ Context.block.children.extend(self.children)
755
+ return self
756
+
757
+ def is_callable(self, fn_index: int = 0) -> bool:
758
+ """Checks if a particular Blocks function is callable (i.e. not stateful or a generator)."""
759
+ block_fn = self.fns[fn_index]
760
+ dependency = self.dependencies[fn_index]
761
+
762
+ if inspect.isasyncgenfunction(block_fn.fn):
763
+ return False
764
+ if inspect.isgeneratorfunction(block_fn.fn):
765
+ return False
766
+ for input_id in dependency["inputs"]:
767
+ block = self.blocks[input_id]
768
+ if getattr(block, "stateful", False):
769
+ return False
770
+ for output_id in dependency["outputs"]:
771
+ block = self.blocks[output_id]
772
+ if getattr(block, "stateful", False):
773
+ return False
774
+
775
+ return True
776
+
777
+ def __call__(self, *inputs, fn_index: int = 0, api_name: str | None = None):
778
+ """
779
+ Allows Blocks objects to be called as functions. Supply the parameters to the
780
+ function as positional arguments. To choose which function to call, use the
781
+ fn_index parameter, which must be a keyword argument.
782
+
783
+ Parameters:
784
+ *inputs: the parameters to pass to the function
785
+ fn_index: the index of the function to call (defaults to 0, which for Interfaces, is the default prediction function)
786
+ api_name: The api_name of the dependency to call. Will take precedence over fn_index.
787
+ """
788
+ if api_name is not None:
789
+ inferred_fn_index = next(
790
+ (
791
+ i
792
+ for i, d in enumerate(self.dependencies)
793
+ if d.get("api_name") == api_name
794
+ ),
795
+ None,
796
+ )
797
+ if inferred_fn_index is None:
798
+ raise InvalidApiName(f"Cannot find a function with api_name {api_name}")
799
+ fn_index = inferred_fn_index
800
+ if not (self.is_callable(fn_index)):
801
+ raise ValueError(
802
+ "This function is not callable because it is either stateful or is a generator. Please use the .launch() method instead to create an interactive user interface."
803
+ )
804
+
805
+ inputs = list(inputs)
806
+ processed_inputs = self.serialize_data(fn_index, inputs)
807
+ batch = self.dependencies[fn_index]["batch"]
808
+ if batch:
809
+ processed_inputs = [[inp] for inp in processed_inputs]
810
+
811
+ outputs = utils.synchronize_async(
812
+ self.process_api,
813
+ fn_index=fn_index,
814
+ inputs=processed_inputs,
815
+ request=None,
816
+ state={},
817
+ )
818
+ outputs = outputs["data"]
819
+
820
+ if batch:
821
+ outputs = [out[0] for out in outputs]
822
+
823
+ processed_outputs = self.deserialize_data(fn_index, outputs)
824
+ processed_outputs = utils.resolve_singleton(processed_outputs)
825
+
826
+ return processed_outputs
827
+
828
+ async def call_function(
829
+ self,
830
+ fn_index: int,
831
+ processed_input: List[Any],
832
+ iterator: Iterator[Any] | None = None,
833
+ requests: routes.Request | List[routes.Request] | None = None,
834
+ event_id: str | None = None,
835
+ event_data: EventData | None = None,
836
+ ):
837
+ """
838
+ Calls function with given index and preprocessed input, and measures process time.
839
+ Parameters:
840
+ fn_index: index of function to call
841
+ processed_input: preprocessed input to pass to function
842
+ iterator: iterator to use if function is a generator
843
+ requests: requests to pass to function
844
+ event_id: id of event in queue
845
+ event_data: data associated with event trigger
846
+ """
847
+ block_fn = self.fns[fn_index]
848
+ assert block_fn.fn, f"function with index {fn_index} not defined."
849
+ is_generating = False
850
+
851
+ if block_fn.inputs_as_dict:
852
+ processed_input = [
853
+ {
854
+ input_component: data
855
+ for input_component, data in zip(block_fn.inputs, processed_input)
856
+ }
857
+ ]
858
+
859
+ if isinstance(requests, list):
860
+ request = requests[0]
861
+ else:
862
+ request = requests
863
+ processed_input, progress_index, _ = special_args(
864
+ block_fn.fn, processed_input, request, event_data
865
+ )
866
+ progress_tracker = (
867
+ processed_input[progress_index] if progress_index is not None else None
868
+ )
869
+
870
+ start = time.time()
871
+
872
+ if iterator is None: # If not a generator function that has already run
873
+ if progress_tracker is not None and progress_index is not None:
874
+ progress_tracker, fn = create_tracker(
875
+ self, event_id, block_fn.fn, progress_tracker.track_tqdm
876
+ )
877
+ processed_input[progress_index] = progress_tracker
878
+ else:
879
+ fn = block_fn.fn
880
+
881
+ if inspect.iscoroutinefunction(fn):
882
+ prediction = await fn(*processed_input)
883
+ else:
884
+ prediction = await anyio.to_thread.run_sync(
885
+ fn, *processed_input, limiter=self.limiter
886
+ )
887
+ else:
888
+ prediction = None
889
+
890
+ if inspect.isasyncgenfunction(block_fn.fn):
891
+ raise ValueError("Gradio does not support async generators.")
892
+ if inspect.isgeneratorfunction(block_fn.fn):
893
+ if not self.enable_queue:
894
+ raise ValueError("Need to enable queue to use generators.")
895
+ try:
896
+ if iterator is None:
897
+ iterator = prediction
898
+ prediction = await anyio.to_thread.run_sync(
899
+ utils.async_iteration, iterator, limiter=self.limiter
900
+ )
901
+ is_generating = True
902
+ except StopAsyncIteration:
903
+ n_outputs = len(self.dependencies[fn_index].get("outputs"))
904
+ prediction = (
905
+ components._Keywords.FINISHED_ITERATING
906
+ if n_outputs == 1
907
+ else (components._Keywords.FINISHED_ITERATING,) * n_outputs
908
+ )
909
+ iterator = None
910
+
911
+ duration = time.time() - start
912
+
913
+ return {
914
+ "prediction": prediction,
915
+ "duration": duration,
916
+ "is_generating": is_generating,
917
+ "iterator": iterator,
918
+ }
919
+
920
+ def serialize_data(self, fn_index: int, inputs: List[Any]) -> List[Any]:
921
+ dependency = self.dependencies[fn_index]
922
+ processed_input = []
923
+
924
+ for i, input_id in enumerate(dependency["inputs"]):
925
+ block = self.blocks[input_id]
926
+ assert isinstance(
927
+ block, components.IOComponent
928
+ ), f"{block.__class__} Component with id {input_id} not a valid input component."
929
+ serialized_input = block.serialize(inputs[i])
930
+ processed_input.append(serialized_input)
931
+
932
+ return processed_input
933
+
934
+ def deserialize_data(self, fn_index: int, outputs: List[Any]) -> List[Any]:
935
+ dependency = self.dependencies[fn_index]
936
+ predictions = []
937
+
938
+ for o, output_id in enumerate(dependency["outputs"]):
939
+ block = self.blocks[output_id]
940
+ assert isinstance(
941
+ block, components.IOComponent
942
+ ), f"{block.__class__} Component with id {output_id} not a valid output component."
943
+ deserialized = block.deserialize(outputs[o], root_url=block.root_url)
944
+ predictions.append(deserialized)
945
+
946
+ return predictions
947
+
948
+ def preprocess_data(self, fn_index: int, inputs: List[Any], state: Dict[int, Any]):
949
+ block_fn = self.fns[fn_index]
950
+ dependency = self.dependencies[fn_index]
951
+
952
+ if block_fn.preprocess:
953
+ processed_input = []
954
+ for i, input_id in enumerate(dependency["inputs"]):
955
+ block = self.blocks[input_id]
956
+ assert isinstance(
957
+ block, components.Component
958
+ ), f"{block.__class__} Component with id {input_id} not a valid input component."
959
+ if getattr(block, "stateful", False):
960
+ processed_input.append(state.get(input_id))
961
+ else:
962
+ processed_input.append(block.preprocess(inputs[i]))
963
+ else:
964
+ processed_input = inputs
965
+ return processed_input
966
+
967
+ def postprocess_data(
968
+ self, fn_index: int, predictions: List | Dict, state: Dict[int, Any]
969
+ ):
970
+ block_fn = self.fns[fn_index]
971
+ dependency = self.dependencies[fn_index]
972
+ batch = dependency["batch"]
973
+
974
+ if type(predictions) is dict and len(predictions) > 0:
975
+ predictions = convert_component_dict_to_list(
976
+ dependency["outputs"], predictions
977
+ )
978
+
979
+ if len(dependency["outputs"]) == 1 and not (batch):
980
+ predictions = [
981
+ predictions,
982
+ ]
983
+
984
+ output = []
985
+ for i, output_id in enumerate(dependency["outputs"]):
986
+ try:
987
+ if predictions[i] is components._Keywords.FINISHED_ITERATING:
988
+ output.append(None)
989
+ continue
990
+ except (IndexError, KeyError):
991
+ raise ValueError(
992
+ f"Number of output components does not match number of values returned from from function {block_fn.name}"
993
+ )
994
+ block = self.blocks[output_id]
995
+ if getattr(block, "stateful", False):
996
+ if not utils.is_update(predictions[i]):
997
+ state[output_id] = predictions[i]
998
+ output.append(None)
999
+ else:
1000
+ prediction_value = predictions[i]
1001
+ if utils.is_update(prediction_value):
1002
+ assert isinstance(prediction_value, dict)
1003
+ prediction_value = postprocess_update_dict(
1004
+ block=block,
1005
+ update_dict=prediction_value,
1006
+ postprocess=block_fn.postprocess,
1007
+ )
1008
+ elif block_fn.postprocess:
1009
+ assert isinstance(
1010
+ block, components.Component
1011
+ ), f"{block.__class__} Component with id {output_id} not a valid output component."
1012
+ prediction_value = block.postprocess(prediction_value)
1013
+ output.append(prediction_value)
1014
+
1015
+ return output
1016
+
1017
+ async def process_api(
1018
+ self,
1019
+ fn_index: int,
1020
+ inputs: List[Any],
1021
+ state: Dict[int, Any],
1022
+ request: routes.Request | List[routes.Request] | None = None,
1023
+ iterators: Dict[int, Any] | None = None,
1024
+ event_id: str | None = None,
1025
+ event_data: EventData | None = None,
1026
+ ) -> Dict[str, Any]:
1027
+ """
1028
+ Processes API calls from the frontend. First preprocesses the data,
1029
+ then runs the relevant function, then postprocesses the output.
1030
+ Parameters:
1031
+ fn_index: Index of function to run.
1032
+ inputs: input data received from the frontend
1033
+ username: name of user if authentication is set up (not used)
1034
+ state: data stored from stateful components for session (key is input block id)
1035
+ iterators: the in-progress iterators for each generator function (key is function index)
1036
+ event_id: id of event that triggered this API call
1037
+ event_data: data associated with the event trigger itself
1038
+ Returns: None
1039
+ """
1040
+ block_fn = self.fns[fn_index]
1041
+ batch = self.dependencies[fn_index]["batch"]
1042
+
1043
+ if batch:
1044
+ max_batch_size = self.dependencies[fn_index]["max_batch_size"]
1045
+ batch_sizes = [len(inp) for inp in inputs]
1046
+ batch_size = batch_sizes[0]
1047
+ if inspect.isasyncgenfunction(block_fn.fn) or inspect.isgeneratorfunction(
1048
+ block_fn.fn
1049
+ ):
1050
+ raise ValueError("Gradio does not support generators in batch mode.")
1051
+ if not all(x == batch_size for x in batch_sizes):
1052
+ raise ValueError(
1053
+ f"All inputs to a batch function must have the same length but instead have sizes: {batch_sizes}."
1054
+ )
1055
+ if batch_size > max_batch_size:
1056
+ raise ValueError(
1057
+ f"Batch size ({batch_size}) exceeds the max_batch_size for this function ({max_batch_size})"
1058
+ )
1059
+
1060
+ inputs = [
1061
+ self.preprocess_data(fn_index, list(i), state) for i in zip(*inputs)
1062
+ ]
1063
+ result = await self.call_function(
1064
+ fn_index, list(zip(*inputs)), None, request, event_id, event_data
1065
+ )
1066
+ preds = result["prediction"]
1067
+ data = [
1068
+ self.postprocess_data(fn_index, list(o), state) for o in zip(*preds)
1069
+ ]
1070
+ data = list(zip(*data))
1071
+ is_generating, iterator = None, None
1072
+ else:
1073
+ inputs = self.preprocess_data(fn_index, inputs, state)
1074
+ iterator = iterators.get(fn_index, None) if iterators else None
1075
+ result = await self.call_function(
1076
+ fn_index, inputs, iterator, request, event_id, event_data
1077
+ )
1078
+ data = self.postprocess_data(fn_index, result["prediction"], state)
1079
+ is_generating, iterator = result["is_generating"], result["iterator"]
1080
+
1081
+ block_fn.total_runtime += result["duration"]
1082
+ block_fn.total_runs += 1
1083
+
1084
+ return {
1085
+ "data": data,
1086
+ "is_generating": is_generating,
1087
+ "iterator": iterator,
1088
+ "duration": result["duration"],
1089
+ "average_duration": block_fn.total_runtime / block_fn.total_runs,
1090
+ }
1091
+
1092
+ async def create_limiter(self):
1093
+ self.limiter = (
1094
+ None
1095
+ if self.max_threads == 40
1096
+ else CapacityLimiter(total_tokens=self.max_threads)
1097
+ )
1098
+
1099
+ def get_config(self):
1100
+ return {"type": "column"}
1101
+
1102
+ def get_config_file(self):
1103
+ config = {
1104
+ "version": routes.VERSION,
1105
+ "mode": self.mode,
1106
+ "dev_mode": self.dev_mode,
1107
+ "analytics_enabled": self.analytics_enabled,
1108
+ "components": [],
1109
+ "css": self.css,
1110
+ "title": self.title or "Gradio",
1111
+ "is_space": self.is_space,
1112
+ "enable_queue": getattr(self, "enable_queue", False), # launch attributes
1113
+ "show_error": getattr(self, "show_error", False),
1114
+ "show_api": self.show_api,
1115
+ "is_colab": utils.colab_check(),
1116
+ "stylesheets": self.stylesheets,
1117
+ "root": self.root,
1118
+ }
1119
+
1120
+ def getLayout(block):
1121
+ if not isinstance(block, BlockContext):
1122
+ return {"id": block._id}
1123
+ children_layout = []
1124
+ for child in block.children:
1125
+ children_layout.append(getLayout(child))
1126
+ return {"id": block._id, "children": children_layout}
1127
+
1128
+ config["layout"] = getLayout(self)
1129
+
1130
+ for _id, block in self.blocks.items():
1131
+ config["components"].append(
1132
+ {
1133
+ "id": _id,
1134
+ "type": (block.get_block_name()),
1135
+ "props": utils.delete_none(block.get_config())
1136
+ if hasattr(block, "get_config")
1137
+ else {},
1138
+ }
1139
+ )
1140
+ config["dependencies"] = self.dependencies
1141
+ return config
1142
+
1143
+ def __enter__(self):
1144
+ if Context.block is None:
1145
+ Context.root_block = self
1146
+ self.parent = Context.block
1147
+ Context.block = self
1148
+ return self
1149
+
1150
+ def __exit__(self, *args):
1151
+ super().fill_expected_parents()
1152
+ Context.block = self.parent
1153
+ # Configure the load events before root_block is reset
1154
+ self.attach_load_events()
1155
+ if self.parent is None:
1156
+ Context.root_block = None
1157
+ else:
1158
+ self.parent.children.extend(self.children)
1159
+ self.config = self.get_config_file()
1160
+ self.app = routes.App.create_app(self)
1161
+ self.progress_tracking = any(block_fn.tracks_progress for block_fn in self.fns)
1162
+
1163
+ @class_or_instancemethod
1164
+ def load(
1165
+ self_or_cls,
1166
+ fn: Callable | None = None,
1167
+ inputs: List[Component] | None = None,
1168
+ outputs: List[Component] | None = None,
1169
+ api_name: str | None = None,
1170
+ scroll_to_output: bool = False,
1171
+ show_progress: bool = True,
1172
+ queue=None,
1173
+ batch: bool = False,
1174
+ max_batch_size: int = 4,
1175
+ preprocess: bool = True,
1176
+ postprocess: bool = True,
1177
+ every: float | None = None,
1178
+ _js: str | None = None,
1179
+ *,
1180
+ name: str | None = None,
1181
+ src: str | None = None,
1182
+ api_key: str | None = None,
1183
+ alias: str | None = None,
1184
+ **kwargs,
1185
+ ) -> Blocks | Dict[str, Any] | None:
1186
+ """
1187
+ For reverse compatibility reasons, this is both a class method and an instance
1188
+ method, the two of which, confusingly, do two completely different things.
1189
+
1190
+
1191
+ Class method: loads a demo from a Hugging Face Spaces repo and creates it locally and returns a block instance. Equivalent to gradio.Interface.load()
1192
+
1193
+
1194
+ Instance method: adds event that runs as soon as the demo loads in the browser. Example usage below.
1195
+ Parameters:
1196
+ name: Class Method - the name of the model (e.g. "gpt2" or "facebook/bart-base") or space (e.g. "flax-community/spanish-gpt2"), can include the `src` as prefix (e.g. "models/facebook/bart-base")
1197
+ src: Class Method - the source of the model: `models` or `spaces` (or leave empty if source is provided as a prefix in `name`)
1198
+ api_key: Class Method - optional access token for loading private Hugging Face Hub models or spaces. Find your token here: https://huggingface.co/settings/tokens
1199
+ alias: Class Method - optional string used as the name of the loaded model instead of the default name (only applies if loading a Space running Gradio 2.x)
1200
+ fn: Instance Method - the function to wrap an interface around. Often a machine learning model's prediction function. Each parameter of the function corresponds to one input component, and the function should return a single value or a tuple of values, with each element in the tuple corresponding to one output component.
1201
+ inputs: Instance Method - List of gradio.components to use as inputs. If the function takes no inputs, this should be an empty list.
1202
+ outputs: Instance Method - List of gradio.components to use as inputs. If the function returns no outputs, this should be an empty list.
1203
+ api_name: Instance Method - Defining this parameter exposes the endpoint in the api docs
1204
+ scroll_to_output: Instance Method - If True, will scroll to output component on completion
1205
+ show_progress: Instance Method - If True, will show progress animation while pending
1206
+ queue: Instance Method - If True, will place the request on the queue, if the queue exists
1207
+ batch: Instance Method - If True, then the function should process a batch of inputs, meaning that it should accept a list of input values for each parameter. The lists should be of equal length (and be up to length `max_batch_size`). The function is then *required* to return a tuple of lists (even if there is only 1 output component), with each list in the tuple corresponding to one output component.
1208
+ max_batch_size: Instance Method - Maximum number of inputs to batch together if this is called from the queue (only relevant if batch=True)
1209
+ preprocess: Instance Method - If False, will not run preprocessing of component data before running 'fn' (e.g. leaving it as a base64 string if this method is called with the `Image` component).
1210
+ postprocess: Instance Method - If False, will not run postprocessing of component data before returning 'fn' output to the browser.
1211
+ every: Instance Method - Run this event 'every' number of seconds. Interpreted in seconds. Queue must be enabled.
1212
+ Example:
1213
+ import gradio as gr
1214
+ import datetime
1215
+ with gr.Blocks() as demo:
1216
+ def get_time():
1217
+ return datetime.datetime.now().time()
1218
+ dt = gr.Textbox(label="Current time")
1219
+ demo.load(get_time, inputs=None, outputs=dt)
1220
+ demo.launch()
1221
+ """
1222
+ # _js: Optional frontend js method to run before running 'fn'. Input arguments for js method are values of 'inputs' and 'outputs', return should be a list of values for output components.
1223
+ if isinstance(self_or_cls, type):
1224
+ if name is None:
1225
+ raise ValueError(
1226
+ "Blocks.load() requires passing parameters as keyword arguments"
1227
+ )
1228
+ return external.load_blocks_from_repo(name, src, api_key, alias, **kwargs)
1229
+ else:
1230
+ return self_or_cls.set_event_trigger(
1231
+ event_name="load",
1232
+ fn=fn,
1233
+ inputs=inputs,
1234
+ outputs=outputs,
1235
+ api_name=api_name,
1236
+ preprocess=preprocess,
1237
+ postprocess=postprocess,
1238
+ scroll_to_output=scroll_to_output,
1239
+ show_progress=show_progress,
1240
+ js=_js,
1241
+ queue=queue,
1242
+ batch=batch,
1243
+ max_batch_size=max_batch_size,
1244
+ every=every,
1245
+ no_target=True,
1246
+ )[0]
1247
+
1248
+ def clear(self):
1249
+ """Resets the layout of the Blocks object."""
1250
+ self.blocks = {}
1251
+ self.fns = []
1252
+ self.dependencies = []
1253
+ self.children = []
1254
+ return self
1255
+
1256
+ @document()
1257
+ def queue(
1258
+ self,
1259
+ concurrency_count: int = 1,
1260
+ status_update_rate: float | Literal["auto"] = "auto",
1261
+ client_position_to_load_data: int | None = None,
1262
+ default_enabled: bool | None = None,
1263
+ api_open: bool = True,
1264
+ max_size: int | None = None,
1265
+ ):
1266
+ """
1267
+ You can control the rate of processed requests by creating a queue. This will allow you to set the number of requests to be processed at one time, and will let users know their position in the queue.
1268
+ Parameters:
1269
+ concurrency_count: Number of worker threads that will be processing requests from the queue concurrently. Increasing this number will increase the rate at which requests are processed, but will also increase the memory usage of the queue.
1270
+ status_update_rate: If "auto", Queue will send status estimations to all clients whenever a job is finished. Otherwise Queue will send status at regular intervals set by this parameter as the number of seconds.
1271
+ client_position_to_load_data: DEPRECATED. This parameter is deprecated and has no effect.
1272
+ default_enabled: Deprecated and has no effect.
1273
+ api_open: If True, the REST routes of the backend will be open, allowing requests made directly to those endpoints to skip the queue.
1274
+ max_size: The maximum number of events the queue will store at any given moment. If the queue is full, new events will not be added and a user will receive a message saying that the queue is full. If None, the queue size will be unlimited.
1275
+ Example: (Blocks)
1276
+ with gr.Blocks() as demo:
1277
+ button = gr.Button(label="Generate Image")
1278
+ button.click(fn=image_generator, inputs=gr.Textbox(), outputs=gr.Image())
1279
+ demo.queue(concurrency_count=3)
1280
+ demo.launch()
1281
+ Example: (Interface)
1282
+ demo = gr.Interface(image_generator, gr.Textbox(), gr.Image())
1283
+ demo.queue(concurrency_count=3)
1284
+ demo.launch()
1285
+ """
1286
+ if default_enabled is not None:
1287
+ warnings.warn(
1288
+ "The default_enabled parameter of queue has no effect and will be removed "
1289
+ "in a future version of gradio."
1290
+ )
1291
+ self.enable_queue = True
1292
+ self.api_open = api_open
1293
+ if client_position_to_load_data is not None:
1294
+ warnings.warn("The client_position_to_load_data parameter is deprecated.")
1295
+ self._queue = queueing.Queue(
1296
+ live_updates=status_update_rate == "auto",
1297
+ concurrency_count=concurrency_count,
1298
+ update_intervals=status_update_rate if status_update_rate != "auto" else 1,
1299
+ max_size=max_size,
1300
+ blocks_dependencies=self.dependencies,
1301
+ )
1302
+ self.config = self.get_config_file()
1303
+ self.app = routes.App.create_app(self)
1304
+ return self
1305
+
1306
+ def launch(
1307
+ self,
1308
+ inline: bool | None = None,
1309
+ inbrowser: bool = False,
1310
+ share: bool | None = None,
1311
+ debug: bool = False,
1312
+ enable_queue: bool | None = None,
1313
+ max_threads: int = 40,
1314
+ auth: Callable | Tuple[str, str] | List[Tuple[str, str]] | None = None,
1315
+ auth_message: str | None = None,
1316
+ prevent_thread_lock: bool = False,
1317
+ show_error: bool = False,
1318
+ server_name: str | None = None,
1319
+ server_port: int | None = None,
1320
+ show_tips: bool = False,
1321
+ height: int = 500,
1322
+ width: int | str = "100%",
1323
+ encrypt: bool | None = None,
1324
+ favicon_path: str | None = None,
1325
+ ssl_keyfile: str | None = None,
1326
+ ssl_certfile: str | None = None,
1327
+ ssl_keyfile_password: str | None = None,
1328
+ quiet: bool = False,
1329
+ show_api: bool = True,
1330
+ file_directories: List[str] | None = None,
1331
+ _frontend: bool = True,
1332
+ ) -> Tuple[FastAPI, str, str]:
1333
+ """
1334
+ Launches a simple web server that serves the demo. Can also be used to create a
1335
+ public link used by anyone to access the demo from their browser by setting share=True.
1336
+
1337
+ Parameters:
1338
+ inline: whether to display in the interface inline in an iframe. Defaults to True in python notebooks; False otherwise.
1339
+ inbrowser: whether to automatically launch the interface in a new tab on the default browser.
1340
+ share: whether to create a publicly shareable link for the interface. Creates an SSH tunnel to make your UI accessible from anywhere. If not provided, it is set to False by default every time, except when running in Google Colab. When localhost is not accessible (e.g. Google Colab), setting share=False is not supported.
1341
+ debug: if True, blocks the main thread from running. If running in Google Colab, this is needed to print the errors in the cell output.
1342
+ auth: If provided, username and password (or list of username-password tuples) required to access interface. Can also provide function that takes username and password and returns True if valid login.
1343
+ auth_message: If provided, HTML message provided on login page.
1344
+ prevent_thread_lock: If True, the interface will block the main thread while the server is running.
1345
+ show_error: If True, any errors in the interface will be displayed in an alert modal and printed in the browser console log
1346
+ server_port: will start gradio app on this port (if available). Can be set by environment variable GRADIO_SERVER_PORT. If None, will search for an available port starting at 7860.
1347
+ server_name: to make app accessible on local network, set this to "0.0.0.0". Can be set by environment variable GRADIO_SERVER_NAME. If None, will use "127.0.0.1".
1348
+ show_tips: if True, will occasionally show tips about new Gradio features
1349
+ enable_queue: DEPRECATED (use .queue() method instead.) if True, inference requests will be served through a queue instead of with parallel threads. Required for longer inference times (> 1min) to prevent timeout. The default option in HuggingFace Spaces is True. The default option elsewhere is False.
1350
+ max_threads: the maximum number of total threads that the Gradio app can generate in parallel. The default is inherited from the starlette library (currently 40). Applies whether the queue is enabled or not. But if queuing is enabled, this parameter is increaseed to be at least the concurrency_count of the queue.
1351
+ width: The width in pixels of the iframe element containing the interface (used if inline=True)
1352
+ height: The height in pixels of the iframe element containing the interface (used if inline=True)
1353
+ encrypt: DEPRECATED. Has no effect.
1354
+ favicon_path: If a path to a file (.png, .gif, or .ico) is provided, it will be used as the favicon for the web page.
1355
+ ssl_keyfile: If a path to a file is provided, will use this as the private key file to create a local server running on https.
1356
+ ssl_certfile: If a path to a file is provided, will use this as the signed certificate for https. Needs to be provided if ssl_keyfile is provided.
1357
+ ssl_keyfile_password: If a password is provided, will use this with the ssl certificate for https.
1358
+ quiet: If True, suppresses most print statements.
1359
+ show_api: If True, shows the api docs in the footer of the app. Default True. If the queue is enabled, then api_open parameter of .queue() will determine if the api docs are shown, independent of the value of show_api.
1360
+ file_directories: List of directories that gradio is allowed to serve files from (in addition to the directory containing the gradio python file). Must be absolute paths. Warning: any files in these directories or its children are potentially accessible to all users of your app.
1361
+ Returns:
1362
+ app: FastAPI app object that is running the demo
1363
+ local_url: Locally accessible link to the demo
1364
+ share_url: Publicly accessible link to the demo (if share=True, otherwise None)
1365
+ Example: (Blocks)
1366
+ import gradio as gr
1367
+ def reverse(text):
1368
+ return text[::-1]
1369
+ with gr.Blocks() as demo:
1370
+ button = gr.Button(value="Reverse")
1371
+ button.click(reverse, gr.Textbox(), gr.Textbox())
1372
+ demo.launch(share=True, auth=("username", "password"))
1373
+ Example: (Interface)
1374
+ import gradio as gr
1375
+ def reverse(text):
1376
+ return text[::-1]
1377
+ demo = gr.Interface(reverse, "text", "text")
1378
+ demo.launch(share=True, auth=("username", "password"))
1379
+ """
1380
+ self.dev_mode = False
1381
+ if (
1382
+ auth
1383
+ and not callable(auth)
1384
+ and not isinstance(auth[0], tuple)
1385
+ and not isinstance(auth[0], list)
1386
+ ):
1387
+ self.auth = [auth]
1388
+ else:
1389
+ self.auth = auth
1390
+ self.auth_message = auth_message
1391
+ self.show_tips = show_tips
1392
+ self.show_error = show_error
1393
+ self.height = height
1394
+ self.width = width
1395
+ self.favicon_path = favicon_path
1396
+
1397
+ if enable_queue is not None:
1398
+ self.enable_queue = enable_queue
1399
+ warnings.warn(
1400
+ "The `enable_queue` parameter has been deprecated. Please use the `.queue()` method instead.",
1401
+ DeprecationWarning,
1402
+ )
1403
+ if encrypt is not None:
1404
+ warnings.warn(
1405
+ "The `encrypt` parameter has been deprecated and has no effect.",
1406
+ DeprecationWarning,
1407
+ )
1408
+
1409
+ if self.is_space:
1410
+ self.enable_queue = self.enable_queue is not False
1411
+ else:
1412
+ self.enable_queue = self.enable_queue is True
1413
+ if self.enable_queue and not hasattr(self, "_queue"):
1414
+ self.queue()
1415
+ self.show_api = self.api_open if self.enable_queue else show_api
1416
+
1417
+ self.file_directories = file_directories if file_directories is not None else []
1418
+ if not isinstance(self.file_directories, list):
1419
+ raise ValueError("file_directories must be a list of directories.")
1420
+
1421
+ if not self.enable_queue and self.progress_tracking:
1422
+ raise ValueError("Progress tracking requires queuing to be enabled.")
1423
+
1424
+ for dep in self.dependencies:
1425
+ for i in dep["cancels"]:
1426
+ if not self.queue_enabled_for_fn(i):
1427
+ raise ValueError(
1428
+ "In order to cancel an event, the queue for that event must be enabled! "
1429
+ "You may get this error by either 1) passing a function that uses the yield keyword "
1430
+ "into an interface without enabling the queue or 2) defining an event that cancels "
1431
+ "another event without enabling the queue. Both can be solved by calling .queue() "
1432
+ "before .launch()"
1433
+ )
1434
+ if dep["batch"] and (
1435
+ dep["queue"] is False
1436
+ or (dep["queue"] is None and not self.enable_queue)
1437
+ ):
1438
+ raise ValueError("In order to use batching, the queue must be enabled.")
1439
+
1440
+ self.config = self.get_config_file()
1441
+ self.max_threads = max(
1442
+ self._queue.max_thread_count if self.enable_queue else 0, max_threads
1443
+ )
1444
+
1445
+ if self.is_running:
1446
+ assert isinstance(
1447
+ self.local_url, str
1448
+ ), f"Invalid local_url: {self.local_url}"
1449
+ if not (quiet):
1450
+ print(
1451
+ "Rerunning server... use `close()` to stop if you need to change `launch()` parameters.\n----"
1452
+ )
1453
+ else:
1454
+ server_name, server_port, local_url, app, server = networking.start_server(
1455
+ self,
1456
+ server_name,
1457
+ server_port,
1458
+ ssl_keyfile,
1459
+ ssl_certfile,
1460
+ ssl_keyfile_password,
1461
+ )
1462
+ self.server_name = server_name
1463
+ self.local_url = local_url
1464
+ self.server_port = server_port
1465
+ self.server_app = app
1466
+ self.server = server
1467
+ self.is_running = True
1468
+ self.is_colab = utils.colab_check()
1469
+ self.is_kaggle = utils.kaggle_check()
1470
+ self.is_sagemaker = utils.sagemaker_check()
1471
+
1472
+ self.protocol = (
1473
+ "https"
1474
+ if self.local_url.startswith("https") or self.is_colab
1475
+ else "http"
1476
+ )
1477
+
1478
+ if self.enable_queue:
1479
+ self._queue.set_url(self.local_url)
1480
+
1481
+ # Cannot run async functions in background other than app's scope.
1482
+ # Workaround by triggering the app endpoint
1483
+ requests.get(f"{self.local_url}startup-events")
1484
+
1485
+ utils.launch_counter()
1486
+
1487
+ if share is None:
1488
+ if self.is_colab and self.enable_queue:
1489
+ if not quiet:
1490
+ print(
1491
+ "Setting queue=True in a Colab notebook requires sharing enabled. Setting `share=True` (you can turn this off by setting `share=False` in `launch()` explicitly).\n"
1492
+ )
1493
+ self.share = True
1494
+ elif self.is_kaggle:
1495
+ if not quiet:
1496
+ print(
1497
+ "Kaggle notebooks require sharing enabled. Setting `share=True` (you can turn this off by setting `share=False` in `launch()` explicitly).\n"
1498
+ )
1499
+ self.share = True
1500
+ elif self.is_sagemaker:
1501
+ if not quiet:
1502
+ print(
1503
+ "Sagemaker notebooks may require sharing enabled. Setting `share=True` (you can turn this off by setting `share=False` in `launch()` explicitly).\n"
1504
+ )
1505
+ self.share = True
1506
+ else:
1507
+ self.share = False
1508
+ else:
1509
+ self.share = share
1510
+
1511
+ # If running in a colab or not able to access localhost,
1512
+ # a shareable link must be created.
1513
+ if _frontend and (not networking.url_ok(self.local_url)) and (not self.share):
1514
+ raise ValueError(
1515
+ "When localhost is not accessible, a shareable link must be created. Please set share=True."
1516
+ )
1517
+
1518
+ if self.is_colab:
1519
+ if not quiet:
1520
+ if debug:
1521
+ print(strings.en["COLAB_DEBUG_TRUE"])
1522
+ else:
1523
+ print(strings.en["COLAB_DEBUG_FALSE"])
1524
+ if not self.share:
1525
+ print(strings.en["COLAB_WARNING"].format(self.server_port))
1526
+ if self.enable_queue and not self.share:
1527
+ raise ValueError(
1528
+ "When using queueing in Colab, a shareable link must be created. Please set share=True."
1529
+ )
1530
+ else:
1531
+ print(
1532
+ strings.en["RUNNING_LOCALLY_SEPARATED"].format(
1533
+ self.protocol, self.server_name, self.server_port
1534
+ )
1535
+ )
1536
+
1537
+ if self.share:
1538
+ if self.is_space:
1539
+ raise RuntimeError("Share is not supported when you are in Spaces")
1540
+ try:
1541
+ if self.share_url is None:
1542
+ self.share_url = networking.setup_tunnel(
1543
+ self.server_name, self.server_port, self.share_token
1544
+ )
1545
+ print(strings.en["SHARE_LINK_DISPLAY"].format(self.share_url))
1546
+ if not (quiet):
1547
+ print(strings.en["SHARE_LINK_MESSAGE"])
1548
+ except (RuntimeError, requests.exceptions.ConnectionError):
1549
+ if self.analytics_enabled:
1550
+ utils.error_analytics("Not able to set up tunnel")
1551
+ self.share_url = None
1552
+ self.share = False
1553
+ print(strings.en["COULD_NOT_GET_SHARE_LINK"])
1554
+ else:
1555
+ if not (quiet):
1556
+ print(strings.en["PUBLIC_SHARE_TRUE"])
1557
+ self.share_url = None
1558
+
1559
+ if inbrowser:
1560
+ link = self.share_url if self.share and self.share_url else self.local_url
1561
+ webbrowser.open(link)
1562
+
1563
+ # Check if running in a Python notebook in which case, display inline
1564
+ if inline is None:
1565
+ inline = utils.ipython_check() and (self.auth is None)
1566
+ if inline:
1567
+ if self.auth is not None:
1568
+ print(
1569
+ "Warning: authentication is not supported inline. Please"
1570
+ "click the link to access the interface in a new tab."
1571
+ )
1572
+ try:
1573
+ from IPython.display import HTML, Javascript, display # type: ignore
1574
+
1575
+ if self.share and self.share_url:
1576
+ while not networking.url_ok(self.share_url):
1577
+ time.sleep(0.25)
1578
+ display(
1579
+ HTML(
1580
+ f'<div><iframe src="{self.share_url}" width="{self.width}" height="{self.height}" allow="autoplay; camera; microphone; clipboard-read; clipboard-write;" frameborder="0" allowfullscreen></iframe></div>'
1581
+ )
1582
+ )
1583
+ elif self.is_colab:
1584
+ # modified from /usr/local/lib/python3.7/dist-packages/google/colab/output/_util.py within Colab environment
1585
+ code = """(async (port, path, width, height, cache, element) => {
1586
+ if (!google.colab.kernel.accessAllowed && !cache) {
1587
+ return;
1588
+ }
1589
+ element.appendChild(document.createTextNode(''));
1590
+ const url = await google.colab.kernel.proxyPort(port, {cache});
1591
+
1592
+ const external_link = document.createElement('div');
1593
+ external_link.innerHTML = `
1594
+ <div style="font-family: monospace; margin-bottom: 0.5rem">
1595
+ Running on <a href=${new URL(path, url).toString()} target="_blank">
1596
+ https://localhost:${port}${path}
1597
+ </a>
1598
+ </div>
1599
+ `;
1600
+ element.appendChild(external_link);
1601
+
1602
+ const iframe = document.createElement('iframe');
1603
+ iframe.src = new URL(path, url).toString();
1604
+ iframe.height = height;
1605
+ iframe.allow = "autoplay; camera; microphone; clipboard-read; clipboard-write;"
1606
+ iframe.width = width;
1607
+ iframe.style.border = 0;
1608
+ element.appendChild(iframe);
1609
+ })""" + "({port}, {path}, {width}, {height}, {cache}, window.element)".format(
1610
+ port=json.dumps(self.server_port),
1611
+ path=json.dumps("/"),
1612
+ width=json.dumps(self.width),
1613
+ height=json.dumps(self.height),
1614
+ cache=json.dumps(False),
1615
+ )
1616
+
1617
+ display(Javascript(code))
1618
+ else:
1619
+ display(
1620
+ HTML(
1621
+ f'<div><iframe src="{self.local_url}" width="{self.width}" height="{self.height}" allow="autoplay; camera; microphone; clipboard-read; clipboard-write;" frameborder="0" allowfullscreen></iframe></div>'
1622
+ )
1623
+ )
1624
+ except ImportError:
1625
+ pass
1626
+
1627
+ if getattr(self, "analytics_enabled", False):
1628
+ data = {
1629
+ "launch_method": "browser" if inbrowser else "inline",
1630
+ "is_google_colab": self.is_colab,
1631
+ "is_sharing_on": self.share,
1632
+ "share_url": self.share_url,
1633
+ "enable_queue": self.enable_queue,
1634
+ "show_tips": self.show_tips,
1635
+ "server_name": server_name,
1636
+ "server_port": server_port,
1637
+ "is_spaces": self.is_space,
1638
+ "mode": self.mode,
1639
+ }
1640
+ utils.launch_analytics(data)
1641
+ utils.launched_telemetry(self, data)
1642
+
1643
+ utils.show_tip(self)
1644
+
1645
+ # Block main thread if debug==True
1646
+ if debug or int(os.getenv("GRADIO_DEBUG", 0)) == 1:
1647
+ self.block_thread()
1648
+ # Block main thread if running in a script to stop script from exiting
1649
+ is_in_interactive_mode = bool(getattr(sys, "ps1", sys.flags.interactive))
1650
+
1651
+ if not prevent_thread_lock and not is_in_interactive_mode:
1652
+ self.block_thread()
1653
+
1654
+ return TupleNoPrint((self.server_app, self.local_url, self.share_url))
1655
+
1656
+ def integrate(
1657
+ self,
1658
+ comet_ml: comet_ml.Experiment | None = None,
1659
+ wandb: ModuleType | None = None,
1660
+ mlflow: ModuleType | None = None,
1661
+ ) -> None:
1662
+ """
1663
+ A catch-all method for integrating with other libraries. This method should be run after launch()
1664
+ Parameters:
1665
+ comet_ml: If a comet_ml Experiment object is provided, will integrate with the experiment and appear on Comet dashboard
1666
+ wandb: If the wandb module is provided, will integrate with it and appear on WandB dashboard
1667
+ mlflow: If the mlflow module is provided, will integrate with the experiment and appear on ML Flow dashboard
1668
+ """
1669
+ analytics_integration = ""
1670
+ if comet_ml is not None:
1671
+ analytics_integration = "CometML"
1672
+ comet_ml.log_other("Created from", "Gradio")
1673
+ if self.share_url is not None:
1674
+ comet_ml.log_text("gradio: " + self.share_url)
1675
+ comet_ml.end()
1676
+ elif self.local_url:
1677
+ comet_ml.log_text("gradio: " + self.local_url)
1678
+ comet_ml.end()
1679
+ else:
1680
+ raise ValueError("Please run `launch()` first.")
1681
+ if wandb is not None:
1682
+ analytics_integration = "WandB"
1683
+ if self.share_url is not None:
1684
+ wandb.log(
1685
+ {
1686
+ "Gradio panel": wandb.Html(
1687
+ '<iframe src="'
1688
+ + self.share_url
1689
+ + '" width="'
1690
+ + str(self.width)
1691
+ + '" height="'
1692
+ + str(self.height)
1693
+ + '" frameBorder="0"></iframe>'
1694
+ )
1695
+ }
1696
+ )
1697
+ else:
1698
+ print(
1699
+ "The WandB integration requires you to "
1700
+ "`launch(share=True)` first."
1701
+ )
1702
+ if mlflow is not None:
1703
+ analytics_integration = "MLFlow"
1704
+ if self.share_url is not None:
1705
+ mlflow.log_param("Gradio Interface Share Link", self.share_url)
1706
+ else:
1707
+ mlflow.log_param("Gradio Interface Local Link", self.local_url)
1708
+ if self.analytics_enabled and analytics_integration:
1709
+ data = {"integration": analytics_integration}
1710
+ utils.integration_analytics(data)
1711
+
1712
+ def close(self, verbose: bool = True) -> None:
1713
+ """
1714
+ Closes the Interface that was launched and frees the port.
1715
+ """
1716
+ try:
1717
+ if self.enable_queue:
1718
+ self._queue.close()
1719
+ self.server.close()
1720
+ self.is_running = False
1721
+ # So that the startup events (starting the queue)
1722
+ # happen the next time the app is launched
1723
+ self.app.startup_events_triggered = False
1724
+ if verbose:
1725
+ print("Closing server running on port: {}".format(self.server_port))
1726
+ except (AttributeError, OSError): # can't close if not running
1727
+ pass
1728
+
1729
+ def block_thread(
1730
+ self,
1731
+ ) -> None:
1732
+ """Block main thread until interrupted by user."""
1733
+ try:
1734
+ while True:
1735
+ time.sleep(0.1)
1736
+ except (KeyboardInterrupt, OSError):
1737
+ print("Keyboard interruption in main thread... closing server.")
1738
+ self.server.close()
1739
+ for tunnel in CURRENT_TUNNELS:
1740
+ tunnel.kill()
1741
+
1742
+ def attach_load_events(self):
1743
+ """Add a load event for every component whose initial value should be randomized."""
1744
+ if Context.root_block:
1745
+ for component in Context.root_block.blocks.values():
1746
+ if (
1747
+ isinstance(component, components.IOComponent)
1748
+ and component.load_event_to_attach
1749
+ ):
1750
+ load_fn, every = component.load_event_to_attach
1751
+ # Use set_event_trigger to avoid ambiguity between load class/instance method
1752
+ dep = self.set_event_trigger(
1753
+ "load",
1754
+ load_fn,
1755
+ None,
1756
+ component,
1757
+ no_target=True,
1758
+ # If every is None, for sure skip the queue
1759
+ # else, let the enable_queue parameter take precedence
1760
+ # this will raise a nice error message is every is used
1761
+ # without queue
1762
+ queue=False if every is None else None,
1763
+ every=every,
1764
+ )[0]
1765
+ component.load_event = dep
1766
+
1767
+ def startup_events(self):
1768
+ """Events that should be run when the app containing this block starts up."""
1769
+
1770
+ if self.enable_queue:
1771
+ utils.run_coro_in_background(self._queue.start, (self.progress_tracking,))
1772
+ # So that processing can resume in case the queue was stopped
1773
+ self._queue.stopped = False
1774
+ utils.run_coro_in_background(self.create_limiter)
1775
+
1776
+ def queue_enabled_for_fn(self, fn_index: int):
1777
+ if self.dependencies[fn_index]["queue"] is None:
1778
+ return self.enable_queue
1779
+ return self.dependencies[fn_index]["queue"]
gradio/components.py ADDED
The diff for this file is too large to render. See raw diff
 
gradio/context.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Defines the Context class, which is used to store the state of all Blocks that are being rendered.
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import TYPE_CHECKING
6
+
7
+ if TYPE_CHECKING: # Only import for type checking (is False at runtime).
8
+ from gradio.blocks import BlockContext, Blocks
9
+
10
+
11
+ class Context:
12
+ root_block: Blocks | None = None # The current root block that holds all blocks.
13
+ block: BlockContext | None = None # The current block that children are added to.
14
+ id: int = 0 # Running id to uniquely refer to any block that gets defined
15
+ ip_address: str | None = None # The IP address of the user.
16
+ access_token: str | None = (
17
+ None # The HF token that is provided when loading private models or Spaces
18
+ )
gradio/data_classes.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Pydantic data models and other dataclasses. This is the only file that uses Optional[]
2
+ typing syntax instead of | None syntax to work with pydantic"""
3
+ from enum import Enum, auto
4
+ from typing import Any, Dict, List, Optional, Union
5
+
6
+ from pydantic import BaseModel
7
+
8
+
9
+ class PredictBody(BaseModel):
10
+ session_hash: Optional[str]
11
+ event_id: Optional[str]
12
+ data: List[Any]
13
+ event_data: Optional[Any]
14
+ fn_index: Optional[int]
15
+ batched: Optional[
16
+ bool
17
+ ] = False # Whether the data is a batch of samples (i.e. called from the queue if batch=True) or a single sample (i.e. called from the UI)
18
+ request: Optional[
19
+ Union[Dict, List[Dict]]
20
+ ] = None # dictionary of request headers, query parameters, url, etc. (used to to pass in request for queuing)
21
+
22
+
23
+ class ResetBody(BaseModel):
24
+ session_hash: str
25
+ fn_index: int
26
+
27
+
28
+ class InterfaceTypes(Enum):
29
+ STANDARD = auto()
30
+ INPUT_ONLY = auto()
31
+ OUTPUT_ONLY = auto()
32
+ UNIFIED = auto()
33
+
34
+
35
+ class Estimation(BaseModel):
36
+ msg: Optional[str] = "estimation"
37
+ rank: Optional[int] = None
38
+ queue_size: int
39
+ avg_event_process_time: Optional[float]
40
+ avg_event_concurrent_process_time: Optional[float]
41
+ rank_eta: Optional[float] = None
42
+ queue_eta: float
43
+
44
+
45
+ class ProgressUnit(BaseModel):
46
+ index: Optional[int]
47
+ length: Optional[int]
48
+ unit: Optional[str]
49
+ progress: Optional[float]
50
+ desc: Optional[str]
51
+
52
+
53
+ class Progress(BaseModel):
54
+ msg: str = "progress"
55
+ progress_data: List[ProgressUnit] = []
gradio/deprecation.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+
3
+
4
+ def simple_deprecated_notice(term: str) -> str:
5
+ return f"`{term}` parameter is deprecated, and it has no effect"
6
+
7
+
8
+ def use_in_launch(term: str) -> str:
9
+ return f"`{term}` is deprecated in `Interface()`, please use it within `launch()` instead."
10
+
11
+
12
+ DEPRECATION_MESSAGE = {
13
+ "optional": simple_deprecated_notice("optional"),
14
+ "keep_filename": simple_deprecated_notice("keep_filename"),
15
+ "numeric": simple_deprecated_notice("numeric"),
16
+ "verbose": simple_deprecated_notice("verbose"),
17
+ "allow_screenshot": simple_deprecated_notice("allow_screenshot"),
18
+ "layout": simple_deprecated_notice("layout"),
19
+ "show_input": simple_deprecated_notice("show_input"),
20
+ "show_output": simple_deprecated_notice("show_output"),
21
+ "capture_session": simple_deprecated_notice("capture_session"),
22
+ "api_mode": simple_deprecated_notice("api_mode"),
23
+ "show_tips": use_in_launch("show_tips"),
24
+ "encrypt": simple_deprecated_notice("encrypt"),
25
+ "enable_queue": use_in_launch("enable_queue"),
26
+ "server_name": use_in_launch("server_name"),
27
+ "server_port": use_in_launch("server_port"),
28
+ "width": use_in_launch("width"),
29
+ "height": use_in_launch("height"),
30
+ "plot": "The 'plot' parameter has been deprecated. Use the new Plot component instead",
31
+ "type": "The 'type' parameter has been deprecated. Use the Number component instead.",
32
+ }
33
+
34
+
35
+ def check_deprecated_parameters(cls: str, **kwargs) -> None:
36
+ for key, value in DEPRECATION_MESSAGE.items():
37
+ if key in kwargs:
38
+ kwargs.pop(key)
39
+ # Interestingly, using DeprecationWarning causes warning to not appear.
40
+ warnings.warn(value)
41
+
42
+ if len(kwargs) != 0:
43
+ warnings.warn(
44
+ f"You have unused kwarg parameters in {cls}, please remove them: {kwargs}"
45
+ )
gradio/documentation.py ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Contains methods that generate documentation for Gradio functions and classes."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import inspect
6
+ from typing import Callable, Dict, List, Tuple
7
+
8
+ classes_to_document = {}
9
+ classes_inherit_documentation = {}
10
+ documentation_group = None
11
+
12
+
13
+ def set_documentation_group(m):
14
+ global documentation_group
15
+ documentation_group = m
16
+ if m not in classes_to_document:
17
+ classes_to_document[m] = []
18
+
19
+
20
+ def extract_instance_attr_doc(cls, attr):
21
+ code = inspect.getsource(cls.__init__)
22
+ lines = [line.strip() for line in code.split("\n")]
23
+ i = None
24
+ for i, line in enumerate(lines):
25
+ if line.startswith("self." + attr + ":") or line.startswith(
26
+ "self." + attr + " ="
27
+ ):
28
+ break
29
+ assert i is not None, f"Could not find {attr} in {cls.__name__}"
30
+ start_line = lines.index('"""', i)
31
+ end_line = lines.index('"""', start_line + 1)
32
+ for j in range(i + 1, start_line):
33
+ assert not lines[j].startswith("self."), (
34
+ f"Found another attribute before docstring for {attr} in {cls.__name__}: "
35
+ + lines[j]
36
+ + "\n start:"
37
+ + lines[i]
38
+ )
39
+ doc_string = " ".join(lines[start_line + 1 : end_line])
40
+ return doc_string
41
+
42
+
43
+ def document(*fns, inherit=False):
44
+ """
45
+ Defines the @document decorator which adds classes or functions to the Gradio
46
+ documentation at www.gradio.app/docs.
47
+
48
+ Usage examples:
49
+ - Put @document() above a class to document the class and its constructor.
50
+ - Put @document("fn1", "fn2") above a class to also document methods fn1 and fn2.
51
+ - Put @document("*fn3") with an asterisk above a class to document the instance attribute methods f3.
52
+ """
53
+
54
+ def inner_doc(cls):
55
+ global documentation_group
56
+ if inherit:
57
+ classes_inherit_documentation[cls] = None
58
+ classes_to_document[documentation_group].append((cls, fns))
59
+ return cls
60
+
61
+ return inner_doc
62
+
63
+
64
+ def document_fn(fn: Callable, cls) -> Tuple[str, List[Dict], Dict, str | None]:
65
+ """
66
+ Generates documentation for any function.
67
+ Parameters:
68
+ fn: Function to document
69
+ Returns:
70
+ description: General description of fn
71
+ parameters: A list of dicts for each parameter, storing data for the parameter name, annotation and doc
72
+ return: A dict storing data for the returned annotation and doc
73
+ example: Code for an example use of the fn
74
+ """
75
+ doc_str = inspect.getdoc(fn) or ""
76
+ doc_lines = doc_str.split("\n")
77
+ signature = inspect.signature(fn)
78
+ description, parameters, returns, examples = [], {}, [], []
79
+ mode = "description"
80
+ for line in doc_lines:
81
+ line = line.rstrip()
82
+ if line == "Parameters:":
83
+ mode = "parameter"
84
+ elif line.startswith("Example:"):
85
+ mode = "example"
86
+ if "(" in line and ")" in line:
87
+ c = line.split("(")[1].split(")")[0]
88
+ if c != cls.__name__:
89
+ mode = "ignore"
90
+ elif line == "Returns:":
91
+ mode = "return"
92
+ else:
93
+ if mode == "description":
94
+ description.append(line if line.strip() else "<br>")
95
+ continue
96
+ assert (
97
+ line.startswith(" ") or line.strip() == ""
98
+ ), f"Documentation format for {fn.__name__} has format error in line: {line}"
99
+ line = line[4:]
100
+ if mode == "parameter":
101
+ colon_index = line.index(": ")
102
+ assert (
103
+ colon_index > -1
104
+ ), f"Documentation format for {fn.__name__} has format error in line: {line}"
105
+ parameter = line[:colon_index]
106
+ parameter_doc = line[colon_index + 2 :]
107
+ parameters[parameter] = parameter_doc
108
+ elif mode == "return":
109
+ returns.append(line)
110
+ elif mode == "example":
111
+ examples.append(line)
112
+ description_doc = " ".join(description)
113
+ parameter_docs = []
114
+ for param_name, param in signature.parameters.items():
115
+ if param_name.startswith("_"):
116
+ continue
117
+ if param_name == "kwargs" and param_name not in parameters:
118
+ continue
119
+ parameter_doc = {
120
+ "name": param_name,
121
+ "annotation": param.annotation,
122
+ "doc": parameters.get(param_name),
123
+ }
124
+ if param_name in parameters:
125
+ del parameters[param_name]
126
+ if param.default != inspect.Parameter.empty:
127
+ default = param.default
128
+ if type(default) == str:
129
+ default = '"' + default + '"'
130
+ if default.__class__.__module__ != "builtins":
131
+ default = f"{default.__class__.__name__}()"
132
+ parameter_doc["default"] = default
133
+ elif parameter_doc["doc"] is not None and "kwargs" in parameter_doc["doc"]:
134
+ parameter_doc["kwargs"] = True
135
+ parameter_docs.append(parameter_doc)
136
+ assert (
137
+ len(parameters) == 0
138
+ ), f"Documentation format for {fn.__name__} documents nonexistent parameters: {''.join(parameters.keys())}"
139
+ if len(returns) == 0:
140
+ return_docs = {}
141
+ elif len(returns) == 1:
142
+ return_docs = {"annotation": signature.return_annotation, "doc": returns[0]}
143
+ else:
144
+ return_docs = {}
145
+ # raise ValueError("Does not support multiple returns yet.")
146
+ examples_doc = "\n".join(examples) if len(examples) > 0 else None
147
+ return description_doc, parameter_docs, return_docs, examples_doc
148
+
149
+
150
+ def document_cls(cls):
151
+ doc_str = inspect.getdoc(cls)
152
+ if doc_str is None:
153
+ return "", {}, ""
154
+ tags = {}
155
+ description_lines = []
156
+ mode = "description"
157
+ for line in doc_str.split("\n"):
158
+ line = line.rstrip()
159
+ if line.endswith(":") and " " not in line:
160
+ mode = line[:-1].lower()
161
+ tags[mode] = []
162
+ elif line.split(" ")[0].endswith(":") and not line.startswith(" "):
163
+ tag = line[: line.index(":")].lower()
164
+ value = line[line.index(":") + 2 :]
165
+ tags[tag] = value
166
+ else:
167
+ if mode == "description":
168
+ description_lines.append(line if line.strip() else "<br>")
169
+ else:
170
+ assert (
171
+ line.startswith(" ") or not line.strip()
172
+ ), f"Documentation format for {cls.__name__} has format error in line: {line}"
173
+ tags[mode].append(line[4:])
174
+ if "example" in tags:
175
+ example = "\n".join(tags["example"])
176
+ del tags["example"]
177
+ else:
178
+ example = None
179
+ for key, val in tags.items():
180
+ if isinstance(val, list):
181
+ tags[key] = "<br>".join(val)
182
+ description = " ".join(description_lines).replace("\n", "<br>")
183
+ return description, tags, example
184
+
185
+
186
+ def generate_documentation():
187
+ documentation = {}
188
+ for mode, class_list in classes_to_document.items():
189
+ documentation[mode] = []
190
+ for cls, fns in class_list:
191
+ fn_to_document = cls if inspect.isfunction(cls) else cls.__init__
192
+ _, parameter_doc, return_doc, _ = document_fn(fn_to_document, cls)
193
+ cls_description, cls_tags, cls_example = document_cls(cls)
194
+ cls_documentation = {
195
+ "class": cls,
196
+ "name": cls.__name__,
197
+ "description": cls_description,
198
+ "tags": cls_tags,
199
+ "parameters": parameter_doc,
200
+ "returns": return_doc,
201
+ "example": cls_example,
202
+ "fns": [],
203
+ }
204
+ for fn_name in fns:
205
+ instance_attribute_fn = fn_name.startswith("*")
206
+ if instance_attribute_fn:
207
+ fn_name = fn_name[1:]
208
+ # Instance attribute fns are classes
209
+ # whose __call__ method determines their behavior
210
+ fn = getattr(cls(), fn_name).__call__
211
+ else:
212
+ fn = getattr(cls, fn_name)
213
+ if not callable(fn):
214
+ description_doc = str(fn)
215
+ parameter_docs = {}
216
+ return_docs = {}
217
+ examples_doc = ""
218
+ override_signature = f"gr.{cls.__name__}.{fn_name}"
219
+ else:
220
+ (
221
+ description_doc,
222
+ parameter_docs,
223
+ return_docs,
224
+ examples_doc,
225
+ ) = document_fn(fn, cls)
226
+ override_signature = None
227
+ if instance_attribute_fn:
228
+ description_doc = extract_instance_attr_doc(cls, fn_name)
229
+ cls_documentation["fns"].append(
230
+ {
231
+ "fn": fn,
232
+ "name": fn_name,
233
+ "description": description_doc,
234
+ "tags": {},
235
+ "parameters": parameter_docs,
236
+ "returns": return_docs,
237
+ "example": examples_doc,
238
+ "override_signature": override_signature,
239
+ }
240
+ )
241
+ documentation[mode].append(cls_documentation)
242
+ if cls in classes_inherit_documentation:
243
+ classes_inherit_documentation[cls] = cls_documentation["fns"]
244
+ for mode, class_list in classes_to_document.items():
245
+ for i, (cls, _) in enumerate(class_list):
246
+ for super_class in classes_inherit_documentation:
247
+ if (
248
+ inspect.isclass(cls)
249
+ and issubclass(cls, super_class)
250
+ and cls != super_class
251
+ ):
252
+ for inherited_fn in classes_inherit_documentation[super_class]:
253
+ inherited_fn = dict(inherited_fn)
254
+ try:
255
+ inherited_fn["description"] = extract_instance_attr_doc(
256
+ cls, inherited_fn["name"]
257
+ )
258
+ except (ValueError, AssertionError):
259
+ pass
260
+ documentation[mode][i]["fns"].append(inherited_fn)
261
+ return documentation
gradio/events.py ADDED
@@ -0,0 +1,298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Contains all of the events that can be triggered in a gr.Blocks() app, with the exception
2
+ of the on-page-load event, which is defined in gr.Blocks().load()."""
3
+
4
+ from __future__ import annotations
5
+
6
+ import warnings
7
+ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Set, Tuple
8
+
9
+ from gradio.blocks import Block
10
+ from gradio.documentation import document, set_documentation_group
11
+ from gradio.helpers import EventData
12
+ from gradio.utils import get_cancel_function
13
+
14
+ if TYPE_CHECKING: # Only import for type checking (is False at runtime).
15
+ from gradio.components import Component, StatusTracker
16
+
17
+ set_documentation_group("events")
18
+
19
+
20
+ def set_cancel_events(
21
+ block: Block, event_name: str, cancels: None | Dict[str, Any] | List[Dict[str, Any]]
22
+ ):
23
+ if cancels:
24
+ if not isinstance(cancels, list):
25
+ cancels = [cancels]
26
+ cancel_fn, fn_indices_to_cancel = get_cancel_function(cancels)
27
+ block.set_event_trigger(
28
+ event_name,
29
+ cancel_fn,
30
+ inputs=None,
31
+ outputs=None,
32
+ queue=False,
33
+ preprocess=False,
34
+ cancels=fn_indices_to_cancel,
35
+ )
36
+
37
+
38
+ class EventListener(Block):
39
+ def __init__(self: Any):
40
+ for event_listener_class in EventListener.__subclasses__():
41
+ if isinstance(self, event_listener_class):
42
+ event_listener_class.__init__(self)
43
+
44
+
45
+ class Dependency(dict):
46
+ def __init__(self, trigger, key_vals, dep_index):
47
+ super().__init__(key_vals)
48
+ self.trigger = trigger
49
+ self.then = EventListenerMethod(
50
+ self.trigger,
51
+ "then",
52
+ trigger_after=dep_index,
53
+ trigger_only_on_success=False,
54
+ )
55
+ """
56
+ Triggered after directly preceding event is completed, regardless of success or failure.
57
+ """
58
+ self.success = EventListenerMethod(
59
+ self.trigger,
60
+ "success",
61
+ trigger_after=dep_index,
62
+ trigger_only_on_success=True,
63
+ )
64
+ """
65
+ Triggered after directly preceding event is completed, if it was successful.
66
+ """
67
+
68
+
69
+ class EventListenerMethod:
70
+ """
71
+ Triggered on an event deployment.
72
+ """
73
+
74
+ def __init__(
75
+ self,
76
+ trigger: Block,
77
+ event_name: str,
78
+ show_progress: bool = True,
79
+ callback: Callable | None = None,
80
+ trigger_after: int | None = None,
81
+ trigger_only_on_success: bool = False,
82
+ ):
83
+ self.trigger = trigger
84
+ self.event_name = event_name
85
+ self.show_progress = show_progress
86
+ self.callback = callback
87
+ self.trigger_after = trigger_after
88
+ self.trigger_only_on_success = trigger_only_on_success
89
+
90
+ def __call__(
91
+ self,
92
+ fn: Callable | None,
93
+ inputs: Component | List[Component] | Set[Component] | None = None,
94
+ outputs: Component | List[Component] | None = None,
95
+ api_name: str | None = None,
96
+ status_tracker: StatusTracker | None = None,
97
+ scroll_to_output: bool = False,
98
+ show_progress: bool | None = None,
99
+ queue: bool | None = None,
100
+ batch: bool = False,
101
+ max_batch_size: int = 4,
102
+ preprocess: bool = True,
103
+ postprocess: bool = True,
104
+ cancels: Dict[str, Any] | List[Dict[str, Any]] | None = None,
105
+ every: float | None = None,
106
+ _js: str | None = None,
107
+ ) -> Dependency:
108
+ """
109
+ Parameters:
110
+ fn: the function to wrap an interface around. Often a machine learning model's prediction function. Each parameter of the function corresponds to one input component, and the function should return a single value or a tuple of values, with each element in the tuple corresponding to one output component.
111
+ inputs: List of gradio.components to use as inputs. If the function takes no inputs, this should be an empty list.
112
+ outputs: List of gradio.components to use as outputs. If the function returns no outputs, this should be an empty list.
113
+ api_name: Defining this parameter exposes the endpoint in the api docs
114
+ scroll_to_output: If True, will scroll to output component on completion
115
+ show_progress: If True, will show progress animation while pending
116
+ queue: If True, will place the request on the queue, if the queue exists
117
+ batch: If True, then the function should process a batch of inputs, meaning that it should accept a list of input values for each parameter. The lists should be of equal length (and be up to length `max_batch_size`). The function is then *required* to return a tuple of lists (even if there is only 1 output component), with each list in the tuple corresponding to one output component.
118
+ max_batch_size: Maximum number of inputs to batch together if this is called from the queue (only relevant if batch=True)
119
+ preprocess: If False, will not run preprocessing of component data before running 'fn' (e.g. leaving it as a base64 string if this method is called with the `Image` component).
120
+ postprocess: If False, will not run postprocessing of component data before returning 'fn' output to the browser.
121
+ cancels: A list of other events to cancel when this event is triggered. For example, setting cancels=[click_event] will cancel the click_event, where click_event is the return value of another components .click method.
122
+ every: Run this event 'every' number of seconds while the client connection is open. Interpreted in seconds. Queue must be enabled.
123
+ """
124
+ # _js: Optional frontend js method to run before running 'fn'. Input arguments for js method are values of 'inputs' and 'outputs', return should be a list of values for output components.
125
+ if status_tracker:
126
+ warnings.warn(
127
+ "The 'status_tracker' parameter has been deprecated and has no effect."
128
+ )
129
+ dep, dep_index = self.trigger.set_event_trigger(
130
+ self.event_name,
131
+ fn,
132
+ inputs,
133
+ outputs,
134
+ preprocess=preprocess,
135
+ postprocess=postprocess,
136
+ scroll_to_output=scroll_to_output,
137
+ show_progress=show_progress
138
+ if show_progress is not None
139
+ else self.show_progress,
140
+ api_name=api_name,
141
+ js=_js,
142
+ queue=queue,
143
+ batch=batch,
144
+ max_batch_size=max_batch_size,
145
+ every=every,
146
+ trigger_after=self.trigger_after,
147
+ trigger_only_on_success=self.trigger_only_on_success,
148
+ )
149
+ set_cancel_events(self.trigger, self.event_name, cancels)
150
+ if self.callback:
151
+ self.callback()
152
+ return Dependency(self.trigger, dep, dep_index)
153
+
154
+
155
+ @document("*change", inherit=True)
156
+ class Changeable(EventListener):
157
+ def __init__(self):
158
+ self.change = EventListenerMethod(self, "change")
159
+ """
160
+ This event is triggered when the component's input value changes (e.g. when the user types in a textbox
161
+ or uploads an image). This method can be used when this component is in a Gradio Blocks.
162
+ """
163
+
164
+
165
+ @document("*click", inherit=True)
166
+ class Clickable(EventListener):
167
+ def __init__(self):
168
+ self.click = EventListenerMethod(self, "click")
169
+ """
170
+ This event is triggered when the component (e.g. a button) is clicked.
171
+ This method can be used when this component is in a Gradio Blocks.
172
+ """
173
+
174
+
175
+ @document("*submit", inherit=True)
176
+ class Submittable(EventListener):
177
+ def __init__(self):
178
+ self.submit = EventListenerMethod(self, "submit")
179
+ """
180
+ This event is triggered when the user presses the Enter key while the component (e.g. a textbox) is focused.
181
+ This method can be used when this component is in a Gradio Blocks.
182
+ """
183
+
184
+
185
+ @document("*edit", inherit=True)
186
+ class Editable(EventListener):
187
+ def __init__(self):
188
+ self.edit = EventListenerMethod(self, "edit")
189
+ """
190
+ This event is triggered when the user edits the component (e.g. image) using the
191
+ built-in editor. This method can be used when this component is in a Gradio Blocks.
192
+ """
193
+
194
+
195
+ @document("*clear", inherit=True)
196
+ class Clearable(EventListener):
197
+ def __init__(self):
198
+ self.clear = EventListenerMethod(self, "clear")
199
+ """
200
+ This event is triggered when the user clears the component (e.g. image or audio)
201
+ using the X button for the component. This method can be used when this component is in a Gradio Blocks.
202
+ """
203
+
204
+
205
+ @document("*play", "*pause", "*stop", inherit=True)
206
+ class Playable(EventListener):
207
+ def __init__(self):
208
+ self.play = EventListenerMethod(self, "play")
209
+ """
210
+ This event is triggered when the user plays the component (e.g. audio or video).
211
+ This method can be used when this component is in a Gradio Blocks.
212
+ """
213
+
214
+ self.pause = EventListenerMethod(self, "pause")
215
+ """
216
+ This event is triggered when the user pauses the component (e.g. audio or video).
217
+ This method can be used when this component is in a Gradio Blocks.
218
+ """
219
+
220
+ self.stop = EventListenerMethod(self, "stop")
221
+ """
222
+ This event is triggered when the user stops the component (e.g. audio or video).
223
+ This method can be used when this component is in a Gradio Blocks.
224
+ """
225
+
226
+
227
+ @document("*stream", inherit=True)
228
+ class Streamable(EventListener):
229
+ def __init__(self):
230
+ self.streaming: bool
231
+ self.stream = EventListenerMethod(
232
+ self,
233
+ "stream",
234
+ show_progress=False,
235
+ callback=lambda: setattr(self, "streaming", True),
236
+ )
237
+ """
238
+ This event is triggered when the user streams the component (e.g. a live webcam
239
+ component). This method can be used when this component is in a Gradio Blocks.
240
+ """
241
+
242
+
243
+ @document("*blur", inherit=True)
244
+ class Blurrable(EventListener):
245
+ def __init__(self):
246
+ self.blur = EventListenerMethod(self, "blur")
247
+ """
248
+ This event is triggered when the component's is unfocused/blurred (e.g. when the user clicks outside of a textbox). This method can be used when this component is in a Gradio Blocks.
249
+ """
250
+
251
+
252
+ @document("*upload", inherit=True)
253
+ class Uploadable(EventListener):
254
+ def __init__(self):
255
+ self.upload = EventListenerMethod(self, "upload")
256
+ """
257
+ This event is triggered when the user uploads a file into the component (e.g. when the user uploads a video into a video component). This method can be used when this component is in a Gradio Blocks.
258
+ """
259
+
260
+
261
+ @document("*release", inherit=True)
262
+ class Releaseable(EventListener):
263
+ def __init__(self):
264
+ self.release = EventListenerMethod(self, "release")
265
+ """
266
+ This event is triggered when the user releases the mouse on this component (e.g. when the user releases the slider). This method can be used when this component is in a Gradio Blocks.
267
+ """
268
+
269
+
270
+ @document("*select", inherit=True)
271
+ class Selectable(EventListener):
272
+ def __init__(self):
273
+ self.selectable: bool = False
274
+ self.select = EventListenerMethod(
275
+ self, "select", callback=lambda: setattr(self, "selectable", True)
276
+ )
277
+ """
278
+ This event is triggered when the user selects from within the Component.
279
+ This event has EventData of type gradio.SelectData that carries information, accessible through SelectData.index and SelectData.value.
280
+ See EventData documentation on how to use this event data.
281
+ """
282
+
283
+
284
+ class SelectData(EventData):
285
+ def __init__(self, target: Block | None, data: Any):
286
+ super().__init__(target, data)
287
+ self.index: int | Tuple[int, int] = data["index"]
288
+ """
289
+ The index of the selected item. Is a tuple if the component is two dimensional or selection is a range.
290
+ """
291
+ self.value: Any = data["value"]
292
+ """
293
+ The value of the selected item.
294
+ """
295
+ self.selected: bool = data.get("selected", True)
296
+ """
297
+ True if the item was selected, False if deselected.
298
+ """
gradio/exceptions.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from gradio.documentation import document, set_documentation_group
2
+
3
+ set_documentation_group("helpers")
4
+
5
+
6
+ class DuplicateBlockError(ValueError):
7
+ """Raised when a Blocks contains more than one Block with the same id"""
8
+
9
+ pass
10
+
11
+
12
+ class TooManyRequestsError(Exception):
13
+ """Raised when the Hugging Face API returns a 429 status code."""
14
+
15
+ pass
16
+
17
+
18
+ class InvalidApiName(ValueError):
19
+ pass
20
+
21
+
22
+ @document()
23
+ class Error(Exception):
24
+ """
25
+ This class allows you to pass custom error messages to the user. You can do so by raising a gr.Error("custom message") anywhere in the code, and when that line is executed the custom message will appear in a modal on the demo.
26
+
27
+ Demos: calculator
28
+ """
29
+
30
+ def __init__(self, message: str):
31
+ """
32
+ Parameters:
33
+ message: The error message to be displayed to the user.
34
+ """
35
+ self.message = message
36
+ super().__init__(self.message)
37
+
38
+ def __str__(self):
39
+ return repr(self.message)
gradio/external.py ADDED
@@ -0,0 +1,512 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This module should not be used directly as its API is subject to change. Instead,
2
+ use the `gr.Blocks.load()` or `gr.Interface.load()` functions."""
3
+
4
+ from __future__ import annotations
5
+
6
+ import json
7
+ import re
8
+ import uuid
9
+ import warnings
10
+ from copy import deepcopy
11
+ from typing import TYPE_CHECKING, Callable, Dict
12
+
13
+ import requests
14
+
15
+ import gradio
16
+ from gradio import components, utils
17
+ from gradio.context import Context
18
+ from gradio.exceptions import Error, TooManyRequestsError
19
+ from gradio.external_utils import (
20
+ cols_to_rows,
21
+ encode_to_base64,
22
+ get_tabular_examples,
23
+ get_ws_fn,
24
+ postprocess_label,
25
+ rows_to_cols,
26
+ streamline_spaces_interface,
27
+ use_websocket,
28
+ )
29
+ from gradio.processing_utils import to_binary
30
+
31
+ if TYPE_CHECKING:
32
+ from gradio.blocks import Blocks
33
+ from gradio.interface import Interface
34
+
35
+
36
+ def load_blocks_from_repo(
37
+ name: str,
38
+ src: str | None = None,
39
+ api_key: str | None = None,
40
+ alias: str | None = None,
41
+ **kwargs,
42
+ ) -> Blocks:
43
+ """Creates and returns a Blocks instance from a Hugging Face model or Space repo."""
44
+ if src is None:
45
+ # Separate the repo type (e.g. "model") from repo name (e.g. "google/vit-base-patch16-224")
46
+ tokens = name.split("/")
47
+ assert (
48
+ len(tokens) > 1
49
+ ), "Either `src` parameter must be provided, or `name` must be formatted as {src}/{repo name}"
50
+ src = tokens[0]
51
+ name = "/".join(tokens[1:])
52
+
53
+ factory_methods: Dict[str, Callable] = {
54
+ # for each repo type, we have a method that returns the Interface given the model name & optionally an api_key
55
+ "huggingface": from_model,
56
+ "models": from_model,
57
+ "spaces": from_spaces,
58
+ }
59
+ assert src.lower() in factory_methods, "parameter: src must be one of {}".format(
60
+ factory_methods.keys()
61
+ )
62
+
63
+ if api_key is not None:
64
+ if Context.access_token is not None and Context.access_token != api_key:
65
+ warnings.warn(
66
+ """You are loading a model/Space with a different access token than the one you used to load a previous model/Space. This is not recommended, as it may cause unexpected behavior."""
67
+ )
68
+ Context.access_token = api_key
69
+
70
+ blocks: gradio.Blocks = factory_methods[src](name, api_key, alias, **kwargs)
71
+ return blocks
72
+
73
+
74
+ def chatbot_preprocess(text, state):
75
+ payload = {
76
+ "inputs": {"generated_responses": None, "past_user_inputs": None, "text": text}
77
+ }
78
+ if state is not None:
79
+ payload["inputs"]["generated_responses"] = state["conversation"][
80
+ "generated_responses"
81
+ ]
82
+ payload["inputs"]["past_user_inputs"] = state["conversation"][
83
+ "past_user_inputs"
84
+ ]
85
+
86
+ return payload
87
+
88
+
89
+ def chatbot_postprocess(response):
90
+ response_json = response.json()
91
+ chatbot_value = list(
92
+ zip(
93
+ response_json["conversation"]["past_user_inputs"],
94
+ response_json["conversation"]["generated_responses"],
95
+ )
96
+ )
97
+ return chatbot_value, response_json
98
+
99
+
100
+ def from_model(model_name: str, api_key: str | None, alias: str | None, **kwargs):
101
+ model_url = "https://huggingface.co/{}".format(model_name)
102
+ api_url = "https://api-inference.huggingface.co/models/{}".format(model_name)
103
+ print("Fetching model from: {}".format(model_url))
104
+
105
+ headers = {"Authorization": f"Bearer {api_key}"} if api_key is not None else {}
106
+
107
+ # Checking if model exists, and if so, it gets the pipeline
108
+ response = requests.request("GET", api_url, headers=headers)
109
+ assert (
110
+ response.status_code == 200
111
+ ), f"Could not find model: {model_name}. If it is a private or gated model, please provide your Hugging Face access token (https://huggingface.co/settings/tokens) as the argument for the `api_key` parameter."
112
+ p = response.json().get("pipeline_tag")
113
+ pipelines = {
114
+ "audio-classification": {
115
+ # example model: ehcalabres/wav2vec2-lg-xlsr-en-speech-emotion-recognition
116
+ "inputs": components.Audio(source="upload", type="filepath", label="Input"),
117
+ "outputs": components.Label(label="Class"),
118
+ "preprocess": lambda i: to_binary,
119
+ "postprocess": lambda r: postprocess_label(
120
+ {i["label"].split(", ")[0]: i["score"] for i in r.json()}
121
+ ),
122
+ },
123
+ "audio-to-audio": {
124
+ # example model: facebook/xm_transformer_sm_all-en
125
+ "inputs": components.Audio(source="upload", type="filepath", label="Input"),
126
+ "outputs": components.Audio(label="Output"),
127
+ "preprocess": to_binary,
128
+ "postprocess": encode_to_base64,
129
+ },
130
+ "automatic-speech-recognition": {
131
+ # example model: facebook/wav2vec2-base-960h
132
+ "inputs": components.Audio(source="upload", type="filepath", label="Input"),
133
+ "outputs": components.Textbox(label="Output"),
134
+ "preprocess": to_binary,
135
+ "postprocess": lambda r: r.json()["text"],
136
+ },
137
+ "conversational": {
138
+ "inputs": [components.Textbox(), components.State()], # type: ignore
139
+ "outputs": [components.Chatbot(), components.State()], # type: ignore
140
+ "preprocess": chatbot_preprocess,
141
+ "postprocess": chatbot_postprocess,
142
+ },
143
+ "feature-extraction": {
144
+ # example model: julien-c/distilbert-feature-extraction
145
+ "inputs": components.Textbox(label="Input"),
146
+ "outputs": components.Dataframe(label="Output"),
147
+ "preprocess": lambda x: {"inputs": x},
148
+ "postprocess": lambda r: r.json()[0],
149
+ },
150
+ "fill-mask": {
151
+ "inputs": components.Textbox(label="Input"),
152
+ "outputs": components.Label(label="Classification"),
153
+ "preprocess": lambda x: {"inputs": x},
154
+ "postprocess": lambda r: postprocess_label(
155
+ {i["token_str"]: i["score"] for i in r.json()}
156
+ ),
157
+ },
158
+ "image-classification": {
159
+ # Example: google/vit-base-patch16-224
160
+ "inputs": components.Image(type="filepath", label="Input Image"),
161
+ "outputs": components.Label(label="Classification"),
162
+ "preprocess": to_binary,
163
+ "postprocess": lambda r: postprocess_label(
164
+ {i["label"].split(", ")[0]: i["score"] for i in r.json()}
165
+ ),
166
+ },
167
+ "image-to-text": {
168
+ "inputs": components.Image(type="filepath", label="Input Image"),
169
+ "outputs": components.Textbox(),
170
+ "preprocess": to_binary,
171
+ "postprocess": lambda r: r.json()[0]["generated_text"],
172
+ },
173
+ "question-answering": {
174
+ # Example: deepset/xlm-roberta-base-squad2
175
+ "inputs": [
176
+ components.Textbox(lines=7, label="Context"),
177
+ components.Textbox(label="Question"),
178
+ ],
179
+ "outputs": [
180
+ components.Textbox(label="Answer"),
181
+ components.Label(label="Score"),
182
+ ],
183
+ "preprocess": lambda c, q: {"inputs": {"context": c, "question": q}},
184
+ "postprocess": lambda r: (r.json()["answer"], {"label": r.json()["score"]}),
185
+ },
186
+ "summarization": {
187
+ # Example: facebook/bart-large-cnn
188
+ "inputs": components.Textbox(label="Input"),
189
+ "outputs": components.Textbox(label="Summary"),
190
+ "preprocess": lambda x: {"inputs": x},
191
+ "postprocess": lambda r: r.json()[0]["summary_text"],
192
+ },
193
+ "text-classification": {
194
+ # Example: distilbert-base-uncased-finetuned-sst-2-english
195
+ "inputs": components.Textbox(label="Input"),
196
+ "outputs": components.Label(label="Classification"),
197
+ "preprocess": lambda x: {"inputs": x},
198
+ "postprocess": lambda r: postprocess_label(
199
+ {i["label"].split(", ")[0]: i["score"] for i in r.json()[0]}
200
+ ),
201
+ },
202
+ "text-generation": {
203
+ # Example: gpt2
204
+ "inputs": components.Textbox(label="Input"),
205
+ "outputs": components.Textbox(label="Output"),
206
+ "preprocess": lambda x: {"inputs": x},
207
+ "postprocess": lambda r: r.json()[0]["generated_text"],
208
+ },
209
+ "text2text-generation": {
210
+ # Example: valhalla/t5-small-qa-qg-hl
211
+ "inputs": components.Textbox(label="Input"),
212
+ "outputs": components.Textbox(label="Generated Text"),
213
+ "preprocess": lambda x: {"inputs": x},
214
+ "postprocess": lambda r: r.json()[0]["generated_text"],
215
+ },
216
+ "translation": {
217
+ "inputs": components.Textbox(label="Input"),
218
+ "outputs": components.Textbox(label="Translation"),
219
+ "preprocess": lambda x: {"inputs": x},
220
+ "postprocess": lambda r: r.json()[0]["translation_text"],
221
+ },
222
+ "zero-shot-classification": {
223
+ # Example: facebook/bart-large-mnli
224
+ "inputs": [
225
+ components.Textbox(label="Input"),
226
+ components.Textbox(label="Possible class names (" "comma-separated)"),
227
+ components.Checkbox(label="Allow multiple true classes"),
228
+ ],
229
+ "outputs": components.Label(label="Classification"),
230
+ "preprocess": lambda i, c, m: {
231
+ "inputs": i,
232
+ "parameters": {"candidate_labels": c, "multi_class": m},
233
+ },
234
+ "postprocess": lambda r: postprocess_label(
235
+ {
236
+ r.json()["labels"][i]: r.json()["scores"][i]
237
+ for i in range(len(r.json()["labels"]))
238
+ }
239
+ ),
240
+ },
241
+ "sentence-similarity": {
242
+ # Example: sentence-transformers/distilbert-base-nli-stsb-mean-tokens
243
+ "inputs": [
244
+ components.Textbox(
245
+ value="That is a happy person", label="Source Sentence"
246
+ ),
247
+ components.Textbox(
248
+ lines=7,
249
+ placeholder="Separate each sentence by a newline",
250
+ label="Sentences to compare to",
251
+ ),
252
+ ],
253
+ "outputs": components.Label(label="Classification"),
254
+ "preprocess": lambda src, sentences: {
255
+ "inputs": {
256
+ "source_sentence": src,
257
+ "sentences": [s for s in sentences.splitlines() if s != ""],
258
+ }
259
+ },
260
+ "postprocess": lambda r: postprocess_label(
261
+ {f"sentence {i}": v for i, v in enumerate(r.json())}
262
+ ),
263
+ },
264
+ "text-to-speech": {
265
+ # Example: julien-c/ljspeech_tts_train_tacotron2_raw_phn_tacotron_g2p_en_no_space_train
266
+ "inputs": components.Textbox(label="Input"),
267
+ "outputs": components.Audio(label="Audio"),
268
+ "preprocess": lambda x: {"inputs": x},
269
+ "postprocess": encode_to_base64,
270
+ },
271
+ "text-to-image": {
272
+ # example model: osanseviero/BigGAN-deep-128
273
+ "inputs": components.Textbox(label="Input"),
274
+ "outputs": components.Image(label="Output"),
275
+ "preprocess": lambda x: {"inputs": x},
276
+ "postprocess": encode_to_base64,
277
+ },
278
+ "token-classification": {
279
+ # example model: huggingface-course/bert-finetuned-ner
280
+ "inputs": components.Textbox(label="Input"),
281
+ "outputs": components.HighlightedText(label="Output"),
282
+ "preprocess": lambda x: {"inputs": x},
283
+ "postprocess": lambda r: r, # Handled as a special case in query_huggingface_api()
284
+ },
285
+ }
286
+
287
+ if p in ["tabular-classification", "tabular-regression"]:
288
+ example_data = get_tabular_examples(model_name)
289
+ col_names, example_data = cols_to_rows(example_data)
290
+ example_data = [[example_data]] if example_data else None
291
+
292
+ pipelines[p] = {
293
+ "inputs": components.Dataframe(
294
+ label="Input Rows",
295
+ type="pandas",
296
+ headers=col_names,
297
+ col_count=(len(col_names), "fixed"),
298
+ ),
299
+ "outputs": components.Dataframe(
300
+ label="Predictions", type="array", headers=["prediction"]
301
+ ),
302
+ "preprocess": rows_to_cols,
303
+ "postprocess": lambda r: {
304
+ "headers": ["prediction"],
305
+ "data": [[pred] for pred in json.loads(r.text)],
306
+ },
307
+ "examples": example_data,
308
+ }
309
+
310
+ if p is None or not (p in pipelines):
311
+ raise ValueError("Unsupported pipeline type: {}".format(p))
312
+
313
+ pipeline = pipelines[p]
314
+
315
+ def query_huggingface_api(*params):
316
+ # Convert to a list of input components
317
+ data = pipeline["preprocess"](*params)
318
+ if isinstance(
319
+ data, dict
320
+ ): # HF doesn't allow additional parameters for binary files (e.g. images or audio files)
321
+ data.update({"options": {"wait_for_model": True}})
322
+ data = json.dumps(data)
323
+ response = requests.request("POST", api_url, headers=headers, data=data)
324
+ if not (response.status_code == 200):
325
+ errors_json = response.json()
326
+ errors, warns = "", ""
327
+ if errors_json.get("error"):
328
+ errors = f", Error: {errors_json.get('error')}"
329
+ if errors_json.get("warnings"):
330
+ warns = f", Warnings: {errors_json.get('warnings')}"
331
+ raise Error(
332
+ f"Could not complete request to HuggingFace API, Status Code: {response.status_code}"
333
+ + errors
334
+ + warns
335
+ )
336
+ if (
337
+ p == "token-classification"
338
+ ): # Handle as a special case since HF API only returns the named entities and we need the input as well
339
+ ner_groups = response.json()
340
+ input_string = params[0]
341
+ response = utils.format_ner_list(input_string, ner_groups)
342
+ output = pipeline["postprocess"](response)
343
+ return output
344
+
345
+ if alias is None:
346
+ query_huggingface_api.__name__ = model_name
347
+ else:
348
+ query_huggingface_api.__name__ = alias
349
+
350
+ interface_info = {
351
+ "fn": query_huggingface_api,
352
+ "inputs": pipeline["inputs"],
353
+ "outputs": pipeline["outputs"],
354
+ "title": model_name,
355
+ "examples": pipeline.get("examples"),
356
+ }
357
+
358
+ kwargs = dict(interface_info, **kwargs)
359
+
360
+ # So interface doesn't run pre/postprocess
361
+ # except for conversational interfaces which
362
+ # are stateful
363
+ kwargs["_api_mode"] = p != "conversational"
364
+
365
+ interface = gradio.Interface(**kwargs)
366
+ return interface
367
+
368
+
369
+ def from_spaces(
370
+ space_name: str, api_key: str | None, alias: str | None, **kwargs
371
+ ) -> Blocks:
372
+ space_url = "https://huggingface.co/spaces/{}".format(space_name)
373
+
374
+ print("Fetching Space from: {}".format(space_url))
375
+
376
+ headers = {}
377
+ if api_key is not None:
378
+ headers["Authorization"] = f"Bearer {api_key}"
379
+
380
+ iframe_url = (
381
+ requests.get(
382
+ f"https://huggingface.co/api/spaces/{space_name}/host", headers=headers
383
+ )
384
+ .json()
385
+ .get("host")
386
+ )
387
+
388
+ if iframe_url is None:
389
+ raise ValueError(
390
+ f"Could not find Space: {space_name}. If it is a private or gated Space, please provide your Hugging Face access token (https://huggingface.co/settings/tokens) as the argument for the `api_key` parameter."
391
+ )
392
+
393
+ r = requests.get(iframe_url, headers=headers)
394
+
395
+ result = re.search(
396
+ r"window.gradio_config = (.*?);[\s]*</script>", r.text
397
+ ) # some basic regex to extract the config
398
+ try:
399
+ config = json.loads(result.group(1)) # type: ignore
400
+ except AttributeError:
401
+ raise ValueError("Could not load the Space: {}".format(space_name))
402
+ if "allow_flagging" in config: # Create an Interface for Gradio 2.x Spaces
403
+ return from_spaces_interface(
404
+ space_name, config, alias, api_key, iframe_url, **kwargs
405
+ )
406
+ else: # Create a Blocks for Gradio 3.x Spaces
407
+ if kwargs:
408
+ warnings.warn(
409
+ "You cannot override parameters for this Space by passing in kwargs. "
410
+ "Instead, please load the Space as a function and use it to create a "
411
+ "Blocks or Interface locally. You may find this Guide helpful: "
412
+ "https://gradio.app/using_blocks_like_functions/"
413
+ )
414
+ return from_spaces_blocks(config, api_key, iframe_url)
415
+
416
+
417
+ def from_spaces_blocks(config: Dict, api_key: str | None, iframe_url: str) -> Blocks:
418
+ api_url = "{}/api/predict/".format(iframe_url)
419
+
420
+ headers = {"Content-Type": "application/json"}
421
+ if api_key is not None:
422
+ headers["Authorization"] = f"Bearer {api_key}"
423
+ ws_url = "{}/queue/join".format(iframe_url).replace("https", "wss")
424
+
425
+ ws_fn = get_ws_fn(ws_url, headers)
426
+
427
+ fns = []
428
+ for d, dependency in enumerate(config["dependencies"]):
429
+ if dependency["backend_fn"]:
430
+
431
+ def get_fn(outputs, fn_index, use_ws):
432
+ def fn(*data):
433
+ data = json.dumps({"data": data, "fn_index": fn_index})
434
+ hash_data = json.dumps(
435
+ {"fn_index": fn_index, "session_hash": str(uuid.uuid4())}
436
+ )
437
+ if use_ws:
438
+ result = utils.synchronize_async(ws_fn, data, hash_data)
439
+ output = result["data"]
440
+ else:
441
+ response = requests.post(api_url, headers=headers, data=data)
442
+ result = json.loads(response.content.decode("utf-8"))
443
+ try:
444
+ output = result["data"]
445
+ except KeyError:
446
+ if "error" in result and "429" in result["error"]:
447
+ raise TooManyRequestsError(
448
+ "Too many requests to the Hugging Face API"
449
+ )
450
+ raise KeyError(
451
+ f"Could not find 'data' key in response from external Space. Response received: {result}"
452
+ )
453
+ if len(outputs) == 1:
454
+ output = output[0]
455
+ return output
456
+
457
+ return fn
458
+
459
+ fn = get_fn(
460
+ deepcopy(dependency["outputs"]), d, use_websocket(config, dependency)
461
+ )
462
+ fns.append(fn)
463
+ else:
464
+ fns.append(None)
465
+ return gradio.Blocks.from_config(config, fns, iframe_url)
466
+
467
+
468
+ def from_spaces_interface(
469
+ model_name: str,
470
+ config: Dict,
471
+ alias: str | None,
472
+ api_key: str | None,
473
+ iframe_url: str,
474
+ **kwargs,
475
+ ) -> Interface:
476
+
477
+ config = streamline_spaces_interface(config)
478
+ api_url = "{}/api/predict/".format(iframe_url)
479
+ headers = {"Content-Type": "application/json"}
480
+ if api_key is not None:
481
+ headers["Authorization"] = f"Bearer {api_key}"
482
+
483
+ # The function should call the API with preprocessed data
484
+ def fn(*data):
485
+ data = json.dumps({"data": data})
486
+ response = requests.post(api_url, headers=headers, data=data)
487
+ result = json.loads(response.content.decode("utf-8"))
488
+ try:
489
+ output = result["data"]
490
+ except KeyError:
491
+ if "error" in result and "429" in result["error"]:
492
+ raise TooManyRequestsError("Too many requests to the Hugging Face API")
493
+ raise KeyError(
494
+ f"Could not find 'data' key in response from external Space. Response received: {result}"
495
+ )
496
+ if (
497
+ len(config["outputs"]) == 1
498
+ ): # if the fn is supposed to return a single value, pop it
499
+ output = output[0]
500
+ if len(config["outputs"]) == 1 and isinstance(
501
+ output, list
502
+ ): # Needed to support Output.Image() returning bounding boxes as well (TODO: handle different versions of gradio since they have slightly different APIs)
503
+ output = output[0]
504
+ return output
505
+
506
+ fn.__name__ = alias if (alias is not None) else model_name
507
+ config["fn"] = fn
508
+
509
+ kwargs = dict(config, **kwargs)
510
+ kwargs["_api_mode"] = True
511
+ interface = gradio.Interface(**kwargs)
512
+ return interface
gradio/external_utils.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Utility function for gradio/external.py"""
2
+
3
+ import base64
4
+ import json
5
+ import math
6
+ import operator
7
+ import re
8
+ import warnings
9
+ from typing import Any, Dict, List, Tuple
10
+
11
+ import requests
12
+ import websockets
13
+ import yaml
14
+ from packaging import version
15
+ from websockets.legacy.protocol import WebSocketCommonProtocol
16
+
17
+ from gradio import components, exceptions
18
+
19
+ ##################
20
+ # Helper functions for processing tabular data
21
+ ##################
22
+
23
+
24
+ def get_tabular_examples(model_name: str) -> Dict[str, List[float]]:
25
+ readme = requests.get(f"https://huggingface.co/{model_name}/resolve/main/README.md")
26
+ if readme.status_code != 200:
27
+ warnings.warn(f"Cannot load examples from README for {model_name}", UserWarning)
28
+ example_data = {}
29
+ else:
30
+ yaml_regex = re.search(
31
+ "(?:^|[\r\n])---[\n\r]+([\\S\\s]*?)[\n\r]+---([\n\r]|$)", readme.text
32
+ )
33
+ if yaml_regex is None:
34
+ example_data = {}
35
+ else:
36
+ example_yaml = next(
37
+ yaml.safe_load_all(readme.text[: yaml_regex.span()[-1]])
38
+ )
39
+ example_data = example_yaml.get("widget", {}).get("structuredData", {})
40
+ if not example_data:
41
+ raise ValueError(
42
+ f"No example data found in README.md of {model_name} - Cannot build gradio demo. "
43
+ "See the README.md here: https://huggingface.co/scikit-learn/tabular-playground/blob/main/README.md "
44
+ "for a reference on how to provide example data to your model."
45
+ )
46
+ # replace nan with string NaN for inference API
47
+ for data in example_data.values():
48
+ for i, val in enumerate(data):
49
+ if isinstance(val, float) and math.isnan(val):
50
+ data[i] = "NaN"
51
+ return example_data
52
+
53
+
54
+ def cols_to_rows(
55
+ example_data: Dict[str, List[float]]
56
+ ) -> Tuple[List[str], List[List[float]]]:
57
+ headers = list(example_data.keys())
58
+ n_rows = max(len(example_data[header] or []) for header in headers)
59
+ data = []
60
+ for row_index in range(n_rows):
61
+ row_data = []
62
+ for header in headers:
63
+ col = example_data[header] or []
64
+ if row_index >= len(col):
65
+ row_data.append("NaN")
66
+ else:
67
+ row_data.append(col[row_index])
68
+ data.append(row_data)
69
+ return headers, data
70
+
71
+
72
+ def rows_to_cols(incoming_data: Dict) -> Dict[str, Dict[str, Dict[str, List[str]]]]:
73
+ data_column_wise = {}
74
+ for i, header in enumerate(incoming_data["headers"]):
75
+ data_column_wise[header] = [str(row[i]) for row in incoming_data["data"]]
76
+ return {"inputs": {"data": data_column_wise}}
77
+
78
+
79
+ ##################
80
+ # Helper functions for processing other kinds of data
81
+ ##################
82
+
83
+
84
+ def postprocess_label(scores: Dict) -> Dict:
85
+ sorted_pred = sorted(scores.items(), key=operator.itemgetter(1), reverse=True)
86
+ return {
87
+ "label": sorted_pred[0][0],
88
+ "confidences": [
89
+ {"label": pred[0], "confidence": pred[1]} for pred in sorted_pred
90
+ ],
91
+ }
92
+
93
+
94
+ def encode_to_base64(r: requests.Response) -> str:
95
+ # Handles the different ways HF API returns the prediction
96
+ base64_repr = base64.b64encode(r.content).decode("utf-8")
97
+ data_prefix = ";base64,"
98
+ # Case 1: base64 representation already includes data prefix
99
+ if data_prefix in base64_repr:
100
+ return base64_repr
101
+ else:
102
+ content_type = r.headers.get("content-type")
103
+ # Case 2: the data prefix is a key in the response
104
+ if content_type == "application/json":
105
+ try:
106
+ content_type = r.json()[0]["content-type"]
107
+ base64_repr = r.json()[0]["blob"]
108
+ except KeyError:
109
+ raise ValueError(
110
+ "Cannot determine content type returned" "by external API."
111
+ )
112
+ # Case 3: the data prefix is included in the response headers
113
+ else:
114
+ pass
115
+ new_base64 = "data:{};base64,".format(content_type) + base64_repr
116
+ return new_base64
117
+
118
+
119
+ ##################
120
+ # Helper functions for connecting to websockets
121
+ ##################
122
+
123
+
124
+ async def get_pred_from_ws(
125
+ websocket: WebSocketCommonProtocol, data: str, hash_data: str
126
+ ) -> Dict[str, Any]:
127
+ completed = False
128
+ resp = {}
129
+ while not completed:
130
+ msg = await websocket.recv()
131
+ resp = json.loads(msg)
132
+ if resp["msg"] == "queue_full":
133
+ raise exceptions.Error("Queue is full! Please try again.")
134
+ if resp["msg"] == "send_hash":
135
+ await websocket.send(hash_data)
136
+ elif resp["msg"] == "send_data":
137
+ await websocket.send(data)
138
+ completed = resp["msg"] == "process_completed"
139
+ return resp["output"]
140
+
141
+
142
+ def get_ws_fn(ws_url, headers):
143
+ async def ws_fn(data, hash_data):
144
+ async with websockets.connect( # type: ignore
145
+ ws_url, open_timeout=10, extra_headers=headers
146
+ ) as websocket:
147
+ return await get_pred_from_ws(websocket, data, hash_data)
148
+
149
+ return ws_fn
150
+
151
+
152
+ def use_websocket(config, dependency):
153
+ queue_enabled = config.get("enable_queue", False)
154
+ queue_uses_websocket = version.parse(
155
+ config.get("version", "2.0")
156
+ ) >= version.Version("3.2")
157
+ dependency_uses_queue = dependency.get("queue", False) is not False
158
+ return queue_enabled and queue_uses_websocket and dependency_uses_queue
159
+
160
+
161
+ ##################
162
+ # Helper function for cleaning up an Interface loaded from HF Spaces
163
+ ##################
164
+
165
+
166
+ def streamline_spaces_interface(config: Dict) -> Dict:
167
+ """Streamlines the interface config dictionary to remove unnecessary keys."""
168
+ config["inputs"] = [
169
+ components.get_component_instance(component)
170
+ for component in config["input_components"]
171
+ ]
172
+ config["outputs"] = [
173
+ components.get_component_instance(component)
174
+ for component in config["output_components"]
175
+ ]
176
+ parameters = {
177
+ "article",
178
+ "description",
179
+ "flagging_options",
180
+ "inputs",
181
+ "outputs",
182
+ "title",
183
+ }
184
+ config = {k: config[k] for k in parameters}
185
+ return config
gradio/flagging.py ADDED
@@ -0,0 +1,555 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import csv
4
+ import datetime
5
+ import json
6
+ import os
7
+ import time
8
+ import uuid
9
+ from abc import ABC, abstractmethod
10
+ from distutils.version import StrictVersion
11
+ from pathlib import Path
12
+ from typing import TYPE_CHECKING, Any, List
13
+
14
+ import pkg_resources
15
+
16
+ import gradio as gr
17
+ from gradio import utils
18
+ from gradio.documentation import document, set_documentation_group
19
+
20
+ if TYPE_CHECKING:
21
+ from gradio.components import IOComponent
22
+
23
+ set_documentation_group("flagging")
24
+
25
+
26
+ def _get_dataset_features_info(is_new, components):
27
+ """
28
+ Takes in a list of components and returns a dataset features info
29
+
30
+ Parameters:
31
+ is_new: boolean, whether the dataset is new or not
32
+ components: list of components
33
+
34
+ Returns:
35
+ infos: a dictionary of the dataset features
36
+ file_preview_types: dictionary mapping of gradio components to appropriate string.
37
+ header: list of header strings
38
+
39
+ """
40
+ infos = {"flagged": {"features": {}}}
41
+ # File previews for certain input and output types
42
+ file_preview_types = {gr.Audio: "Audio", gr.Image: "Image"}
43
+ headers = []
44
+
45
+ # Generate the headers and dataset_infos
46
+ if is_new:
47
+
48
+ for component in components:
49
+ headers.append(component.label)
50
+ infos["flagged"]["features"][component.label] = {
51
+ "dtype": "string",
52
+ "_type": "Value",
53
+ }
54
+ if isinstance(component, tuple(file_preview_types)):
55
+ headers.append(component.label + " file")
56
+ for _component, _type in file_preview_types.items():
57
+ if isinstance(component, _component):
58
+ infos["flagged"]["features"][
59
+ (component.label or "") + " file"
60
+ ] = {"_type": _type}
61
+ break
62
+
63
+ headers.append("flag")
64
+ infos["flagged"]["features"]["flag"] = {
65
+ "dtype": "string",
66
+ "_type": "Value",
67
+ }
68
+
69
+ return infos, file_preview_types, headers
70
+
71
+
72
+ class FlaggingCallback(ABC):
73
+ """
74
+ An abstract class for defining the methods that any FlaggingCallback should have.
75
+ """
76
+
77
+ @abstractmethod
78
+ def setup(self, components: List[IOComponent], flagging_dir: str):
79
+ """
80
+ This method should be overridden and ensure that everything is set up correctly for flag().
81
+ This method gets called once at the beginning of the Interface.launch() method.
82
+ Parameters:
83
+ components: Set of components that will provide flagged data.
84
+ flagging_dir: A string, typically containing the path to the directory where the flagging file should be storied (provided as an argument to Interface.__init__()).
85
+ """
86
+ pass
87
+
88
+ @abstractmethod
89
+ def flag(
90
+ self,
91
+ flag_data: List[Any],
92
+ flag_option: str = "",
93
+ username: str | None = None,
94
+ ) -> int:
95
+ """
96
+ This method should be overridden by the FlaggingCallback subclass and may contain optional additional arguments.
97
+ This gets called every time the <flag> button is pressed.
98
+ Parameters:
99
+ interface: The Interface object that is being used to launch the flagging interface.
100
+ flag_data: The data to be flagged.
101
+ flag_option (optional): In the case that flagging_options are provided, the flag option that is being used.
102
+ username (optional): The username of the user that is flagging the data, if logged in.
103
+ Returns:
104
+ (int) The total number of samples that have been flagged.
105
+ """
106
+ pass
107
+
108
+
109
+ @document()
110
+ class SimpleCSVLogger(FlaggingCallback):
111
+ """
112
+ A simplified implementation of the FlaggingCallback abstract class
113
+ provided for illustrative purposes. Each flagged sample (both the input and output data)
114
+ is logged to a CSV file on the machine running the gradio app.
115
+ Example:
116
+ import gradio as gr
117
+ def image_classifier(inp):
118
+ return {'cat': 0.3, 'dog': 0.7}
119
+ demo = gr.Interface(fn=image_classifier, inputs="image", outputs="label",
120
+ flagging_callback=SimpleCSVLogger())
121
+ """
122
+
123
+ def __init__(self):
124
+ pass
125
+
126
+ def setup(self, components: List[IOComponent], flagging_dir: str | Path):
127
+ self.components = components
128
+ self.flagging_dir = flagging_dir
129
+ os.makedirs(flagging_dir, exist_ok=True)
130
+
131
+ def flag(
132
+ self,
133
+ flag_data: List[Any],
134
+ flag_option: str = "",
135
+ username: str | None = None,
136
+ ) -> int:
137
+ flagging_dir = self.flagging_dir
138
+ log_filepath = Path(flagging_dir) / "log.csv"
139
+
140
+ csv_data = []
141
+ for component, sample in zip(self.components, flag_data):
142
+ save_dir = Path(flagging_dir) / utils.strip_invalid_filename_characters(
143
+ component.label or ""
144
+ )
145
+ csv_data.append(
146
+ component.deserialize(
147
+ sample,
148
+ save_dir,
149
+ None,
150
+ )
151
+ )
152
+
153
+ with open(log_filepath, "a", newline="") as csvfile:
154
+ writer = csv.writer(csvfile)
155
+ writer.writerow(utils.sanitize_list_for_csv(csv_data))
156
+
157
+ with open(log_filepath, "r") as csvfile:
158
+ line_count = len([None for row in csv.reader(csvfile)]) - 1
159
+ return line_count
160
+
161
+
162
+ @document()
163
+ class CSVLogger(FlaggingCallback):
164
+ """
165
+ The default implementation of the FlaggingCallback abstract class. Each flagged
166
+ sample (both the input and output data) is logged to a CSV file with headers on the machine running the gradio app.
167
+ Example:
168
+ import gradio as gr
169
+ def image_classifier(inp):
170
+ return {'cat': 0.3, 'dog': 0.7}
171
+ demo = gr.Interface(fn=image_classifier, inputs="image", outputs="label",
172
+ flagging_callback=CSVLogger())
173
+ Guides: using_flagging
174
+ """
175
+
176
+ def __init__(self):
177
+ pass
178
+
179
+ def setup(
180
+ self,
181
+ components: List[IOComponent],
182
+ flagging_dir: str | Path,
183
+ ):
184
+ self.components = components
185
+ self.flagging_dir = flagging_dir
186
+ os.makedirs(flagging_dir, exist_ok=True)
187
+
188
+ def flag(
189
+ self,
190
+ flag_data: List[Any],
191
+ flag_option: str = "",
192
+ username: str | None = None,
193
+ ) -> int:
194
+ flagging_dir = self.flagging_dir
195
+ log_filepath = Path(flagging_dir) / "log.csv"
196
+ is_new = not Path(log_filepath).exists()
197
+ headers = [
198
+ getattr(component, "label", None) or f"component {idx}"
199
+ for idx, component in enumerate(self.components)
200
+ ] + [
201
+ "flag",
202
+ "username",
203
+ "timestamp",
204
+ ]
205
+
206
+ csv_data = []
207
+ for idx, (component, sample) in enumerate(zip(self.components, flag_data)):
208
+ save_dir = Path(flagging_dir) / utils.strip_invalid_filename_characters(
209
+ getattr(component, "label", None) or f"component {idx}"
210
+ )
211
+ if utils.is_update(sample):
212
+ csv_data.append(str(sample))
213
+ else:
214
+ csv_data.append(
215
+ component.deserialize(sample, save_dir=save_dir)
216
+ if sample is not None
217
+ else ""
218
+ )
219
+ csv_data.append(flag_option)
220
+ csv_data.append(username if username is not None else "")
221
+ csv_data.append(str(datetime.datetime.now()))
222
+
223
+ with open(log_filepath, "a", newline="", encoding="utf-8") as csvfile:
224
+ writer = csv.writer(csvfile)
225
+ if is_new:
226
+ writer.writerow(utils.sanitize_list_for_csv(headers))
227
+ writer.writerow(utils.sanitize_list_for_csv(csv_data))
228
+
229
+ with open(log_filepath, "r", encoding="utf-8") as csvfile:
230
+ line_count = len([None for row in csv.reader(csvfile)]) - 1
231
+ return line_count
232
+
233
+
234
+ @document()
235
+ class HuggingFaceDatasetSaver(FlaggingCallback):
236
+ """
237
+ A callback that saves each flagged sample (both the input and output data)
238
+ to a HuggingFace dataset.
239
+ Example:
240
+ import gradio as gr
241
+ hf_writer = gr.HuggingFaceDatasetSaver(HF_API_TOKEN, "image-classification-mistakes")
242
+ def image_classifier(inp):
243
+ return {'cat': 0.3, 'dog': 0.7}
244
+ demo = gr.Interface(fn=image_classifier, inputs="image", outputs="label",
245
+ allow_flagging="manual", flagging_callback=hf_writer)
246
+ Guides: using_flagging
247
+ """
248
+
249
+ def __init__(
250
+ self,
251
+ hf_token: str,
252
+ dataset_name: str,
253
+ organization: str | None = None,
254
+ private: bool = False,
255
+ ):
256
+ """
257
+ Parameters:
258
+ hf_token: The HuggingFace token to use to create (and write the flagged sample to) the HuggingFace dataset.
259
+ dataset_name: The name of the dataset to save the data to, e.g. "image-classifier-1"
260
+ organization: The organization to save the dataset under. The hf_token must provide write access to this organization. If not provided, saved under the name of the user corresponding to the hf_token.
261
+ private: Whether the dataset should be private (defaults to False).
262
+ """
263
+ self.hf_token = hf_token
264
+ self.dataset_name = dataset_name
265
+ self.organization_name = organization
266
+ self.dataset_private = private
267
+
268
+ def setup(self, components: List[IOComponent], flagging_dir: str):
269
+ """
270
+ Params:
271
+ flagging_dir (str): local directory where the dataset is cloned,
272
+ updated, and pushed from.
273
+ """
274
+ try:
275
+ import huggingface_hub
276
+ except (ImportError, ModuleNotFoundError):
277
+ raise ImportError(
278
+ "Package `huggingface_hub` not found is needed "
279
+ "for HuggingFaceDatasetSaver. Try 'pip install huggingface_hub'."
280
+ )
281
+ hh_version = pkg_resources.get_distribution("huggingface_hub").version
282
+ try:
283
+ if StrictVersion(hh_version) < StrictVersion("0.6.0"):
284
+ raise ImportError(
285
+ "The `huggingface_hub` package must be version 0.6.0 or higher"
286
+ "for HuggingFaceDatasetSaver. Try 'pip install huggingface_hub --upgrade'."
287
+ )
288
+ except ValueError:
289
+ pass
290
+ repo_id = huggingface_hub.get_full_repo_name(
291
+ self.dataset_name, token=self.hf_token
292
+ )
293
+ path_to_dataset_repo = huggingface_hub.create_repo(
294
+ repo_id=repo_id,
295
+ token=self.hf_token,
296
+ private=self.dataset_private,
297
+ repo_type="dataset",
298
+ exist_ok=True,
299
+ )
300
+ self.path_to_dataset_repo = path_to_dataset_repo # e.g. "https://huggingface.co/datasets/abidlabs/test-audio-10"
301
+ self.components = components
302
+ self.flagging_dir = flagging_dir
303
+ self.dataset_dir = Path(flagging_dir) / self.dataset_name
304
+ self.repo = huggingface_hub.Repository(
305
+ local_dir=str(self.dataset_dir),
306
+ clone_from=path_to_dataset_repo,
307
+ use_auth_token=self.hf_token,
308
+ )
309
+ self.repo.git_pull(lfs=True)
310
+
311
+ # Should filename be user-specified?
312
+ self.log_file = Path(self.dataset_dir) / "data.csv"
313
+ self.infos_file = Path(self.dataset_dir) / "dataset_infos.json"
314
+
315
+ def flag(
316
+ self,
317
+ flag_data: List[Any],
318
+ flag_option: str = "",
319
+ username: str | None = None,
320
+ ) -> int:
321
+ self.repo.git_pull(lfs=True)
322
+
323
+ is_new = not Path(self.log_file).exists()
324
+
325
+ with open(self.log_file, "a", newline="", encoding="utf-8") as csvfile:
326
+ writer = csv.writer(csvfile)
327
+
328
+ # File previews for certain input and output types
329
+ infos, file_preview_types, headers = _get_dataset_features_info(
330
+ is_new, self.components
331
+ )
332
+
333
+ # Generate the headers and dataset_infos
334
+ if is_new:
335
+ writer.writerow(utils.sanitize_list_for_csv(headers))
336
+
337
+ # Generate the row corresponding to the flagged sample
338
+ csv_data = []
339
+ for component, sample in zip(self.components, flag_data):
340
+ save_dir = Path(
341
+ self.dataset_dir
342
+ ) / utils.strip_invalid_filename_characters(component.label or "")
343
+ filepath = component.deserialize(sample, save_dir, None)
344
+ csv_data.append(filepath)
345
+ if isinstance(component, tuple(file_preview_types)):
346
+ csv_data.append(
347
+ "{}/resolve/main/{}".format(self.path_to_dataset_repo, filepath)
348
+ )
349
+ csv_data.append(flag_option)
350
+ writer.writerow(utils.sanitize_list_for_csv(csv_data))
351
+
352
+ if is_new:
353
+ json.dump(infos, open(self.infos_file, "w"))
354
+
355
+ with open(self.log_file, "r", encoding="utf-8") as csvfile:
356
+ line_count = len([None for row in csv.reader(csvfile)]) - 1
357
+
358
+ self.repo.push_to_hub(commit_message="Flagged sample #{}".format(line_count))
359
+
360
+ return line_count
361
+
362
+
363
+ class HuggingFaceDatasetJSONSaver(FlaggingCallback):
364
+ """
365
+ A FlaggingCallback that saves flagged data to a Hugging Face dataset in JSONL format.
366
+
367
+ Each data sample is saved in a different JSONL file,
368
+ allowing multiple users to use flagging simultaneously.
369
+ Saving to a single CSV would cause errors as only one user can edit at the same time.
370
+
371
+ """
372
+
373
+ def __init__(
374
+ self,
375
+ hf_token: str,
376
+ dataset_name: str,
377
+ organization: str | None = None,
378
+ private: bool = False,
379
+ verbose: bool = True,
380
+ ):
381
+ """
382
+ Params:
383
+ hf_token (str): The token to use to access the huggingface API.
384
+ dataset_name (str): The name of the dataset to save the data to, e.g.
385
+ "image-classifier-1"
386
+ organization (str): The name of the organization to which to attach
387
+ the datasets. If None, the dataset attaches to the user only.
388
+ private (bool): If the dataset does not already exist, whether it
389
+ should be created as a private dataset or public. Private datasets
390
+ may require paid huggingface.co accounts
391
+ verbose (bool): Whether to print out the status of the dataset
392
+ creation.
393
+ """
394
+ self.hf_token = hf_token
395
+ self.dataset_name = dataset_name
396
+ self.organization_name = organization
397
+ self.dataset_private = private
398
+ self.verbose = verbose
399
+
400
+ def setup(self, components: List[IOComponent], flagging_dir: str):
401
+ """
402
+ Params:
403
+ components List[Component]: list of components for flagging
404
+ flagging_dir (str): local directory where the dataset is cloned,
405
+ updated, and pushed from.
406
+ """
407
+ try:
408
+ import huggingface_hub
409
+ except (ImportError, ModuleNotFoundError):
410
+ raise ImportError(
411
+ "Package `huggingface_hub` not found is needed "
412
+ "for HuggingFaceDatasetJSONSaver. Try 'pip install huggingface_hub'."
413
+ )
414
+ hh_version = pkg_resources.get_distribution("huggingface_hub").version
415
+ try:
416
+ if StrictVersion(hh_version) < StrictVersion("0.6.0"):
417
+ raise ImportError(
418
+ "The `huggingface_hub` package must be version 0.6.0 or higher"
419
+ "for HuggingFaceDatasetSaver. Try 'pip install huggingface_hub --upgrade'."
420
+ )
421
+ except ValueError:
422
+ pass
423
+ repo_id = huggingface_hub.get_full_repo_name(
424
+ self.dataset_name, token=self.hf_token
425
+ )
426
+ path_to_dataset_repo = huggingface_hub.create_repo(
427
+ repo_id=repo_id,
428
+ token=self.hf_token,
429
+ private=self.dataset_private,
430
+ repo_type="dataset",
431
+ exist_ok=True,
432
+ )
433
+ self.path_to_dataset_repo = path_to_dataset_repo # e.g. "https://huggingface.co/datasets/abidlabs/test-audio-10"
434
+ self.components = components
435
+ self.flagging_dir = flagging_dir
436
+ self.dataset_dir = Path(flagging_dir) / self.dataset_name
437
+ self.repo = huggingface_hub.Repository(
438
+ local_dir=str(self.dataset_dir),
439
+ clone_from=path_to_dataset_repo,
440
+ use_auth_token=self.hf_token,
441
+ )
442
+ self.repo.git_pull(lfs=True)
443
+
444
+ self.infos_file = Path(self.dataset_dir) / "dataset_infos.json"
445
+
446
+ def flag(
447
+ self,
448
+ flag_data: List[Any],
449
+ flag_option: str = "",
450
+ username: str | None = None,
451
+ ) -> str:
452
+ self.repo.git_pull(lfs=True)
453
+
454
+ # Generate unique folder for the flagged sample
455
+ unique_name = self.get_unique_name() # unique name for folder
456
+ folder_name = (
457
+ Path(self.dataset_dir) / unique_name
458
+ ) # unique folder for specific example
459
+ os.makedirs(folder_name)
460
+
461
+ # Now uses the existence of `dataset_infos.json` to determine if new
462
+ is_new = not Path(self.infos_file).exists()
463
+
464
+ # File previews for certain input and output types
465
+ infos, file_preview_types, _ = _get_dataset_features_info(
466
+ is_new, self.components
467
+ )
468
+
469
+ # Generate the row and header corresponding to the flagged sample
470
+ csv_data = []
471
+ headers = []
472
+
473
+ for component, sample in zip(self.components, flag_data):
474
+ headers.append(component.label)
475
+
476
+ try:
477
+ save_dir = Path(folder_name) / utils.strip_invalid_filename_characters(
478
+ component.label or ""
479
+ )
480
+ filepath = component.deserialize(sample, save_dir, None)
481
+ except Exception:
482
+ # Could not parse 'sample' (mostly) because it was None and `component.save_flagged`
483
+ # does not handle None cases.
484
+ # for example: Label (line 3109 of components.py raises an error if data is None)
485
+ filepath = None
486
+
487
+ if isinstance(component, tuple(file_preview_types)):
488
+ headers.append(component.label or "" + " file")
489
+
490
+ csv_data.append(
491
+ "{}/resolve/main/{}/{}".format(
492
+ self.path_to_dataset_repo, unique_name, filepath
493
+ )
494
+ if filepath is not None
495
+ else None
496
+ )
497
+
498
+ csv_data.append(filepath)
499
+ headers.append("flag")
500
+ csv_data.append(flag_option)
501
+
502
+ # Creates metadata dict from row data and dumps it
503
+ metadata_dict = {
504
+ header: _csv_data for header, _csv_data in zip(headers, csv_data)
505
+ }
506
+ self.dump_json(metadata_dict, Path(folder_name) / "metadata.jsonl")
507
+
508
+ if is_new:
509
+ json.dump(infos, open(self.infos_file, "w"))
510
+
511
+ self.repo.push_to_hub(commit_message="Flagged sample {}".format(unique_name))
512
+ return unique_name
513
+
514
+ def get_unique_name(self):
515
+ id = uuid.uuid4()
516
+ return str(id)
517
+
518
+ def dump_json(self, thing: dict, file_path: str | Path) -> None:
519
+ with open(file_path, "w+", encoding="utf8") as f:
520
+ json.dump(thing, f)
521
+
522
+
523
+ class FlagMethod:
524
+ """
525
+ Helper class that contains the flagging options and calls the flagging method. Also
526
+ provides visual feedback to the user when flag is clicked.
527
+ """
528
+
529
+ def __init__(
530
+ self,
531
+ flagging_callback: FlaggingCallback,
532
+ label: str,
533
+ value: str,
534
+ visual_feedback: bool = True,
535
+ ):
536
+ self.flagging_callback = flagging_callback
537
+ self.label = label
538
+ self.value = value
539
+ self.__name__ = "Flag"
540
+ self.visual_feedback = visual_feedback
541
+
542
+ def __call__(self, *flag_data):
543
+ try:
544
+ self.flagging_callback.flag(list(flag_data), flag_option=self.value)
545
+ except Exception as e:
546
+ print("Error while flagging: {}".format(e))
547
+ if self.visual_feedback:
548
+ return "Error!"
549
+ if not self.visual_feedback:
550
+ return
551
+ time.sleep(0.8) # to provide enough time for the user to observe button change
552
+ return self.reset()
553
+
554
+ def reset(self):
555
+ return gr.Button.update(value=self.label, interactive=True)
gradio/helpers.py ADDED
@@ -0,0 +1,839 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Defines helper methods useful for loading and caching Interface examples.
3
+ """
4
+ from __future__ import annotations
5
+
6
+ import ast
7
+ import csv
8
+ import inspect
9
+ import os
10
+ import subprocess
11
+ import tempfile
12
+ import threading
13
+ import warnings
14
+ from pathlib import Path
15
+ from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Tuple
16
+
17
+ import matplotlib
18
+ import matplotlib.pyplot as plt
19
+ import numpy as np
20
+ import PIL
21
+ import PIL.Image
22
+
23
+ from gradio import processing_utils, routes, utils
24
+ from gradio.context import Context
25
+ from gradio.documentation import document, set_documentation_group
26
+ from gradio.flagging import CSVLogger
27
+
28
+ if TYPE_CHECKING: # Only import for type checking (to avoid circular imports).
29
+ from gradio.blocks import Block
30
+ from gradio.components import IOComponent
31
+
32
+ CACHED_FOLDER = "gradio_cached_examples"
33
+ LOG_FILE = "log.csv"
34
+
35
+ set_documentation_group("helpers")
36
+
37
+
38
+ def create_examples(
39
+ examples: List[Any] | List[List[Any]] | str,
40
+ inputs: IOComponent | List[IOComponent],
41
+ outputs: IOComponent | List[IOComponent] | None = None,
42
+ fn: Callable | None = None,
43
+ cache_examples: bool = False,
44
+ examples_per_page: int = 10,
45
+ _api_mode: bool = False,
46
+ label: str | None = None,
47
+ elem_id: str | None = None,
48
+ run_on_click: bool = False,
49
+ preprocess: bool = True,
50
+ postprocess: bool = True,
51
+ batch: bool = False,
52
+ ):
53
+ """Top-level synchronous function that creates Examples. Provided for backwards compatibility, i.e. so that gr.Examples(...) can be used to create the Examples component."""
54
+ examples_obj = Examples(
55
+ examples=examples,
56
+ inputs=inputs,
57
+ outputs=outputs,
58
+ fn=fn,
59
+ cache_examples=cache_examples,
60
+ examples_per_page=examples_per_page,
61
+ _api_mode=_api_mode,
62
+ label=label,
63
+ elem_id=elem_id,
64
+ run_on_click=run_on_click,
65
+ preprocess=preprocess,
66
+ postprocess=postprocess,
67
+ batch=batch,
68
+ _initiated_directly=False,
69
+ )
70
+ utils.synchronize_async(examples_obj.create)
71
+ return examples_obj
72
+
73
+
74
+ @document()
75
+ class Examples:
76
+ """
77
+ This class is a wrapper over the Dataset component and can be used to create Examples
78
+ for Blocks / Interfaces. Populates the Dataset component with examples and
79
+ assigns event listener so that clicking on an example populates the input/output
80
+ components. Optionally handles example caching for fast inference.
81
+
82
+ Demos: blocks_inputs, fake_gan
83
+ Guides: more_on_examples_and_flagging, using_hugging_face_integrations, image_classification_in_pytorch, image_classification_in_tensorflow, image_classification_with_vision_transformers, create_your_own_friends_with_a_gan
84
+ """
85
+
86
+ def __init__(
87
+ self,
88
+ examples: List[Any] | List[List[Any]] | str,
89
+ inputs: IOComponent | List[IOComponent],
90
+ outputs: IOComponent | List[IOComponent] | None = None,
91
+ fn: Callable | None = None,
92
+ cache_examples: bool = False,
93
+ examples_per_page: int = 10,
94
+ _api_mode: bool = False,
95
+ label: str | None = "Examples",
96
+ elem_id: str | None = None,
97
+ run_on_click: bool = False,
98
+ preprocess: bool = True,
99
+ postprocess: bool = True,
100
+ batch: bool = False,
101
+ _initiated_directly: bool = True,
102
+ ):
103
+ """
104
+ Parameters:
105
+ examples: example inputs that can be clicked to populate specific components. Should be nested list, in which the outer list consists of samples and each inner list consists of an input corresponding to each input component. A string path to a directory of examples can also be provided but it should be within the directory with the python file running the gradio app. If there are multiple input components and a directory is provided, a log.csv file must be present in the directory to link corresponding inputs.
106
+ inputs: the component or list of components corresponding to the examples
107
+ outputs: optionally, provide the component or list of components corresponding to the output of the examples. Required if `cache` is True.
108
+ fn: optionally, provide the function to run to generate the outputs corresponding to the examples. Required if `cache` is True.
109
+ cache_examples: if True, caches examples for fast runtime. If True, then `fn` and `outputs` need to be provided
110
+ examples_per_page: how many examples to show per page.
111
+ label: the label to use for the examples component (by default, "Examples")
112
+ elem_id: an optional string that is assigned as the id of this component in the HTML DOM.
113
+ run_on_click: if cache_examples is False, clicking on an example does not run the function when an example is clicked. Set this to True to run the function when an example is clicked. Has no effect if cache_examples is True.
114
+ preprocess: if True, preprocesses the example input before running the prediction function and caching the output. Only applies if cache_examples is True.
115
+ postprocess: if True, postprocesses the example output after running the prediction function and before caching. Only applies if cache_examples is True.
116
+ batch: If True, then the function should process a batch of inputs, meaning that it should accept a list of input values for each parameter. Used only if cache_examples is True.
117
+ """
118
+ if _initiated_directly:
119
+ warnings.warn(
120
+ "Please use gr.Examples(...) instead of gr.examples.Examples(...) to create the Examples.",
121
+ )
122
+
123
+ if cache_examples and (fn is None or outputs is None):
124
+ raise ValueError("If caching examples, `fn` and `outputs` must be provided")
125
+
126
+ if not isinstance(inputs, list):
127
+ inputs = [inputs]
128
+ if outputs and not isinstance(outputs, list):
129
+ outputs = [outputs]
130
+
131
+ working_directory = Path().absolute()
132
+
133
+ if examples is None:
134
+ raise ValueError("The parameter `examples` cannot be None")
135
+ elif isinstance(examples, list) and (
136
+ len(examples) == 0 or isinstance(examples[0], list)
137
+ ):
138
+ pass
139
+ elif (
140
+ isinstance(examples, list) and len(inputs) == 1
141
+ ): # If there is only one input component, examples can be provided as a regular list instead of a list of lists
142
+ examples = [[e] for e in examples]
143
+ elif isinstance(examples, str):
144
+ if not Path(examples).exists():
145
+ raise FileNotFoundError(
146
+ "Could not find examples directory: " + examples
147
+ )
148
+ working_directory = examples
149
+ if not (Path(examples) / LOG_FILE).exists():
150
+ if len(inputs) == 1:
151
+ examples = [[e] for e in os.listdir(examples)]
152
+ else:
153
+ raise FileNotFoundError(
154
+ "Could not find log file (required for multiple inputs): "
155
+ + LOG_FILE
156
+ )
157
+ else:
158
+ with open(Path(examples) / LOG_FILE) as logs:
159
+ examples = list(csv.reader(logs))
160
+ examples = [
161
+ examples[i][: len(inputs)] for i in range(1, len(examples))
162
+ ] # remove header and unnecessary columns
163
+
164
+ else:
165
+ raise ValueError(
166
+ "The parameter `examples` must either be a string directory or a list"
167
+ "(if there is only 1 input component) or (more generally), a nested "
168
+ "list, where each sublist represents a set of inputs."
169
+ )
170
+
171
+ input_has_examples = [False] * len(inputs)
172
+ for example in examples:
173
+ for idx, example_for_input in enumerate(example):
174
+ if not (example_for_input is None):
175
+ try:
176
+ input_has_examples[idx] = True
177
+ except IndexError:
178
+ pass # If there are more example components than inputs, ignore. This can sometimes be intentional (e.g. loading from a log file where outputs and timestamps are also logged)
179
+
180
+ inputs_with_examples = [
181
+ inp for (inp, keep) in zip(inputs, input_has_examples) if keep
182
+ ]
183
+ non_none_examples = [
184
+ [ex for (ex, keep) in zip(example, input_has_examples) if keep]
185
+ for example in examples
186
+ ]
187
+
188
+ self.examples = examples
189
+ self.non_none_examples = non_none_examples
190
+ self.inputs = inputs
191
+ self.inputs_with_examples = inputs_with_examples
192
+ self.outputs = outputs
193
+ self.fn = fn
194
+ self.cache_examples = cache_examples
195
+ self._api_mode = _api_mode
196
+ self.preprocess = preprocess
197
+ self.postprocess = postprocess
198
+ self.batch = batch
199
+
200
+ with utils.set_directory(working_directory):
201
+ self.processed_examples = [
202
+ [
203
+ component.postprocess(sample)
204
+ for component, sample in zip(inputs, example)
205
+ ]
206
+ for example in examples
207
+ ]
208
+ self.non_none_processed_examples = [
209
+ [ex for (ex, keep) in zip(example, input_has_examples) if keep]
210
+ for example in self.processed_examples
211
+ ]
212
+ if cache_examples:
213
+ for example in self.examples:
214
+ if len([ex for ex in example if ex is not None]) != len(self.inputs):
215
+ warnings.warn(
216
+ "Examples are being cached but not all input components have "
217
+ "example values. This may result in an exception being thrown by "
218
+ "your function. If you do get an error while caching examples, make "
219
+ "sure all of your inputs have example values for all of your examples "
220
+ "or you provide default values for those particular parameters in your function."
221
+ )
222
+ break
223
+
224
+ from gradio import components
225
+
226
+ with utils.set_directory(working_directory):
227
+ self.dataset = components.Dataset(
228
+ components=inputs_with_examples,
229
+ samples=non_none_examples,
230
+ type="index",
231
+ label=label,
232
+ samples_per_page=examples_per_page,
233
+ elem_id=elem_id,
234
+ )
235
+
236
+ self.cached_folder = Path(CACHED_FOLDER) / str(self.dataset._id)
237
+ self.cached_file = Path(self.cached_folder) / "log.csv"
238
+ self.cache_examples = cache_examples
239
+ self.run_on_click = run_on_click
240
+
241
+ async def create(self) -> None:
242
+ """Caches the examples if self.cache_examples is True and creates the Dataset
243
+ component to hold the examples"""
244
+
245
+ async def load_example(example_id):
246
+ if self.cache_examples:
247
+ processed_example = self.non_none_processed_examples[
248
+ example_id
249
+ ] + await self.load_from_cache(example_id)
250
+ else:
251
+ processed_example = self.non_none_processed_examples[example_id]
252
+ return utils.resolve_singleton(processed_example)
253
+
254
+ if Context.root_block:
255
+ if self.cache_examples and self.outputs:
256
+ targets = self.inputs_with_examples + self.outputs
257
+ else:
258
+ targets = self.inputs_with_examples
259
+ self.dataset.click(
260
+ load_example,
261
+ inputs=[self.dataset],
262
+ outputs=targets, # type: ignore
263
+ show_progress=False,
264
+ postprocess=False,
265
+ queue=False,
266
+ )
267
+ if self.run_on_click and not self.cache_examples:
268
+ if self.fn is None:
269
+ raise ValueError("Cannot run_on_click if no function is provided")
270
+ self.dataset.click(
271
+ self.fn,
272
+ inputs=self.inputs, # type: ignore
273
+ outputs=self.outputs, # type: ignore
274
+ )
275
+
276
+ if self.cache_examples:
277
+ await self.cache()
278
+
279
+ async def cache(self) -> None:
280
+ """
281
+ Caches all of the examples so that their predictions can be shown immediately.
282
+ """
283
+ if Path(self.cached_file).exists():
284
+ print(
285
+ f"Using cache from '{utils.abspath(self.cached_folder)}' directory. If method or examples have changed since last caching, delete this folder to clear cache."
286
+ )
287
+ else:
288
+ if Context.root_block is None:
289
+ raise ValueError("Cannot cache examples if not in a Blocks context")
290
+
291
+ print(f"Caching examples at: '{utils.abspath(self.cached_folder)}'")
292
+ cache_logger = CSVLogger()
293
+
294
+ # create a fake dependency to process the examples and get the predictions
295
+ dependency, fn_index = Context.root_block.set_event_trigger(
296
+ event_name="fake_event",
297
+ fn=self.fn,
298
+ inputs=self.inputs_with_examples, # type: ignore
299
+ outputs=self.outputs, # type: ignore
300
+ preprocess=self.preprocess and not self._api_mode,
301
+ postprocess=self.postprocess and not self._api_mode,
302
+ batch=self.batch,
303
+ )
304
+
305
+ assert self.outputs is not None
306
+ cache_logger.setup(self.outputs, self.cached_folder)
307
+ for example_id, _ in enumerate(self.examples):
308
+ processed_input = self.processed_examples[example_id]
309
+ if self.batch:
310
+ processed_input = [[value] for value in processed_input]
311
+ prediction = await Context.root_block.process_api(
312
+ fn_index=fn_index, inputs=processed_input, request=None, state={}
313
+ )
314
+ output = prediction["data"]
315
+ if self.batch:
316
+ output = [value[0] for value in output]
317
+ cache_logger.flag(output)
318
+ # Remove the "fake_event" to prevent bugs in loading interfaces from spaces
319
+ Context.root_block.dependencies.remove(dependency)
320
+ Context.root_block.fns.pop(fn_index)
321
+
322
+ async def load_from_cache(self, example_id: int) -> List[Any]:
323
+ """Loads a particular cached example for the interface.
324
+ Parameters:
325
+ example_id: The id of the example to process (zero-indexed).
326
+ """
327
+ with open(self.cached_file, encoding="utf-8") as cache:
328
+ examples = list(csv.reader(cache))
329
+ example = examples[example_id + 1] # +1 to adjust for header
330
+ output = []
331
+ assert self.outputs is not None
332
+ for component, value in zip(self.outputs, example):
333
+ try:
334
+ value_as_dict = ast.literal_eval(value)
335
+ assert utils.is_update(value_as_dict)
336
+ output.append(value_as_dict)
337
+ except (ValueError, TypeError, SyntaxError, AssertionError):
338
+ output.append(component.serialize(value, self.cached_folder))
339
+ return output
340
+
341
+
342
+ class TrackedIterable:
343
+ def __init__(
344
+ self,
345
+ iterable: Iterable | None,
346
+ index: int | None,
347
+ length: int | None,
348
+ desc: str | None,
349
+ unit: str | None,
350
+ _tqdm=None,
351
+ progress: float | None = None,
352
+ ) -> None:
353
+ self.iterable = iterable
354
+ self.index = index
355
+ self.length = length
356
+ self.desc = desc
357
+ self.unit = unit
358
+ self._tqdm = _tqdm
359
+ self.progress = progress
360
+
361
+
362
+ @document("__call__", "tqdm")
363
+ class Progress(Iterable):
364
+ """
365
+ The Progress class provides a custom progress tracker that is used in a function signature.
366
+ To attach a Progress tracker to a function, simply add a parameter right after the input parameters that has a default value set to a `gradio.Progress()` instance.
367
+ The Progress tracker can then be updated in the function by calling the Progress object or using the `tqdm` method on an Iterable.
368
+ The Progress tracker is currently only available with `queue()`.
369
+ Example:
370
+ import gradio as gr
371
+ import time
372
+ def my_function(x, progress=gr.Progress()):
373
+ progress(0, desc="Starting...")
374
+ time.sleep(1)
375
+ for i in progress.tqdm(range(100)):
376
+ time.sleep(0.1)
377
+ return x
378
+ gr.Interface(my_function, gr.Textbox(), gr.Textbox()).queue().launch()
379
+ Demos: progress
380
+ """
381
+
382
+ def __init__(
383
+ self,
384
+ track_tqdm: bool = False,
385
+ _callback: Callable | None = None, # for internal use only
386
+ _event_id: str | None = None,
387
+ ):
388
+ """
389
+ Parameters:
390
+ track_tqdm: If True, the Progress object will track any tqdm.tqdm iterations with the tqdm library in the function.
391
+ """
392
+ self.track_tqdm = track_tqdm
393
+ self._callback = _callback
394
+ self._event_id = _event_id
395
+ self.iterables: List[TrackedIterable] = []
396
+
397
+ def __len__(self):
398
+ return self.iterables[-1].length
399
+
400
+ def __iter__(self):
401
+ return self
402
+
403
+ def __next__(self):
404
+ """
405
+ Updates progress tracker with next item in iterable.
406
+ """
407
+ if self._callback:
408
+ current_iterable = self.iterables[-1]
409
+ while (
410
+ not hasattr(current_iterable.iterable, "__next__")
411
+ and len(self.iterables) > 0
412
+ ):
413
+ current_iterable = self.iterables.pop()
414
+ self._callback(
415
+ event_id=self._event_id,
416
+ iterables=self.iterables,
417
+ )
418
+ assert current_iterable.index is not None, "Index not set."
419
+ current_iterable.index += 1
420
+ try:
421
+ return next(current_iterable.iterable) # type: ignore
422
+ except StopIteration:
423
+ self.iterables.pop()
424
+ raise StopIteration
425
+ else:
426
+ return self
427
+
428
+ def __call__(
429
+ self,
430
+ progress: float | Tuple[int, int | None] | None,
431
+ desc: str | None = None,
432
+ total: int | None = None,
433
+ unit: str = "steps",
434
+ _tqdm=None,
435
+ ):
436
+ """
437
+ Updates progress tracker with progress and message text.
438
+ Parameters:
439
+ progress: If float, should be between 0 and 1 representing completion. If Tuple, first number represents steps completed, and second value represents total steps or None if unknown. If None, hides progress bar.
440
+ desc: description to display.
441
+ total: estimated total number of steps.
442
+ unit: unit of iterations.
443
+ """
444
+ if self._callback:
445
+ if isinstance(progress, tuple):
446
+ index, total = progress
447
+ progress = None
448
+ else:
449
+ index = None
450
+ self._callback(
451
+ event_id=self._event_id,
452
+ iterables=self.iterables
453
+ + [TrackedIterable(None, index, total, desc, unit, _tqdm, progress)],
454
+ )
455
+ else:
456
+ return progress
457
+
458
+ def tqdm(
459
+ self,
460
+ iterable: Iterable | None,
461
+ desc: str | None = None,
462
+ total: int | None = None,
463
+ unit: str = "steps",
464
+ _tqdm=None,
465
+ *args,
466
+ **kwargs,
467
+ ):
468
+ """
469
+ Attaches progress tracker to iterable, like tqdm.
470
+ Parameters:
471
+ iterable: iterable to attach progress tracker to.
472
+ desc: description to display.
473
+ total: estimated total number of steps.
474
+ unit: unit of iterations.
475
+ """
476
+ if self._callback:
477
+ if iterable is None:
478
+ new_iterable = TrackedIterable(None, 0, total, desc, unit, _tqdm)
479
+ self.iterables.append(new_iterable)
480
+ self._callback(event_id=self._event_id, iterables=self.iterables)
481
+ return self
482
+ length = len(iterable) if hasattr(iterable, "__len__") else None # type: ignore
483
+ self.iterables.append(
484
+ TrackedIterable(iter(iterable), 0, length, desc, unit, _tqdm)
485
+ )
486
+ return self
487
+
488
+ def update(self, n=1):
489
+ """
490
+ Increases latest iterable with specified number of steps.
491
+ Parameters:
492
+ n: number of steps completed.
493
+ """
494
+ if self._callback and len(self.iterables) > 0:
495
+ current_iterable = self.iterables[-1]
496
+ assert current_iterable.index is not None, "Index not set."
497
+ current_iterable.index += n
498
+ self._callback(
499
+ event_id=self._event_id,
500
+ iterables=self.iterables,
501
+ )
502
+ else:
503
+ return
504
+
505
+ def close(self, _tqdm):
506
+ """
507
+ Removes iterable with given _tqdm.
508
+ """
509
+ if self._callback:
510
+ for i in range(len(self.iterables)):
511
+ if id(self.iterables[i]._tqdm) == id(_tqdm):
512
+ self.iterables.pop(i)
513
+ break
514
+ self._callback(
515
+ event_id=self._event_id,
516
+ iterables=self.iterables,
517
+ )
518
+ else:
519
+ return
520
+
521
+
522
+ def create_tracker(root_blocks, event_id, fn, track_tqdm):
523
+
524
+ progress = Progress(_callback=root_blocks._queue.set_progress, _event_id=event_id)
525
+ if not track_tqdm:
526
+ return progress, fn
527
+
528
+ try:
529
+ _tqdm = __import__("tqdm")
530
+ except ModuleNotFoundError:
531
+ return progress, fn
532
+ if not hasattr(root_blocks, "_progress_tracker_per_thread"):
533
+ root_blocks._progress_tracker_per_thread = {}
534
+
535
+ def init_tqdm(self, iterable=None, desc=None, *args, **kwargs):
536
+ self._progress = root_blocks._progress_tracker_per_thread.get(
537
+ threading.get_ident()
538
+ )
539
+ if self._progress is not None:
540
+ self._progress.event_id = event_id
541
+ self._progress.tqdm(iterable, desc, _tqdm=self, *args, **kwargs)
542
+ kwargs["file"] = open(os.devnull, "w")
543
+ self.__init__orig__(iterable, desc, *args, **kwargs)
544
+
545
+ def iter_tqdm(self):
546
+ if self._progress is not None:
547
+ return self._progress
548
+ else:
549
+ return self.__iter__orig__()
550
+
551
+ def update_tqdm(self, n=1):
552
+ if self._progress is not None:
553
+ self._progress.update(n)
554
+ return self.__update__orig__(n)
555
+
556
+ def close_tqdm(self):
557
+ if self._progress is not None:
558
+ self._progress.close(self)
559
+ return self.__close__orig__()
560
+
561
+ def exit_tqdm(self, exc_type, exc_value, traceback):
562
+ if self._progress is not None:
563
+ self._progress.close(self)
564
+ return self.__exit__orig__(exc_type, exc_value, traceback)
565
+
566
+ if not hasattr(_tqdm.tqdm, "__init__orig__"):
567
+ _tqdm.tqdm.__init__orig__ = _tqdm.tqdm.__init__
568
+ _tqdm.tqdm.__init__ = init_tqdm
569
+ if not hasattr(_tqdm.tqdm, "__update__orig__"):
570
+ _tqdm.tqdm.__update__orig__ = _tqdm.tqdm.update
571
+ _tqdm.tqdm.update = update_tqdm
572
+ if not hasattr(_tqdm.tqdm, "__close__orig__"):
573
+ _tqdm.tqdm.__close__orig__ = _tqdm.tqdm.close
574
+ _tqdm.tqdm.close = close_tqdm
575
+ if not hasattr(_tqdm.tqdm, "__exit__orig__"):
576
+ _tqdm.tqdm.__exit__orig__ = _tqdm.tqdm.__exit__
577
+ _tqdm.tqdm.__exit__ = exit_tqdm
578
+ if not hasattr(_tqdm.tqdm, "__iter__orig__"):
579
+ _tqdm.tqdm.__iter__orig__ = _tqdm.tqdm.__iter__
580
+ _tqdm.tqdm.__iter__ = iter_tqdm
581
+ if hasattr(_tqdm, "auto") and hasattr(_tqdm.auto, "tqdm"):
582
+ _tqdm.auto.tqdm = _tqdm.tqdm
583
+
584
+ def tracked_fn(*args):
585
+ thread_id = threading.get_ident()
586
+ root_blocks._progress_tracker_per_thread[thread_id] = progress
587
+ response = fn(*args)
588
+ del root_blocks._progress_tracker_per_thread[thread_id]
589
+ return response
590
+
591
+ return progress, tracked_fn
592
+
593
+
594
+ def special_args(
595
+ fn: Callable,
596
+ inputs: List[Any] | None = None,
597
+ request: routes.Request | None = None,
598
+ event_data: EventData | None = None,
599
+ ):
600
+ """
601
+ Checks if function has special arguments Request or EventData (via annotation) or Progress (via default value).
602
+ If inputs is provided, these values will be loaded into the inputs array.
603
+ Parameters:
604
+ block_fn: function to check.
605
+ inputs: array to load special arguments into.
606
+ request: request to load into inputs.
607
+ Returns:
608
+ updated inputs, progress index, event data index.
609
+ """
610
+ signature = inspect.signature(fn)
611
+ positional_args = []
612
+ for i, param in enumerate(signature.parameters.values()):
613
+ if param.kind not in (param.POSITIONAL_ONLY, param.POSITIONAL_OR_KEYWORD):
614
+ break
615
+ positional_args.append(param)
616
+ progress_index = None
617
+ event_data_index = None
618
+ for i, param in enumerate(positional_args):
619
+ if isinstance(param.default, Progress):
620
+ progress_index = i
621
+ if inputs is not None:
622
+ inputs.insert(i, param.default)
623
+ elif param.annotation == routes.Request:
624
+ if inputs is not None:
625
+ inputs.insert(i, request)
626
+ elif isinstance(param.annotation, type) and issubclass(
627
+ param.annotation, EventData
628
+ ):
629
+ event_data_index = i
630
+ if inputs is not None and event_data is not None:
631
+ inputs.insert(i, param.annotation(event_data.target, event_data._data))
632
+ if inputs is not None:
633
+ while len(inputs) < len(positional_args):
634
+ i = len(inputs)
635
+ param = positional_args[i]
636
+ if param.default == param.empty:
637
+ warnings.warn("Unexpected argument. Filling with None.")
638
+ inputs.append(None)
639
+ else:
640
+ inputs.append(param.default)
641
+ return inputs or [], progress_index, event_data_index
642
+
643
+
644
+ @document()
645
+ def update(**kwargs) -> dict:
646
+ """
647
+ Updates component properties. When a function passed into a Gradio Interface or a Blocks events returns a typical value, it updates the value of the output component. But it is also possible to update the properties of an output component (such as the number of lines of a `Textbox` or the visibility of an `Image`) by returning the component's `update()` function, which takes as parameters any of the constructor parameters for that component.
648
+ This is a shorthand for using the update method on a component.
649
+ For example, rather than using gr.Number.update(...) you can just use gr.update(...).
650
+ Note that your editor's autocompletion will suggest proper parameters
651
+ if you use the update method on the component.
652
+ Demos: blocks_essay, blocks_update, blocks_essay_update
653
+
654
+ Parameters:
655
+ kwargs: Key-word arguments used to update the component's properties.
656
+ Example:
657
+ # Blocks Example
658
+ import gradio as gr
659
+ with gr.Blocks() as demo:
660
+ radio = gr.Radio([1, 2, 4], label="Set the value of the number")
661
+ number = gr.Number(value=2, interactive=True)
662
+ radio.change(fn=lambda value: gr.update(value=value), inputs=radio, outputs=number)
663
+ demo.launch()
664
+
665
+ # Interface example
666
+ import gradio as gr
667
+ def change_textbox(choice):
668
+ if choice == "short":
669
+ return gr.Textbox.update(lines=2, visible=True)
670
+ elif choice == "long":
671
+ return gr.Textbox.update(lines=8, visible=True)
672
+ else:
673
+ return gr.Textbox.update(visible=False)
674
+ gr.Interface(
675
+ change_textbox,
676
+ gr.Radio(
677
+ ["short", "long", "none"], label="What kind of essay would you like to write?"
678
+ ),
679
+ gr.Textbox(lines=2),
680
+ live=True,
681
+ ).launch()
682
+ """
683
+ kwargs["__type__"] = "generic_update"
684
+ return kwargs
685
+
686
+
687
+ def skip() -> dict:
688
+ return update()
689
+
690
+
691
+ @document()
692
+ def make_waveform(
693
+ audio: str | Tuple[int, np.ndarray],
694
+ *,
695
+ bg_color: str = "#f3f4f6",
696
+ bg_image: str | None = None,
697
+ fg_alpha: float = 0.75,
698
+ bars_color: str | Tuple[str, str] = ("#fbbf24", "#ea580c"),
699
+ bar_count: int = 50,
700
+ bar_width: float = 0.6,
701
+ ):
702
+ """
703
+ Generates a waveform video from an audio file. Useful for creating an easy to share audio visualization. The output should be passed into a `gr.Video` component.
704
+ Parameters:
705
+ audio: Audio file path or tuple of (sample_rate, audio_data)
706
+ bg_color: Background color of waveform (ignored if bg_image is provided)
707
+ bg_image: Background image of waveform
708
+ fg_alpha: Opacity of foreground waveform
709
+ bars_color: Color of waveform bars. Can be a single color or a tuple of (start_color, end_color) of gradient
710
+ bar_count: Number of bars in waveform
711
+ bar_width: Width of bars in waveform. 1 represents full width, 0.5 represents half width, etc.
712
+ Returns:
713
+ A filepath to the output video.
714
+ """
715
+ if isinstance(audio, str):
716
+ audio_file = audio
717
+ audio = processing_utils.audio_from_file(audio)
718
+ else:
719
+ tmp_wav = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
720
+ processing_utils.audio_to_file(audio[0], audio[1], tmp_wav.name)
721
+ audio_file = tmp_wav.name
722
+ duration = round(len(audio[1]) / audio[0], 4)
723
+
724
+ # Helper methods to create waveform
725
+ def hex_to_RGB(hex_str):
726
+ return [int(hex_str[i : i + 2], 16) for i in range(1, 6, 2)]
727
+
728
+ def get_color_gradient(c1, c2, n):
729
+ assert n > 1
730
+ c1_rgb = np.array(hex_to_RGB(c1)) / 255
731
+ c2_rgb = np.array(hex_to_RGB(c2)) / 255
732
+ mix_pcts = [x / (n - 1) for x in range(n)]
733
+ rgb_colors = [((1 - mix) * c1_rgb + (mix * c2_rgb)) for mix in mix_pcts]
734
+ return [
735
+ "#" + "".join([format(int(round(val * 255)), "02x") for val in item])
736
+ for item in rgb_colors
737
+ ]
738
+
739
+ # Reshape audio to have a fixed number of bars
740
+ samples = audio[1]
741
+ if len(samples.shape) > 1:
742
+ samples = np.mean(samples, 1)
743
+ bins_to_pad = bar_count - (len(samples) % bar_count)
744
+ samples = np.pad(samples, [(0, bins_to_pad)])
745
+ samples = np.reshape(samples, (bar_count, -1))
746
+ samples = np.abs(samples)
747
+ samples = np.max(samples, 1)
748
+
749
+ matplotlib.use("Agg")
750
+ plt.clf()
751
+ # Plot waveform
752
+ color = (
753
+ bars_color
754
+ if isinstance(bars_color, str)
755
+ else get_color_gradient(bars_color[0], bars_color[1], bar_count)
756
+ )
757
+ plt.bar(
758
+ np.arange(0, bar_count),
759
+ samples * 2,
760
+ bottom=(-1 * samples),
761
+ width=bar_width,
762
+ color=color,
763
+ )
764
+ plt.axis("off")
765
+ plt.margins(x=0)
766
+ tmp_img = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
767
+ savefig_kwargs: Dict[str, Any] = {"bbox_inches": "tight"}
768
+ if bg_image is not None:
769
+ savefig_kwargs["transparent"] = True
770
+ else:
771
+ savefig_kwargs["facecolor"] = bg_color
772
+ plt.savefig(tmp_img.name, **savefig_kwargs)
773
+ waveform_img = PIL.Image.open(tmp_img.name)
774
+ waveform_img = waveform_img.resize((1000, 200))
775
+
776
+ # Composite waveform with background image
777
+ if bg_image is not None:
778
+ waveform_array = np.array(waveform_img)
779
+ waveform_array[:, :, 3] = waveform_array[:, :, 3] * fg_alpha
780
+ waveform_img = PIL.Image.fromarray(waveform_array)
781
+
782
+ bg_img = PIL.Image.open(bg_image)
783
+ waveform_width, waveform_height = waveform_img.size
784
+ bg_width, bg_height = bg_img.size
785
+ if waveform_width != bg_width:
786
+ bg_img = bg_img.resize(
787
+ (waveform_width, 2 * int(bg_height * waveform_width / bg_width / 2))
788
+ )
789
+ bg_width, bg_height = bg_img.size
790
+ composite_height = max(bg_height, waveform_height)
791
+ composite = PIL.Image.new("RGBA", (waveform_width, composite_height), "#FFFFFF")
792
+ composite.paste(bg_img, (0, composite_height - bg_height))
793
+ composite.paste(
794
+ waveform_img, (0, composite_height - waveform_height), waveform_img
795
+ )
796
+ composite.save(tmp_img.name)
797
+ img_width, img_height = composite.size
798
+ else:
799
+ img_width, img_height = waveform_img.size
800
+ waveform_img.save(tmp_img.name)
801
+
802
+ # Convert waveform to video with ffmpeg
803
+ output_mp4 = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)
804
+
805
+ ffmpeg_cmd = f"""ffmpeg -loop 1 -i {tmp_img.name} -i {audio_file} -vf "color=c=#FFFFFF77:s={img_width}x{img_height}[bar];[0][bar]overlay=-w+(w/{duration})*t:H-h:shortest=1" -t {duration} -y {output_mp4.name}"""
806
+
807
+ subprocess.call(ffmpeg_cmd, shell=True)
808
+ return output_mp4.name
809
+
810
+
811
+ @document()
812
+ class EventData:
813
+ """
814
+ When a subclass of EventData is added as a type hint to an argument of an event listener method, this object will be passed as that argument.
815
+ It contains information about the event that triggered the listener, such the target object, and other data related to the specific event that are attributes of the subclass.
816
+
817
+ Example:
818
+ table = gr.Dataframe([[1, 2, 3], [4, 5, 6]])
819
+ gallery = gr.Gallery([("cat.jpg", "Cat"), ("dog.jpg", "Dog")])
820
+ textbox = gr.Textbox("Hello World!")
821
+
822
+ statement = gr.Textbox()
823
+
824
+ def on_select(evt: gr.SelectData): # SelectData is a subclass of EventData
825
+ return f"You selected {evt.value} at {evt.index} from {evt.target}"
826
+
827
+ table.select(on_select, None, statement)
828
+ gallery.select(on_select, None, statement)
829
+ textbox.select(on_select, None, statement)
830
+ Demos: gallery_selections, tictactoe
831
+ """
832
+
833
+ def __init__(self, target: Block | None, _data: Any):
834
+ """
835
+ Parameters:
836
+ target: The target object that triggered the event. Can be used to distinguish if multiple components are bound to the same listener.
837
+ """
838
+ self.target = target
839
+ self._data = _data
gradio/inputs.py ADDED
@@ -0,0 +1,473 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # type: ignore
2
+ """
3
+ This module defines various classes that can serve as the `input` to an interface. Each class must inherit from
4
+ `InputComponent`, and each class must define a path to its template. All of the subclasses of `InputComponent` are
5
+ automatically added to a registry, which allows them to be easily referenced in other parts of the code.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import warnings
11
+ from typing import Any, List, Optional, Tuple
12
+
13
+ from gradio import components
14
+
15
+
16
+ class Textbox(components.Textbox):
17
+ def __init__(
18
+ self,
19
+ lines: int = 1,
20
+ placeholder: Optional[str] = None,
21
+ default: str = "",
22
+ numeric: Optional[bool] = False,
23
+ type: Optional[str] = "text",
24
+ label: Optional[str] = None,
25
+ optional: bool = False,
26
+ ):
27
+ warnings.warn(
28
+ "Usage of gradio.inputs is deprecated, and will not be supported in the future, please import your component from gradio.components",
29
+ )
30
+ super().__init__(
31
+ value=default,
32
+ lines=lines,
33
+ placeholder=placeholder,
34
+ label=label,
35
+ numeric=numeric,
36
+ type=type,
37
+ optional=optional,
38
+ )
39
+
40
+
41
+ class Number(components.Number):
42
+ """
43
+ Component creates a field for user to enter numeric input. Provides a number as an argument to the wrapped function.
44
+ Input type: float
45
+ """
46
+
47
+ def __init__(
48
+ self,
49
+ default: Optional[float] = None,
50
+ label: Optional[str] = None,
51
+ optional: bool = False,
52
+ ):
53
+ """
54
+ Parameters:
55
+ default (float): default value.
56
+ label (str): component name in interface.
57
+ optional (bool): If True, the interface can be submitted with no value for this component.
58
+ """
59
+ warnings.warn(
60
+ "Usage of gradio.inputs is deprecated, and will not be supported in the future, please import your component from gradio.components",
61
+ )
62
+ super().__init__(value=default, label=label, optional=optional)
63
+
64
+
65
+ class Slider(components.Slider):
66
+ """
67
+ Component creates a slider that ranges from `minimum` to `maximum`. Provides number as an argument to the wrapped function.
68
+ Input type: float
69
+ """
70
+
71
+ def __init__(
72
+ self,
73
+ minimum: float = 0,
74
+ maximum: float = 100,
75
+ step: Optional[float] = None,
76
+ default: Optional[float] = None,
77
+ label: Optional[str] = None,
78
+ optional: bool = False,
79
+ ):
80
+ """
81
+ Parameters:
82
+ minimum (float): minimum value for slider.
83
+ maximum (float): maximum value for slider.
84
+ step (float): increment between slider values.
85
+ default (float): default value.
86
+ label (str): component name in interface.
87
+ optional (bool): this parameter is ignored.
88
+ """
89
+ warnings.warn(
90
+ "Usage of gradio.inputs is deprecated, and will not be supported in the future, please import your component from gradio.components",
91
+ )
92
+
93
+ super().__init__(
94
+ value=default,
95
+ minimum=minimum,
96
+ maximum=maximum,
97
+ step=step,
98
+ label=label,
99
+ optional=optional,
100
+ )
101
+
102
+
103
+ class Checkbox(components.Checkbox):
104
+ """
105
+ Component creates a checkbox that can be set to `True` or `False`. Provides a boolean as an argument to the wrapped function.
106
+ Input type: bool
107
+ """
108
+
109
+ def __init__(
110
+ self,
111
+ default: bool = False,
112
+ label: Optional[str] = None,
113
+ optional: bool = False,
114
+ ):
115
+ """
116
+ Parameters:
117
+ label (str): component name in interface.
118
+ default (bool): if True, checked by default.
119
+ optional (bool): this parameter is ignored.
120
+ """
121
+ warnings.warn(
122
+ "Usage of gradio.inputs is deprecated, and will not be supported in the future, please import your component from gradio.components",
123
+ )
124
+ super().__init__(value=default, label=label, optional=optional)
125
+
126
+
127
+ class CheckboxGroup(components.CheckboxGroup):
128
+ """
129
+ Component creates a set of checkboxes of which a subset can be selected. Provides a list of strings representing the selected choices as an argument to the wrapped function.
130
+ Input type: Union[List[str], List[int]]
131
+ """
132
+
133
+ def __init__(
134
+ self,
135
+ choices: List[str],
136
+ default: List[str] = [],
137
+ type: str = "value",
138
+ label: Optional[str] = None,
139
+ optional: bool = False,
140
+ ):
141
+ """
142
+ Parameters:
143
+ choices (List[str]): list of options to select from.
144
+ default (List[str]): default selected list of options.
145
+ type (str): Type of value to be returned by component. "value" returns the list of strings of the choices selected, "index" returns the list of indicies of the choices selected.
146
+ label (str): component name in interface.
147
+ optional (bool): this parameter is ignored.
148
+ """
149
+ warnings.warn(
150
+ "Usage of gradio.inputs is deprecated, and will not be supported in the future, please import your component from gradio.components",
151
+ )
152
+ super().__init__(
153
+ value=default,
154
+ choices=choices,
155
+ type=type,
156
+ label=label,
157
+ optional=optional,
158
+ )
159
+
160
+
161
+ class Radio(components.Radio):
162
+ """
163
+ Component creates a set of radio buttons of which only one can be selected. Provides string representing selected choice as an argument to the wrapped function.
164
+ Input type: Union[str, int]
165
+ """
166
+
167
+ def __init__(
168
+ self,
169
+ choices: List[str],
170
+ type: str = "value",
171
+ default: Optional[str] = None,
172
+ label: Optional[str] = None,
173
+ optional: bool = False,
174
+ ):
175
+ """
176
+ Parameters:
177
+ choices (List[str]): list of options to select from.
178
+ type (str): Type of value to be returned by component. "value" returns the string of the choice selected, "index" returns the index of the choice selected.
179
+ default (str): the button selected by default. If None, no button is selected by default.
180
+ label (str): component name in interface.
181
+ optional (bool): this parameter is ignored.
182
+ """
183
+ warnings.warn(
184
+ "Usage of gradio.inputs is deprecated, and will not be supported in the future, please import your component from gradio.components",
185
+ )
186
+ super().__init__(
187
+ choices=choices,
188
+ type=type,
189
+ value=default,
190
+ label=label,
191
+ optional=optional,
192
+ )
193
+
194
+
195
+ class Dropdown(components.Dropdown):
196
+ """
197
+ Component creates a dropdown of which only one can be selected. Provides string representing selected choice as an argument to the wrapped function.
198
+ Input type: Union[str, int]
199
+ """
200
+
201
+ def __init__(
202
+ self,
203
+ choices: List[str],
204
+ type: str = "value",
205
+ default: Optional[str] = None,
206
+ label: Optional[str] = None,
207
+ optional: bool = False,
208
+ ):
209
+ """
210
+ Parameters:
211
+ choices (List[str]): list of options to select from.
212
+ type (str): Type of value to be returned by component. "value" returns the string of the choice selected, "index" returns the index of the choice selected.
213
+ default (str): default value selected in dropdown. If None, no value is selected by default.
214
+ label (str): component name in interface.
215
+ optional (bool): this parameter is ignored.
216
+ """
217
+ warnings.warn(
218
+ "Usage of gradio.inputs is deprecated, and will not be supported in the future, please import your component from gradio.components",
219
+ )
220
+ super().__init__(
221
+ choices=choices,
222
+ type=type,
223
+ value=default,
224
+ label=label,
225
+ optional=optional,
226
+ )
227
+
228
+
229
+ class Image(components.Image):
230
+ """
231
+ Component creates an image upload box with editing capabilities.
232
+ Input type: Union[numpy.array, PIL.Image, file-object]
233
+ """
234
+
235
+ def __init__(
236
+ self,
237
+ shape: Tuple[int, int] = None,
238
+ image_mode: str = "RGB",
239
+ invert_colors: bool = False,
240
+ source: str = "upload",
241
+ tool: str = "editor",
242
+ type: str = "numpy",
243
+ label: str = None,
244
+ optional: bool = False,
245
+ ):
246
+ """
247
+ Parameters:
248
+ shape (Tuple[int, int]): (width, height) shape to crop and resize image to; if None, matches input image size.
249
+ image_mode (str): How to process the uploaded image. Accepts any of the PIL image modes, e.g. "RGB" for color images, "RGBA" to include the transparency mask, "L" for black-and-white images.
250
+ invert_colors (bool): whether to invert the image as a preprocessing step.
251
+ source (str): Source of image. "upload" creates a box where user can drop an image file, "webcam" allows user to take snapshot from their webcam, "canvas" defaults to a white image that can be edited and drawn upon with tools.
252
+ tool (str): Tools used for editing. "editor" allows a full screen editor, "select" provides a cropping and zoom tool.
253
+ type (str): Type of value to be returned by component. "numpy" returns a numpy array with shape (width, height, 3) and values from 0 to 255, "pil" returns a PIL image object, "file" returns a temporary file object whose path can be retrieved by file_obj.name, "filepath" returns the path directly.
254
+ label (str): component name in interface.
255
+ optional (bool): If True, the interface can be submitted with no uploaded image, in which case the input value is None.
256
+ """
257
+ warnings.warn(
258
+ "Usage of gradio.inputs is deprecated, and will not be supported in the future, please import your component from gradio.components",
259
+ )
260
+ super().__init__(
261
+ shape=shape,
262
+ image_mode=image_mode,
263
+ invert_colors=invert_colors,
264
+ source=source,
265
+ tool=tool,
266
+ type=type,
267
+ label=label,
268
+ optional=optional,
269
+ )
270
+
271
+
272
+ class Video(components.Video):
273
+ """
274
+ Component creates a video file upload that is converted to a file path.
275
+
276
+ Input type: filepath
277
+ """
278
+
279
+ def __init__(
280
+ self,
281
+ type: Optional[str] = None,
282
+ source: str = "upload",
283
+ label: Optional[str] = None,
284
+ optional: bool = False,
285
+ ):
286
+ """
287
+ Parameters:
288
+ type (str): Type of video format to be returned by component, such as 'avi' or 'mp4'. If set to None, video will keep uploaded format.
289
+ source (str): Source of video. "upload" creates a box where user can drop an video file, "webcam" allows user to record a video from their webcam.
290
+ label (str): component name in interface.
291
+ optional (bool): If True, the interface can be submitted with no uploaded video, in which case the input value is None.
292
+ """
293
+ warnings.warn(
294
+ "Usage of gradio.inputs is deprecated, and will not be supported in the future, please import your components from gradio.components",
295
+ )
296
+ super().__init__(format=type, source=source, label=label, optional=optional)
297
+
298
+
299
+ class Audio(components.Audio):
300
+ """
301
+ Component accepts audio input files.
302
+ Input type: Union[Tuple[int, numpy.array], file-object, numpy.array]
303
+ """
304
+
305
+ def __init__(
306
+ self,
307
+ source: str = "upload",
308
+ type: str = "numpy",
309
+ label: str = None,
310
+ optional: bool = False,
311
+ ):
312
+ """
313
+ Parameters:
314
+ source (str): Source of audio. "upload" creates a box where user can drop an audio file, "microphone" creates a microphone input.
315
+ type (str): Type of value to be returned by component. "numpy" returns a 2-set tuple with an integer sample_rate and the data numpy.array of shape (samples, 2), "file" returns a temporary file object whose path can be retrieved by file_obj.name, "filepath" returns the path directly.
316
+ label (str): component name in interface.
317
+ optional (bool): If True, the interface can be submitted with no uploaded audio, in which case the input value is None.
318
+ """
319
+ warnings.warn(
320
+ "Usage of gradio.inputs is deprecated, and will not be supported in the future, please import your components from gradio.components",
321
+ )
322
+ super().__init__(source=source, type=type, label=label, optional=optional)
323
+
324
+
325
+ class File(components.File):
326
+ """
327
+ Component accepts generic file uploads.
328
+ Input type: Union[file-object, bytes, List[Union[file-object, bytes]]]
329
+ """
330
+
331
+ def __init__(
332
+ self,
333
+ file_count: str = "single",
334
+ type: str = "file",
335
+ label: Optional[str] = None,
336
+ keep_filename: bool = True,
337
+ optional: bool = False,
338
+ ):
339
+ """
340
+ Parameters:
341
+ file_count (str): if single, allows user to upload one file. If "multiple", user uploads multiple files. If "directory", user uploads all files in selected directory. Return type will be list for each file in case of "multiple" or "directory".
342
+ type (str): Type of value to be returned by component. "file" returns a temporary file object whose path can be retrieved by file_obj.name, "binary" returns an bytes object.
343
+ label (str): component name in interface.
344
+ keep_filename (bool): DEPRECATED. Original filename always kept.
345
+ optional (bool): If True, the interface can be submitted with no uploaded image, in which case the input value is None.
346
+ """
347
+ warnings.warn(
348
+ "Usage of gradio.inputs is deprecated, and will not be supported in the future, please import your components from gradio.components",
349
+ )
350
+ super().__init__(
351
+ file_count=file_count,
352
+ type=type,
353
+ label=label,
354
+ keep_filename=keep_filename,
355
+ optional=optional,
356
+ )
357
+
358
+
359
+ class Dataframe(components.Dataframe):
360
+ """
361
+ Component accepts 2D input through a spreadsheet interface.
362
+ Input type: Union[pandas.DataFrame, numpy.array, List[Union[str, float]], List[List[Union[str, float]]]]
363
+ """
364
+
365
+ def __init__(
366
+ self,
367
+ headers: Optional[List[str]] = None,
368
+ row_count: int = 3,
369
+ col_count: Optional[int] = 3,
370
+ datatype: str | List[str] = "str",
371
+ col_width: int | List[int] = None,
372
+ default: Optional[List[List[Any]]] = None,
373
+ type: str = "pandas",
374
+ label: Optional[str] = None,
375
+ optional: bool = False,
376
+ ):
377
+ """
378
+ Parameters:
379
+ headers (List[str]): Header names to dataframe. If None, no headers are shown.
380
+ row_count (int): Limit number of rows for input.
381
+ col_count (int): Limit number of columns for input. If equal to 1, return data will be one-dimensional. Ignored if `headers` is provided.
382
+ datatype (Union[str, List[str]]): Datatype of values in sheet. Can be provided per column as a list of strings, or for the entire sheet as a single string. Valid datatypes are "str", "number", "bool", and "date".
383
+ col_width (Union[int, List[int]]): Width of columns in pixels. Can be provided as single value or list of values per column.
384
+ default (List[List[Any]]): Default value
385
+ type (str): Type of value to be returned by component. "pandas" for pandas dataframe, "numpy" for numpy array, or "array" for a Python array.
386
+ label (str): component name in interface.
387
+ optional (bool): this parameter is ignored.
388
+ """
389
+ warnings.warn(
390
+ "Usage of gradio.inputs is deprecated, and will not be supported in the future, please import your components from gradio.components",
391
+ )
392
+ super().__init__(
393
+ value=default,
394
+ headers=headers,
395
+ row_count=row_count,
396
+ col_count=col_count,
397
+ datatype=datatype,
398
+ col_width=col_width,
399
+ type=type,
400
+ label=label,
401
+ optional=optional,
402
+ )
403
+
404
+
405
+ class Timeseries(components.Timeseries):
406
+ """
407
+ Component accepts pandas.DataFrame uploaded as a timeseries csv file.
408
+ Input type: pandas.DataFrame
409
+ """
410
+
411
+ def __init__(
412
+ self,
413
+ x: Optional[str] = None,
414
+ y: str | List[str] = None,
415
+ label: Optional[str] = None,
416
+ optional: bool = False,
417
+ ):
418
+ """
419
+ Parameters:
420
+ x (str): Column name of x (time) series. None if csv has no headers, in which case first column is x series.
421
+ y (Union[str, List[str]]): Column name of y series, or list of column names if multiple series. None if csv has no headers, in which case every column after first is a y series.
422
+ label (str): component name in interface.
423
+ optional (bool): If True, the interface can be submitted with no uploaded csv file, in which case the input value is None.
424
+ """
425
+ warnings.warn(
426
+ "Usage of gradio.inputs is deprecated, and will not be supported in the future, please import your components from gradio.components",
427
+ )
428
+ super().__init__(x=x, y=y, label=label, optional=optional)
429
+
430
+
431
+ class State(components.State):
432
+ """
433
+ Special hidden component that stores state across runs of the interface.
434
+ Input type: Any
435
+ """
436
+
437
+ def __init__(
438
+ self,
439
+ label: str = None,
440
+ default: Any = None,
441
+ ):
442
+ """
443
+ Parameters:
444
+ label (str): component name in interface (not used).
445
+ default (Any): the initial value of the state.
446
+ optional (bool): this parameter is ignored.
447
+ """
448
+ warnings.warn(
449
+ "Usage of gradio.inputs is deprecated, and will not be supported in the future, please import this component as gr.State() from gradio.components",
450
+ )
451
+ super().__init__(value=default, label=label)
452
+
453
+
454
+ class Image3D(components.Model3D):
455
+ """
456
+ Used for 3D image model output.
457
+ Input type: File object of type (.obj, glb, or .gltf)
458
+ """
459
+
460
+ def __init__(
461
+ self,
462
+ label: Optional[str] = None,
463
+ optional: bool = False,
464
+ ):
465
+ """
466
+ Parameters:
467
+ label (str): component name in interface.
468
+ optional (bool): If True, the interface can be submitted with no uploaded image, in which case the input value is None.
469
+ """
470
+ warnings.warn(
471
+ "Usage of gradio.outputs is deprecated, and will not be supported in the future, please import your components from gradio.components",
472
+ )
473
+ super().__init__(label=label, optional=optional)
gradio/interface.py ADDED
@@ -0,0 +1,888 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This is the core file in the `gradio` package, and defines the Interface class,
3
+ including various methods for constructing an interface and then launching it.
4
+ """
5
+
6
+ from __future__ import annotations
7
+
8
+ import inspect
9
+ import json
10
+ import os
11
+ import re
12
+ import warnings
13
+ import weakref
14
+ from typing import TYPE_CHECKING, Any, Callable, List, Tuple
15
+
16
+ from gradio import Examples, interpretation, utils
17
+ from gradio.blocks import Blocks
18
+ from gradio.components import (
19
+ Button,
20
+ Interpretation,
21
+ IOComponent,
22
+ Markdown,
23
+ State,
24
+ get_component_instance,
25
+ )
26
+ from gradio.data_classes import InterfaceTypes
27
+ from gradio.documentation import document, set_documentation_group
28
+ from gradio.events import Changeable, Streamable
29
+ from gradio.flagging import CSVLogger, FlaggingCallback, FlagMethod
30
+ from gradio.layouts import Column, Row, Tab, Tabs
31
+ from gradio.pipelines import load_from_pipeline
32
+ from gradio.themes import ThemeClass as Theme
33
+ from gradio.utils import GRADIO_VERSION
34
+
35
+ set_documentation_group("interface")
36
+
37
+ if TYPE_CHECKING: # Only import for type checking (is False at runtime).
38
+ from transformers.pipelines.base import Pipeline
39
+
40
+
41
+ @document("launch", "load", "from_pipeline", "integrate", "queue")
42
+ class Interface(Blocks):
43
+ """
44
+ Interface is Gradio's main high-level class, and allows you to create a web-based GUI / demo
45
+ around a machine learning model (or any Python function) in a few lines of code.
46
+ You must specify three parameters: (1) the function to create a GUI for (2) the desired input components and
47
+ (3) the desired output components. Additional parameters can be used to control the appearance
48
+ and behavior of the demo.
49
+
50
+ Example:
51
+ import gradio as gr
52
+
53
+ def image_classifier(inp):
54
+ return {'cat': 0.3, 'dog': 0.7}
55
+
56
+ demo = gr.Interface(fn=image_classifier, inputs="image", outputs="label")
57
+ demo.launch()
58
+ Demos: hello_world, hello_world_3, gpt_j
59
+ Guides: quickstart, key_features, sharing_your_app, interface_state, reactive_interfaces, advanced_interface_features, setting_up_a_gradio_demo_for_maximum_performance
60
+ """
61
+
62
+ # stores references to all currently existing Interface instances
63
+ instances: weakref.WeakSet = weakref.WeakSet()
64
+
65
+ @classmethod
66
+ def get_instances(cls) -> List[Interface]:
67
+ """
68
+ :return: list of all current instances.
69
+ """
70
+ return list(Interface.instances)
71
+
72
+ @classmethod
73
+ def load(
74
+ cls,
75
+ name: str,
76
+ src: str | None = None,
77
+ api_key: str | None = None,
78
+ alias: str | None = None,
79
+ **kwargs,
80
+ ) -> Interface:
81
+ """
82
+ Class method that constructs an Interface from a Hugging Face repo. Can accept
83
+ model repos (if src is "models") or Space repos (if src is "spaces"). The input
84
+ and output components are automatically loaded from the repo.
85
+ Parameters:
86
+ name: the name of the model (e.g. "gpt2" or "facebook/bart-base") or space (e.g. "flax-community/spanish-gpt2"), can include the `src` as prefix (e.g. "models/facebook/bart-base")
87
+ src: the source of the model: `models` or `spaces` (or leave empty if source is provided as a prefix in `name`)
88
+ api_key: optional access token for loading private Hugging Face Hub models or spaces. Find your token here: https://huggingface.co/settings/tokens
89
+ alias: optional string used as the name of the loaded model instead of the default name (only applies if loading a Space running Gradio 2.x)
90
+ Returns:
91
+ a Gradio Interface object for the given model
92
+ Example:
93
+ import gradio as gr
94
+ description = "Story generation with GPT"
95
+ examples = [["An adventurer is approached by a mysterious stranger in the tavern for a new quest."]]
96
+ demo = gr.Interface.load("models/EleutherAI/gpt-neo-1.3B", description=description, examples=examples)
97
+ demo.launch()
98
+ """
99
+ return super().load(name=name, src=src, api_key=api_key, alias=alias, **kwargs)
100
+
101
+ @classmethod
102
+ def from_pipeline(cls, pipeline: Pipeline, **kwargs) -> Interface:
103
+ """
104
+ Class method that constructs an Interface from a Hugging Face transformers.Pipeline object.
105
+ The input and output components are automatically determined from the pipeline.
106
+ Parameters:
107
+ pipeline: the pipeline object to use.
108
+ Returns:
109
+ a Gradio Interface object from the given Pipeline
110
+ Example:
111
+ import gradio as gr
112
+ from transformers import pipeline
113
+ pipe = pipeline("image-classification")
114
+ gr.Interface.from_pipeline(pipe).launch()
115
+ """
116
+ interface_info = load_from_pipeline(pipeline)
117
+ kwargs = dict(interface_info, **kwargs)
118
+ interface = cls(**kwargs)
119
+ return interface
120
+
121
+ def __init__(
122
+ self,
123
+ fn: Callable,
124
+ inputs: str | IOComponent | List[str | IOComponent] | None,
125
+ outputs: str | IOComponent | List[str | IOComponent] | None,
126
+ examples: List[Any] | List[List[Any]] | str | None = None,
127
+ cache_examples: bool | None = None,
128
+ examples_per_page: int = 10,
129
+ live: bool = False,
130
+ interpretation: Callable | str | None = None,
131
+ num_shap: float = 2.0,
132
+ title: str | None = None,
133
+ description: str | None = None,
134
+ article: str | None = None,
135
+ thumbnail: str | None = None,
136
+ theme: Theme | None = None,
137
+ css: str | None = None,
138
+ allow_flagging: str | None = None,
139
+ flagging_options: List[str] | List[Tuple[str, str]] | None = None,
140
+ flagging_dir: str = "flagged",
141
+ flagging_callback: FlaggingCallback = CSVLogger(),
142
+ analytics_enabled: bool | None = None,
143
+ batch: bool = False,
144
+ max_batch_size: int = 4,
145
+ _api_mode: bool = False,
146
+ **kwargs,
147
+ ):
148
+ """
149
+ Parameters:
150
+ fn: the function to wrap an interface around. Often a machine learning model's prediction function. Each parameter of the function corresponds to one input component, and the function should return a single value or a tuple of values, with each element in the tuple corresponding to one output component.
151
+ inputs: a single Gradio component, or list of Gradio components. Components can either be passed as instantiated objects, or referred to by their string shortcuts. The number of input components should match the number of parameters in fn. If set to None, then only the output components will be displayed.
152
+ outputs: a single Gradio component, or list of Gradio components. Components can either be passed as instantiated objects, or referred to by their string shortcuts. The number of output components should match the number of values returned by fn. If set to None, then only the input components will be displayed.
153
+ examples: sample inputs for the function; if provided, appear below the UI components and can be clicked to populate the interface. Should be nested list, in which the outer list consists of samples and each inner list consists of an input corresponding to each input component. A string path to a directory of examples can also be provided, but it should be within the directory with the python file running the gradio app. If there are multiple input components and a directory is provided, a log.csv file must be present in the directory to link corresponding inputs.
154
+ cache_examples: If True, caches examples in the server for fast runtime in examples. The default option in HuggingFace Spaces is True. The default option elsewhere is False.
155
+ examples_per_page: If examples are provided, how many to display per page.
156
+ live: whether the interface should automatically rerun if any of the inputs change.
157
+ interpretation: function that provides interpretation explaining prediction output. Pass "default" to use simple built-in interpreter, "shap" to use a built-in shapley-based interpreter, or your own custom interpretation function. For more information on the different interpretation methods, see the Advanced Interface Features guide.
158
+ num_shap: a multiplier that determines how many examples are computed for shap-based interpretation. Increasing this value will increase shap runtime, but improve results. Only applies if interpretation is "shap".
159
+ title: a title for the interface; if provided, appears above the input and output components in large font. Also used as the tab title when opened in a browser window.
160
+ description: a description for the interface; if provided, appears above the input and output components and beneath the title in regular font. Accepts Markdown and HTML content.
161
+ article: an expanded article explaining the interface; if provided, appears below the input and output components in regular font. Accepts Markdown and HTML content.
162
+ thumbnail: path or url to image to use as display image when the web demo is shared on social media.
163
+ theme: Theme to use, loaded from gradio.themes.
164
+ css: custom css or path to custom css file to use with interface.
165
+ allow_flagging: one of "never", "auto", or "manual". If "never" or "auto", users will not see a button to flag an input and output. If "manual", users will see a button to flag. If "auto", every input the user submits will be automatically flagged (outputs are not flagged). If "manual", both the input and outputs are flagged when the user clicks flag button. This parameter can be set with environmental variable GRADIO_ALLOW_FLAGGING; otherwise defaults to "manual".
166
+ flagging_options: if provided, allows user to select from the list of options when flagging. Only applies if allow_flagging is "manual". Can either be a list of tuples of the form (label, value), where label is the string that will be displayed on the button and value is the string that will be stored in the flagging CSV; or it can be a list of strings ["X", "Y"], in which case the values will be the list of strings and the labels will ["Flag as X", "Flag as Y"], etc.
167
+ flagging_dir: what to name the directory where flagged data is stored.
168
+ flagging_callback: An instance of a subclass of FlaggingCallback which will be called when a sample is flagged. By default logs to a local CSV file.
169
+ analytics_enabled: Whether to allow basic telemetry. If None, will use GRADIO_ANALYTICS_ENABLED environment variable if defined, or default to True.
170
+ batch: If True, then the function should process a batch of inputs, meaning that it should accept a list of input values for each parameter. The lists should be of equal length (and be up to length `max_batch_size`). The function is then *required* to return a tuple of lists (even if there is only 1 output component), with each list in the tuple corresponding to one output component.
171
+ max_batch_size: Maximum number of inputs to batch together if this is called from the queue (only relevant if batch=True)
172
+ """
173
+ super().__init__(
174
+ analytics_enabled=analytics_enabled,
175
+ mode="interface",
176
+ css=css,
177
+ title=title or "Gradio",
178
+ theme=theme,
179
+ **kwargs,
180
+ )
181
+
182
+ if isinstance(fn, list):
183
+ raise DeprecationWarning(
184
+ "The `fn` parameter only accepts a single function, support for a list "
185
+ "of functions has been deprecated. Please use gradio.mix.Parallel "
186
+ "instead."
187
+ )
188
+
189
+ self.interface_type = InterfaceTypes.STANDARD
190
+ if (inputs is None or inputs == []) and (outputs is None or outputs == []):
191
+ raise ValueError("Must provide at least one of `inputs` or `outputs`")
192
+ elif outputs is None or outputs == []:
193
+ outputs = []
194
+ self.interface_type = InterfaceTypes.INPUT_ONLY
195
+ elif inputs is None or inputs == []:
196
+ inputs = []
197
+ self.interface_type = InterfaceTypes.OUTPUT_ONLY
198
+
199
+ assert isinstance(inputs, (str, list, IOComponent))
200
+ assert isinstance(outputs, (str, list, IOComponent))
201
+
202
+ if not isinstance(inputs, list):
203
+ inputs = [inputs]
204
+ if not isinstance(outputs, list):
205
+ outputs = [outputs]
206
+
207
+ if self.is_space and cache_examples is None:
208
+ self.cache_examples = True
209
+ else:
210
+ self.cache_examples = cache_examples or False
211
+
212
+ state_input_indexes = [
213
+ idx for idx, i in enumerate(inputs) if i == "state" or isinstance(i, State)
214
+ ]
215
+ state_output_indexes = [
216
+ idx for idx, o in enumerate(outputs) if o == "state" or isinstance(o, State)
217
+ ]
218
+
219
+ if len(state_input_indexes) == 0 and len(state_output_indexes) == 0:
220
+ pass
221
+ elif len(state_input_indexes) != 1 or len(state_output_indexes) != 1:
222
+ raise ValueError(
223
+ "If using 'state', there must be exactly one state input and one state output."
224
+ )
225
+ else:
226
+ state_input_index = state_input_indexes[0]
227
+ state_output_index = state_output_indexes[0]
228
+ if inputs[state_input_index] == "state":
229
+ default = utils.get_default_args(fn)[state_input_index]
230
+ state_variable = State(value=default) # type: ignore
231
+ else:
232
+ state_variable = inputs[state_input_index]
233
+
234
+ inputs[state_input_index] = state_variable
235
+ outputs[state_output_index] = state_variable
236
+
237
+ if cache_examples:
238
+ warnings.warn(
239
+ "Cache examples cannot be used with state inputs and outputs."
240
+ "Setting cache_examples to False."
241
+ )
242
+ self.cache_examples = False
243
+
244
+ self.input_components = [
245
+ get_component_instance(i, render=False) for i in inputs
246
+ ]
247
+ self.output_components = [
248
+ get_component_instance(o, render=False) for o in outputs
249
+ ]
250
+
251
+ for component in self.input_components + self.output_components:
252
+ if not (isinstance(component, IOComponent)):
253
+ raise ValueError(
254
+ f"{component} is not a valid input/output component for Interface."
255
+ )
256
+
257
+ if len(self.input_components) == len(self.output_components):
258
+ same_components = [
259
+ i is o for i, o in zip(self.input_components, self.output_components)
260
+ ]
261
+ if all(same_components):
262
+ self.interface_type = InterfaceTypes.UNIFIED
263
+
264
+ if self.interface_type in [
265
+ InterfaceTypes.STANDARD,
266
+ InterfaceTypes.OUTPUT_ONLY,
267
+ ]:
268
+ for o in self.output_components:
269
+ assert isinstance(o, IOComponent)
270
+ o.interactive = False # Force output components to be non-interactive
271
+
272
+ if (
273
+ interpretation is None
274
+ or isinstance(interpretation, list)
275
+ or callable(interpretation)
276
+ ):
277
+ self.interpretation = interpretation
278
+ elif isinstance(interpretation, str):
279
+ self.interpretation = [
280
+ interpretation.lower() for _ in self.input_components
281
+ ]
282
+ else:
283
+ raise ValueError("Invalid value for parameter: interpretation")
284
+
285
+ self.api_mode = _api_mode
286
+ self.fn = fn
287
+ self.fn_durations = [0, 0]
288
+ self.__name__ = getattr(fn, "__name__", "fn")
289
+ self.live = live
290
+ self.title = title
291
+
292
+ CLEANER = re.compile("<.*?>")
293
+
294
+ def clean_html(raw_html):
295
+ cleantext = re.sub(CLEANER, "", raw_html)
296
+ return cleantext
297
+
298
+ md = utils.get_markdown_parser()
299
+ simple_description = None
300
+ if description is not None:
301
+ description = md.render(description)
302
+ simple_description = clean_html(description)
303
+ self.simple_description = simple_description
304
+ self.description = description
305
+ if article is not None:
306
+ article = utils.readme_to_html(article)
307
+ article = md.render(article)
308
+ self.article = article
309
+
310
+ self.thumbnail = thumbnail
311
+ self.theme = theme
312
+
313
+ self.examples = examples
314
+ self.num_shap = num_shap
315
+ self.examples_per_page = examples_per_page
316
+
317
+ self.simple_server = None
318
+
319
+ # For allow_flagging: (1) first check for parameter,
320
+ # (2) check for env variable, (3) default to True/"manual"
321
+ if allow_flagging is None:
322
+ allow_flagging = os.getenv("GRADIO_ALLOW_FLAGGING", "manual")
323
+ if allow_flagging is True:
324
+ warnings.warn(
325
+ "The `allow_flagging` parameter in `Interface` now"
326
+ "takes a string value ('auto', 'manual', or 'never')"
327
+ ", not a boolean. Setting parameter to: 'manual'."
328
+ )
329
+ self.allow_flagging = "manual"
330
+ elif allow_flagging == "manual":
331
+ self.allow_flagging = "manual"
332
+ elif allow_flagging is False:
333
+ warnings.warn(
334
+ "The `allow_flagging` parameter in `Interface` now"
335
+ "takes a string value ('auto', 'manual', or 'never')"
336
+ ", not a boolean. Setting parameter to: 'never'."
337
+ )
338
+ self.allow_flagging = "never"
339
+ elif allow_flagging == "never":
340
+ self.allow_flagging = "never"
341
+ elif allow_flagging == "auto":
342
+ self.allow_flagging = "auto"
343
+ else:
344
+ raise ValueError(
345
+ "Invalid value for `allow_flagging` parameter."
346
+ "Must be: 'auto', 'manual', or 'never'."
347
+ )
348
+
349
+ if flagging_options is None:
350
+ self.flagging_options = [("Flag", "")]
351
+ elif not (isinstance(flagging_options, list)):
352
+ raise ValueError(
353
+ "flagging_options must be a list of strings or list of (string, string) tuples."
354
+ )
355
+ elif all([isinstance(x, str) for x in flagging_options]):
356
+ self.flagging_options = [(f"Flag as {x}", x) for x in flagging_options]
357
+ elif all([isinstance(x, tuple) for x in flagging_options]):
358
+ self.flagging_options = flagging_options
359
+ else:
360
+ raise ValueError(
361
+ "flagging_options must be a list of strings or list of (string, string) tuples."
362
+ )
363
+
364
+ self.flagging_callback = flagging_callback
365
+ self.flagging_dir = flagging_dir
366
+ self.batch = batch
367
+ self.max_batch_size = max_batch_size
368
+
369
+ self.save_to = None # Used for selenium tests
370
+ self.share = None
371
+ self.share_url = None
372
+ self.local_url = None
373
+
374
+ self.favicon_path = None
375
+
376
+ if self.analytics_enabled:
377
+ data = {
378
+ "mode": self.mode,
379
+ "fn": fn,
380
+ "inputs": inputs,
381
+ "outputs": outputs,
382
+ "live": live,
383
+ "interpretation": interpretation,
384
+ "allow_flagging": allow_flagging,
385
+ "custom_css": self.css is not None,
386
+ "theme": self.theme,
387
+ "version": GRADIO_VERSION,
388
+ }
389
+ utils.initiated_analytics(data)
390
+
391
+ utils.version_check()
392
+ Interface.instances.add(self)
393
+
394
+ param_names = inspect.getfullargspec(self.fn)[0]
395
+ if len(param_names) > 0 and inspect.ismethod(self.fn):
396
+ param_names = param_names[1:]
397
+ for component, param_name in zip(self.input_components, param_names):
398
+ assert isinstance(component, IOComponent)
399
+ if component.label is None:
400
+ component.label = param_name
401
+ for i, component in enumerate(self.output_components):
402
+ assert isinstance(component, IOComponent)
403
+ if component.label is None:
404
+ if len(self.output_components) == 1:
405
+ component.label = "output"
406
+ else:
407
+ component.label = "output " + str(i)
408
+
409
+ if self.allow_flagging != "never":
410
+ if (
411
+ self.interface_type == InterfaceTypes.UNIFIED
412
+ or self.allow_flagging == "auto"
413
+ ):
414
+ self.flagging_callback.setup(self.input_components, self.flagging_dir) # type: ignore
415
+ elif self.interface_type == InterfaceTypes.INPUT_ONLY:
416
+ pass
417
+ else:
418
+ self.flagging_callback.setup(
419
+ self.input_components + self.output_components, self.flagging_dir # type: ignore
420
+ )
421
+
422
+ # Render the Gradio UI
423
+ with self:
424
+ self.render_title_description()
425
+
426
+ submit_btn, clear_btn, stop_btn, flag_btns = None, None, None, None
427
+ interpretation_btn, interpretation_set = None, None
428
+ input_component_column, interpret_component_column = None, None
429
+
430
+ with Row().style(equal_height=False):
431
+ if self.interface_type in [
432
+ InterfaceTypes.STANDARD,
433
+ InterfaceTypes.INPUT_ONLY,
434
+ InterfaceTypes.UNIFIED,
435
+ ]:
436
+ (
437
+ submit_btn,
438
+ clear_btn,
439
+ stop_btn,
440
+ flag_btns,
441
+ input_component_column,
442
+ interpret_component_column,
443
+ interpretation_set,
444
+ ) = self.render_input_column()
445
+ if self.interface_type in [
446
+ InterfaceTypes.STANDARD,
447
+ InterfaceTypes.OUTPUT_ONLY,
448
+ ]:
449
+ (
450
+ submit_btn_out,
451
+ clear_btn_2_out,
452
+ stop_btn_2_out,
453
+ flag_btns_out,
454
+ interpretation_btn,
455
+ ) = self.render_output_column(submit_btn)
456
+ submit_btn = submit_btn or submit_btn_out
457
+ clear_btn = clear_btn or clear_btn_2_out
458
+ stop_btn = stop_btn or stop_btn_2_out
459
+ flag_btns = flag_btns or flag_btns_out
460
+
461
+ assert clear_btn is not None, "Clear button not rendered"
462
+
463
+ self.attach_submit_events(submit_btn, stop_btn)
464
+ self.attach_clear_events(
465
+ clear_btn, input_component_column, interpret_component_column
466
+ )
467
+ self.attach_interpretation_events(
468
+ interpretation_btn,
469
+ interpretation_set,
470
+ input_component_column,
471
+ interpret_component_column,
472
+ )
473
+
474
+ self.attach_flagging_events(flag_btns, clear_btn)
475
+ self.render_examples()
476
+ self.render_article()
477
+
478
+ self.config = self.get_config_file()
479
+
480
+ def render_title_description(self) -> None:
481
+ if self.title:
482
+ Markdown(
483
+ "<h1 style='text-align: center; margin-bottom: 1rem'>"
484
+ + self.title
485
+ + "</h1>"
486
+ )
487
+ if self.description:
488
+ Markdown(self.description)
489
+
490
+ def render_flag_btns(self) -> List[Button]:
491
+ return [Button(label) for label, _ in self.flagging_options]
492
+
493
+ def render_input_column(
494
+ self,
495
+ ) -> Tuple[
496
+ Button | None,
497
+ Button | None,
498
+ Button | None,
499
+ List[Button] | None,
500
+ Column,
501
+ Column | None,
502
+ List[Interpretation] | None,
503
+ ]:
504
+ submit_btn, clear_btn, stop_btn, flag_btns = None, None, None, None
505
+ interpret_component_column, interpretation_set = None, None
506
+
507
+ with Column(variant="panel"):
508
+ input_component_column = Column()
509
+ with input_component_column:
510
+ for component in self.input_components:
511
+ component.render()
512
+ if self.interpretation:
513
+ interpret_component_column = Column(visible=False)
514
+ interpretation_set = []
515
+ with interpret_component_column:
516
+ for component in self.input_components:
517
+ interpretation_set.append(Interpretation(component))
518
+ with Row():
519
+ if self.interface_type in [
520
+ InterfaceTypes.STANDARD,
521
+ InterfaceTypes.INPUT_ONLY,
522
+ ]:
523
+ clear_btn = Button("Clear")
524
+ if not self.live:
525
+ submit_btn = Button("Submit", variant="primary")
526
+ # Stopping jobs only works if the queue is enabled
527
+ # We don't know if the queue is enabled when the interface
528
+ # is created. We use whether a generator function is provided
529
+ # as a proxy of whether the queue will be enabled.
530
+ # Using a generator function without the queue will raise an error.
531
+ if inspect.isgeneratorfunction(self.fn):
532
+ stop_btn = Button("Stop", variant="stop", visible=False)
533
+ elif self.interface_type == InterfaceTypes.UNIFIED:
534
+ clear_btn = Button("Clear")
535
+ submit_btn = Button("Submit", variant="primary")
536
+ if inspect.isgeneratorfunction(self.fn) and not self.live:
537
+ stop_btn = Button("Stop", variant="stop")
538
+ if self.allow_flagging == "manual":
539
+ flag_btns = self.render_flag_btns()
540
+ elif self.allow_flagging == "auto":
541
+ flag_btns = [submit_btn]
542
+ return (
543
+ submit_btn,
544
+ clear_btn,
545
+ stop_btn,
546
+ flag_btns,
547
+ input_component_column,
548
+ interpret_component_column,
549
+ interpretation_set,
550
+ )
551
+
552
+ def render_output_column(
553
+ self,
554
+ submit_btn_in: Button | None,
555
+ ) -> Tuple[Button | None, Button | None, Button | None, List | None, Button | None]:
556
+ submit_btn = submit_btn_in
557
+ interpretation_btn, clear_btn, flag_btns, stop_btn = None, None, None, None
558
+
559
+ with Column(variant="panel"):
560
+ for component in self.output_components:
561
+ if not (isinstance(component, State)):
562
+ component.render()
563
+ with Row():
564
+ if self.interface_type == InterfaceTypes.OUTPUT_ONLY:
565
+ clear_btn = Button("Clear")
566
+ submit_btn = Button("Generate", variant="primary")
567
+ if inspect.isgeneratorfunction(self.fn) and not self.live:
568
+ # Stopping jobs only works if the queue is enabled
569
+ # We don't know if the queue is enabled when the interface
570
+ # is created. We use whether a generator function is provided
571
+ # as a proxy of whether the queue will be enabled.
572
+ # Using a generator function without the queue will raise an error.
573
+ stop_btn = Button("Stop", variant="stop", visible=False)
574
+ if self.allow_flagging == "manual":
575
+ flag_btns = self.render_flag_btns()
576
+ elif self.allow_flagging == "auto":
577
+ assert submit_btn is not None, "Submit button not rendered"
578
+ flag_btns = [submit_btn]
579
+ if self.interpretation:
580
+ interpretation_btn = Button("Interpret")
581
+
582
+ return submit_btn, clear_btn, stop_btn, flag_btns, interpretation_btn
583
+
584
+ def render_article(self):
585
+ if self.article:
586
+ Markdown(self.article)
587
+
588
+ def attach_submit_events(self, submit_btn: Button | None, stop_btn: Button | None):
589
+ if self.live:
590
+ if self.interface_type == InterfaceTypes.OUTPUT_ONLY:
591
+ assert submit_btn is not None, "Submit button not rendered"
592
+ super().load(self.fn, None, self.output_components)
593
+ # For output-only interfaces, the user probably still want a "generate"
594
+ # button even if the Interface is live
595
+ submit_btn.click(
596
+ self.fn,
597
+ None,
598
+ self.output_components,
599
+ api_name="predict",
600
+ preprocess=not (self.api_mode),
601
+ postprocess=not (self.api_mode),
602
+ batch=self.batch,
603
+ max_batch_size=self.max_batch_size,
604
+ )
605
+ else:
606
+ for component in self.input_components:
607
+ if isinstance(component, Streamable) and component.streaming:
608
+ component.stream(
609
+ self.fn,
610
+ self.input_components,
611
+ self.output_components,
612
+ api_name="predict",
613
+ preprocess=not (self.api_mode),
614
+ postprocess=not (self.api_mode),
615
+ )
616
+ continue
617
+ if isinstance(component, Changeable):
618
+ component.change(
619
+ self.fn,
620
+ self.input_components,
621
+ self.output_components,
622
+ api_name="predict",
623
+ preprocess=not (self.api_mode),
624
+ postprocess=not (self.api_mode),
625
+ )
626
+ else:
627
+ assert submit_btn is not None, "Submit button not rendered"
628
+ fn = self.fn
629
+ extra_output = []
630
+ if stop_btn:
631
+
632
+ # Wrap the original function to show/hide the "Stop" button
633
+ def fn(*args):
634
+ # The main idea here is to call the original function
635
+ # and append some updates to keep the "Submit" button
636
+ # hidden and the "Stop" button visible
637
+ # The 'finally' block hides the "Stop" button and
638
+ # shows the "submit" button. Having a 'finally' block
639
+ # will make sure the UI is "reset" even if there is an exception
640
+ try:
641
+ for output in self.fn(*args):
642
+ if len(self.output_components) == 1 and not self.batch:
643
+ output = [output]
644
+ output = [o for o in output]
645
+ yield output + [
646
+ Button.update(visible=False),
647
+ Button.update(visible=True),
648
+ ]
649
+ finally:
650
+ yield [
651
+ {"__type__": "generic_update"}
652
+ for _ in self.output_components
653
+ ] + [Button.update(visible=True), Button.update(visible=False)]
654
+
655
+ extra_output = [submit_btn, stop_btn]
656
+ pred = submit_btn.click(
657
+ fn,
658
+ self.input_components,
659
+ self.output_components + extra_output,
660
+ api_name="predict",
661
+ scroll_to_output=True,
662
+ preprocess=not (self.api_mode),
663
+ postprocess=not (self.api_mode),
664
+ batch=self.batch,
665
+ max_batch_size=self.max_batch_size,
666
+ )
667
+ if stop_btn:
668
+ submit_btn.click(
669
+ lambda: (
670
+ submit_btn.update(visible=False),
671
+ stop_btn.update(visible=True),
672
+ ),
673
+ inputs=None,
674
+ outputs=[submit_btn, stop_btn],
675
+ queue=False,
676
+ )
677
+ stop_btn.click(
678
+ lambda: (
679
+ submit_btn.update(visible=True),
680
+ stop_btn.update(visible=False),
681
+ ),
682
+ inputs=None,
683
+ outputs=[submit_btn, stop_btn],
684
+ cancels=[pred],
685
+ queue=False,
686
+ )
687
+
688
+ def attach_clear_events(
689
+ self,
690
+ clear_btn: Button,
691
+ input_component_column: Column | None,
692
+ interpret_component_column: Column | None,
693
+ ):
694
+ clear_btn.click(
695
+ None,
696
+ [],
697
+ (
698
+ self.input_components
699
+ + self.output_components
700
+ + ([input_component_column] if input_component_column else [])
701
+ + ([interpret_component_column] if self.interpretation else [])
702
+ ), # type: ignore
703
+ _js=f"""() => {json.dumps(
704
+ [getattr(component, "cleared_value", None)
705
+ for component in self.input_components + self.output_components] + (
706
+ [Column.update(visible=True)]
707
+ if self.interface_type
708
+ in [
709
+ InterfaceTypes.STANDARD,
710
+ InterfaceTypes.INPUT_ONLY,
711
+ InterfaceTypes.UNIFIED,
712
+ ]
713
+ else []
714
+ )
715
+ + ([Column.update(visible=False)] if self.interpretation else [])
716
+ )}
717
+ """,
718
+ )
719
+
720
+ def attach_interpretation_events(
721
+ self,
722
+ interpretation_btn: Button | None,
723
+ interpretation_set: List[Interpretation] | None,
724
+ input_component_column: Column | None,
725
+ interpret_component_column: Column | None,
726
+ ):
727
+ if interpretation_btn:
728
+ interpretation_btn.click(
729
+ self.interpret_func,
730
+ inputs=self.input_components + self.output_components,
731
+ outputs=(interpretation_set or []) + [input_component_column, interpret_component_column], # type: ignore
732
+ preprocess=False,
733
+ )
734
+
735
+ def attach_flagging_events(self, flag_btns: List[Button] | None, clear_btn: Button):
736
+ if flag_btns:
737
+ if self.interface_type in [
738
+ InterfaceTypes.STANDARD,
739
+ InterfaceTypes.OUTPUT_ONLY,
740
+ InterfaceTypes.UNIFIED,
741
+ ]:
742
+ if self.allow_flagging == "auto":
743
+ flag_method = FlagMethod(
744
+ self.flagging_callback, "", "", visual_feedback=False
745
+ )
746
+ flag_btns[0].click( # flag_btns[0] is just the "Submit" button
747
+ flag_method,
748
+ inputs=self.input_components,
749
+ outputs=None,
750
+ preprocess=False,
751
+ queue=False,
752
+ )
753
+ return
754
+
755
+ if self.interface_type == InterfaceTypes.UNIFIED:
756
+ flag_components = self.input_components
757
+ else:
758
+ flag_components = self.input_components + self.output_components
759
+
760
+ for flag_btn, (label, value) in zip(flag_btns, self.flagging_options):
761
+ assert isinstance(value, str)
762
+ flag_method = FlagMethod(self.flagging_callback, label, value)
763
+ flag_btn.click(
764
+ lambda: Button.update(value="Saving...", interactive=False),
765
+ None,
766
+ flag_btn,
767
+ queue=False,
768
+ )
769
+ flag_btn.click(
770
+ flag_method,
771
+ inputs=flag_components,
772
+ outputs=flag_btn,
773
+ preprocess=False,
774
+ queue=False,
775
+ )
776
+ clear_btn.click(
777
+ flag_method.reset,
778
+ None,
779
+ flag_btn,
780
+ queue=False,
781
+ )
782
+
783
+ def render_examples(self):
784
+ if self.examples:
785
+ non_state_inputs = [
786
+ c for c in self.input_components if not isinstance(c, State)
787
+ ]
788
+ non_state_outputs = [
789
+ c for c in self.output_components if not isinstance(c, State)
790
+ ]
791
+ self.examples_handler = Examples(
792
+ examples=self.examples,
793
+ inputs=non_state_inputs, # type: ignore
794
+ outputs=non_state_outputs, # type: ignore
795
+ fn=self.fn,
796
+ cache_examples=self.cache_examples,
797
+ examples_per_page=self.examples_per_page,
798
+ _api_mode=self.api_mode,
799
+ batch=self.batch,
800
+ )
801
+
802
+ def __str__(self):
803
+ return self.__repr__()
804
+
805
+ def __repr__(self):
806
+ repr = f"Gradio Interface for: {self.__name__}"
807
+ repr += "\n" + "-" * len(repr)
808
+ repr += "\ninputs:"
809
+ for component in self.input_components:
810
+ repr += "\n|-{}".format(str(component))
811
+ repr += "\noutputs:"
812
+ for component in self.output_components:
813
+ repr += "\n|-{}".format(str(component))
814
+ return repr
815
+
816
+ async def interpret_func(self, *args):
817
+ return await self.interpret(list(args)) + [
818
+ Column.update(visible=False),
819
+ Column.update(visible=True),
820
+ ]
821
+
822
+ async def interpret(self, raw_input: List[Any]) -> List[Any]:
823
+ return [
824
+ {"original": raw_value, "interpretation": interpretation}
825
+ for interpretation, raw_value in zip(
826
+ (await interpretation.run_interpret(self, raw_input))[0], raw_input
827
+ )
828
+ ]
829
+
830
+ def test_launch(self) -> None:
831
+ """
832
+ Deprecated.
833
+ """
834
+ warnings.warn("The Interface.test_launch() function is deprecated.")
835
+
836
+
837
+ @document()
838
+ class TabbedInterface(Blocks):
839
+ """
840
+ A TabbedInterface is created by providing a list of Interfaces, each of which gets
841
+ rendered in a separate tab.
842
+ Demos: stt_or_tts
843
+ """
844
+
845
+ def __init__(
846
+ self,
847
+ interface_list: List[Interface],
848
+ tab_names: List[str] | None = None,
849
+ title: str | None = None,
850
+ theme: Theme | None = None,
851
+ analytics_enabled: bool | None = None,
852
+ css: str | None = None,
853
+ ):
854
+ """
855
+ Parameters:
856
+ interface_list: a list of interfaces to be rendered in tabs.
857
+ tab_names: a list of tab names. If None, the tab names will be "Tab 1", "Tab 2", etc.
858
+ title: a title for the interface; if provided, appears above the input and output components in large font. Also used as the tab title when opened in a browser window.
859
+ analytics_enabled: whether to allow basic telemetry. If None, will use GRADIO_ANALYTICS_ENABLED environment variable or default to True.
860
+ css: custom css or path to custom css file to apply to entire Blocks
861
+ Returns:
862
+ a Gradio Tabbed Interface for the given interfaces
863
+ """
864
+ super().__init__(
865
+ title=title or "Gradio",
866
+ theme=theme,
867
+ analytics_enabled=analytics_enabled,
868
+ mode="tabbed_interface",
869
+ css=css,
870
+ )
871
+ if tab_names is None:
872
+ tab_names = ["Tab {}".format(i) for i in range(len(interface_list))]
873
+ with self:
874
+ if title:
875
+ Markdown(
876
+ "<h1 style='text-align: center; margin-bottom: 1rem'>"
877
+ + title
878
+ + "</h1>"
879
+ )
880
+ with Tabs():
881
+ for (interface, tab_name) in zip(interface_list, tab_names):
882
+ with Tab(label=tab_name):
883
+ interface.render()
884
+
885
+
886
+ def close_all(verbose: bool = True) -> None:
887
+ for io in Interface.get_instances():
888
+ io.close(verbose)