File size: 1,000 Bytes
9ef7a9a
 
 
 
 
 
 
 
 
 
ab44a13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
---
library_name: transformers
tags: []
---

# Model Card for Model ID

<!-- Provide a quick summary of what the model is/does. -->


1. Download the model weight file `model.safetensors` and save it into your local path, e.g., `./weights/model.safetensors`
2. Build original T5-small model and modify lm_head and finally reload the weights
```python
from transformers import AutoModelForSeq2SeqLM
from safetensors.torch import load_file

NUM_LABELS = 6 * 32

weight = load_file('./weights/model.safetensors')
model = AutoModelForSeq2SeqLM.from_pretrained('google-t5/t5-small')
model.lm_head = torch.nn.Linear(model.config.hidden_size, NUM_LABELS, bias=False)
model.load_state_dict(w)
```
3. Predict the token patterns

```python
bs = 2
seq_len = 10
prompt_ids = torch.randint(0, 100,(bs, seq_len))
attention_mask = torch.ones(bs, seq_len)
decode_ids = torch.randint(0, 100, (bs, 1))
decode_token_patterns = model(input_ids=prompt_ids, attention_mask=attention_mask, decoder_input_ids=decode_ids)
```