Spaces:
Runtime error
Runtime error
import os | |
import torch | |
import torchvision.transforms as transforms | |
from torch.utils.data import DataLoader | |
from PIL import Image | |
from models import ResNet18 | |
from datasets import HandGestureDataset | |
# Set the path to the dataset directory | |
data_dir = 'dataset' | |
# Define the image transforms | |
transform = transforms.Compose([ | |
transforms.Resize((224, 224)), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
]) | |
# Define the dataset | |
dataset = HandGestureDataset(data_dir, transform=transform) | |
# Define the data loader | |
dataloader = DataLoader(dataset, batch_size=32, shuffle=True) | |
# Load the pre-trained neural network | |
model = ResNet18(pretrained=True) | |
# Replace the final fully connected layer with a new one | |
num_classes = 7 | |
model.fc = torch.nn.Linear(model.fc.in_features, num_classes) | |
# Set the model to evaluation mode | |
model.eval() | |
# Load an image from the dataset | |
for i, (image, label) in enumerate(dataloader): | |
# Apply the image transforms | |
image = transform(image) | |
# Add a batch dimension | |
image = image.unsqueeze(0) | |
# Make a prediction on the image | |
with torch.no_grad(): | |
output = model(image) | |
prediction = torch.argmax(output) | |
# Print the prediction | |
print(f'Image {i+1}, Predicted note: {prediction.item()}') |