AK47-M4A4's picture
v1
ae1d0b9
import gc
import os
import torch
from PIL import Image
from torch import nn, optim
from torch.utils.data import DataLoader, Dataset, random_split
from torchvision import models, transforms
# Define data transformations for training and validation
transform = transforms.Compose(
[
transforms.Resize((224, 224)), # Ensure all images are 224x224
transforms.ToTensor(), # Convert to tensor
transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
), # Standard for ResNet
]
)
# Custom dataset class for loading chess piece images
class ChessPieceDataset(Dataset):
def __init__(self, root_dir, transform=None):
"""
Args:
root_dir (str): Directory with all the images and subdirectories (class labels).
transform (callable, optional): Optional transform to be applied on an image.
"""
self.root_dir = root_dir
self.transform = transform
self.classes = sorted(
[
d
for d in os.listdir(root_dir)
if os.path.isdir(os.path.join(root_dir, d))
]
)
self.image_paths = []
self.labels = []
for label, class_name in enumerate(self.classes):
class_folder = os.path.join(root_dir, class_name)
for image_name in os.listdir(class_folder):
img_path = os.path.join(class_folder, image_name)
# Only include valid image files
if img_path.lower().endswith((".png", ".jpg", ".jpeg", ".bmp", ".gif")):
try:
# Verify the image can be opened
with Image.open(img_path) as img:
img.verify() # Verify image integrity
self.image_paths.append(img_path)
self.labels.append(label)
except Exception as e:
print(f"Skipping corrupted image {img_path}: {e}")
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
img_path = self.image_paths[idx]
try:
image = Image.open(img_path).convert("RGB")
except Exception as e:
print(f"Error loading image {img_path}: {e}")
# Return a dummy image and label to avoid crashing
image = Image.new("RGB", (224, 224), (0, 0, 0))
label = self.labels[idx]
else:
label = self.labels[idx]
if self.transform:
try:
image = self.transform(image)
# Verify the image size after transformation
if image.shape != (3, 224, 224):
print(
f"Unexpected image size after transform for {img_path}: {image.shape}"
)
except Exception as e:
print(f"Error applying transform to {img_path}: {e}")
image = self.transform(Image.new("RGB", (224, 224), (0, 0, 0)))
return image, label
# Define training function (unchanged)
def train_model(
model, train_loader, val_loader, criterion, optimizer, num_epochs=10, device="cpu"
):
best_accuracy = 0.0
for epoch in range(num_epochs):
model.train()
running_loss = 0.0
correct = 0
total = 0
for inputs, labels in train_loader:
inputs, labels = inputs.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
_, predicted = torch.max(outputs, 1)
correct += (predicted == labels).sum().item()
total += labels.size(0)
model.eval()
val_correct = 0
val_total = 0
with torch.no_grad():
for inputs, labels in val_loader:
inputs, labels = inputs.to(device), labels.to(device)
outputs = model(inputs)
_, predicted = torch.max(outputs, 1)
val_correct += (predicted == labels).sum().item()
val_total += labels.size(0)
epoch_loss = running_loss / len(train_loader)
epoch_train_accuracy = 100 * correct / total
epoch_val_accuracy = 100 * val_correct / val_total
print(
f"Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}, "
f"Train Accuracy: {epoch_train_accuracy:.2f}%, "
f"Validation Accuracy: {epoch_val_accuracy:.2f}%"
)
if epoch_val_accuracy > best_accuracy:
best_accuracy = epoch_val_accuracy
torch.save(model.state_dict(), "best_chess_piece_model.pth")
print("Training completed.")
# Path to dataset folder
dataset_path = "train" # Ensure this path is correct
# Create dataset
full_dataset = ChessPieceDataset(dataset_path, transform=transform)
# Check if dataset is empty
if len(full_dataset) == 0:
raise ValueError(
"Dataset is empty. Check dataset_path and ensure it contains valid images."
)
# Split the dataset into training and validation sets
train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])
# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load the pre-trained ResNet18 model and modify the final layer
model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
model.fc = nn.Linear(model.fc.in_features, len(full_dataset.classes))
model = model.to(device)
# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001)
# Train the model
train_model(
model, train_loader, val_loader, criterion, optimizer, num_epochs=10, device=device
)
# After training, load the best model for inference
model.load_state_dict(torch.load("best_chess_piece_model.pth", map_location=device))
model.eval()
gc.collect()
del model
torch.cuda.empty_cache()
gc.collect()