GAP_SMALL_PROJECT / utils /chart_vis.py
fatimaxa's picture
Upload app files for plant disease detection
c514c69 verified
import matplotlib.pyplot as plt
import numpy as np
def create_prediction_plot(top_preds):
"""Create a horizontal bar chart of predictions"""
labels = [label for label, _ in top_preds]
probs = [prob for _, prob in top_preds]
# Create figure
fig, ax = plt.subplots(figsize=(10, 6))
# Create horizontal bar chart
y_pos = np.arange(len(labels))
colors = plt.cm.RdYlGn(np.array(probs)) # Color based on confidence
ax.barh(y_pos, probs, color=colors, alpha=0.8)
ax.set_yticks(y_pos)
ax.set_yticklabels(labels)
ax.invert_yaxis() # Top prediction at the top
ax.set_xlabel('Confidence Score', fontsize=12)
ax.set_title('Top Disease Predictions', fontsize=14, fontweight='bold')
ax.set_xlim([0, 1])
# Add value labels on bars
for i, (label, prob) in enumerate(zip(labels, probs)):
ax.text(prob + 0.01, i, f'{prob:.3f}',
va='center', fontsize=10)
plt.tight_layout()
return fig