Spaces:
Sleeping
Sleeping
File size: 2,771 Bytes
5d92054 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 |
# -*- coding: utf-8 -*-
"""
Z-Location Estimator Model for Deployment
Created on Mon May 23 04:55:50 2022
@author: ODD_team
Edited by our team : Sat Oct 4 11:00 PM 2024
@based on LSTM model
"""
import torch
import torch.nn as nn
from config import CONFIG
device = CONFIG['device']
# Define the LSTM-based Z-location estimator model
class Zloc_Estimator(nn.Module):
def __init__(self, input_dim, hidden_dim, layer_dim):
super(Zloc_Estimator, self).__init__()
# LSTM layer
self.rnn = nn.LSTM(input_dim, hidden_dim, layer_dim, batch_first=True, bidirectional=False)
# Fully connected layers
layersize = [306, 154, 76]
layerlist = []
n_in = hidden_dim
for i in layersize:
layerlist.append(nn.Linear(n_in, i))
layerlist.append(nn.ReLU())
n_in = i
layerlist.append(nn.Linear(layersize[-1], 1)) # Final output layer
self.fc = nn.Sequential(*layerlist)
def forward(self, x):
out, hn = self.rnn(x)
output = self.fc(out[:, -1]) # Get the last output for prediction
return output
# Deployment-ready class for handling the model
class LSTM_Model:
def __init__(self):
"""
Initializes the LSTM model for deployment with predefined parameters
and loads the pre-trained model weights.
:param model_path: Path to the pre-trained model weights file (.pth)
"""
self.input_dim = 15
self.hidden_dim = 612
self.layer_dim = 3
# Initialize the Z-location estimator model
self.model = Zloc_Estimator(self.input_dim, self.hidden_dim, self.layer_dim)
# Load the state dictionary from the file, using map_location in torch.load()
state_dict = torch.load(CONFIG['lstm_model_path'], map_location=device)
# Load the model with the state dictionary
self.model.load_state_dict(state_dict, strict=False)
self.model.to(device) # This line ensures the model is moved to the right device
self.model.eval() # Set the model to evaluation mode
def predict(self, data):
"""
Predicts the z-location based on input data.
:param data: Input tensor of shape (batch_size, input_dim)
:return: Predicted z-location as a tensor
"""
with torch.no_grad(): # Disable gradient computation for deployment
data = data.to(device) # Move data to the appropriate device
data = data.reshape(-1, 1, self.input_dim) # Reshape data to (batch_size, sequence_length, input_dim)
zloc = self.model(data)
return zloc.cpu() # Return the output in CPU memory for further processing
|