nlp_project / task2.py
Tatiana
files added
dd3dbad
from transformers import BertTokenizer, BertForSequenceClassification
import torch
from sklearn.preprocessing import LabelEncoder
from transformers import BertTokenizer, BertForSequenceClassification
import torch
from sklearn.preprocessing import LabelEncoder
labels = ['мода', 'спорт', 'технологии', 'финансы', 'крипта']
label_encoder = LabelEncoder()
label_encoder.fit(labels)
# Загрузка сохраненной модели и токенизатора в Streamlit
loaded_model_path = "rubert-base-cased"
loaded_tokenizer_path = BertForSequenceClassification.from_pretrained(loaded_model_path)
# Инициализация модели и токенизатора
loaded_model = BertForSequenceClassification.from_pretrained(loaded_model_path)
loaded_tokenizer = BertTokenizer.from_pretrained(loaded_tokenizer_path)
# Создание модели с архитектурой BertForSequenceClassification
# Передайте в аргумент `num_labels` количество классов, для которых модель будет выполнять классификацию
model = BertForSequenceClassification(num_labels=len(labels))
# Загрузка весов из сохраненного файла
weights_path = "model_weights_epoch_8.pt"
state_dict = torch.load(weights_path, map_location='cpu') # Укажите 'cuda' вместо 'cpu', если используете GPU
model.load_state_dict(state_dict)
# Пример использования загруженной модели
user_input = "Ваш текст для классификации"
predicted_class = predict_class(user_input, model=model, tokenizer=loaded_tokenizer, label_encoder=label_encoder)
print(predicted_class)
# #Загрузка сохраненной модели и токенизатора в Streamlit
# loaded_model_path = "nlp_project/model"
# loaded_tokenizer_path = "nlp_project/tokenizer"
# loaded_model = BertForSequenceClassification.from_pretrained(loaded_model_path)
# loaded_tokenizer = BertTokenizer.from_pretrained(loaded_tokenizer_path)
def predict_class(user_input, model=loaded_model, tokenizer=loaded_tokenizer, label_encoder=label_encoder, max_length=128):
if not user_input:
return "Введите текст"
def tokenize_text(text):
encoded_text = tokenizer.encode_plus(
text,
add_special_tokens=True,
max_length=max_length,
pad_to_max_length=True,
return_attention_mask=True,
return_tensors='pt'
)
return encoded_text
encoded_text = tokenize_text(user_input)
with torch.no_grad():
model.eval()
input_ids = encoded_text['input_ids']
attention_mask = encoded_text['attention_mask']
outputs = model(input_ids, attention_mask=attention_mask)
logits = outputs.logits
predicted_class_index = torch.argmax(logits, dim=1).item()
# Получение названия класса
predicted_class = label_encoder.classes_[predicted_class_index]
return predicted_class