File size: 3,251 Bytes
0fcf2a3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 |
import json
import numpy as np
import torch
from pyspark.sql import SparkSession
from rocketmq.client import PushConsumer
from rocketmq.ffi import _CConsumeStatus
from torch.utils.data import DataLoader, Dataset
from BILSTM_Att import BiLSTMModelWithAttention, LOLDataset
# 创建SparkSession
spark = SparkSession.builder \
.appName("BiLSTM_Predict") \
.master("spark://master:7077") \
.config("spark.executor.memory", "2g") \
.config("spark.executor.cores", "2") \
.config("spark.num.executors", "4") \
.getOrCreate()
# 修改的LOLDataset类
class LOLDataset(Dataset):
def __init__(self, data):
self.data = torch.tensor(data, dtype=torch.float32) # 转换为张量
self.X = self.data.view(self.data.size(0), 1, self.data.size(1)) # 调整形状
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
X = self.X[idx] # 获取张量
return X
# 预测函数
def predict(data):
predict_data = np.array([data])
# 调试输出
print(f"Received data: {predict_data}")
if predict_data.size == 0 or predict_data[0] is None:
print("Received invalid data")
return
batch_size = predict_data.shape[0]
input_size = predict_data.shape[1]
hidden_size = 1024
num_layers = 2
output_size = 1
# 加载测试数据
test_dataset = LOLDataset(predict_data)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
# 加载模型
model = BiLSTMModelWithAttention(input_size, hidden_size, num_layers, output_size)
model.load_state_dict(torch.load('BILSTM_Att.pt', map_location=torch.device('cpu'))) # 加载训练好的模型
model.eval() # 将模型设置为评估模式
# 进行预测
predictions = []
with torch.no_grad():
for X_batch in test_loader:
outputs = model(X_batch)
predictions.extend(outputs.cpu().numpy())
# 将预测结果转换为二分类(例如,大于0.5为正类,小于等于0.5为负类)
predictions = np.array(predictions).flatten()
predictions_binary = np.where(predictions > 0.5, "A队胜利", "B队胜利")
A_win = predictions[0] * 100
B_win = (1 - predictions[0]) * 100
# 输出预测结果
print(f"A队胜率:{A_win:.2f}%, B队胜率:{B_win:.2f}%, 胜利情况:{predictions_binary[0]}")
# 设置RocketMQ消费者
consumer = PushConsumer('LOLProducerGroup')
consumer.set_namesrv_addr('master:9876')
def callback(msg):
try:
if not msg.body:
print("Received empty message body")
return _CConsumeStatus.CONSUME_SUCCESS
data = json.loads(msg.body.decode('utf-8'))
# 调试输出
print(f"Received message: {data}")
# 这里假设data包含需要的预测数据
predict(data)
except Exception as e:
print(f"Error processing message: {e}")
return _CConsumeStatus.CONSUME_SUCCESS
consumer.subscribe('LOLPredictTopic', callback)
consumer.start()
try:
while True:
pass
except KeyboardInterrupt:
consumer.shutdown()
spark.stop()
|