|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1' |
|
import tensorflow as tf |
|
import numpy as np |
|
|
|
from .helper_functions import * |
|
from .helper_image_loading import * |
|
from .chessboard_finder import * |
|
|
|
def load_graph(frozen_graph_filepath): |
|
|
|
with tf.io.gfile.GFile(frozen_graph_filepath, "rb") as f: |
|
graph_def = tf.compat.v1.GraphDef() |
|
graph_def.ParseFromString(f.read()) |
|
|
|
|
|
with tf.Graph().as_default() as graph: |
|
|
|
tf.import_graph_def(graph_def, name="tcb") |
|
return graph |
|
|
|
class ChessboardPredictor(object): |
|
"""ChessboardPredictor using saved model""" |
|
def __init__(self, frozen_graph_path='./saved_models/frozen_graph.pb'): |
|
|
|
print("\t Loading model '%s'" % frozen_graph_path) |
|
graph = load_graph(frozen_graph_path) |
|
self.sess = tf.compat.v1.Session(graph=graph) |
|
|
|
|
|
self.x = graph.get_tensor_by_name('tcb/Input:0') |
|
self.keep_prob = graph.get_tensor_by_name('tcb/KeepProb:0') |
|
self.prediction = graph.get_tensor_by_name('tcb/prediction:0') |
|
self.probabilities = graph.get_tensor_by_name('tcb/probabilities:0') |
|
print("\t Model restored.") |
|
|
|
def getPrediction(self, tiles): |
|
"""Run trained neural network on tiles generated from image""" |
|
if tiles is None or len(tiles) == 0: |
|
print("Couldn't parse chessboard") |
|
return None, 0.0 |
|
|
|
|
|
validation_set = np.swapaxes(np.reshape(tiles, [32*32, 64]),0,1) |
|
|
|
|
|
guess_prob, guessed = self.sess.run( |
|
[self.probabilities, self.prediction], |
|
feed_dict={self.x: validation_set, self.keep_prob: 1.0}) |
|
|
|
|
|
a = np.array(list(map(lambda x: x[0][x[1]], zip(guess_prob, guessed)))) |
|
tile_certainties = a.reshape([8,8])[::-1,:] |
|
|
|
|
|
|
|
labelIndex2Name = lambda label_index: ' KQRBNPkqrbnp'[label_index] |
|
pieceNames = list(map(lambda k: '1' if k == 0 else labelIndex2Name(k), guessed)) |
|
fen = '/'.join([''.join(pieceNames[i*8:(i+1)*8]) for i in reversed(range(8))]) |
|
return fen, tile_certainties |
|
|
|
|
|
def makePrediction(self, url): |
|
"""Try and return a FEN prediction and certainty for URL, return Nones otherwise""" |
|
img, url = helper_image_loading.loadImageFromURL(url, max_size_bytes=2000000) |
|
result = [None, None, None] |
|
|
|
|
|
if img is None: |
|
print('Couldn\'t load URL: "%s"' % url) |
|
return result |
|
|
|
|
|
img = helper_image_loading.resizeAsNeeded(img) |
|
|
|
|
|
if img is None: |
|
print('Image too large to resize: "%s"' % url) |
|
return result |
|
|
|
|
|
tiles, corners = chessboard_finder.findGrayscaleTilesInImage(img) |
|
|
|
|
|
if tiles is None: |
|
print('Couldn\'t find chessboard in image') |
|
return result |
|
|
|
|
|
fen, tile_certainties = self.getPrediction(tiles) |
|
|
|
|
|
certainty = tile_certainties.min() |
|
|
|
|
|
visualize_link = helper_image_loading.getVisualizeLink(corners, url) |
|
|
|
|
|
result = [fen, certainty, visualize_link] |
|
return result |
|
|
|
def close(self): |
|
print("Closing session.") |
|
self.sess.close() |
|
|
|
|
|
|
|
|
|
def main(args): |
|
|
|
if args.filepath: |
|
|
|
img = helper_image_loading.loadImageFromPath(args.filepath) |
|
args.url = None |
|
else: |
|
img, args.url = helper_image_loading.loadImageFromURL(args.url) |
|
|
|
|
|
if img is None: |
|
raise Exception('Couldn\'t load URL: "%s"' % args.url) |
|
|
|
|
|
|
|
|
|
|
|
tiles, corners = chessboard_finder.findGrayscaleTilesInImage(img) |
|
|
|
|
|
if tiles is None: |
|
raise Exception('Couldn\'t find chessboard in image') |
|
|
|
|
|
if args.url: |
|
viz_link = helper_image_loading.getVisualizeLink(corners, args.url) |
|
print('---\nVisualize tiles link:\n %s\n---' % viz_link) |
|
|
|
if args.url: |
|
print("\n--- Prediction on url %s ---" % args.url) |
|
else: |
|
print("\n--- Prediction on file %s ---" % args.filepath) |
|
|
|
|
|
predictor = ChessboardPredictor() |
|
fen, tile_certainties = predictor.getPrediction(tiles) |
|
predictor.close() |
|
if args.unflip: |
|
fen = unflipFEN(fen) |
|
short_fen = shortenFEN(fen) |
|
|
|
certainty = tile_certainties.min() |
|
|
|
print('Per-tile certainty:') |
|
print(tile_certainties) |
|
print("Certainty range [%g - %g], Avg: %g" % ( |
|
tile_certainties.min(), tile_certainties.max(), tile_certainties.mean())) |
|
|
|
active = args.active |
|
print("---\nPredicted FEN:\n%s %s - - 0 1" % (short_fen, active)) |
|
print("Final Certainty: %.1f%%" % (certainty*100)) |
|
|
|
if __name__ == '__main__': |
|
np.set_printoptions(suppress=True, precision=3) |
|
import argparse |
|
parser = argparse.ArgumentParser(description='Predict a chessboard FEN from supplied local image link or URL') |
|
parser.add_argument('--url', default='http://imgur.com/u4zF5Hj.png', help='URL of image (ex. http://imgur.com/u4zF5Hj.png)') |
|
parser.add_argument('--filepath', help='filepath to image (ex. u4zF5Hj.png)') |
|
parser.add_argument('--unflip', default=False, action='store_true', help='revert the image of a flipped chessboard') |
|
parser.add_argument('--active', default='w') |
|
args = parser.parse_args() |
|
main(args) |
|
|
|
|
|
|