File size: 5,222 Bytes
4f591e5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import torch
import torch.nn as nn
from transformers import LongformerPreTrainedModel, LongformerModel
import torch.nn.functional as F

class AttentionPooling(nn.Module):
    def __init__(self, hidden_dim):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.query = nn.Linear(hidden_dim, hidden_dim)
        self.energy = nn.Linear(hidden_dim, 1)
        
        # Initialize weights
        nn.init.xavier_uniform_(self.query.weight)
        nn.init.xavier_uniform_(self.energy.weight)
        self.query.bias.data.zero_()
        self.energy.bias.data.zero_()

    def forward(self, hidden_states, attention_mask=None):
        # Compute attention scores
        transformed = torch.tanh(self.query(hidden_states))  # (batch_size, seq_len, hidden_dim)
        scores = self.energy(transformed).squeeze(-1)  # (batch_size, seq_len)
        
        # Apply attention mask if provided
        if attention_mask is not None:
            scores = scores.masked_fill(attention_mask == 0, float('-inf'))
        
        # Compute attention weights
        weights = F.softmax(scores, dim=-1)  # (batch_size, seq_len)
        
        # Apply attention pooling
        pooled = torch.sum(hidden_states * weights.unsqueeze(-1), dim=1)  # (batch_size, hidden_dim)
        return pooled



class CustomLongformerForSequenceClassification(LongformerPreTrainedModel):
    """Longformer model with attention pooling for sequence classification.
    
    Uses attention pooling over the last four hidden layers instead of CLS token pooling.
    """
    def __init__(self, config):
        super().__init__(config)
        self.num_labels = config.num_labels
        self.config = config
        
        # Longformer backbone
        self.longformer = LongformerModel(config)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        
        # Attention pooling for each layer
        self.attention_poolers = nn.ModuleList([
            AttentionPooling(config.hidden_size) for _ in range(4)
        ])
        
        # Final classifier
        self.classifier = nn.Linear(config.hidden_size * 4, config.num_labels)
        
        # Initialize weights
        self.post_init()

    def forward(self, input_ids=None, attention_mask=None, labels=None, **kwargs):
        outputs = self.longformer(
            input_ids=input_ids,
            attention_mask=attention_mask,
            output_hidden_states=True,
            **kwargs
        )

        # Get last four hidden layers
        last_four_layers = outputs.hidden_states[-4:]
        
        # Apply attention pooling to each layer
        pooled = []
        for layer, pooler in zip(last_four_layers, self.attention_poolers):
            pooled.append(pooler(layer, attention_mask=attention_mask))
        
        # Concatenate pooled representations
        concatenated = torch.cat(pooled, dim=1)
        concatenated = self.dropout(concatenated)
        logits = self.classifier(concatenated)

        # Compute loss if labels provided
        loss = None
        if labels is not None:
            if hasattr(self, 'loss_fct'):
                loss = self.loss_fct(logits, labels)
            else:
                loss = F.mse_loss(logits, labels.float())

        return {'loss': loss, 'logits': logits}

class CustomLongformerForSequenceClassification(LongformerPreTrainedModel):
    """Longformer model with attention pooling for sequence classification."""
    def __init__(self, config):
        super().__init__(config)
        self.num_labels = config.num_labels
        self.config = config
        
        # Longformer backbone
        self.longformer = LongformerModel(config)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        
        # Attention pooling for each layer
        self.attention_poolers = nn.ModuleList([
            AttentionPooling(config.hidden_size) for _ in range(4)
        ])
        
        # Final classifier
        self.classifier = nn.Linear(config.hidden_size * 4, config.num_labels)
        
        # Initialize weights
        self.post_init()

    def forward(self, input_ids=None, attention_mask=None, labels=None, **kwargs):
        outputs = self.longformer(
            input_ids=input_ids,
            attention_mask=attention_mask,
            output_hidden_states=True,
            **kwargs
        )

        # Get last four hidden layers
        last_four_layers = outputs.hidden_states[-4:]
        
        # Apply attention pooling to each layer
        pooled = []
        for layer, pooler in zip(last_four_layers, self.attention_poolers):
            pooled.append(pooler(layer, attention_mask=attention_mask))
        
        # Concatenate pooled representations
        concatenated = torch.cat(pooled, dim=1)
        concatenated = self.dropout(concatenated)
        logits = self.classifier(concatenated)

        # Compute loss if labels provided
        loss = None
        if labels is not None:
            if hasattr(self, 'loss_fct'):
                loss = self.loss_fct(logits, labels)
            else:
                loss = F.mse_loss(logits.view(-1), labels.float().view(-1))

        return {'loss': loss, 'logits': logits}