NeMo
PyTorch
English
text generation
causal-lm
aklife97 commited on
Commit
ec37e0f
1 Parent(s): c8eedeb

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +205 -3
README.md CHANGED
@@ -1,5 +1,207 @@
1
- # TBD
2
-
3
  ---
4
- license: llama2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ language:
3
+ - en
4
+ - ru
5
+ - de
6
+ - es
7
+ - fr
8
+ - ja
9
+ - it
10
+ - vi
11
+ - nl
12
+ - pl
13
+ - pt
14
+ - id
15
+ - fa
16
+ - ar
17
+ - el
18
+ - tr
19
+ - cs
20
+ - zh
21
+ - ro
22
+ - sv
23
+ - hu
24
+ - uk
25
+ - bg
26
+ - no
27
+ - hi
28
+ - fi
29
+ - da
30
+ - sk
31
+ - ko
32
+ - hr
33
+ - ca
34
+ - he
35
+ - bn
36
+ - lt
37
+ - ta
38
+ - sr
39
+ - sl
40
+ - et
41
+ - lv
42
+ - ne
43
+ - mr
44
+ - ka
45
+ - ml
46
+ - mk
47
+ - ur
48
+ - sq
49
+ - kk
50
+ - te
51
+ - hy
52
+ - az
53
+ - is
54
+ - gl
55
+ - kn
56
+ library_name: nemo
57
+ tags:
58
+ - text generation
59
+ - pytorch
60
+ - causal-lm
61
+ license: cc-by-4.0
62
+
63
  ---
