|
""" |
|
Ways to transform interfaces to produce new interfaces |
|
""" |
|
import asyncio |
|
import warnings |
|
|
|
import gradio |
|
from gradio.documentation import document, set_documentation_group |
|
|
|
set_documentation_group("mix_interface") |
|
|
|
|
|
@document() |
|
class Parallel(gradio.Interface): |
|
""" |
|
Creates a new Interface consisting of multiple Interfaces in parallel (comparing their outputs). |
|
The Interfaces to put in Parallel must share the same input components (but can have different output components). |
|
|
|
Demos: interface_parallel, interface_parallel_load |
|
Guides: advanced_interface_features |
|
""" |
|
|
|
def __init__(self, *interfaces: gradio.Interface, **options): |
|
""" |
|
Parameters: |
|
interfaces: any number of Interface objects that are to be compared in parallel |
|
options: additional kwargs that are passed into the new Interface object to customize it |
|
Returns: |
|
an Interface object comparing the given models |
|
""" |
|
outputs = [] |
|
|
|
for interface in interfaces: |
|
if not (isinstance(interface, gradio.Interface)): |
|
warnings.warn( |
|
"Parallel requires all inputs to be of type Interface. " |
|
"May not work as expected." |
|
) |
|
outputs.extend(interface.output_components) |
|
|
|
async def parallel_fn(*args): |
|
return_values_with_durations = await asyncio.gather( |
|
*[interface.call_function(0, list(args)) for interface in interfaces] |
|
) |
|
return_values = [rv["prediction"] for rv in return_values_with_durations] |
|
combined_list = [] |
|
for interface, return_value in zip(interfaces, return_values): |
|
if len(interface.output_components) == 1: |
|
combined_list.append(return_value) |
|
else: |
|
combined_list.extend(return_value) |
|
if len(outputs) == 1: |
|
return combined_list[0] |
|
return combined_list |
|
|
|
parallel_fn.__name__ = " | ".join([io.__name__ for io in interfaces]) |
|
|
|
kwargs = { |
|
"fn": parallel_fn, |
|
"inputs": interfaces[0].input_components, |
|
"outputs": outputs, |
|
} |
|
kwargs.update(options) |
|
super().__init__(**kwargs) |
|
|
|
|
|
@document() |
|
class Series(gradio.Interface): |
|
""" |
|
Creates a new Interface from multiple Interfaces in series (the output of one is fed as the input to the next, |
|
and so the input and output components must agree between the interfaces). |
|
|
|
Demos: interface_series, interface_series_load |
|
Guides: advanced_interface_features |
|
""" |
|
|
|
def __init__(self, *interfaces: gradio.Interface, **options): |
|
""" |
|
Parameters: |
|
interfaces: any number of Interface objects that are to be connected in series |
|
options: additional kwargs that are passed into the new Interface object to customize it |
|
Returns: |
|
an Interface object connecting the given models |
|
""" |
|
|
|
async def connected_fn(*data): |
|
for idx, interface in enumerate(interfaces): |
|
|
|
if idx > 0 and not (interface.api_mode): |
|
data = [ |
|
input_component.preprocess(data[i]) |
|
for i, input_component in enumerate(interface.input_components) |
|
] |
|
|
|
|
|
data = (await interface.call_function(0, list(data)))["prediction"] |
|
if len(interface.output_components) == 1: |
|
data = [data] |
|
|
|
|
|
if idx < len(interfaces) - 1 and not (interface.api_mode): |
|
data = [ |
|
output_component.postprocess(data[i]) |
|
for i, output_component in enumerate( |
|
interface.output_components |
|
) |
|
] |
|
|
|
if len(interface.output_components) == 1: |
|
return data[0] |
|
return data |
|
|
|
for interface in interfaces: |
|
if not (isinstance(interface, gradio.Interface)): |
|
warnings.warn( |
|
"Series requires all inputs to be of type Interface. May " |
|
"not work as expected." |
|
) |
|
connected_fn.__name__ = " => ".join([io.__name__ for io in interfaces]) |
|
|
|
kwargs = { |
|
"fn": connected_fn, |
|
"inputs": interfaces[0].input_components, |
|
"outputs": interfaces[-1].output_components, |
|
"_api_mode": interfaces[0].api_mode, |
|
} |
|
kwargs.update(options) |
|
super().__init__(**kwargs) |
|
|