| from transformers import PreTrainedModel |
| from .configuration import MoLMConfig |
| import torch |
| import torch.nn as nn |
| from torch.nn import functional as F |
| from transformers.utils import ModelOutput |
| from .gpt import GPTBase |
| from .aux_losses import entropy_reg, load_balancing_loss, router_z_loss |
| from typing import Optional, List |
| from dataclasses import dataclass |
| import tiktoken |
|
|
|
|
| @dataclass |
| class Output(ModelOutput): |
| logits: torch.FloatTensor = None |
| loss: Optional[torch.FloatTensor] = None |
| expert_losses: Optional[List] = None |
| loss_to_log: Optional[float] = None |
| router_logits: Optional[torch.FloatTensor] = None |
| selected_experts: Optional[torch.LongTensor] = None |
| combined_log_probs: Optional[torch.FloatTensor] = None |
|
|
|
|
| class MoLM(PreTrainedModel): |
| config_class = MoLMConfig |
|
|
| def __init__(self, config, expert_weights=None, dropout=0.1): |
| """ |
| Constructor for the MoLM (Mixture of Language Models) class. |
| |
| :param config: The configuration of the model (should be a PretrainedConfig object) |
| :param expert_weights: (Optional) A list of weights for each expert to load pre-trained weights (should match the number of experts) |
| :param dropout: Dropout rate for the model |
| :param use_router: Flag to indicate whether to use routing (currently not implemented) |
| """ |
| super(MoLM, self).__init__(config) |
| |
| |
| self.num_experts = config.num_experts |
| |
| |
| assert len(config.expert_configs) == self.num_experts, "Number of expert configurations must match num_experts in config." |
| self.expert_configs = config.expert_configs |
|
|
| |
| self.use_router = config.use_router |
| |
| self.router = nn.Sequential( |
| nn.Linear(config.n_embd, self.num_experts), |
| ) |
| self.top_k = config.top_k_experts if hasattr(config, "top_k_experts") else self.num_experts |
|
|
| |
| self.experts = nn.ModuleList([GPTBase(config=self.expert_configs[i]) for i in range(self.num_experts)]) |
| self.tokenizer = tiktoken.get_encoding("gpt2") |
| |
| |
| if expert_weights is not None: |
| for i, expert in enumerate(self.experts): |
| expert.load_state_dict(expert_weights[i], strict=False) |
| expert.transformer.wte.weight = torch.nn.Parameter(expert.transformer.wte.weight.clone()) |
| for param in expert.parameters(): |
| param.requires_grad = False |
|
|
| def forward(self, input_ids, attention_mask=None, targets=None, date=None, masking_enabled=True, **kwargs): |
| """ |
| Forward pass for the MoLM model, passing input through all experts and averaging their outputs. |
| |
| :param input_ids: Input token IDs (batch_size, seq_len) |
| :param attention_mask: Attention mask (batch_size, seq_len) |
| :param targets: Target labels for calculating loss (batch_size, seq_len) |
| :param date: A tensor indicating which experts to use. Each sample in the batch can have a different date. |
| :param masking_enabled: Whether or not to perform expert masking (True/False) |
| :param kwargs: Additional arguments |
| :return: The averaged output of all active experts up to the specified date for each sample in the batch |
| """ |
| device = input_ids.device |
| b, t = input_ids.size() |
|
|
| |
| assert t <= self.config.sequence_length, f"Cannot forward sequence of length {t}, block size is only {self.config.sequence_length}" |
|
|
| |
| if date is None: |
| date = torch.full((1, b), 6, dtype=torch.long, device=device).squeeze(0) |
| elif isinstance(date, int): |
| |
| date = (date - 2013) // 2 + 1 |
| date = torch.full((1, b), date, dtype=torch.long, device=device).squeeze(0) |
| elif isinstance(date, torch.Tensor): |
| |
| assert date.size(0) == b, "The size of date tensor must match the batch size." |
| date = date.to(device) |
|
|
| |
| expert_outputs = [] |
| expert_losses = [] |
|
|
| |
| active_experts_count = torch.zeros(b, dtype=torch.long, device=device) |
|
|
| |
| with torch.no_grad(): |
| for i, expert in enumerate(self.experts): |
| |
| expert_mask = date >= i |
| |
| |
| expert_mask_expanded = expert_mask.unsqueeze(-1).unsqueeze(-1).float() |
|
|
| expert_output = expert(input_ids, targets=targets, date=date, **kwargs, get_logits=True) |
|
|
| logits = expert_output["logits"] |
| loss_to_log = expert_output["loss_to_log"] |
|
|
| |
| logits = logits * expert_mask_expanded |
|
|
| |
| expert_outputs.append(logits) |
| expert_losses.append(loss_to_log) |
|
|
| |
| active_experts_count += expert_mask.long() |
|
|
| |
| expert_outputs = torch.stack(expert_outputs, dim=0) |
|
|
| |
| log_probs = F.log_softmax(expert_outputs, dim=-1) |
| |
| if self.use_router: |
| hidden = self.experts[0].transformer.wte(input_ids) |
| pooled_hidden = hidden.mean(dim=1) |
| router_logits = self.router(pooled_hidden) |
|
|
| expert_ids = torch.arange(self.num_experts, device=input_ids.device) |
| router_mask = date.unsqueeze(1) >= expert_ids.unsqueeze(0) |
| masked_router_logits = router_logits.masked_fill(~router_mask, float("-inf")) |
|
|
| |
| topk_probs, topk_indices = torch.topk(F.softmax(masked_router_logits, dim=-1), self.top_k, dim=-1) |
| sparse_probs = torch.zeros_like(router_logits) |
| sparse_probs.scatter_(1, topk_indices, topk_probs) |
| sparse_probs = sparse_probs / sparse_probs.sum(dim=1, keepdim=True) |
|
|
| |
| log_weights = torch.log(sparse_probs + 1e-9) |
|
|
| |
| log_weights_exp = log_weights.transpose(0, 1).unsqueeze(-1).unsqueeze(-1) |
| weighted_log_probs = log_probs + log_weights_exp |
|
|
| combined_log_probs = torch.logsumexp(weighted_log_probs, dim=0) |
|
|
| else: |
| |
| log_weights = torch.log(1.0 / active_experts_count.float().clamp(min=1.0)).view(1, -1, 1, 1) |
| weighted_log_probs = log_probs + log_weights |
| combined_log_probs = torch.logsumexp(weighted_log_probs, dim=0) |
|
|
| |
| if targets is not None: |
| loss = F.nll_loss(combined_log_probs.view(-1, combined_log_probs.size(-1)), targets.view(-1), ignore_index=-1) |
| loss_to_log = loss.item() |
|
|
| |
| if self.use_router and self.training: |
| flat_router_logits = router_logits.view(-1, router_logits.size(-1)) |
| flat_selected_experts = topk_indices.view(-1, topk_indices.size(-1)) |
|
|
| |
| entropy = entropy_reg(flat_router_logits) |
| lb_loss = load_balancing_loss(flat_router_logits, flat_selected_experts) |
| zloss = router_z_loss(flat_router_logits) |
|
|
| |
| loss = ( |
| loss |
| + 0.01 *entropy |
| + 0.01 * lb_loss |
| + 0.0001 * zloss |
| ) |
| else: |
| loss = None |
| loss_to_log = None |
|
|
| return Output( |
| logits=expert_outputs, |
| loss=loss, |
| combined_log_probs=combined_log_probs, |
| loss_to_log=loss_to_log, |
| expert_losses=expert_losses, |
| router_logits=router_logits if self.use_router else None, |
| selected_experts=topk_indices if self.use_router else None, |
| ) |
|
|
|
|
| @torch.no_grad() |
| def generate(self, input_ids, max_new_tokens, date=None, temperature=1.0, top_k=None): |
| """ |
| Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete |
| the sequence max_new_tokens times, feeding the predictions back into the model each time. |
| Most likely you'll want to make sure to be in model.eval() mode of operation for this. |
| """ |
| idx = input_ids |
| for _ in range(max_new_tokens): |
| |
| idx_cond = ( |
| idx |
| if idx.size(1) <= self.config.sequence_length |
| else idx[:, -self.config.sequence_length :] |
| ) |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| log_probs = self(idx_cond, date=date).combined_log_probs[:, -1, :] |
| |
| |
| if temperature == 0: |
| |
| idx_next = torch.argmax(log_probs, dim=-1, keepdim=True) |
| else: |
| |
| scaled_log_probs = log_probs / temperature |
| |
| probs = torch.exp(scaled_log_probs) |
| |
| idx_next = torch.multinomial(probs, num_samples=1) |
| |
| idx = torch.cat((idx, idx_next), dim=1) |
| |
| if idx_next.item() == 50526: |
| break |
|
|
| return idx |
|
|
| @torch.no_grad() |
| def generate_from_string(self, in_str, max_new_tokens, date=None, temperature=1.0, top_k=None): |
| idx = ( |
| torch.tensor( |
| self.tokenizer.encode(in_str) |
| ) |
| .view(1, -1) |
| .to(next(self.parameters()).device) |
| ) |
| out_idx = ( |
| self.generate(idx, max_new_tokens, date, temperature, top_k) |
| .view(-1) |
| .to("cpu") |
| .numpy() |
| ) |
| return self.tokenizer.decode(out_idx) |
| |
|
|