Spaces:
Runtime error
Runtime error
import numpy as np | |
import gradio as gr | |
import os | |
import PIL | |
import PIL.Image | |
import tensorflow as tf | |
import tensorflow_datasets as tfds | |
import pathlib | |
dataset_url = "https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz" | |
data_dir = tf.keras.utils.get_file(origin=dataset_url, | |
fname='flower_photos', | |
untar=True) | |
data_dir = pathlib.Path(data_dir) | |
batch_size = 32 | |
img_height = 180 | |
img_width = 180 | |
train_ds = tf.keras.utils.image_dataset_from_directory( | |
data_dir, | |
validation_split=0.2, | |
subset="training", | |
seed=123, | |
image_size=(img_height, img_width), | |
batch_size=batch_size) | |
val_ds = tf.keras.utils.image_dataset_from_directory( | |
data_dir, | |
validation_split=0.2, | |
subset="validation", | |
seed=123, | |
image_size=(img_height, img_width), | |
batch_size=batch_size) | |
class_names = train_ds.class_names | |
#print(class_names) | |
normalization_layer = tf.keras.layers.Rescaling(1./255) | |
normalized_ds = train_ds.map(lambda x, y: (normalization_layer(x), y)) | |
image_batch, labels_batch = next(iter(normalized_ds)) | |
first_image = image_batch[0] | |
# Notice the pixel values are now in `[0,1]`. | |
#print(np.min(first_image), np.max(first_image)) | |
AUTOTUNE = tf.data.AUTOTUNE | |
train_ds = train_ds.cache().prefetch(buffer_size=AUTOTUNE) | |
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE) | |
num_classes = 5 | |
model = tf.keras.Sequential([ | |
tf.keras.layers.Rescaling(1./255), | |
tf.keras.layers.Conv2D(32, 3, activation='relu'), | |
tf.keras.layers.MaxPooling2D(), | |
tf.keras.layers.Dropout(0.4), | |
tf.keras.layers.Conv2D(32, 3, activation='relu'), | |
tf.keras.layers.MaxPooling2D(), | |
tf.keras.layers.Dropout(0.4), | |
tf.keras.layers.Conv2D(32, 3, activation='relu'), | |
tf.keras.layers.MaxPooling2D(), | |
tf.keras.layers.Flatten(), | |
tf.keras.layers.Dense(256, activation='relu'), | |
tf.keras.layers.Dense(num_classes, activation="softmax") | |
]) | |
model.compile( | |
optimizer='adam', | |
loss='SparseCategoricalCrossentropy', | |
metrics=['accuracy']) | |
model.fit( | |
train_ds, | |
validation_data=val_ds, | |
epochs=5 | |
) | |
def predict_input_image(img): | |
img_4d=img.reshape(-1,180,180,3) | |
prediction=model.predict(img_4d)[0] | |
return {class_names[i]: float(prediction[i]*0.100) for i in range(5)} | |
image = gr.inputs.Image(shape=(180,180)) | |
label =gr.outputs.Label(num_top_classes=5) | |
gr.Interface(fn=predict_input_image, inputs=image, outputs=label,title="Flowers Image classification").launch() | |
#pt |