Spaces:
Running
Running
from ultralytics import YOLO | |
import cv2 | |
from stockfish import Stockfish | |
import os | |
import numpy as np | |
import streamlit as st | |
import requests | |
# Constants | |
FEN_MAPPING = { | |
"black-pawn": "p", "black-rook": "r", "black-knight": "n", "black-bishop": "b", "black-queen": "q", "black-king": "k", | |
"white-pawn": "P", "white-rook": "R", "white-knight": "N", "white-bishop": "B", "white-queen": "Q", "white-king": "K" | |
} | |
GRID_BORDER = 10 # Border size in pixels | |
GRID_SIZE = 204 # Effective grid size (10px to 214px) | |
BLOCK_SIZE = GRID_SIZE // 8 # Each block is ~25px | |
X_LABELS = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h'] # Labels for x-axis (a to h) | |
Y_LABELS = [8, 7, 6, 5, 4, 3, 2, 1] # Reversed labels for y-axis (8 to 1) | |
# Functions | |
def get_grid_coordinate(pixel_x, pixel_y): | |
""" | |
Function to determine the grid coordinate of a pixel, considering a 10px border and | |
the grid where bottom-left is (a, 1) and top-left is (h, 8). | |
""" | |
# Grid settings | |
border = 10 # 10px border | |
grid_size = 204 # Effective grid size (10px to 214px) | |
block_size = grid_size // 8 # Each block is ~25px | |
x_labels = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h'] # Labels for x-axis (a to h) | |
y_labels = [8, 7, 6, 5, 4, 3, 2, 1] # Reversed labels for y-axis (8 to 1) | |
# Adjust pixel_x and pixel_y by subtracting the border (grid starts at pixel 10) | |
adjusted_x = pixel_x - border | |
adjusted_y = pixel_y - border | |
# Check bounds | |
if adjusted_x < 0 or adjusted_y < 0 or adjusted_x >= grid_size or adjusted_y >= grid_size: | |
return "Pixel outside grid bounds" | |
# Determine the grid column and row | |
x_index = adjusted_x // block_size | |
y_index = adjusted_y // block_size | |
if x_index < 0 or x_index >= len(x_labels) or y_index < 0 or y_index >= len(y_labels): | |
return "Pixel outside grid bounds" | |
# Convert indices to grid coordinates | |
x_index = adjusted_x // block_size # Determine the column index (0-7) | |
y_index = adjusted_y // block_size # Determine the row index (0-7) | |
# Convert row index to the correct label, with '8' at the bottom | |
y_labeld = y_labels[y_index] # Correct index directly maps to '8' to '1' | |
x_label = x_labels[x_index] | |
y_label = 8 - y_labeld + 1 | |
return f"{x_label}{y_label}" | |
def predict_next_move(fen, stockfish): | |
""" | |
Predict the next move using Stockfish. | |
""" | |
if stockfish.is_fen_valid(fen): | |
stockfish.set_fen_position(fen) | |
else: | |
return "Invalid FEN notation!" | |
best_move = stockfish.get_best_move() | |
ans = transform_string(best_move) | |
return f"The predicted next move is: {ans}" if best_move else "No valid move found (checkmate/stalemate)." | |
# def download_stockfish(): | |
# url = "https://drive.google.com/file/d/18pkwBVc13fgKP3LzrTHE4yzhjyGJexlR/view?usp=sharing" # Replace with the actual link | |
# file_name = "stockfish-windows-x86-64-avx2.exe" | |
# if not os.path.exists(file_name): | |
# print(f"Downloading {file_name}...") | |
# response = requests.get(url, stream=True) | |
# with open(file_name, "wb") as file: | |
# for chunk in response.iter_content(chunk_size=1024): | |
# if chunk: | |
# file.write(chunk) | |
# print(f"{file_name} downloaded successfully.") | |
def process_image(image_path): | |
# Ensure output directory exists | |
if not os.path.exists('output'): | |
os.makedirs('output') | |
# Load the segmentation model | |
segmentation_model = YOLO("segmentation.pt") | |
# Run inference to get segmentation results | |
results = segmentation_model.predict( | |
source=image_path, | |
conf=0.8 # Confidence threshold | |
) | |
# Initialize variables for the segmented mask and bounding box | |
segmentation_mask = None | |
bbox = None | |
for result in results: | |
if result.boxes.conf[0] >= 0.8: # Filter results by confidence | |
segmentation_mask = result.masks.data.cpu().numpy().astype(np.uint8)[0] | |
bbox = result.boxes.xyxy[0].cpu().numpy() # Get the bounding box coordinates | |
break | |
if segmentation_mask is None: | |
print("No segmentation mask with confidence above 0.8 found.") | |
return None | |
# Load the image | |
image = cv2.imread(image_path) | |
# Resize segmentation mask to match the input image dimensions | |
segmentation_mask_resized = cv2.resize(segmentation_mask, (image.shape[1], image.shape[0])) | |
# Extract bounding box coordinates | |
if bbox is not None: | |
x1, y1, x2, y2 = bbox | |
# Crop the segmented region based on the bounding box | |
cropped_segment = image[int(y1):int(y2), int(x1):int(x2)] | |
# Save the cropped segmented image | |
cropped_image_path = 'output/cropped_segment.jpg' | |
cv2.imwrite(cropped_image_path, cropped_segment) | |
print(f"Cropped segmented image saved to {cropped_image_path}") | |
st.image(cropped_segment, caption="Uploaded Image", use_column_width=True) | |
# Return the cropped image | |
return cropped_segment | |
def transform_string(input_str): | |
# Remove extra spaces and convert to lowercase | |
input_str = input_str.strip().lower() | |
# Check if input is valid | |
if len(input_str) != 4 or not input_str[0].isalpha() or not input_str[1].isdigit() or \ | |
not input_str[2].isalpha() or not input_str[3].isdigit(): | |
return "Invalid input" | |
# Define mappings | |
letter_mapping = { | |
'a': 'h', 'b': 'g', 'c': 'f', 'd': 'e', | |
'e': 'd', 'f': 'c', 'g': 'b', 'h': 'a' | |
} | |
number_mapping = { | |
'1': '8', '2': '7', '3': '6', '4': '5', | |
'5': '4', '6': '3', '7': '2', '8': '1' | |
} | |
# Transform string | |
result = "" | |
for i, char in enumerate(input_str): | |
if i % 2 == 0: # Letters | |
result += letter_mapping.get(char, "Invalid") | |
else: # Numbers | |
result += number_mapping.get(char, "Invalid") | |
# Check for invalid transformations | |
if "Invalid" in result: | |
return "Invalid input" | |
return result | |
# Streamlit app | |
def main(): | |
# download_stockfish() | |
st.title("Chessboard Position Detection and Move Prediction") | |
os.chmod("/home/user/app/stockfish-ubuntu-x86-64-sse41-popcnt", 0o755) | |
st.write(os.getcwd()) | |
# User uploads an image or captures it from their camera | |
image_file = st.camera_input("Capture a chessboard image") or st.file_uploader("Upload a chessboard image", type=["jpg", "jpeg", "png"]) | |
if image_file is not None: | |
# Save the image to a temporary file | |
temp_dir = "temp_images" | |
os.makedirs(temp_dir, exist_ok=True) | |
temp_file_path = os.path.join(temp_dir, "uploaded_image.jpg") | |
with open(temp_file_path, "wb") as f: | |
f.write(image_file.getbuffer()) | |
# Process the image using its file path | |
processed_image = process_image(temp_file_path) | |
if processed_image is not None: | |
# Resize the image to 224x224 | |
processed_image = cv2.resize(processed_image, (224, 224)) | |
height, width, _ = processed_image.shape | |
# Initialize the YOLO model | |
model = YOLO("standard.pt") # Replace with your trained model weights file | |
# Run detection | |
results = model.predict(source=processed_image, save=False, save_txt=False, conf=0.6) | |
# Initialize the board for FEN (empty rows represented by "8") | |
board = [["8"] * 8 for _ in range(8)] | |
# Extract predictions and map to FEN board | |
for result in results[0].boxes: | |
x1, y1, x2, y2 = result.xyxy[0].tolist() | |
class_id = int(result.cls[0]) | |
class_name = model.names[class_id] | |
# Convert class_name to FEN notation | |
fen_piece = FEN_MAPPING.get(class_name, None) | |
if not fen_piece: | |
continue | |
# Calculate the center of the bounding box | |
center_x = (x1 + x2) / 2 | |
center_y = (y1 + y2) / 2 | |
# Convert to integer pixel coordinates | |
pixel_x = int(center_x) | |
pixel_y = int(height - center_y) # Flip Y-axis for generic coordinate system | |
# Get grid coordinate | |
grid_position = get_grid_coordinate(pixel_x, pixel_y) | |
if grid_position != "Pixel outside grid bounds": | |
file = ord(grid_position[0]) - ord('a') # Column index (0-7) | |
rank = int(grid_position[1]) - 1 # Row index (0-7) | |
# Place the piece on the board | |
board[7 - rank][file] = fen_piece # Flip rank index for FEN | |
# Generate the FEN string | |
fen_rows = [] | |
for row in board: | |
fen_row = "" | |
empty_count = 0 | |
for cell in row: | |
if cell == "8": | |
empty_count += 1 | |
else: | |
if empty_count > 0: | |
fen_row += str(empty_count) | |
empty_count = 0 | |
fen_row += cell | |
if empty_count > 0: | |
fen_row += str(empty_count) | |
fen_rows.append(fen_row) | |
position_fen = "/".join(fen_rows) | |
# Ask the user for the next move side | |
move_side = st.selectbox("Select the side to move:", ["w (White)", "b (Black)"]) | |
move_side = "w" if move_side.startswith("w") else "b" | |
# Append the full FEN string continuation | |
fen_notation = f"{position_fen} {move_side} - - 0 0" | |
st.subheader("Generated FEN Notation:") | |
st.code(fen_notation) | |
# Initialize the Stockfish engine | |
stockfish_path = os.path.join(os.getcwd(), "stockfish-ubuntu-x86-64-sse41-popcnt") | |
stockfish = Stockfish( | |
path=stockfish_path, | |
depth=15, | |
parameters={"Threads": 2, "Minimum Thinking Time": 30} | |
) | |
# Predict the next move | |
next_move = predict_next_move(fen_notation, stockfish) | |
st.subheader("Stockfish Recommended Move:") | |
st.write(next_move) | |
else: | |
st.error("Failed to process the image. Please try again.") | |
if __name__ == "__main__": | |
main() | |