Text Generation
Transformers
English
legal
chat
transformer
CyberFuture-A1 / model.py
SkillForge45's picture
Create model.py
12ea3de verified
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset
from transformers import AutoTokenizer
from tqdm import tqdm
import math
# 1. Dataset class for loading and processing data
class FullChatDataset(Dataset):
def __init__(self, dataset_names=["blended_skill_talk", "conv_ai_2", "social_i_qa"], max_length=128):
self.datasets = []
# Load all specified datasets
for name in dataset_names:
try:
dataset = load_dataset(name, split="train")
self.datasets.append(dataset)
except Exception as e:
print(f"Failed to load dataset {name}: {e}")
self.tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
self.tokenizer.add_special_tokens({'pad_token': '[PAD]'})
self.max_length = max_length
def __len__(self):
return sum(len(d) for d in self.datasets)
def __getitem__(self, idx):
# Determine which dataset the index belongs to
for dataset in self.datasets:
if idx < len(dataset):
item = dataset[idx]
break
idx -= len(dataset)
# Handling different dataset formats
if 'dialog' in item: # For Daily Dialog
dialog = item['dialog']
elif 'messages' in item: # For some other datasets
dialog = [msg['text'] for msg in item['messages']]
else: # Universal handling
dialog = [v for k, v in item.items() if isinstance(v, str)]
context = " [SEP] ".join(dialog[:-1])
response = dialog[-1]
inputs = self.tokenizer(
context,
text_pair=response,
max_length=self.max_length,
padding='max_length',
truncation=True,
return_tensors="pt"
)
return {
'input_ids': inputs['input_ids'].flatten(),
'attention_mask': inputs['attention_mask'].flatten(),
'labels': inputs['input_ids'].flatten()
}
# 2. Model architecture
class SimpleTransformerModel(nn.Module):
def __init__(self, vocab_size, d_model=256, nhead=4, num_layers=3):
super().__init__()
self.embedding = nn.Embedding(vocab_size, d_model)
self.pos_encoder = PositionalEncoding(d_model)
encoder_layer = nn.TransformerEncoderLayer(d_model, nhead)
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers)
self.fc = nn.Linear(d_model, vocab_size)
def forward(self, x, mask=None):
x = self.embedding(x)
x = self.pos_encoder(x)
x = self.transformer(x, mask)
return self.fc(x)
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_len=500):
super().__init__()
position = torch.arange(max_len).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
pe = torch.zeros(max_len, d_model)
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
self.register_buffer('pe', pe)
def forward(self, x):
return x + self.pe[:x.size(1)]
# 3. Model training
def train(model, dataloader, epochs=3, lr=3e-4):
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
criterion = nn.CrossEntropyLoss(ignore_index=0)
optimizer = optim.Adam(model.parameters(), lr=lr)
for epoch in range(epochs):
model.train()
total_loss = 0
pbar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs}")
for batch in pbar:
inputs = batch['input_ids'].to(device)
masks = batch['attention_mask'].to(device)
labels = batch['labels'].to(device)
optimizer.zero_grad()
outputs = model(inputs, masks)
loss = criterion(outputs.view(-1, outputs.size(-1)), labels.view(-1))
loss.backward()
optimizer.step()
total_loss += loss.item()
pbar.set_postfix({'loss': loss.item()})
print(f"Epoch {epoch+1} - Avg loss: {total_loss/len(dataloader):.4f}")
# 4. Response generation
def chat(model, tokenizer, prompt, max_length=50):
device = next(model.parameters()).device
model.eval()
inputs = tokenizer(
prompt,
return_tensors="pt",
max_length=128,
truncation=True,
padding='max_length'
).to(device)
with torch.no_grad():
outputs = model.generate(
input_ids=inputs['input_ids'],
attention_mask=inputs['attention_mask'],
max_length=max_length,
do_sample=True,
top_k=50,
top_p=0.95,
temperature=0.7
)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
# 5. Main process
if __name__ == "__main__":
# Initialization
dataset = FullChatDataset()
dataloader = DataLoader(dataset, batch_size=16, shuffle=True)
# Model creation
model = SimpleTransformerModel(len(dataset.tokenizer))
# Training
train(model, dataloader)
# Saving
torch.save(model.state_dict(), "chatbot_model.pt")
dataset.tokenizer.save_pretrained("chatbot_tokenizer")
while True:
user_input = input("You: ")
if user_input.lower() in ['exit', 'quit']:
break
response = chat(model, dataset.tokenizer, user_input)
print(f"Bot: {response}")