nlp_project_gpt_team / funcs /nastya_funcs.py
Seppukku's picture
initial commit
8fb2bb2
import time
import joblib
import re
import string
import pymorphy3
import torch
from transformers import BertModel, BertTokenizer
from torch import nn
model_name = "cointegrated/rubert-tiny2"
tokenizer = BertTokenizer.from_pretrained(model_name)
bert_model = BertModel.from_pretrained(model_name)
class MyTinyBERT(nn.Module):
def __init__(self):
super().__init__()
self.bert = bert_model
for param in self.bert.parameters():
param.requires_grad = False
self.linear = nn.Sequential(
nn.Linear(312, 256),
nn.Sigmoid(),
nn.Dropout(),
nn.Linear(256, 6)
)
def forward(self, input_ids, attention_mask=None):
# Pass the input_ids and attention_mask to the BERT model
bert_out = self.bert(input_ids=input_ids, attention_mask=attention_mask)
# Normalize the output from BERT
normed_bert_out = nn.functional.normalize(bert_out.last_hidden_state[:, 0, :])
# Pass through the linear layer
out = self.linear(normed_bert_out)
return out
weights_path = "models/clf_rewievs_bert.pt"
model = MyTinyBERT()
model.load_state_dict(torch.load(weights_path, map_location=torch.device('cpu')))
model.to('cpu')
# tokenizer = transformers.AutoTokenizer.from_pretrained("cointegrated/rubert-tiny2")
# bert_model = transformers.AutoModel.from_pretrained("cointegrated/rubert-tiny2")
# weights_path = "./model_weights.pt" # Replace with your .pt file path
# bert_model.load_state_dict(torch.load('models/clf_rewievs_bert.pt', map_location=torch.device('cpu')))
# bert_model.to('cpu')
morph = pymorphy3.MorphAnalyzer()
def lemmatize(text):
words = text.split()
lem_words = [morph.parse(word)[0].normal_form for word in words]
return " ".join(lem_words)
logreg = joblib.load('models/logregmodel_restaurants.pkl')
vectorizer = joblib.load('models/tfidf_vectorizer_restaurants.pkl')
with open(
"funcs/stopwords-ru.txt", "r", encoding="utf-8"
) as file:
stop_words = set(file.read().split())
rating_dict = {
1: "Отвратительно",
2: "Плохо",
3: "Удовлетворительно",
4: "Хорошо",
5: "Великолепно",}
emoji_pattern = re.compile(
"["
"\U0001F600-\U0001F64F" # Emoticons
"\U0001F300-\U0001F5FF" # Symbols & Pictographs
"\U0001F680-\U0001F6FF" # Transport & Map Symbols
"\U0001F1E0-\U0001F1FF" # Flags (iOS)
"\U00002700-\U000027BF" # Dingbats
"\U0001F900-\U0001F9FF" # Supplemental Symbols and Pictographs
"\U00002600-\U000026FF" # Miscellaneous Symbols
"\U00002B50-\U00002B55" # Miscellaneous Symbols and Pictographs
"\U0001FA70-\U0001FAFF" # Symbols and Pictographs Extended-A
"\U0001F700-\U0001F77F" # Alchemical Symbols
"\U0001F780-\U0001F7FF" # Geometric Shapes Extended
"\U0001F800-\U0001F8FF" # Supplemental Arrows-C
"\U0001F900-\U0001F9FF" # Supplemental Symbols and Pictographs
"\U0001FA00-\U0001FA6F" # Chess Symbols
"]+",
flags=re.UNICODE,
)
def clean(text, stopwords):
text = text.lower() # нижний регистр
text = re.sub(r"http\S+", " ", text) # удаляем ссылки
text = re.sub(r"@\w+", " ", text) # удаляем упоминания пользователей
text = re.sub(r"#\w+", " ", text) # удаляем хэштеги
text = re.sub(r"\d+", " ", text) # удаляем числа
text = text.translate(str.maketrans("", "", string.punctuation))
text = re.sub(r"<.*?>", " ", text) #
text = re.sub(r"[️«»—]", " ", text)
text = re.sub(r"[^а-яё ]", " ", text)
text = text.lower()
text = emoji_pattern.sub(r"", text)
text = " ".join([word for word in text.split() if word not in stopwords])
return text
def predict_review(review):
start_time = time.time()
# Очистка и лемматизация текста
clean_text = clean(review, stop_words)
lem_text = lemmatize(clean_text)
# Преобразование текста в TF-IDF представление
X_new = vectorizer.transform([lem_text])
# Предсказание
prediction = logreg.predict(X_new)[0]
# Проверка допустимости предсказания
if prediction not in rating_dict:
rating = "Ошибка предсказания"
else:
rating = rating_dict[prediction]
# Измерение времени
end_time = time.time()
elapsed_time = end_time - start_time
print(f"Лейбл: {prediction}")
print(f"Оценка отзыва: {rating}")
print(f"Затраченное время: {elapsed_time:.6f} seconds")
return prediction, rating, elapsed_time
def preprocess_input(text):
inputs = tokenizer(text, return_tensors='pt', max_length=512, truncation=True, padding=True)
return inputs
def predict_bert(text):
start_time = time.time()
model.eval()
inputs = preprocess_input(text)
# Move tensors to the correct device if using GPU
inputs = {k: v.to('cpu') for k, v in inputs.items()}
# Get model predictions
with torch.no_grad():
outputs = model(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'])
# Since the output is already logits, no need to access outputs.logits
predicted_class = outputs.argmax(dim=-1).item()
end_time = time.time()
elapsed_time = end_time - start_time
return predicted_class, rating_dict[predicted_class], elapsed_time