desco / main.py
zdou0830's picture
test
551c110
raw
history blame
No virus
2.05 kB
import gradio as gr
import torch
import requests
import socket
from torchvision import transforms
import os
torch.hub._validate_not_a_forked_repo=lambda a,b,c: True
model = torch.hub.load("pytorch/vision:v0.6.0", "resnet18", pretrained=True).eval()
response = requests.get("https://git.io/JJkYN")
labels = response.text.split("\n")
INITIAL_PORT_VALUE = int(os.getenv("GRADIO_SERVER_PORT", "7860"))
TRY_NUM_PORTS = int(os.getenv("GRADIO_NUM_PORTS", "100"))
LOCALHOST_NAME = os.getenv("GRADIO_SERVER_NAME", "127.0.0.1")
GRADIO_API_SERVER = "https://api.gradio.app/v1/tunnel-request"
def get_first_available_port(initial: int, final: int) -> int:
"""
Gets the first open port in a specified range of port numbers
Parameters:
initial: the initial value in the range of port numbers
final: final (exclusive) value in the range of port numbers, should be greater than `initial`
Returns:
port: the first open port in the range
"""
for port in range(initial, final):
try:
s = socket.socket() # create a socket object
s.bind((LOCALHOST_NAME, port)) # Bind to the port
s.close()
return port
except OSError:
pass
raise OSError(
"All ports from {} to {} are in use. Please close a port.".format(
initial, final
)
)
def predict(inp):
inp = transforms.ToTensor()(inp).unsqueeze(0)
with torch.no_grad():
prediction = torch.nn.functional.softmax(model(inp)[0], dim=0)
confidences = {labels[i]: float(prediction[i]) for i in range(1000)}
return confidences
def run():
demo = gr.Interface(
fn=predict,
inputs=gr.inputs.Image(type="pil"),
outputs=gr.outputs.Label(num_top_classes=3),
)
demo.launch(server_name=LOCALHOST_NAME, server_port=get_first_available_port(
INITIAL_PORT_VALUE, INITIAL_PORT_VALUE + TRY_NUM_PORTS
))
#demo.launch(server_name="0.0.0.0", server_port=7861)
if __name__ == "__main__":
run()