import gc import os import torch from PIL import Image from torch import nn, optim from torch.utils.data import DataLoader, Dataset, random_split from torchvision import models, transforms # Define data transformations for training and validation transform = transforms.Compose( [ transforms.Resize((224, 224)), # Ensure all images are 224x224 transforms.ToTensor(), # Convert to tensor transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ), # Standard for ResNet ] ) # Custom dataset class for loading chess piece images class ChessPieceDataset(Dataset): def __init__(self, root_dir, transform=None): """ Args: root_dir (str): Directory with all the images and subdirectories (class labels). transform (callable, optional): Optional transform to be applied on an image. """ self.root_dir = root_dir self.transform = transform self.classes = sorted( [ d for d in os.listdir(root_dir) if os.path.isdir(os.path.join(root_dir, d)) ] ) self.image_paths = [] self.labels = [] for label, class_name in enumerate(self.classes): class_folder = os.path.join(root_dir, class_name) for image_name in os.listdir(class_folder): img_path = os.path.join(class_folder, image_name) # Only include valid image files if img_path.lower().endswith((".png", ".jpg", ".jpeg", ".bmp", ".gif")): try: # Verify the image can be opened with Image.open(img_path) as img: img.verify() # Verify image integrity self.image_paths.append(img_path) self.labels.append(label) except Exception as e: print(f"Skipping corrupted image {img_path}: {e}") def __len__(self): return len(self.image_paths) def __getitem__(self, idx): img_path = self.image_paths[idx] try: image = Image.open(img_path).convert("RGB") except Exception as e: print(f"Error loading image {img_path}: {e}") # Return a dummy image and label to avoid crashing image = Image.new("RGB", (224, 224), (0, 0, 0)) label = self.labels[idx] else: label = self.labels[idx] if self.transform: try: image = self.transform(image) # Verify the image size after transformation if image.shape != (3, 224, 224): print( f"Unexpected image size after transform for {img_path}: {image.shape}" ) except Exception as e: print(f"Error applying transform to {img_path}: {e}") image = self.transform(Image.new("RGB", (224, 224), (0, 0, 0))) return image, label # Define training function (unchanged) def train_model( model, train_loader, val_loader, criterion, optimizer, num_epochs=10, device="cpu" ): best_accuracy = 0.0 for epoch in range(num_epochs): model.train() running_loss = 0.0 correct = 0 total = 0 for inputs, labels in train_loader: inputs, labels = inputs.to(device), labels.to(device) optimizer.zero_grad() outputs = model(inputs) loss = criterion(outputs, labels) loss.backward() optimizer.step() running_loss += loss.item() _, predicted = torch.max(outputs, 1) correct += (predicted == labels).sum().item() total += labels.size(0) model.eval() val_correct = 0 val_total = 0 with torch.no_grad(): for inputs, labels in val_loader: inputs, labels = inputs.to(device), labels.to(device) outputs = model(inputs) _, predicted = torch.max(outputs, 1) val_correct += (predicted == labels).sum().item() val_total += labels.size(0) epoch_loss = running_loss / len(train_loader) epoch_train_accuracy = 100 * correct / total epoch_val_accuracy = 100 * val_correct / val_total print( f"Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}, " f"Train Accuracy: {epoch_train_accuracy:.2f}%, " f"Validation Accuracy: {epoch_val_accuracy:.2f}%" ) if epoch_val_accuracy > best_accuracy: best_accuracy = epoch_val_accuracy torch.save(model.state_dict(), "best_chess_piece_model.pth") print("Training completed.") # Path to dataset folder dataset_path = "train" # Ensure this path is correct # Create dataset full_dataset = ChessPieceDataset(dataset_path, transform=transform) # Check if dataset is empty if len(full_dataset) == 0: raise ValueError( "Dataset is empty. Check dataset_path and ensure it contains valid images." ) # Split the dataset into training and validation sets train_size = int(0.8 * len(full_dataset)) val_size = len(full_dataset) - train_size train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size]) # Create DataLoaders train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True) val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False) # Set device device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Load the pre-trained ResNet18 model and modify the final layer model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT) model.fc = nn.Linear(model.fc.in_features, len(full_dataset.classes)) model = model.to(device) # Define loss function and optimizer criterion = nn.CrossEntropyLoss() optimizer = optim.Adam(model.parameters(), lr=0.0001) # Train the model train_model( model, train_loader, val_loader, criterion, optimizer, num_epochs=10, device=device ) # After training, load the best model for inference model.load_state_dict(torch.load("best_chess_piece_model.pth", map_location=device)) model.eval() gc.collect() del model torch.cuda.empty_cache() gc.collect()