| import cv2 | |
| import math | |
| from enum import Enum | |
| import torch | |
| from torch import nn | |
| import torch.nn.functional as F | |
| from .filter import Filter, FILTER_MODULES | |
| class CascadeArgumentRegressor(nn.Module): | |
| def __init__(self, in_channels, base_channels, out_channels, head_num): | |
| super(CascadeArgumentRegressor, self).__init__() | |
| self.in_channels = in_channels | |
| self.base_channels = base_channels | |
| self.out_channels = out_channels | |
| self.head_num = head_num | |
| self.pool = nn.AdaptiveAvgPool2d((1, 1)) | |
| self.f = nn.Linear(self.in_channels, 160) | |
| self.g = nn.Linear(self.in_channels, self.base_channels) | |
| self.headers = nn.ModuleList() | |
| for i in range(0, self.head_num): | |
| self.headers.append( | |
| nn.ModuleList([ | |
| nn.Linear(160 + self.base_channels, self.base_channels), | |
| nn.Linear(self.base_channels, self.out_channels), | |
| ]) | |
| ) | |
| def forward(self, x): | |
| x = self.pool(x) | |
| n, c, _, _ = x.shape | |
| x = x.view(n, c) | |
| f = self.f(x) | |
| g = self.g(x) | |
| pred_args = [] | |
| for i in range(0, self.head_num): | |
| g = self.headers[i][0](torch.cat((f, g), dim=1)) | |
| pred_args.append(self.headers[i][1](g)) | |
| return pred_args | |
| class FilterPerformer(nn.Module): | |
| def __init__(self, filter_types): | |
| super(FilterPerformer, self).__init__() | |
| self.filters = [FILTER_MODULES[filter_type]() for filter_type in filter_types] | |
| def forward(self): | |
| pass | |
| def restore(self, x, mask, arguments): | |
| assert len(self.filters) == len(arguments) | |
| outputs = [] | |
| _image = x | |
| for filter, arg in zip(self.filters, arguments): | |
| _image = filter(_image, arg) | |
| outputs.append(_image * mask + x * (1 - mask)) | |
| return outputs | |
| def adjust(self, image, mask, arguments): | |
| assert len(self.filters) == len(arguments) | |
| outputs = [] | |
| _image = image | |
| for filter, arg in zip(reversed(self.filters), reversed(arguments)): | |
| _image = filter(_image, arg) | |
| outputs.append(_image * mask + image * (1 - mask)) | |
| return outputs | |