File size: 1,905 Bytes
832aa0d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import argparse

import gradio as gr
import torch
from PIL import Image
from torchvision.transforms import Compose, ConvertImageDtype, Normalize, PILToTensor, Resize
from torchvision.transforms.functional import InterpolationMode

from holocron import models


def main(args):

    model = models.rexnet1_3x(pretrained=True).eval()

    preprocessor = Compose([
        Resize(model.default_cfg['input_shape'][1:], interpolation=InterpolationMode.BILINEAR),
        PILToTensor(),
        ConvertImageDtype(torch.float32),
        Normalize(model.default_cfg['mean'], model.default_cfg['std'])
    ])

    def predict(input):
        input = Image.fromarray(input.astype('uint8'), 'RGB')
        input = preprocessor(input)
        with torch.inference_mode():
            prediction = torch.nn.functional.softmax(model(input.unsqueeze(0))[0], dim=0)
        return {class_name: float(conf) for class_name, conf in zip(model.default_cfg['classes'], prediction)}

    image = gr.inputs.Image()
    outputs = gr.outputs.Label(num_top_classes=3)

    interface = gr.Interface(
        fn=predict,
        inputs=[image],
        outputs=outputs,
        title="Holocron: image classification demo",
        article=("<p style='text-align: center'><a href='https://github.com/frgfm/Holocron'>" "Github Repo</a> | "
                 "<a href='https://frgfm.github.io/Holocron/'>Documentation</a></p>"),
        live=True,
        theme="huggingface",
        layout="horizontal",
    )

    interface.launch(server_port=args.port, show_error=True)


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Holocron image classification demo',
                                     formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument("--port", type=int, default=8001, help="Port on which the webserver will be run")
    args = parser.parse_args()

    main(args)