File size: 6,077 Bytes
11b6e82
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
"""Utility function for gradio/external.py"""

import base64
import json
import math
import operator
import re
import warnings
from typing import Any, Dict, List, Tuple

import requests
import websockets
import yaml
from packaging import version
from websockets.legacy.protocol import WebSocketCommonProtocol

from gradio import components, exceptions

##################
# Helper functions for processing tabular data
##################


def get_tabular_examples(model_name: str) -> Dict[str, List[float]]:
    readme = requests.get(f"https://huggingface.co/{model_name}/resolve/main/README.md")
    if readme.status_code != 200:
        warnings.warn(f"Cannot load examples from README for {model_name}", UserWarning)
        example_data = {}
    else:
        yaml_regex = re.search(
            "(?:^|[\r\n])---[\n\r]+([\\S\\s]*?)[\n\r]+---([\n\r]|$)", readme.text
        )
        if yaml_regex is None:
            example_data = {}
        else:
            example_yaml = next(
                yaml.safe_load_all(readme.text[: yaml_regex.span()[-1]])
            )
            example_data = example_yaml.get("widget", {}).get("structuredData", {})
    if not example_data:
        raise ValueError(
            f"No example data found in README.md of {model_name} - Cannot build gradio demo. "
            "See the README.md here: https://huggingface.co/scikit-learn/tabular-playground/blob/main/README.md "
            "for a reference on how to provide example data to your model."
        )
    # replace nan with string NaN for inference API
    for data in example_data.values():
        for i, val in enumerate(data):
            if isinstance(val, float) and math.isnan(val):
                data[i] = "NaN"
    return example_data


def cols_to_rows(
    example_data: Dict[str, List[float]]
) -> Tuple[List[str], List[List[float]]]:
    headers = list(example_data.keys())
    n_rows = max(len(example_data[header] or []) for header in headers)
    data = []
    for row_index in range(n_rows):
        row_data = []
        for header in headers:
            col = example_data[header] or []
            if row_index >= len(col):
                row_data.append("NaN")
            else:
                row_data.append(col[row_index])
        data.append(row_data)
    return headers, data


def rows_to_cols(incoming_data: Dict) -> Dict[str, Dict[str, Dict[str, List[str]]]]:
    data_column_wise = {}
    for i, header in enumerate(incoming_data["headers"]):
        data_column_wise[header] = [str(row[i]) for row in incoming_data["data"]]
    return {"inputs": {"data": data_column_wise}}


##################
# Helper functions for processing other kinds of data
##################


def postprocess_label(scores: Dict) -> Dict:
    sorted_pred = sorted(scores.items(), key=operator.itemgetter(1), reverse=True)
    return {
        "label": sorted_pred[0][0],
        "confidences": [
            {"label": pred[0], "confidence": pred[1]} for pred in sorted_pred
        ],
    }


def encode_to_base64(r: requests.Response) -> str:
    # Handles the different ways HF API returns the prediction
    base64_repr = base64.b64encode(r.content).decode("utf-8")
    data_prefix = ";base64,"
    # Case 1: base64 representation already includes data prefix
    if data_prefix in base64_repr:
        return base64_repr
    else:
        content_type = r.headers.get("content-type")
        # Case 2: the data prefix is a key in the response
        if content_type == "application/json":
            try:
                content_type = r.json()[0]["content-type"]
                base64_repr = r.json()[0]["blob"]
            except KeyError:
                raise ValueError(
                    "Cannot determine content type returned" "by external API."
                )
        # Case 3: the data prefix is included in the response headers
        else:
            pass
        new_base64 = "data:{};base64,".format(content_type) + base64_repr
        return new_base64


##################
# Helper functions for connecting to websockets
##################


async def get_pred_from_ws(
    websocket: WebSocketCommonProtocol, data: str, hash_data: str
) -> Dict[str, Any]:
    completed = False
    resp = {}
    while not completed:
        msg = await websocket.recv()
        resp = json.loads(msg)
        if resp["msg"] == "queue_full":
            raise exceptions.Error("Queue is full! Please try again.")
        if resp["msg"] == "send_hash":
            await websocket.send(hash_data)
        elif resp["msg"] == "send_data":
            await websocket.send(data)
        completed = resp["msg"] == "process_completed"
    return resp["output"]


def get_ws_fn(ws_url, headers):
    async def ws_fn(data, hash_data):
        async with websockets.connect(  # type: ignore
            ws_url, open_timeout=10, extra_headers=headers
        ) as websocket:
            return await get_pred_from_ws(websocket, data, hash_data)

    return ws_fn


def use_websocket(config, dependency):
    queue_enabled = config.get("enable_queue", False)
    queue_uses_websocket = version.parse(
        config.get("version", "2.0")
    ) >= version.Version("3.2")
    dependency_uses_queue = dependency.get("queue", False) is not False
    return queue_enabled and queue_uses_websocket and dependency_uses_queue


##################
# Helper function for cleaning up an Interface loaded from HF Spaces
##################


def streamline_spaces_interface(config: Dict) -> Dict:
    """Streamlines the interface config dictionary to remove unnecessary keys."""
    config["inputs"] = [
        components.get_component_instance(component)
        for component in config["input_components"]
    ]
    config["outputs"] = [
        components.get_component_instance(component)
        for component in config["output_components"]
    ]
    parameters = {
        "article",
        "description",
        "flagging_options",
        "inputs",
        "outputs",
        "title",
    }
    config = {k: config[k] for k in parameters}
    return config