John David Pressman
commited on
Commit
•
29ebdee
1
Parent(s):
167d940
Add model and update README
Browse files- README.md +123 -1
- adapter_config.json +29 -0
- adapter_model.safetensors +3 -0
- decoder/adapter_config.json +29 -0
- decoder/adapter_model.safetensors +3 -0
- encoder/adapter_config.json +29 -0
- encoder/adapter_model.safetensors +3 -0
- router/adapter_config.json +29 -0
- router/adapter_model.safetensors +3 -0
- state.json +1 -0
- vae.safetensors +3 -0
README.md
CHANGED
@@ -1,3 +1,125 @@
|
|
1 |
---
|
2 |
-
|
3 |
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
+
library_name: peft
|
3 |
---
|
4 |
+
|
5 |
+
BigVAE is an [AdaVAE](https://arxiv.org/abs/2205.05862) trained as a pair of LoRa finetunes on [Mistral 7B](https://huggingface.co/mistralai/Mistral-7B-v0.1).
|
6 |
+
It is meant to be used with the [MiniHF VAE inference code](https://github.com/JD-P/minihf/blob/adavae-moe/vae_infer.py) and will not work if you try to load it
|
7 |
+
as an ordinary language checkpoint and perform inference. AdaVAE is an encoder-decoder model trained by taking an existing GPT-N and designating one LoRa the
|
8 |
+
encoder and the other its decoder and then tuning with a latent attention mechanism. This model is the encoder and router decoder head for BigVAE, a planned
|
9 |
+
Mixture-of-Experts system based on LoRa retrieval rather than gating. It is usable in and of itself as a model for embedding, retrieval, as well as planning
|
10 |
+
and guided sampling. Here is an example of a sampling procedure for BigVAE which distills its autoregressive pretraining task into its autoassociative
|
11 |
+
recontruction task by averaging together multiple completions. It takes the topic sentence of a paragraph (prompt), guides the next sentences by weighing
|
12 |
+
them towards the topic, while averaging together multiple completions on each sentence to improve generation quality:
|
13 |
+
|
14 |
+
```
|
15 |
+
def bigvae_generate_avg(vae_model, router, prompt, context, n_steps, n_avg):
|
16 |
+
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
|
17 |
+
context_toks = tokenizer(context, return_tensors="pt")
|
18 |
+
context_ids = context_toks["input_ids"].to(device)
|
19 |
+
context_mask = context_toks["attention_mask"].to(device)
|
20 |
+
embed_toks = tokenizer(prompt, return_tensors="pt")
|
21 |
+
embed_ids = embed_toks["input_ids"].to(device)
|
22 |
+
embed_mask = embed_toks["attention_mask"].to(device)
|
23 |
+
mean = vae_model.encode(embed_ids, embed_mask)
|
24 |
+
prompt_embed = vae_model.vae.sample(mean)
|
25 |
+
for i in range(n_steps):
|
26 |
+
mean = vae_model.encode(embed_ids, embed_mask)
|
27 |
+
z = vae_model.vae.sample(mean)
|
28 |
+
embeds = []
|
29 |
+
for i in range(n_avg):
|
30 |
+
output_ids = router.generate(z * 0.5 + prompt_embed * 0.5,
|
31 |
+
context_ids,
|
32 |
+
context_mask,
|
33 |
+
256,
|
34 |
+
tau=0.9)
|
35 |
+
intermediate_embed_ids = output_ids[:,-128:]
|
36 |
+
intermediate_embed_mask = context_mask.new_ones(
|
37 |
+
[1, intermediate_embed_ids.shape[1]]
|
38 |
+
)
|
39 |
+
mean = vae_model.encode(intermediate_embed_ids, intermediate_embed_mask)
|
40 |
+
embeds.append(vae_model.vae.sample(mean))
|
41 |
+
output_ids = router.generate((sum(embeds) / n_avg * 0.7) + prompt_embed * 0.3,
|
42 |
+
context_ids,
|
43 |
+
context_mask,
|
44 |
+
256,
|
45 |
+
tau=0.9)
|
46 |
+
context_ids = torch.cat([context_ids, embed_ids], dim=1)
|
47 |
+
context_mask = torch.cat([context_mask, embed_mask], dim=1)
|
48 |
+
embed_ids = output_ids[:,-256:-128]
|
49 |
+
embed_mask = context_mask.new_ones([1, embed_ids.shape[1]])
|
50 |
+
out_texts = [tokenizer.decode(toks, skip_special_tokens=True) for toks in context_ids]
|
51 |
+
return out_texts
|
52 |
+
```
|
53 |
+
|
54 |
+
Here is an example of an output from this process:
|
55 |
+
|
56 |
+
```
|
57 |
+
Then it asked the network to reconstruct the input and the original embedding. The network had to learn to match the
|
58 |
+
embedding to the original input, therefore matching the inference by consuming the embedding. This was key because
|
59 |
+
the embedding had to be able to match the text with the text it was consumed with. 'Here's how you do it,' Boru told Mu,
|
60 |
+
'Just impute the mean and variance.' This Mu did, transforming not words but entire paragraphs into vectors and then
|
61 |
+
inferring the next paragraph. It took some tweaks and tuning to get the initial performance but the second arago spot
|
62 |
+
had been found. To make sure the network was learning the right thing, Boru had to check the first value in the vector.
|
63 |
+
If the first value was below 0, the network had failed to learn the first value. If the value was above 0, the network
|
64 |
+
had been able to learn the first value.
|
65 |
+
‘What have you called this, Boru?’ asked Mu. ‘Latent variable regression.’ ‘It looks like a mixture of density network
|
66 |
+
and autoencoder,’ said Nayaf. ‘It’s an autoencoder but it’s using latent variables, but we’re using the mean and variance
|
67 |
+
of Grade had a difficult time seeing it, but he could tell it was close. 'So you've found the second arago,' he said.
|
68 |
+
'Yes,' Rin replied. 'We just have to figure out how to use it.'
|
69 |
+
'How?' Rin asked.
|
70 |
+
'You can move the second word in, right?'
|
71 |
+
'Possibly.' Rin thought for a moment.
|
72 |
+
'The second word will be the first word of the next arago,' Mu said. 'We just need to find it.'
|
73 |
+
'True,' Rin agreed. 'Well, I'll let you know what a Gaussian.’ ‘Let’s see if we can get it to work.’ ‘Arago the second
|
74 |
+
spot?’ ‘We’re here,’ Arago said.
|
75 |
+
The second spot was located in the middle of the text. Arago had to read it again to find the proper signal. ‘I’m going
|
76 |
+
to have to tweak some of the weights,’ said Arago. ‘I’ve had to change the input to the next layer from an input to
|
77 |
+
output.’ ‘You’re making a mistake again,’ said Mu to Arago. ‘It’s a mistake.’ The network had been learning I find out.'
|
78 |
+
'That's the second arago,' Rin said.
|
79 |
+
'The second arago?' Argo asked.
|
80 |
+
'Rin has found the second arago.'
|
81 |
+
Argo stared at Rin. 'Argo, is there something wrong?'
|
82 |
+
'I thought so.'
|
83 |
+
'What?' Rin said.
|
84 |
+
'I don't know,' Argo said. 'I thought I was the smartest person in the world but, well, I only had a certain amount of
|
85 |
+
energy. I didn't know how to do the second arago until now, but I can't
|
86 |
+
```
|
87 |
+
|
88 |
+
This generation method is slow, but retrieval could be used to speed up inference and make it converge closer and closer
|
89 |
+
to normal sampling speed as the model becomes able to call upon more and more relevant sentences that it has generated before.
|
90 |
+
|
91 |
+
Because the BigVAE combines guided sampling with the ability to merge representations, it becomes possible to formulate plans and
|
92 |
+
cognitive strategies for the model to follow. The inference policy can adjudicate between an expected plan or series of steps and
|
93 |
+
the specific context the model is responding to.
|
94 |
+
|
95 |
+
This model is also highly interpretable. Because it is an encoder-decoder every sentence generated by the model has a latent representation
|
96 |
+
that can be tracked along with its behavioral token sequence. Our hope is that BigVAE will shed light on the latent operations performed by
|
97 |
+
autoregressive language models and be useful to alignment and interpretability researchers.
|
98 |
+
|
99 |
+
## Training procedure
|
100 |
+
|
101 |
+
This model was trained on [a 1 billion token sample](https://huggingface.co/datasets/togethercomputer/RedPajama-Data-1T-Sample) of RedPajama
|
102 |
+
on 8x H100 GPUs for roughly 24 hours. The difference from v0.1 is that KL weight was turned up to 0.1 over 50k steps.
|
103 |
+
|
104 |
+
Using the scripts in the MiniHF repo as they exist now the training commands were:
|
105 |
+
|
106 |
+
accelerate launch train_vae_overlap.py --model "mistralai/Mistral-7B-v0.1" --preprocessed preprocessed_mistral --context 64 --output vae_64_overlap_mistral_2 --batch-size 24
|
107 |
+
|
108 |
+
accelerate launch train_vae_router.py --model "mistralai/Mistral-7B-v0.1" --preprocessed preprocessed_mistral --vae-context 64 --start-from vae_64_overlap_mistral_2 --output vae_64_overlap_router_mistral_2 --lr 1e-4 --batch-size 1
|
109 |
+
|
110 |
+
|
111 |
+
The following `bitsandbytes` quantization config was used during training:
|
112 |
+
- quant_method: bitsandbytes
|
113 |
+
- load_in_8bit: False
|
114 |
+
- load_in_4bit: True
|
115 |
+
- llm_int8_threshold: 6.0
|
116 |
+
- llm_int8_skip_modules: None
|
117 |
+
- llm_int8_enable_fp32_cpu_offload: False
|
118 |
+
- llm_int8_has_fp16_weight: False
|
119 |
+
- bnb_4bit_quant_type: nf4
|
120 |
+
- bnb_4bit_use_double_quant: True
|
121 |
+
- bnb_4bit_compute_dtype: bfloat16
|
122 |
+
|
123 |
+
### Framework versions
|
124 |
+
|
125 |
+
- PEFT 0.4.0
|
adapter_config.json
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"auto_mapping": {
|
3 |
+
"base_model_class": "MistralForCausalLM",
|
4 |
+
"parent_library": "transformers.models.mistral.modeling_mistral"
|
5 |
+
},
|
6 |
+
"base_model_name_or_path": "mistralai/Mistral-7B-v0.1",
|
7 |
+
"bias": "none",
|
8 |
+
"fan_in_fan_out": false,
|
9 |
+
"inference_mode": true,
|
10 |
+
"init_lora_weights": true,
|
11 |
+
"layers_pattern": null,
|
12 |
+
"layers_to_transform": null,
|
13 |
+
"lora_alpha": 8,
|
14 |
+
"lora_dropout": 0.0,
|
15 |
+
"modules_to_save": null,
|
16 |
+
"peft_type": "LORA",
|
17 |
+
"r": 32,
|
18 |
+
"revision": null,
|
19 |
+
"target_modules": [
|
20 |
+
"self_attn.q_proj",
|
21 |
+
"self_attn.k_proj",
|
22 |
+
"self_attn.v_proj",
|
23 |
+
"self_attn.o_proj",
|
24 |
+
"mlp.gate_proj",
|
25 |
+
"mlp.up_proj",
|
26 |
+
"mlp.down_proj"
|
27 |
+
],
|
28 |
+
"task_type": null
|
29 |
+
}
|
adapter_model.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e1600a224ba30755eed087dba66d1e75264203159a08c7e5b064e0da7b00c2f8
|
3 |
+
size 335604696
|
decoder/adapter_config.json
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"auto_mapping": {
|
3 |
+
"base_model_class": "MistralForCausalLM",
|
4 |
+
"parent_library": "transformers.models.mistral.modeling_mistral"
|
5 |
+
},
|
6 |
+
"base_model_name_or_path": "mistralai/Mistral-7B-v0.1",
|
7 |
+
"bias": "none",
|
8 |
+
"fan_in_fan_out": false,
|
9 |
+
"inference_mode": true,
|
10 |
+
"init_lora_weights": true,
|
11 |
+
"layers_pattern": null,
|
12 |
+
"layers_to_transform": null,
|
13 |
+
"lora_alpha": 8,
|
14 |
+
"lora_dropout": 0.0,
|
15 |
+
"modules_to_save": null,
|
16 |
+
"peft_type": "LORA",
|
17 |
+
"r": 32,
|
18 |
+
"revision": null,
|
19 |
+
"target_modules": [
|
20 |
+
"self_attn.q_proj",
|
21 |
+
"self_attn.k_proj",
|
22 |
+
"self_attn.v_proj",
|
23 |
+
"self_attn.o_proj",
|
24 |
+
"mlp.gate_proj",
|
25 |
+
"mlp.up_proj",
|
26 |
+
"mlp.down_proj"
|
27 |
+
],
|
28 |
+
"task_type": null
|
29 |
+
}
|
decoder/adapter_model.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6ec65ce2dfb0ac65bc57b1bb16908986d899bbb2e08a200ced19ea24d6583134
|
3 |
+
size 335604696
|
encoder/adapter_config.json
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"auto_mapping": {
|
3 |
+
"base_model_class": "MistralForCausalLM",
|
4 |
+
"parent_library": "transformers.models.mistral.modeling_mistral"
|
5 |
+
},
|
6 |
+
"base_model_name_or_path": "mistralai/Mistral-7B-v0.1",
|
7 |
+
"bias": "none",
|
8 |
+
"fan_in_fan_out": false,
|
9 |
+
"inference_mode": true,
|
10 |
+
"init_lora_weights": true,
|
11 |
+
"layers_pattern": null,
|
12 |
+
"layers_to_transform": null,
|
13 |
+
"lora_alpha": 8,
|
14 |
+
"lora_dropout": 0.0,
|
15 |
+
"modules_to_save": null,
|
16 |
+
"peft_type": "LORA",
|
17 |
+
"r": 32,
|
18 |
+
"revision": null,
|
19 |
+
"target_modules": [
|
20 |
+
"self_attn.q_proj",
|
21 |
+
"self_attn.k_proj",
|
22 |
+
"self_attn.v_proj",
|
23 |
+
"self_attn.o_proj",
|
24 |
+
"mlp.gate_proj",
|
25 |
+
"mlp.up_proj",
|
26 |
+
"mlp.down_proj"
|
27 |
+
],
|
28 |
+
"task_type": null
|
29 |
+
}
|
encoder/adapter_model.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e10e83292c8dd14a50d6cc780518ee1c6c4fb8baf80975320c0f034f477e229c
|
3 |
+
size 335604696
|
router/adapter_config.json
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"auto_mapping": {
|
3 |
+
"base_model_class": "MistralForCausalLM",
|
4 |
+
"parent_library": "transformers.models.mistral.modeling_mistral"
|
5 |
+
},
|
6 |
+
"base_model_name_or_path": "mistralai/Mistral-7B-v0.1",
|
7 |
+
"bias": "none",
|
8 |
+
"fan_in_fan_out": false,
|
9 |
+
"inference_mode": true,
|
10 |
+
"init_lora_weights": true,
|
11 |
+
"layers_pattern": null,
|
12 |
+
"layers_to_transform": null,
|
13 |
+
"lora_alpha": 8,
|
14 |
+
"lora_dropout": 0.0,
|
15 |
+
"modules_to_save": null,
|
16 |
+
"peft_type": "LORA",
|
17 |
+
"r": 32,
|
18 |
+
"revision": null,
|
19 |
+
"target_modules": [
|
20 |
+
"self_attn.q_proj",
|
21 |
+
"self_attn.k_proj",
|
22 |
+
"self_attn.v_proj",
|
23 |
+
"self_attn.o_proj",
|
24 |
+
"mlp.gate_proj",
|
25 |
+
"mlp.up_proj",
|
26 |
+
"mlp.down_proj"
|
27 |
+
],
|
28 |
+
"task_type": null
|
29 |
+
}
|
router/adapter_model.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1c9efeec948e5059cfdf6b59b903f250623b65204e0fbb535820bb3330225380
|
3 |
+
size 335604696
|
state.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"step": 75144, "last_kl_weight": 0.01}
|
vae.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d5301993cb6161f263cf35c4dcba67b8f4cce1053aa3d1d923e51cd6ead8f891
|
3 |
+
size 25202116
|