|
from langchain_community.embeddings.sentence_transformer import ( |
|
SentenceTransformerEmbeddings, |
|
) |
|
from langchain_community.vectorstores import Chroma |
|
|
|
import time |
|
import torch |
|
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, TrainingArguments, DataCollatorForLanguageModeling |
|
from trl import SFTTrainer |
|
from peft import get_peft_model, LoraConfig, prepare_model_for_kbit_training |
|
|
|
|
|
embedding_function = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2") |
|
|
|
|
|
db = Chroma(embedding_function=embedding_function, persist_directory="./chroma_db") |
|
|
|
print("There are", db._collection.count(), " docs in the collection") |
|
|
|
docs = db._collection.peek(db._collection.count()) |
|
dataset = docs['documents'] |
|
|
|
if torch.cuda.is_available(): |
|
print("Cuda is available") |
|
|
|
base_model_id = "microsoft/phi-2" |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(base_model_id) |
|
if tokenizer.pad_token is None: |
|
tokenizer.pad_token = tokenizer.eos_token |
|
print("pad_token was missing and has been set to eos_token") |
|
|
|
|
|
bnb_config = BitsAndBytesConfig(load_in_4bit=True, |
|
bnb_4bit_quant_type='nf4', |
|
|
|
bnb_4bit_compute_dtype=torch.bfloat16, |
|
bnb_4bit_use_double_quant=False) |
|
|
|
model = AutoModelForCausalLM.from_pretrained(base_model_id, attn_implementation="flash_attention_2", quantization_config=bnb_config, torch_dtype="auto") |
|
print(model) |
|
|
|
|
|
model.gradient_checkpointing_enable() |
|
|
|
|
|
model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=True) |
|
|
|
peft_config = LoraConfig( |
|
r=64, |
|
lora_alpha=64, |
|
target_modules= ["q_proj","k_proj","v_proj","dense","fc2","fc1"], |
|
bias="none", |
|
lora_dropout=0.05, |
|
task_type="CAUSAL_LM", |
|
) |
|
|
|
training_args = TrainingArguments( |
|
output_dir='./results', |
|
overwrite_output_dir=True, |
|
per_device_train_batch_size=2, |
|
per_device_eval_batch_size=2, |
|
gradient_accumulation_steps=5, |
|
gradient_checkpointing=True, |
|
gradient_checkpointing_kwargs={"use_reentrant": False}, |
|
warmup_steps=10, |
|
|
|
num_train_epochs=20, |
|
learning_rate=5e-5, |
|
weight_decay=0.01, |
|
optim="paged_adamw_8bit", |
|
bf16=True, |
|
|
|
logging_dir='./logs', |
|
logging_strategy="epoch", |
|
logging_steps=10, |
|
save_strategy="epoch", |
|
save_steps=10, |
|
save_total_limit=2, |
|
evaluation_strategy="epoch", |
|
eval_steps=10, |
|
load_best_model_at_end=True, |
|
lr_scheduler_type="linear", |
|
) |
|
|
|
def formatting_func(doc): |
|
return doc |
|
|
|
trainer = SFTTrainer( |
|
model=model, |
|
train_dataset=dataset, |
|
eval_dataset=dataset, |
|
peft_config=peft_config, |
|
args=training_args, |
|
max_seq_length=1024, |
|
packing=True, |
|
formatting_func=formatting_func |
|
) |
|
|
|
model.config.use_cache = False |
|
|
|
start_time = time.time() |
|
trainer.train() |
|
end_time = time.time() |
|
|
|
training_time = end_time - start_time |
|
|
|
trainer.save_model("./results") |
|
print(f"Training completed in {training_time} seconds.") |