import torch from torchvision import transforms from PIL import Image import numpy as np import torch.nn as nn import matplotlib.pyplot as plt import matplotlib as mpl """ These functions are executed in the streamlit app """ class CustomModel(nn.Module): def __init__(self): super(CustomModel, self).__init__() self.fc_layers = nn.Sequential( nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, 744 * 554) ) def forward(self, x_features): x_features = self.fc_layers(x_features) output = x_features.view(-1, 1, 744, 554) return output def predict_image(parameters): """Predicts an image based on parameters (feedrate, depth of cut, toolwear)""" # Load the trained model model = CustomModel() model.load_state_dict(torch.load("processing/prediction_model.pth")) model.eval() with torch.no_grad(): input_features = torch.tensor(parameters, dtype=torch.float32) predicted_image = model(input_features.unsqueeze(0)) return predicted_image.numpy() def image_to_ts(image): """Transforms an image to a time series and returns the plot""" # Extract the pixel values from the image z_values = image[0, :] reversed_values = z_values reversed_values = (reversed_values - 0.5) """ # Define vmin and vmax used during transformation vmin = -0.1 vmax = 0.1 # Inverse transformation reversed_values = vmin + (z_values / 255.0) * (vmax - vmin) # drop all entries with value 0.01: reversed_values = reversed_values[reversed_values != 0.01] #reversed_values = reversed_values + 3.25 """ # reverse transformation. Check again for mistakes x = np.arange(len(reversed_values)) / len(reversed_values) * 25 + 5 # Plot the time series fig, ax = plt.subplots(figsize=(8, 5)) ax.set_ylim(-0.25, 0.25) ax.set_xlim(5, 30) mpl.rcParams[''] = 'Arial' mpl.rcParams['font.size'] = 30 ax.set_xlabel("Bauteillänge", fontname="Arial", fontsize=16, labelpad=7) ax.set_ylabel("Normalisierte Oberfläche", fontname="Arial", fontsize=16, labelpad=7) plt.yticks(fontname="Arial", fontsize=14, color="black") plt.xticks(range(5, 31, 5), fontname="Arial", fontsize=14, color = "black") #plt.title("Oberfläche",fontname="Arial", fontsize=18, color="black", weight="bold", pad=10) xticks = ax.get_xticks() xticklabels = [str(int(x)) if x != xticks[-2] else "mm" for x in xticks] ax.set_xticklabels(xticklabels) yticks = ax.get_yticks() #yticklabels = [str(int(y)) if y != yticks[-2] else "" for y in yticks] yticklabels = yticks = ["" for y in yticks] ax.set_yticklabels(yticklabels) gridwidth = 1.5 plt.grid(axis="y", linewidth=0.75, color="black") plt.grid(axis="x", linewidth=0.75, color="black") rand = ["top", "right", "bottom", "left"] for i in rand: plt.gca().spines[i].set_linewidth(gridwidth) ax.spines[i].set_color('black') plt.plot(x, reversed_values, color="#00509b", linewidth=2) """ # Define the tolerance range tolerance_lower = -0.085 tolerance_upper = 0.085 ax.fill_between(x, tolerance_lower, tolerance_upper, color='gray', alpha=0.2) # Check if the plot is within tolerance within_tolerance = all(tolerance_lower <= val <= tolerance_upper for val in reversed_values) tolerance = None if within_tolerance: tolerance = True else: tolerance = False """ return fig def NN_prediction(feed, plaindepth, wear): new_features = torch.tensor([[feed, plaindepth, wear]]) predicted_image = predict_image(new_features) fig = image_to_ts(predicted_image[0, 0]) return fig