AK47-M4A4's picture
v1
ae1d0b9
import gc
import cv2
import numpy as np
import torch
from PIL import Image
from torchvision import models, transforms
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load the pre-trained model
try:
model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
model.fc = torch.nn.Linear(model.fc.in_features, 13) # 13 classes including 'empty'
model.load_state_dict(torch.load("best_chess_piece_model.pth", map_location=device))
model.eval()
model = model.to(device)
except Exception as e:
print(f"Error loading model: {e}")
exit(1)
# Mapping chess piece indices
piece_labels = [
"black_bishop",
"black_king",
"black_knight",
"black_pawn",
"black_queen",
"black_rook",
"empty",
"white_bishop",
"white_king",
"white_knight",
"white_pawn",
"white_queen",
"white_rook",
]
# Define chessboard coordinates (0,0) is top-left (a8), (7,7) is bottom-right (h1)
coordinates = [(i, j) for i in range(8) for j in range(8)]
# Define a transformation to prepare images for the model
transform = transforms.Compose(
[
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
)
# Function to process and predict the piece type at each square
def predict_piece(image, model, device):
try:
if len(image.shape) == 2 or image.shape[2] == 1:
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
else:
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image = Image.fromarray(image)
image_tensor = transform(image).unsqueeze(0).to(device)
with torch.no_grad():
output = model(image_tensor)
_, predicted = torch.max(output, 1)
return piece_labels[predicted.item()]
except Exception as e:
print(f"Error predicting piece: {e}")
return "unknown"
# Function to detect chessboard grid using edge detection and Hough lines
def detect_chessboard_grid(image):
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
# Enhance contrast
gray = cv2.convertScaleAbs(gray, alpha=1.2, beta=20)
# Apply Gaussian blur to reduce noise
gray = cv2.GaussianBlur(gray, (5, 5), 0)
# Edge detection with Canny
edges = cv2.Canny(gray, 50, 150, apertureSize=3)
# Detect lines using Hough Transform
lines = cv2.HoughLinesP(
edges, 1, np.pi / 180, threshold=80, minLineLength=50, maxLineGap=10
)
if lines is None:
print("No lines detected.")
return None, edges
# Separate horizontal and vertical lines
h_lines = []
v_lines = []
for line in lines:
x1, y1, x2, y2 = line[0]
if abs(x2 - x1) > abs(y2 - y1): # Horizontal line
h_lines.append((y1, x1, x2))
else: # Vertical line
v_lines.append((x1, y1, y2))
# Sort and filter to get exactly 9 lines for each
h_lines = sorted(h_lines, key=lambda x: x[0])[:9] # Top 9 horizontal lines
v_lines = sorted(v_lines, key=lambda x: x[0])[:9] # Top 9 vertical lines
if len(h_lines) < 9 or len(v_lines) < 9:
print(
f"Insufficient lines detected: {len(h_lines)} horizontal, {len(v_lines)} vertical"
)
return None, edges
# Find intersections to get 8x8 grid corners
corners = []
for h in h_lines:
y = h[0]
for v in v_lines:
x = v[0]
corners.append([x, y])
# corners = []
# for i in range(8):
# for j in range(8):
# x = int((v_lines[j][0] + v_lines[j + 1][0]) / 2)
# y = int((h_lines[i][1] + h_lines[i + 1][1]) / 2)
# corners.append([x, y])
# Ensure exactly 64 corners (8x8 grid)
if len(corners) != 64:
print(f"Expected 64 corners, got {len(corners)}")
return None, edges
corners = np.array(corners, dtype=np.float32).reshape(8, 8, 2)
# Visualize detected lines for debugging
debug_image = image.copy()
for y, x1, x2 in h_lines:
cv2.line(debug_image, (x1, y), (x2, y), (0, 255, 0), 2)
for x, y1, y2 in v_lines:
cv2.line(debug_image, (x, y1), (x, y2), (0, 0, 255), 2)
cv2.imwrite("lines_debug.png", debug_image)
return corners, edges
# Function to extract coordinates of chess pieces from an image
def extract_chessboard_coordinates(image_path):
try:
image = cv2.imread(image_path)
if image is None:
print(f"Failed to load image: {image_path}")
return []
except Exception as e:
print(f"Error loading image {image_path}: {e}")
return []
# Try OpenCV's chessboard detection first
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
ret, corners = cv2.findChessboardCorners(gray, (8, 8), None)
if ret:
corners = cv2.cornerSubPix(
gray,
corners,
(11, 11),
(-1, -1),
criteria=(cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 30, 0.1),
)
corners = corners.reshape(8, 8, 2)
else:
print("OpenCV chessboard detection failed. Attempting edge-based detection.")
corners, edges = detect_chessboard_grid(image)
if corners is None:
# Save edges for debugging
cv2.imwrite("edges_debug.png", edges)
print("Saved edge detection output to edges_debug.png")
return []
# Save debug image with detected corners
debug_image = image.copy()
for h in range(8):
for v in range(8):
x, y = int(corners[h, v, 0]), int(corners[h, v, 1])
cv2.circle(debug_image, (x, y), 5, (0, 255, 0), -1)
cv2.imwrite("grid_debug.png", debug_image)
print("Saved grid detection debug image to grid_debug.png")
# Calculate square size dynamically
square_width = np.mean(
[
np.linalg.norm(corners[i, j] - corners[i, j + 1])
for i in range(8)
for j in range(7)
]
)
square_height = np.mean(
[
np.linalg.norm(corners[i, j] - corners[i + 1, j])
for i in range(7)
for j in range(8)
]
)
square_size = int(min(square_width, square_height))
# Create a blank grid to store coordinates
piece_coordinates = []
# Loop through all coordinates and detect pieces
for i, j in coordinates:
try:
x = int(corners[i, j, 0])
y = int(corners[i, j, 1])
w = h = square_size
x = max(0, x)
y = max(0, y)
x_end = min(image.shape[1], x + w)
y_end = min(image.shape[0], y + h)
roi = image[y:y_end, x:x_end]
if roi.shape[0] == 0 or roi.shape[1] == 0:
print(f"Invalid ROI at square ({i}, {j})")
piece_coordinates.append(((i, j), "unknown"))
continue
predicted_piece = predict_piece(roi, model, device)
piece_coordinates.append(((i, j), predicted_piece))
except Exception as e:
print(f"Error processing square ({i}, {j}): {e}")
piece_coordinates.append(((i, j), "unknown"))
return piece_coordinates
# Example usage
IMAGE_PATH = "test.png"
coordinates = extract_chessboard_coordinates(IMAGE_PATH)
for coord, piece in coordinates:
print(f"Piece at {coord}: {piece}")
del model
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()