File size: 2,036 Bytes
0ee996a
 
 
 
 
cab9a88
0ee996a
cab9a88
0ee996a
 
 
 
cab9a88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0ee996a
 
 
cab9a88
 
 
 
 
 
 
 
 
 
 
 
 
481a13a
cab9a88
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
---
library_name: transformers
tags: []
---

# Jamba-Small v2

This is a pruned version of AI21 Labs' Jamba-v0.1 model that is ~25% the size of Jamba-v0.1.



## Model Details
Whereas Jamba-v0.1 contains 4 Jamba blocks, Jamba-Small contains only 1 Jamba block.
Jamba-Small's Jamba blocks follow the same structure seen in Jamba-v0.1, with a 1:7 ratio of attention-to-Mamba layers and MoE applied every 2 layers.

Jamba-Small's weights are initialized from various layers in the original Jamba-v0.1 model. For v2, the layer weights are mapped as follows (left is Jamba-Small layer number, right is Jamba-v0.1 layer number):
```
0: 0,  # Block 0, layer 0 (mamba)
1: 1,  # Block 0, layer 1 (mamba MoE)
2: 6,  # Block 0, layer 6 (mamba)
3: 9,  # Block 1, layer 1 (mamba MoE)
4: 12, # Block 1, layer 4 (transformer)
5: 15, # Block 1, layer 7 (mamba MoE)
6: 24, # Block 3, layer 0 (mamba)
7: 31  # Block 4, layer 7 (mamba MoE)
```

Note that no additional fine-tuning has been performed on this model. As such, its performance is exceptionally poor. This should not be used in production without additional training.

### Model Description

- **Developed by:** Nathan Brown (OxxoCodes)
- **Compute provided by:** Clemson Palmetto Cluster
- **Model type:** Joint Attention and Mamba (Jamba)
- **Language(s) (NLP):** English
- **License:** Apache 2.0
- **Original model:** [Jamba-v0.1](https://huggingface.co/ai21labs/Jamba-v0.1)
- **Jamba paper:** [https://arxiv.org/pdf/2403.19887.pdf](https://arxiv.org/pdf/2403.19887.pdf)

### How to Use
```python
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained("OxxoCodes/jamba-small-v2", torch_dtype=torch.bfloat16, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained("ai21labs/Jamba-v0.1")

with torch.no_grad():
    input_ids = tokenizer("There once was a", return_tensors='pt').to(model.device)["input_ids"]
    outputs = model.generate(input_ids, max_new_tokens=216)
    print(tokenizer.batch_decode(outputs))
```