File size: 2,496 Bytes
eb28db4
b3abc18
 
 
 
 
 
 
eb28db4
 
b3abc18
 
6becb1d
b3abc18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
---

language:
- en
tags:
- pytorch
- causal-lm
- dcformer
- dcmha 
license: mit
---

DCPythia-6.9B is a pretrained language model on the Pile with 300B tokens. With comparison of Pythia-6.9B, we validate the scaling performance of Dynamically
Composable Multi-Head Attention(DCMHA), a parameter and computation efficient attention architecture that tackles the shortcomings of MHA and increases the expressive power of the model
by dynamically composing attention heads. Please see downstrem evaluations and more details in the paper[(Improving Transformers with Dynamically Composable Multi-Head Attention)](https://arxiv.org/abs/2405.08553). In addition, we open-source Jax training code on [(Github)](https://github.com/Caiyun-AI/DCFormer/).

We recommend <strong>compiled version</strong> of DCPythia with *torch.compile* for inference acceleration. Please refer to Generation section for compile implementation.

# Usage

## Env

You need to upgrade transformers to avoid [(loading problems)](https://github.com/huggingface/transformers/pull/29175).  
 
```

pip install transformers>=4.40.2

```


## Generation 

```

import torch

from transformers import AutoTokenizer, AutoModelForCausalLM



import os

os.environ['TOKENIZERS_PARALLELISM'] = 'false'



tokenizer = AutoTokenizer.from_pretrained("Caiyun-AI/DCPythia-6.9B")

model = AutoModelForCausalLM.from_pretrained("Caiyun-AI/DCPythia-6.9B", trust_remote_code=True)



device = torch.device('cuda')

MAX_BATCH_SIZE = 1

MAX_SEQ_LENGTH = 2048

NUM_TOKENS_TO_GENERATE = 100

COMPILE = True



_ = model.to(device=device,dtype=torch.float16)

with torch.device(device):

    model.setup_caches(max_batch_size=MAX_BATCH_SIZE, max_seq_length=MAX_SEQ_LENGTH, set_kv_cache=True)



def decode_one_token(model, cur_token, input_pos):

    logits = model(cur_token, input_pos=input_pos, return_tensor=True)

    new_token = torch.argmax(logits[:, -1], dim=-1)[:,None]

    return new_token



prompt = "Beijing is the capital of China. London is the capital of"

input_ids = tokenizer.encode(prompt, return_tensors='pt')



compiled_decode_one_token = torch.compile(decode_one_token,mode="reduce-overhead", fullgraph=True) if COMPILE else None



with torch.no_grad():

    generated_ids = model.generate(input_ids.to(device),num_tokens_to_generate=NUM_TOKENS_TO_GENERATE, compiled_decode_one_token=compiled_decode_one_token)

    text = tokenizer.decode(generated_ids[0])

    print('generated text:', text)

```