qywu
/

Text2Text Generation
Transformers
PyTorch
Inference Endpoints
Edit model card

Memformers

Memformers utilize a external dynamic memory to store history information. This repo contains implementation of the pre-trained model MemBART and its training code.

Check the repo memformers for details.

Install

Download this repo and install it with:

git clone https://github.com/qywu/memformers
cd memformers
pip install -e .

Usage

Inference and Generation

Our implementation is based on huggingface transformers. Currently, we provide two checkpoints "qywu/membart-large" (checkpooint) and "qywu/membart-base"(checkpooint). You can directly load the checkpoint with:

import torch
from transformers import AutoTokenizer
from memformers.models.membart import MemBartForConditionalGeneration

tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large")
# load the large model in huggingface way
membart = MemBartForConditionalGeneration.from_pretrained("qywu/membart-large")


text1 = "Barack Obama served as the 44th President of the United States."
text2 = "<mask> served as the 44th President of the United States."

# construct the initial memory
memory_states = membart.construct_memory(batch_size=1)

# t = 0
input_ids1 = torch.LongTensor([tokenizer.encode(text1)])
# only run the encoder to get memory states
encoder_outputs = membart.model.encoder(input_ids=input_ids1, memory_states=memory_states, attention_mask=None)
memory_states = encoder_outputs.memory_states


# t = 1
input_ids2 = torch.LongTensor([tokenizer.encode(text2)])

encoder_outputs2 = membart.model.encoder(input_ids=input_ids2, memory_states=memory_states, attention_mask=None)

outputs = membart.generate(
    encoder_outputs=encoder_outputs2,
    decoder_start_token_id=tokenizer.bos_token_id,
    max_length=64,
    num_beams=1,
    do_sample=False,
    return_dict_in_generate=True,
)

print(tokenizer.decode(outputs.sequences[0]))
# Barack Obama served as the 44th President of the United States.

Note that due to BART denosing pre-training, it needs to further fine-tune the model on the downstream tasks to get better performance.

Training

Training requires to install TorchFly.

git clone https://github.com/qywu/TorchFly
cd TorchFly
pip install -e .

Then, you can refer to the code in examples/finetune_dialog for details about finetuning or further pre-training MemBart on your tasks.

python train.py

For details, see examples/training_msc.

Citations

Memformer: A Memory-Augmented Transformer for Sequence Modeling

@inproceedings{DBLP:conf/ijcnlp/WuLQGGY22,
  author    = {Qingyang Wu and
               Zhenzhong Lan and
               Kun Qian and
               Jing Gu and
               Alborz Geramifard and
               Zhou Yu},
  title     = {Memformer: {A} Memory-Augmented Transformer for Sequence Modeling},
  booktitle = {Findings of the Association for Computational Linguistics: {AACL-IJCNLP}
               2022, Online only, November 20-23, 2022},
  pages     = {308--318},
  publisher = {Association for Computational Linguistics},
  year      = {2022},
  url       = {https://aclanthology.org/2022.findings-aacl.29},
  timestamp = {Tue, 29 Nov 2022 14:53:03 +0100},
  biburl    = {https://dblp.org/rec/conf/ijcnlp/WuLQGGY22.bib},
  bibsource = {dblp computer science bibliography, https://dblp.org}
}

Stateful Memory-Augmented Transformers for Dialogue Modeling

@article{DBLP:journals/corr/abs-2209-07634,
  author    = {Qingyang Wu and
               Zhou Yu},
  title     = {Stateful Memory-Augmented Transformers for Dialogue Modeling},
  journal   = {CoRR},
  volume    = {abs/2209.07634},
  year      = {2022},
  url       = {https://doi.org/10.48550/arXiv.2209.07634},
  doi       = {10.48550/arXiv.2209.07634},
  eprinttype = {arXiv},
  eprint    = {2209.07634},
  timestamp = {Tue, 27 Sep 2022 16:29:43 +0200},
  biburl    = {https://dblp.org/rec/journals/corr/abs-2209-07634.bib},
  bibsource = {dblp computer science bibliography, https://dblp.org}
}
Downloads last month
5
Inference API
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social visibility and check back later, or deploy to Inference Endpoints (dedicated) instead.