Chakshu123's picture
Duplicate from Chakshu123/sketch-colorization-with-hint
443d045
raw
history blame
6.72 kB
"""
Defines helper methods useful for setting up ports, launching servers, and
creating tunnels.
"""
from __future__ import annotations
import os
import socket
import threading
import time
import warnings
from typing import TYPE_CHECKING, Tuple
import requests
import uvicorn
from gradio.routes import App
from gradio.tunneling import Tunnel
if TYPE_CHECKING: # Only import for type checking (to avoid circular imports).
from gradio.blocks import Blocks
# By default, the local server will try to open on localhost, port 7860.
# If that is not available, then it will try 7861, 7862, ... 7959.
INITIAL_PORT_VALUE = int(os.getenv("GRADIO_SERVER_PORT", "7860"))
TRY_NUM_PORTS = int(os.getenv("GRADIO_NUM_PORTS", "100"))
LOCALHOST_NAME = os.getenv("GRADIO_SERVER_NAME", "127.0.0.1")
GRADIO_API_SERVER = "https://api.gradio.app/v2/tunnel-request"
class Server(uvicorn.Server):
def install_signal_handlers(self):
pass
def run_in_thread(self):
self.thread = threading.Thread(target=self.run, daemon=True)
self.thread.start()
while not self.started:
time.sleep(1e-3)
def close(self):
self.should_exit = True
self.thread.join()
def get_first_available_port(initial: int, final: int) -> int:
"""
Gets the first open port in a specified range of port numbers
Parameters:
initial: the initial value in the range of port numbers
final: final (exclusive) value in the range of port numbers, should be greater than `initial`
Returns:
port: the first open port in the range
"""
for port in range(initial, final):
try:
s = socket.socket() # create a socket object
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
s.bind((LOCALHOST_NAME, port)) # Bind to the port
s.close()
return port
except OSError:
pass
raise OSError(
"All ports from {} to {} are in use. Please close a port.".format(
initial, final - 1
)
)
def configure_app(app: App, blocks: Blocks) -> App:
auth = blocks.auth
if auth is not None:
if not callable(auth):
app.auth = {account[0]: account[1] for account in auth}
else:
app.auth = auth
else:
app.auth = None
app.blocks = blocks
app.cwd = os.getcwd()
app.favicon_path = blocks.favicon_path
app.tokens = {}
return app
def start_server(
blocks: Blocks,
server_name: str | None = None,
server_port: int | None = None,
ssl_keyfile: str | None = None,
ssl_certfile: str | None = None,
ssl_keyfile_password: str | None = None,
) -> Tuple[str, int, str, App, Server]:
"""Launches a local server running the provided Interface
Parameters:
blocks: The Blocks object to run on the server
server_name: to make app accessible on local network, set this to "0.0.0.0". Can be set by environment variable GRADIO_SERVER_NAME.
server_port: will start gradio app on this port (if available). Can be set by environment variable GRADIO_SERVER_PORT.
auth: If provided, username and password (or list of username-password tuples) required to access the Blocks. Can also provide function that takes username and password and returns True if valid login.
ssl_keyfile: If a path to a file is provided, will use this as the private key file to create a local server running on https.
ssl_certfile: If a path to a file is provided, will use this as the signed certificate for https. Needs to be provided if ssl_keyfile is provided.
ssl_keyfile_password: If a password is provided, will use this with the ssl certificate for https.
Returns:
port: the port number the server is running on
path_to_local_server: the complete address that the local server can be accessed at
app: the FastAPI app object
server: the server object that is a subclass of uvicorn.Server (used to close the server)
"""
server_name = server_name or LOCALHOST_NAME
# if port is not specified, search for first available port
if server_port is None:
port = get_first_available_port(
INITIAL_PORT_VALUE, INITIAL_PORT_VALUE + TRY_NUM_PORTS
)
else:
try:
s = socket.socket()
s.bind((LOCALHOST_NAME, server_port))
s.close()
except OSError:
raise OSError(
"Port {} is in use. If a gradio.Blocks is running on the port, you can close() it or gradio.close_all().".format(
server_port
)
)
port = server_port
url_host_name = "localhost" if server_name == "0.0.0.0" else server_name
if ssl_keyfile is not None:
if ssl_certfile is None:
raise ValueError(
"ssl_certfile must be provided if ssl_keyfile is provided."
)
path_to_local_server = "https://{}:{}/".format(url_host_name, port)
else:
path_to_local_server = "http://{}:{}/".format(url_host_name, port)
app = App.create_app(blocks)
if blocks.save_to is not None: # Used for selenium tests
blocks.save_to["port"] = port
config = uvicorn.Config(
app=app,
port=port,
host=server_name,
log_level="warning",
ssl_keyfile=ssl_keyfile,
ssl_certfile=ssl_certfile,
ssl_keyfile_password=ssl_keyfile_password,
ws_max_size=1024 * 1024 * 1024, # Setting max websocket size to be 1 GB
)
server = Server(config=config)
server.run_in_thread()
return server_name, port, path_to_local_server, app, server
def setup_tunnel(local_host: str, local_port: int) -> str:
response = requests.get(GRADIO_API_SERVER)
if response and response.status_code == 200:
try:
payload = response.json()[0]
remote_host, remote_port = payload["host"], int(payload["port"])
tunnel = Tunnel(remote_host, remote_port, local_host, local_port)
address = tunnel.start_tunnel()
return address
except Exception as e:
raise RuntimeError(str(e))
else:
raise RuntimeError("Could not get share link from Gradio API Server.")
def url_ok(url: str) -> bool:
try:
for _ in range(5):
with warnings.catch_warnings():
warnings.filterwarnings("ignore")
r = requests.head(url, timeout=3, verify=False)
if r.status_code in (200, 401, 302): # 401 or 302 if auth is set
return True
time.sleep(0.500)
except (ConnectionError, requests.exceptions.ConnectionError):
return False
return False