import matplotlib.pyplot as plt import numpy as np # creates horizontal bar plots based on the classification results def horizontal_bar_plot(index, confidence, label): plt.rcParams["font.family"] = "Arial" fig, ax = plt.subplots() fig.set_size_inches(8, 0.9) ax.barh([0], [1], color='lightgray', edgecolor='black', linewidth=3) if index == 0: i = 0.1 elif index == 1: i = 0.5 else: i = 0.9 plt.title(label, fontsize=16, fontweight='light', color='black', loc='center') ax.scatter([i], [0], color='#00509b', marker="s", s=2000, edgecolors='black', linewidth=2) #ax.text(i, 0, f"C={confidence:.2f}", fontsize=12, color='white', ha='center', va='center') ax.set_xlim(-0.2, 1.2) # Adjust x-limits ax.set_ylim(-0.5, 0.5) ax.set_xticks([]) ax.set_yticks([]) ax.spines['top'].set_visible(False) ax.spines['right'].set_visible(False) ax.spines['bottom'].set_visible(False) ax.spines['left'].set_visible(False) return fig def find_max_index(array): max_index = np.argmax(array) return max_index, array[max_index] def check_classification(array): feed = array[0:3] depth = array[3:6] condition = array[6:9] feed_index, feed = find_max_index(feed) fig_feed = horizontal_bar_plot(feed_index, feed, "Vorschub") depth_index, depth = find_max_index(depth) fig_depth = horizontal_bar_plot(depth_index, depth, "Schnitttiefe") condition_index, condition = find_max_index(condition) fig_condition = horizontal_bar_plot(condition_index, condition, "Verschleiß") return fig_feed, fig_depth, fig_condition