File size: 1,590 Bytes
8612eed
 
86ead23
8612eed
c1d2c43
 
 
 
 
ec87c44
c1d2c43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
---

license: mit
inference: False
---


# training logs
- https://wandb.ai/junyu/huggingface/runs/1jg2jlgt

# install
- https://github.com/JunnYu/FLASHQuad_pytorch



# usage

```python

import torch

from flash import FLASHForMaskedLM

from transformers import BertTokenizerFast

tokenizer = BertTokenizerFast.from_pretrained("junnyu/flash_small_wwm_cluecorpussmall")

model = FLASHForMaskedLM.from_pretrained("junnyu/flash_small_wwm_cluecorpussmall")

model.eval()

text = "天气预报说今天的天[MASK]很好,那么我[MASK]一起去公园玩吧!"

inputs = tokenizer(text, return_tensors="pt", padding="max_length", max_length=512,  return_token_type_ids=False) #这里必须是512,不然结果可能不对。

with torch.no_grad():
    pt_outputs = model(**inputs).logits[0]


pt_outputs_sentence = "pytorch: "
for i, id in enumerate(tokenizer.encode(text)):
    if id == tokenizer.mask_token_id:

        val,idx = pt_outputs[i].softmax(-1).topk(k=5)

        tokens = tokenizer.convert_ids_to_tokens(idx)

        new_tokens = []

        for v,t in zip(val.cpu(),tokens):

            new_tokens.append(f"{t}+{round(v.item(),4)}")

        pt_outputs_sentence += "[" + "||".join(new_tokens) + "]"

    else:

        pt_outputs_sentence += "".join(

            tokenizer.convert_ids_to_tokens([id], skip_special_tokens=True))

print(pt_outputs_sentence)

# pytorch: 天气预报说今天的天[气+0.994||天+0.0015||空+0.0014||晴+0.0005||阳+0.0003]很好,那么我[们+0.9563||就+0.0381||也+0.0032||俩+0.0004||来+0.0002]一起去公园玩吧!

```