Demo_MSE-CNN / msecnn.py
kevinmevin's picture
add Main modules
5e85b1b
"""@package docstring
@file MSECNN.py
@brief Group of functions and classes that directly contribute for the implementation of the loss function and MSE-CNN
@section libraries_MSECNN Libraries
- torch
- train_model_utils
- math
- numpy
@section classes_MSECNN Classes
- QP_half_mask
- MseCnnStg1
- MseCnnStgX
- LossFunctionMSE
- LossFunctionMSE_Ratios
@section functions_MSECNN Functions
- init_weights_seq(m)
- init_weights_single(m)
@section global_vars_MSECNN Global Variables
- None
@section todo_MSECNN TODO
- None
@section license License
MIT License
Copyright (c) 2022 Raul Kevin do Espirito Santo Viana
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
@section author_MSECNN Author(s)
- Created by Raul Kevin Viana
- Last time modified is 2023-01-29 22:22:04.142369
"""
# ==============================================================
# Imports
# ==============================================================
import torch
from torch import nn
import math
import numpy as np
import train_model_utils
# ==============================================================
# Functions
# ==============================================================
def init_weights_seq(m):
"""!
@brief Initializes a given sequential model with the Xavier Initialization (with uniform distribution)
@param [in] m: Model to initiliaze the weights
"""
if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d) or isinstance(m, nn.PReLU):
torch.nn.init.xavier_uniform(m.weight)
def init_weights_single(m):
"""!
@brief Initializes a given layer with the Xavier Initialization (with uniform distribution)
@param [in] m: Layer to initiliaze the weights
"""
torch.nn.init.xavier_uniform(m.weight)
# ==============================================================
# Classes
# ==============================================================
class QP_half_mask(nn.Module):
def __init__(self, QP=32):
super(QP_half_mask, self).__init__()
# Initialize varible
self.QP = QP
def normalize_QP(self, QP):
"""!
@brief Normalize the QP value
@param [in] QP: QP value not normalized
@param [out] q_tilde: Normalized value of QP
"""
q_tilde = QP / 51 + 0.5
return q_tilde
def forward(self, feature_maps):
"""!
@brief This function implements the QP half mask operation
@param [in] feature_maps: Variable with the feature maps
@param [out] new_feature_maps: Output which differs from the input by half of the feature
"""
# Normalize QP
q_tilde = self.normalize_QP(self.QP)
# Multiply half of the feature_maps by the normalized QP
new_feature_maps = feature_maps
new_feature_maps_size = new_feature_maps.size()
dim_change = 1
half_num = new_feature_maps_size[dim_change] // 2
half_1, half_2 = torch.split(new_feature_maps, (half_num, half_num), dim_change)
half_2 = half_2 * q_tilde
new_feature_maps = torch.cat((half_1, half_2), dim=dim_change)
return new_feature_maps
# Model for stage 1
class MseCnnStg1(nn.Module):
def __init__(self, device="cpu", QP=32):
super(MseCnnStg1, self).__init__()
# Initializing variables
self.first_simple_conv = nn.Sequential(
# nn.BatchNorm2D(1), # Consider adding normalization of the input, even knowing it's not mentioned in the paper
nn.Conv2d(in_channels=1, out_channels=16, kernel_size=3, stride=1, padding='same', device=device),
nn.PReLU(num_parameters=1, init=0.2, device=device)
) # Simple convolution with activation and padding
# Conditional convolution stg 1
self.simple_conv_stg1 = nn.Sequential(
nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3, stride=1, padding='same', device=device),
nn.PReLU(num_parameters=1, init=0.2, device=device)
) # Simple convolution with activation and padding
self.simple_conv_no_activation_stg1 = nn.Sequential(
nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3, stride=1, padding='same', device=device)
) # Simple convolution with no activation and padding
self.simple_conv2_stg1 = nn.Sequential(
nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3, stride=1, padding='same', device=device),
nn.PReLU(num_parameters=1, init=0.2, device=device)
) # Simple convolution with activation and padding
self.simple_conv_no_activation2_stg1 = nn.Sequential(
nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3, stride=1, padding='same', device=device)
) # Simple convolution with no activation and padding
self.activation_PRelu_stg1 = nn.PReLU(num_parameters=1, init=0.2, device=device) # Parametric
self.activation_PRelu2_stg1 = nn.PReLU(num_parameters=1, init=0.2, device=device) # Parametric
# Conditional convolution stg 2
self.simple_conv_stg2 = nn.Sequential(
nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3, stride=1, padding='same', device=device),
nn.PReLU(num_parameters=1, init=0.2, device=device)
) # Simple convolution with activation and padding
self.simple_conv_no_activation_stg2 = nn.Sequential(
nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3, stride=1, padding='same', device=device)
) # Simple convolution with no activation and padding
self.simple_conv2_stg2 = nn.Sequential(
nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3, stride=1, padding='same', device=device),
nn.PReLU(num_parameters=1, init=0.2, device=device)
) # Simple convolution with activation and padding
self.simple_conv_no_activation2_stg2 = nn.Sequential(
nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3, stride=1, padding='same', device=device)
) # Simple convolution with no activation and padding
self.activation_PRelu_stg2 = nn.PReLU(num_parameters=1, init=0.2, device=device) # Parametric
self.activation_PRelu2_stg2 = nn.PReLU(num_parameters=1, init=0.2, device=device) # Parametric
# Initialiaze networks with Xavier initialization
init_weights_seq(self.first_simple_conv)
init_weights_seq(self.simple_conv_stg1)
init_weights_seq(self.simple_conv_no_activation_stg1)
init_weights_seq(self.simple_conv_stg2)
init_weights_seq(self.simple_conv_no_activation_stg2)
init_weights_seq(self.simple_conv2_stg1)
init_weights_seq(self.simple_conv_no_activation2_stg1)
init_weights_seq(self.simple_conv2_stg2)
init_weights_seq(self.simple_conv_no_activation2_stg2)
# Sub-networks
self.sub_net = nn.Sequential(QP_half_mask(QP),
nn.Conv2d(in_channels=16, out_channels=8, kernel_size=4, stride=4,
padding='valid', device=device),
nn.PReLU(num_parameters=8, init=0.2, device=device),
nn.Conv2d(in_channels=8, out_channels=8, kernel_size=4, stride=4, padding='valid', device=device),
nn.PReLU(num_parameters=8, init=0.2, device=device),
QP_half_mask(QP),
nn.Flatten(),
nn.Linear(128, 8, device=device),
nn.PReLU(num_parameters=1, init=0.2, device=device),
nn.Linear(8, 6, device=device),
nn.Softmax(dim=1))
# Initialize weights
init_weights_seq(self.sub_net)
def residual_unit_stg1(self, x, nr):
"""!
@brief Generic residual unit
@param [in] x: Input of the network
@param [in] nr: Number of residual units
"""
x_shortcut = x # Copy the initial value
if nr == 1:
# Residual unit 1
x = self.simple_conv_stg1(x)
x = self.simple_conv_no_activation_stg1(x)
x = torch.add(x_shortcut, x) # Adding the initial value with the new processed value
x = self.activation_PRelu_stg1(x)
elif nr == 2:
# Residual unit 1
x = self.simple_conv_stg1(x)
x = self.simple_conv_no_activation_stg1(x)
x = torch.add(x_shortcut, x) # Adding the initial value with the new processed value
x = self.activation_PRelu_stg1(x)
# Residual unit 2
x_shortcut = x
x = self.simple_conv2_stg1(x)
x = self.simple_conv_no_activation2_stg1(x)
x = torch.add(x_shortcut, x) # Adding the initial value with the new processed value
x = self.activation_PRelu2_stg1(x)
else:
pass
return x
def residual_unit_stg2(self, x, nr):
"""!
@brief Generic residual unit
@param [in] x: Input of the network
@param [in] nr: Number of residual units
"""
x_shortcut = x # Copy the initial value
if nr == 1:
# Residual unit 1
x = self.simple_conv_stg2(x)
x = self.simple_conv_no_activation_stg2(x)
x = torch.add(x_shortcut, x) # Adding the initial value with the new processed value
x = self.activation_PRelu_stg2(x)
elif nr == 2:
# Residual unit 1
x = self.simple_conv_stg2(x)
x = self.simple_conv_no_activation_stg2(x)
x = torch.add(x_shortcut, x) # Adding the initial value with the new processed value
x = self.activation_PRelu_stg2(x)
# Residual unit 2
x_shortcut = x
x = self.simple_conv2_stg2(x)
x = self.simple_conv_no_activation2_stg2(x)
x = torch.add(x_shortcut, x) # Adding the initial value with the new processed value
x = self.activation_PRelu2_stg2(x)
else:
pass
return x
def nr_calc(self, ac, ap):
"""!
@brief Calculate the number of residual units
@param [in] ac: Minimum value of the current input axises
@param [in] ap: Minimum value of the parent input axises
@param [out] nr: Number of residual units
"""
nr = 0
if (ac != 0):
if (ac == 128):
nr = 1
elif (ap != 0):
if (4 <= ac <= 64):
nr = int(math.log2(ap / ac))
else:
raise Exception("ac with invalid number! ac =", ac)
else:
raise Exception("ap with invalid number! ap =", ap)
else:
raise Exception("ac with invalid number! ac =", ac)
return nr
def split(self, cu, coords, sizes, split):
"""!
@brief Splits feature maps in specific way
@param [in] cu: Input to the model
@param [in] coords: Coordinates of the new CUs
@param [in] size: Size of the new CUs
@param [in] split: Way to split CU
@param [out] cu_out: New Feature maps
"""
# Initizalize list
cus_list = []
for i in range(coords.shape[0]):
if split[i] == 0: # Non-split
cus_list.append(cu[i, :, :, :].unsqueeze(dim=0))
elif split[i] == 1: # Quad-tree
# Split CU and add to list
cus_list.append(cu[i, :, coords[i, 0]: coords[i, 0] + sizes[i, 0], coords[i, 1]: coords[i, 1]+sizes[i, 1]].unsqueeze(dim=0))
elif split[i] == 2: # HBT
# Split CU and add to list
cus_list.append(cu[i, :, coords[i, 0]: coords[i, 0] + sizes[i, 0], :].unsqueeze(dim=0))
elif split[i] == 3: # VBT
# Split CU and add to list
cus_list.append(cu[i, :, :, coords[i, 1]: coords[i, 1] + sizes[i, 1]].unsqueeze(dim=0))
elif split[i] == 4: # HTT
# Split CU and add to list
cus_list.append(cu[i, :, coords[i, 0]: coords[i, 0] + sizes[i, 0], :].unsqueeze(dim=0))
elif split[i] == 5: # VTT
# Split CU and add to list
cus_list.append(cu[i, :, :, coords[i, 1]: coords[i, 1] + sizes[i, 1]].unsqueeze(dim=0))
else:
raise Exception("This can't happen! Wrong split mode number: ", str(split[i]))
cu_out = torch.cat(cus_list)
return cu_out
def forward(self, cu, sizes=None, coords=None):
"""!
@brief This functions propagates the input to the output
@param [in] cu: Input to the model
@param [out] logits: Vector of raw predictions that a classification model generates
"""
## First layer: Overlaping convolution, 3x3 kernel with zero-padding
cu = self.first_simple_conv(cu)
##Conditional Convolution stg 1
# Number of residual units
ac = min(cu.shape[-1], cu.shape[-2]) # Getting current minimum axis value
nr = self.nr_calc(ac, 128) # Number of residual units, possible values are 0, 1 and 2
cu = self.residual_unit_stg1(cu, nr)
# Split CU and get specific
if coords != None:
cu = self.split(cu, coords, sizes, np.ones(coords.shape[0])) # Split in Quad tree
##Conditional Convolution stg 2
# Number of residual units
ac = min(cu.shape[-1], cu.shape[-2]) # Getting current minimum axis value
nr = self.nr_calc(ac, 128) # Number of residual units, possible values are 0, 1 and 2
cu = self.residual_unit_stg2(cu, nr)
# Sub-networks
logits = self.sub_net(cu)
return logits, cu, ac
# Model for stage 3, 4, 5 and 6
class MseCnnStgX(MseCnnStg1):
def __init__(self, device="cpu", QP=32):
super(MseCnnStgX, self).__init__()
# Hide not needed variables
# Conditional convolution stg 1
self.simple_conv_stg1 = None # Simple convolution with activation and padding
self.simple_conv_no_activation_stg1 = None # Simple convolution with no activation and padding
self.simple_conv2_stg1 = None # Simple convolution with activation and padding
self.simple_conv_no_activation2_stg1 = None # Simple convolution with no activation and padding
self.activation_PRelu_stg1 = None # Parametric
self.activation_PRelu2_stg1 = None # Parametric
# Conditional convolution stg 2
self.simple_conv_stg2 = None # Simple convolution with activation and padding
self.simple_conv_no_activation_stg2 = None # Simple convolution with no activation and padding
self.simple_conv2_stg2 = None # Simple convolution with activation and padding
self.simple_conv_no_activation2_stg2 = None # Simple convolution with no activation and padding
self.activation_PRelu_stg2 = None # Parametric
self.activation_PRelu2_stg2 = None # Parametric
self.first_simple_conv = None
self.sub_net = None
## Sub-networks
# Convolutional layers
# Min 32
self.conv_32_64 = nn.Sequential(QP_half_mask(QP),
nn.Conv2d(in_channels=16, out_channels=16, kernel_size=(int(32/8), int(64/8)),
stride=(int(32/8), int(64/8)), padding='valid', device=device))
self.conv_32_32 = nn.Sequential(QP_half_mask(QP),
nn.Conv2d(in_channels=16, out_channels=16, kernel_size=(int(32/8), int(32/8)),
stride=(int(32/8), int(32/8)), padding='valid', device=device))
# Min 16
self.conv_16_16 = nn.Sequential(QP_half_mask(QP),
nn.Conv2d(in_channels=16, out_channels=16, kernel_size=(int(16/4), int(16/4)),
stride=(int(16 / 4), int(16 / 4)), padding='valid', device=device))
self.conv_16_32 = nn.Sequential(QP_half_mask(QP),
nn.Conv2d(in_channels=16, out_channels=16, kernel_size=(int(16/4), int(32/4)),
stride=(int(16 / 4), int(32 / 4)), padding='valid', device=device))
self.conv_16_64 = nn.Sequential(QP_half_mask(QP),
nn.Conv2d(in_channels=16, out_channels=16, kernel_size=(int(16/4), int(64/4)),
stride=(int(16/4), int(64/4)), padding='valid', device=device))
# Min 8
self.conv_8_8 = nn.Sequential(QP_half_mask(QP),
nn.Conv2d(in_channels=16, out_channels=16, kernel_size=(int(8/2), int(8/2)),
stride=(int(8/2), int(8/2)), padding='valid'))
self.conv_8_16 = nn.Sequential(QP_half_mask(QP),
nn.Conv2d(in_channels=16, out_channels=16, kernel_size=(int(8/2), int(16/2)),
stride=(int(8/2), int(16/2)), padding='valid'))
self.conv_8_32 = nn.Sequential(QP_half_mask(QP),
nn.Conv2d(in_channels=16, out_channels=16, kernel_size=(int(8/2), int(32/2)),
stride=(int(8/2), int(32/2)), padding='valid'))
self.conv_8_64 = nn.Sequential(QP_half_mask(QP),
nn.Conv2d(in_channels=16, out_channels=16, kernel_size=(int(8/2), int(64/2)),
stride=(int(8/2), int(64/2)), padding='valid'))
# Min 4
self.conv_4_32 = nn.Sequential(QP_half_mask(QP),
nn.Conv2d(in_channels=16, out_channels=16, kernel_size=(int(4/2), int(32/2)),
stride=(int(4/2), int(32/2)), padding='valid'))
self.conv_4_16 = nn.Sequential(QP_half_mask(QP),
nn.Conv2d(in_channels=16, out_channels=16, kernel_size=(int(4/2), int(16/2)),
stride=(int(4/2), int(16/2)), padding='valid'))
self.conv_4_8 = nn.Sequential(QP_half_mask(QP),
nn.Conv2d(in_channels=16, out_channels=16, kernel_size=(int(4/2), int(8/2)),
stride=(int(4/2), int(8/2)), padding='valid'))
self.conv_4_4 = nn.Sequential(QP_half_mask(QP),
nn.Conv2d(in_channels=16, out_channels=16, kernel_size=(int(4/2), int(4/2)),
stride=(int(4/2), int(4/2)), padding='valid'))
# Sub-networks
self.sub_net_min_32 = nn.Sequential(nn.PReLU(num_parameters=16, init=0.2, device=device),
nn.Conv2d(in_channels=16, out_channels=32, kernel_size=4, stride=4,
padding='valid', device=device),
nn.PReLU(num_parameters=32, init=0.2, device=device),
nn.Conv2d(in_channels=32, out_channels=128, kernel_size=2, stride=2,
padding='valid', device=device),
nn.PReLU(num_parameters=128, init=0.2, device=device),
QP_half_mask(QP),
nn.Flatten(),
nn.Linear(128, 64, device=device),
nn.PReLU(num_parameters=1, init=0.2, device=device),
nn.Linear(64, 6, device=device),
nn.Softmax(dim=1))
self.sub_net_min_16 = nn.Sequential(nn.PReLU(num_parameters=16, init=0.2, device=device),
nn.Conv2d(in_channels=16, out_channels=32, kernel_size=2, stride=2,
padding='valid', device=device),
nn.PReLU(num_parameters=32, init=0.2, device=device),
nn.Conv2d(in_channels=32, out_channels=64, kernel_size=2, stride=2,
padding='valid', device=device),
nn.PReLU(num_parameters=64, init=0.2, device=device),
QP_half_mask(QP),
nn.Flatten(),
nn.Linear(64, 64, device=device),
nn.PReLU(num_parameters=1, init=0.2, device=device),
nn.Linear(64, 6, device=device),
nn.Softmax(dim=1))
self.sub_net_min_8 = nn.Sequential(nn.PReLU(num_parameters=16, init=0.2, device=device),
nn.Conv2d(in_channels=16, out_channels=32, kernel_size=2, stride=2,
padding='valid'),
nn.PReLU(num_parameters=32, init=0.2, device=device),
QP_half_mask(QP),
nn.Flatten(),
nn.Linear(32, 16),
nn.PReLU(num_parameters=1, init=0.2, device=device),
nn.Linear(16, 6, device=device),
nn.Softmax(dim=1))
self.sub_net_min_4 = nn.Sequential(nn.PReLU(num_parameters=16, init=0.2, device=device),
nn.Conv2d(in_channels=16, out_channels=32, kernel_size=2, stride=2,
padding='valid'),
nn.PReLU(num_parameters=32, init=0.2, device=device),
QP_half_mask(QP),
nn.Flatten(),
nn.Linear(32, 16, device=device),
nn.PReLU(num_parameters=1, init=0.2, device=device),
nn.Linear(16, 6, device=device),
nn.Softmax(dim=1))
# Initialize weights
init_weights_seq(self.sub_net_min_32)
init_weights_seq(self.conv_32_64)
init_weights_seq(self.conv_32_32)
init_weights_seq(self.sub_net_min_16)
init_weights_seq(self.conv_16_64)
init_weights_seq(self.conv_16_32)
init_weights_seq(self.conv_16_16)
init_weights_seq(self.sub_net_min_8)
init_weights_seq(self.conv_8_64)
init_weights_seq(self.conv_8_32)
init_weights_seq(self.conv_8_16)
init_weights_seq(self.conv_8_8)
init_weights_seq(self.sub_net_min_4)
init_weights_seq(self.conv_4_32)
init_weights_seq(self.conv_4_16)
init_weights_seq(self.conv_4_8)
init_weights_seq(self.conv_4_4)
# Residual units
self.simple_conv = nn.Sequential(
nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3, stride=1, padding='same', device=device),
nn.PReLU(num_parameters=1, init=0.2, device=device)
) # Simple convolution with activation and padding
self.simple_conv_no_activation = nn.Sequential(
nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3, stride=1, padding='same', device=device)
) # Simple convolution with no activation and padding
self.activation_PRelu = nn.PReLU(num_parameters=1, init=0.2, device=device) # Parametric Relu activation
# function
self.simple_conv2 = nn.Sequential(
nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3, stride=1, padding='same', device=device),
nn.PReLU(num_parameters=1, init=0.2, device=device)
) # Simple convolution with activation and padding
self.simple_conv_no_activation2 = nn.Sequential(
nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3, stride=1, padding='same', device=device)
) # Simple convolution with no activation and padding
self.activation_PRelu2 = nn.PReLU(num_parameters=1, init=0.2, device=device) # Parametric Relu activation
# function
# Initialiaze networks with Xavier initialization
init_weights_seq(self.first_simple_conv)
init_weights_seq(self.simple_conv)
init_weights_seq(self.simple_conv_no_activation)
def residual_unit_stg1(self, x, nr):
"""!
@brief Generic residual unit
@param [in] x: Input of the network
@param [in] nr: Number of residual units
"""
pass
def residual_unit_stg2(self, x, nr):
"""!
@brief Generic residual unit
@param [in] x: Input of the network
@param [in] nr: Number of residual units
"""
pass
def residual_unit(self, x, nr):
"""!
@brief Generic residual unit
@param [in] x: Input of the network
@param [in] nr: Number of residual units
"""
x_shortcut = x # Copy the initial value
if nr == 1:
# Residual unit 1
x = self.simple_conv(x)
x = self.simple_conv_no_activation(x)
x = torch.add(x_shortcut, x) # Adding the initial value with the new processed value
x = self.activation_PRelu(x)
elif nr == 2:
# Residual unit 1
x = self.simple_conv(x)
x = self.simple_conv_no_activation(x)
x = torch.add(x_shortcut, x) # Adding the initial value with the new processed value
x = self.activation_PRelu(x)
# Residual unit 2
x_shortcut = x
x = self.simple_conv2(x)
x = self.simple_conv_no_activation2(x)
x = torch.add(x_shortcut, x) # Adding the initial value with the new processed value
x = self.activation_PRelu2(x)
else:
pass
return x
def pass_through_subnet(self, x):
"""!
@brief This functions propagates the it's input through a specific subnetwork depending on the shape of the input
@param [in] x: Input to the model
@param [out] logits: Vector of raw predictions that a classification model generates
"""
try:
# Obtain input shape: Make sure Height is smaller than Width
if x.shape[-2] < x.shape[-1]:
input_shape = (x.shape[-2], x.shape[-1]) # (Height, Width)
else:
input_shape = (x.shape[-1], x.shape[-2]) # (Height, Width)
# Initialize variable
logits = torch.tensor([])
if min(input_shape) == 64:
logits = self.sub_net(x)
elif min(input_shape) == 32:
if input_shape == (32, 64):
logits = self.conv_32_64(x)
else:
logits = self.conv_32_32(x)
logits = self.sub_net_min_32(logits)
elif min(input_shape) == 16:
if input_shape == (16, 64):
logits = self.conv_16_64(x)
elif input_shape == (16, 32):
logits = self.conv_16_32(x)
else:# input_shape == (16, 16):
logits = self.conv_16_16(x)
logits = self.sub_net_min_16(logits)
elif min(input_shape) == 8:
if input_shape == (8, 64):
logits = self.conv_8_64(x)
elif input_shape == (8, 32):
logits = self.conv_8_32(x)
elif input_shape == (8, 16):
logits = self.conv_8_16(x)
else: #input_shape == (8, 8):
logits = self.conv_8_8(x)
logits = self.sub_net_min_8(logits)
else:# min(input_shape) == 4:
if input_shape == (4, 32):
logits = self.conv_4_32(x)
elif input_shape == (4, 16):
logits = self.conv_4_16(x)
elif input_shape == (4, 8):
logits = self.conv_4_8(x)
else:
logits = self.conv_4_4(x)
logits = self.sub_net_min_4(logits)
except:
raise Exception()
return logits
def forward(self, cu, ap, splits=None, sizes=None, coords=None):
"""!
@brief This functions propagates the input to the output
@param [in] cu: Input to the model
@param [out] logits: Vector of raw predictions that a classification model generates
"""
# Split CU and get specific
if coords != None:
cu = self.split(cu, coords, sizes, splits)
##Conditional Convolution
# Number of residual units
ac = min(cu.shape[-1], cu.shape[-2]) # Getting current minimum axis value
nr = self.nr_calc(ac, ap) # Number of residual units, possible values are 0, 1 and 2
cu = self.residual_unit(cu, nr)
# Transpose if needed
if not train_model_utils.right_size(cu):
cu_transposed = torch.clone(torch.transpose(cu, -1, -2))
else:
cu_transposed = torch.clone(cu)
# Sub-networks
logits = self.pass_through_subnet(cu_transposed)
return logits, cu, ac
class LossFunctionMSE(nn.Module):
def __init__(self, use_mod_cross_entropy=True, beta=1):
super(LossFunctionMSE, self).__init__()
self.beta = beta
self.use_mod_cross_entropy = use_mod_cross_entropy
# Constants
self.MAX_RD = 1E10 # Valor encontrado no dataset: 1.7E304
self.MAX_LOSS = 20 # float("inf")
self.last_loss = -1
self.last_pred = -1
self.last_RD = -1
def get_proportion_CUs(self, labels):
"""!
@brief This function returns the proportion of CU's for all the types of split mode
@param [in] labels: Ground truth tensor
@param [out] p_m: Tensor with the proportions
"""
p_m = torch.sum(labels, dim=0)
p_m = torch.reshape(p_m / torch.sum(p_m), shape=(1, -1))
return p_m
def get_min_RDs(self, RDs):
"""!
@brief Obtain the lowest value that isnt zero from RDs tensor
@param [in] RDs: Tensor with RDs
@param [out] min_RD: Lowest value of RDs that isnt zero
"""
clone_RDs = RDs.clone()
mask = (clone_RDs == 0) # Mask with the values equal to zero
clone_RDs[mask] = float('inf') # Idx where there are zeros are substituted by infinity
min_RD = torch.reshape(torch.min(clone_RDs, dim=1)[0], shape=(-1, 1)) # Get minimum values of RD
return min_RD
def remove_values_lower(self, tensor, max_val, subst_val):
"""!
@brief Remove values from tensor that are lower than a given value
@param [in] tensor: Tensor with values
@param [in] max_val: Threshold val
@param [in] subst_val: Max value to replace the others
@param [out] tensor: New tensor
"""
mask = (tensor < max_val) # Mask with the values equal to zero
tensor[mask] = subst_val # Idx where there are zeros are substituted by infinity
return tensor
def remove_inf_values(self, tensor):
"""!
@brief Remove values from tensor that are inf and make them zeros
@param [in] tensor: Tensor with values
@param [out] tensor: New tensor
"""
clone_RDs = tensor.clone()
mask = (clone_RDs == float('inf')) # Mask with the values equal to zero
clone_RDs[mask] = 0 # Idx where there are zeros are substituted by infinity
return clone_RDs
def remove_zero(self, RDs):
"""!
@brief Substitutes the zeros values for big RD values
@param [in] RDs: Tensor with RDs
@param [out] RDs: Tensor With Max values added
"""
clone_RDs = RDs.clone()
mask = (clone_RDs == 0.0) # Mask with the values equal to zero
clone_RDs[mask] = self.MAX_RD
return clone_RDs
def remove_values_above(self, RDs, max_val):
"""!
@brief Substitutes values above MAX_RD for the MAX_RD
@param [in] RDs: Tensor with RDs
@param [in] max_val: Max value to add
@param [out] clone_RDs: Tensor With Max values added
"""
clone_RDs = RDs.clone()
mask = (clone_RDs >= max_val) # Mask with the values equal to zero
clone_RDs[mask] = max_val
return clone_RDs
def forward(self, pred, labels, RD):
"""!
@brief This function implements the loss function
@param [in] pred: Predictions made by the model
@param [in] labels: Ground-truth tensor
@param [in] RD: Rate distortion tensor
@param [out] loss: Vector of raw predictions that a classification model generates
"""
# Cross Entropy loss
if self.use_mod_cross_entropy:
# Loss
loss_CE = torch.mul(torch.log(torch.add(pred, 0.000000000000000001)), labels)
loss_CE = torch.sum(loss_CE, dim=1)
loss_CE = -torch.mean(loss_CE, dim=0)
else:
# Cross Entropy loss
labels_mod = torch.squeeze(train_model_utils.obtain_mode(labels))
loss_CE = nn.CrossEntropyLoss()(pred, labels_mod) # 6 classes
#print("Cross entropy loss:", loss_CE)
# Rate distortion
min_RDs = self.get_min_RDs(RD) # Get minimum values of RD
# Substitute inf RD values
RD_mod = self.remove_inf_values(RD)
# Replace zero values for a big number
RD_mod = self.remove_zero(RD_mod)
# Replace high values for a smaller number
RD_mod = self.remove_values_above(RD_mod, self.MAX_RD)
# Compute loss
# Paper loss function
loss_RD = torch.mul(pred, torch.sub(torch.div(RD_mod, min_RDs), 1))
# Remove negative
loss_RD = self.remove_values_lower(loss_RD, 0, 0)
# Mean loss
# TODO: Instead of limiting the loss, limit the gradients
loss_RD = torch.sum(loss_RD, dim=1)
loss_RD = torch.mean(loss_RD, dim=0)
if loss_RD.item() > self.MAX_LOSS:
# Reduce loss
temp = torch.div(self.MAX_LOSS, loss_RD)
loss_RD = torch.mul(loss_RD, temp)
#print("Rate Distortion loss:", loss_RD)
# Total loss
loss = torch.add(loss_CE, torch.mul(loss_RD, self.beta))
# Exceptions
if torch.isnan(loss):
raise Exception("Loss can not be equal to nan!! Last loss was: ", str(self.last_loss))
elif torch.isinf(loss):
raise Exception("Loss is infinite!! Last loss was: ", str(self.last_loss))
self.last_loss = loss
self.last_pred = pred
return loss, loss_CE, loss_RD
class LossFunctionMSE_Ratios(nn.Module):
# This version uses penalty weights for less represented classes
def __init__(self, pm, use_mod_cross_entropy=True, beta=1, alpha=0.5):
super(LossFunctionMSE_Ratios, self).__init__()
self.beta = beta
self.alpha = alpha
self.pm = pm
self.use_mod_cross_entropy = use_mod_cross_entropy
# Constants
self.MAX_RD = 1E10 # Valor encontrado no dataset: 1.7E304
self.MAX_LOSS = 20 # float("inf")
self.last_loss = -1
self.last_pred = -1
self.last_RD = -1
def get_proportion_CUs(self, labels):
"""!
@brief This function returns the proportion of CU's for all the types of split mode
@param [in] labels: Ground truth tensor
@param [out] p_m: Tensor with the proportions
"""
p_m = torch.sum(labels, dim=0)
p_m = torch.reshape(p_m / torch.sum(p_m), shape=(1, -1))
return p_m
def get_min_RDs(self, RDs):
"""!
@brief Obtain the lowest value that isnt zero from RDs tensor
@param [in] RDs: Tensor with RDs
@param [out] min_RD: Lowest value of RDs that isnt zero
"""
clone_RDs = RDs.clone()
mask = (clone_RDs == 0) # Mask with the values equal to zero
clone_RDs[mask] = float('inf') # Idx where there are zeros are substituted by infinity
min_RD = torch.reshape(torch.min(clone_RDs, dim=1)[0], shape=(-1, 1)) # Get minimum values of RD
return min_RD
def remove_values_lower(self, tensor, max_val, subst_val):
"""!
@brief Remove values from tensor that are lower than a given value
@param [in] tensor: Tensor with values
@param [in] max_val: Threshold val
@param [in] subst_val: Max value to replace the others
@param [out] tensor: New tensor
"""
mask = (tensor < max_val) # Mask with the values equal to zero
tensor[mask] = subst_val # Idx where there are zeros are substituted by infinity
return tensor
def remove_inf_values(self, tensor):
"""!
@brief Remove values from tensor that are inf and make them zeros
@param [in] tensor: Tensor with values
@param [out] tensor: New tensor
"""
clone_RDs = tensor.clone()
mask = (clone_RDs == float('inf')) # Mask with the values equal to zero
clone_RDs[mask] = 0 # Idx where there are zeros are substituted by infinity
return clone_RDs
def remove_zero(self, RDs):
"""!
@brief Substitutes the zeros values for big RD values
@param [in] RDs: Tensor with RDs
@param [out] RDs: Tensor With Max values added
"""
clone_RDs = RDs.clone()
mask = (clone_RDs == 0.0) # Mask with the values equal to zero
clone_RDs[mask] = self.MAX_RD
return clone_RDs
def remove_values_above(self, RDs, max_val):
"""!
@brief Substitutes values above MAX_RD for the MAX_RD
@param [in] RDs: Tensor with RDs
@param [in] max_val: Max value to add
@param [out] clone_RDs: Tensor With Max values added
"""
clone_RDs = RDs.clone()
mask = (clone_RDs >= max_val) # Mask with the values equal to zero
clone_RDs[mask] = max_val
return clone_RDs
def forward(self, pred, labels, RD):
"""!
@brief This function implements the loss function
@param [in] pred: Predictions made by the model
@param [in] labels: Ground-truth tensor
@param [in] RD: Rate distortion tensor
@param [out] loss: Vector of raw predictions that a classification model generates
"""
# Cross Entropy loss
if self.use_mod_cross_entropy:
# Loss
loss_CE = torch.mul(torch.log(torch.add(pred, 0.000000000000000001)), labels)
loss_CE = torch.mul(torch.pow(1/self.pm, self.alpha), loss_CE)
loss_CE = torch.sum(loss_CE, dim=1)
loss_CE = -torch.mean(loss_CE, dim=0)
else:
# Cross Entropy loss
labels_mod = torch.squeeze(train_model_utils.obtain_mode(labels))
loss_CE = nn.CrossEntropyLoss()(pred, labels_mod) # 6 classes
#print("Cross entropy loss:", loss_CE)
# Rate distortion
min_RDs = self.get_min_RDs(RD) # Get minimum values of RD
# print("min_RDs", min_RDs)
# Substitute inf RD values
RD_mod = self.remove_inf_values(RD)
# Replace zero values for a big number
RD_mod = self.remove_zero(RD_mod)
# Replace high values for a smaller number
RD_mod = self.remove_values_above(RD_mod, self.MAX_RD)
# Compute loss
# Paper loss function
loss_RD = torch.mul(pred, torch.sub(torch.div(RD_mod, min_RDs), 1))
# Remove negative
loss_RD = self.remove_values_lower(loss_RD, 0, 0)
# Mean loss
loss_RD = torch.sum(loss_RD, dim=1)
loss_RD = torch.mean(loss_RD, dim=0)
if loss_RD.item() > self.MAX_LOSS:
# Reduce loss
temp = torch.div(self.MAX_LOSS, loss_RD)
loss_RD = torch.mul(loss_RD, temp)
#print("Rate Distortion loss:", loss_RD)
# Total loss
loss = torch.add(loss_CE, torch.mul(loss_RD, self.beta))
# Exceptions
if torch.isnan(loss):
raise Exception("Loss can not be equal to nan!! Last loss was: ", str(self.last_loss))
elif torch.isinf(loss):
raise Exception("Loss is infinite!! Last loss was: ", str(self.last_loss))
self.last_loss = loss
self.last_pred = pred
return loss, loss_CE, loss_RD