File size: 4,962 Bytes
11b6e82
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
"""
Ways to transform interfaces to produce new interfaces
"""
import asyncio
import warnings

import gradio
from gradio.documentation import document, set_documentation_group

set_documentation_group("mix_interface")


@document()
class Parallel(gradio.Interface):
    """
    Creates a new Interface consisting of multiple Interfaces in parallel (comparing their outputs).
    The Interfaces to put in Parallel must share the same input components (but can have different output components).

    Demos: interface_parallel, interface_parallel_load
    Guides: advanced_interface_features
    """

    def __init__(self, *interfaces: gradio.Interface, **options):
        """
        Parameters:
            interfaces: any number of Interface objects that are to be compared in parallel
            options: additional kwargs that are passed into the new Interface object to customize it
        Returns:
            an Interface object comparing the given models
        """
        outputs = []

        for interface in interfaces:
            if not (isinstance(interface, gradio.Interface)):
                warnings.warn(
                    "Parallel requires all inputs to be of type Interface. "
                    "May not work as expected."
                )
            outputs.extend(interface.output_components)

        async def parallel_fn(*args):
            return_values_with_durations = await asyncio.gather(
                *[interface.call_function(0, list(args)) for interface in interfaces]
            )
            return_values = [rv["prediction"] for rv in return_values_with_durations]
            combined_list = []
            for interface, return_value in zip(interfaces, return_values):
                if len(interface.output_components) == 1:
                    combined_list.append(return_value)
                else:
                    combined_list.extend(return_value)
            if len(outputs) == 1:
                return combined_list[0]
            return combined_list

        parallel_fn.__name__ = " | ".join([io.__name__ for io in interfaces])

        kwargs = {
            "fn": parallel_fn,
            "inputs": interfaces[0].input_components,
            "outputs": outputs,
        }
        kwargs.update(options)
        super().__init__(**kwargs)


@document()
class Series(gradio.Interface):
    """
    Creates a new Interface from multiple Interfaces in series (the output of one is fed as the input to the next,
    and so the input and output components must agree between the interfaces).

    Demos: interface_series, interface_series_load
    Guides: advanced_interface_features
    """

    def __init__(self, *interfaces: gradio.Interface, **options):
        """
        Parameters:
            interfaces: any number of Interface objects that are to be connected in series
            options: additional kwargs that are passed into the new Interface object to customize it
        Returns:
            an Interface object connecting the given models
        """

        async def connected_fn(*data):
            for idx, interface in enumerate(interfaces):
                # skip preprocessing for first interface since the Series interface will include it
                if idx > 0 and not (interface.api_mode):
                    data = [
                        input_component.preprocess(data[i])
                        for i, input_component in enumerate(interface.input_components)
                    ]

                # run all of predictions sequentially
                data = (await interface.call_function(0, list(data)))["prediction"]
                if len(interface.output_components) == 1:
                    data = [data]

                # skip postprocessing for final interface since the Series interface will include it
                if idx < len(interfaces) - 1 and not (interface.api_mode):
                    data = [
                        output_component.postprocess(data[i])
                        for i, output_component in enumerate(
                            interface.output_components
                        )
                    ]

            if len(interface.output_components) == 1:  # type: ignore
                return data[0]
            return data

        for interface in interfaces:
            if not (isinstance(interface, gradio.Interface)):
                warnings.warn(
                    "Series requires all inputs to be of type Interface. May "
                    "not work as expected."
                )
        connected_fn.__name__ = " => ".join([io.__name__ for io in interfaces])

        kwargs = {
            "fn": connected_fn,
            "inputs": interfaces[0].input_components,
            "outputs": interfaces[-1].output_components,
            "_api_mode": interfaces[0].api_mode,  # TODO: set api_mode per-interface
        }
        kwargs.update(options)
        super().__init__(**kwargs)