simple_chat_bot / app.py
Kiet2302's picture
Update app.py
85af8eb verified
import streamlit as st
from datasets import load_dataset
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import torch
ds = load_dataset("higgsfield/school-math-questions")
qa_pairs = [(item['prompt'], item['completion']) for item in ds['train']]
class MathDataset(torch.utils.data.Dataset):
def __init__(self, qa_pairs, tokenizer, max_length=128):
self.qa_pairs = qa_pairs
self.tokenizer = tokenizer
self.max_length = max_length
def __len__(self):
return len(self.qa_pairs)
def __getitem__(self, idx):
question, answer = self.qa_pairs[idx]
input_text = f"Q: {question} A:"
# Tokenize and pad input and target sequences
input_ids = self.tokenizer.encode(input_text, truncation=True, padding="max_length", max_length=self.max_length, return_tensors="pt").squeeze(0)
target_ids = self.tokenizer.encode(answer.strip(), truncation=True, padding="max_length", max_length=self.max_length, return_tensors="pt").squeeze(0)
# Set the labels to -100 where input_ids are padding tokens
target_ids[target_ids == self.tokenizer.pad_token_id] = -100
return {
"input_ids": input_ids,
"labels": target_ids,
}
model_name = "gpt2"
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
model = GPT2LMHeadModel.from_pretrained(model_name)
math_dataset = MathDataset(qa_pairs, tokenizer)
from transformers import Trainer, TrainingArguments
# Set training arguments
training_args = TrainingArguments(
output_dir="./results",
num_train_epochs=3,
per_device_train_batch_size=2,
save_steps=10,
save_total_limit=2,
)
# Create a Trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=math_dataset,
)
# Fine-tune the model
trainer.train()
class MathChatBot:
def __init__(self, model_name="gpt2"):
self.tokenizer = GPT2Tokenizer.from_pretrained(model_name)
self.model = GPT2LMHeadModel.from_pretrained(model_name)
def get_response(self, question):
input_text = f"Q: {question} A:"
input_ids = self.tokenizer.encode(input_text, return_tensors="pt")
output = self.model.generate(input_ids, max_length=50, num_return_sequences=1)
answer = self.tokenizer.decode(output[0], skip_special_tokens=True)
return answer.split("A:")[-1].strip()
# Usage
if __name__ == "__main__":
bot = MathChatBot()
user_input = st.text_area("Enter your question:")
response = bot.get_response(user_input)
st.write(f"Bot: {response}")