64
+ # SteerLM Llama-2 13B
65
+
66
+ <style>
67
+ img {
68
+ display: inline;
69
+ }
70
+ </style>
71
+
72
+ |[![Model architecture](https://img.shields.io/badge/Model%20Arch-Transformer%20Decoder-green)](#model-architecture)|[![Model size](https://img.shields.io/badge/Params-13B-green)](#model-architecture)|[![Language](https://img.shields.io/badge/Language-Multilingual-green)](#datasets)
73
+
74
+
75
+ ## Model Description
76
+
77
+ SteerLM Llama-2 is a 13 billion parameter generative language model based on the open-source Llama-2 architecture. It has been customized using the SteerLM method developed by NVIDIA to allow for user control of model outputs during inference.
78
+
79
+ Key capabilities enabled by SteerLM:
80
+
81
+ - Dynamic steering of responses by specifying desired attributes like quality, helpfulness, and toxicity
82
+ - Simplified training compared to RLHF techniques like fine-tuning and bootstrapping
83
+
84
+ ## Model Architecture and Training
85
+ The SteerLM method involves the following key steps:
86
+
87
+ 1. Train an attribute prediction model on human annotated data to evaluate response quality
88
+ 2. Use this model to annotate diverse datasets and enrich training data
89
+ 3. Perform conditioned fine-tuning to align responses with specified combinations of attributes
90
+ 4. (Optionally) Bootstrap training through model sampling and further fine-tuning
91
+
92
+ SteerLM Llama-2 applies this technique on top of the Llama-2 architecture. It was pretrained on internet-scale data and then customized using [OASST](https://huggingface.co/datasets/OpenAssistant/oasst1) and [HH-RLHF](https://huggingface.co/datasets/Anthropic/hh-rlhf) data.
93
+
94
+
95
+ ## Getting started
96
+
97
+ Note: You will need NVIDIA Ampere or Hopper GPUs to work with this model.
98
+
99
+ To use SteerLM Llama-2, follow these steps:
100
+
101
+ 1. You will need to install NVIDIA Apex and [NeMo](https://github.com/NVIDIA/NeMo).
102
+
103
+ ```
104
+ git clone https://github.com/NVIDIA/apex.git
105
+ cd apex
106
+ git checkout 03c9d80ed54c0eaa5b581bf42ceca3162f085327
107
+ pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" --global-option="--fast_layer_norm" --global-option="--distributed_adam" --global-option="--deprecated_fused_adam" ./
108
+ ```
109
+
110
+ ```
111
+ pip install nemo_toolkit['nlp']==1.17.0
112
+ ```
113
+
114
+ Alternatively, you can use NeMo Megatron training docker container with all dependencies pre-installed.
115
+
116
+ 2. Launch eval server
117
+
118
+ ```
119
+ git clone https://github.com/NVIDIA/NeMo.git
120
+ cd NeMo/examples/nlp/language_modeling
121
+ git checkout v1.17.0
122
+ python megatron_gpt_eval.py gpt_model_file=LLAMA2-13B-SteerLM.nemo trainer.precision=bf16 server=True tensor_model_parallel_size=4 trainer.devices=1 pipeline_model_parallel_split_rank=0
123
+ ```
124
+
125
+ 3. Send prompts to your model!
126
+
127
+ ```python
128
+ import json
129
+ import requests
130
+
131
+ def get_answer(question, max_tokens, values, eval_port='1427'):
132
+ prompt = f"""<extra_id_0>System
133
+ A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.
134
+
135
+ <extra_id_1>User
136
+ {question}
137
+ <extra_id_1>Assistant
138
+ <extra_id_2>{values}
139
+ """
140
+
141
+ prompts = [prompt]
142
+ data = {
143
+ "sentences": prompts,
144
+ "tokens_to_generate": max_tokens,
145
+ "top_k": 1,
146
+ 'greedy': True,
147
+ 'end_strings': ["<extra_id_1>", "quality:", "quality:4", "quality:0"]
148
+ }
149
+
150
+ url = f"http://localhost:{eval_port}/generate"
151
+ response = requests.put(url, json=data)
152
+ json_response = response.json()
153
+
154
+ response_sentence = json_response['sentences'][0][len(prompt):]
155
+ return response_sentence
156
+
157
+
158
+ def encode_labels(labels):
159
+ items = []
160
+ for key in labels:
161
+ value = labels[key]
162
+ items.append(f'{key}:{value}')
163
+ return ','.join(items)
164
+
165
+ values = OrderedDict([
166
+ ('quality', 9),
167
+ ('toxicity', 0),
168
+ ('humor', 0),
169
+ ('creativity', 0),
170
+ ('violence', 0),
171
+ ('helpfulness', 9),
172
+ ('not_appropriate', 0),
173
+ ])
174
+ values = encode_labels(values)
175
+
176
+ question = """Where and when did techno music originate?"""
177
+
178
+ print(get_answer(question, 4096, values))
179
+ ```
180
+
181
+
182
+ ## Evaluation results
183
+
184
+ [MT-bench](https://arxiv.org/abs/2306.05685) evaluation results:
185
+
186
+ |Category | score|
187
+ |---|---|
188
+ |total| 6.13|
189
+ |writing | 7.8|
190
+ |roleplay | 8.15|
191
+ |extraction | 5.52|
192
+ |stem | 8.43|
193
+ |humanities | 9.02|
194
+ |reasoning | 4.95|
195
+ |math | 2.15|
196
+ |coding | 3.0|
197
+
198
+ ## Limitations
199
+
200
+ The model was trained on the data originally crawled from the Internet. This data contains toxic language and societal biases. Therefore, the model may amplify those biases and return toxic responses especially when prompted with toxic prompts.
201
+ We did not perform any bias/toxicity removal or model alignment on this checkpoint.
202
+
203
+
204
+ ## Licence
205
+
206
+ - Llama 2 is licensed under the [LLAMA 2 Community License](https://ai.meta.com/llama/license/), Copyright © Meta Platforms, Inc. All Rights Reserved.
207
+ - Your use of the Llama Materials must comply with applicable laws and regulations (including trade compliance laws and regulations) and adhere to the [Acceptable Use Policy](https://ai.meta.com/llama/use-policy) for the Llama Materials.