1.projekat / app.py
skatanic9421rn
komit
e07ca76
raw history blame
No virus
1.33 kB
import os
import torch
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from PIL import Image
from models import ResNet18
from datasets import HandGestureDataset
# Set the path to the dataset directory
data_dir = 'dataset'
# Define the image transforms
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# Define the dataset
dataset = HandGestureDataset(data_dir, transform=transform)
# Define the data loader
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
# Load the pre-trained neural network
model = ResNet18(pretrained=True)
# Replace the final fully connected layer with a new one
num_classes = 7
model.fc = torch.nn.Linear(model.fc.in_features, num_classes)
# Set the model to evaluation mode
model.eval()
# Load an image from the dataset
for i, (image, label) in enumerate(dataloader):
# Apply the image transforms
image = transform(image)
# Add a batch dimension
image = image.unsqueeze(0)
# Make a prediction on the image
with torch.no_grad():
output = model(image)
prediction = torch.argmax(output)
# Print the prediction
print(f'Image {i+1}, Predicted note: {prediction.item()}')