|
--- |
|
library_name: peft |
|
base_model: mistralai/Mistral-7B-v0.1 |
|
--- |
|
|
|
### Model Details |
|
|
|
### Model Description |
|
|
|
This is quantized model of mistral-7B. |
|
|
|
- **Developed by:** Rais Kazi |
|
|
|
### Model Sources [optional] |
|
https://github.com/meetrais/LLM-Fine-Tuning/blob/main/finetune_mistral_7b.py |
|
https://github.com/meetrais/LLM-Fine-Tuning/blob/main/call_finetune_mistral_7b.py |
|
|
|
## Training procedure |
|
|
|
|
|
The following `bitsandbytes` quantization config was used during training: |
|
- quant_method: QuantizationMethod.BITS_AND_BYTES |
|
- load_in_8bit: False |
|
- load_in_4bit: True |
|
- llm_int8_threshold: 6.0 |
|
- llm_int8_skip_modules: None |
|
- llm_int8_enable_fp32_cpu_offload: False |
|
- llm_int8_has_fp16_weight: False |
|
- bnb_4bit_quant_type: nf4 |
|
- bnb_4bit_use_double_quant: True |
|
- bnb_4bit_compute_dtype: bfloat16 |
|
|
|
### Framework versions |
|
|
|
|
|
- PEFT 0.6.2.dev0 |
|
|
|
## Code to call this mnodel |
|
|
|
import torch |
|
|
|
from peft import PeftModel, PeftConfig |
|
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
|
from transformers import BitsAndBytesConfig |
|
|
|
peft_model_id = "meetrais/finetuned_mistral_7b" |
|
|
|
config = PeftConfig.from_pretrained(peft_model_id) |
|
|
|
bnb_config = BitsAndBytesConfig( |
|
load_in_4bit=True, |
|
bnb_4bit_use_double_quant=True, |
|
bnb_4bit_quant_type="nf4", |
|
bnb_4bit_compute_dtype=torch.bfloat16 |
|
) |
|
|
|
model = AutoModelForCausalLM.from_pretrained(peft_model_id, quantization_config=bnb_config, device_map='auto') |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path) |
|
|
|
if tokenizer.pad_token is None: |
|
tokenizer.add_special_tokens({'pad_token': '[PAD]'}) |
|
text = "Capital of USA is" |
|
device = "cuda:0" |
|
|
|
inputs = tokenizer(text, return_tensors="pt").to(device) |
|
|
|
outputs = model.generate(**inputs, pad_token_id= tokenizer.eos_token_id, max_new_tokens=20) |
|
print(tokenizer.decode(outputs[0], skip_special_tokens=True)) |
|
|