marsggbo's picture
Update README.md
ab44a13 verified
---
library_name: transformers
tags: []
---
# Model Card for Model ID
<!-- Provide a quick summary of what the model is/does. -->
1. Download the model weight file `model.safetensors` and save it into your local path, e.g., `./weights/model.safetensors`
2. Build original T5-small model and modify lm_head and finally reload the weights
```python
from transformers import AutoModelForSeq2SeqLM
from safetensors.torch import load_file
NUM_LABELS = 6 * 32
weight = load_file('./weights/model.safetensors')
model = AutoModelForSeq2SeqLM.from_pretrained('google-t5/t5-small')
model.lm_head = torch.nn.Linear(model.config.hidden_size, NUM_LABELS, bias=False)
model.load_state_dict(w)
```
3. Predict the token patterns
```python
bs = 2
seq_len = 10
prompt_ids = torch.randint(0, 100,(bs, seq_len))
attention_mask = torch.ones(bs, seq_len)
decode_ids = torch.randint(0, 100, (bs, 1))
decode_token_patterns = model(input_ids=prompt_ids, attention_mask=attention_mask, decoder_input_ids=decode_ids)
```