# Create a function to import an image and resize it to be able to be used with our model import tensorflow as tf import matplotlib.pyplot as plt import os import pathlib import numpy as np data_dir = pathlib.Path("/Users/rosh/Downloads/Train_data") class_names = np.array(sorted([item.name for item in data_dir.glob("*")])) class_names = list(class_names) class_names.pop(0) loaded_model = tf.keras.models.load_model('model_4_improved_8.h5') def load_and_prep_image(filename, img_shape=224): """ Reads an image from filename, turns it into a tensor and reshapes it to (img_shape, img_shape, colour_channel). """ # Read in target file (an image) img = tf.io.read_file(filename) # Decode the read file into a tensor & ensure 3 colour channels # (our model is trained on images with 3 colour channels and sometimes images have 4 colour channels) img = tf.image.decode_image(img, channels=3) # Resize the image (to the same size our model was trained on) img = tf.image.resize(img, size = [img_shape, img_shape]) # Rescale the image (get all values between 0 and 1) img = img/255. return img # Adjust function to work with multi-class def pred_and_plot(model, filename, class_names): """ Imports an image located at filename, makes a prediction on it with a trained model and plots the image with the predicted class as the title. """ # Import the target image and preprocess it img = load_and_prep_image(filename) # Make a prediction pred = model.predict(tf.expand_dims(img, axis=0)) # Get the predicted class pred_class = class_names[pred.argmax()] # if more than one output, take the max # Plot the image and predicted class plt.imshow(img) plt.title(f"Prediction: {pred_class}") plt.axis(False) plt.show() pred_and_plot(loaded_model, "/Users/rosh/Downloads/egret.jpg", class_names) # # loaded_model.compile(loss='categorical_crossentropy', # # optimizer='adam', # # metrics=['accuracy']) # # Get true labels # valid_datagen = ImageDataGenerator( # rescale=1./255 # Rescaling factor # ) # valid_dir = "/Users/rosh/Downloads/Validation_data" # valid_data = valid_datagen.flow_from_directory(directory=valid_dir, # batch_size=32, # target_size=(224, 224), # class_mode="categorical", # seed=42) # pred = loaded_model.predict(valid_data) # preds = pred.argmax(axis=1) # print(preds)