Upload 2 files
Browse files- main.py +133 -0
- trained_model.pth +3 -0
main.py
ADDED
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import FastAPI
|
2 |
+
from typing import List
|
3 |
+
import torch
|
4 |
+
from PIL import Image
|
5 |
+
import numpy as np
|
6 |
+
import torch.nn as nn
|
7 |
+
import matplotlib.pyplot as plt
|
8 |
+
import matplotlib as mpl
|
9 |
+
from fastapi.responses import FileResponse
|
10 |
+
import shutil
|
11 |
+
|
12 |
+
|
13 |
+
app = FastAPI()
|
14 |
+
|
15 |
+
|
16 |
+
class CustomModel(nn.Module):
|
17 |
+
def __init__(self):
|
18 |
+
super(CustomModel, self).__init__()
|
19 |
+
|
20 |
+
self.fc_layers = nn.Sequential(
|
21 |
+
nn.Linear(3, 64),
|
22 |
+
nn.ReLU(),
|
23 |
+
nn.Linear(64, 744 * 554)
|
24 |
+
)
|
25 |
+
|
26 |
+
def forward(self, x_features):
|
27 |
+
x_features = self.fc_layers(x_features)
|
28 |
+
output = x_features.view(-1, 1, 744, 554)
|
29 |
+
return output
|
30 |
+
|
31 |
+
|
32 |
+
def predict_image(parameters):
|
33 |
+
|
34 |
+
"""Predicts an image based on parameters (feedrate, depth of cut, toolwear)"""
|
35 |
+
# Load the trained model
|
36 |
+
model = CustomModel()
|
37 |
+
model.load_state_dict(torch.load('trained_model.pth'))
|
38 |
+
model.eval()
|
39 |
+
|
40 |
+
with torch.no_grad():
|
41 |
+
input_features = torch.tensor(parameters, dtype=torch.float32)
|
42 |
+
predicted_image = model(input_features.unsqueeze(0))
|
43 |
+
return predicted_image.numpy()
|
44 |
+
|
45 |
+
def image_to_ts(image):
|
46 |
+
|
47 |
+
"""Transforms an image to a time series and returns the plot"""
|
48 |
+
|
49 |
+
# Extract the pixel values from the image
|
50 |
+
z_values = image[0, :]
|
51 |
+
reversed_values = z_values
|
52 |
+
reversed_values = (reversed_values - 0.5)
|
53 |
+
|
54 |
+
x = np.arange(len(reversed_values)) / len(reversed_values) * 25 + 5
|
55 |
+
|
56 |
+
# Plot the time series
|
57 |
+
fig, ax = plt.subplots(figsize=(8, 5))
|
58 |
+
|
59 |
+
ax.set_ylim(-0.2, 0.2)
|
60 |
+
ax.set_xlim(5, 30)
|
61 |
+
|
62 |
+
mpl.rcParams['font.family'] = 'Arial'
|
63 |
+
mpl.rcParams['font.size'] = 30
|
64 |
+
|
65 |
+
ax.set_xlabel("Workpiece length", fontname="Arial", fontsize=16, labelpad=7)
|
66 |
+
ax.set_ylabel("Normalized surface height", fontname="Arial", fontsize=16, labelpad=7)
|
67 |
+
plt.yticks(fontname="Arial", fontsize=14, color="black")
|
68 |
+
plt.xticks(range(5, 31, 5), fontname="Arial", fontsize=14, color = "black")
|
69 |
+
#plt.title("Surface geometry",fontname="Arial", fontsize=18, color="black", weight="bold", pad=10)
|
70 |
+
|
71 |
+
xticks = ax.get_xticks()
|
72 |
+
xticklabels = [str(int(x)) if x != xticks[-2] else "mm" for x in xticks]
|
73 |
+
ax.set_xticklabels(xticklabels)
|
74 |
+
|
75 |
+
|
76 |
+
yticks = ax.get_yticks()
|
77 |
+
#yticklabels = [str(int(y)) if y != yticks[-2] else "" for y in yticks]
|
78 |
+
yticklabels = yticks = ["" for y in yticks]
|
79 |
+
ax.set_yticklabels(yticklabels)
|
80 |
+
|
81 |
+
|
82 |
+
gridwidth = 1.5
|
83 |
+
plt.grid(axis="y", linewidth=0.75, color="black")
|
84 |
+
plt.grid(axis="x", linewidth=0.75, color="black")
|
85 |
+
|
86 |
+
rand = ["top", "right", "bottom", "left"]
|
87 |
+
for i in rand:
|
88 |
+
plt.gca().spines[i].set_linewidth(gridwidth)
|
89 |
+
ax.spines[i].set_color('black')
|
90 |
+
|
91 |
+
plt.plot(x, reversed_values, color="#00509b", linewidth=1.5)
|
92 |
+
|
93 |
+
# Define the tolerance range
|
94 |
+
tolerance_lower = -0.085
|
95 |
+
tolerance_upper = 0.085
|
96 |
+
|
97 |
+
ax.fill_between(x, tolerance_lower, tolerance_upper, color='gray', alpha=0.2)
|
98 |
+
|
99 |
+
# Check if the plot is within tolerance
|
100 |
+
within_tolerance = all(tolerance_lower <= val <= tolerance_upper for val in reversed_values)
|
101 |
+
|
102 |
+
tolerance = None
|
103 |
+
if within_tolerance:
|
104 |
+
tolerance = True
|
105 |
+
else:
|
106 |
+
tolerance = False
|
107 |
+
|
108 |
+
return fig, tolerance
|
109 |
+
|
110 |
+
def save_figure(fig, filename):
|
111 |
+
"""Save the figure as an image file."""
|
112 |
+
fig.savefig(filename)
|
113 |
+
plt.close(fig)
|
114 |
+
|
115 |
+
|
116 |
+
|
117 |
+
@app.post("/predict")
|
118 |
+
def prediction(feedrate: float, depth_of_cut: float, toolwear: float):
|
119 |
+
new_features = torch.tensor([[feedrate, depth_of_cut, toolwear]])
|
120 |
+
predicted_image = predict_image(new_features)
|
121 |
+
fig, tolerance = image_to_ts(predicted_image[0, 0])
|
122 |
+
|
123 |
+
figure_filename = "output_figure.png"
|
124 |
+
save_figure(fig, figure_filename)
|
125 |
+
|
126 |
+
figure_url = f"/get_figure/{figure_filename}"
|
127 |
+
|
128 |
+
#return {"figure": fig, "within_tolerance": tolerance}
|
129 |
+
return {"figure_url": figure_url, "within_tolerance": tolerance}
|
130 |
+
|
131 |
+
@app.get("/get_figure/{filename}")
|
132 |
+
def get_figure(filename: str):
|
133 |
+
return FileResponse(filename, media_type="image/png", filename=filename)
|
trained_model.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0e50218a972bb12b3b3dc1c1f17a231aad7e27d47f0c28d475a492cea9ab84a5
|
3 |
+
size 107168579
|