|
"""Utility function for gradio/external.py""" |
|
|
|
import base64 |
|
import json |
|
import math |
|
import operator |
|
import re |
|
import warnings |
|
from typing import Any, Dict, List, Tuple |
|
|
|
import requests |
|
import websockets |
|
import yaml |
|
from packaging import version |
|
from websockets.legacy.protocol import WebSocketCommonProtocol |
|
|
|
from gradio import components, exceptions |
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_tabular_examples(model_name: str) -> Dict[str, List[float]]: |
|
readme = requests.get(f"https://huggingface.co/{model_name}/resolve/main/README.md") |
|
if readme.status_code != 200: |
|
warnings.warn(f"Cannot load examples from README for {model_name}", UserWarning) |
|
example_data = {} |
|
else: |
|
yaml_regex = re.search( |
|
"(?:^|[\r\n])---[\n\r]+([\\S\\s]*?)[\n\r]+---([\n\r]|$)", readme.text |
|
) |
|
if yaml_regex is None: |
|
example_data = {} |
|
else: |
|
example_yaml = next( |
|
yaml.safe_load_all(readme.text[: yaml_regex.span()[-1]]) |
|
) |
|
example_data = example_yaml.get("widget", {}).get("structuredData", {}) |
|
if not example_data: |
|
raise ValueError( |
|
f"No example data found in README.md of {model_name} - Cannot build gradio demo. " |
|
"See the README.md here: https://huggingface.co/scikit-learn/tabular-playground/blob/main/README.md " |
|
"for a reference on how to provide example data to your model." |
|
) |
|
|
|
for data in example_data.values(): |
|
for i, val in enumerate(data): |
|
if isinstance(val, float) and math.isnan(val): |
|
data[i] = "NaN" |
|
return example_data |
|
|
|
|
|
def cols_to_rows( |
|
example_data: Dict[str, List[float]] |
|
) -> Tuple[List[str], List[List[float]]]: |
|
headers = list(example_data.keys()) |
|
n_rows = max(len(example_data[header] or []) for header in headers) |
|
data = [] |
|
for row_index in range(n_rows): |
|
row_data = [] |
|
for header in headers: |
|
col = example_data[header] or [] |
|
if row_index >= len(col): |
|
row_data.append("NaN") |
|
else: |
|
row_data.append(col[row_index]) |
|
data.append(row_data) |
|
return headers, data |
|
|
|
|
|
def rows_to_cols(incoming_data: Dict) -> Dict[str, Dict[str, Dict[str, List[str]]]]: |
|
data_column_wise = {} |
|
for i, header in enumerate(incoming_data["headers"]): |
|
data_column_wise[header] = [str(row[i]) for row in incoming_data["data"]] |
|
return {"inputs": {"data": data_column_wise}} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def postprocess_label(scores: Dict) -> Dict: |
|
sorted_pred = sorted(scores.items(), key=operator.itemgetter(1), reverse=True) |
|
return { |
|
"label": sorted_pred[0][0], |
|
"confidences": [ |
|
{"label": pred[0], "confidence": pred[1]} for pred in sorted_pred |
|
], |
|
} |
|
|
|
|
|
def encode_to_base64(r: requests.Response) -> str: |
|
|
|
base64_repr = base64.b64encode(r.content).decode("utf-8") |
|
data_prefix = ";base64," |
|
|
|
if data_prefix in base64_repr: |
|
return base64_repr |
|
else: |
|
content_type = r.headers.get("content-type") |
|
|
|
if content_type == "application/json": |
|
try: |
|
content_type = r.json()[0]["content-type"] |
|
base64_repr = r.json()[0]["blob"] |
|
except KeyError: |
|
raise ValueError( |
|
"Cannot determine content type returned" "by external API." |
|
) |
|
|
|
else: |
|
pass |
|
new_base64 = "data:{};base64,".format(content_type) + base64_repr |
|
return new_base64 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def get_pred_from_ws( |
|
websocket: WebSocketCommonProtocol, data: str, hash_data: str |
|
) -> Dict[str, Any]: |
|
completed = False |
|
resp = {} |
|
while not completed: |
|
msg = await websocket.recv() |
|
resp = json.loads(msg) |
|
if resp["msg"] == "queue_full": |
|
raise exceptions.Error("Queue is full! Please try again.") |
|
if resp["msg"] == "send_hash": |
|
await websocket.send(hash_data) |
|
elif resp["msg"] == "send_data": |
|
await websocket.send(data) |
|
completed = resp["msg"] == "process_completed" |
|
return resp["output"] |
|
|
|
|
|
def get_ws_fn(ws_url, headers): |
|
async def ws_fn(data, hash_data): |
|
async with websockets.connect( |
|
ws_url, open_timeout=10, extra_headers=headers |
|
) as websocket: |
|
return await get_pred_from_ws(websocket, data, hash_data) |
|
|
|
return ws_fn |
|
|
|
|
|
def use_websocket(config, dependency): |
|
queue_enabled = config.get("enable_queue", False) |
|
queue_uses_websocket = version.parse( |
|
config.get("version", "2.0") |
|
) >= version.Version("3.2") |
|
dependency_uses_queue = dependency.get("queue", False) is not False |
|
return queue_enabled and queue_uses_websocket and dependency_uses_queue |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def streamline_spaces_interface(config: Dict) -> Dict: |
|
"""Streamlines the interface config dictionary to remove unnecessary keys.""" |
|
config["inputs"] = [ |
|
components.get_component_instance(component) |
|
for component in config["input_components"] |
|
] |
|
config["outputs"] = [ |
|
components.get_component_instance(component) |
|
for component in config["output_components"] |
|
] |
|
parameters = { |
|
"article", |
|
"description", |
|
"flagging_options", |
|
"inputs", |
|
"outputs", |
|
"theme", |
|
"title", |
|
} |
|
config = {k: config[k] for k in parameters} |
|
return config |
|
|