import gradio as gr from huggingface_hub import hf_hub_url import requests import torch from torchvision.io import read_image from torchvision import transforms as T import torchvision.models as models # we import the model resnet34 = models.resnet34(pretrained=True) # evaluation mode resnet34.eval() ## labeling # Load the file containing the 1,000 labels for the ImageNet dataset classes url = hf_hub_url(repo_id="Selma/pytorch-resnet34", filename="imagenet_classes.txt") response = requests.get(url) # write to a label file open("labels.txt", "wb").write(response.content) # extract the labels from the file with open('labels.txt', "r") as f: labels = [line.strip() for line in f.readlines()] def classify(image_raw): """ Takes an image in PIL format as input and returns a text that gives the most likely label to the image, along with its score according to resnet34 :param image_raw: The image to be classified/labeled :type image_raw: PIL image :returns: a list of strings representing the header columns :rtype: str """ ## preprocessing # we need a transform step to normalise the pictures transform = T.Compose([T.Resize(256), T.CenterCrop(224), T.ToTensor(), T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) # normalise the image image_transformed = transform(image) # reshape batch_image_transformed = torch.unsqueeze(image_transformed, 0) # get the predictions output = resnet34(batch_image_transformed) ## predict the class # Find the index (tensor) corresponding to the maximum score in the out tensor. # Torch.max function can be used to find the information _, index = torch.max(output, 1) # Find the score in terms of percentage by using torch.nn.functional.softmax function # which normalizes the output to range [0,1] and multiplying by 100 percentage = torch.nn.functional.softmax(output, dim=1)[0] * 100 return "The image depicts: " + labels[index[0]] + " with a score of " + str(round(percentage[index[0]].item())) + "%" iface = gr.Interface(fn=classify, inputs=gr.inputs.Image(type="pil"), outputs="text") iface.launch()