Spaces:
Sleeping
Sleeping
File size: 10,415 Bytes
5ac9d87 df04fd4 5ac9d87 d238f88 73b83df 5ac9d87 d238f88 5ac9d87 4d8763f 131f827 5ac9d87 8f200cf 5ac9d87 e07e067 5ac9d87 e07e067 5ac9d87 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 |
from ultralytics import YOLO
import cv2
from stockfish import Stockfish
import os
import numpy as np
import streamlit as st
import requests
# Constants
FEN_MAPPING = {
"black-pawn": "p", "black-rook": "r", "black-knight": "n", "black-bishop": "b", "black-queen": "q", "black-king": "k",
"white-pawn": "P", "white-rook": "R", "white-knight": "N", "white-bishop": "B", "white-queen": "Q", "white-king": "K"
}
GRID_BORDER = 10 # Border size in pixels
GRID_SIZE = 204 # Effective grid size (10px to 214px)
BLOCK_SIZE = GRID_SIZE // 8 # Each block is ~25px
X_LABELS = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h'] # Labels for x-axis (a to h)
Y_LABELS = [8, 7, 6, 5, 4, 3, 2, 1] # Reversed labels for y-axis (8 to 1)
# Functions
def get_grid_coordinate(pixel_x, pixel_y):
"""
Function to determine the grid coordinate of a pixel, considering a 10px border and
the grid where bottom-left is (a, 1) and top-left is (h, 8).
"""
# Grid settings
border = 10 # 10px border
grid_size = 204 # Effective grid size (10px to 214px)
block_size = grid_size // 8 # Each block is ~25px
x_labels = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h'] # Labels for x-axis (a to h)
y_labels = [8, 7, 6, 5, 4, 3, 2, 1] # Reversed labels for y-axis (8 to 1)
# Adjust pixel_x and pixel_y by subtracting the border (grid starts at pixel 10)
adjusted_x = pixel_x - border
adjusted_y = pixel_y - border
# Check bounds
if adjusted_x < 0 or adjusted_y < 0 or adjusted_x >= grid_size or adjusted_y >= grid_size:
return "Pixel outside grid bounds"
# Determine the grid column and row
x_index = adjusted_x // block_size
y_index = adjusted_y // block_size
if x_index < 0 or x_index >= len(x_labels) or y_index < 0 or y_index >= len(y_labels):
return "Pixel outside grid bounds"
# Convert indices to grid coordinates
x_index = adjusted_x // block_size # Determine the column index (0-7)
y_index = adjusted_y // block_size # Determine the row index (0-7)
# Convert row index to the correct label, with '8' at the bottom
y_labeld = y_labels[y_index] # Correct index directly maps to '8' to '1'
x_label = x_labels[x_index]
y_label = 8 - y_labeld + 1
return f"{x_label}{y_label}"
def predict_next_move(fen, stockfish):
"""
Predict the next move using Stockfish.
"""
if stockfish.is_fen_valid(fen):
stockfish.set_fen_position(fen)
else:
return "Invalid FEN notation!"
best_move = stockfish.get_best_move()
ans = transform_string(best_move)
return f"The predicted next move is: {ans}" if best_move else "No valid move found (checkmate/stalemate)."
# def download_stockfish():
# url = "https://drive.google.com/file/d/18pkwBVc13fgKP3LzrTHE4yzhjyGJexlR/view?usp=sharing" # Replace with the actual link
# file_name = "stockfish-windows-x86-64-avx2.exe"
# if not os.path.exists(file_name):
# print(f"Downloading {file_name}...")
# response = requests.get(url, stream=True)
# with open(file_name, "wb") as file:
# for chunk in response.iter_content(chunk_size=1024):
# if chunk:
# file.write(chunk)
# print(f"{file_name} downloaded successfully.")
def process_image(image_path):
# Ensure output directory exists
if not os.path.exists('output'):
os.makedirs('output')
# Load the segmentation model
segmentation_model = YOLO("segmentation.pt")
# Run inference to get segmentation results
results = segmentation_model.predict(
source=image_path,
conf=0.8 # Confidence threshold
)
# Initialize variables for the segmented mask and bounding box
segmentation_mask = None
bbox = None
for result in results:
if result.boxes.conf[0] >= 0.8: # Filter results by confidence
segmentation_mask = result.masks.data.cpu().numpy().astype(np.uint8)[0]
bbox = result.boxes.xyxy[0].cpu().numpy() # Get the bounding box coordinates
break
if segmentation_mask is None:
print("No segmentation mask with confidence above 0.8 found.")
return None
# Load the image
image = cv2.imread(image_path)
# Resize segmentation mask to match the input image dimensions
segmentation_mask_resized = cv2.resize(segmentation_mask, (image.shape[1], image.shape[0]))
# Extract bounding box coordinates
if bbox is not None:
x1, y1, x2, y2 = bbox
# Crop the segmented region based on the bounding box
cropped_segment = image[int(y1):int(y2), int(x1):int(x2)]
# Save the cropped segmented image
cropped_image_path = 'output/cropped_segment.jpg'
cv2.imwrite(cropped_image_path, cropped_segment)
print(f"Cropped segmented image saved to {cropped_image_path}")
st.image(cropped_segment, caption="Uploaded Image", use_column_width=True)
# Return the cropped image
return cropped_segment
def transform_string(input_str):
# Remove extra spaces and convert to lowercase
input_str = input_str.strip().lower()
# Check if input is valid
if len(input_str) != 4 or not input_str[0].isalpha() or not input_str[1].isdigit() or \
not input_str[2].isalpha() or not input_str[3].isdigit():
return "Invalid input"
# Define mappings
letter_mapping = {
'a': 'h', 'b': 'g', 'c': 'f', 'd': 'e',
'e': 'd', 'f': 'c', 'g': 'b', 'h': 'a'
}
number_mapping = {
'1': '8', '2': '7', '3': '6', '4': '5',
'5': '4', '6': '3', '7': '2', '8': '1'
}
# Transform string
result = ""
for i, char in enumerate(input_str):
if i % 2 == 0: # Letters
result += letter_mapping.get(char, "Invalid")
else: # Numbers
result += number_mapping.get(char, "Invalid")
# Check for invalid transformations
if "Invalid" in result:
return "Invalid input"
return result
# Streamlit app
def main():
# download_stockfish()
st.title("Chessboard Position Detection and Move Prediction")
os.chmod("/home/user/app/stockfish-ubuntu-x86-64-sse41-popcnt", 0o755)
st.write(os.getcwd())
# User uploads an image or captures it from their camera
image_file = st.camera_input("Capture a chessboard image") or st.file_uploader("Upload a chessboard image", type=["jpg", "jpeg", "png"])
if image_file is not None:
# Save the image to a temporary file
temp_dir = "temp_images"
os.makedirs(temp_dir, exist_ok=True)
temp_file_path = os.path.join(temp_dir, "uploaded_image.jpg")
with open(temp_file_path, "wb") as f:
f.write(image_file.getbuffer())
# Process the image using its file path
processed_image = process_image(temp_file_path)
if processed_image is not None:
# Resize the image to 224x224
processed_image = cv2.resize(processed_image, (224, 224))
height, width, _ = processed_image.shape
# Initialize the YOLO model
model = YOLO("standard.pt") # Replace with your trained model weights file
# Run detection
results = model.predict(source=processed_image, save=False, save_txt=False, conf=0.6)
# Initialize the board for FEN (empty rows represented by "8")
board = [["8"] * 8 for _ in range(8)]
# Extract predictions and map to FEN board
for result in results[0].boxes:
x1, y1, x2, y2 = result.xyxy[0].tolist()
class_id = int(result.cls[0])
class_name = model.names[class_id]
# Convert class_name to FEN notation
fen_piece = FEN_MAPPING.get(class_name, None)
if not fen_piece:
continue
# Calculate the center of the bounding box
center_x = (x1 + x2) / 2
center_y = (y1 + y2) / 2
# Convert to integer pixel coordinates
pixel_x = int(center_x)
pixel_y = int(height - center_y) # Flip Y-axis for generic coordinate system
# Get grid coordinate
grid_position = get_grid_coordinate(pixel_x, pixel_y)
if grid_position != "Pixel outside grid bounds":
file = ord(grid_position[0]) - ord('a') # Column index (0-7)
rank = int(grid_position[1]) - 1 # Row index (0-7)
# Place the piece on the board
board[7 - rank][file] = fen_piece # Flip rank index for FEN
# Generate the FEN string
fen_rows = []
for row in board:
fen_row = ""
empty_count = 0
for cell in row:
if cell == "8":
empty_count += 1
else:
if empty_count > 0:
fen_row += str(empty_count)
empty_count = 0
fen_row += cell
if empty_count > 0:
fen_row += str(empty_count)
fen_rows.append(fen_row)
position_fen = "/".join(fen_rows)
# Ask the user for the next move side
move_side = st.selectbox("Select the side to move:", ["w (White)", "b (Black)"])
move_side = "w" if move_side.startswith("w") else "b"
# Append the full FEN string continuation
fen_notation = f"{position_fen} {move_side} - - 0 0"
st.subheader("Generated FEN Notation:")
st.code(fen_notation)
# Initialize the Stockfish engine
stockfish_path = os.path.join(os.getcwd(), "stockfish-ubuntu-x86-64-sse41-popcnt")
stockfish = Stockfish(
path=stockfish_path,
depth=15,
parameters={"Threads": 2, "Minimum Thinking Time": 30}
)
# Predict the next move
next_move = predict_next_move(fen_notation, stockfish)
st.subheader("Stockfish Recommended Move:")
st.write(next_move)
else:
st.error("Failed to process the image. Please try again.")
if __name__ == "__main__":
main()
|