|
import gc |
|
import os |
|
|
|
import torch |
|
from PIL import Image |
|
from torch import nn, optim |
|
from torch.utils.data import DataLoader, Dataset, random_split |
|
from torchvision import models, transforms |
|
|
|
|
|
transform = transforms.Compose( |
|
[ |
|
transforms.Resize((224, 224)), |
|
transforms.ToTensor(), |
|
transforms.Normalize( |
|
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] |
|
), |
|
] |
|
) |
|
|
|
|
|
|
|
class ChessPieceDataset(Dataset): |
|
def __init__(self, root_dir, transform=None): |
|
""" |
|
Args: |
|
root_dir (str): Directory with all the images and subdirectories (class labels). |
|
transform (callable, optional): Optional transform to be applied on an image. |
|
""" |
|
self.root_dir = root_dir |
|
self.transform = transform |
|
self.classes = sorted( |
|
[ |
|
d |
|
for d in os.listdir(root_dir) |
|
if os.path.isdir(os.path.join(root_dir, d)) |
|
] |
|
) |
|
self.image_paths = [] |
|
self.labels = [] |
|
|
|
for label, class_name in enumerate(self.classes): |
|
class_folder = os.path.join(root_dir, class_name) |
|
for image_name in os.listdir(class_folder): |
|
img_path = os.path.join(class_folder, image_name) |
|
|
|
if img_path.lower().endswith((".png", ".jpg", ".jpeg", ".bmp", ".gif")): |
|
try: |
|
|
|
with Image.open(img_path) as img: |
|
img.verify() |
|
self.image_paths.append(img_path) |
|
self.labels.append(label) |
|
except Exception as e: |
|
print(f"Skipping corrupted image {img_path}: {e}") |
|
|
|
def __len__(self): |
|
return len(self.image_paths) |
|
|
|
def __getitem__(self, idx): |
|
img_path = self.image_paths[idx] |
|
try: |
|
image = Image.open(img_path).convert("RGB") |
|
except Exception as e: |
|
print(f"Error loading image {img_path}: {e}") |
|
|
|
image = Image.new("RGB", (224, 224), (0, 0, 0)) |
|
label = self.labels[idx] |
|
else: |
|
label = self.labels[idx] |
|
|
|
if self.transform: |
|
try: |
|
image = self.transform(image) |
|
|
|
if image.shape != (3, 224, 224): |
|
print( |
|
f"Unexpected image size after transform for {img_path}: {image.shape}" |
|
) |
|
except Exception as e: |
|
print(f"Error applying transform to {img_path}: {e}") |
|
image = self.transform(Image.new("RGB", (224, 224), (0, 0, 0))) |
|
|
|
return image, label |
|
|
|
|
|
|
|
def train_model( |
|
model, train_loader, val_loader, criterion, optimizer, num_epochs=10, device="cpu" |
|
): |
|
best_accuracy = 0.0 |
|
|
|
for epoch in range(num_epochs): |
|
model.train() |
|
running_loss = 0.0 |
|
correct = 0 |
|
total = 0 |
|
|
|
for inputs, labels in train_loader: |
|
inputs, labels = inputs.to(device), labels.to(device) |
|
optimizer.zero_grad() |
|
outputs = model(inputs) |
|
loss = criterion(outputs, labels) |
|
loss.backward() |
|
optimizer.step() |
|
|
|
running_loss += loss.item() |
|
_, predicted = torch.max(outputs, 1) |
|
correct += (predicted == labels).sum().item() |
|
total += labels.size(0) |
|
|
|
model.eval() |
|
val_correct = 0 |
|
val_total = 0 |
|
|
|
with torch.no_grad(): |
|
for inputs, labels in val_loader: |
|
inputs, labels = inputs.to(device), labels.to(device) |
|
outputs = model(inputs) |
|
_, predicted = torch.max(outputs, 1) |
|
val_correct += (predicted == labels).sum().item() |
|
val_total += labels.size(0) |
|
|
|
epoch_loss = running_loss / len(train_loader) |
|
epoch_train_accuracy = 100 * correct / total |
|
epoch_val_accuracy = 100 * val_correct / val_total |
|
|
|
print( |
|
f"Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}, " |
|
f"Train Accuracy: {epoch_train_accuracy:.2f}%, " |
|
f"Validation Accuracy: {epoch_val_accuracy:.2f}%" |
|
) |
|
|
|
if epoch_val_accuracy > best_accuracy: |
|
best_accuracy = epoch_val_accuracy |
|
torch.save(model.state_dict(), "best_chess_piece_model.pth") |
|
|
|
print("Training completed.") |
|
|
|
|
|
|
|
dataset_path = "train" |
|
|
|
|
|
full_dataset = ChessPieceDataset(dataset_path, transform=transform) |
|
|
|
|
|
if len(full_dataset) == 0: |
|
raise ValueError( |
|
"Dataset is empty. Check dataset_path and ensure it contains valid images." |
|
) |
|
|
|
|
|
train_size = int(0.8 * len(full_dataset)) |
|
val_size = len(full_dataset) - train_size |
|
train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size]) |
|
|
|
|
|
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True) |
|
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False) |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT) |
|
model.fc = nn.Linear(model.fc.in_features, len(full_dataset.classes)) |
|
model = model.to(device) |
|
|
|
|
|
criterion = nn.CrossEntropyLoss() |
|
optimizer = optim.Adam(model.parameters(), lr=0.0001) |
|
|
|
|
|
train_model( |
|
model, train_loader, val_loader, criterion, optimizer, num_epochs=10, device=device |
|
) |
|
|
|
|
|
model.load_state_dict(torch.load("best_chess_piece_model.pth", map_location=device)) |
|
model.eval() |
|
|
|
gc.collect() |
|
|
|
del model |
|
torch.cuda.empty_cache() |
|
|
|
gc.collect() |
|
|