qywu commited on
Commit
6025ab3
·
1 Parent(s): b790351

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +128 -0
README.md CHANGED
@@ -1,3 +1,131 @@
1
  ---
2
  license: apache-2.0
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  license: apache-2.0
3
  ---
4
+
5
+ # Memformers
6
+
7
+ Memformers utilize a external dynamic memory to store history information.
8
+ This repo contains implementation of the pre-trained model MemBART and its training code.
9
+
10
+ Check the repo [memformers](https://github.com/qywu/memformers) fpr details.
11
+
12
+ ## Install
13
+
14
+ Download this repo and install it with:
15
+ ```bash
16
+ git clone https://github.com/qywu/memformers
17
+ cd memformers
18
+ pip install -e .
19
+ ```
20
+
21
+ ## Usage
22
+
23
+
24
+ ### Inference and Generation
25
+
26
+ Our implementation is based on huggingface [transformers](https://github.com/huggingface/transformers). Currently, we provide two checkpoints `"qywu/membart-large"` [(checkpooint)](https://huggingface.co/qywu/membart-large) and `"qywu/membart-base"`[(checkpooint)](https://huggingface.co/qywu/membart-base).
27
+ You can directly load the checkpoint with:
28
+
29
+ ```python
30
+ import torch
31
+ from transformers import AutoTokenizer
32
+ from memformers.models.membart import MemBartForConditionalGeneration
33
+
34
+ tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large")
35
+ # load the large model in huggingface way
36
+ membart = MemBartForConditionalGeneration.from_pretrained("qywu/membart-large")
37
+
38
+
39
+ text1 = "Barack Obama served as the 44th President of the United States."
40
+ text2 = "<mask> served as the 44th President of the United States."
41
+
42
+ # construct the initial memory
43
+ memory_states = membart.construct_memory(batch_size=1)
44
+
45
+ # t = 0
46
+ input_ids1 = torch.LongTensor([tokenizer.encode(text1)])
47
+ # only run the encoder to get memory states
48
+ encoder_outputs = membart.model.encoder(input_ids=input_ids1, memory_states=memory_states, attention_mask=None)
49
+ memory_states = encoder_outputs.memory_states
50
+
51
+
52
+ # t = 1
53
+ input_ids2 = torch.LongTensor([tokenizer.encode(text2)])
54
+
55
+ encoder_outputs2 = membart.model.encoder(input_ids=input_ids2, memory_states=memory_states, attention_mask=None)
56
+
57
+ outputs = membart.generate(
58
+ encoder_outputs=encoder_outputs2,
59
+ decoder_start_token_id=tokenizer.bos_token_id,
60
+ max_length=64,
61
+ num_beams=1,
62
+ do_sample=False,
63
+ return_dict_in_generate=True,
64
+ )
65
+
66
+ print(tokenizer.decode(outputs.sequences[0]))
67
+ # Barack Obama served as the 44th President of the United States.
68
+ ```
69
+
70
+
71
+ Note that due to [BART](https://arxiv.org/abs/1910.13461) denosing pre-training, it needs to further fine-tune the model on the downstream tasks to get better performance.
72
+
73
+ ### Training
74
+
75
+ Training requires to install [TorchFly](https://github.com/qywu/TorchFly).
76
+ ```bash
77
+ git clone https://github.com/qywu/TorchFly
78
+ cd TorchFly
79
+ pip install -e .
80
+ ```
81
+
82
+ Then, you can refer to the code in `examples/finetune_dialog` for details about finetuning or further pre-training MemBart on your tasks.
83
+
84
+ ```python
85
+ python train.py
86
+ ```
87
+
88
+ For details, see `examples/training_msc`.
89
+
90
+ ## Citations
91
+
92
+ Memformer: A Memory-Augmented Transformer for Sequence Modeling
93
+ ```bib
94
+ @inproceedings{DBLP:conf/ijcnlp/WuLQGGY22,
95
+ author = {Qingyang Wu and
96
+ Zhenzhong Lan and
97
+ Kun Qian and
98
+ Jing Gu and
99
+ Alborz Geramifard and
100
+ Zhou Yu},
101
+ title = {Memformer: {A} Memory-Augmented Transformer for Sequence Modeling},
102
+ booktitle = {Findings of the Association for Computational Linguistics: {AACL-IJCNLP}
103
+ 2022, Online only, November 20-23, 2022},
104
+ pages = {308--318},
105
+ publisher = {Association for Computational Linguistics},
106
+ year = {2022},
107
+ url = {https://aclanthology.org/2022.findings-aacl.29},
108
+ timestamp = {Tue, 29 Nov 2022 14:53:03 +0100},
109
+ biburl = {https://dblp.org/rec/conf/ijcnlp/WuLQGGY22.bib},
110
+ bibsource = {dblp computer science bibliography, https://dblp.org}
111
+ }
112
+ ```
113
+
114
+ Stateful Memory-Augmented Transformers for Dialogue Modeling
115
+ ```bib
116
+ @article{DBLP:journals/corr/abs-2209-07634,
117
+ author = {Qingyang Wu and
118
+ Zhou Yu},
119
+ title = {Stateful Memory-Augmented Transformers for Dialogue Modeling},
120
+ journal = {CoRR},
121
+ volume = {abs/2209.07634},
122
+ year = {2022},
123
+ url = {https://doi.org/10.48550/arXiv.2209.07634},
124
+ doi = {10.48550/arXiv.2209.07634},
125
+ eprinttype = {arXiv},
126
+ eprint = {2209.07634},
127
+ timestamp = {Tue, 27 Sep 2022 16:29:43 +0200},
128
+ biburl = {https://dblp.org/rec/journals/corr/abs-2209-07634.bib},
129
+ bibsource = {dblp computer science bibliography, https://dblp.org}
130
+ }
131
+ ```