File size: 6,361 Bytes
ae1d0b9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
import gc
import os

import torch
from PIL import Image
from torch import nn, optim
from torch.utils.data import DataLoader, Dataset, random_split
from torchvision import models, transforms

# Define data transformations for training and validation
transform = transforms.Compose(
    [
        transforms.Resize((224, 224)),  # Ensure all images are 224x224
        transforms.ToTensor(),  # Convert to tensor
        transforms.Normalize(
            mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
        ),  # Standard for ResNet
    ]
)


# Custom dataset class for loading chess piece images
class ChessPieceDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        """
        Args:
            root_dir (str): Directory with all the images and subdirectories (class labels).
            transform (callable, optional): Optional transform to be applied on an image.
        """
        self.root_dir = root_dir
        self.transform = transform
        self.classes = sorted(
            [
                d
                for d in os.listdir(root_dir)
                if os.path.isdir(os.path.join(root_dir, d))
            ]
        )
        self.image_paths = []
        self.labels = []

        for label, class_name in enumerate(self.classes):
            class_folder = os.path.join(root_dir, class_name)
            for image_name in os.listdir(class_folder):
                img_path = os.path.join(class_folder, image_name)
                # Only include valid image files
                if img_path.lower().endswith((".png", ".jpg", ".jpeg", ".bmp", ".gif")):
                    try:
                        # Verify the image can be opened
                        with Image.open(img_path) as img:
                            img.verify()  # Verify image integrity
                        self.image_paths.append(img_path)
                        self.labels.append(label)
                    except Exception as e:
                        print(f"Skipping corrupted image {img_path}: {e}")

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        try:
            image = Image.open(img_path).convert("RGB")
        except Exception as e:
            print(f"Error loading image {img_path}: {e}")
            # Return a dummy image and label to avoid crashing
            image = Image.new("RGB", (224, 224), (0, 0, 0))
            label = self.labels[idx]
        else:
            label = self.labels[idx]

        if self.transform:
            try:
                image = self.transform(image)
                # Verify the image size after transformation
                if image.shape != (3, 224, 224):
                    print(
                        f"Unexpected image size after transform for {img_path}: {image.shape}"
                    )
            except Exception as e:
                print(f"Error applying transform to {img_path}: {e}")
                image = self.transform(Image.new("RGB", (224, 224), (0, 0, 0)))

        return image, label


# Define training function (unchanged)
def train_model(
    model, train_loader, val_loader, criterion, optimizer, num_epochs=10, device="cpu"
):
    best_accuracy = 0.0

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0

        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            correct += (predicted == labels).sum().item()
            total += labels.size(0)

        model.eval()
        val_correct = 0
        val_total = 0

        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                _, predicted = torch.max(outputs, 1)
                val_correct += (predicted == labels).sum().item()
                val_total += labels.size(0)

        epoch_loss = running_loss / len(train_loader)
        epoch_train_accuracy = 100 * correct / total
        epoch_val_accuracy = 100 * val_correct / val_total

        print(
            f"Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}, "
            f"Train Accuracy: {epoch_train_accuracy:.2f}%, "
            f"Validation Accuracy: {epoch_val_accuracy:.2f}%"
        )

        if epoch_val_accuracy > best_accuracy:
            best_accuracy = epoch_val_accuracy
            torch.save(model.state_dict(), "best_chess_piece_model.pth")

    print("Training completed.")


# Path to dataset folder
dataset_path = "train"  # Ensure this path is correct

# Create dataset
full_dataset = ChessPieceDataset(dataset_path, transform=transform)

# Check if dataset is empty
if len(full_dataset) == 0:
    raise ValueError(
        "Dataset is empty. Check dataset_path and ensure it contains valid images."
    )

# Split the dataset into training and validation sets
train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])

# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load the pre-trained ResNet18 model and modify the final layer
model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
model.fc = nn.Linear(model.fc.in_features, len(full_dataset.classes))
model = model.to(device)

# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001)

# Train the model
train_model(
    model, train_loader, val_loader, criterion, optimizer, num_epochs=10, device=device
)

# After training, load the best model for inference
model.load_state_dict(torch.load("best_chess_piece_model.pth", map_location=device))
model.eval()

gc.collect()

del model
torch.cuda.empty_cache()

gc.collect()