ifw-arz commited on
Commit
e382984
1 Parent(s): 64fdaee

Upload 2 files

Browse files
Files changed (2) hide show
  1. main.py +133 -0
  2. trained_model.pth +3 -0
main.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI
2
+ from typing import List
3
+ import torch
4
+ from PIL import Image
5
+ import numpy as np
6
+ import torch.nn as nn
7
+ import matplotlib.pyplot as plt
8
+ import matplotlib as mpl
9
+ from fastapi.responses import FileResponse
10
+ import shutil
11
+
12
+
13
+ app = FastAPI()
14
+
15
+
16
+ class CustomModel(nn.Module):
17
+ def __init__(self):
18
+ super(CustomModel, self).__init__()
19
+
20
+ self.fc_layers = nn.Sequential(
21
+ nn.Linear(3, 64),
22
+ nn.ReLU(),
23
+ nn.Linear(64, 744 * 554)
24
+ )
25
+
26
+ def forward(self, x_features):
27
+ x_features = self.fc_layers(x_features)
28
+ output = x_features.view(-1, 1, 744, 554)
29
+ return output
30
+
31
+
32
+ def predict_image(parameters):
33
+
34
+ """Predicts an image based on parameters (feedrate, depth of cut, toolwear)"""
35
+ # Load the trained model
36
+ model = CustomModel()
37
+ model.load_state_dict(torch.load('trained_model.pth'))
38
+ model.eval()
39
+
40
+ with torch.no_grad():
41
+ input_features = torch.tensor(parameters, dtype=torch.float32)
42
+ predicted_image = model(input_features.unsqueeze(0))
43
+ return predicted_image.numpy()
44
+
45
+ def image_to_ts(image):
46
+
47
+ """Transforms an image to a time series and returns the plot"""
48
+
49
+ # Extract the pixel values from the image
50
+ z_values = image[0, :]
51
+ reversed_values = z_values
52
+ reversed_values = (reversed_values - 0.5)
53
+
54
+ x = np.arange(len(reversed_values)) / len(reversed_values) * 25 + 5
55
+
56
+ # Plot the time series
57
+ fig, ax = plt.subplots(figsize=(8, 5))
58
+
59
+ ax.set_ylim(-0.2, 0.2)
60
+ ax.set_xlim(5, 30)
61
+
62
+ mpl.rcParams['font.family'] = 'Arial'
63
+ mpl.rcParams['font.size'] = 30
64
+
65
+ ax.set_xlabel("Workpiece length", fontname="Arial", fontsize=16, labelpad=7)
66
+ ax.set_ylabel("Normalized surface height", fontname="Arial", fontsize=16, labelpad=7)
67
+ plt.yticks(fontname="Arial", fontsize=14, color="black")
68
+ plt.xticks(range(5, 31, 5), fontname="Arial", fontsize=14, color = "black")
69
+ #plt.title("Surface geometry",fontname="Arial", fontsize=18, color="black", weight="bold", pad=10)
70
+
71
+ xticks = ax.get_xticks()
72
+ xticklabels = [str(int(x)) if x != xticks[-2] else "mm" for x in xticks]
73
+ ax.set_xticklabels(xticklabels)
74
+
75
+
76
+ yticks = ax.get_yticks()
77
+ #yticklabels = [str(int(y)) if y != yticks[-2] else "" for y in yticks]
78
+ yticklabels = yticks = ["" for y in yticks]
79
+ ax.set_yticklabels(yticklabels)
80
+
81
+
82
+ gridwidth = 1.5
83
+ plt.grid(axis="y", linewidth=0.75, color="black")
84
+ plt.grid(axis="x", linewidth=0.75, color="black")
85
+
86
+ rand = ["top", "right", "bottom", "left"]
87
+ for i in rand:
88
+ plt.gca().spines[i].set_linewidth(gridwidth)
89
+ ax.spines[i].set_color('black')
90
+
91
+ plt.plot(x, reversed_values, color="#00509b", linewidth=1.5)
92
+
93
+ # Define the tolerance range
94
+ tolerance_lower = -0.085
95
+ tolerance_upper = 0.085
96
+
97
+ ax.fill_between(x, tolerance_lower, tolerance_upper, color='gray', alpha=0.2)
98
+
99
+ # Check if the plot is within tolerance
100
+ within_tolerance = all(tolerance_lower <= val <= tolerance_upper for val in reversed_values)
101
+
102
+ tolerance = None
103
+ if within_tolerance:
104
+ tolerance = True
105
+ else:
106
+ tolerance = False
107
+
108
+ return fig, tolerance
109
+
110
+ def save_figure(fig, filename):
111
+ """Save the figure as an image file."""
112
+ fig.savefig(filename)
113
+ plt.close(fig)
114
+
115
+
116
+
117
+ @app.post("/predict")
118
+ def prediction(feedrate: float, depth_of_cut: float, toolwear: float):
119
+ new_features = torch.tensor([[feedrate, depth_of_cut, toolwear]])
120
+ predicted_image = predict_image(new_features)
121
+ fig, tolerance = image_to_ts(predicted_image[0, 0])
122
+
123
+ figure_filename = "output_figure.png"
124
+ save_figure(fig, figure_filename)
125
+
126
+ figure_url = f"/get_figure/{figure_filename}"
127
+
128
+ #return {"figure": fig, "within_tolerance": tolerance}
129
+ return {"figure_url": figure_url, "within_tolerance": tolerance}
130
+
131
+ @app.get("/get_figure/{filename}")
132
+ def get_figure(filename: str):
133
+ return FileResponse(filename, media_type="image/png", filename=filename)
trained_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0e50218a972bb12b3b3dc1c1f17a231aad7e27d47f0c28d475a492cea9ab84a5
3
+ size 107168579