Holocron / app.py
fg-mindee
feat: Added first draft of demo app
832aa0d
raw history blame
No virus
1.91 kB
import argparse
import gradio as gr
import torch
from PIL import Image
from torchvision.transforms import Compose, ConvertImageDtype, Normalize, PILToTensor, Resize
from torchvision.transforms.functional import InterpolationMode
from holocron import models
def main(args):
model = models.rexnet1_3x(pretrained=True).eval()
preprocessor = Compose([
Resize(model.default_cfg['input_shape'][1:], interpolation=InterpolationMode.BILINEAR),
PILToTensor(),
ConvertImageDtype(torch.float32),
Normalize(model.default_cfg['mean'], model.default_cfg['std'])
])
def predict(input):
input = Image.fromarray(input.astype('uint8'), 'RGB')
input = preprocessor(input)
with torch.inference_mode():
prediction = torch.nn.functional.softmax(model(input.unsqueeze(0))[0], dim=0)
return {class_name: float(conf) for class_name, conf in zip(model.default_cfg['classes'], prediction)}
image = gr.inputs.Image()
outputs = gr.outputs.Label(num_top_classes=3)
interface = gr.Interface(
fn=predict,
inputs=[image],
outputs=outputs,
title="Holocron: image classification demo",
article=("<p style='text-align: center'><a href='https://github.com/frgfm/Holocron'>" "Github Repo</a> | "
"<a href='https://frgfm.github.io/Holocron/'>Documentation</a></p>"),
live=True,
theme="huggingface",
layout="horizontal",
)
interface.launch(server_port=args.port, show_error=True)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Holocron image classification demo',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("--port", type=int, default=8001, help="Port on which the webserver will be run")
args = parser.parse_args()
main(args)