Spaces:
Build error
Build error
| import gradio | |
| import torch | |
| import torchvision.transforms as T | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| import os | |
| import random | |
| import PIL.Image as Image | |
| import time | |
| from model import create_fasterrcnn_model | |
| categories = [ | |
| { | |
| "id": 0, | |
| "name": "creatures", | |
| "supercategory": "none" | |
| }, | |
| { | |
| "id": 1, | |
| "name": "fish", | |
| "supercategory": "creatures" | |
| }, | |
| { | |
| "id": 2, | |
| "name": "jellyfish", | |
| "supercategory": "creatures" | |
| }, | |
| { | |
| "id": 3, | |
| "name": "penguin", | |
| "supercategory": "creatures" | |
| }, | |
| { | |
| "id": 4, | |
| "name": "puffin", | |
| "supercategory": "creatures" | |
| }, | |
| { | |
| "id": 5, | |
| "name": "shark", | |
| "supercategory": "creatures" | |
| }, | |
| { | |
| "id": 6, | |
| "name": "starfish", | |
| "supercategory": "creatures" | |
| }, | |
| { | |
| "id": 7, | |
| "name": "stingray", | |
| "supercategory": "creatures" | |
| } | |
| ] | |
| # 1, Create title, description and article strings | |
| title = "Ocean creatures detection Faster-R-CNN" | |
| description = "A Faster-RCNN-ResNet-50 backbone feature extractor computer vision model to classify images of fish, penguin, shark, etc" | |
| faster_rcnn = create_fasterrcnn_model( | |
| num_classes=8, # len(class_names) would also work | |
| ) | |
| # Load saved weights | |
| faster_rcnn.load_state_dict( | |
| torch.load( | |
| f="./third_train.pth", | |
| map_location=torch.device("cpu"), # load to CPU | |
| ) | |
| ) | |
| import random | |
| # Create predict function | |
| def predict(img): | |
| """Transforms and performs a prediction on img and returns prediction and time taken. | |
| """ | |
| # Start the timer | |
| start_time = time.time() | |
| device = 'cpu' | |
| transform = T.Compose([T.ToPILImage(),T.ToTensor()]) | |
| image_tensor = transform(img).to(device) | |
| image_tensor = image_tensor.unsqueeze(0) | |
| faster_rcnn.eval() | |
| with torch.no_grad(): | |
| predictions = faster_rcnn(image_tensor) | |
| pred_boxes = predictions[0]['boxes'].cpu().numpy() | |
| pred_scores = predictions[0]['scores'].cpu().numpy() | |
| pred_labels = predictions[0]['labels'].cpu().numpy() | |
| label_names = [categories[label]['name'] for label in pred_labels] | |
| fig, ax = plt.subplots(1) | |
| ax.imshow(img) | |
| for box, score, label_name in zip(pred_boxes, pred_scores, label_names): | |
| if score > 0.5: | |
| x1, y1, x2, y2 = box | |
| w, h = x2 - x1, y2 - y1 | |
| rect = plt.Rectangle((x1, y1), w, h, fill=False, edgecolor='red', linewidth=2) | |
| ax.add_patch(rect) | |
| ax.text(x1, y1, f'{label_name}: {score:.2f}', fontsize=5, color='white', bbox=dict(facecolor='red', alpha=0.2)) | |
| # save the figure to an image file | |
| random_name = str(random.randint(0,99)) | |
| img_path = f"./{random_name}.png" | |
| fig.savefig(img_path) | |
| # convert the figure to an image | |
| fig.canvas.draw() | |
| # Calculate the prediction time | |
| pred_time = round(time.time() - start_time, 5) | |
| # return the predicted label, the path to the saved image, and the prediction time | |
| return img_path, str(pred_time) | |
| ### 4. Gradio app ### | |
| # Get a list of all image file paths in the folder | |
| example_list = [["examples/" + example] for example in os.listdir("examples")] | |
| # Create the Gradio demo | |
| demo = gradio.Interface(fn=predict, # mapping function from input to output | |
| inputs=gradio.Image(type= "numpy"), # what are the inputs? | |
| outputs=[gradio.outputs.Image(type= "filepath", label="Image with Bounding Boxes"), | |
| gradio.outputs.Label(type="auto", label="Prediction Time")], # our fn has two outputs | |
| # Create examples list from "examples/" directory | |
| examples=example_list, | |
| title=title, | |
| description=description) | |
| # Launch the demo! | |
| demo.launch(debug =True) |