Spaces:
Sleeping
Sleeping
#!/usr/bin/env python | |
# -*- coding: utf-8 -*- | |
""" | |
Execution: | |
python3 -W ignore model.py | |
""" | |
import numpy as np | |
import torch | |
# Meta data | |
feedrate = ["477", "955", "1242"] # [mm/min] | |
depth_of_cut = ["0.1", "0.25", "0.5"] # [mm] | |
condition = ["new", "slightly used", "heavily used"] | |
# Trained model | |
model = None | |
def load_model(pathToModel: str): | |
""" | |
Load torch model | |
Args: | |
pathToModel (str): (Full-)Path to pytorch model. | |
""" | |
# Load torch model | |
global model | |
model = torch.jit.load(pathToModel) | |
model.eval() | |
model.cpu() | |
def classify(data: torch.Tensor): | |
""" | |
Classifies data based on loaded model. | |
Args: | |
data (torch.Tensor): Data to classify. Structure [#tasks, 1, 500]. | |
Returns: | |
An array with dim [9), separated by indices [0-3), [3,6), [6,9). | |
Representing the classes "feedrate", "depth of cut" and "condition". | |
"feerate": 477, 955, 1242 mm/min | |
"depth of cut": 0.1, 0.25, 0.5 mm | |
"condition": new, slightly used, heavily used | |
Based on data separation, the "condition" was equally compressed among | |
the data. | |
""" | |
try: | |
load_model("processing/model_d2_v0.7.pt") | |
except: | |
pass | |
with torch.no_grad(): | |
# Data transformation and prediction | |
data = torch.tensor(data, dtype=torch.float).to("cpu") | |
prediction = model(data) | |
prediction = torch.nn.functional.sigmoid(prediction) | |
# Summed classification | |
summed_pred = (prediction.sum(axis=0) / len(prediction)).numpy()[0] | |
# Prediction class for label set | |
res_pred = (np.argmax(summed_pred[0:3]).item(), | |
np.argmax(summed_pred[3:6]).item(), | |
np.argmax(summed_pred[6:len(summed_pred)]).item()) | |
res_con = f"Feedrate [mm/min]: {feedrate[res_pred[0]]}, "\ | |
f"Depth of cut [mm]: {depth_of_cut[res_pred[1]]}, "\ | |
f"Condition: {condition[res_pred[2]]}" | |
# Label representation | |
res_pred = (res_pred[0], res_pred[1] + 3, res_pred[2] + 6) | |
res_label = torch.zeros(len(prediction[0][0])) | |
for x in res_pred: | |
res_label[x] = 1 | |
return summed_pred, res_label, res_con | |