PEFT
Safetensors
English
suri-sft / README.md
chtmp223's picture
Update README.md
e172057 verified
|
raw
history blame
3.65 kB
metadata
library_name: peft
base_model: mistralai/Mistral-7B-Instruct-v0.2
license: apache-2.0
language:
  - en
datasets:
  - chtmp223/suri

Suri-SFT

Suri-SFT is a fine-tuned version of mistralai/Mistral-7B-Instruct-v0.2 using supervised fine-tuning with LoRA. Please check our paper for more details on the method.

πŸ“’ Model Details

Model Description

Model Sources

⚠️ Getting Started

Use the code in this repository for training and inference.

πŸ’» Training Details

Training Data

chtmp223/suri

Training Procedure

Configurations Values
Hardware (Training and Inference) 4xA100s
Tracking wandb
lora_r 16
lora_alpha 16
lora_dropout 0.05
gradient_accumulation_steps 1
gradient_checkpointing True
learning_rate 5.0e-5
lr_scheduler_type cosine
max_length 15024
max_completion_length 15000
max_prompt_length 5000
num_train_epochs 2
optim adamw_torch
per_device_train_batch_size 1

Software

Training code is adapted from Alignment Handbook and Trl.

πŸ€— Inference

from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel, PeftConfig
from datasets import load_dataset
import torch
os.environ["TOKENIZERS_PARALLELISM"] = "False"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.cuda.empty_cache()

model_name = "chtmp223/suri-sft"
base_model_name = "mistralai/Mistral-7B-Instruct-v0.2"
config = PeftConfig.from_pretrained(model_name)
base_model = AutoModelForCausalLM.from_pretrained(base_model_name).to(device)
model = PeftModel.from_pretrained(base_model, model_name).to(device)
tokenizer = AutoTokenizer.from_pretrained(base_model_name)
prompt = [
  {
      "role": "user",
      "content": user_prompt, 
  }
]
input_context = tokenizer.apply_chat_template(
  prompt, add_generation_prompt=True, tokenize=False
)
input_ids = tokenizer.encode(
  input_context, return_tensors="pt", add_special_tokens=False
).to(model.device)
output = model.generate(
  input_ids, max_length=10000, do_sample=True, use_cache=True
).cpu()

print(tokenizer.decode(output[0]))

πŸ“œ Citation

@misc{pham2024surimulticonstraintinstructionfollowing,
      title={Suri: Multi-constraint Instruction Following for Long-form Text Generation}, 
      author={Chau Minh Pham and Simeng Sun and Mohit Iyyer},
      year={2024},
      eprint={2406.19371},
      archivePrefix={arXiv},
      primaryClass={cs.CL},
      url={https://arxiv.org/abs/2406.19371}, 
}

βš™οΈ Framework versions

  • PEFT 0.11.1