Spaces:
Sleeping
Sleeping
import torch | |
from torchvision import transforms | |
from PIL import Image | |
import numpy as np | |
import torch.nn as nn | |
import matplotlib.pyplot as plt | |
import matplotlib as mpl | |
""" These functions are executed in the streamlit app """ | |
class CustomModel(nn.Module): | |
def __init__(self): | |
super(CustomModel, self).__init__() | |
self.fc_layers = nn.Sequential( | |
nn.Linear(3, 64), | |
nn.ReLU(), | |
nn.Linear(64, 744 * 554) | |
) | |
def forward(self, x_features): | |
x_features = self.fc_layers(x_features) | |
output = x_features.view(-1, 1, 744, 554) | |
return output | |
def predict_image(parameters): | |
"""Predicts an image based on parameters (feedrate, depth of cut, toolwear)""" | |
# Load the trained model | |
model = CustomModel() | |
model.load_state_dict(torch.load("processing/prediction_model.pth")) | |
model.eval() | |
with torch.no_grad(): | |
input_features = torch.tensor(parameters, dtype=torch.float32) | |
predicted_image = model(input_features.unsqueeze(0)) | |
return predicted_image.numpy() | |
def image_to_ts(image): | |
"""Transforms an image to a time series and returns the plot""" | |
# Extract the pixel values from the image | |
z_values = image[0, :] | |
reversed_values = z_values | |
reversed_values = (reversed_values - 0.5) | |
""" | |
# Define vmin and vmax used during transformation | |
vmin = -0.1 | |
vmax = 0.1 | |
# Inverse transformation | |
reversed_values = vmin + (z_values / 255.0) * (vmax - vmin) | |
# drop all entries with value 0.01: | |
reversed_values = reversed_values[reversed_values != 0.01] | |
#reversed_values = reversed_values + 3.25 | |
""" | |
# reverse transformation. Check again for mistakes | |
x = np.arange(len(reversed_values)) / len(reversed_values) * 25 + 5 | |
# Plot the time series | |
fig, ax = plt.subplots(figsize=(8, 5)) | |
ax.set_ylim(-0.25, 0.25) | |
ax.set_xlim(5, 30) | |
mpl.rcParams['font.family'] = 'Arial' | |
mpl.rcParams['font.size'] = 30 | |
ax.set_xlabel("Bauteillänge", fontname="Arial", fontsize=16, labelpad=7) | |
ax.set_ylabel("Normalisierte Oberfläche", fontname="Arial", fontsize=16, labelpad=7) | |
plt.yticks(fontname="Arial", fontsize=14, color="black") | |
plt.xticks(range(5, 31, 5), fontname="Arial", fontsize=14, color = "black") | |
#plt.title("Oberfläche",fontname="Arial", fontsize=18, color="black", weight="bold", pad=10) | |
xticks = ax.get_xticks() | |
xticklabels = [str(int(x)) if x != xticks[-2] else "mm" for x in xticks] | |
ax.set_xticklabels(xticklabels) | |
yticks = ax.get_yticks() | |
#yticklabels = [str(int(y)) if y != yticks[-2] else "" for y in yticks] | |
yticklabels = yticks = ["" for y in yticks] | |
ax.set_yticklabels(yticklabels) | |
gridwidth = 1.5 | |
plt.grid(axis="y", linewidth=0.75, color="black") | |
plt.grid(axis="x", linewidth=0.75, color="black") | |
rand = ["top", "right", "bottom", "left"] | |
for i in rand: | |
plt.gca().spines[i].set_linewidth(gridwidth) | |
ax.spines[i].set_color('black') | |
plt.plot(x, reversed_values, color="#00509b", linewidth=2) | |
""" | |
# Define the tolerance range | |
tolerance_lower = -0.085 | |
tolerance_upper = 0.085 | |
ax.fill_between(x, tolerance_lower, tolerance_upper, color='gray', alpha=0.2) | |
# Check if the plot is within tolerance | |
within_tolerance = all(tolerance_lower <= val <= tolerance_upper for val in reversed_values) | |
tolerance = None | |
if within_tolerance: | |
tolerance = True | |
else: | |
tolerance = False | |
""" | |
return fig | |
def NN_prediction(feed, plaindepth, wear): | |
new_features = torch.tensor([[feed, plaindepth, wear]]) | |
predicted_image = predict_image(new_features) | |
fig = image_to_ts(predicted_image[0, 0]) | |
return fig | |