guru001's picture
model files added
b388df6
raw history blame
No virus
2.61 kB
import tensorflow as tf
import numpy as np
from PIL import Image
import os
# from matplotlib import image as mpimg
# from matplotlib import pyplot as plt
class api():
height=64
width=64
channels=3
model_name = 'cnn_model'
classes = { 0 : 'Zero' , 1 : 'One' , 2 : 'Two' , 3 : 'Three' , 4 : 'Four' , 5 : 'Five' }
def reset_graph(self,seed=42):
tf.reset_default_graph()
tf.set_random_seed(seed)
np.random.seed(seed)
def __init__(self,upload_path='uploads'):
self.upload_path = upload_path
# self.model_name = 'cnn_model'
print('print',os.path.join('signs_api','{}.meta'.format(self.model_name)))
# self.import_meta = tf.train.import_meta_graph(os.path.join('signs_api','{}.meta'.format(self.model_name)))
def predict(self,im):
try :
# im = Image.open( os.path.join(self.upload_path,filename) )
#image size
size=(self.height,self.width)
#resize image
out = im.resize(size)
test_image = np.array(out.getdata())
test_image = test_image.reshape((-1,self.height,self.width,self.channels))
# to make this notebook's output stable across runs
self.reset_graph()
# import meta from directory
# import_meta = tf.train.import_meta_graph('{}.meta'.format(self.model_name))
import_meta = tf.train.import_meta_graph(os.path.join('signs_api','{}.meta'.format(self.model_name)))
with tf.Session() as sess:
# tf.train.latest_checkpoint(<dir>) also works
import_meta.restore(sess,'{}.ckpt'.format( os.path.join('signs_api',self.model_name) ) )
# W1_val = sess.graph.get_tensor_by_name('W1:0')
# X_val = sess.graph.get_tensor_by_name('Placeholder:0')
ArgMax = sess.graph.get_tensor_by_name('ArgMax:0')
ArgMax_val = ArgMax.eval({ 'Placeholder:0' : test_image })
# graph = tf.get_default_graph()
# for op in graph.get_operations():
# print(op.name)
# print('W1_val',W1_val)
# print('X_val',X_val)
print('ArgMax',ArgMax_val)
index = ArgMax_val.tolist()[0]
class_val = self.classes[index]
# os.remove(os.path.join(self.upload_path,filename))
return { 'value' : index , 'class' : class_val }
except (OSError,IOError) as e:
print('error',e)
return { 'error' : True }