File size: 3,932 Bytes
889c722
bf445ef
 
61fd762
bf445ef
61fd762
 
4bd6084
889c722
 
7a5a4d6
889c722
718e39a
889c722
7a5a4d6
889c722
7a5a4d6
 
 
 
 
889c722
7a5a4d6
889c722
7a5a4d6
 
e320ad7
889c722
7a5a4d6
 
 
889c722
7a5a4d6
 
 
889c722
7a5a4d6
 
 
 
 
 
 
 
 
889c722
7a5a4d6
889c722
7a5a4d6
889c722
7a5a4d6
 
889c722
7a5a4d6
 
da9daca
 
 
 
7a5a4d6
 
889c722
 
7a5a4d6
da9daca
 
 
 
7a5a4d6
 
889c722
 
7a5a4d6
 
da9daca
 
 
 
7a5a4d6
 
889c722
7a5a4d6
 
 
 
 
889c722
7a5a4d6
889c722
7a5a4d6
 
 
 
 
 
e2deeaa
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
---
language:
- ru
license: apache-2.0
library_name: transformers
datasets:
- hivaze/ru-AAQG-QA-QG
pipeline_tag: text2text-generation
---

## Description

This is **ai-forever/FRED-T5-1.7B** model trained on **Question-Answering**, **Question-Generation** and **Answer-Aware Question Generation** tasks on russian dataset (**hivaze/ru-AAQG-QA-QG**)

### Prompts

```python
AAQG_PROMPT = "Сгенерируй вопрос по тексту, используя известный ответ. Текст: '{context}'. Ответ: '{answer}'."
QG_PROMPT = "Сгенерируй вопрос по тексту. Текст: '{context}'."
QA_PROMPT = "Сгенерируй ответ на вопрос по тексту. Текст: '{context}'. Вопрос: '{question}'."
```

### Examples and code

```python
from transformers import AutoTokenizer, T5ForConditionalGeneration
from functools import partial

saved_checkpoint = 'hivaze/AAQG-QA-QG-FRED-T5-1.7B'
tokenizer = AutoTokenizer.from_pretrained(saved_checkpoint)
model = T5ForConditionalGeneration.from_pretrained(saved_checkpoint).cuda()

def generate_text(prompt, tokenizer, model, n=1, temperature=0.8, num_beams=3):
  encoded_input = tokenizer.encode_plus(prompt, return_tensors='pt')
  encoded_input = {k: v.to(model.device) for k, v in encoded_input.items()}

  resulted_tokens = model.generate(**encoded_input,
                                   max_new_tokens=64,
                                   do_sample=True,
                                   num_beams=num_beams,
                                   num_return_sequences=n,
                                   temperature=temperature,
                                   top_p=0.9,
                                   top_k=50)
  resulted_texts = tokenizer.batch_decode(resulted_tokens, skip_special_tokens=True)

  return resulted_texts

generate_text = partial(generate_text, tokenizer=tokenizer, model=model)

test_context = "Путешественник Федор Конюхов и пилот Игорь Потапкин установили мировой рекорд высоты полета на паралёте, поднявшись на высоту 4728 метров — сайт Конюхова"
```

#### AAQG
```python
generate_text(AAQG_PROMPT.format(
  context=test_context,
  answer='на паралёте'
), n=1)
```
> "На чем путешественник Федор Конюхов и пилот Игорь Потапкин установили мировой рекорд высоты полета?"


```python
generate_text(AAQG_PROMPT.format(
  context=test_context,
  answer='рекорд высоты полета'
), n=1)
```
> "Что установили путешественник Федор Конюхов и пилот Игорь Потапкин?"


#### QA
```python
generate_text(QA_PROMPT.format(
  context=test_context,
  question='Что установили путешественник Федор Конюхов и пилот Игорь Потапкин?'
), n=1)
```
> "Мировой рекорд высоты полета на паралёте"

#### QG
```python
generate_text(QG_PROMPT.format(context=test_context), n=1)
```
> "Кто установил мировой рекорд высоты полета на паралёте?"

## Metrics

| Step | Training Loss | Validation Loss | Sbleu | Chr F | Rouge1 | Rouge2 | Rougel |
|---|---|---|---|---|---|---|---|
| 500 | 1.020500 | 1.059296 | 41.556000 | 66.391100 | 0.104200 | 0.033700 | 0.104200 |
| 1000 | 1.050200 | 0.998357 | 43.035900 | 66.376800 | 0.105100 | 0.034100 | 0.105200 |
| 1500 | 0.994000 | 0.966051 | 43.692200 | 66.597600 | 0.106300 | 0.034400 | 0.106400 |
| 2000 | 0.947800 | 0.953637 | 44.012400 | 66.711100 | 0.106600 | 0.034900 | 0.106800 |
| 2500 | 0.978200 | 0.944621 | 44.027900 | 66.657400 | 0.106500 | 0.034600 | 0.106500 |

## Authors
- Sergei Bratchikov (https://t.me/nlpwanderer)