File size: 6,105 Bytes
f8519d1 5e634b7 f8519d1 5e634b7 f8519d1 5e634b7 f8519d1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 |
from dataclasses import dataclass, fields, asdict
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
from mamba import MambaConfig #Mamba, MambaConfig, RMSNorm
from mamba_ssm import MambaLMHeadModel
"""
Encapsulates a Mamba model as language model. It has an embedding layer, and a LM head which maps the model output to logits.
"""
# TODO generate function : batch size != 1 ? (for now B=1)
# TODO generate function : top-p sampling
@dataclass
class MambaLMConfig(MambaConfig):
vocab_size: int = 32000
pad_vocab_size_multiple: int = 8
def __post_init__(self):
super().__post_init__()
#if self.vocab_size % self.pad_vocab_size_multiple != 0:
# self.vocab_size += (self.pad_vocab_size_multiple - self.vocab_size % self.pad_vocab_size_multiple)
def to_mamba_config(self) -> MambaConfig:
mamba_config_fields = {field.name for field in fields(MambaConfig)}
filtered_dict = {k: v for k, v in asdict(self).items() if k in mamba_config_fields}
return MambaConfig(**filtered_dict)
# adapted from https://github.com/johnma2006/mamba-minimal
def from_pretrained(name: str):
"""
Returns a model loaded with pretrained weights pulled from HuggingFace.
Args:
name: As of now, supports
* 'state-spaces/mamba-2.8b-slimpj'
* 'state-spaces/mamba-2.8b'
* 'state-spaces/mamba-1.4b'
* 'state-spaces/mamba-790m'
* 'state-spaces/mamba-370m'
* 'state-spaces/mamba-130m'
Returns:
model: a Mamba model configured with the proper parameters and initialized with the proper weights
"""
from transformers.utils import WEIGHTS_NAME, CONFIG_NAME
from transformers.utils.hub import cached_file
def load_config_hf(model_name):
resolved_archive_file = cached_file(model_name, CONFIG_NAME, _raise_exceptions_for_missing_entries=False)
return json.load(open(resolved_archive_file))
def load_state_dict_hf(model_name):
resolved_archive_file = cached_file(model_name, WEIGHTS_NAME, _raise_exceptions_for_missing_entries=False)
return torch.load(resolved_archive_file, weights_only=True, map_location='cpu', mmap=True)
# copy config data
config_data = load_config_hf(name)
config = MambaLMConfig(d_model=config_data['d_model'], n_layers=config_data['n_layer'], vocab_size=config_data['vocab_size'])
model = MambaLMHeadModel(config)
# copy weights
state_dict = load_state_dict_hf(name)
new_state_dict = {}
for key in state_dict:
if key == 'backbone.embedding.weight' or key == 'backbone.norm_f.weight':
new_key = key.replace('backbone.', '')
else:
new_key = key.replace('backbone', 'mamba')
new_state_dict[new_key] = state_dict[key]
model.load_state_dict(new_state_dict)
return model
class MambaLM(nn.Module):
def __init__(self, lm_config: MambaLMConfig):
super().__init__()
self.lm_config = lm_config
self.config = lm_config#.to_mamba_config()
self.embedding = nn.Embedding(self.lm_config.vocab_size, self.config.d_model)
self.mamba = Mamba(self.config)
self.norm_f = RMSNorm(self.config.d_model)
self.lm_head = nn.Linear(self.config.d_model, self.lm_config.vocab_size, bias=False)
self.lm_head.weight = self.embedding.weight
def forward(self, tokens):
# tokens : (B, L)
# logits : (B, L, vocab_size)
x = self.embedding(tokens)
x = self.mamba(x)
x = self.norm_f(x)
logits = self.lm_head(x)
return logits
def step(self, token, caches):
# token : (B)
# caches : [cache(layer) for all layers], cache : (h, inputs)
# logits : (B, vocab_size)
# caches : [cache(layer) for all layers], cache : (h, inputs)
x = self.embedding(token)
x, caches = self.mamba.step(x, caches)
x = self.norm_f(x)
logits = self.lm_head(x)
return logits, caches
# TODO temperature
# TODO process prompt in parallel, and pass in sequential mode when prompt is finished ?
def generate(self, tokenizer, prompt: str, num_tokens: int = 50, sample: bool = True, top_k: int = 40):
self.eval()
input_ids = tokenizer(prompt, return_tensors='pt').input_ids.to(next(self.parameters()).device) # (1, num_tokens)
# caches is a list of cache, one per layer
# cache is composed of : the hidden state, and the last d_conv-1 inputs
# the hidden state because the update is like an RNN
# the last d_conv-1 inputs because they are used in a 1d convolution (usually d_conv=4 so this is not large)
caches = [(None, torch.zeros(1, self.config.d_inner, self.config.d_conv-1, device=input_ids.device)) for _ in range(self.config.n_layers)]
for i in range(input_ids.size(1) + num_tokens - 1):
with torch.no_grad():
# forward the new output, get new cache
next_token_logits, caches = self.step(input_ids[:, i], caches) # (1, vocab_size), caches
# sample (no sampling when the prompt is being processed)
if i+1 >= input_ids.size(1):
probs = F.softmax(next_token_logits, dim=-1) # (1, vocab_size)
if top_k is not None:
values, _ = torch.topk(probs, k=top_k) # (1, k) ordered from lowest to biggest
probs[probs < values[:, -1, None]] = 0
probs = probs / probs.sum(axis=1, keepdims=True)
if sample:
next_token = torch.multinomial(probs, num_samples=1).squeeze(1) # (1)
else:
next_token = torch.argmax(probs, dim=-1) # (1)
input_ids = torch.cat([input_ids, next_token.unsqueeze(1)], dim=1)
output = [tokenizer.decode(output.tolist()) for output in input_ids][0]
self.train()
return output
|