File size: 4,757 Bytes
2215b89
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import types
from typing import List, Callable

import torch 
from torch import nn, Tensor
from torch.nn import functional as F
from torchvision.models.resnet import BasicBlock


def trp_criterion(trp_blocks: nn.ModuleList, shared_head: Callable, criterion: Callable, lambdas: List[float], hidden_state: Tensor, logits: Tensor, targets: Tensor, loss_normalization=False):
    losses, rewards = criterion(logits, targets)
    returns = torch.ones_like(rewards, dtype=torch.float32, device=rewards.device)
    if loss_normalization:
        coeff = torch.mean(losses).detach()

    embeds = [hidden_state]
    predictions = []
    for k, w in enumerate(lambdas):
        embeds.append(trp_blocks[k](embeds[-1]))
        predictions.append(shared_head(embeds[-1]))
        returns = returns + w * rewards
        replica_losses, rewards = criterion(predictions[-1], targets, rewards)
        losses = losses + replica_losses
    loss = torch.mean(losses * returns)
    
    if loss_normalization:
        with torch.no_grad():
            coeff = torch.exp(coeff) / torch.exp(loss.detach())
        loss = coeff * loss
    
    return loss


class TPBlock(nn.Module):
    def __init__(self, depths: int, inplanes: int, planes: int):
        super(TPBlock, self).__init__()

        blocks = [BasicBlock(inplanes=inplanes, planes=planes) for _ in range(depths)]
        self.blocks = nn.Sequential(*blocks)
        for name, param in self.blocks.named_parameters():
            if 'conv' in name:
                nn.init.zeros_(param)  # Initialize weights
            elif 'downsample' in name:
                nn.init.zeros_(param)   # Initialize biases

    def forward(self, x):
        return self.blocks(x)


class ResNetConfig:
    @staticmethod
    def gen_criterion(label_smoothing=0.0, top_k=1):
        def func(input, target, mask=None):
            """
            Args:
                input (Tensor): Input tensor of shape [B, C].
                target (Tensor): Target labels of shape [B] or [B, C].

            Returns:
                loss (Tensor): Scalar tensor representing the loss.
                mask (Tensor): Boolean mask tensor of shape [B].
            """
            label = torch.argmax(target, dim=1) if label_smoothing > 0.0 else target
                
            unmasked_loss = F.cross_entropy(input, label, reduction="none", label_smoothing=label_smoothing)
            if mask is None:
                mask = torch.ones_like(unmasked_loss, dtype=torch.float32, device=target.device)
            loss = torch.sum(mask * unmasked_loss) / (torch.sum(mask) + 1e-6)

            with torch.no_grad():
                topk_values, topk_indices = torch.topk(input, top_k, dim=-1)
                mask = mask * torch.eq(topk_indices, label[:, None]).any(dim=-1).to(input.dtype)

            return loss, mask
        return func
    
    @staticmethod
    def gen_shared_head(self):
        def func(x):
            """
            Args:
                x (Tensor): Hidden States tensor of shape [B, C, H, Whidden_units].

            Returns:
                logits (Tensor): Logits tensor of shape [B, C].
            """
            x = self.layer4(x)
            x = self.avgpool(x)
            x = torch.flatten(x, 1)
            logits = self.fc(x)
            return logits
        return func

    @staticmethod
    def gen_forward(lambdas, loss_normalization=True, label_smoothing=0.0, top_k=1):
        def func(self, x: Tensor, targets=None) -> Tensor:
            x = self.conv1(x)
            x = self.bn1(x)
            x = self.relu(x)
            x = self.maxpool(x)

            x = self.layer1(x)
            x = self.layer2(x)
            hidden_states = self.layer3(x)
            x = self.layer4(hidden_states)
            x = self.avgpool(x)
            x = torch.flatten(x, 1)
            logits = self.fc(x)

            if self.training:
                shared_head = ResNetConfig.gen_shared_head(self)
                criterion = ResNetConfig.gen_criterion(label_smoothing=label_smoothing, top_k=top_k)

                loss = trp_criterion(self.trp_blocks, shared_head, criterion, lambdas, hidden_states, logits, targets, loss_normalization=loss_normalization)
                
                return logits, loss

            return logits
        
        return func
    

def apply_trp(model, depths: List[int], planes: int, lambdas: List[float], **kwargs):
    print("✅ Applying TRP to ResNet for Image Classification...")
    model.trp_blocks = torch.nn.ModuleList([TPBlock(depths=d, inplanes=planes, planes=planes) for d in depths])
    model.forward = types.MethodType(ResNetConfig.gen_forward(lambdas, True, label_smoothing=kwargs["label_smoothing"], top_k=1), model)    
    return model