File size: 7,567 Bytes
ae1d0b9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
import gc

import cv2
import numpy as np
import torch
from PIL import Image
from torchvision import models, transforms

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load the pre-trained model
try:
    model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
    model.fc = torch.nn.Linear(model.fc.in_features, 13)  # 13 classes including 'empty'
    model.load_state_dict(torch.load("best_chess_piece_model.pth", map_location=device))
    model.eval()
    model = model.to(device)
except Exception as e:
    print(f"Error loading model: {e}")
    exit(1)

# Mapping chess piece indices
piece_labels = [
    "black_bishop",
    "black_king",
    "black_knight",
    "black_pawn",
    "black_queen",
    "black_rook",
    "empty",
    "white_bishop",
    "white_king",
    "white_knight",
    "white_pawn",
    "white_queen",
    "white_rook",
]

# Define chessboard coordinates (0,0) is top-left (a8), (7,7) is bottom-right (h1)
coordinates = [(i, j) for i in range(8) for j in range(8)]

# Define a transformation to prepare images for the model
transform = transforms.Compose(
    [
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]
)


# Function to process and predict the piece type at each square
def predict_piece(image, model, device):
    try:
        if len(image.shape) == 2 or image.shape[2] == 1:
            image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
        else:
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        image = Image.fromarray(image)
        image_tensor = transform(image).unsqueeze(0).to(device)
        with torch.no_grad():
            output = model(image_tensor)
            _, predicted = torch.max(output, 1)
        return piece_labels[predicted.item()]
    except Exception as e:
        print(f"Error predicting piece: {e}")
        return "unknown"


# Function to detect chessboard grid using edge detection and Hough lines
def detect_chessboard_grid(image):
    gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
    # Enhance contrast
    gray = cv2.convertScaleAbs(gray, alpha=1.2, beta=20)
    # Apply Gaussian blur to reduce noise
    gray = cv2.GaussianBlur(gray, (5, 5), 0)
    # Edge detection with Canny
    edges = cv2.Canny(gray, 50, 150, apertureSize=3)

    # Detect lines using Hough Transform
    lines = cv2.HoughLinesP(
        edges, 1, np.pi / 180, threshold=80, minLineLength=50, maxLineGap=10
    )

    if lines is None:
        print("No lines detected.")
        return None, edges

    # Separate horizontal and vertical lines
    h_lines = []
    v_lines = []
    for line in lines:
        x1, y1, x2, y2 = line[0]
        if abs(x2 - x1) > abs(y2 - y1):  # Horizontal line
            h_lines.append((y1, x1, x2))
        else:  # Vertical line
            v_lines.append((x1, y1, y2))

    # Sort and filter to get exactly 9 lines for each
    h_lines = sorted(h_lines, key=lambda x: x[0])[:9]  # Top 9 horizontal lines
    v_lines = sorted(v_lines, key=lambda x: x[0])[:9]  # Top 9 vertical lines

    if len(h_lines) < 9 or len(v_lines) < 9:
        print(
            f"Insufficient lines detected: {len(h_lines)} horizontal, {len(v_lines)} vertical"
        )
        return None, edges

    # Find intersections to get 8x8 grid corners
    corners = []
    for h in h_lines:
        y = h[0]
        for v in v_lines:
            x = v[0]
            corners.append([x, y])

    # corners = []
    # for i in range(8):
    #     for j in range(8):
    #         x = int((v_lines[j][0] + v_lines[j + 1][0]) / 2)
    #         y = int((h_lines[i][1] + h_lines[i + 1][1]) / 2)
    #         corners.append([x, y])

    # Ensure exactly 64 corners (8x8 grid)
    if len(corners) != 64:
        print(f"Expected 64 corners, got {len(corners)}")
        return None, edges

    corners = np.array(corners, dtype=np.float32).reshape(8, 8, 2)

    # Visualize detected lines for debugging
    debug_image = image.copy()
    for y, x1, x2 in h_lines:
        cv2.line(debug_image, (x1, y), (x2, y), (0, 255, 0), 2)
    for x, y1, y2 in v_lines:
        cv2.line(debug_image, (x, y1), (x, y2), (0, 0, 255), 2)
    cv2.imwrite("lines_debug.png", debug_image)

    return corners, edges


# Function to extract coordinates of chess pieces from an image
def extract_chessboard_coordinates(image_path):
    try:
        image = cv2.imread(image_path)
        if image is None:
            print(f"Failed to load image: {image_path}")
            return []
    except Exception as e:
        print(f"Error loading image {image_path}: {e}")
        return []

    # Try OpenCV's chessboard detection first
    gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
    ret, corners = cv2.findChessboardCorners(gray, (8, 8), None)

    if ret:
        corners = cv2.cornerSubPix(
            gray,
            corners,
            (11, 11),
            (-1, -1),
            criteria=(cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 30, 0.1),
        )
        corners = corners.reshape(8, 8, 2)
    else:
        print("OpenCV chessboard detection failed. Attempting edge-based detection.")
        corners, edges = detect_chessboard_grid(image)
        if corners is None:
            # Save edges for debugging
            cv2.imwrite("edges_debug.png", edges)
            print("Saved edge detection output to edges_debug.png")
            return []
        # Save debug image with detected corners
        debug_image = image.copy()
        for h in range(8):
            for v in range(8):
                x, y = int(corners[h, v, 0]), int(corners[h, v, 1])
                cv2.circle(debug_image, (x, y), 5, (0, 255, 0), -1)
        cv2.imwrite("grid_debug.png", debug_image)
        print("Saved grid detection debug image to grid_debug.png")

    # Calculate square size dynamically
    square_width = np.mean(
        [
            np.linalg.norm(corners[i, j] - corners[i, j + 1])
            for i in range(8)
            for j in range(7)
        ]
    )
    square_height = np.mean(
        [
            np.linalg.norm(corners[i, j] - corners[i + 1, j])
            for i in range(7)
            for j in range(8)
        ]
    )
    square_size = int(min(square_width, square_height))

    # Create a blank grid to store coordinates
    piece_coordinates = []

    # Loop through all coordinates and detect pieces
    for i, j in coordinates:
        try:
            x = int(corners[i, j, 0])
            y = int(corners[i, j, 1])
            w = h = square_size

            x = max(0, x)
            y = max(0, y)
            x_end = min(image.shape[1], x + w)
            y_end = min(image.shape[0], y + h)
            roi = image[y:y_end, x:x_end]

            if roi.shape[0] == 0 or roi.shape[1] == 0:
                print(f"Invalid ROI at square ({i}, {j})")
                piece_coordinates.append(((i, j), "unknown"))
                continue

            predicted_piece = predict_piece(roi, model, device)
            piece_coordinates.append(((i, j), predicted_piece))
        except Exception as e:
            print(f"Error processing square ({i}, {j}): {e}")
            piece_coordinates.append(((i, j), "unknown"))

    return piece_coordinates


# Example usage
IMAGE_PATH = "test.png"
coordinates = extract_chessboard_coordinates(IMAGE_PATH)
for coord, piece in coordinates:
    print(f"Piece at {coord}: {piece}")

del model
if torch.cuda.is_available():
    torch.cuda.empty_cache()
gc.collect()