|
import gradio as gr |
|
from huggingface_hub import hf_hub_url |
|
import requests |
|
import torch |
|
from torchvision.io import read_image |
|
from torchvision import transforms as T |
|
import torchvision.models as models |
|
|
|
|
|
|
|
resnet34 = models.resnet34(pretrained=True) |
|
|
|
resnet34.eval() |
|
|
|
|
|
|
|
url = hf_hub_url(repo_id="Selma/pytorch-resnet34", filename="imagenet_classes.txt") |
|
response = requests.get(url) |
|
|
|
open("labels.txt", "wb").write(response.content) |
|
|
|
with open('labels.txt', "r") as f: |
|
labels = [line.strip() for line in f.readlines()] |
|
|
|
|
|
def classify(image_raw): |
|
""" |
|
Takes an image in PIL format as input and returns a text that gives the most likely label to the image, along with its score according to resnet34 |
|
|
|
:param image_raw: The image to be classified/labeled |
|
:type image_raw: PIL image |
|
|
|
:returns: a list of strings representing the header columns |
|
:rtype: str |
|
""" |
|
|
|
|
|
|
|
transform = T.Compose([T.Resize(256), T.CenterCrop(224), T.ToTensor(), |
|
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) |
|
|
|
image_transformed = transform(image) |
|
|
|
batch_image_transformed = torch.unsqueeze(image_transformed, 0) |
|
|
|
|
|
output = resnet34(batch_image_transformed) |
|
|
|
|
|
|
|
|
|
_, index = torch.max(output, 1) |
|
|
|
|
|
|
|
percentage = torch.nn.functional.softmax(output, dim=1)[0] * 100 |
|
|
|
return "The image depicts: " + labels[index[0]] + " with a score of " + str(round(percentage[index[0]].item())) + "%" |
|
|
|
iface = gr.Interface(fn=classify, inputs=gr.inputs.Image(type="pil"), outputs="text") |
|
iface.launch() |