Spaces:
Sleeping
Sleeping
from datetime import datetime | |
from flask import Flask | |
from flask import make_response | |
from flask import request | |
from flask import send_from_directory, redirect | |
from typing import Literal | |
import json | |
import logging | |
import numpy as np | |
import os | |
import portpicker | |
import requests | |
import shutil | |
import sys | |
import threading | |
import traceback | |
import urllib.parse | |
import zipfile | |
_VISUAL_BLOCKS_BUNDLE_VERSION = "1716228179" | |
# Disable logging from werkzeug. | |
# | |
# Without this, flask will show a warning about using dev server (which is OK | |
# in our usecase). | |
logging.getLogger("werkzeug").disabled = True | |
# Function registrations. | |
GENERIC_FNS = {} | |
TEXT_TO_TEXT_FNS = {} | |
TEXT_TO_TENSORS_FNS = {} | |
def register_vb_fn( | |
type: Literal["generic", "text_to_text", "text_to_tensors"] = "generic" | |
): | |
"""A function decorator to register python function with Visual Blocks. | |
Args: | |
type: | |
the type of function to register for. | |
Currently, VB supports the following function types: | |
generic: | |
A function or iterable of functions, defined in the same Colab notebook, | |
that Visual Blocks can call to implement a generic model runner block. | |
A generic inference function must take a single argument, the input | |
tensors as an iterable of numpy.ndarrays; run inference; and return the | |
output tensors, also as an iterable of numpy.ndarrays. | |
text_to_text: | |
A function or iterable of functions, defined in the same Colab notebook, | |
that Visual Blocks can call to implement a text-to-text model runner | |
block. | |
A text_to_text function must take a string and return a string. | |
text_to_tensors: | |
A function or iterable of functions, defined in the same Colab notebook, | |
that Visual Blocks can call to implement a text-to-tensors model runner | |
block. | |
A text_to_tensors function must take a string and return the output | |
tensors, as an iterable of numpy.ndarrays. | |
""" | |
def decorator_register_vb_fn(func): | |
func_name = func.__name__ | |
if type == "generic": | |
GENERIC_FNS[func_name] = func | |
elif type == "text_to_text": | |
TEXT_TO_TEXT_FNS[func_name] = func | |
elif type == "text_to_tensors": | |
TEXT_TO_TENSORS_FNS[func_name] = func | |
return func | |
return decorator_register_vb_fn | |
def _json_to_ndarray(json_tensor): | |
"""Convert a JSON dictionary from the web app to an np.ndarray.""" | |
array = np.array(json_tensor["tensorValues"]) | |
array.shape = json_tensor["tensorShape"] | |
return array | |
def _ndarray_to_json(array): | |
"""Convert a np.ndarray to the JSON dictionary for the web app.""" | |
values = array.ravel().tolist() | |
shape = array.shape | |
return { | |
"tensorValues": values, | |
"tensorShape": shape, | |
} | |
def _make_json_response(obj): | |
body = json.dumps(obj) | |
resp = make_response(body) | |
resp.headers["Content-Type"] = "application/json" | |
return resp | |
def _ensure_iterable(x): | |
"""Turn x into an iterable if not already iterable.""" | |
if x is None: | |
return () | |
elif hasattr(x, "__iter__"): | |
return x | |
else: | |
return (x,) | |
def _add_to_registry(fns, registry): | |
"""Adds the functions to the given registry (dict).""" | |
for fn in fns: | |
registry[fn.__name__] = fn | |
def _is_list_of_nd_array(obj): | |
return isinstance(obj, list) and all(isinstance(elem, np.ndarray) for elem in obj) | |
def Server( | |
host="0.0.0.0", | |
port=7860, | |
generic=None, | |
text_to_text=None, | |
text_to_tensors=None, | |
height=900, | |
tmp_dir="/tmp", | |
read_saved_pipeline=True, | |
): | |
"""Creates a server that serves visual blocks web app in an iFrame. | |
Other than serving the web app, it will also listen to requests sent from the | |
web app at various API end points. Once a request is received, it will use the | |
data in the request body to call the corresponding functions that users have | |
registered with VB, either through the '@register_vb_fn' decorator, or passed | |
in when creating the server. | |
Args: | |
generic: | |
A function or iterable of functions, defined in the same Colab notebook, | |
that Visual Blocks can call to implement a generic model runner block. | |
A generic inference function must take a single argument, the input | |
tensors as an iterable of numpy.ndarrays; run inference; and return the output | |
tensors, also as an iterable of numpy.ndarrays. | |
text_to_text: | |
A function or iterable of functions, defined in the same Colab notebook, | |
that Visual Blocks can call to implement a text-to-text model runner | |
block. | |
A text_to_text function must take a string and return a string. | |
text_to_tensors: | |
A function or iterable of functions, defined in the same Colab notebook, | |
that Visual Blocks can call to implement a text-to-tensors model runner | |
block. | |
A text_to_tensors function must take a string and return the output | |
tensors, as an iterable of numpy.ndarrays. | |
height: | |
The height of the embedded iFrame. | |
tmp_dir: | |
The tmp dir where the server stores the web app's static resources. | |
read_saved_pipeline: | |
Whether to read the saved pipeline in the notebook or not. | |
""" | |
_add_to_registry(_ensure_iterable(generic), GENERIC_FNS) | |
_add_to_registry(_ensure_iterable(text_to_text), TEXT_TO_TEXT_FNS) | |
_add_to_registry(_ensure_iterable(text_to_tensors), TEXT_TO_TENSORS_FNS) | |
app = Flask(__name__) | |
# Disable startup messages. | |
cli = sys.modules["flask.cli"] | |
cli.show_server_banner = lambda *x: None | |
# Prepare tmp dir and log file. | |
base_path = tmp_dir + "/visual-blocks-colab" | |
if os.path.exists(base_path): | |
shutil.rmtree(base_path) | |
os.mkdir(base_path) | |
log_file_path = base_path + "/log" | |
open(log_file_path, "w").close() | |
# Download the zip file that bundles the visual blocks web app. | |
bundle_target_path = os.path.join(base_path, "visual_blocks.zip") | |
url = ( | |
"https://storage.googleapis.com/tfweb/rapsai-colab-bundles/visual_blocks_%s.zip" | |
% _VISUAL_BLOCKS_BUNDLE_VERSION | |
) | |
r = requests.get(url) | |
with open(bundle_target_path, "wb") as zip_file: | |
zip_file.write(r.content) | |
# Unzip it. | |
# This will unzip all files to {base_path}/build. | |
with zipfile.ZipFile(bundle_target_path, "r") as zip_ref: | |
zip_ref.extractall(base_path) | |
site_root_path = os.path.join(base_path, "build") | |
def log(msg): | |
"""Logs the given message to the log file.""" | |
now = datetime.now() | |
dt_string = now.strftime("%d/%m/%Y %H:%M:%S") | |
with open(log_file_path, "a") as log_file: | |
log_file.write("{}: {}\n".format(dt_string, msg)) | |
def list_inference_functions(): | |
result = {} | |
if len(GENERIC_FNS): | |
result["generic"] = list(GENERIC_FNS.keys()) | |
result["generic"].sort() | |
if len(TEXT_TO_TEXT_FNS): | |
result["text_to_text"] = list(TEXT_TO_TEXT_FNS.keys()) | |
result["text_to_text"].sort() | |
if len(TEXT_TO_TENSORS_FNS): | |
result["text_to_tensors"] = list(TEXT_TO_TENSORS_FNS.keys()) | |
result["text_to_tensors"].sort() | |
return _make_json_response(result) | |
# Note: using "/api/..." for POST requests is not allowed. | |
def inference_generic(): | |
"""Handler for the generic api endpoint.""" | |
result = {} | |
try: | |
func_name = request.json["function"] | |
inference_fn = GENERIC_FNS[func_name] | |
input_tensors = [_json_to_ndarray(x) for x in request.json["tensors"]] | |
output_tensors = inference_fn(input_tensors) | |
if not _is_list_of_nd_array(output_tensors): | |
result = { | |
"error": "The returned value from %s is not a list of ndarray" | |
% func_name | |
} | |
else: | |
result["tensors"] = [_ndarray_to_json(x) for x in output_tensors] | |
except Exception as e: | |
msg = "".join(traceback.format_exception(type(e), e, e.__traceback__)) | |
result = {"error": msg} | |
finally: | |
return _make_json_response(result) | |
# Note: using "/api/..." for POST requests is not allowed. | |
def inference_text_to_text(): | |
"""Handler for the text_to_text api endpoint.""" | |
result = {} | |
try: | |
func_name = request.json["function"] | |
inference_fn = TEXT_TO_TEXT_FNS[func_name] | |
text = request.json["text"] | |
ret = inference_fn(text) | |
if not isinstance(ret, str): | |
result = { | |
"error": "The returned value from %s is not a string" % func_name | |
} | |
else: | |
result["text"] = ret | |
except Exception as e: | |
msg = "".join(traceback.format_exception(type(e), e, e.__traceback__)) | |
result = {"error": msg} | |
finally: | |
return _make_json_response(result) | |
# Note: using "/api/..." for POST requests is not allowed. | |
def inference_text_to_tensors(): | |
"""Handler for the text_to_tensors api endpoint.""" | |
result = {} | |
try: | |
func_name = request.json["function"] | |
inference_fn = TEXT_TO_TENSORS_FNS[func_name] | |
text = request.json["text"] | |
output_tensors = inference_fn(text) | |
if not _is_list_of_nd_array(output_tensors): | |
result = { | |
"error": "The returned value from %s is not a list of ndarray" | |
% func_name | |
} | |
else: | |
result["tensors"] = [_ndarray_to_json(x) for x in output_tensors] | |
except Exception as e: | |
msg = "".join(traceback.format_exception(type(e), e, e.__traceback__)) | |
result = {"error": msg} | |
finally: | |
return _make_json_response(result) | |
def redirect_to_edit_new(): | |
"""Redirect root URL to /#/edit/new/""" | |
return redirect("/#/edit/new/") | |
def get_static(path): | |
"""Handler for serving static resources.""" | |
if path == "": | |
path = "index.html" | |
return send_from_directory(site_root_path, path) | |
# Start background server. | |
# threading.Thread(target=app.run, kwargs={"host": host, "port": port}).start() | |
# A thin wrapper class for exposing a "display" method. | |
class _Server: | |
def run(self): | |
print("Visual Blocks server started at http://%s:%s" % (host, port)) | |
app.run(host=host, port=port) | |
return _Server() | |