holiday_testing / test_models /create_setfit_model.py
svystun-taras's picture
fixed a bug
ca57df8
raw
history blame contribute delete
No virus
4.35 kB
import torch
from torch import nn
from sentence_transformers import SentenceTransformer
from datasets import load_dataset
from sklearn.utils.class_weight import compute_class_weight
from safetensors.torch import load_model
from setfit.__init__ import SetFitModel
DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
class MLP(nn.Module):
def __init__(self, input_size=768, output_size=3, dropout_rate=.2, class_weights=None):
super(MLP, self).__init__()
self.class_weights = class_weights
# self.bn1 = nn.BatchNorm1d(hidden_size)
self.dropout = nn.Dropout(dropout_rate)
self.linear = nn.Linear(input_size, output_size)
# nn.init.kaiming_normal_(self.fc1.weight, nonlinearity='relu')
# nn.init.kaiming_normal_(self.fc2.weight)
def forward(self, x):
# return self.linear(self.dropout(x))
return self.dropout(self.linear(x))
def predict(self, x):
_, predicted = torch.max(self.forward(x), 1)
return predicted
def predict_proba(self, x):
return self.forward(x)
def get_loss_fn(self):
return nn.CrossEntropyLoss(weight=self.class_weights, reduction='mean')
dataset = load_dataset("CabraVC/vector_dataset_roberta-fine-tuned")
class_weights = torch.tensor(compute_class_weight('balanced', classes=[0, 1, 2], y=dataset['train']['labels']), dtype=torch.float) ** .5
model_head = MLP(class_weights=class_weights)
if __name__ == '__main__' or __name__ == 'create_setfit_model':
model_body = SentenceTransformer('financial-roberta')
load_model(model_head, f'models/linear_head.safetensors')
elif __name__ == 'test_models.create_setfit_model':
model_body = SentenceTransformer('test_models/financial-roberta')
load_model(model_head, f'/home/user/app/test_models/models/linear_head.safetensors')
model = SetFitModel(model_body=model_body,
model_head=model_head,
labels=dataset['train'].features['labels'].names).to(DEVICE)
if __name__ == '__main__':
from time import perf_counter
start = perf_counter()
test_sentences = [
"""Two thousand and six was a very good year for The Coca-Cola Company. We achieved our 52nd
consecutive year of unit case volume growth. Volume reached a record high of 2.4 billion unit cases.
Net operating revenues grew 4 percent to $24.billion, and operating income grew
4 percent to $6.3 billion. Our total return to shareowners was 23 percent, outperforming the Dow
Jones Industrial Average and the S&P 500. By virtually every measure, we met or exceeded our
objectives—a strong ending for the year with great momentum for entering 2007.""",
"""
The secret formula to our success in 2006? There is no one answer. Our inspiration comes from
many sources—our bottling partners, retail customers and consumers, as well as our critics. And the
men and women of The Coca-Cola Company have a passion for what they do that ignites this
inspiration every day, everywhere we do business. We remain fresh, relevant and original by knowing
what
to change without changing what we know. We are asking more questions, listening more closely and
collaborating more effectively with our bottling partners, suppliers and retail customers to give
consumers what they want.
""",
"""
And we continue to strengthen our bench, nurturing leaders and promoting from within our
organization. As 2006 came to a close, our Board of Directors elected Muhtar Kent as president and
chief operating officer of our Company. Muhtar is a 28-year veteran of the Coca-Cola system (the
Company and our bottling partners). Muhtar’s close working relationships with our bottling partners
will enable us to continue capturing marketplace opportunities and improving our business. Other
system veterans promoted and now leading operating groups include Ahmet Bozer, Eurasia; Sandy
Douglas, North America; and Glenn Jordan, Pacific. Combined, these leaders have 65 years of Coca-
Cola system experience.
"""
]
# for sentence in test_sentences:
# print(model(sentence))
# print('-' * 50)
print(model(test_sentences))
print(f'It took me: {(perf_counter() - start) // 60:.0f} mins {(perf_counter() - start) % 60:.0f} secs')