File size: 4,962 Bytes
947e9b9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 |
"""
Ways to transform interfaces to produce new interfaces
"""
import asyncio
import warnings
import gradio
from gradio.documentation import document, set_documentation_group
set_documentation_group("mix_interface")
@document()
class Parallel(gradio.Interface):
"""
Creates a new Interface consisting of multiple Interfaces in parallel (comparing their outputs).
The Interfaces to put in Parallel must share the same input components (but can have different output components).
Demos: interface_parallel, interface_parallel_load
Guides: advanced_interface_features
"""
def __init__(self, *interfaces: gradio.Interface, **options):
"""
Parameters:
interfaces: any number of Interface objects that are to be compared in parallel
options: additional kwargs that are passed into the new Interface object to customize it
Returns:
an Interface object comparing the given models
"""
outputs = []
for interface in interfaces:
if not (isinstance(interface, gradio.Interface)):
warnings.warn(
"Parallel requires all inputs to be of type Interface. "
"May not work as expected."
)
outputs.extend(interface.output_components)
async def parallel_fn(*args):
return_values_with_durations = await asyncio.gather(
*[interface.call_function(0, list(args)) for interface in interfaces]
)
return_values = [rv["prediction"] for rv in return_values_with_durations]
combined_list = []
for interface, return_value in zip(interfaces, return_values):
if len(interface.output_components) == 1:
combined_list.append(return_value)
else:
combined_list.extend(return_value)
if len(outputs) == 1:
return combined_list[0]
return combined_list
parallel_fn.__name__ = " | ".join([io.__name__ for io in interfaces])
kwargs = {
"fn": parallel_fn,
"inputs": interfaces[0].input_components,
"outputs": outputs,
}
kwargs.update(options)
super().__init__(**kwargs)
@document()
class Series(gradio.Interface):
"""
Creates a new Interface from multiple Interfaces in series (the output of one is fed as the input to the next,
and so the input and output components must agree between the interfaces).
Demos: interface_series, interface_series_load
Guides: advanced_interface_features
"""
def __init__(self, *interfaces: gradio.Interface, **options):
"""
Parameters:
interfaces: any number of Interface objects that are to be connected in series
options: additional kwargs that are passed into the new Interface object to customize it
Returns:
an Interface object connecting the given models
"""
async def connected_fn(*data):
for idx, interface in enumerate(interfaces):
# skip preprocessing for first interface since the Series interface will include it
if idx > 0 and not (interface.api_mode):
data = [
input_component.preprocess(data[i])
for i, input_component in enumerate(interface.input_components)
]
# run all of predictions sequentially
data = (await interface.call_function(0, list(data)))["prediction"]
if len(interface.output_components) == 1:
data = [data]
# skip postprocessing for final interface since the Series interface will include it
if idx < len(interfaces) - 1 and not (interface.api_mode):
data = [
output_component.postprocess(data[i])
for i, output_component in enumerate(
interface.output_components
)
]
if len(interface.output_components) == 1: # type: ignore
return data[0]
return data
for interface in interfaces:
if not (isinstance(interface, gradio.Interface)):
warnings.warn(
"Series requires all inputs to be of type Interface. May "
"not work as expected."
)
connected_fn.__name__ = " => ".join([io.__name__ for io in interfaces])
kwargs = {
"fn": connected_fn,
"inputs": interfaces[0].input_components,
"outputs": interfaces[-1].output_components,
"_api_mode": interfaces[0].api_mode, # TODO: set api_mode per-interface
}
kwargs.update(options)
super().__init__(**kwargs)
|