nyust-eb210
commited on
Commit
•
6fbea0e
1
Parent(s):
bd09b82
update usage in readme.md
Browse files
README.md
CHANGED
@@ -23,4 +23,29 @@ Training Arguments:
|
|
23 |
|
24 |
- epoch: 3
|
25 |
|
26 |
-
[Colab for detailed](https://colab.research.google.com/drive/1kZv7ZRmvUdCKEhQg8MBrKljGWvV2X3CP?usp=sharing)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
|
24 |
- epoch: 3
|
25 |
|
26 |
+
[Colab for detailed](https://colab.research.google.com/drive/1kZv7ZRmvUdCKEhQg8MBrKljGWvV2X3CP?usp=sharing)
|
27 |
+
|
28 |
+
|
29 |
+
## Usage
|
30 |
+
|
31 |
+
### In Transformers
|
32 |
+
|
33 |
+
```python
|
34 |
+
text = "馬雲是我爸爸。"
|
35 |
+
query = "我爸爸是誰?"
|
36 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
37 |
+
tokenizer = BertTokenizerFast.from_pretrained("bert_drcd_384")
|
38 |
+
model = BertForQuestionAnswering.from_pretrained("bert_drcd_384").to(device)
|
39 |
+
encoded_input = tokenizer(text, query, return_tensors="pt").to(device)
|
40 |
+
qa_outputs = model(**encoded_input)
|
41 |
+
|
42 |
+
start = torch.argmax(qa_outputs.start_logits).item()
|
43 |
+
end = torch.argmax(qa_outputs.end_logits).item()
|
44 |
+
answer = encoded_input.input_ids.tolist()[0][start : end + 1]
|
45 |
+
answer = "".join(tokenizer.decode(answer).split())
|
46 |
+
|
47 |
+
start_prob = torch.max(torch.nn.Softmax(dim=-1)(qa_outputs.start_logits)).item()
|
48 |
+
end_prob = torch.max(torch.nn.Softmax(dim=-1)(qa_outputs.end_logits)).item()
|
49 |
+
confidence = (start_prob + end_prob) / 2
|
50 |
+
print(answer, confidence) # 馬雲, 0.98
|
51 |
+
```
|