|
from transformers import AutoConfig, AutoModel |
|
from transformers import PreTrainedModel, PretrainedConfig |
|
import torch.nn as nn |
|
import torch |
|
|
|
class ArchitectureConfig(PretrainedConfig): |
|
model_type = "architecture" |
|
|
|
def __init__(self, **kwargs): |
|
super().__init__(**kwargs) |
|
self.input_size = kwargs.get("input_size", 9) |
|
self.hidden_size_1 = kwargs.get("hidden_size_1", 9) |
|
self.hidden_size_2 = kwargs.get("hidden_size_2", 9) |
|
self.hidden_size_3 = kwargs.get("hidden_size_3", 9) |
|
self.hidden_size_4 = kwargs.get("hidden_size_4", 9) |
|
self.hidden_size_5 = kwargs.get("hidden_size_5", 9) |
|
self.hidden_size_6 = kwargs.get("hidden_size_6", 9) |
|
self.hidden_size_7 = kwargs.get("hidden_size_7", 9) |
|
self.output_size = kwargs.get("output_size", 9) |
|
|
|
class Architecture(PreTrainedModel): |
|
config_class = ArchitectureConfig |
|
|
|
def __init__(self, config: ArchitectureConfig): |
|
super().__init__(config) |
|
self.input_size = config.input_size |
|
self.hidden_size_1 = config.hidden_size_1 |
|
self.hidden_size_2 = config.hidden_size_2 |
|
self.hidden_size_3 = config.hidden_size_3 |
|
self.hidden_size_4 = config.hidden_size_4 |
|
self.hidden_size_5 = config.hidden_size_5 |
|
self.hidden_size_6 = config.hidden_size_6 |
|
self.hidden_size_7 = config.hidden_size_7 |
|
self.output_size = config.output_size |
|
|
|
self.fc1 = nn.Linear(self.input_size, self.hidden_size_1) |
|
self.fc2 = nn.Linear(self.hidden_size_1, self.hidden_size_2) |
|
self.fc3 = nn.Linear(self.hidden_size_2, self.hidden_size_3) |
|
self.fc4 = nn.Linear(self.hidden_size_3, self.hidden_size_4) |
|
self.fc5 = nn.Linear(self.hidden_size_4, self.hidden_size_5) |
|
self.fc6 = nn.Linear(self.hidden_size_5, self.hidden_size_6) |
|
self.fc7 = nn.Linear(self.hidden_size_6, self.hidden_size_7) |
|
self.fc8 = nn.Linear(self.hidden_size_7, self.output_size) |
|
|
|
self.relu = nn.ReLU() |
|
|
|
def forward(self, x): |
|
x1 = self.relu(self.fc1(x)) |
|
x2 = self.relu(self.fc2(x1)) |
|
x3 = self.relu(self.fc3(x2)) |
|
x4 = self.relu(self.fc4(x3)) |
|
x5 = self.relu(self.fc5(x4)) |
|
x6 = self.relu(self.fc6(x5)) |
|
x7 = self.relu(self.fc7(x6)) |
|
x8 = self.fc8(x7) |
|
return x8 |
|
|
|
def inference(self, x): |
|
return self.forward(x) |
|
|
|
|
|
def load_model(): |
|
AutoConfig.register("architecture", ArchitectureConfig) |
|
AutoModel.register(ArchitectureConfig, Architecture) |
|
config = ArchitectureConfig() |
|
model = Architecture(config) |
|
model.load_state_dict(torch.load('./model_weights.pth')) |
|
return model |
|
|
|
load_model() |
|
|