class GroupedAutoEncoder(nn.Module): def __init__(self, input_dim, hidden_dim, num_groups): super(GroupedAutoEncoder, self).__init__() self.num_groups = num_groups self.group_input_dim = input_dim // num_groups self.group_hidden_dim = hidden_dim // num_groups assert input_dim % num_groups == 0, "Input dimension must be divisible by the number of groups." assert hidden_dim % num_groups == 0, "Hidden dimension must be divisible by the number of groups." # Define group-wise encoders and decoders self.encoders = nn.ModuleList([ nn.Linear(self.group_input_dim, self.group_hidden_dim, bias=False) for _ in range(num_groups) ]) ''' self.decoders = nn.ModuleList([ nn.Linear(self.group_hidden_dim, self.group_input_dim, bias=False) for _ in range(num_groups) ]) ''' self.decoder = nn.Linear(hidden_dim, input_dim, bias=False) self.init_weights() def init_weights(self): for encoder in self.encoders: nn.init.xavier_uniform_(encoder.weight) #for decoder in self.decoders: # nn.init.xavier_uniform_(decoder.weight) nn.init.xavier_uniform_(self.decoder.weight) def forward(self, x): # Split input into groups group_inputs = torch.split(x, self.group_input_dim, dim=1) # Apply group-wise encoding encoded_groups = [encoder(group) for group, encoder in zip(group_inputs, self.encoders)] # Apply group-wise decoding #decoded_groups = [decoder(group) for group, decoder in zip(encoded_groups, self.decoders)] reconstructed = self.decoder(torch.cat(encoded_groups,dim=1)) # Concatenate groups back together # reconstructed = torch.cat(decoded_groups, dim=1) return reconstructed input_dim = 5120 hidden_dim = 320 num_groups = 40 model = GroupedAutoEncoder(input_dim=input_dim, hidden_dim=hidden_dim, num_groups=num_groups).cuda()