File size: 6,361 Bytes
ae1d0b9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 |
import gc
import os
import torch
from PIL import Image
from torch import nn, optim
from torch.utils.data import DataLoader, Dataset, random_split
from torchvision import models, transforms
# Define data transformations for training and validation
transform = transforms.Compose(
[
transforms.Resize((224, 224)), # Ensure all images are 224x224
transforms.ToTensor(), # Convert to tensor
transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
), # Standard for ResNet
]
)
# Custom dataset class for loading chess piece images
class ChessPieceDataset(Dataset):
def __init__(self, root_dir, transform=None):
"""
Args:
root_dir (str): Directory with all the images and subdirectories (class labels).
transform (callable, optional): Optional transform to be applied on an image.
"""
self.root_dir = root_dir
self.transform = transform
self.classes = sorted(
[
d
for d in os.listdir(root_dir)
if os.path.isdir(os.path.join(root_dir, d))
]
)
self.image_paths = []
self.labels = []
for label, class_name in enumerate(self.classes):
class_folder = os.path.join(root_dir, class_name)
for image_name in os.listdir(class_folder):
img_path = os.path.join(class_folder, image_name)
# Only include valid image files
if img_path.lower().endswith((".png", ".jpg", ".jpeg", ".bmp", ".gif")):
try:
# Verify the image can be opened
with Image.open(img_path) as img:
img.verify() # Verify image integrity
self.image_paths.append(img_path)
self.labels.append(label)
except Exception as e:
print(f"Skipping corrupted image {img_path}: {e}")
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
img_path = self.image_paths[idx]
try:
image = Image.open(img_path).convert("RGB")
except Exception as e:
print(f"Error loading image {img_path}: {e}")
# Return a dummy image and label to avoid crashing
image = Image.new("RGB", (224, 224), (0, 0, 0))
label = self.labels[idx]
else:
label = self.labels[idx]
if self.transform:
try:
image = self.transform(image)
# Verify the image size after transformation
if image.shape != (3, 224, 224):
print(
f"Unexpected image size after transform for {img_path}: {image.shape}"
)
except Exception as e:
print(f"Error applying transform to {img_path}: {e}")
image = self.transform(Image.new("RGB", (224, 224), (0, 0, 0)))
return image, label
# Define training function (unchanged)
def train_model(
model, train_loader, val_loader, criterion, optimizer, num_epochs=10, device="cpu"
):
best_accuracy = 0.0
for epoch in range(num_epochs):
model.train()
running_loss = 0.0
correct = 0
total = 0
for inputs, labels in train_loader:
inputs, labels = inputs.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
_, predicted = torch.max(outputs, 1)
correct += (predicted == labels).sum().item()
total += labels.size(0)
model.eval()
val_correct = 0
val_total = 0
with torch.no_grad():
for inputs, labels in val_loader:
inputs, labels = inputs.to(device), labels.to(device)
outputs = model(inputs)
_, predicted = torch.max(outputs, 1)
val_correct += (predicted == labels).sum().item()
val_total += labels.size(0)
epoch_loss = running_loss / len(train_loader)
epoch_train_accuracy = 100 * correct / total
epoch_val_accuracy = 100 * val_correct / val_total
print(
f"Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}, "
f"Train Accuracy: {epoch_train_accuracy:.2f}%, "
f"Validation Accuracy: {epoch_val_accuracy:.2f}%"
)
if epoch_val_accuracy > best_accuracy:
best_accuracy = epoch_val_accuracy
torch.save(model.state_dict(), "best_chess_piece_model.pth")
print("Training completed.")
# Path to dataset folder
dataset_path = "train" # Ensure this path is correct
# Create dataset
full_dataset = ChessPieceDataset(dataset_path, transform=transform)
# Check if dataset is empty
if len(full_dataset) == 0:
raise ValueError(
"Dataset is empty. Check dataset_path and ensure it contains valid images."
)
# Split the dataset into training and validation sets
train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])
# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load the pre-trained ResNet18 model and modify the final layer
model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
model.fc = nn.Linear(model.fc.in_features, len(full_dataset.classes))
model = model.to(device)
# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001)
# Train the model
train_model(
model, train_loader, val_loader, criterion, optimizer, num_epochs=10, device=device
)
# After training, load the best model for inference
model.load_state_dict(torch.load("best_chess_piece_model.pth", map_location=device))
model.eval()
gc.collect()
del model
torch.cuda.empty_cache()
gc.collect()
|