import argparse import gradio as gr import torch from PIL import Image from torchvision.transforms import Compose, ConvertImageDtype, Normalize, PILToTensor, Resize from torchvision.transforms.functional import InterpolationMode from holocron import models def main(args): model = models.rexnet1_3x(pretrained=True).eval() preprocessor = Compose([ Resize(model.default_cfg['input_shape'][1:], interpolation=InterpolationMode.BILINEAR), PILToTensor(), ConvertImageDtype(torch.float32), Normalize(model.default_cfg['mean'], model.default_cfg['std']) ]) def predict(input): input = Image.fromarray(input.astype('uint8'), 'RGB') input = preprocessor(input) with torch.inference_mode(): prediction = torch.nn.functional.softmax(model(input.unsqueeze(0))[0], dim=0) return {class_name: float(conf) for class_name, conf in zip(model.default_cfg['classes'], prediction)} image = gr.inputs.Image() outputs = gr.outputs.Label(num_top_classes=3) interface = gr.Interface( fn=predict, inputs=[image], outputs=outputs, title="Holocron: image classification demo", article=("

" "Github Repo | " "Documentation

"), live=True, theme="huggingface", layout="horizontal", ) interface.launch(server_port=args.port, show_error=True) if __name__ == '__main__': parser = argparse.ArgumentParser(description='Holocron image classification demo', formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument("--port", type=int, default=8001, help="Port on which the webserver will be run") args = parser.parse_args() main(args)