Update README.md
Browse files
README.md
CHANGED
@@ -10,4 +10,42 @@ datasets:
|
|
10 |
# inference: true
|
11 |
# widget:
|
12 |
# - text: 'What are Glaciers?'
|
13 |
-
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
# inference: true
|
11 |
# widget:
|
12 |
# - text: 'What are Glaciers?'
|
13 |
+
---
|
14 |
+
## Description
|
15 |
+
This Question-Answering model was fine-tuned & trained from a generative, left-to-right transformer in the style of GPT-2, the [distilgpt2](https://huggingface.co/distilgpt2) model. This model was trained on [Wiki-QA](https://huggingface.co/datasets/wiki_qa) dataset from Microsoft.
|
16 |
+
|
17 |
+
# How to run Distil-GPT2-Wiki-QA using Transformers
|
18 |
+
## Question-Answering
|
19 |
+
|
20 |
+
The following code shows how to use the Distil-GPT2-Wiki-QA checkpoint and Transformers to generate Answers.
|
21 |
+
```python
|
22 |
+
from transformers import GPT2LMHeadModel, GPT2Tokenizer
|
23 |
+
|
24 |
+
import torch
|
25 |
+
import re
|
26 |
+
|
27 |
+
tokenizer = GPT2Tokenizer.from_pretrained("XBOT-RK/distilgpt2-wiki-qa")
|
28 |
+
model = GPT2LMHeadModel.from_pretrained("XBOT-RK/distilgpt2-wiki-qa")
|
29 |
+
|
30 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
31 |
+
|
32 |
+
def infer(question):
|
33 |
+
generated_tensor = model.generate(**tokenizer(question, return_tensors="pt").to(device), max_new_tokens = 50)
|
34 |
+
generated_text = tokenizer.decode(generated_tensor[0])
|
35 |
+
return generated_text
|
36 |
+
|
37 |
+
def processAnswer(question, result):
|
38 |
+
answer = result.replace(question, '').strip()
|
39 |
+
if "<bot>:" in answer:
|
40 |
+
answer = re.search('<bot>:(.*)', answer).group(1).strip()
|
41 |
+
if "<endofstring>" in answer:
|
42 |
+
answer = re.search('(.*)<endofstring>', answer).group(1).strip()
|
43 |
+
return answer
|
44 |
+
|
45 |
+
question = "What is a tropical cyclone?"
|
46 |
+
result = infer(question)
|
47 |
+
answer = processAnswer(question, result)
|
48 |
+
print('Question: ', question)
|
49 |
+
print('Answer: ', answer)
|
50 |
+
|
51 |
+
```
|