training logs
install
usage
import torch
from flash import FLASHForMaskedLM
from transformers import BertTokenizerFast
tokenizer = BertTokenizerFast.from_pretrained("junnyu/flash_small_wwm_cluecorpussmall")
model = FLASHForMaskedLM.from_pretrained("junnyu/flash_small_wwm_cluecorpussmall")
model.eval()
text = "天气预报说今天的天[MASK]很好,那么我[MASK]一起去公园玩吧!"
inputs = tokenizer(text, return_tensors="pt", padding="max_length", max_length=512, return_token_type_ids=False)
with torch.no_grad():
pt_outputs = model(**inputs).logits[0]
pt_outputs_sentence = "pytorch: "
for i, id in enumerate(tokenizer.encode(text)):
if id == tokenizer.mask_token_id:
val,idx = pt_outputs[i].softmax(-1).topk(k=5)
tokens = tokenizer.convert_ids_to_tokens(idx)
new_tokens = []
for v,t in zip(val.cpu(),tokens):
new_tokens.append(f"{t}+{round(v.item(),4)}")
pt_outputs_sentence += "[" + "||".join(new_tokens) + "]"
else:
pt_outputs_sentence += "".join(
tokenizer.convert_ids_to_tokens([id], skip_special_tokens=True))
print(pt_outputs_sentence)