File size: 3,681 Bytes
81248a2
b4c3cb2
81248a2
 
551c110
81248a2
5390f50
af70a47
4a4402a
5be48be
 
81248a2
 
 
4a4402a
 
f444299
 
 
 
 
 
af70a47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f444299
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81248a2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9fdaaeb
 
 
 
f444299
81248a2
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import gradio as gr
import sys
import torch
import requests
import socket
from torchvision import transforms
import os
import asyncio

torch.hub._validate_not_a_forked_repo=lambda a,b,c: True

model = torch.hub.load("pytorch/vision:v0.6.0", "resnet18", pretrained=True).eval()
response = requests.get("https://git.io/JJkYN")
labels = response.text.split("\n")


INITIAL_PORT_VALUE = int(os.getenv("GRADIO_SERVER_PORT", "7860"))
TRY_NUM_PORTS = int(os.getenv("GRADIO_NUM_PORTS", "100"))
LOCALHOST_NAME = os.getenv("GRADIO_SERVER_NAME", "127.0.0.1")
GRADIO_API_SERVER = "https://api.gradio.app/v1/tunnel-request"


if sys.platform == "win32" and hasattr(asyncio, "WindowsSelectorEventLoopPolicy"):
    # "Any thread" and "selector" should be orthogonal, but there's not a clean
    # interface for composing policies so pick the right base.
    _BasePolicy = asyncio.WindowsSelectorEventLoopPolicy  # type: ignore
else:
    _BasePolicy = asyncio.DefaultEventLoopPolicy


class AnyThreadEventLoopPolicy(_BasePolicy):  # type: ignore
    """Event loop policy that allows loop creation on any thread.
    The default `asyncio` event loop policy only automatically creates
    event loops in the main threads. Other threads must create event
    loops explicitly or `asyncio.get_event_loop` (and therefore
    `.IOLoop.current`) will fail. Installing this policy allows event
    loops to be created automatically on any thread, matching the
    behavior of Tornado versions prior to 5.0 (or 5.0 on Python 2).
    Usage::
        asyncio.set_event_loop_policy(AnyThreadEventLoopPolicy())
    .. versionadded:: 5.0
    """

    def get_event_loop(self) -> asyncio.AbstractEventLoop:
        try:
            return super().get_event_loop()
        except (RuntimeError, AssertionError):
            # This was an AssertionError in python 3.4.2 (which ships with debian jessie)
            # and changed to a RuntimeError in 3.4.3.
            # "There is no current event loop in thread %r"
            loop = self.new_event_loop()
            self.set_event_loop(loop)
            return loop


asyncio.set_event_loop_policy(AnyThreadEventLoopPolicy())

def get_first_available_port(initial: int, final: int) -> int:
    """
    Gets the first open port in a specified range of port numbers
    Parameters:
    initial: the initial value in the range of port numbers
    final: final (exclusive) value in the range of port numbers, should be greater than `initial`
    Returns:
    port: the first open port in the range
    """
    for port in range(initial, final):
        try:
            s = socket.socket()  # create a socket object
            s.bind((LOCALHOST_NAME, port))  # Bind to the port
            s.close()
            return port
        except OSError:
            pass
    raise OSError(
        "All ports from {} to {} are in use. Please close a port.".format(
            initial, final
        )
    )

def predict(inp):
    inp = transforms.ToTensor()(inp).unsqueeze(0)
    with torch.no_grad():
        prediction = torch.nn.functional.softmax(model(inp)[0], dim=0)
        confidences = {labels[i]: float(prediction[i]) for i in range(1000)}
    return confidences


def run():
    demo = gr.Interface(
        fn=predict,
        inputs=gr.inputs.Image(type="pil"),
        outputs=gr.outputs.Label(num_top_classes=3),
    )

    demo.launch() #server_name="0.0.0.0", server_port=7861)
    #demo.launch(server_name=LOCALHOST_NAME, server_port=get_first_available_port(
    #        INITIAL_PORT_VALUE, INITIAL_PORT_VALUE + TRY_NUM_PORTS
    #    ), share=True)
    #demo.launch(server_name="0.0.0.0", server_port=7861)


if __name__ == "__main__":
    run()