File size: 800 Bytes
85edc93
edc8afb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85edc93
edc8afb
85edc93
 
 
edc8afb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import os
import cv2
import torch
from urllib.request import urlretrieve

def read_image(file):
    """Reads the image file

    Returns the numpy array.

    Args:
        file : path to the image

    Returns:
        (numpy.ndarray): image read as numpy array
    """
    image = cv2.imread(file)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    return image
        
def accuracy(predictions, ground_truth):
    """Funtion to calculate accuracy of the model.
    """
    
    _, preds = torch.max(predictions, dim=1)
    score = (preds == ground_truth).float().mean()
    return score.item()

def download_weights(url):
    cd = os.getcwd()
    fname = url.split('/')[-1]
    fname = os.path.join(cd, fname)
    if not os.path.exists(fname):
        urlretrieve(url, fname)
    return fname