File size: 2,910 Bytes
45b51e1
6435dcf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
This model provides a Chinese GPT-2 language model trained with SimCTG on the LCCC benchmark [(Wang et al., 2020)](https://arxiv.org/pdf/2008.03946v2.pdf) based on our paper [_A Contrastive Framework for Neural Text Generation_](https://arxiv.org/abs/2202.06417).

We provide a detailed tutorial on how to apply SimCTG and Contrastive Search in our [project repo](https://github.com/yxuansu/SimCTG#4-huggingface-style-tutorials-back-to-top). In the following, we illustrate a brief tutorial on how to use our approach to perform text generation.

## 1. Installation of SimCTG:
```yaml
pip install simctg --upgrade
```

## 2. Initialize SimCTG Model:
```python
import torch
# load SimCTG language model
from simctg.simctggpt import SimCTGGPT
model_name = r'cambridgeltl/simctg_lccc_dialogue'
model = SimCTGGPT(model_name)
model.eval()
tokenizer = model.tokenizer
eos_token = '[SEP]'
eos_token_id = tokenizer.convert_tokens_to_ids([eos_token])[0]
```

## 3. Prepare the Text Prefix:
```python
context_list = ['刺猬很可爱!以前别人送了只没养,味儿太大!', '是很可爱但是非常臭', '是啊,没办法养', '那个怎么养哦不会扎手吗']
prefix_text = eos_token.join(context_list).strip(eos_token) + eos_token
print ('Prefix is: {}'.format(prefix_text))
tokens = tokenizer.tokenize(prefix_text)
input_ids = tokenizer.convert_tokens_to_ids(tokens)
input_ids = torch.LongTensor(input_ids).view(1,-1)
```

## 4. Generate Text with Contrastive Search:
```python
beam_width, alpha, decoding_len = 5, 0.6, 64
output = model.fast_contrastive_search(input_ids=input_ids, beam_width=beam_width, alpha=alpha, 
                                       decoding_len=decoding_len, end_of_sequence_token_id=eos_token_id,
                                       early_stop=True)  
                                           
print("Output:\n" + 100 * '-')
print(''.join(tokenizer.decode(output)))
'''
  Prefix is: 刺猬很可爱!以前别人送了只没养,味儿太大![SEP]是很可爱但是非常臭[SEP]是啊,没办法养[SEP]那个怎么养哦不会扎手吗[SEP]
  Output:
  ----------------------------------------------------------------------------------------------------
  刺猬很可爱!以前别人送了只没养,味儿太大![SEP]是很可爱但是非常臭[SEP]是啊,没办法养[SEP]那个怎么养哦不会扎手吗[SEP]我觉得还好,就是有点臭
'''
```

For more details of our work, please refer to our main [project repo](https://github.com/yxuansu/SimCTG).

## 5. Citation:
If you find our paper and resources useful, please kindly leave a star and cite our paper. Thanks!

```bibtex
@article{su2022contrastive,
  title={A Contrastive Framework for Neural Text Generation},
  author={Su, Yixuan and Lan, Tian and Wang, Yan and Yogatama, Dani and Kong, Lingpeng and Collier, Nigel},
  journal={arXiv preprint arXiv:2202.06417},
  year={2022}
}
```