--- license: apache-2.0 language: - en dataset: - chargoddard/Open-Platypus-Chat tags: - axolotl base_model: ai21labs/Jamba-v0.1 --- ![image/webp](https://cdn-uploads.huggingface.co/production/uploads/61b8e2ba285851687028d395/efmF8RtLLeKgQ9OwPqfD8.webp) # Jambatypus-v0.1 This model is a QLoRA fine-tuned version of [ai21labs/Jamba-v0.1](https://huggingface.co/ai21labs/Jamba-v0.1) on the [chargoddard/Open-Platypus-Chat](https://huggingface.co/datasets/chargoddard/Open-Platypus-Chat) dataset. It has been trained on 2xA100 80 GB using my [LazyAxolotl - Jamba](https://colab.research.google.com/drive/1alsgwZFvLPPAwIgkAxeMKHQSJYfW7DeZ?usp=sharing) notebook. This repo contains both the adapter and the merged model in FP16 precision. I recommend using the ChatML template to use this model. [Built with Axolotl](https://github.com/OpenAccess-AI-Collective/axolotl)
See axolotl config axolotl version: `0.4.0` ```yaml base_model: ai21labs/Jamba-v0.1 trust_remote_code: true load_in_8bit: false load_in_4bit: true strict: false datasets: - path: chargoddard/Open-Platypus-Chat type: sharegpt chat_template: chatml dataset_prepared_path: val_set_size: 0.01 output_dir: ./out sequence_len: 4096 sample_packing: true pad_to_sequence_len: true eval_sample_packing: false use_wandb: true wandb_project: axolotl wandb_entity: wandb_watch: wandb_name: Jambatypus-v0.1 wandb_log_model: adapter: qlora lora_r: 16 lora_alpha: 32 lora_dropout: 0.05 lora_target_linear: true low_cpu_mem_usage: true gradient_accumulation_steps: 8 micro_batch_size: 1 num_epochs: 1 optimizer: adamw_bnb_8bit adam_beta2: 0.95 adam_epsilon: 0.00001 max_grad_norm: 1.0 lr_scheduler: cosine learning_rate: 0.0002 train_on_inputs: false group_by_length: false bf16: auto fp16: tf32: false gradient_checkpointing: true gradient_checkpointing_kwargs: use_reentrant: false early_stopping_patience: resume_from_checkpoint: local_rank: logging_steps: 1 xformers_attention: flash_attention: true warmup_steps: 10 evals_per_epoch: 4 saves_per_epoch: 4 save_total_limit: 2 debug: deepspeed: weight_decay: 0.0 special_tokens: ```

### Training hyperparameters The following hyperparameters were used during training: - learning_rate: 0.0002 - train_batch_size: 1 - eval_batch_size: 1 - seed: 42 - distributed_type: multi-GPU - num_devices: 2 - gradient_accumulation_steps: 8 - total_train_batch_size: 16 - total_eval_batch_size: 2 - optimizer: Adam with betas=(0.9,0.95) and epsilon=1e-05 - lr_scheduler_type: cosine - lr_scheduler_warmup_steps: 10 - num_epochs: 1 ### Training results | Training Loss | Epoch | Step | Validation Loss | |:-------------:|:-----:|:----:|:---------------:| | 0.6274 | 0.01 | 1 | 1.0298 | | 0.44 | 0.25 | 42 | 0.9770 | | 0.4406 | 0.5 | 84 | 0.9653 | | 0.4445 | 0.75 | 126 | 0.9645 | | 0.4609 | 1.0 | 168 | 0.9641 | ### Framework versions - PEFT 0.10.0 - Transformers 4.40.0.dev0 - Pytorch 2.1.2+cu118 - Datasets 2.18.0 - Tokenizers 0.15.0 ## 💻 Usage The following code creates a Gradio chat interface with Jambatypus. ```python !pip install -qqq -U git+https://github.com/huggingface/transformers !pip install -qqq mamba-ssm causal-conv1d>=1.2.0 !pip install -qqq accelerate bitsandbytes torch datasets peft gradio !pip install -qqq flash-attn --no-build-isolation import torch import gradio as gr from threading import Thread from peft import PeftModel, PeftConfig from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TextIteratorStreamer STOP_TOKEN = "<|im_end|>" def predict(message, history, system_prompt, temperature, max_new_tokens, top_k, repetition_penalty, top_p): # Format history with a given chat template stop_token = "<|im_end|>" instruction = '<|im_start|>system\n' + system_prompt + '\n<|im_end|>\n' for human, assistant in history: instruction += '<|im_start|>user\n' + human + '\n<|im_end|>\n<|im_start|>assistant\n' + assistant instruction += '\n<|im_start|>user\n' + message + '\n<|im_end|>\n<|im_start|>assistant\n' streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) enc = tokenizer([instruction], return_tensors="pt", padding=True, truncation=True) input_ids, attention_mask = enc.input_ids, enc.attention_mask generate_kwargs = dict( {"input_ids": input_ids.to(device), "attention_mask": attention_mask.to(device)}, streamer=streamer, do_sample=True, temperature=temperature, max_new_tokens=max_new_tokens, top_k=top_k, repetition_penalty=repetition_penalty, top_p=top_p ) t = Thread(target=model.generate, kwargs=generate_kwargs) t.start() outputs = [] for new_token in streamer: if STOP_TOKEN in new_token: outputs.append(new_token[:-len(stop_token)-1]) yield "".join(outputs) break outputs.append(new_token) yield "".join(outputs) # Load model device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') tokenizer = AutoTokenizer.from_pretrained("ai21labs/Jamba-v0.1") # 4-bit precision quant config quantization_config = BitsAndBytesConfig( load_in_8bit=True, llm_int8_skip_modules=["mamba"] ) # Load model and tokenizer with ChatML format model = AutoModelForCausalLM.from_pretrained( "ai21labs/Jamba-v0.1", trust_remote_code=True, torch_dtype=torch.cuda.is_bf16_supported() and torch.bfloat16 or torch.float16, attn_implementation="flash_attention_2", low_cpu_mem_usage=True, quantization_config=quantization_config ) config = PeftConfig.from_pretrained("mlabonne/Jambatypus-v0.1") model = PeftModel.from_pretrained(model, "mlabonne/Jambatypus-v0.1") # Create Gradio interface gr.ChatInterface( predict, title="Jambatypus", description="Chat with Jambatypus!", examples=[ ["Can you solve the equation 2x + 3 = 11 for x?"], ["Write an epic poem about Ancient Rome."], ["Who was the first person to walk on the Moon?"], ["Use a list comprehension to create a list of squares for numbers from 1 to 10."], ["Recommend some popular science fiction books."], ["Can you write a short story about a time-traveling detective?"] ], additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False), additional_inputs=[ gr.Textbox("Perform the task to the best of your ability.", label="System prompt"), gr.Slider(0, 1, 0.8, label="Temperature"), gr.Slider(128, 4096, 1024, label="Max new tokens"), gr.Slider(1, 80, 40, label="Top K sampling"), gr.Slider(0, 2, 1.1, label="Repetition penalty"), gr.Slider(0, 1, 0.95, label="Top P sampling"), ], theme=gr.themes.Soft(primary_hue="green"), ).queue().launch(share=True) ```