qwen-110b / 情感.py
bokobring's picture
Upload 4 files
9f90d6f verified
# coding=utf-8
import os
import re
import torch
import torch.nn as nn
import numpy as np
from transformers import BertModel, BertTokenizer
from tqdm import tqdm
class Config(object):
"""配置参数"""
def __init__(self):
self.model_name = 'bert'
self.class_list = ['中性', '积极', '消极'] # 类别名单
self.save_path = 'bert.ckpt' # 模型训练结果
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # 设备
self.require_improvement = 1000
self.num_classes = len(self.class_list) # 类别数
self.num_epochs = 3
self.batch_size = 128
self.pad_size = 32
self.learning_rate = 5e-5
self.bert_path = 'bert-base-chinese'
self.tokenizer = BertTokenizer.from_pretrained(self.bert_path)
self.hidden_size = 768
class Model(nn.Module):
def __init__(self, config):
super(Model, self).__init__()
self.bert = BertModel.from_pretrained(config.bert_path)
for param in self.bert.parameters():
param.requires_grad = True
self.fc = nn.Linear(config.hidden_size, config.num_classes)
def forward(self, x):
context = x[0]
mask = x[2]
outputs = self.bert(context, attention_mask=mask)
pooled = outputs[1]
out = self.fc(pooled)
return out
def clean(text):
URL_REGEX = re.compile(
r'(?i)\b((?:https?://|www\d{0,3}[.]|[a-z0-9.\-]+[.][a-z]{2,4}/)(?:[^\s()<>]+|\(([^\s()<>]+|(\([^\s()<>]+\)))*\))+(?:\(([^\s()<>]+|(\([^\s()<>]+\)))*\)|[^\s`!()\[\]{};:\'".,<>?«»“”‘’]))',
re.IGNORECASE)
text = re.sub(URL_REGEX, "", text)
text = text.replace("转发微博", "")
text = re.sub(r"\s+", " ", text)
return text.strip()
def load_dataset(data, config):
pad_size = config.pad_size
contents = []
for line in data:
lin = clean(line)
token = config.tokenizer.tokenize(lin)
token = ['[CLS]'] + token
seq_len = len(token)
mask = []
token_ids = config.tokenizer.convert_tokens_to_ids(token)
if pad_size:
if len(token) < pad_size:
mask = [1] * len(token_ids) + [0] * (pad_size - len(token))
token_ids += ([0] * (pad_size - len(token)))
else:
mask = [1] * pad_size
token_ids = token_ids[:pad_size]
seq_len = pad_size
contents.append((token_ids, int(0), seq_len, mask))
return contents
class DatasetIterater(object):
def __init__(self, batches, batch_size, device):
self.batch_size = batch_size
self.batches = batches
self.n_batches = len(batches) // batch_size
self.residue = False
if len(batches) % self.n_batches != 0:
self.residue = True
self.index = 0
self.device = device
def _to_tensor(self, datas):
x = torch.LongTensor([_[0] for _ in datas]).to(self.device)
y = torch.LongTensor([_[1] for _ in datas]).to(self.device)
seq_len = torch.LongTensor([_[2] for _ in datas]).to(self.device)
mask = torch.LongTensor([_[3] for _ in datas]).to(self.device)
return (x, seq_len, mask), y
def __next__(self):
if self.residue and self.index == self.n_batches:
batches = self.batches[self.index * self.batch_size: len(self.batches)]
self.index += 1
batches = self._to_tensor(batches)
return batches
elif self.index >= self.n_batches:
self.index = 0
raise StopIteration
else:
batches = self.batches[self.index * self.batch_size: (self.index + 1) * self.batch_size]
self.index += 1
batches = self._to_tensor(batches)
return batches
def __iter__(self):
return self
def __len__(self):
if self.residue:
return self.n_batches + 1
else:
return self.n_batches
def build_iterator(dataset, config):
iter = DatasetIterater(dataset, 1, config.device)
return iter
def match_label(pred, config):
label_list = config.class_list
return label_list[pred]
def load_data(file_path, config):
if not os.path.isfile(file_path):
raise Exception(f"File not found: {file_path}")
with open(file_path, 'r', encoding='utf-8') as f:
lines = f.readlines()
data = load_dataset(lines, config)
return data, lines
def final_predict(config, model, data_iter):
map_location = lambda storage, loc: storage
if not os.path.isfile(config.save_path):
raise Exception(f"File not found: {config.save_path}")
model.load_state_dict(torch.load(config.save_path, map_location=map_location))
model.eval()
predict_all = np.array([], dtype=str)
with torch.no_grad():
for texts, _ in tqdm(data_iter, desc="Predicting"):
outputs = model(texts)
pred = torch.max(outputs.data, 1)[1].cpu().numpy()
predict_all = np.append(predict_all, [match_label(i, config) for i in pred])
return predict_all
def output_results(file_path, results, lines):
pos = sum(1 for result in results if result == '积极')
neg = sum(1 for result in results if result == '消极')
neu = sum(1 for result in results if result == '中性')
with open(file_path, 'w', encoding='utf-8') as f:
f.write(f'评论数量:{len(results)}\n')
f.write(f'积极评论数量:{pos}\n')
f.write(f'消极评论数量:{neg}\n')
f.write(f'中性评论数量:{neu}\n')
for line, result in zip(lines, results):
f.write(f'评论:{line.strip()} ,情感:{result}\n')
def predict(input_file, output_file):
config = Config()
model = Model(config).to(config.device)
test_data, lines = load_data(input_file, config)
test_iter = build_iterator(test_data, config)
results = final_predict(config, model, test_iter)
output_results(output_file, results, lines)
if __name__ == '__main__':
input_file = input("请输入评论文件的名字(相对于程序文件的路径):")
output_file = input("请输入结果输出文件的名字(相对于程序文件的路径):")
try:
predict(input_file, output_file)
except Exception as e:
print(f"Error: {e}")