File size: 6,244 Bytes
a6668db
93cd37e
a6668db
93cd37e
 
 
596b12e
 
a6668db
cda61f2
 
 
 
 
 
c7072f7
cda61f2
 
 
 
 
 
 
 
 
 
 
 
 
c7072f7
cda61f2
 
 
 
 
 
c7072f7
cda61f2
 
 
 
 
 
 
 
 
3bb1d41
 
 
 
 
 
 
 
 
 
cda61f2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
---
language: ru
license: cc-by-nc-4.0
tags:
- paraphrasing
- seq2seq
datasets:
- inkoziev/paraphrases
---

## Поэтический перефразировщик

Это генеративная модель на основе ```sberbank-ai/rugpt3large_based_on_gpt2```, дообученной
на датасете перефразировок [inkoziev/paraphrases](https://huggingface.co/datasets/inkoziev/paraphrases).
Она разработана для использования в проекте [генеративной поэзии](https://github.com/Koziev/verslibre).
Код для тренировки и использования перефразировщика доступен в репозитрии [https://github.com/Koziev/paraphraser](https://github.com/Koziev/paraphraser).


### Особенности перефразировки

Обращаю внимание, что модель **не предназначена** для использования там, где требуется
особо аккуратная работа с именованными сущностями. Так как в стихах не возникает особых проблем (более того,
в некоторых сценариях использования это даже желательно), если перефразировки теряют или добавляют некоторую семантику в исходный текст, то обучающий датасет
и модель на его основе может путать дни недели, имена, добавлять что-то от себя, быть метафоричной или иносказательной.

### Методика файнтюна

В обучающем датасете есть негативные примеры перефразировок, и я использую их вместе с правильными примерами в ходе файнтюна,
подавая на классификационную голову в [GPT2DoubleHeadsModel](https://huggingface.co/docs/transformers/model_doc/gpt2#transformers.GPT2DoubleHeadsModel).
Код, выполняющий файнтюн, доступен [тут](https://github.com/Koziev/paraphraser/blob/main/train_paraphraser_with_gpt2doublehead.py).

Такой подход к файнтюну оказался лучше, чем два других подхода:

1) дефолтный способ файнтюна, когда GPT дообучается просто на текстах, состоящих из исходного текста и перефразировки,
разделенных специальным токеном. В этом подходе модель обучается также на токенах затравки, что может быть нежелательным.
2) вариация первого способа, в котором токены затравки (исходного текста) исключаются из обратного распространения с помощью
задания labels=-100 ([код](https://github.com/Koziev/paraphraser/blob/main/finetune_paraphraser_with_prompt_masking.py)).

В качестве метрики для сравнения подходов и для подбора числа неверных вариантов перефразировки в GPT2DoubleHeadsModel
использована комбинация из:
1) близость векторов эмбеддингов исходного текста и сгенерированной перефразировки. Векторы получаются с помощью
модели ```sberbank-ai/sbert_large_mt_nlu_ru```. Я не стал использовать [модель-критик](https://huggingface.co/inkoziev/sbert_synonymy),
поскольку она обучалась на таком же датасете.
2) дисконтируем результаты п.1 символьной близостью (3-граммы) по коэффициенту Жаккара. Это штрафует перестановочные
перефразировки, воспроизведение исходного текста и небольшие переписывания.

### Формат входных данных

На вход модели подается исходный текст с добавлением токенов ```<s>``` в начале и ```<sep>``` в конце, например:

```
input_text = '<s>Мороз и солнце, день чудесный<sep>'
```

Результат генерации будет содержать текст с токеном ```</s>``` - это конец последовательности.

### Пример использования

Следующий код позволяет ввести в консоли короткое предложение
и видеть результат ее перефразировки моделью:
```
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM


device = "cuda" if torch.cuda.is_available() else "cpu"

model_name = "inkoziev/paraphraser"

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
model.to(device)
model.eval()

while True:
    seed = input(':> ').strip()
    encoded_prompt = tokenizer.encode("<s>" + seed + "<sep>", add_special_tokens=False, return_tensors="pt").to(device)
    output_sequences = model.generate(input_ids=encoded_prompt,
                                      max_length=100,
                                      typical_p=0.85,
                                      top_k=0,
                                      top_p=1.0,
                                      do_sample=True,
                                      num_return_sequences=10,
                                      pad_token_id=tokenizer.pad_token_id)

    for o in output_sequences:
        text = tokenizer.decode(o.tolist(), clean_up_tokenization_spaces=True)
        text = text[text.index('<sep>') + 5:]
        text = text[: text.find('</s>')]
        print(text)
```