import os import cv2 import torch from urllib.request import urlretrieve def read_image(file): """Reads the image file Returns the numpy array. Args: file : path to the image Returns: (numpy.ndarray): image read as numpy array """ image = cv2.imread(file) image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) return image def accuracy(predictions, ground_truth): """Funtion to calculate accuracy of the model. """ _, preds = torch.max(predictions, dim=1) score = (preds == ground_truth).float().mean() return score.item() def download_weights(url): cd = os.getcwd() fname = url.split('/')[-1] fname = os.path.join(cd, fname) if not os.path.exists(fname): urlretrieve(url, fname) return fname