andrewzamp commited on
Commit
9f198ef
1 Parent(s): 1364165

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +100 -1
app.py CHANGED
@@ -1,3 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # Define the Gradio interface
2
  interface = gr.Interface(
3
  fn=make_prediction, # Function to be called for predictions
@@ -11,6 +110,6 @@ interface = gr.Interface(
11
  # Launch the Gradio interface with authentication for the specified users
12
  interface.launch(auth=[
13
  ("Andrea Zampetti", "andreazampetti"),
14
- ("Luca Santini", "lucasantini"),
15
  ("Ana Ben铆tez L贸pez", "anaben铆tezl贸pez")
16
  ])
 
1
+ # Import the libraries
2
+ import numpy as np
3
+ import pandas as pd
4
+ from tensorflow.keras.models import load_model
5
+ from tensorflow.keras.preprocessing.image import load_img, img_to_array
6
+ from tensorflow.keras.applications.convnext import preprocess_input
7
+ import gradio as gr
8
+
9
+ # Load the model
10
+ model = load_model('models/ConvNeXtBase_80_tresh_spp.tf')
11
+
12
+ # Load the taxonomy .csv
13
+ taxo_df = pd.read_csv('taxonomy/taxonomy_mapping.csv')
14
+ taxo_df['species'] = taxo_df['species'].str.replace('_', ' ')
15
+
16
+ # Available taxonomic levels
17
+ taxonomic_levels = ['species', 'genus', 'family', 'order', 'class']
18
+
19
+ # Function to map predicted class index to class name at the selected taxonomic level
20
+ def get_class_name(predicted_class, taxonomic_level):
21
+ unique_labels = sorted(taxo_df[taxonomic_level].unique())
22
+ return unique_labels[predicted_class]
23
+
24
+ # Function to aggregate predictions to a higher taxonomic level
25
+ def aggregate_predictions(predicted_probs, taxonomic_level, class_names):
26
+ unique_labels = sorted(taxo_df[taxonomic_level].unique())
27
+ aggregated_predictions = np.zeros((predicted_probs.shape[0], len(unique_labels)))
28
+
29
+ for idx, row in taxo_df.iterrows():
30
+ species = row['species']
31
+ higher_level = row[taxonomic_level]
32
+
33
+ species_index = class_names.index(species) # Index of the species in the prediction array
34
+ higher_level_index = unique_labels.index(higher_level)
35
+
36
+ aggregated_predictions[:, higher_level_index] += predicted_probs[:, species_index]
37
+
38
+ return aggregated_predictions, unique_labels
39
+
40
+ # Function to load and preprocess the image
41
+ def load_and_preprocess_image(image, target_size=(224, 224)):
42
+ # Resize the image
43
+ img_array = img_to_array(image.resize(target_size))
44
+ # Expand the dimensions to match model input
45
+ img_array = np.expand_dims(img_array, axis=0)
46
+ # Preprocess the image
47
+ img_array = preprocess_input(img_array)
48
+ return img_array
49
+
50
+ # Function to make predictions
51
+ def make_prediction(image, taxonomic_level):
52
+ # Preprocess the image
53
+ img_array = load_and_preprocess_image(image)
54
+
55
+ # Get the class names from the 'species' column
56
+ class_names = sorted(taxo_df['species'].unique()) # Add this line to define class_names
57
+
58
+ # Make a prediction
59
+ prediction = model.predict(img_array)
60
+
61
+ # Aggregate predictions based on the selected taxonomic level
62
+ aggregated_predictions, aggregated_class_labels = aggregate_predictions(prediction, taxonomic_level, class_names)
63
+
64
+ # Get the top 5 predictions
65
+ top_indices = np.argsort(aggregated_predictions[0])[-5:][::-1]
66
+
67
+ # Get predicted class for the top prediction
68
+ predicted_class_index = np.argmax(aggregated_predictions)
69
+ predicted_class_name = aggregated_class_labels[predicted_class_index]
70
+
71
+ # Check if common name should be displayed (only at species level)
72
+ if taxonomic_level == "species":
73
+ predicted_common_name = taxo_df[taxo_df[taxonomic_level] == predicted_class_name]['common_name'].values[0]
74
+ output_text = f"<h1 style='font-weight: bold;'><span style='font-style: italic;'>{predicted_class_name}</span> ({predicted_common_name})</h1>"
75
+ else:
76
+ output_text = f"<h1 style='font-weight: bold;'>{predicted_class_name}</h1>"
77
+
78
+ # Add the top 5 predictions
79
+ output_text += "<h4 style='font-weight: bold; font-size: 1.2em;'>Top 5 Predictions:</h4>"
80
+
81
+ for i in top_indices:
82
+ class_name = aggregated_class_labels[i]
83
+
84
+ if taxonomic_level == "species":
85
+ # Display common names only at species level and make it italic
86
+ common_name = taxo_df[taxo_df[taxonomic_level] == class_name]['common_name'].values[0]
87
+ confidence_percentage = aggregated_predictions[0][i] * 100
88
+ output_text += f"<div style='display: flex; justify-content: space-between;'>" \
89
+ f"<span style='font-style: italic;'>{class_name}</span>&nbsp;(<span>{common_name}</span>)" \
90
+ f"<span style='margin-left: auto;'>{confidence_percentage:.2f}%</span></div>"
91
+ else:
92
+ # No common names at higher taxonomic levels
93
+ confidence_percentage = aggregated_predictions[0][i] * 100
94
+ output_text += f"<div style='display: flex; justify-content: space-between;'>" \
95
+ f"<span>{class_name}</span>" \
96
+ f"<span style='margin-left: auto;'>{confidence_percentage:.2f}%</span></div>"
97
+
98
+ return output_text
99
+
100
  # Define the Gradio interface
101
  interface = gr.Interface(
102
  fn=make_prediction, # Function to be called for predictions
 
110
  # Launch the Gradio interface with authentication for the specified users
111
  interface.launch(auth=[
112
  ("Andrea Zampetti", "andreazampetti"),
113
+ ("Luca Santini", "lucasantini"),
114
  ("Ana Ben铆tez L贸pez", "anaben铆tezl贸pez")
115
  ])