holiday_testing / test_models /train_head.py
svystun-taras's picture
tested the model on all dataset
501f2e5
raw
history blame contribute delete
No virus
4.75 kB
import torch
from torch import nn
import matplotlib.pyplot as plt
import numpy as np
# import torch.nn as nn
torch.set_printoptions(sci_mode=False)
class MLP(nn.Module):
def __init__(self, input_size=768, output_size=3, dropout_rate=.2, class_weights=None):
super(MLP, self).__init__()
self.class_weights = class_weights
# self.bn1 = nn.BatchNorm1d(hidden_size)
self.dropout = nn.Dropout(dropout_rate)
self.linear = nn.Linear(input_size, output_size)
# nn.init.kaiming_normal_(self.fc1.weight, nonlinearity='relu')
# nn.init.kaiming_normal_(self.fc2.weight)
def forward(self, x):
# return self.linear(self.dropout(x))
return self.dropout(self.linear(x))
def predict(self, x):
_, predicted = torch.max(self.forward(x), 1)
print('I am predict')
return predicted
def predict_proba(self, x):
print('I am predict_proba')
return self.forward(x)
def get_loss_fn(self):
return nn.CrossEntropyLoss(weight=self.class_weights, reduction='mean')
if __name__ == '__main__':
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
import sys
# from datetime import datetime
# from collections import Counter
import torch
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from safetensors.torch import load_model, save_model
from sklearn.utils.class_weight import compute_class_weight
import warnings
from train_classificator import (
# MLP,
plot_labels_distribution,
plot_training_metrics,
train_model,
eval_model
)
warnings.filterwarnings("ignore")
SEED = 1003200212 + 1
DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
dataset = load_dataset("CabraVC/vector_dataset_roberta-fine-tuned")
# plot_labels_distribution(dataset
# # , save_as_filename=f'plots/labels_distribution_{datetime.now().strftime("%Y-%m-%d_%H-%M")}.png'
# )
input_size = len(dataset['train']['embeddings'][0])
learning_rate = 5e-4
weight_decay = 0
batch_size = 128
epochs = 40
class_weights = torch.tensor(compute_class_weight('balanced', classes=[0, 1, 2], y=dataset['train']['labels']), dtype=torch.float) ** .5
model = MLP(input_size=input_size, class_weights=class_weights)
criterion = model.get_loss_fn()
test_data = TensorDataset(torch.tensor(dataset['test']['embeddings']), torch.tensor(dataset['test']['labels']))
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=True)
loss, accuracy = eval_model(model, criterion, test_loader, test_data, show=False,
# save_as_filename=f'plots/confusion_matrix_{datetime.now().strftime("%Y-%m-%d_%H-%M")}.png'
)
optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=.2, patience=5, threshold=1e-4, min_lr=1e-7, verbose=True)
train_data = TensorDataset(torch.tensor(dataset['train']['embeddings']), torch.tensor(dataset['train']['labels']))
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
val_data = TensorDataset(torch.tensor(dataset['val']['embeddings']), torch.tensor(dataset['val']['labels']))
val_loader = DataLoader(val_data, batch_size=batch_size, shuffle=True)
losses, accuracies = train_model(model, criterion, optimizer, lr_scheduler, train_loader, val_loader, train_data, val_data, epochs)
plot_training_metrics(losses, accuracies
# , save_as_filename=f'plots/training_metrics_plot_{datetime.now().strftime("%Y-%m-%d_%H-%M")}.png'
)
test_data = TensorDataset(torch.tensor(dataset['test']['embeddings']), torch.tensor(dataset['test']['labels']))
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=True)
loss, accuracy = eval_model(model, criterion, test_loader, test_data, show=False
# save_as_filename=f'plots/confusion_matrix_{datetime.now().strftime("%Y-%m-%d_%H-%M")}.png'
)
# torch.save(model.state_dict(), f'models/linear_head.pth')
# save_model(model, f'models/linear_head.safetensors')
# load_model(model, f'models/linear_head.safetensors')
# print(model)
# dataset.push_to_hub(f'CabraVC/vector_dataset_stratified_ttv_split_{datetime.now().strftime("%Y-%m-%d_%H-%M")}', private=True)