rca_demo / main.py
ifw-arz's picture
Update main.py
92e7b0b
raw
history blame contribute delete
No virus
4.13 kB
from fastapi import FastAPI
from typing import List
import torch
from PIL import Image
import numpy as np
import torch.nn as nn
import matplotlib.pyplot as plt
import matplotlib as mpl
from fastapi.responses import FileResponse
import shutil
app = FastAPI()
class CustomModel(nn.Module):
def __init__(self):
super(CustomModel, self).__init__()
self.fc_layers = nn.Sequential(
nn.Linear(3, 64),
nn.ReLU(),
nn.Linear(64, 744 * 554)
)
def forward(self, x_features):
x_features = self.fc_layers(x_features)
output = x_features.view(-1, 1, 744, 554)
return output
def predict_image(parameters):
"""Predicts an image based on parameters (feedrate, depth of cut, toolwear)"""
# Load the trained model
model = CustomModel()
model.load_state_dict(torch.load('trained_model.pth'))
model.eval()
with torch.no_grad():
input_features = torch.tensor(parameters, dtype=torch.float32)
predicted_image = model(input_features.unsqueeze(0))
return predicted_image.numpy()
def image_to_ts(image):
"""Transforms an image to a time series and returns the plot"""
# Extract the pixel values from the image
z_values = image[0, :]
reversed_values = z_values
reversed_values = (reversed_values - 0.5)
x = np.arange(len(reversed_values)) / len(reversed_values) * 25 + 5
# Plot the time series
fig, ax = plt.subplots(figsize=(8, 5))
ax.set_ylim(-0.2, 0.2)
ax.set_xlim(5, 30)
mpl.rcParams['font.family'] = 'Arial'
mpl.rcParams['font.size'] = 30
ax.set_xlabel("Workpiece length", fontname="Arial", fontsize=16, labelpad=7)
ax.set_ylabel("Normalized surface height", fontname="Arial", fontsize=16, labelpad=7)
plt.yticks(fontname="Arial", fontsize=14, color="black")
plt.xticks(range(5, 31, 5), fontname="Arial", fontsize=14, color = "black")
#plt.title("Surface geometry",fontname="Arial", fontsize=18, color="black", weight="bold", pad=10)
xticks = ax.get_xticks()
xticklabels = [str(int(x)) if x != xticks[-2] else "mm" for x in xticks]
ax.set_xticklabels(xticklabels)
yticks = ax.get_yticks()
#yticklabels = [str(int(y)) if y != yticks[-2] else "" for y in yticks]
yticklabels = yticks = ["" for y in yticks]
ax.set_yticklabels(yticklabels)
gridwidth = 1.5
plt.grid(axis="y", linewidth=0.75, color="black")
plt.grid(axis="x", linewidth=0.75, color="black")
rand = ["top", "right", "bottom", "left"]
for i in rand:
plt.gca().spines[i].set_linewidth(gridwidth)
ax.spines[i].set_color('black')
plt.plot(x, reversed_values, color="#00509b", linewidth=1.5)
# Define the tolerance range
tolerance_lower = -0.085
tolerance_upper = 0.085
ax.fill_between(x, tolerance_lower, tolerance_upper, color='gray', alpha=0.2)
# Check if the plot is within tolerance
within_tolerance = all(tolerance_lower <= val <= tolerance_upper for val in reversed_values)
tolerance = None
if within_tolerance:
tolerance = True
else:
tolerance = False
return fig, tolerance
def save_figure(fig, filename):
"""Save the figure as an image file."""
fig.savefig(filename)
plt.close(fig)
@app.post("/predict")
def prediction(feedrate: float, depth_of_cut: float, toolwear: float):
new_features = torch.tensor([[feedrate, depth_of_cut, toolwear]])
predicted_image = predict_image(new_features)
fig, tolerance = image_to_ts(predicted_image[0, 0])
figure_filename = "output_figure.png"
#save_figure(fig, figure_filename)
figure_url = f"/get_figure/{figure_filename}"
#return {"figure": fig, "within_tolerance": tolerance}
return {"figure_url": figure_url, "within_tolerance": tolerance}
@app.get("/get_figure/{filename}")
def get_figure(filename: str):
return FileResponse(filename, media_type="image/png", filename=filename)
@app.get("/")
def root():
return {"Doku hier"}
@app.get("/test")
def test():
return {"das ist ein Test"}