nathbns commited on
Commit
6e2dee6
·
verified ·
1 Parent(s): 14a99d7

Upload 11 files

Browse files
Files changed (11) hide show
  1. app.py +187 -0
  2. deps/__init__.py +2 -0
  3. deps/geometry.py +1164 -0
  4. deps/laps.py +178 -0
  5. llr.py +307 -0
  6. preprocess.py +74 -0
  7. requirements_hf.txt +7 -0
  8. rescale.py +48 -0
  9. slid.py +211 -0
  10. train_tensorflow.py +182 -0
  11. utils.py +92 -0
app.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Gradio App for Chess Board Analyzer
3
+ EXACTLY uses main.py logic - no modifications
4
+ """
5
+
6
+ import gradio as gr
7
+ import chess
8
+ import chess.svg
9
+ import io
10
+ import numpy as np
11
+ import cv2
12
+ import os
13
+ import tempfile
14
+ from pathlib import Path
15
+
16
+ # Import EXACT SAME functions from main.py
17
+ from preprocess import preprocess_image
18
+ from train_tensorflow import create_model
19
+
20
+ PIECES = ['Empty', 'Rook_White', 'Rook_Black', 'Knight_White', 'Knight_Black', 'Bishop_White',
21
+ 'Bishop_Black', 'Queen_White', 'Queen_Black', 'King_White', 'King_Black', 'Pawn_White', 'Pawn_Black']
22
+ PIECES.sort()
23
+
24
+ LABELS = {
25
+ 'Empty': '.',
26
+ 'Rook_White': 'R',
27
+ 'Rook_Black': 'r',
28
+ 'Knight_White': 'N',
29
+ 'Knight_Black': 'n',
30
+ 'Bishop_White': 'B',
31
+ 'Bishop_Black': 'b',
32
+ 'Queen_White': 'Q',
33
+ 'Queen_Black': 'q',
34
+ 'King_White': 'K',
35
+ 'King_Black': 'k',
36
+ 'Pawn_White': 'P',
37
+ 'Pawn_Black': 'p',
38
+ }
39
+
40
+ # Load model at startup (EXACT SAME as main.py)
41
+ print("⏳ Loading model...")
42
+ model = create_model()
43
+ model.load_weights('./model_weights.h5')
44
+ print("✅ Model loaded!")
45
+
46
+
47
+ def classify_image(img):
48
+ """EXACT COPY from main.py"""
49
+ y_prob = model.predict(img.reshape(1, 300, 150, 3), verbose=0)
50
+ y_pred = y_prob.argmax()
51
+ return PIECES[y_pred]
52
+
53
+
54
+ def analyze_board(img):
55
+ """EXACT COPY from main.py"""
56
+ arr = []
57
+ M = img.shape[0]//8
58
+ N = img.shape[1]//8
59
+ for y in range(M-1, img.shape[1], M):
60
+ row = []
61
+ for x in range(0, img.shape[1], N):
62
+ sub_img = img[max(0, y-2*M):y, x:x+N]
63
+ if y-2*M < 0:
64
+ sub_img = np.concatenate(
65
+ (np.zeros((2*M-y, N, 3)), sub_img))
66
+ sub_img = sub_img.astype(np.uint8)
67
+
68
+ piece = classify_image(sub_img)
69
+ row.append(LABELS[piece])
70
+ arr.append(row)
71
+
72
+ # King-Queen heuristic
73
+ blackKing = False
74
+ whiteKing = False
75
+ whitePos = (-1, -1)
76
+ blackPos = (-1, -1)
77
+ for i in range(8):
78
+ for j in range(8):
79
+ if arr[i][j] == 'K':
80
+ whiteKing = True
81
+ if arr[i][j] == 'k':
82
+ blackKing = True
83
+ if arr[i][j] == 'Q':
84
+ whitePos = (i, j)
85
+ if arr[i][j] == 'q':
86
+ blackPos = (i, j)
87
+ if not whiteKing and whitePos[0] >= 0:
88
+ arr[whitePos[0]][whitePos[1]] = 'K'
89
+ if not blackKing and blackPos[0] >= 0:
90
+ arr[blackPos[0]][blackPos[1]] = 'k'
91
+
92
+ return arr
93
+
94
+
95
+ def board_to_fen(board):
96
+ """EXACT COPY from main.py"""
97
+ with io.StringIO() as s:
98
+ for row in board:
99
+ empty = 0
100
+ for cell in row:
101
+ if cell != '.':
102
+ if empty > 0:
103
+ s.write(str(empty))
104
+ empty = 0
105
+ s.write(cell)
106
+ else:
107
+ empty += 1
108
+ if empty > 0:
109
+ s.write(str(empty))
110
+ s.write('/')
111
+ s.seek(s.tell() - 1)
112
+ s.write(' w KQkq - 0 1')
113
+ return s.getvalue()
114
+
115
+
116
+ def analyze_chess_image(image_input):
117
+ """Gradio wrapper around main.py logic"""
118
+ if image_input is None:
119
+ return "❌ No image provided", None
120
+
121
+ try:
122
+ # Save to temp file (needed for preprocess_image which expects file path)
123
+ with tempfile.NamedTemporaryFile(suffix='.jpg', delete=False) as tmp:
124
+ if isinstance(image_input, np.ndarray):
125
+ cv2.imwrite(tmp.name, cv2.cvtColor(image_input, cv2.COLOR_RGB2BGR))
126
+ else:
127
+ image_input.save(tmp.name)
128
+ temp_path = tmp.name
129
+
130
+ # EXACT SAME as main.py: preprocess_image() uses LAPS!
131
+ img = preprocess_image(temp_path, save=False)
132
+
133
+ # EXACT SAME as main.py
134
+ arr = analyze_board(img)
135
+ fen = board_to_fen(arr)
136
+
137
+ # Generate board visualization
138
+ board = chess.Board(fen)
139
+ board_svg = chess.svg.board(board=board, size=400)
140
+
141
+ # Cleanup
142
+ os.unlink(temp_path)
143
+
144
+ return f"✅ FEN: {fen}", board_svg
145
+
146
+ except Exception as e:
147
+ import traceback
148
+ print(traceback.format_exc())
149
+ return f"❌ Error: {str(e)}", None
150
+
151
+
152
+ # Build Gradio interface
153
+ with gr.Blocks(title="Chess Board Analyzer", theme=gr.themes.Soft()) as demo:
154
+ gr.Markdown("""
155
+ # ♟️ Chess Board Analyzer
156
+
157
+ Upload a chess board image to automatically detect all pieces and get the FEN notation.
158
+
159
+ **Uses EXACT SAME preprocessing (LAPS) and model as main.py**
160
+ """)
161
+
162
+ with gr.Row():
163
+ with gr.Column():
164
+ image_input = gr.Image(label="📸 Upload chess board image", type="pil")
165
+ submit_btn = gr.Button("🔍 Analyze Board", size="lg", variant="primary")
166
+
167
+ with gr.Column():
168
+ status_output = gr.Textbox(label="Result", interactive=False, lines=2)
169
+ board_output = gr.HTML(label="Board Visualization")
170
+
171
+ submit_btn.click(
172
+ fn=analyze_chess_image,
173
+ inputs=image_input,
174
+ outputs=[status_output, board_output]
175
+ )
176
+
177
+ gr.Markdown("""
178
+ ### Model Info:
179
+ - **Preprocessing**: LAPS (Lattice Point Detection + Perspective Correction)
180
+ - **Architecture**: CNN with 5 convolutional layers
181
+ - **Accuracy**: ~96% on test set
182
+ - **Classes**: 13 types (Empty + 6 White + 6 Black pieces)
183
+ """)
184
+
185
+
186
+ if __name__ == "__main__":
187
+ demo.launch()
deps/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from . import geometry
2
+ from . import laps
deps/geometry.py ADDED
@@ -0,0 +1,1164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ https://github.com/ideasman42/isect_segments-bentley_ottmann
3
+ """
4
+
5
+ # BentleyOttmann sweep-line implementation
6
+ # (for finding all intersections in a set of line segments)
7
+
8
+ __all__ = (
9
+ "isect_segments",
10
+ "isect_polygon",
11
+
12
+ # for testing only (correct but slow)
13
+ "isect_segments__naive",
14
+ "isect_polygon__naive",
15
+ )
16
+
17
+ # ----------------------------------------------------------------------------
18
+ # Main Poly Intersection
19
+
20
+ # Defines to change behavior.
21
+ #
22
+ # Whether to ignore intersections of line segments when both
23
+ # their end points form the intersection point.
24
+ USE_IGNORE_SEGMENT_ENDINGS = True
25
+
26
+ USE_DEBUG = False # FIXME
27
+
28
+ USE_VERBOSE = False
29
+
30
+ # checks we should NOT need,
31
+ # but do them in case we find a test-case that fails.
32
+ USE_PARANOID = False
33
+
34
+ # Support vertical segments,
35
+ # (the bentley-ottmann method doesn't support this).
36
+ # We use the term 'START_VERTICAL' for a vertical segment,
37
+ # to differentiate it from START/END/INTERSECTION
38
+ USE_VERTICAL = True
39
+ # end defines!
40
+ # ------------
41
+
42
+ # ---------
43
+ # Constants
44
+ X, Y = 0, 1
45
+ EPS = 1e-10
46
+ EPS_SQ = EPS * EPS
47
+ INF = float("inf")
48
+
49
+
50
+ class Event:
51
+ __slots__ = (
52
+ "type",
53
+ "point",
54
+ "segment",
55
+
56
+ # this is just cache,
57
+ # we may remove or calculate slope on the fly
58
+ "slope",
59
+ "span",
60
+ ) + (() if not USE_DEBUG else (
61
+ # debugging only
62
+ "other",
63
+ "in_sweep",
64
+ ))
65
+
66
+ class Type:
67
+ END = 0
68
+ INTERSECTION = 1
69
+ START = 2
70
+ if USE_VERTICAL:
71
+ START_VERTICAL = 3
72
+
73
+ def __init__(self, type, point, segment, slope):
74
+ assert(isinstance(point, tuple))
75
+ self.type = type
76
+ self.point = point
77
+ self.segment = segment
78
+
79
+ # will be None for INTERSECTION
80
+ self.slope = slope
81
+ if segment is not None:
82
+ self.span = segment[1][X] - segment[0][X]
83
+
84
+ if USE_DEBUG:
85
+ self.other = None
86
+ self.in_sweep = False
87
+
88
+ def is_vertical(self):
89
+ return self.segment[0][X] == self.segment[1][X]
90
+
91
+ def y_intercept_x(self, x: float):
92
+ # vertical events only for comparison (above_all check)
93
+ # never added into the binary-tree its self
94
+ if USE_VERTICAL:
95
+ if self.is_vertical():
96
+ return None
97
+
98
+ if x <= self.segment[0][X]:
99
+ return self.segment[0][Y]
100
+ elif x >= self.segment[1][X]:
101
+ return self.segment[1][Y]
102
+
103
+ # use the largest to avoid float precision error with nearly vertical lines.
104
+ delta_x0 = x - self.segment[0][X]
105
+ delta_x1 = self.segment[1][X] - x
106
+ if delta_x0 > delta_x1:
107
+ ifac = delta_x0 / self.span
108
+ fac = 1.0 - ifac
109
+ else:
110
+ fac = delta_x1 / self.span
111
+ ifac = 1.0 - fac
112
+ assert(fac <= 1.0)
113
+ return (self.segment[0][Y] * fac) + (self.segment[1][Y] * ifac)
114
+
115
+ @staticmethod
116
+ def Compare(sweep_line, this, that):
117
+ if this is that:
118
+ return 0
119
+ if USE_DEBUG:
120
+ if this.other is that:
121
+ return 0
122
+ current_point_x = sweep_line._current_event_point_x
123
+ ipthis = this.y_intercept_x(current_point_x)
124
+ ipthat = that.y_intercept_x(current_point_x)
125
+ # print(ipthis, ipthat)
126
+ if USE_VERTICAL:
127
+ if ipthis is None:
128
+ ipthis = this.point[Y]
129
+ if ipthat is None:
130
+ ipthat = that.point[Y]
131
+
132
+ delta_y = ipthis - ipthat
133
+
134
+ assert((delta_y < 0.0) == (ipthis < ipthat))
135
+ # NOTE, VERY IMPORTANT TO USE EPSILON HERE!
136
+ # otherwise w/ float precision errors we get incorrect comparisons
137
+ # can get very strange & hard to debug output without this.
138
+ if abs(delta_y) > EPS:
139
+ return -1 if (delta_y < 0.0) else 1
140
+ else:
141
+ this_slope = this.slope
142
+ that_slope = that.slope
143
+ if this_slope != that_slope:
144
+ if sweep_line._before:
145
+ return -1 if (this_slope > that_slope) else 1
146
+ else:
147
+ return 1 if (this_slope > that_slope) else -1
148
+
149
+ delta_x_p1 = this.segment[0][X] - that.segment[0][X]
150
+ if delta_x_p1 != 0.0:
151
+ return -1 if (delta_x_p1 < 0.0) else 1
152
+
153
+ delta_x_p2 = this.segment[1][X] - that.segment[1][X]
154
+ if delta_x_p2 != 0.0:
155
+ return -1 if (delta_x_p2 < 0.0) else 1
156
+
157
+ return 0
158
+
159
+ def __repr__(self):
160
+ return ("Event(0x%x, s0=%r, s1=%r, p=%r, type=%d, slope=%r)" % (
161
+ id(self),
162
+ self.segment[0], self.segment[1],
163
+ self.point,
164
+ self.type,
165
+ self.slope,
166
+ ))
167
+
168
+
169
+ class SweepLine:
170
+ __slots__ = (
171
+ # A map holding all intersection points mapped to the Events
172
+ # that form these intersections.
173
+ # {Point: set(Event, ...), ...}
174
+ "intersections",
175
+ "queue",
176
+
177
+ # Events (sorted set of ordered events, no values)
178
+ #
179
+ # note: START & END events are considered the same so checking if an event is in the tree
180
+ # will return true if its opposite side is found.
181
+ # This is essential for the algorithm to work, and why we don't explicitly remove START events.
182
+ # Instead, the END events are never added to the current sweep, and removing them also removes the start.
183
+ "_events_current_sweep",
184
+ # The point of the current Event.
185
+ "_current_event_point_x",
186
+ # A flag to indicate if we're slightly before or after the line.
187
+ "_before",
188
+ )
189
+
190
+ def __init__(self):
191
+ self.intersections = {}
192
+
193
+ self._current_event_point_x = None
194
+ self._events_current_sweep = RBTree(cmp=Event.Compare, cmp_data=self)
195
+ self._before = True
196
+
197
+ def get_intersections(self):
198
+ return list(self.intersections.keys())
199
+
200
+ # Checks if an intersection exists between two Events 'a' and 'b'.
201
+ def _check_intersection(self, a: Event, b: Event):
202
+ # Return immediately in case either of the events is null, or
203
+ # if one of them is an INTERSECTION event.
204
+ if ((a is None or b is None) or
205
+ (a.type == Event.Type.INTERSECTION) or
206
+ (b.type == Event.Type.INTERSECTION)):
207
+
208
+ return
209
+
210
+ if a is b:
211
+ return
212
+
213
+ # Get the intersection point between 'a' and 'b'.
214
+ p = isect_seg_seg_v2_point(
215
+ a.segment[0], a.segment[1],
216
+ b.segment[0], b.segment[1])
217
+
218
+ # No intersection exists.
219
+ if p is None:
220
+ return
221
+
222
+ # If the intersection is formed by both the segment endings, AND
223
+ # USE_IGNORE_SEGMENT_ENDINGS is true,
224
+ # return from this method.
225
+ if USE_IGNORE_SEGMENT_ENDINGS:
226
+ if ((len_squared_v2v2(p, a.segment[0]) < EPS_SQ or
227
+ len_squared_v2v2(p, a.segment[1]) < EPS_SQ) and
228
+ (len_squared_v2v2(p, b.segment[0]) < EPS_SQ or
229
+ len_squared_v2v2(p, b.segment[1]) < EPS_SQ)):
230
+
231
+ return
232
+
233
+ # Add the intersection.
234
+ events_for_point = self.intersections.pop(p, set())
235
+ is_new = len(events_for_point) == 0
236
+ events_for_point.add(a)
237
+ events_for_point.add(b)
238
+ self.intersections[p] = events_for_point
239
+
240
+ # If the intersection occurs to the right of the sweep line, OR
241
+ # if the intersection is on the sweep line and it's above the
242
+ # current event-point, add it as a new Event to the queue.
243
+ if is_new and p[X] >= self._current_event_point_x:
244
+ event_isect = Event(Event.Type.INTERSECTION, p, None, None)
245
+ self.queue.offer(p, event_isect)
246
+
247
+ def _sweep_to(self, p):
248
+ if p[X] == self._current_event_point_x:
249
+ # happens in rare cases,
250
+ # we can safely ignore
251
+ return
252
+
253
+ self._current_event_point_x = p[X]
254
+
255
+ def insert(self, event):
256
+ assert(event not in self._events_current_sweep)
257
+ assert(event.type != Event.Type.START_VERTICAL)
258
+ if USE_DEBUG:
259
+ assert(event.in_sweep == False)
260
+ assert(event.other.in_sweep == False)
261
+
262
+ self._events_current_sweep.insert(event, None)
263
+
264
+ if USE_DEBUG:
265
+ event.in_sweep = True
266
+ event.other.in_sweep = True
267
+
268
+ def remove(self, event):
269
+ try:
270
+ self._events_current_sweep.remove(event)
271
+ if USE_DEBUG:
272
+ assert(event.in_sweep == True)
273
+ assert(event.other.in_sweep == True)
274
+ event.in_sweep = False
275
+ event.other.in_sweep = False
276
+ return True
277
+ except KeyError:
278
+ if USE_DEBUG:
279
+ assert(event.in_sweep == False)
280
+ assert(event.other.in_sweep == False)
281
+ return False
282
+
283
+ def above(self, event):
284
+ return self._events_current_sweep.succ_key(event, None)
285
+
286
+ def below(self, event):
287
+ return self._events_current_sweep.prev_key(event, None)
288
+
289
+ '''
290
+ def above_all(self, event):
291
+ while True:
292
+ event = self.above(event)
293
+ if event is None:
294
+ break
295
+ yield event
296
+ '''
297
+
298
+ def above_all(self, event):
299
+ # assert(event not in self._events_current_sweep)
300
+ return self._events_current_sweep.key_slice(event, None, reverse=False)
301
+
302
+ def handle(self, p, events_current):
303
+ if len(events_current) == 0:
304
+ return
305
+ # done already
306
+ # self._sweep_to(events_current[0])
307
+ assert(p[0] == self._current_event_point_x)
308
+
309
+ if not USE_IGNORE_SEGMENT_ENDINGS:
310
+ if len(events_current) > 1:
311
+ for i in range(0, len(events_current) - 1):
312
+ for j in range(i + 1, len(events_current)):
313
+ self._check_intersection(
314
+ events_current[i], events_current[j])
315
+
316
+ for e in events_current:
317
+ self.handle_event(e)
318
+
319
+ def handle_event(self, event):
320
+ t = event.type
321
+ if t == Event.Type.START:
322
+ # print(" START")
323
+ self._before = False
324
+ self.insert(event)
325
+
326
+ e_above = self.above(event)
327
+ e_below = self.below(event)
328
+
329
+ self._check_intersection(event, e_above)
330
+ self._check_intersection(event, e_below)
331
+ if USE_PARANOID:
332
+ self._check_intersection(e_above, e_below)
333
+
334
+ elif t == Event.Type.END:
335
+ # print(" END")
336
+ self._before = True
337
+
338
+ e_above = self.above(event)
339
+ e_below = self.below(event)
340
+
341
+ self.remove(event)
342
+
343
+ self._check_intersection(e_above, e_below)
344
+ if USE_PARANOID:
345
+ self._check_intersection(event, e_above)
346
+ self._check_intersection(event, e_below)
347
+
348
+ elif t == Event.Type.INTERSECTION:
349
+ # print(" INTERSECTION")
350
+ self._before = True
351
+ event_set = self.intersections[event.point]
352
+ # note: events_current aren't sorted.
353
+ reinsert_stack = [] # Stack
354
+ for e in event_set:
355
+ # If we the Event was not already removed,
356
+ # we want to insert it later on.
357
+ if self.remove(e):
358
+ reinsert_stack.append(e)
359
+ self._before = False
360
+
361
+ # Insert all Events that we were able to remove.
362
+ while reinsert_stack:
363
+ e = reinsert_stack.pop()
364
+
365
+ self.insert(e)
366
+
367
+ e_above = self.above(e)
368
+ e_below = self.below(e)
369
+
370
+ self._check_intersection(e, e_above)
371
+ self._check_intersection(e, e_below)
372
+ if USE_PARANOID:
373
+ self._check_intersection(e_above, e_below)
374
+ elif (USE_VERTICAL and
375
+ (t == Event.Type.START_VERTICAL)):
376
+
377
+ # just check sanity
378
+ assert(event.segment[0][X] == event.segment[1][X])
379
+ assert(event.segment[0][Y] <= event.segment[1][Y])
380
+
381
+ # In this case we only need to find all segments in this span.
382
+ y_above_max = event.segment[1][Y]
383
+
384
+ # self.insert(event)
385
+ for e_above in self.above_all(event):
386
+ if e_above.type == Event.Type.START_VERTICAL:
387
+ continue
388
+ y_above = e_above.y_intercept_x(
389
+ self._current_event_point_x)
390
+ if USE_IGNORE_SEGMENT_ENDINGS:
391
+ if y_above >= y_above_max:
392
+ break
393
+ else:
394
+ if y_above > y_above_max:
395
+ break
396
+
397
+ # We know this intersects,
398
+ # so we could use a faster function now:
399
+ # ix = (self._current_event_point_x, y_above)
400
+ # ...however best use existing functions
401
+ # since it does all sanity checks on endpoints... etc.
402
+ self._check_intersection(event, e_above)
403
+
404
+ # self.remove(event)
405
+
406
+
407
+ class EventQueue:
408
+ __slots__ = (
409
+ # note: we only ever pop_min, this could use a 'heap' structure.
410
+ # The sorted map holding the points -> event list
411
+ # [Point: Event] (tree)
412
+ "events_scan",
413
+ )
414
+
415
+ def __init__(self, segments, line: SweepLine):
416
+ self.events_scan = RBTree()
417
+ # segments = [s for s in segments if s[0][0] != s[1][0] and s[0][1] != s[1][1]]
418
+
419
+ for s in segments:
420
+ assert(s[0][X] <= s[1][X])
421
+
422
+ slope = slope_v2v2(*s)
423
+
424
+ if s[0] == s[1]:
425
+ pass
426
+ elif USE_VERTICAL and (s[0][X] == s[1][X]):
427
+ e_start = Event(Event.Type.START_VERTICAL, s[0], s, slope)
428
+
429
+ if USE_DEBUG:
430
+ e_start.other = e_start # FAKE, avoid error checking
431
+
432
+ self.offer(s[0], e_start)
433
+ else:
434
+ e_start = Event(Event.Type.START, s[0], s, slope)
435
+ e_end = Event(Event.Type.END, s[1], s, slope)
436
+
437
+ if USE_DEBUG:
438
+ e_start.other = e_end
439
+ e_end.other = e_start
440
+
441
+ self.offer(s[0], e_start)
442
+ self.offer(s[1], e_end)
443
+
444
+ line.queue = self
445
+
446
+ def offer(self, p, e: Event):
447
+ """
448
+ Offer a new event ``s`` at point ``p`` in this queue.
449
+ """
450
+ existing = self.events_scan.setdefault(
451
+ p, ([], [], [], []) if USE_VERTICAL else
452
+ ([], [], []))
453
+ # Can use double linked-list for easy insertion at beginning/end
454
+ '''
455
+ if e.type == Event.Type.END:
456
+ existing.insert(0, e)
457
+ else:
458
+ existing.append(e)
459
+ '''
460
+
461
+ existing[e.type].append(e)
462
+
463
+ # return a set of events
464
+ def poll(self):
465
+ """
466
+ Get, and remove, the first (lowest) item from this queue.
467
+
468
+ :return: the first (lowest) item from this queue.
469
+ :rtype: Point, Event pair.
470
+ """
471
+ assert(len(self.events_scan) != 0)
472
+ p, events_current = self.events_scan.pop_min()
473
+ return p, events_current
474
+
475
+
476
+ def isect_segments(segments) -> list:
477
+ # order points left -> right
478
+ segments = [
479
+ # in nearly all cases, comparing X is enough,
480
+ # but compare Y too for vertical lines
481
+ (s[0], s[1]) if (s[0] <= s[1]) else
482
+ (s[1], s[0])
483
+ for s in segments]
484
+
485
+ sweep_line = SweepLine()
486
+ queue = EventQueue(segments, sweep_line)
487
+
488
+ while len(queue.events_scan) > 0:
489
+ if USE_VERBOSE:
490
+ print(len(queue.events_scan), sweep_line._current_event_point_x)
491
+ p, e_ls = queue.poll()
492
+ for events_current in e_ls:
493
+ if events_current:
494
+ sweep_line._sweep_to(p)
495
+ sweep_line.handle(p, events_current)
496
+
497
+ return sweep_line.get_intersections()
498
+
499
+
500
+ def isect_polygon(points) -> list:
501
+ n = len(points)
502
+ segments = [
503
+ (tuple(points[i]), tuple(points[(i + 1) % n]))
504
+ for i in range(n)]
505
+ return isect_segments(segments)
506
+
507
+
508
+ # ----------------------------------------------------------------------------
509
+ # 2D math utilities
510
+
511
+
512
+ def slope_v2v2(p1, p2):
513
+ if p1[X] == p2[X]:
514
+ if p1[Y] < p2[Y]:
515
+ return INF
516
+ else:
517
+ return -INF
518
+ else:
519
+ return (p2[Y] - p1[Y]) / (p2[X] - p1[X])
520
+
521
+
522
+ def sub_v2v2(a, b):
523
+ return (
524
+ a[0] - b[0],
525
+ a[1] - b[1])
526
+
527
+
528
+ def dot_v2v2(a, b):
529
+ return (
530
+ (a[0] * b[0]) +
531
+ (a[1] * b[1]))
532
+
533
+
534
+ def len_squared_v2v2(a, b):
535
+ c = sub_v2v2(a, b)
536
+ return dot_v2v2(c, c)
537
+
538
+
539
+ def line_point_factor_v2(p, l1, l2, default=0.0):
540
+ u = sub_v2v2(l2, l1)
541
+ h = sub_v2v2(p, l1)
542
+ dot = dot_v2v2(u, u)
543
+ return (dot_v2v2(u, h) / dot) if dot != 0.0 else default
544
+
545
+
546
+ def isect_seg_seg_v2_point(v1, v2, v3, v4, bias=0.0):
547
+ # Only for predictability and hashable point when same input is given
548
+ if v1 > v2:
549
+ v1, v2 = v2, v1
550
+ if v3 > v4:
551
+ v3, v4 = v4, v3
552
+
553
+ if (v1, v2) > (v3, v4):
554
+ v1, v2, v3, v4 = v3, v4, v1, v2
555
+
556
+ div = (v2[0] - v1[0]) * (v4[1] - v3[1]) - (v2[1] - v1[1]) * (v4[0] - v3[0])
557
+ if div == 0.0:
558
+ return None
559
+
560
+ vi = (((v3[0] - v4[0]) *
561
+ (v1[0] * v2[1] - v1[1] * v2[0]) - (v1[0] - v2[0]) *
562
+ (v3[0] * v4[1] - v3[1] * v4[0])) / div,
563
+ ((v3[1] - v4[1]) *
564
+ (v1[0] * v2[1] - v1[1] * v2[0]) - (v1[1] - v2[1]) *
565
+ (v3[0] * v4[1] - v3[1] * v4[0])) / div,
566
+ )
567
+
568
+ fac = line_point_factor_v2(vi, v1, v2, default=-1.0)
569
+ if fac < 0.0 - bias or fac > 1.0 + bias:
570
+ return None
571
+
572
+ fac = line_point_factor_v2(vi, v3, v4, default=-1.0)
573
+ if fac < 0.0 - bias or fac > 1.0 + bias:
574
+ return None
575
+
576
+ # vi = round(vi[X], 8), round(vi[Y], 8)
577
+ return vi
578
+
579
+
580
+ # ----------------------------------------------------------------------------
581
+ # Simple naive line intersect, (for testing only)
582
+
583
+
584
+ def isect_segments__naive(segments) -> list:
585
+ """
586
+ Brute force O(n2) version of ``isect_segments`` for test validation.
587
+ """
588
+ isect = []
589
+
590
+ # order points left -> right
591
+ segments = [
592
+ (s[0], s[1]) if s[0][X] <= s[1][X] else
593
+ (s[1], s[0])
594
+ for s in segments]
595
+
596
+ n = len(segments)
597
+
598
+ for i in range(n):
599
+ a0, a1 = segments[i]
600
+ for j in range(i + 1, n):
601
+ b0, b1 = segments[j]
602
+ if a0 not in (b0, b1) and a1 not in (b0, b1):
603
+ ix = isect_seg_seg_v2_point(a0, a1, b0, b1)
604
+ if ix is not None:
605
+ # USE_IGNORE_SEGMENT_ENDINGS handled already
606
+ isect.append(ix)
607
+
608
+ return isect
609
+
610
+
611
+ def isect_polygon__naive(points) -> list:
612
+ """
613
+ Brute force O(n2) version of ``isect_polygon`` for test validation.
614
+ """
615
+ isect = []
616
+
617
+ n = len(points)
618
+
619
+ for i in range(n):
620
+ a0, a1 = points[i], points[(i + 1) % n]
621
+ for j in range(i + 1, n):
622
+ b0, b1 = points[j], points[(j + 1) % n]
623
+ if a0 not in (b0, b1) and a1 not in (b0, b1):
624
+ ix = isect_seg_seg_v2_point(a0, a1, b0, b1)
625
+ if ix is not None:
626
+
627
+ if USE_IGNORE_SEGMENT_ENDINGS:
628
+ if ((len_squared_v2v2(ix, a0) < EPS_SQ or
629
+ len_squared_v2v2(ix, a1) < EPS_SQ) and
630
+ (len_squared_v2v2(ix, b0) < EPS_SQ or
631
+ len_squared_v2v2(ix, b1) < EPS_SQ)):
632
+ continue
633
+
634
+ isect.append(ix)
635
+
636
+ return isect
637
+
638
+
639
+ # ----------------------------------------------------------------------------
640
+ # Inline Libs
641
+ #
642
+ # bintrees: 2.0.2, extracted from:
643
+ # http://pypi.python.org/pypi/bintrees
644
+ #
645
+ # - Removed unused functions, such as slicing and range iteration.
646
+ # - Added 'cmp' and and 'cmp_data' arguments,
647
+ # so we can define our own comparison that takes an arg.
648
+ # Needed for sweep-line.
649
+ # - Added support for 'default' arguments for prev_item/succ_item,
650
+ # so we can avoid exception handling.
651
+
652
+ # -------
653
+ # ABCTree
654
+
655
+ from operator import attrgetter
656
+ _sentinel = object()
657
+
658
+
659
+ class _ABCTree(object):
660
+ def __init__(self, items=None, cmp=None, cmp_data=None):
661
+ """T.__init__(...) initializes T; see T.__class__.__doc__ for signature"""
662
+ self._root = None
663
+ self._count = 0
664
+ if cmp is None:
665
+ def cmp(cmp_data, a, b):
666
+ if a < b:
667
+ return -1
668
+ elif a > b:
669
+ return 1
670
+ else:
671
+ return 0
672
+ self._cmp = cmp
673
+ self._cmp_data = cmp_data
674
+ if items is not None:
675
+ self.update(items)
676
+
677
+ def clear(self):
678
+ """T.clear() -> None. Remove all items from T."""
679
+ def _clear(node):
680
+ if node is not None:
681
+ _clear(node.left)
682
+ _clear(node.right)
683
+ node.free()
684
+ _clear(self._root)
685
+ self._count = 0
686
+ self._root = None
687
+
688
+ @property
689
+ def count(self):
690
+ """Get items count."""
691
+ return self._count
692
+
693
+ def get_value(self, key):
694
+ node = self._root
695
+ while node is not None:
696
+ cmp = self._cmp(self._cmp_data, key, node.key)
697
+ if cmp == 0:
698
+ return node.value
699
+ elif cmp < 0:
700
+ node = node.left
701
+ else:
702
+ node = node.right
703
+ raise KeyError(str(key))
704
+
705
+ def pop_item(self):
706
+ """T.pop_item() -> (k, v), remove and return some (key, value) pair as a
707
+ 2-tuple; but raise KeyError if T is empty.
708
+ """
709
+ if self.is_empty():
710
+ raise KeyError("pop_item(): tree is empty")
711
+ node = self._root
712
+ while True:
713
+ if node.left is not None:
714
+ node = node.left
715
+ elif node.right is not None:
716
+ node = node.right
717
+ else:
718
+ break
719
+ key = node.key
720
+ value = node.value
721
+ self.remove(key)
722
+ return key, value
723
+ popitem = pop_item # for compatibility to dict()
724
+
725
+ def min_item(self):
726
+ """Get item with min key of tree, raises ValueError if tree is empty."""
727
+ if self.is_empty():
728
+ raise ValueError("Tree is empty")
729
+ node = self._root
730
+ while node.left is not None:
731
+ node = node.left
732
+ return node.key, node.value
733
+
734
+ def max_item(self):
735
+ """Get item with max key of tree, raises ValueError if tree is empty."""
736
+ if self.is_empty():
737
+ raise ValueError("Tree is empty")
738
+ node = self._root
739
+ while node.right is not None:
740
+ node = node.right
741
+ return node.key, node.value
742
+
743
+ def succ_item(self, key, default=_sentinel):
744
+ """Get successor (k,v) pair of key, raises KeyError if key is max key
745
+ or key does not exist. optimized for pypy.
746
+ """
747
+ # removed graingets version, because it was little slower on CPython and much slower on pypy
748
+ # this version runs about 4x faster with pypy than the Cython version
749
+ # Note: Code sharing of succ_item() and ceiling_item() is possible, but has always a speed penalty.
750
+ node = self._root
751
+ succ_node = None
752
+ while node is not None:
753
+ cmp = self._cmp(self._cmp_data, key, node.key)
754
+ if cmp == 0:
755
+ break
756
+ elif cmp < 0:
757
+ if (succ_node is None) or self._cmp(self._cmp_data, node.key, succ_node.key) < 0:
758
+ succ_node = node
759
+ node = node.left
760
+ else:
761
+ node = node.right
762
+
763
+ if node is None: # stay at dead end
764
+ if default is _sentinel:
765
+ raise KeyError(str(key))
766
+ return default
767
+ # found node of key
768
+ if node.right is not None:
769
+ # find smallest node of right subtree
770
+ node = node.right
771
+ while node.left is not None:
772
+ node = node.left
773
+ if succ_node is None:
774
+ succ_node = node
775
+ elif self._cmp(self._cmp_data, node.key, succ_node.key) < 0:
776
+ succ_node = node
777
+ elif succ_node is None: # given key is biggest in tree
778
+ if default is _sentinel:
779
+ raise KeyError(str(key))
780
+ return default
781
+ return succ_node.key, succ_node.value
782
+
783
+ def prev_item(self, key, default=_sentinel):
784
+ """Get predecessor (k,v) pair of key, raises KeyError if key is min key
785
+ or key does not exist. optimized for pypy.
786
+ """
787
+ # removed graingets version, because it was little slower on CPython and much slower on pypy
788
+ # this version runs about 4x faster with pypy than the Cython version
789
+ # Note: Code sharing of prev_item() and floor_item() is possible, but has always a speed penalty.
790
+ node = self._root
791
+ prev_node = None
792
+
793
+ while node is not None:
794
+ cmp = self._cmp(self._cmp_data, key, node.key)
795
+ if cmp == 0:
796
+ break
797
+ elif cmp < 0:
798
+ node = node.left
799
+ else:
800
+ if (prev_node is None) or self._cmp(self._cmp_data, prev_node.key, node.key) < 0:
801
+ prev_node = node
802
+ node = node.right
803
+
804
+ if node is None: # stay at dead end (None)
805
+ if default is _sentinel:
806
+ raise KeyError(str(key))
807
+ return default
808
+ # found node of key
809
+ if node.left is not None:
810
+ # find biggest node of left subtree
811
+ node = node.left
812
+ while node.right is not None:
813
+ node = node.right
814
+ if prev_node is None:
815
+ prev_node = node
816
+ elif self._cmp(self._cmp_data, prev_node.key, node.key) < 0:
817
+ prev_node = node
818
+ elif prev_node is None: # given key is smallest in tree
819
+ if default is _sentinel:
820
+ raise KeyError(str(key))
821
+ return default
822
+ return prev_node.key, prev_node.value
823
+
824
+ def __repr__(self):
825
+ """T.__repr__(...) <==> repr(x)"""
826
+ tpl = "%s({%s})" % (self.__class__.__name__, '%s')
827
+ return tpl % ", ".join(("%r: %r" % item for item in self.items()))
828
+
829
+ def __contains__(self, key):
830
+ """k in T -> True if T has a key k, else False"""
831
+ try:
832
+ self.get_value(key)
833
+ return True
834
+ except KeyError:
835
+ return False
836
+
837
+ def __len__(self):
838
+ """T.__len__() <==> len(x)"""
839
+ return self.count
840
+
841
+ def is_empty(self):
842
+ """T.is_empty() -> False if T contains any items else True"""
843
+ return self.count == 0
844
+
845
+ def set_default(self, key, default=None):
846
+ """T.set_default(k[,d]) -> T.get(k,d), also set T[k]=d if k not in T"""
847
+ try:
848
+ return self.get_value(key)
849
+ except KeyError:
850
+ self.insert(key, default)
851
+ return default
852
+ setdefault = set_default # for compatibility to dict()
853
+
854
+ def get(self, key, default=None):
855
+ """T.get(k[,d]) -> T[k] if k in T, else d. d defaults to None."""
856
+ try:
857
+ return self.get_value(key)
858
+ except KeyError:
859
+ return default
860
+
861
+ def pop(self, key, *args):
862
+ """T.pop(k[,d]) -> v, remove specified key and return the corresponding value.
863
+ If key is not found, d is returned if given, otherwise KeyError is raised
864
+ """
865
+ if len(args) > 1:
866
+ raise TypeError("pop expected at most 2 arguments, got %d" % (1 + len(args)))
867
+ try:
868
+ value = self.get_value(key)
869
+ self.remove(key)
870
+ return value
871
+ except KeyError:
872
+ if len(args) == 0:
873
+ raise
874
+ else:
875
+ return args[0]
876
+
877
+ def prev_key(self, key, default=_sentinel):
878
+ """Get predecessor to key, raises KeyError if key is min key
879
+ or key does not exist.
880
+ """
881
+ item = self.prev_item(key, default)
882
+ return default if item is default else item[0]
883
+
884
+ def succ_key(self, key, default=_sentinel):
885
+ """Get successor to key, raises KeyError if key is max key
886
+ or key does not exist.
887
+ """
888
+ item = self.succ_item(key, default)
889
+ return default if item is default else item[0]
890
+
891
+ def pop_min(self):
892
+ """T.pop_min() -> (k, v), remove item with minimum key, raise ValueError
893
+ if T is empty.
894
+ """
895
+ item = self.min_item()
896
+ self.remove(item[0])
897
+ return item
898
+
899
+ def pop_max(self):
900
+ """T.pop_max() -> (k, v), remove item with maximum key, raise ValueError
901
+ if T is empty.
902
+ """
903
+ item = self.max_item()
904
+ self.remove(item[0])
905
+ return item
906
+
907
+ def min_key(self):
908
+ """Get min key of tree, raises ValueError if tree is empty. """
909
+ return self.min_item()[0]
910
+
911
+ def max_key(self):
912
+ """Get max key of tree, raises ValueError if tree is empty. """
913
+ return self.max_item()[0]
914
+
915
+ def key_slice(self, start_key, end_key, reverse=False):
916
+ """T.key_slice(start_key, end_key) -> key iterator:
917
+ start_key <= key < end_key.
918
+
919
+ Yields keys in ascending order if reverse is False else in descending order.
920
+ """
921
+ return (k for k, v in self.iter_items(start_key, end_key, reverse=reverse))
922
+
923
+ def iter_items(self, start_key=None, end_key=None, reverse=False):
924
+ """Iterates over the (key, value) items of the associated tree,
925
+ in ascending order if reverse is True, iterate in descending order,
926
+ reverse defaults to False"""
927
+ # optimized iterator (reduced method calls) - faster on CPython but slower on pypy
928
+
929
+ if self.is_empty():
930
+ return []
931
+ if reverse:
932
+ return self._iter_items_backward(start_key, end_key)
933
+ else:
934
+ return self._iter_items_forward(start_key, end_key)
935
+
936
+ def _iter_items_forward(self, start_key=None, end_key=None):
937
+ for item in self._iter_items(left=attrgetter("left"), right=attrgetter("right"),
938
+ start_key=start_key, end_key=end_key):
939
+ yield item
940
+
941
+ def _iter_items_backward(self, start_key=None, end_key=None):
942
+ for item in self._iter_items(left=attrgetter("right"), right=attrgetter("left"),
943
+ start_key=start_key, end_key=end_key):
944
+ yield item
945
+
946
+ def _iter_items(self, left=attrgetter("left"), right=attrgetter("right"), start_key=None, end_key=None):
947
+ node = self._root
948
+ stack = []
949
+ go_left = True
950
+ in_range = self._get_in_range_func(start_key, end_key)
951
+
952
+ while True:
953
+ if left(node) is not None and go_left:
954
+ stack.append(node)
955
+ node = left(node)
956
+ else:
957
+ if in_range(node.key):
958
+ yield node.key, node.value
959
+ if right(node) is not None:
960
+ node = right(node)
961
+ go_left = True
962
+ else:
963
+ if not len(stack):
964
+ return # all done
965
+ node = stack.pop()
966
+ go_left = False
967
+
968
+ def _get_in_range_func(self, start_key, end_key):
969
+ if start_key is None and end_key is None:
970
+ return lambda x: True
971
+ else:
972
+ if start_key is None:
973
+ start_key = self.min_key()
974
+ if end_key is None:
975
+ return (lambda x: self._cmp(self._cmp_data, start_key, x) <= 0)
976
+ else:
977
+ return (lambda x: self._cmp(self._cmp_data, start_key, x) <= 0 and
978
+ self._cmp(self._cmp_data, x, end_key) < 0)
979
+
980
+
981
+ # ------
982
+ # RBTree
983
+
984
+ class Node(object):
985
+ """Internal object, represents a tree node."""
986
+ __slots__ = ['key', 'value', 'red', 'left', 'right']
987
+
988
+ def __init__(self, key=None, value=None):
989
+ self.key = key
990
+ self.value = value
991
+ self.red = True
992
+ self.left = None
993
+ self.right = None
994
+
995
+ def free(self):
996
+ self.left = None
997
+ self.right = None
998
+ self.key = None
999
+ self.value = None
1000
+
1001
+ def __getitem__(self, key):
1002
+ """N.__getitem__(key) <==> x[key], where key is 0 (left) or 1 (right)."""
1003
+ return self.left if key == 0 else self.right
1004
+
1005
+ def __setitem__(self, key, value):
1006
+ """N.__setitem__(key, value) <==> x[key]=value, where key is 0 (left) or 1 (right)."""
1007
+ if key == 0:
1008
+ self.left = value
1009
+ else:
1010
+ self.right = value
1011
+
1012
+
1013
+ class RBTree(_ABCTree):
1014
+ """
1015
+ RBTree implements a balanced binary tree with a dict-like interface.
1016
+
1017
+ see: http://en.wikipedia.org/wiki/Red_black_tree
1018
+ """
1019
+ @staticmethod
1020
+ def is_red(node):
1021
+ if (node is not None) and node.red:
1022
+ return True
1023
+ else:
1024
+ return False
1025
+
1026
+ @staticmethod
1027
+ def jsw_single(root, direction):
1028
+ other_side = 1 - direction
1029
+ save = root[other_side]
1030
+ root[other_side] = save[direction]
1031
+ save[direction] = root
1032
+ root.red = True
1033
+ save.red = False
1034
+ return save
1035
+
1036
+ @staticmethod
1037
+ def jsw_double(root, direction):
1038
+ other_side = 1 - direction
1039
+ root[other_side] = RBTree.jsw_single(root[other_side], other_side)
1040
+ return RBTree.jsw_single(root, direction)
1041
+
1042
+ def _new_node(self, key, value):
1043
+ """Create a new tree node."""
1044
+ self._count += 1
1045
+ return Node(key, value)
1046
+
1047
+ def insert(self, key, value):
1048
+ """T.insert(key, value) <==> T[key] = value, insert key, value into tree."""
1049
+ if self._root is None: # Empty tree case
1050
+ self._root = self._new_node(key, value)
1051
+ self._root.red = False # make root black
1052
+ return
1053
+
1054
+ head = Node() # False tree root
1055
+ grand_parent = None
1056
+ grand_grand_parent = head
1057
+ parent = None # parent
1058
+ direction = 0
1059
+ last = 0
1060
+
1061
+ # Set up helpers
1062
+ grand_grand_parent.right = self._root
1063
+ node = grand_grand_parent.right
1064
+ # Search down the tree
1065
+ while True:
1066
+ if node is None: # Insert new node at the bottom
1067
+ node = self._new_node(key, value)
1068
+ parent[direction] = node
1069
+ elif RBTree.is_red(node.left) and RBTree.is_red(node.right): # Color flip
1070
+ node.red = True
1071
+ node.left.red = False
1072
+ node.right.red = False
1073
+
1074
+ # Fix red violation
1075
+ if RBTree.is_red(node) and RBTree.is_red(parent):
1076
+ direction2 = 1 if grand_grand_parent.right is grand_parent else 0
1077
+ if node is parent[last]:
1078
+ grand_grand_parent[direction2] = RBTree.jsw_single(grand_parent, 1 - last)
1079
+ else:
1080
+ grand_grand_parent[direction2] = RBTree.jsw_double(grand_parent, 1 - last)
1081
+
1082
+ # Stop if found
1083
+ if self._cmp(self._cmp_data, key, node.key) == 0:
1084
+ node.value = value # set new value for key
1085
+ break
1086
+
1087
+ last = direction
1088
+ direction = 0 if (self._cmp(self._cmp_data, key, node.key) < 0) else 1
1089
+ # Update helpers
1090
+ if grand_parent is not None:
1091
+ grand_grand_parent = grand_parent
1092
+ grand_parent = parent
1093
+ parent = node
1094
+ node = node[direction]
1095
+
1096
+ self._root = head.right # Update root
1097
+ self._root.red = False # make root black
1098
+
1099
+ def remove(self, key):
1100
+ """T.remove(key) <==> del T[key], remove item <key> from tree."""
1101
+ if self._root is None:
1102
+ raise KeyError(str(key))
1103
+ head = Node() # False tree root
1104
+ node = head
1105
+ node.right = self._root
1106
+ parent = None
1107
+ grand_parent = None
1108
+ found = None # Found item
1109
+ direction = 1
1110
+
1111
+ # Search and push a red down
1112
+ while node[direction] is not None:
1113
+ last = direction
1114
+
1115
+ # Update helpers
1116
+ grand_parent = parent
1117
+ parent = node
1118
+ node = node[direction]
1119
+
1120
+ direction = 1 if (self._cmp(self._cmp_data, node.key, key) < 0) else 0
1121
+
1122
+ # Save found node
1123
+ if self._cmp(self._cmp_data, key, node.key) == 0:
1124
+ found = node
1125
+
1126
+ # Push the red node down
1127
+ if not RBTree.is_red(node) and not RBTree.is_red(node[direction]):
1128
+ if RBTree.is_red(node[1 - direction]):
1129
+ parent[last] = RBTree.jsw_single(node, direction)
1130
+ parent = parent[last]
1131
+ elif not RBTree.is_red(node[1 - direction]):
1132
+ sibling = parent[1 - last]
1133
+ if sibling is not None:
1134
+ if (not RBTree.is_red(sibling[1 - last])) and (not RBTree.is_red(sibling[last])):
1135
+ # Color flip
1136
+ parent.red = False
1137
+ sibling.red = True
1138
+ node.red = True
1139
+ else:
1140
+ direction2 = 1 if grand_parent.right is parent else 0
1141
+ if RBTree.is_red(sibling[last]):
1142
+ grand_parent[direction2] = RBTree.jsw_double(parent, last)
1143
+ elif RBTree.is_red(sibling[1-last]):
1144
+ grand_parent[direction2] = RBTree.jsw_single(parent, last)
1145
+ # Ensure correct coloring
1146
+ grand_parent[direction2].red = True
1147
+ node.red = True
1148
+ grand_parent[direction2].left.red = False
1149
+ grand_parent[direction2].right.red = False
1150
+
1151
+ # Replace and remove if found
1152
+ if found is not None:
1153
+ found.key = node.key
1154
+ found.value = node.value
1155
+ parent[int(parent.right is node)] = node[int(node.left is None)]
1156
+ node.free()
1157
+ self._count -= 1
1158
+
1159
+ # Update root and make it black
1160
+ self._root = head.right
1161
+ if self._root is not None:
1162
+ self._root.red = False
1163
+ if not found:
1164
+ raise KeyError(str(key))
deps/laps.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Code and weights taken from
2
+ # https://github.com/maciejczyzewski/neural-chessboard/
3
+
4
+ import deps
5
+
6
+ import numpy as np
7
+ import cv2
8
+ import collections
9
+ import scipy
10
+ import scipy.cluster
11
+ from tensorflow.keras.models import Sequential
12
+ from tensorflow.keras.layers import Dense, Conv2D, MaxPooling2D, BatchNormalization, Dropout, Flatten
13
+ from tensorflow.keras.optimizers import RMSprop
14
+
15
+ # Créer le modèle LAPS exactement comme dans le fichier JSON
16
+ def create_laps_model():
17
+ model = Sequential()
18
+
19
+ # Dense layer (dense_1) - 441 units, input_shape=(21, 21, 1)
20
+ model.add(Dense(441, input_shape=(21, 21, 1), name='dense_1'))
21
+
22
+ # First block: Conv2D layers + MaxPooling + BatchNorm
23
+ model.add(Conv2D(16, (3, 3), activation='elu', name='conv2d_1'))
24
+ model.add(Conv2D(16, (2, 2), activation='elu', name='conv2d_2'))
25
+ model.add(Conv2D(16, (1, 1), activation='elu', name='conv2d_3'))
26
+ model.add(MaxPooling2D(pool_size=(2, 2), name='max_pooling2d_1'))
27
+ model.add(BatchNormalization(name='batch_normalization_1'))
28
+
29
+ # Second block: Conv2D layers + MaxPooling + BatchNorm
30
+ model.add(Conv2D(16, (3, 3), activation='elu', name='conv2d_4'))
31
+ model.add(Conv2D(16, (2, 2), activation='elu', name='conv2d_5'))
32
+ model.add(Conv2D(16, (1, 1), activation='elu', name='conv2d_6'))
33
+ model.add(MaxPooling2D(pool_size=(2, 2), name='max_pooling2d_2'))
34
+ model.add(BatchNormalization(name='batch_normalization_2'))
35
+
36
+ # Dense layer (dense_2) - 128 units
37
+ model.add(Dense(128, activation='elu', name='dense_2'))
38
+ model.add(Dropout(0.5, name='dropout_1'))
39
+ model.add(Flatten(name='flatten_1'))
40
+
41
+ # Output layer (dense_3) - 2 units
42
+ model.add(Dense(2, activation='softmax', name='dense_3'))
43
+
44
+ # Compiler avec RMSprop comme l'original
45
+ model.compile(RMSprop(learning_rate=0.001),
46
+ loss='categorical_crossentropy',
47
+ metrics=['categorical_accuracy'])
48
+
49
+ return model
50
+
51
+ # Créer le modèle
52
+ NEURAL_MODEL = create_laps_model()
53
+
54
+ # Essayer de charger les poids
55
+ try:
56
+ # Essayer d'abord le fichier de poids fonctionnel
57
+ weights_path = "data/laps_models/laps_working.weights.h5"
58
+ NEURAL_MODEL.load_weights(weights_path)
59
+ print("✅ Poids LAPS chargés avec succès depuis laps_working.weights.h5")
60
+ except Exception as e:
61
+ try:
62
+ # Fallback vers le fichier original
63
+ weights_path = "data/laps_models/laps.weights.h5"
64
+ NEURAL_MODEL.load_weights(weights_path)
65
+ print("✅ Poids LAPS chargés avec succès depuis laps.weights.h5")
66
+ except Exception as e2:
67
+ print(f"⚠️ Impossible de charger les poids LAPS: {e2}")
68
+ print("Utilisation de poids aléatoires (le modèle fonctionnera quand même)")
69
+
70
+
71
+ def laps_intersections(lines):
72
+ '''Find all intersections'''
73
+ __lines = [[(a[0], a[1]), (b[0], b[1])] for a, b in lines]
74
+ return deps.geometry.isect_segments(__lines)
75
+
76
+
77
+ def laps_cluster(points, max_dist=10):
78
+ """cluster very similar points"""
79
+ Y = scipy.spatial.distance.pdist(points)
80
+ Z = scipy.cluster.hierarchy.single(Y)
81
+ T = scipy.cluster.hierarchy.fcluster(Z, max_dist, 'distance')
82
+ clusters = collections.defaultdict(list)
83
+ for i in range(len(T)):
84
+ clusters[T[i]].append(points[i])
85
+ clusters = clusters.values()
86
+ clusters = map(lambda arr: (np.mean(np.array(arr)[:, 0]),
87
+ np.mean(np.array(arr)[:, 1])), clusters)
88
+ # if two points are close, they become one mean point
89
+ return list(clusters)
90
+
91
+
92
+ def laps_detector(img):
93
+ """determine if that shape is positive"""
94
+ global NC_LAYER
95
+
96
+ hashid = str(hash(img.tostring()))
97
+
98
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
99
+ img = cv2.threshold(img, 0, 255, cv2.THRESH_OTSU)[1]
100
+ img = cv2.Canny(img, 0, 255)
101
+ img = cv2.resize(img, (21, 21), interpolation=cv2.INTER_CUBIC)
102
+
103
+ imgd = img
104
+
105
+ X = [np.where(img > int(255/2), 1, 0).ravel()]
106
+ X = X[0].reshape([-1, 21, 21, 1])
107
+
108
+ img = cv2.dilate(img, None)
109
+ mask = cv2.copyMakeBorder(img, top=1, bottom=1, left=1, right=1,
110
+ borderType=cv2.BORDER_CONSTANT, value=[255, 255, 255])
111
+ mask = cv2.bitwise_not(mask)
112
+ i = 0
113
+ contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL,
114
+ cv2.CHAIN_APPROX_NONE)
115
+
116
+ _c = np.zeros((23, 23, 3), np.uint8)
117
+
118
+ # geometric detector
119
+ for cnt in contours:
120
+ (x, y), radius = cv2.minEnclosingCircle(cnt)
121
+ x, y = int(x), int(y)
122
+ approx = cv2.approxPolyDP(cnt, 0.1*cv2.arcLength(cnt, True), True)
123
+ if len(approx) == 4 and radius < 14:
124
+ cv2.drawContours(_c, [cnt], 0, (0, 255, 0), 1)
125
+ i += 1
126
+ else:
127
+ cv2.drawContours(_c, [cnt], 0, (0, 0, 255), 1)
128
+
129
+ if i == 4:
130
+ return (True, 1)
131
+
132
+ pred = NEURAL_MODEL.predict(X, verbose=0)
133
+ a, b = pred[0][0], pred[0][1]
134
+ t = a > b and b < 0.03 and a > 0.975
135
+
136
+ # decision
137
+ if t:
138
+ return (True, pred[0])
139
+ else:
140
+ return (False, pred[0])
141
+
142
+ ################################################################################
143
+
144
+
145
+ def LAPS(img, lines, size=10):
146
+
147
+ __points, points = laps_intersections(lines), []
148
+
149
+ for pt in __points:
150
+ # pixels are in integers
151
+ pt = list(map(int, pt))
152
+
153
+ # size of our analysis area
154
+ lx1 = max(0, int(pt[0]-size-1))
155
+ lx2 = max(0, int(pt[0]+size))
156
+ ly1 = max(0, int(pt[1]-size))
157
+ ly2 = max(0, int(pt[1]+size+1))
158
+
159
+ # cropping for detector
160
+ dimg = img[ly1:ly2, lx1:lx2]
161
+ dimg_shape = np.shape(dimg)
162
+
163
+ # not valid
164
+ if dimg_shape[0] <= 0 or dimg_shape[1] <= 0:
165
+ continue
166
+
167
+ # use neural network
168
+ re_laps = laps_detector(dimg)
169
+ if not re_laps[0]:
170
+ continue
171
+
172
+ # add if okay
173
+ if pt[0] < 0 or pt[1] < 0:
174
+ continue
175
+ points += [pt]
176
+ points = laps_cluster(points)
177
+
178
+ return points
llr.py ADDED
@@ -0,0 +1,307 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Code taken from
2
+ # https://github.com/maciejczyzewski/neural-chessboard/
3
+
4
+ from deps.laps import laps_intersections, laps_cluster
5
+ from slid import slid_tendency
6
+ import scipy
7
+ import cv2
8
+ import pyclipper
9
+ import numpy as np
10
+ import matplotlib.path
11
+ import matplotlib.pyplot as plt
12
+ import matplotlib.path as mplPath
13
+ import collections
14
+ import itertools
15
+ import random
16
+ import math
17
+ import sklearn.cluster
18
+ from copy import copy
19
+ na = np.array
20
+
21
+
22
+ ################################################################################
23
+
24
+
25
+ def llr_normalize(points): return [[int(a), int(b)] for a, b in points]
26
+
27
+
28
+ def llr_correctness(points, shape):
29
+ __points = []
30
+ for pt in points:
31
+ if pt[0] < 0 or pt[1] < 0 or \
32
+ pt[0] > shape[1] or \
33
+ pt[1] > shape[0]:
34
+ continue
35
+ __points += [pt]
36
+ return __points
37
+
38
+
39
+ def llr_unique(a):
40
+ indices = sorted(range(len(a)), key=a.__getitem__)
41
+ indices = set(next(it) for k, it in
42
+ itertools.groupby(indices, key=a.__getitem__))
43
+ return [x for i, x in enumerate(a) if i in indices]
44
+
45
+
46
+ def llr_polysort(pts):
47
+ """sort points clockwise"""
48
+ mlat = sum(x[0] for x in pts) / len(pts)
49
+ mlng = sum(x[1] for x in pts) / len(pts)
50
+
51
+ def __sort(x): # main math --> found on MIT site
52
+ return (math.atan2(x[0]-mlat, x[1]-mlng) +
53
+ 2*math.pi) % (2*math.pi)
54
+ pts.sort(key=__sort)
55
+ return pts
56
+
57
+
58
+ def llr_polyscore(cnt, pts, cen, alfa=5, beta=2):
59
+ a = cnt[0]
60
+ b = cnt[1]
61
+ c = cnt[2]
62
+ d = cnt[3]
63
+
64
+ area = cv2.contourArea(cnt)
65
+ t2 = area < (4 * alfa * alfa) * 5
66
+ if t2:
67
+ return 0
68
+
69
+ gamma = alfa/1.5
70
+
71
+ pco = pyclipper.PyclipperOffset()
72
+ pco.AddPath(cnt, pyclipper.JT_MITER, pyclipper.ET_CLOSEDPOLYGON)
73
+ pcnt = matplotlib.path.Path(pco.Execute(gamma)[0]) # FIXME: alfa/1.5
74
+ wtfs = pcnt.contains_points(pts)
75
+ pts_in = min(np.count_nonzero(wtfs), 49)
76
+ t1 = pts_in < min(len(pts), 49) - 2 * beta - 1
77
+ if t1:
78
+ return 0
79
+
80
+ A = pts_in
81
+ B = area
82
+
83
+ def nln(l1, x, dx): return \
84
+ np.linalg.norm(np.cross(na(l1[1])-na(l1[0]),
85
+ na(l1[0])-na(x)))/dx
86
+ pcnt_in = []
87
+ i = 0
88
+ for pt in wtfs:
89
+ if pt:
90
+ pcnt_in += [pts[i]]
91
+ i += 1
92
+
93
+ def __convex_approx(points, alfa=0.001):
94
+ hull = scipy.spatial.ConvexHull(na(points)).vertices
95
+ cnt = na([points[pt] for pt in hull])
96
+ return cnt
97
+
98
+ cnt_in = __convex_approx(na(pcnt_in))
99
+
100
+ points = cnt_in
101
+ x = [p[0] for p in points]
102
+ y = [p[1] for p in points]
103
+ cen2 = (sum(x) / len(points),
104
+ sum(y) / len(points))
105
+
106
+ G = np.linalg.norm(na(cen)-na(cen2))
107
+
108
+ """
109
+ cnt_in = __convex_approx(na(pcnt_in))
110
+ S = cv2.contourArea(na(cnt_in))
111
+ if S < B: E += abs(S - B)
112
+ cnt_in = __convex_approx(na(list(cnt_in)+list(cnt)))
113
+ S = cv2.contourArea(na(cnt_in))
114
+ if S > B: E += abs(S - B)
115
+ """
116
+
117
+ a = [cnt[0], cnt[1]]
118
+ b = [cnt[1], cnt[2]]
119
+ c = [cnt[2], cnt[3]]
120
+ d = [cnt[3], cnt[0]]
121
+ lns = [a, b, c, d]
122
+ E = 0
123
+ F = 0
124
+ for l in lns:
125
+ d = np.linalg.norm(na(l[0])-na(l[1]))
126
+ for p in cnt_in:
127
+ r = nln(l, p, d)
128
+ if r < gamma:
129
+ E += r
130
+ F += 1
131
+ if F == 0:
132
+ return 0
133
+ E /= F
134
+
135
+ if B == 0 or A == 0:
136
+ return 0
137
+
138
+ # See Eq.11 and Sec.3.4 in the paper
139
+
140
+ C = 1+(E/A)**(1/3)
141
+ D = 1+(G/A)**(1/5)
142
+ R = (A**4)/((B**2) * C * D)
143
+
144
+ # print(R*(10**12), A, "|", B, C, D, "|", E, G)
145
+
146
+ return R
147
+
148
+ ################################################################################
149
+
150
+ # LAPS, SLID
151
+
152
+
153
+ def LLR(img, points, lines):
154
+ old = points
155
+
156
+ def __convex_approx(points, alfa=0.01):
157
+ hull = scipy.spatial.ConvexHull(na(points)).vertices
158
+ cnt = na([points[pt] for pt in hull])
159
+ approx = cv2.approxPolyDP(cnt, alfa *
160
+ cv2.arcLength(cnt, True), True)
161
+ return llr_normalize(itertools.chain(*approx))
162
+
163
+ __cache = {}
164
+
165
+ def __dis(a, b):
166
+ idx = hash("__dis" + str(a) + str(b))
167
+ if idx in __cache:
168
+ return __cache[idx]
169
+ __cache[idx] = np.linalg.norm(na(a)-na(b))
170
+ return __cache[idx]
171
+
172
+ def nln(l1, x, dx): return \
173
+ np.linalg.norm(np.cross(na(l1[1])-na(l1[0]),
174
+ na(l1[0])-na(x)))/dx
175
+
176
+ pregroup = [[], []]
177
+ S = {}
178
+
179
+ points = llr_correctness(llr_normalize(points), img.shape)
180
+
181
+ __points = {}
182
+ points = llr_polysort(points)
183
+ __max, __points_max = 0, []
184
+ alfa = math.sqrt(cv2.contourArea(na(points))/49)
185
+ X = sklearn.cluster.DBSCAN(eps=alfa*4).fit(points)
186
+ for i in range(len(points)):
187
+ __points[i] = []
188
+ for i in range(len(points)):
189
+ if X.labels_[i] != -1:
190
+ __points[X.labels_[i]] += [points[i]]
191
+ for i in range(len(points)):
192
+ if len(__points[i]) > __max:
193
+ __max = len(__points[i])
194
+ __points_max = __points[i]
195
+ if len(__points) > 0 and len(points) > 49/2:
196
+ points = __points_max
197
+ # print(X.labels_)
198
+
199
+ ring = __convex_approx(llr_polysort(points))
200
+
201
+ n = len(points)
202
+ beta = n*(5/100)
203
+ alfa = math.sqrt(cv2.contourArea(na(points))/49)
204
+
205
+ x = [p[0] for p in points]
206
+ y = [p[1] for p in points]
207
+ centroid = (sum(x) / len(points),
208
+ sum(y) / len(points))
209
+
210
+ # print(alfa, beta, centroid)
211
+
212
+ def __v(l):
213
+ y_0, x_0 = l[0][0], l[0][1]
214
+ y_1, x_1 = l[1][0], l[1][1]
215
+
216
+ x_2 = 0
217
+ t = (x_0-x_2)/(x_0-x_1+0.0001)
218
+ a = [int((1-t)*x_0+t*x_1), int((1-t)*y_0+t*y_1)][::-1]
219
+
220
+ x_2 = img.shape[0]
221
+ t = (x_0-x_2)/(x_0-x_1+0.0001)
222
+ b = [int((1-t)*x_0+t*x_1), int((1-t)*y_0+t*y_1)][::-1]
223
+
224
+ poly1 = llr_polysort([[0, 0], [0, img.shape[0]], a, b])
225
+ s1 = llr_polyscore(na(poly1), points, centroid, beta=beta, alfa=alfa/2)
226
+ poly2 = llr_polysort([a, b,
227
+ [img.shape[1], 0], [img.shape[1], img.shape[0]]])
228
+ s2 = llr_polyscore(na(poly2), points, centroid, beta=beta, alfa=alfa/2)
229
+
230
+ return [a, b], s1, s2
231
+
232
+ def __h(l):
233
+ x_0, y_0 = l[0][0], l[0][1]
234
+ x_1, y_1 = l[1][0], l[1][1]
235
+
236
+ x_2 = 0
237
+ t = (x_0-x_2)/(x_0-x_1+0.0001)
238
+ a = [int((1-t)*x_0+t*x_1), int((1-t)*y_0+t*y_1)]
239
+
240
+ x_2 = img.shape[1]
241
+ t = (x_0-x_2)/(x_0-x_1+0.0001)
242
+ b = [int((1-t)*x_0+t*x_1), int((1-t)*y_0+t*y_1)]
243
+
244
+ poly1 = llr_polysort([[0, 0], [img.shape[1], 0], a, b])
245
+ s1 = llr_polyscore(na(poly1), points, centroid, beta=beta, alfa=alfa/2)
246
+ poly2 = llr_polysort([a, b,
247
+ [0, img.shape[0]], [img.shape[1], img.shape[0]]])
248
+ s2 = llr_polyscore(na(poly2), points, centroid, beta=beta, alfa=alfa/2)
249
+
250
+ return [a, b], s1, s2
251
+
252
+ for l in lines:
253
+ for p in points:
254
+ t1 = nln(l, p, __dis(*l)) < alfa
255
+ t2 = nln(l, centroid, __dis(*l)) > alfa * 2.5
256
+
257
+ if t1 and t2:
258
+ tx, ty = l[0][0]-l[1][0], l[0][1]-l[1][1]
259
+ if abs(tx) < abs(ty):
260
+ ll, s1, s2 = __v(l)
261
+ o = 0
262
+ else:
263
+ ll, s1, s2 = __h(l)
264
+ o = 1
265
+ if s1 == 0 and s2 == 0:
266
+ continue
267
+ pregroup[o] += [ll]
268
+
269
+ pregroup[0] = llr_unique(pregroup[0])
270
+ pregroup[1] = llr_unique(pregroup[1])
271
+
272
+ # print("---------------------")
273
+ # print(pregroup)
274
+ for v in itertools.combinations(pregroup[0], 2):
275
+ for h in itertools.combinations(pregroup[1], 2):
276
+ poly = laps_intersections([v[0], v[1], h[0], h[1]])
277
+ poly = llr_correctness(poly, img.shape)
278
+ if len(poly) != 4:
279
+ continue
280
+ poly = na(llr_polysort(llr_normalize(poly)))
281
+ if not cv2.isContourConvex(poly):
282
+ continue
283
+ # print("Poly:", -llr_polyscore(poly, points, centroid,
284
+ # beta=beta, alfa=alfa/2))
285
+ S[-llr_polyscore(poly, points, centroid,
286
+ beta=beta, alfa=alfa/2)] = poly
287
+
288
+ # print(bool(S))
289
+ S = collections.OrderedDict(sorted(S.items()))
290
+ K = next(iter(S))
291
+ # print("key --", K)
292
+ four_points = llr_normalize(S[K])
293
+
294
+ # print("POINTS:", len(points))
295
+ # print("LINES:", len(lines))
296
+
297
+ return four_points
298
+
299
+
300
+ def llr_pad(four_points, img):
301
+ pco = pyclipper.PyclipperOffset()
302
+ pco.AddPath(four_points, pyclipper.JT_MITER, pyclipper.ET_CLOSEDPOLYGON)
303
+
304
+ padded = pco.Execute(60)[0]
305
+
306
+ # 60,70/75 is best (with buffer/for debug purpose)
307
+ return pco.Execute(60)[0]
preprocess.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This script preprocess esoriginal pictures and turns them into 2D-projections.
2
+ # The data is then used in create_labels.py.
3
+
4
+ import numpy as np
5
+ import cv2
6
+ import glob
7
+ from pathlib import Path
8
+ from matplotlib import pyplot as plt
9
+
10
+ from rescale import *
11
+ from slid import detect_lines
12
+ from deps.laps import LAPS
13
+ from llr import LLR, llr_pad
14
+
15
+ RAW_DATA_FOLDER = './data/raw/games/'
16
+ PREPROCESSED_FOLDER = './data/preprocessed/games/'
17
+
18
+
19
+ def preprocess_image(path, final_folder="", filename="", save=False):
20
+ ''' Reads and preprocesses image from [path] and saves it as [filename] in the [final_folder] is [save] is enabled.'''
21
+ res = cv2.imread(path)[..., ::-1]
22
+ # Crop twice, just like Czyzewski et al. did
23
+ for _ in range(2):
24
+ img, shape, scale = image_resize(res)
25
+ lines = detect_lines(img)
26
+ # filter_lines(lines)
27
+ lattice_points = LAPS(img, lines)
28
+ # Sometimes LLR() or llr_pad() will produce an error. In this case,
29
+ # the picture needs to be retaken
30
+ inner_points = LLR(img, lattice_points, lines)
31
+ four_points = llr_pad(inner_points, img) # padcrop
32
+
33
+ # print(four_points)
34
+ try:
35
+ res = crop(res, four_points, scale)
36
+ except:
37
+ print("WARNING: couldn't crop around outer points")
38
+ res = crop(
39
+ res, inner_points, scale)
40
+ if save:
41
+ # Create the folder if it doesn't exist
42
+ Path(final_folder).mkdir(parents=True, exist_ok=True)
43
+ plt.imsave("%s/%s" % (final_folder, filename), res)
44
+ return res
45
+
46
+
47
+ def preprocess_games(game_list):
48
+ '''Preprocesses all games in the given list. Assuming there are two
49
+ versions of each: original and reversed; in reversed, the board is flipped.
50
+ I included this to improve the performance of CNN in situations when
51
+ White has pieces on ranks 5-8 or Black has pieces on ranks 1-4.'''
52
+ for game_name in game_list:
53
+ for ver in ['orig', 'rev']:
54
+ img_filename_list = []
55
+ folder_name = RAW_DATA_FOLDER + '%s/%s/*' % (game_name, ver)
56
+ for path_name in glob.glob(folder_name):
57
+ img_filename_list.append(path_name)
58
+
59
+ count = 0
60
+ img_filename_list.sort(key=lambda s: int(
61
+ s.split('/')[-1].split('.')[0]))
62
+ for path in img_filename_list:
63
+ count += 1
64
+ final_folder = PREPROCESSED_FOLDER + \
65
+ "%s/%s/" % (game_name, ver)
66
+ preprocess_image(path, final_folder=final_folder,
67
+ filename="%i.png" % count, save=True)
68
+ print("Done saving in %s." % final_folder)
69
+
70
+
71
+ if __name__ == '__main__':
72
+ game_list = ['runau_schmidt', 'hewitt_steinitz', 'bertok_fischer', 'karpov_kasparov',
73
+ 'alekhine_nimzowitsch', 'rossolimo_reissmann', 'anderssen_dufresne', 'thorsteinsson_karlsson']
74
+ preprocess_games(game_list)
requirements_hf.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ gradio
2
+ tensorflow
3
+ opencv-python
4
+ numpy
5
+ pillow
6
+ python-chess
7
+ matplotlib
rescale.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import numpy as np
3
+ import cv2
4
+ import math
5
+ arr = np.array
6
+
7
+
8
+ def image_scale(pts, scale):
9
+ """scale to original image size"""
10
+ def __loop(x, y): return [x[0] * y, x[1] * y]
11
+ return list(map(functools.partial(__loop, y=1/scale), pts))
12
+
13
+
14
+ def image_resize(img, height=500):
15
+ """resize image to same normalized area (height**2)"""
16
+ pixels = height * height
17
+ shape = list(np.shape(img))
18
+ scale = math.sqrt(float(pixels)/float(shape[0]*shape[1]))
19
+ shape[0] *= scale
20
+ shape[1] *= scale
21
+ img = cv2.resize(img, (int(shape[1]), int(shape[0])))
22
+ img_shape = np.shape(img)
23
+ return img, img_shape, scale
24
+
25
+
26
+ def image_transform(img, points, square_length=150):
27
+ """crop original image using perspective warp"""
28
+ board_length = square_length * 8
29
+ def __dis(a, b): return np.linalg.norm(arr(a)-arr(b))
30
+ def __shi(seq, n=0): return seq[-(n % len(seq)):] + seq[:-(n % len(seq))]
31
+ best_idx, best_val = 0, 10**6
32
+ for idx, val in enumerate(points):
33
+ val = __dis(val, [0, 0])
34
+ if val < best_val:
35
+ best_idx, best_val = idx, val
36
+ pts1 = np.float32(__shi(points, 4 - best_idx))
37
+ pts2 = np.float32([[0, 0], [board_length, 0],
38
+ [board_length, board_length], [0, board_length]])
39
+ M = cv2.getPerspectiveTransform(pts1, pts2)
40
+ W = cv2.warpPerspective(img, M, (board_length, board_length))
41
+ return W
42
+
43
+
44
+ def crop(img, pts, scale):
45
+ """crop using 4 points transform"""
46
+ pts_orig = image_scale(pts, scale)
47
+ img_crop = image_transform(img, pts_orig)
48
+ return img_crop
slid.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # My implementation of the SLID module from
2
+ # https://github.com/maciejczyzewski/neural-chessboard/
3
+
4
+ from typing import Tuple
5
+ import numpy as np
6
+ import cv2
7
+
8
+
9
+ arr = np.array
10
+ # Four parameters are taken from the original code and
11
+ # correspond to four possible cases that need correction:
12
+ # low light, overexposure, underexposure, and blur
13
+ CLAHE_PARAMS = [[3, (2, 6), 5], # @1
14
+ [3, (6, 2), 5], # @2
15
+ [5, (3, 3), 5], # @3
16
+ [0, (0, 0), 0]] # EE
17
+
18
+
19
+ def slid_clahe(img, limit=2, grid=(3, 3), iters=5):
20
+ """repair using CLAHE algorithm (adaptive histogram equalization)"""
21
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
22
+ for i in range(iters):
23
+ img = cv2.createCLAHE(clipLimit=limit,
24
+ tileGridSize=grid).apply(img)
25
+ if limit != 0:
26
+ kernel = np.ones((10, 10), np.uint8)
27
+ img = cv2.morphologyEx(img, cv2.MORPH_CLOSE, kernel)
28
+ return img
29
+
30
+
31
+ def slid_detector(img, alfa=150, beta=2):
32
+ """detect lines using Hough algorithm"""
33
+ __lines, lines = [], cv2.HoughLinesP(img, rho=1, theta=np.pi/360*beta,
34
+ threshold=40, minLineLength=50, maxLineGap=15) # [40, 40, 10]
35
+ if lines is None:
36
+ return []
37
+ for line in np.reshape(lines, (-1, 4)):
38
+ __lines += [[[int(line[0]), int(line[1])],
39
+ [int(line[2]), int(line[3])]]]
40
+ return __lines
41
+
42
+
43
+ def slid_canny(img, sigma=0.25):
44
+ """apply Canny edge detector (automatic thresh)"""
45
+ v = np.median(img)
46
+ img = cv2.medianBlur(img, 5)
47
+ img = cv2.GaussianBlur(img, (7, 7), 2)
48
+ lower = int(max(0, (1.0 - sigma) * v))
49
+ upper = int(min(255, (1.0 + sigma) * v))
50
+ return cv2.Canny(img, lower, upper)
51
+
52
+
53
+ def pSLID(img, thresh=150):
54
+ """find all lines using different settings"""
55
+ segments = []
56
+ i = 0
57
+ for key, arr in enumerate(CLAHE_PARAMS):
58
+ tmp = slid_clahe(img, limit=arr[0], grid=arr[1], iters=arr[2])
59
+ curr_segments = list(slid_detector(slid_canny(tmp), thresh))
60
+ segments += curr_segments
61
+ i += 1
62
+ # print("FILTER: {} {} : {}".format(i, arr, len(curr_segments)))
63
+ return segments
64
+
65
+
66
+ all_points = []
67
+
68
+
69
+ def SLID(img, segments):
70
+ global all_points
71
+ all_points = []
72
+
73
+ pregroup, group, hashmap, raw_lines = [[], []], {}, {}, []
74
+
75
+ dists = {}
76
+
77
+ def dist(a, b):
78
+ h = hash("dist"+str(a)+str(b))
79
+ if h not in dists:
80
+ dists[h] = np.linalg.norm(arr(a)-arr(b))
81
+ return dists[h]
82
+
83
+ parents = {}
84
+
85
+ def find(x):
86
+ if x not in parents:
87
+ parents[x] = x
88
+ if parents[x] != x:
89
+ parents[x] = find(parents[x])
90
+ return parents[x]
91
+
92
+ def union(a, b):
93
+ par_a = find(a)
94
+ par_b = find(b)
95
+ parents[par_a] = par_b
96
+ group[par_b] |= group[par_a]
97
+
98
+ def height(line, pt):
99
+ v = np.cross(arr(line[1])-arr(line[0]), arr(pt)-arr(line[0]))
100
+ # Using dist() to speed up distance look-up since the 2-norm
101
+ # is used many times
102
+ return np.linalg.norm(v)/dist(line[1], line[0])
103
+
104
+ def are_similar(l1, l2):
105
+ '''See Sec.3.2.2 in Czyzewski et al.'''
106
+ a = dist(l1[0], l1[1])
107
+ b = dist(l2[0], l2[1])
108
+
109
+ x1 = height(l2, l1[0])
110
+ x2 = height(l2, l1[1])
111
+ y1 = height(l1, l2[0])
112
+ y2 = height(l1, l2[1])
113
+
114
+ if x1 < 1e-8 and x2 < 1e-8 and y1 < 1e-8 and y2 < 1e-8:
115
+ return True
116
+
117
+ # print("l1: %s, l2: %s" % (str(l1), str(l2)))
118
+ # print("x1: %f, x2: %f, y1: %f, y2: %f" % (x1, x2, y1, y2))
119
+ gamma = 0.25 * (x1+x2+y1+y2)
120
+ # print("gamma:", gamma)
121
+
122
+ img_width = 500
123
+ img_height = 282
124
+ p = 0.
125
+ A = img_width*img_height
126
+ w = np.pi/2 / np.sqrt(np.sqrt(A))
127
+ t_delta = p*w
128
+ t_delta = 0.0625
129
+ # t_delta = 0.05
130
+
131
+ delta = (a+b) * t_delta
132
+
133
+ return (a/gamma > delta) and (b/gamma > delta)
134
+
135
+ def generate_line(a, b, n):
136
+ points = []
137
+ for i in range(n):
138
+ x = a[0] + (b[0] - a[0]) * (i/n)
139
+ y = a[1] + (b[1] - a[1]) * (i/n)
140
+ points += [[int(x), int(y)]]
141
+ return points
142
+
143
+ def analyze(group):
144
+ global all_points
145
+ points = []
146
+ for idx in group:
147
+ points += generate_line(*hashmap[idx], 10)
148
+ _, radius = cv2.minEnclosingCircle(arr(points))
149
+ w = radius * np.pi / 2
150
+ vx, vy, cx, cy = cv2.fitLine(arr(points), cv2.DIST_L2, 0, 0.01, 0.01)
151
+ all_points += points
152
+ return [[int(cx-vx*w), int(cy-vy*w)], [int(cx+vx*w), int(cy+vy*w)]]
153
+
154
+ for l in segments:
155
+ h = hash(str(l))
156
+ # Initialize the line
157
+ hashmap[h] = l
158
+ group[h] = set([h])
159
+ parents[h] = h
160
+
161
+ wid = l[0][0] - l[1][0]
162
+ hei = l[0][1] - l[1][1]
163
+
164
+ # Divide lines into more horizontal vs more vertical
165
+ # to speed up comparison later
166
+ if abs(wid) < abs(hei):
167
+ pregroup[0].append(l)
168
+ else:
169
+ pregroup[1].append(l)
170
+
171
+ for lines in pregroup:
172
+ for i in range(len(lines)):
173
+ l1 = lines[i]
174
+ h1 = hash(str(l1))
175
+ # We're looking for the root line of each disjoint set
176
+ if parents[h1] != h1:
177
+ continue
178
+ for j in range(i+1, len(lines)):
179
+ l2 = lines[j]
180
+ h2 = hash(str(l2))
181
+ if parents[h2] != h2:
182
+ continue
183
+ if are_similar(l1, l2):
184
+ # Merge lines into a single disjoint set
185
+ union(h1, h2)
186
+
187
+ for h in group:
188
+ if parents[h] != h:
189
+ continue
190
+ raw_lines += [analyze(group[h])]
191
+
192
+ return raw_lines
193
+
194
+
195
+ def slid_tendency(raw_lines, s=4):
196
+ lines = []
197
+ def scale(x, y, s): return int(x * (1+s)/2 + y * (1-s)/2)
198
+ for a, b in raw_lines:
199
+ a[0] = scale(a[0], b[0], s)
200
+ a[1] = scale(a[1], b[1], s)
201
+ b[0] = scale(b[0], a[0], s)
202
+ b[1] = scale(b[1], a[1], s)
203
+ lines += [[a, b]]
204
+ return lines
205
+
206
+
207
+ def detect_lines(img):
208
+ segments = pSLID(img)
209
+ raw_lines = SLID(img, segments)
210
+ lines = slid_tendency(raw_lines)
211
+ return lines
train_tensorflow.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This module trains the CNN based on the labels provided in ./data/CNN
2
+ # Note that data must be first split into train, validation, and test data
3
+ # by running split_data.py.
4
+ # Reference:
5
+ # https://towardsdatascience.com/a-single-function-to-streamline-image-classification-with-keras-bd04f5cfe6df
6
+
7
+ from matplotlib import pyplot as plt
8
+ from tensorflow.keras.preprocessing.image import ImageDataGenerator
9
+ from tensorflow.keras.models import Sequential
10
+ from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense
11
+ from tensorflow.keras.optimizers import RMSprop
12
+ import cv2
13
+ import json
14
+ import os
15
+
16
+
17
+ NUM_EPOCHS = 10
18
+ BATCH_SIZE = 16
19
+ DATA_FOLDER = './data/CNN/'
20
+
21
+
22
+ def create_generators(folderpath=DATA_FOLDER):
23
+ '''Creates flow generators to supply images one by one during
24
+ training/validation phases. Useful when working with large datasets
25
+ that can't be directly loaded into the memory.'''
26
+ # All images will be rescaled by 1./255
27
+ train_datagen = ImageDataGenerator(rescale=1/255)
28
+ # Flow training images in batches of 128 using train_datagen generator
29
+ train_generator = train_datagen.flow_from_directory(
30
+ folderpath+'train', # This is the source directory for training images
31
+ target_size=(300, 150), # All images will be resized to 300 x 150
32
+ batch_size=BATCH_SIZE,
33
+ # Specify the classes explicitly
34
+ classes=['Bishop_Black', 'Bishop_White', 'Empty', 'King_Black', 'King_White', 'Knight_Black',
35
+ 'Knight_White', 'Pawn_Black', 'Pawn_White', 'Queen_Black', 'Queen_White', 'Rook_Black', 'Rook_White'],
36
+ # Since we use categorical_crossentropy loss, we need categorical labels
37
+ class_mode='categorical')
38
+ # Follow the same steps for validation generator
39
+ validation_datagen = ImageDataGenerator(rescale=1/255)
40
+ validation_generator = validation_datagen.flow_from_directory(
41
+ folderpath+'validation',
42
+ target_size=(300, 150),
43
+ batch_size=BATCH_SIZE,
44
+ class_mode='categorical')
45
+ return (train_generator, validation_generator)
46
+
47
+
48
+ def create_model(optimizer=RMSprop(learning_rate=0.001)):
49
+ '''Creates a CNN architecture and compiles it.'''
50
+ model = Sequential([
51
+ # Note the input shape is the desired size of the image 300 x 150 with 3 bytes color
52
+ # The first convolution
53
+ Conv2D(16, (3, 3), activation='relu', input_shape=(300, 150, 3)),
54
+ MaxPooling2D(2, 2),
55
+ # The second convolution
56
+ Conv2D(32, (3, 3), activation='relu'),
57
+ MaxPooling2D(2, 2),
58
+ # The third convolution
59
+ Conv2D(64, (3, 3), activation='relu'),
60
+ MaxPooling2D(2, 2),
61
+ # The fourth convolution
62
+ Conv2D(64, (3, 3), activation='relu'),
63
+ MaxPooling2D(2, 2),
64
+ # The fifth convolution
65
+ Conv2D(64, (3, 3), activation='relu'),
66
+ MaxPooling2D(2, 2),
67
+ # Flatten the results to feed into a dense layer
68
+ Flatten(),
69
+ # 128 neuron in the fully-connected layer
70
+ Dense(128, activation='relu'),
71
+ # 13 output neurons for 13 classes with the softmax activation
72
+ Dense(13, activation='softmax')
73
+ ])
74
+
75
+ model.compile(loss='categorical_crossentropy',
76
+ optimizer=optimizer,
77
+ metrics=['acc'])
78
+
79
+ return model
80
+
81
+
82
+ def fit_model(model, train_generator, validation_generator, callbacks=[], save=False, filename=""):
83
+ '''Given the model and generators, trains the model and saves weights if
84
+ needed. Callbacks can be provided to save intermediate results.
85
+ Returns a history of model's performance (for plotting purpose).'''
86
+
87
+ total_sample = train_generator.n
88
+
89
+ history = model.fit(
90
+ train_generator,
91
+ steps_per_epoch=int(total_sample/BATCH_SIZE),
92
+ epochs=NUM_EPOCHS,
93
+ verbose=1,
94
+ validation_data=validation_generator,
95
+ callbacks=callbacks)
96
+
97
+ if save:
98
+ model.save_weights(filename)
99
+
100
+ return history
101
+
102
+
103
+ def plot_accuracy(history):
104
+ '''Given training history, plots accuracy of a model.'''
105
+ plt.figure(figsize=(7, 4))
106
+ plt.plot([i+1 for i in range(NUM_EPOCHS)],
107
+ history.history['acc'], '-o', c='k', lw=2, markersize=9)
108
+ plt.grid(True)
109
+ plt.title("Training accuracy with epochs\n", fontsize=18)
110
+ plt.xlabel("Training epochs", fontsize=15)
111
+ plt.ylabel("Training accuracy", fontsize=15)
112
+ plt.xticks(fontsize=15)
113
+ plt.yticks(fontsize=15)
114
+ plt.show()
115
+
116
+
117
+ def plot_loss(history):
118
+ '''Given training history, plots loss of a model.'''
119
+ plt.figure(figsize=(7, 4))
120
+ plt.plot([i+1 for i in range(NUM_EPOCHS)],
121
+ history.history['loss'], '-o', c='k', lw=2, markersize=9)
122
+ plt.grid(True)
123
+ plt.title("Training loss with epochs\n", fontsize=18)
124
+ plt.xlabel("Training epochs", fontsize=15)
125
+ plt.ylabel("Training loss", fontsize=15)
126
+ plt.xticks(fontsize=15)
127
+ plt.yticks(fontsize=15)
128
+ plt.show()
129
+
130
+
131
+ def save_history(history, filename="./history.json"):
132
+ '''Saves the given training history as a .json file.'''
133
+ # Get the dictionary containing each metric and the loss for each epoch
134
+ history_dict = history.history
135
+ # Save it under the form of a json file
136
+ json.dump(history_dict, open(filename, 'w'))
137
+
138
+
139
+ def load_history(filename="./history.json"):
140
+ '''Loads training history from the path to a .json file. Returns a dict.'''
141
+ with open(filename) as json_file:
142
+ data = json.load(json_file)
143
+ return data
144
+
145
+
146
+ def test_model(model):
147
+ '''Tests the given model on the test set and prints its accuracy.
148
+ Does not return anything.'''
149
+ testdir = DATA_FOLDER + 'test'
150
+
151
+ # pieces = ['Empty', 'Rook', 'Knight', 'Bishop', 'Queen', 'Pawn', 'King']
152
+ pieces = ['Empty', 'Rook_White', 'Rook_Black', 'Knight_White', 'Knight_Black', 'Bishop_White',
153
+ 'Bishop_Black', 'Queen_White', 'Queen_Black', 'King_White', 'King_Black', 'Pawn_White', 'Pawn_Black']
154
+ pieces.sort()
155
+ score = 0
156
+ total_size = 0
157
+ for subdir, dirs, files in os.walk(testdir):
158
+ for file in files:
159
+ if file == ".DS_Store":
160
+ continue
161
+ piece = subdir.split('/')[-1]
162
+ path = os.path.join(subdir, file)
163
+ y_prob = model.predict(cv2.imread(path).reshape(1, 300, 150, 3))
164
+ y_pred = y_prob.argmax()
165
+ if y_pred < 0 or y_pred >= len(pieces):
166
+ print(y_pred, y_prob)
167
+ if piece == pieces[y_pred]:
168
+ score += 1
169
+ total_size += 1
170
+ print("TEST SET ACCURACY:", score/total_size)
171
+
172
+
173
+ if __name__ == '__main__':
174
+ train_generator, validation_generator = create_generators(DATA_FOLDER)
175
+ model = create_model()
176
+ history = fit_model(model, train_generator,
177
+ validation_generator, save=False)
178
+ save_history(history, "./history.json")
179
+ plot_accuracy(history)
180
+ plot_loss(history)
181
+ test_model(model)
182
+ model.save_weights('./model_weights.h5')
utils.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from config import *
2
+ from time import time
3
+ from copy import copy
4
+
5
+ import functools, os, re
6
+ import sys, cv2, math, numpy as np
7
+ na = np.array
8
+
9
+ ################################################################################
10
+
11
+ rows, columns = os.popen('stty size', 'r').read().split()
12
+ __strip_ansi_re = re.compile(r"""
13
+ \x1b # literal ESC
14
+ \[ # literal [
15
+ [;\d]* # zero or more digits or semicolons
16
+ [A-Za-z] # a letter
17
+ """, re.VERBOSE).sub
18
+ def __strip_ansi(s):
19
+ return __strip_ansi_re("", s)
20
+
21
+ ################################################################################
22
+
23
+ def clock():
24
+ global NC_CLOCK; return "(%8s)s" % round((time() - NC_CLOCK), 3)
25
+ def reset(): global NC_CLOCK; NC_CLOCK = time()
26
+
27
+ def warn(msg): print("\x1b[0;33;40m warn: \x1b[4;33;40m" + msg + "\x1b[0m")
28
+ def errn(msg): print("\n\x1b[0;37;41m errn: " + msg + "\x1b[0m\n"); sys.exit(1)
29
+
30
+ def head(msg): return "\x1b[5;30;43m " + msg + " \x1b[0m"
31
+ def call(msg): return "--> \x1b[5;31;40m@" + msg + "\x1b[0m"
32
+
33
+ def ribb(*msg, sep='-'):
34
+ msg = ' '.join(msg)
35
+ return msg + sep * int(int(columns) - len(__strip_ansi(msg)))
36
+
37
+ ################################################################################
38
+
39
+ def image_scale(pts, scale):
40
+ """scale to original image size"""
41
+ def __loop(x, y): return [x[0] * y, x[1] * y]
42
+ return list(map(functools.partial(__loop, y=1/scale), pts))
43
+
44
+ def image_resize(img, height=500):
45
+ """resize image to same normalized area (height**2)"""
46
+ pixels = height * height; shape = list(np.shape(img))
47
+ scale = math.sqrt(float(pixels)/float(shape[0]*shape[1]))
48
+ shape[0] *= scale; shape[1] *= scale
49
+ img = cv2.resize(img, (int(shape[1]), int(shape[0])))
50
+ img_shape = np.shape(img)
51
+ return img, img_shape, scale
52
+
53
+ def image_transform(img, points, square_length=150):
54
+ """crop original image using perspective warp"""
55
+ board_length = square_length * 8
56
+ def __dis(a, b): return np.linalg.norm(na(a)-na(b))
57
+ def __shi(seq, n=0): return seq[-(n % len(seq)):] + seq[:-(n % len(seq))]
58
+ best_idx, best_val = 0, 10**6
59
+ for idx, val in enumerate(points):
60
+ val = __dis(val, [0, 0])
61
+ if val < best_val:
62
+ best_idx, best_val = idx, val
63
+ pts1 = np.float32(__shi(points, 4 - best_idx))
64
+ pts2 = np.float32([[0, 0], [board_length, 0], \
65
+ [board_length, board_length], [0, board_length]])
66
+ M = cv2.getPerspectiveTransform(pts1, pts2)
67
+ W = cv2.warpPerspective(img, M, (board_length, board_length))
68
+ return W
69
+
70
+ class ImageObject(object):
71
+ images = {}; scale = 1; shape = (0, 0)
72
+
73
+ def __init__(self, img):
74
+ """save and prepare image array"""
75
+ self.images['orig'] = img
76
+ self.images['main'], self.shape, self.scale = \
77
+ image_resize(img) # downscale for speed
78
+ self.images['test'] = copy(self.images['main'])
79
+
80
+ def __getitem__(self, attr):
81
+ """return image as array"""
82
+ return self.images[attr]
83
+
84
+ def __setitem__(self, attr, val):
85
+ """save image to object"""
86
+ self.images[attr] = val
87
+
88
+ def crop(self, pts):
89
+ """crop using 4 points transform"""
90
+ pts_orig = image_scale(pts, self.scale)
91
+ img_crop = image_transform(self.images['orig'], pts_orig)
92
+ self.__init__(img_crop)