llm-excel-plotter-agent / train_model.py
“Transcendental-Programmer”
feat: inital project files and Docker setup
d773e1b
import pandas as pd
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, Seq2SeqTrainer, Seq2SeqTrainingArguments
from sklearn.model_selection import train_test_split
data = pd.read_csv('data/train_data.csv')
queries = data['query'].tolist()
arguments = data['arguments'].tolist()
train_queries, eval_queries, train_arguments, eval_arguments = train_test_split(queries, arguments, test_size=0.2, random_state=42)
tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large")
model = AutoModelForSeq2SeqLM.from_pretrained("facebook/bart-large")
train_encodings = tokenizer(train_queries, truncation=True, padding=True)
eval_encodings = tokenizer(eval_queries, truncation=True, padding=True)
with tokenizer.as_target_tokenizer():
train_labels = tokenizer(train_arguments, truncation=True, padding=True)
eval_labels = tokenizer(eval_arguments, truncation=True, padding=True)
class PlotDataset(torch.utils.data.Dataset):
def __init__(self, encodings, labels):
self.encodings = encodings
self.labels = labels
def __getitem__(self, idx):
item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
item['labels'] = torch.tensor(self.labels['input_ids'][idx])
return item
def __len__(self):
return len(self.encodings.input_ids)
train_dataset = PlotDataset(train_encodings, train_labels)
eval_dataset = PlotDataset(eval_encodings, eval_labels)
training_args = Seq2SeqTrainingArguments(
output_dir='./results',
per_device_train_batch_size=2,
per_device_eval_batch_size=2,
num_train_epochs=3,
logging_dir='./logs',
logging_steps=10,
save_steps=500,
save_total_limit=2,
evaluation_strategy="epoch",
predict_with_generate=True,
generation_max_length=100,
)
trainer = Seq2SeqTrainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
tokenizer=tokenizer,
)
trainer.train()
trainer.save_model("fine-tuned-bart-large")
tokenizer.save_pretrained("fine-tuned-bart-large")
print("Model and tokenizer fine-tuned and saved as 'fine-tuned-bart-large'")