Spaces:
No application file
No application file
from transformers import BertTokenizer, BertForSequenceClassification | |
import torch | |
from sklearn.preprocessing import LabelEncoder | |
from transformers import BertTokenizer, BertForSequenceClassification | |
import torch | |
from sklearn.preprocessing import LabelEncoder | |
labels = ['мода', 'спорт', 'технологии', 'финансы', 'крипта'] | |
label_encoder = LabelEncoder() | |
label_encoder.fit(labels) | |
# Загрузка сохраненной модели и токенизатора в Streamlit | |
loaded_model_path = "rubert-base-cased" | |
loaded_tokenizer_path = BertForSequenceClassification.from_pretrained(loaded_model_path) | |
# Инициализация модели и токенизатора | |
loaded_model = BertForSequenceClassification.from_pretrained(loaded_model_path) | |
loaded_tokenizer = BertTokenizer.from_pretrained(loaded_tokenizer_path) | |
# Создание модели с архитектурой BertForSequenceClassification | |
# Передайте в аргумент `num_labels` количество классов, для которых модель будет выполнять классификацию | |
model = BertForSequenceClassification(num_labels=len(labels)) | |
# Загрузка весов из сохраненного файла | |
weights_path = "model_weights_epoch_8.pt" | |
state_dict = torch.load(weights_path, map_location='cpu') # Укажите 'cuda' вместо 'cpu', если используете GPU | |
model.load_state_dict(state_dict) | |
# Пример использования загруженной модели | |
user_input = "Ваш текст для классификации" | |
predicted_class = predict_class(user_input, model=model, tokenizer=loaded_tokenizer, label_encoder=label_encoder) | |
print(predicted_class) | |
# #Загрузка сохраненной модели и токенизатора в Streamlit | |
# loaded_model_path = "nlp_project/model" | |
# loaded_tokenizer_path = "nlp_project/tokenizer" | |
# loaded_model = BertForSequenceClassification.from_pretrained(loaded_model_path) | |
# loaded_tokenizer = BertTokenizer.from_pretrained(loaded_tokenizer_path) | |
def predict_class(user_input, model=loaded_model, tokenizer=loaded_tokenizer, label_encoder=label_encoder, max_length=128): | |
if not user_input: | |
return "Введите текст" | |
def tokenize_text(text): | |
encoded_text = tokenizer.encode_plus( | |
text, | |
add_special_tokens=True, | |
max_length=max_length, | |
pad_to_max_length=True, | |
return_attention_mask=True, | |
return_tensors='pt' | |
) | |
return encoded_text | |
encoded_text = tokenize_text(user_input) | |
with torch.no_grad(): | |
model.eval() | |
input_ids = encoded_text['input_ids'] | |
attention_mask = encoded_text['attention_mask'] | |
outputs = model(input_ids, attention_mask=attention_mask) | |
logits = outputs.logits | |
predicted_class_index = torch.argmax(logits, dim=1).item() | |
# Получение названия класса | |
predicted_class = label_encoder.classes_[predicted_class_index] | |
return predicted_class | |