Update README.md
Browse files
README.md
CHANGED
@@ -1,3 +1,131 @@
|
|
1 |
---
|
2 |
license: apache-2.0
|
3 |
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
license: apache-2.0
|
3 |
---
|
4 |
+
|
5 |
+
# Memformers
|
6 |
+
|
7 |
+
Memformers utilize a external dynamic memory to store history information.
|
8 |
+
This repo contains implementation of the pre-trained model MemBART and its training code.
|
9 |
+
|
10 |
+
Check the repo [memformers](https://github.com/qywu/memformers) fpr details.
|
11 |
+
|
12 |
+
## Install
|
13 |
+
|
14 |
+
Download this repo and install it with:
|
15 |
+
```bash
|
16 |
+
git clone https://github.com/qywu/memformers
|
17 |
+
cd memformers
|
18 |
+
pip install -e .
|
19 |
+
```
|
20 |
+
|
21 |
+
## Usage
|
22 |
+
|
23 |
+
|
24 |
+
### Inference and Generation
|
25 |
+
|
26 |
+
Our implementation is based on huggingface [transformers](https://github.com/huggingface/transformers). Currently, we provide two checkpoints `"qywu/membart-large"` [(checkpooint)](https://huggingface.co/qywu/membart-large) and `"qywu/membart-base"`[(checkpooint)](https://huggingface.co/qywu/membart-base).
|
27 |
+
You can directly load the checkpoint with:
|
28 |
+
|
29 |
+
```python
|
30 |
+
import torch
|
31 |
+
from transformers import AutoTokenizer
|
32 |
+
from memformers.models.membart import MemBartForConditionalGeneration
|
33 |
+
|
34 |
+
tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large")
|
35 |
+
# load the large model in huggingface way
|
36 |
+
membart = MemBartForConditionalGeneration.from_pretrained("qywu/membart-large")
|
37 |
+
|
38 |
+
|
39 |
+
text1 = "Barack Obama served as the 44th President of the United States."
|
40 |
+
text2 = "<mask> served as the 44th President of the United States."
|
41 |
+
|
42 |
+
# construct the initial memory
|
43 |
+
memory_states = membart.construct_memory(batch_size=1)
|
44 |
+
|
45 |
+
# t = 0
|
46 |
+
input_ids1 = torch.LongTensor([tokenizer.encode(text1)])
|
47 |
+
# only run the encoder to get memory states
|
48 |
+
encoder_outputs = membart.model.encoder(input_ids=input_ids1, memory_states=memory_states, attention_mask=None)
|
49 |
+
memory_states = encoder_outputs.memory_states
|
50 |
+
|
51 |
+
|
52 |
+
# t = 1
|
53 |
+
input_ids2 = torch.LongTensor([tokenizer.encode(text2)])
|
54 |
+
|
55 |
+
encoder_outputs2 = membart.model.encoder(input_ids=input_ids2, memory_states=memory_states, attention_mask=None)
|
56 |
+
|
57 |
+
outputs = membart.generate(
|
58 |
+
encoder_outputs=encoder_outputs2,
|
59 |
+
decoder_start_token_id=tokenizer.bos_token_id,
|
60 |
+
max_length=64,
|
61 |
+
num_beams=1,
|
62 |
+
do_sample=False,
|
63 |
+
return_dict_in_generate=True,
|
64 |
+
)
|
65 |
+
|
66 |
+
print(tokenizer.decode(outputs.sequences[0]))
|
67 |
+
# Barack Obama served as the 44th President of the United States.
|
68 |
+
```
|
69 |
+
|
70 |
+
|
71 |
+
Note that due to [BART](https://arxiv.org/abs/1910.13461) denosing pre-training, it needs to further fine-tune the model on the downstream tasks to get better performance.
|
72 |
+
|
73 |
+
### Training
|
74 |
+
|
75 |
+
Training requires to install [TorchFly](https://github.com/qywu/TorchFly).
|
76 |
+
```bash
|
77 |
+
git clone https://github.com/qywu/TorchFly
|
78 |
+
cd TorchFly
|
79 |
+
pip install -e .
|
80 |
+
```
|
81 |
+
|
82 |
+
Then, you can refer to the code in `examples/finetune_dialog` for details about finetuning or further pre-training MemBart on your tasks.
|
83 |
+
|
84 |
+
```python
|
85 |
+
python train.py
|
86 |
+
```
|
87 |
+
|
88 |
+
For details, see `examples/training_msc`.
|
89 |
+
|
90 |
+
## Citations
|
91 |
+
|
92 |
+
Memformer: A Memory-Augmented Transformer for Sequence Modeling
|
93 |
+
```bib
|
94 |
+
@inproceedings{DBLP:conf/ijcnlp/WuLQGGY22,
|
95 |
+
author = {Qingyang Wu and
|
96 |
+
Zhenzhong Lan and
|
97 |
+
Kun Qian and
|
98 |
+
Jing Gu and
|
99 |
+
Alborz Geramifard and
|
100 |
+
Zhou Yu},
|
101 |
+
title = {Memformer: {A} Memory-Augmented Transformer for Sequence Modeling},
|
102 |
+
booktitle = {Findings of the Association for Computational Linguistics: {AACL-IJCNLP}
|
103 |
+
2022, Online only, November 20-23, 2022},
|
104 |
+
pages = {308--318},
|
105 |
+
publisher = {Association for Computational Linguistics},
|
106 |
+
year = {2022},
|
107 |
+
url = {https://aclanthology.org/2022.findings-aacl.29},
|
108 |
+
timestamp = {Tue, 29 Nov 2022 14:53:03 +0100},
|
109 |
+
biburl = {https://dblp.org/rec/conf/ijcnlp/WuLQGGY22.bib},
|
110 |
+
bibsource = {dblp computer science bibliography, https://dblp.org}
|
111 |
+
}
|
112 |
+
```
|
113 |
+
|
114 |
+
Stateful Memory-Augmented Transformers for Dialogue Modeling
|
115 |
+
```bib
|
116 |
+
@article{DBLP:journals/corr/abs-2209-07634,
|
117 |
+
author = {Qingyang Wu and
|
118 |
+
Zhou Yu},
|
119 |
+
title = {Stateful Memory-Augmented Transformers for Dialogue Modeling},
|
120 |
+
journal = {CoRR},
|
121 |
+
volume = {abs/2209.07634},
|
122 |
+
year = {2022},
|
123 |
+
url = {https://doi.org/10.48550/arXiv.2209.07634},
|
124 |
+
doi = {10.48550/arXiv.2209.07634},
|
125 |
+
eprinttype = {arXiv},
|
126 |
+
eprint = {2209.07634},
|
127 |
+
timestamp = {Tue, 27 Sep 2022 16:29:43 +0200},
|
128 |
+
biburl = {https://dblp.org/rec/journals/corr/abs-2209-07634.bib},
|
129 |
+
bibsource = {dblp computer science bibliography, https://dblp.org}
|
130 |
+
}
|
131 |
+
```
|