|
import gc |
|
|
|
import cv2 |
|
import numpy as np |
|
import torch |
|
from PIL import Image |
|
from torchvision import models, transforms |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
try: |
|
model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT) |
|
model.fc = torch.nn.Linear(model.fc.in_features, 13) |
|
model.load_state_dict(torch.load("best_chess_piece_model.pth", map_location=device)) |
|
model.eval() |
|
model = model.to(device) |
|
except Exception as e: |
|
print(f"Error loading model: {e}") |
|
exit(1) |
|
|
|
|
|
piece_labels = [ |
|
"black_bishop", |
|
"black_king", |
|
"black_knight", |
|
"black_pawn", |
|
"black_queen", |
|
"black_rook", |
|
"empty", |
|
"white_bishop", |
|
"white_king", |
|
"white_knight", |
|
"white_pawn", |
|
"white_queen", |
|
"white_rook", |
|
] |
|
|
|
|
|
coordinates = [(i, j) for i in range(8) for j in range(8)] |
|
|
|
|
|
transform = transforms.Compose( |
|
[ |
|
transforms.Resize((224, 224)), |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), |
|
] |
|
) |
|
|
|
|
|
|
|
def predict_piece(image, model, device): |
|
try: |
|
if len(image.shape) == 2 or image.shape[2] == 1: |
|
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB) |
|
else: |
|
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) |
|
|
|
image = Image.fromarray(image) |
|
image_tensor = transform(image).unsqueeze(0).to(device) |
|
with torch.no_grad(): |
|
output = model(image_tensor) |
|
_, predicted = torch.max(output, 1) |
|
return piece_labels[predicted.item()] |
|
except Exception as e: |
|
print(f"Error predicting piece: {e}") |
|
return "unknown" |
|
|
|
|
|
|
|
def detect_chessboard_grid(image): |
|
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) |
|
|
|
gray = cv2.convertScaleAbs(gray, alpha=1.2, beta=20) |
|
|
|
gray = cv2.GaussianBlur(gray, (5, 5), 0) |
|
|
|
edges = cv2.Canny(gray, 50, 150, apertureSize=3) |
|
|
|
|
|
lines = cv2.HoughLinesP( |
|
edges, 1, np.pi / 180, threshold=80, minLineLength=50, maxLineGap=10 |
|
) |
|
|
|
if lines is None: |
|
print("No lines detected.") |
|
return None, edges |
|
|
|
|
|
h_lines = [] |
|
v_lines = [] |
|
for line in lines: |
|
x1, y1, x2, y2 = line[0] |
|
if abs(x2 - x1) > abs(y2 - y1): |
|
h_lines.append((y1, x1, x2)) |
|
else: |
|
v_lines.append((x1, y1, y2)) |
|
|
|
|
|
h_lines = sorted(h_lines, key=lambda x: x[0])[:9] |
|
v_lines = sorted(v_lines, key=lambda x: x[0])[:9] |
|
|
|
if len(h_lines) < 9 or len(v_lines) < 9: |
|
print( |
|
f"Insufficient lines detected: {len(h_lines)} horizontal, {len(v_lines)} vertical" |
|
) |
|
return None, edges |
|
|
|
|
|
corners = [] |
|
for h in h_lines: |
|
y = h[0] |
|
for v in v_lines: |
|
x = v[0] |
|
corners.append([x, y]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if len(corners) != 64: |
|
print(f"Expected 64 corners, got {len(corners)}") |
|
return None, edges |
|
|
|
corners = np.array(corners, dtype=np.float32).reshape(8, 8, 2) |
|
|
|
|
|
debug_image = image.copy() |
|
for y, x1, x2 in h_lines: |
|
cv2.line(debug_image, (x1, y), (x2, y), (0, 255, 0), 2) |
|
for x, y1, y2 in v_lines: |
|
cv2.line(debug_image, (x, y1), (x, y2), (0, 0, 255), 2) |
|
cv2.imwrite("lines_debug.png", debug_image) |
|
|
|
return corners, edges |
|
|
|
|
|
|
|
def extract_chessboard_coordinates(image_path): |
|
try: |
|
image = cv2.imread(image_path) |
|
if image is None: |
|
print(f"Failed to load image: {image_path}") |
|
return [] |
|
except Exception as e: |
|
print(f"Error loading image {image_path}: {e}") |
|
return [] |
|
|
|
|
|
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) |
|
ret, corners = cv2.findChessboardCorners(gray, (8, 8), None) |
|
|
|
if ret: |
|
corners = cv2.cornerSubPix( |
|
gray, |
|
corners, |
|
(11, 11), |
|
(-1, -1), |
|
criteria=(cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 30, 0.1), |
|
) |
|
corners = corners.reshape(8, 8, 2) |
|
else: |
|
print("OpenCV chessboard detection failed. Attempting edge-based detection.") |
|
corners, edges = detect_chessboard_grid(image) |
|
if corners is None: |
|
|
|
cv2.imwrite("edges_debug.png", edges) |
|
print("Saved edge detection output to edges_debug.png") |
|
return [] |
|
|
|
debug_image = image.copy() |
|
for h in range(8): |
|
for v in range(8): |
|
x, y = int(corners[h, v, 0]), int(corners[h, v, 1]) |
|
cv2.circle(debug_image, (x, y), 5, (0, 255, 0), -1) |
|
cv2.imwrite("grid_debug.png", debug_image) |
|
print("Saved grid detection debug image to grid_debug.png") |
|
|
|
|
|
square_width = np.mean( |
|
[ |
|
np.linalg.norm(corners[i, j] - corners[i, j + 1]) |
|
for i in range(8) |
|
for j in range(7) |
|
] |
|
) |
|
square_height = np.mean( |
|
[ |
|
np.linalg.norm(corners[i, j] - corners[i + 1, j]) |
|
for i in range(7) |
|
for j in range(8) |
|
] |
|
) |
|
square_size = int(min(square_width, square_height)) |
|
|
|
|
|
piece_coordinates = [] |
|
|
|
|
|
for i, j in coordinates: |
|
try: |
|
x = int(corners[i, j, 0]) |
|
y = int(corners[i, j, 1]) |
|
w = h = square_size |
|
|
|
x = max(0, x) |
|
y = max(0, y) |
|
x_end = min(image.shape[1], x + w) |
|
y_end = min(image.shape[0], y + h) |
|
roi = image[y:y_end, x:x_end] |
|
|
|
if roi.shape[0] == 0 or roi.shape[1] == 0: |
|
print(f"Invalid ROI at square ({i}, {j})") |
|
piece_coordinates.append(((i, j), "unknown")) |
|
continue |
|
|
|
predicted_piece = predict_piece(roi, model, device) |
|
piece_coordinates.append(((i, j), predicted_piece)) |
|
except Exception as e: |
|
print(f"Error processing square ({i}, {j}): {e}") |
|
piece_coordinates.append(((i, j), "unknown")) |
|
|
|
return piece_coordinates |
|
|
|
|
|
|
|
IMAGE_PATH = "test.png" |
|
coordinates = extract_chessboard_coordinates(IMAGE_PATH) |
|
for coord, piece in coordinates: |
|
print(f"Piece at {coord}: {piece}") |
|
|
|
del model |
|
if torch.cuda.is_available(): |
|
torch.cuda.empty_cache() |
|
gc.collect() |
|
|