Spaces:
Sleeping
Sleeping
File size: 10,415 Bytes
5ac9d87 df04fd4 5ac9d87 d238f88 73b83df 5ac9d87 d238f88 5ac9d87 4d8763f 131f827 5ac9d87 8f200cf 5ac9d87 e07e067 5ac9d87 e07e067 5ac9d87 |
|
from ultralytics import YOLO
import cv2
from stockfish import Stockfish
import os
import numpy as np
import streamlit as st
import requests
# Constants
FEN_MAPPING = {
"black-pawn": "p", "black-rook": "r", "black-knight": "n", "black-bishop": "b", "black-queen": "q", "black-king": "k",
"white-pawn": "P", "white-rook": "R", "white-knight": "N", "white-bishop": "B", "white-queen": "Q", "white-king": "K"
}
GRID_BORDER = 10 # Border size in pixels
GRID_SIZE = 204 # Effective grid size (10px to 214px)
BLOCK_SIZE = GRID_SIZE // 8 # Each block is ~25px
X_LABELS = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h'] # Labels for x-axis (a to h)
Y_LABELS = [8, 7, 6, 5, 4, 3, 2, 1] # Reversed labels for y-axis (8 to 1)
# Functions
def get_grid_coordinate(pixel_x, pixel_y):
"""
Function to determine the grid coordinate of a pixel, considering a 10px border and
the grid where bottom-left is (a, 1) and top-left is (h, 8).
"""
# Grid settings
border = 10 # 10px border
grid_size = 204 # Effective grid size (10px to 214px)
block_size = grid_size // 8 # Each block is ~25px
x_labels = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h'] # Labels for x-axis (a to h)
y_labels = [8, 7, 6, 5, 4, 3, 2, 1] # Reversed labels for y-axis (8 to 1)
# Adjust pixel_x and pixel_y by subtracting the border (grid starts at pixel 10)
adjusted_x = pixel_x - border
adjusted_y = pixel_y - border
# Check bounds
if adjusted_x < 0 or adjusted_y < 0 or adjusted_x >= grid_size or adjusted_y >= grid_size:
return "Pixel outside grid bounds"
# Determine the grid column and row
x_index = adjusted_x // block_size
y_index = adjusted_y // block_size
if x_index < 0 or x_index >= len(x_labels) or y_index < 0 or y_index >= len(y_labels):
return "Pixel outside grid bounds"
# Convert indices to grid coordinates
x_index = adjusted_x // block_size # Determine the column index (0-7)
y_index = adjusted_y // block_size # Determine the row index (0-7)
# Convert row index to the correct label, with '8' at the bottom
y_labeld = y_labels[y_index] # Correct index directly maps to '8' to '1'
x_label = x_labels[x_index]
y_label = 8 - y_labeld + 1
return f"{x_label}{y_label}"
def predict_next_move(fen, stockfish):
"""
Predict the next move using Stockfish.
"""
if stockfish.is_fen_valid(fen):
stockfish.set_fen_position(fen)
else:
return "Invalid FEN notation!"
best_move = stockfish.get_best_move()
ans = transform_string(best_move)
return f"The predicted next move is: {ans}" if best_move else "No valid move found (checkmate/stalemate)."
# def download_stockfish():
# url = "https://drive.google.com/file/d/18pkwBVc13fgKP3LzrTHE4yzhjyGJexlR/view?usp=sharing" # Replace with the actual link
# file_name = "stockfish-windows-x86-64-avx2.exe"
# if not os.path.exists(file_name):
# print(f"Downloading {file_name}...")
# response = requests.get(url, stream=True)
# with open(file_name, "wb") as file:
# for chunk in response.iter_content(chunk_size=1024):
# if chunk:
# file.write(chunk)
# print(f"{file_name} downloaded successfully.")
def process_image(image_path):
# Ensure output directory exists
if not os.path.exists('output'):
os.makedirs('output')
# Load the segmentation model
segmentation_model = YOLO("segmentation.pt")
# Run inference to get segmentation results
results = segmentation_model.predict(
source=image_path,
conf=0.8 # Confidence threshold
)
# Initialize variables for the segmented mask and bounding box
segmentation_mask = None
bbox = None
for result in results:
if result.boxes.conf[0] >= 0.8: # Filter results by confidence
segmentation_mask = result.masks.data.cpu().numpy().astype(np.uint8)[0]
bbox = result.boxes.xyxy[0].cpu().numpy() # Get the bounding box coordinates
break
if segmentation_mask is None:
print("No segmentation mask with confidence above 0.8 found.")
return None
# Load the image
image = cv2.imread(image_path)
# Resize segmentation mask to match the input image dimensions
segmentation_mask_resized = cv2.resize(segmentation_mask, (image.shape[1], image.shape[0]))
# Extract bounding box coordinates
if bbox is not None:
x1, y1, x2, y2 = bbox
# Crop the segmented region based on the bounding box
cropped_segment = image[int(y1):int(y2), int(x1):int(x2)]
# Save the cropped segmented image
cropped_image_path = 'output/cropped_segment.jpg'
cv2.imwrite(cropped_image_path, cropped_segment)
print(f"Cropped segmented image saved to {cropped_image_path}")
st.image(cropped_segment, caption="Uploaded Image", use_column_width=True)
# Return the cropped image
return cropped_segment
def transform_string(input_str):
# Remove extra spaces and convert to lowercase
input_str = input_str.strip().lower()
# Check if input is valid
if len(input_str) != 4 or not input_str[0].isalpha() or not input_str[1].isdigit() or \
not input_str[2].isalpha() or not input_str[3].isdigit():
return "Invalid input"
# Define mappings
letter_mapping = {
'a': 'h', 'b': 'g', 'c': 'f', 'd': 'e',
'e': 'd', 'f': 'c', 'g': 'b', 'h': 'a'
}
number_mapping = {
'1': '8', '2': '7', '3': '6', '4': '5',
'5': '4', '6': '3', '7': '2', '8': '1'
}
# Transform string
result = ""
for i, char in enumerate(input_str):
if i % 2 == 0: # Letters
result += letter_mapping.get(char, "Invalid")
else: # Numbers
result += number_mapping.get(char, "Invalid")
# Check for invalid transformations
if "Invalid" in result:
return "Invalid input"
return result
# Streamlit app
def main():
# download_stockfish()
st.title("Chessboard Position Detection and Move Prediction")
os.chmod("/home/user/app/stockfish-ubuntu-x86-64-sse41-popcnt", 0o755)
st.write(os.getcwd())
# User uploads an image or captures it from their camera
image_file = st.camera_input("Capture a chessboard image") or st.file_uploader("Upload a chessboard image", type=["jpg", "jpeg", "png"])
if image_file is not None:
# Save the image to a temporary file
temp_dir = "temp_images"
os.makedirs(temp_dir, exist_ok=True)
temp_file_path = os.path.join(temp_dir, "uploaded_image.jpg")
with open(temp_file_path, "wb") as f:
f.write(image_file.getbuffer())
# Process the image using its file path
processed_image = process_image(temp_file_path)
if processed_image is not None:
# Resize the image to 224x224
processed_image = cv2.resize(processed_image, (224, 224))
height, width, _ = processed_image.shape
# Initialize the YOLO model
model = YOLO("standard.pt") # Replace with your trained model weights file
# Run detection
results = model.predict(source=processed_image, save=False, save_txt=False, conf=0.6)
# Initialize the board for FEN (empty rows represented by "8")
board = [["8"] * 8 for _ in range(8)]
# Extract predictions and map to FEN board
for result in results[0].boxes:
x1, y1, x2, y2 = result.xyxy[0].tolist()
class_id = int(result.cls[0])
class_name = model.names[class_id]
# Convert class_name to FEN notation
fen_piece = FEN_MAPPING.get(class_name, None)
if not fen_piece:
continue
# Calculate the center of the bounding box
center_x = (x1 + x2) / 2
center_y = (y1 + y2) / 2
# Convert to integer pixel coordinates
pixel_x = int(center_x)
pixel_y = int(height - center_y) # Flip Y-axis for generic coordinate system
# Get grid coordinate
grid_position = get_grid_coordinate(pixel_x, pixel_y)
if grid_position != "Pixel outside grid bounds":
file = ord(grid_position[0]) - ord('a') # Column index (0-7)
rank = int(grid_position[1]) - 1 # Row index (0-7)
# Place the piece on the board
board[7 - rank][file] = fen_piece # Flip rank index for FEN
# Generate the FEN string
fen_rows = []
for row in board:
fen_row = ""
empty_count = 0
for cell in row:
if cell == "8":
empty_count += 1
else:
if empty_count > 0:
fen_row += str(empty_count)
empty_count = 0
fen_row += cell
if empty_count > 0:
fen_row += str(empty_count)
fen_rows.append(fen_row)
position_fen = "/".join(fen_rows)
# Ask the user for the next move side
move_side = st.selectbox("Select the side to move:", ["w (White)", "b (Black)"])
move_side = "w" if move_side.startswith("w") else "b"
# Append the full FEN string continuation
fen_notation = f"{position_fen} {move_side} - - 0 0"
st.subheader("Generated FEN Notation:")
st.code(fen_notation)
# Initialize the Stockfish engine
stockfish_path = os.path.join(os.getcwd(), "stockfish-ubuntu-x86-64-sse41-popcnt")
stockfish = Stockfish(
path=stockfish_path,
depth=15,
parameters={"Threads": 2, "Minimum Thinking Time": 30}
)
# Predict the next move
next_move = predict_next_move(fen_notation, stockfish)
st.subheader("Stockfish Recommended Move:")
st.write(next_move)
else:
st.error("Failed to process the image. Please try again.")
if __name__ == "__main__":
main()
|