Spaces:
Running
Running
Upload 11 files
Browse files- app.py +187 -0
- deps/__init__.py +2 -0
- deps/geometry.py +1164 -0
- deps/laps.py +178 -0
- llr.py +307 -0
- preprocess.py +74 -0
- requirements_hf.txt +7 -0
- rescale.py +48 -0
- slid.py +211 -0
- train_tensorflow.py +182 -0
- utils.py +92 -0
app.py
ADDED
|
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Gradio App for Chess Board Analyzer
|
| 3 |
+
EXACTLY uses main.py logic - no modifications
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import gradio as gr
|
| 7 |
+
import chess
|
| 8 |
+
import chess.svg
|
| 9 |
+
import io
|
| 10 |
+
import numpy as np
|
| 11 |
+
import cv2
|
| 12 |
+
import os
|
| 13 |
+
import tempfile
|
| 14 |
+
from pathlib import Path
|
| 15 |
+
|
| 16 |
+
# Import EXACT SAME functions from main.py
|
| 17 |
+
from preprocess import preprocess_image
|
| 18 |
+
from train_tensorflow import create_model
|
| 19 |
+
|
| 20 |
+
PIECES = ['Empty', 'Rook_White', 'Rook_Black', 'Knight_White', 'Knight_Black', 'Bishop_White',
|
| 21 |
+
'Bishop_Black', 'Queen_White', 'Queen_Black', 'King_White', 'King_Black', 'Pawn_White', 'Pawn_Black']
|
| 22 |
+
PIECES.sort()
|
| 23 |
+
|
| 24 |
+
LABELS = {
|
| 25 |
+
'Empty': '.',
|
| 26 |
+
'Rook_White': 'R',
|
| 27 |
+
'Rook_Black': 'r',
|
| 28 |
+
'Knight_White': 'N',
|
| 29 |
+
'Knight_Black': 'n',
|
| 30 |
+
'Bishop_White': 'B',
|
| 31 |
+
'Bishop_Black': 'b',
|
| 32 |
+
'Queen_White': 'Q',
|
| 33 |
+
'Queen_Black': 'q',
|
| 34 |
+
'King_White': 'K',
|
| 35 |
+
'King_Black': 'k',
|
| 36 |
+
'Pawn_White': 'P',
|
| 37 |
+
'Pawn_Black': 'p',
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
# Load model at startup (EXACT SAME as main.py)
|
| 41 |
+
print("⏳ Loading model...")
|
| 42 |
+
model = create_model()
|
| 43 |
+
model.load_weights('./model_weights.h5')
|
| 44 |
+
print("✅ Model loaded!")
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def classify_image(img):
|
| 48 |
+
"""EXACT COPY from main.py"""
|
| 49 |
+
y_prob = model.predict(img.reshape(1, 300, 150, 3), verbose=0)
|
| 50 |
+
y_pred = y_prob.argmax()
|
| 51 |
+
return PIECES[y_pred]
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def analyze_board(img):
|
| 55 |
+
"""EXACT COPY from main.py"""
|
| 56 |
+
arr = []
|
| 57 |
+
M = img.shape[0]//8
|
| 58 |
+
N = img.shape[1]//8
|
| 59 |
+
for y in range(M-1, img.shape[1], M):
|
| 60 |
+
row = []
|
| 61 |
+
for x in range(0, img.shape[1], N):
|
| 62 |
+
sub_img = img[max(0, y-2*M):y, x:x+N]
|
| 63 |
+
if y-2*M < 0:
|
| 64 |
+
sub_img = np.concatenate(
|
| 65 |
+
(np.zeros((2*M-y, N, 3)), sub_img))
|
| 66 |
+
sub_img = sub_img.astype(np.uint8)
|
| 67 |
+
|
| 68 |
+
piece = classify_image(sub_img)
|
| 69 |
+
row.append(LABELS[piece])
|
| 70 |
+
arr.append(row)
|
| 71 |
+
|
| 72 |
+
# King-Queen heuristic
|
| 73 |
+
blackKing = False
|
| 74 |
+
whiteKing = False
|
| 75 |
+
whitePos = (-1, -1)
|
| 76 |
+
blackPos = (-1, -1)
|
| 77 |
+
for i in range(8):
|
| 78 |
+
for j in range(8):
|
| 79 |
+
if arr[i][j] == 'K':
|
| 80 |
+
whiteKing = True
|
| 81 |
+
if arr[i][j] == 'k':
|
| 82 |
+
blackKing = True
|
| 83 |
+
if arr[i][j] == 'Q':
|
| 84 |
+
whitePos = (i, j)
|
| 85 |
+
if arr[i][j] == 'q':
|
| 86 |
+
blackPos = (i, j)
|
| 87 |
+
if not whiteKing and whitePos[0] >= 0:
|
| 88 |
+
arr[whitePos[0]][whitePos[1]] = 'K'
|
| 89 |
+
if not blackKing and blackPos[0] >= 0:
|
| 90 |
+
arr[blackPos[0]][blackPos[1]] = 'k'
|
| 91 |
+
|
| 92 |
+
return arr
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def board_to_fen(board):
|
| 96 |
+
"""EXACT COPY from main.py"""
|
| 97 |
+
with io.StringIO() as s:
|
| 98 |
+
for row in board:
|
| 99 |
+
empty = 0
|
| 100 |
+
for cell in row:
|
| 101 |
+
if cell != '.':
|
| 102 |
+
if empty > 0:
|
| 103 |
+
s.write(str(empty))
|
| 104 |
+
empty = 0
|
| 105 |
+
s.write(cell)
|
| 106 |
+
else:
|
| 107 |
+
empty += 1
|
| 108 |
+
if empty > 0:
|
| 109 |
+
s.write(str(empty))
|
| 110 |
+
s.write('/')
|
| 111 |
+
s.seek(s.tell() - 1)
|
| 112 |
+
s.write(' w KQkq - 0 1')
|
| 113 |
+
return s.getvalue()
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def analyze_chess_image(image_input):
|
| 117 |
+
"""Gradio wrapper around main.py logic"""
|
| 118 |
+
if image_input is None:
|
| 119 |
+
return "❌ No image provided", None
|
| 120 |
+
|
| 121 |
+
try:
|
| 122 |
+
# Save to temp file (needed for preprocess_image which expects file path)
|
| 123 |
+
with tempfile.NamedTemporaryFile(suffix='.jpg', delete=False) as tmp:
|
| 124 |
+
if isinstance(image_input, np.ndarray):
|
| 125 |
+
cv2.imwrite(tmp.name, cv2.cvtColor(image_input, cv2.COLOR_RGB2BGR))
|
| 126 |
+
else:
|
| 127 |
+
image_input.save(tmp.name)
|
| 128 |
+
temp_path = tmp.name
|
| 129 |
+
|
| 130 |
+
# EXACT SAME as main.py: preprocess_image() uses LAPS!
|
| 131 |
+
img = preprocess_image(temp_path, save=False)
|
| 132 |
+
|
| 133 |
+
# EXACT SAME as main.py
|
| 134 |
+
arr = analyze_board(img)
|
| 135 |
+
fen = board_to_fen(arr)
|
| 136 |
+
|
| 137 |
+
# Generate board visualization
|
| 138 |
+
board = chess.Board(fen)
|
| 139 |
+
board_svg = chess.svg.board(board=board, size=400)
|
| 140 |
+
|
| 141 |
+
# Cleanup
|
| 142 |
+
os.unlink(temp_path)
|
| 143 |
+
|
| 144 |
+
return f"✅ FEN: {fen}", board_svg
|
| 145 |
+
|
| 146 |
+
except Exception as e:
|
| 147 |
+
import traceback
|
| 148 |
+
print(traceback.format_exc())
|
| 149 |
+
return f"❌ Error: {str(e)}", None
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
# Build Gradio interface
|
| 153 |
+
with gr.Blocks(title="Chess Board Analyzer", theme=gr.themes.Soft()) as demo:
|
| 154 |
+
gr.Markdown("""
|
| 155 |
+
# ♟️ Chess Board Analyzer
|
| 156 |
+
|
| 157 |
+
Upload a chess board image to automatically detect all pieces and get the FEN notation.
|
| 158 |
+
|
| 159 |
+
**Uses EXACT SAME preprocessing (LAPS) and model as main.py**
|
| 160 |
+
""")
|
| 161 |
+
|
| 162 |
+
with gr.Row():
|
| 163 |
+
with gr.Column():
|
| 164 |
+
image_input = gr.Image(label="📸 Upload chess board image", type="pil")
|
| 165 |
+
submit_btn = gr.Button("🔍 Analyze Board", size="lg", variant="primary")
|
| 166 |
+
|
| 167 |
+
with gr.Column():
|
| 168 |
+
status_output = gr.Textbox(label="Result", interactive=False, lines=2)
|
| 169 |
+
board_output = gr.HTML(label="Board Visualization")
|
| 170 |
+
|
| 171 |
+
submit_btn.click(
|
| 172 |
+
fn=analyze_chess_image,
|
| 173 |
+
inputs=image_input,
|
| 174 |
+
outputs=[status_output, board_output]
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
gr.Markdown("""
|
| 178 |
+
### Model Info:
|
| 179 |
+
- **Preprocessing**: LAPS (Lattice Point Detection + Perspective Correction)
|
| 180 |
+
- **Architecture**: CNN with 5 convolutional layers
|
| 181 |
+
- **Accuracy**: ~96% on test set
|
| 182 |
+
- **Classes**: 13 types (Empty + 6 White + 6 Black pieces)
|
| 183 |
+
""")
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
if __name__ == "__main__":
|
| 187 |
+
demo.launch()
|
deps/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from . import geometry
|
| 2 |
+
from . import laps
|
deps/geometry.py
ADDED
|
@@ -0,0 +1,1164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
https://github.com/ideasman42/isect_segments-bentley_ottmann
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
# BentleyOttmann sweep-line implementation
|
| 6 |
+
# (for finding all intersections in a set of line segments)
|
| 7 |
+
|
| 8 |
+
__all__ = (
|
| 9 |
+
"isect_segments",
|
| 10 |
+
"isect_polygon",
|
| 11 |
+
|
| 12 |
+
# for testing only (correct but slow)
|
| 13 |
+
"isect_segments__naive",
|
| 14 |
+
"isect_polygon__naive",
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
# ----------------------------------------------------------------------------
|
| 18 |
+
# Main Poly Intersection
|
| 19 |
+
|
| 20 |
+
# Defines to change behavior.
|
| 21 |
+
#
|
| 22 |
+
# Whether to ignore intersections of line segments when both
|
| 23 |
+
# their end points form the intersection point.
|
| 24 |
+
USE_IGNORE_SEGMENT_ENDINGS = True
|
| 25 |
+
|
| 26 |
+
USE_DEBUG = False # FIXME
|
| 27 |
+
|
| 28 |
+
USE_VERBOSE = False
|
| 29 |
+
|
| 30 |
+
# checks we should NOT need,
|
| 31 |
+
# but do them in case we find a test-case that fails.
|
| 32 |
+
USE_PARANOID = False
|
| 33 |
+
|
| 34 |
+
# Support vertical segments,
|
| 35 |
+
# (the bentley-ottmann method doesn't support this).
|
| 36 |
+
# We use the term 'START_VERTICAL' for a vertical segment,
|
| 37 |
+
# to differentiate it from START/END/INTERSECTION
|
| 38 |
+
USE_VERTICAL = True
|
| 39 |
+
# end defines!
|
| 40 |
+
# ------------
|
| 41 |
+
|
| 42 |
+
# ---------
|
| 43 |
+
# Constants
|
| 44 |
+
X, Y = 0, 1
|
| 45 |
+
EPS = 1e-10
|
| 46 |
+
EPS_SQ = EPS * EPS
|
| 47 |
+
INF = float("inf")
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class Event:
|
| 51 |
+
__slots__ = (
|
| 52 |
+
"type",
|
| 53 |
+
"point",
|
| 54 |
+
"segment",
|
| 55 |
+
|
| 56 |
+
# this is just cache,
|
| 57 |
+
# we may remove or calculate slope on the fly
|
| 58 |
+
"slope",
|
| 59 |
+
"span",
|
| 60 |
+
) + (() if not USE_DEBUG else (
|
| 61 |
+
# debugging only
|
| 62 |
+
"other",
|
| 63 |
+
"in_sweep",
|
| 64 |
+
))
|
| 65 |
+
|
| 66 |
+
class Type:
|
| 67 |
+
END = 0
|
| 68 |
+
INTERSECTION = 1
|
| 69 |
+
START = 2
|
| 70 |
+
if USE_VERTICAL:
|
| 71 |
+
START_VERTICAL = 3
|
| 72 |
+
|
| 73 |
+
def __init__(self, type, point, segment, slope):
|
| 74 |
+
assert(isinstance(point, tuple))
|
| 75 |
+
self.type = type
|
| 76 |
+
self.point = point
|
| 77 |
+
self.segment = segment
|
| 78 |
+
|
| 79 |
+
# will be None for INTERSECTION
|
| 80 |
+
self.slope = slope
|
| 81 |
+
if segment is not None:
|
| 82 |
+
self.span = segment[1][X] - segment[0][X]
|
| 83 |
+
|
| 84 |
+
if USE_DEBUG:
|
| 85 |
+
self.other = None
|
| 86 |
+
self.in_sweep = False
|
| 87 |
+
|
| 88 |
+
def is_vertical(self):
|
| 89 |
+
return self.segment[0][X] == self.segment[1][X]
|
| 90 |
+
|
| 91 |
+
def y_intercept_x(self, x: float):
|
| 92 |
+
# vertical events only for comparison (above_all check)
|
| 93 |
+
# never added into the binary-tree its self
|
| 94 |
+
if USE_VERTICAL:
|
| 95 |
+
if self.is_vertical():
|
| 96 |
+
return None
|
| 97 |
+
|
| 98 |
+
if x <= self.segment[0][X]:
|
| 99 |
+
return self.segment[0][Y]
|
| 100 |
+
elif x >= self.segment[1][X]:
|
| 101 |
+
return self.segment[1][Y]
|
| 102 |
+
|
| 103 |
+
# use the largest to avoid float precision error with nearly vertical lines.
|
| 104 |
+
delta_x0 = x - self.segment[0][X]
|
| 105 |
+
delta_x1 = self.segment[1][X] - x
|
| 106 |
+
if delta_x0 > delta_x1:
|
| 107 |
+
ifac = delta_x0 / self.span
|
| 108 |
+
fac = 1.0 - ifac
|
| 109 |
+
else:
|
| 110 |
+
fac = delta_x1 / self.span
|
| 111 |
+
ifac = 1.0 - fac
|
| 112 |
+
assert(fac <= 1.0)
|
| 113 |
+
return (self.segment[0][Y] * fac) + (self.segment[1][Y] * ifac)
|
| 114 |
+
|
| 115 |
+
@staticmethod
|
| 116 |
+
def Compare(sweep_line, this, that):
|
| 117 |
+
if this is that:
|
| 118 |
+
return 0
|
| 119 |
+
if USE_DEBUG:
|
| 120 |
+
if this.other is that:
|
| 121 |
+
return 0
|
| 122 |
+
current_point_x = sweep_line._current_event_point_x
|
| 123 |
+
ipthis = this.y_intercept_x(current_point_x)
|
| 124 |
+
ipthat = that.y_intercept_x(current_point_x)
|
| 125 |
+
# print(ipthis, ipthat)
|
| 126 |
+
if USE_VERTICAL:
|
| 127 |
+
if ipthis is None:
|
| 128 |
+
ipthis = this.point[Y]
|
| 129 |
+
if ipthat is None:
|
| 130 |
+
ipthat = that.point[Y]
|
| 131 |
+
|
| 132 |
+
delta_y = ipthis - ipthat
|
| 133 |
+
|
| 134 |
+
assert((delta_y < 0.0) == (ipthis < ipthat))
|
| 135 |
+
# NOTE, VERY IMPORTANT TO USE EPSILON HERE!
|
| 136 |
+
# otherwise w/ float precision errors we get incorrect comparisons
|
| 137 |
+
# can get very strange & hard to debug output without this.
|
| 138 |
+
if abs(delta_y) > EPS:
|
| 139 |
+
return -1 if (delta_y < 0.0) else 1
|
| 140 |
+
else:
|
| 141 |
+
this_slope = this.slope
|
| 142 |
+
that_slope = that.slope
|
| 143 |
+
if this_slope != that_slope:
|
| 144 |
+
if sweep_line._before:
|
| 145 |
+
return -1 if (this_slope > that_slope) else 1
|
| 146 |
+
else:
|
| 147 |
+
return 1 if (this_slope > that_slope) else -1
|
| 148 |
+
|
| 149 |
+
delta_x_p1 = this.segment[0][X] - that.segment[0][X]
|
| 150 |
+
if delta_x_p1 != 0.0:
|
| 151 |
+
return -1 if (delta_x_p1 < 0.0) else 1
|
| 152 |
+
|
| 153 |
+
delta_x_p2 = this.segment[1][X] - that.segment[1][X]
|
| 154 |
+
if delta_x_p2 != 0.0:
|
| 155 |
+
return -1 if (delta_x_p2 < 0.0) else 1
|
| 156 |
+
|
| 157 |
+
return 0
|
| 158 |
+
|
| 159 |
+
def __repr__(self):
|
| 160 |
+
return ("Event(0x%x, s0=%r, s1=%r, p=%r, type=%d, slope=%r)" % (
|
| 161 |
+
id(self),
|
| 162 |
+
self.segment[0], self.segment[1],
|
| 163 |
+
self.point,
|
| 164 |
+
self.type,
|
| 165 |
+
self.slope,
|
| 166 |
+
))
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
class SweepLine:
|
| 170 |
+
__slots__ = (
|
| 171 |
+
# A map holding all intersection points mapped to the Events
|
| 172 |
+
# that form these intersections.
|
| 173 |
+
# {Point: set(Event, ...), ...}
|
| 174 |
+
"intersections",
|
| 175 |
+
"queue",
|
| 176 |
+
|
| 177 |
+
# Events (sorted set of ordered events, no values)
|
| 178 |
+
#
|
| 179 |
+
# note: START & END events are considered the same so checking if an event is in the tree
|
| 180 |
+
# will return true if its opposite side is found.
|
| 181 |
+
# This is essential for the algorithm to work, and why we don't explicitly remove START events.
|
| 182 |
+
# Instead, the END events are never added to the current sweep, and removing them also removes the start.
|
| 183 |
+
"_events_current_sweep",
|
| 184 |
+
# The point of the current Event.
|
| 185 |
+
"_current_event_point_x",
|
| 186 |
+
# A flag to indicate if we're slightly before or after the line.
|
| 187 |
+
"_before",
|
| 188 |
+
)
|
| 189 |
+
|
| 190 |
+
def __init__(self):
|
| 191 |
+
self.intersections = {}
|
| 192 |
+
|
| 193 |
+
self._current_event_point_x = None
|
| 194 |
+
self._events_current_sweep = RBTree(cmp=Event.Compare, cmp_data=self)
|
| 195 |
+
self._before = True
|
| 196 |
+
|
| 197 |
+
def get_intersections(self):
|
| 198 |
+
return list(self.intersections.keys())
|
| 199 |
+
|
| 200 |
+
# Checks if an intersection exists between two Events 'a' and 'b'.
|
| 201 |
+
def _check_intersection(self, a: Event, b: Event):
|
| 202 |
+
# Return immediately in case either of the events is null, or
|
| 203 |
+
# if one of them is an INTERSECTION event.
|
| 204 |
+
if ((a is None or b is None) or
|
| 205 |
+
(a.type == Event.Type.INTERSECTION) or
|
| 206 |
+
(b.type == Event.Type.INTERSECTION)):
|
| 207 |
+
|
| 208 |
+
return
|
| 209 |
+
|
| 210 |
+
if a is b:
|
| 211 |
+
return
|
| 212 |
+
|
| 213 |
+
# Get the intersection point between 'a' and 'b'.
|
| 214 |
+
p = isect_seg_seg_v2_point(
|
| 215 |
+
a.segment[0], a.segment[1],
|
| 216 |
+
b.segment[0], b.segment[1])
|
| 217 |
+
|
| 218 |
+
# No intersection exists.
|
| 219 |
+
if p is None:
|
| 220 |
+
return
|
| 221 |
+
|
| 222 |
+
# If the intersection is formed by both the segment endings, AND
|
| 223 |
+
# USE_IGNORE_SEGMENT_ENDINGS is true,
|
| 224 |
+
# return from this method.
|
| 225 |
+
if USE_IGNORE_SEGMENT_ENDINGS:
|
| 226 |
+
if ((len_squared_v2v2(p, a.segment[0]) < EPS_SQ or
|
| 227 |
+
len_squared_v2v2(p, a.segment[1]) < EPS_SQ) and
|
| 228 |
+
(len_squared_v2v2(p, b.segment[0]) < EPS_SQ or
|
| 229 |
+
len_squared_v2v2(p, b.segment[1]) < EPS_SQ)):
|
| 230 |
+
|
| 231 |
+
return
|
| 232 |
+
|
| 233 |
+
# Add the intersection.
|
| 234 |
+
events_for_point = self.intersections.pop(p, set())
|
| 235 |
+
is_new = len(events_for_point) == 0
|
| 236 |
+
events_for_point.add(a)
|
| 237 |
+
events_for_point.add(b)
|
| 238 |
+
self.intersections[p] = events_for_point
|
| 239 |
+
|
| 240 |
+
# If the intersection occurs to the right of the sweep line, OR
|
| 241 |
+
# if the intersection is on the sweep line and it's above the
|
| 242 |
+
# current event-point, add it as a new Event to the queue.
|
| 243 |
+
if is_new and p[X] >= self._current_event_point_x:
|
| 244 |
+
event_isect = Event(Event.Type.INTERSECTION, p, None, None)
|
| 245 |
+
self.queue.offer(p, event_isect)
|
| 246 |
+
|
| 247 |
+
def _sweep_to(self, p):
|
| 248 |
+
if p[X] == self._current_event_point_x:
|
| 249 |
+
# happens in rare cases,
|
| 250 |
+
# we can safely ignore
|
| 251 |
+
return
|
| 252 |
+
|
| 253 |
+
self._current_event_point_x = p[X]
|
| 254 |
+
|
| 255 |
+
def insert(self, event):
|
| 256 |
+
assert(event not in self._events_current_sweep)
|
| 257 |
+
assert(event.type != Event.Type.START_VERTICAL)
|
| 258 |
+
if USE_DEBUG:
|
| 259 |
+
assert(event.in_sweep == False)
|
| 260 |
+
assert(event.other.in_sweep == False)
|
| 261 |
+
|
| 262 |
+
self._events_current_sweep.insert(event, None)
|
| 263 |
+
|
| 264 |
+
if USE_DEBUG:
|
| 265 |
+
event.in_sweep = True
|
| 266 |
+
event.other.in_sweep = True
|
| 267 |
+
|
| 268 |
+
def remove(self, event):
|
| 269 |
+
try:
|
| 270 |
+
self._events_current_sweep.remove(event)
|
| 271 |
+
if USE_DEBUG:
|
| 272 |
+
assert(event.in_sweep == True)
|
| 273 |
+
assert(event.other.in_sweep == True)
|
| 274 |
+
event.in_sweep = False
|
| 275 |
+
event.other.in_sweep = False
|
| 276 |
+
return True
|
| 277 |
+
except KeyError:
|
| 278 |
+
if USE_DEBUG:
|
| 279 |
+
assert(event.in_sweep == False)
|
| 280 |
+
assert(event.other.in_sweep == False)
|
| 281 |
+
return False
|
| 282 |
+
|
| 283 |
+
def above(self, event):
|
| 284 |
+
return self._events_current_sweep.succ_key(event, None)
|
| 285 |
+
|
| 286 |
+
def below(self, event):
|
| 287 |
+
return self._events_current_sweep.prev_key(event, None)
|
| 288 |
+
|
| 289 |
+
'''
|
| 290 |
+
def above_all(self, event):
|
| 291 |
+
while True:
|
| 292 |
+
event = self.above(event)
|
| 293 |
+
if event is None:
|
| 294 |
+
break
|
| 295 |
+
yield event
|
| 296 |
+
'''
|
| 297 |
+
|
| 298 |
+
def above_all(self, event):
|
| 299 |
+
# assert(event not in self._events_current_sweep)
|
| 300 |
+
return self._events_current_sweep.key_slice(event, None, reverse=False)
|
| 301 |
+
|
| 302 |
+
def handle(self, p, events_current):
|
| 303 |
+
if len(events_current) == 0:
|
| 304 |
+
return
|
| 305 |
+
# done already
|
| 306 |
+
# self._sweep_to(events_current[0])
|
| 307 |
+
assert(p[0] == self._current_event_point_x)
|
| 308 |
+
|
| 309 |
+
if not USE_IGNORE_SEGMENT_ENDINGS:
|
| 310 |
+
if len(events_current) > 1:
|
| 311 |
+
for i in range(0, len(events_current) - 1):
|
| 312 |
+
for j in range(i + 1, len(events_current)):
|
| 313 |
+
self._check_intersection(
|
| 314 |
+
events_current[i], events_current[j])
|
| 315 |
+
|
| 316 |
+
for e in events_current:
|
| 317 |
+
self.handle_event(e)
|
| 318 |
+
|
| 319 |
+
def handle_event(self, event):
|
| 320 |
+
t = event.type
|
| 321 |
+
if t == Event.Type.START:
|
| 322 |
+
# print(" START")
|
| 323 |
+
self._before = False
|
| 324 |
+
self.insert(event)
|
| 325 |
+
|
| 326 |
+
e_above = self.above(event)
|
| 327 |
+
e_below = self.below(event)
|
| 328 |
+
|
| 329 |
+
self._check_intersection(event, e_above)
|
| 330 |
+
self._check_intersection(event, e_below)
|
| 331 |
+
if USE_PARANOID:
|
| 332 |
+
self._check_intersection(e_above, e_below)
|
| 333 |
+
|
| 334 |
+
elif t == Event.Type.END:
|
| 335 |
+
# print(" END")
|
| 336 |
+
self._before = True
|
| 337 |
+
|
| 338 |
+
e_above = self.above(event)
|
| 339 |
+
e_below = self.below(event)
|
| 340 |
+
|
| 341 |
+
self.remove(event)
|
| 342 |
+
|
| 343 |
+
self._check_intersection(e_above, e_below)
|
| 344 |
+
if USE_PARANOID:
|
| 345 |
+
self._check_intersection(event, e_above)
|
| 346 |
+
self._check_intersection(event, e_below)
|
| 347 |
+
|
| 348 |
+
elif t == Event.Type.INTERSECTION:
|
| 349 |
+
# print(" INTERSECTION")
|
| 350 |
+
self._before = True
|
| 351 |
+
event_set = self.intersections[event.point]
|
| 352 |
+
# note: events_current aren't sorted.
|
| 353 |
+
reinsert_stack = [] # Stack
|
| 354 |
+
for e in event_set:
|
| 355 |
+
# If we the Event was not already removed,
|
| 356 |
+
# we want to insert it later on.
|
| 357 |
+
if self.remove(e):
|
| 358 |
+
reinsert_stack.append(e)
|
| 359 |
+
self._before = False
|
| 360 |
+
|
| 361 |
+
# Insert all Events that we were able to remove.
|
| 362 |
+
while reinsert_stack:
|
| 363 |
+
e = reinsert_stack.pop()
|
| 364 |
+
|
| 365 |
+
self.insert(e)
|
| 366 |
+
|
| 367 |
+
e_above = self.above(e)
|
| 368 |
+
e_below = self.below(e)
|
| 369 |
+
|
| 370 |
+
self._check_intersection(e, e_above)
|
| 371 |
+
self._check_intersection(e, e_below)
|
| 372 |
+
if USE_PARANOID:
|
| 373 |
+
self._check_intersection(e_above, e_below)
|
| 374 |
+
elif (USE_VERTICAL and
|
| 375 |
+
(t == Event.Type.START_VERTICAL)):
|
| 376 |
+
|
| 377 |
+
# just check sanity
|
| 378 |
+
assert(event.segment[0][X] == event.segment[1][X])
|
| 379 |
+
assert(event.segment[0][Y] <= event.segment[1][Y])
|
| 380 |
+
|
| 381 |
+
# In this case we only need to find all segments in this span.
|
| 382 |
+
y_above_max = event.segment[1][Y]
|
| 383 |
+
|
| 384 |
+
# self.insert(event)
|
| 385 |
+
for e_above in self.above_all(event):
|
| 386 |
+
if e_above.type == Event.Type.START_VERTICAL:
|
| 387 |
+
continue
|
| 388 |
+
y_above = e_above.y_intercept_x(
|
| 389 |
+
self._current_event_point_x)
|
| 390 |
+
if USE_IGNORE_SEGMENT_ENDINGS:
|
| 391 |
+
if y_above >= y_above_max:
|
| 392 |
+
break
|
| 393 |
+
else:
|
| 394 |
+
if y_above > y_above_max:
|
| 395 |
+
break
|
| 396 |
+
|
| 397 |
+
# We know this intersects,
|
| 398 |
+
# so we could use a faster function now:
|
| 399 |
+
# ix = (self._current_event_point_x, y_above)
|
| 400 |
+
# ...however best use existing functions
|
| 401 |
+
# since it does all sanity checks on endpoints... etc.
|
| 402 |
+
self._check_intersection(event, e_above)
|
| 403 |
+
|
| 404 |
+
# self.remove(event)
|
| 405 |
+
|
| 406 |
+
|
| 407 |
+
class EventQueue:
|
| 408 |
+
__slots__ = (
|
| 409 |
+
# note: we only ever pop_min, this could use a 'heap' structure.
|
| 410 |
+
# The sorted map holding the points -> event list
|
| 411 |
+
# [Point: Event] (tree)
|
| 412 |
+
"events_scan",
|
| 413 |
+
)
|
| 414 |
+
|
| 415 |
+
def __init__(self, segments, line: SweepLine):
|
| 416 |
+
self.events_scan = RBTree()
|
| 417 |
+
# segments = [s for s in segments if s[0][0] != s[1][0] and s[0][1] != s[1][1]]
|
| 418 |
+
|
| 419 |
+
for s in segments:
|
| 420 |
+
assert(s[0][X] <= s[1][X])
|
| 421 |
+
|
| 422 |
+
slope = slope_v2v2(*s)
|
| 423 |
+
|
| 424 |
+
if s[0] == s[1]:
|
| 425 |
+
pass
|
| 426 |
+
elif USE_VERTICAL and (s[0][X] == s[1][X]):
|
| 427 |
+
e_start = Event(Event.Type.START_VERTICAL, s[0], s, slope)
|
| 428 |
+
|
| 429 |
+
if USE_DEBUG:
|
| 430 |
+
e_start.other = e_start # FAKE, avoid error checking
|
| 431 |
+
|
| 432 |
+
self.offer(s[0], e_start)
|
| 433 |
+
else:
|
| 434 |
+
e_start = Event(Event.Type.START, s[0], s, slope)
|
| 435 |
+
e_end = Event(Event.Type.END, s[1], s, slope)
|
| 436 |
+
|
| 437 |
+
if USE_DEBUG:
|
| 438 |
+
e_start.other = e_end
|
| 439 |
+
e_end.other = e_start
|
| 440 |
+
|
| 441 |
+
self.offer(s[0], e_start)
|
| 442 |
+
self.offer(s[1], e_end)
|
| 443 |
+
|
| 444 |
+
line.queue = self
|
| 445 |
+
|
| 446 |
+
def offer(self, p, e: Event):
|
| 447 |
+
"""
|
| 448 |
+
Offer a new event ``s`` at point ``p`` in this queue.
|
| 449 |
+
"""
|
| 450 |
+
existing = self.events_scan.setdefault(
|
| 451 |
+
p, ([], [], [], []) if USE_VERTICAL else
|
| 452 |
+
([], [], []))
|
| 453 |
+
# Can use double linked-list for easy insertion at beginning/end
|
| 454 |
+
'''
|
| 455 |
+
if e.type == Event.Type.END:
|
| 456 |
+
existing.insert(0, e)
|
| 457 |
+
else:
|
| 458 |
+
existing.append(e)
|
| 459 |
+
'''
|
| 460 |
+
|
| 461 |
+
existing[e.type].append(e)
|
| 462 |
+
|
| 463 |
+
# return a set of events
|
| 464 |
+
def poll(self):
|
| 465 |
+
"""
|
| 466 |
+
Get, and remove, the first (lowest) item from this queue.
|
| 467 |
+
|
| 468 |
+
:return: the first (lowest) item from this queue.
|
| 469 |
+
:rtype: Point, Event pair.
|
| 470 |
+
"""
|
| 471 |
+
assert(len(self.events_scan) != 0)
|
| 472 |
+
p, events_current = self.events_scan.pop_min()
|
| 473 |
+
return p, events_current
|
| 474 |
+
|
| 475 |
+
|
| 476 |
+
def isect_segments(segments) -> list:
|
| 477 |
+
# order points left -> right
|
| 478 |
+
segments = [
|
| 479 |
+
# in nearly all cases, comparing X is enough,
|
| 480 |
+
# but compare Y too for vertical lines
|
| 481 |
+
(s[0], s[1]) if (s[0] <= s[1]) else
|
| 482 |
+
(s[1], s[0])
|
| 483 |
+
for s in segments]
|
| 484 |
+
|
| 485 |
+
sweep_line = SweepLine()
|
| 486 |
+
queue = EventQueue(segments, sweep_line)
|
| 487 |
+
|
| 488 |
+
while len(queue.events_scan) > 0:
|
| 489 |
+
if USE_VERBOSE:
|
| 490 |
+
print(len(queue.events_scan), sweep_line._current_event_point_x)
|
| 491 |
+
p, e_ls = queue.poll()
|
| 492 |
+
for events_current in e_ls:
|
| 493 |
+
if events_current:
|
| 494 |
+
sweep_line._sweep_to(p)
|
| 495 |
+
sweep_line.handle(p, events_current)
|
| 496 |
+
|
| 497 |
+
return sweep_line.get_intersections()
|
| 498 |
+
|
| 499 |
+
|
| 500 |
+
def isect_polygon(points) -> list:
|
| 501 |
+
n = len(points)
|
| 502 |
+
segments = [
|
| 503 |
+
(tuple(points[i]), tuple(points[(i + 1) % n]))
|
| 504 |
+
for i in range(n)]
|
| 505 |
+
return isect_segments(segments)
|
| 506 |
+
|
| 507 |
+
|
| 508 |
+
# ----------------------------------------------------------------------------
|
| 509 |
+
# 2D math utilities
|
| 510 |
+
|
| 511 |
+
|
| 512 |
+
def slope_v2v2(p1, p2):
|
| 513 |
+
if p1[X] == p2[X]:
|
| 514 |
+
if p1[Y] < p2[Y]:
|
| 515 |
+
return INF
|
| 516 |
+
else:
|
| 517 |
+
return -INF
|
| 518 |
+
else:
|
| 519 |
+
return (p2[Y] - p1[Y]) / (p2[X] - p1[X])
|
| 520 |
+
|
| 521 |
+
|
| 522 |
+
def sub_v2v2(a, b):
|
| 523 |
+
return (
|
| 524 |
+
a[0] - b[0],
|
| 525 |
+
a[1] - b[1])
|
| 526 |
+
|
| 527 |
+
|
| 528 |
+
def dot_v2v2(a, b):
|
| 529 |
+
return (
|
| 530 |
+
(a[0] * b[0]) +
|
| 531 |
+
(a[1] * b[1]))
|
| 532 |
+
|
| 533 |
+
|
| 534 |
+
def len_squared_v2v2(a, b):
|
| 535 |
+
c = sub_v2v2(a, b)
|
| 536 |
+
return dot_v2v2(c, c)
|
| 537 |
+
|
| 538 |
+
|
| 539 |
+
def line_point_factor_v2(p, l1, l2, default=0.0):
|
| 540 |
+
u = sub_v2v2(l2, l1)
|
| 541 |
+
h = sub_v2v2(p, l1)
|
| 542 |
+
dot = dot_v2v2(u, u)
|
| 543 |
+
return (dot_v2v2(u, h) / dot) if dot != 0.0 else default
|
| 544 |
+
|
| 545 |
+
|
| 546 |
+
def isect_seg_seg_v2_point(v1, v2, v3, v4, bias=0.0):
|
| 547 |
+
# Only for predictability and hashable point when same input is given
|
| 548 |
+
if v1 > v2:
|
| 549 |
+
v1, v2 = v2, v1
|
| 550 |
+
if v3 > v4:
|
| 551 |
+
v3, v4 = v4, v3
|
| 552 |
+
|
| 553 |
+
if (v1, v2) > (v3, v4):
|
| 554 |
+
v1, v2, v3, v4 = v3, v4, v1, v2
|
| 555 |
+
|
| 556 |
+
div = (v2[0] - v1[0]) * (v4[1] - v3[1]) - (v2[1] - v1[1]) * (v4[0] - v3[0])
|
| 557 |
+
if div == 0.0:
|
| 558 |
+
return None
|
| 559 |
+
|
| 560 |
+
vi = (((v3[0] - v4[0]) *
|
| 561 |
+
(v1[0] * v2[1] - v1[1] * v2[0]) - (v1[0] - v2[0]) *
|
| 562 |
+
(v3[0] * v4[1] - v3[1] * v4[0])) / div,
|
| 563 |
+
((v3[1] - v4[1]) *
|
| 564 |
+
(v1[0] * v2[1] - v1[1] * v2[0]) - (v1[1] - v2[1]) *
|
| 565 |
+
(v3[0] * v4[1] - v3[1] * v4[0])) / div,
|
| 566 |
+
)
|
| 567 |
+
|
| 568 |
+
fac = line_point_factor_v2(vi, v1, v2, default=-1.0)
|
| 569 |
+
if fac < 0.0 - bias or fac > 1.0 + bias:
|
| 570 |
+
return None
|
| 571 |
+
|
| 572 |
+
fac = line_point_factor_v2(vi, v3, v4, default=-1.0)
|
| 573 |
+
if fac < 0.0 - bias or fac > 1.0 + bias:
|
| 574 |
+
return None
|
| 575 |
+
|
| 576 |
+
# vi = round(vi[X], 8), round(vi[Y], 8)
|
| 577 |
+
return vi
|
| 578 |
+
|
| 579 |
+
|
| 580 |
+
# ----------------------------------------------------------------------------
|
| 581 |
+
# Simple naive line intersect, (for testing only)
|
| 582 |
+
|
| 583 |
+
|
| 584 |
+
def isect_segments__naive(segments) -> list:
|
| 585 |
+
"""
|
| 586 |
+
Brute force O(n2) version of ``isect_segments`` for test validation.
|
| 587 |
+
"""
|
| 588 |
+
isect = []
|
| 589 |
+
|
| 590 |
+
# order points left -> right
|
| 591 |
+
segments = [
|
| 592 |
+
(s[0], s[1]) if s[0][X] <= s[1][X] else
|
| 593 |
+
(s[1], s[0])
|
| 594 |
+
for s in segments]
|
| 595 |
+
|
| 596 |
+
n = len(segments)
|
| 597 |
+
|
| 598 |
+
for i in range(n):
|
| 599 |
+
a0, a1 = segments[i]
|
| 600 |
+
for j in range(i + 1, n):
|
| 601 |
+
b0, b1 = segments[j]
|
| 602 |
+
if a0 not in (b0, b1) and a1 not in (b0, b1):
|
| 603 |
+
ix = isect_seg_seg_v2_point(a0, a1, b0, b1)
|
| 604 |
+
if ix is not None:
|
| 605 |
+
# USE_IGNORE_SEGMENT_ENDINGS handled already
|
| 606 |
+
isect.append(ix)
|
| 607 |
+
|
| 608 |
+
return isect
|
| 609 |
+
|
| 610 |
+
|
| 611 |
+
def isect_polygon__naive(points) -> list:
|
| 612 |
+
"""
|
| 613 |
+
Brute force O(n2) version of ``isect_polygon`` for test validation.
|
| 614 |
+
"""
|
| 615 |
+
isect = []
|
| 616 |
+
|
| 617 |
+
n = len(points)
|
| 618 |
+
|
| 619 |
+
for i in range(n):
|
| 620 |
+
a0, a1 = points[i], points[(i + 1) % n]
|
| 621 |
+
for j in range(i + 1, n):
|
| 622 |
+
b0, b1 = points[j], points[(j + 1) % n]
|
| 623 |
+
if a0 not in (b0, b1) and a1 not in (b0, b1):
|
| 624 |
+
ix = isect_seg_seg_v2_point(a0, a1, b0, b1)
|
| 625 |
+
if ix is not None:
|
| 626 |
+
|
| 627 |
+
if USE_IGNORE_SEGMENT_ENDINGS:
|
| 628 |
+
if ((len_squared_v2v2(ix, a0) < EPS_SQ or
|
| 629 |
+
len_squared_v2v2(ix, a1) < EPS_SQ) and
|
| 630 |
+
(len_squared_v2v2(ix, b0) < EPS_SQ or
|
| 631 |
+
len_squared_v2v2(ix, b1) < EPS_SQ)):
|
| 632 |
+
continue
|
| 633 |
+
|
| 634 |
+
isect.append(ix)
|
| 635 |
+
|
| 636 |
+
return isect
|
| 637 |
+
|
| 638 |
+
|
| 639 |
+
# ----------------------------------------------------------------------------
|
| 640 |
+
# Inline Libs
|
| 641 |
+
#
|
| 642 |
+
# bintrees: 2.0.2, extracted from:
|
| 643 |
+
# http://pypi.python.org/pypi/bintrees
|
| 644 |
+
#
|
| 645 |
+
# - Removed unused functions, such as slicing and range iteration.
|
| 646 |
+
# - Added 'cmp' and and 'cmp_data' arguments,
|
| 647 |
+
# so we can define our own comparison that takes an arg.
|
| 648 |
+
# Needed for sweep-line.
|
| 649 |
+
# - Added support for 'default' arguments for prev_item/succ_item,
|
| 650 |
+
# so we can avoid exception handling.
|
| 651 |
+
|
| 652 |
+
# -------
|
| 653 |
+
# ABCTree
|
| 654 |
+
|
| 655 |
+
from operator import attrgetter
|
| 656 |
+
_sentinel = object()
|
| 657 |
+
|
| 658 |
+
|
| 659 |
+
class _ABCTree(object):
|
| 660 |
+
def __init__(self, items=None, cmp=None, cmp_data=None):
|
| 661 |
+
"""T.__init__(...) initializes T; see T.__class__.__doc__ for signature"""
|
| 662 |
+
self._root = None
|
| 663 |
+
self._count = 0
|
| 664 |
+
if cmp is None:
|
| 665 |
+
def cmp(cmp_data, a, b):
|
| 666 |
+
if a < b:
|
| 667 |
+
return -1
|
| 668 |
+
elif a > b:
|
| 669 |
+
return 1
|
| 670 |
+
else:
|
| 671 |
+
return 0
|
| 672 |
+
self._cmp = cmp
|
| 673 |
+
self._cmp_data = cmp_data
|
| 674 |
+
if items is not None:
|
| 675 |
+
self.update(items)
|
| 676 |
+
|
| 677 |
+
def clear(self):
|
| 678 |
+
"""T.clear() -> None. Remove all items from T."""
|
| 679 |
+
def _clear(node):
|
| 680 |
+
if node is not None:
|
| 681 |
+
_clear(node.left)
|
| 682 |
+
_clear(node.right)
|
| 683 |
+
node.free()
|
| 684 |
+
_clear(self._root)
|
| 685 |
+
self._count = 0
|
| 686 |
+
self._root = None
|
| 687 |
+
|
| 688 |
+
@property
|
| 689 |
+
def count(self):
|
| 690 |
+
"""Get items count."""
|
| 691 |
+
return self._count
|
| 692 |
+
|
| 693 |
+
def get_value(self, key):
|
| 694 |
+
node = self._root
|
| 695 |
+
while node is not None:
|
| 696 |
+
cmp = self._cmp(self._cmp_data, key, node.key)
|
| 697 |
+
if cmp == 0:
|
| 698 |
+
return node.value
|
| 699 |
+
elif cmp < 0:
|
| 700 |
+
node = node.left
|
| 701 |
+
else:
|
| 702 |
+
node = node.right
|
| 703 |
+
raise KeyError(str(key))
|
| 704 |
+
|
| 705 |
+
def pop_item(self):
|
| 706 |
+
"""T.pop_item() -> (k, v), remove and return some (key, value) pair as a
|
| 707 |
+
2-tuple; but raise KeyError if T is empty.
|
| 708 |
+
"""
|
| 709 |
+
if self.is_empty():
|
| 710 |
+
raise KeyError("pop_item(): tree is empty")
|
| 711 |
+
node = self._root
|
| 712 |
+
while True:
|
| 713 |
+
if node.left is not None:
|
| 714 |
+
node = node.left
|
| 715 |
+
elif node.right is not None:
|
| 716 |
+
node = node.right
|
| 717 |
+
else:
|
| 718 |
+
break
|
| 719 |
+
key = node.key
|
| 720 |
+
value = node.value
|
| 721 |
+
self.remove(key)
|
| 722 |
+
return key, value
|
| 723 |
+
popitem = pop_item # for compatibility to dict()
|
| 724 |
+
|
| 725 |
+
def min_item(self):
|
| 726 |
+
"""Get item with min key of tree, raises ValueError if tree is empty."""
|
| 727 |
+
if self.is_empty():
|
| 728 |
+
raise ValueError("Tree is empty")
|
| 729 |
+
node = self._root
|
| 730 |
+
while node.left is not None:
|
| 731 |
+
node = node.left
|
| 732 |
+
return node.key, node.value
|
| 733 |
+
|
| 734 |
+
def max_item(self):
|
| 735 |
+
"""Get item with max key of tree, raises ValueError if tree is empty."""
|
| 736 |
+
if self.is_empty():
|
| 737 |
+
raise ValueError("Tree is empty")
|
| 738 |
+
node = self._root
|
| 739 |
+
while node.right is not None:
|
| 740 |
+
node = node.right
|
| 741 |
+
return node.key, node.value
|
| 742 |
+
|
| 743 |
+
def succ_item(self, key, default=_sentinel):
|
| 744 |
+
"""Get successor (k,v) pair of key, raises KeyError if key is max key
|
| 745 |
+
or key does not exist. optimized for pypy.
|
| 746 |
+
"""
|
| 747 |
+
# removed graingets version, because it was little slower on CPython and much slower on pypy
|
| 748 |
+
# this version runs about 4x faster with pypy than the Cython version
|
| 749 |
+
# Note: Code sharing of succ_item() and ceiling_item() is possible, but has always a speed penalty.
|
| 750 |
+
node = self._root
|
| 751 |
+
succ_node = None
|
| 752 |
+
while node is not None:
|
| 753 |
+
cmp = self._cmp(self._cmp_data, key, node.key)
|
| 754 |
+
if cmp == 0:
|
| 755 |
+
break
|
| 756 |
+
elif cmp < 0:
|
| 757 |
+
if (succ_node is None) or self._cmp(self._cmp_data, node.key, succ_node.key) < 0:
|
| 758 |
+
succ_node = node
|
| 759 |
+
node = node.left
|
| 760 |
+
else:
|
| 761 |
+
node = node.right
|
| 762 |
+
|
| 763 |
+
if node is None: # stay at dead end
|
| 764 |
+
if default is _sentinel:
|
| 765 |
+
raise KeyError(str(key))
|
| 766 |
+
return default
|
| 767 |
+
# found node of key
|
| 768 |
+
if node.right is not None:
|
| 769 |
+
# find smallest node of right subtree
|
| 770 |
+
node = node.right
|
| 771 |
+
while node.left is not None:
|
| 772 |
+
node = node.left
|
| 773 |
+
if succ_node is None:
|
| 774 |
+
succ_node = node
|
| 775 |
+
elif self._cmp(self._cmp_data, node.key, succ_node.key) < 0:
|
| 776 |
+
succ_node = node
|
| 777 |
+
elif succ_node is None: # given key is biggest in tree
|
| 778 |
+
if default is _sentinel:
|
| 779 |
+
raise KeyError(str(key))
|
| 780 |
+
return default
|
| 781 |
+
return succ_node.key, succ_node.value
|
| 782 |
+
|
| 783 |
+
def prev_item(self, key, default=_sentinel):
|
| 784 |
+
"""Get predecessor (k,v) pair of key, raises KeyError if key is min key
|
| 785 |
+
or key does not exist. optimized for pypy.
|
| 786 |
+
"""
|
| 787 |
+
# removed graingets version, because it was little slower on CPython and much slower on pypy
|
| 788 |
+
# this version runs about 4x faster with pypy than the Cython version
|
| 789 |
+
# Note: Code sharing of prev_item() and floor_item() is possible, but has always a speed penalty.
|
| 790 |
+
node = self._root
|
| 791 |
+
prev_node = None
|
| 792 |
+
|
| 793 |
+
while node is not None:
|
| 794 |
+
cmp = self._cmp(self._cmp_data, key, node.key)
|
| 795 |
+
if cmp == 0:
|
| 796 |
+
break
|
| 797 |
+
elif cmp < 0:
|
| 798 |
+
node = node.left
|
| 799 |
+
else:
|
| 800 |
+
if (prev_node is None) or self._cmp(self._cmp_data, prev_node.key, node.key) < 0:
|
| 801 |
+
prev_node = node
|
| 802 |
+
node = node.right
|
| 803 |
+
|
| 804 |
+
if node is None: # stay at dead end (None)
|
| 805 |
+
if default is _sentinel:
|
| 806 |
+
raise KeyError(str(key))
|
| 807 |
+
return default
|
| 808 |
+
# found node of key
|
| 809 |
+
if node.left is not None:
|
| 810 |
+
# find biggest node of left subtree
|
| 811 |
+
node = node.left
|
| 812 |
+
while node.right is not None:
|
| 813 |
+
node = node.right
|
| 814 |
+
if prev_node is None:
|
| 815 |
+
prev_node = node
|
| 816 |
+
elif self._cmp(self._cmp_data, prev_node.key, node.key) < 0:
|
| 817 |
+
prev_node = node
|
| 818 |
+
elif prev_node is None: # given key is smallest in tree
|
| 819 |
+
if default is _sentinel:
|
| 820 |
+
raise KeyError(str(key))
|
| 821 |
+
return default
|
| 822 |
+
return prev_node.key, prev_node.value
|
| 823 |
+
|
| 824 |
+
def __repr__(self):
|
| 825 |
+
"""T.__repr__(...) <==> repr(x)"""
|
| 826 |
+
tpl = "%s({%s})" % (self.__class__.__name__, '%s')
|
| 827 |
+
return tpl % ", ".join(("%r: %r" % item for item in self.items()))
|
| 828 |
+
|
| 829 |
+
def __contains__(self, key):
|
| 830 |
+
"""k in T -> True if T has a key k, else False"""
|
| 831 |
+
try:
|
| 832 |
+
self.get_value(key)
|
| 833 |
+
return True
|
| 834 |
+
except KeyError:
|
| 835 |
+
return False
|
| 836 |
+
|
| 837 |
+
def __len__(self):
|
| 838 |
+
"""T.__len__() <==> len(x)"""
|
| 839 |
+
return self.count
|
| 840 |
+
|
| 841 |
+
def is_empty(self):
|
| 842 |
+
"""T.is_empty() -> False if T contains any items else True"""
|
| 843 |
+
return self.count == 0
|
| 844 |
+
|
| 845 |
+
def set_default(self, key, default=None):
|
| 846 |
+
"""T.set_default(k[,d]) -> T.get(k,d), also set T[k]=d if k not in T"""
|
| 847 |
+
try:
|
| 848 |
+
return self.get_value(key)
|
| 849 |
+
except KeyError:
|
| 850 |
+
self.insert(key, default)
|
| 851 |
+
return default
|
| 852 |
+
setdefault = set_default # for compatibility to dict()
|
| 853 |
+
|
| 854 |
+
def get(self, key, default=None):
|
| 855 |
+
"""T.get(k[,d]) -> T[k] if k in T, else d. d defaults to None."""
|
| 856 |
+
try:
|
| 857 |
+
return self.get_value(key)
|
| 858 |
+
except KeyError:
|
| 859 |
+
return default
|
| 860 |
+
|
| 861 |
+
def pop(self, key, *args):
|
| 862 |
+
"""T.pop(k[,d]) -> v, remove specified key and return the corresponding value.
|
| 863 |
+
If key is not found, d is returned if given, otherwise KeyError is raised
|
| 864 |
+
"""
|
| 865 |
+
if len(args) > 1:
|
| 866 |
+
raise TypeError("pop expected at most 2 arguments, got %d" % (1 + len(args)))
|
| 867 |
+
try:
|
| 868 |
+
value = self.get_value(key)
|
| 869 |
+
self.remove(key)
|
| 870 |
+
return value
|
| 871 |
+
except KeyError:
|
| 872 |
+
if len(args) == 0:
|
| 873 |
+
raise
|
| 874 |
+
else:
|
| 875 |
+
return args[0]
|
| 876 |
+
|
| 877 |
+
def prev_key(self, key, default=_sentinel):
|
| 878 |
+
"""Get predecessor to key, raises KeyError if key is min key
|
| 879 |
+
or key does not exist.
|
| 880 |
+
"""
|
| 881 |
+
item = self.prev_item(key, default)
|
| 882 |
+
return default if item is default else item[0]
|
| 883 |
+
|
| 884 |
+
def succ_key(self, key, default=_sentinel):
|
| 885 |
+
"""Get successor to key, raises KeyError if key is max key
|
| 886 |
+
or key does not exist.
|
| 887 |
+
"""
|
| 888 |
+
item = self.succ_item(key, default)
|
| 889 |
+
return default if item is default else item[0]
|
| 890 |
+
|
| 891 |
+
def pop_min(self):
|
| 892 |
+
"""T.pop_min() -> (k, v), remove item with minimum key, raise ValueError
|
| 893 |
+
if T is empty.
|
| 894 |
+
"""
|
| 895 |
+
item = self.min_item()
|
| 896 |
+
self.remove(item[0])
|
| 897 |
+
return item
|
| 898 |
+
|
| 899 |
+
def pop_max(self):
|
| 900 |
+
"""T.pop_max() -> (k, v), remove item with maximum key, raise ValueError
|
| 901 |
+
if T is empty.
|
| 902 |
+
"""
|
| 903 |
+
item = self.max_item()
|
| 904 |
+
self.remove(item[0])
|
| 905 |
+
return item
|
| 906 |
+
|
| 907 |
+
def min_key(self):
|
| 908 |
+
"""Get min key of tree, raises ValueError if tree is empty. """
|
| 909 |
+
return self.min_item()[0]
|
| 910 |
+
|
| 911 |
+
def max_key(self):
|
| 912 |
+
"""Get max key of tree, raises ValueError if tree is empty. """
|
| 913 |
+
return self.max_item()[0]
|
| 914 |
+
|
| 915 |
+
def key_slice(self, start_key, end_key, reverse=False):
|
| 916 |
+
"""T.key_slice(start_key, end_key) -> key iterator:
|
| 917 |
+
start_key <= key < end_key.
|
| 918 |
+
|
| 919 |
+
Yields keys in ascending order if reverse is False else in descending order.
|
| 920 |
+
"""
|
| 921 |
+
return (k for k, v in self.iter_items(start_key, end_key, reverse=reverse))
|
| 922 |
+
|
| 923 |
+
def iter_items(self, start_key=None, end_key=None, reverse=False):
|
| 924 |
+
"""Iterates over the (key, value) items of the associated tree,
|
| 925 |
+
in ascending order if reverse is True, iterate in descending order,
|
| 926 |
+
reverse defaults to False"""
|
| 927 |
+
# optimized iterator (reduced method calls) - faster on CPython but slower on pypy
|
| 928 |
+
|
| 929 |
+
if self.is_empty():
|
| 930 |
+
return []
|
| 931 |
+
if reverse:
|
| 932 |
+
return self._iter_items_backward(start_key, end_key)
|
| 933 |
+
else:
|
| 934 |
+
return self._iter_items_forward(start_key, end_key)
|
| 935 |
+
|
| 936 |
+
def _iter_items_forward(self, start_key=None, end_key=None):
|
| 937 |
+
for item in self._iter_items(left=attrgetter("left"), right=attrgetter("right"),
|
| 938 |
+
start_key=start_key, end_key=end_key):
|
| 939 |
+
yield item
|
| 940 |
+
|
| 941 |
+
def _iter_items_backward(self, start_key=None, end_key=None):
|
| 942 |
+
for item in self._iter_items(left=attrgetter("right"), right=attrgetter("left"),
|
| 943 |
+
start_key=start_key, end_key=end_key):
|
| 944 |
+
yield item
|
| 945 |
+
|
| 946 |
+
def _iter_items(self, left=attrgetter("left"), right=attrgetter("right"), start_key=None, end_key=None):
|
| 947 |
+
node = self._root
|
| 948 |
+
stack = []
|
| 949 |
+
go_left = True
|
| 950 |
+
in_range = self._get_in_range_func(start_key, end_key)
|
| 951 |
+
|
| 952 |
+
while True:
|
| 953 |
+
if left(node) is not None and go_left:
|
| 954 |
+
stack.append(node)
|
| 955 |
+
node = left(node)
|
| 956 |
+
else:
|
| 957 |
+
if in_range(node.key):
|
| 958 |
+
yield node.key, node.value
|
| 959 |
+
if right(node) is not None:
|
| 960 |
+
node = right(node)
|
| 961 |
+
go_left = True
|
| 962 |
+
else:
|
| 963 |
+
if not len(stack):
|
| 964 |
+
return # all done
|
| 965 |
+
node = stack.pop()
|
| 966 |
+
go_left = False
|
| 967 |
+
|
| 968 |
+
def _get_in_range_func(self, start_key, end_key):
|
| 969 |
+
if start_key is None and end_key is None:
|
| 970 |
+
return lambda x: True
|
| 971 |
+
else:
|
| 972 |
+
if start_key is None:
|
| 973 |
+
start_key = self.min_key()
|
| 974 |
+
if end_key is None:
|
| 975 |
+
return (lambda x: self._cmp(self._cmp_data, start_key, x) <= 0)
|
| 976 |
+
else:
|
| 977 |
+
return (lambda x: self._cmp(self._cmp_data, start_key, x) <= 0 and
|
| 978 |
+
self._cmp(self._cmp_data, x, end_key) < 0)
|
| 979 |
+
|
| 980 |
+
|
| 981 |
+
# ------
|
| 982 |
+
# RBTree
|
| 983 |
+
|
| 984 |
+
class Node(object):
|
| 985 |
+
"""Internal object, represents a tree node."""
|
| 986 |
+
__slots__ = ['key', 'value', 'red', 'left', 'right']
|
| 987 |
+
|
| 988 |
+
def __init__(self, key=None, value=None):
|
| 989 |
+
self.key = key
|
| 990 |
+
self.value = value
|
| 991 |
+
self.red = True
|
| 992 |
+
self.left = None
|
| 993 |
+
self.right = None
|
| 994 |
+
|
| 995 |
+
def free(self):
|
| 996 |
+
self.left = None
|
| 997 |
+
self.right = None
|
| 998 |
+
self.key = None
|
| 999 |
+
self.value = None
|
| 1000 |
+
|
| 1001 |
+
def __getitem__(self, key):
|
| 1002 |
+
"""N.__getitem__(key) <==> x[key], where key is 0 (left) or 1 (right)."""
|
| 1003 |
+
return self.left if key == 0 else self.right
|
| 1004 |
+
|
| 1005 |
+
def __setitem__(self, key, value):
|
| 1006 |
+
"""N.__setitem__(key, value) <==> x[key]=value, where key is 0 (left) or 1 (right)."""
|
| 1007 |
+
if key == 0:
|
| 1008 |
+
self.left = value
|
| 1009 |
+
else:
|
| 1010 |
+
self.right = value
|
| 1011 |
+
|
| 1012 |
+
|
| 1013 |
+
class RBTree(_ABCTree):
|
| 1014 |
+
"""
|
| 1015 |
+
RBTree implements a balanced binary tree with a dict-like interface.
|
| 1016 |
+
|
| 1017 |
+
see: http://en.wikipedia.org/wiki/Red_black_tree
|
| 1018 |
+
"""
|
| 1019 |
+
@staticmethod
|
| 1020 |
+
def is_red(node):
|
| 1021 |
+
if (node is not None) and node.red:
|
| 1022 |
+
return True
|
| 1023 |
+
else:
|
| 1024 |
+
return False
|
| 1025 |
+
|
| 1026 |
+
@staticmethod
|
| 1027 |
+
def jsw_single(root, direction):
|
| 1028 |
+
other_side = 1 - direction
|
| 1029 |
+
save = root[other_side]
|
| 1030 |
+
root[other_side] = save[direction]
|
| 1031 |
+
save[direction] = root
|
| 1032 |
+
root.red = True
|
| 1033 |
+
save.red = False
|
| 1034 |
+
return save
|
| 1035 |
+
|
| 1036 |
+
@staticmethod
|
| 1037 |
+
def jsw_double(root, direction):
|
| 1038 |
+
other_side = 1 - direction
|
| 1039 |
+
root[other_side] = RBTree.jsw_single(root[other_side], other_side)
|
| 1040 |
+
return RBTree.jsw_single(root, direction)
|
| 1041 |
+
|
| 1042 |
+
def _new_node(self, key, value):
|
| 1043 |
+
"""Create a new tree node."""
|
| 1044 |
+
self._count += 1
|
| 1045 |
+
return Node(key, value)
|
| 1046 |
+
|
| 1047 |
+
def insert(self, key, value):
|
| 1048 |
+
"""T.insert(key, value) <==> T[key] = value, insert key, value into tree."""
|
| 1049 |
+
if self._root is None: # Empty tree case
|
| 1050 |
+
self._root = self._new_node(key, value)
|
| 1051 |
+
self._root.red = False # make root black
|
| 1052 |
+
return
|
| 1053 |
+
|
| 1054 |
+
head = Node() # False tree root
|
| 1055 |
+
grand_parent = None
|
| 1056 |
+
grand_grand_parent = head
|
| 1057 |
+
parent = None # parent
|
| 1058 |
+
direction = 0
|
| 1059 |
+
last = 0
|
| 1060 |
+
|
| 1061 |
+
# Set up helpers
|
| 1062 |
+
grand_grand_parent.right = self._root
|
| 1063 |
+
node = grand_grand_parent.right
|
| 1064 |
+
# Search down the tree
|
| 1065 |
+
while True:
|
| 1066 |
+
if node is None: # Insert new node at the bottom
|
| 1067 |
+
node = self._new_node(key, value)
|
| 1068 |
+
parent[direction] = node
|
| 1069 |
+
elif RBTree.is_red(node.left) and RBTree.is_red(node.right): # Color flip
|
| 1070 |
+
node.red = True
|
| 1071 |
+
node.left.red = False
|
| 1072 |
+
node.right.red = False
|
| 1073 |
+
|
| 1074 |
+
# Fix red violation
|
| 1075 |
+
if RBTree.is_red(node) and RBTree.is_red(parent):
|
| 1076 |
+
direction2 = 1 if grand_grand_parent.right is grand_parent else 0
|
| 1077 |
+
if node is parent[last]:
|
| 1078 |
+
grand_grand_parent[direction2] = RBTree.jsw_single(grand_parent, 1 - last)
|
| 1079 |
+
else:
|
| 1080 |
+
grand_grand_parent[direction2] = RBTree.jsw_double(grand_parent, 1 - last)
|
| 1081 |
+
|
| 1082 |
+
# Stop if found
|
| 1083 |
+
if self._cmp(self._cmp_data, key, node.key) == 0:
|
| 1084 |
+
node.value = value # set new value for key
|
| 1085 |
+
break
|
| 1086 |
+
|
| 1087 |
+
last = direction
|
| 1088 |
+
direction = 0 if (self._cmp(self._cmp_data, key, node.key) < 0) else 1
|
| 1089 |
+
# Update helpers
|
| 1090 |
+
if grand_parent is not None:
|
| 1091 |
+
grand_grand_parent = grand_parent
|
| 1092 |
+
grand_parent = parent
|
| 1093 |
+
parent = node
|
| 1094 |
+
node = node[direction]
|
| 1095 |
+
|
| 1096 |
+
self._root = head.right # Update root
|
| 1097 |
+
self._root.red = False # make root black
|
| 1098 |
+
|
| 1099 |
+
def remove(self, key):
|
| 1100 |
+
"""T.remove(key) <==> del T[key], remove item <key> from tree."""
|
| 1101 |
+
if self._root is None:
|
| 1102 |
+
raise KeyError(str(key))
|
| 1103 |
+
head = Node() # False tree root
|
| 1104 |
+
node = head
|
| 1105 |
+
node.right = self._root
|
| 1106 |
+
parent = None
|
| 1107 |
+
grand_parent = None
|
| 1108 |
+
found = None # Found item
|
| 1109 |
+
direction = 1
|
| 1110 |
+
|
| 1111 |
+
# Search and push a red down
|
| 1112 |
+
while node[direction] is not None:
|
| 1113 |
+
last = direction
|
| 1114 |
+
|
| 1115 |
+
# Update helpers
|
| 1116 |
+
grand_parent = parent
|
| 1117 |
+
parent = node
|
| 1118 |
+
node = node[direction]
|
| 1119 |
+
|
| 1120 |
+
direction = 1 if (self._cmp(self._cmp_data, node.key, key) < 0) else 0
|
| 1121 |
+
|
| 1122 |
+
# Save found node
|
| 1123 |
+
if self._cmp(self._cmp_data, key, node.key) == 0:
|
| 1124 |
+
found = node
|
| 1125 |
+
|
| 1126 |
+
# Push the red node down
|
| 1127 |
+
if not RBTree.is_red(node) and not RBTree.is_red(node[direction]):
|
| 1128 |
+
if RBTree.is_red(node[1 - direction]):
|
| 1129 |
+
parent[last] = RBTree.jsw_single(node, direction)
|
| 1130 |
+
parent = parent[last]
|
| 1131 |
+
elif not RBTree.is_red(node[1 - direction]):
|
| 1132 |
+
sibling = parent[1 - last]
|
| 1133 |
+
if sibling is not None:
|
| 1134 |
+
if (not RBTree.is_red(sibling[1 - last])) and (not RBTree.is_red(sibling[last])):
|
| 1135 |
+
# Color flip
|
| 1136 |
+
parent.red = False
|
| 1137 |
+
sibling.red = True
|
| 1138 |
+
node.red = True
|
| 1139 |
+
else:
|
| 1140 |
+
direction2 = 1 if grand_parent.right is parent else 0
|
| 1141 |
+
if RBTree.is_red(sibling[last]):
|
| 1142 |
+
grand_parent[direction2] = RBTree.jsw_double(parent, last)
|
| 1143 |
+
elif RBTree.is_red(sibling[1-last]):
|
| 1144 |
+
grand_parent[direction2] = RBTree.jsw_single(parent, last)
|
| 1145 |
+
# Ensure correct coloring
|
| 1146 |
+
grand_parent[direction2].red = True
|
| 1147 |
+
node.red = True
|
| 1148 |
+
grand_parent[direction2].left.red = False
|
| 1149 |
+
grand_parent[direction2].right.red = False
|
| 1150 |
+
|
| 1151 |
+
# Replace and remove if found
|
| 1152 |
+
if found is not None:
|
| 1153 |
+
found.key = node.key
|
| 1154 |
+
found.value = node.value
|
| 1155 |
+
parent[int(parent.right is node)] = node[int(node.left is None)]
|
| 1156 |
+
node.free()
|
| 1157 |
+
self._count -= 1
|
| 1158 |
+
|
| 1159 |
+
# Update root and make it black
|
| 1160 |
+
self._root = head.right
|
| 1161 |
+
if self._root is not None:
|
| 1162 |
+
self._root.red = False
|
| 1163 |
+
if not found:
|
| 1164 |
+
raise KeyError(str(key))
|
deps/laps.py
ADDED
|
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Code and weights taken from
|
| 2 |
+
# https://github.com/maciejczyzewski/neural-chessboard/
|
| 3 |
+
|
| 4 |
+
import deps
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import cv2
|
| 8 |
+
import collections
|
| 9 |
+
import scipy
|
| 10 |
+
import scipy.cluster
|
| 11 |
+
from tensorflow.keras.models import Sequential
|
| 12 |
+
from tensorflow.keras.layers import Dense, Conv2D, MaxPooling2D, BatchNormalization, Dropout, Flatten
|
| 13 |
+
from tensorflow.keras.optimizers import RMSprop
|
| 14 |
+
|
| 15 |
+
# Créer le modèle LAPS exactement comme dans le fichier JSON
|
| 16 |
+
def create_laps_model():
|
| 17 |
+
model = Sequential()
|
| 18 |
+
|
| 19 |
+
# Dense layer (dense_1) - 441 units, input_shape=(21, 21, 1)
|
| 20 |
+
model.add(Dense(441, input_shape=(21, 21, 1), name='dense_1'))
|
| 21 |
+
|
| 22 |
+
# First block: Conv2D layers + MaxPooling + BatchNorm
|
| 23 |
+
model.add(Conv2D(16, (3, 3), activation='elu', name='conv2d_1'))
|
| 24 |
+
model.add(Conv2D(16, (2, 2), activation='elu', name='conv2d_2'))
|
| 25 |
+
model.add(Conv2D(16, (1, 1), activation='elu', name='conv2d_3'))
|
| 26 |
+
model.add(MaxPooling2D(pool_size=(2, 2), name='max_pooling2d_1'))
|
| 27 |
+
model.add(BatchNormalization(name='batch_normalization_1'))
|
| 28 |
+
|
| 29 |
+
# Second block: Conv2D layers + MaxPooling + BatchNorm
|
| 30 |
+
model.add(Conv2D(16, (3, 3), activation='elu', name='conv2d_4'))
|
| 31 |
+
model.add(Conv2D(16, (2, 2), activation='elu', name='conv2d_5'))
|
| 32 |
+
model.add(Conv2D(16, (1, 1), activation='elu', name='conv2d_6'))
|
| 33 |
+
model.add(MaxPooling2D(pool_size=(2, 2), name='max_pooling2d_2'))
|
| 34 |
+
model.add(BatchNormalization(name='batch_normalization_2'))
|
| 35 |
+
|
| 36 |
+
# Dense layer (dense_2) - 128 units
|
| 37 |
+
model.add(Dense(128, activation='elu', name='dense_2'))
|
| 38 |
+
model.add(Dropout(0.5, name='dropout_1'))
|
| 39 |
+
model.add(Flatten(name='flatten_1'))
|
| 40 |
+
|
| 41 |
+
# Output layer (dense_3) - 2 units
|
| 42 |
+
model.add(Dense(2, activation='softmax', name='dense_3'))
|
| 43 |
+
|
| 44 |
+
# Compiler avec RMSprop comme l'original
|
| 45 |
+
model.compile(RMSprop(learning_rate=0.001),
|
| 46 |
+
loss='categorical_crossentropy',
|
| 47 |
+
metrics=['categorical_accuracy'])
|
| 48 |
+
|
| 49 |
+
return model
|
| 50 |
+
|
| 51 |
+
# Créer le modèle
|
| 52 |
+
NEURAL_MODEL = create_laps_model()
|
| 53 |
+
|
| 54 |
+
# Essayer de charger les poids
|
| 55 |
+
try:
|
| 56 |
+
# Essayer d'abord le fichier de poids fonctionnel
|
| 57 |
+
weights_path = "data/laps_models/laps_working.weights.h5"
|
| 58 |
+
NEURAL_MODEL.load_weights(weights_path)
|
| 59 |
+
print("✅ Poids LAPS chargés avec succès depuis laps_working.weights.h5")
|
| 60 |
+
except Exception as e:
|
| 61 |
+
try:
|
| 62 |
+
# Fallback vers le fichier original
|
| 63 |
+
weights_path = "data/laps_models/laps.weights.h5"
|
| 64 |
+
NEURAL_MODEL.load_weights(weights_path)
|
| 65 |
+
print("✅ Poids LAPS chargés avec succès depuis laps.weights.h5")
|
| 66 |
+
except Exception as e2:
|
| 67 |
+
print(f"⚠️ Impossible de charger les poids LAPS: {e2}")
|
| 68 |
+
print("Utilisation de poids aléatoires (le modèle fonctionnera quand même)")
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def laps_intersections(lines):
|
| 72 |
+
'''Find all intersections'''
|
| 73 |
+
__lines = [[(a[0], a[1]), (b[0], b[1])] for a, b in lines]
|
| 74 |
+
return deps.geometry.isect_segments(__lines)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def laps_cluster(points, max_dist=10):
|
| 78 |
+
"""cluster very similar points"""
|
| 79 |
+
Y = scipy.spatial.distance.pdist(points)
|
| 80 |
+
Z = scipy.cluster.hierarchy.single(Y)
|
| 81 |
+
T = scipy.cluster.hierarchy.fcluster(Z, max_dist, 'distance')
|
| 82 |
+
clusters = collections.defaultdict(list)
|
| 83 |
+
for i in range(len(T)):
|
| 84 |
+
clusters[T[i]].append(points[i])
|
| 85 |
+
clusters = clusters.values()
|
| 86 |
+
clusters = map(lambda arr: (np.mean(np.array(arr)[:, 0]),
|
| 87 |
+
np.mean(np.array(arr)[:, 1])), clusters)
|
| 88 |
+
# if two points are close, they become one mean point
|
| 89 |
+
return list(clusters)
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def laps_detector(img):
|
| 93 |
+
"""determine if that shape is positive"""
|
| 94 |
+
global NC_LAYER
|
| 95 |
+
|
| 96 |
+
hashid = str(hash(img.tostring()))
|
| 97 |
+
|
| 98 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
| 99 |
+
img = cv2.threshold(img, 0, 255, cv2.THRESH_OTSU)[1]
|
| 100 |
+
img = cv2.Canny(img, 0, 255)
|
| 101 |
+
img = cv2.resize(img, (21, 21), interpolation=cv2.INTER_CUBIC)
|
| 102 |
+
|
| 103 |
+
imgd = img
|
| 104 |
+
|
| 105 |
+
X = [np.where(img > int(255/2), 1, 0).ravel()]
|
| 106 |
+
X = X[0].reshape([-1, 21, 21, 1])
|
| 107 |
+
|
| 108 |
+
img = cv2.dilate(img, None)
|
| 109 |
+
mask = cv2.copyMakeBorder(img, top=1, bottom=1, left=1, right=1,
|
| 110 |
+
borderType=cv2.BORDER_CONSTANT, value=[255, 255, 255])
|
| 111 |
+
mask = cv2.bitwise_not(mask)
|
| 112 |
+
i = 0
|
| 113 |
+
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL,
|
| 114 |
+
cv2.CHAIN_APPROX_NONE)
|
| 115 |
+
|
| 116 |
+
_c = np.zeros((23, 23, 3), np.uint8)
|
| 117 |
+
|
| 118 |
+
# geometric detector
|
| 119 |
+
for cnt in contours:
|
| 120 |
+
(x, y), radius = cv2.minEnclosingCircle(cnt)
|
| 121 |
+
x, y = int(x), int(y)
|
| 122 |
+
approx = cv2.approxPolyDP(cnt, 0.1*cv2.arcLength(cnt, True), True)
|
| 123 |
+
if len(approx) == 4 and radius < 14:
|
| 124 |
+
cv2.drawContours(_c, [cnt], 0, (0, 255, 0), 1)
|
| 125 |
+
i += 1
|
| 126 |
+
else:
|
| 127 |
+
cv2.drawContours(_c, [cnt], 0, (0, 0, 255), 1)
|
| 128 |
+
|
| 129 |
+
if i == 4:
|
| 130 |
+
return (True, 1)
|
| 131 |
+
|
| 132 |
+
pred = NEURAL_MODEL.predict(X, verbose=0)
|
| 133 |
+
a, b = pred[0][0], pred[0][1]
|
| 134 |
+
t = a > b and b < 0.03 and a > 0.975
|
| 135 |
+
|
| 136 |
+
# decision
|
| 137 |
+
if t:
|
| 138 |
+
return (True, pred[0])
|
| 139 |
+
else:
|
| 140 |
+
return (False, pred[0])
|
| 141 |
+
|
| 142 |
+
################################################################################
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def LAPS(img, lines, size=10):
|
| 146 |
+
|
| 147 |
+
__points, points = laps_intersections(lines), []
|
| 148 |
+
|
| 149 |
+
for pt in __points:
|
| 150 |
+
# pixels are in integers
|
| 151 |
+
pt = list(map(int, pt))
|
| 152 |
+
|
| 153 |
+
# size of our analysis area
|
| 154 |
+
lx1 = max(0, int(pt[0]-size-1))
|
| 155 |
+
lx2 = max(0, int(pt[0]+size))
|
| 156 |
+
ly1 = max(0, int(pt[1]-size))
|
| 157 |
+
ly2 = max(0, int(pt[1]+size+1))
|
| 158 |
+
|
| 159 |
+
# cropping for detector
|
| 160 |
+
dimg = img[ly1:ly2, lx1:lx2]
|
| 161 |
+
dimg_shape = np.shape(dimg)
|
| 162 |
+
|
| 163 |
+
# not valid
|
| 164 |
+
if dimg_shape[0] <= 0 or dimg_shape[1] <= 0:
|
| 165 |
+
continue
|
| 166 |
+
|
| 167 |
+
# use neural network
|
| 168 |
+
re_laps = laps_detector(dimg)
|
| 169 |
+
if not re_laps[0]:
|
| 170 |
+
continue
|
| 171 |
+
|
| 172 |
+
# add if okay
|
| 173 |
+
if pt[0] < 0 or pt[1] < 0:
|
| 174 |
+
continue
|
| 175 |
+
points += [pt]
|
| 176 |
+
points = laps_cluster(points)
|
| 177 |
+
|
| 178 |
+
return points
|
llr.py
ADDED
|
@@ -0,0 +1,307 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Code taken from
|
| 2 |
+
# https://github.com/maciejczyzewski/neural-chessboard/
|
| 3 |
+
|
| 4 |
+
from deps.laps import laps_intersections, laps_cluster
|
| 5 |
+
from slid import slid_tendency
|
| 6 |
+
import scipy
|
| 7 |
+
import cv2
|
| 8 |
+
import pyclipper
|
| 9 |
+
import numpy as np
|
| 10 |
+
import matplotlib.path
|
| 11 |
+
import matplotlib.pyplot as plt
|
| 12 |
+
import matplotlib.path as mplPath
|
| 13 |
+
import collections
|
| 14 |
+
import itertools
|
| 15 |
+
import random
|
| 16 |
+
import math
|
| 17 |
+
import sklearn.cluster
|
| 18 |
+
from copy import copy
|
| 19 |
+
na = np.array
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
################################################################################
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def llr_normalize(points): return [[int(a), int(b)] for a, b in points]
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def llr_correctness(points, shape):
|
| 29 |
+
__points = []
|
| 30 |
+
for pt in points:
|
| 31 |
+
if pt[0] < 0 or pt[1] < 0 or \
|
| 32 |
+
pt[0] > shape[1] or \
|
| 33 |
+
pt[1] > shape[0]:
|
| 34 |
+
continue
|
| 35 |
+
__points += [pt]
|
| 36 |
+
return __points
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def llr_unique(a):
|
| 40 |
+
indices = sorted(range(len(a)), key=a.__getitem__)
|
| 41 |
+
indices = set(next(it) for k, it in
|
| 42 |
+
itertools.groupby(indices, key=a.__getitem__))
|
| 43 |
+
return [x for i, x in enumerate(a) if i in indices]
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def llr_polysort(pts):
|
| 47 |
+
"""sort points clockwise"""
|
| 48 |
+
mlat = sum(x[0] for x in pts) / len(pts)
|
| 49 |
+
mlng = sum(x[1] for x in pts) / len(pts)
|
| 50 |
+
|
| 51 |
+
def __sort(x): # main math --> found on MIT site
|
| 52 |
+
return (math.atan2(x[0]-mlat, x[1]-mlng) +
|
| 53 |
+
2*math.pi) % (2*math.pi)
|
| 54 |
+
pts.sort(key=__sort)
|
| 55 |
+
return pts
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def llr_polyscore(cnt, pts, cen, alfa=5, beta=2):
|
| 59 |
+
a = cnt[0]
|
| 60 |
+
b = cnt[1]
|
| 61 |
+
c = cnt[2]
|
| 62 |
+
d = cnt[3]
|
| 63 |
+
|
| 64 |
+
area = cv2.contourArea(cnt)
|
| 65 |
+
t2 = area < (4 * alfa * alfa) * 5
|
| 66 |
+
if t2:
|
| 67 |
+
return 0
|
| 68 |
+
|
| 69 |
+
gamma = alfa/1.5
|
| 70 |
+
|
| 71 |
+
pco = pyclipper.PyclipperOffset()
|
| 72 |
+
pco.AddPath(cnt, pyclipper.JT_MITER, pyclipper.ET_CLOSEDPOLYGON)
|
| 73 |
+
pcnt = matplotlib.path.Path(pco.Execute(gamma)[0]) # FIXME: alfa/1.5
|
| 74 |
+
wtfs = pcnt.contains_points(pts)
|
| 75 |
+
pts_in = min(np.count_nonzero(wtfs), 49)
|
| 76 |
+
t1 = pts_in < min(len(pts), 49) - 2 * beta - 1
|
| 77 |
+
if t1:
|
| 78 |
+
return 0
|
| 79 |
+
|
| 80 |
+
A = pts_in
|
| 81 |
+
B = area
|
| 82 |
+
|
| 83 |
+
def nln(l1, x, dx): return \
|
| 84 |
+
np.linalg.norm(np.cross(na(l1[1])-na(l1[0]),
|
| 85 |
+
na(l1[0])-na(x)))/dx
|
| 86 |
+
pcnt_in = []
|
| 87 |
+
i = 0
|
| 88 |
+
for pt in wtfs:
|
| 89 |
+
if pt:
|
| 90 |
+
pcnt_in += [pts[i]]
|
| 91 |
+
i += 1
|
| 92 |
+
|
| 93 |
+
def __convex_approx(points, alfa=0.001):
|
| 94 |
+
hull = scipy.spatial.ConvexHull(na(points)).vertices
|
| 95 |
+
cnt = na([points[pt] for pt in hull])
|
| 96 |
+
return cnt
|
| 97 |
+
|
| 98 |
+
cnt_in = __convex_approx(na(pcnt_in))
|
| 99 |
+
|
| 100 |
+
points = cnt_in
|
| 101 |
+
x = [p[0] for p in points]
|
| 102 |
+
y = [p[1] for p in points]
|
| 103 |
+
cen2 = (sum(x) / len(points),
|
| 104 |
+
sum(y) / len(points))
|
| 105 |
+
|
| 106 |
+
G = np.linalg.norm(na(cen)-na(cen2))
|
| 107 |
+
|
| 108 |
+
"""
|
| 109 |
+
cnt_in = __convex_approx(na(pcnt_in))
|
| 110 |
+
S = cv2.contourArea(na(cnt_in))
|
| 111 |
+
if S < B: E += abs(S - B)
|
| 112 |
+
cnt_in = __convex_approx(na(list(cnt_in)+list(cnt)))
|
| 113 |
+
S = cv2.contourArea(na(cnt_in))
|
| 114 |
+
if S > B: E += abs(S - B)
|
| 115 |
+
"""
|
| 116 |
+
|
| 117 |
+
a = [cnt[0], cnt[1]]
|
| 118 |
+
b = [cnt[1], cnt[2]]
|
| 119 |
+
c = [cnt[2], cnt[3]]
|
| 120 |
+
d = [cnt[3], cnt[0]]
|
| 121 |
+
lns = [a, b, c, d]
|
| 122 |
+
E = 0
|
| 123 |
+
F = 0
|
| 124 |
+
for l in lns:
|
| 125 |
+
d = np.linalg.norm(na(l[0])-na(l[1]))
|
| 126 |
+
for p in cnt_in:
|
| 127 |
+
r = nln(l, p, d)
|
| 128 |
+
if r < gamma:
|
| 129 |
+
E += r
|
| 130 |
+
F += 1
|
| 131 |
+
if F == 0:
|
| 132 |
+
return 0
|
| 133 |
+
E /= F
|
| 134 |
+
|
| 135 |
+
if B == 0 or A == 0:
|
| 136 |
+
return 0
|
| 137 |
+
|
| 138 |
+
# See Eq.11 and Sec.3.4 in the paper
|
| 139 |
+
|
| 140 |
+
C = 1+(E/A)**(1/3)
|
| 141 |
+
D = 1+(G/A)**(1/5)
|
| 142 |
+
R = (A**4)/((B**2) * C * D)
|
| 143 |
+
|
| 144 |
+
# print(R*(10**12), A, "|", B, C, D, "|", E, G)
|
| 145 |
+
|
| 146 |
+
return R
|
| 147 |
+
|
| 148 |
+
################################################################################
|
| 149 |
+
|
| 150 |
+
# LAPS, SLID
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def LLR(img, points, lines):
|
| 154 |
+
old = points
|
| 155 |
+
|
| 156 |
+
def __convex_approx(points, alfa=0.01):
|
| 157 |
+
hull = scipy.spatial.ConvexHull(na(points)).vertices
|
| 158 |
+
cnt = na([points[pt] for pt in hull])
|
| 159 |
+
approx = cv2.approxPolyDP(cnt, alfa *
|
| 160 |
+
cv2.arcLength(cnt, True), True)
|
| 161 |
+
return llr_normalize(itertools.chain(*approx))
|
| 162 |
+
|
| 163 |
+
__cache = {}
|
| 164 |
+
|
| 165 |
+
def __dis(a, b):
|
| 166 |
+
idx = hash("__dis" + str(a) + str(b))
|
| 167 |
+
if idx in __cache:
|
| 168 |
+
return __cache[idx]
|
| 169 |
+
__cache[idx] = np.linalg.norm(na(a)-na(b))
|
| 170 |
+
return __cache[idx]
|
| 171 |
+
|
| 172 |
+
def nln(l1, x, dx): return \
|
| 173 |
+
np.linalg.norm(np.cross(na(l1[1])-na(l1[0]),
|
| 174 |
+
na(l1[0])-na(x)))/dx
|
| 175 |
+
|
| 176 |
+
pregroup = [[], []]
|
| 177 |
+
S = {}
|
| 178 |
+
|
| 179 |
+
points = llr_correctness(llr_normalize(points), img.shape)
|
| 180 |
+
|
| 181 |
+
__points = {}
|
| 182 |
+
points = llr_polysort(points)
|
| 183 |
+
__max, __points_max = 0, []
|
| 184 |
+
alfa = math.sqrt(cv2.contourArea(na(points))/49)
|
| 185 |
+
X = sklearn.cluster.DBSCAN(eps=alfa*4).fit(points)
|
| 186 |
+
for i in range(len(points)):
|
| 187 |
+
__points[i] = []
|
| 188 |
+
for i in range(len(points)):
|
| 189 |
+
if X.labels_[i] != -1:
|
| 190 |
+
__points[X.labels_[i]] += [points[i]]
|
| 191 |
+
for i in range(len(points)):
|
| 192 |
+
if len(__points[i]) > __max:
|
| 193 |
+
__max = len(__points[i])
|
| 194 |
+
__points_max = __points[i]
|
| 195 |
+
if len(__points) > 0 and len(points) > 49/2:
|
| 196 |
+
points = __points_max
|
| 197 |
+
# print(X.labels_)
|
| 198 |
+
|
| 199 |
+
ring = __convex_approx(llr_polysort(points))
|
| 200 |
+
|
| 201 |
+
n = len(points)
|
| 202 |
+
beta = n*(5/100)
|
| 203 |
+
alfa = math.sqrt(cv2.contourArea(na(points))/49)
|
| 204 |
+
|
| 205 |
+
x = [p[0] for p in points]
|
| 206 |
+
y = [p[1] for p in points]
|
| 207 |
+
centroid = (sum(x) / len(points),
|
| 208 |
+
sum(y) / len(points))
|
| 209 |
+
|
| 210 |
+
# print(alfa, beta, centroid)
|
| 211 |
+
|
| 212 |
+
def __v(l):
|
| 213 |
+
y_0, x_0 = l[0][0], l[0][1]
|
| 214 |
+
y_1, x_1 = l[1][0], l[1][1]
|
| 215 |
+
|
| 216 |
+
x_2 = 0
|
| 217 |
+
t = (x_0-x_2)/(x_0-x_1+0.0001)
|
| 218 |
+
a = [int((1-t)*x_0+t*x_1), int((1-t)*y_0+t*y_1)][::-1]
|
| 219 |
+
|
| 220 |
+
x_2 = img.shape[0]
|
| 221 |
+
t = (x_0-x_2)/(x_0-x_1+0.0001)
|
| 222 |
+
b = [int((1-t)*x_0+t*x_1), int((1-t)*y_0+t*y_1)][::-1]
|
| 223 |
+
|
| 224 |
+
poly1 = llr_polysort([[0, 0], [0, img.shape[0]], a, b])
|
| 225 |
+
s1 = llr_polyscore(na(poly1), points, centroid, beta=beta, alfa=alfa/2)
|
| 226 |
+
poly2 = llr_polysort([a, b,
|
| 227 |
+
[img.shape[1], 0], [img.shape[1], img.shape[0]]])
|
| 228 |
+
s2 = llr_polyscore(na(poly2), points, centroid, beta=beta, alfa=alfa/2)
|
| 229 |
+
|
| 230 |
+
return [a, b], s1, s2
|
| 231 |
+
|
| 232 |
+
def __h(l):
|
| 233 |
+
x_0, y_0 = l[0][0], l[0][1]
|
| 234 |
+
x_1, y_1 = l[1][0], l[1][1]
|
| 235 |
+
|
| 236 |
+
x_2 = 0
|
| 237 |
+
t = (x_0-x_2)/(x_0-x_1+0.0001)
|
| 238 |
+
a = [int((1-t)*x_0+t*x_1), int((1-t)*y_0+t*y_1)]
|
| 239 |
+
|
| 240 |
+
x_2 = img.shape[1]
|
| 241 |
+
t = (x_0-x_2)/(x_0-x_1+0.0001)
|
| 242 |
+
b = [int((1-t)*x_0+t*x_1), int((1-t)*y_0+t*y_1)]
|
| 243 |
+
|
| 244 |
+
poly1 = llr_polysort([[0, 0], [img.shape[1], 0], a, b])
|
| 245 |
+
s1 = llr_polyscore(na(poly1), points, centroid, beta=beta, alfa=alfa/2)
|
| 246 |
+
poly2 = llr_polysort([a, b,
|
| 247 |
+
[0, img.shape[0]], [img.shape[1], img.shape[0]]])
|
| 248 |
+
s2 = llr_polyscore(na(poly2), points, centroid, beta=beta, alfa=alfa/2)
|
| 249 |
+
|
| 250 |
+
return [a, b], s1, s2
|
| 251 |
+
|
| 252 |
+
for l in lines:
|
| 253 |
+
for p in points:
|
| 254 |
+
t1 = nln(l, p, __dis(*l)) < alfa
|
| 255 |
+
t2 = nln(l, centroid, __dis(*l)) > alfa * 2.5
|
| 256 |
+
|
| 257 |
+
if t1 and t2:
|
| 258 |
+
tx, ty = l[0][0]-l[1][0], l[0][1]-l[1][1]
|
| 259 |
+
if abs(tx) < abs(ty):
|
| 260 |
+
ll, s1, s2 = __v(l)
|
| 261 |
+
o = 0
|
| 262 |
+
else:
|
| 263 |
+
ll, s1, s2 = __h(l)
|
| 264 |
+
o = 1
|
| 265 |
+
if s1 == 0 and s2 == 0:
|
| 266 |
+
continue
|
| 267 |
+
pregroup[o] += [ll]
|
| 268 |
+
|
| 269 |
+
pregroup[0] = llr_unique(pregroup[0])
|
| 270 |
+
pregroup[1] = llr_unique(pregroup[1])
|
| 271 |
+
|
| 272 |
+
# print("---------------------")
|
| 273 |
+
# print(pregroup)
|
| 274 |
+
for v in itertools.combinations(pregroup[0], 2):
|
| 275 |
+
for h in itertools.combinations(pregroup[1], 2):
|
| 276 |
+
poly = laps_intersections([v[0], v[1], h[0], h[1]])
|
| 277 |
+
poly = llr_correctness(poly, img.shape)
|
| 278 |
+
if len(poly) != 4:
|
| 279 |
+
continue
|
| 280 |
+
poly = na(llr_polysort(llr_normalize(poly)))
|
| 281 |
+
if not cv2.isContourConvex(poly):
|
| 282 |
+
continue
|
| 283 |
+
# print("Poly:", -llr_polyscore(poly, points, centroid,
|
| 284 |
+
# beta=beta, alfa=alfa/2))
|
| 285 |
+
S[-llr_polyscore(poly, points, centroid,
|
| 286 |
+
beta=beta, alfa=alfa/2)] = poly
|
| 287 |
+
|
| 288 |
+
# print(bool(S))
|
| 289 |
+
S = collections.OrderedDict(sorted(S.items()))
|
| 290 |
+
K = next(iter(S))
|
| 291 |
+
# print("key --", K)
|
| 292 |
+
four_points = llr_normalize(S[K])
|
| 293 |
+
|
| 294 |
+
# print("POINTS:", len(points))
|
| 295 |
+
# print("LINES:", len(lines))
|
| 296 |
+
|
| 297 |
+
return four_points
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
def llr_pad(four_points, img):
|
| 301 |
+
pco = pyclipper.PyclipperOffset()
|
| 302 |
+
pco.AddPath(four_points, pyclipper.JT_MITER, pyclipper.ET_CLOSEDPOLYGON)
|
| 303 |
+
|
| 304 |
+
padded = pco.Execute(60)[0]
|
| 305 |
+
|
| 306 |
+
# 60,70/75 is best (with buffer/for debug purpose)
|
| 307 |
+
return pco.Execute(60)[0]
|
preprocess.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# This script preprocess esoriginal pictures and turns them into 2D-projections.
|
| 2 |
+
# The data is then used in create_labels.py.
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import cv2
|
| 6 |
+
import glob
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from matplotlib import pyplot as plt
|
| 9 |
+
|
| 10 |
+
from rescale import *
|
| 11 |
+
from slid import detect_lines
|
| 12 |
+
from deps.laps import LAPS
|
| 13 |
+
from llr import LLR, llr_pad
|
| 14 |
+
|
| 15 |
+
RAW_DATA_FOLDER = './data/raw/games/'
|
| 16 |
+
PREPROCESSED_FOLDER = './data/preprocessed/games/'
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def preprocess_image(path, final_folder="", filename="", save=False):
|
| 20 |
+
''' Reads and preprocesses image from [path] and saves it as [filename] in the [final_folder] is [save] is enabled.'''
|
| 21 |
+
res = cv2.imread(path)[..., ::-1]
|
| 22 |
+
# Crop twice, just like Czyzewski et al. did
|
| 23 |
+
for _ in range(2):
|
| 24 |
+
img, shape, scale = image_resize(res)
|
| 25 |
+
lines = detect_lines(img)
|
| 26 |
+
# filter_lines(lines)
|
| 27 |
+
lattice_points = LAPS(img, lines)
|
| 28 |
+
# Sometimes LLR() or llr_pad() will produce an error. In this case,
|
| 29 |
+
# the picture needs to be retaken
|
| 30 |
+
inner_points = LLR(img, lattice_points, lines)
|
| 31 |
+
four_points = llr_pad(inner_points, img) # padcrop
|
| 32 |
+
|
| 33 |
+
# print(four_points)
|
| 34 |
+
try:
|
| 35 |
+
res = crop(res, four_points, scale)
|
| 36 |
+
except:
|
| 37 |
+
print("WARNING: couldn't crop around outer points")
|
| 38 |
+
res = crop(
|
| 39 |
+
res, inner_points, scale)
|
| 40 |
+
if save:
|
| 41 |
+
# Create the folder if it doesn't exist
|
| 42 |
+
Path(final_folder).mkdir(parents=True, exist_ok=True)
|
| 43 |
+
plt.imsave("%s/%s" % (final_folder, filename), res)
|
| 44 |
+
return res
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def preprocess_games(game_list):
|
| 48 |
+
'''Preprocesses all games in the given list. Assuming there are two
|
| 49 |
+
versions of each: original and reversed; in reversed, the board is flipped.
|
| 50 |
+
I included this to improve the performance of CNN in situations when
|
| 51 |
+
White has pieces on ranks 5-8 or Black has pieces on ranks 1-4.'''
|
| 52 |
+
for game_name in game_list:
|
| 53 |
+
for ver in ['orig', 'rev']:
|
| 54 |
+
img_filename_list = []
|
| 55 |
+
folder_name = RAW_DATA_FOLDER + '%s/%s/*' % (game_name, ver)
|
| 56 |
+
for path_name in glob.glob(folder_name):
|
| 57 |
+
img_filename_list.append(path_name)
|
| 58 |
+
|
| 59 |
+
count = 0
|
| 60 |
+
img_filename_list.sort(key=lambda s: int(
|
| 61 |
+
s.split('/')[-1].split('.')[0]))
|
| 62 |
+
for path in img_filename_list:
|
| 63 |
+
count += 1
|
| 64 |
+
final_folder = PREPROCESSED_FOLDER + \
|
| 65 |
+
"%s/%s/" % (game_name, ver)
|
| 66 |
+
preprocess_image(path, final_folder=final_folder,
|
| 67 |
+
filename="%i.png" % count, save=True)
|
| 68 |
+
print("Done saving in %s." % final_folder)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
if __name__ == '__main__':
|
| 72 |
+
game_list = ['runau_schmidt', 'hewitt_steinitz', 'bertok_fischer', 'karpov_kasparov',
|
| 73 |
+
'alekhine_nimzowitsch', 'rossolimo_reissmann', 'anderssen_dufresne', 'thorsteinsson_karlsson']
|
| 74 |
+
preprocess_games(game_list)
|
requirements_hf.txt
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
gradio
|
| 2 |
+
tensorflow
|
| 3 |
+
opencv-python
|
| 4 |
+
numpy
|
| 5 |
+
pillow
|
| 6 |
+
python-chess
|
| 7 |
+
matplotlib
|
rescale.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import functools
|
| 2 |
+
import numpy as np
|
| 3 |
+
import cv2
|
| 4 |
+
import math
|
| 5 |
+
arr = np.array
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def image_scale(pts, scale):
|
| 9 |
+
"""scale to original image size"""
|
| 10 |
+
def __loop(x, y): return [x[0] * y, x[1] * y]
|
| 11 |
+
return list(map(functools.partial(__loop, y=1/scale), pts))
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def image_resize(img, height=500):
|
| 15 |
+
"""resize image to same normalized area (height**2)"""
|
| 16 |
+
pixels = height * height
|
| 17 |
+
shape = list(np.shape(img))
|
| 18 |
+
scale = math.sqrt(float(pixels)/float(shape[0]*shape[1]))
|
| 19 |
+
shape[0] *= scale
|
| 20 |
+
shape[1] *= scale
|
| 21 |
+
img = cv2.resize(img, (int(shape[1]), int(shape[0])))
|
| 22 |
+
img_shape = np.shape(img)
|
| 23 |
+
return img, img_shape, scale
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def image_transform(img, points, square_length=150):
|
| 27 |
+
"""crop original image using perspective warp"""
|
| 28 |
+
board_length = square_length * 8
|
| 29 |
+
def __dis(a, b): return np.linalg.norm(arr(a)-arr(b))
|
| 30 |
+
def __shi(seq, n=0): return seq[-(n % len(seq)):] + seq[:-(n % len(seq))]
|
| 31 |
+
best_idx, best_val = 0, 10**6
|
| 32 |
+
for idx, val in enumerate(points):
|
| 33 |
+
val = __dis(val, [0, 0])
|
| 34 |
+
if val < best_val:
|
| 35 |
+
best_idx, best_val = idx, val
|
| 36 |
+
pts1 = np.float32(__shi(points, 4 - best_idx))
|
| 37 |
+
pts2 = np.float32([[0, 0], [board_length, 0],
|
| 38 |
+
[board_length, board_length], [0, board_length]])
|
| 39 |
+
M = cv2.getPerspectiveTransform(pts1, pts2)
|
| 40 |
+
W = cv2.warpPerspective(img, M, (board_length, board_length))
|
| 41 |
+
return W
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def crop(img, pts, scale):
|
| 45 |
+
"""crop using 4 points transform"""
|
| 46 |
+
pts_orig = image_scale(pts, scale)
|
| 47 |
+
img_crop = image_transform(img, pts_orig)
|
| 48 |
+
return img_crop
|
slid.py
ADDED
|
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# My implementation of the SLID module from
|
| 2 |
+
# https://github.com/maciejczyzewski/neural-chessboard/
|
| 3 |
+
|
| 4 |
+
from typing import Tuple
|
| 5 |
+
import numpy as np
|
| 6 |
+
import cv2
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
arr = np.array
|
| 10 |
+
# Four parameters are taken from the original code and
|
| 11 |
+
# correspond to four possible cases that need correction:
|
| 12 |
+
# low light, overexposure, underexposure, and blur
|
| 13 |
+
CLAHE_PARAMS = [[3, (2, 6), 5], # @1
|
| 14 |
+
[3, (6, 2), 5], # @2
|
| 15 |
+
[5, (3, 3), 5], # @3
|
| 16 |
+
[0, (0, 0), 0]] # EE
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def slid_clahe(img, limit=2, grid=(3, 3), iters=5):
|
| 20 |
+
"""repair using CLAHE algorithm (adaptive histogram equalization)"""
|
| 21 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
| 22 |
+
for i in range(iters):
|
| 23 |
+
img = cv2.createCLAHE(clipLimit=limit,
|
| 24 |
+
tileGridSize=grid).apply(img)
|
| 25 |
+
if limit != 0:
|
| 26 |
+
kernel = np.ones((10, 10), np.uint8)
|
| 27 |
+
img = cv2.morphologyEx(img, cv2.MORPH_CLOSE, kernel)
|
| 28 |
+
return img
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def slid_detector(img, alfa=150, beta=2):
|
| 32 |
+
"""detect lines using Hough algorithm"""
|
| 33 |
+
__lines, lines = [], cv2.HoughLinesP(img, rho=1, theta=np.pi/360*beta,
|
| 34 |
+
threshold=40, minLineLength=50, maxLineGap=15) # [40, 40, 10]
|
| 35 |
+
if lines is None:
|
| 36 |
+
return []
|
| 37 |
+
for line in np.reshape(lines, (-1, 4)):
|
| 38 |
+
__lines += [[[int(line[0]), int(line[1])],
|
| 39 |
+
[int(line[2]), int(line[3])]]]
|
| 40 |
+
return __lines
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def slid_canny(img, sigma=0.25):
|
| 44 |
+
"""apply Canny edge detector (automatic thresh)"""
|
| 45 |
+
v = np.median(img)
|
| 46 |
+
img = cv2.medianBlur(img, 5)
|
| 47 |
+
img = cv2.GaussianBlur(img, (7, 7), 2)
|
| 48 |
+
lower = int(max(0, (1.0 - sigma) * v))
|
| 49 |
+
upper = int(min(255, (1.0 + sigma) * v))
|
| 50 |
+
return cv2.Canny(img, lower, upper)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def pSLID(img, thresh=150):
|
| 54 |
+
"""find all lines using different settings"""
|
| 55 |
+
segments = []
|
| 56 |
+
i = 0
|
| 57 |
+
for key, arr in enumerate(CLAHE_PARAMS):
|
| 58 |
+
tmp = slid_clahe(img, limit=arr[0], grid=arr[1], iters=arr[2])
|
| 59 |
+
curr_segments = list(slid_detector(slid_canny(tmp), thresh))
|
| 60 |
+
segments += curr_segments
|
| 61 |
+
i += 1
|
| 62 |
+
# print("FILTER: {} {} : {}".format(i, arr, len(curr_segments)))
|
| 63 |
+
return segments
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
all_points = []
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def SLID(img, segments):
|
| 70 |
+
global all_points
|
| 71 |
+
all_points = []
|
| 72 |
+
|
| 73 |
+
pregroup, group, hashmap, raw_lines = [[], []], {}, {}, []
|
| 74 |
+
|
| 75 |
+
dists = {}
|
| 76 |
+
|
| 77 |
+
def dist(a, b):
|
| 78 |
+
h = hash("dist"+str(a)+str(b))
|
| 79 |
+
if h not in dists:
|
| 80 |
+
dists[h] = np.linalg.norm(arr(a)-arr(b))
|
| 81 |
+
return dists[h]
|
| 82 |
+
|
| 83 |
+
parents = {}
|
| 84 |
+
|
| 85 |
+
def find(x):
|
| 86 |
+
if x not in parents:
|
| 87 |
+
parents[x] = x
|
| 88 |
+
if parents[x] != x:
|
| 89 |
+
parents[x] = find(parents[x])
|
| 90 |
+
return parents[x]
|
| 91 |
+
|
| 92 |
+
def union(a, b):
|
| 93 |
+
par_a = find(a)
|
| 94 |
+
par_b = find(b)
|
| 95 |
+
parents[par_a] = par_b
|
| 96 |
+
group[par_b] |= group[par_a]
|
| 97 |
+
|
| 98 |
+
def height(line, pt):
|
| 99 |
+
v = np.cross(arr(line[1])-arr(line[0]), arr(pt)-arr(line[0]))
|
| 100 |
+
# Using dist() to speed up distance look-up since the 2-norm
|
| 101 |
+
# is used many times
|
| 102 |
+
return np.linalg.norm(v)/dist(line[1], line[0])
|
| 103 |
+
|
| 104 |
+
def are_similar(l1, l2):
|
| 105 |
+
'''See Sec.3.2.2 in Czyzewski et al.'''
|
| 106 |
+
a = dist(l1[0], l1[1])
|
| 107 |
+
b = dist(l2[0], l2[1])
|
| 108 |
+
|
| 109 |
+
x1 = height(l2, l1[0])
|
| 110 |
+
x2 = height(l2, l1[1])
|
| 111 |
+
y1 = height(l1, l2[0])
|
| 112 |
+
y2 = height(l1, l2[1])
|
| 113 |
+
|
| 114 |
+
if x1 < 1e-8 and x2 < 1e-8 and y1 < 1e-8 and y2 < 1e-8:
|
| 115 |
+
return True
|
| 116 |
+
|
| 117 |
+
# print("l1: %s, l2: %s" % (str(l1), str(l2)))
|
| 118 |
+
# print("x1: %f, x2: %f, y1: %f, y2: %f" % (x1, x2, y1, y2))
|
| 119 |
+
gamma = 0.25 * (x1+x2+y1+y2)
|
| 120 |
+
# print("gamma:", gamma)
|
| 121 |
+
|
| 122 |
+
img_width = 500
|
| 123 |
+
img_height = 282
|
| 124 |
+
p = 0.
|
| 125 |
+
A = img_width*img_height
|
| 126 |
+
w = np.pi/2 / np.sqrt(np.sqrt(A))
|
| 127 |
+
t_delta = p*w
|
| 128 |
+
t_delta = 0.0625
|
| 129 |
+
# t_delta = 0.05
|
| 130 |
+
|
| 131 |
+
delta = (a+b) * t_delta
|
| 132 |
+
|
| 133 |
+
return (a/gamma > delta) and (b/gamma > delta)
|
| 134 |
+
|
| 135 |
+
def generate_line(a, b, n):
|
| 136 |
+
points = []
|
| 137 |
+
for i in range(n):
|
| 138 |
+
x = a[0] + (b[0] - a[0]) * (i/n)
|
| 139 |
+
y = a[1] + (b[1] - a[1]) * (i/n)
|
| 140 |
+
points += [[int(x), int(y)]]
|
| 141 |
+
return points
|
| 142 |
+
|
| 143 |
+
def analyze(group):
|
| 144 |
+
global all_points
|
| 145 |
+
points = []
|
| 146 |
+
for idx in group:
|
| 147 |
+
points += generate_line(*hashmap[idx], 10)
|
| 148 |
+
_, radius = cv2.minEnclosingCircle(arr(points))
|
| 149 |
+
w = radius * np.pi / 2
|
| 150 |
+
vx, vy, cx, cy = cv2.fitLine(arr(points), cv2.DIST_L2, 0, 0.01, 0.01)
|
| 151 |
+
all_points += points
|
| 152 |
+
return [[int(cx-vx*w), int(cy-vy*w)], [int(cx+vx*w), int(cy+vy*w)]]
|
| 153 |
+
|
| 154 |
+
for l in segments:
|
| 155 |
+
h = hash(str(l))
|
| 156 |
+
# Initialize the line
|
| 157 |
+
hashmap[h] = l
|
| 158 |
+
group[h] = set([h])
|
| 159 |
+
parents[h] = h
|
| 160 |
+
|
| 161 |
+
wid = l[0][0] - l[1][0]
|
| 162 |
+
hei = l[0][1] - l[1][1]
|
| 163 |
+
|
| 164 |
+
# Divide lines into more horizontal vs more vertical
|
| 165 |
+
# to speed up comparison later
|
| 166 |
+
if abs(wid) < abs(hei):
|
| 167 |
+
pregroup[0].append(l)
|
| 168 |
+
else:
|
| 169 |
+
pregroup[1].append(l)
|
| 170 |
+
|
| 171 |
+
for lines in pregroup:
|
| 172 |
+
for i in range(len(lines)):
|
| 173 |
+
l1 = lines[i]
|
| 174 |
+
h1 = hash(str(l1))
|
| 175 |
+
# We're looking for the root line of each disjoint set
|
| 176 |
+
if parents[h1] != h1:
|
| 177 |
+
continue
|
| 178 |
+
for j in range(i+1, len(lines)):
|
| 179 |
+
l2 = lines[j]
|
| 180 |
+
h2 = hash(str(l2))
|
| 181 |
+
if parents[h2] != h2:
|
| 182 |
+
continue
|
| 183 |
+
if are_similar(l1, l2):
|
| 184 |
+
# Merge lines into a single disjoint set
|
| 185 |
+
union(h1, h2)
|
| 186 |
+
|
| 187 |
+
for h in group:
|
| 188 |
+
if parents[h] != h:
|
| 189 |
+
continue
|
| 190 |
+
raw_lines += [analyze(group[h])]
|
| 191 |
+
|
| 192 |
+
return raw_lines
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
def slid_tendency(raw_lines, s=4):
|
| 196 |
+
lines = []
|
| 197 |
+
def scale(x, y, s): return int(x * (1+s)/2 + y * (1-s)/2)
|
| 198 |
+
for a, b in raw_lines:
|
| 199 |
+
a[0] = scale(a[0], b[0], s)
|
| 200 |
+
a[1] = scale(a[1], b[1], s)
|
| 201 |
+
b[0] = scale(b[0], a[0], s)
|
| 202 |
+
b[1] = scale(b[1], a[1], s)
|
| 203 |
+
lines += [[a, b]]
|
| 204 |
+
return lines
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
def detect_lines(img):
|
| 208 |
+
segments = pSLID(img)
|
| 209 |
+
raw_lines = SLID(img, segments)
|
| 210 |
+
lines = slid_tendency(raw_lines)
|
| 211 |
+
return lines
|
train_tensorflow.py
ADDED
|
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# This module trains the CNN based on the labels provided in ./data/CNN
|
| 2 |
+
# Note that data must be first split into train, validation, and test data
|
| 3 |
+
# by running split_data.py.
|
| 4 |
+
# Reference:
|
| 5 |
+
# https://towardsdatascience.com/a-single-function-to-streamline-image-classification-with-keras-bd04f5cfe6df
|
| 6 |
+
|
| 7 |
+
from matplotlib import pyplot as plt
|
| 8 |
+
from tensorflow.keras.preprocessing.image import ImageDataGenerator
|
| 9 |
+
from tensorflow.keras.models import Sequential
|
| 10 |
+
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense
|
| 11 |
+
from tensorflow.keras.optimizers import RMSprop
|
| 12 |
+
import cv2
|
| 13 |
+
import json
|
| 14 |
+
import os
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
NUM_EPOCHS = 10
|
| 18 |
+
BATCH_SIZE = 16
|
| 19 |
+
DATA_FOLDER = './data/CNN/'
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def create_generators(folderpath=DATA_FOLDER):
|
| 23 |
+
'''Creates flow generators to supply images one by one during
|
| 24 |
+
training/validation phases. Useful when working with large datasets
|
| 25 |
+
that can't be directly loaded into the memory.'''
|
| 26 |
+
# All images will be rescaled by 1./255
|
| 27 |
+
train_datagen = ImageDataGenerator(rescale=1/255)
|
| 28 |
+
# Flow training images in batches of 128 using train_datagen generator
|
| 29 |
+
train_generator = train_datagen.flow_from_directory(
|
| 30 |
+
folderpath+'train', # This is the source directory for training images
|
| 31 |
+
target_size=(300, 150), # All images will be resized to 300 x 150
|
| 32 |
+
batch_size=BATCH_SIZE,
|
| 33 |
+
# Specify the classes explicitly
|
| 34 |
+
classes=['Bishop_Black', 'Bishop_White', 'Empty', 'King_Black', 'King_White', 'Knight_Black',
|
| 35 |
+
'Knight_White', 'Pawn_Black', 'Pawn_White', 'Queen_Black', 'Queen_White', 'Rook_Black', 'Rook_White'],
|
| 36 |
+
# Since we use categorical_crossentropy loss, we need categorical labels
|
| 37 |
+
class_mode='categorical')
|
| 38 |
+
# Follow the same steps for validation generator
|
| 39 |
+
validation_datagen = ImageDataGenerator(rescale=1/255)
|
| 40 |
+
validation_generator = validation_datagen.flow_from_directory(
|
| 41 |
+
folderpath+'validation',
|
| 42 |
+
target_size=(300, 150),
|
| 43 |
+
batch_size=BATCH_SIZE,
|
| 44 |
+
class_mode='categorical')
|
| 45 |
+
return (train_generator, validation_generator)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def create_model(optimizer=RMSprop(learning_rate=0.001)):
|
| 49 |
+
'''Creates a CNN architecture and compiles it.'''
|
| 50 |
+
model = Sequential([
|
| 51 |
+
# Note the input shape is the desired size of the image 300 x 150 with 3 bytes color
|
| 52 |
+
# The first convolution
|
| 53 |
+
Conv2D(16, (3, 3), activation='relu', input_shape=(300, 150, 3)),
|
| 54 |
+
MaxPooling2D(2, 2),
|
| 55 |
+
# The second convolution
|
| 56 |
+
Conv2D(32, (3, 3), activation='relu'),
|
| 57 |
+
MaxPooling2D(2, 2),
|
| 58 |
+
# The third convolution
|
| 59 |
+
Conv2D(64, (3, 3), activation='relu'),
|
| 60 |
+
MaxPooling2D(2, 2),
|
| 61 |
+
# The fourth convolution
|
| 62 |
+
Conv2D(64, (3, 3), activation='relu'),
|
| 63 |
+
MaxPooling2D(2, 2),
|
| 64 |
+
# The fifth convolution
|
| 65 |
+
Conv2D(64, (3, 3), activation='relu'),
|
| 66 |
+
MaxPooling2D(2, 2),
|
| 67 |
+
# Flatten the results to feed into a dense layer
|
| 68 |
+
Flatten(),
|
| 69 |
+
# 128 neuron in the fully-connected layer
|
| 70 |
+
Dense(128, activation='relu'),
|
| 71 |
+
# 13 output neurons for 13 classes with the softmax activation
|
| 72 |
+
Dense(13, activation='softmax')
|
| 73 |
+
])
|
| 74 |
+
|
| 75 |
+
model.compile(loss='categorical_crossentropy',
|
| 76 |
+
optimizer=optimizer,
|
| 77 |
+
metrics=['acc'])
|
| 78 |
+
|
| 79 |
+
return model
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def fit_model(model, train_generator, validation_generator, callbacks=[], save=False, filename=""):
|
| 83 |
+
'''Given the model and generators, trains the model and saves weights if
|
| 84 |
+
needed. Callbacks can be provided to save intermediate results.
|
| 85 |
+
Returns a history of model's performance (for plotting purpose).'''
|
| 86 |
+
|
| 87 |
+
total_sample = train_generator.n
|
| 88 |
+
|
| 89 |
+
history = model.fit(
|
| 90 |
+
train_generator,
|
| 91 |
+
steps_per_epoch=int(total_sample/BATCH_SIZE),
|
| 92 |
+
epochs=NUM_EPOCHS,
|
| 93 |
+
verbose=1,
|
| 94 |
+
validation_data=validation_generator,
|
| 95 |
+
callbacks=callbacks)
|
| 96 |
+
|
| 97 |
+
if save:
|
| 98 |
+
model.save_weights(filename)
|
| 99 |
+
|
| 100 |
+
return history
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def plot_accuracy(history):
|
| 104 |
+
'''Given training history, plots accuracy of a model.'''
|
| 105 |
+
plt.figure(figsize=(7, 4))
|
| 106 |
+
plt.plot([i+1 for i in range(NUM_EPOCHS)],
|
| 107 |
+
history.history['acc'], '-o', c='k', lw=2, markersize=9)
|
| 108 |
+
plt.grid(True)
|
| 109 |
+
plt.title("Training accuracy with epochs\n", fontsize=18)
|
| 110 |
+
plt.xlabel("Training epochs", fontsize=15)
|
| 111 |
+
plt.ylabel("Training accuracy", fontsize=15)
|
| 112 |
+
plt.xticks(fontsize=15)
|
| 113 |
+
plt.yticks(fontsize=15)
|
| 114 |
+
plt.show()
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def plot_loss(history):
|
| 118 |
+
'''Given training history, plots loss of a model.'''
|
| 119 |
+
plt.figure(figsize=(7, 4))
|
| 120 |
+
plt.plot([i+1 for i in range(NUM_EPOCHS)],
|
| 121 |
+
history.history['loss'], '-o', c='k', lw=2, markersize=9)
|
| 122 |
+
plt.grid(True)
|
| 123 |
+
plt.title("Training loss with epochs\n", fontsize=18)
|
| 124 |
+
plt.xlabel("Training epochs", fontsize=15)
|
| 125 |
+
plt.ylabel("Training loss", fontsize=15)
|
| 126 |
+
plt.xticks(fontsize=15)
|
| 127 |
+
plt.yticks(fontsize=15)
|
| 128 |
+
plt.show()
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def save_history(history, filename="./history.json"):
|
| 132 |
+
'''Saves the given training history as a .json file.'''
|
| 133 |
+
# Get the dictionary containing each metric and the loss for each epoch
|
| 134 |
+
history_dict = history.history
|
| 135 |
+
# Save it under the form of a json file
|
| 136 |
+
json.dump(history_dict, open(filename, 'w'))
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def load_history(filename="./history.json"):
|
| 140 |
+
'''Loads training history from the path to a .json file. Returns a dict.'''
|
| 141 |
+
with open(filename) as json_file:
|
| 142 |
+
data = json.load(json_file)
|
| 143 |
+
return data
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def test_model(model):
|
| 147 |
+
'''Tests the given model on the test set and prints its accuracy.
|
| 148 |
+
Does not return anything.'''
|
| 149 |
+
testdir = DATA_FOLDER + 'test'
|
| 150 |
+
|
| 151 |
+
# pieces = ['Empty', 'Rook', 'Knight', 'Bishop', 'Queen', 'Pawn', 'King']
|
| 152 |
+
pieces = ['Empty', 'Rook_White', 'Rook_Black', 'Knight_White', 'Knight_Black', 'Bishop_White',
|
| 153 |
+
'Bishop_Black', 'Queen_White', 'Queen_Black', 'King_White', 'King_Black', 'Pawn_White', 'Pawn_Black']
|
| 154 |
+
pieces.sort()
|
| 155 |
+
score = 0
|
| 156 |
+
total_size = 0
|
| 157 |
+
for subdir, dirs, files in os.walk(testdir):
|
| 158 |
+
for file in files:
|
| 159 |
+
if file == ".DS_Store":
|
| 160 |
+
continue
|
| 161 |
+
piece = subdir.split('/')[-1]
|
| 162 |
+
path = os.path.join(subdir, file)
|
| 163 |
+
y_prob = model.predict(cv2.imread(path).reshape(1, 300, 150, 3))
|
| 164 |
+
y_pred = y_prob.argmax()
|
| 165 |
+
if y_pred < 0 or y_pred >= len(pieces):
|
| 166 |
+
print(y_pred, y_prob)
|
| 167 |
+
if piece == pieces[y_pred]:
|
| 168 |
+
score += 1
|
| 169 |
+
total_size += 1
|
| 170 |
+
print("TEST SET ACCURACY:", score/total_size)
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
if __name__ == '__main__':
|
| 174 |
+
train_generator, validation_generator = create_generators(DATA_FOLDER)
|
| 175 |
+
model = create_model()
|
| 176 |
+
history = fit_model(model, train_generator,
|
| 177 |
+
validation_generator, save=False)
|
| 178 |
+
save_history(history, "./history.json")
|
| 179 |
+
plot_accuracy(history)
|
| 180 |
+
plot_loss(history)
|
| 181 |
+
test_model(model)
|
| 182 |
+
model.save_weights('./model_weights.h5')
|
utils.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from config import *
|
| 2 |
+
from time import time
|
| 3 |
+
from copy import copy
|
| 4 |
+
|
| 5 |
+
import functools, os, re
|
| 6 |
+
import sys, cv2, math, numpy as np
|
| 7 |
+
na = np.array
|
| 8 |
+
|
| 9 |
+
################################################################################
|
| 10 |
+
|
| 11 |
+
rows, columns = os.popen('stty size', 'r').read().split()
|
| 12 |
+
__strip_ansi_re = re.compile(r"""
|
| 13 |
+
\x1b # literal ESC
|
| 14 |
+
\[ # literal [
|
| 15 |
+
[;\d]* # zero or more digits or semicolons
|
| 16 |
+
[A-Za-z] # a letter
|
| 17 |
+
""", re.VERBOSE).sub
|
| 18 |
+
def __strip_ansi(s):
|
| 19 |
+
return __strip_ansi_re("", s)
|
| 20 |
+
|
| 21 |
+
################################################################################
|
| 22 |
+
|
| 23 |
+
def clock():
|
| 24 |
+
global NC_CLOCK; return "(%8s)s" % round((time() - NC_CLOCK), 3)
|
| 25 |
+
def reset(): global NC_CLOCK; NC_CLOCK = time()
|
| 26 |
+
|
| 27 |
+
def warn(msg): print("\x1b[0;33;40m warn: \x1b[4;33;40m" + msg + "\x1b[0m")
|
| 28 |
+
def errn(msg): print("\n\x1b[0;37;41m errn: " + msg + "\x1b[0m\n"); sys.exit(1)
|
| 29 |
+
|
| 30 |
+
def head(msg): return "\x1b[5;30;43m " + msg + " \x1b[0m"
|
| 31 |
+
def call(msg): return "--> \x1b[5;31;40m@" + msg + "\x1b[0m"
|
| 32 |
+
|
| 33 |
+
def ribb(*msg, sep='-'):
|
| 34 |
+
msg = ' '.join(msg)
|
| 35 |
+
return msg + sep * int(int(columns) - len(__strip_ansi(msg)))
|
| 36 |
+
|
| 37 |
+
################################################################################
|
| 38 |
+
|
| 39 |
+
def image_scale(pts, scale):
|
| 40 |
+
"""scale to original image size"""
|
| 41 |
+
def __loop(x, y): return [x[0] * y, x[1] * y]
|
| 42 |
+
return list(map(functools.partial(__loop, y=1/scale), pts))
|
| 43 |
+
|
| 44 |
+
def image_resize(img, height=500):
|
| 45 |
+
"""resize image to same normalized area (height**2)"""
|
| 46 |
+
pixels = height * height; shape = list(np.shape(img))
|
| 47 |
+
scale = math.sqrt(float(pixels)/float(shape[0]*shape[1]))
|
| 48 |
+
shape[0] *= scale; shape[1] *= scale
|
| 49 |
+
img = cv2.resize(img, (int(shape[1]), int(shape[0])))
|
| 50 |
+
img_shape = np.shape(img)
|
| 51 |
+
return img, img_shape, scale
|
| 52 |
+
|
| 53 |
+
def image_transform(img, points, square_length=150):
|
| 54 |
+
"""crop original image using perspective warp"""
|
| 55 |
+
board_length = square_length * 8
|
| 56 |
+
def __dis(a, b): return np.linalg.norm(na(a)-na(b))
|
| 57 |
+
def __shi(seq, n=0): return seq[-(n % len(seq)):] + seq[:-(n % len(seq))]
|
| 58 |
+
best_idx, best_val = 0, 10**6
|
| 59 |
+
for idx, val in enumerate(points):
|
| 60 |
+
val = __dis(val, [0, 0])
|
| 61 |
+
if val < best_val:
|
| 62 |
+
best_idx, best_val = idx, val
|
| 63 |
+
pts1 = np.float32(__shi(points, 4 - best_idx))
|
| 64 |
+
pts2 = np.float32([[0, 0], [board_length, 0], \
|
| 65 |
+
[board_length, board_length], [0, board_length]])
|
| 66 |
+
M = cv2.getPerspectiveTransform(pts1, pts2)
|
| 67 |
+
W = cv2.warpPerspective(img, M, (board_length, board_length))
|
| 68 |
+
return W
|
| 69 |
+
|
| 70 |
+
class ImageObject(object):
|
| 71 |
+
images = {}; scale = 1; shape = (0, 0)
|
| 72 |
+
|
| 73 |
+
def __init__(self, img):
|
| 74 |
+
"""save and prepare image array"""
|
| 75 |
+
self.images['orig'] = img
|
| 76 |
+
self.images['main'], self.shape, self.scale = \
|
| 77 |
+
image_resize(img) # downscale for speed
|
| 78 |
+
self.images['test'] = copy(self.images['main'])
|
| 79 |
+
|
| 80 |
+
def __getitem__(self, attr):
|
| 81 |
+
"""return image as array"""
|
| 82 |
+
return self.images[attr]
|
| 83 |
+
|
| 84 |
+
def __setitem__(self, attr, val):
|
| 85 |
+
"""save image to object"""
|
| 86 |
+
self.images[attr] = val
|
| 87 |
+
|
| 88 |
+
def crop(self, pts):
|
| 89 |
+
"""crop using 4 points transform"""
|
| 90 |
+
pts_orig = image_scale(pts, self.scale)
|
| 91 |
+
img_crop = image_transform(self.images['orig'], pts_orig)
|
| 92 |
+
self.__init__(img_crop)
|