brxerq commited on
Commit
612620a
1 Parent(s): 6c71344

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -31
app.py CHANGED
@@ -1,45 +1,43 @@
1
- # -*- coding: utf-8 -*-
 
2
  import gradio as gr
3
  import numpy as np
4
  import tensorflow as tf
 
5
  import cv2
6
- import tensorflow_hub as hub
7
 
8
- # Load class labels from the text file
 
 
 
9
  train_info = []
10
  with open('labelwithspace.txt', 'r') as file:
11
- train_info = [line.strip() for line in file.readlines()]
12
-
13
- # Load your actual model from the .h5 file
14
- def load_real_model():
15
- try:
16
- # Register KerasLayer from TensorFlow Hub if used
17
- custom_objects = {'KerasLayer': hub.KerasLayer}
18
- model = tf.keras.models.load_model('bird_model4.h5', custom_objects=custom_objects)
19
- except Exception as e:
20
- print("Error loading the model:", e)
21
- exit()
22
- return model
23
 
24
- # Initialize the real model
25
- model = load_real_model()
26
-
27
- # Function to preprocess the image and make predictions
28
- def predict_image(image):
29
- # Resize and normalize the image
30
  img = cv2.resize(image, (224, 224))
31
- img = img / 255.0 # Normalize to [0, 1] range
 
32
 
33
- # Make predictions using the loaded model
34
- predictions = model.predict(img[np.newaxis, ...])[0]
35
- top_classes = np.argsort(predictions)[-3:][::-1] # Indices of top 3 predictions
36
- top_class = top_classes[0] # Index of the highest probability class
37
- label = train_info[top_class] # Get the corresponding label
 
 
 
 
 
 
 
38
  return label
39
 
40
  # Define the Gradio interface
41
- input_image = gr.Image()
42
- output_label = gr.Label()
43
 
44
- # Launch the Gradio interface for Hugging Face deployment
45
- gr.Interface(fn=predict_image, inputs=input_image, outputs=output_label).launch()
 
1
+ # app.py
2
+
3
  import gradio as gr
4
  import numpy as np
5
  import tensorflow as tf
6
+ from tensorflow.keras.models import load_model
7
  import cv2
 
8
 
9
+ # Load the pre-trained model
10
+ model = load_model('bird_model.h5')
11
+
12
+ # Load class labels from your text file
13
  train_info = []
14
  with open('labelwithspace.txt', 'r') as file:
15
+ train_info = [line.strip() for line in file.read().splitlines()]
 
 
 
 
 
 
 
 
 
 
 
16
 
17
+ # Function to preprocess the input image
18
+ def preprocess_image(image):
19
+ # Resize the image to the input size expected by the model
 
 
 
20
  img = cv2.resize(image, (224, 224))
21
+ img = img / 255.0 # Normalize the image
22
+ return img
23
 
24
+ # Prediction function
25
+ def predict_image(image):
26
+ # Preprocess the image
27
+ img = preprocess_image(image)
28
+ # Expand dimensions to match the model's input shape
29
+ img = np.expand_dims(img, axis=0)
30
+ # Get model predictions
31
+ predictions = model.predict(img)[0]
32
+ # Find the top prediction
33
+ top_class = np.argmax(predictions)
34
+ # Get the label for the top prediction
35
+ label = train_info[top_class]
36
  return label
37
 
38
  # Define the Gradio interface
39
+ input_image = gr.Image(shape=(224, 224), label="Input Image")
40
+ output_label = gr.Label(label="Predicted Bird Species")
41
 
42
+ # Launch the Gradio app
43
+ gr.Interface(fn=predict_image, inputs=input_image, outputs=output_label, capture_session=True).launch()