bird-species-classifier / evaluate_model.py
chriamue's picture
adds evaluate script
62a2783
raw
history blame
1.22 kB
import torch
import urllib.request
from PIL import Image
from datasets import load_dataset
from transformers import EfficientNetImageProcessor, EfficientNetForImageClassification
dataset = load_dataset("chriamue/bird-species-dataset")
#####
labels = dataset["test"].features["label"].names
label2id, id2label = dict(), dict()
for i, label in enumerate(labels):
label2id[label] = str(i)
id2label[str(i)] = label
preprocessor = EfficientNetImageProcessor.from_pretrained("google/efficientnet-b2")
model = EfficientNetForImageClassification.from_pretrained("chriamue/bird-species-classifier", num_labels=len(
labels), id2label=id2label, label2id=label2id, ignore_mismatched_sizes=True)
image = dataset["validation"][0]["image"]
url = 'https://upload.wikimedia.org/wikipedia/commons/a/a9/Common_Blackbird.jpg'
image = Image.open(urllib.request.urlretrieve(url)[0])
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
inputs = preprocessor(image, return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
logits = model(**inputs).logits
predicted_label = logits.argmax(-1).item()
print(labels[predicted_label])