| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| class Conv2d_TRGP(nn.Conv2d): |
| |
| def __init__(self, |
| in_channels, |
| out_channels, |
| kernel_size, |
| padding=0, |
| stride=1, |
| dilation=1, |
| groups=1, |
| bias=True): |
| super().__init__(in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=bias) |
|
|
| |
| size = self.weight.shape[1] * self.weight.shape[2] * self.weight.shape[3] |
| self.identity_matrix = torch.eye(size, device = self.weight.device) |
|
|
| self.space = [] |
| self.scale_param = nn.ParameterList() |
|
|
| def enable_scale(self, space): |
| self.space = space |
| self.scale_param = nn.ParameterList([nn.Parameter(self.identity_matrix).to(self.weight.device) for _ in self.space]) |
|
|
| def disable_scale(self): |
|
|
| self.space = [] |
| self.scale_param = nn.ParameterList() |
|
|
| def forward(self, input, compute_input_matrix = False): |
| |
| |
| if compute_input_matrix: |
| self.input_matrix = input |
|
|
| sz = self.weight.shape[0] |
|
|
| masked_weight = self.weight |
|
|
| for scale, space in zip(self.scale_param, self.space): |
|
|
| cropped_scale = scale[:space.size(1), :space.size(1)] |
| cropped_identity_matrix = self.identity_matrix[:space.shape[1], :space.shape[1]].to(self.weight.device) |
|
|
| |
| |
|
|
| masked_weight = masked_weight + (masked_weight.view(sz, -1) @ space @ (cropped_scale - cropped_identity_matrix) @ space.T).\ |
| view(masked_weight.shape) |
|
|
|
|
| return F.conv2d(input, masked_weight, self.bias, self.stride, self.padding, self.dilation, self.groups) |
|
|
| class Linear_TRGP(nn.Linear): |
|
|
| def __init__(self, in_features, out_features, bias = True): |
| super().__init__(in_features, out_features, bias = bias) |
|
|
| |
| self.identity_matrix = torch.eye(self.weight.shape[1], device = self.weight.device) |
|
|
| self.space = [] |
| self.scale_param = nn.ParameterList() |
|
|
| def enable_scale(self, space): |
| self.space = space |
| self.scale_param = nn.ParameterList([nn.Parameter(self.identity_matrix).to(self.weight.device) for _ in self.space]) |
|
|
| def disable_scale(self): |
|
|
| self.space = [] |
| self.scale_param = nn.ParameterList() |
|
|
| def forward(self, input, compute_input_matrix = False): |
|
|
| |
| if compute_input_matrix: |
| self.input_matrix = input |
| |
| masked_weight = self.weight |
| for scale, space in zip(self.scale_param, self.space): |
|
|
| cropped_scale = scale[:space.shape[1], :space.shape[1]] |
| cropped_identity_matrix = self.identity_matrix[:space.shape[1], :space.shape[1]].to(self.weight.device) |
|
|
| masked_weight = masked_weight + masked_weight @ space @ (cropped_scale - cropped_identity_matrix) @ space.T |
|
|
| return F.linear(input, masked_weight, self.bias) |
|
|
| class AlexNet_TRGP(nn.Module): |
|
|
| def __init__(self, dropout_rate_1 = 0.2, dropout_rate_2 = 0.5, **kwargs): |
|
|
| super().__init__() |
|
|
| self.conv1 = Conv2d_TRGP(in_channels = 3, out_channels = 64, kernel_size = 4, bias = False) |
| self.bn1 = nn.BatchNorm2d(64, track_running_stats = False) |
| |
| self.conv2 = Conv2d_TRGP(in_channels = 64, out_channels = 128, kernel_size = 3, bias = False) |
| self.bn2 = nn.BatchNorm2d(128, track_running_stats = False) |
|
|
| self.conv3 = Conv2d_TRGP(in_channels = 128, out_channels = 256, kernel_size = 2, bias = False) |
| self.bn3 = nn.BatchNorm2d(256, track_running_stats = False) |
|
|
| self.fc1 = Linear_TRGP(in_features = 1024, out_features = 2048, bias = False) |
| self.bn4 = nn.BatchNorm1d(2048, track_running_stats = False) |
|
|
| self.fc2 = Linear_TRGP(in_features = 2048, out_features = 2048, bias=False) |
| self.bn5 = nn.BatchNorm1d(2048, track_running_stats = False) |
|
|
| self.feat_dim = 2048 |
|
|
|
|
| |
| self.relu = nn.ReLU() |
| self.dropout1 = nn.Dropout(dropout_rate_1) |
| self.dropout2 = nn.Dropout(dropout_rate_2) |
| self.maxpool = nn.MaxPool2d(kernel_size = 2) |
|
|
| def forward(self, x, compute_input_matrix): |
|
|
| x = self.conv1(x, compute_input_matrix) |
| x = self.bn1(x) |
| x = self.relu(x) |
| x = self.dropout1(x) |
| x = self.maxpool(x) |
| |
| x = self.conv2(x, compute_input_matrix) |
| x = self.bn2(x) |
| x = self.relu(x) |
| x = self.dropout1(x) |
| x = self.maxpool(x) |
|
|
| x = self.conv3(x, compute_input_matrix) |
| x = self.bn3(x) |
| x = self.relu(x) |
| x = self.dropout2(x) |
| x = self.maxpool(x) |
|
|
| x = x.view(x.size(0), -1) |
|
|
| x = self.fc1(x, compute_input_matrix) |
| x = self.bn4(x) |
| x = self.relu(x) |
| x = self.dropout2(x) |
|
|
| x = self.fc2(x, compute_input_matrix) |
| x = self.bn5(x) |
| x = self.relu(x) |
| x = self.dropout2(x) |
|
|
| return x |
|
|
| |
|
|
| class Conv2d_API(nn.Conv2d): |
| def __init__(self,in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros'): |
| super().__init__(in_channels, out_channels, kernel_size, stride, padding, bias=bias, dilation=dilation, groups=groups, padding_mode=padding_mode) |
|
|
| self.extra_ws = nn.ParameterList([]) |
| self.expand = [] |
|
|
| def forward(self, input, t, compute_input_matrix = False): |
|
|
| input = torch.cat([input] + [(input.permute(0, 2, 3, 1) @ self.extra_ws[i]).permute(0, 3, 1, 2) for i in range(t)], dim=1) |
|
|
| if compute_input_matrix: |
| self.input_matrix = input |
| |
| return F.conv2d(input, self.weight[:, :input.shape[1]], bias=self.bias, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=self.groups) |
|
|
| def duplicate(self, in_channels, extra_w): |
| dup = Conv2d_API( |
| self.in_channels + in_channels, |
| self.out_channels, |
| self.kernel_size, |
| self.stride, |
| self.padding, |
| self.dilation, |
| self.groups, |
| self.bias is not None, |
| self.padding_mode |
| ) |
|
|
| dup.extra_ws = self.extra_ws |
| dup.extra_ws.append(extra_w) |
| dup.expand = self.expand + [in_channels] |
|
|
| dup.weight.data[:, :self.in_channels].data.copy_(self.weight.data) |
| |
| if self.bias is not None: |
| dup.bias.data[:, :self.in_channels].data.copy_(self.bias.data) |
|
|
| return dup |
|
|
| class Linear_API(nn.Linear): |
| def __init__(self, in_features: int, out_features: int, bias: bool = True, device=None, dtype=None) -> None: |
| super().__init__(in_features, out_features, bias, device, dtype) |
| |
| self.extra_ws = nn.ParameterList([]) |
| self.expand = [] |
|
|
| def forward(self, input, t, compute_input_matrix=False): |
|
|
| input = torch.cat([input] + [input @ self.extra_ws[i] for i in range(t)], dim=1) |
| |
| if compute_input_matrix: |
| self.input_matrix = input |
|
|
| return F.linear(input, self.weight[:,:input.shape[1]], bias=self.bias) |
|
|
| def duplicate(self, in_features, extra_w): |
| dup = Linear_API( |
| self.in_features + in_features, |
| self.out_features, |
| self.bias is not None |
| ) |
|
|
| dup.extra_ws = self.extra_ws |
| dup.extra_ws.append(extra_w) |
| dup.expand = self.expand + [in_features] |
|
|
| dup.weight.data[:, :self.in_features].data.copy_(self.weight.data) |
|
|
| if self.bias is not None: |
| dup.bias.data[:, :self.in_features].data.copy_(self.bias.data) |
|
|
| return dup |
|
|
| class AlexNet_API(nn.Module): |
|
|
| def __init__(self, dropout_rate_1 = 0.2, dropout_rate_2 = 0.5, **kwargs): |
|
|
| super().__init__() |
|
|
| self.select1, self.select2, self.select3, self.select4, self.select5 = [], [], [], [], [] |
|
|
| self.conv1 = Conv2d_API(in_channels = 3, out_channels = 64, kernel_size = 4, bias = False) |
| self.bn1 = nn.BatchNorm2d(64, track_running_stats = False) |
| |
| self.conv2 = Conv2d_API(in_channels = 64, out_channels = 128, kernel_size = 3, bias = False) |
| self.bn2 = nn.BatchNorm2d(128, track_running_stats = False) |
|
|
| self.conv3 = Conv2d_API(in_channels = 128, out_channels = 256, kernel_size = 2, bias = False) |
| self.bn3 = nn.BatchNorm2d(256, track_running_stats = False) |
|
|
| self.fc1 = Linear_API(in_features = 1024, out_features = 2048, bias = False) |
| self.bn4 = nn.BatchNorm1d(2048, track_running_stats = False) |
|
|
| self.fc2 = Linear_API(in_features = 2048, out_features = 2048, bias=False) |
| self.bn5 = nn.BatchNorm1d(2048, track_running_stats = False) |
|
|
| self.feat_dim = 2048 |
|
|
| |
| self.relu = nn.ReLU() |
| self.dropout1 = nn.Dropout(dropout_rate_1) |
| self.dropout2 = nn.Dropout(dropout_rate_2) |
| self.maxpool = nn.MaxPool2d(kernel_size = 2) |
|
|
| def forward(self, x, t = 0, compute_input_matrix = False): |
|
|
| x = self.conv1(x, t, compute_input_matrix) |
| x = self.bn1(x) |
| x = self.relu(x) |
| x = self.dropout1(x) |
| x = self.maxpool(x) |
| |
| x = self.conv2(x, t, compute_input_matrix) |
| x = self.bn2(x) |
| x = self.relu(x) |
| x = self.dropout1(x) |
| x = self.maxpool(x) |
|
|
| x = self.conv3(x, t, compute_input_matrix) |
| x = self.bn3(x) |
| x = self.relu(x) |
| x = self.dropout2(x) |
| x = self.maxpool(x) |
|
|
| x = x.view(x.size(0), -1) |
|
|
| x = self.fc1(x, t, compute_input_matrix) |
| x = self.bn4(x) |
| x = self.relu(x) |
| x = self.dropout2(x) |
|
|
| x = self.fc2(x, t, compute_input_matrix) |
| x = self.bn5(x) |
| x = self.relu(x) |
| x = self.dropout2(x) |
|
|
| return x |
|
|
| def expand(self, sizes, extra_ws): |
| self.conv1 = self.conv1.duplicate(sizes[0], extra_ws[0]) |
| self.conv2 = self.conv2.duplicate(sizes[1], extra_ws[1]) |
| self.conv3 = self.conv3.duplicate(sizes[2], extra_ws[2]) |
| self.fc1 = self.fc1.duplicate(sizes[3], extra_ws[3]) |
| self.fc2 = self.fc2.duplicate(sizes[4], extra_ws[4]) |
|
|