Edit model card

Asymmetric version of bigscience's mt0 large model, which trim down like Meena chatbot from Google. An interesting aspect of Meena is that they use a small encoder, big decoder architecture.

Meena has a single Evolved Transformer encoder block and 13 Evolved Transformer decoder blocks, as illustrated below. The encoder is responsible for processing the conversation context to help Meena understand what has already been said in the conversation. The decoder then uses that information to formulate an actual response. Through tuning the hyper-parameters, we discovered that a more powerful decoder was the key to higher conversational quality.

This trimmed version of model is to serve as a pretrained version of Meena like architecture for future research into this arch design idea.

Here's how the model are trimmed:

1. Vocab truncation

The tokenizer is passed through a list of corpus multilingual dialogue and only the set of token ids used are kept. We then remap the tokenizer and embedding to trim down about 40% of embedding weights. Do note that the byte level token were still kept, so this tokenizer should still be able to handle unseen characters.

2. Encoder trunction

Model restructuring

Using the trimmed down version of mt0 model, we then cut the encoder layers to 4 layers only ( making it about 1:6 encoder-decoder ratio ). Ideally we only want to kept 2 layers of encoder but I find it to be too weak in the later stage.

Model reinitialization

Because the new encoder has a different output embeddings than the old ones, we need to do a retraining using the original encoder as the teacher. In this project, we simply use the output features of the original encoder as the latent ground truth and the new smaller encoder task is to fit the ground truth latent via MAE loss.

All models are trained until the loss curve plateau and no longer improves.

Result

No reinitialization phase:

input :what is one plus one?</s>
trimmed output : Extendeds

input :你想ηŸ₯ι“ι—œζ–Όζˆ‘ηš„δ»€ιΊΌ?</s>
trimmed output : δ½ δΈͺζœ‹ε‹

input :こんにけは!γŠε…ƒζ°—</s>
trimmed output : !- -_n_wip-------------D2

With reinitialization phase:

input :what is one plus one?</s>
trimmed output : hundred

input :你想ηŸ₯ι“ι—œζ–Όζˆ‘ηš„δ»€ιΊΌ?</s>
trimmed output : 你們?ζˆ‘ε°ε—Ž?ι“δΈι‚£εšηš„ε°θͺͺε°ι€™εšδΊ†ε—ι—œζ²’?

input :こんにけは!γŠε…ƒζ°—</s>
trimmed output : !,

note that it's impossible to have the performance of the orignal model since roughly 30% of the weights were trimmed away.

Downloads last month
3