|
--- |
|
library_name: transformers |
|
tags: [] |
|
--- |
|
|
|
# Model Card for Model ID |
|
|
|
<!-- Provide a quick summary of what the model is/does. --> |
|
|
|
|
|
1. Download the model weight file `model.safetensors` and save it into your local path, e.g., `./weights/model.safetensors` |
|
2. Build original T5-small model and modify lm_head and finally reload the weights |
|
```python |
|
from transformers import AutoModelForSeq2SeqLM |
|
from safetensors.torch import load_file |
|
|
|
NUM_LABELS = 6 * 32 |
|
|
|
weight = load_file('./weights/model.safetensors') |
|
model = AutoModelForSeq2SeqLM.from_pretrained('google-t5/t5-small') |
|
model.lm_head = torch.nn.Linear(model.config.hidden_size, NUM_LABELS, bias=False) |
|
model.load_state_dict(w) |
|
``` |
|
3. Predict the token patterns |
|
|
|
```python |
|
bs = 2 |
|
seq_len = 10 |
|
prompt_ids = torch.randint(0, 100,(bs, seq_len)) |
|
attention_mask = torch.ones(bs, seq_len) |
|
decode_ids = torch.randint(0, 100, (bs, 1)) |
|
decode_token_patterns = model(input_ids=prompt_ids, attention_mask=attention_mask, decoder_input_ids=decode_ids) |
|
``` |