--- license: apache-2.0 metrics: - perplexity pipeline_tag: text-generation --- Train in 30B Byte. Mode size 353M. Table 2 in [MambaByte](https://arxiv.org/abs/2401.13660) To use ``` import torch from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel import numpy as np model=MambaLMHeadModel.from_pretrained("JunxiongWang/MambaByte_Code", device='cuda', dtype=torch.float32) text = "import torch" text_byte = np.frombuffer(text.encode('utf-8'), dtype=np.uint8) input_ids = torch.from_numpy(text_byte[None, :]).long().cuda() sample = model.generate( input_ids=input_ids, max_length=2048, cg=True, return_dict_in_generate=True, output_scores=True, enable_timing=True, temperature=1, top_k=256, top_p=0.9, ) print(bytes(sample.sequences[0].tolist()).decode('utf-8')) ``` Output ``` import torch import numpy as np import torch.nn.functional as F from torch.autograd import Variable from networkx.states import TransientState def extract_data(num_epochs, epochs, is_last_epoch): def get_data(num_features, num_classes): data_features = num_features data_classes = num_classes data_labels = num_epochs if num_features == 0 or num_classes == 0: return data_features, data_classes if is_last_epoch: data_features = num_features data_classes = num_classes data_labels = num_epochs return data_features, data_classes data_features, data_classes = get_data(num_epochs, epochs, is_last_epoch) data_labels = num_epochs * 2 return data_features, data_classes class NumChannel: def __init__(self, x, y, dx=1, dy=1, idx=1, data_size=2, epoch=None): """idx is the channel index with data feature in the first epoch. x is the channel of the input data. y is the element of the input data. dx is the element of the data feature of the input data. data_size is the size of the element of the data. epoch is the channel of the element of the data. """ self.x = x self.y = y self.dx = dx self.data_size = data_size self.epoch = epoch self.reference_count = 0 self.data_features = {} self.data_classes = {} self._initialize() if idx is not None: self._start_time = time.time() def _initialize(self): """idx is the channel index with data feature in the first epoch. x is the channel of the input data. y is the element of the input data. dx is the element of the data feature of the input data. data_size is the size of the element of the data. epoch is the channel of the element of the data. """ self.idx = idx ```