|
import torch |
|
import torch.nn as nn |
|
|
|
class ExpressionCNN(nn.Module): |
|
def __init__(self, num_classes=7): |
|
super(ExpressionCNN, self).__init__() |
|
self.conv = nn.Sequential( |
|
nn.Conv2d(1, 32, 3, padding=1), nn.ReLU(), nn.BatchNorm2d(32), nn.MaxPool2d(2), |
|
nn.Conv2d(32, 64, 3, padding=1), nn.ReLU(), nn.BatchNorm2d(64), nn.MaxPool2d(2), |
|
nn.Conv2d(64, 128, 3, padding=1), nn.ReLU(), nn.BatchNorm2d(128), nn.MaxPool2d(2), |
|
nn.Conv2d(128, 256, 3, padding=1), nn.ReLU(), nn.BatchNorm2d(256), nn.AdaptiveAvgPool2d((1, 1)) |
|
) |
|
self.fc = nn.Sequential( |
|
nn.Flatten(), |
|
nn.Linear(256, num_classes) |
|
) |
|
|
|
def forward(self, x): |
|
x = self.conv(x) |
|
x = self.fc(x) |
|
return x |
|
|
|
def load_model(model_path, device): |
|
model = ExpressionCNN() |
|
model.load_state_dict(torch.load(model_path, map_location=device)) |
|
model.to(device) |
|
model.eval() |
|
return model |