File size: 3,559 Bytes
fc54e43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
# =============================================================================
# training/loss.py
# =============================================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Dict, Optional

class MambaLoss(nn.Module):
    """Loss functions for Mamba training"""
    
    def __init__(self, config, vocab_size: int):
        super().__init__()
        self.config = config
        self.vocab_size = vocab_size
        
        # Primary loss
        self.lm_loss = nn.CrossEntropyLoss(ignore_index=-100)
        
        # Auxiliary losses
        self.diversity_weight = 0.01
        self.specialist_balance_weight = 0.001
        
    def forward(self, logits: torch.Tensor, targets: torch.Tensor, 

                specialist_weights: Optional[Dict] = None) -> Dict[str, torch.Tensor]:
        """

        Compute total loss

        

        Args:

            logits: [batch, seq_len, vocab_size]

            targets: [batch, seq_len]

            specialist_weights: Dict of specialist activation weights

            

        Returns:

            Dict with loss components

        """
        losses = {}
        
        # Primary language modeling loss
        lm_loss = self.lm_loss(
            logits.view(-1, logits.size(-1)),
            targets.view(-1)
        )
        losses['lm_loss'] = lm_loss
        
        # Diversity loss to encourage specialist specialization
        if specialist_weights is not None:
            diversity_loss = self._compute_diversity_loss(specialist_weights)
            losses['diversity_loss'] = diversity_loss
            
            # Balance loss to prevent specialist dominance
            balance_loss = self._compute_balance_loss(specialist_weights)
            losses['balance_loss'] = balance_loss
        else:
            losses['diversity_loss'] = torch.tensor(0.0, device=logits.device)
            losses['balance_loss'] = torch.tensor(0.0, device=logits.device)
        
        # Total loss
        total_loss = (
            lm_loss + 
            self.diversity_weight * losses['diversity_loss'] +
            self.specialist_balance_weight * losses['balance_loss']
        )
        losses['total_loss'] = total_loss
        
        return losses
    
    def _compute_diversity_loss(self, specialist_weights: Dict) -> torch.Tensor:
        """Encourage specialists to be diverse"""
        if len(specialist_weights) < 2:
            return torch.tensor(0.0)
        
        # Convert weights to tensor
        weights = torch.stack(list(specialist_weights.values()))
        
        # Compute pairwise similarities
        normalized_weights = F.normalize(weights, dim=-1)
        similarity_matrix = torch.mm(normalized_weights, normalized_weights.t())
        
        # Penalize high similarities (encourage diversity)
        diversity_loss = similarity_matrix.triu(diagonal=1).mean()
        
        return diversity_loss
    
    def _compute_balance_loss(self, specialist_weights: Dict) -> torch.Tensor:
        """Encourage balanced specialist usage"""
        if not specialist_weights:
            return torch.tensor(0.0)
        
        # Get activation frequencies
        activations = torch.stack(list(specialist_weights.values()))
        
        # Compute variance in activations (lower is more balanced)
        balance_loss = activations.var()
        
        return balance_loss