John David Pressman commited on
Commit
29ebdee
1 Parent(s): 167d940

Add model and update README

Browse files
README.md CHANGED
@@ -1,3 +1,125 @@
1
  ---
2
- license: apache-2.0
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ library_name: peft
3
  ---
4
+
5
+ BigVAE is an [AdaVAE](https://arxiv.org/abs/2205.05862) trained as a pair of LoRa finetunes on [Mistral 7B](https://huggingface.co/mistralai/Mistral-7B-v0.1).
6
+ It is meant to be used with the [MiniHF VAE inference code](https://github.com/JD-P/minihf/blob/adavae-moe/vae_infer.py) and will not work if you try to load it
7
+ as an ordinary language checkpoint and perform inference. AdaVAE is an encoder-decoder model trained by taking an existing GPT-N and designating one LoRa the
8
+ encoder and the other its decoder and then tuning with a latent attention mechanism. This model is the encoder and router decoder head for BigVAE, a planned
9
+ Mixture-of-Experts system based on LoRa retrieval rather than gating. It is usable in and of itself as a model for embedding, retrieval, as well as planning
10
+ and guided sampling. Here is an example of a sampling procedure for BigVAE which distills its autoregressive pretraining task into its autoassociative
11
+ recontruction task by averaging together multiple completions. It takes the topic sentence of a paragraph (prompt), guides the next sentences by weighing
12
+ them towards the topic, while averaging together multiple completions on each sentence to improve generation quality:
13
+
14
+ ```
15
+ def bigvae_generate_avg(vae_model, router, prompt, context, n_steps, n_avg):
16
+ with torch.cuda.amp.autocast(dtype=torch.bfloat16):
17
+ context_toks = tokenizer(context, return_tensors="pt")
18
+ context_ids = context_toks["input_ids"].to(device)
19
+ context_mask = context_toks["attention_mask"].to(device)
20
+ embed_toks = tokenizer(prompt, return_tensors="pt")
21
+ embed_ids = embed_toks["input_ids"].to(device)
22
+ embed_mask = embed_toks["attention_mask"].to(device)
23
+ mean = vae_model.encode(embed_ids, embed_mask)
24
+ prompt_embed = vae_model.vae.sample(mean)
25
+ for i in range(n_steps):
26
+ mean = vae_model.encode(embed_ids, embed_mask)
27
+ z = vae_model.vae.sample(mean)
28
+ embeds = []
29
+ for i in range(n_avg):
30
+ output_ids = router.generate(z * 0.5 + prompt_embed * 0.5,
31
+ context_ids,
32
+ context_mask,
33
+ 256,
34
+ tau=0.9)
35
+ intermediate_embed_ids = output_ids[:,-128:]
36
+ intermediate_embed_mask = context_mask.new_ones(
37
+ [1, intermediate_embed_ids.shape[1]]
38
+ )
39
+ mean = vae_model.encode(intermediate_embed_ids, intermediate_embed_mask)
40
+ embeds.append(vae_model.vae.sample(mean))
41
+ output_ids = router.generate((sum(embeds) / n_avg * 0.7) + prompt_embed * 0.3,
42
+ context_ids,
43
+ context_mask,
44
+ 256,
45
+ tau=0.9)
46
+ context_ids = torch.cat([context_ids, embed_ids], dim=1)
47
+ context_mask = torch.cat([context_mask, embed_mask], dim=1)
48
+ embed_ids = output_ids[:,-256:-128]
49
+ embed_mask = context_mask.new_ones([1, embed_ids.shape[1]])
50
+ out_texts = [tokenizer.decode(toks, skip_special_tokens=True) for toks in context_ids]
51
+ return out_texts
52
+ ```
53
+
54
+ Here is an example of an output from this process:
55
+
56
+ ```
57
+ Then it asked the network to reconstruct the input and the original embedding. The network had to learn to match the
58
+ embedding to the original input, therefore matching the inference by consuming the embedding. This was key because
59
+ the embedding had to be able to match the text with the text it was consumed with. 'Here's how you do it,' Boru told Mu,
60
+ 'Just impute the mean and variance.' This Mu did, transforming not words but entire paragraphs into vectors and then
61
+ inferring the next paragraph. It took some tweaks and tuning to get the initial performance but the second arago spot
62
+ had been found. To make sure the network was learning the right thing, Boru had to check the first value in the vector.
63
+ If the first value was below 0, the network had failed to learn the first value. If the value was above 0, the network
64
+ had been able to learn the first value.
65
+ ‘What have you called this, Boru?’ asked Mu. ‘Latent variable regression.’ ‘It looks like a mixture of density network
66
+ and autoencoder,’ said Nayaf. ‘It’s an autoencoder but it’s using latent variables, but we’re using the mean and variance
67
+ of Grade had a difficult time seeing it, but he could tell it was close. 'So you've found the second arago,' he said.
68
+ 'Yes,' Rin replied. 'We just have to figure out how to use it.'
69
+ 'How?' Rin asked.
70
+ 'You can move the second word in, right?'
71
+ 'Possibly.' Rin thought for a moment.
72
+ 'The second word will be the first word of the next arago,' Mu said. 'We just need to find it.'
73
+ 'True,' Rin agreed. 'Well, I'll let you know what a Gaussian.’ ‘Let’s see if we can get it to work.’ ‘Arago the second
74
+ spot?’ ‘We’re here,’ Arago said.
75
+ The second spot was located in the middle of the text. Arago had to read it again to find the proper signal. ‘I’m going
76
+ to have to tweak some of the weights,’ said Arago. ‘I’ve had to change the input to the next layer from an input to
77
+ output.’ ‘You’re making a mistake again,’ said Mu to Arago. ‘It’s a mistake.’ The network had been learning I find out.'
78
+ 'That's the second arago,' Rin said.
79
+ 'The second arago?' Argo asked.
80
+ 'Rin has found the second arago.'
81
+ Argo stared at Rin. 'Argo, is there something wrong?'
82
+ 'I thought so.'
83
+ 'What?' Rin said.
84
+ 'I don't know,' Argo said. 'I thought I was the smartest person in the world but, well, I only had a certain amount of
85
+ energy. I didn't know how to do the second arago until now, but I can't
86
+ ```
87
+
88
+ This generation method is slow, but retrieval could be used to speed up inference and make it converge closer and closer
89
+ to normal sampling speed as the model becomes able to call upon more and more relevant sentences that it has generated before.
90
+
91
+ Because the BigVAE combines guided sampling with the ability to merge representations, it becomes possible to formulate plans and
92
+ cognitive strategies for the model to follow. The inference policy can adjudicate between an expected plan or series of steps and
93
+ the specific context the model is responding to.
94
+
95
+ This model is also highly interpretable. Because it is an encoder-decoder every sentence generated by the model has a latent representation
96
+ that can be tracked along with its behavioral token sequence. Our hope is that BigVAE will shed light on the latent operations performed by
97
+ autoregressive language models and be useful to alignment and interpretability researchers.
98
+
99
+ ## Training procedure
100
+
101
+ This model was trained on [a 1 billion token sample](https://huggingface.co/datasets/togethercomputer/RedPajama-Data-1T-Sample) of RedPajama
102
+ on 8x H100 GPUs for roughly 24 hours. The difference from v0.1 is that KL weight was turned up to 0.1 over 50k steps.
103
+
104
+ Using the scripts in the MiniHF repo as they exist now the training commands were:
105
+
106
+ accelerate launch train_vae_overlap.py --model "mistralai/Mistral-7B-v0.1" --preprocessed preprocessed_mistral --context 64 --output vae_64_overlap_mistral_2 --batch-size 24
107
+
108
+ accelerate launch train_vae_router.py --model "mistralai/Mistral-7B-v0.1" --preprocessed preprocessed_mistral --vae-context 64 --start-from vae_64_overlap_mistral_2 --output vae_64_overlap_router_mistral_2 --lr 1e-4 --batch-size 1
109
+
110
+
111
+ The following `bitsandbytes` quantization config was used during training:
112
+ - quant_method: bitsandbytes
113
+ - load_in_8bit: False
114
+ - load_in_4bit: True
115
+ - llm_int8_threshold: 6.0
116
+ - llm_int8_skip_modules: None
117
+ - llm_int8_enable_fp32_cpu_offload: False
118
+ - llm_int8_has_fp16_weight: False
119
+ - bnb_4bit_quant_type: nf4
120
+ - bnb_4bit_use_double_quant: True
121
+ - bnb_4bit_compute_dtype: bfloat16
122
+
123
+ ### Framework versions
124
+
125
+ - PEFT 0.4.0
adapter_config.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "auto_mapping": {
3
+ "base_model_class": "MistralForCausalLM",
4
+ "parent_library": "transformers.models.mistral.modeling_mistral"
5
+ },
6
+ "base_model_name_or_path": "mistralai/Mistral-7B-v0.1",
7
+ "bias": "none",
8
+ "fan_in_fan_out": false,
9
+ "inference_mode": true,
10
+ "init_lora_weights": true,
11
+ "layers_pattern": null,
12
+ "layers_to_transform": null,
13
+ "lora_alpha": 8,
14
+ "lora_dropout": 0.0,
15
+ "modules_to_save": null,
16
+ "peft_type": "LORA",
17
+ "r": 32,
18
+ "revision": null,
19
+ "target_modules": [
20
+ "self_attn.q_proj",
21
+ "self_attn.k_proj",
22
+ "self_attn.v_proj",
23
+ "self_attn.o_proj",
24
+ "mlp.gate_proj",
25
+ "mlp.up_proj",
26
+ "mlp.down_proj"
27
+ ],
28
+ "task_type": null
29
+ }
adapter_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e1600a224ba30755eed087dba66d1e75264203159a08c7e5b064e0da7b00c2f8
3
+ size 335604696
decoder/adapter_config.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "auto_mapping": {
3
+ "base_model_class": "MistralForCausalLM",
4
+ "parent_library": "transformers.models.mistral.modeling_mistral"
5
+ },
6
+ "base_model_name_or_path": "mistralai/Mistral-7B-v0.1",
7
+ "bias": "none",
8
+ "fan_in_fan_out": false,
9
+ "inference_mode": true,
10
+ "init_lora_weights": true,
11
+ "layers_pattern": null,
12
+ "layers_to_transform": null,
13
+ "lora_alpha": 8,
14
+ "lora_dropout": 0.0,
15
+ "modules_to_save": null,
16
+ "peft_type": "LORA",
17
+ "r": 32,
18
+ "revision": null,
19
+ "target_modules": [
20
+ "self_attn.q_proj",
21
+ "self_attn.k_proj",
22
+ "self_attn.v_proj",
23
+ "self_attn.o_proj",
24
+ "mlp.gate_proj",
25
+ "mlp.up_proj",
26
+ "mlp.down_proj"
27
+ ],
28
+ "task_type": null
29
+ }
decoder/adapter_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6ec65ce2dfb0ac65bc57b1bb16908986d899bbb2e08a200ced19ea24d6583134
3
+ size 335604696
encoder/adapter_config.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "auto_mapping": {
3
+ "base_model_class": "MistralForCausalLM",
4
+ "parent_library": "transformers.models.mistral.modeling_mistral"
5
+ },
6
+ "base_model_name_or_path": "mistralai/Mistral-7B-v0.1",
7
+ "bias": "none",
8
+ "fan_in_fan_out": false,
9
+ "inference_mode": true,
10
+ "init_lora_weights": true,
11
+ "layers_pattern": null,
12
+ "layers_to_transform": null,
13
+ "lora_alpha": 8,
14
+ "lora_dropout": 0.0,
15
+ "modules_to_save": null,
16
+ "peft_type": "LORA",
17
+ "r": 32,
18
+ "revision": null,
19
+ "target_modules": [
20
+ "self_attn.q_proj",
21
+ "self_attn.k_proj",
22
+ "self_attn.v_proj",
23
+ "self_attn.o_proj",
24
+ "mlp.gate_proj",
25
+ "mlp.up_proj",
26
+ "mlp.down_proj"
27
+ ],
28
+ "task_type": null
29
+ }
encoder/adapter_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e10e83292c8dd14a50d6cc780518ee1c6c4fb8baf80975320c0f034f477e229c
3
+ size 335604696
router/adapter_config.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "auto_mapping": {
3
+ "base_model_class": "MistralForCausalLM",
4
+ "parent_library": "transformers.models.mistral.modeling_mistral"
5
+ },
6
+ "base_model_name_or_path": "mistralai/Mistral-7B-v0.1",
7
+ "bias": "none",
8
+ "fan_in_fan_out": false,
9
+ "inference_mode": true,
10
+ "init_lora_weights": true,
11
+ "layers_pattern": null,
12
+ "layers_to_transform": null,
13
+ "lora_alpha": 8,
14
+ "lora_dropout": 0.0,
15
+ "modules_to_save": null,
16
+ "peft_type": "LORA",
17
+ "r": 32,
18
+ "revision": null,
19
+ "target_modules": [
20
+ "self_attn.q_proj",
21
+ "self_attn.k_proj",
22
+ "self_attn.v_proj",
23
+ "self_attn.o_proj",
24
+ "mlp.gate_proj",
25
+ "mlp.up_proj",
26
+ "mlp.down_proj"
27
+ ],
28
+ "task_type": null
29
+ }
router/adapter_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1c9efeec948e5059cfdf6b59b903f250623b65204e0fbb535820bb3330225380
3
+ size 335604696
state.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"step": 75144, "last_kl_weight": 0.01}
vae.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d5301993cb6161f263cf35c4dcba67b8f4cce1053aa3d1d923e51cd6ead8f891
3
+ size 25202116