mlabonne commited on
Commit
bc8fdf3
1 Parent(s): 39d9fbf

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +106 -2
README.md CHANGED
@@ -13,10 +13,14 @@ base_model: ai21labs/Jamba-v0.1
13
 
14
  # Jambatypus-v0.1
15
 
16
- This model is a fine-tuned version of [ai21labs/Jamba-v0.1](https://huggingface.co/ai21labs/Jamba-v0.1) on the [chargoddard/Open-Platypus-Chat](https://huggingface.co/datasets/chargoddard/Open-Platypus-Chat) dataset.
17
 
18
  It has been trained on 2xA100 80 GB using my [LazyAxolotl - Jamba](https://colab.research.google.com/drive/1alsgwZFvLPPAwIgkAxeMKHQSJYfW7DeZ?usp=sharing) notebook.
19
 
 
 
 
 
20
  [<img src="https://raw.githubusercontent.com/OpenAccess-AI-Collective/axolotl/main/image/axolotl-badge-web.png" alt="Built with Axolotl" width="200" height="32"/>](https://github.com/OpenAccess-AI-Collective/axolotl)
21
  <details><summary>See axolotl config</summary>
22
 
@@ -130,4 +134,104 @@ The following hyperparameters were used during training:
130
  - Transformers 4.40.0.dev0
131
  - Pytorch 2.1.2+cu118
132
  - Datasets 2.18.0
133
- - Tokenizers 0.15.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
  # Jambatypus-v0.1
15
 
16
+ This model is a QLoRA fine-tuned version of [ai21labs/Jamba-v0.1](https://huggingface.co/ai21labs/Jamba-v0.1) on the [chargoddard/Open-Platypus-Chat](https://huggingface.co/datasets/chargoddard/Open-Platypus-Chat) dataset.
17
 
18
  It has been trained on 2xA100 80 GB using my [LazyAxolotl - Jamba](https://colab.research.google.com/drive/1alsgwZFvLPPAwIgkAxeMKHQSJYfW7DeZ?usp=sharing) notebook.
19
 
20
+ This repo contains both the adapter and the merged model in FP16 precision.
21
+
22
+ I recommend using the ChatML template to use this model.
23
+
24
  [<img src="https://raw.githubusercontent.com/OpenAccess-AI-Collective/axolotl/main/image/axolotl-badge-web.png" alt="Built with Axolotl" width="200" height="32"/>](https://github.com/OpenAccess-AI-Collective/axolotl)
25
  <details><summary>See axolotl config</summary>
26
 
 
134
  - Transformers 4.40.0.dev0
135
  - Pytorch 2.1.2+cu118
136
  - Datasets 2.18.0
137
+ - Tokenizers 0.15.0
138
+
139
+ ## 💻 Usage
140
+
141
+ The following code creates a Gradio chat interface with Jambatypus.
142
+
143
+ ```python
144
+ !pip install -qqq -U git+https://github.com/huggingface/transformers
145
+ !pip install -qqq mamba-ssm causal-conv1d>=1.2.0
146
+ !pip install -qqq accelerate bitsandbytes torch datasets peft gradio
147
+ !pip install -qqq flash-attn --no-build-isolation
148
+
149
+
150
+ import torch
151
+ import gradio as gr
152
+ from threading import Thread
153
+ from peft import PeftModel, PeftConfig
154
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TextIteratorStreamer
155
+
156
+ STOP_TOKEN = "<|im_end|>"
157
+
158
+ def predict(message, history, system_prompt, temperature, max_new_tokens, top_k, repetition_penalty, top_p):
159
+ # Format history with a given chat template
160
+ stop_token = "<|im_end|>"
161
+ instruction = '<|im_start|>system\n' + system_prompt + '\n<|im_end|>\n'
162
+ for human, assistant in history:
163
+ instruction += '<|im_start|>user\n' + human + '\n<|im_end|>\n<|im_start|>assistant\n' + assistant
164
+ instruction += '\n<|im_start|>user\n' + message + '\n<|im_end|>\n<|im_start|>assistant\n'
165
+
166
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
167
+ enc = tokenizer([instruction], return_tensors="pt", padding=True, truncation=True)
168
+ input_ids, attention_mask = enc.input_ids, enc.attention_mask
169
+
170
+ generate_kwargs = dict(
171
+ {"input_ids": input_ids.to(device), "attention_mask": attention_mask.to(device)},
172
+ streamer=streamer,
173
+ do_sample=True,
174
+ temperature=temperature,
175
+ max_new_tokens=max_new_tokens,
176
+ top_k=top_k,
177
+ repetition_penalty=repetition_penalty,
178
+ top_p=top_p
179
+ )
180
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
181
+ t.start()
182
+ outputs = []
183
+ for new_token in streamer:
184
+ if STOP_TOKEN in new_token:
185
+ outputs.append(new_token[:-len(stop_token)-1])
186
+ yield "".join(outputs)
187
+ break
188
+ outputs.append(new_token)
189
+ yield "".join(outputs)
190
+
191
+
192
+ # Load model
193
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
194
+ tokenizer = AutoTokenizer.from_pretrained("ai21labs/Jamba-v0.1")
195
+
196
+ # 4-bit precision quant config
197
+ quantization_config = BitsAndBytesConfig(
198
+ load_in_8bit=True,
199
+ llm_int8_skip_modules=["mamba"]
200
+ )
201
+ # Load model and tokenizer with ChatML format
202
+ model = AutoModelForCausalLM.from_pretrained(
203
+ "ai21labs/Jamba-v0.1",
204
+ trust_remote_code=True,
205
+ torch_dtype=torch.cuda.is_bf16_supported() and torch.bfloat16 or torch.float16,
206
+ attn_implementation="flash_attention_2",
207
+ low_cpu_mem_usage=True,
208
+ quantization_config=quantization_config
209
+ )
210
+ config = PeftConfig.from_pretrained("mlabonne/Jambatypus-v0.1")
211
+ model = PeftModel.from_pretrained(model, "mlabonne/Jambatypus-v0.1")
212
+
213
+ # Create Gradio interface
214
+ gr.ChatInterface(
215
+ predict,
216
+ title="Jambatypus",
217
+ description="Chat with Jambatypus!",
218
+ examples=[
219
+ ["Can you solve the equation 2x + 3 = 11 for x?"],
220
+ ["Write an epic poem about Ancient Rome."],
221
+ ["Who was the first person to walk on the Moon?"],
222
+ ["Use a list comprehension to create a list of squares for numbers from 1 to 10."],
223
+ ["Recommend some popular science fiction books."],
224
+ ["Can you write a short story about a time-traveling detective?"]
225
+ ],
226
+ additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False),
227
+ additional_inputs=[
228
+ gr.Textbox("Perform the task to the best of your ability.", label="System prompt"),
229
+ gr.Slider(0, 1, 0.8, label="Temperature"),
230
+ gr.Slider(128, 4096, 1024, label="Max new tokens"),
231
+ gr.Slider(1, 80, 40, label="Top K sampling"),
232
+ gr.Slider(0, 2, 1.1, label="Repetition penalty"),
233
+ gr.Slider(0, 1, 0.95, label="Top P sampling"),
234
+ ],
235
+ theme=gr.themes.Soft(primary_hue="green"),
236
+ ).queue().launch(share=True)
237
+ ```