FeatherFlock_AI / test.py
roshnn24's picture
Upload 5 files
04ff542 verified
raw
history blame
No virus
2.59 kB
# Create a function to import an image and resize it to be able to be used with our model
import tensorflow as tf
import matplotlib.pyplot as plt
import os
import pathlib
import numpy as np
data_dir = pathlib.Path("/Users/rosh/Downloads/Train_data")
class_names = np.array(sorted([item.name for item in data_dir.glob("*")]))
class_names = list(class_names)
class_names.pop(0)
loaded_model = tf.keras.models.load_model('model_4_improved_8.h5')
def load_and_prep_image(filename, img_shape=224):
"""
Reads an image from filename, turns it into a tensor
and reshapes it to (img_shape, img_shape, colour_channel).
"""
# Read in target file (an image)
img = tf.io.read_file(filename)
# Decode the read file into a tensor & ensure 3 colour channels
# (our model is trained on images with 3 colour channels and sometimes images have 4 colour channels)
img = tf.image.decode_image(img, channels=3)
# Resize the image (to the same size our model was trained on)
img = tf.image.resize(img, size = [img_shape, img_shape])
# Rescale the image (get all values between 0 and 1)
img = img/255.
return img
# Adjust function to work with multi-class
def pred_and_plot(model, filename, class_names):
"""
Imports an image located at filename, makes a prediction on it with
a trained model and plots the image with the predicted class as the title.
"""
# Import the target image and preprocess it
img = load_and_prep_image(filename)
# Make a prediction
pred = model.predict(tf.expand_dims(img, axis=0))
# Get the predicted class
pred_class = class_names[pred.argmax()] # if more than one output, take the max
# Plot the image and predicted class
plt.imshow(img)
plt.title(f"Prediction: {pred_class}")
plt.axis(False)
plt.show()
pred_and_plot(loaded_model, "/Users/rosh/Downloads/egret.jpg", class_names)
# # loaded_model.compile(loss='categorical_crossentropy',
# # optimizer='adam',
# # metrics=['accuracy'])
# # Get true labels
# valid_datagen = ImageDataGenerator(
# rescale=1./255 # Rescaling factor
# )
# valid_dir = "/Users/rosh/Downloads/Validation_data"
# valid_data = valid_datagen.flow_from_directory(directory=valid_dir,
# batch_size=32,
# target_size=(224, 224),
# class_mode="categorical",
# seed=42)
# pred = loaded_model.predict(valid_data)
# preds = pred.argmax(axis=1)
# print(preds)