File size: 6,724 Bytes
443d045 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 |
"""
Defines helper methods useful for setting up ports, launching servers, and
creating tunnels.
"""
from __future__ import annotations
import os
import socket
import threading
import time
import warnings
from typing import TYPE_CHECKING, Tuple
import requests
import uvicorn
from gradio.routes import App
from gradio.tunneling import Tunnel
if TYPE_CHECKING: # Only import for type checking (to avoid circular imports).
from gradio.blocks import Blocks
# By default, the local server will try to open on localhost, port 7860.
# If that is not available, then it will try 7861, 7862, ... 7959.
INITIAL_PORT_VALUE = int(os.getenv("GRADIO_SERVER_PORT", "7860"))
TRY_NUM_PORTS = int(os.getenv("GRADIO_NUM_PORTS", "100"))
LOCALHOST_NAME = os.getenv("GRADIO_SERVER_NAME", "127.0.0.1")
GRADIO_API_SERVER = "https://api.gradio.app/v2/tunnel-request"
class Server(uvicorn.Server):
def install_signal_handlers(self):
pass
def run_in_thread(self):
self.thread = threading.Thread(target=self.run, daemon=True)
self.thread.start()
while not self.started:
time.sleep(1e-3)
def close(self):
self.should_exit = True
self.thread.join()
def get_first_available_port(initial: int, final: int) -> int:
"""
Gets the first open port in a specified range of port numbers
Parameters:
initial: the initial value in the range of port numbers
final: final (exclusive) value in the range of port numbers, should be greater than `initial`
Returns:
port: the first open port in the range
"""
for port in range(initial, final):
try:
s = socket.socket() # create a socket object
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
s.bind((LOCALHOST_NAME, port)) # Bind to the port
s.close()
return port
except OSError:
pass
raise OSError(
"All ports from {} to {} are in use. Please close a port.".format(
initial, final - 1
)
)
def configure_app(app: App, blocks: Blocks) -> App:
auth = blocks.auth
if auth is not None:
if not callable(auth):
app.auth = {account[0]: account[1] for account in auth}
else:
app.auth = auth
else:
app.auth = None
app.blocks = blocks
app.cwd = os.getcwd()
app.favicon_path = blocks.favicon_path
app.tokens = {}
return app
def start_server(
blocks: Blocks,
server_name: str | None = None,
server_port: int | None = None,
ssl_keyfile: str | None = None,
ssl_certfile: str | None = None,
ssl_keyfile_password: str | None = None,
) -> Tuple[str, int, str, App, Server]:
"""Launches a local server running the provided Interface
Parameters:
blocks: The Blocks object to run on the server
server_name: to make app accessible on local network, set this to "0.0.0.0". Can be set by environment variable GRADIO_SERVER_NAME.
server_port: will start gradio app on this port (if available). Can be set by environment variable GRADIO_SERVER_PORT.
auth: If provided, username and password (or list of username-password tuples) required to access the Blocks. Can also provide function that takes username and password and returns True if valid login.
ssl_keyfile: If a path to a file is provided, will use this as the private key file to create a local server running on https.
ssl_certfile: If a path to a file is provided, will use this as the signed certificate for https. Needs to be provided if ssl_keyfile is provided.
ssl_keyfile_password: If a password is provided, will use this with the ssl certificate for https.
Returns:
port: the port number the server is running on
path_to_local_server: the complete address that the local server can be accessed at
app: the FastAPI app object
server: the server object that is a subclass of uvicorn.Server (used to close the server)
"""
server_name = server_name or LOCALHOST_NAME
# if port is not specified, search for first available port
if server_port is None:
port = get_first_available_port(
INITIAL_PORT_VALUE, INITIAL_PORT_VALUE + TRY_NUM_PORTS
)
else:
try:
s = socket.socket()
s.bind((LOCALHOST_NAME, server_port))
s.close()
except OSError:
raise OSError(
"Port {} is in use. If a gradio.Blocks is running on the port, you can close() it or gradio.close_all().".format(
server_port
)
)
port = server_port
url_host_name = "localhost" if server_name == "0.0.0.0" else server_name
if ssl_keyfile is not None:
if ssl_certfile is None:
raise ValueError(
"ssl_certfile must be provided if ssl_keyfile is provided."
)
path_to_local_server = "https://{}:{}/".format(url_host_name, port)
else:
path_to_local_server = "http://{}:{}/".format(url_host_name, port)
app = App.create_app(blocks)
if blocks.save_to is not None: # Used for selenium tests
blocks.save_to["port"] = port
config = uvicorn.Config(
app=app,
port=port,
host=server_name,
log_level="warning",
ssl_keyfile=ssl_keyfile,
ssl_certfile=ssl_certfile,
ssl_keyfile_password=ssl_keyfile_password,
ws_max_size=1024 * 1024 * 1024, # Setting max websocket size to be 1 GB
)
server = Server(config=config)
server.run_in_thread()
return server_name, port, path_to_local_server, app, server
def setup_tunnel(local_host: str, local_port: int) -> str:
response = requests.get(GRADIO_API_SERVER)
if response and response.status_code == 200:
try:
payload = response.json()[0]
remote_host, remote_port = payload["host"], int(payload["port"])
tunnel = Tunnel(remote_host, remote_port, local_host, local_port)
address = tunnel.start_tunnel()
return address
except Exception as e:
raise RuntimeError(str(e))
else:
raise RuntimeError("Could not get share link from Gradio API Server.")
def url_ok(url: str) -> bool:
try:
for _ in range(5):
with warnings.catch_warnings():
warnings.filterwarnings("ignore")
r = requests.head(url, timeout=3, verify=False)
if r.status_code in (200, 401, 302): # 401 or 302 if auth is set
return True
time.sleep(0.500)
except (ConnectionError, requests.exceptions.ConnectionError):
return False
return False
|