from gradio.outputs import Label from icevision import tfms from icevision.models.checkpoint import model_from_checkpoint import PIL import gradio as gr import os detection_threshold = 0.7 checkpoint_path = "2022-10-03_resnet_slates_147.pth" description = ( "A faster-rcnn model that detects film slates / clappers. " "Upload an image of a slate or click an example below!" ) # Load model checkpoint_and_model = model_from_checkpoint(checkpoint_path) model = checkpoint_and_model["model"] model_type = checkpoint_and_model["model_type"] class_map = checkpoint_and_model["class_map"] # Transforms img_size = checkpoint_and_model["img_size"] valid_tfms = tfms.A.Adapter( [*tfms.A.resize_and_pad(img_size), tfms.A.Normalize()] ) # Populate examples in Gradio interface examples = [ ["1.jpg"], ["2.jpg"], ["3.jpg"], ] def show_preds(input_image): img = PIL.Image.fromarray(input_image, "RGB") pred_dict = model_type.end2end_detect( img, valid_tfms, model, class_map=class_map, detection_threshold=detection_threshold, display_label=False, display_bbox=True, return_img=True, font_size=16, label_color="#02D800", ) return pred_dict["img"] gr_interface = gr.Interface( fn=show_preds, inputs=["image"], outputs=[gr.outputs.Image(type="pil", label="Inference")], title="Film Slate Detector", description=description, examples=examples, ) gr_interface.launch(inline=False, share=False, debug=True)