Edit model card

Model Card for Model ID

  1. Download the model weight file model.safetensors and save it into your local path, e.g., ./weights/model.safetensors
  2. Build original T5-small model and modify lm_head and finally reload the weights
from transformers import AutoModelForSeq2SeqLM
from safetensors.torch import load_file

NUM_LABELS = 6 * 64

weight = load_file('./weights/model.safetensors')
model = AutoModelForSeq2SeqLM.from_pretrained('google-t5/t5-small')
model.lm_head = torch.nn.Linear(model.config.hidden_size, NUM_LABELS, bias=False)
model.load_state_dict(w)
  1. Predict the token patterns
bs = 2
seq_len = 10
prompt_ids = torch.randint(0, 100,(bs, seq_len))
attention_mask = torch.ones(bs, seq_len)
decode_ids = torch.randint(0, 100, (bs, 1))
decode_token_patterns = model(input_ids=prompt_ids, attention_mask=attention_mask, decoder_input_ids=decode_ids)
Downloads last month
1
Safetensors
Model size
60.7M params
Tensor type
F32
·