proki-demo / processing /prediction.py
ifw-arz's picture
init
bcad657
raw
history blame
No virus
3.81 kB
import torch
from torchvision import transforms
from PIL import Image
import numpy as np
import torch.nn as nn
import matplotlib.pyplot as plt
import matplotlib as mpl
""" These functions are executed in the streamlit app """
class CustomModel(nn.Module):
def __init__(self):
super(CustomModel, self).__init__()
self.fc_layers = nn.Sequential(
nn.Linear(3, 64),
nn.ReLU(),
nn.Linear(64, 744 * 554)
)
def forward(self, x_features):
x_features = self.fc_layers(x_features)
output = x_features.view(-1, 1, 744, 554)
return output
def predict_image(parameters):
"""Predicts an image based on parameters (feedrate, depth of cut, toolwear)"""
# Load the trained model
model = CustomModel()
model.load_state_dict(torch.load("processing/prediction_model.pth"))
model.eval()
with torch.no_grad():
input_features = torch.tensor(parameters, dtype=torch.float32)
predicted_image = model(input_features.unsqueeze(0))
return predicted_image.numpy()
def image_to_ts(image):
"""Transforms an image to a time series and returns the plot"""
# Extract the pixel values from the image
z_values = image[0, :]
reversed_values = z_values
reversed_values = (reversed_values - 0.5)
"""
# Define vmin and vmax used during transformation
vmin = -0.1
vmax = 0.1
# Inverse transformation
reversed_values = vmin + (z_values / 255.0) * (vmax - vmin)
# drop all entries with value 0.01:
reversed_values = reversed_values[reversed_values != 0.01]
#reversed_values = reversed_values + 3.25
"""
# reverse transformation. Check again for mistakes
x = np.arange(len(reversed_values)) / len(reversed_values) * 25 + 5
# Plot the time series
fig, ax = plt.subplots(figsize=(8, 5))
ax.set_ylim(-0.25, 0.25)
ax.set_xlim(5, 30)
mpl.rcParams['font.family'] = 'Arial'
mpl.rcParams['font.size'] = 30
ax.set_xlabel("Bauteillänge", fontname="Arial", fontsize=16, labelpad=7)
ax.set_ylabel("Normalisierte Oberfläche", fontname="Arial", fontsize=16, labelpad=7)
plt.yticks(fontname="Arial", fontsize=14, color="black")
plt.xticks(range(5, 31, 5), fontname="Arial", fontsize=14, color = "black")
#plt.title("Oberfläche",fontname="Arial", fontsize=18, color="black", weight="bold", pad=10)
xticks = ax.get_xticks()
xticklabels = [str(int(x)) if x != xticks[-2] else "mm" for x in xticks]
ax.set_xticklabels(xticklabels)
yticks = ax.get_yticks()
#yticklabels = [str(int(y)) if y != yticks[-2] else "" for y in yticks]
yticklabels = yticks = ["" for y in yticks]
ax.set_yticklabels(yticklabels)
gridwidth = 1.5
plt.grid(axis="y", linewidth=0.75, color="black")
plt.grid(axis="x", linewidth=0.75, color="black")
rand = ["top", "right", "bottom", "left"]
for i in rand:
plt.gca().spines[i].set_linewidth(gridwidth)
ax.spines[i].set_color('black')
plt.plot(x, reversed_values, color="#00509b", linewidth=2)
"""
# Define the tolerance range
tolerance_lower = -0.085
tolerance_upper = 0.085
ax.fill_between(x, tolerance_lower, tolerance_upper, color='gray', alpha=0.2)
# Check if the plot is within tolerance
within_tolerance = all(tolerance_lower <= val <= tolerance_upper for val in reversed_values)
tolerance = None
if within_tolerance:
tolerance = True
else:
tolerance = False
"""
return fig
def NN_prediction(feed, plaindepth, wear):
new_features = torch.tensor([[feed, plaindepth, wear]])
predicted_image = predict_image(new_features)
fig = image_to_ts(predicted_image[0, 0])
return fig