|
from fastapi import FastAPI |
|
from typing import List |
|
import torch |
|
from PIL import Image |
|
import numpy as np |
|
import torch.nn as nn |
|
import matplotlib.pyplot as plt |
|
import matplotlib as mpl |
|
from fastapi.responses import FileResponse |
|
import shutil |
|
|
|
|
|
app = FastAPI() |
|
|
|
|
|
class CustomModel(nn.Module): |
|
def __init__(self): |
|
super(CustomModel, self).__init__() |
|
|
|
self.fc_layers = nn.Sequential( |
|
nn.Linear(3, 64), |
|
nn.ReLU(), |
|
nn.Linear(64, 744 * 554) |
|
) |
|
|
|
def forward(self, x_features): |
|
x_features = self.fc_layers(x_features) |
|
output = x_features.view(-1, 1, 744, 554) |
|
return output |
|
|
|
|
|
def predict_image(parameters): |
|
|
|
"""Predicts an image based on parameters (feedrate, depth of cut, toolwear)""" |
|
|
|
model = CustomModel() |
|
model.load_state_dict(torch.load('trained_model.pth')) |
|
model.eval() |
|
|
|
with torch.no_grad(): |
|
input_features = torch.tensor(parameters, dtype=torch.float32) |
|
predicted_image = model(input_features.unsqueeze(0)) |
|
return predicted_image.numpy() |
|
|
|
def image_to_ts(image): |
|
|
|
"""Transforms an image to a time series and returns the plot""" |
|
|
|
|
|
z_values = image[0, :] |
|
reversed_values = z_values |
|
reversed_values = (reversed_values - 0.5) |
|
|
|
x = np.arange(len(reversed_values)) / len(reversed_values) * 25 + 5 |
|
|
|
|
|
fig, ax = plt.subplots(figsize=(8, 5)) |
|
|
|
ax.set_ylim(-0.2, 0.2) |
|
ax.set_xlim(5, 30) |
|
|
|
mpl.rcParams['font.family'] = 'Arial' |
|
mpl.rcParams['font.size'] = 30 |
|
|
|
ax.set_xlabel("Workpiece length", fontname="Arial", fontsize=16, labelpad=7) |
|
ax.set_ylabel("Normalized surface height", fontname="Arial", fontsize=16, labelpad=7) |
|
plt.yticks(fontname="Arial", fontsize=14, color="black") |
|
plt.xticks(range(5, 31, 5), fontname="Arial", fontsize=14, color = "black") |
|
|
|
|
|
xticks = ax.get_xticks() |
|
xticklabels = [str(int(x)) if x != xticks[-2] else "mm" for x in xticks] |
|
ax.set_xticklabels(xticklabels) |
|
|
|
|
|
yticks = ax.get_yticks() |
|
|
|
yticklabels = yticks = ["" for y in yticks] |
|
ax.set_yticklabels(yticklabels) |
|
|
|
|
|
gridwidth = 1.5 |
|
plt.grid(axis="y", linewidth=0.75, color="black") |
|
plt.grid(axis="x", linewidth=0.75, color="black") |
|
|
|
rand = ["top", "right", "bottom", "left"] |
|
for i in rand: |
|
plt.gca().spines[i].set_linewidth(gridwidth) |
|
ax.spines[i].set_color('black') |
|
|
|
plt.plot(x, reversed_values, color="#00509b", linewidth=1.5) |
|
|
|
|
|
tolerance_lower = -0.085 |
|
tolerance_upper = 0.085 |
|
|
|
ax.fill_between(x, tolerance_lower, tolerance_upper, color='gray', alpha=0.2) |
|
|
|
|
|
within_tolerance = all(tolerance_lower <= val <= tolerance_upper for val in reversed_values) |
|
|
|
tolerance = None |
|
if within_tolerance: |
|
tolerance = True |
|
else: |
|
tolerance = False |
|
|
|
return fig, tolerance |
|
|
|
def save_figure(fig, filename): |
|
"""Save the figure as an image file.""" |
|
fig.savefig(filename) |
|
plt.close(fig) |
|
|
|
|
|
|
|
@app.post("/predict") |
|
def prediction(feedrate: float, depth_of_cut: float, toolwear: float): |
|
new_features = torch.tensor([[feedrate, depth_of_cut, toolwear]]) |
|
predicted_image = predict_image(new_features) |
|
fig, tolerance = image_to_ts(predicted_image[0, 0]) |
|
|
|
figure_filename = "output_figure.png" |
|
|
|
|
|
figure_url = f"/get_figure/{figure_filename}" |
|
|
|
|
|
return {"figure_url": figure_url, "within_tolerance": tolerance} |
|
|
|
@app.get("/get_figure/{filename}") |
|
def get_figure(filename: str): |
|
return FileResponse(filename, media_type="image/png", filename=filename) |
|
|
|
|
|
@app.get("/") |
|
def root(): |
|
return {"Doku hier"} |
|
|
|
|
|
@app.get("/test") |
|
def test(): |
|
return {"das ist ein Test"} |