File size: 2,212 Bytes
cb955fa
15d6ba8
 
7c3ad8f
15d6ba8
7c3ad8f
 
15d6ba8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cb955fa
 
2cc79c8
9aaffac
 
 
 
 
 
 
 
 
2cc79c8
15d6ba8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7c3ad8f
8cc9b12
cb955fa
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import gradio as gr
from huggingface_hub import hf_hub_url
import requests
import torch
from torchvision.io import read_image
from torchvision import transforms as T
import torchvision.models as models


# we import the model
resnet34 = models.resnet34(pretrained=True)
# evaluation mode
resnet34.eval()

## labeling
# Load the file containing the 1,000 labels for the ImageNet dataset classes
url = hf_hub_url(repo_id="Selma/pytorch-resnet34", filename="imagenet_classes.txt")
response = requests.get(url)
# write to a label file
open("labels.txt", "wb").write(response.content)
# extract the labels from the file
with open('labels.txt', "r") as f:
    labels = [line.strip() for line in f.readlines()]


def classify(image_raw):
    """
    Takes an image in PIL format as input and returns a text that gives the most likely label to the image, along with its score according to resnet34
    
    :param image_raw: The image to be classified/labeled
    :type image_raw: PIL image
    
    :returns: a list of strings representing the header columns
    :rtype: str
    """

    ## preprocessing
    # we need a transform step to normalise the pictures
    transform = T.Compose([T.Resize(256), T.CenterCrop(224), T.ToTensor(),
                           T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
    # normalise the image
    image_transformed = transform(image)
    # reshape
    batch_image_transformed = torch.unsqueeze(image_transformed, 0)

    # get the predictions
    output = resnet34(batch_image_transformed)

    ## predict the class
    # Find the index (tensor) corresponding to the maximum score in the out tensor.
    # Torch.max function can be used to find the information
    _, index = torch.max(output, 1)

    # Find the score in terms of percentage by using torch.nn.functional.softmax function
    # which normalizes the output to range [0,1] and multiplying by 100
    percentage = torch.nn.functional.softmax(output, dim=1)[0] * 100

    return "The image depicts: " + labels[index[0]] + " with a score of " + str(round(percentage[index[0]].item())) + "%"

iface = gr.Interface(fn=classify, inputs=gr.inputs.Image(type="pil"), outputs="text")
iface.launch()