amosfang commited on
Commit
8a60603
1 Parent(s): d053659

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -2
app.py CHANGED
@@ -1,4 +1,3 @@
1
- import matplotlib.pyplot as plt
2
  import numpy as np
3
  from PIL import Image
4
  from skimage.transform import resize
@@ -7,6 +6,9 @@ from tensorflow.keras.models import load_model
7
 
8
  from huggingface_hub import snapshot_download
9
 
 
 
 
10
  import gradio as gr
11
  import os
12
  import io
@@ -62,13 +64,24 @@ def get_predictions(y_prediction_encoded):
62
  return predicted_label_indices
63
 
64
  def predict(image):
 
 
65
  sample_image_resized = resize_image(image)
66
  y_pred = ensemble_predict(sample_image_resized)
67
  y_pred = get_predictions(y_pred).squeeze()
68
 
 
 
 
 
 
 
 
 
 
69
  # Create a figure without saving it to a file
70
  fig, ax = plt.subplots()
71
- cax = ax.imshow(y_pred, cmap='viridis', vmin=1, vmax=7)
72
 
73
  # Convert the figure to a PIL Image
74
  image_buffer = io.BytesIO()
 
 
1
  import numpy as np
2
  from PIL import Image
3
  from skimage.transform import resize
 
6
 
7
  from huggingface_hub import snapshot_download
8
 
9
+ import matplotlib.pyplot as plt
10
+ from matplotlib.colors import ListedColormap
11
+
12
  import gradio as gr
13
  import os
14
  import io
 
64
  return predicted_label_indices
65
 
66
  def predict(image):
67
+
68
+ # Steps to get prediction
69
  sample_image_resized = resize_image(image)
70
  y_pred = ensemble_predict(sample_image_resized)
71
  y_pred = get_predictions(y_pred).squeeze()
72
 
73
+ # Define your custom colors for each label
74
+ colors = ['cyan', 'yellow', 'magenta', 'green', 'blue', 'black', 'white']
75
+ # Create a ListedColormap
76
+ cmap = ListedColormap(colors)
77
+ # Create colorbar and set ticks and ticklabels
78
+ cbar = plt.colorbar(ticks=np.arange(1, 8))
79
+ cbar.set_ticklabels(['Urban', 'Agriculture', 'Range Land', 'Forest', 'Water', 'Barren', 'Unknown'])
80
+
81
+
82
  # Create a figure without saving it to a file
83
  fig, ax = plt.subplots()
84
+ cax = ax.imshow(y_pred, cmap=cmap, vmin=1, vmax=7)
85
 
86
  # Convert the figure to a PIL Image
87
  image_buffer = io.BytesIO()